diff --git a/.clang-tidy b/.clang-tidy index 980f5063085..5d191e2b756 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -6,7 +6,7 @@ CheckOptions: - key: bugprone-unused-return-value.AllowCastToVoid value: true - key: cppcoreguidelines-macro-usage.AllowedRegexp - value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|DEPRECATED|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_|_PP_' + value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|DEPRECATED|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_DEFINE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_|_PP_' - key: modernize-loop-convert.MinConfidence value: risky - key: modernize-loop-convert.NamingStyle diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7a1d00a4faa..2bef914f72a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -137,11 +137,11 @@ add_library(migraphx serialize.cpp shape.cpp shape_transform_descriptor.cpp - sym.cpp simplify_algebra.cpp simplify_dyn_ops.cpp simplify_reshapes.cpp split_single_dyn_dim.cpp + sym.cpp target.cpp tmp_dir.cpp truncate_float.cpp diff --git a/src/include/migraphx/hash.hpp b/src/include/migraphx/hash.hpp index 44e9b155df8..85e62dd79f9 100644 --- a/src/include/migraphx/hash.hpp +++ b/src/include/migraphx/hash.hpp @@ -25,23 +25,42 @@ #define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP #include +#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { template -auto hash_value(const T& v) -> decltype(std::hash{}(v)) +auto hash_value(rank<2>, const T& v) -> decltype(std::hash{}(v)) { return std::hash{}(v); } +template +auto hash_value(rank<1>, const T& v) -> decltype(v.hash()) +{ + return v.hash(); +} + +template +auto hash_value(const T& v) -> decltype(hash_value(rank<2>{}, v)) +{ + return hash_value(rank<2>{}, v); +} + template void hash_combine(std::size_t& seed, const T& v) { seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6u) + (seed >> 2u); } +template +void hash_range(std::size_t& seed, Iterator first, Iterator last) +{ + std::for_each(first, last, [&](const auto& x) { hash_combine(seed, x); }); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP diff --git a/src/include/migraphx/ranges.hpp b/src/include/migraphx/ranges.hpp index 0a50c0eeb34..0c7efbe0768 100644 --- a/src/include/migraphx/ranges.hpp +++ b/src/include/migraphx/ranges.hpp @@ -269,6 +269,10 @@ struct iterator_range Iterator start; Iterator last; + bool empty() const { return start == last; } + + std::size_t size() const { return std::distance(start, last); } + Iterator begin() const { return start; } Iterator end() const { return last; } diff --git a/src/include/migraphx/simple_parser.hpp b/src/include/migraphx/simple_parser.hpp new file mode 100644 index 00000000000..c6b1b3ce715 --- /dev/null +++ b/src/include/migraphx/simple_parser.hpp @@ -0,0 +1,229 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_SIMPLE_PARSER_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_SIMPLE_PARSER_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace parser { + +template > +struct simple_parser +{ + View buffer; + Iterator pos = buffer.begin(); + + static View make_view(Iterator begin, Iterator end) + { + if constexpr(std::is_constructible{}) + { + auto n = std::distance(begin, end); + if(n == 0) + return {}; + return {std::addressof(*begin), static_cast(n)}; + } + else + { + return {begin, end}; + } + } + + View peek() const + { + if(pos >= buffer.end()) + return {}; + return make_view(pos, buffer.end()); + } + + void advance(std::size_t n) + { + pos += n; + if(pos > buffer.end()) + MIGRAPHX_THROW("Parser advanced past end of buffer"); + if constexpr(AutoSkipWhitespace) + { + pos = std::find_if( + pos, buffer.end(), [](unsigned char c) { return not std::isspace(c); }); + } + } + + template + View parse_while(Pred p) + { + auto start = pos; + auto it = std::find_if(pos, buffer.end(), [&](auto c) { return not p(c); }); + auto n = std::distance(pos, it); + advance(n); + return make_view(start, it); + } + + bool starts_with(const View& prefix) const + { + auto tail = peek(); + if(prefix.size() > tail.size()) + return false; + else + return std::equal(prefix.begin(), prefix.end(), tail.begin()); + } + + bool done() const { return pos >= buffer.end(); } + + bool match(const View& prefix) + { + if(not starts_with(prefix)) + return false; + advance(prefix.size()); + return true; + } + + void expect(const View& str) + { + if(not starts_with(str)) + MIGRAPHX_THROW(error_message("'" + std::string{str} + "'")); + advance(str.size()); + } + + char peek_char() const + { + if(done()) + return '\0'; + return *pos; + } + + std::string error_message(std::string_view expected) const + { + auto offset = std::distance(buffer.begin(), pos); + return "Expected " + std::string(expected) + " at position " + std::to_string(offset) + + " in '" + std::string(buffer) + "'"; + } + + template + bool try_parse(F f) + { + auto copy = *this; + f(*this); + if(copy.pos != pos) + return true; + *this = copy; + return false; + } + + View first_of(View view) + { + if(match(view)) + return view; + return {}; + } + + template + View first_of(View view, Views... views) + { + if(match(view)) + return view; + return first_of(views...); + } + + template + auto first_of(F f) -> decltype(f(*this)) + { + return f(*this); + } + + template + auto first_of(F f, G g, Fs... fs) -> decltype(f(*this)) + { + auto copy = *this; + auto result = f(*this); + if(copy.pos != pos) + return result; + *this = copy; + return first_of(g, fs...); + } + + template + auto repeat(F f) + { + using result_type = decltype(f(*this)); + std::vector results; + for(;;) + { + auto copy = *this; + results.push_back(f(*this)); + if(copy.pos == pos) + { + results.pop_back(); + break; + } + } + return results; + } +}; + +template +struct parser_action +{ + F fn; + + template + auto operator()(Parser& p) const -> decltype(fn(p)) + { + return fn(p); + } +}; + +template +parser_action> action(F&& f) +{ + return {std::forward(f)}; +} + +template +auto operator|(parser_action a, parser_action b) +{ + return action([a = std::move(a), b = std::move(b)](auto& p) -> decltype(a(p)) { + return p.first_of(a, b); + }); +} + +template +auto operator*(parser_action a) +{ + return action([a = std::move(a)](auto& p) { return p.repeat(a); }); +} + +using simple_string_view_skip_parser = + simple_parser; + +} // namespace parser +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_SIMPLE_PARSER_HPP diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 59d3f842884..f79316e2726 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -21,17 +21,25 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SYM_HPP -#define MIGRAPHX_GUARD_MIGRAPHLIB_SYM_HPP +#ifndef MIGRAPHX_GUARD_SYM_HPP +#define MIGRAPHX_GUARD_SYM_HPP -#include -#include +#include +#include +#include +#include #include #include +#include #include #include - +#include +#include +#include #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -40,53 +48,271 @@ struct value; namespace sym { -struct expr; -MIGRAPHX_EXPORT expr var(const std::string& name); -MIGRAPHX_EXPORT expr lit(int64_t n); -MIGRAPHX_EXPORT expr parse(const std::string& s); +using scalar = std::variant; + +template {})> +scalar make_scalar(T v) +{ + if constexpr(std::is_unsigned{} and sizeof(T) >= sizeof(int64_t)) + return int64_t(std::min(v, std::numeric_limits::max())); + else if constexpr(std::is_integral{}) + return int64_t(v); + else + return double(v); +} + +template +To to(const scalar& v) +{ + return std::visit([](auto x) -> To { return x; }, v); +} + +template +scalar scalar_invoke(F f, const Ts&... vs) +{ + return std::visit([&](auto... xs) -> scalar { return f(xs...); }, vs...); +} -struct MIGRAPHX_EXPORT expr +template +scalar scalar_invoke_common(F f, const Ts&... xs) { - expr(); + if((std::holds_alternative(xs) and ...)) + return f(std::get(xs)...); + return f(to(xs)...); +} + +template +R scalar_invoke_common(F f, const Ts&... xs) +{ + if((std::holds_alternative(xs) and ...)) + return f(std::get(xs)...); + return f(to(xs)...); +} + +scalar scalar_min(const scalar& a, const scalar& b); +scalar scalar_max(const scalar& a, const scalar& b); + +template +auto unpack_container(F f) +{ + return [=](auto&& c) { + if(c.size() != N) + MIGRAPHX_THROW("Mismatch number of inputs"); + return sequence_c([&](auto... is) { return f(c[is]...); }); + }; +} + +struct MIGRAPHX_EXPORT interval +{ + scalar min = int64_t{0}; + scalar max = int64_t{0}; + + interval& operator+=(interval b) { return *this = *this + b; } + interval& operator-=(interval b) { return *this = *this - b; } + interval& operator*=(interval b) { return *this = *this * b; } + interval& operator/=(interval b) { return *this = *this / b; } + interval& operator%=(interval b) { return *this = *this % b; } + friend interval operator+(interval a, interval b); + friend interval operator-(interval a, interval b); + friend interval operator*(interval a, interval b); + friend interval operator/(interval a, interval b); + friend interval operator%(interval a, interval b); + friend interval operator-(interval a); + friend bool operator<(interval a, interval b); + friend bool operator<=(interval a, interval b); + friend bool operator>(interval a, interval b); + friend bool operator>=(interval a, interval b); + friend bool operator==(const interval& a, const interval& b); + friend bool operator!=(const interval& a, const interval& b); + friend interval sin(interval x); + friend interval cos(interval x); + friend interval tan(interval x); + friend interval exp(interval x); + friend interval log(interval x); + friend interval sqrt(interval x); + friend interval abs(interval x); + friend interval floor(interval x); + friend interval ceil(interval x); + friend interval pow(interval x, interval y); + friend interval min(interval x, interval y); + friend interval max(interval x, interval y); + friend std::ostream& operator<<(std::ostream& os, const interval& i); +}; + +struct op_def +{ + std::string name; + std::function&)> eval; + std::function&)> eval_interval; + bool associative = false; +}; + +class expr; +MIGRAPHX_EXPORT expr lit(scalar v); + +class MIGRAPHX_EXPORT expr +{ + struct impl; + std::shared_ptr pimpl; + + template + static std::shared_ptr make_impl(Node node, std::vector children); + + public: + expr() = default; + + template + explicit expr(Node node, std::vector children = {}) + : pimpl(make_impl(std::move(node), std::move(children))) + { + } + + std::string name() const; + bool is_raw() const; + const impl* get_pimpl() const; + const std::vector& children() const; + scalar eval(const std::unordered_map& vars) const; + interval eval_interval(const std::unordered_map& vars) const; + std::string to_string() const; bool empty() const; std::size_t hash() const; - std::string to_string() const; - value to_value() const; - void from_value(const value& v); std::size_t eval_uint(const std::unordered_map& symbol_map) const; expr subs(const std::unordered_map& symbol_map) const; + std::set eval_optimals() const; - MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); - MIGRAPHX_EXPORT friend expr operator-(const expr& a, const expr& b); - MIGRAPHX_EXPORT friend expr operator*(const expr& a, const expr& b); - MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b); - MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b); - MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b); - MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e); - - friend expr operator+(const expr& a, int64_t b) { return a + lit(b); } - friend expr operator+(int64_t a, const expr& b) { return lit(a) + b; } - friend expr operator-(const expr& a, int64_t b) { return a - lit(b); } - friend expr operator-(int64_t a, const expr& b) { return lit(a) - b; } - friend expr operator*(const expr& a, int64_t b) { return a * lit(b); } - friend expr operator*(int64_t a, const expr& b) { return lit(a) * b; } - friend expr operator/(const expr& a, int64_t b) { return a / lit(b); } - friend expr operator/(int64_t a, const expr& b) { return lit(a) / b; } + friend expr operator+(expr ex, expr ey); + friend expr operator-(expr ex, expr ey); + friend expr operator*(expr ex, expr ey); + friend expr operator/(expr ex, expr ey); + friend expr operator%(expr ex, expr ey); + friend expr operator-(expr e); + friend bool operator==(const expr& a, const expr& b); + friend bool operator!=(const expr& a, const expr& b); + friend std::ostream& operator<<(std::ostream& os, const expr& e); - struct impl; +#define MIGRAPHX_SYM_DEFINE_OP(binary, assign) \ + expr& operator assign(expr ey) { return *this = *this binary std::move(ey); } \ + template {})> \ + expr& operator assign(T x) \ + { \ + return *this = *this binary lit(make_scalar(x)); \ + } \ + template {})> \ + friend expr operator binary(expr ex, T y) \ + { \ + return std::move(ex) binary lit(make_scalar(y)); \ + } \ + template {})> \ + friend expr operator binary(T x, expr ey) \ + { \ + return lit(make_scalar(x)) binary std::move(ey); \ + } + + MIGRAPHX_SYM_DEFINE_OP(+, +=) + MIGRAPHX_SYM_DEFINE_OP(-, -=) + MIGRAPHX_SYM_DEFINE_OP(*, *=) + MIGRAPHX_SYM_DEFINE_OP(/, /=) + MIGRAPHX_SYM_DEFINE_OP(%, %=) +}; + +template {})> +expr lit(T v) +{ + return lit(make_scalar(v)); +} + +MIGRAPHX_EXPORT expr var(std::string name); +MIGRAPHX_EXPORT expr var(std::string name, interval constraint, std::set optimals = {}); + +MIGRAPHX_EXPORT expr arg(expr x); + +template {})> +expr arg(T x) +{ + if constexpr(std::is_integral{}) + return lit(static_cast(x)); + else + return lit(static_cast(x)); +} + +expr call_op(const op_def* op, std::vector args); + +template +expr call_op(std::string name, + Eval eval, + EvalInterval eval_interval, + std::vector args, + bool is_associative = false) +{ + static const op_def op{ + std::move(name), std::move(eval), std::move(eval_interval), is_associative}; + return call_op(&op, std::move(args)); +} + +template +auto call(std::string name, Eval eval, EvalInterval eval_interval) +{ + return [=](auto... es) { + auto eval1 = unpack_container( + [=](auto... xs) { return scalar_invoke_common(eval, xs...); }); + auto eval_interval1 = + unpack_container([=](auto... xs) { return eval_interval(xs...); }); + return call_op(name, eval1, eval_interval1, {arg(std::move(es))...}); + }; +} - MIGRAPHX_EXPORT friend expr var(const std::string& name); - MIGRAPHX_EXPORT friend expr lit(int64_t n); - MIGRAPHX_EXPORT friend expr parse(const std::string& s); +template +auto call(std::string name, Eval eval) +{ + return call(std::move(name), eval, eval); +} + +MIGRAPHX_EXPORT std::string to_string(const expr& e); + +MIGRAPHX_EXPORT expr parse(const std::string& str); + +MIGRAPHX_EXPORT expr sin(expr e); +MIGRAPHX_EXPORT expr cos(expr e); +MIGRAPHX_EXPORT expr tan(expr e); +MIGRAPHX_EXPORT expr exp(expr e); +MIGRAPHX_EXPORT expr log(expr e); +MIGRAPHX_EXPORT expr sqrt(expr e); +MIGRAPHX_EXPORT expr abs(expr e); +MIGRAPHX_EXPORT expr floor(expr e); +MIGRAPHX_EXPORT expr ceil(expr e); +MIGRAPHX_EXPORT expr pow(expr x, expr y); +MIGRAPHX_EXPORT expr min(expr x, expr y); +MIGRAPHX_EXPORT expr max(expr x, expr y); - private: - expr(std::shared_ptr pi); - std::shared_ptr p; +// Pattern matching rewrite DSL +expr pvar(int id); + +struct rewrite_rule +{ + expr pattern; + expr replacement; }; -} // namespace sym +inline rewrite_rule operator>>(expr pattern, expr replacement) +{ + return {std::move(pattern), std::move(replacement)}; +} + +template {})> +rewrite_rule operator>>(expr pattern, T replacement) +{ + return {std::move(pattern), lit(replacement)}; +} + +MIGRAPHX_EXPORT expr simplify(const expr& e, const std::vector& rules); +MIGRAPHX_EXPORT void migraphx_to_value(value& v, const sym::interval& i); +MIGRAPHX_EXPORT void migraphx_from_value(const value& v, sym::interval& i); +MIGRAPHX_EXPORT void migraphx_to_value(value& v, const sym::expr& e); +MIGRAPHX_EXPORT void migraphx_from_value(const value& v, sym::expr& e); + +} // namespace sym } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx @@ -94,10 +320,9 @@ namespace std { template <> struct hash { - using argument_type = migraphx::sym::expr; - using result_type = std::size_t; - result_type operator()(const migraphx::sym::expr& e) const { return e.hash(); } + std::size_t operator()(const migraphx::sym::expr& e) const { return e.hash(); } }; + } // namespace std #endif diff --git a/src/include/migraphx/utility_operators.hpp b/src/include/migraphx/utility_operators.hpp index f56f61e0452..135f6794d27 100644 --- a/src/include/migraphx/utility_operators.hpp +++ b/src/include/migraphx/utility_operators.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace migraphx { @@ -129,6 +130,26 @@ struct totally_ordered : equality_comparable, less_than_comparable { }; +template +struct ordered_as : totally_ordered>, equivalence> +{ + T value; + Compare compare; + + ordered_as(T v, Compare c) : value(std::move(v)), compare(std::move(c)) {} + + friend bool operator<(const ordered_as& a, const ordered_as& b) + { + return a.compare(a.value, b.value); + } +}; + +template +ordered_as make_ordered_as(T value, Compare compare) +{ + return {std::move(value), std::move(compare)}; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_UTILITY_OPERATORS_HPP diff --git a/src/sym.cpp b/src/sym.cpp index 00d9480e810..7f46bfd5c04 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -21,1044 +21,1650 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ - #include -#include -#include #include - +#include +#include +#include +#include +#include +#include +#include #include -#include -#include +#include #include -#include -#include +#include #include +#include #include -#include -#include -#include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +namespace sym { -// =================================================================== -// Section 1: Expression node types -// =================================================================== - -struct expr_node; -using expr_ptr = std::shared_ptr; - -struct expr_compare +scalar scalar_min(const scalar& a, const scalar& b) { - bool operator()(const expr_ptr& a, const expr_ptr& b) const; -}; - -using term_map = std::map; -using factor_map = std::map; + return scalar_invoke_common([](auto x, auto y) { return x < y ? x : y; }, a, b); +} -struct integer_data -{ - int64_t value; -}; -struct symbol_data -{ - std::string name; -}; -struct add_data +scalar scalar_max(const scalar& a, const scalar& b) { - int64_t constant; - term_map terms; -}; -struct mul_data -{ - int64_t coefficient; - factor_map factors; -}; -struct tdiv_data -{ - expr_ptr numerator; - expr_ptr denominator; -}; - -using expr_data = std::variant; + return scalar_invoke_common([](auto x, auto y) { return x > y ? x : y; }, a, b); +} -struct expr_node +interval operator+(interval a, interval b) { - expr_data data; - std::size_t cached_hash = 0; -}; + auto f = [](auto x, auto y) { return x + y; }; + return {scalar_invoke_common(f, a.min, b.min), scalar_invoke_common(f, a.max, b.max)}; +} -template -static bool holds(const expr_ptr& e) +interval operator-(interval a, interval b) { - return std::holds_alternative(e->data); + auto f = [](auto x, auto y) { return x - y; }; + return {scalar_invoke_common(f, a.min, b.max), scalar_invoke_common(f, a.max, b.min)}; } -static int64_t get_integer(const expr_ptr& e) { return std::get(e->data).value; } -static const add_data& get_add(const expr_ptr& e) { return std::get(e->data); } -static const mul_data& get_mul(const expr_ptr& e) { return std::get(e->data); } - -// =================================================================== -// Section 2: Hash computation -// =================================================================== - -static std::size_t hash_combine(std::size_t seed, std::size_t v) +interval operator*(interval a, interval b) { - return seed ^ (v + 0x9e3779b9 + (seed << 6u) + (seed >> 2u)); + auto f = [](auto x, auto y) { return x * y; }; + auto p1 = scalar_invoke_common(f, a.min, b.min); + auto p2 = scalar_invoke_common(f, a.min, b.max); + auto p3 = scalar_invoke_common(f, a.max, b.min); + auto p4 = scalar_invoke_common(f, a.max, b.max); + return {scalar_min(scalar_min(p1, p2), scalar_min(p3, p4)), + scalar_max(scalar_max(p1, p2), scalar_max(p3, p4))}; } -template -static std::size_t hash_ordered_map(const Map& m) +interval operator/(interval a, interval b) { - std::size_t h = 0; - for(const auto& [key, val] : m) + auto b_lo = to(b.min); + auto b_hi = to(b.max); + // If the divisor brackets zero, the 4-corner formula is wrong (and may hit + // integer division-by-zero). Handle it explicitly using infinities so the + // unbounded regions are representable. + if(b_lo <= 0.0 and b_hi >= 0.0) { - h = hash_combine(h, key->cached_hash); - h = hash_combine(h, std::hash{}(val)); + if(b_lo == 0.0 and b_hi == 0.0) + MIGRAPHX_THROW("Interval division by zero"); + constexpr double inf = std::numeric_limits::infinity(); + // Strictly crosses zero: 1/b sweeps the full real line. + if(b_lo < 0.0 and b_hi > 0.0) + return {-inf, inf}; + auto a_lo = to(a.min); + auto a_hi = to(a.max); + // b == [0, b_hi], b_hi > 0: 1/b in [1/b_hi, +inf). + if(b_lo == 0.0) + { + if(a_lo >= 0.0) + return {a_lo / b_hi, inf}; + if(a_hi <= 0.0) + return {-inf, a_hi / b_hi}; + return {-inf, inf}; + } + // b == [b_lo, 0], b_lo < 0: 1/b in (-inf, 1/b_lo]. + if(a_lo >= 0.0) + return {-inf, a_lo / b_lo}; + if(a_hi <= 0.0) + return {a_hi / b_lo, inf}; + return {-inf, inf}; } - return h; + + auto f = [](auto x, auto y) { return x / y; }; + auto p1 = scalar_invoke_common(f, a.min, b.min); + auto p2 = scalar_invoke_common(f, a.min, b.max); + auto p3 = scalar_invoke_common(f, a.max, b.min); + auto p4 = scalar_invoke_common(f, a.max, b.max); + return {scalar_min(scalar_min(p1, p2), scalar_min(p3, p4)), + scalar_max(scalar_max(p1, p2), scalar_max(p3, p4))}; } -static std::size_t compute_hash(const expr_data& d) +interval operator%(interval a, interval b) { - std::size_t h = std::hash{}(d.index()); - return std::visit( - overloaded{ - [&](const integer_data& p) { return hash_combine(h, std::hash{}(p.value)); }, - [&](const symbol_data& p) { return hash_combine(h, std::hash{}(p.name)); }, - [&](const add_data& p) { - return hash_combine(hash_combine(h, std::hash{}(p.constant)), - hash_ordered_map(p.terms)); - }, - [&](const mul_data& p) { - return hash_combine(hash_combine(h, std::hash{}(p.coefficient)), - hash_ordered_map(p.factors)); - }, - [&](const tdiv_data& p) { - return hash_combine(hash_combine(h, p.numerator->cached_hash), - p.denominator->cached_hash); - }}, - d); -} - -// =================================================================== -// Section 3: Canonical ordering (expr_compare) -// =================================================================== - -static int compare_expr(const expr_ptr& a, const expr_ptr& b); - -template -static int compare_maps(const Map& a, const Map& b) -{ - auto it_a = a.begin(); - auto it_b = b.begin(); - for(; it_a != a.end() and it_b != b.end(); ++it_a, ++it_b) + // The 4-corner min/max formula is wrong for mod (e.g. [1,5] % [3,3] should + // include {0,1,2}, not just the corner values). Use a loose but correct + // bound: |a % b| < max(|b_lo|, |b_hi|). + auto b_lo = to(b.min); + auto b_hi = to(b.max); + if(b_lo == 0.0 and b_hi == 0.0) + MIGRAPHX_THROW("Interval mod by zero"); + auto max_abs = std::max(std::abs(b_lo), std::abs(b_hi)); + if(std::holds_alternative(b.min) and std::holds_alternative(b.max)) { - int c = compare_expr(it_a->first, it_b->first); - if(c != 0) - return c; - if(it_a->second < it_b->second) - return -1; - if(it_a->second > it_b->second) - return 1; + auto m = static_cast(max_abs); + return {int64_t{-m}, m}; } - if(it_a != a.end()) - return 1; - if(it_b != b.end()) - return -1; - return 0; + return {-max_abs, max_abs}; } -static int compare_expr(const expr_ptr& a, const expr_ptr& b) +interval operator-(interval a) { - if(a->data.index() != b->data.index()) - return a->data.index() < b->data.index() ? -1 : 1; + auto f = [](auto x) { return -x; }; + return {scalar_invoke_common(f, a.max), scalar_invoke_common(f, a.min)}; +} - return std::visit( - overloaded{[&](const integer_data& da) { - const auto& db = std::get(b->data); - return (da.value < db.value) ? -1 : (da.value > db.value) ? 1 : 0; - }, - [&](const symbol_data& da) { - const auto& db = std::get(b->data); - return da.name.compare(db.name); - }, - [&](const add_data& da) { - const auto& db = std::get(b->data); - if(da.constant != db.constant) - return da.constant < db.constant ? -1 : 1; - return compare_maps(da.terms, db.terms); - }, - [&](const mul_data& da) { - const auto& db = std::get(b->data); - if(da.coefficient != db.coefficient) - return da.coefficient < db.coefficient ? -1 : 1; - return compare_maps(da.factors, db.factors); - }, - [&](const tdiv_data& da) { - const auto& db = std::get(b->data); - int c = compare_expr(da.numerator, db.numerator); - if(c != 0) - return c; - return compare_expr(da.denominator, db.denominator); - }}, - a->data); -} - -bool expr_compare::operator()(const expr_ptr& a, const expr_ptr& b) const -{ - return compare_expr(a, b) < 0; -} - -// =================================================================== -// Section 4: Structural equality -// =================================================================== - -static bool expr_equal(const expr_ptr& a, const expr_ptr& b) -{ - if(a.get() == b.get()) - return true; - if(a->cached_hash != b->cached_hash) - return false; - return compare_expr(a, b) == 0; +bool operator==(const interval& a, const interval& b) { return a.min == b.min and a.max == b.max; } + +bool operator!=(const interval& a, const interval& b) { return not(a == b); } + +std::ostream& operator<<(std::ostream& os, const interval& i) +{ + os << "["; + std::visit([&](auto x) { os << x; }, i.min); + os << ", "; + std::visit([&](auto x) { os << x; }, i.max); + os << "]"; + return os; +} + +static bool scalar_less(const scalar& a, const scalar& b) +{ + auto f = [](auto x, auto y) -> int64_t { return x < y ? 1 : 0; }; + return std::get(scalar_invoke_common(f, a, b)) != 0; } -// =================================================================== -// Section 5: Factory functions (canonical constructors) -// =================================================================== +bool operator<(interval a, interval b) { return scalar_less(a.max, b.min); } + +bool operator<=(interval a, interval b) { return not scalar_less(b.min, a.max); } + +bool operator>(interval a, interval b) { return scalar_less(b.max, a.min); } -static expr_ptr make_node(expr_data d) +bool operator>=(interval a, interval b) { return not scalar_less(a.min, b.max); } + +interval sin(interval x) { - auto n = std::make_shared(); - n->data = std::move(d); - n->cached_hash = compute_hash(n->data); - return n; + double lo = to(x.min); + double hi = to(x.max); + const double pi = std::acos(-1.0); + if(hi - lo >= 2.0 * pi) + return {-1.0, 1.0}; + double slo = std::sin(lo); + double shi = std::sin(hi); + double rmin = std::min(slo, shi); + double rmax = std::max(slo, shi); + double k = std::ceil((lo - pi / 2.0) / (2.0 * pi)); + if(pi / 2.0 + k * 2.0 * pi <= hi) + rmax = 1.0; + k = std::ceil((lo + pi / 2.0) / (2.0 * pi)); + if(-pi / 2.0 + k * 2.0 * pi <= hi) + rmin = -1.0; + return {rmin, rmax}; } -static const expr_ptr& const_zero() +interval cos(interval x) { - static auto p = make_node(integer_data{0}); - return p; + double lo = to(x.min); + double hi = to(x.max); + const double pi = std::acos(-1.0); + if(hi - lo >= 2.0 * pi) + return {-1.0, 1.0}; + double clo = std::cos(lo); + double chi = std::cos(hi); + double rmin = std::min(clo, chi); + double rmax = std::max(clo, chi); + double k = std::ceil(lo / (2.0 * pi)); + if(k * 2.0 * pi <= hi) + rmax = 1.0; + k = std::ceil((lo - pi) / (2.0 * pi)); + if(pi + k * 2.0 * pi <= hi) + rmin = -1.0; + return {rmin, rmax}; } -static const expr_ptr& const_one() + +interval tan(interval x) { return {std::tan(to(x.min)), std::tan(to(x.max))}; } + +interval exp(interval x) { return {std::exp(to(x.min)), std::exp(to(x.max))}; } + +interval log(interval x) { return {std::log(to(x.min)), std::log(to(x.max))}; } + +interval sqrt(interval x) { - static auto p = make_node(integer_data{1}); - return p; + auto lo = std::sqrt(std::max(0.0, to(x.min))); + auto hi = std::sqrt(std::max(0.0, to(x.max))); + return {lo, hi}; } -static const expr_ptr& const_neg_one() + +interval abs(interval x) { - static auto p = make_node(integer_data{-1}); - return p; + double lo = to(x.min); + double hi = to(x.max); + if(lo >= 0.0) + return x; + if(hi <= 0.0) + return -x; + auto neg_min = scalar_invoke_common([](auto v) { return -v; }, x.min); + return {int64_t{0}, scalar_max(neg_min, x.max)}; } -static expr_ptr make_integer(int64_t n) +interval floor(interval x) { - if(n == 0) - return const_zero(); - if(n == 1) - return const_one(); - if(n == -1) - return const_neg_one(); - return make_node(integer_data{n}); + return {std::floor(to(x.min)), std::floor(to(x.max))}; } -static expr_ptr make_symbol(const std::string& name) { return make_node(symbol_data{name}); } +interval ceil(interval x) { return {std::ceil(to(x.min)), std::ceil(to(x.max))}; } -static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b); -static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b); -static expr_ptr make_neg(const expr_ptr& a); -static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b); -static expr_ptr make_trunc_div(const expr_ptr& a, const expr_ptr& b); -static expr_ptr build_mul(int64_t coefficient, factor_map factors); +interval pow(interval x, interval y) +{ + auto f = MIGRAPHX_LIFT(std::pow); + auto p1 = scalar_invoke_common(f, x.min, y.min); + auto p2 = scalar_invoke_common(f, x.min, y.max); + auto p3 = scalar_invoke_common(f, x.max, y.min); + auto p4 = scalar_invoke_common(f, x.max, y.max); + return {scalar_min(scalar_min(p1, p2), scalar_min(p3, p4)), + scalar_max(scalar_max(p1, p2), scalar_max(p3, p4))}; +} -struct add_parts +interval min(interval x, interval y) { - int64_t constant = 0; - term_map terms; -}; + return {scalar_min(x.min, y.min), scalar_min(x.max, y.max)}; +} + +interval max(interval x, interval y) +{ + return {scalar_max(x.min, y.min), scalar_max(x.max, y.max)}; +} -static add_parts extract_add(const expr_ptr& e) +static std::size_t hash_scalar(scalar s) { return std::visit( - overloaded{[](const integer_data& d) -> add_parts { return {d.value, {}}; }, - [](const add_data& d) -> add_parts { return {d.constant, d.terms}; }, - [&](const mul_data& d) -> add_parts { - auto base = build_mul(1, d.factors); - return {0, {{base, d.coefficient}}}; - }, - [&](const auto&) -> add_parts { return {0, {{e, 1}}}; }}, - e->data); + [](auto x) -> std::size_t { + using T = std::decay_t; + if constexpr(std::is_floating_point{}) + { + int64_t i = x; + if(float_equal(x, i)) + return hash_value(i); + } + return hash_value(x); + }, + s); } -static expr_ptr build_add(int64_t constant, term_map terms) +struct literal_node { - // Remove zero-coefficient terms - for(auto it = terms.begin(); it != terms.end();) + scalar val; + std::size_t hash() const { return hash_scalar(val); } + friend bool operator==(const literal_node& a, const literal_node& b) { - if(it->second == 0) - it = terms.erase(it); - else - ++it; + return scalar_invoke_common( + [](auto a, auto b) { return float_equal(a, b); }, a.val, b.val); } - if(terms.empty()) - return make_integer(constant); - if(constant == 0 and terms.size() == 1) + friend bool operator!=(const literal_node& a, const literal_node& b) { return not(a == b); } +}; + +struct variable_node +{ + std::string name; + std::vector constraints; + std::set optimals; + + std::size_t hash() const { return hash_value(name); } + friend bool operator==(const variable_node& a, const variable_node& b) { - auto& [term, coeff] = *terms.begin(); - if(coeff == 1) - return term; - return make_mul(make_integer(coeff), term); + return a.name == b.name; } - return make_node(add_data{constant, std::move(terms)}); + friend bool operator!=(const variable_node& a, const variable_node& b) { return not(a == b); } +}; + +struct op_node +{ + const op_def* op; + friend bool operator==(const op_node& a, const op_node& b) { return a.op == b.op; } + friend bool operator!=(const op_node& a, const op_node& b) { return not(a == b); } + + std::size_t hash() const { return hash_value(op->name); } +}; + +using node_variant = std::variant; + +static std::size_t hash_node(const node_variant& nv) +{ + return std::visit([](const auto& x) { return x.hash(); }, nv); } -static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b) +struct expr::impl { - auto pa = extract_add(a); - auto pb = extract_add(b); + node_variant node; + std::vector children; + bool raw_flag = false; + std::size_t cached_hash = 0; +}; - int64_t constant = pa.constant + pb.constant; - term_map terms = std::move(pa.terms); - for(const auto& [term, coeff] : pb.terms) - terms[term] += coeff; +const expr::impl* expr::get_pimpl() const { return pimpl.get(); } - return build_add(constant, std::move(terms)); +static const node_variant& get_node(const expr& e) +{ + assert(e.get_pimpl() != nullptr); + return e.get_pimpl()->node; } -static expr_ptr make_neg(const expr_ptr& a) +static std::string get_sym_name(const node_variant& nv) +{ + return std::visit(overloaded{[](const variable_node& n) { return n.name; }, + [](const op_node& n) -> std::string { return n.op->name; }, + [](const literal_node&) -> std::string { return ""; }}, + nv); +} + +static std::string get_node_name(const node_variant& nv) +{ + return std::visit(overloaded{[](const literal_node&) -> std::string { return "literal"; }, + [](const variable_node&) -> std::string { return "variable"; }, + [](const op_node& n) -> std::string { return n.op->name; }}, + nv); +} + +static scalar get_scalar_or(const node_variant& nv, scalar s) { return std::visit( - overloaded{ - [](const integer_data& d) -> expr_ptr { return make_integer(-d.value); }, - [](const add_data& d) -> expr_ptr { - term_map negated; - for(const auto& [term, coeff] : d.terms) - negated[term] = -coeff; - return build_add(-d.constant, std::move(negated)); - }, - [](const mul_data& d) -> expr_ptr { return build_mul(-d.coefficient, d.factors); }, - [&](const auto&) -> expr_ptr { return make_mul(make_integer(-1), a); }}, - a->data); + overloaded{[](const literal_node& n) { return n.val; }, [&](const auto&) { return s; }}, + nv); +} + +template +std::shared_ptr expr::make_impl(Node node, std::vector children) +{ + bool raw = + std::any_of(children.begin(), children.end(), [](const expr& e) { return e.is_raw(); }); + if constexpr(std::is_same{}) + raw = raw or (not node.name.empty() and node.name[0] == '_'); + auto h = hash_node(node); + hash_range(h, children.begin(), children.end()); + return std::make_shared( + impl{node_variant{std::move(node)}, std::move(children), raw, h}); } -static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b) { return make_add(a, make_neg(b)); } +template std::shared_ptr expr::make_impl(literal_node, std::vector); +template std::shared_ptr expr::make_impl(variable_node, std::vector); +template std::shared_ptr expr::make_impl(op_node, std::vector); + +expr lit(scalar v) { return expr(literal_node{v}); } -struct mul_parts +expr var(std::string name) { - int64_t coefficient = 1; - factor_map factors; -}; + if(name.empty()) + MIGRAPHX_THROW("Variable name must not be empty"); + return expr(variable_node{std::move(name), {}, {}}); +} -static mul_parts extract_mul(const expr_ptr& e) +expr var(std::string name, interval constraint, std::set optimals) { - return std::visit( - overloaded{[](const integer_data& d) -> mul_parts { return {d.value, {}}; }, - [](const mul_data& d) -> mul_parts { return {d.coefficient, d.factors}; }, - [&](const auto&) -> mul_parts { return {1, {{e, 1}}}; }}, - e->data); + if(name.empty()) + MIGRAPHX_THROW("Variable name must not be empty"); + return expr(variable_node{std::move(name), {constraint}, std::move(optimals)}); } -static expr_ptr build_mul(int64_t coefficient, factor_map factors) +expr arg(expr x) { return x; } + +static bool expr_children_less(const std::vector& a, const std::vector& b); + +static auto expr_compare_key(const expr& e) { - if(coefficient == 0) - return make_integer(0); - for(auto it = factors.begin(); it != factors.end();) - { - if(it->second == 0) - it = factors.erase(it); - else - ++it; - } - if(factors.empty()) - return make_integer(coefficient); - if(coefficient == 1 and factors.size() == 1) - { - auto& [base, exp] = *factors.begin(); - if(exp == 1) - return base; - } - return make_node(mul_data{coefficient, std::move(factors)}); + const auto& n = get_node(e); + auto children = make_ordered_as(std::cref(e.children()), &expr_children_less); + return std::make_tuple( + n.index(), get_scalar_or(n, scalar{int64_t{0}}), get_sym_name(n), children); +} + +static bool expr_children_less(const std::vector& a, const std::vector& b) +{ + return std::lexicographical_compare( + a.begin(), a.end(), b.begin(), b.end(), by(std::less<>{}, &expr_compare_key)); } -static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) +static bool is_pvar(const expr& e) { - if(holds(a) and holds(b)) + const auto* v = std::get_if(&get_node(e)); + return v != nullptr and not v->name.empty() and v->name[0] == '_'; +} + +static bool match_expr(const expr& pattern, const expr& e, std::unordered_map& bindings) +{ + if(is_pvar(pattern)) { - int64_t n = get_integer(a); - if(n == 0) - return make_integer(0); - if(n == 1) - return b; - const auto& d = get_add(b); - term_map scaled; - for(const auto& [term, coeff] : d.terms) - scaled[term] = coeff * n; - return build_add(d.constant * n, std::move(scaled)); + auto it = bindings.find(pattern); + if(it != bindings.end()) + return it->second == e; + bindings.emplace(pattern, e); + return true; } - if(holds(b) and holds(a)) - return make_mul(b, a); + if(get_node(pattern).index() != get_node(e).index()) + return false; + return std::visit(overloaded{[&](const literal_node& pl) { + return pl.val == std::get(get_node(e)).val; + }, + [&](const variable_node& pv) { + const auto& ev = std::get(get_node(e)); + return pv.name == ev.name and pv.constraints == ev.constraints; + }, + [&](const op_node& po) { + const auto& eo = std::get(get_node(e)); + if(po.op->name != eo.op->name) + return false; + if(pattern.children().size() != e.children().size()) + return false; + return std::equal(pattern.children().begin(), + pattern.children().end(), + e.children().begin(), + [&](const expr& p, const expr& c) { + return match_expr(p, c, bindings); + }); + }}, + get_node(pattern)); +} - auto pa = extract_mul(a); - auto pb = extract_mul(b); +static bool is_zero(const scalar& v) { return v == scalar{int64_t{0}} or v == scalar{0.0}; } - int64_t coefficient = pa.coefficient * pb.coefficient; - if(coefficient == 0) - return make_integer(0); +static bool is_one(const scalar& v) { return v == scalar{int64_t{1}} or v == scalar{1.0}; } - factor_map factors = std::move(pa.factors); - for(const auto& [base, exp] : pb.factors) - factors[base] += exp; +struct term +{ + scalar coeff; + std::vector bases; +}; - return build_mul(coefficient, std::move(factors)); +static term extract_term(const expr& e) +{ + if(e.name() == "literal") + { + const auto* n = std::get_if(&get_node(e)); + return {n->val, {}}; + } + if(e.name() == "*") + { + return std::accumulate(e.children().begin(), + e.children().end(), + term{scalar{int64_t{1}}, {}}, + [](term t, const expr& child) { + if(child.name() == "literal") + { + const auto* n = std::get_if(&get_node(child)); + t.coeff = scalar_invoke_common( + [](auto x, auto y) { return x * y; }, t.coeff, n->val); + } + else + { + t.bases.push_back(child); + } + return t; + }); + } + return {scalar{int64_t{1}}, {e}}; } -static expr_ptr try_cancel_single(const mul_data& da, const expr_ptr& b) +static expr build_term(const term& t) { - auto it = da.factors.find(b); - if(it == da.factors.end()) - return nullptr; - factor_map reduced = da.factors; - if(it->second == 1) - reduced.erase(it->first); - else - reduced[it->first] = it->second - 1; - return build_mul(da.coefficient, std::move(reduced)); + if(t.bases.empty()) + return lit(t.coeff); + auto base_product = std::accumulate(t.bases.begin() + 1, + t.bases.end(), + t.bases.front(), + [](expr acc, const expr& b) { return std::move(acc) * b; }); + if(is_one(t.coeff)) + return base_product; + return lit(t.coeff) * base_product; } -static expr_ptr try_div_int_over_add(const add_data& d, int64_t den) +static expr normalize_add(const op_def* op, std::vector args) { - if(d.constant % den != 0) - return nullptr; - bool all_divisible = std::all_of( - d.terms.begin(), d.terms.end(), [&](const auto& p) { return p.second % den == 0; }); - if(not all_divisible) - return nullptr; - term_map divided = d.terms; - for(auto& [base, coeff] : divided) - coeff /= den; - return build_add(d.constant / den, std::move(divided)); + std::vector terms; + terms.reserve(args.size()); + std::transform(args.begin(), args.end(), std::back_inserter(terms), extract_term); + + std::stable_sort(terms.begin(), terms.end(), [](const term& a, const term& b) { + return expr_children_less(a.bases, b.bases); + }); + + // Merge adjacent terms with matching bases + std::vector merged; + group_unique( + terms.begin(), + terms.end(), + [&](auto first, auto last) { + merged.push_back( + std::accumulate(std::next(first), last, *first, [](term acc, const term& t) { + acc.coeff = scalar_invoke_common( + [](auto x, auto y) { return x + y; }, acc.coeff, t.coeff); + return acc; + })); + }, + [](const term& a, const term& b) { return a.bases == b.bases; }); + + merged.erase(std::remove_if( + merged.begin(), merged.end(), [](const term& t) { return is_zero(t.coeff); }), + merged.end()); + + if(merged.empty()) + return lit(int64_t{0}); + if(merged.size() == 1) + return build_term(merged[0]); + + std::vector result_children; + result_children.reserve(merged.size()); + std::transform(merged.begin(), merged.end(), std::back_inserter(result_children), build_term); + std::stable_sort( + result_children.begin(), result_children.end(), by(std::greater<>{}, &expr_compare_key)); + return expr(op_node{op}, std::move(result_children)); } -static expr_ptr -try_cancel_factors(const mul_data& da, const mul_data& db, const expr_ptr& a, const expr_ptr& b) +static expr normalize_mul(const op_def* op, std::vector args) { - factor_map reduced_num = da.factors; - factor_map reduced_den = db.factors; - for(auto it_den = reduced_den.begin(); it_den != reduced_den.end();) + auto partition_it = std::stable_partition( + args.begin(), args.end(), [](const expr& a) { return a.name() != "literal"; }); + auto coeff = transform_accumulate( + partition_it, + args.end(), + scalar{int64_t{1}}, + [](scalar acc, scalar v) { + return scalar_invoke_common([](auto x, auto y) { return x * y; }, acc, v); + }, + [](const expr& a) { return std::get_if(&get_node(a))->val; }); + + if(is_zero(coeff)) + return lit(coeff); + + std::vector factors; + if(not is_one(coeff)) + factors.push_back(lit(coeff)); + factors.insert(factors.end(), + std::make_move_iterator(args.begin()), + std::make_move_iterator(partition_it)); + + auto it = + std::find_if(factors.begin(), factors.end(), [](const expr& e) { return e.name() == "+"; }); + if(it != factors.end()) { - auto it_num = reduced_num.find(it_den->first); - if(it_num == reduced_num.end()) - { - ++it_den; - continue; - } - int64_t cancel = std::min(it_num->second, it_den->second); - it_num->second -= cancel; - it_den->second -= cancel; - if(it_num->second == 0) - reduced_num.erase(it_num); - if(it_den->second == 0) - it_den = reduced_den.erase(it_den); - else - ++it_den; + auto plus_children = it->children(); + std::vector other_factors; + std::copy_if(factors.begin(), + factors.end(), + std::back_inserter(other_factors), + [&](const expr& f) { return &f != &*it; }); + std::vector distributed; + distributed.reserve(plus_children.size()); + std::transform(plus_children.begin(), + plus_children.end(), + std::back_inserter(distributed), + [&](const expr& pc) { + return std::accumulate( + other_factors.begin(), + other_factors.end(), + pc, + [](expr product, const expr& f) { return std::move(product) * f; }); + }); + return std::accumulate(distributed.begin() + 1, + distributed.end(), + distributed.front(), + [](expr acc, const expr& e) { return std::move(acc) + e; }); } - auto new_num = build_mul(da.coefficient, std::move(reduced_num)); - auto new_den = build_mul(db.coefficient, std::move(reduced_den)); - if(not expr_equal(new_num, a) or not expr_equal(new_den, b)) - return make_trunc_div(new_num, new_den); - return nullptr; + + if(factors.empty()) + return lit(coeff); + if(factors.size() == 1) + return factors[0]; + std::stable_sort(factors.begin(), factors.end(), by(std::less<>{}, &expr_compare_key)); + return expr(op_node{op}, std::move(factors)); } -static expr_ptr make_trunc_div(const expr_ptr& a, const expr_ptr& b) +static expr normalize_div(const op_def* op, std::vector args) { - if(holds(a) and get_integer(a) == 0) - return a; + const auto& num = args[0]; + const auto& den = args[1]; - if(expr_equal(a, b)) - return make_integer(1); + // 0 / x == 0 + if(num.name() == "literal") + { + const auto* n = std::get_if(&get_node(num)); + if(is_zero(n->val)) + return lit(n->val); + } - if(holds(b)) + // x / 1 == x + if(den.name() == "literal") { - int64_t den = get_integer(b); - if(den == 0) - MIGRAPHX_THROW("symbolic: division by zero"); - if(den == 1) - return a; - if(holds(a)) - return make_integer(get_integer(a) / den); - if(holds(a)) - { - const auto& d = get_mul(a); - if(d.coefficient % den == 0) - return build_mul(d.coefficient / den, d.factors); - } - if(holds(a)) + const auto* n = std::get_if(&get_node(den)); + if(is_one(n->val)) + return num; + } + + // x / x == 1 + if(num == den) + return lit(int64_t{1}); + + // Factor cancellation between products + auto num_term = extract_term(num); + auto den_term = extract_term(den); + + // Cancel common symbolic bases using set_difference on sorted ranges + auto num_bases = num_term.bases; + auto den_bases = den_term.bases; + auto cmp = by(std::less<>{}, &expr_compare_key); + std::stable_sort(num_bases.begin(), num_bases.end(), cmp); + std::stable_sort(den_bases.begin(), den_bases.end(), cmp); + + std::vector remaining_num_bases; + std::set_difference(num_bases.begin(), + num_bases.end(), + den_bases.begin(), + den_bases.end(), + std::back_inserter(remaining_num_bases), + cmp); + std::vector remaining_den_bases; + std::set_difference(den_bases.begin(), + den_bases.end(), + num_bases.begin(), + num_bases.end(), + std::back_inserter(remaining_den_bases), + cmp); + + bool bases_changed = remaining_num_bases.size() != num_term.bases.size() or + remaining_den_bases.size() != den_term.bases.size(); + + // Cancel GCD of integer coefficients + auto num_coeff = num_term.coeff; + auto den_coeff = den_term.coeff; + scalar new_num_coeff = num_coeff; + scalar new_den_coeff = den_coeff; + + if(std::holds_alternative(num_coeff) and std::holds_alternative(den_coeff)) + { + auto nc = std::get(num_coeff); + auto dc = std::get(den_coeff); + if(dc != 0) { - auto r = try_div_int_over_add(get_add(a), den); - if(r != nullptr) - return r; + auto g = std::gcd(std::abs(nc), std::abs(dc)); + if(g > 1) + { + new_num_coeff = int64_t{nc / g}; + new_den_coeff = int64_t{dc / g}; + bases_changed = true; + } } } - if(holds(a)) + if(bases_changed) { - const auto& da = get_mul(a); - if(holds(b)) + expr new_num = build_term({new_num_coeff, remaining_num_bases}); + expr new_den = build_term({new_den_coeff, remaining_den_bases}); + + if(new_den.name() == "literal") { - auto r = try_cancel_factors(da, get_mul(b), a, b); - if(r != nullptr) - return r; + const auto* n = std::get_if(&get_node(new_den)); + if(is_one(n->val)) + return new_num; } - else + return new_num / new_den; + } + + // Distribute over sum: (a*k + b*k) / k when all terms are divisible + if(num.name() == "+" and den.name() == "literal") + { + const auto* d = std::get_if(&get_node(den)); + if(std::holds_alternative(d->val)) { - auto r = try_cancel_single(da, b); - if(r != nullptr) - return r; + auto dv = std::get(d->val); + bool all_divisible = + std::all_of(num.children().begin(), num.children().end(), [&](const expr& child) { + auto t = extract_term(child); + if(not std::holds_alternative(t.coeff)) + return false; + return std::get(t.coeff) % dv == 0; + }); + if(all_divisible) + { + std::vector divided; + divided.reserve(num.children().size()); + std::transform(num.children().begin(), + num.children().end(), + std::back_inserter(divided), + [&](const expr& child) { return child / den; }); + return std::accumulate(divided.begin() + 1, + divided.end(), + divided.front(), + [](expr acc, const expr& e) { return std::move(acc) + e; }); + } } } - return make_node(tdiv_data{a, b}); + return expr(op_node{op}, std::move(args)); } -// =================================================================== -// Section 6: Substitution and evaluation -// =================================================================== +static expr normalize_impl(const op_def* op, std::vector args) +{ + if(std::any_of(args.begin(), args.end(), [](const expr& e) { return e.empty(); })) + { + return {}; + } + if(std::all_of(args.begin(), args.end(), [](const expr& e) { return e.name() == "literal"; })) + { + auto e = expr(op_node{op}, std::move(args)); + return lit(e.eval({})); + } + if(contains({"/", "%"}, op->name) and args.at(1) == lit(0)) + MIGRAPHX_THROW("Division by zero"); + if(op->name == "+") + return normalize_add(op, std::move(args)); + if(op->name == "*") + return normalize_mul(op, std::move(args)); + if(op->name == "/") + return normalize_div(op, std::move(args)); + return expr(op_node{op}, std::move(args)); +} -using binding_map = std::map; -using subs_map = std::map; +static const std::vector& get_rewrite_rules() +{ + static const std::vector rules = [] { + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto _2 = pvar(2); // NOLINT(readability-identifier-naming) + return std::vector{ + sqrt(_1 * _2) >> sqrt(_1) * sqrt(_2), + sqrt(_1 / _2) >> sqrt(_1) / sqrt(_2), + log(exp(_1)) >> _1, + exp(log(_1)) >> _1, + }; + }(); + return rules; +} -static expr_ptr substitute(const expr_ptr& e, const subs_map& bindings) +static expr apply_rewrite_rules(const expr& e) { - return std::visit(overloaded{[&](const integer_data&) -> expr_ptr { return e; }, - [&](const symbol_data&) -> expr_ptr { - auto it = bindings.find(e); - if(it != bindings.end()) - return it->second; - return e; - }, - [&](const add_data& d) -> expr_ptr { - expr_ptr result = make_integer(d.constant); - for(const auto& [term, coeff] : d.terms) - { - auto st = substitute(term, bindings); - result = - make_add(result, make_mul(make_integer(coeff), st)); - } - return result; - }, - [&](const mul_data& d) -> expr_ptr { - expr_ptr result = make_integer(d.coefficient); - for(const auto& [base, exp] : d.factors) - { - auto sb = substitute(base, bindings); - for(int64_t i = 0; i < exp; ++i) - result = make_mul(result, sb); - } - return result; - }, - [&](const tdiv_data& d) -> expr_ptr { - auto sn = substitute(d.numerator, bindings); - auto sd = substitute(d.denominator, bindings); - return make_trunc_div(sn, sd); - }}, - e->data); + if(e.empty()) + return e; + for(const auto& rule : get_rewrite_rules()) + { + std::unordered_map bindings; + if(match_expr(rule.pattern, e, bindings)) + return rule.replacement.subs(bindings); + } + return e; } -static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) +static expr normalize_expr(const op_def* op, std::vector args) { - return std::visit(overloaded{[](const integer_data& d) -> int64_t { return d.value; }, - [&](const symbol_data& d) -> int64_t { - auto it = bindings.find(e); - if(it != bindings.end()) - return it->second; - MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + - d.name + "'"); - }, - [&](const add_data& d) -> int64_t { - int64_t sum = d.constant; - for(const auto& [term, coeff] : d.terms) - sum += coeff * eval_direct(term, bindings); - return sum; - }, - [&](const mul_data& d) -> int64_t { - int64_t prod = d.coefficient; - for(const auto& [base, exp] : d.factors) - { - int64_t val = eval_direct(base, bindings); - for(int64_t i = 0; i < exp; ++i) - prod *= val; - } - return prod; - }, - [&](const tdiv_data& d) -> int64_t { - auto denom = eval_direct(d.denominator, bindings); - if(denom == 0) - MIGRAPHX_THROW("sym::expr::eval_uint: division by zero"); - return eval_direct(d.numerator, bindings) / denom; - }}, - e->data); + return apply_rewrite_rules(normalize_impl(op, std::move(args))); } -// =================================================================== -// Section 7: Pretty-printer -// =================================================================== +static std::vector flatten_args(const std::string& op_name, std::vector args) +{ + std::vector flat_args; + std::transform(args.begin(), args.end(), join_back_inserter(flat_args), [&](const expr& a) { + if(a.name() == op_name) + return a.children(); + return std::vector{a}; + }); + return flat_args; +} -enum +static expr fold_associative_args(expr e) { - prec_atom = 100, - prec_mul = 50, - prec_add = 40 -}; + if(e.empty()) + return e; + if(not std::holds_alternative(get_node(e))) + return e; + if(e.children().size() <= 2) + return e; + const auto& op_n = std::get(get_node(e)); + auto children = std::accumulate(e.children().begin() + 1, + e.children().end(), + std::vector{e.children().front()}, + [&](std::vector c, expr x) { + if(std::holds_alternative(get_node(x)) and + std::holds_alternative(get_node(c.back()))) + { + auto d = expr(op_n, {c.back(), x}); + c.back() = lit(d.eval({})); + } + else + { + c.push_back(std::move(x)); + } + return c; + }); + return expr(op_n, std::move(children)); +} -static std::string print_expr(const expr_ptr& e, int parent_prec = 0); +expr call_op(const op_def* op, std::vector args) +{ + if(std::any_of(args.begin(), args.end(), [](const expr& e) { return e.is_raw(); })) + return expr(op_node{op}, std::move(args)); + if(op->associative) + args = flatten_args(op->name, std::move(args)); + auto result = normalize_expr(op, std::move(args)); + if(op->associative) + result = fold_associative_args(std::move(result)); + return result; +} -static std::string print_add(const add_data& d, int parent_prec) +template +static auto call_associative(std::string name, Eval eval, EvalInterval eval_interval) { - std::ostringstream os; - bool first = true; - for(const auto& [term, coeff] : d.terms) - { - if(first) - { - if(coeff == -1) - os << "-" << print_expr(term, prec_add); - else if(coeff == 1) - os << print_expr(term, prec_add); - else - os << coeff << "*" << print_expr(term, prec_mul + 1); - first = false; - } - else - { - if(coeff == 1) - os << " + " << print_expr(term, prec_add); - else if(coeff == -1) - os << " - " << print_expr(term, prec_add); - else if(coeff > 0) - os << " + " << coeff << "*" << print_expr(term, prec_mul + 1); + return [=](auto... es) { + auto eval1 = [=](const std::vector& args) { + return std::accumulate(args.begin() + 1, + args.end(), + args.front(), + [=](const scalar& acc, const scalar& arg) { + return scalar_invoke_common(eval, acc, arg); + }); + }; + auto eval_interval1 = [=](const std::vector& args) { + return std::accumulate( + args.begin() + 1, + args.end(), + args.front(), + [=](const interval& acc, const interval& arg) { return eval_interval(acc, arg); }); + }; + return call_op(name, eval1, eval_interval1, {arg(std::move(es))...}, true); + }; +} + +template +static auto call_associative(std::string name, Eval eval) +{ + return call_associative(std::move(name), eval, eval); +} + +expr operator+(expr ex, expr ey) +{ + return call_associative("+", [](auto x, auto y) { return x + y; })(std::move(ex), + std::move(ey)); +} + +expr operator-(expr ex, expr ey) { return std::move(ex) + (-std::move(ey)); } + +expr operator*(expr ex, expr ey) +{ + return call_associative("*", [](auto x, auto y) { return x * y; })(std::move(ex), + std::move(ey)); +} + +expr operator/(expr ex, expr ey) +{ + return call( + "/", + [](auto x, auto y) { + if(float_equal(y, 0)) + MIGRAPHX_THROW("Division by zero"); + return x / y; + }, + [](interval x, interval y) { return x / y; })(std::move(ex), std::move(ey)); +} + +expr operator%(expr ex, expr ey) +{ + return call( + "%", + [](auto x, auto y) { + if(float_equal(y, 0)) + MIGRAPHX_THROW("Division by zero"); + if constexpr(std::is_integral{} and std::is_integral{}) + return x % y; else - os << " - " << (-coeff) << "*" << print_expr(term, prec_mul + 1); - } - } - if(d.constant > 0) - os << " + " << d.constant; - else if(d.constant < 0) - os << " - " << (-d.constant); - std::string s = os.str(); - if(parent_prec > prec_add) - return "(" + s + ")"; - return s; + return std::fmod(static_cast(x), static_cast(y)); + }, + [](interval x, interval y) { return x % y; })(std::move(ex), std::move(ey)); } -static std::string print_mul(const mul_data& d, int parent_prec) +expr operator-(expr e) { return lit(-1) * std::move(e); } + +bool operator==(const expr& a, const expr& b) { - std::ostringstream os; - bool first = true; - if(d.coefficient == -1) - { - os << "-"; - } - else if(d.coefficient != 1) - { - os << d.coefficient; - first = false; - } - for(const auto& [base, exp] : d.factors) - { - for(int64_t i = 0; i < exp; ++i) - { - if(not first) - os << "*"; - os << print_expr(base, prec_mul + 1); - first = false; - } - } - std::string raw = os.str(); - if(parent_prec > prec_mul) - return "(" + raw + ")"; - return raw; + if(a.pimpl == b.pimpl) + return true; + if(not a.pimpl or not b.pimpl) + return false; + if(a.pimpl->cached_hash != b.pimpl->cached_hash) + return false; + return get_node(a) == get_node(b) and a.children() == b.children(); } -static std::string print_expr(const expr_ptr& e, int parent_prec) +bool operator!=(const expr& a, const expr& b) { return not(a == b); } + +std::ostream& operator<<(std::ostream& os, const expr& e) { return os << e.to_string(); } + +bool expr::empty() const { return not pimpl; } + +std::size_t expr::hash() const { - return std::visit( - overloaded{[](const integer_data& d) -> std::string { return std::to_string(d.value); }, - [](const symbol_data& d) -> std::string { return d.name; }, - [&](const add_data& d) -> std::string { return print_add(d, parent_prec); }, - [&](const mul_data& d) -> std::string { return print_mul(d, parent_prec); }, - [&](const tdiv_data& d) -> std::string { - std::string s = print_expr(d.numerator, prec_mul + 1) + "/" + - print_expr(d.denominator, prec_mul + 1); - if(parent_prec > prec_mul) - return "(" + s + ")"; - return s; - }}, - e->data); + if(not pimpl) + return 0; + return pimpl->cached_hash; } -// =================================================================== -// Section 8: Recursive descent parser -// =================================================================== +static scalar generic_eval_auto_apply(const op_node& op, const std::vector& args) +{ + return op.op->eval(args); +} -static void skip_ws(const char*& p) +static interval generic_eval_auto_apply(const op_node& op, const std::vector& args) { - while(*p != '\0' and std::isspace(static_cast(*p)) != 0) - ++p; + return op.op->eval_interval(args); } -static expr_ptr parse_expr(const char*& p); -static expr_ptr parse_term(const char*& p); -static expr_ptr parse_unary(const char*& p); -static expr_ptr parse_primary(const char*& p); +static expr generic_eval_auto_apply(const op_node& op, const std::vector& args) +{ + return call_op(op.op, args); +} -static expr_ptr parse_primary(const char*& p) +template +static R generic_eval(const expr& e, const Replace& replace, const Apply& apply) { - skip_ws(p); - if(std::isdigit(static_cast(*p)) != 0) - { - int64_t n = 0; - while(std::isdigit(static_cast(*p)) != 0) - { - n = n * 10 + (*p - '0'); - ++p; - } - return make_integer(n); - } - if(std::isalpha(static_cast(*p)) != 0 or *p == '_') - { - std::string name; - while(std::isalnum(static_cast(*p)) != 0 or *p == '_') - { - name += *p; - ++p; - } - return make_symbol(name); - } - if(*p == '(') - { - ++p; - auto inner = parse_expr(p); - skip_ws(p); - if(*p != ')') - MIGRAPHX_THROW("symbolic parser: expected ')'"); - ++p; - return inner; - } - MIGRAPHX_THROW("symbolic parser: unexpected character '" + std::string(1, *p) + "'"); + auto r = replace(e); + if(r) + return *r; + const auto& children = e.children(); + std::vector args; + args.reserve(children.size()); + std::transform(children.begin(), + children.end(), + std::back_inserter(args), + [&](const expr& child) { return generic_eval(child, replace, apply); }); + return apply(std::get(get_node(e)), std::move(args)); } -static expr_ptr parse_unary(const char*& p) +template +static R generic_eval(const expr& e, const Replace& replace) { - skip_ws(p); - if(*p == '-') - { - ++p; - return make_neg(parse_unary(p)); - } - return parse_primary(p); + return generic_eval(e, replace, MIGRAPHX_LIFT(generic_eval_auto_apply)); } -static expr_ptr parse_term(const char*& p) +std::size_t expr::eval_uint(const std::unordered_map& symbol_map) const { - auto left = parse_unary(p); - for(;;) - { - skip_ws(p); - if(*p == '*') - { - ++p; - left = make_mul(left, parse_unary(p)); - } - else if(*p == '/') - { - ++p; - left = make_trunc_div(left, parse_unary(p)); - } - else - break; - } - return left; + return to(generic_eval(*this, [&](const expr& e) -> std::optional { + auto it = symbol_map.find(e); + if(it != symbol_map.end()) + return make_scalar(it->second); + return std::visit( + overloaded{[](const literal_node& n) -> std::optional { return n.val; }, + [](const auto&) -> std::optional { return std::nullopt; }}, + get_node(e)); + })); } -static expr_ptr parse_expr(const char*& p) +expr expr::subs(const std::unordered_map& symbol_map) const { - auto left = parse_term(p); - for(;;) - { - skip_ws(p); - if(*p == '+') - { - ++p; - left = make_add(left, parse_term(p)); - } - else if(*p == '-') - { - ++p; - left = make_sub(left, parse_term(p)); - } - else - break; - } - return left; + return generic_eval(*this, [&](const expr& e) -> std::optional { + auto it = symbol_map.find(e); + if(it != symbol_map.end()) + return it->second; + if(e.empty()) + return e; + return std::visit( + overloaded{[&](const literal_node&) -> std::optional { return e; }, + [&](const variable_node&) -> std::optional { return e; }, + [](const op_node&) -> std::optional { return std::nullopt; }}, + get_node(e)); + }); } -static expr_ptr parse_string(const std::string& s) +expr sin(expr e) { - const char* p = s.c_str(); - auto result = parse_expr(p); - skip_ws(p); - if(*p != '\0') - MIGRAPHX_THROW("symbolic parser: unexpected trailing characters: '" + std::string(p) + "'"); - return result; + return call("sin", MIGRAPHX_LIFT(std::sin), [](interval x) { return sin(x); })(std::move(e)); } -// =================================================================== -// Section 9: sym::expr public API wrapper -// =================================================================== +expr cos(expr e) +{ + return call("cos", MIGRAPHX_LIFT(std::cos), [](interval x) { return cos(x); })(std::move(e)); +} -namespace sym { +expr tan(expr e) +{ + return call("tan", MIGRAPHX_LIFT(std::tan), [](interval x) { return tan(x); })(std::move(e)); +} -struct expr::impl +expr exp(expr e) { - expr_ptr node; + return call("exp", MIGRAPHX_LIFT(std::exp), [](interval x) { return exp(x); })(std::move(e)); +} - impl() : node(make_integer(0)) {} - explicit impl(expr_ptr e) : node(std::move(e)) {} -}; +expr log(expr e) +{ + return call("log", MIGRAPHX_LIFT(std::log), [](interval x) { return log(x); })(std::move(e)); +} + +expr sqrt(expr e) +{ + return call("sqrt", MIGRAPHX_LIFT(std::sqrt), [](interval x) { return sqrt(x); })(std::move(e)); +} -expr::expr() = default; +expr abs(expr e) +{ + return call( + "abs", [](auto x) { return x < 0 ? -x : x; }, [](interval x) { return abs(x); })( + std::move(e)); +} -expr::expr(std::shared_ptr pi) : p(std::move(pi)) {} +expr floor(expr e) +{ + return call("floor", MIGRAPHX_LIFT(std::floor), [](interval x) { return floor(x); })( + std::move(e)); +} -bool expr::empty() const { return p == nullptr; } +expr ceil(expr e) +{ + return call("ceil", MIGRAPHX_LIFT(std::ceil), [](interval x) { return ceil(x); })(std::move(e)); +} -std::size_t expr::hash() const +expr pow(expr x, expr y) { - if(empty()) - return 0; - return p->node->cached_hash; + return call("pow", MIGRAPHX_LIFT(std::pow), [](interval a, interval b) { return pow(a, b); })( + std::move(x), std::move(y)); } -std::string expr::to_string() const +expr min(expr x, expr y) { - if(empty()) - return {}; - return print_expr(p->node); + return call( + "min", + [](auto a, auto b) { return a < b ? a : b; }, + [](interval a, interval b) { return min(a, b); })(std::move(x), std::move(y)); } -std::size_t expr::eval_uint(const std::unordered_map& symbol_map) const +expr max(expr x, expr y) +{ + return call( + "max", + [](auto a, auto b) { return a > b ? a : b; }, + [](interval a, interval b) { return max(a, b); })(std::move(x), std::move(y)); +} + +std::string expr::name() const { if(empty()) - return 0; - binding_map bindings; - for(const auto& [k, v] : symbol_map) - { - if(k.empty() or not holds(k.p->node)) - MIGRAPHX_THROW("sym::expr::eval_uint: map key '" + k.to_string() + "' is not a symbol"); - bindings[k.p->node] = v; - } - auto v = eval_direct(p->node, bindings); - if(v < 0) - MIGRAPHX_THROW("sym::expr::eval_uint: expression evaluated to negative value"); - return v; + return ""; + return get_node_name(get_node(*this)); } -expr expr::subs(const std::unordered_map& symbol_map) const +bool expr::is_raw() const { return pimpl and pimpl->raw_flag; } + +const std::vector& expr::children() const { + static const std::vector empty_children = {}; if(empty()) - return {}; - subs_map bindings; - for(const auto& [k, v] : symbol_map) - { - if(k.empty() or not holds(k.p->node)) - MIGRAPHX_THROW("sym::expr::subs: map key '" + k.to_string() + "' is not a symbol"); - if(v.empty()) - MIGRAPHX_THROW("sym::expr::subs: substitution value must not be empty"); - bindings[k.p->node] = v.p->node; - } - return {std::make_shared(substitute(p->node, bindings))}; + return empty_children; + return pimpl->children; } -expr operator+(const expr& a, const expr& b) +scalar expr::eval(const std::unordered_map& vars) const { - if(a.empty() or b.empty()) - return {}; - return {std::make_shared(make_add(a.p->node, b.p->node))}; + return generic_eval(*this, [&](const expr& e) -> std::optional { + auto it = vars.find(e); + if(it != vars.end()) + return it->second; + return std::visit( + overloaded{[](const literal_node& n) -> std::optional { return n.val; }, + [](const auto&) -> std::optional { return std::nullopt; }}, + get_node(e)); + }); } -expr operator-(const expr& a, const expr& b) +interval expr::eval_interval(const std::unordered_map& vars) const { - if(a.empty() or b.empty()) - return {}; - return {std::make_shared(make_sub(a.p->node, b.p->node))}; + return generic_eval(*this, [&](const expr& e) -> std::optional { + auto it = vars.find(e); + if(it != vars.end()) + return it->second; + return std::visit( + overloaded{[](const literal_node& n) -> std::optional { + return interval{n.val, n.val}; + }, + [](const variable_node& n) -> std::optional { + if(not n.constraints.empty()) + return n.constraints.front(); + MIGRAPHX_THROW("Variable '" + n.name + "' not found in interval map"); + }, + [](const op_node&) -> std::optional { return std::nullopt; }}, + get_node(e)); + }); } -expr operator*(const expr& a, const expr& b) +struct optimal_sample { - if(a.empty() or b.empty()) - return {}; - return {std::make_shared(make_mul(a.p->node, b.p->node))}; + std::unordered_map bindings; + scalar value; +}; + +// Combine optimal samples from each child of an op_node. +// +// Each sample carries the variable bindings that produced its value. Children +// are folded in one at a time: for every existing (base, value-list) pair and +// every sample from the next child, the combination is kept only when their +// bindings agree on every shared variable. This makes repeated occurrences of +// the same variable "pair up" (e.g. h*h with h in {2,3} yields {4, 9} rather +// than the cross-product {4, 6, 9}), while subtrees that depend on disjoint +// variables take the full cartesian product. Once all children are folded in, +// the op's eval is applied to each surviving value list. +static std::vector combine_optimals(const op_node& op, + std::vector> args) +{ + if(args.empty()) + return {{{}, op.op->eval({})}}; + + std::vector, std::vector>> partial; + partial.reserve(args.front().size()); + std::transform(args.front().begin(), + args.front().end(), + std::back_inserter(partial), + [](const optimal_sample& s) { + return std::make_pair(s.bindings, std::vector{s.value}); + }); + + for(std::size_t i = 1; i < args.size(); ++i) + { + std::vector, std::vector>> next; + for(const auto& base : partial) + { + for(const auto& s : args[i]) + { + bool compat = + std::all_of(s.bindings.begin(), s.bindings.end(), [&](const auto& kv) { + auto it = base.first.find(kv.first); + return it == base.first.end() or it->second == kv.second; + }); + if(not compat) + continue; + auto new_bindings = base.first; + new_bindings.insert(s.bindings.begin(), s.bindings.end()); + auto new_values = base.second; + new_values.push_back(s.value); + next.emplace_back(std::move(new_bindings), std::move(new_values)); + } + } + partial = std::move(next); + } + + std::vector result; + result.reserve(partial.size()); + std::transform(partial.begin(), partial.end(), std::back_inserter(result), [&](auto& p) { + return optimal_sample{std::move(p.first), op.op->eval(p.second)}; + }); + return result; } -expr operator/(const expr& a, const expr& b) +std::set expr::eval_optimals() const { - if(a.empty() or b.empty()) + if(empty()) return {}; - return {std::make_shared(make_trunc_div(a.p->node, b.p->node))}; + auto samples = generic_eval>( + *this, + [](const expr& e) -> std::optional> { + return std::visit( + overloaded{ + [](const literal_node& n) -> std::optional> { + return std::vector{{{}, n.val}}; + }, + [&](const variable_node& n) -> std::optional> { + if(n.optimals.empty()) + MIGRAPHX_THROW("Variable '" + n.name + "' has no optimals to evaluate"); + std::vector samples; + samples.reserve(n.optimals.size()); + std::transform( + n.optimals.begin(), + n.optimals.end(), + std::back_inserter(samples), + [&](const scalar& v) { return optimal_sample{{{e, v}}, v}; }); + return samples; + }, + [](const op_node&) -> std::optional> { + return std::nullopt; + }}, + get_node(e)); + }, + [](const op_node& op, std::vector> args) { + return combine_optimals(op, std::move(args)); + }); + std::set result; + std::transform(samples.begin(), + samples.end(), + std::inserter(result, result.end()), + [](const auto& s) { return s.value; }); + return result; } -bool operator==(const expr& a, const expr& b) +static std::string scalar_to_string(const scalar& v) { - if(a.empty() and b.empty()) - return true; - if(a.empty() != b.empty()) - return false; - return expr_equal(a.p->node, b.p->node); + return std::visit( + [](auto x) -> std::string { + std::ostringstream ss; + ss << x; + return ss.str(); + }, + v); } -bool operator!=(const expr& a, const expr& b) { return not(a == b); } +struct string_prec +{ + std::string str; + int prec = 0; +}; -std::ostream& operator<<(std::ostream& os, const expr& e) +static int op_precedence(const std::string& name) { - if(not e.empty()) - os << print_expr(e.p->node); - return os; + if(name == "+" or name == "-") + return 1; + if(name == "*" or name == "/" or name == "%") + return 2; + return 0; } -expr var(const std::string& name) +static bool is_infix_op(const std::string& name) { return op_precedence(name) > 0; } + +static std::string wrap_if(const string_prec& sp, int parent_prec) { - if(name.empty()) - MIGRAPHX_THROW("sym::var: variable name must not be empty"); - return {std::make_shared(make_symbol(name))}; + if(sp.prec > 0 and sp.prec < parent_prec) + return "(" + sp.str + ")"; + return sp.str; +} + +std::string expr::to_string() const +{ + return generic_eval( + *this, + [](const expr& e) -> std::optional { + if(e.empty()) + return string_prec{}; + return std::visit( + overloaded{[](const literal_node& n) -> std::optional { + return string_prec{scalar_to_string(n.val)}; + }, + [](const variable_node& n) -> std::optional { + return string_prec{n.name}; + }, + [](const op_node&) -> std::optional { + return std::nullopt; + }}, + get_node(e)); + }, + [](const op_node& op, std::vector args) -> string_prec { + int prec = op_precedence(op.op->name); + if(is_infix_op(op.op->name) and args.size() >= 2) + { + // -1*x -> -x + if(op.op->name == "*" and args[0].str == "-1") + { + std::vector strs; + strs.reserve(args.size() - 1); + std::transform(args.begin() + 1, + args.end(), + std::back_inserter(strs), + [&](const string_prec& sp) { return wrap_if(sp, prec); }); + return {"-" + join_strings(strs, "*"), prec}; + } + // x + (-y) -> x - y + if(op.op->name == "+") + { + std::string result = wrap_if(args[0], prec); + std::for_each(args.begin() + 1, args.end(), [&](const string_prec& sp) { + auto s = wrap_if(sp, prec); + if(not s.empty() and s.front() == '-') + result += " - " + s.substr(1); + else + result += " + " + s; + }); + return {result, prec}; + } + std::string delim = prec >= 2 ? op.op->name : " " + op.op->name + " "; + std::vector strs; + strs.reserve(args.size()); + std::transform(args.begin(), + args.end(), + std::back_inserter(strs), + [&](const string_prec& sp) { return wrap_if(sp, prec); }); + return {join_strings(strs, delim), prec}; + } + std::vector strs; + strs.reserve(args.size()); + std::transform(args.begin(), + args.end(), + std::back_inserter(strs), + [](const string_prec& sp) { return sp.str; }); + return {op.op->name + "(" + join_strings(strs, ", ") + ")"}; + }) + .str; } -expr lit(int64_t n) { return {std::make_shared(make_integer(n))}; } +std::string to_string(const expr& e) { return e.to_string(); } -expr parse(const std::string& s) +expr pvar(int id) { return var("_" + std::to_string(id)); } + +static expr simplify_impl(const expr& e, const std::vector& rules); + +static expr apply_rules(const expr& e, const std::vector& rules) { - if(s.find_first_not_of(" \t\n\r") == std::string::npos) - return {}; - return {std::make_shared(parse_string(s))}; + for(const auto& rule : rules) + { + std::unordered_map bindings; + if(match_expr(rule.pattern, e, bindings)) + return simplify_impl(rule.replacement.subs(bindings), rules); + } + return e; } -static value node_to_value(const expr_ptr& e) +static expr simplify_impl(const expr& e, const std::vector& rules) { - return std::visit(overloaded{[](const integer_data& d) -> value { - value r; - r["type"] = "int"; - r["value"] = d.value; - return r; - }, - [](const symbol_data& d) -> value { - value r; - r["type"] = "sym"; - r["name"] = d.name; - return r; - }, - [](const add_data& d) -> value { - value r; - r["type"] = "add"; - r["constant"] = d.constant; - value terms = value::array{}; - for(const auto& [term, coeff] : d.terms) - { - value t; - t["expr"] = node_to_value(term); - t["coeff"] = coeff; - terms.push_back(t); - } - r["terms"] = terms; - return r; - }, - [](const mul_data& d) -> value { - value r; - r["type"] = "mul"; - r["coeff"] = d.coefficient; - value factors = value::array{}; - for(const auto& [base, exp] : d.factors) - { - value f; - f["expr"] = node_to_value(base); - f["exp"] = exp; - factors.push_back(f); - } - r["factors"] = factors; - return r; - }, - [](const tdiv_data& d) -> value { - value r; - r["type"] = "tdiv"; - r["num"] = node_to_value(d.numerator); - r["den"] = node_to_value(d.denominator); - return r; - }}, - e->data); + if(e.children().empty()) + return apply_rules(e, rules); + const auto* op_n = std::get_if(&get_node(e)); + std::vector new_children; + new_children.reserve(e.children().size()); + std::transform(e.children().begin(), + e.children().end(), + std::back_inserter(new_children), + [&](const expr& child) { return simplify_impl(child, rules); }); + return apply_rules(call_op(op_n->op, std::move(new_children)), rules); +} + +expr simplify(const expr& e, const std::vector& rules) +{ + return simplify_impl(e, rules); } -static expr_ptr node_from_value(const value& v) +using sym_parser = parser::simple_string_view_skip_parser; + +static expr parse_expr(sym_parser& p); + +template +struct call_wrapper { - const auto& type = v.at("type").get_string(); - if(type == "int") + F f; + template + auto try_call(rank<1>, Args&&... args) const -> decltype(f(std::forward(args)...)) { - return make_integer(v.at("value").get_int64()); + return f(std::forward(args)...); } - else if(type == "sym") + + template + expr try_call(rank<0>, Args&&... args) const { - return make_symbol(v.at("name").get_string()); + MIGRAPHX_THROW( + (std::string("Function is not callable: ") + ... + (to_string(args) + ", "))); } - else if(type == "add") + + template + static expr visit_size(std::size_t n, G g) { - auto constant = v.at("constant").get_int64(); - term_map terms; - for(const auto& t : v.at("terms")) + switch(n) { - auto term = node_from_value(t.at("expr")); - auto coeff = t.at("coeff").get_int64(); - terms[term] = coeff; + case 0: return g(std::integral_constant{}); + case 1: return g(std::integral_constant{}); + case 2: return g(std::integral_constant{}); + case 3: return g(std::integral_constant{}); + default: MIGRAPHX_THROW("Invalid size: " + std::to_string(n)); } - return build_add(constant, std::move(terms)); } - else if(type == "mul") + + expr operator()(const std::vector& args) const { - auto coefficient = v.at("coeff").get_int64(); - factor_map factors; - for(const auto& f : v.at("factors")) - { - auto base = node_from_value(f.at("expr")); - auto exp = f.at("exp").get_int64(); - factors[base] = exp; - } - return build_mul(coefficient, std::move(factors)); + return visit_size(args.size(), [&](auto n) { + return sequence_c([&](auto... is) { return try_call(rank<1>{}, args[is]...); }); + }); + } +}; + +template +call_wrapper(F) -> call_wrapper; + +template +static auto associative_call_wrapper(F f) +{ + return [=](const std::vector& args) { + if(args.empty()) + MIGRAPHX_THROW("Associative function requires at least one argument"); + return std::accumulate(args.begin() + 1, args.end(), args.front(), f); + }; +} + +static expr call_function(const std::string& name, const std::vector& args) +{ +#define MIGRAPHX_CALL_FUNC(name) \ + { \ + #name, call_wrapper { MIGRAPHX_LIFT(name) } \ } - else if(type == "tdiv") + static const std::unordered_map& args)>> + functions = { + {"+", associative_call_wrapper(std::plus<>{})}, + {"*", associative_call_wrapper(std::multiplies<>{})}, + {"-", call_wrapper{std::minus<>{}}}, + {"/", call_wrapper{std::divides<>{}}}, + {"%", call_wrapper{std::modulus<>{}}}, + MIGRAPHX_CALL_FUNC(pow), + MIGRAPHX_CALL_FUNC(min), + MIGRAPHX_CALL_FUNC(max), + MIGRAPHX_CALL_FUNC(sin), + MIGRAPHX_CALL_FUNC(cos), + MIGRAPHX_CALL_FUNC(tan), + MIGRAPHX_CALL_FUNC(exp), + MIGRAPHX_CALL_FUNC(log), + MIGRAPHX_CALL_FUNC(sqrt), + MIGRAPHX_CALL_FUNC(abs), + MIGRAPHX_CALL_FUNC(floor), + MIGRAPHX_CALL_FUNC(ceil), + }; +#undef MIGRAPHX_CALL_FUNC + return functions.at(name)(args); +} + +static expr parse_number(sym_parser& p) +{ + if((std::isdigit(p.peek_char()) == 0) and p.peek_char() != '.') + return {}; + auto token = p.parse_while([](unsigned char c) { return std::isdigit(c) or c == '.'; }); + bool is_float = token.find('.') != std::string_view::npos; + if(is_float) + return lit(std::stod(std::string(token))); + return lit(std::stoll(std::string(token))); +} + +static expr parse_func_or_var(sym_parser& p) +{ + char c = p.peek_char(); + if((std::isalpha(c) == 0) and c != '_') + return {}; + auto name = p.parse_while([](unsigned char ch) { return std::isalnum(ch) or ch == '_'; }); + std::string sname(name); + if(p.peek_char() != '(') + return var(sname); + p.advance(1); + std::vector args; + if(p.peek_char() != ')') { - auto num = node_from_value(v.at("num")); - auto den = node_from_value(v.at("den")); - return make_trunc_div(num, den); + args.push_back(parse_expr(p)); + while(p.match(std::string_view(","))) + args.push_back(parse_expr(p)); } - MIGRAPHX_THROW("Unknown sym::expr node type: " + type); + p.expect(std::string_view(")")); + return call_function(sname, args); } -value expr::to_value() const +static expr parse_paren_expr(sym_parser& p) { - if(empty()) + if(not p.match(std::string_view("("))) + return {}; + auto e = parse_expr(p); + p.expect(std::string_view(")")); + return e; +} + +static expr parse_primary(sym_parser& p) +{ + return p.first_of(&parse_paren_expr, + &parse_func_or_var, + &parse_number, + [](sym_parser& q) -> expr { MIGRAPHX_THROW(q.error_message("expression")); }); +} + +static expr parse_unary(sym_parser& p) +{ + if(p.match(std::string_view("-"))) + return -parse_unary(p); + return parse_primary(p); +} + +static expr parse_mul_expr(sym_parser& p) +{ + auto left = parse_unary(p); + auto ops = p.repeat([](sym_parser& q) -> std::pair { + auto op = q.first_of(std::string_view("*"), std::string_view("/"), std::string_view("%")); + if(op.empty()) + return {}; + return {op, parse_unary(q)}; + }); + for(auto& [op, rhs] : ops) + left = call_function(std::string(op), {std::move(left), std::move(rhs)}); + return left; +} + +static expr parse_expr(sym_parser& p) +{ + auto left = parse_mul_expr(p); + auto ops = p.repeat([](sym_parser& q) -> std::pair { + auto op = q.first_of(std::string_view("+"), std::string_view("-")); + if(op.empty()) + return {}; + return {op, parse_mul_expr(q)}; + }); + for(auto& [op, rhs] : ops) + left = call_function(std::string(op), {std::move(left), std::move(rhs)}); + return left; +} + +expr parse(const std::string& str) +{ + std::string_view sv(str); + sym_parser p{sv}; + // skip leading whitespace + p.advance(0); + if(p.done()) return {}; - return node_to_value(p->node); + auto result = parse_expr(p); + if(not p.done()) + MIGRAPHX_THROW(p.error_message("end of input")); + return result; +} + +static migraphx::value sym_scalar_to_value(const sym::scalar& sv) +{ + return std::visit([](auto x) -> migraphx::value { return migraphx::to_value(x); }, sv); +} + +static sym::scalar value_to_sym_scalar(const migraphx::value& v) +{ + if(v.is_float()) + return sym::scalar{v.get_float()}; + return sym::scalar{v.get_int64()}; +} + +void migraphx_to_value(migraphx::value& v, const sym::interval& i) +{ + migraphx::value result; + result["min"] = sym_scalar_to_value(i.min); + result["max"] = sym_scalar_to_value(i.max); + v = result; } -void expr::from_value(const value& v) +void migraphx_from_value(const migraphx::value& v, sym::interval& i) +{ + i.min = value_to_sym_scalar(v.at("min")); + i.max = value_to_sym_scalar(v.at("max")); +} + +static migraphx::value expr_to_value(const sym::expr& e) +{ + if(e.empty()) + return {}; + migraphx::value result; + std::visit( + [&](const auto& n) { + using t = std::decay_t; + if constexpr(std::is_same{}) + { + result["type"] = "literal"; + result["value"] = sym_scalar_to_value(n.val); + } + else if constexpr(std::is_same{}) + { + result["type"] = "variable"; + result["name"] = n.name; + if(not n.constraints.empty()) + result["constraints"] = migraphx::to_value(n.constraints); + if(not n.optimals.empty()) + { + migraphx::value opt_vals; + std::transform(n.optimals.begin(), + n.optimals.end(), + std::back_inserter(opt_vals), + [](const scalar& s) { return sym_scalar_to_value(s); }); + result["optimals"] = opt_vals; + } + } + else + { + result["type"] = "op"; + result["name"] = n.op->name; + } + }, + get_node(e)); + const auto& children = e.children(); + if(not children.empty()) + { + std::vector child_vals; + child_vals.reserve(children.size()); + std::transform(children.begin(), + children.end(), + std::back_inserter(child_vals), + [](const sym::expr& c) { return expr_to_value(c); }); + result["children"] = child_vals; + } + return result; +} + +void migraphx_to_value(migraphx::value& v, const sym::expr& e) { v = expr_to_value(e); } + +void migraphx_from_value(const migraphx::value& v, sym::expr& e) { if(v.is_null()) { - *this = expr{}; + e = sym::expr{}; return; } - *this = expr{std::make_shared(node_from_value(v))}; + auto type = v.at("type").get_string(); + if(type == "literal") + { + e = sym::lit(value_to_sym_scalar(v.at("value"))); + } + else if(type == "variable") + { + auto name = v.at("name").get_string(); + std::vector constraints; + if(v.contains("constraints")) + constraints = migraphx::from_value>(v.at("constraints")); + std::set optimals; + if(v.contains("optimals")) + { + const auto& opt_vals = v.at("optimals"); + std::transform(opt_vals.begin(), + opt_vals.end(), + std::inserter(optimals, optimals.end()), + [](const migraphx::value& ov) { return value_to_sym_scalar(ov); }); + } + e = expr(variable_node{std::move(name), std::move(constraints), std::move(optimals)}); + } + else + { + auto name = v.at("name").get_string(); + std::vector children; + if(v.contains("children")) + { + const auto& cv = v.at("children"); + children.reserve(cv.size()); + std::transform( + cv.begin(), cv.end(), std::back_inserter(children), [](const migraphx::value& c) { + return migraphx::from_value(c); + }); + } + e = sym::call_function(name, children); + } } } // namespace sym - } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/test/simple_parser.cpp b/test/simple_parser.cpp new file mode 100644 index 00000000000..d96ae972bd9 --- /dev/null +++ b/test/simple_parser.cpp @@ -0,0 +1,242 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include + +using migraphx::parser::action; +using migraphx::parser::simple_string_view_skip_parser; + +TEST_CASE(parser_peek_and_advance) +{ + std::string_view sv("ab cd"); + simple_string_view_skip_parser p{sv}; + EXPECT(p.peek_char() == 'a'); + p.advance(1); + EXPECT(p.peek_char() == 'b'); + p.advance(1); + EXPECT(p.peek_char() == 'c'); +} + +TEST_CASE(parser_done) +{ + std::string_view sv("x"); + simple_string_view_skip_parser p{sv}; + EXPECT(not p.done()); + p.advance(1); + EXPECT(p.done()); +} + +TEST_CASE(parser_match) +{ + std::string_view sv("hello world"); + simple_string_view_skip_parser p{sv}; + EXPECT(p.match(std::string_view("hello"))); + EXPECT(p.peek_char() == 'w'); + EXPECT(not p.match(std::string_view("xyz"))); + EXPECT(p.peek_char() == 'w'); +} + +TEST_CASE(parser_parse_while) +{ + std::string_view sv("abc 123"); + simple_string_view_skip_parser p{sv}; + auto letters = p.parse_while([](char c) { return std::isalpha(c); }); + EXPECT(letters == "abc"); + auto digits = p.parse_while([](char c) { return std::isdigit(c); }); + EXPECT(digits == "123"); +} + +TEST_CASE(parser_try_parse) +{ + std::string_view sv("hello world"); + simple_string_view_skip_parser p{sv}; + bool matched = p.try_parse([](auto& q) { q.match(std::string_view("hello")); }); + EXPECT(matched); + EXPECT(p.peek_char() == 'w'); + + bool missed = p.try_parse([](auto& q) { q.match(std::string_view("xyz")); }); + EXPECT(not missed); + EXPECT(p.peek_char() == 'w'); +} + +TEST_CASE(parser_first_of) +{ + std::string_view sv("42"); + simple_string_view_skip_parser p{sv}; + auto result = p.first_of( + [](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalpha(c); })); + }, + [](auto& q) -> std::string { + if(not std::isdigit(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isdigit(c); })); + }); + EXPECT(result == "42"); +} + +TEST_CASE(parser_first_of_first_match) +{ + std::string_view sv("abc"); + simple_string_view_skip_parser p{sv}; + auto result = p.first_of( + [](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalpha(c); })); + }, + [](auto& q) -> std::string { + if(not std::isdigit(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isdigit(c); })); + }); + EXPECT(result == "abc"); +} + +TEST_CASE(parser_repeat) +{ + std::string_view sv("a b c ."); + simple_string_view_skip_parser p{sv}; + auto results = p.repeat([](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalpha(c); })); + }); + EXPECT(results.size() == 3); + EXPECT(results[0] == "a"); + EXPECT(results[1] == "b"); + EXPECT(results[2] == "c"); + EXPECT(p.peek_char() == '.'); +} + +TEST_CASE(parser_repeat_empty) +{ + std::string_view sv("123"); + simple_string_view_skip_parser p{sv}; + auto results = p.repeat([](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalpha(c); })); + }); + EXPECT(results.empty()); + EXPECT(p.peek_char() == '1'); +} + +TEST_CASE(first_of_views) +{ + std::string_view sv("+ rest"); + simple_string_view_skip_parser p{sv}; + auto result = p.first_of(std::string_view("*"), std::string_view("+"), std::string_view("/")); + EXPECT(result == "+"); + EXPECT(p.peek_char() == 'r'); +} + +TEST_CASE(first_of_views_no_match) +{ + std::string_view sv("xyz"); + simple_string_view_skip_parser p{sv}; + auto result = p.first_of(std::string_view("+"), std::string_view("-")); + EXPECT(result.empty()); + EXPECT(p.peek_char() == 'x'); +} + +TEST_CASE(first_of_views_multichar) +{ + std::string_view sv("== rest"); + simple_string_view_skip_parser p{sv}; + auto result = p.first_of(std::string_view("!="), std::string_view("==")); + EXPECT(result == "=="); +} + +TEST_CASE(parser_action_pipe) +{ + auto parse_alpha = action([](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalpha(c); })); + }); + auto parse_digits = action([](auto& q) -> std::string { + if(not std::isdigit(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isdigit(c); })); + }); + auto parse_token = parse_alpha | parse_digits; + + std::string_view sv("123"); + simple_string_view_skip_parser p{sv}; + auto result = parse_token(p); + EXPECT(result == "123"); + + std::string_view sv2("abc"); + simple_string_view_skip_parser p2{sv2}; + auto result2 = parse_token(p2); + EXPECT(result2 == "abc"); +} + +TEST_CASE(parser_action_star) +{ + auto parse_word = action([](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalpha(c); })); + }); + auto parse_words = *parse_word; + + std::string_view sv("foo bar baz 123"); + simple_string_view_skip_parser p{sv}; + auto results = parse_words(p); + EXPECT(results.size() == 3); + EXPECT(results[0] == "foo"); + EXPECT(results[1] == "bar"); + EXPECT(results[2] == "baz"); +} + +TEST_CASE(parser_action_star_pipe) +{ + auto parse_ident = action([](auto& q) -> std::string { + if(not std::isalpha(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isalnum(c); })); + }); + auto parse_number = action([](auto& q) -> std::string { + if(not std::isdigit(q.peek_char())) + return {}; + return std::string(q.parse_while([](char c) { return std::isdigit(c); })); + }); + auto parse_token = parse_number | parse_ident; + auto parse_tokens = *parse_token; + + std::string_view sv("123 abc 456 def"); + simple_string_view_skip_parser p{sv}; + auto results = parse_tokens(p); + EXPECT(results.size() == 4); + EXPECT(results[0] == "123"); + EXPECT(results[1] == "abc"); + EXPECT(results[2] == "456"); + EXPECT(results[3] == "def"); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/sym.cpp b/test/sym.cpp new file mode 100644 index 00000000000..afbd210f6f2 --- /dev/null +++ b/test/sym.cpp @@ -0,0 +1,3196 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include + +using migraphx::sym::abs; +using migraphx::sym::call; +using migraphx::sym::ceil; +using migraphx::sym::cos; +using migraphx::sym::exp; +using migraphx::sym::expr; +using migraphx::sym::floor; +using migraphx::sym::interval; +using migraphx::sym::lit; +using migraphx::sym::log; +using migraphx::sym::max; +using migraphx::sym::min; +using migraphx::sym::parse; +using migraphx::sym::pow; +using migraphx::sym::pvar; +using migraphx::sym::scalar; +using migraphx::sym::simplify; +using migraphx::sym::sin; +using migraphx::sym::sqrt; +using migraphx::sym::tan; +using migraphx::sym::to_string; +using migraphx::sym::var; + +// ---- make_scalar tests ---- + +using migraphx::sym::make_scalar; + +TEST_CASE(make_scalar_int) +{ + EXPECT(make_scalar(42) == scalar{int64_t{42}}); + EXPECT(make_scalar(-7) == scalar{int64_t{-7}}); + EXPECT(make_scalar(0) == scalar{int64_t{0}}); +} + +TEST_CASE(make_scalar_int64) +{ + EXPECT(make_scalar(int64_t{100}) == scalar{int64_t{100}}); + EXPECT(make_scalar(std::numeric_limits::max()) == + scalar{std::numeric_limits::max()}); + EXPECT(make_scalar(std::numeric_limits::min()) == + scalar{std::numeric_limits::min()}); +} + +TEST_CASE(make_scalar_unsigned) +{ + EXPECT(make_scalar(0u) == scalar{int64_t{0}}); + EXPECT(make_scalar(42u) == scalar{int64_t{42}}); +} + +TEST_CASE(make_scalar_size_t) +{ + EXPECT(make_scalar(std::size_t{0}) == scalar{int64_t{0}}); + EXPECT(make_scalar(std::size_t{100}) == scalar{int64_t{100}}); +} + +TEST_CASE(make_scalar_unsigned_clips_max) +{ + auto max_int64 = std::numeric_limits::max(); + EXPECT(make_scalar(uint64_t(max_int64)) == scalar{max_int64}); + EXPECT(make_scalar(std::numeric_limits::max()) == scalar{max_int64}); + EXPECT(make_scalar(uint64_t(max_int64) + 1) == scalar{max_int64}); +} + +TEST_CASE(make_scalar_double) +{ + EXPECT(make_scalar(3.14) == scalar{3.14}); + EXPECT(make_scalar(0.0) == scalar{0.0}); + EXPECT(make_scalar(-2.5) == scalar{-2.5}); +} + +TEST_CASE(make_scalar_float) { EXPECT(make_scalar(1.0f) == scalar{double{1.0f}}); } + +// ---- Value evaluation tests ---- + +TEST_CASE(literal_zero) { EXPECT(lit(0) == lit(0.0)); } + +TEST_CASE(literal_int_eval) +{ + auto e = lit(42); + auto result = e.eval({}); + EXPECT(result == scalar{int64_t{42}}); +} + +TEST_CASE(literal_double_eval) +{ + auto e = lit(3.14); + auto result = e.eval({}); + EXPECT(result == scalar{3.14}); +} + +TEST_CASE(variable_eval) +{ + auto e = var("x"); + auto result = e.eval({{var("x"), int64_t{10}}}); + EXPECT(result == scalar{int64_t{10}}); +} + +TEST_CASE(add_int_eval) +{ + auto e = lit(3) + lit(4); + auto result = e.eval({}); + EXPECT(result == scalar{int64_t{7}}); +} + +TEST_CASE(add_mixed_eval) +{ + auto e = lit(3) + lit(1.5); + auto result = e.eval({}); + EXPECT(result == scalar{4.5}); +} + +TEST_CASE(sub_eval) +{ + auto e = lit(10) - lit(3); + auto result = e.eval({}); + EXPECT(result == scalar{int64_t{7}}); +} + +TEST_CASE(mul_eval) +{ + auto e = lit(6) * lit(7); + auto result = e.eval({}); + EXPECT(result == scalar{int64_t{42}}); +} + +TEST_CASE(div_double_eval) +{ + auto e = lit(10.0) / lit(4.0); + auto result = e.eval({}); + EXPECT(result == scalar{2.5}); +} + +TEST_CASE(mod_int_eval) +{ + auto e = lit(10) % lit(3); + auto result = e.eval({}); + EXPECT(result == scalar{int64_t{1}}); +} + +TEST_CASE(mod_double_eval) +{ + auto e = lit(10.5) % lit(3.0); + auto result = e.eval({}); + EXPECT(result == scalar{std::fmod(10.5, 3.0)}); +} + +TEST_CASE(mod_variable_eval) +{ + auto e = var("x") % lit(3); + auto result = e.eval({{var("x"), int64_t{10}}}); + EXPECT(result == scalar{int64_t{1}}); +} + +TEST_CASE(neg_eval) +{ + auto e = -lit(5); + auto result = e.eval({}); + EXPECT(result == scalar{int64_t{-5}}); +} + +TEST_CASE(compound_expr_eval) +{ + auto x = var("x"); + auto e = (x + lit(3)) * lit(2); + auto result = e.eval({{var("x"), int64_t{5}}}); + EXPECT(result == scalar{int64_t{16}}); +} + +TEST_CASE(multi_variable_eval) +{ + auto x = var("x"); + auto y = var("y"); + auto e = x * y + lit(1); + + auto result = e.eval({{var("x"), int64_t{3}}, {var("y"), int64_t{4}}}); + EXPECT(result == scalar{int64_t{13}}); +} + +TEST_CASE(sqrt_eval) +{ + auto e = sqrt(lit(4.0)); + auto result = e.eval({}); + EXPECT(result == scalar{2.0}); +} + +TEST_CASE(sqrt_int_eval) +{ + auto e = sqrt(lit(9)); + auto result = e.eval({}); + EXPECT(result == scalar{3.0}); +} + +TEST_CASE(nested_sqrt_eval) +{ + auto e = sqrt(lit(16.0)) + lit(1.0); + auto result = e.eval({}); + EXPECT(result == scalar{5.0}); +} + +TEST_CASE(arg_int_literal) +{ + auto x = var("x"); + auto e = call("+", [](auto a, auto b) { return a + b; })(x, 3); + auto result = e.eval({{var("x"), int64_t{5}}}); + EXPECT(result == scalar{int64_t{8}}); +} + +TEST_CASE(arg_double_literal) +{ + auto x = var("x"); + auto e = call("*", [](auto a, auto b) { return a * b; })(x, 2.0); + auto result = e.eval({{var("x"), 3.0}}); + EXPECT(result == scalar{6.0}); +} + +TEST_CASE(shared_subexpr) +{ + auto x = var("x"); + auto sub = x + lit(1); + auto e = sub * sub; + auto result = e.eval({{var("x"), int64_t{4}}}); + EXPECT(result == scalar{int64_t{25}}); +} + +// ---- Interval evaluation tests ---- + +TEST_CASE(literal_interval) +{ + auto e = lit(5); + auto result = e.eval_interval({}); + EXPECT(result == interval{int64_t{5}, int64_t{5}}); +} + +TEST_CASE(literal_double_interval) +{ + auto e = lit(2.5); + auto result = e.eval_interval({}); + EXPECT(result == interval{2.5, 2.5}); +} + +TEST_CASE(variable_interval) +{ + auto x = var("x"); + auto result = x.eval_interval({{var("x"), interval{int64_t{1}, int64_t{10}}}}); + EXPECT(result == (interval{int64_t{1}, int64_t{10}})); +} + +TEST_CASE(add_interval) +{ + auto x = var("x"); + auto y = var("y"); + auto e = x + y; + // [1,3] + [2,4] = [3,7] + auto result = e.eval_interval({{var("x"), interval{int64_t{1}, int64_t{3}}}, + {var("y"), interval{int64_t{2}, int64_t{4}}}}); + EXPECT(result == (interval{int64_t{3}, int64_t{7}})); +} + +TEST_CASE(sub_interval) +{ + auto x = var("x"); + auto y = var("y"); + auto e = x - y; + // [5,10] - [1,3] = [5-3, 10-1] = [2, 9] + auto result = e.eval_interval({{var("x"), interval{int64_t{5}, int64_t{10}}}, + {var("y"), interval{int64_t{1}, int64_t{3}}}}); + EXPECT(result == (interval{int64_t{2}, int64_t{9}})); +} + +TEST_CASE(mul_interval_positive) +{ + auto x = var("x"); + auto y = var("y"); + auto e = x * y; + // [2,3] * [4,5]: products = {8,10,12,15}, min=8, max=15 + auto result = e.eval_interval({{var("x"), interval{int64_t{2}, int64_t{3}}}, + {var("y"), interval{int64_t{4}, int64_t{5}}}}); + EXPECT(result == (interval{int64_t{8}, int64_t{15}})); +} + +TEST_CASE(mul_interval_mixed_sign) +{ + auto x = var("x"); + auto y = var("y"); + auto e = x * y; + // [-2,3] * [1,4]: products = {-2,-8,3,12}, min=-8, max=12 + auto result = e.eval_interval({{var("x"), interval{int64_t{-2}, int64_t{3}}}, + {var("y"), interval{int64_t{1}, int64_t{4}}}}); + EXPECT(result == (interval{int64_t{-8}, int64_t{12}})); +} + +TEST_CASE(mod_interval) +{ + auto x = var("x"); + auto e = x % lit(3); + // Loose conservative bound: [7,10] % 3 -> [-3, 3] + auto result = e.eval_interval({{var("x"), interval{int64_t{7}, int64_t{10}}}}); + EXPECT(result == (interval{int64_t{-3}, int64_t{3}})); +} + +TEST_CASE(mod_interval_range) +{ + auto x = var("x"); + auto e = x % lit(5); + // Loose conservative bound: [3,8] % 5 -> [-5, 5] + auto result = e.eval_interval({{var("x"), interval{int64_t{3}, int64_t{8}}}}); + EXPECT(result == (interval{int64_t{-5}, int64_t{5}})); +} + +TEST_CASE(neg_interval) +{ + auto x = var("x"); + auto e = -x; + // -[3,7] = [-7,-3] + auto result = e.eval_interval({{var("x"), interval{int64_t{3}, int64_t{7}}}}); + EXPECT(result == (interval{int64_t{-7}, int64_t{-3}})); +} + +TEST_CASE(compound_interval) +{ + auto x = var("x"); + auto e = (x + lit(3)) * lit(2); + // x in [1,5], x+3 in [4,8], *2 in [8,16] + auto result = e.eval_interval({{var("x"), interval{int64_t{1}, int64_t{5}}}}); + EXPECT(result == (interval{int64_t{8}, int64_t{16}})); +} + +TEST_CASE(sqrt_interval) +{ + auto e = sqrt(var("x")); + auto result = e.eval_interval({{var("x"), interval{4.0, 9.0}}}); + EXPECT(result == (interval{2.0, 3.0})); +} + +TEST_CASE(variable_constraint_interval) +{ + auto x = var("x", interval{int64_t{0}, int64_t{100}}); + auto result = x.eval_interval({}); + EXPECT(result == (interval{int64_t{0}, int64_t{100}})); +} + +TEST_CASE(constraint_overridden_by_map) +{ + auto x = var("x", interval{int64_t{0}, int64_t{100}}); + auto result = x.eval_interval({{x, interval{int64_t{5}, int64_t{10}}}}); + EXPECT(result == (interval{int64_t{5}, int64_t{10}})); +} + +// ---- Interval comparison tests ---- + +TEST_CASE(interval_less_true) +{ + // [1,3] < [5,7] → true (a.max < b.min) + interval a{int64_t{1}, int64_t{3}}; + interval b{int64_t{5}, int64_t{7}}; + EXPECT(a < b); +} + +TEST_CASE(interval_less_false) +{ + // [5,7] < [1,3] → false + interval a{int64_t{5}, int64_t{7}}; + interval b{int64_t{1}, int64_t{3}}; + EXPECT(not(a < b)); +} + +TEST_CASE(interval_less_overlapping) +{ + // [1,5] < [3,7] → false (not all values satisfy) + interval a{int64_t{1}, int64_t{5}}; + interval b{int64_t{3}, int64_t{7}}; + EXPECT(not(a < b)); +} + +TEST_CASE(interval_less_equal_endpoint) +{ + // [1,5] < [5,10] → false (a.max == b.min, not strictly less) + interval a{int64_t{1}, int64_t{5}}; + interval b{int64_t{5}, int64_t{10}}; + EXPECT(not(a < b)); +} + +TEST_CASE(interval_leq_true) +{ + // [1,3] <= [3,5] → true (a.max <= b.min) + interval a{int64_t{1}, int64_t{3}}; + interval b{int64_t{3}, int64_t{5}}; + EXPECT(a <= b); +} + +TEST_CASE(interval_leq_false) +{ + // [5,7] <= [1,3] → false + interval a{int64_t{5}, int64_t{7}}; + interval b{int64_t{1}, int64_t{3}}; + EXPECT(not(a <= b)); +} + +TEST_CASE(interval_leq_overlapping) +{ + // [1,5] <= [3,4] → false (not all values satisfy) + interval a{int64_t{1}, int64_t{5}}; + interval b{int64_t{3}, int64_t{4}}; + EXPECT(not(a <= b)); +} + +TEST_CASE(interval_greater_true) +{ + // [5,7] > [1,3] → true (a.min > b.max) + interval a{int64_t{5}, int64_t{7}}; + interval b{int64_t{1}, int64_t{3}}; + EXPECT(a > b); +} + +TEST_CASE(interval_greater_false) +{ + // [1,3] > [5,7] → false + interval a{int64_t{1}, int64_t{3}}; + interval b{int64_t{5}, int64_t{7}}; + EXPECT(not(a > b)); +} + +TEST_CASE(interval_geq_true) +{ + // [3,5] >= [1,3] → true (a.min >= b.max) + interval a{int64_t{3}, int64_t{5}}; + interval b{int64_t{1}, int64_t{3}}; + EXPECT(a >= b); +} + +TEST_CASE(interval_geq_false) +{ + // [1,3] >= [5,7] → false + interval a{int64_t{1}, int64_t{3}}; + interval b{int64_t{5}, int64_t{7}}; + EXPECT(not(a >= b)); +} + +TEST_CASE(interval_geq_overlapping) +{ + // [1,5] >= [3,7] → false (not all values satisfy) + interval a{int64_t{1}, int64_t{5}}; + interval b{int64_t{3}, int64_t{7}}; + EXPECT(not(a >= b)); +} + +// ---- Interval compound assignment tests ---- + +TEST_CASE(interval_plus_assign) +{ + // [1,3] += [2,4] = [3,7] + interval a{int64_t{1}, int64_t{3}}; + a += interval{int64_t{2}, int64_t{4}}; + EXPECT(a == (interval{int64_t{3}, int64_t{7}})); +} + +TEST_CASE(interval_minus_assign) +{ + // [5,10] -= [1,3] = [2,9] + interval a{int64_t{5}, int64_t{10}}; + a -= interval{int64_t{1}, int64_t{3}}; + EXPECT(a == (interval{int64_t{2}, int64_t{9}})); +} + +TEST_CASE(interval_times_assign) +{ + // [2,3] *= [4,5] = [8,15] + interval a{int64_t{2}, int64_t{3}}; + a *= interval{int64_t{4}, int64_t{5}}; + EXPECT(a == (interval{int64_t{8}, int64_t{15}})); +} + +TEST_CASE(interval_div_assign) +{ + // [2.0,10.0] /= [1.0,5.0] = [0.4,10.0] + interval a{2.0, 10.0}; + a /= interval{1.0, 5.0}; + EXPECT(a == (interval{0.4, 10.0})); +} + +TEST_CASE(interval_div_divisor_strictly_crosses_zero) +{ + // [1, 5] / [-2, 4] -> (-inf, +inf) + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{1}, int64_t{5}} / interval{int64_t{-2}, int64_t{4}}; + EXPECT(r == (interval{-inf, inf})); +} + +TEST_CASE(interval_div_num_crosses_and_divisor_crosses) +{ + // [-5, 5] / [-2, 4] -> (-inf, +inf) + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{-5}, int64_t{5}} / interval{int64_t{-2}, int64_t{4}}; + EXPECT(r == (interval{-inf, inf})); +} + +TEST_CASE(interval_div_positive_by_nonneg_divisor) +{ + // [2, 10] / [0, 5] -> [2/5, +inf) = [0.4, +inf) + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{2}, int64_t{10}} / interval{int64_t{0}, int64_t{5}}; + EXPECT(r == (interval{0.4, inf})); +} + +TEST_CASE(interval_div_negative_by_nonneg_divisor) +{ + // [-10, -2] / [0, 5] -> (-inf, -2/5] = (-inf, -0.4] + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{-10}, int64_t{-2}} / interval{int64_t{0}, int64_t{5}}; + EXPECT(r == (interval{-inf, -0.4})); +} + +TEST_CASE(interval_div_positive_by_nonpos_divisor) +{ + // [2, 10] / [-5, 0] -> (-inf, 2/-5] = (-inf, -0.4] + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{2}, int64_t{10}} / interval{int64_t{-5}, int64_t{0}}; + EXPECT(r == (interval{-inf, -0.4})); +} + +TEST_CASE(interval_div_negative_by_nonpos_divisor) +{ + // [-10, -2] / [-5, 0] -> [-2/-5, +inf) = [0.4, +inf) + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{-10}, int64_t{-2}} / interval{int64_t{-5}, int64_t{0}}; + EXPECT(r == (interval{0.4, inf})); +} + +TEST_CASE(interval_div_num_crosses_by_nonneg_divisor) +{ + // [-5, 5] / [0, 5] -> (-inf, +inf) since numerator spans zero too + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{-5}, int64_t{5}} / interval{int64_t{0}, int64_t{5}}; + EXPECT(r == (interval{-inf, inf})); +} + +TEST_CASE(interval_div_num_crosses_by_nonpos_divisor) +{ + // [-5, 5] / [-5, 0] -> (-inf, +inf) + constexpr double inf = std::numeric_limits::infinity(); + auto r = interval{int64_t{-5}, int64_t{5}} / interval{int64_t{-5}, int64_t{0}}; + EXPECT(r == (interval{-inf, inf})); +} + +TEST_CASE(interval_div_by_zero_point_throws) +{ + // [1, 5] / [0, 0] -> throw + test::throws( + [] { (void)(interval{int64_t{1}, int64_t{5}} / interval{int64_t{0}, int64_t{0}}); }); +} + +TEST_CASE(interval_div_positive_no_zero_cross) +{ + // Regression: [2, 4] / [1, 2] -> [1, 4] (normal 4-corner path) + auto r = interval{int64_t{2}, int64_t{4}} / interval{int64_t{1}, int64_t{2}}; + EXPECT(r == (interval{int64_t{1}, int64_t{4}})); +} + +TEST_CASE(interval_div_negative_divisor_no_zero_cross) +{ + // [2, 4] / [-2, -1] -> [-4, -1] (normal 4-corner path) + auto r = interval{int64_t{2}, int64_t{4}} / interval{int64_t{-2}, int64_t{-1}}; + EXPECT(r == (interval{int64_t{-4}, int64_t{-1}})); +} + +TEST_CASE(interval_mod_assign) +{ + // Loose conservative bound: |a%b| < max(|b|), so [7,10] % [3,3] -> [-3, 3] + interval a{int64_t{7}, int64_t{10}}; + a %= interval{int64_t{3}, int64_t{3}}; + EXPECT(a == (interval{int64_t{-3}, int64_t{3}})); +} + +TEST_CASE(interval_mod_positive_divisor) +{ + // [5, 100] % [2, 7] -> [-7, 7] + auto r = interval{int64_t{5}, int64_t{100}} % interval{int64_t{2}, int64_t{7}}; + EXPECT(r == (interval{int64_t{-7}, int64_t{7}})); +} + +TEST_CASE(interval_mod_negative_divisor) +{ + // [5, 100] % [-7, -2] -> [-7, 7] + auto r = interval{int64_t{5}, int64_t{100}} % interval{int64_t{-7}, int64_t{-2}}; + EXPECT(r == (interval{int64_t{-7}, int64_t{7}})); +} + +TEST_CASE(interval_mod_divisor_crosses_zero) +{ + // [1, 5] % [-2, 4] -> [-4, 4] (max_abs = 4) + auto r = interval{int64_t{1}, int64_t{5}} % interval{int64_t{-2}, int64_t{4}}; + EXPECT(r == (interval{int64_t{-4}, int64_t{4}})); +} + +TEST_CASE(interval_mod_divisor_zero_at_low) +{ + // [1, 5] % [0, 3] -> [-3, 3] + auto r = interval{int64_t{1}, int64_t{5}} % interval{int64_t{0}, int64_t{3}}; + EXPECT(r == (interval{int64_t{-3}, int64_t{3}})); +} + +TEST_CASE(interval_mod_divisor_zero_at_high) +{ + // [1, 5] % [-3, 0] -> [-3, 3] + auto r = interval{int64_t{1}, int64_t{5}} % interval{int64_t{-3}, int64_t{0}}; + EXPECT(r == (interval{int64_t{-3}, int64_t{3}})); +} + +TEST_CASE(interval_mod_float_divisor) +{ + // [1.0, 5.0] % [2.0, 2.5] -> [-2.5, 2.5] + auto r = interval{1.0, 5.0} % interval{2.0, 2.5}; + EXPECT(r == (interval{-2.5, 2.5})); +} + +TEST_CASE(interval_mod_by_zero_point_throws) +{ + test::throws( + [] { (void)(interval{int64_t{1}, int64_t{5}} % interval{int64_t{0}, int64_t{0}}); }); +} + +TEST_CASE(interval_compound_assign_no_alias) +{ + interval a{int64_t{1}, int64_t{3}}; + interval b = a; + a += interval{int64_t{10}, int64_t{10}}; + // b unchanged + EXPECT(b == (interval{int64_t{1}, int64_t{3}})); + EXPECT(a == (interval{int64_t{11}, int64_t{13}})); +} + +// ---- Expr structural equality tests ---- + +TEST_CASE(expr_equal_literals) +{ + EXPECT(lit(42) == lit(42)); + EXPECT(lit(3.14) == lit(3.14)); + EXPECT(lit(1) == lit(1.0)); +} + +TEST_CASE(expr_different_literals) { EXPECT(lit(1) != lit(2)); } + +TEST_CASE(expr_equal_variables) { EXPECT(var("x") == var("x")); } + +TEST_CASE(expr_different_variables) { EXPECT(var("x") != var("y")); } + +TEST_CASE(expr_variable_constraint_equality) +{ + auto c = interval{int64_t{0}, int64_t{10}}; + EXPECT(var("x", c) == var("x", c)); + EXPECT(var("x") == var("x", c)); +} + +TEST_CASE(expr_equal_compound) +{ + auto x = var("x"); + EXPECT(x + lit(1) == x + lit(1)); + EXPECT(x * lit(2) == x * lit(2)); +} + +TEST_CASE(expr_different_compound) +{ + auto x = var("x"); + EXPECT(x + lit(1) != x + lit(2)); + EXPECT(x + lit(1) != x * lit(1)); +} + +TEST_CASE(expr_shared_subexpr_identity) +{ + auto x = var("x"); + auto sub = x + lit(1); + EXPECT(sub == sub); +} + +TEST_CASE(expr_default_constructed_equal) +{ + expr a; + expr b; + EXPECT(a == b); +} + +TEST_CASE(expr_default_not_equal_to_lit) +{ + expr a; + EXPECT(a != lit(0)); +} + +// ---- empty tests ---- + +TEST_CASE(empty_default) +{ + expr e; + EXPECT(e.empty()); +} + +TEST_CASE(empty_literal) +{ + auto e = lit(42); + EXPECT(not e.empty()); +} + +TEST_CASE(empty_variable) +{ + auto e = var("x"); + EXPECT(not e.empty()); +} + +TEST_CASE(empty_compound) +{ + auto e = var("x") + lit(1); + EXPECT(not e.empty()); +} + +// ---- hash tests ---- + +TEST_CASE(hash_equal_exprs) +{ + auto a = var("x") + lit(1); + auto b = var("x") + lit(1); + EXPECT(a.hash() == b.hash()); +} + +TEST_CASE(hash_different_exprs) +{ + auto a = var("x") + lit(1); + auto b = var("x") + lit(2); + EXPECT(a.hash() != b.hash()); +} + +TEST_CASE(hash_default_expr) +{ + expr e; + EXPECT(e.hash() == 0); +} + +TEST_CASE(hash_literal) +{ + auto a = lit(42); + auto b = lit(42); + EXPECT(a.hash() == b.hash()); +} + +TEST_CASE(hash_literal_same_value) +{ + auto a = lit(42); + auto b = lit(42.0); + EXPECT(a == b); + EXPECT(a.hash() == b.hash()); +} + +TEST_CASE(hash_different_literals) { EXPECT(lit(1).hash() != lit(2).hash()); } + +TEST_CASE(hash_different_variables) { EXPECT(var("x").hash() != var("y").hash()); } + +TEST_CASE(hash_unordered_map_key) +{ + std::unordered_map m; + auto x = var("x"); + auto y = var("y"); + m[x] = 10; + m[y] = 20; + EXPECT(m.at(x) == 10); + EXPECT(m.at(y) == 20); +} + +// ---- eval_uint tests ---- + +TEST_CASE(eval_uint_literal) +{ + auto e = lit(42); + EXPECT(e.eval_uint({}) == 42); +} + +TEST_CASE(eval_uint_compound) +{ + auto e = lit(3) + lit(4); + EXPECT(e.eval_uint({}) == 7); +} + +TEST_CASE(eval_uint_symbol_map) +{ + auto x = var("x"); + EXPECT(x.eval_uint({{x, 10}}) == 10); +} + +TEST_CASE(eval_uint_symbol_map_compound) +{ + auto x = var("x"); + auto e = x + lit(5); + EXPECT(e.eval_uint({{e, 42}}) == 42); +} + +TEST_CASE(eval_uint_symbol_map_partial) +{ + auto x = var("x"); + auto e = x * lit(2); + // Map x to 7, so x*2 = 14 + auto inner = lit(7) * lit(2); + EXPECT(inner.eval_uint({}) == 14); +} + +// ---- subs tests ---- + +TEST_CASE(subs_variable) +{ + auto x = var("x"); + auto e = x.subs({{x, lit(42)}}); + EXPECT(e == lit(42)); +} + +TEST_CASE(subs_compound) +{ + auto x = var("x"); + auto e = (x + lit(1)).subs({{x, lit(5)}}); + EXPECT(e.eval({}) == scalar{int64_t{6}}); +} + +TEST_CASE(subs_no_match) +{ + auto x = var("x"); + auto y = var("y"); + auto e = x.subs({{y, lit(5)}}); + EXPECT(e == x); +} + +TEST_CASE(subs_nested) +{ + auto x = var("x"); + auto y = var("y"); + auto e = (x + y).subs({{x, lit(3)}, {y, lit(4)}}); + EXPECT(e.eval({}) == scalar{int64_t{7}}); +} + +TEST_CASE(subs_subexpr) +{ + auto x = var("x"); + auto sub = x + lit(1); + auto e = sin(sub); + auto result = e.subs({{sub, lit(0)}}); + // sin(0) = 0.0 + EXPECT(result.eval({}) == scalar{0.0}); +} + +TEST_CASE(subs_literal_unchanged) +{ + auto e = lit(42).subs({{var("x"), lit(5)}}); + EXPECT(e == lit(42)); +} + +TEST_CASE(subs_empty_map) +{ + auto x = var("x"); + auto e = x.subs({}); + EXPECT(e == x); +} + +TEST_CASE(subs_default_expr) +{ + expr e; + auto result = e.subs({{var("x"), lit(1)}}); + EXPECT(result.empty()); +} + +// ---- Compound assignment tests ---- + +TEST_CASE(plus_assign_eval) +{ + auto e = lit(3); + e += lit(4); + EXPECT(e.eval({}) == scalar{int64_t{7}}); +} + +TEST_CASE(minus_assign_eval) +{ + auto e = lit(10); + e -= lit(3); + EXPECT(e.eval({}) == scalar{int64_t{7}}); +} + +TEST_CASE(times_assign_eval) +{ + auto e = lit(6); + e *= lit(7); + EXPECT(e.eval({}) == scalar{int64_t{42}}); +} + +TEST_CASE(div_assign_eval) +{ + auto e = lit(10.0); + e /= lit(4.0); + EXPECT(e.eval({}) == scalar{2.5}); +} + +TEST_CASE(plus_assign_variable) +{ + auto e = var("x"); + e += lit(5); + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{8}}); +} + +TEST_CASE(compound_assign_chain) +{ + auto e = var("x"); + e += lit(1); + e *= lit(2); + // (x + 1) * 2 with x=4 → 10 + EXPECT(e.eval({{var("x"), int64_t{4}}}) == scalar{int64_t{10}}); +} + +TEST_CASE(plus_assign_cow) +{ + auto a = lit(3); + auto b = a; + EXPECT(a == b); + b += lit(1); + // b is now (3 + 1), a is still 3 + EXPECT(a != b); + EXPECT(a.eval({}) == scalar{int64_t{3}}); + EXPECT(b.eval({}) == scalar{int64_t{4}}); +} + +TEST_CASE(times_assign_cow) +{ + auto x = var("x"); + auto a = x + lit(1); + auto b = a; + EXPECT(a == b); + b *= lit(2); + // b is now (x+1)*2, a is still x+1 + EXPECT(a != b); + EXPECT(a.eval({{var("x"), int64_t{5}}}) == scalar{int64_t{6}}); + EXPECT(b.eval({{var("x"), int64_t{5}}}) == scalar{int64_t{12}}); +} + +TEST_CASE(compound_assign_cow_shared) +{ + auto x = var("x"); + auto sub = x + lit(1); + auto a = sub; + auto b = sub; + a += lit(10); + b *= lit(10); + // a = (x+1)+10, b = (x+1)*10 + EXPECT(a != b); + EXPECT(a.eval({{var("x"), int64_t{2}}}) == scalar{int64_t{13}}); + EXPECT(b.eval({{var("x"), int64_t{2}}}) == scalar{int64_t{30}}); + // original sub unchanged + EXPECT(sub.eval({{var("x"), int64_t{2}}}) == scalar{int64_t{3}}); +} + +// ---- Non-expr operator tests (int64_t / double mixed with expr) ---- + +TEST_CASE(add_expr_int64) +{ + auto e = var("x") + int64_t{5}; + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{8}}); +} + +TEST_CASE(add_int64_expr) +{ + auto e = int64_t{5} + var("x"); + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{8}}); +} + +TEST_CASE(add_expr_double) +{ + auto e = var("x") + 1.5; + EXPECT(e.eval({{var("x"), int64_t{2}}}) == scalar{3.5}); +} + +TEST_CASE(add_double_expr) +{ + auto e = 1.5 + var("x"); + EXPECT(e.eval({{var("x"), int64_t{2}}}) == scalar{3.5}); +} + +TEST_CASE(sub_expr_int64) +{ + auto e = var("x") - int64_t{3}; + EXPECT(e.eval({{var("x"), int64_t{10}}}) == scalar{int64_t{7}}); +} + +TEST_CASE(sub_int64_expr) +{ + auto e = int64_t{10} - var("x"); + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{7}}); +} + +TEST_CASE(sub_expr_double) +{ + auto e = var("x") - 0.5; + EXPECT(e.eval({{var("x"), 2.0}}) == scalar{1.5}); +} + +TEST_CASE(sub_double_expr) +{ + auto e = 10.0 - var("x"); + EXPECT(e.eval({{var("x"), 3.0}}) == scalar{7.0}); +} + +TEST_CASE(mul_expr_int64) +{ + auto e = var("x") * int64_t{6}; + EXPECT(e.eval({{var("x"), int64_t{7}}}) == scalar{int64_t{42}}); +} + +TEST_CASE(mul_int64_expr) +{ + auto e = int64_t{6} * var("x"); + EXPECT(e.eval({{var("x"), int64_t{7}}}) == scalar{int64_t{42}}); +} + +TEST_CASE(mul_expr_double) +{ + auto e = var("x") * 2.5; + EXPECT(e.eval({{var("x"), 4.0}}) == scalar{10.0}); +} + +TEST_CASE(mul_double_expr) +{ + auto e = 2.5 * var("x"); + EXPECT(e.eval({{var("x"), 4.0}}) == scalar{10.0}); +} + +TEST_CASE(div_expr_int64) +{ + auto e = var("x") / int64_t{2}; + EXPECT(e.eval({{var("x"), 10.0}}) == scalar{5.0}); +} + +TEST_CASE(div_int64_expr) +{ + auto e = int64_t{10} / var("x"); + EXPECT(e.eval({{var("x"), 2.0}}) == scalar{5.0}); +} + +TEST_CASE(div_expr_double) +{ + auto e = var("x") / 4.0; + EXPECT(e.eval({{var("x"), 10.0}}) == scalar{2.5}); +} + +TEST_CASE(div_double_expr) +{ + auto e = 10.0 / var("x"); + EXPECT(e.eval({{var("x"), 4.0}}) == scalar{2.5}); +} + +TEST_CASE(plus_assign_int64) +{ + auto e = var("x"); + e += int64_t{5}; + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{8}}); +} + +TEST_CASE(plus_assign_double) +{ + auto e = var("x"); + e += 1.5; + EXPECT(e.eval({{var("x"), 2.0}}) == scalar{3.5}); +} + +TEST_CASE(minus_assign_int64) +{ + auto e = var("x"); + e -= int64_t{3}; + EXPECT(e.eval({{var("x"), int64_t{10}}}) == scalar{int64_t{7}}); +} + +TEST_CASE(minus_assign_double) +{ + auto e = var("x"); + e -= 0.5; + EXPECT(e.eval({{var("x"), 2.0}}) == scalar{1.5}); +} + +TEST_CASE(times_assign_int64) +{ + auto e = var("x"); + e *= int64_t{6}; + EXPECT(e.eval({{var("x"), int64_t{7}}}) == scalar{int64_t{42}}); +} + +TEST_CASE(times_assign_double) +{ + auto e = var("x"); + e *= 2.5; + EXPECT(e.eval({{var("x"), 4.0}}) == scalar{10.0}); +} + +TEST_CASE(div_assign_int64) +{ + auto e = var("x"); + e /= int64_t{2}; + EXPECT(e.eval({{var("x"), 10.0}}) == scalar{5.0}); +} + +TEST_CASE(div_assign_double) +{ + auto e = var("x"); + e /= 4.0; + EXPECT(e.eval({{var("x"), 10.0}}) == scalar{2.5}); +} + +TEST_CASE(mod_expr_int64) +{ + auto e = var("x") % int64_t{3}; + EXPECT(e.eval({{var("x"), int64_t{10}}}) == scalar{int64_t{1}}); +} + +TEST_CASE(mod_int64_expr) +{ + auto e = int64_t{10} % var("x"); + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{1}}); +} + +TEST_CASE(mod_expr_double) +{ + auto e = var("x") % 3.0; + EXPECT(e.eval({{var("x"), 10.5}}) == scalar{std::fmod(10.5, 3.0)}); +} + +TEST_CASE(mod_double_expr) +{ + auto e = 10.5 % var("x"); + EXPECT(e.eval({{var("x"), 3.0}}) == scalar{std::fmod(10.5, 3.0)}); +} + +TEST_CASE(mod_assign_expr) +{ + auto e = var("x"); + e %= lit(3); + EXPECT(e.eval({{var("x"), int64_t{10}}}) == scalar{int64_t{1}}); +} + +TEST_CASE(mod_assign_int64) +{ + auto e = var("x"); + e %= int64_t{3}; + EXPECT(e.eval({{var("x"), int64_t{10}}}) == scalar{int64_t{1}}); +} + +TEST_CASE(mod_assign_double) +{ + auto e = var("x"); + e %= 3.0; + EXPECT(e.eval({{var("x"), 10.5}}) == scalar{std::fmod(10.5, 3.0)}); +} + +TEST_CASE(non_expr_compound) +{ + // (x + 3) * 2.0 - 1 with x=4 → (7) * 2.0 - 1 = 13.0 + auto e = (var("x") + int64_t{3}) * 2.0 - int64_t{1}; + EXPECT(e.eval({{var("x"), int64_t{4}}}) == scalar{13.0}); +} + +TEST_CASE(non_expr_both_sides) +{ + // 2 * x + 1.5 with x=3 → 7.5 + auto e = int64_t{2} * var("x") + 1.5; + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{7.5}); +} + +TEST_CASE(non_expr_chain_assign) +{ + auto e = var("x"); + e += int64_t{1}; + e *= 2.0; + e -= int64_t{1}; + // (x + 1) * 2.0 - 1 with x=4 → 9.0 + EXPECT(e.eval({{var("x"), int64_t{4}}}) == scalar{9.0}); +} + +TEST_CASE(custom_call_eval) +{ + auto square = call("square", [](auto x) { return x * x; }); + auto x = var("x"); + auto e = square(x); + auto result = e.eval({{var("x"), int64_t{7}}}); + EXPECT(result == scalar{int64_t{49}}); +} + +TEST_CASE(custom_call_interval) +{ + auto square = call("square", [](auto x) { return x * x; }); + auto x = var("x"); + auto e = square(x); + // [2,3] squared: interval*interval = [2,3]*[2,3], products={4,6,6,9}, min=4, max=9 + auto result = e.eval_interval({{var("x"), interval{int64_t{2}, int64_t{3}}}}); + EXPECT(result == (interval{int64_t{4}, int64_t{9}})); +} + +// ---- Math function eval tests ---- + +TEST_CASE(sin_eval) { EXPECT(sin(lit(0.0)).eval({}) == scalar{0.0}); } + +TEST_CASE(cos_eval) { EXPECT(cos(lit(0.0)).eval({}) == scalar{1.0}); } + +TEST_CASE(tan_eval) { EXPECT(tan(lit(0.0)).eval({}) == scalar{0.0}); } + +TEST_CASE(exp_eval) { EXPECT(exp(lit(0.0)).eval({}) == scalar{1.0}); } + +TEST_CASE(exp_eval_one) { EXPECT(exp(lit(1.0)).eval({}) == scalar{std::exp(1.0)}); } + +TEST_CASE(log_eval) { EXPECT(log(lit(1.0)).eval({}) == scalar{0.0}); } + +TEST_CASE(sqrt_eval_refactored) { EXPECT(sqrt(lit(4.0)).eval({}) == scalar{2.0}); } + +TEST_CASE(abs_int_eval) +{ + EXPECT(abs(lit(-5)).eval({}) == scalar{int64_t{5}}); + EXPECT(abs(lit(3)).eval({}) == scalar{int64_t{3}}); +} + +TEST_CASE(abs_double_eval) { EXPECT(abs(lit(-2.5)).eval({}) == scalar{2.5}); } + +TEST_CASE(floor_eval) +{ + EXPECT(floor(lit(2.7)).eval({}) == scalar{2.0}); + EXPECT(floor(lit(-2.3)).eval({}) == scalar{-3.0}); +} + +TEST_CASE(ceil_eval) +{ + EXPECT(ceil(lit(2.3)).eval({}) == scalar{3.0}); + EXPECT(ceil(lit(-2.7)).eval({}) == scalar{-2.0}); +} + +TEST_CASE(pow_eval) { EXPECT(pow(lit(2.0), lit(3.0)).eval({}) == scalar{8.0}); } + +TEST_CASE(min_eval) +{ + EXPECT(min(lit(3), lit(5)).eval({}) == scalar{int64_t{3}}); + EXPECT(min(lit(7), lit(2)).eval({}) == scalar{int64_t{2}}); +} + +TEST_CASE(max_eval) +{ + EXPECT(max(lit(3), lit(5)).eval({}) == scalar{int64_t{5}}); + EXPECT(max(lit(7), lit(2)).eval({}) == scalar{int64_t{7}}); +} + +TEST_CASE(math_with_variable) +{ + auto x = var("x"); + EXPECT(sin(x).eval({{var("x"), 0.0}}) == scalar{0.0}); + EXPECT(abs(x).eval({{var("x"), int64_t{-7}}}) == scalar{int64_t{7}}); +} + +// ---- Interval math function tests ---- + +TEST_CASE(sin_interval_contains_max) +{ + // sin over [0, π]: reaches max 1.0 at π/2 + const double pi = std::acos(-1.0); + auto result = sin(interval{0.0, pi}); + EXPECT(result == (interval{0.0, 1.0})); +} + +TEST_CASE(sin_interval_full_period) +{ + const double pi = std::acos(-1.0); + auto result = sin(interval{0.0, 2.0 * pi}); + EXPECT(result == (interval{-1.0, 1.0})); +} + +TEST_CASE(cos_interval_contains_min) +{ + // cos over [0, π]: reaches min -1.0 at π + const double pi = std::acos(-1.0); + auto result = cos(interval{0.0, pi}); + EXPECT(result == (interval{-1.0, 1.0})); +} + +TEST_CASE(cos_interval_monotone) +{ + // cos over [0, 1]: monotonically decreasing + auto result = cos(interval{0.0, 1.0}); + EXPECT(result == (interval{std::cos(1.0), 1.0})); +} + +TEST_CASE(tan_interval_point) +{ + auto result = tan(interval{0.0, 0.0}); + EXPECT(result == (interval{0.0, 0.0})); +} + +TEST_CASE(exp_interval) +{ + auto result = exp(interval{0.0, 1.0}); + EXPECT(result == (interval{1.0, std::exp(1.0)})); +} + +TEST_CASE(log_interval) +{ + auto result = log(interval{1.0, std::exp(1.0)}); + EXPECT(result == (interval{0.0, 1.0})); +} + +TEST_CASE(sqrt_interval_refactored) +{ + auto result = sqrt(interval{4.0, 9.0}); + EXPECT(result == (interval{2.0, 3.0})); +} + +TEST_CASE(abs_interval_positive) +{ + auto result = abs(interval{int64_t{2}, int64_t{5}}); + EXPECT(result == (interval{int64_t{2}, int64_t{5}})); +} + +TEST_CASE(abs_interval_negative) +{ + auto result = abs(interval{int64_t{-5}, int64_t{-2}}); + EXPECT(result == (interval{int64_t{2}, int64_t{5}})); +} + +TEST_CASE(abs_interval_mixed) +{ + auto result = abs(interval{int64_t{-3}, int64_t{5}}); + EXPECT(result == (interval{int64_t{0}, int64_t{5}})); +} + +TEST_CASE(abs_interval_mixed_larger_neg) +{ + auto result = abs(interval{int64_t{-7}, int64_t{2}}); + EXPECT(result == (interval{int64_t{0}, int64_t{7}})); +} + +TEST_CASE(floor_interval) +{ + auto result = floor(interval{1.2, 3.8}); + EXPECT(result == (interval{1.0, 3.0})); +} + +TEST_CASE(ceil_interval) +{ + auto result = ceil(interval{1.2, 3.8}); + EXPECT(result == (interval{2.0, 4.0})); +} + +TEST_CASE(pow_interval) +{ + // [2,3]^[2,2] = [4,9] + auto result = pow(interval{2.0, 3.0}, interval{2.0, 2.0}); + EXPECT(result == (interval{4.0, 9.0})); +} + +TEST_CASE(min_interval) +{ + auto result = min(interval{int64_t{1}, int64_t{5}}, interval{int64_t{3}, int64_t{7}}); + EXPECT(result == (interval{int64_t{1}, int64_t{5}})); +} + +TEST_CASE(max_interval) +{ + auto result = max(interval{int64_t{1}, int64_t{5}}, interval{int64_t{3}, int64_t{7}}); + EXPECT(result == (interval{int64_t{3}, int64_t{7}})); +} + +// ---- Expr math function interval eval tests ---- + +TEST_CASE(expr_abs_interval) +{ + auto x = var("x"); + auto result = abs(x).eval_interval({{var("x"), interval{int64_t{-3}, int64_t{5}}}}); + EXPECT(result == (interval{int64_t{0}, int64_t{5}})); +} + +TEST_CASE(expr_min_interval) +{ + auto x = var("x"); + auto y = var("y"); + auto result = min(x, y).eval_interval({{var("x"), interval{int64_t{1}, int64_t{5}}}, + {var("y"), interval{int64_t{3}, int64_t{7}}}}); + EXPECT(result == (interval{int64_t{1}, int64_t{5}})); +} + +TEST_CASE(expr_max_interval) +{ + auto x = var("x"); + auto y = var("y"); + auto result = max(x, y).eval_interval({{var("x"), interval{int64_t{1}, int64_t{5}}}, + {var("y"), interval{int64_t{3}, int64_t{7}}}}); + EXPECT(result == (interval{int64_t{3}, int64_t{7}})); +} + +TEST_CASE(expr_exp_interval) +{ + auto x = var("x"); + auto result = exp(x).eval_interval({{var("x"), interval{0.0, 1.0}}}); + EXPECT(result == (interval{1.0, std::exp(1.0)})); +} + +// ---- to_string tests ---- + +TEST_CASE(to_string_literal_int) +{ + EXPECT(lit(42).to_string() == "42"); + EXPECT(lit(-7).to_string() == "-7"); +} + +TEST_CASE(to_string_literal_double) +{ + EXPECT(lit(3.14).to_string() == "3.14"); + EXPECT(lit(0.0).to_string() == "0"); +} + +TEST_CASE(to_string_variable) { EXPECT(var("x").to_string() == "x"); } + +TEST_CASE(to_string_add) +{ + auto x = var("x"); + // variables sort before literals + EXPECT((x + lit(3)).to_string() == "x + 3"); +} + +TEST_CASE(to_string_sub) +{ + auto x = var("x"); + // x - 1 is rewritten as x + (-1), displayed as subtraction + EXPECT((x - lit(1)).to_string() == "x - 1"); +} + +TEST_CASE(to_string_mul) +{ + auto x = var("x"); + // literals sort before variables in multiplication + EXPECT((x * lit(2)).to_string() == "2*x"); +} + +TEST_CASE(to_string_div) +{ + auto x = var("x"); + EXPECT((x / lit(4)).to_string() == "x/4"); +} + +TEST_CASE(to_string_mod) +{ + auto x = var("x"); + EXPECT((x % lit(3)).to_string() == "x%3"); +} + +TEST_CASE(to_string_neg) +{ + auto x = var("x"); + // -x is rewritten as -1 * x, displayed as negation + EXPECT((-x).to_string() == "-x"); +} + +TEST_CASE(to_string_nested) +{ + auto x = var("x"); + auto y = var("y"); + auto e = (x + lit(1)) * (y - lit(2)); + // fully expanded: (x+1)*(y-2) = xy - 2x + y - 2 + EXPECT(e.to_string() == "x*y - 2*x + y - 2"); +} + +TEST_CASE(to_string_function) +{ + auto x = var("x"); + EXPECT(sin(x).to_string() == "sin(x)"); + EXPECT(sqrt(x).to_string() == "sqrt(x)"); + EXPECT(abs(x).to_string() == "abs(x)"); +} + +TEST_CASE(to_string_function_two_arg) +{ + auto x = var("x"); + auto y = var("y"); + EXPECT(pow(x, y).to_string() == "pow(x, y)"); + EXPECT(min(x, y).to_string() == "min(x, y)"); + EXPECT(max(x, y).to_string() == "max(x, y)"); +} + +TEST_CASE(to_string_composed) +{ + auto x = var("x"); + auto e = sin(x * lit(2)) + lit(1); + // lit(1) sorts before sin(...) + EXPECT(e.to_string() == "sin(2*x) + 1"); +} + +TEST_CASE(free_to_string) +{ + auto x = var("x"); + EXPECT(to_string(x + lit(1)) == "x + 1"); + EXPECT(to_string(sin(x)) == "sin(x)"); +} + +// ---- ostream operator<< tests ---- + +TEST_CASE(ostream_expr) +{ + std::ostringstream ss; + ss << (var("x") + lit(1)); + EXPECT(ss.str() == "x + 1"); +} + +TEST_CASE(ostream_expr_function) +{ + std::ostringstream ss; + ss << sin(var("x")); + EXPECT(ss.str() == "sin(x)"); +} + +TEST_CASE(ostream_interval) +{ + std::ostringstream ss; + ss << interval{int64_t{1}, int64_t{10}}; + EXPECT(ss.str() == "[1, 10]"); +} + +TEST_CASE(ostream_interval_double) +{ + std::ostringstream ss; + ss << interval{1.5, 3.5}; + EXPECT(ss.str() == "[1.5, 3.5]"); +} + +// ---- Associative flattening tests ---- + +TEST_CASE(flatten_add_right) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // (a + b) + c should flatten to +(a, b, c) + auto e = (a + b) + c; + auto result = e.eval({{var("a"), int64_t{1}}, {var("b"), int64_t{2}}, {var("c"), int64_t{3}}}); + EXPECT(result == scalar{int64_t{6}}); + EXPECT(e.children().size() == 3); +} + +TEST_CASE(flatten_add_left) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // a + (b + c) should flatten to +(a, b, c) + auto e = a + (b + c); + auto result = e.eval({{var("a"), int64_t{1}}, {var("b"), int64_t{2}}, {var("c"), int64_t{3}}}); + EXPECT(result == scalar{int64_t{6}}); + EXPECT(e.children().size() == 3); +} + +TEST_CASE(flatten_add_both) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + // (a + b) + (c + d) should flatten to +(a, b, c, d) + auto e = (a + b) + (c + d); + auto result = e.eval({{var("a"), int64_t{1}}, + {var("b"), int64_t{2}}, + {var("c"), int64_t{3}}, + {var("d"), int64_t{4}}}); + EXPECT(result == scalar{int64_t{10}}); + EXPECT(e.children().size() == 4); +} + +TEST_CASE(flatten_mul_right) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // (a * b) * c should flatten to *(a, b, c) + auto e = (a * b) * c; + auto result = e.eval({{var("a"), int64_t{2}}, {var("b"), int64_t{3}}, {var("c"), int64_t{4}}}); + EXPECT(result == scalar{int64_t{24}}); + EXPECT(e.children().size() == 3); +} + +TEST_CASE(flatten_mul_both) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + // (a * b) * (c * d) should flatten to *(a, b, c, d) + auto e = (a * b) * (c * d); + auto result = e.eval({{var("a"), int64_t{2}}, + {var("b"), int64_t{3}}, + {var("c"), int64_t{4}}, + {var("d"), int64_t{5}}}); + EXPECT(result == scalar{int64_t{120}}); + EXPECT(e.children().size() == 4); +} + +TEST_CASE(flatten_nested_add) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + // ((a + b) + c) + d should flatten to +(a, b, c, d) + auto e = ((a + b) + c) + d; + auto result = e.eval({{var("a"), int64_t{1}}, + {var("b"), int64_t{2}}, + {var("c"), int64_t{3}}, + {var("d"), int64_t{4}}}); + EXPECT(result == scalar{int64_t{10}}); + EXPECT(e.children().size() == 4); +} + +TEST_CASE(sub_flattens_into_add) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // (a - b) - c becomes a + (-b) + (-c), flattened to 3 children + auto e = (a - b) - c; + auto result = e.eval({{var("a"), int64_t{10}}, {var("b"), int64_t{3}}, {var("c"), int64_t{2}}}); + EXPECT(result == scalar{int64_t{5}}); + EXPECT(e.children().size() == 3); +} + +TEST_CASE(no_flatten_div) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // (a / b) / c should NOT flatten + auto e = (a / b) / c; + auto result = e.eval({{var("a"), 12.0}, {var("b"), 2.0}, {var("c"), 3.0}}); + EXPECT(result == scalar{2.0}); + EXPECT(e.children().size() == 2); +} + +TEST_CASE(no_flatten_mixed_ops) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // (a * b) + c should NOT flatten mul into add + auto e = (a * b) + c; + EXPECT(e.children().size() == 2); + auto result = e.eval({{var("a"), int64_t{3}}, {var("b"), int64_t{4}}, {var("c"), int64_t{5}}}); + EXPECT(result == scalar{int64_t{17}}); +} + +TEST_CASE(flatten_add_interval) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto e = (a + b) + c; + // [1,2] + [3,4] + [5,6] = [9,12] + auto result = e.eval_interval({{var("a"), interval{int64_t{1}, int64_t{2}}}, + {var("b"), interval{int64_t{3}, int64_t{4}}}, + {var("c"), interval{int64_t{5}, int64_t{6}}}}); + EXPECT(result == (interval{int64_t{9}, int64_t{12}})); +} + +TEST_CASE(flatten_mul_interval) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto e = (a * b) * c; + // [1,2] * [3,4] * [1,1] = [3,8] * [1,1] = [3,8] + auto result = e.eval_interval({{var("a"), interval{int64_t{1}, int64_t{2}}}, + {var("b"), interval{int64_t{3}, int64_t{4}}}, + {var("c"), interval{int64_t{1}, int64_t{1}}}}); + EXPECT(result == (interval{int64_t{3}, int64_t{8}})); +} + +TEST_CASE(flatten_to_string_add) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + EXPECT(((a + b) + c).to_string() == "c + b + a"); +} + +TEST_CASE(flatten_to_string_add_both) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + EXPECT(((a + b) + (c + d)).to_string() == "d + c + b + a"); +} + +TEST_CASE(flatten_to_string_mul) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + EXPECT(((a * b) * c).to_string() == "a*b*c"); +} + +TEST_CASE(flatten_to_string_nested) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + EXPECT((((a + b) + c) + d).to_string() == "d + c + b + a"); +} + +TEST_CASE(flatten_to_string_mixed) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // (a * b) + c: mul is a child of add, should not flatten across ops + // ops sort before variables + EXPECT(((a * b) + c).to_string() == "a*b + c"); +} + +// ---- Constant folding tests ---- + +TEST_CASE(const_fold_add) +{ + auto e = lit(3) + lit(4); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{int64_t{7}}); +} + +TEST_CASE(const_fold_sub) +{ + auto e = lit(10) - lit(3); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{int64_t{7}}); +} + +TEST_CASE(const_fold_mul) +{ + auto e = lit(6) * lit(7); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{int64_t{42}}); +} + +TEST_CASE(const_fold_div) +{ + auto e = lit(10.0) / lit(4.0); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{2.5}); +} + +TEST_CASE(const_fold_neg) +{ + auto e = -lit(5); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{int64_t{-5}}); +} + +TEST_CASE(const_fold_nested) +{ + // (3 + 4) * 2 should fold completely to 14 + auto e = (lit(3) + lit(4)) * lit(2); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{int64_t{14}}); +} + +TEST_CASE(const_fold_math_functions) +{ + EXPECT(sin(lit(0.0)).name() == "literal"); + EXPECT(cos(lit(0.0)).name() == "literal"); + EXPECT(sqrt(lit(4.0)).name() == "literal"); + EXPECT(abs(lit(-5)).name() == "literal"); + EXPECT(floor(lit(2.7)).name() == "literal"); + EXPECT(ceil(lit(2.3)).name() == "literal"); + EXPECT(pow(lit(2.0), lit(3.0)).name() == "literal"); + EXPECT(min(lit(3), lit(5)).name() == "literal"); + EXPECT(max(lit(3), lit(5)).name() == "literal"); +} + +TEST_CASE(no_const_fold_with_variable) +{ + auto x = var("x"); + auto e = x + lit(3); + EXPECT(e.name() != "literal"); +} + +TEST_CASE(const_fold_partial) +{ + auto x = var("x"); + // x + (3 + 4): the (3+4) subexpr folds to 7, but x+7 does not fold + auto e = x + (lit(3) + lit(4)); + EXPECT(e.name() != "literal"); + auto result = e.eval({{var("x"), int64_t{1}}}); + EXPECT(result == scalar{int64_t{8}}); +} + +TEST_CASE(const_fold_chain) +{ + // lit(1) + lit(2) + lit(3) should flatten then fold to 6 + auto e = lit(1) + lit(2) + lit(3); + EXPECT(e.name() == "literal"); + EXPECT(e.eval({}) == scalar{int64_t{6}}); +} + +TEST_CASE(const_fold_to_string) +{ + auto e = lit(3) + lit(4); + EXPECT(e.to_string() == "7"); +} + +// ---- Associative constant folding tests ---- + +TEST_CASE(assoc_fold_add_trailing_literals) +{ + auto x = var("x"); + // x + 2 + 3: flattened to +(x, 2, 3), adjacent literals 2 and 3 fold to 5 + auto e = x + lit(2) + lit(3); + EXPECT(e.eval({{var("x"), int64_t{0}}}) == scalar{int64_t{5}}); + EXPECT(e == x + lit(5)); +} + +TEST_CASE(assoc_fold_mul_trailing_literals) +{ + auto x = var("x"); + // x * 2 * 3: flattened to *(x, 2, 3), adjacent literals 2 and 3 fold to 6 + auto e = x * lit(2) * lit(3); + EXPECT(e.eval({{var("x"), int64_t{1}}}) == scalar{int64_t{6}}); + EXPECT(e == x * lit(6)); +} + +TEST_CASE(assoc_fold_add_leading_literals) +{ + auto x = var("x"); + // 2 + 3 + x: literals adjacent at the front fold to 5 + auto e = lit(2) + lit(3) + x; + EXPECT(e == x + lit(5)); +} + +TEST_CASE(assoc_fold_mul_leading_literals) +{ + auto x = var("x"); + // 2 * 3 * x: literals adjacent at the front fold to 6 + auto e = lit(2) * lit(3) * x; + EXPECT(e == lit(6) * x); +} + +TEST_CASE(assoc_fold_add_three_literals) +{ + // 1 + 2 + 3: all literals fold completely + auto e = lit(1) + lit(2) + lit(3); + EXPECT(e == lit(6)); +} + +TEST_CASE(assoc_fold_mul_three_literals) +{ + // 2 * 3 * 4: all literals fold completely + auto e = lit(2) * lit(3) * lit(4); + EXPECT(e == lit(24)); +} + +TEST_CASE(assoc_fold_add_mixed_chain) +{ + auto x = var("x"); + auto y = var("y"); + // x + 1 + y + 2: after sorting, literals end up adjacent and fold + auto e = x + lit(1) + y + lit(2); + EXPECT(e.eval({{var("x"), int64_t{10}}, {var("y"), int64_t{20}}}) == scalar{int64_t{33}}); +} + +TEST_CASE(assoc_fold_mul_mixed_chain) +{ + auto x = var("x"); + auto y = var("y"); + // x * 2 * y * 3: after sorting, literals end up adjacent and fold + auto e = x * lit(2) * y * lit(3); + EXPECT(e.eval({{var("x"), int64_t{5}}, {var("y"), int64_t{7}}}) == scalar{int64_t{210}}); +} + +TEST_CASE(assoc_fold_preserves_eval) +{ + auto x = var("x"); + // Folding must not change evaluation results + auto e1 = x + lit(10) + lit(20) + lit(30); + auto e2 = x + lit(60); + EXPECT(e1.eval({{var("x"), int64_t{5}}}) == e2.eval({{var("x"), int64_t{5}}})); + EXPECT(e1 == e2); +} + +TEST_CASE(assoc_fold_double_literals) +{ + auto x = var("x"); + // Folding works with double literals too + auto e = x + lit(1.5) + lit(2.5); + EXPECT(e.eval({{var("x"), 0.0}}) == scalar{4.0}); +} + +TEST_CASE(assoc_fold_no_fold_single_literal) +{ + auto x = var("x"); + // Only one literal, nothing to fold + auto e = x + lit(5); + EXPECT(e.eval({{var("x"), int64_t{3}}}) == scalar{int64_t{8}}); +} + +// ---- Canonicalization tests ---- + +TEST_CASE(canonical_add_commutative) +{ + auto x = var("x"); + auto y = var("y"); + // x + y and y + x should be the same expression + EXPECT(x + y == y + x); +} + +TEST_CASE(canonical_mul_commutative) +{ + auto x = var("x"); + auto y = var("y"); + // x * y and y * x should be the same expression + EXPECT(x * y == y * x); +} + +TEST_CASE(canonical_add_three_vars) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // all orderings should produce the same expression + EXPECT((a + b) + c == (c + a) + b); + EXPECT((a + b) + c == (b + c) + a); +} + +TEST_CASE(canonical_mul_three_vars) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + EXPECT((a * b) * c == (c * a) * b); + EXPECT((a * b) * c == (b * c) * a); +} + +TEST_CASE(canonical_add_lit_var_order) +{ + auto x = var("x"); + // lit + var and var + lit should be the same + EXPECT(lit(5) + x == x + lit(5)); +} + +TEST_CASE(canonical_mul_lit_var_order) +{ + auto x = var("x"); + EXPECT(lit(3) * x == x * lit(3)); +} + +TEST_CASE(canonical_div_not_commutative) +{ + auto x = var("x"); + auto y = var("y"); + // division is not commutative, order should be preserved + EXPECT(x / y != y / x); +} + +TEST_CASE(canonical_compound_commutative) +{ + auto x = var("x"); + auto y = var("y"); + // (x+1) * (y+2) and (y+2) * (x+1) should be the same + EXPECT((x + lit(1)) * (y + lit(2)) == (y + lit(2)) * (x + lit(1))); +} + +TEST_CASE(canonical_nested_commutative) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // a + b*c and b*c + a should be the same + EXPECT(a + b * c == b * c + a); +} + +TEST_CASE(canonical_eval_preserved) +{ + auto x = var("x"); + auto y = var("y"); + // canonicalization should not change evaluation results + auto e1 = x + y; + auto e2 = y + x; + EXPECT(e1.eval({{var("x"), int64_t{3}}, {var("y"), int64_t{7}}}) == scalar{int64_t{10}}); + EXPECT(e2.eval({{var("x"), int64_t{3}}, {var("y"), int64_t{7}}}) == scalar{int64_t{10}}); + + auto e3 = x * y; + auto e4 = y * x; + EXPECT(e3.eval({{var("x"), int64_t{3}}, {var("y"), int64_t{7}}}) == scalar{int64_t{21}}); + EXPECT(e4.eval({{var("x"), int64_t{3}}, {var("y"), int64_t{7}}}) == scalar{int64_t{21}}); +} + +TEST_CASE(canonical_interval_preserved) +{ + auto x = var("x"); + auto y = var("y"); + auto vars = std::unordered_map{{var("x"), interval{int64_t{1}, int64_t{3}}}, + {var("y"), interval{int64_t{4}, int64_t{6}}}}; + EXPECT((x + y).eval_interval(vars) == (y + x).eval_interval(vars)); + EXPECT((x * y).eval_interval(vars) == (y * x).eval_interval(vars)); +} + +// ---- Algebraic normalization tests ---- + +TEST_CASE(norm_x_plus_x) +{ + auto x = var("x"); + EXPECT(x + x == lit(2) * x); + EXPECT(x + x == x * lit(2)); +} + +TEST_CASE(norm_x_plus_2x) +{ + auto x = var("x"); + EXPECT(x + lit(2) * x == lit(3) * x); + EXPECT(x + lit(2) * x == x + x + x); +} + +TEST_CASE(norm_3x_minus_x) +{ + auto x = var("x"); + EXPECT(lit(3) * x - x == lit(2) * x); +} + +TEST_CASE(norm_x_minus_x) +{ + auto x = var("x"); + EXPECT(x - x == lit(0)); +} + +TEST_CASE(norm_x_times_0) +{ + auto x = var("x"); + EXPECT(x * lit(0) == lit(0)); + EXPECT(lit(0) * x == lit(0)); +} + +TEST_CASE(norm_x_times_1) +{ + auto x = var("x"); + EXPECT(x * lit(1) == x); + EXPECT(lit(1) * x == x); +} + +TEST_CASE(norm_x_plus_0) +{ + auto x = var("x"); + EXPECT(x + lit(0) == x); + EXPECT(lit(0) + x == x); +} + +TEST_CASE(norm_distribute_simple) +{ + auto x = var("x"); + auto y = var("y"); + // 2*(x+y) == 2*x + 2*y + EXPECT(lit(2) * (x + y) == lit(2) * x + lit(2) * y); +} + +TEST_CASE(norm_foil) +{ + auto x = var("x"); + auto y = var("y"); + // (x+y)*(x+y) == x*x + 2*x*y + y*y + EXPECT((x + y) * (x + y) == x * x + lit(2) * x * y + y * y); +} + +TEST_CASE(norm_foil_eval) +{ + auto x = var("x"); + auto y = var("y"); + auto lhs = (x + y) * (x + y); + auto rhs = x * x + lit(2) * x * y + y * y; + auto vars = std::unordered_map{{var("x"), int64_t{3}}, {var("y"), int64_t{4}}}; + EXPECT(lhs.eval(vars) == scalar{int64_t{49}}); + EXPECT(rhs.eval(vars) == scalar{int64_t{49}}); +} + +TEST_CASE(norm_difference_of_squares) +{ + auto x = var("x"); + auto y = var("y"); + // (x+y)*(x-y) == x*x - y*y + EXPECT((x + y) * (x - y) == x * x - y * y); +} + +TEST_CASE(norm_difference_of_squares_eval) +{ + auto x = var("x"); + auto y = var("y"); + auto lhs = (x + y) * (x - y); + auto rhs = x * x - y * y; + auto vars = std::unordered_map{{var("x"), int64_t{5}}, {var("y"), int64_t{3}}}; + EXPECT(lhs.eval(vars) == scalar{int64_t{16}}); + EXPECT(rhs.eval(vars) == scalar{int64_t{16}}); +} + +TEST_CASE(norm_triple_product) +{ + auto x = var("x"); + // (x+1)*(x+1)*(x+1) expanded is x^3 + 3x^2 + 3x + 1 + auto cubed = (x + lit(1)) * (x + lit(1)) * (x + lit(1)); + auto expanded = x * x * x + lit(3) * x * x + lit(3) * x + lit(1); + EXPECT(cubed == expanded); +} + +TEST_CASE(norm_triple_product_eval) +{ + auto x = var("x"); + auto cubed = (x + lit(1)) * (x + lit(1)) * (x + lit(1)); + auto result = cubed.eval({{var("x"), int64_t{2}}}); + EXPECT(result == scalar{int64_t{27}}); +} + +TEST_CASE(norm_collect_multi_var) +{ + auto x = var("x"); + auto y = var("y"); + // 2*x*y + 3*x*y == 5*x*y + EXPECT(lit(2) * x * y + lit(3) * x * y == lit(5) * x * y); +} + +TEST_CASE(norm_collect_mixed) +{ + auto x = var("x"); + auto y = var("y"); + // x + y + 2*x + 3*y == 3*x + 4*y + EXPECT(x + y + lit(2) * x + lit(3) * y == lit(3) * x + lit(4) * y); +} + +TEST_CASE(norm_nested_distribute) +{ + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + // a*(b+c) == a*b + a*c + EXPECT(a * (b + c) == a * b + a * c); +} + +TEST_CASE(norm_three_binomial) +{ + auto x = var("x"); + auto y = var("y"); + auto z = var("z"); + // (x+y)*(y+z) == x*y + x*z + y*y + y*z + EXPECT((x + y) * (y + z) == x * y + x * z + y * y + y * z); +} + +TEST_CASE(norm_subtract_expanded) +{ + auto x = var("x"); + auto y = var("y"); + // (x+y)*(x+y) - (x*x + y*y) == 2*x*y + EXPECT((x + y) * (x + y) - (x * x + y * y) == lit(2) * x * y); +} + +TEST_CASE(norm_negate_sum) +{ + auto x = var("x"); + auto y = var("y"); + // -(x + y) == -x + -y == -x - y + EXPECT(-(x + y) == -x - y); +} + +TEST_CASE(norm_double_negate) +{ + auto x = var("x"); + // -(-x) == x + EXPECT(-(-x) == x); // cppcheck-suppress migraphx-MultipleUnaryOperator +} + +TEST_CASE(norm_coefficient_fold) +{ + auto x = var("x"); + // 2 * 3 * x == 6 * x + EXPECT(lit(2) * lit(3) * x == lit(6) * x); +} + +TEST_CASE(norm_constant_add_in_sum) +{ + auto x = var("x"); + // (x + 3) + 5 == x + 8 + EXPECT(x + lit(3) + lit(5) == x + lit(8)); +} + +TEST_CASE(norm_zero_product_sum) +{ + auto x = var("x"); + auto y = var("y"); + // x*y - x*y == 0 + EXPECT(x * y - x * y == lit(0)); +} + +// ---- Division normalization tests ---- + +TEST_CASE(norm_div_identity) +{ + auto x = var("x"); + EXPECT(x / lit(1) == x); +} + +TEST_CASE(norm_div_zero_numerator) +{ + auto x = var("x"); + EXPECT(lit(0) / x == lit(0)); + EXPECT(lit(0) / (x + lit(1)) == lit(0)); +} + +TEST_CASE(norm_div_self) +{ + auto x = var("x"); + EXPECT(x / x == lit(1)); + EXPECT((x + lit(1)) / (x + lit(1)) == lit(1)); +} + +TEST_CASE(norm_div_cancel_symbolic_factor) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(lit(2) * h / h == lit(2)); + EXPECT(h * w / h == w); + EXPECT(h * w / w == h); + EXPECT(lit(3) * h * w / h == lit(3) * w); + EXPECT(lit(3) * h * w / (h * w) == lit(3)); +} + +TEST_CASE(norm_div_cancel_coefficient) +{ + auto n = var("n"); + EXPECT((lit(6) * n) / lit(3) == lit(2) * n); + EXPECT((lit(6) * n) / lit(2) == lit(3) * n); +} + +TEST_CASE(norm_div_cancel_mixed) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(h * lit(6) * w / (lit(3) * w) == lit(2) * h); + EXPECT(h * h * w / (h * w) == h); +} + +TEST_CASE(norm_div_cancel_partial) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(lit(5) * h * w / (lit(2) * h) == lit(5) * w / lit(2)); + EXPECT(h * h * w / (lit(2) * h) == h * w / lit(2)); +} + +TEST_CASE(norm_div_cancel_cross_factor) +{ + auto h = var("h"); + auto w = var("w"); + auto c = var("c"); + EXPECT(h * w / (h * c) == w / c); + EXPECT(h * w / (h * h) == w / h); +} + +TEST_CASE(norm_div_distribute_over_sum) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT((lit(2) * h + lit(4)) / lit(2) == h + lit(2)); + EXPECT((lit(6) * h + lit(3) * w + lit(9)) / lit(3) == lit(2) * h + w + lit(3)); + EXPECT((lit(4) * h + lit(2)) / lit(2) == lit(2) * h + lit(1)); +} + +TEST_CASE(norm_div_no_distribute_not_all_divisible) +{ + auto h = var("h"); + // (2*h + 3) / 2: 3 is not divisible by 2, so no distribution + auto r = (lit(2) * h + lit(3)) / lit(2); + EXPECT(r != h); +} + +TEST_CASE(norm_div_constant_folding) +{ + EXPECT(lit(7) / lit(2) == lit(3)); + EXPECT(lit(6) / lit(3) == lit(2)); + EXPECT(lit(0) / lit(5) == lit(0)); +} + +// ---- Rewrite DSL tests ---- + +TEST_CASE(dsl_pvar_match) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + // sqrt(x) matches sqrt(_1) + auto result = simplify(sqrt(x), {sqrt(_1) >> _1}); + EXPECT(result == x); +} + +TEST_CASE(dsl_consistent_binding) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto y = var("y"); + // _1 * _1 matches x*x but not x*y + auto rule = _1 * _1 >> _1; + EXPECT(simplify(x * x, {rule}) == x); + EXPECT(simplify(x * y, {rule}) == x * y); +} + +TEST_CASE(dsl_log_exp) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto result = simplify(log(exp(x)), {log(exp(_1)) >> _1}); + EXPECT(result == x); +} + +TEST_CASE(dsl_exp_log) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto result = simplify(exp(log(x)), {exp(log(_1)) >> _1}); + EXPECT(result == x); +} + +TEST_CASE(dsl_sqrt_product) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto _2 = pvar(2); // NOLINT(readability-identifier-naming) + auto a = var("a"); + auto b = var("b"); + auto result = simplify(sqrt(a * b), {sqrt(_1 * _2) >> sqrt(_1) * sqrt(_2)}); + EXPECT(result == sqrt(a) * sqrt(b)); +} + +TEST_CASE(dsl_sqrt_division) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto _2 = pvar(2); // NOLINT(readability-identifier-naming) + auto a = var("a"); + auto b = var("b"); + auto result = simplify(sqrt(a / b), {sqrt(_1 / _2) >> sqrt(_1) / sqrt(_2)}); + EXPECT(result == sqrt(a) / sqrt(b)); +} + +TEST_CASE(dsl_recursive) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto y = var("y"); + // Rule applied to subexpressions: log(exp(x)) + log(exp(y)) + auto e = log(exp(x)) + log(exp(y)); + auto result = simplify(e, {log(exp(_1)) >> _1}); + EXPECT(result == x + y); +} + +TEST_CASE(dsl_multiple_rules) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto _2 = pvar(2); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto y = var("y"); + // Chain: pow(x,2) → x*x, then abs(x*x) → x*x (already positive) + auto result = + simplify(abs(pow(x, y)), {pow(_1, _2) >> _1 * _2, abs(_1 * _2) >> abs(_1) * abs(_2)}); + EXPECT(result == abs(x) * abs(y)); +} + +TEST_CASE(dsl_trig_identity) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + // sin(x)^2 + cos(x)^2 == 1 + auto e = sin(x) * sin(x) + cos(x) * cos(x); + auto result = simplify(e, {(sin(_1) * sin(_1) + cos(_1) * cos(_1)) >> lit(1)}); + EXPECT(result == lit(1)); +} + +TEST_CASE(dsl_no_match) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + // Rule doesn't match, expression unchanged + auto result = simplify(sin(x), {log(exp(_1)) >> _1}); + EXPECT(result == sin(x)); +} + +TEST_CASE(dsl_literal_pattern) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto result = simplify(pow(x, lit(2)), {pow(_1, lit(2)) >> _1 * _1}); + EXPECT(result == x * x); +} + +TEST_CASE(dsl_eval_after_simplify) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto _2 = pvar(2); // NOLINT(readability-identifier-naming) + auto x = var("x"); + auto y = var("y"); + auto e = sqrt(x * y); + auto result = simplify(e, {sqrt(_1 * _2) >> sqrt(_1) * sqrt(_2)}); + // sqrt(4) * sqrt(9) = 2 * 3 = 6 + EXPECT(result.eval({{var("x"), 4.0}, {var("y"), 9.0}}) == scalar{6.0}); +} + +TEST_CASE(dsl_chained_simplify) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto x = var("x"); + // exp(log(exp(log(x)))) with repeated rule application + auto e = exp(log(exp(log(x)))); + auto result = simplify(e, {exp(log(_1)) >> _1, log(exp(_1)) >> _1}); + EXPECT(result == x); +} + +TEST_CASE(dsl_nested_subexpr) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + auto _2 = pvar(2); // NOLINT(readability-identifier-naming) + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + // sqrt(a*b) + sqrt(c*d): rule applied to both subexprs + auto e = sqrt(a * b) + sqrt(c * d); + auto result = simplify(e, {sqrt(_1 * _2) >> sqrt(_1) * sqrt(_2)}); + EXPECT(result == sqrt(a) * sqrt(b) + sqrt(c) * sqrt(d)); +} + +// ---- Built-in rewrite rule tests ---- + +TEST_CASE(builtin_sqrt_product) +{ + auto a = var("a"); + auto b = var("b"); + // sqrt(a*b) automatically rewrites to sqrt(a)*sqrt(b) + EXPECT(sqrt(a * b) == sqrt(a) * sqrt(b)); +} + +TEST_CASE(builtin_sqrt_division) +{ + auto a = var("a"); + auto b = var("b"); + // sqrt(a/b) automatically rewrites to sqrt(a)/sqrt(b) + EXPECT(sqrt(a / b) == sqrt(a) / sqrt(b)); +} + +TEST_CASE(builtin_log_exp) +{ + auto x = var("x"); + // log(exp(x)) automatically simplifies to x + EXPECT(log(exp(x)) == x); +} + +TEST_CASE(builtin_exp_log) +{ + auto x = var("x"); + // exp(log(x)) automatically simplifies to x + EXPECT(exp(log(x)) == x); +} + +TEST_CASE(builtin_sqrt_product_eval) +{ + auto a = var("a"); + auto b = var("b"); + // sqrt(a*b) == sqrt(a)*sqrt(b), verify eval + auto e = sqrt(a * b); + EXPECT(e.eval({{var("a"), 4.0}, {var("b"), 9.0}}) == scalar{6.0}); +} + +TEST_CASE(builtin_log_exp_nested) +{ + auto x = var("x"); + auto y = var("y"); + // log(exp(x)) + log(exp(y)) automatically simplifies to x + y + EXPECT(log(exp(x)) + log(exp(y)) == x + y); +} + +TEST_CASE(builtin_raw_no_leak) +{ + auto x = var("x"); + // Ensure raw flag doesn't leak into normal expressions + EXPECT(not x.is_raw()); + EXPECT(not(x + lit(1)).is_raw()); + EXPECT(not sqrt(x).is_raw()); +} + +TEST_CASE(builtin_pvar_is_raw) +{ + auto _1 = pvar(1); // NOLINT(readability-identifier-naming) + EXPECT(_1.is_raw()); + // Expressions built from pvars are raw + EXPECT((_1 * pvar(2)).is_raw()); + EXPECT(sqrt(_1).is_raw()); +} + +// ---- Parse tests ---- + +TEST_CASE(parse_integer) +{ + auto e = parse("42"); + EXPECT(e == lit(42)); +} + +TEST_CASE(parse_double) +{ + auto e = parse("3.14"); + EXPECT(e == lit(3.14)); +} + +TEST_CASE(parse_variable) +{ + auto e = parse("x"); + EXPECT(e == var("x")); +} + +TEST_CASE(parse_add) +{ + auto e = parse("x + y"); + EXPECT(e == var("x") + var("y")); +} + +TEST_CASE(parse_sub) +{ + auto e = parse("x - y"); + EXPECT(e == var("x") - var("y")); +} + +TEST_CASE(parse_mul) +{ + auto e = parse("x * y"); + EXPECT(e == var("x") * var("y")); +} + +TEST_CASE(parse_div) +{ + auto e = parse("x / y"); + EXPECT(e == var("x") / var("y")); +} + +TEST_CASE(parse_mod) +{ + auto e = parse("x % y"); + EXPECT(e == var("x") % var("y")); +} + +TEST_CASE(parse_precedence) +{ + auto e = parse("x + y * z"); + EXPECT(e == var("x") + var("y") * var("z")); +} + +TEST_CASE(parse_precedence_left) +{ + auto e = parse("x * y + z"); + EXPECT(e == var("x") * var("y") + var("z")); +} + +TEST_CASE(parse_parens) +{ + auto e = parse("(x + y) * z"); + EXPECT(e == (var("x") + var("y")) * var("z")); +} + +TEST_CASE(parse_nested_parens) +{ + auto e = parse("((x))"); + EXPECT(e == var("x")); +} + +TEST_CASE(parse_unary_neg) +{ + auto e = parse("-x"); + EXPECT(e == -var("x")); +} + +TEST_CASE(parse_neg_in_expr) +{ + auto e = parse("x + -y"); + EXPECT(e == var("x") + (-var("y"))); +} + +TEST_CASE(parse_function_sin) +{ + auto e = parse("sin(x)"); + EXPECT(e == sin(var("x"))); +} + +TEST_CASE(parse_function_cos) +{ + auto e = parse("cos(x)"); + EXPECT(e == cos(var("x"))); +} + +TEST_CASE(parse_function_sqrt) +{ + auto e = parse("sqrt(x)"); + EXPECT(e == sqrt(var("x"))); +} + +TEST_CASE(parse_function_exp) +{ + auto e = parse("exp(x)"); + EXPECT(e == exp(var("x"))); +} + +TEST_CASE(parse_function_log) +{ + auto e = parse("log(x)"); + EXPECT(e == log(var("x"))); +} + +TEST_CASE(parse_function_abs) +{ + auto e = parse("abs(x)"); + EXPECT(e == abs(var("x"))); +} + +TEST_CASE(parse_function_floor) +{ + auto e = parse("floor(x)"); + EXPECT(e == floor(var("x"))); +} + +TEST_CASE(parse_function_ceil) +{ + auto e = parse("ceil(x)"); + EXPECT(e == ceil(var("x"))); +} + +TEST_CASE(parse_function_tan) +{ + auto e = parse("tan(x)"); + EXPECT(e == tan(var("x"))); +} + +TEST_CASE(parse_function_pow) +{ + auto e = parse("pow(x, y)"); + EXPECT(e == pow(var("x"), var("y"))); +} + +TEST_CASE(parse_function_min) +{ + auto e = parse("min(x, y)"); + EXPECT(e == min(var("x"), var("y"))); +} + +TEST_CASE(parse_function_max) +{ + auto e = parse("max(x, y)"); + EXPECT(e == max(var("x"), var("y"))); +} + +TEST_CASE(parse_nested_functions) +{ + auto e = parse("sqrt(x * x + y * y)"); + EXPECT(e == sqrt(var("x") * var("x") + var("y") * var("y"))); +} + +TEST_CASE(parse_complex_expr) +{ + auto e = parse("sin(x) * cos(y) + 1"); + EXPECT(e == sin(var("x")) * cos(var("y")) + lit(1)); +} + +TEST_CASE(parse_whitespace_handling) +{ + auto e = parse(" x + y "); + EXPECT(e == var("x") + var("y")); +} + +TEST_CASE(parse_no_whitespace) +{ + auto e = parse("x+y*z"); + EXPECT(e == var("x") + var("y") * var("z")); +} + +TEST_CASE(parse_literal_arithmetic) +{ + auto e = parse("2 + 3"); + EXPECT(e == lit(5)); +} + +TEST_CASE(parse_roundtrip) +{ + auto x = var("x"); + auto y = var("y"); + auto original = sin(x) + cos(y) * lit(2); + auto str = to_string(original); + auto parsed = parse(str); + EXPECT(parsed == original); +} + +// ---- Serialization tests (to_value / from_value) ---- + +TEST_CASE(serialize_interval_int) +{ + interval i{int64_t{3}, int64_t{10}}; + auto v = migraphx::to_value(i); + auto i2 = migraphx::from_value(v); + EXPECT(i == i2); +} + +TEST_CASE(serialize_interval_double) +{ + interval i{1.5, 3.5}; + auto v = migraphx::to_value(i); + auto i2 = migraphx::from_value(v); + EXPECT(i == i2); +} + +TEST_CASE(serialize_interval_mixed) +{ + interval i{int64_t{0}, 5.5}; + auto v = migraphx::to_value(i); + auto i2 = migraphx::from_value(v); + EXPECT(i == i2); +} + +TEST_CASE(serialize_expr_literal_int) +{ + auto e = lit(42); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_literal_double) +{ + auto e = lit(3.14); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_variable) +{ + auto e = var("x"); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_add) +{ + auto e = var("x") + lit(1); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_compound) +{ + auto e = (var("x") + lit(1)) * var("y"); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_function) +{ + auto e = sin(var("x")); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_nested_function) +{ + auto e = sqrt(var("x") + lit(1)); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_empty) +{ + expr e; + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e2.empty()); +} + +TEST_CASE(serialize_expr_eval_preserved) +{ + auto e = var("x") * lit(2) + lit(3); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e2.eval({{var("x"), int64_t{5}}}) == scalar{int64_t{13}}); +} + +TEST_CASE(serialize_expr_mod) +{ + auto e = var("x") % lit(3); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_variable_constraint) +{ + auto e = var("x", interval{int64_t{0}, int64_t{10}}); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +TEST_CASE(serialize_expr_variable_multiple_constraints) +{ + auto c1 = interval{int64_t{0}, int64_t{10}}; + auto c2 = interval{int64_t{20}, int64_t{30}}; + // Build expression with two constraints via serialization round-trip + auto e1 = var("x", c1); + auto v1 = migraphx::to_value(e1); + // Manually add a second constraint to the serialized form + auto constraints = migraphx::from_value>(v1.at("constraints")); + constraints.push_back(c2); + v1.at("constraints") = migraphx::to_value(constraints); + auto e2 = migraphx::from_value(v1); + // The deserialized expr should have both constraints, not just the last one + auto v2 = migraphx::to_value(e2); + auto result_constraints = migraphx::from_value>(v2.at("constraints")); + EXPECT(result_constraints.size() == 2); + EXPECT(result_constraints[0] == c1); + EXPECT(result_constraints[1] == c2); +} + +TEST_CASE(serialize_expr_variable_optimals) +{ + auto e = var("x", interval{int64_t{0}, int64_t{10}}, std::set{int64_t{1}, int64_t{5}}); + auto v = migraphx::to_value(e); + auto e2 = migraphx::from_value(v); + EXPECT(e == e2); +} + +// ---- eval_optimals tests ---- + +TEST_CASE(var_with_optimals) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}, int64_t{3}}); + EXPECT(x.name() == "variable"); + EXPECT(x.eval({{x, int64_t{5}}}) == scalar{int64_t{5}}); +} + +TEST_CASE(var_with_constraint_and_optimals) +{ + auto x = var("x", interval{int64_t{0}, int64_t{10}}, std::set{int64_t{1}, int64_t{5}}); + EXPECT(x.name() == "variable"); + EXPECT(x.eval({{x, int64_t{3}}}) == scalar{int64_t{3}}); +} + +TEST_CASE(eval_optimals_literal) +{ + auto e = lit(42); + auto result = e.eval_optimals(); + EXPECT(result == std::set{int64_t{42}}); +} + +TEST_CASE(eval_optimals_single_var) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}, int64_t{3}}); + auto result = x.eval_optimals(); + EXPECT(result == std::set{int64_t{1}, int64_t{2}, int64_t{3}}); +} + +TEST_CASE(eval_optimals_var_plus_lit) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto e = x + lit(10); + auto result = e.eval_optimals(); + EXPECT(result == std::set{int64_t{11}, int64_t{12}}); +} + +TEST_CASE(eval_optimals_var_times_lit) +{ + auto x = var("x", interval{}, std::set{int64_t{2}, int64_t{3}, int64_t{5}}); + auto e = x * lit(2); + auto result = e.eval_optimals(); + EXPECT(result == std::set{int64_t{4}, int64_t{6}, int64_t{10}}); +} + +TEST_CASE(eval_optimals_two_vars) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto y = var("y", interval{}, std::set{int64_t{10}, int64_t{20}}); + auto e = x + y; + auto result = e.eval_optimals(); + // Cartesian product: (1,10), (1,20), (2,10), (2,20) -> 11, 21, 12, 22 + EXPECT(result == std::set{int64_t{11}, int64_t{12}, int64_t{21}, int64_t{22}}); +} + +TEST_CASE(eval_optimals_two_vars_multiply) +{ + auto x = var("x", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto y = var("y", interval{}, std::set{int64_t{5}, int64_t{7}}); + auto e = x * y; + auto result = e.eval_optimals(); + // 2*5=10, 2*7=14, 3*5=15, 3*7=21 + EXPECT(result == std::set{int64_t{10}, int64_t{14}, int64_t{15}, int64_t{21}}); +} + +TEST_CASE(eval_optimals_compound) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto e = x * x + lit(1); + auto result = e.eval_optimals(); + // x=1: 1*1+1=2, x=2: 2*2+1=5 + EXPECT(result == std::set{int64_t{2}, int64_t{5}}); +} + +TEST_CASE(eval_optimals_duplicate_results) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{-1}}); + auto e = x * x; + auto result = e.eval_optimals(); + // Both 1*1 and (-1)*(-1) = 1, so result should have single element + EXPECT(result == std::set{int64_t{1}}); +} + +TEST_CASE(eval_optimals_var_no_optimals) +{ + auto x = var("x"); + auto e = x + lit(1); + test::throws([&] { e.eval_optimals(); }); +} + +TEST_CASE(eval_optimals_double_values) +{ + auto x = var("x", interval{}, std::set{1.0, 2.0, 3.0}); + auto e = x * lit(0.5); + auto result = e.eval_optimals(); + EXPECT(result == std::set{0.5, 1.0, 1.5}); +} + +TEST_CASE(eval_optimals_three_vars) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto y = var("y", interval{}, std::set{int64_t{10}}); + auto z = var("z", interval{}, std::set{int64_t{100}, int64_t{200}}); + auto e = x + y + z; + auto result = e.eval_optimals(); + // x=1,y=10,z=100: 111, x=1,y=10,z=200: 211, x=2,y=10,z=100: 112, x=2,y=10,z=200: 212 + EXPECT(result == std::set{int64_t{111}, int64_t{112}, int64_t{211}, int64_t{212}}); +} + +TEST_CASE(eval_optimals_h_squared) +{ + auto h = var("h", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto e = h * h; + auto result = e.eval_optimals(); + // Same variable: paired, not cross-producted: 2*2=4, 3*3=9 + EXPECT(result == std::set{int64_t{4}, int64_t{9}}); +} + +TEST_CASE(eval_optimals_h_cubed) +{ + auto h = var("h", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto e = h * h * h; + auto result = e.eval_optimals(); + // Same variable: 2^3=8, 3^3=27 + EXPECT(result == std::set{int64_t{8}, int64_t{27}}); +} + +TEST_CASE(eval_optimals_h_plus_one_times_h) +{ + auto h = var("h", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto e = (h + lit(1)) * h; + auto result = e.eval_optimals(); + // Same variable across subtrees: h=2: (2+1)*2=6, h=3: (3+1)*3=12 + EXPECT(result == std::set{int64_t{6}, int64_t{12}}); +} + +TEST_CASE(eval_optimals_h_squared_plus_w) +{ + auto h = var("h", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto w = var("w", interval{}, std::set{int64_t{5}, int64_t{7}}); + auto e = h * h + w; + auto result = e.eval_optimals(); + // h*h paired: {4, 9}; cross with w: 4+5=9, 4+7=11, 9+5=14, 9+7=16 + EXPECT(result == std::set{int64_t{9}, int64_t{11}, int64_t{14}, int64_t{16}}); +} + +TEST_CASE(eval_optimals_h_times_w_with_lit) +{ + auto h = var("h", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto w = var("w", interval{}, std::set{int64_t{5}, int64_t{7}}); + auto e = h * w * lit(10); + auto result = e.eval_optimals(); + // Different vars cross-product, then *10: (2*5, 2*7, 3*5, 3*7)*10 = (100,140,150,210) + EXPECT(result == std::set{int64_t{100}, int64_t{140}, int64_t{150}, int64_t{210}}); +} + +TEST_CASE(eval_optimals_independent_subtrees) +{ + auto x = var("x", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto y = var("y", interval{}, std::set{int64_t{4}, int64_t{5}}); + auto e = (x + lit(1)) * (y + lit(1)); + auto result = e.eval_optimals(); + // (x+1) ranges over {3,4}, (y+1) ranges over {5,6}; cartesian product: 15,18,20,24 + EXPECT(result == std::set{int64_t{15}, int64_t{18}, int64_t{20}, int64_t{24}}); +} + +TEST_CASE(eval_optimals_shared_subexpr) +{ + auto x = var("x", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto sub = x + lit(1); + auto e = sub * sub; + auto result = e.eval_optimals(); + // Shared subexpression with same x: x=2: 3*3=9, x=3: 4*4=16 + EXPECT(result == std::set{int64_t{9}, int64_t{16}}); +} + +TEST_CASE(eval_optimals_x_plus_x_times_x) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}, int64_t{3}}); + auto e = (x + x) * x; + auto result = e.eval_optimals(); + // Same variable everywhere: x=1: 2*1=2, x=2: 4*2=8, x=3: 6*3=18 + EXPECT(result == std::set{int64_t{2}, int64_t{8}, int64_t{18}}); +} + +TEST_CASE(eval_optimals_two_squares_summed) +{ + auto x = var("x", interval{}, std::set{int64_t{2}, int64_t{3}}); + auto y = var("y", interval{}, std::set{int64_t{5}, int64_t{7}}); + auto e = x * x + y * y; + auto result = e.eval_optimals(); + // x*x in {4,9} crossed with y*y in {25,49}: 29, 53, 34, 58 + EXPECT(result == std::set{int64_t{29}, int64_t{34}, int64_t{53}, int64_t{58}}); +} + +TEST_CASE(eval_optimals_lit_times_var) +{ + auto x = var("x", interval{}, std::set{int64_t{2}, int64_t{3}, int64_t{5}}); + auto e = lit(3) * x; + auto result = e.eval_optimals(); + // Single variable scaled by literal: 6, 9, 15 + EXPECT(result == std::set{int64_t{6}, int64_t{9}, int64_t{15}}); +} + +TEST_CASE(eval_optimals_three_vars_disjoint) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto y = var("y", interval{}, std::set{int64_t{3}, int64_t{4}}); + auto z = var("z", interval{}, std::set{int64_t{10}, int64_t{20}}); + auto e = x * y + z; + auto result = e.eval_optimals(); + // Disjoint vars -> full cartesian: x*y in {3,4,6,8} crossed with z in {10,20} + EXPECT(result == std::set{int64_t{13}, + int64_t{14}, + int64_t{16}, + int64_t{18}, + int64_t{23}, + int64_t{24}, + int64_t{26}, + int64_t{28}}); +} + +TEST_CASE(eval_optimals_three_vars_x_shared) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto y = var("y", interval{}, std::set{int64_t{3}, int64_t{4}}); + auto z = var("z", interval{}, std::set{int64_t{5}, int64_t{6}}); + auto e = x * y + x * z; + auto result = e.eval_optimals(); + // x is shared between the two terms, so y and z range independently per x. + // x=1: (1*y)+(1*z) for y in {3,4}, z in {5,6} -> {3+5, 3+6, 4+5, 4+6} = {8, 9, 9, 10} + // x=2: (2*y)+(2*z) for y in {3,4}, z in {5,6} -> {6+10, 6+12, 8+10, 8+12} = {16,18,18,20} + EXPECT(result == + std::set{ + int64_t{8}, int64_t{9}, int64_t{10}, int64_t{16}, int64_t{18}, int64_t{20}}); +} + +TEST_CASE(eval_optimals_four_vars_disjoint) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto y = var("y", interval{}, std::set{int64_t{3}}); + auto z = var("z", interval{}, std::set{int64_t{5}}); + auto w = var("w", interval{}, std::set{int64_t{7}, int64_t{8}}); + auto e = x * y + z * w; + auto result = e.eval_optimals(); + // x*y in {3, 6} crossed with z*w in {35, 40}: 38, 43, 41, 46 + EXPECT(result == std::set{int64_t{38}, int64_t{41}, int64_t{43}, int64_t{46}}); +} + +TEST_CASE(eval_optimals_four_vars_x_shared_three_ways) +{ + auto x = var("x", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto y = var("y", interval{}, std::set{int64_t{3}}); + auto z = var("z", interval{}, std::set{int64_t{5}}); + auto w = var("w", interval{}, std::set{int64_t{7}}); + auto e = x * y + x * z + x * w; + auto result = e.eval_optimals(); + // x appears in every term, so all terms must agree: x*(3+5+7) = x*15 -> {15, 30} + EXPECT(result == std::set{int64_t{15}, int64_t{30}}); +} + +TEST_CASE(eval_optimals_four_vars_two_pairs_shared) +{ + auto a = var("a", interval{}, std::set{int64_t{1}, int64_t{2}}); + auto b = var("b", interval{}, std::set{int64_t{3}, int64_t{4}}); + auto c = var("c", interval{}, std::set{int64_t{10}, int64_t{20}}); + auto d = var("d", interval{}, std::set{int64_t{100}}); + auto e = a * b + a * c + b * d; + auto result = e.eval_optimals(); + // a is shared in first two terms; b is shared with the third; c and d are independent. + // Effective: a*(b+c) + b*d, with a in {1,2}, b in {3,4}, c in {10,20}, d=100. + // a=1,b=3,c=10: 1*(3+10) + 3*100 = 13 + 300 = 313 + // a=1,b=3,c=20: 1*(3+20) + 3*100 = 23 + 300 = 323 + // a=1,b=4,c=10: 1*(4+10) + 4*100 = 14 + 400 = 414 + // a=1,b=4,c=20: 1*(4+20) + 4*100 = 24 + 400 = 424 + // a=2,b=3,c=10: 2*(3+10) + 3*100 = 26 + 300 = 326 + // a=2,b=3,c=20: 2*(3+20) + 3*100 = 46 + 300 = 346 + // a=2,b=4,c=10: 2*(4+10) + 4*100 = 28 + 400 = 428 + // a=2,b=4,c=20: 2*(4+20) + 4*100 = 48 + 400 = 448 + EXPECT(result == std::set{int64_t{313}, + int64_t{323}, + int64_t{326}, + int64_t{346}, + int64_t{414}, + int64_t{424}, + int64_t{428}, + int64_t{448}}); +} + +// ---- division by zero tests ---- + +TEST_CASE(div_by_zero_int_throws) +{ + test::throws([&] { var("x") / 0; }); +} + +TEST_CASE(div_by_zero_double_throws) +{ + test::throws([&] { var("x") / 0.0; }); +} + +TEST_CASE(div_by_zero_lit_throws) +{ + test::throws([&] { lit(5) / lit(0); }); +} + +// ---- ceiling division tests ---- + +TEST_CASE(ceildiv_eval) +{ + auto x = var("x"); + auto y = var("y"); + auto e = (x + y - 1) / y; + // ceil(7/3) = 3: (7+3-1)/3 = 9/3 = 3 + EXPECT(e.eval({{x, int64_t{7}}, {y, int64_t{3}}}) == scalar{int64_t{3}}); + // ceil(8/4) = 2: (8+4-1)/4 = 11/4 = 2 + EXPECT(e.eval({{x, int64_t{8}}, {y, int64_t{4}}}) == scalar{int64_t{2}}); + // ceil(9/4) = 3: (9+4-1)/4 = 12/4 = 3 + EXPECT(e.eval({{x, int64_t{9}}, {y, int64_t{4}}}) == scalar{int64_t{3}}); + // ceil(1/1) = 1: (1+1-1)/1 = 1 + EXPECT(e.eval({{x, int64_t{1}}, {y, int64_t{1}}}) == scalar{int64_t{1}}); + // ceil(10/5) = 2: (10+5-1)/5 = 14/5 = 2 + EXPECT(e.eval({{x, int64_t{10}}, {y, int64_t{5}}}) == scalar{int64_t{2}}); + // ceil(11/5) = 3: (11+5-1)/5 = 15/5 = 3 + EXPECT(e.eval({{x, int64_t{11}}, {y, int64_t{5}}}) == scalar{int64_t{3}}); +} + +TEST_CASE(ceildiv_to_string) +{ + auto x = var("x"); + auto y = var("y"); + auto e = (x + y - 1) / y; + auto s = e.to_string(); + // Just verify it produces something reasonable and roundtrips eval + EXPECT(not s.empty()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/sym_test.cpp b/test/sym_test.cpp index f822b0e773f..1be575cc11d 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -516,11 +516,12 @@ TEST_CASE(eval_non_symbol_key_throws) EXPECT(test::throws([&] { h.eval_uint({{h + 1, 10}}); })); } -TEST_CASE(subs_non_symbol_key_throws) +TEST_CASE(subs_non_symbol_key_unchanged) { auto h = var("h"); - EXPECT(test::throws([&] { h.subs({{h + 1, lit(5)}}); })); - EXPECT(test::throws([&] { h.subs({{lit(3), lit(5)}}); })); + // Non-matching keys leave the expression unchanged + EXPECT(h.subs({{h + 1, lit(5)}}) == h); + EXPECT(h.subs({{lit(3), lit(5)}}) == h); } TEST_CASE(subs_partial)