From d2b684e1546e4092031db9bf5745c73881cf4f67 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 24 Mar 2026 11:18:38 -0700 Subject: [PATCH 01/60] custom symbolic expression lib --- src/CMakeLists.txt | 1 + src/include/migraphx/symbolic.hpp | 85 +++ src/symbolic.cpp | 925 ++++++++++++++++++++++++++++++ test/symbolic_test.cpp | 652 +++++++++++++++++++++ 4 files changed, 1663 insertions(+) create mode 100644 src/include/migraphx/symbolic.hpp create mode 100644 src/symbolic.cpp create mode 100644 test/symbolic_test.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bd0083dfe6f..9905c198719 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -136,6 +136,7 @@ add_library(migraphx serialize.cpp shape.cpp shape_transform_descriptor.cpp + symbolic.cpp simplify_algebra.cpp simplify_dyn_ops.cpp simplify_reshapes.cpp diff --git a/src/include/migraphx/symbolic.hpp b/src/include/migraphx/symbolic.hpp new file mode 100644 index 00000000000..c7edac2aee1 --- /dev/null +++ b/src/include/migraphx/symbolic.hpp @@ -0,0 +1,85 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SYMBOLIC_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_SYMBOLIC_HPP + +#include +#include +#include +#include +#include + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct value; + +struct MIGRAPHX_EXPORT symbolic_expr +{ + symbolic_expr(); + explicit symbolic_expr(std::size_t n); + explicit symbolic_expr(const std::string& s); + + bool empty() const; + std::string to_string() const; + std::size_t eval(const std::map& symbol_map) const; + symbolic_expr subs(const std::map& symbol_map) const; + + MIGRAPHX_EXPORT friend symbolic_expr operator+(const symbolic_expr& a, + const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator-(const symbolic_expr& a, + const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator*(const symbolic_expr& a, + const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator/(const symbolic_expr& a, + const symbolic_expr& b); + MIGRAPHX_EXPORT friend bool operator==(const symbolic_expr& a, const symbolic_expr& b); + MIGRAPHX_EXPORT friend bool operator!=(const symbolic_expr& a, const symbolic_expr& b); + MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const symbolic_expr& e); + + struct impl; + + private: + symbolic_expr(std::shared_ptr pi); + std::shared_ptr p; +}; + +inline symbolic_expr operator+(const symbolic_expr& a, std::size_t b) { return a + symbolic_expr(b); } +inline symbolic_expr operator+(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) + b; } +inline symbolic_expr operator-(const symbolic_expr& a, std::size_t b) { return a - symbolic_expr(b); } +inline symbolic_expr operator-(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) - b; } +inline symbolic_expr operator*(const symbolic_expr& a, std::size_t b) { return a * symbolic_expr(b); } +inline symbolic_expr operator*(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) * b; } +inline symbolic_expr operator/(const symbolic_expr& a, std::size_t b) { return a / symbolic_expr(b); } +inline symbolic_expr operator/(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) / b; } + +MIGRAPHX_EXPORT void migraphx_to_value(value& v, const symbolic_expr& e); +MIGRAPHX_EXPORT void migraphx_from_value(const value& v, symbolic_expr& e); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/symbolic.cpp b/src/symbolic.cpp new file mode 100644 index 00000000000..98480a0d92f --- /dev/null +++ b/src/symbolic.cpp @@ -0,0 +1,925 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// =================================================================== +// Section 1: Expression node types +// =================================================================== + +struct expr_node; +using expr_ptr = std::shared_ptr; + +struct expr_compare +{ + bool operator()(const expr_ptr& a, const expr_ptr& b) const; +}; + +using term_map = std::map; +using factor_map = std::map; + +struct integer_data +{ + int64_t value; +}; +struct symbol_data +{ + std::string name; +}; +struct add_data +{ + int64_t constant; + term_map terms; +}; +struct mul_data +{ + int64_t coefficient; + factor_map factors; +}; +struct fdiv_data +{ + expr_ptr numerator; + expr_ptr denominator; +}; + +using expr_data = std::variant; + +constexpr int kind_integer = 0; +constexpr int kind_symbol = 1; +constexpr int kind_add = 2; +constexpr int kind_mul = 3; +constexpr int kind_fdiv = 4; + +struct expr_node +{ + expr_data data; + std::size_t cached_hash = 0; + + int kind() const { return static_cast(data.index()); } +}; + +static bool is_integer(const expr_ptr& e) { return e->kind() == kind_integer; } +static bool is_symbol(const expr_ptr& e) { return e->kind() == kind_symbol; } +static bool is_add(const expr_ptr& e) { return e->kind() == kind_add; } +static bool is_mul(const expr_ptr& e) { return e->kind() == kind_mul; } +static bool is_fdiv(const expr_ptr& e) { return e->kind() == kind_fdiv; } + +static int64_t get_integer(const expr_ptr& e) { return std::get(e->data).value; } +static const std::string& get_symbol_name(const expr_ptr& e) +{ + return std::get(e->data).name; +} +static const add_data& get_add(const expr_ptr& e) { return std::get(e->data); } +static const mul_data& get_mul(const expr_ptr& e) { return std::get(e->data); } +static const fdiv_data& get_fdiv(const expr_ptr& e) { return std::get(e->data); } + +// =================================================================== +// Section 2: Hash computation +// =================================================================== + +static std::size_t hash_combine(std::size_t seed, std::size_t v) +{ + return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2)); +} + +template +static std::size_t hash_ordered_map(const Map& m) +{ + std::size_t h = 0; + for(const auto& [key, val] : m) + { + h = hash_combine(h, key->cached_hash); + h = hash_combine(h, std::hash{}(val)); + } + return h; +} + +static std::size_t compute_hash(const expr_data& d) +{ + std::size_t h = std::hash{}(static_cast(d.index())); + if(auto* p = std::get_if(&d)) + return hash_combine(h, std::hash{}(p->value)); + if(auto* p = std::get_if(&d)) + return hash_combine(h, std::hash{}(p->name)); + if(auto* p = std::get_if(&d)) + { + h = hash_combine(h, std::hash{}(p->constant)); + return hash_combine(h, hash_ordered_map(p->terms)); + } + if(auto* p = std::get_if(&d)) + { + h = hash_combine(h, std::hash{}(p->coefficient)); + return hash_combine(h, hash_ordered_map(p->factors)); + } + if(auto* p = std::get_if(&d)) + { + h = hash_combine(h, p->numerator->cached_hash); + return hash_combine(h, p->denominator->cached_hash); + } + return h; +} + +// =================================================================== +// Section 3: Canonical ordering (expr_compare) +// =================================================================== + +static int compare_expr(const expr_ptr& a, const expr_ptr& b); + +template +static int compare_maps(const Map& a, const Map& b) +{ + auto it_a = a.begin(); + auto it_b = b.begin(); + for(; it_a != a.end() and it_b != b.end(); ++it_a, ++it_b) + { + int c = compare_expr(it_a->first, it_b->first); + if(c != 0) + return c; + if(it_a->second < it_b->second) + return -1; + if(it_a->second > it_b->second) + return 1; + } + if(it_a != a.end()) + return 1; + if(it_b != b.end()) + return -1; + return 0; +} + +static int compare_expr(const expr_ptr& a, const expr_ptr& b) +{ + if(a->kind() != b->kind()) + return a->kind() < b->kind() ? -1 : 1; + + switch(a->kind()) + { + case kind_integer: { + auto va = get_integer(a); + auto vb = get_integer(b); + return (va < vb) ? -1 : (va > vb) ? 1 : 0; + } + case kind_symbol: { + const auto& na = get_symbol_name(a); + const auto& nb = get_symbol_name(b); + return na.compare(nb); + } + case kind_add: { + const auto& da = get_add(a); + const auto& db = get_add(b); + if(da.constant != db.constant) + return da.constant < db.constant ? -1 : 1; + return compare_maps(da.terms, db.terms); + } + case kind_mul: { + const auto& da = get_mul(a); + const auto& db = get_mul(b); + if(da.coefficient != db.coefficient) + return da.coefficient < db.coefficient ? -1 : 1; + return compare_maps(da.factors, db.factors); + } + case kind_fdiv: { + const auto& da = get_fdiv(a); + const auto& db = get_fdiv(b); + int c = compare_expr(da.numerator, db.numerator); + if(c != 0) + return c; + return compare_expr(da.denominator, db.denominator); + } + default: return 0; + } +} + +bool expr_compare::operator()(const expr_ptr& a, const expr_ptr& b) const +{ + return compare_expr(a, b) < 0; +} + +// =================================================================== +// Section 4: Structural equality +// =================================================================== + +static bool expr_equal(const expr_ptr& a, const expr_ptr& b) +{ + if(a.get() == b.get()) + return true; + if(a->cached_hash != b->cached_hash) + return false; + return compare_expr(a, b) == 0; +} + +// =================================================================== +// Section 5: Factory functions (canonical constructors) +// =================================================================== + +static expr_ptr make_node(expr_data d) +{ + auto n = std::make_shared(); + n->data = std::move(d); + n->cached_hash = compute_hash(n->data); + return n; +} + +static const expr_ptr& const_zero() +{ + static auto p = make_node(integer_data{0}); + return p; +} +static const expr_ptr& const_one() +{ + static auto p = make_node(integer_data{1}); + return p; +} +static const expr_ptr& const_neg_one() +{ + static auto p = make_node(integer_data{-1}); + return p; +} + +static expr_ptr make_integer(int64_t n) +{ + if(n == 0) + return const_zero(); + if(n == 1) + return const_one(); + if(n == -1) + return const_neg_one(); + return make_node(integer_data{n}); +} + +static expr_ptr make_symbol(const std::string& name) { return make_node(symbol_data{name}); } + +static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b); +static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b); +static expr_ptr make_neg(const expr_ptr& a); +static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b); +static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b); +static expr_ptr build_mul(int64_t coefficient, factor_map factors); + +struct add_parts +{ + int64_t constant = 0; + term_map terms; +}; + +static add_parts extract_add(const expr_ptr& e) +{ + if(is_integer(e)) + return {get_integer(e), {}}; + if(is_add(e)) + { + const auto& d = get_add(e); + return {d.constant, d.terms}; + } + if(is_mul(e)) + { + const auto& d = get_mul(e); + auto base = build_mul(1, d.factors); + return {0, {{base, d.coefficient}}}; + } + return {0, {{e, 1}}}; +} + +static expr_ptr build_add(int64_t constant, term_map terms) +{ + // Remove zero-coefficient terms + for(auto it = terms.begin(); it != terms.end();) + { + if(it->second == 0) + it = terms.erase(it); + else + ++it; + } + if(terms.empty()) + return make_integer(constant); + if(constant == 0 and terms.size() == 1) + { + auto& [term, coeff] = *terms.begin(); + if(coeff == 1) + return term; + return make_mul(make_integer(coeff), term); + } + return make_node(add_data{constant, std::move(terms)}); +} + +static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b) +{ + auto pa = extract_add(a); + auto pb = extract_add(b); + + int64_t constant = pa.constant + pb.constant; + term_map terms = std::move(pa.terms); + for(const auto& [term, coeff] : pb.terms) + terms[term] += coeff; + + return build_add(constant, std::move(terms)); +} + +static expr_ptr make_neg(const expr_ptr& a) +{ + if(is_integer(a)) + return make_integer(-get_integer(a)); + if(is_add(a)) + { + const auto& d = get_add(a); + term_map negated; + for(const auto& [term, coeff] : d.terms) + negated[term] = -coeff; + return build_add(-d.constant, std::move(negated)); + } + if(is_mul(a)) + { + const auto& d = get_mul(a); + return make_node(mul_data{-d.coefficient, d.factors}); + } + return make_mul(make_integer(-1), a); +} + +static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b) +{ + return make_add(a, make_neg(b)); +} + +struct mul_parts +{ + int64_t coefficient = 1; + factor_map factors; +}; + +static mul_parts extract_mul(const expr_ptr& e) +{ + if(is_integer(e)) + return {get_integer(e), {}}; + if(is_mul(e)) + { + const auto& d = get_mul(e); + return {d.coefficient, d.factors}; + } + return {1, {{e, 1}}}; +} + +static expr_ptr build_mul(int64_t coefficient, factor_map factors) +{ + if(coefficient == 0) + return make_integer(0); + for(auto it = factors.begin(); it != factors.end();) + { + if(it->second == 0) + it = factors.erase(it); + else + ++it; + } + if(factors.empty()) + return make_integer(coefficient); + if(coefficient == 1 and factors.size() == 1) + { + auto& [base, exp] = *factors.begin(); + if(exp == 1) + return base; + } + return make_node(mul_data{coefficient, std::move(factors)}); +} + +static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) +{ + // Special case: multiplying an integer by an Add → scale the Add + if(is_integer(a) and is_add(b)) + { + int64_t n = get_integer(a); + if(n == 0) + return make_integer(0); + if(n == 1) + return b; + const auto& d = get_add(b); + term_map scaled; + for(const auto& [term, coeff] : d.terms) + scaled[term] = coeff * n; + return build_add(d.constant * n, std::move(scaled)); + } + if(is_integer(b) and is_add(a)) + return make_mul(b, a); + + auto pa = extract_mul(a); + auto pb = extract_mul(b); + + int64_t coefficient = pa.coefficient * pb.coefficient; + if(coefficient == 0) + return make_integer(0); + + factor_map factors = std::move(pa.factors); + for(const auto& [base, exp] : pb.factors) + factors[base] += exp; + + return build_mul(coefficient, std::move(factors)); +} + +static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) +{ + if(is_integer(b)) + { + int64_t den = get_integer(b); + if(den == 0) + MIGRAPHX_THROW("symbolic: division by zero"); + if(den == 1) + return a; + if(is_integer(a)) + return make_integer(get_integer(a) / den); + // Cancel exact coefficient in Mul: (c*factors)/den where c%den==0 + if(is_mul(a)) + { + const auto& d = get_mul(a); + if(d.coefficient % den == 0) + return build_mul(d.coefficient / den, d.factors); + } + } + + return make_node(fdiv_data{a, b}); +} + +// =================================================================== +// Section 6: Substitution and evaluation +// =================================================================== + +// Partial substitution: replaces bound symbols with integers and re-canonicalizes. +// Unbound symbols are left as-is, producing a simplified symbolic expression. +static expr_ptr substitute(const expr_ptr& e, + const std::map& bindings) +{ + switch(e->kind()) + { + case kind_integer: return e; + case kind_symbol: { + auto it = bindings.find(get_symbol_name(e)); + if(it != bindings.end()) + return make_integer(it->second); + return e; + } + case kind_add: { + const auto& d = get_add(e); + expr_ptr result = make_integer(d.constant); + for(const auto& [term, coeff] : d.terms) + { + auto st = substitute(term, bindings); + result = make_add(result, make_mul(make_integer(coeff), st)); + } + return result; + } + case kind_mul: { + const auto& d = get_mul(e); + expr_ptr result = make_integer(d.coefficient); + for(const auto& [base, exp] : d.factors) + { + auto sb = substitute(base, bindings); + for(int64_t i = 0; i < exp; ++i) + result = make_mul(result, sb); + } + return result; + } + case kind_fdiv: { + const auto& d = get_fdiv(e); + auto sn = substitute(d.numerator, bindings); + auto sd = substitute(d.denominator, bindings); + return make_floor_div(sn, sd); + } + default: return e; + } +} + +// Full evaluation: computes integer result directly without allocations. +// All symbols must be bound; throws if any symbol is unbound. +static int64_t eval_direct(const expr_ptr& e, + const std::map& symbol_map) +{ + switch(e->kind()) + { + case kind_integer: return get_integer(e); + case kind_symbol: { + auto it = symbol_map.find(get_symbol_name(e)); + if(it != symbol_map.end()) + return static_cast(it->second); + MIGRAPHX_THROW("symbolic_expr::eval: unbound symbol '" + get_symbol_name(e) + "'"); + } + case kind_add: { + const auto& d = get_add(e); + int64_t sum = d.constant; + for(const auto& [term, coeff] : d.terms) + sum += coeff * eval_direct(term, symbol_map); + return sum; + } + case kind_mul: { + const auto& d = get_mul(e); + int64_t prod = d.coefficient; + for(const auto& [base, exp] : d.factors) + { + int64_t val = eval_direct(base, symbol_map); + for(int64_t i = 0; i < exp; ++i) + prod *= val; + } + return prod; + } + case kind_fdiv: { + const auto& d = get_fdiv(e); + return eval_direct(d.numerator, symbol_map) / eval_direct(d.denominator, symbol_map); + } + default: return 0; + } +} + +static int64_t evaluate(const expr_ptr& e, + const std::map& symbol_map) +{ + return eval_direct(e, symbol_map); +} + +// =================================================================== +// Section 7: Pretty-printer +// =================================================================== + +enum +{ + prec_atom = 100, + prec_mul = 50, + prec_add = 40 +}; + +static std::string print_expr(const expr_ptr& e, int parent_prec = 0); + +static std::string print_expr(const expr_ptr& e, int parent_prec) +{ + switch(e->kind()) + { + case kind_integer: return std::to_string(get_integer(e)); + case kind_symbol: return get_symbol_name(e); + case kind_add: { + const auto& d = get_add(e); + std::ostringstream os; + bool first = true; + for(const auto& [term, coeff] : d.terms) + { + if(first) + { + if(coeff == -1) + os << "-" << print_expr(term, prec_add); + else if(coeff == 1) + os << print_expr(term, prec_add); + else + os << coeff << "*" << print_expr(term, prec_mul + 1); + first = false; + } + else + { + if(coeff == 1) + os << " + " << print_expr(term, prec_add); + else if(coeff == -1) + os << " - " << print_expr(term, prec_add); + else if(coeff > 0) + os << " + " << coeff << "*" << print_expr(term, prec_mul + 1); + else + os << " - " << (-coeff) << "*" << print_expr(term, prec_mul + 1); + } + } + if(d.constant > 0) + os << " + " << d.constant; + else if(d.constant < 0) + os << " - " << (-d.constant); + std::string s = os.str(); + if(parent_prec > prec_add) + return "(" + s + ")"; + return s; + } + case kind_mul: { + const auto& d = get_mul(e); + std::ostringstream os; + bool first = true; + if(d.coefficient == -1) + { + os << "-"; + } + else if(d.coefficient != 1) + { + os << d.coefficient; + first = false; + } + for(const auto& [base, exp] : d.factors) + { + if(not first) + os << "*"; + if(exp == 1) + os << print_expr(base, prec_mul + 1); + else + os << print_expr(base, prec_mul + 1) << "**" << exp; + first = false; + } + std::string raw = os.str(); + if(parent_prec > prec_mul) + return "(" + raw + ")"; + return raw; + } + case kind_fdiv: { + const auto& d = get_fdiv(e); + std::string s = print_expr(d.numerator, prec_mul + 1) + "/" + + print_expr(d.denominator, prec_mul + 1); + if(parent_prec > prec_mul) + return "(" + s + ")"; + return s; + } + default: return "?"; + } +} + +// =================================================================== +// Section 8: Recursive descent parser +// =================================================================== + +static void skip_ws(const char*& p) +{ + while(*p and std::isspace(static_cast(*p))) + ++p; +} + +static expr_ptr parse_expr(const char*& p); +static expr_ptr parse_term(const char*& p); +static expr_ptr parse_unary(const char*& p); +static expr_ptr parse_primary(const char*& p); + +static expr_ptr parse_primary(const char*& p) +{ + skip_ws(p); + if(std::isdigit(static_cast(*p))) + { + int64_t n = 0; + while(std::isdigit(static_cast(*p))) + { + n = n * 10 + (*p - '0'); + ++p; + } + return make_integer(n); + } + if(std::isalpha(static_cast(*p)) or *p == '_') + { + std::string name; + while(std::isalnum(static_cast(*p)) or *p == '_') + { + name += *p; + ++p; + } + if(name == "floor") + { + skip_ws(p); + if(*p != '(') + MIGRAPHX_THROW("symbolic parser: expected '(' after 'floor'"); + ++p; + auto inner = parse_expr(p); + skip_ws(p); + if(*p != ')') + MIGRAPHX_THROW("symbolic parser: expected ')' after floor argument"); + ++p; + return inner; + } + return make_symbol(name); + } + if(*p == '(') + { + ++p; + auto inner = parse_expr(p); + skip_ws(p); + if(*p != ')') + MIGRAPHX_THROW("symbolic parser: expected ')'"); + ++p; + return inner; + } + MIGRAPHX_THROW("symbolic parser: unexpected character '" + std::string(1, *p) + "'"); +} + +static expr_ptr parse_unary(const char*& p) +{ + skip_ws(p); + if(*p == '-') + { + ++p; + return make_neg(parse_unary(p)); + } + return parse_primary(p); +} + +static expr_ptr parse_term(const char*& p) +{ + auto left = parse_unary(p); + for(;;) + { + skip_ws(p); + if(*p == '*') + { + ++p; + left = make_mul(left, parse_unary(p)); + } + else if(*p == '/') + { + ++p; + left = make_floor_div(left, parse_unary(p)); + } + else + break; + } + return left; +} + +static expr_ptr parse_expr(const char*& p) +{ + auto left = parse_term(p); + for(;;) + { + skip_ws(p); + if(*p == '+') + { + ++p; + left = make_add(left, parse_term(p)); + } + else if(*p == '-') + { + ++p; + left = make_sub(left, parse_term(p)); + } + else + break; + } + return left; +} + +static expr_ptr parse_string(const std::string& s) +{ + const char* p = s.c_str(); + auto result = parse_expr(p); + skip_ws(p); + if(*p != '\0') + MIGRAPHX_THROW("symbolic parser: unexpected trailing characters: '" + std::string(p) + "'"); + return result; +} + +// =================================================================== +// Section 9: symbolic_expr public API wrapper +// =================================================================== + +struct symbolic_expr::impl +{ + expr_ptr node; + + impl() : node(make_integer(0)) {} + explicit impl(expr_ptr e) : node(std::move(e)) {} +}; + +symbolic_expr::symbolic_expr() = default; + +symbolic_expr::symbolic_expr(std::shared_ptr pi) : p(std::move(pi)) {} + +symbolic_expr::symbolic_expr(std::size_t n) + : p(std::make_shared(make_integer(static_cast(n)))) +{ +} + +symbolic_expr::symbolic_expr(const std::string& s) +{ + if(s.empty()) + return; + p = std::make_shared(parse_string(s)); +} + +bool symbolic_expr::empty() const { return p == nullptr; } + +std::string symbolic_expr::to_string() const +{ + if(empty()) + return {}; + return print_expr(p->node); +} + +std::size_t symbolic_expr::eval(const std::map& symbol_map) const +{ + if(empty()) + return 0; + auto v = evaluate(p->node, symbol_map); + assert(v >= 0 && "symbolic dimension evaluated to negative value"); + return static_cast(v); +} + +symbolic_expr symbolic_expr::subs(const std::map& symbol_map) const +{ + if(empty()) + return {}; + std::map bindings; + for(const auto& [k, v] : symbol_map) + bindings[k] = static_cast(v); + auto result = substitute(p->node, bindings); + return {std::make_shared(std::move(result))}; +} + +symbolic_expr operator+(const symbolic_expr& a, const symbolic_expr& b) +{ + if(a.empty() and b.empty()) + return {}; + auto ea = a.p ? a.p->node : make_integer(0); + auto eb = b.p ? b.p->node : make_integer(0); + return {std::make_shared(make_add(ea, eb))}; +} + +symbolic_expr operator-(const symbolic_expr& a, const symbolic_expr& b) +{ + if(a.empty() and b.empty()) + return {}; + auto ea = a.p ? a.p->node : make_integer(0); + auto eb = b.p ? b.p->node : make_integer(0); + return {std::make_shared(make_sub(ea, eb))}; +} + +symbolic_expr operator*(const symbolic_expr& a, const symbolic_expr& b) +{ + if(a.empty() and b.empty()) + return {}; + auto ea = a.p ? a.p->node : make_integer(0); + auto eb = b.p ? b.p->node : make_integer(0); + return {std::make_shared(make_mul(ea, eb))}; +} + +symbolic_expr operator/(const symbolic_expr& a, const symbolic_expr& b) +{ + if(a.empty() and b.empty()) + return {}; + auto ea = a.p ? a.p->node : make_integer(0); + auto eb = b.p ? b.p->node : make_integer(0); + return {std::make_shared(make_floor_div(ea, eb))}; +} + +bool operator==(const symbolic_expr& a, const symbolic_expr& b) +{ + if(a.empty() and b.empty()) + return true; + if(a.empty() != b.empty()) + return false; + return expr_equal(a.p->node, b.p->node); +} + +bool operator!=(const symbolic_expr& a, const symbolic_expr& b) { return not(a == b); } + +std::ostream& operator<<(std::ostream& os, const symbolic_expr& e) +{ + if(not e.empty()) + os << print_expr(e.p->node); + return os; +} + +void migraphx_to_value(value& v, const symbolic_expr& e) +{ + v = migraphx::to_value(e.to_string()); +} + +void migraphx_from_value(const value& v, symbolic_expr& e) +{ + auto s = v.get_string(); + if(s.empty()) + e = symbolic_expr{}; + else + e = symbolic_expr(s); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/symbolic_test.cpp b/test/symbolic_test.cpp new file mode 100644 index 00000000000..df0528b5192 --- /dev/null +++ b/test/symbolic_test.cpp @@ -0,0 +1,652 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include "test.hpp" + +using se = migraphx::symbolic_expr; + +// =================================================================== +// Tier 1: Expression construction and canonicalization +// =================================================================== + +TEST_CASE(construct_integer) +{ + EXPECT(se(0).to_string() == "0"); + EXPECT(se(1).to_string() == "1"); + EXPECT(se(42).to_string() == "42"); +} + +TEST_CASE(construct_symbol) +{ + EXPECT(se("H").to_string() == "H"); + EXPECT(se("batch_size").to_string() == "batch_size"); +} + +TEST_CASE(construct_empty) +{ + se e; + EXPECT(e.empty()); + EXPECT(e.to_string().empty()); +} + +TEST_CASE(add_identity) +{ + EXPECT(se("H") + 0 == se("H")); + EXPECT(0 + se("H") == se("H")); +} + +TEST_CASE(add_commutativity) +{ + EXPECT(se("H") + se("W") == se("W") + se("H")); +} + +TEST_CASE(add_like_term_folding) +{ + auto r = se("H") + se("H"); + EXPECT(r.to_string() == "2*H"); +} + +TEST_CASE(add_constant_folding) +{ + EXPECT(se(3) + se(5) == se(8)); + EXPECT(se(0) + se(0) == se(0)); +} + +TEST_CASE(add_flattening) +{ + auto a = (se("H") + se("W")) + se("C"); + auto b = se("H") + (se("W") + se("C")); + auto c = (se("C") + se("H")) + se("W"); + EXPECT(a == b); + EXPECT(b == c); +} + +TEST_CASE(add_mixed) +{ + auto r = se("H") + 3 + se("H") + 2; + EXPECT(r == 2 * se("H") + 5); + auto r2 = se("H") + se("H"); + EXPECT(r2 + 5 == 2 * se("H") + 5); +} + +TEST_CASE(add_cancellation) +{ + auto neg_h = se("-H"); + EXPECT(se("H") + neg_h == se(0)); +} + +TEST_CASE(sub_identity) +{ + EXPECT(se("H") - 0 == se("H")); +} + +TEST_CASE(sub_self) +{ + EXPECT(se("H") - se("H") == se(0)); +} + +TEST_CASE(sub_constant_folding) +{ + EXPECT(se(10) - se(3) == se(7)); +} + +TEST_CASE(sub_produces_negation) +{ + EXPECT(se("-H").to_string() == "-H"); + EXPECT((3 - se("H")).to_string() == "-H + 3"); +} + +TEST_CASE(neg_integer) +{ + auto r = se(0) - 5; + EXPECT(r == se(-5)); +} + +TEST_CASE(neg_double_negation) +{ + auto r = 0 - se("-H"); + EXPECT(r == se("H")); +} + +TEST_CASE(mul_identity) +{ + EXPECT(se("H") * 1 == se("H")); + EXPECT(1 * se("H") == se("H")); +} + +TEST_CASE(mul_zero) +{ + EXPECT(se("H") * 0 == se(0)); + EXPECT(0 * se("H") == se(0)); +} + +TEST_CASE(mul_constant_folding) +{ + EXPECT(se(3) * se(7) == se(21)); +} + +TEST_CASE(mul_commutativity) +{ + EXPECT(se("B") * se("A") == se("A") * se("B")); +} + +TEST_CASE(mul_coefficient_accumulation) +{ + auto r = 2 * se("H") * 3; + EXPECT(r.to_string() == "6*H"); +} + +TEST_CASE(mul_flattening) +{ + auto a = (se("H") * se("W")) * se("C"); + auto b = se("H") * (se("W") * se("C")); + auto c = (se("C") * se("H")) * se("W"); + EXPECT(a == b); + EXPECT(b == c); +} + +TEST_CASE(mul_distributive) +{ + auto r = 2 * (se("H") + 1); + EXPECT(r == 2 * se("H") + 2); +} + +TEST_CASE(mul_symbolic_times_add_no_distribution) +{ + auto r = se("N") * (se("H") + 1); + EXPECT(r != se("N") * se("H") + se("N")); +} + +TEST_CASE(fdiv_identity) +{ + EXPECT(se("H") / 1 == se("H")); +} + +TEST_CASE(fdiv_constant_folding) +{ + EXPECT(se(7) / se(2) == se(3)); + EXPECT(se(6) / se(3) == se(2)); + EXPECT(se(0) / se(5) == se(0)); +} + +TEST_CASE(fdiv_exact_coefficient_cancel) +{ + auto r = (6 * se("N")) / 3; + EXPECT(r.to_string() == "2*N"); +} + +TEST_CASE(fdiv_non_simplifiable) +{ + auto r = (se("H") - 1) / 2; + EXPECT(r.to_string() == "(H - 1)/2"); +} + +TEST_CASE(fdiv_division_by_zero) +{ + EXPECT(test::throws([&] { se("H") / 0; })); +} + +TEST_CASE(add_scaled_subtraction) +{ + EXPECT(2 * se("H") - se("H") == se("H")); + EXPECT(3 * se("H") - 2 * se("H") == se("H")); + EXPECT(se("H") + se("H") + se("H") == 3 * se("H")); +} + +TEST_CASE(add_of_two_adds) +{ + auto r = (se("H") + 1) + (se("H") + 2); + EXPECT(r == 2 * se("H") + 3); +} + +TEST_CASE(sub_strip_constant) +{ + EXPECT((se("H") + 1) - se("H") == se(1)); +} + +TEST_CASE(sub_of_two_adds) +{ + auto r = (se("H") + 1) - (se("H") + 2); + EXPECT(r == se(-1)); +} + +TEST_CASE(mul_zero_propagation) +{ + auto z = se("H") - se("H"); + EXPECT(50 * z == se(0)); +} + +TEST_CASE(add_chain_constant_cancel) +{ + auto r = se(2) - se("H") - se(2); + EXPECT(r == se("-H")); +} + +TEST_CASE(neg_of_sum_distributes) +{ + auto r = se(-1) * (se("H") + 1); + EXPECT(r == se("-H") - 1); +} + +TEST_CASE(neg_of_product_double) +{ + auto hw = se("H") * se("W"); + auto neg = 0 - hw; + EXPECT(neg.to_string() == "-H*W"); + auto dbl = 0 - neg; + EXPECT(dbl == hw); +} + +TEST_CASE(add_compound_product_like_terms) +{ + auto hw = se("H") * se("W"); + auto wh = se("W") * se("H"); + EXPECT(hw + hw == 2 * hw); + EXPECT(hw + wh == 2 * hw); + EXPECT(hw + 2 * hw == 3 * hw); +} + +TEST_CASE(add_compound_product_cancellation) +{ + auto hw = se("H") * se("W"); + EXPECT(hw - hw == se(0)); +} + +// X*Y and X cancel pairwise: (X*Y - X) - (X*Y - X) == 0 +TEST_CASE(sub_compound_product_mixed) +{ + auto xy = se("X") * se("Y"); + auto r = xy - se("X") - xy + se("X"); + EXPECT(r == se(0)); +} + +// Duplicate A*B terms fold even when separated by another term +TEST_CASE(add_multi_term_accumulation) +{ + auto r = se("A") * se("B") + se("C") + se("A") * se("B"); + auto expected = 2 * (se("A") * se("B")) + se("C"); + EXPECT(r == expected); +} + +TEST_CASE(fdiv_negative_constant_folding) +{ + EXPECT(se(-7) / se(2) == se(-7 / 2)); + EXPECT(se(-6) / se(3) == se(-2)); + EXPECT(se(7) / se(-2) == se(7 / -2)); +} + +TEST_CASE(fdiv_large_constants) +{ + EXPECT(se(1000000) / se(1000) == se(1000)); + EXPECT(se(999999) / se(1000) == se(999)); +} + +// =================================================================== +// Tier 2: Equality and hashing +// =================================================================== + +TEST_CASE(eq_different_values) +{ + EXPECT(se("H") + 1 != se("H") + 2); + EXPECT(se("H") != se("W")); + EXPECT(se(3) != se(4)); +} + +TEST_CASE(eq_empty) +{ + EXPECT(se{} == se{}); + EXPECT(se{} != se(0)); + EXPECT(se(0) != se{}); +} + + +// =================================================================== +// Tier 3: Evaluation and substitution +// =================================================================== + +TEST_CASE(eval_simple) +{ + EXPECT(se("H").eval({{"H", 26}}) == 26); + EXPECT(se(42).eval({}) == 42); +} + +TEST_CASE(eval_arithmetic) +{ + EXPECT((se("H") - 3).eval({{"H", 26}}) == 23); + EXPECT((se("H") + 5).eval({{"H", 10}}) == 15); + EXPECT((2 * se("H")).eval({{"H", 13}}) == 26); +} + +TEST_CASE(eval_compound) +{ + auto e = (se("H") - 3) / 2 + 1; + EXPECT(e.eval({{"H", 26}}) == 12); + EXPECT(e.eval({{"H", 27}}) == 13); +} + +TEST_CASE(eval_multiple_symbols) +{ + auto e = se("N") * se("H"); + EXPECT(e.eval({{"N", 4}, {"H", 26}}) == 104); +} + +TEST_CASE(eval_floor_division) +{ + auto e = (se("H") - 1) / 2; + EXPECT(e.eval({{"H", 7}}) == 3); + EXPECT(e.eval({{"H", 8}}) == 3); + EXPECT(e.eval({{"H", 9}}) == 4); +} + +TEST_CASE(eval_unbound_throws) +{ + EXPECT(test::throws([&] { se("H").eval({}); })); + EXPECT(test::throws([&] { (se("H") + se("W")).eval({{"H", 1}}); })); +} + +TEST_CASE(eval_integer_expr) +{ + EXPECT(se(0).eval({}) == 0); + EXPECT(se(100).eval({}) == 100); +} + +TEST_CASE(subs_partial) +{ + auto e = se("N") * se("H") + 1; + auto r = e.subs({{"N", 4}}); + EXPECT(r == 4 * se("H") + 1); + EXPECT(r.eval({{"H", 10}}) == 41); +} + +TEST_CASE(subs_full) +{ + auto e = se("H") + 1; + auto r = e.subs({{"H", 5}}); + EXPECT(r == se(6)); +} + +TEST_CASE(subs_none) +{ + auto e = se("H"); + EXPECT(e.subs({}) == se("H")); +} + +TEST_CASE(subs_floor_div) +{ + auto e = (se("H") - 1) / 2; + auto r = e.subs({{"H", 7}}); + EXPECT(r == se(3)); +} + +// eval() and subs()+eval() must agree on a compound expression +TEST_CASE(subs_eval_cross_validation) +{ + auto e = (se("N") * se("H") - 3) / 2 + 1; + std::map m = {{"N", 4}, {"H", 26}}; + auto via_eval = e.eval(m); + auto via_subs = e.subs(m).eval({}); + EXPECT(via_eval == via_subs); +} + +TEST_CASE(subs_empty) +{ + se e; + auto r = e.subs({{"H", 5}}); + EXPECT(r.empty()); +} + +TEST_CASE(subs_creates_like_terms) +{ + auto e = se("H") + se("W"); + auto r = e.subs({{"W", 0}}); + EXPECT(r == se("H")); +} + +TEST_CASE(eval_compound_product) +{ + auto e = se("H") * se("W") + 1; + EXPECT(e.eval({{"H", 3}, {"W", 4}}) == 13); +} + +TEST_CASE(eval_negative_intermediate) +{ + auto e = (se("H") - 10) * 2 + 20; + EXPECT(e.eval({{"H", 3}}) == 6); +} + +// =================================================================== +// Tier 4: Printing and parsing +// =================================================================== + +TEST_CASE(print_atoms) +{ + EXPECT(se(42).to_string() == "42"); + EXPECT(se("H").to_string() == "H"); + EXPECT(se(0).to_string() == "0"); + EXPECT(se(-3).to_string() == "-3"); +} + +TEST_CASE(print_add) +{ + EXPECT((se("H") + 1).to_string() == "H + 1"); + EXPECT((se("H") - 3).to_string() == "H - 3"); +} + +TEST_CASE(print_mul) +{ + EXPECT((2 * se("H")).to_string() == "2*H"); + auto r = se("A") * se("B"); + EXPECT(r.to_string() == "A*B"); +} + +TEST_CASE(print_fdiv_parens) +{ + auto r = (se("H") - 1) / 2; + EXPECT(r.to_string() == "(H - 1)/2"); +} + +TEST_CASE(print_compound) +{ + auto r = (se("H") - 3) / 2 + 1; + EXPECT(r.to_string() == "(H - 3)/2 + 1"); +} + +TEST_CASE(parse_atoms) +{ + EXPECT(se("42") == se(42)); + EXPECT(se("H") == se("H")); +} + +TEST_CASE(parse_arithmetic) +{ + auto r = se("H + 1"); + EXPECT(r == se("H") + 1); + + auto r2 = se("H - 3"); + EXPECT(r2 == se("H") - 3); + + auto r3 = se("2*H"); + EXPECT(r3 == 2 * se("H")); +} + +TEST_CASE(parse_precedence) +{ + auto r = se("H + 1 * 2"); + EXPECT(r == se("H") + 2); +} + +TEST_CASE(parse_parentheses) +{ + auto r = se("(H + 1) * 2"); + EXPECT(r == 2 * (se("H") + 1)); +} + +TEST_CASE(parse_division) +{ + auto r = se("(H - 1)/2"); + EXPECT(r == (se("H") - 1) / 2); +} + +TEST_CASE(parse_unary_minus) +{ + EXPECT(se("-H") == se("-H")); + EXPECT(se("-H").to_string() == "-H"); + EXPECT(se("-(H + 1)") == se("-H") - 1); +} + +// Legacy floor() wrapper is accepted by parser and treated as no-op +TEST_CASE(parse_floor_backward_compat) +{ + auto a = se("floor((H-1)/2)"); + auto b = se("(H-1)/2"); + EXPECT(a == b); + + auto c = se("floor((H-1)/2) + 1"); + auto d = (se("H") - 1) / 2 + 1; + EXPECT(c == d); +} + +TEST_CASE(parse_whitespace_tolerance) +{ + EXPECT(se(" H + 1 ") == se("H + 1")); + EXPECT(se("H+1") == se("H + 1")); +} + +TEST_CASE(print_negative_mul_coefficient) +{ + auto r = 0 - 3 * se("H"); + EXPECT(r.to_string() == "-3*H"); +} + +TEST_CASE(print_multi_symbol_product) +{ + auto r = se("H") * se("W"); + auto s = r.to_string(); + EXPECT(s == "H*W" or s == "W*H"); + EXPECT(se("H*W") == se("W*H")); +} + +TEST_CASE(print_compound_expression) +{ + auto r = 2 * (se("H") * se("W")) + se("C") - 1; + auto s = r.to_string(); + EXPECT(se(s) == r); +} + +TEST_CASE(parse_compound_mul) +{ + auto r = se("2*H*W"); + EXPECT(r == 2 * se("H") * se("W")); +} + +TEST_CASE(print_parse_round_trip) +{ + std::vector exprs = { + se("H"), + se("H") + 1, + 2 * se("H") - 3, + (se("H") - 3) / 2 + 1, + se("N") * se("C") * se("H") * se("W"), + (se("H") - 1) / 2, + }; + for(const auto& e : exprs) + { + auto s = e.to_string(); + auto reparsed = se(s); + EXPECT(reparsed == e); + } +} + +// =================================================================== +// Tier 6: Edge cases and robustness +// =================================================================== + +// 5 levels of (e-1)/2: simulates repeated pooling/conv stride reduction +TEST_CASE(edge_deeply_nested) +{ + auto e = se("H"); + for(int i = 0; i < 5; ++i) + e = (e - 1) / 2; + EXPECT(e.eval({{"H", 255}}) == 7); +} + +TEST_CASE(edge_many_symbols) +{ + auto e = se("A") + se("B") + se("C") + se("D") + se("E"); + EXPECT(e.eval({{"A", 1}, {"B", 2}, {"C", 3}, {"D", 4}, {"E", 5}}) == 15); +} + +TEST_CASE(edge_neg_one_coefficient) +{ + EXPECT(se("-H").to_string() == "-H"); + EXPECT(se("-H") + se("H") == se(0)); +} + + +TEST_CASE(edge_empty_operations) +{ + se empty; + EXPECT((empty + empty).empty()); + EXPECT((empty - empty).empty()); + EXPECT((empty * empty).empty()); + EXPECT((empty / empty).empty()); +} + +TEST_CASE(edge_empty_with_nonempty) +{ + se empty; + auto r1 = se("H") + empty; + EXPECT(not r1.empty()); + + auto r2 = empty + se("H"); + EXPECT(not r2.empty()); +} + +TEST_CASE(edge_large_coefficients) +{ + auto r = 1000000 * se("H"); + EXPECT(r.eval({{"H", 1000000}}) == 1000000000000ULL); +} + +// Incrementally adding H ten times must fold to 11*H +TEST_CASE(edge_chained_operations) +{ + auto e = se("H"); + for(int i = 0; i < 10; ++i) + e = e + se("H"); + EXPECT(e == 11 * se("H")); +} + + +TEST_CASE(edge_repeated_parse) +{ + for(int i = 0; i < 10; ++i) + { + auto r = se("(H - 3)/2 + 1"); + EXPECT(r == (se("H") - 3) / 2 + 1); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From aa557852c54f58f0c255a87226d866c936001137 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 24 Mar 2026 12:04:26 -0700 Subject: [PATCH 02/60] format --- src/include/migraphx/symbolic.hpp | 52 +++++++++++++++++++++---------- src/symbolic.cpp | 35 ++++++++------------- 2 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/include/migraphx/symbolic.hpp b/src/include/migraphx/symbolic.hpp index c7edac2aee1..0cb0aa4c2b2 100644 --- a/src/include/migraphx/symbolic.hpp +++ b/src/include/migraphx/symbolic.hpp @@ -48,14 +48,10 @@ struct MIGRAPHX_EXPORT symbolic_expr std::size_t eval(const std::map& symbol_map) const; symbolic_expr subs(const std::map& symbol_map) const; - MIGRAPHX_EXPORT friend symbolic_expr operator+(const symbolic_expr& a, - const symbolic_expr& b); - MIGRAPHX_EXPORT friend symbolic_expr operator-(const symbolic_expr& a, - const symbolic_expr& b); - MIGRAPHX_EXPORT friend symbolic_expr operator*(const symbolic_expr& a, - const symbolic_expr& b); - MIGRAPHX_EXPORT friend symbolic_expr operator/(const symbolic_expr& a, - const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator+(const symbolic_expr& a, const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator-(const symbolic_expr& a, const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator*(const symbolic_expr& a, const symbolic_expr& b); + MIGRAPHX_EXPORT friend symbolic_expr operator/(const symbolic_expr& a, const symbolic_expr& b); MIGRAPHX_EXPORT friend bool operator==(const symbolic_expr& a, const symbolic_expr& b); MIGRAPHX_EXPORT friend bool operator!=(const symbolic_expr& a, const symbolic_expr& b); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const symbolic_expr& e); @@ -67,14 +63,38 @@ struct MIGRAPHX_EXPORT symbolic_expr std::shared_ptr p; }; -inline symbolic_expr operator+(const symbolic_expr& a, std::size_t b) { return a + symbolic_expr(b); } -inline symbolic_expr operator+(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) + b; } -inline symbolic_expr operator-(const symbolic_expr& a, std::size_t b) { return a - symbolic_expr(b); } -inline symbolic_expr operator-(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) - b; } -inline symbolic_expr operator*(const symbolic_expr& a, std::size_t b) { return a * symbolic_expr(b); } -inline symbolic_expr operator*(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) * b; } -inline symbolic_expr operator/(const symbolic_expr& a, std::size_t b) { return a / symbolic_expr(b); } -inline symbolic_expr operator/(std::size_t a, const symbolic_expr& b) { return symbolic_expr(a) / b; } +inline symbolic_expr operator+(const symbolic_expr& a, std::size_t b) +{ + return a + symbolic_expr(b); +} +inline symbolic_expr operator+(std::size_t a, const symbolic_expr& b) +{ + return symbolic_expr(a) + b; +} +inline symbolic_expr operator-(const symbolic_expr& a, std::size_t b) +{ + return a - symbolic_expr(b); +} +inline symbolic_expr operator-(std::size_t a, const symbolic_expr& b) +{ + return symbolic_expr(a) - b; +} +inline symbolic_expr operator*(const symbolic_expr& a, std::size_t b) +{ + return a * symbolic_expr(b); +} +inline symbolic_expr operator*(std::size_t a, const symbolic_expr& b) +{ + return symbolic_expr(a) * b; +} +inline symbolic_expr operator/(const symbolic_expr& a, std::size_t b) +{ + return a / symbolic_expr(b); +} +inline symbolic_expr operator/(std::size_t a, const symbolic_expr& b) +{ + return symbolic_expr(a) / b; +} MIGRAPHX_EXPORT void migraphx_to_value(value& v, const symbolic_expr& e); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, symbolic_expr& e); diff --git a/src/symbolic.cpp b/src/symbolic.cpp index 98480a0d92f..6ccac242cb6 100644 --- a/src/symbolic.cpp +++ b/src/symbolic.cpp @@ -221,7 +221,7 @@ static int compare_expr(const expr_ptr& a, const expr_ptr& b) case kind_fdiv: { const auto& da = get_fdiv(a); const auto& db = get_fdiv(b); - int c = compare_expr(da.numerator, db.numerator); + int c = compare_expr(da.numerator, db.numerator); if(c != 0) return c; return compare_expr(da.denominator, db.denominator); @@ -254,7 +254,7 @@ static bool expr_equal(const expr_ptr& a, const expr_ptr& b) static expr_ptr make_node(expr_data d) { - auto n = std::make_shared(); + auto n = std::make_shared(); n->data = std::move(d); n->cached_hash = compute_hash(n->data); return n; @@ -314,7 +314,7 @@ static add_parts extract_add(const expr_ptr& e) if(is_mul(e)) { const auto& d = get_mul(e); - auto base = build_mul(1, d.factors); + auto base = build_mul(1, d.factors); return {0, {{base, d.coefficient}}}; } return {0, {{e, 1}}}; @@ -375,10 +375,7 @@ static expr_ptr make_neg(const expr_ptr& a) return make_mul(make_integer(-1), a); } -static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b) -{ - return make_add(a, make_neg(b)); -} +static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b) { return make_add(a, make_neg(b)); } struct mul_parts { @@ -425,7 +422,7 @@ static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) // Special case: multiplying an integer by an Add → scale the Add if(is_integer(a) and is_add(b)) { - int64_t n = get_integer(a); + int64_t n = get_integer(a); if(n == 0) return make_integer(0); if(n == 1) @@ -482,8 +479,7 @@ static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) // Partial substitution: replaces bound symbols with integers and re-canonicalizes. // Unbound symbols are left as-is, producing a simplified symbolic expression. -static expr_ptr substitute(const expr_ptr& e, - const std::map& bindings) +static expr_ptr substitute(const expr_ptr& e, const std::map& bindings) { switch(e->kind()) { @@ -495,7 +491,7 @@ static expr_ptr substitute(const expr_ptr& e, return e; } case kind_add: { - const auto& d = get_add(e); + const auto& d = get_add(e); expr_ptr result = make_integer(d.constant); for(const auto& [term, coeff] : d.terms) { @@ -505,7 +501,7 @@ static expr_ptr substitute(const expr_ptr& e, return result; } case kind_mul: { - const auto& d = get_mul(e); + const auto& d = get_mul(e); expr_ptr result = make_integer(d.coefficient); for(const auto& [base, exp] : d.factors) { @@ -527,8 +523,7 @@ static expr_ptr substitute(const expr_ptr& e, // Full evaluation: computes integer result directly without allocations. // All symbols must be bound; throws if any symbol is unbound. -static int64_t eval_direct(const expr_ptr& e, - const std::map& symbol_map) +static int64_t eval_direct(const expr_ptr& e, const std::map& symbol_map) { switch(e->kind()) { @@ -565,8 +560,7 @@ static int64_t eval_direct(const expr_ptr& e, } } -static int64_t evaluate(const expr_ptr& e, - const std::map& symbol_map) +static int64_t evaluate(const expr_ptr& e, const std::map& symbol_map) { return eval_direct(e, symbol_map); } @@ -657,8 +651,8 @@ static std::string print_expr(const expr_ptr& e, int parent_prec) } case kind_fdiv: { const auto& d = get_fdiv(e); - std::string s = print_expr(d.numerator, prec_mul + 1) + "/" + - print_expr(d.denominator, prec_mul + 1); + std::string s = + print_expr(d.numerator, prec_mul + 1) + "/" + print_expr(d.denominator, prec_mul + 1); if(parent_prec > prec_mul) return "(" + s + ")"; return s; @@ -907,10 +901,7 @@ std::ostream& operator<<(std::ostream& os, const symbolic_expr& e) return os; } -void migraphx_to_value(value& v, const symbolic_expr& e) -{ - v = migraphx::to_value(e.to_string()); -} +void migraphx_to_value(value& v, const symbolic_expr& e) { v = migraphx::to_value(e.to_string()); } void migraphx_from_value(const value& v, symbolic_expr& e) { From 314f7cfdb74174f8490bd6ef3961bc1188e7d40f Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 24 Mar 2026 14:55:25 -0700 Subject: [PATCH 03/60] use visit --- src/symbolic.cpp | 500 ++++++++++++++++++++++------------------------- 1 file changed, 229 insertions(+), 271 deletions(-) diff --git a/src/symbolic.cpp b/src/symbolic.cpp index 6ccac242cb6..36e7d3b1090 100644 --- a/src/symbolic.cpp +++ b/src/symbolic.cpp @@ -84,34 +84,29 @@ struct fdiv_data using expr_data = std::variant; -constexpr int kind_integer = 0; -constexpr int kind_symbol = 1; -constexpr int kind_add = 2; -constexpr int kind_mul = 3; -constexpr int kind_fdiv = 4; +template +struct overloaded : Ts... +{ + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; struct expr_node { expr_data data; std::size_t cached_hash = 0; - - int kind() const { return static_cast(data.index()); } }; -static bool is_integer(const expr_ptr& e) { return e->kind() == kind_integer; } -static bool is_symbol(const expr_ptr& e) { return e->kind() == kind_symbol; } -static bool is_add(const expr_ptr& e) { return e->kind() == kind_add; } -static bool is_mul(const expr_ptr& e) { return e->kind() == kind_mul; } -static bool is_fdiv(const expr_ptr& e) { return e->kind() == kind_fdiv; } - -static int64_t get_integer(const expr_ptr& e) { return std::get(e->data).value; } -static const std::string& get_symbol_name(const expr_ptr& e) +template +static bool holds(const expr_ptr& e) { - return std::get(e->data).name; + return std::holds_alternative(e->data); } + +static int64_t get_integer(const expr_ptr& e) { return std::get(e->data).value; } static const add_data& get_add(const expr_ptr& e) { return std::get(e->data); } static const mul_data& get_mul(const expr_ptr& e) { return std::get(e->data); } -static const fdiv_data& get_fdiv(const expr_ptr& e) { return std::get(e->data); } // =================================================================== // Section 2: Hash computation @@ -137,26 +132,23 @@ static std::size_t hash_ordered_map(const Map& m) static std::size_t compute_hash(const expr_data& d) { std::size_t h = std::hash{}(static_cast(d.index())); - if(auto* p = std::get_if(&d)) - return hash_combine(h, std::hash{}(p->value)); - if(auto* p = std::get_if(&d)) - return hash_combine(h, std::hash{}(p->name)); - if(auto* p = std::get_if(&d)) - { - h = hash_combine(h, std::hash{}(p->constant)); - return hash_combine(h, hash_ordered_map(p->terms)); - } - if(auto* p = std::get_if(&d)) - { - h = hash_combine(h, std::hash{}(p->coefficient)); - return hash_combine(h, hash_ordered_map(p->factors)); - } - if(auto* p = std::get_if(&d)) - { - h = hash_combine(h, p->numerator->cached_hash); - return hash_combine(h, p->denominator->cached_hash); - } - return h; + return std::visit( + overloaded{ + [&](const integer_data& p) { return hash_combine(h, std::hash{}(p.value)); }, + [&](const symbol_data& p) { return hash_combine(h, std::hash{}(p.name)); }, + [&](const add_data& p) { + return hash_combine(hash_combine(h, std::hash{}(p.constant)), + hash_ordered_map(p.terms)); + }, + [&](const mul_data& p) { + return hash_combine(hash_combine(h, std::hash{}(p.coefficient)), + hash_ordered_map(p.factors)); + }, + [&](const fdiv_data& p) { + return hash_combine(hash_combine(h, p.numerator->cached_hash), + p.denominator->cached_hash); + }}, + d); } // =================================================================== @@ -189,45 +181,38 @@ static int compare_maps(const Map& a, const Map& b) static int compare_expr(const expr_ptr& a, const expr_ptr& b) { - if(a->kind() != b->kind()) - return a->kind() < b->kind() ? -1 : 1; - - switch(a->kind()) - { - case kind_integer: { - auto va = get_integer(a); - auto vb = get_integer(b); - return (va < vb) ? -1 : (va > vb) ? 1 : 0; - } - case kind_symbol: { - const auto& na = get_symbol_name(a); - const auto& nb = get_symbol_name(b); - return na.compare(nb); - } - case kind_add: { - const auto& da = get_add(a); - const auto& db = get_add(b); - if(da.constant != db.constant) - return da.constant < db.constant ? -1 : 1; - return compare_maps(da.terms, db.terms); - } - case kind_mul: { - const auto& da = get_mul(a); - const auto& db = get_mul(b); - if(da.coefficient != db.coefficient) - return da.coefficient < db.coefficient ? -1 : 1; - return compare_maps(da.factors, db.factors); - } - case kind_fdiv: { - const auto& da = get_fdiv(a); - const auto& db = get_fdiv(b); - int c = compare_expr(da.numerator, db.numerator); - if(c != 0) - return c; - return compare_expr(da.denominator, db.denominator); - } - default: return 0; - } + if(a->data.index() != b->data.index()) + return a->data.index() < b->data.index() ? -1 : 1; + + return std::visit( + overloaded{[&](const integer_data& da) { + const auto& db = std::get(b->data); + return (da.value < db.value) ? -1 : (da.value > db.value) ? 1 : 0; + }, + [&](const symbol_data& da) { + const auto& db = std::get(b->data); + return da.name.compare(db.name); + }, + [&](const add_data& da) { + const auto& db = std::get(b->data); + if(da.constant != db.constant) + return da.constant < db.constant ? -1 : 1; + return compare_maps(da.terms, db.terms); + }, + [&](const mul_data& da) { + const auto& db = std::get(b->data); + if(da.coefficient != db.coefficient) + return da.coefficient < db.coefficient ? -1 : 1; + return compare_maps(da.factors, db.factors); + }, + [&](const fdiv_data& da) { + const auto& db = std::get(b->data); + int c = compare_expr(da.numerator, db.numerator); + if(c != 0) + return c; + return compare_expr(da.denominator, db.denominator); + }}, + a->data); } bool expr_compare::operator()(const expr_ptr& a, const expr_ptr& b) const @@ -304,20 +289,15 @@ struct add_parts static add_parts extract_add(const expr_ptr& e) { - if(is_integer(e)) - return {get_integer(e), {}}; - if(is_add(e)) - { - const auto& d = get_add(e); - return {d.constant, d.terms}; - } - if(is_mul(e)) - { - const auto& d = get_mul(e); - auto base = build_mul(1, d.factors); - return {0, {{base, d.coefficient}}}; - } - return {0, {{e, 1}}}; + return std::visit( + overloaded{[](const integer_data& d) -> add_parts { return {d.value, {}}; }, + [](const add_data& d) -> add_parts { return {d.constant, d.terms}; }, + [&](const mul_data& d) -> add_parts { + auto base = build_mul(1, d.factors); + return {0, {{base, d.coefficient}}}; + }, + [&](const auto&) -> add_parts { return {0, {{e, 1}}}; }}, + e->data); } static expr_ptr build_add(int64_t constant, term_map terms) @@ -357,22 +337,19 @@ static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b) static expr_ptr make_neg(const expr_ptr& a) { - if(is_integer(a)) - return make_integer(-get_integer(a)); - if(is_add(a)) - { - const auto& d = get_add(a); - term_map negated; - for(const auto& [term, coeff] : d.terms) - negated[term] = -coeff; - return build_add(-d.constant, std::move(negated)); - } - if(is_mul(a)) - { - const auto& d = get_mul(a); - return make_node(mul_data{-d.coefficient, d.factors}); - } - return make_mul(make_integer(-1), a); + return std::visit( + overloaded{[](const integer_data& d) -> expr_ptr { return make_integer(-d.value); }, + [](const add_data& d) -> expr_ptr { + term_map negated; + for(const auto& [term, coeff] : d.terms) + negated[term] = -coeff; + return build_add(-d.constant, std::move(negated)); + }, + [](const mul_data& d) -> expr_ptr { + return make_node(mul_data{-d.coefficient, d.factors}); + }, + [&](const auto&) -> expr_ptr { return make_mul(make_integer(-1), a); }}, + a->data); } static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b) { return make_add(a, make_neg(b)); } @@ -385,14 +362,11 @@ struct mul_parts static mul_parts extract_mul(const expr_ptr& e) { - if(is_integer(e)) - return {get_integer(e), {}}; - if(is_mul(e)) - { - const auto& d = get_mul(e); - return {d.coefficient, d.factors}; - } - return {1, {{e, 1}}}; + return std::visit( + overloaded{[](const integer_data& d) -> mul_parts { return {d.value, {}}; }, + [](const mul_data& d) -> mul_parts { return {d.coefficient, d.factors}; }, + [&](const auto&) -> mul_parts { return {1, {{e, 1}}}; }}, + e->data); } static expr_ptr build_mul(int64_t coefficient, factor_map factors) @@ -419,8 +393,7 @@ static expr_ptr build_mul(int64_t coefficient, factor_map factors) static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) { - // Special case: multiplying an integer by an Add → scale the Add - if(is_integer(a) and is_add(b)) + if(holds(a) and holds(b)) { int64_t n = get_integer(a); if(n == 0) @@ -433,7 +406,7 @@ static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) scaled[term] = coeff * n; return build_add(d.constant * n, std::move(scaled)); } - if(is_integer(b) and is_add(a)) + if(holds(b) and holds(a)) return make_mul(b, a); auto pa = extract_mul(a); @@ -452,17 +425,16 @@ static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) { - if(is_integer(b)) + if(holds(b)) { int64_t den = get_integer(b); if(den == 0) MIGRAPHX_THROW("symbolic: division by zero"); if(den == 1) return a; - if(is_integer(a)) + if(holds(a)) return make_integer(get_integer(a) / den); - // Cancel exact coefficient in Mul: (c*factors)/den where c%den==0 - if(is_mul(a)) + if(holds(a)) { const auto& d = get_mul(a); if(d.coefficient % den == 0) @@ -481,83 +453,74 @@ static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) // Unbound symbols are left as-is, producing a simplified symbolic expression. static expr_ptr substitute(const expr_ptr& e, const std::map& bindings) { - switch(e->kind()) - { - case kind_integer: return e; - case kind_symbol: { - auto it = bindings.find(get_symbol_name(e)); - if(it != bindings.end()) - return make_integer(it->second); - return e; - } - case kind_add: { - const auto& d = get_add(e); - expr_ptr result = make_integer(d.constant); - for(const auto& [term, coeff] : d.terms) - { - auto st = substitute(term, bindings); - result = make_add(result, make_mul(make_integer(coeff), st)); - } - return result; - } - case kind_mul: { - const auto& d = get_mul(e); - expr_ptr result = make_integer(d.coefficient); - for(const auto& [base, exp] : d.factors) - { - auto sb = substitute(base, bindings); - for(int64_t i = 0; i < exp; ++i) - result = make_mul(result, sb); - } - return result; - } - case kind_fdiv: { - const auto& d = get_fdiv(e); - auto sn = substitute(d.numerator, bindings); - auto sd = substitute(d.denominator, bindings); - return make_floor_div(sn, sd); - } - default: return e; - } + return std::visit(overloaded{[&](const integer_data&) -> expr_ptr { return e; }, + [&](const symbol_data& d) -> expr_ptr { + auto it = bindings.find(d.name); + if(it != bindings.end()) + return make_integer(it->second); + return e; + }, + [&](const add_data& d) -> expr_ptr { + expr_ptr result = make_integer(d.constant); + for(const auto& [term, coeff] : d.terms) + { + auto st = substitute(term, bindings); + result = + make_add(result, make_mul(make_integer(coeff), st)); + } + return result; + }, + [&](const mul_data& d) -> expr_ptr { + expr_ptr result = make_integer(d.coefficient); + for(const auto& [base, exp] : d.factors) + { + auto sb = substitute(base, bindings); + for(int64_t i = 0; i < exp; ++i) + result = make_mul(result, sb); + } + return result; + }, + [&](const fdiv_data& d) -> expr_ptr { + auto sn = substitute(d.numerator, bindings); + auto sd = substitute(d.denominator, bindings); + return make_floor_div(sn, sd); + }}, + e->data); } // Full evaluation: computes integer result directly without allocations. // All symbols must be bound; throws if any symbol is unbound. static int64_t eval_direct(const expr_ptr& e, const std::map& symbol_map) { - switch(e->kind()) - { - case kind_integer: return get_integer(e); - case kind_symbol: { - auto it = symbol_map.find(get_symbol_name(e)); - if(it != symbol_map.end()) - return static_cast(it->second); - MIGRAPHX_THROW("symbolic_expr::eval: unbound symbol '" + get_symbol_name(e) + "'"); - } - case kind_add: { - const auto& d = get_add(e); - int64_t sum = d.constant; - for(const auto& [term, coeff] : d.terms) - sum += coeff * eval_direct(term, symbol_map); - return sum; - } - case kind_mul: { - const auto& d = get_mul(e); - int64_t prod = d.coefficient; - for(const auto& [base, exp] : d.factors) - { - int64_t val = eval_direct(base, symbol_map); - for(int64_t i = 0; i < exp; ++i) - prod *= val; - } - return prod; - } - case kind_fdiv: { - const auto& d = get_fdiv(e); - return eval_direct(d.numerator, symbol_map) / eval_direct(d.denominator, symbol_map); - } - default: return 0; - } + return std::visit(overloaded{[](const integer_data& d) -> int64_t { return d.value; }, + [&](const symbol_data& d) -> int64_t { + auto it = symbol_map.find(d.name); + if(it != symbol_map.end()) + return static_cast(it->second); + MIGRAPHX_THROW("symbolic_expr::eval: unbound symbol '" + + d.name + "'"); + }, + [&](const add_data& d) -> int64_t { + int64_t sum = d.constant; + for(const auto& [term, coeff] : d.terms) + sum += coeff * eval_direct(term, symbol_map); + return sum; + }, + [&](const mul_data& d) -> int64_t { + int64_t prod = d.coefficient; + for(const auto& [base, exp] : d.factors) + { + int64_t val = eval_direct(base, symbol_map); + for(int64_t i = 0; i < exp; ++i) + prod *= val; + } + return prod; + }, + [&](const fdiv_data& d) -> int64_t { + return eval_direct(d.numerator, symbol_map) / + eval_direct(d.denominator, symbol_map); + }}, + e->data); } static int64_t evaluate(const expr_ptr& e, const std::map& symbol_map) @@ -580,85 +543,80 @@ static std::string print_expr(const expr_ptr& e, int parent_prec = 0); static std::string print_expr(const expr_ptr& e, int parent_prec) { - switch(e->kind()) - { - case kind_integer: return std::to_string(get_integer(e)); - case kind_symbol: return get_symbol_name(e); - case kind_add: { - const auto& d = get_add(e); - std::ostringstream os; - bool first = true; - for(const auto& [term, coeff] : d.terms) - { - if(first) - { - if(coeff == -1) - os << "-" << print_expr(term, prec_add); - else if(coeff == 1) - os << print_expr(term, prec_add); - else - os << coeff << "*" << print_expr(term, prec_mul + 1); - first = false; - } - else - { - if(coeff == 1) - os << " + " << print_expr(term, prec_add); - else if(coeff == -1) - os << " - " << print_expr(term, prec_add); - else if(coeff > 0) - os << " + " << coeff << "*" << print_expr(term, prec_mul + 1); - else - os << " - " << (-coeff) << "*" << print_expr(term, prec_mul + 1); - } - } - if(d.constant > 0) - os << " + " << d.constant; - else if(d.constant < 0) - os << " - " << (-d.constant); - std::string s = os.str(); - if(parent_prec > prec_add) - return "(" + s + ")"; - return s; - } - case kind_mul: { - const auto& d = get_mul(e); - std::ostringstream os; - bool first = true; - if(d.coefficient == -1) - { - os << "-"; - } - else if(d.coefficient != 1) - { - os << d.coefficient; - first = false; - } - for(const auto& [base, exp] : d.factors) - { - if(not first) - os << "*"; - if(exp == 1) - os << print_expr(base, prec_mul + 1); - else - os << print_expr(base, prec_mul + 1) << "**" << exp; - first = false; - } - std::string raw = os.str(); - if(parent_prec > prec_mul) - return "(" + raw + ")"; - return raw; - } - case kind_fdiv: { - const auto& d = get_fdiv(e); - std::string s = - print_expr(d.numerator, prec_mul + 1) + "/" + print_expr(d.denominator, prec_mul + 1); - if(parent_prec > prec_mul) - return "(" + s + ")"; - return s; - } - default: return "?"; - } + return std::visit( + overloaded{[](const integer_data& d) -> std::string { return std::to_string(d.value); }, + [](const symbol_data& d) -> std::string { return d.name; }, + [&](const add_data& d) -> std::string { + std::ostringstream os; + bool first = true; + for(const auto& [term, coeff] : d.terms) + { + if(first) + { + if(coeff == -1) + os << "-" << print_expr(term, prec_add); + else if(coeff == 1) + os << print_expr(term, prec_add); + else + os << coeff << "*" << print_expr(term, prec_mul + 1); + first = false; + } + else + { + if(coeff == 1) + os << " + " << print_expr(term, prec_add); + else if(coeff == -1) + os << " - " << print_expr(term, prec_add); + else if(coeff > 0) + os << " + " << coeff << "*" << print_expr(term, prec_mul + 1); + else + os << " - " << (-coeff) << "*" << print_expr(term, prec_mul + 1); + } + } + if(d.constant > 0) + os << " + " << d.constant; + else if(d.constant < 0) + os << " - " << (-d.constant); + std::string s = os.str(); + if(parent_prec > prec_add) + return "(" + s + ")"; + return s; + }, + [&](const mul_data& d) -> std::string { + std::ostringstream os; + bool first = true; + if(d.coefficient == -1) + { + os << "-"; + } + else if(d.coefficient != 1) + { + os << d.coefficient; + first = false; + } + for(const auto& [base, exp] : d.factors) + { + if(not first) + os << "*"; + if(exp == 1) + os << print_expr(base, prec_mul + 1); + else + os << print_expr(base, prec_mul + 1) << "**" << exp; + first = false; + } + std::string raw = os.str(); + if(parent_prec > prec_mul) + return "(" + raw + ")"; + return raw; + }, + [&](const fdiv_data& d) -> std::string { + std::string s = print_expr(d.numerator, prec_mul + 1) + "/" + + print_expr(d.denominator, prec_mul + 1); + if(parent_prec > prec_mul) + return "(" + s + ")"; + return s; + }}, + e->data); } // =================================================================== From dcfe82540e26a0332d2ebc06c5b6472abcda4663 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 24 Mar 2026 15:23:53 -0700 Subject: [PATCH 04/60] format --- test/symbolic_test.cpp | 51 ++++++++++-------------------------------- 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/test/symbolic_test.cpp b/test/symbolic_test.cpp index df0528b5192..27bd3d3fccf 100644 --- a/test/symbolic_test.cpp +++ b/test/symbolic_test.cpp @@ -57,10 +57,7 @@ TEST_CASE(add_identity) EXPECT(0 + se("H") == se("H")); } -TEST_CASE(add_commutativity) -{ - EXPECT(se("H") + se("W") == se("W") + se("H")); -} +TEST_CASE(add_commutativity) { EXPECT(se("H") + se("W") == se("W") + se("H")); } TEST_CASE(add_like_term_folding) { @@ -97,20 +94,11 @@ TEST_CASE(add_cancellation) EXPECT(se("H") + neg_h == se(0)); } -TEST_CASE(sub_identity) -{ - EXPECT(se("H") - 0 == se("H")); -} +TEST_CASE(sub_identity) { EXPECT(se("H") - 0 == se("H")); } -TEST_CASE(sub_self) -{ - EXPECT(se("H") - se("H") == se(0)); -} +TEST_CASE(sub_self) { EXPECT(se("H") - se("H") == se(0)); } -TEST_CASE(sub_constant_folding) -{ - EXPECT(se(10) - se(3) == se(7)); -} +TEST_CASE(sub_constant_folding) { EXPECT(se(10) - se(3) == se(7)); } TEST_CASE(sub_produces_negation) { @@ -142,15 +130,9 @@ TEST_CASE(mul_zero) EXPECT(0 * se("H") == se(0)); } -TEST_CASE(mul_constant_folding) -{ - EXPECT(se(3) * se(7) == se(21)); -} +TEST_CASE(mul_constant_folding) { EXPECT(se(3) * se(7) == se(21)); } -TEST_CASE(mul_commutativity) -{ - EXPECT(se("B") * se("A") == se("A") * se("B")); -} +TEST_CASE(mul_commutativity) { EXPECT(se("B") * se("A") == se("A") * se("B")); } TEST_CASE(mul_coefficient_accumulation) { @@ -179,10 +161,7 @@ TEST_CASE(mul_symbolic_times_add_no_distribution) EXPECT(r != se("N") * se("H") + se("N")); } -TEST_CASE(fdiv_identity) -{ - EXPECT(se("H") / 1 == se("H")); -} +TEST_CASE(fdiv_identity) { EXPECT(se("H") / 1 == se("H")); } TEST_CASE(fdiv_constant_folding) { @@ -221,10 +200,7 @@ TEST_CASE(add_of_two_adds) EXPECT(r == 2 * se("H") + 3); } -TEST_CASE(sub_strip_constant) -{ - EXPECT((se("H") + 1) - se("H") == se(1)); -} +TEST_CASE(sub_strip_constant) { EXPECT((se("H") + 1) - se("H") == se(1)); } TEST_CASE(sub_of_two_adds) { @@ -285,7 +261,7 @@ TEST_CASE(sub_compound_product_mixed) // Duplicate A*B terms fold even when separated by another term TEST_CASE(add_multi_term_accumulation) { - auto r = se("A") * se("B") + se("C") + se("A") * se("B"); + auto r = se("A") * se("B") + se("C") + se("A") * se("B"); auto expected = 2 * (se("A") * se("B")) + se("C"); EXPECT(r == expected); } @@ -321,7 +297,6 @@ TEST_CASE(eq_empty) EXPECT(se(0) != se{}); } - // =================================================================== // Tier 3: Evaluation and substitution // =================================================================== @@ -403,10 +378,10 @@ TEST_CASE(subs_floor_div) // eval() and subs()+eval() must agree on a compound expression TEST_CASE(subs_eval_cross_validation) { - auto e = (se("N") * se("H") - 3) / 2 + 1; + auto e = (se("N") * se("H") - 3) / 2 + 1; std::map m = {{"N", 4}, {"H", 26}}; - auto via_eval = e.eval(m); - auto via_subs = e.subs(m).eval({}); + auto via_eval = e.eval(m); + auto via_subs = e.subs(m).eval({}); EXPECT(via_eval == via_subs); } @@ -604,7 +579,6 @@ TEST_CASE(edge_neg_one_coefficient) EXPECT(se("-H") + se("H") == se(0)); } - TEST_CASE(edge_empty_operations) { se empty; @@ -639,7 +613,6 @@ TEST_CASE(edge_chained_operations) EXPECT(e == 11 * se("H")); } - TEST_CASE(edge_repeated_parse) { for(int i = 0; i < 10; ++i) From 2ec09692e853d4d7a4cea8e6966e11b642b62475 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 24 Mar 2026 17:21:20 -0700 Subject: [PATCH 05/60] integrate symbolic expression in dynamic_dimension --- src/include/migraphx/shape.hpp | 58 ++++++++- src/shape.cpp | 174 +++++++++++++++++++++++-- src/symbolic.cpp | 140 ++++++++++++++++++++ test/shape_test.cpp | 227 ++++++++++++++++++++++++++++++++- 4 files changed, 583 insertions(+), 16 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index c9b343cd01c..7a19a8b5449 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -39,6 +39,7 @@ #include #include #include +#include #include namespace migraphx { @@ -99,19 +100,38 @@ struct MIGRAPHX_EXPORT shape std::size_t min = 0; std::size_t max = 0; std::set optimals{}; + optional sym; + + dynamic_dimension() = default; + dynamic_dimension(std::size_t min_v, std::size_t max_v) : min(min_v), max(max_v) {} + dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt) + : min(min_v), max(max_v), optimals(std::move(opt)) + { + } + dynamic_dimension(std::size_t min_v, + std::size_t max_v, + std::set opt, + optional s) + : min(min_v), max(max_v), optimals(std::move(opt)), sym(std::move(s)) + { + } template static auto reflect(Self& self, F f) { - return pack(f(self.min, "min"), f(self.max, "max"), f(self.optimals, "optimals")); + return pack(f(self.min, "min"), + f(self.max, "max"), + f(self.optimals, "optimals"), + f(self.sym, "sym")); } bool is_fixed() const; + bool is_symbolic() const { return sym.has_value(); } bool has_optimal() const; /** * Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if - * possible. + * possible. Preserves the symbolic expression only when the result is still dynamic. */ std::optional intersection(const dynamic_dimension& other) const { @@ -119,7 +139,12 @@ struct MIGRAPHX_EXPORT shape auto right = std::min(this->max, other.max); if(left <= right) { - return dynamic_dimension{left, right}; + optional s; + if(left != right) + { + s = (this->sym.has_value() and not this->is_fixed()) ? this->sym : other.sym; + } + return dynamic_dimension{left, right, {}, s}; } return nullopt; } @@ -137,10 +162,11 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y); - // add, subtract, multiply fixed std::size_t dimension + // add, subtract, multiply, divide fixed std::size_t dimension dynamic_dimension& operator+=(const std::size_t& x); dynamic_dimension& operator-=(const std::size_t& x); dynamic_dimension& operator*=(const std::size_t& x); + dynamic_dimension& operator/=(const std::size_t& x); MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x, @@ -151,6 +177,22 @@ struct MIGRAPHX_EXPORT shape const std::size_t& y); MIGRAPHX_EXPORT friend dynamic_dimension operator*(const std::size_t& x, const dynamic_dimension& y); + MIGRAPHX_EXPORT friend dynamic_dimension operator/(const dynamic_dimension& x, + const std::size_t& y); + + // dd-to-dd arithmetic (defined in symbolic.cpp) + dynamic_dimension& operator+=(const dynamic_dimension& x); + dynamic_dimension& operator-=(const dynamic_dimension& x); + dynamic_dimension& operator*=(const dynamic_dimension& x); + dynamic_dimension& operator/=(const dynamic_dimension& x); + MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x, + const dynamic_dimension& y); + MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x, + const dynamic_dimension& y); + MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x, + const dynamic_dimension& y); + MIGRAPHX_EXPORT friend dynamic_dimension operator/(const dynamic_dimension& x, + const dynamic_dimension& y); }; static std::string to_sizes_string(const std::vector& shapes); @@ -178,6 +220,8 @@ struct MIGRAPHX_EXPORT shape shape(type_t t, std::vector dims); + shape(type_t t, std::vector dims, std::vector dstrides); + // Construct a dynamic shape from vectors of mins, maxes, and optimals. // optimals_list is a vector of optimals that corresponds to each min and max. shape(type_t t, @@ -245,6 +289,9 @@ struct MIGRAPHX_EXPORT shape const std::vector& dyn_dims() const; + bool symbolic() const; + const std::vector& dyn_strides() const; + /*! * Minimum lengths for dynamic shape. * lens() for static shape. @@ -364,11 +411,12 @@ struct MIGRAPHX_EXPORT shape shape with_type(type_t t) const; - // convert the shape to an equivalent dynamic shape with empty optimals + // convert the shape to an equivalent dynamic shape with constant symbolic strides shape to_dynamic() const; // convert the shape to a static one setting any non-fixed dynamic_dimensions to x shape to_static(std::size_t x) const; + shape to_static(const std::map& symbol_map) const; MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); diff --git a/src/shape.cpp b/src/shape.cpp index f977237240b..667426a1086 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -77,6 +78,16 @@ struct shape_impl shape_impl(shape::type_t t, std::vector dims) : m_type(t), m_dyn_dims(std::move(dims)) + { + if(std::any_of( + m_dyn_dims.begin(), m_dyn_dims.end(), [](const auto& d) { return d.is_symbolic(); })) + calculate_dyn_strides(); + } + + shape_impl(shape::type_t t, + std::vector dims, + std::vector dstrides) + : m_type(t), m_dyn_dims(std::move(dims)), m_dyn_strides(std::move(dstrides)) { } @@ -112,6 +123,23 @@ struct shape_impl bool m_standard = false; std::vector m_dyn_dims = {}; + std::vector m_dyn_strides = {}; + + void calculate_dyn_strides() + { + m_dyn_strides.clear(); + if(m_dyn_dims.empty()) + return; + m_dyn_strides.resize(m_dyn_dims.size()); + m_dyn_strides.back() = symbolic_expr(std::size_t{1}); + std::transform(m_dyn_dims.rbegin(), + m_dyn_dims.rend() - 1, + m_dyn_strides.rbegin(), + m_dyn_strides.rbegin() + 1, + [](const auto& dd, const auto& stride) { + return dd.sym.value_or(symbolic_expr(dd.min)) * stride; + }); + } void calculate_strides() { @@ -359,6 +387,11 @@ shape::shape(type_t t, std::vector dims) { } +shape::shape(type_t t, std::vector dims, std::vector dstrides) + : impl(std::make_shared(t, std::move(dims), std::move(dstrides))) +{ +} + shape::shape(type_t t, std::vector mins, std::vector maxes, @@ -638,7 +671,15 @@ shape shape::to_dynamic() const { return *this; } - return {type(), lens(), lens(), {}}; + std::vector dims; + dims.reserve(ndim()); + for(auto len : lens()) + dims.push_back(dynamic_dimension{len, len}); + std::vector dstrides; + dstrides.reserve(ndim()); + for(auto s : strides()) + dstrides.emplace_back(s); + return shape(type(), std::move(dims), std::move(dstrides)); } shape shape::to_static(std::size_t x) const @@ -665,6 +706,40 @@ shape shape::to_static(std::size_t x) const return {type(), static_lens}; } +shape shape::to_static(const std::map& symbol_map) const +{ + if(not sub_shapes().empty()) + { + std::vector subs; + std::transform(sub_shapes().cbegin(), + sub_shapes().cend(), + std::back_inserter(subs), + [&](auto s) { return s.to_static(symbol_map); }); + return shape(subs); + } + if(not this->dynamic()) + return *this; + std::vector static_lens(this->ndim()); + std::transform(this->dyn_dims().cbegin(), + this->dyn_dims().cend(), + static_lens.begin(), + [&](const auto& dd) -> std::size_t { + if(dd.is_fixed()) + return dd.min; + if(dd.sym) + return dd.sym->eval(symbol_map); + MIGRAPHX_THROW("to_static: non-fixed dimension has no symbolic expression"); + }); + const auto& ds = this->dyn_strides(); + if(ds.empty()) + return {type(), static_lens}; + std::vector static_strides(ds.size()); + std::transform(ds.cbegin(), ds.cend(), static_strides.begin(), [&](const auto& s) { + return s.eval(symbol_map); + }); + return {type(), static_lens, static_strides}; +} + std::size_t shape::element_space() const { return impl->element_space(); } std::string shape::type_string() const { return name(this->type()); } @@ -693,6 +768,15 @@ const std::vector& shape::dyn_dims() const return impl->m_dyn_dims; } +bool shape::symbolic() const +{ + return std::any_of(impl->m_dyn_dims.begin(), impl->m_dyn_dims.end(), [](const auto& d) { + return d.is_symbolic(); + }); +} + +const std::vector& shape::dyn_strides() const { return impl->m_dyn_strides; } + std::vector shape::min_lens() const { return this->dynamic() ? impl->min_lens() : this->lens(); @@ -711,6 +795,8 @@ bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty() shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) { + if(is_symbolic()) + sym = *sym + symbolic_expr(x); this->min += x; this->max += x; std::set new_optimals; @@ -724,6 +810,8 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x) { + if(is_symbolic()) + sym = *sym - symbolic_expr(x); assert(this->min >= x); assert(this->max >= x); this->min -= x; @@ -742,6 +830,8 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t& x) { + if(is_symbolic()) + sym = *sym * symbolic_expr(x); this->min *= x; this->max *= x; std::set new_optimals; @@ -753,9 +843,25 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t return *this; } +shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const std::size_t& x) +{ + if(is_symbolic()) + sym = *sym / symbolic_expr(x); + this->min = (x == 0) ? 0 : this->min / x; + this->max = (x == 0) ? std::numeric_limits::max() : this->max / x; + std::set new_optimals; + std::transform(this->optimals.begin(), + this->optimals.end(), + std::inserter(new_optimals, new_optimals.begin()), + [&x](const auto& opt) { return (x == 0) ? std::size_t{0} : opt / x; }); + this->optimals = new_optimals; + return *this; +} + bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { - // don't check optimals if both are fixed + if(not(x.sym == y.sym)) + return false; return (x.min == y.min and x.max == y.max and ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals))); } @@ -766,7 +872,15 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio } std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { - os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]"; + if(x.is_symbolic()) + os << x.sym->to_string(); + if(x.is_fixed()) + { + if(not x.is_symbolic()) + os << x.min; + return os; + } + os << "[" << x.min << ".." << x.max << "]"; return os; } @@ -806,12 +920,19 @@ shape::dynamic_dimension operator*(const std::size_t& x, const shape::dynamic_di return y * x; } +shape::dynamic_dimension operator/(const shape::dynamic_dimension& x, const std::size_t& y) +{ + auto dd = x; + return dd /= y; +} + bool operator==(const shape& x, const shape& y) { if(x.dynamic() and y.dynamic()) { - return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and - x.sub_shapes() == y.sub_shapes()); + return x.impl == y.impl or + (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and + x.dyn_strides() == y.dyn_strides() and x.sub_shapes() == y.sub_shapes()); } return x.impl == y.impl or (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and @@ -824,7 +945,23 @@ std::ostream& operator<<(std::ostream& os, const shape& x) { if(x.sub_shapes().empty()) { - if(x.dynamic()) + if(x.symbolic()) + { + os << x.type_string() << ", {"; + const auto& dd = x.dyn_dims(); + for(std::size_t i = 0; i < dd.size(); ++i) + { + if(i > 0) + os << ", "; + if(dd[i].is_symbolic()) + os << dd[i]; + else + os << dd[i].min; + } + os << "}, "; + os << "{" << to_string_range(x.dyn_strides()) << "}"; + } + else if(x.dynamic()) { os << "dynamic, "; os << x.type_string() << ", "; @@ -891,7 +1028,6 @@ void migraphx_to_value(value& v, const shape& s) value result; result["type"] = migraphx::to_value(s.type_string()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); - // avoid calling functions that will throw if(s.dynamic()) { result["lens"] = {}; @@ -904,6 +1040,14 @@ void migraphx_to_value(value& v, const shape& s) result["strides"] = migraphx::to_value(s.strides()); result["dynamic_dimensions"] = {}; } + if(s.symbolic()) + { + result["dyn_strides"] = migraphx::to_value(s.dyn_strides()); + } + else + { + result["dyn_strides"] = {}; + } v = result; } @@ -925,13 +1069,25 @@ void migraphx_from_value(const value& v, shape& s) else { auto v_dd = v.at("dynamic_dimensions"); - std::vector dyn_dims(v.at("dynamic_dimensions").size()); + std::vector dyn_dims(v_dd.size()); std::transform( v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) { return from_value(x); }); - s = shape{shape::parse_type(t), dyn_dims}; + if(v.contains("dyn_strides") and not v.at("dyn_strides").empty()) + { + auto v_ds = v.at("dyn_strides"); + std::vector dstrides; + dstrides.reserve(v_ds.size()); + for(const auto& x : v_ds) + dstrides.emplace_back(x.get_string()); + s = shape(shape::parse_type(t), std::move(dyn_dims), std::move(dstrides)); + } + else + { + s = shape{shape::parse_type(t), dyn_dims}; + } } } } diff --git a/src/symbolic.cpp b/src/symbolic.cpp index 36e7d3b1090..3fc2f54c4e8 100644 --- a/src/symbolic.cpp +++ b/src/symbolic.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -870,5 +871,144 @@ void migraphx_from_value(const value& v, symbolic_expr& e) e = symbolic_expr(s); } +// =================================================================== +// Section 10: dynamic_dimension arithmetic +// =================================================================== + +using dd = shape::dynamic_dimension; + +static optional to_expr(const dd& d) +{ + if(d.sym.has_value()) + return d.sym; + if(d.is_fixed()) + return symbolic_expr(d.min); + return nullopt; +} + +dd& dd::operator+=(const dd& x) +{ + auto lhs = to_expr(*this); + auto rhs = to_expr(x); + min = min + x.min; + max = (max > std::numeric_limits::max() - x.max) + ? std::numeric_limits::max() + : max + x.max; + if(x.is_fixed()) + { + std::set new_optimals; + std::transform(optimals.begin(), + optimals.end(), + std::inserter(new_optimals, new_optimals.begin()), + [&](auto o) { return o + x.min; }); + optimals = new_optimals; + } + else + { + optimals.clear(); + } + sym = (lhs and rhs) ? optional(*lhs + *rhs) : nullopt; + return *this; +} + +dd& dd::operator-=(const dd& x) +{ + auto lhs = to_expr(*this); + auto rhs = to_expr(x); + min = (min > x.max) ? min - x.max : 0; + max = (max > x.min) ? max - x.min : 0; + if(x.is_fixed()) + { + std::set new_optimals; + std::transform(optimals.begin(), + optimals.end(), + std::inserter(new_optimals, new_optimals.begin()), + [&](auto o) { return (o > x.min) ? o - x.min : 0; }); + optimals = new_optimals; + } + else + { + optimals.clear(); + } + sym = (lhs and rhs) ? optional(*lhs - *rhs) : nullopt; + return *this; +} + +dd& dd::operator*=(const dd& x) +{ + auto lhs = to_expr(*this); + auto rhs = to_expr(x); + min = min * x.min; + max = (max > std::numeric_limits::max() / (x.max == 0 ? 1 : x.max)) + ? std::numeric_limits::max() + : max * x.max; + if(x.is_fixed()) + { + std::set new_optimals; + std::transform(optimals.begin(), + optimals.end(), + std::inserter(new_optimals, new_optimals.begin()), + [&](auto o) { return o * x.min; }); + optimals = new_optimals; + } + else + { + optimals.clear(); + } + sym = (lhs and rhs) ? optional(*lhs * *rhs) : nullopt; + return *this; +} + +dd& dd::operator/=(const dd& x) +{ + auto lhs = to_expr(*this); + auto rhs = to_expr(x); + min = (x.max == 0) ? 0 : min / x.max; + max = (x.min == 0) ? std::numeric_limits::max() : max / x.min; + if(x.is_fixed()) + { + std::set new_optimals; + std::transform(optimals.begin(), + optimals.end(), + std::inserter(new_optimals, new_optimals.begin()), + [&](auto o) { return (x.min == 0) ? std::size_t{0} : o / x.min; }); + optimals = new_optimals; + } + else + { + optimals.clear(); + } + sym = (lhs and rhs) ? optional(*lhs / *rhs) : nullopt; + return *this; +} + +dd operator+(const dd& x, const dd& y) +{ + auto result = x; + result += y; + return result; +} + +dd operator-(const dd& x, const dd& y) +{ + auto result = x; + result -= y; + return result; +} + +dd operator*(const dd& x, const dd& y) +{ + auto result = x; + result *= y; + return result; +} + +dd operator/(const dd& x, const dd& y) +{ + auto result = x; + result /= y; + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 386c9058aeb..8346cc4b84b 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -439,9 +440,12 @@ TEST_CASE(test_shape_transposed2) TEST_CASE(test_shape_static_to_dynamic) { + using se = migraphx::symbolic_expr; migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s1 = s0.to_dynamic(); - migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}}; + migraphx::shape s2{migraphx::shape::float_type, + {{1, 1}, {2, 2}, {4, 4}, {4, 4}}, + {se(32), se(16), se(4), se(1)}}; EXPECT(s1 == s2); } @@ -454,6 +458,7 @@ TEST_CASE(test_shape_dyn_to_dynamic) TEST_CASE(test_shape_subshapes_to_dynamic) { + using se = migraphx::symbolic_expr; std::vector sub_shapes0 = {}; sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}); @@ -461,7 +466,8 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}}); + sub_shapes1.push_back(migraphx::shape{ + migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {se(20), se(5), se(1)}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } @@ -1272,4 +1278,221 @@ TEST_CASE(shape_same_lens_static_dynamic) EXPECT(not migraphx::shape::same_lens(s1, s3)); } +// =================================================================== +// Symbolic dynamic_dimension tests +// =================================================================== + +TEST_CASE(test_dd_symbolic_add_size_t) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension dd{1, 8, {4}, se("N")}; + dd += 2; + EXPECT(dd.min == 3); + EXPECT(dd.max == 10); + EXPECT(*dd.sym == se("N") + 2); +} + +TEST_CASE(test_dd_symbolic_sub_size_t) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension dd{3, 8, {4}, se("N")}; + dd -= 1; + EXPECT(dd.min == 2); + EXPECT(dd.max == 7); + EXPECT(*dd.sym == se("N") - 1); +} + +TEST_CASE(test_dd_symbolic_mul_size_t) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension dd{1, 8, {4}, se("N")}; + dd *= 3; + EXPECT(dd.min == 3); + EXPECT(dd.max == 24); + EXPECT(*dd.sym == se("N") * 3); +} + +TEST_CASE(test_dd_symbolic_div_size_t) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension dd{4, 16, {8}, se("N")}; + dd /= 2; + EXPECT(dd.min == 2); + EXPECT(dd.max == 8); + EXPECT(*dd.sym == se("N") / 2); +} + +TEST_CASE(test_dd_symbolic_add_dd) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension b{2, 4, {}, se("C")}; + auto c = a + b; + EXPECT(c.min == 3); + EXPECT(c.max == 12); + EXPECT(*c.sym == se("N") + se("C")); +} + +TEST_CASE(test_dd_symbolic_sub_dd) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{4, 16, {}, se("N")}; + migraphx::shape::dynamic_dimension b{1, 4, {}, se("K")}; + auto c = a - b; + EXPECT(c.min == 0); + EXPECT(c.max == 15); + EXPECT(*c.sym == se("N") - se("K")); +} + +TEST_CASE(test_dd_symbolic_mul_dd) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension b{2, 4, {}, se("C")}; + auto c = a * b; + EXPECT(c.min == 2); + EXPECT(c.max == 32); + EXPECT(*c.sym == se("N") * se("C")); +} + +TEST_CASE(test_dd_symbolic_div_dd) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{4, 16, {}, se("N")}; + migraphx::shape::dynamic_dimension b{2, 4, {}, se("K")}; + auto c = a / b; + EXPECT(c.min == 1); + EXPECT(c.max == 8); + EXPECT(*c.sym == se("N") / se("K")); +} + +TEST_CASE(test_dd_symbolic_plus_fixed) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension b{3, 3}; + auto c = a + b; + EXPECT(c.sym.has_value()); + EXPECT(*c.sym == se("N") + 3); + EXPECT(c.min == 4); + EXPECT(c.max == 11); +} + +TEST_CASE(test_dd_nonfixed_nonsymbolic_plus_symbolic_drops_sym) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}}; + migraphx::shape::dynamic_dimension b{2, 4, {}, se("C")}; + auto c = a + b; + EXPECT(not c.sym.has_value()); + EXPECT(c.min == 3); + EXPECT(c.max == 12); +} + +TEST_CASE(test_dd_nonsymbolic_remains_nonsymbolic) +{ + migraphx::shape::dynamic_dimension a{1, 8, {}}; + migraphx::shape::dynamic_dimension b{2, 4, {}}; + auto c = a + b; + EXPECT(not c.sym.has_value()); +} + +TEST_CASE(test_dd_equality_with_sym) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension b{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension c{1, 8, {}, se("C")}; + migraphx::shape::dynamic_dimension d{1, 8, {}}; + EXPECT(a == b); + EXPECT(a != c); + EXPECT(a != d); +} + +TEST_CASE(test_symbolic_shape_construction) +{ + using se = migraphx::symbolic_expr; + migraphx::shape s{migraphx::shape::float_type, + {{1, 8, {}, se("N")}, {3, 3}, {224, 224}}, + {se("N") * se(3) * se(224), se(224), se(1)}}; + EXPECT(s.dynamic()); + EXPECT(s.symbolic()); + EXPECT(s.dyn_dims().size() == 3); + EXPECT(s.dyn_strides().size() == 3); +} + +TEST_CASE(test_symbolic_stride_auto_compute) +{ + using se = migraphx::symbolic_expr; + migraphx::shape s{migraphx::shape::float_type, + {{1, 8, {}, se("N")}, {1, 16, {}, se("S")}, {4, 4}}}; + EXPECT(s.symbolic()); + EXPECT(s.dyn_strides().size() == 3); + EXPECT(s.dyn_strides()[2] == se(1)); + EXPECT(s.dyn_strides()[1] == se(4)); + EXPECT(s.dyn_strides()[0] == se("S") * 4); +} + +TEST_CASE(test_symbolic_to_static) +{ + using se = migraphx::symbolic_expr; + migraphx::shape s{migraphx::shape::float_type, + {{1, 8, {}, se("N")}, {1, 16, {}, se("S")}, {4, 4}}}; + std::map symbol_map = {{"N", 2}, {"S", 8}}; + auto s_static = s.to_static(symbol_map); + EXPECT(not s_static.dynamic()); + EXPECT(s_static.lens() == std::vector{2, 8, 4}); + EXPECT(s_static.strides() == std::vector{32, 4, 1}); +} + +TEST_CASE(test_symbolic_shape_equality) +{ + using se = migraphx::symbolic_expr; + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, se("N")}, {3, 3}}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, se("N")}, {3, 3}}}; + migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, se("C")}, {3, 3}}}; + EXPECT(s1 == s2); + EXPECT(s1 != s3); +} + +TEST_CASE(test_symbolic_shape_print) +{ + using se = migraphx::symbolic_expr; + auto to_str = [](const migraphx::shape& s) { + std::stringstream ss; + ss << s; + return ss.str(); + }; + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, se("N")}, {3, 3}, {4, 4}}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, se("N")}, {3, 3}, {4, 4}}}; + migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, se("C")}, {3, 3}, {4, 4}}}; + EXPECT(to_str(s1) == to_str(s2)); + EXPECT(to_str(s1) != to_str(s3)); +} + +TEST_CASE(test_dd_intersection_symbolic) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension b{2, 6}; + auto result = a.intersection(b); + EXPECT(result.has_value()); + EXPECT(result->min == 2); + EXPECT(result->max == 6); + EXPECT(result->sym.has_value()); + EXPECT(*result->sym == se("N")); +} + +TEST_CASE(test_dd_intersection_fixed_drops_sym) +{ + using se = migraphx::symbolic_expr; + migraphx::shape::dynamic_dimension a{1, 8, {}, se("N")}; + migraphx::shape::dynamic_dimension b{4, 4}; + auto result = a.intersection(b); + EXPECT(result.has_value()); + EXPECT(result->min == 4); + EXPECT(result->max == 4); + EXPECT(not result->sym.has_value()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 18caf6b3b157d1974798dcc84a7fee1782015055 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 24 Mar 2026 17:39:15 -0700 Subject: [PATCH 06/60] tidy --- src/symbolic.cpp | 148 ++++++++++++++++++++++++----------------------- 1 file changed, 77 insertions(+), 71 deletions(-) diff --git a/src/symbolic.cpp b/src/symbolic.cpp index 36e7d3b1090..6d19771f0a2 100644 --- a/src/symbolic.cpp +++ b/src/symbolic.cpp @@ -114,7 +114,7 @@ static const mul_data& get_mul(const expr_ptr& e) { return std::get(e- static std::size_t hash_combine(std::size_t seed, std::size_t v) { - return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2)); + return seed ^ (v + 0x9e3779b9u + (seed << 6u) + (seed >> 2u)); } template @@ -541,74 +541,80 @@ enum static std::string print_expr(const expr_ptr& e, int parent_prec = 0); +static std::string print_add(const add_data& d, int parent_prec) +{ + std::ostringstream os; + bool first = true; + for(const auto& [term, coeff] : d.terms) + { + if(first) + { + if(coeff == -1) + os << "-" << print_expr(term, prec_add); + else if(coeff == 1) + os << print_expr(term, prec_add); + else + os << coeff << "*" << print_expr(term, prec_mul + 1); + first = false; + } + else + { + if(coeff == 1) + os << " + " << print_expr(term, prec_add); + else if(coeff == -1) + os << " - " << print_expr(term, prec_add); + else if(coeff > 0) + os << " + " << coeff << "*" << print_expr(term, prec_mul + 1); + else + os << " - " << (-coeff) << "*" << print_expr(term, prec_mul + 1); + } + } + if(d.constant > 0) + os << " + " << d.constant; + else if(d.constant < 0) + os << " - " << (-d.constant); + std::string s = os.str(); + if(parent_prec > prec_add) + return "(" + s + ")"; + return s; +} + +static std::string print_mul(const mul_data& d, int parent_prec) +{ + std::ostringstream os; + bool first = true; + if(d.coefficient == -1) + { + os << "-"; + } + else if(d.coefficient != 1) + { + os << d.coefficient; + first = false; + } + for(const auto& [base, exp] : d.factors) + { + if(not first) + os << "*"; + if(exp == 1) + os << print_expr(base, prec_mul + 1); + else + os << print_expr(base, prec_mul + 1) << "**" << exp; + first = false; + } + std::string raw = os.str(); + if(parent_prec > prec_mul) + return "(" + raw + ")"; + return raw; +} + static std::string print_expr(const expr_ptr& e, int parent_prec) { return std::visit( overloaded{[](const integer_data& d) -> std::string { return std::to_string(d.value); }, [](const symbol_data& d) -> std::string { return d.name; }, - [&](const add_data& d) -> std::string { - std::ostringstream os; - bool first = true; - for(const auto& [term, coeff] : d.terms) - { - if(first) - { - if(coeff == -1) - os << "-" << print_expr(term, prec_add); - else if(coeff == 1) - os << print_expr(term, prec_add); - else - os << coeff << "*" << print_expr(term, prec_mul + 1); - first = false; - } - else - { - if(coeff == 1) - os << " + " << print_expr(term, prec_add); - else if(coeff == -1) - os << " - " << print_expr(term, prec_add); - else if(coeff > 0) - os << " + " << coeff << "*" << print_expr(term, prec_mul + 1); - else - os << " - " << (-coeff) << "*" << print_expr(term, prec_mul + 1); - } - } - if(d.constant > 0) - os << " + " << d.constant; - else if(d.constant < 0) - os << " - " << (-d.constant); - std::string s = os.str(); - if(parent_prec > prec_add) - return "(" + s + ")"; - return s; - }, - [&](const mul_data& d) -> std::string { - std::ostringstream os; - bool first = true; - if(d.coefficient == -1) - { - os << "-"; - } - else if(d.coefficient != 1) - { - os << d.coefficient; - first = false; - } - for(const auto& [base, exp] : d.factors) - { - if(not first) - os << "*"; - if(exp == 1) - os << print_expr(base, prec_mul + 1); - else - os << print_expr(base, prec_mul + 1) << "**" << exp; - first = false; - } - std::string raw = os.str(); - if(parent_prec > prec_mul) - return "(" + raw + ")"; - return raw; - }, + [&](const add_data& d) -> std::string { return print_add(d, parent_prec); }, + [&](const mul_data& d) -> std::string { return print_mul(d, parent_prec); }, [&](const fdiv_data& d) -> std::string { std::string s = print_expr(d.numerator, prec_mul + 1) + "/" + print_expr(d.denominator, prec_mul + 1); @@ -625,7 +631,7 @@ static std::string print_expr(const expr_ptr& e, int parent_prec) static void skip_ws(const char*& p) { - while(*p and std::isspace(static_cast(*p))) + while(*p != 0 and std::isspace(static_cast(*p)) != 0) ++p; } @@ -637,20 +643,20 @@ static expr_ptr parse_primary(const char*& p); static expr_ptr parse_primary(const char*& p) { skip_ws(p); - if(std::isdigit(static_cast(*p))) + if(std::isdigit(static_cast(*p)) != 0) { int64_t n = 0; - while(std::isdigit(static_cast(*p))) + while(std::isdigit(static_cast(*p)) != 0) { n = n * 10 + (*p - '0'); ++p; } return make_integer(n); } - if(std::isalpha(static_cast(*p)) or *p == '_') + if(std::isalpha(static_cast(*p)) != 0 or *p == '_') { std::string name; - while(std::isalnum(static_cast(*p)) or *p == '_') + while(std::isalnum(static_cast(*p)) != 0 or *p == '_') { name += *p; ++p; @@ -790,7 +796,7 @@ std::size_t symbolic_expr::eval(const std::map& symbol if(empty()) return 0; auto v = evaluate(p->node, symbol_map); - assert(v >= 0 && "symbolic dimension evaluated to negative value"); + assert(v >= 0 and "symbolic dimension evaluated to negative value"); return static_cast(v); } @@ -863,7 +869,7 @@ void migraphx_to_value(value& v, const symbolic_expr& e) { v = migraphx::to_valu void migraphx_from_value(const value& v, symbolic_expr& e) { - auto s = v.get_string(); + const auto& s = v.get_string(); if(s.empty()) e = symbolic_expr{}; else From 6af3621713f9d69b1ab93e0920dbb91742096ab6 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 11:16:47 -0700 Subject: [PATCH 07/60] fix constructor ambiguity --- src/include/migraphx/shape.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 7a19a8b5449..698b11c79ef 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -220,7 +220,9 @@ struct MIGRAPHX_EXPORT shape shape(type_t t, std::vector dims); - shape(type_t t, std::vector dims, std::vector dstrides); + explicit shape(type_t t, + std::vector dims, + std::vector dstrides); // Construct a dynamic shape from vectors of mins, maxes, and optimals. // optimals_list is a vector of optimals that corresponds to each min and max. From b7d7c237ea70f49c18fa47abba3b8974777bc920 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 14:08:29 -0700 Subject: [PATCH 08/60] fix ambiguity --- src/include/migraphx/shape.hpp | 5 ++--- src/shape.cpp | 7 +++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 698b11c79ef..392a36be457 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -217,12 +217,11 @@ struct MIGRAPHX_EXPORT shape // Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to // shape(type_t, std::vector l) shape(type_t t, std::initializer_list d); + shape(type_t t, std::initializer_list l, std::initializer_list s); shape(type_t t, std::vector dims); - explicit shape(type_t t, - std::vector dims, - std::vector dstrides); + shape(type_t t, std::vector dims, std::vector dstrides); // Construct a dynamic shape from vectors of mins, maxes, and optimals. // optimals_list is a vector of optimals that corresponds to each min and max. diff --git a/src/shape.cpp b/src/shape.cpp index 667426a1086..3f632229d57 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -382,6 +382,13 @@ shape::shape(type_t t, std::initializer_list d) { } +shape::shape(type_t t, std::initializer_list l, std::initializer_list s) + : shape::shape(t, + std::vector{l.begin(), l.end()}, + std::vector{s.begin(), s.end()}) +{ +} + shape::shape(type_t t, std::vector dims) : impl(std::make_shared(t, std::move(dims))) { From 34861354e1c542b62144ebddb05a7927b23bf94d Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 18:39:40 -0700 Subject: [PATCH 09/60] update namespace and interface design --- src/CMakeLists.txt | 2 +- src/include/migraphx/sym.hpp | 101 ++++ src/include/migraphx/symbolic.hpp | 105 ---- src/{symbolic.cpp => sym.cpp} | 259 +++++++--- test/sym_test.cpp | 780 ++++++++++++++++++++++++++++++ test/symbolic_test.cpp | 625 ------------------------ 6 files changed, 1072 insertions(+), 800 deletions(-) create mode 100644 src/include/migraphx/sym.hpp delete mode 100644 src/include/migraphx/symbolic.hpp rename src/{symbolic.cpp => sym.cpp} (75%) create mode 100644 test/sym_test.cpp delete mode 100644 test/symbolic_test.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9905c198719..ea10f64fd34 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -136,7 +136,7 @@ add_library(migraphx serialize.cpp shape.cpp shape_transform_descriptor.cpp - symbolic.cpp + sym.cpp simplify_algebra.cpp simplify_dyn_ops.cpp simplify_reshapes.cpp diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp new file mode 100644 index 00000000000..99b86093f92 --- /dev/null +++ b/src/include/migraphx/sym.hpp @@ -0,0 +1,101 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SYM_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_SYM_HPP + +#include +#include +#include +#include +#include + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct value; + +namespace sym { + +struct MIGRAPHX_EXPORT expr +{ + expr(); + + bool empty() const; + std::size_t hash() const; + std::string to_string() const; + value to_value() const; + void from_value(const value& v); + std::size_t eval(const std::unordered_map& symbol_map) const; + expr subs(const std::unordered_map& symbol_map) const; + + MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend expr operator-(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend expr operator*(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e); + + struct impl; + + friend expr var(const std::string& name); + friend expr lit(std::size_t n); + friend expr parse(const std::string& s); + + private: + expr(std::shared_ptr pi); + std::shared_ptr p; +}; + +MIGRAPHX_EXPORT expr var(const std::string& name); +MIGRAPHX_EXPORT expr lit(std::size_t n); +MIGRAPHX_EXPORT expr parse(const std::string& s); + +inline expr operator+(const expr& a, std::size_t b) { return a + lit(b); } +inline expr operator+(std::size_t a, const expr& b) { return lit(a) + b; } +inline expr operator-(const expr& a, std::size_t b) { return a - lit(b); } +inline expr operator-(std::size_t a, const expr& b) { return lit(a) - b; } +inline expr operator*(const expr& a, std::size_t b) { return a * lit(b); } +inline expr operator*(std::size_t a, const expr& b) { return lit(a) * b; } +inline expr operator/(const expr& a, std::size_t b) { return a / lit(b); } +inline expr operator/(std::size_t a, const expr& b) { return lit(a) / b; } + +} // namespace sym + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +namespace std { +template <> +struct hash +{ + using argument_type = migraphx::sym::expr; + using result_type = std::size_t; + result_type operator()(const migraphx::sym::expr& e) const { return e.hash(); } +}; +} // namespace std + +#endif diff --git a/src/include/migraphx/symbolic.hpp b/src/include/migraphx/symbolic.hpp deleted file mode 100644 index 0cb0aa4c2b2..00000000000 --- a/src/include/migraphx/symbolic.hpp +++ /dev/null @@ -1,105 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SYMBOLIC_HPP -#define MIGRAPHX_GUARD_MIGRAPHLIB_SYMBOLIC_HPP - -#include -#include -#include -#include -#include - -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -struct value; - -struct MIGRAPHX_EXPORT symbolic_expr -{ - symbolic_expr(); - explicit symbolic_expr(std::size_t n); - explicit symbolic_expr(const std::string& s); - - bool empty() const; - std::string to_string() const; - std::size_t eval(const std::map& symbol_map) const; - symbolic_expr subs(const std::map& symbol_map) const; - - MIGRAPHX_EXPORT friend symbolic_expr operator+(const symbolic_expr& a, const symbolic_expr& b); - MIGRAPHX_EXPORT friend symbolic_expr operator-(const symbolic_expr& a, const symbolic_expr& b); - MIGRAPHX_EXPORT friend symbolic_expr operator*(const symbolic_expr& a, const symbolic_expr& b); - MIGRAPHX_EXPORT friend symbolic_expr operator/(const symbolic_expr& a, const symbolic_expr& b); - MIGRAPHX_EXPORT friend bool operator==(const symbolic_expr& a, const symbolic_expr& b); - MIGRAPHX_EXPORT friend bool operator!=(const symbolic_expr& a, const symbolic_expr& b); - MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const symbolic_expr& e); - - struct impl; - - private: - symbolic_expr(std::shared_ptr pi); - std::shared_ptr p; -}; - -inline symbolic_expr operator+(const symbolic_expr& a, std::size_t b) -{ - return a + symbolic_expr(b); -} -inline symbolic_expr operator+(std::size_t a, const symbolic_expr& b) -{ - return symbolic_expr(a) + b; -} -inline symbolic_expr operator-(const symbolic_expr& a, std::size_t b) -{ - return a - symbolic_expr(b); -} -inline symbolic_expr operator-(std::size_t a, const symbolic_expr& b) -{ - return symbolic_expr(a) - b; -} -inline symbolic_expr operator*(const symbolic_expr& a, std::size_t b) -{ - return a * symbolic_expr(b); -} -inline symbolic_expr operator*(std::size_t a, const symbolic_expr& b) -{ - return symbolic_expr(a) * b; -} -inline symbolic_expr operator/(const symbolic_expr& a, std::size_t b) -{ - return a / symbolic_expr(b); -} -inline symbolic_expr operator/(std::size_t a, const symbolic_expr& b) -{ - return symbolic_expr(a) / b; -} - -MIGRAPHX_EXPORT void migraphx_to_value(value& v, const symbolic_expr& e); -MIGRAPHX_EXPORT void migraphx_from_value(const value& v, symbolic_expr& e); - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif diff --git a/src/symbolic.cpp b/src/sym.cpp similarity index 75% rename from src/symbolic.cpp rename to src/sym.cpp index 6d19771f0a2..d8d035db477 100644 --- a/src/symbolic.cpp +++ b/src/sym.cpp @@ -22,7 +22,7 @@ * THE SOFTWARE. */ -#include +#include #include #include @@ -114,7 +114,7 @@ static const mul_data& get_mul(const expr_ptr& e) { return std::get(e- static std::size_t hash_combine(std::size_t seed, std::size_t v) { - return seed ^ (v + 0x9e3779b9u + (seed << 6u) + (seed >> 2u)); + return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2)); } template @@ -449,15 +449,16 @@ static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) // Section 6: Substitution and evaluation // =================================================================== -// Partial substitution: replaces bound symbols with integers and re-canonicalizes. -// Unbound symbols are left as-is, producing a simplified symbolic expression. -static expr_ptr substitute(const expr_ptr& e, const std::map& bindings) +using binding_map = std::map; +using subs_map = std::map; + +static expr_ptr substitute(const expr_ptr& e, const subs_map& bindings) { return std::visit(overloaded{[&](const integer_data&) -> expr_ptr { return e; }, - [&](const symbol_data& d) -> expr_ptr { - auto it = bindings.find(d.name); + [&](const symbol_data&) -> expr_ptr { + auto it = bindings.find(e); if(it != bindings.end()) - return make_integer(it->second); + return it->second; return e; }, [&](const add_data& d) -> expr_ptr { @@ -488,46 +489,39 @@ static expr_ptr substitute(const expr_ptr& e, const std::mapdata); } -// Full evaluation: computes integer result directly without allocations. -// All symbols must be bound; throws if any symbol is unbound. -static int64_t eval_direct(const expr_ptr& e, const std::map& symbol_map) +static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) { return std::visit(overloaded{[](const integer_data& d) -> int64_t { return d.value; }, [&](const symbol_data& d) -> int64_t { - auto it = symbol_map.find(d.name); - if(it != symbol_map.end()) + auto it = bindings.find(e); + if(it != bindings.end()) return static_cast(it->second); - MIGRAPHX_THROW("symbolic_expr::eval: unbound symbol '" + - d.name + "'"); + MIGRAPHX_THROW("sym::expr::eval: unbound symbol '" + d.name + + "'"); }, [&](const add_data& d) -> int64_t { int64_t sum = d.constant; for(const auto& [term, coeff] : d.terms) - sum += coeff * eval_direct(term, symbol_map); + sum += coeff * eval_direct(term, bindings); return sum; }, [&](const mul_data& d) -> int64_t { int64_t prod = d.coefficient; for(const auto& [base, exp] : d.factors) { - int64_t val = eval_direct(base, symbol_map); + int64_t val = eval_direct(base, bindings); for(int64_t i = 0; i < exp; ++i) prod *= val; } return prod; }, [&](const fdiv_data& d) -> int64_t { - return eval_direct(d.numerator, symbol_map) / - eval_direct(d.denominator, symbol_map); + return eval_direct(d.numerator, bindings) / + eval_direct(d.denominator, bindings); }}, e->data); } -static int64_t evaluate(const expr_ptr& e, const std::map& symbol_map) -{ - return eval_direct(e, symbol_map); -} - // =================================================================== // Section 7: Pretty-printer // =================================================================== @@ -631,7 +625,7 @@ static std::string print_expr(const expr_ptr& e, int parent_prec) static void skip_ws(const char*& p) { - while(*p != 0 and std::isspace(static_cast(*p)) != 0) + while(*p and std::isspace(static_cast(*p))) ++p; } @@ -643,20 +637,20 @@ static expr_ptr parse_primary(const char*& p); static expr_ptr parse_primary(const char*& p) { skip_ws(p); - if(std::isdigit(static_cast(*p)) != 0) + if(std::isdigit(static_cast(*p))) { int64_t n = 0; - while(std::isdigit(static_cast(*p)) != 0) + while(std::isdigit(static_cast(*p))) { n = n * 10 + (*p - '0'); ++p; } return make_integer(n); } - if(std::isalpha(static_cast(*p)) != 0 or *p == '_') + if(std::isalpha(static_cast(*p)) or *p == '_') { std::string name; - while(std::isalnum(static_cast(*p)) != 0 or *p == '_') + while(std::isalnum(static_cast(*p)) or *p == '_') { name += *p; ++p; @@ -755,10 +749,12 @@ static expr_ptr parse_string(const std::string& s) } // =================================================================== -// Section 9: symbolic_expr public API wrapper +// Section 9: sym::expr public API wrapper // =================================================================== -struct symbolic_expr::impl +namespace sym { + +struct expr::impl { expr_ptr node; @@ -766,88 +762,93 @@ struct symbolic_expr::impl explicit impl(expr_ptr e) : node(std::move(e)) {} }; -symbolic_expr::symbolic_expr() = default; +expr::expr() = default; -symbolic_expr::symbolic_expr(std::shared_ptr pi) : p(std::move(pi)) {} +expr::expr(std::shared_ptr pi) : p(std::move(pi)) {} -symbolic_expr::symbolic_expr(std::size_t n) - : p(std::make_shared(make_integer(static_cast(n)))) -{ -} +bool expr::empty() const { return p == nullptr; } -symbolic_expr::symbolic_expr(const std::string& s) +std::size_t expr::hash() const { - if(s.empty()) - return; - p = std::make_shared(parse_string(s)); + if(empty()) + return 0; + return p->node->cached_hash; } -bool symbolic_expr::empty() const { return p == nullptr; } - -std::string symbolic_expr::to_string() const +std::string expr::to_string() const { if(empty()) return {}; return print_expr(p->node); } -std::size_t symbolic_expr::eval(const std::map& symbol_map) const +std::size_t expr::eval(const std::unordered_map& symbol_map) const { if(empty()) return 0; - auto v = evaluate(p->node, symbol_map); - assert(v >= 0 and "symbolic dimension evaluated to negative value"); + binding_map bindings; + for(const auto& [k, v] : symbol_map) + { + if(k.empty() or not holds(k.p->node)) + MIGRAPHX_THROW("sym::expr::eval: map key '" + k.to_string() + "' is not a symbol"); + bindings[k.p->node] = v; + } + auto v = eval_direct(p->node, bindings); + assert(v >= 0 && "symbolic dimension evaluated to negative value"); return static_cast(v); } -symbolic_expr symbolic_expr::subs(const std::map& symbol_map) const +expr expr::subs(const std::unordered_map& symbol_map) const { if(empty()) return {}; - std::map bindings; + subs_map bindings; for(const auto& [k, v] : symbol_map) - bindings[k] = static_cast(v); - auto result = substitute(p->node, bindings); - return {std::make_shared(std::move(result))}; + { + if(k.empty() or not holds(k.p->node)) + MIGRAPHX_THROW("sym::expr::subs: map key '" + k.to_string() + "' is not a symbol"); + bindings[k.p->node] = v.p ? v.p->node : make_integer(0); + } + return {std::make_shared(substitute(p->node, bindings))}; } -symbolic_expr operator+(const symbolic_expr& a, const symbolic_expr& b) +expr operator+(const expr& a, const expr& b) { if(a.empty() and b.empty()) return {}; auto ea = a.p ? a.p->node : make_integer(0); auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_add(ea, eb))}; + return {std::make_shared(make_add(ea, eb))}; } -symbolic_expr operator-(const symbolic_expr& a, const symbolic_expr& b) +expr operator-(const expr& a, const expr& b) { if(a.empty() and b.empty()) return {}; auto ea = a.p ? a.p->node : make_integer(0); auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_sub(ea, eb))}; + return {std::make_shared(make_sub(ea, eb))}; } -symbolic_expr operator*(const symbolic_expr& a, const symbolic_expr& b) +expr operator*(const expr& a, const expr& b) { if(a.empty() and b.empty()) return {}; auto ea = a.p ? a.p->node : make_integer(0); auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_mul(ea, eb))}; + return {std::make_shared(make_mul(ea, eb))}; } -symbolic_expr operator/(const symbolic_expr& a, const symbolic_expr& b) +expr operator/(const expr& a, const expr& b) { if(a.empty() and b.empty()) return {}; auto ea = a.p ? a.p->node : make_integer(0); auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_floor_div(ea, eb))}; + return {std::make_shared(make_floor_div(ea, eb))}; } -bool operator==(const symbolic_expr& a, const symbolic_expr& b) +bool operator==(const expr& a, const expr& b) { if(a.empty() and b.empty()) return true; @@ -856,25 +857,145 @@ bool operator==(const symbolic_expr& a, const symbolic_expr& b) return expr_equal(a.p->node, b.p->node); } -bool operator!=(const symbolic_expr& a, const symbolic_expr& b) { return not(a == b); } +bool operator!=(const expr& a, const expr& b) { return not(a == b); } -std::ostream& operator<<(std::ostream& os, const symbolic_expr& e) +std::ostream& operator<<(std::ostream& os, const expr& e) { if(not e.empty()) os << print_expr(e.p->node); return os; } -void migraphx_to_value(value& v, const symbolic_expr& e) { v = migraphx::to_value(e.to_string()); } +expr var(const std::string& name) { return {std::make_shared(make_symbol(name))}; } + +expr lit(std::size_t n) +{ + return {std::make_shared(make_integer(static_cast(n)))}; +} -void migraphx_from_value(const value& v, symbolic_expr& e) +expr parse(const std::string& s) { - const auto& s = v.get_string(); if(s.empty()) - e = symbolic_expr{}; - else - e = symbolic_expr(s); + return {}; + return {std::make_shared(parse_string(s))}; +} + +static value node_to_value(const expr_ptr& e) +{ + return std::visit(overloaded{[](const integer_data& d) -> value { + value r; + r["type"] = "int"; + r["value"] = d.value; + return r; + }, + [](const symbol_data& d) -> value { + value r; + r["type"] = "sym"; + r["name"] = d.name; + return r; + }, + [](const add_data& d) -> value { + value r; + r["type"] = "add"; + r["constant"] = d.constant; + value terms = value::array{}; + for(const auto& [term, coeff] : d.terms) + { + value t; + t["expr"] = node_to_value(term); + t["coeff"] = coeff; + terms.push_back(t); + } + r["terms"] = terms; + return r; + }, + [](const mul_data& d) -> value { + value r; + r["type"] = "mul"; + r["coeff"] = d.coefficient; + value factors = value::array{}; + for(const auto& [base, exp] : d.factors) + { + value f; + f["expr"] = node_to_value(base); + f["exp"] = exp; + factors.push_back(f); + } + r["factors"] = factors; + return r; + }, + [](const fdiv_data& d) -> value { + value r; + r["type"] = "fdiv"; + r["num"] = node_to_value(d.numerator); + r["den"] = node_to_value(d.denominator); + return r; + }}, + e->data); } +static expr_ptr node_from_value(const value& v) +{ + const auto& type = v.at("type").get_string(); + if(type == "int") + { + return make_integer(v.at("value").get_int64()); + } + else if(type == "sym") + { + return make_symbol(v.at("name").get_string()); + } + else if(type == "add") + { + auto constant = v.at("constant").get_int64(); + term_map terms; + for(const auto& t : v.at("terms")) + { + auto term = node_from_value(t.at("expr")); + auto coeff = t.at("coeff").get_int64(); + terms[term] = coeff; + } + return build_add(constant, std::move(terms)); + } + else if(type == "mul") + { + auto coefficient = v.at("coeff").get_int64(); + factor_map factors; + for(const auto& f : v.at("factors")) + { + auto base = node_from_value(f.at("expr")); + auto exp = f.at("exp").get_int64(); + factors[base] = exp; + } + return build_mul(coefficient, std::move(factors)); + } + else if(type == "fdiv") + { + auto num = node_from_value(v.at("num")); + auto den = node_from_value(v.at("den")); + return make_floor_div(num, den); + } + MIGRAPHX_THROW("Unknown sym::expr node type: " + type); +} + +value expr::to_value() const +{ + if(empty()) + return {}; + return node_to_value(p->node); +} + +void expr::from_value(const value& v) +{ + if(v.is_null()) + { + *this = expr{}; + return; + } + *this = expr{std::make_shared(node_from_value(v))}; +} + +} // namespace sym + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/test/sym_test.cpp b/test/sym_test.cpp new file mode 100644 index 00000000000..a14010ef754 --- /dev/null +++ b/test/sym_test.cpp @@ -0,0 +1,780 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include "test.hpp" + +using se = migraphx::sym::expr; +using migraphx::sym::lit; +using migraphx::sym::parse; +using migraphx::sym::var; + +// =================================================================== +// Tier 1: Expression construction and canonicalization +// =================================================================== + +TEST_CASE(construct_integer) +{ + EXPECT(lit(0).to_string() == "0"); + EXPECT(lit(1).to_string() == "1"); + EXPECT(lit(42).to_string() == "42"); +} + +TEST_CASE(construct_symbol) +{ + EXPECT(var("H").to_string() == "H"); + EXPECT(var("batch_size").to_string() == "batch_size"); +} + +TEST_CASE(construct_empty) +{ + se e; + EXPECT(e.empty()); + EXPECT(e.to_string().empty()); +} + +TEST_CASE(add_identity) +{ + auto H = var("H"); + EXPECT(H + 0 == H); + EXPECT(0 + H == H); +} + +TEST_CASE(add_commutativity) +{ + auto H = var("H"), W = var("W"); + EXPECT(H + W == W + H); +} + +TEST_CASE(add_like_term_folding) +{ + auto H = var("H"); + auto r = H + H; + EXPECT(r == 2 * H); +} + +TEST_CASE(add_constant_folding) +{ + EXPECT(lit(3) + lit(5) == lit(8)); + EXPECT(lit(0) + lit(0) == lit(0)); +} + +TEST_CASE(add_flattening) +{ + auto H = var("H"), W = var("W"), C = var("C"); + auto a = (H + W) + C; + auto b = H + (W + C); + auto c = (C + H) + W; + EXPECT(a == b); + EXPECT(b == c); +} + +TEST_CASE(add_mixed) +{ + auto H = var("H"); + auto r = H + 3 + H + 2; + EXPECT(r == 2 * H + 5); + auto r2 = H + H; + EXPECT(r2 + 5 == 2 * H + 5); +} + +TEST_CASE(add_cancellation) +{ + auto H = var("H"); + EXPECT(H + (-1 * H) == lit(0)); +} + +TEST_CASE(sub_identity) +{ + auto H = var("H"); + EXPECT(H - 0 == H); +} + +TEST_CASE(sub_self) +{ + auto H = var("H"); + EXPECT(H - H == lit(0)); +} + +TEST_CASE(sub_constant_folding) { EXPECT(lit(10) - lit(3) == lit(7)); } + +TEST_CASE(sub_produces_negation) +{ + auto H = var("H"); + EXPECT(-1 * H == lit(0) - H); + EXPECT(3 - H == lit(0) - H + 3); +} + +TEST_CASE(neg_integer) +{ + auto r = lit(0) - 5; + EXPECT(r == lit(-5)); +} + +TEST_CASE(neg_double_negation) +{ + auto H = var("H"); + auto r = 0 - (-1 * H); + EXPECT(r == H); +} + +TEST_CASE(mul_identity) +{ + auto H = var("H"); + EXPECT(H * 1 == H); + EXPECT(1 * H == H); +} + +TEST_CASE(mul_zero) +{ + auto H = var("H"); + EXPECT(H * 0 == lit(0)); + EXPECT(0 * H == lit(0)); +} + +TEST_CASE(mul_constant_folding) { EXPECT(lit(3) * lit(7) == lit(21)); } + +TEST_CASE(mul_commutativity) +{ + auto A = var("A"), B = var("B"); + EXPECT(B * A == A * B); +} + +TEST_CASE(mul_coefficient_accumulation) +{ + auto H = var("H"); + auto r = 2 * H * 3; + EXPECT(r == 6 * H); +} + +TEST_CASE(mul_flattening) +{ + auto H = var("H"), W = var("W"), C = var("C"); + auto a = (H * W) * C; + auto b = H * (W * C); + auto c = (C * H) * W; + EXPECT(a == b); + EXPECT(b == c); +} + +TEST_CASE(mul_distributive) +{ + auto H = var("H"); + auto r = 2 * (H + 1); + EXPECT(r == 2 * H + 2); +} + +TEST_CASE(fdiv_identity) +{ + auto H = var("H"); + EXPECT(H / 1 == H); +} + +TEST_CASE(fdiv_constant_folding) +{ + EXPECT(lit(7) / lit(2) == lit(3)); + EXPECT(lit(6) / lit(3) == lit(2)); + EXPECT(lit(0) / lit(5) == lit(0)); +} + +TEST_CASE(fdiv_exact_coefficient_cancel) +{ + auto N = var("N"); + auto r = (6 * N) / 3; + EXPECT(r == 2 * N); +} + +TEST_CASE(fdiv_non_simplifiable) +{ + auto H = var("H"); + auto r = (H - 1) / 2; + EXPECT(r == (H - 1) / 2); +} + +TEST_CASE(fdiv_division_by_zero) +{ + EXPECT(test::throws([&] { var("H") / 0; })); +} + +TEST_CASE(add_scaled_subtraction) +{ + auto H = var("H"); + EXPECT(2 * H - H == H); + EXPECT(3 * H - 2 * H == H); + EXPECT(H + H + H == 3 * H); +} + +TEST_CASE(add_of_two_adds) +{ + auto H = var("H"); + auto r = (H + 1) + (H + 2); + EXPECT(r == 2 * H + 3); +} + +TEST_CASE(sub_strip_constant) +{ + auto H = var("H"); + EXPECT((H + 1) - H == lit(1)); +} + +TEST_CASE(sub_of_two_adds) +{ + auto H = var("H"); + auto r = (H + 1) - (H + 2); + EXPECT(r == lit(-1)); +} + +TEST_CASE(mul_zero_propagation) +{ + auto H = var("H"); + auto z = H - H; + EXPECT(50 * z == lit(0)); +} + +TEST_CASE(add_chain_constant_cancel) +{ + auto H = var("H"); + auto r = lit(2) - H - lit(2); + EXPECT(r == -1 * H); +} + +TEST_CASE(neg_of_sum_distributes) +{ + auto H = var("H"); + auto r = lit(-1) * (H + 1); + EXPECT(r == -1 * H - 1); +} + +TEST_CASE(neg_of_product_double) +{ + auto hw = var("H") * var("W"); + auto neg = 0 - hw; + EXPECT(neg == lit(0) - hw); + auto dbl = 0 - neg; + EXPECT(dbl == hw); +} + +TEST_CASE(add_compound_product_like_terms) +{ + auto H = var("H"), W = var("W"); + auto hw = H * W; + auto wh = W * H; + EXPECT(hw + hw == 2 * hw); + EXPECT(hw + wh == 2 * hw); + EXPECT(hw + 2 * hw == 3 * hw); +} + +TEST_CASE(add_compound_product_cancellation) +{ + auto hw = var("H") * var("W"); + EXPECT(hw - hw == lit(0)); +} + +// X*Y and X cancel pairwise: (X*Y - X) - (X*Y - X) == 0 +TEST_CASE(sub_compound_product_mixed) +{ + auto X = var("X"), Y = var("Y"); + auto xy = X * Y; + auto r = xy - X - xy + X; + EXPECT(r == lit(0)); +} + +// Duplicate A*B terms fold even when separated by another term +TEST_CASE(add_multi_term_accumulation) +{ + auto A = var("A"), B = var("B"), C = var("C"); + auto r = A * B + C + A * B; + auto expected = 2 * (A * B) + C; + EXPECT(r == expected); +} + +TEST_CASE(fdiv_negative_constant_folding) +{ + EXPECT(lit(-7) / lit(2) == lit(-7 / 2)); + EXPECT(lit(-6) / lit(3) == lit(-2)); + EXPECT(lit(7) / lit(-2) == lit(7 / -2)); +} + +TEST_CASE(fdiv_large_constants) +{ + EXPECT(lit(1000000) / lit(1000) == lit(1000)); + EXPECT(lit(999999) / lit(1000) == lit(999)); +} + +// =================================================================== +// Tier 2: Equality and hashing +// =================================================================== + +TEST_CASE(eq_different_values) +{ + auto H = var("H"), W = var("W"); + EXPECT(H + 1 != H + 2); + EXPECT(H != W); + EXPECT(lit(3) != lit(4)); +} + +TEST_CASE(eq_empty) +{ + EXPECT(se{} == se{}); + EXPECT(se{} != lit(0)); + EXPECT(lit(0) != se{}); +} + +// =================================================================== +// Tier 3: Evaluation and substitution +// =================================================================== + +TEST_CASE(eval_simple) +{ + auto H = var("H"); + EXPECT(H.eval({{H, 26}}) == 26); + EXPECT(lit(42).eval({}) == 42); +} + +TEST_CASE(eval_arithmetic) +{ + auto H = var("H"); + EXPECT((H - 3).eval({{H, 26}}) == 23); + EXPECT((H + 5).eval({{H, 10}}) == 15); + EXPECT((2 * H).eval({{H, 13}}) == 26); +} + +TEST_CASE(eval_compound) +{ + auto H = var("H"); + auto e = (H - 3) / 2 + 1; + EXPECT(e.eval({{H, 26}}) == 12); + EXPECT(e.eval({{H, 27}}) == 13); +} + +TEST_CASE(eval_multiple_symbols) +{ + auto N = var("N"), H = var("H"); + auto e = N * H; + EXPECT(e.eval({{N, 4}, {H, 26}}) == 104); +} + +TEST_CASE(eval_floor_division) +{ + auto H = var("H"); + auto e = (H - 1) / 2; + EXPECT(e.eval({{H, 7}}) == 3); + EXPECT(e.eval({{H, 8}}) == 3); + EXPECT(e.eval({{H, 9}}) == 4); +} + +TEST_CASE(eval_unbound_throws) +{ + auto H = var("H"), W = var("W"); + EXPECT(test::throws([&] { H.eval({}); })); + EXPECT(test::throws([&] { (H + W).eval({{H, 1}}); })); +} + +TEST_CASE(eval_integer_expr) +{ + EXPECT(lit(0).eval({}) == 0); + EXPECT(lit(100).eval({}) == 100); +} + +TEST_CASE(subs_partial) +{ + auto N = var("N"), H = var("H"); + auto e = N * H + 1; + auto r = e.subs({{N, lit(4)}}); + EXPECT(r == 4 * H + 1); + EXPECT(r.eval({{H, 10}}) == 41); +} + +TEST_CASE(subs_full) +{ + auto H = var("H"); + auto e = H + 1; + auto r = e.subs({{H, lit(5)}}); + EXPECT(r == lit(6)); +} + +TEST_CASE(subs_none) +{ + auto H = var("H"); + EXPECT(H.subs({}) == H); +} + +TEST_CASE(subs_floor_div) +{ + auto H = var("H"); + auto e = (H - 1) / 2; + auto r = e.subs({{H, lit(7)}}); + EXPECT(r == lit(3)); +} + +// eval() and subs()+eval() must agree on a compound expression +TEST_CASE(subs_eval_cross_validation) +{ + auto N = var("N"), H = var("H"); + auto e = (N * H - 3) / 2 + 1; + std::unordered_map eval_map = {{N, 4}, {H, 26}}; + std::unordered_map subs_map = {{N, lit(4)}, {H, lit(26)}}; + auto via_eval = e.eval(eval_map); + auto via_subs = e.subs(subs_map).eval({}); + EXPECT(via_eval == via_subs); +} + +TEST_CASE(subs_empty) +{ + se e; + auto r = e.subs({{var("H"), lit(5)}}); + EXPECT(r.empty()); +} + +TEST_CASE(subs_creates_like_terms) +{ + auto H = var("H"), W = var("W"); + auto e = H + W; + auto r = e.subs({{W, lit(0)}}); + EXPECT(r == H); +} + +TEST_CASE(subs_with_expression) +{ + auto H = var("H"), W = var("W"); + auto e = 2 * H + 1; + auto r = e.subs({{H, W + 3}}); + EXPECT(r == 2 * W + 7); +} + +TEST_CASE(subs_symbol_for_symbol) +{ + auto H = var("H"), W = var("W"); + auto e = H * H + 1; + auto r = e.subs({{H, W}}); + EXPECT(r == W * W + 1); +} + +TEST_CASE(subs_compound_expression) +{ + auto N = var("N"), H = var("H"), W = var("W"); + auto e = (N * H + W - 3) / 2; + auto r = e.subs({{H, 2 * W + 1}, {N, W - 1}}); + // N*H => (W-1)*(2*W+1) = 2*W^2 - W - 1 + // N*H + W - 3 => 2*W^2 - 2*W - 4 + W = 2*W^2 - 2 + // Verify by evaluating with W=5: (W-1)*(2*W+1) + W - 3 = 4*11 + 5 - 3 = 46, 46/2 = 23 + EXPECT(r.eval({{W, 5}}) == 23); + // Also verify the original expression with direct values agrees + EXPECT(e.eval({{N, 4}, {H, 11}, {W, 5}}) == 23); +} + +TEST_CASE(eval_compound_product) +{ + auto H = var("H"), W = var("W"); + auto e = H * W + 1; + EXPECT(e.eval({{H, 3}, {W, 4}}) == 13); +} + +TEST_CASE(eval_negative_intermediate) +{ + auto H = var("H"); + auto e = (H - 10) * 2 + 20; + EXPECT(e.eval({{H, 3}}) == 6); +} + +// =================================================================== +// Tier 4: Printing and parsing +// =================================================================== + +TEST_CASE(print_atoms) +{ + EXPECT(lit(42).to_string() == "42"); + EXPECT(var("H").to_string() == "H"); + EXPECT(lit(0).to_string() == "0"); + EXPECT(lit(-3).to_string() == "-3"); +} + +TEST_CASE(print_add) +{ + auto H = var("H"); + EXPECT((H + 1).to_string() == "H + 1"); + EXPECT((H - 3).to_string() == "H - 3"); +} + +TEST_CASE(print_mul) +{ + EXPECT((2 * var("H")).to_string() == "2*H"); + auto r = var("A") * var("B"); + EXPECT(r.to_string() == "A*B"); +} + +TEST_CASE(print_fdiv_parens) +{ + auto r = (var("H") - 1) / 2; + EXPECT(r.to_string() == "(H - 1)/2"); +} + +TEST_CASE(print_compound) +{ + auto r = (var("H") - 3) / 2 + 1; + EXPECT(r.to_string() == "(H - 3)/2 + 1"); +} + +TEST_CASE(parse_atoms) +{ + EXPECT(parse("42") == lit(42)); + EXPECT(parse("H") == var("H")); +} + +TEST_CASE(parse_arithmetic) +{ + auto H = var("H"); + auto r = parse("H + 1"); + EXPECT(r == H + 1); + + auto r2 = parse("H - 3"); + EXPECT(r2 == H - 3); + + auto r3 = parse("2*H"); + EXPECT(r3 == 2 * H); +} + +TEST_CASE(parse_precedence) +{ + auto r = parse("H + 1 * 2"); + EXPECT(r == var("H") + 2); +} + +TEST_CASE(parse_parentheses) +{ + auto r = parse("(H + 1) * 2"); + EXPECT(r == 2 * (var("H") + 1)); +} + +TEST_CASE(parse_division) +{ + auto r = parse("(H - 1)/2"); + EXPECT(r == (var("H") - 1) / 2); +} + +TEST_CASE(parse_unary_minus) +{ + auto H = var("H"); + EXPECT(parse("-H") == -1 * H); + EXPECT(parse("-H").to_string() == "-H"); + EXPECT(parse("-(H + 1)") == -1 * H - 1); +} + +// Legacy floor() wrapper is accepted by parser and treated as no-op +TEST_CASE(parse_floor_backward_compat) +{ + auto a = parse("floor((H-1)/2)"); + auto b = parse("(H-1)/2"); + EXPECT(a == b); + + auto c = parse("floor((H-1)/2) + 1"); + auto d = (var("H") - 1) / 2 + 1; + EXPECT(c == d); +} + +TEST_CASE(parse_whitespace_tolerance) +{ + EXPECT(parse(" H + 1 ") == parse("H + 1")); + EXPECT(parse("H+1") == parse("H + 1")); +} + +TEST_CASE(print_negative_mul_coefficient) +{ + auto r = 0 - 3 * var("H"); + EXPECT(r.to_string() == "-3*H"); +} + +TEST_CASE(print_multi_symbol_product) +{ + auto r = var("H") * var("W"); + auto s = r.to_string(); + EXPECT(s == "H*W" or s == "W*H"); + EXPECT(parse("H*W") == parse("W*H")); +} + +TEST_CASE(print_compound_expression) +{ + auto r = 2 * (var("H") * var("W")) + var("C") - 1; + auto s = r.to_string(); + EXPECT(parse(s) == r); +} + +TEST_CASE(parse_compound_mul) +{ + auto r = parse("2*H*W"); + EXPECT(r == 2 * var("H") * var("W")); +} + +TEST_CASE(print_parse_round_trip) +{ + auto H = var("H"), N = var("N"), C = var("C"), W = var("W"); + std::vector exprs = { + H, + H + 1, + 2 * H - 3, + (H - 3) / 2 + 1, + N * C * H * W, + (H - 1) / 2, + }; + for(const auto& e : exprs) + { + auto s = e.to_string(); + auto reparsed = parse(s); + EXPECT(reparsed == e); + } +} + +// =================================================================== +// Tier 6: Edge cases and robustness +// =================================================================== + +// 5 levels of (e-1)/2: simulates repeated pooling/conv stride reduction +TEST_CASE(edge_deeply_nested) +{ + auto H = var("H"); + se e = H; + for(int i = 0; i < 5; ++i) + e = (e - 1) / 2; + EXPECT(e.eval({{H, 255}}) == 7); +} + +TEST_CASE(edge_many_symbols) +{ + auto A = var("A"), B = var("B"), C = var("C"), D = var("D"), E = var("E"); + auto e = A + B + C + D + E; + EXPECT(e.eval({{A, 1}, {B, 2}, {C, 3}, {D, 4}, {E, 5}}) == 15); +} + +TEST_CASE(edge_neg_one_coefficient) +{ + auto H = var("H"); + EXPECT(-1 * H == lit(0) - H); + EXPECT(-1 * H + H == lit(0)); +} + +TEST_CASE(edge_empty_operations) +{ + se empty; + EXPECT((empty + empty).empty()); + EXPECT((empty - empty).empty()); + EXPECT((empty * empty).empty()); + EXPECT((empty / empty).empty()); +} + +TEST_CASE(edge_empty_with_nonempty) +{ + se empty; + auto H = var("H"); + auto r1 = H + empty; + EXPECT(not r1.empty()); + + auto r2 = empty + H; + EXPECT(not r2.empty()); +} + +TEST_CASE(edge_large_coefficients) +{ + auto H = var("H"); + auto r = 1000000 * H; + EXPECT(r.eval({{H, 1000000}}) == 1000000000000ULL); +} + +// Incrementally adding H ten times must fold to 11*H +TEST_CASE(edge_chained_operations) +{ + auto H = var("H"); + auto e = H; + for(int i = 0; i < 10; ++i) + e = e + H; + EXPECT(e == 11 * H); +} + +TEST_CASE(edge_repeated_parse) +{ + auto H = var("H"); + for(int i = 0; i < 10; ++i) + { + auto r = parse("(H - 3)/2 + 1"); + EXPECT(r == (H - 3) / 2 + 1); + } +} + +// =================================================================== +// Serialization round-trip +// =================================================================== + +static se round_trip(const se& e) +{ + auto v = migraphx::to_value(e); + return migraphx::from_value(v); +} + +TEST_CASE(serialize_empty) +{ + se e; + EXPECT(round_trip(e).empty()); +} + +TEST_CASE(serialize_integer) +{ + EXPECT(round_trip(lit(0)) == lit(0)); + EXPECT(round_trip(lit(42)) == lit(42)); +} + +TEST_CASE(serialize_symbol) +{ + auto H = var("H"); + EXPECT(round_trip(H) == H); +} + +TEST_CASE(serialize_add) +{ + auto H = var("H"); + auto e = 2 * H + 3; + EXPECT(round_trip(e) == e); +} + +TEST_CASE(serialize_mul) +{ + auto H = var("H"), W = var("W"); + auto e = H * W; + EXPECT(round_trip(e) == e); +} + +TEST_CASE(serialize_fdiv) +{ + auto H = var("H"); + auto e = (H - 1) / 2; + EXPECT(round_trip(e) == e); +} + +TEST_CASE(serialize_compound) +{ + auto N = var("N"), H = var("H"), W = var("W"); + auto e = (N * H * W + 3) / 2 - 1; + EXPECT(round_trip(e) == e); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/symbolic_test.cpp b/test/symbolic_test.cpp deleted file mode 100644 index 27bd3d3fccf..00000000000 --- a/test/symbolic_test.cpp +++ /dev/null @@ -1,625 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include -#include "test.hpp" - -using se = migraphx::symbolic_expr; - -// =================================================================== -// Tier 1: Expression construction and canonicalization -// =================================================================== - -TEST_CASE(construct_integer) -{ - EXPECT(se(0).to_string() == "0"); - EXPECT(se(1).to_string() == "1"); - EXPECT(se(42).to_string() == "42"); -} - -TEST_CASE(construct_symbol) -{ - EXPECT(se("H").to_string() == "H"); - EXPECT(se("batch_size").to_string() == "batch_size"); -} - -TEST_CASE(construct_empty) -{ - se e; - EXPECT(e.empty()); - EXPECT(e.to_string().empty()); -} - -TEST_CASE(add_identity) -{ - EXPECT(se("H") + 0 == se("H")); - EXPECT(0 + se("H") == se("H")); -} - -TEST_CASE(add_commutativity) { EXPECT(se("H") + se("W") == se("W") + se("H")); } - -TEST_CASE(add_like_term_folding) -{ - auto r = se("H") + se("H"); - EXPECT(r.to_string() == "2*H"); -} - -TEST_CASE(add_constant_folding) -{ - EXPECT(se(3) + se(5) == se(8)); - EXPECT(se(0) + se(0) == se(0)); -} - -TEST_CASE(add_flattening) -{ - auto a = (se("H") + se("W")) + se("C"); - auto b = se("H") + (se("W") + se("C")); - auto c = (se("C") + se("H")) + se("W"); - EXPECT(a == b); - EXPECT(b == c); -} - -TEST_CASE(add_mixed) -{ - auto r = se("H") + 3 + se("H") + 2; - EXPECT(r == 2 * se("H") + 5); - auto r2 = se("H") + se("H"); - EXPECT(r2 + 5 == 2 * se("H") + 5); -} - -TEST_CASE(add_cancellation) -{ - auto neg_h = se("-H"); - EXPECT(se("H") + neg_h == se(0)); -} - -TEST_CASE(sub_identity) { EXPECT(se("H") - 0 == se("H")); } - -TEST_CASE(sub_self) { EXPECT(se("H") - se("H") == se(0)); } - -TEST_CASE(sub_constant_folding) { EXPECT(se(10) - se(3) == se(7)); } - -TEST_CASE(sub_produces_negation) -{ - EXPECT(se("-H").to_string() == "-H"); - EXPECT((3 - se("H")).to_string() == "-H + 3"); -} - -TEST_CASE(neg_integer) -{ - auto r = se(0) - 5; - EXPECT(r == se(-5)); -} - -TEST_CASE(neg_double_negation) -{ - auto r = 0 - se("-H"); - EXPECT(r == se("H")); -} - -TEST_CASE(mul_identity) -{ - EXPECT(se("H") * 1 == se("H")); - EXPECT(1 * se("H") == se("H")); -} - -TEST_CASE(mul_zero) -{ - EXPECT(se("H") * 0 == se(0)); - EXPECT(0 * se("H") == se(0)); -} - -TEST_CASE(mul_constant_folding) { EXPECT(se(3) * se(7) == se(21)); } - -TEST_CASE(mul_commutativity) { EXPECT(se("B") * se("A") == se("A") * se("B")); } - -TEST_CASE(mul_coefficient_accumulation) -{ - auto r = 2 * se("H") * 3; - EXPECT(r.to_string() == "6*H"); -} - -TEST_CASE(mul_flattening) -{ - auto a = (se("H") * se("W")) * se("C"); - auto b = se("H") * (se("W") * se("C")); - auto c = (se("C") * se("H")) * se("W"); - EXPECT(a == b); - EXPECT(b == c); -} - -TEST_CASE(mul_distributive) -{ - auto r = 2 * (se("H") + 1); - EXPECT(r == 2 * se("H") + 2); -} - -TEST_CASE(mul_symbolic_times_add_no_distribution) -{ - auto r = se("N") * (se("H") + 1); - EXPECT(r != se("N") * se("H") + se("N")); -} - -TEST_CASE(fdiv_identity) { EXPECT(se("H") / 1 == se("H")); } - -TEST_CASE(fdiv_constant_folding) -{ - EXPECT(se(7) / se(2) == se(3)); - EXPECT(se(6) / se(3) == se(2)); - EXPECT(se(0) / se(5) == se(0)); -} - -TEST_CASE(fdiv_exact_coefficient_cancel) -{ - auto r = (6 * se("N")) / 3; - EXPECT(r.to_string() == "2*N"); -} - -TEST_CASE(fdiv_non_simplifiable) -{ - auto r = (se("H") - 1) / 2; - EXPECT(r.to_string() == "(H - 1)/2"); -} - -TEST_CASE(fdiv_division_by_zero) -{ - EXPECT(test::throws([&] { se("H") / 0; })); -} - -TEST_CASE(add_scaled_subtraction) -{ - EXPECT(2 * se("H") - se("H") == se("H")); - EXPECT(3 * se("H") - 2 * se("H") == se("H")); - EXPECT(se("H") + se("H") + se("H") == 3 * se("H")); -} - -TEST_CASE(add_of_two_adds) -{ - auto r = (se("H") + 1) + (se("H") + 2); - EXPECT(r == 2 * se("H") + 3); -} - -TEST_CASE(sub_strip_constant) { EXPECT((se("H") + 1) - se("H") == se(1)); } - -TEST_CASE(sub_of_two_adds) -{ - auto r = (se("H") + 1) - (se("H") + 2); - EXPECT(r == se(-1)); -} - -TEST_CASE(mul_zero_propagation) -{ - auto z = se("H") - se("H"); - EXPECT(50 * z == se(0)); -} - -TEST_CASE(add_chain_constant_cancel) -{ - auto r = se(2) - se("H") - se(2); - EXPECT(r == se("-H")); -} - -TEST_CASE(neg_of_sum_distributes) -{ - auto r = se(-1) * (se("H") + 1); - EXPECT(r == se("-H") - 1); -} - -TEST_CASE(neg_of_product_double) -{ - auto hw = se("H") * se("W"); - auto neg = 0 - hw; - EXPECT(neg.to_string() == "-H*W"); - auto dbl = 0 - neg; - EXPECT(dbl == hw); -} - -TEST_CASE(add_compound_product_like_terms) -{ - auto hw = se("H") * se("W"); - auto wh = se("W") * se("H"); - EXPECT(hw + hw == 2 * hw); - EXPECT(hw + wh == 2 * hw); - EXPECT(hw + 2 * hw == 3 * hw); -} - -TEST_CASE(add_compound_product_cancellation) -{ - auto hw = se("H") * se("W"); - EXPECT(hw - hw == se(0)); -} - -// X*Y and X cancel pairwise: (X*Y - X) - (X*Y - X) == 0 -TEST_CASE(sub_compound_product_mixed) -{ - auto xy = se("X") * se("Y"); - auto r = xy - se("X") - xy + se("X"); - EXPECT(r == se(0)); -} - -// Duplicate A*B terms fold even when separated by another term -TEST_CASE(add_multi_term_accumulation) -{ - auto r = se("A") * se("B") + se("C") + se("A") * se("B"); - auto expected = 2 * (se("A") * se("B")) + se("C"); - EXPECT(r == expected); -} - -TEST_CASE(fdiv_negative_constant_folding) -{ - EXPECT(se(-7) / se(2) == se(-7 / 2)); - EXPECT(se(-6) / se(3) == se(-2)); - EXPECT(se(7) / se(-2) == se(7 / -2)); -} - -TEST_CASE(fdiv_large_constants) -{ - EXPECT(se(1000000) / se(1000) == se(1000)); - EXPECT(se(999999) / se(1000) == se(999)); -} - -// =================================================================== -// Tier 2: Equality and hashing -// =================================================================== - -TEST_CASE(eq_different_values) -{ - EXPECT(se("H") + 1 != se("H") + 2); - EXPECT(se("H") != se("W")); - EXPECT(se(3) != se(4)); -} - -TEST_CASE(eq_empty) -{ - EXPECT(se{} == se{}); - EXPECT(se{} != se(0)); - EXPECT(se(0) != se{}); -} - -// =================================================================== -// Tier 3: Evaluation and substitution -// =================================================================== - -TEST_CASE(eval_simple) -{ - EXPECT(se("H").eval({{"H", 26}}) == 26); - EXPECT(se(42).eval({}) == 42); -} - -TEST_CASE(eval_arithmetic) -{ - EXPECT((se("H") - 3).eval({{"H", 26}}) == 23); - EXPECT((se("H") + 5).eval({{"H", 10}}) == 15); - EXPECT((2 * se("H")).eval({{"H", 13}}) == 26); -} - -TEST_CASE(eval_compound) -{ - auto e = (se("H") - 3) / 2 + 1; - EXPECT(e.eval({{"H", 26}}) == 12); - EXPECT(e.eval({{"H", 27}}) == 13); -} - -TEST_CASE(eval_multiple_symbols) -{ - auto e = se("N") * se("H"); - EXPECT(e.eval({{"N", 4}, {"H", 26}}) == 104); -} - -TEST_CASE(eval_floor_division) -{ - auto e = (se("H") - 1) / 2; - EXPECT(e.eval({{"H", 7}}) == 3); - EXPECT(e.eval({{"H", 8}}) == 3); - EXPECT(e.eval({{"H", 9}}) == 4); -} - -TEST_CASE(eval_unbound_throws) -{ - EXPECT(test::throws([&] { se("H").eval({}); })); - EXPECT(test::throws([&] { (se("H") + se("W")).eval({{"H", 1}}); })); -} - -TEST_CASE(eval_integer_expr) -{ - EXPECT(se(0).eval({}) == 0); - EXPECT(se(100).eval({}) == 100); -} - -TEST_CASE(subs_partial) -{ - auto e = se("N") * se("H") + 1; - auto r = e.subs({{"N", 4}}); - EXPECT(r == 4 * se("H") + 1); - EXPECT(r.eval({{"H", 10}}) == 41); -} - -TEST_CASE(subs_full) -{ - auto e = se("H") + 1; - auto r = e.subs({{"H", 5}}); - EXPECT(r == se(6)); -} - -TEST_CASE(subs_none) -{ - auto e = se("H"); - EXPECT(e.subs({}) == se("H")); -} - -TEST_CASE(subs_floor_div) -{ - auto e = (se("H") - 1) / 2; - auto r = e.subs({{"H", 7}}); - EXPECT(r == se(3)); -} - -// eval() and subs()+eval() must agree on a compound expression -TEST_CASE(subs_eval_cross_validation) -{ - auto e = (se("N") * se("H") - 3) / 2 + 1; - std::map m = {{"N", 4}, {"H", 26}}; - auto via_eval = e.eval(m); - auto via_subs = e.subs(m).eval({}); - EXPECT(via_eval == via_subs); -} - -TEST_CASE(subs_empty) -{ - se e; - auto r = e.subs({{"H", 5}}); - EXPECT(r.empty()); -} - -TEST_CASE(subs_creates_like_terms) -{ - auto e = se("H") + se("W"); - auto r = e.subs({{"W", 0}}); - EXPECT(r == se("H")); -} - -TEST_CASE(eval_compound_product) -{ - auto e = se("H") * se("W") + 1; - EXPECT(e.eval({{"H", 3}, {"W", 4}}) == 13); -} - -TEST_CASE(eval_negative_intermediate) -{ - auto e = (se("H") - 10) * 2 + 20; - EXPECT(e.eval({{"H", 3}}) == 6); -} - -// =================================================================== -// Tier 4: Printing and parsing -// =================================================================== - -TEST_CASE(print_atoms) -{ - EXPECT(se(42).to_string() == "42"); - EXPECT(se("H").to_string() == "H"); - EXPECT(se(0).to_string() == "0"); - EXPECT(se(-3).to_string() == "-3"); -} - -TEST_CASE(print_add) -{ - EXPECT((se("H") + 1).to_string() == "H + 1"); - EXPECT((se("H") - 3).to_string() == "H - 3"); -} - -TEST_CASE(print_mul) -{ - EXPECT((2 * se("H")).to_string() == "2*H"); - auto r = se("A") * se("B"); - EXPECT(r.to_string() == "A*B"); -} - -TEST_CASE(print_fdiv_parens) -{ - auto r = (se("H") - 1) / 2; - EXPECT(r.to_string() == "(H - 1)/2"); -} - -TEST_CASE(print_compound) -{ - auto r = (se("H") - 3) / 2 + 1; - EXPECT(r.to_string() == "(H - 3)/2 + 1"); -} - -TEST_CASE(parse_atoms) -{ - EXPECT(se("42") == se(42)); - EXPECT(se("H") == se("H")); -} - -TEST_CASE(parse_arithmetic) -{ - auto r = se("H + 1"); - EXPECT(r == se("H") + 1); - - auto r2 = se("H - 3"); - EXPECT(r2 == se("H") - 3); - - auto r3 = se("2*H"); - EXPECT(r3 == 2 * se("H")); -} - -TEST_CASE(parse_precedence) -{ - auto r = se("H + 1 * 2"); - EXPECT(r == se("H") + 2); -} - -TEST_CASE(parse_parentheses) -{ - auto r = se("(H + 1) * 2"); - EXPECT(r == 2 * (se("H") + 1)); -} - -TEST_CASE(parse_division) -{ - auto r = se("(H - 1)/2"); - EXPECT(r == (se("H") - 1) / 2); -} - -TEST_CASE(parse_unary_minus) -{ - EXPECT(se("-H") == se("-H")); - EXPECT(se("-H").to_string() == "-H"); - EXPECT(se("-(H + 1)") == se("-H") - 1); -} - -// Legacy floor() wrapper is accepted by parser and treated as no-op -TEST_CASE(parse_floor_backward_compat) -{ - auto a = se("floor((H-1)/2)"); - auto b = se("(H-1)/2"); - EXPECT(a == b); - - auto c = se("floor((H-1)/2) + 1"); - auto d = (se("H") - 1) / 2 + 1; - EXPECT(c == d); -} - -TEST_CASE(parse_whitespace_tolerance) -{ - EXPECT(se(" H + 1 ") == se("H + 1")); - EXPECT(se("H+1") == se("H + 1")); -} - -TEST_CASE(print_negative_mul_coefficient) -{ - auto r = 0 - 3 * se("H"); - EXPECT(r.to_string() == "-3*H"); -} - -TEST_CASE(print_multi_symbol_product) -{ - auto r = se("H") * se("W"); - auto s = r.to_string(); - EXPECT(s == "H*W" or s == "W*H"); - EXPECT(se("H*W") == se("W*H")); -} - -TEST_CASE(print_compound_expression) -{ - auto r = 2 * (se("H") * se("W")) + se("C") - 1; - auto s = r.to_string(); - EXPECT(se(s) == r); -} - -TEST_CASE(parse_compound_mul) -{ - auto r = se("2*H*W"); - EXPECT(r == 2 * se("H") * se("W")); -} - -TEST_CASE(print_parse_round_trip) -{ - std::vector exprs = { - se("H"), - se("H") + 1, - 2 * se("H") - 3, - (se("H") - 3) / 2 + 1, - se("N") * se("C") * se("H") * se("W"), - (se("H") - 1) / 2, - }; - for(const auto& e : exprs) - { - auto s = e.to_string(); - auto reparsed = se(s); - EXPECT(reparsed == e); - } -} - -// =================================================================== -// Tier 6: Edge cases and robustness -// =================================================================== - -// 5 levels of (e-1)/2: simulates repeated pooling/conv stride reduction -TEST_CASE(edge_deeply_nested) -{ - auto e = se("H"); - for(int i = 0; i < 5; ++i) - e = (e - 1) / 2; - EXPECT(e.eval({{"H", 255}}) == 7); -} - -TEST_CASE(edge_many_symbols) -{ - auto e = se("A") + se("B") + se("C") + se("D") + se("E"); - EXPECT(e.eval({{"A", 1}, {"B", 2}, {"C", 3}, {"D", 4}, {"E", 5}}) == 15); -} - -TEST_CASE(edge_neg_one_coefficient) -{ - EXPECT(se("-H").to_string() == "-H"); - EXPECT(se("-H") + se("H") == se(0)); -} - -TEST_CASE(edge_empty_operations) -{ - se empty; - EXPECT((empty + empty).empty()); - EXPECT((empty - empty).empty()); - EXPECT((empty * empty).empty()); - EXPECT((empty / empty).empty()); -} - -TEST_CASE(edge_empty_with_nonempty) -{ - se empty; - auto r1 = se("H") + empty; - EXPECT(not r1.empty()); - - auto r2 = empty + se("H"); - EXPECT(not r2.empty()); -} - -TEST_CASE(edge_large_coefficients) -{ - auto r = 1000000 * se("H"); - EXPECT(r.eval({{"H", 1000000}}) == 1000000000000ULL); -} - -// Incrementally adding H ten times must fold to 11*H -TEST_CASE(edge_chained_operations) -{ - auto e = se("H"); - for(int i = 0; i < 10; ++i) - e = e + se("H"); - EXPECT(e == 11 * se("H")); -} - -TEST_CASE(edge_repeated_parse) -{ - for(int i = 0; i < 10; ++i) - { - auto r = se("(H - 3)/2 + 1"); - EXPECT(r == (se("H") - 3) / 2 + 1); - } -} - -int main(int argc, const char* argv[]) { test::run(argc, argv); } From 964f9341f623798260efc1188d4968738daa1fc2 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 19:44:34 -0700 Subject: [PATCH 10/60] use int64 for literals --- src/include/migraphx/sym.hpp | 21 +++++++++++---------- src/sym.cpp | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 99b86093f92..91bbe85957b 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -25,6 +25,7 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_SYM_HPP #include +#include #include #include #include @@ -62,7 +63,7 @@ struct MIGRAPHX_EXPORT expr struct impl; friend expr var(const std::string& name); - friend expr lit(std::size_t n); + friend expr lit(int64_t n); friend expr parse(const std::string& s); private: @@ -71,17 +72,17 @@ struct MIGRAPHX_EXPORT expr }; MIGRAPHX_EXPORT expr var(const std::string& name); -MIGRAPHX_EXPORT expr lit(std::size_t n); +MIGRAPHX_EXPORT expr lit(int64_t n); MIGRAPHX_EXPORT expr parse(const std::string& s); -inline expr operator+(const expr& a, std::size_t b) { return a + lit(b); } -inline expr operator+(std::size_t a, const expr& b) { return lit(a) + b; } -inline expr operator-(const expr& a, std::size_t b) { return a - lit(b); } -inline expr operator-(std::size_t a, const expr& b) { return lit(a) - b; } -inline expr operator*(const expr& a, std::size_t b) { return a * lit(b); } -inline expr operator*(std::size_t a, const expr& b) { return lit(a) * b; } -inline expr operator/(const expr& a, std::size_t b) { return a / lit(b); } -inline expr operator/(std::size_t a, const expr& b) { return lit(a) / b; } +inline expr operator+(const expr& a, int64_t b) { return a + lit(b); } +inline expr operator+(int64_t a, const expr& b) { return lit(a) + b; } +inline expr operator-(const expr& a, int64_t b) { return a - lit(b); } +inline expr operator-(int64_t a, const expr& b) { return lit(a) - b; } +inline expr operator*(const expr& a, int64_t b) { return a * lit(b); } +inline expr operator*(int64_t a, const expr& b) { return lit(a) * b; } +inline expr operator/(const expr& a, int64_t b) { return a / lit(b); } +inline expr operator/(int64_t a, const expr& b) { return lit(a) / b; } } // namespace sym diff --git a/src/sym.cpp b/src/sym.cpp index d8d035db477..cb0260b46c7 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -868,9 +868,9 @@ std::ostream& operator<<(std::ostream& os, const expr& e) expr var(const std::string& name) { return {std::make_shared(make_symbol(name))}; } -expr lit(std::size_t n) +expr lit(int64_t n) { - return {std::make_shared(make_integer(static_cast(n)))}; + return {std::make_shared(make_integer(n))}; } expr parse(const std::string& s) From 364bd231122002e8750240bb476a32a24d31af9b Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 21:15:40 -0700 Subject: [PATCH 11/60] fix merge --- src/shape.cpp | 53 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 295e5a32393..82bb121e867 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -933,13 +933,23 @@ shape::dynamic_dimension operator/(const shape::dynamic_dimension& x, const std: return dd /= y; } +static optional get_sym(const shape::dynamic_dimension& dd) +{ + if(dd.sym_expr) + return dd.sym_expr; + if(dd.is_fixed()) + return sym::lit(dd.min); + return nullopt; +} + shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) { - auto e = sym_expr.value_or(sym::lit(min)) + x.sym_expr.value_or(sym::lit(x.min)); - min = min + x.min; - max = (max > std::numeric_limits::max() - x.max) - ? std::numeric_limits::max() - : max + x.max; + auto lhs_sym = get_sym(*this); + auto rhs_sym = get_sym(x); + min = min + x.min; + max = (max > std::numeric_limits::max() - x.max) + ? std::numeric_limits::max() + : max + x.max; if(x.is_fixed()) { std::set new_optimals; @@ -953,15 +963,16 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dyna { optimals.clear(); } - sym_expr = (sym_expr or x.sym_expr) ? optional(e) : nullopt; + sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym + *rhs_sym) : nullopt; return *this; } shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dynamic_dimension& x) { - auto e = sym_expr.value_or(sym::lit(min)) - x.sym_expr.value_or(sym::lit(x.min)); - min = (min > x.max) ? min - x.max : 0; - max = (max > x.min) ? max - x.min : 0; + auto lhs_sym = get_sym(*this); + auto rhs_sym = get_sym(x); + min = (min > x.max) ? min - x.max : 0; + max = (max > x.min) ? max - x.min : 0; if(x.is_fixed()) { std::set new_optimals; @@ -975,17 +986,18 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dyna { optimals.clear(); } - sym_expr = (sym_expr or x.sym_expr) ? optional(e) : nullopt; + sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym - *rhs_sym) : nullopt; return *this; } shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto e = sym_expr.value_or(sym::lit(min)) * x.sym_expr.value_or(sym::lit(x.min)); - min = min * x.min; - max = (max > std::numeric_limits::max() / (x.max == 0 ? 1 : x.max)) - ? std::numeric_limits::max() - : max * x.max; + auto lhs_sym = get_sym(*this); + auto rhs_sym = get_sym(x); + min = min * x.min; + max = (max > std::numeric_limits::max() / (x.max == 0 ? 1 : x.max)) + ? std::numeric_limits::max() + : max * x.max; if(x.is_fixed()) { std::set new_optimals; @@ -999,15 +1011,16 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dyna { optimals.clear(); } - sym_expr = (sym_expr or x.sym_expr) ? optional(e) : nullopt; + sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym * *rhs_sym) : nullopt; return *this; } shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dynamic_dimension& x) { - auto e = sym_expr.value_or(sym::lit(min)) / x.sym_expr.value_or(sym::lit(x.min)); - min = (x.max == 0) ? 0 : min / x.max; - max = (x.min == 0) ? std::numeric_limits::max() : max / x.min; + auto lhs_sym = get_sym(*this); + auto rhs_sym = get_sym(x); + min = (x.max == 0) ? 0 : min / x.max; + max = (x.min == 0) ? std::numeric_limits::max() : max / x.min; if(x.is_fixed()) { std::set new_optimals; @@ -1021,7 +1034,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dyna { optimals.clear(); } - sym_expr = (sym_expr or x.sym_expr) ? optional(e) : nullopt; + sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym / *rhs_sym) : nullopt; return *this; } From 33614e043562ca063299a4885bc26aeb3b2bc06f Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 21:27:41 -0700 Subject: [PATCH 12/60] change eval func name --- src/include/migraphx/sym.hpp | 2 +- src/sym.cpp | 13 ++++------ test/sym_test.cpp | 50 ++++++++++++++++++------------------ 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 91bbe85957b..eaf677e2270 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -49,7 +49,7 @@ struct MIGRAPHX_EXPORT expr std::string to_string() const; value to_value() const; void from_value(const value& v); - std::size_t eval(const std::unordered_map& symbol_map) const; + std::size_t eval_dim(const std::unordered_map& symbol_map) const; expr subs(const std::unordered_map& symbol_map) const; MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); diff --git a/src/sym.cpp b/src/sym.cpp index cb0260b46c7..14d3df42c9e 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -496,8 +496,8 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) auto it = bindings.find(e); if(it != bindings.end()) return static_cast(it->second); - MIGRAPHX_THROW("sym::expr::eval: unbound symbol '" + d.name + - "'"); + MIGRAPHX_THROW("sym::expr::eval_dim: unbound symbol '" + + d.name + "'"); }, [&](const add_data& d) -> int64_t { int64_t sum = d.constant; @@ -782,7 +782,7 @@ std::string expr::to_string() const return print_expr(p->node); } -std::size_t expr::eval(const std::unordered_map& symbol_map) const +std::size_t expr::eval_dim(const std::unordered_map& symbol_map) const { if(empty()) return 0; @@ -790,7 +790,7 @@ std::size_t expr::eval(const std::unordered_map& symbol_map) for(const auto& [k, v] : symbol_map) { if(k.empty() or not holds(k.p->node)) - MIGRAPHX_THROW("sym::expr::eval: map key '" + k.to_string() + "' is not a symbol"); + MIGRAPHX_THROW("sym::expr::eval_dim: map key '" + k.to_string() + "' is not a symbol"); bindings[k.p->node] = v; } auto v = eval_direct(p->node, bindings); @@ -868,10 +868,7 @@ std::ostream& operator<<(std::ostream& os, const expr& e) expr var(const std::string& name) { return {std::make_shared(make_symbol(name))}; } -expr lit(int64_t n) -{ - return {std::make_shared(make_integer(n))}; -} +expr lit(int64_t n) { return {std::make_shared(make_integer(n))}; } expr parse(const std::string& s) { diff --git a/test/sym_test.cpp b/test/sym_test.cpp index a14010ef754..ebcd31eadc3 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -349,53 +349,53 @@ TEST_CASE(eq_empty) TEST_CASE(eval_simple) { auto H = var("H"); - EXPECT(H.eval({{H, 26}}) == 26); - EXPECT(lit(42).eval({}) == 42); + EXPECT(H.eval_dim({{H, 26}}) == 26); + EXPECT(lit(42).eval_dim({}) == 42); } TEST_CASE(eval_arithmetic) { auto H = var("H"); - EXPECT((H - 3).eval({{H, 26}}) == 23); - EXPECT((H + 5).eval({{H, 10}}) == 15); - EXPECT((2 * H).eval({{H, 13}}) == 26); + EXPECT((H - 3).eval_dim({{H, 26}}) == 23); + EXPECT((H + 5).eval_dim({{H, 10}}) == 15); + EXPECT((2 * H).eval_dim({{H, 13}}) == 26); } TEST_CASE(eval_compound) { auto H = var("H"); auto e = (H - 3) / 2 + 1; - EXPECT(e.eval({{H, 26}}) == 12); - EXPECT(e.eval({{H, 27}}) == 13); + EXPECT(e.eval_dim({{H, 26}}) == 12); + EXPECT(e.eval_dim({{H, 27}}) == 13); } TEST_CASE(eval_multiple_symbols) { auto N = var("N"), H = var("H"); auto e = N * H; - EXPECT(e.eval({{N, 4}, {H, 26}}) == 104); + EXPECT(e.eval_dim({{N, 4}, {H, 26}}) == 104); } TEST_CASE(eval_floor_division) { auto H = var("H"); auto e = (H - 1) / 2; - EXPECT(e.eval({{H, 7}}) == 3); - EXPECT(e.eval({{H, 8}}) == 3); - EXPECT(e.eval({{H, 9}}) == 4); + EXPECT(e.eval_dim({{H, 7}}) == 3); + EXPECT(e.eval_dim({{H, 8}}) == 3); + EXPECT(e.eval_dim({{H, 9}}) == 4); } TEST_CASE(eval_unbound_throws) { auto H = var("H"), W = var("W"); - EXPECT(test::throws([&] { H.eval({}); })); - EXPECT(test::throws([&] { (H + W).eval({{H, 1}}); })); + EXPECT(test::throws([&] { H.eval_dim({}); })); + EXPECT(test::throws([&] { (H + W).eval_dim({{H, 1}}); })); } TEST_CASE(eval_integer_expr) { - EXPECT(lit(0).eval({}) == 0); - EXPECT(lit(100).eval({}) == 100); + EXPECT(lit(0).eval_dim({}) == 0); + EXPECT(lit(100).eval_dim({}) == 100); } TEST_CASE(subs_partial) @@ -404,7 +404,7 @@ TEST_CASE(subs_partial) auto e = N * H + 1; auto r = e.subs({{N, lit(4)}}); EXPECT(r == 4 * H + 1); - EXPECT(r.eval({{H, 10}}) == 41); + EXPECT(r.eval_dim({{H, 10}}) == 41); } TEST_CASE(subs_full) @@ -436,8 +436,8 @@ TEST_CASE(subs_eval_cross_validation) auto e = (N * H - 3) / 2 + 1; std::unordered_map eval_map = {{N, 4}, {H, 26}}; std::unordered_map subs_map = {{N, lit(4)}, {H, lit(26)}}; - auto via_eval = e.eval(eval_map); - auto via_subs = e.subs(subs_map).eval({}); + auto via_eval = e.eval_dim(eval_map); + auto via_subs = e.subs(subs_map).eval_dim({}); EXPECT(via_eval == via_subs); } @@ -480,23 +480,23 @@ TEST_CASE(subs_compound_expression) // N*H => (W-1)*(2*W+1) = 2*W^2 - W - 1 // N*H + W - 3 => 2*W^2 - 2*W - 4 + W = 2*W^2 - 2 // Verify by evaluating with W=5: (W-1)*(2*W+1) + W - 3 = 4*11 + 5 - 3 = 46, 46/2 = 23 - EXPECT(r.eval({{W, 5}}) == 23); + EXPECT(r.eval_dim({{W, 5}}) == 23); // Also verify the original expression with direct values agrees - EXPECT(e.eval({{N, 4}, {H, 11}, {W, 5}}) == 23); + EXPECT(e.eval_dim({{N, 4}, {H, 11}, {W, 5}}) == 23); } TEST_CASE(eval_compound_product) { auto H = var("H"), W = var("W"); auto e = H * W + 1; - EXPECT(e.eval({{H, 3}, {W, 4}}) == 13); + EXPECT(e.eval_dim({{H, 3}, {W, 4}}) == 13); } TEST_CASE(eval_negative_intermediate) { auto H = var("H"); auto e = (H - 10) * 2 + 20; - EXPECT(e.eval({{H, 3}}) == 6); + EXPECT(e.eval_dim({{H, 3}}) == 6); } // =================================================================== @@ -657,14 +657,14 @@ TEST_CASE(edge_deeply_nested) se e = H; for(int i = 0; i < 5; ++i) e = (e - 1) / 2; - EXPECT(e.eval({{H, 255}}) == 7); + EXPECT(e.eval_dim({{H, 255}}) == 7); } TEST_CASE(edge_many_symbols) { auto A = var("A"), B = var("B"), C = var("C"), D = var("D"), E = var("E"); auto e = A + B + C + D + E; - EXPECT(e.eval({{A, 1}, {B, 2}, {C, 3}, {D, 4}, {E, 5}}) == 15); + EXPECT(e.eval_dim({{A, 1}, {B, 2}, {C, 3}, {D, 4}, {E, 5}}) == 15); } TEST_CASE(edge_neg_one_coefficient) @@ -698,7 +698,7 @@ TEST_CASE(edge_large_coefficients) { auto H = var("H"); auto r = 1000000 * H; - EXPECT(r.eval({{H, 1000000}}) == 1000000000000ULL); + EXPECT(r.eval_dim({{H, 1000000}}) == 1000000000000ULL); } // Incrementally adding H ten times must fold to 11*H From 830594f4531941920aca6e4a763a62fbbab210cc Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 21:57:10 -0700 Subject: [PATCH 13/60] use int64 for internal eval --- src/sym.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sym.cpp b/src/sym.cpp index 14d3df42c9e..cab231a9c5b 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -449,7 +449,7 @@ static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) // Section 6: Substitution and evaluation // =================================================================== -using binding_map = std::map; +using binding_map = std::map; using subs_map = std::map; static expr_ptr substitute(const expr_ptr& e, const subs_map& bindings) @@ -495,7 +495,7 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) [&](const symbol_data& d) -> int64_t { auto it = bindings.find(e); if(it != bindings.end()) - return static_cast(it->second); + return it->second; MIGRAPHX_THROW("sym::expr::eval_dim: unbound symbol '" + d.name + "'"); }, @@ -791,7 +791,7 @@ std::size_t expr::eval_dim(const std::unordered_map& symbol_m { if(k.empty() or not holds(k.p->node)) MIGRAPHX_THROW("sym::expr::eval_dim: map key '" + k.to_string() + "' is not a symbol"); - bindings[k.p->node] = v; + bindings[k.p->node] = static_cast(v); } auto v = eval_direct(p->node, bindings); assert(v >= 0 && "symbolic dimension evaluated to negative value"); From def3038552d76cf4c58df25e433fba7b37b09593 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 22:00:48 -0700 Subject: [PATCH 14/60] fix eval call --- src/shape.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 82bb121e867..1a52ec01041 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -734,7 +734,7 @@ shape shape::to_static(const std::unordered_map& symbol_ if(dd.is_fixed()) return dd.min; if(dd.sym_expr) - return dd.sym_expr->eval(symbol_map); + return dd.sym_expr->eval_dim(symbol_map); MIGRAPHX_THROW("to_static: non-fixed dimension has no symbolic expression"); }); const auto& ds = this->dyn_strides(); @@ -742,7 +742,7 @@ shape shape::to_static(const std::unordered_map& symbol_ return {type(), static_lens}; std::vector static_strides(ds.size()); std::transform(ds.cbegin(), ds.cend(), static_strides.begin(), [&](const auto& s) { - return s.eval(symbol_map); + return s.eval_dim(symbol_map); }); return {type(), static_lens, static_strides}; } From 359070d2d180317db65356095e0a284e77f89d5b Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 22:43:52 -0700 Subject: [PATCH 15/60] copilot comments --- src/sym.cpp | 33 ++++++++++++++++++++++++++++----- test/sym_test.cpp | 19 +++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/sym.cpp b/src/sym.cpp index cab231a9c5b..3b3601677f4 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -516,8 +516,10 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) return prod; }, [&](const fdiv_data& d) -> int64_t { - return eval_direct(d.numerator, bindings) / - eval_direct(d.denominator, bindings); + auto denom = eval_direct(d.denominator, bindings); + if(denom == 0) + MIGRAPHX_THROW("sym::expr::eval_dim: division by zero"); + return eval_direct(d.numerator, bindings) / denom; }}, e->data); } @@ -694,21 +696,42 @@ static expr_ptr parse_unary(const char*& p) return parse_primary(p); } +static expr_ptr parse_power(const char*& p) +{ + auto base = parse_unary(p); + skip_ws(p); + if(*p == '*' and *(p + 1) == '*') + { + p += 2; + auto exp_node = parse_unary(p); + if(not holds(exp_node)) + MIGRAPHX_THROW("symbolic parser: ** exponent must be an integer literal"); + auto exp = get_integer(exp_node); + if(exp < 0) + MIGRAPHX_THROW("symbolic parser: ** exponent must be non-negative"); + expr_ptr result = make_integer(1); + for(int64_t i = 0; i < exp; ++i) + result = make_mul(result, base); + return result; + } + return base; +} + static expr_ptr parse_term(const char*& p) { - auto left = parse_unary(p); + auto left = parse_power(p); for(;;) { skip_ws(p); if(*p == '*') { ++p; - left = make_mul(left, parse_unary(p)); + left = make_mul(left, parse_power(p)); } else if(*p == '/') { ++p; - left = make_floor_div(left, parse_unary(p)); + left = make_floor_div(left, parse_power(p)); } else break; diff --git a/test/sym_test.cpp b/test/sym_test.cpp index ebcd31eadc3..72744048bf9 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -392,6 +392,12 @@ TEST_CASE(eval_unbound_throws) EXPECT(test::throws([&] { (H + W).eval_dim({{H, 1}}); })); } +TEST_CASE(eval_division_by_zero_throws) +{ + auto H = var("H"), D = var("D"); + EXPECT(test::throws([&] { (H / D).eval_dim({{H, 10}, {D, 0}}); })); +} + TEST_CASE(eval_integer_expr) { EXPECT(lit(0).eval_dim({}) == 0); @@ -600,6 +606,17 @@ TEST_CASE(parse_whitespace_tolerance) EXPECT(parse("H+1") == parse("H + 1")); } +TEST_CASE(parse_power_operator) +{ + auto H = var("H"); + EXPECT(parse("H**2") == H * H); + EXPECT(parse("H**3") == H * H * H); + EXPECT(parse("H**1") == H); + EXPECT(parse("H**0") == lit(1)); + EXPECT(parse("2*H**2 + 1") == 2 * H * H + 1); + EXPECT(parse("(2*H)**3 + 5") == 8 * H * H * H + 5); +} + TEST_CASE(print_negative_mul_coefficient) { auto r = 0 - 3 * var("H"); @@ -637,6 +654,8 @@ TEST_CASE(print_parse_round_trip) (H - 3) / 2 + 1, N * C * H * W, (H - 1) / 2, + H * H, + H * H * H, }; for(const auto& e : exprs) { From 9ad996f170adab68e9da67e4354cd6571d99ffd9 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 25 Mar 2026 22:48:03 -0700 Subject: [PATCH 16/60] copilot review fix --- src/sym.cpp | 2 +- test/sym_test.cpp | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/sym.cpp b/src/sym.cpp index 3b3601677f4..605853fb262 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -346,7 +346,7 @@ static expr_ptr make_neg(const expr_ptr& a) return build_add(-d.constant, std::move(negated)); }, [](const mul_data& d) -> expr_ptr { - return make_node(mul_data{-d.coefficient, d.factors}); + return build_mul(-d.coefficient, d.factors); }, [&](const auto&) -> expr_ptr { return make_mul(make_integer(-1), a); }}, a->data); diff --git a/test/sym_test.cpp b/test/sym_test.cpp index 72744048bf9..25994c72896 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -276,6 +276,15 @@ TEST_CASE(neg_of_product_double) EXPECT(dbl == hw); } +TEST_CASE(neg_of_neg_mul_canonicalizes) +{ + auto H = var("H"); + auto neg = 0 - H; + EXPECT(neg == lit(-1) * H); + auto pos = 0 - neg; + EXPECT(pos == H); +} + TEST_CASE(add_compound_product_like_terms) { auto H = var("H"), W = var("W"); From bda9f9154fdb9ccaf5def6b8d0775b611171e536 Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 26 Mar 2026 10:05:41 -0700 Subject: [PATCH 17/60] format and tidy --- src/sym.cpp | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/sym.cpp b/src/sym.cpp index 605853fb262..f76d14449c5 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -114,7 +114,7 @@ static const mul_data& get_mul(const expr_ptr& e) { return std::get(e- static std::size_t hash_combine(std::size_t seed, std::size_t v) { - return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2)); + return seed ^ (v + 0x9e3779b9 + (seed << 6u) + (seed >> 2u)); } template @@ -338,17 +338,16 @@ static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b) static expr_ptr make_neg(const expr_ptr& a) { return std::visit( - overloaded{[](const integer_data& d) -> expr_ptr { return make_integer(-d.value); }, - [](const add_data& d) -> expr_ptr { - term_map negated; - for(const auto& [term, coeff] : d.terms) - negated[term] = -coeff; - return build_add(-d.constant, std::move(negated)); - }, - [](const mul_data& d) -> expr_ptr { - return build_mul(-d.coefficient, d.factors); - }, - [&](const auto&) -> expr_ptr { return make_mul(make_integer(-1), a); }}, + overloaded{ + [](const integer_data& d) -> expr_ptr { return make_integer(-d.value); }, + [](const add_data& d) -> expr_ptr { + term_map negated; + for(const auto& [term, coeff] : d.terms) + negated[term] = -coeff; + return build_add(-d.constant, std::move(negated)); + }, + [](const mul_data& d) -> expr_ptr { return build_mul(-d.coefficient, d.factors); }, + [&](const auto&) -> expr_ptr { return make_mul(make_integer(-1), a); }}, a->data); } @@ -627,7 +626,7 @@ static std::string print_expr(const expr_ptr& e, int parent_prec) static void skip_ws(const char*& p) { - while(*p and std::isspace(static_cast(*p))) + while(*p != '\0' and std::isspace(static_cast(*p)) != 0) ++p; } @@ -639,20 +638,20 @@ static expr_ptr parse_primary(const char*& p); static expr_ptr parse_primary(const char*& p) { skip_ws(p); - if(std::isdigit(static_cast(*p))) + if(std::isdigit(static_cast(*p)) != 0) { int64_t n = 0; - while(std::isdigit(static_cast(*p))) + while(std::isdigit(static_cast(*p)) != 0) { n = n * 10 + (*p - '0'); ++p; } return make_integer(n); } - if(std::isalpha(static_cast(*p)) or *p == '_') + if(std::isalpha(static_cast(*p)) != 0 or *p == '_') { std::string name; - while(std::isalnum(static_cast(*p)) or *p == '_') + while(std::isalnum(static_cast(*p)) != 0 or *p == '_') { name += *p; ++p; @@ -817,7 +816,7 @@ std::size_t expr::eval_dim(const std::unordered_map& symbol_m bindings[k.p->node] = static_cast(v); } auto v = eval_direct(p->node, bindings); - assert(v >= 0 && "symbolic dimension evaluated to negative value"); + assert(v >= 0 and "symbolic dimension evaluated to negative value"); return static_cast(v); } From 003c9d3d6420444af85aa729da9ee637c0c2da09 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 10:01:53 -0700 Subject: [PATCH 18/60] tidy fix --- test/sym_test.cpp | 539 ++++++++++++++++++++++++---------------------- 1 file changed, 285 insertions(+), 254 deletions(-) diff --git a/test/sym_test.cpp b/test/sym_test.cpp index 25994c72896..bbee34da0e0 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -44,7 +44,7 @@ TEST_CASE(construct_integer) TEST_CASE(construct_symbol) { - EXPECT(var("H").to_string() == "H"); + EXPECT(var("h").to_string() == "h"); EXPECT(var("batch_size").to_string() == "batch_size"); } @@ -57,22 +57,23 @@ TEST_CASE(construct_empty) TEST_CASE(add_identity) { - auto H = var("H"); - EXPECT(H + 0 == H); - EXPECT(0 + H == H); + auto h = var("h"); + EXPECT(h + 0 == h); + EXPECT(0 + h == h); } TEST_CASE(add_commutativity) { - auto H = var("H"), W = var("W"); - EXPECT(H + W == W + H); + auto h = var("h"); + auto w = var("w"); + EXPECT(h + w == w + h); } TEST_CASE(add_like_term_folding) { - auto H = var("H"); - auto r = H + H; - EXPECT(r == 2 * H); + auto h = var("h"); + auto r = h + h; + EXPECT(r == 2 * h); } TEST_CASE(add_constant_folding) @@ -83,48 +84,50 @@ TEST_CASE(add_constant_folding) TEST_CASE(add_flattening) { - auto H = var("H"), W = var("W"), C = var("C"); - auto a = (H + W) + C; - auto b = H + (W + C); - auto c = (C + H) + W; - EXPECT(a == b); - EXPECT(b == c); + auto h = var("h"); + auto w = var("w"); + auto c = var("c"); + auto r = (h + w) + c; + auto s = h + (w + c); + auto t = (c + h) + w; + EXPECT(r == s); + EXPECT(s == t); } TEST_CASE(add_mixed) { - auto H = var("H"); - auto r = H + 3 + H + 2; - EXPECT(r == 2 * H + 5); - auto r2 = H + H; - EXPECT(r2 + 5 == 2 * H + 5); + auto h = var("h"); + auto r = h + 3 + h + 2; + EXPECT(r == 2 * h + 5); + auto r2 = h + h; + EXPECT(r2 + 5 == 2 * h + 5); } TEST_CASE(add_cancellation) { - auto H = var("H"); - EXPECT(H + (-1 * H) == lit(0)); + auto h = var("h"); + EXPECT(h + (-1 * h) == lit(0)); } TEST_CASE(sub_identity) { - auto H = var("H"); - EXPECT(H - 0 == H); + auto h = var("h"); + EXPECT(h - 0 == h); } TEST_CASE(sub_self) { - auto H = var("H"); - EXPECT(H - H == lit(0)); + auto h = var("h"); + EXPECT(h - h == lit(0)); } TEST_CASE(sub_constant_folding) { EXPECT(lit(10) - lit(3) == lit(7)); } TEST_CASE(sub_produces_negation) { - auto H = var("H"); - EXPECT(-1 * H == lit(0) - H); - EXPECT(3 - H == lit(0) - H + 3); + auto h = var("h"); + EXPECT(-1 * h == lit(0) - h); + EXPECT(3 - h == lit(0) - h + 3); } TEST_CASE(neg_integer) @@ -135,61 +138,64 @@ TEST_CASE(neg_integer) TEST_CASE(neg_double_negation) { - auto H = var("H"); - auto r = 0 - (-1 * H); - EXPECT(r == H); + auto h = var("h"); + auto r = 0 - (-1 * h); + EXPECT(r == h); } TEST_CASE(mul_identity) { - auto H = var("H"); - EXPECT(H * 1 == H); - EXPECT(1 * H == H); + auto h = var("h"); + EXPECT(h * 1 == h); + EXPECT(1 * h == h); } TEST_CASE(mul_zero) { - auto H = var("H"); - EXPECT(H * 0 == lit(0)); - EXPECT(0 * H == lit(0)); + auto h = var("h"); + EXPECT(h * 0 == lit(0)); + EXPECT(0 * h == lit(0)); } TEST_CASE(mul_constant_folding) { EXPECT(lit(3) * lit(7) == lit(21)); } TEST_CASE(mul_commutativity) { - auto A = var("A"), B = var("B"); - EXPECT(B * A == A * B); + auto a = var("a"); + auto b = var("b"); + EXPECT(b * a == a * b); } TEST_CASE(mul_coefficient_accumulation) { - auto H = var("H"); - auto r = 2 * H * 3; - EXPECT(r == 6 * H); + auto h = var("h"); + auto r = 2 * h * 3; + EXPECT(r == 6 * h); } TEST_CASE(mul_flattening) { - auto H = var("H"), W = var("W"), C = var("C"); - auto a = (H * W) * C; - auto b = H * (W * C); - auto c = (C * H) * W; - EXPECT(a == b); - EXPECT(b == c); + auto h = var("h"); + auto w = var("w"); + auto c = var("c"); + auto r = (h * w) * c; + auto s = h * (w * c); + auto t = (c * h) * w; + EXPECT(r == s); + EXPECT(s == t); } TEST_CASE(mul_distributive) { - auto H = var("H"); - auto r = 2 * (H + 1); - EXPECT(r == 2 * H + 2); + auto h = var("h"); + auto r = 2 * (h + 1); + EXPECT(r == 2 * h + 2); } TEST_CASE(fdiv_identity) { - auto H = var("H"); - EXPECT(H / 1 == H); + auto h = var("h"); + EXPECT(h / 1 == h); } TEST_CASE(fdiv_constant_folding) @@ -201,75 +207,75 @@ TEST_CASE(fdiv_constant_folding) TEST_CASE(fdiv_exact_coefficient_cancel) { - auto N = var("N"); - auto r = (6 * N) / 3; - EXPECT(r == 2 * N); + auto n = var("n"); + auto r = (6 * n) / 3; + EXPECT(r == 2 * n); } TEST_CASE(fdiv_non_simplifiable) { - auto H = var("H"); - auto r = (H - 1) / 2; - EXPECT(r == (H - 1) / 2); + auto h = var("h"); + auto r = (h - 1) / 2; + EXPECT(r == (h - 1) / 2); } TEST_CASE(fdiv_division_by_zero) { - EXPECT(test::throws([&] { var("H") / 0; })); + EXPECT(test::throws([&] { var("h") / 0; })); } TEST_CASE(add_scaled_subtraction) { - auto H = var("H"); - EXPECT(2 * H - H == H); - EXPECT(3 * H - 2 * H == H); - EXPECT(H + H + H == 3 * H); + auto h = var("h"); + EXPECT(2 * h - h == h); + EXPECT(3 * h - 2 * h == h); + EXPECT(h + h + h == 3 * h); } TEST_CASE(add_of_two_adds) { - auto H = var("H"); - auto r = (H + 1) + (H + 2); - EXPECT(r == 2 * H + 3); + auto h = var("h"); + auto r = (h + 1) + (h + 2); + EXPECT(r == 2 * h + 3); } TEST_CASE(sub_strip_constant) { - auto H = var("H"); - EXPECT((H + 1) - H == lit(1)); + auto h = var("h"); + EXPECT((h + 1) - h == lit(1)); } TEST_CASE(sub_of_two_adds) { - auto H = var("H"); - auto r = (H + 1) - (H + 2); + auto h = var("h"); + auto r = (h + 1) - (h + 2); EXPECT(r == lit(-1)); } TEST_CASE(mul_zero_propagation) { - auto H = var("H"); - auto z = H - H; + auto h = var("h"); + auto z = h - h; EXPECT(50 * z == lit(0)); } TEST_CASE(add_chain_constant_cancel) { - auto H = var("H"); - auto r = lit(2) - H - lit(2); - EXPECT(r == -1 * H); + auto h = var("h"); + auto r = lit(2) - h - lit(2); + EXPECT(r == -1 * h); } TEST_CASE(neg_of_sum_distributes) { - auto H = var("H"); - auto r = lit(-1) * (H + 1); - EXPECT(r == -1 * H - 1); + auto h = var("h"); + auto r = lit(-1) * (h + 1); + EXPECT(r == -1 * h - 1); } TEST_CASE(neg_of_product_double) { - auto hw = var("H") * var("W"); + auto hw = var("h") * var("w"); auto neg = 0 - hw; EXPECT(neg == lit(0) - hw); auto dbl = 0 - neg; @@ -278,18 +284,19 @@ TEST_CASE(neg_of_product_double) TEST_CASE(neg_of_neg_mul_canonicalizes) { - auto H = var("H"); - auto neg = 0 - H; - EXPECT(neg == lit(-1) * H); + auto h = var("h"); + auto neg = 0 - h; + EXPECT(neg == lit(-1) * h); auto pos = 0 - neg; - EXPECT(pos == H); + EXPECT(pos == h); } TEST_CASE(add_compound_product_like_terms) { - auto H = var("H"), W = var("W"); - auto hw = H * W; - auto wh = W * H; + auto h = var("h"); + auto w = var("w"); + auto hw = h * w; + auto wh = w * h; EXPECT(hw + hw == 2 * hw); EXPECT(hw + wh == 2 * hw); EXPECT(hw + 2 * hw == 3 * hw); @@ -297,25 +304,28 @@ TEST_CASE(add_compound_product_like_terms) TEST_CASE(add_compound_product_cancellation) { - auto hw = var("H") * var("W"); + auto hw = var("h") * var("w"); EXPECT(hw - hw == lit(0)); } // X*Y and X cancel pairwise: (X*Y - X) - (X*Y - X) == 0 TEST_CASE(sub_compound_product_mixed) { - auto X = var("X"), Y = var("Y"); - auto xy = X * Y; - auto r = xy - X - xy + X; + auto x = var("x"); + auto y = var("y"); + auto xy = x * y; + auto r = xy - x - xy + x; EXPECT(r == lit(0)); } // Duplicate A*B terms fold even when separated by another term TEST_CASE(add_multi_term_accumulation) { - auto A = var("A"), B = var("B"), C = var("C"); - auto r = A * B + C + A * B; - auto expected = 2 * (A * B) + C; + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto r = a * b + c + a * b; + auto expected = 2 * (a * b) + c; EXPECT(r == expected); } @@ -338,9 +348,10 @@ TEST_CASE(fdiv_large_constants) TEST_CASE(eq_different_values) { - auto H = var("H"), W = var("W"); - EXPECT(H + 1 != H + 2); - EXPECT(H != W); + auto h = var("h"); + auto w = var("w"); + EXPECT(h + 1 != h + 2); + EXPECT(h != w); EXPECT(lit(3) != lit(4)); } @@ -357,54 +368,57 @@ TEST_CASE(eq_empty) TEST_CASE(eval_simple) { - auto H = var("H"); - EXPECT(H.eval_dim({{H, 26}}) == 26); + auto h = var("h"); + EXPECT(h.eval_dim({{h, 26}}) == 26); EXPECT(lit(42).eval_dim({}) == 42); } TEST_CASE(eval_arithmetic) { - auto H = var("H"); - EXPECT((H - 3).eval_dim({{H, 26}}) == 23); - EXPECT((H + 5).eval_dim({{H, 10}}) == 15); - EXPECT((2 * H).eval_dim({{H, 13}}) == 26); + auto h = var("h"); + EXPECT((h - 3).eval_dim({{h, 26}}) == 23); + EXPECT((h + 5).eval_dim({{h, 10}}) == 15); + EXPECT((2 * h).eval_dim({{h, 13}}) == 26); } TEST_CASE(eval_compound) { - auto H = var("H"); - auto e = (H - 3) / 2 + 1; - EXPECT(e.eval_dim({{H, 26}}) == 12); - EXPECT(e.eval_dim({{H, 27}}) == 13); + auto h = var("h"); + auto e = (h - 3) / 2 + 1; + EXPECT(e.eval_dim({{h, 26}}) == 12); + EXPECT(e.eval_dim({{h, 27}}) == 13); } TEST_CASE(eval_multiple_symbols) { - auto N = var("N"), H = var("H"); - auto e = N * H; - EXPECT(e.eval_dim({{N, 4}, {H, 26}}) == 104); + auto n = var("n"); + auto h = var("h"); + auto e = n * h; + EXPECT(e.eval_dim({{n, 4}, {h, 26}}) == 104); } TEST_CASE(eval_floor_division) { - auto H = var("H"); - auto e = (H - 1) / 2; - EXPECT(e.eval_dim({{H, 7}}) == 3); - EXPECT(e.eval_dim({{H, 8}}) == 3); - EXPECT(e.eval_dim({{H, 9}}) == 4); + auto h = var("h"); + auto e = (h - 1) / 2; + EXPECT(e.eval_dim({{h, 7}}) == 3); + EXPECT(e.eval_dim({{h, 8}}) == 3); + EXPECT(e.eval_dim({{h, 9}}) == 4); } TEST_CASE(eval_unbound_throws) { - auto H = var("H"), W = var("W"); - EXPECT(test::throws([&] { H.eval_dim({}); })); - EXPECT(test::throws([&] { (H + W).eval_dim({{H, 1}}); })); + auto h = var("h"); + auto w = var("w"); + EXPECT(test::throws([&] { h.eval_dim({}); })); + EXPECT(test::throws([&] { (h + w).eval_dim({{h, 1}}); })); } TEST_CASE(eval_division_by_zero_throws) { - auto H = var("H"), D = var("D"); - EXPECT(test::throws([&] { (H / D).eval_dim({{H, 10}, {D, 0}}); })); + auto h = var("h"); + auto d = var("d"); + EXPECT(test::throws([&] { (h / d).eval_dim({{h, 10}, {d, 0}}); })); } TEST_CASE(eval_integer_expr) @@ -415,42 +429,44 @@ TEST_CASE(eval_integer_expr) TEST_CASE(subs_partial) { - auto N = var("N"), H = var("H"); - auto e = N * H + 1; - auto r = e.subs({{N, lit(4)}}); - EXPECT(r == 4 * H + 1); - EXPECT(r.eval_dim({{H, 10}}) == 41); + auto n = var("n"); + auto h = var("h"); + auto e = n * h + 1; + auto r = e.subs({{n, lit(4)}}); + EXPECT(r == 4 * h + 1); + EXPECT(r.eval_dim({{h, 10}}) == 41); } TEST_CASE(subs_full) { - auto H = var("H"); - auto e = H + 1; - auto r = e.subs({{H, lit(5)}}); + auto h = var("h"); + auto e = h + 1; + auto r = e.subs({{h, lit(5)}}); EXPECT(r == lit(6)); } TEST_CASE(subs_none) { - auto H = var("H"); - EXPECT(H.subs({}) == H); + auto h = var("h"); + EXPECT(h.subs({}) == h); } TEST_CASE(subs_floor_div) { - auto H = var("H"); - auto e = (H - 1) / 2; - auto r = e.subs({{H, lit(7)}}); + auto h = var("h"); + auto e = (h - 1) / 2; + auto r = e.subs({{h, lit(7)}}); EXPECT(r == lit(3)); } // eval() and subs()+eval() must agree on a compound expression TEST_CASE(subs_eval_cross_validation) { - auto N = var("N"), H = var("H"); - auto e = (N * H - 3) / 2 + 1; - std::unordered_map eval_map = {{N, 4}, {H, 26}}; - std::unordered_map subs_map = {{N, lit(4)}, {H, lit(26)}}; + auto n = var("n"); + auto h = var("h"); + auto e = (n * h - 3) / 2 + 1; + std::unordered_map eval_map = {{n, 4}, {h, 26}}; + std::unordered_map subs_map = {{n, lit(4)}, {h, lit(26)}}; auto via_eval = e.eval_dim(eval_map); auto via_subs = e.subs(subs_map).eval_dim({}); EXPECT(via_eval == via_subs); @@ -459,59 +475,65 @@ TEST_CASE(subs_eval_cross_validation) TEST_CASE(subs_empty) { se e; - auto r = e.subs({{var("H"), lit(5)}}); + auto r = e.subs({{var("h"), lit(5)}}); EXPECT(r.empty()); } TEST_CASE(subs_creates_like_terms) { - auto H = var("H"), W = var("W"); - auto e = H + W; - auto r = e.subs({{W, lit(0)}}); - EXPECT(r == H); + auto h = var("h"); + auto w = var("w"); + auto e = h + w; + auto r = e.subs({{w, lit(0)}}); + EXPECT(r == h); } TEST_CASE(subs_with_expression) { - auto H = var("H"), W = var("W"); - auto e = 2 * H + 1; - auto r = e.subs({{H, W + 3}}); - EXPECT(r == 2 * W + 7); + auto h = var("h"); + auto w = var("w"); + auto e = 2 * h + 1; + auto r = e.subs({{h, w + 3}}); + EXPECT(r == 2 * w + 7); } TEST_CASE(subs_symbol_for_symbol) { - auto H = var("H"), W = var("W"); - auto e = H * H + 1; - auto r = e.subs({{H, W}}); - EXPECT(r == W * W + 1); + auto h = var("h"); + auto w = var("w"); + auto e = h * h + 1; + auto r = e.subs({{h, w}}); + EXPECT(r == w * w + 1); } TEST_CASE(subs_compound_expression) { - auto N = var("N"), H = var("H"), W = var("W"); - auto e = (N * H + W - 3) / 2; - auto r = e.subs({{H, 2 * W + 1}, {N, W - 1}}); + auto n = var("n"); + auto h = var("h"); + auto w = var("w"); + auto e = (n * h + w - 3) / 2; + auto r = e.subs({{h, 2 * w + 1}, {n, w - 1}}); // N*H => (W-1)*(2*W+1) = 2*W^2 - W - 1 // N*H + W - 3 => 2*W^2 - 2*W - 4 + W = 2*W^2 - 2 // Verify by evaluating with W=5: (W-1)*(2*W+1) + W - 3 = 4*11 + 5 - 3 = 46, 46/2 = 23 - EXPECT(r.eval_dim({{W, 5}}) == 23); + EXPECT(r.eval_dim({{w, 5}}) == 23); // Also verify the original expression with direct values agrees - EXPECT(e.eval_dim({{N, 4}, {H, 11}, {W, 5}}) == 23); + EXPECT(e.eval_dim({{n, 4}, {h, 11}, {w, 5}}) == 23); } TEST_CASE(eval_compound_product) { - auto H = var("H"), W = var("W"); - auto e = H * W + 1; - EXPECT(e.eval_dim({{H, 3}, {W, 4}}) == 13); + auto h = var("h"); + auto w = var("w"); + auto e = h * w + 1; + EXPECT(e.eval_dim({{h, 3}, {w, 4}}) == 13); } TEST_CASE(eval_negative_intermediate) { - auto H = var("H"); - auto e = (H - 10) * 2 + 20; - EXPECT(e.eval_dim({{H, 3}}) == 6); + auto h = var("h"); + auto e = (h - 10) * 2 + 20; + EXPECT(e.eval_dim({{h, 3}}) == 6); } // =================================================================== @@ -521,150 +543,152 @@ TEST_CASE(eval_negative_intermediate) TEST_CASE(print_atoms) { EXPECT(lit(42).to_string() == "42"); - EXPECT(var("H").to_string() == "H"); + EXPECT(var("h").to_string() == "h"); EXPECT(lit(0).to_string() == "0"); EXPECT(lit(-3).to_string() == "-3"); } TEST_CASE(print_add) { - auto H = var("H"); - EXPECT((H + 1).to_string() == "H + 1"); - EXPECT((H - 3).to_string() == "H - 3"); + auto h = var("h"); + EXPECT((h + 1).to_string() == "h + 1"); + EXPECT((h - 3).to_string() == "h - 3"); } TEST_CASE(print_mul) { - EXPECT((2 * var("H")).to_string() == "2*H"); - auto r = var("A") * var("B"); - EXPECT(r.to_string() == "A*B"); + EXPECT((2 * var("h")).to_string() == "2*h"); + auto r = var("a") * var("b"); + EXPECT(r.to_string() == "a*b"); } TEST_CASE(print_fdiv_parens) { - auto r = (var("H") - 1) / 2; - EXPECT(r.to_string() == "(H - 1)/2"); + auto r = (var("h") - 1) / 2; + EXPECT(r.to_string() == "(h - 1)/2"); } TEST_CASE(print_compound) { - auto r = (var("H") - 3) / 2 + 1; - EXPECT(r.to_string() == "(H - 3)/2 + 1"); + auto r = (var("h") - 3) / 2 + 1; + EXPECT(r.to_string() == "(h - 3)/2 + 1"); } TEST_CASE(parse_atoms) { EXPECT(parse("42") == lit(42)); - EXPECT(parse("H") == var("H")); + EXPECT(parse("h") == var("h")); } TEST_CASE(parse_arithmetic) { - auto H = var("H"); - auto r = parse("H + 1"); - EXPECT(r == H + 1); + auto h = var("h"); + auto r = parse("h + 1"); + EXPECT(r == h + 1); - auto r2 = parse("H - 3"); - EXPECT(r2 == H - 3); + auto r2 = parse("h - 3"); + EXPECT(r2 == h - 3); - auto r3 = parse("2*H"); - EXPECT(r3 == 2 * H); + auto r3 = parse("2*h"); + EXPECT(r3 == 2 * h); } TEST_CASE(parse_precedence) { - auto r = parse("H + 1 * 2"); - EXPECT(r == var("H") + 2); + auto r = parse("h + 1 * 2"); + EXPECT(r == var("h") + 2); } TEST_CASE(parse_parentheses) { - auto r = parse("(H + 1) * 2"); - EXPECT(r == 2 * (var("H") + 1)); + auto r = parse("(h + 1) * 2"); + EXPECT(r == 2 * (var("h") + 1)); } TEST_CASE(parse_division) { - auto r = parse("(H - 1)/2"); - EXPECT(r == (var("H") - 1) / 2); + auto r = parse("(h - 1)/2"); + EXPECT(r == (var("h") - 1) / 2); } TEST_CASE(parse_unary_minus) { - auto H = var("H"); - EXPECT(parse("-H") == -1 * H); - EXPECT(parse("-H").to_string() == "-H"); - EXPECT(parse("-(H + 1)") == -1 * H - 1); + auto h = var("h"); + EXPECT(parse("-h") == -1 * h); + EXPECT(parse("-h").to_string() == "-h"); + EXPECT(parse("-(h + 1)") == -1 * h - 1); } -// Legacy floor() wrapper is accepted by parser and treated as no-op TEST_CASE(parse_floor_backward_compat) { - auto a = parse("floor((H-1)/2)"); - auto b = parse("(H-1)/2"); + auto a = parse("floor((h-1)/2)"); + auto b = parse("(h-1)/2"); EXPECT(a == b); - auto c = parse("floor((H-1)/2) + 1"); - auto d = (var("H") - 1) / 2 + 1; + auto c = parse("floor((h-1)/2) + 1"); + auto d = (var("h") - 1) / 2 + 1; EXPECT(c == d); } TEST_CASE(parse_whitespace_tolerance) { - EXPECT(parse(" H + 1 ") == parse("H + 1")); - EXPECT(parse("H+1") == parse("H + 1")); + EXPECT(parse(" h + 1 ") == parse("h + 1")); + EXPECT(parse("h+1") == parse("h + 1")); } TEST_CASE(parse_power_operator) { - auto H = var("H"); - EXPECT(parse("H**2") == H * H); - EXPECT(parse("H**3") == H * H * H); - EXPECT(parse("H**1") == H); - EXPECT(parse("H**0") == lit(1)); - EXPECT(parse("2*H**2 + 1") == 2 * H * H + 1); - EXPECT(parse("(2*H)**3 + 5") == 8 * H * H * H + 5); + auto h = var("h"); + EXPECT(parse("h**2") == h * h); + EXPECT(parse("h**3") == h * h * h); + EXPECT(parse("h**1") == h); + EXPECT(parse("h**0") == lit(1)); + EXPECT(parse("2*h**2 + 1") == 2 * h * h + 1); + EXPECT(parse("(2*h)**3 + 5") == 8 * h * h * h + 5); } TEST_CASE(print_negative_mul_coefficient) { - auto r = 0 - 3 * var("H"); - EXPECT(r.to_string() == "-3*H"); + auto r = 0 - 3 * var("h"); + EXPECT(r.to_string() == "-3*h"); } TEST_CASE(print_multi_symbol_product) { - auto r = var("H") * var("W"); + auto r = var("h") * var("w"); auto s = r.to_string(); - EXPECT(s == "H*W" or s == "W*H"); - EXPECT(parse("H*W") == parse("W*H")); + EXPECT(s == "h*w" or s == "w*h"); + EXPECT(parse("h*w") == parse("w*h")); } TEST_CASE(print_compound_expression) { - auto r = 2 * (var("H") * var("W")) + var("C") - 1; + auto r = 2 * (var("h") * var("w")) + var("c") - 1; auto s = r.to_string(); EXPECT(parse(s) == r); } TEST_CASE(parse_compound_mul) { - auto r = parse("2*H*W"); - EXPECT(r == 2 * var("H") * var("W")); + auto r = parse("2*h*w"); + EXPECT(r == 2 * var("h") * var("w")); } TEST_CASE(print_parse_round_trip) { - auto H = var("H"), N = var("N"), C = var("C"), W = var("W"); + auto h = var("h"); + auto n = var("n"); + auto c = var("c"); + auto w = var("w"); std::vector exprs = { - H, - H + 1, - 2 * H - 3, - (H - 3) / 2 + 1, - N * C * H * W, - (H - 1) / 2, - H * H, - H * H * H, + h, + h + 1, + 2 * h - 3, + (h - 3) / 2 + 1, + n * c * h * w, + (h - 1) / 2, + h * h, + h * h * h, }; for(const auto& e : exprs) { @@ -681,25 +705,29 @@ TEST_CASE(print_parse_round_trip) // 5 levels of (e-1)/2: simulates repeated pooling/conv stride reduction TEST_CASE(edge_deeply_nested) { - auto H = var("H"); - se e = H; + auto h = var("h"); + se e = h; for(int i = 0; i < 5; ++i) e = (e - 1) / 2; - EXPECT(e.eval_dim({{H, 255}}) == 7); + EXPECT(e.eval_dim({{h, 255}}) == 7); } TEST_CASE(edge_many_symbols) { - auto A = var("A"), B = var("B"), C = var("C"), D = var("D"), E = var("E"); - auto e = A + B + C + D + E; - EXPECT(e.eval_dim({{A, 1}, {B, 2}, {C, 3}, {D, 4}, {E, 5}}) == 15); + auto a = var("a"); + auto b = var("b"); + auto c = var("c"); + auto d = var("d"); + auto e = var("e"); + auto r = a + b + c + d + e; + EXPECT(r.eval_dim({{a, 1}, {b, 2}, {c, 3}, {d, 4}, {e, 5}}) == 15); } TEST_CASE(edge_neg_one_coefficient) { - auto H = var("H"); - EXPECT(-1 * H == lit(0) - H); - EXPECT(-1 * H + H == lit(0)); + auto h = var("h"); + EXPECT(-1 * h == lit(0) - h); + EXPECT(-1 * h + h == lit(0)); } TEST_CASE(edge_empty_operations) @@ -714,38 +742,38 @@ TEST_CASE(edge_empty_operations) TEST_CASE(edge_empty_with_nonempty) { se empty; - auto H = var("H"); - auto r1 = H + empty; + auto h = var("h"); + auto r1 = h + empty; EXPECT(not r1.empty()); - auto r2 = empty + H; + auto r2 = empty + h; EXPECT(not r2.empty()); } TEST_CASE(edge_large_coefficients) { - auto H = var("H"); - auto r = 1000000 * H; - EXPECT(r.eval_dim({{H, 1000000}}) == 1000000000000ULL); + auto h = var("h"); + auto r = 1000000 * h; + EXPECT(r.eval_dim({{h, 1000000}}) == 1000000000000ULL); } // Incrementally adding H ten times must fold to 11*H TEST_CASE(edge_chained_operations) { - auto H = var("H"); - auto e = H; + auto h = var("h"); + auto e = h; for(int i = 0; i < 10; ++i) - e = e + H; - EXPECT(e == 11 * H); + e = e + h; + EXPECT(e == 11 * h); } TEST_CASE(edge_repeated_parse) { - auto H = var("H"); + auto h = var("h"); for(int i = 0; i < 10; ++i) { - auto r = parse("(H - 3)/2 + 1"); - EXPECT(r == (H - 3) / 2 + 1); + auto r = parse("(h - 3)/2 + 1"); + EXPECT(r == (h - 3) / 2 + 1); } } @@ -773,35 +801,38 @@ TEST_CASE(serialize_integer) TEST_CASE(serialize_symbol) { - auto H = var("H"); - EXPECT(round_trip(H) == H); + auto h = var("h"); + EXPECT(round_trip(h) == h); } TEST_CASE(serialize_add) { - auto H = var("H"); - auto e = 2 * H + 3; + auto h = var("h"); + auto e = 2 * h + 3; EXPECT(round_trip(e) == e); } TEST_CASE(serialize_mul) { - auto H = var("H"), W = var("W"); - auto e = H * W; + auto h = var("h"); + auto w = var("w"); + auto e = h * w; EXPECT(round_trip(e) == e); } TEST_CASE(serialize_fdiv) { - auto H = var("H"); - auto e = (H - 1) / 2; + auto h = var("h"); + auto e = (h - 1) / 2; EXPECT(round_trip(e) == e); } TEST_CASE(serialize_compound) { - auto N = var("N"), H = var("H"), W = var("W"); - auto e = (N * H * W + 3) / 2 - 1; + auto n = var("n"); + auto h = var("h"); + auto w = var("w"); + auto e = (n * h * w + 3) / 2 - 1; EXPECT(round_trip(e) == e); } From 3759299448883ae3d58c3657539a85a4b97b1e76 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 13:10:50 -0700 Subject: [PATCH 19/60] tidy --- test/sym_test.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/sym_test.cpp b/test/sym_test.cpp index bbee34da0e0..7da6644b1bf 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -255,8 +255,7 @@ TEST_CASE(sub_of_two_adds) TEST_CASE(mul_zero_propagation) { auto h = var("h"); - auto z = h - h; - EXPECT(50 * z == lit(0)); + EXPECT(50 * (h - h) == lit(0)); } TEST_CASE(add_chain_constant_cancel) From fe649ff8f604d9f0554a39cd4882e3899209095c Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 14:24:07 -0700 Subject: [PATCH 20/60] update the only call sites using the braced-init-list that cannot be distinguished by old libstdc++ --- src/include/migraphx/shape.hpp | 1 - src/targets/gpu/gemm_impl.cpp | 6 +++--- src/targets/gpu/hip_gemm_impl.cpp | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index bf900e861b7..216a0ad1973 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -221,7 +221,6 @@ struct MIGRAPHX_EXPORT shape shape(type_t t, std::initializer_list l, std::initializer_list s); shape(type_t t, std::vector dims); - shape(type_t t, std::vector dims, std::vector dstrides); // Construct a dynamic shape from vectors of mins, maxes, and optimals. diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index be83c0c1a4f..1b68bd418e3 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -93,9 +93,9 @@ void blas_shape(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; - shape batch_shape{s.type(), - {s.lens().begin(), s.lens().end() - 2}, - {s.strides().begin(), s.strides().end() - 2}}; + shape batch_shape(s.type(), + std::vector(s.lens().begin(), s.lens().end() - 2), + std::vector(s.strides().begin(), s.strides().end() - 2)); auto batch_shapes = reduce_dims({batch_shape}); if(batch_shapes.front().lens().size() != 1) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 7dda04c4e9a..0c2452c3708 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -89,9 +89,9 @@ void blas_shape_hip(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; - shape batch_shape{s.type(), - {s.lens().begin(), s.lens().end() - 2}, - {s.strides().begin(), s.strides().end() - 2}}; + shape batch_shape(s.type(), + std::vector(s.lens().begin(), s.lens().end() - 2), + std::vector(s.strides().begin(), s.strides().end() - 2)); auto batch_shapes = reduce_dims({batch_shape}); if(batch_shapes.front().lens().size() != 1) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); From 1274d3a0c6909915c92cadb6abbe4dd2babd109c Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 15:33:33 -0700 Subject: [PATCH 21/60] address review comments --- src/include/migraphx/sym.hpp | 2 +- src/sym.cpp | 128 ++++++++++++++---- test/sym_test.cpp | 250 +++++++++++++++++++++++++++++------ 3 files changed, 316 insertions(+), 64 deletions(-) diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index eaf677e2270..efa87d8cd21 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -49,7 +49,7 @@ struct MIGRAPHX_EXPORT expr std::string to_string() const; value to_value() const; void from_value(const value& v); - std::size_t eval_dim(const std::unordered_map& symbol_map) const; + std::size_t eval_uint(const std::unordered_map& symbol_map) const; expr subs(const std::unordered_map& symbol_map) const; MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); diff --git a/src/sym.cpp b/src/sym.cpp index f76d14449c5..e066324c0f0 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -27,7 +27,6 @@ #include #include -#include #include #include #include @@ -76,13 +75,13 @@ struct mul_data int64_t coefficient; factor_map factors; }; -struct fdiv_data +struct tdiv_data { expr_ptr numerator; expr_ptr denominator; }; -using expr_data = std::variant; +using expr_data = std::variant; template struct overloaded : Ts... @@ -144,7 +143,7 @@ static std::size_t compute_hash(const expr_data& d) return hash_combine(hash_combine(h, std::hash{}(p.coefficient)), hash_ordered_map(p.factors)); }, - [&](const fdiv_data& p) { + [&](const tdiv_data& p) { return hash_combine(hash_combine(h, p.numerator->cached_hash), p.denominator->cached_hash); }}, @@ -205,8 +204,8 @@ static int compare_expr(const expr_ptr& a, const expr_ptr& b) return da.coefficient < db.coefficient ? -1 : 1; return compare_maps(da.factors, db.factors); }, - [&](const fdiv_data& da) { - const auto& db = std::get(b->data); + [&](const tdiv_data& da) { + const auto& db = std::get(b->data); int c = compare_expr(da.numerator, db.numerator); if(c != 0) return c; @@ -278,7 +277,7 @@ static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b); static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b); static expr_ptr make_neg(const expr_ptr& a); static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b); -static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b); +static expr_ptr make_trunc_div(const expr_ptr& a, const expr_ptr& b); static expr_ptr build_mul(int64_t coefficient, factor_map factors); struct add_parts @@ -422,8 +421,14 @@ static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) return build_mul(coefficient, std::move(factors)); } -static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) +static expr_ptr make_trunc_div(const expr_ptr& a, const expr_ptr& b) { + if(holds(a) and get_integer(a) == 0) + return a; + + if(expr_equal(a, b)) + return make_integer(1); + if(holds(b)) { int64_t den = get_integer(b); @@ -439,9 +444,74 @@ static expr_ptr make_floor_div(const expr_ptr& a, const expr_ptr& b) if(d.coefficient % den == 0) return build_mul(d.coefficient / den, d.factors); } + if(holds(a)) + { + const auto& d = get_add(a); + bool all_divisible = (d.constant % den == 0); + if(all_divisible) + { + all_divisible = std::all_of(d.terms.begin(), d.terms.end(), [&](const auto& p) { + return p.second % den == 0; + }); + } + if(all_divisible) + { + term_map divided = d.terms; + for(auto& [base, coeff] : divided) + coeff /= den; + return build_add(d.constant / den, std::move(divided)); + } + } } - return make_node(fdiv_data{a, b}); + if(holds(a)) + { + const auto& da = get_mul(a); + + if(holds(b)) + { + const auto& db = get_mul(b); + factor_map reduced_num = da.factors; + factor_map reduced_den = db.factors; + for(auto it_den = reduced_den.begin(); it_den != reduced_den.end();) + { + auto it_num = reduced_num.find(it_den->first); + if(it_num == reduced_num.end()) + { + ++it_den; + continue; + } + int64_t cancel = std::min(it_num->second, it_den->second); + it_num->second -= cancel; + it_den->second -= cancel; + if(it_num->second == 0) + reduced_num.erase(it_num); + if(it_den->second == 0) + it_den = reduced_den.erase(it_den); + else + ++it_den; + } + auto new_num = build_mul(da.coefficient, std::move(reduced_num)); + auto new_den = build_mul(db.coefficient, std::move(reduced_den)); + if(not expr_equal(new_num, a) or not expr_equal(new_den, b)) + return make_trunc_div(new_num, new_den); + } + else + { + auto it = da.factors.find(b); + if(it != da.factors.end()) + { + factor_map reduced = da.factors; + if(it->second == 1) + reduced.erase(it->first); + else + reduced[it->first] = it->second - 1; + return build_mul(da.coefficient, std::move(reduced)); + } + } + } + + return make_node(tdiv_data{a, b}); } // =================================================================== @@ -480,10 +550,10 @@ static expr_ptr substitute(const expr_ptr& e, const subs_map& bindings) } return result; }, - [&](const fdiv_data& d) -> expr_ptr { + [&](const tdiv_data& d) -> expr_ptr { auto sn = substitute(d.numerator, bindings); auto sd = substitute(d.denominator, bindings); - return make_floor_div(sn, sd); + return make_trunc_div(sn, sd); }}, e->data); } @@ -495,7 +565,7 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) auto it = bindings.find(e); if(it != bindings.end()) return it->second; - MIGRAPHX_THROW("sym::expr::eval_dim: unbound symbol '" + + MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + d.name + "'"); }, [&](const add_data& d) -> int64_t { @@ -514,10 +584,10 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) } return prod; }, - [&](const fdiv_data& d) -> int64_t { + [&](const tdiv_data& d) -> int64_t { auto denom = eval_direct(d.denominator, bindings); if(denom == 0) - MIGRAPHX_THROW("sym::expr::eval_dim: division by zero"); + MIGRAPHX_THROW("sym::expr::eval_uint: division by zero"); return eval_direct(d.numerator, bindings) / denom; }}, e->data); @@ -610,7 +680,7 @@ static std::string print_expr(const expr_ptr& e, int parent_prec) [](const symbol_data& d) -> std::string { return d.name; }, [&](const add_data& d) -> std::string { return print_add(d, parent_prec); }, [&](const mul_data& d) -> std::string { return print_mul(d, parent_prec); }, - [&](const fdiv_data& d) -> std::string { + [&](const tdiv_data& d) -> std::string { std::string s = print_expr(d.numerator, prec_mul + 1) + "/" + print_expr(d.denominator, prec_mul + 1); if(parent_prec > prec_mul) @@ -730,7 +800,7 @@ static expr_ptr parse_term(const char*& p) else if(*p == '/') { ++p; - left = make_floor_div(left, parse_power(p)); + left = make_trunc_div(left, parse_power(p)); } else break; @@ -804,7 +874,7 @@ std::string expr::to_string() const return print_expr(p->node); } -std::size_t expr::eval_dim(const std::unordered_map& symbol_map) const +std::size_t expr::eval_uint(const std::unordered_map& symbol_map) const { if(empty()) return 0; @@ -812,11 +882,12 @@ std::size_t expr::eval_dim(const std::unordered_map& symbol_m for(const auto& [k, v] : symbol_map) { if(k.empty() or not holds(k.p->node)) - MIGRAPHX_THROW("sym::expr::eval_dim: map key '" + k.to_string() + "' is not a symbol"); + MIGRAPHX_THROW("sym::expr::eval_uint: map key '" + k.to_string() + "' is not a symbol"); bindings[k.p->node] = static_cast(v); } auto v = eval_direct(p->node, bindings); - assert(v >= 0 and "symbolic dimension evaluated to negative value"); + if(v < 0) + MIGRAPHX_THROW("sym::expr::eval_uint: expression evaluated to negative value"); return static_cast(v); } @@ -867,7 +938,7 @@ expr operator/(const expr& a, const expr& b) return {}; auto ea = a.p ? a.p->node : make_integer(0); auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_floor_div(ea, eb))}; + return {std::make_shared(make_trunc_div(ea, eb))}; } bool operator==(const expr& a, const expr& b) @@ -888,13 +959,18 @@ std::ostream& operator<<(std::ostream& os, const expr& e) return os; } -expr var(const std::string& name) { return {std::make_shared(make_symbol(name))}; } +expr var(const std::string& name) +{ + if(name.empty()) + MIGRAPHX_THROW("sym::var: variable name must not be empty"); + return {std::make_shared(make_symbol(name))}; +} expr lit(int64_t n) { return {std::make_shared(make_integer(n))}; } expr parse(const std::string& s) { - if(s.empty()) + if(s.find_first_not_of(" \t\n\r") == std::string::npos) return {}; return {std::make_shared(parse_string(s))}; } @@ -943,9 +1019,9 @@ static value node_to_value(const expr_ptr& e) r["factors"] = factors; return r; }, - [](const fdiv_data& d) -> value { + [](const tdiv_data& d) -> value { value r; - r["type"] = "fdiv"; + r["type"] = "tdiv"; r["num"] = node_to_value(d.numerator); r["den"] = node_to_value(d.denominator); return r; @@ -988,11 +1064,11 @@ static expr_ptr node_from_value(const value& v) } return build_mul(coefficient, std::move(factors)); } - else if(type == "fdiv") + else if(type == "tdiv") { auto num = node_from_value(v.at("num")); auto den = node_from_value(v.at("den")); - return make_floor_div(num, den); + return make_trunc_div(num, den); } MIGRAPHX_THROW("Unknown sym::expr node type: " + type); } diff --git a/test/sym_test.cpp b/test/sym_test.cpp index 7da6644b1bf..dc7361caa7a 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -55,6 +55,11 @@ TEST_CASE(construct_empty) EXPECT(e.to_string().empty()); } +TEST_CASE(construct_empty_var_name_throws) +{ + EXPECT(test::throws([&] { var(""); })); +} + TEST_CASE(add_identity) { auto h = var("h"); @@ -192,34 +197,34 @@ TEST_CASE(mul_distributive) EXPECT(r == 2 * h + 2); } -TEST_CASE(fdiv_identity) +TEST_CASE(tdiv_identity) { auto h = var("h"); EXPECT(h / 1 == h); } -TEST_CASE(fdiv_constant_folding) +TEST_CASE(tdiv_constant_folding) { EXPECT(lit(7) / lit(2) == lit(3)); EXPECT(lit(6) / lit(3) == lit(2)); EXPECT(lit(0) / lit(5) == lit(0)); } -TEST_CASE(fdiv_exact_coefficient_cancel) +TEST_CASE(tdiv_exact_coefficient_cancel) { auto n = var("n"); auto r = (6 * n) / 3; EXPECT(r == 2 * n); } -TEST_CASE(fdiv_non_simplifiable) +TEST_CASE(tdiv_non_simplifiable) { auto h = var("h"); auto r = (h - 1) / 2; EXPECT(r == (h - 1) / 2); } -TEST_CASE(fdiv_division_by_zero) +TEST_CASE(tdiv_division_by_zero) { EXPECT(test::throws([&] { var("h") / 0; })); } @@ -328,19 +333,75 @@ TEST_CASE(add_multi_term_accumulation) EXPECT(r == expected); } -TEST_CASE(fdiv_negative_constant_folding) +TEST_CASE(tdiv_negative_constant_folding) { EXPECT(lit(-7) / lit(2) == lit(-7 / 2)); EXPECT(lit(-6) / lit(3) == lit(-2)); EXPECT(lit(7) / lit(-2) == lit(7 / -2)); } -TEST_CASE(fdiv_large_constants) +TEST_CASE(tdiv_large_constants) { EXPECT(lit(1000000) / lit(1000) == lit(1000)); EXPECT(lit(999999) / lit(1000) == lit(999)); } +TEST_CASE(tdiv_zero_numerator) +{ + auto h = var("h"); + EXPECT(lit(0) / h == lit(0)); + EXPECT(lit(0) / (h + 1) == lit(0)); +} + +TEST_CASE(tdiv_self) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(h / h == lit(1)); + EXPECT((h + 1) / (h + 1) == lit(1)); + EXPECT((h * w) / (h * w) == lit(1)); +} + +TEST_CASE(tdiv_cancel_symbolic_factor) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(2 * h / h == lit(2)); + EXPECT(h * w / h == w); + EXPECT(h * w / w == h); + EXPECT(3 * h * w / h == 3 * w); + EXPECT(3 * h * w / (h * w) == lit(3)); + EXPECT(h * 6 * w / (3 * w) == 2 * h); + EXPECT(h * h * w / (h * w) == h); +} + +TEST_CASE(tdiv_cancel_partial) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(5 * h * w / (2 * h) == 5 * w / lit(2)); + EXPECT(h * h * w / (2 * h) == h * w / lit(2)); +} + +TEST_CASE(tdiv_cancel_cross_factor) +{ + auto h = var("h"); + auto w = var("w"); + auto c = var("c"); + EXPECT(h * w / (h * c) == w / c); + EXPECT(h * w / (h * h) == w / h); +} + +TEST_CASE(tdiv_distribute_over_sum) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT((2 * h + 4) / 2 == h + 2); + EXPECT((6 * h + 3 * w + 9) / 3 == 2 * h + w + 3); + EXPECT((4 * h + 2) / 2 == 2 * h + 1); + EXPECT((2 * h + 3) / 2 != h); +} + // =================================================================== // Tier 2: Equality and hashing // =================================================================== @@ -361,6 +422,28 @@ TEST_CASE(eq_empty) EXPECT(lit(0) != se{}); } +TEST_CASE(hash_consistency) +{ + auto h = var("h"); + auto w = var("w"); + auto n = var("n"); + + auto check = [](const se& a, const se& b) { + EXPECT(a == b); + EXPECT(a.hash() == b.hash()); + }; + + check(h + w, w + h); + check(h * w, w * h); + check(2 * h + 3, 3 + 2 * h); + check(h * w * n, n * h * w); + check((h + 1) * 3, 3 * (h + 1)); + check((h - 1) / 2, (h - 1) / 2); + check(h + 0, h); + check(h * 1, h); + check(lit(5), lit(5)); +} + // =================================================================== // Tier 3: Evaluation and substitution // =================================================================== @@ -368,24 +451,24 @@ TEST_CASE(eq_empty) TEST_CASE(eval_simple) { auto h = var("h"); - EXPECT(h.eval_dim({{h, 26}}) == 26); - EXPECT(lit(42).eval_dim({}) == 42); + EXPECT(h.eval_uint({{h, 26}}) == 26); + EXPECT(lit(42).eval_uint({}) == 42); } TEST_CASE(eval_arithmetic) { auto h = var("h"); - EXPECT((h - 3).eval_dim({{h, 26}}) == 23); - EXPECT((h + 5).eval_dim({{h, 10}}) == 15); - EXPECT((2 * h).eval_dim({{h, 13}}) == 26); + EXPECT((h - 3).eval_uint({{h, 26}}) == 23); + EXPECT((h + 5).eval_uint({{h, 10}}) == 15); + EXPECT((2 * h).eval_uint({{h, 13}}) == 26); } TEST_CASE(eval_compound) { auto h = var("h"); auto e = (h - 3) / 2 + 1; - EXPECT(e.eval_dim({{h, 26}}) == 12); - EXPECT(e.eval_dim({{h, 27}}) == 13); + EXPECT(e.eval_uint({{h, 26}}) == 12); + EXPECT(e.eval_uint({{h, 27}}) == 13); } TEST_CASE(eval_multiple_symbols) @@ -393,37 +476,51 @@ TEST_CASE(eval_multiple_symbols) auto n = var("n"); auto h = var("h"); auto e = n * h; - EXPECT(e.eval_dim({{n, 4}, {h, 26}}) == 104); + EXPECT(e.eval_uint({{n, 4}, {h, 26}}) == 104); } -TEST_CASE(eval_floor_division) +TEST_CASE(eval_trunc_division) { auto h = var("h"); auto e = (h - 1) / 2; - EXPECT(e.eval_dim({{h, 7}}) == 3); - EXPECT(e.eval_dim({{h, 8}}) == 3); - EXPECT(e.eval_dim({{h, 9}}) == 4); + EXPECT(e.eval_uint({{h, 7}}) == 3); + EXPECT(e.eval_uint({{h, 8}}) == 3); + EXPECT(e.eval_uint({{h, 9}}) == 4); } TEST_CASE(eval_unbound_throws) { auto h = var("h"); auto w = var("w"); - EXPECT(test::throws([&] { h.eval_dim({}); })); - EXPECT(test::throws([&] { (h + w).eval_dim({{h, 1}}); })); + EXPECT(test::throws([&] { h.eval_uint({}); })); + EXPECT(test::throws([&] { (h + w).eval_uint({{h, 1}}); })); } TEST_CASE(eval_division_by_zero_throws) { auto h = var("h"); auto d = var("d"); - EXPECT(test::throws([&] { (h / d).eval_dim({{h, 10}, {d, 0}}); })); + EXPECT(test::throws([&] { (h / d).eval_uint({{h, 10}, {d, 0}}); })); } TEST_CASE(eval_integer_expr) { - EXPECT(lit(0).eval_dim({}) == 0); - EXPECT(lit(100).eval_dim({}) == 100); + EXPECT(lit(0).eval_uint({}) == 0); + EXPECT(lit(100).eval_uint({}) == 100); +} + +TEST_CASE(eval_non_symbol_key_throws) +{ + auto h = var("h"); + EXPECT(test::throws([&] { h.eval_uint({{lit(5), 10}}); })); + EXPECT(test::throws([&] { h.eval_uint({{h + 1, 10}}); })); +} + +TEST_CASE(subs_non_symbol_key_throws) +{ + auto h = var("h"); + EXPECT(test::throws([&] { h.subs({{h + 1, lit(5)}}); })); + EXPECT(test::throws([&] { h.subs({{lit(3), lit(5)}}); })); } TEST_CASE(subs_partial) @@ -433,7 +530,7 @@ TEST_CASE(subs_partial) auto e = n * h + 1; auto r = e.subs({{n, lit(4)}}); EXPECT(r == 4 * h + 1); - EXPECT(r.eval_dim({{h, 10}}) == 41); + EXPECT(r.eval_uint({{h, 10}}) == 41); } TEST_CASE(subs_full) @@ -450,7 +547,7 @@ TEST_CASE(subs_none) EXPECT(h.subs({}) == h); } -TEST_CASE(subs_floor_div) +TEST_CASE(subs_trunc_div) { auto h = var("h"); auto e = (h - 1) / 2; @@ -458,6 +555,14 @@ TEST_CASE(subs_floor_div) EXPECT(r == lit(3)); } +TEST_CASE(subs_division_by_zero) +{ + auto h = var("h"); + auto d = var("d"); + auto e = h / d; + EXPECT(test::throws([&] { e.subs({{d, lit(0)}}); })); +} + // eval() and subs()+eval() must agree on a compound expression TEST_CASE(subs_eval_cross_validation) { @@ -466,8 +571,8 @@ TEST_CASE(subs_eval_cross_validation) auto e = (n * h - 3) / 2 + 1; std::unordered_map eval_map = {{n, 4}, {h, 26}}; std::unordered_map subs_map = {{n, lit(4)}, {h, lit(26)}}; - auto via_eval = e.eval_dim(eval_map); - auto via_subs = e.subs(subs_map).eval_dim({}); + auto via_eval = e.eval_uint(eval_map); + auto via_subs = e.subs(subs_map).eval_uint({}); EXPECT(via_eval == via_subs); } @@ -515,9 +620,9 @@ TEST_CASE(subs_compound_expression) // N*H => (W-1)*(2*W+1) = 2*W^2 - W - 1 // N*H + W - 3 => 2*W^2 - 2*W - 4 + W = 2*W^2 - 2 // Verify by evaluating with W=5: (W-1)*(2*W+1) + W - 3 = 4*11 + 5 - 3 = 46, 46/2 = 23 - EXPECT(r.eval_dim({{w, 5}}) == 23); + EXPECT(r.eval_uint({{w, 5}}) == 23); // Also verify the original expression with direct values agrees - EXPECT(e.eval_dim({{n, 4}, {h, 11}, {w, 5}}) == 23); + EXPECT(e.eval_uint({{n, 4}, {h, 11}, {w, 5}}) == 23); } TEST_CASE(eval_compound_product) @@ -525,14 +630,14 @@ TEST_CASE(eval_compound_product) auto h = var("h"); auto w = var("w"); auto e = h * w + 1; - EXPECT(e.eval_dim({{h, 3}, {w, 4}}) == 13); + EXPECT(e.eval_uint({{h, 3}, {w, 4}}) == 13); } TEST_CASE(eval_negative_intermediate) { auto h = var("h"); auto e = (h - 10) * 2 + 20; - EXPECT(e.eval_dim({{h, 3}}) == 6); + EXPECT(e.eval_uint({{h, 3}}) == 6); } // =================================================================== @@ -561,7 +666,7 @@ TEST_CASE(print_mul) EXPECT(r.to_string() == "a*b"); } -TEST_CASE(print_fdiv_parens) +TEST_CASE(print_tdiv_parens) { auto r = (var("h") - 1) / 2; EXPECT(r.to_string() == "(h - 1)/2"); @@ -646,6 +751,49 @@ TEST_CASE(parse_power_operator) EXPECT(parse("(2*h)**3 + 5") == 8 * h * h * h + 5); } +TEST_CASE(parse_empty_string) +{ + EXPECT(parse("").empty()); + EXPECT(parse(" ").empty()); + EXPECT(parse("\t\n").empty()); +} + +TEST_CASE(parse_error_unexpected_char) +{ + EXPECT(test::throws([&] { parse(")"); })); + EXPECT(test::throws([&] { parse("@"); })); +} + +TEST_CASE(parse_error_trailing_chars) +{ + EXPECT(test::throws([&] { parse("1 2"); })); + EXPECT(test::throws([&] { parse("h w"); })); +} + +TEST_CASE(parse_error_unexpected_end) +{ + EXPECT(test::throws([&] { parse("h +"); })); + EXPECT(test::throws([&] { parse("h *"); })); + EXPECT(test::throws([&] { parse("-"); })); +} + +TEST_CASE(parse_error_power_non_integer_exponent) +{ + EXPECT(test::throws([&] { parse("h**h"); })); +} + +TEST_CASE(parse_error_power_negative_exponent) +{ + EXPECT(test::throws([&] { parse("h**-1"); })); +} + +TEST_CASE(parse_double_unary_minus) +{ + auto h = var("h"); + EXPECT(parse("--h") == h); + EXPECT(parse("--5") == lit(5)); +} + TEST_CASE(print_negative_mul_coefficient) { auto r = 0 - 3 * var("h"); @@ -708,7 +856,7 @@ TEST_CASE(edge_deeply_nested) se e = h; for(int i = 0; i < 5; ++i) e = (e - 1) / 2; - EXPECT(e.eval_dim({{h, 255}}) == 7); + EXPECT(e.eval_uint({{h, 255}}) == 7); } TEST_CASE(edge_many_symbols) @@ -719,7 +867,7 @@ TEST_CASE(edge_many_symbols) auto d = var("d"); auto e = var("e"); auto r = a + b + c + d + e; - EXPECT(r.eval_dim({{a, 1}, {b, 2}, {c, 3}, {d, 4}, {e, 5}}) == 15); + EXPECT(r.eval_uint({{a, 1}, {b, 2}, {c, 3}, {d, 4}, {e, 5}}) == 15); } TEST_CASE(edge_neg_one_coefficient) @@ -753,7 +901,7 @@ TEST_CASE(edge_large_coefficients) { auto h = var("h"); auto r = 1000000 * h; - EXPECT(r.eval_dim({{h, 1000000}}) == 1000000000000ULL); + EXPECT(r.eval_uint({{h, 1000000}}) == 1000000000000ULL); } // Incrementally adding H ten times must fold to 11*H @@ -819,13 +967,41 @@ TEST_CASE(serialize_mul) EXPECT(round_trip(e) == e); } -TEST_CASE(serialize_fdiv) +TEST_CASE(serialize_tdiv) { auto h = var("h"); auto e = (h - 1) / 2; EXPECT(round_trip(e) == e); } +TEST_CASE(serialize_negative_integer) +{ + EXPECT(round_trip(lit(-5)) == lit(-5)); + EXPECT(round_trip(lit(-1)) == lit(-1)); +} + +TEST_CASE(serialize_power) +{ + auto h = var("h"); + EXPECT(round_trip(h * h) == h * h); + EXPECT(round_trip(h * h * h) == h * h * h); +} + +TEST_CASE(serialize_negative_coefficient) +{ + auto h = var("h"); + EXPECT(round_trip(0 - 3 * h) == 0 - 3 * h); + EXPECT(round_trip(0 - h) == 0 - h); +} + +TEST_CASE(serialize_tdiv_symbolic_denominator) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(round_trip(h / w) == h / w); + EXPECT(round_trip((h + 1) / w) == (h + 1) / w); +} + TEST_CASE(serialize_compound) { auto n = var("n"); From 83da044b138379e87fbb7d66a97ffdd161da57da Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 15:55:20 -0700 Subject: [PATCH 22/60] license --- src/targets/gpu/hip_gemm_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 0c2452c3708..33e2b4e5e7f 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From d680d558d5ea4f75d522d7b35b8174fcb6367043 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 16:30:03 -0700 Subject: [PATCH 23/60] reduce complexity --- src/sym.cpp | 117 +++++++++++++++++++++++++++++----------------------- 1 file changed, 66 insertions(+), 51 deletions(-) diff --git a/src/sym.cpp b/src/sym.cpp index e066324c0f0..1a5ff05fa60 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -421,6 +421,63 @@ static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) return build_mul(coefficient, std::move(factors)); } +static expr_ptr try_cancel_single(const mul_data& da, const expr_ptr& b) +{ + auto it = da.factors.find(b); + if(it == da.factors.end()) + return nullptr; + factor_map reduced = da.factors; + if(it->second == 1) + reduced.erase(it->first); + else + reduced[it->first] = it->second - 1; + return build_mul(da.coefficient, std::move(reduced)); +} + +static expr_ptr try_div_int_over_add(const add_data& d, int64_t den) +{ + if(d.constant % den != 0) + return nullptr; + bool all_divisible = std::all_of( + d.terms.begin(), d.terms.end(), [&](const auto& p) { return p.second % den == 0; }); + if(not all_divisible) + return nullptr; + term_map divided = d.terms; + for(auto& [base, coeff] : divided) + coeff /= den; + return build_add(d.constant / den, std::move(divided)); +} + +static expr_ptr +try_cancel_factors(const mul_data& da, const mul_data& db, const expr_ptr& a, const expr_ptr& b) +{ + factor_map reduced_num = da.factors; + factor_map reduced_den = db.factors; + for(auto it_den = reduced_den.begin(); it_den != reduced_den.end();) + { + auto it_num = reduced_num.find(it_den->first); + if(it_num == reduced_num.end()) + { + ++it_den; + continue; + } + int64_t cancel = std::min(it_num->second, it_den->second); + it_num->second -= cancel; + it_den->second -= cancel; + if(it_num->second == 0) + reduced_num.erase(it_num); + if(it_den->second == 0) + it_den = reduced_den.erase(it_den); + else + ++it_den; + } + auto new_num = build_mul(da.coefficient, std::move(reduced_num)); + auto new_den = build_mul(db.coefficient, std::move(reduced_den)); + if(not expr_equal(new_num, a) or not expr_equal(new_den, b)) + return make_trunc_div(new_num, new_den); + return nullptr; +} + static expr_ptr make_trunc_div(const expr_ptr& a, const expr_ptr& b) { if(holds(a) and get_integer(a) == 0) @@ -446,68 +503,26 @@ static expr_ptr make_trunc_div(const expr_ptr& a, const expr_ptr& b) } if(holds(a)) { - const auto& d = get_add(a); - bool all_divisible = (d.constant % den == 0); - if(all_divisible) - { - all_divisible = std::all_of(d.terms.begin(), d.terms.end(), [&](const auto& p) { - return p.second % den == 0; - }); - } - if(all_divisible) - { - term_map divided = d.terms; - for(auto& [base, coeff] : divided) - coeff /= den; - return build_add(d.constant / den, std::move(divided)); - } + auto r = try_div_int_over_add(get_add(a), den); + if(r != nullptr) + return r; } } if(holds(a)) { const auto& da = get_mul(a); - if(holds(b)) { - const auto& db = get_mul(b); - factor_map reduced_num = da.factors; - factor_map reduced_den = db.factors; - for(auto it_den = reduced_den.begin(); it_den != reduced_den.end();) - { - auto it_num = reduced_num.find(it_den->first); - if(it_num == reduced_num.end()) - { - ++it_den; - continue; - } - int64_t cancel = std::min(it_num->second, it_den->second); - it_num->second -= cancel; - it_den->second -= cancel; - if(it_num->second == 0) - reduced_num.erase(it_num); - if(it_den->second == 0) - it_den = reduced_den.erase(it_den); - else - ++it_den; - } - auto new_num = build_mul(da.coefficient, std::move(reduced_num)); - auto new_den = build_mul(db.coefficient, std::move(reduced_den)); - if(not expr_equal(new_num, a) or not expr_equal(new_den, b)) - return make_trunc_div(new_num, new_den); + auto r = try_cancel_factors(da, get_mul(b), a, b); + if(r != nullptr) + return r; } else { - auto it = da.factors.find(b); - if(it != da.factors.end()) - { - factor_map reduced = da.factors; - if(it->second == 1) - reduced.erase(it->first); - else - reduced[it->first] = it->second - 1; - return build_mul(da.coefficient, std::move(reduced)); - } + auto r = try_cancel_single(da, b); + if(r != nullptr) + return r; } } From e7ca1d602b5e243ad9a55f92caf79dc73c0cb686 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 16:32:53 -0700 Subject: [PATCH 24/60] update calls to eval_uint --- src/shape.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 1a52ec01041..e92cc1d483c 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -734,7 +734,7 @@ shape shape::to_static(const std::unordered_map& symbol_ if(dd.is_fixed()) return dd.min; if(dd.sym_expr) - return dd.sym_expr->eval_dim(symbol_map); + return dd.sym_expr->eval_uint(symbol_map); MIGRAPHX_THROW("to_static: non-fixed dimension has no symbolic expression"); }); const auto& ds = this->dyn_strides(); @@ -742,7 +742,7 @@ shape shape::to_static(const std::unordered_map& symbol_ return {type(), static_lens}; std::vector static_strides(ds.size()); std::transform(ds.cbegin(), ds.cend(), static_strides.begin(), [&](const auto& s) { - return s.eval_dim(symbol_map); + return s.eval_uint(symbol_map); }); return {type(), static_lens, static_strides}; } From 54debb56254bcf458ac39e73d46e6b187ebf36b0 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 30 Mar 2026 17:31:06 -0700 Subject: [PATCH 25/60] clean up test file --- test/shape_test.cpp | 204 +++++++++++++++++++++++--------------------- 1 file changed, 107 insertions(+), 97 deletions(-) diff --git a/test/shape_test.cpp b/test/shape_test.cpp index d6b7e6b3b7d..98b155e8a84 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1285,160 +1285,167 @@ TEST_CASE(shape_same_lens_static_dynamic) TEST_CASE(test_dd_symbolic_add_size_t) { - auto N = var("N"); - migraphx::shape::dynamic_dimension dd{1, 8, {4}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{1, 8, {4}, n}; dd += 2; EXPECT(dd.min == 3); EXPECT(dd.max == 10); - EXPECT(*dd.sym_expr == N + 2); + EXPECT(*dd.sym_expr == n + 2); } TEST_CASE(test_dd_symbolic_sub_size_t) { - auto N = var("N"); - migraphx::shape::dynamic_dimension dd{3, 8, {4}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{3, 8, {4}, n}; dd -= 1; EXPECT(dd.min == 2); EXPECT(dd.max == 7); - EXPECT(*dd.sym_expr == N - 1); + EXPECT(*dd.sym_expr == n - 1); } TEST_CASE(test_dd_symbolic_mul_size_t) { - auto N = var("N"); - migraphx::shape::dynamic_dimension dd{1, 8, {4}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{1, 8, {4}, n}; dd *= 3; EXPECT(dd.min == 3); EXPECT(dd.max == 24); - EXPECT(*dd.sym_expr == N * 3); + EXPECT(*dd.sym_expr == n * 3); } TEST_CASE(test_dd_symbolic_div_size_t) { - auto N = var("N"); - migraphx::shape::dynamic_dimension dd{4, 16, {8}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{4, 16, {8}, n}; dd /= 2; EXPECT(dd.min == 2); EXPECT(dd.max == 8); - EXPECT(*dd.sym_expr == N / 2); + EXPECT(*dd.sym_expr == n / 2); } TEST_CASE(test_dd_symbolic_add_dd) { - auto N = var("N"), C = var("C"); - migraphx::shape::dynamic_dimension a{1, 8, {}, N}; - migraphx::shape::dynamic_dimension b{2, 4, {}, C}; - auto c = a + b; - EXPECT(c.min == 3); - EXPECT(c.max == 12); - EXPECT(*c.sym_expr == N + C); + auto n = var("n"); + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto r = a + b; + EXPECT(r.min == 3); + EXPECT(r.max == 12); + EXPECT(*r.sym_expr == n + c); } TEST_CASE(test_dd_symbolic_sub_dd) { - auto N = var("N"), K = var("K"); - migraphx::shape::dynamic_dimension a{4, 16, {}, N}; - migraphx::shape::dynamic_dimension b{1, 4, {}, K}; - auto c = a - b; - EXPECT(c.min == 0); - EXPECT(c.max == 15); - EXPECT(*c.sym_expr == N - K); + auto n = var("n"); + auto k = var("k"); + migraphx::shape::dynamic_dimension a{4, 16, {}, n}; + migraphx::shape::dynamic_dimension b{1, 4, {}, k}; + auto r = a - b; + EXPECT(r.min == 0); + EXPECT(r.max == 15); + EXPECT(*r.sym_expr == n - k); } TEST_CASE(test_dd_symbolic_mul_dd) { - auto N = var("N"), C = var("C"); - migraphx::shape::dynamic_dimension a{1, 8, {}, N}; - migraphx::shape::dynamic_dimension b{2, 4, {}, C}; - auto c = a * b; - EXPECT(c.min == 2); - EXPECT(c.max == 32); - EXPECT(*c.sym_expr == N * C); + auto n = var("n"); + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto r = a * b; + EXPECT(r.min == 2); + EXPECT(r.max == 32); + EXPECT(*r.sym_expr == n * c); } TEST_CASE(test_dd_symbolic_div_dd) { - auto N = var("N"), K = var("K"); - migraphx::shape::dynamic_dimension a{4, 16, {}, N}; - migraphx::shape::dynamic_dimension b{2, 4, {}, K}; - auto c = a / b; - EXPECT(c.min == 1); - EXPECT(c.max == 8); - EXPECT(*c.sym_expr == N / K); + auto n = var("n"); + auto k = var("k"); + migraphx::shape::dynamic_dimension a{4, 16, {}, n}; + migraphx::shape::dynamic_dimension b{2, 4, {}, k}; + auto r = a / b; + EXPECT(r.min == 1); + EXPECT(r.max == 8); + EXPECT(*r.sym_expr == n / k); } TEST_CASE(test_dd_symbolic_plus_fixed) { - auto N = var("N"); - migraphx::shape::dynamic_dimension a{1, 8, {}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; migraphx::shape::dynamic_dimension b{3, 3}; - auto c = a + b; - EXPECT(c.sym_expr.has_value()); - EXPECT(*c.sym_expr == N + 3); - EXPECT(c.min == 4); - EXPECT(c.max == 11); + auto r = a + b; + EXPECT(r.sym_expr.has_value()); + EXPECT(*r.sym_expr == n + 3); + EXPECT(r.min == 4); + EXPECT(r.max == 11); } TEST_CASE(test_dd_nonfixed_nonsymbolic_plus_symbolic_drops_sym) { - auto C = var("C"); + auto c = var("c"); migraphx::shape::dynamic_dimension a{1, 8, {}}; - migraphx::shape::dynamic_dimension b{2, 4, {}, C}; - auto c = a + b; - EXPECT(not c.sym_expr.has_value()); - EXPECT(c.min == 3); - EXPECT(c.max == 12); + migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto r = a + b; + EXPECT(not r.sym_expr.has_value()); + EXPECT(r.min == 3); + EXPECT(r.max == 12); } TEST_CASE(test_dd_nonsymbolic_remains_nonsymbolic) { migraphx::shape::dynamic_dimension a{1, 8, {}}; migraphx::shape::dynamic_dimension b{2, 4, {}}; - auto c = a + b; - EXPECT(not c.sym_expr.has_value()); + auto r = a + b; + EXPECT(not r.sym_expr.has_value()); } TEST_CASE(test_dd_equality_with_sym) { - auto N = var("N"), C = var("C"); - migraphx::shape::dynamic_dimension a{1, 8, {}, N}; - migraphx::shape::dynamic_dimension b{1, 8, {}, N}; - migraphx::shape::dynamic_dimension c{1, 8, {}, C}; + auto n = var("n"); + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{1, 8, {}, n}; + migraphx::shape::dynamic_dimension d2{1, 8, {}, c}; migraphx::shape::dynamic_dimension d{1, 8, {}}; EXPECT(a == b); - EXPECT(a != c); + EXPECT(a != d2); EXPECT(a != d); } TEST_CASE(test_symbolic_shape_construction) { - auto N = var("N"); - migraphx::shape s{migraphx::shape::float_type, - {{1, 8, {}, N}, {3, 3}, {224, 224}}, - {N * lit(3) * lit(224), lit(224), lit(1)}}; - EXPECT(s.dynamic()); - EXPECT(s.symbolic()); - EXPECT(s.dyn_dims().size() == 3); - EXPECT(s.dyn_strides().size() == 3); + auto n = var("n"); + migraphx::shape sh{migraphx::shape::float_type, + {{1, 8, {}, n}, {3, 3}, {224, 224}}, + {n * lit(3) * lit(224), lit(224), lit(1)}}; + EXPECT(sh.dynamic()); + EXPECT(sh.symbolic()); + EXPECT(sh.dyn_dims().size() == 3); + EXPECT(sh.dyn_strides().size() == 3); } TEST_CASE(test_symbolic_stride_auto_compute) { - auto N = var("N"), S = var("S"); - migraphx::shape s{migraphx::shape::float_type, {{1, 8, {}, N}, {1, 16, {}, S}, {4, 4}}}; - EXPECT(s.symbolic()); - EXPECT(s.dyn_strides().size() == 3); - EXPECT(s.dyn_strides()[2] == lit(1)); - EXPECT(s.dyn_strides()[1] == lit(4)); - EXPECT(s.dyn_strides()[0] == S * 4); + auto n = var("n"); + auto s = var("s"); + migraphx::shape sh{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + EXPECT(sh.symbolic()); + EXPECT(sh.dyn_strides().size() == 3); + EXPECT(sh.dyn_strides()[2] == lit(1)); + EXPECT(sh.dyn_strides()[1] == lit(4)); + EXPECT(sh.dyn_strides()[0] == s * 4); } TEST_CASE(test_symbolic_to_static) { - auto N = var("N"), S = var("S"); - migraphx::shape s{migraphx::shape::float_type, {{1, 8, {}, N}, {1, 16, {}, S}, {4, 4}}}; - std::unordered_map symbol_map = {{N, 2}, {S, 8}}; - auto s_static = s.to_static(symbol_map); + auto n = var("n"); + auto s = var("s"); + migraphx::shape sh{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + std::unordered_map symbol_map = {{n, 2}, {s, 8}}; + auto s_static = sh.to_static(symbol_map); EXPECT(not s_static.dynamic()); EXPECT(s_static.lens() == std::vector{2, 8, 4}); EXPECT(s_static.strides() == std::vector{32, 4, 1}); @@ -1446,59 +1453,62 @@ TEST_CASE(test_symbolic_to_static) TEST_CASE(test_symbolic_shape_serialize) { - auto N = var("N"), S = var("S"); - migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, N}, {1, 16, {}, S}, {4, 4}}}; + auto n = var("n"); + auto s = var("s"); + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; auto v = migraphx::to_value(s1); auto s2 = migraphx::from_value(v); EXPECT(s1 == s2); EXPECT(s2.symbolic()); EXPECT(s2.dyn_strides().size() == 3); - EXPECT(s2.dyn_strides()[0] == S * 4); + EXPECT(s2.dyn_strides()[0] == s * 4); EXPECT(s2.dyn_strides()[2] == lit(1)); } TEST_CASE(test_symbolic_shape_equality) { - auto N = var("N"), C = var("C"); - migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, N}, {3, 3}}}; - migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, N}, {3, 3}}}; - migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, C}, {3, 3}}}; + auto n = var("n"); + auto c = var("c"); + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}}}; + migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, c}, {3, 3}}}; EXPECT(s1 == s2); EXPECT(s1 != s3); } TEST_CASE(test_symbolic_shape_print) { - auto N = var("N"), C = var("C"); - auto to_str = [](const migraphx::shape& s) { + auto n = var("n"); + auto c = var("c"); + auto to_str = [](const migraphx::shape& sh) { std::stringstream ss; - ss << s; + ss << sh; return ss.str(); }; - migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, N}, {3, 3}, {4, 4}}}; - migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, N}, {3, 3}, {4, 4}}}; - migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, C}, {3, 3}, {4, 4}}}; + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}, {4, 4}}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}, {4, 4}}}; + migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, c}, {3, 3}, {4, 4}}}; EXPECT(to_str(s1) == to_str(s2)); EXPECT(to_str(s1) != to_str(s3)); } TEST_CASE(test_dd_intersection_symbolic) { - auto N = var("N"); - migraphx::shape::dynamic_dimension a{1, 8, {}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; migraphx::shape::dynamic_dimension b{2, 6}; auto result = a.intersection(b); EXPECT(result.has_value()); EXPECT(result->min == 2); EXPECT(result->max == 6); EXPECT(result->sym_expr.has_value()); - EXPECT(*result->sym_expr == N); + EXPECT(*result->sym_expr == n); } TEST_CASE(test_dd_intersection_fixed_drops_sym) { - auto N = var("N"); - migraphx::shape::dynamic_dimension a{1, 8, {}, N}; + auto n = var("n"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; migraphx::shape::dynamic_dimension b{4, 4}; auto result = a.intersection(b); EXPECT(result.has_value()); From c7f698c91af210f4958bc41942bad222756fd30b Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 31 Mar 2026 12:40:05 -0700 Subject: [PATCH 26/60] review comments --- src/include/migraphx/functional.hpp | 8 +++ src/include/migraphx/sym.hpp | 27 ++++---- src/sym.cpp | 95 ++++++++--------------------- test/sym_test.cpp | 44 ++----------- 4 files changed, 52 insertions(+), 122 deletions(-) diff --git a/src/include/migraphx/functional.hpp b/src/include/migraphx/functional.hpp index 4050bdf1f06..50b8e36ec61 100644 --- a/src/include/migraphx/functional.hpp +++ b/src/include/migraphx/functional.hpp @@ -261,6 +261,14 @@ void nop(Ts&&...) { } +template +struct overloaded : Ts... +{ + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index efa87d8cd21..bdae9426ab1 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -40,6 +40,11 @@ struct value; namespace sym { +struct expr; +MIGRAPHX_EXPORT expr var(const std::string& name); +MIGRAPHX_EXPORT expr lit(int64_t n); +MIGRAPHX_EXPORT expr parse(const std::string& s); + struct MIGRAPHX_EXPORT expr { expr(); @@ -60,6 +65,15 @@ struct MIGRAPHX_EXPORT expr MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e); + friend expr operator+(const expr& a, int64_t b) { return a + lit(b); } + friend expr operator+(int64_t a, const expr& b) { return lit(a) + b; } + friend expr operator-(const expr& a, int64_t b) { return a - lit(b); } + friend expr operator-(int64_t a, const expr& b) { return lit(a) - b; } + friend expr operator*(const expr& a, int64_t b) { return a * lit(b); } + friend expr operator*(int64_t a, const expr& b) { return lit(a) * b; } + friend expr operator/(const expr& a, int64_t b) { return a / lit(b); } + friend expr operator/(int64_t a, const expr& b) { return lit(a) / b; } + struct impl; friend expr var(const std::string& name); @@ -71,19 +85,6 @@ struct MIGRAPHX_EXPORT expr std::shared_ptr p; }; -MIGRAPHX_EXPORT expr var(const std::string& name); -MIGRAPHX_EXPORT expr lit(int64_t n); -MIGRAPHX_EXPORT expr parse(const std::string& s); - -inline expr operator+(const expr& a, int64_t b) { return a + lit(b); } -inline expr operator+(int64_t a, const expr& b) { return lit(a) + b; } -inline expr operator-(const expr& a, int64_t b) { return a - lit(b); } -inline expr operator-(int64_t a, const expr& b) { return lit(a) - b; } -inline expr operator*(const expr& a, int64_t b) { return a * lit(b); } -inline expr operator*(int64_t a, const expr& b) { return lit(a) * b; } -inline expr operator/(const expr& a, int64_t b) { return a / lit(b); } -inline expr operator/(int64_t a, const expr& b) { return lit(a) / b; } - } // namespace sym } // namespace MIGRAPHX_INLINE_NS diff --git a/src/sym.cpp b/src/sym.cpp index 1a5ff05fa60..00d9480e810 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -83,14 +84,6 @@ struct tdiv_data using expr_data = std::variant; -template -struct overloaded : Ts... -{ - using Ts::operator()...; -}; -template -overloaded(Ts...) -> overloaded; - struct expr_node { expr_data data; @@ -130,7 +123,7 @@ static std::size_t hash_ordered_map(const Map& m) static std::size_t compute_hash(const expr_data& d) { - std::size_t h = std::hash{}(static_cast(d.index())); + std::size_t h = std::hash{}(d.index()); return std::visit( overloaded{ [&](const integer_data& p) { return hash_combine(h, std::hash{}(p.value)); }, @@ -674,13 +667,13 @@ static std::string print_mul(const mul_data& d, int parent_prec) } for(const auto& [base, exp] : d.factors) { - if(not first) - os << "*"; - if(exp == 1) + for(int64_t i = 0; i < exp; ++i) + { + if(not first) + os << "*"; os << print_expr(base, prec_mul + 1); - else - os << print_expr(base, prec_mul + 1) << "**" << exp; - first = false; + first = false; + } } std::string raw = os.str(); if(parent_prec > prec_mul) @@ -741,19 +734,6 @@ static expr_ptr parse_primary(const char*& p) name += *p; ++p; } - if(name == "floor") - { - skip_ws(p); - if(*p != '(') - MIGRAPHX_THROW("symbolic parser: expected '(' after 'floor'"); - ++p; - auto inner = parse_expr(p); - skip_ws(p); - if(*p != ')') - MIGRAPHX_THROW("symbolic parser: expected ')' after floor argument"); - ++p; - return inner; - } return make_symbol(name); } if(*p == '(') @@ -780,42 +760,21 @@ static expr_ptr parse_unary(const char*& p) return parse_primary(p); } -static expr_ptr parse_power(const char*& p) -{ - auto base = parse_unary(p); - skip_ws(p); - if(*p == '*' and *(p + 1) == '*') - { - p += 2; - auto exp_node = parse_unary(p); - if(not holds(exp_node)) - MIGRAPHX_THROW("symbolic parser: ** exponent must be an integer literal"); - auto exp = get_integer(exp_node); - if(exp < 0) - MIGRAPHX_THROW("symbolic parser: ** exponent must be non-negative"); - expr_ptr result = make_integer(1); - for(int64_t i = 0; i < exp; ++i) - result = make_mul(result, base); - return result; - } - return base; -} - static expr_ptr parse_term(const char*& p) { - auto left = parse_power(p); + auto left = parse_unary(p); for(;;) { skip_ws(p); if(*p == '*') { ++p; - left = make_mul(left, parse_power(p)); + left = make_mul(left, parse_unary(p)); } else if(*p == '/') { ++p; - left = make_trunc_div(left, parse_power(p)); + left = make_trunc_div(left, parse_unary(p)); } else break; @@ -898,12 +857,12 @@ std::size_t expr::eval_uint(const std::unordered_map& symbol_ { if(k.empty() or not holds(k.p->node)) MIGRAPHX_THROW("sym::expr::eval_uint: map key '" + k.to_string() + "' is not a symbol"); - bindings[k.p->node] = static_cast(v); + bindings[k.p->node] = v; } auto v = eval_direct(p->node, bindings); if(v < 0) MIGRAPHX_THROW("sym::expr::eval_uint: expression evaluated to negative value"); - return static_cast(v); + return v; } expr expr::subs(const std::unordered_map& symbol_map) const @@ -915,45 +874,39 @@ expr expr::subs(const std::unordered_map& symbol_map) const { if(k.empty() or not holds(k.p->node)) MIGRAPHX_THROW("sym::expr::subs: map key '" + k.to_string() + "' is not a symbol"); - bindings[k.p->node] = v.p ? v.p->node : make_integer(0); + if(v.empty()) + MIGRAPHX_THROW("sym::expr::subs: substitution value must not be empty"); + bindings[k.p->node] = v.p->node; } return {std::make_shared(substitute(p->node, bindings))}; } expr operator+(const expr& a, const expr& b) { - if(a.empty() and b.empty()) + if(a.empty() or b.empty()) return {}; - auto ea = a.p ? a.p->node : make_integer(0); - auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_add(ea, eb))}; + return {std::make_shared(make_add(a.p->node, b.p->node))}; } expr operator-(const expr& a, const expr& b) { - if(a.empty() and b.empty()) + if(a.empty() or b.empty()) return {}; - auto ea = a.p ? a.p->node : make_integer(0); - auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_sub(ea, eb))}; + return {std::make_shared(make_sub(a.p->node, b.p->node))}; } expr operator*(const expr& a, const expr& b) { - if(a.empty() and b.empty()) + if(a.empty() or b.empty()) return {}; - auto ea = a.p ? a.p->node : make_integer(0); - auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_mul(ea, eb))}; + return {std::make_shared(make_mul(a.p->node, b.p->node))}; } expr operator/(const expr& a, const expr& b) { - if(a.empty() and b.empty()) + if(a.empty() or b.empty()) return {}; - auto ea = a.p ? a.p->node : make_integer(0); - auto eb = b.p ? b.p->node : make_integer(0); - return {std::make_shared(make_trunc_div(ea, eb))}; + return {std::make_shared(make_trunc_div(a.p->node, b.p->node))}; } bool operator==(const expr& a, const expr& b) diff --git a/test/sym_test.cpp b/test/sym_test.cpp index dc7361caa7a..f822b0e773f 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -723,34 +723,12 @@ TEST_CASE(parse_unary_minus) EXPECT(parse("-(h + 1)") == -1 * h - 1); } -TEST_CASE(parse_floor_backward_compat) -{ - auto a = parse("floor((h-1)/2)"); - auto b = parse("(h-1)/2"); - EXPECT(a == b); - - auto c = parse("floor((h-1)/2) + 1"); - auto d = (var("h") - 1) / 2 + 1; - EXPECT(c == d); -} - TEST_CASE(parse_whitespace_tolerance) { EXPECT(parse(" h + 1 ") == parse("h + 1")); EXPECT(parse("h+1") == parse("h + 1")); } -TEST_CASE(parse_power_operator) -{ - auto h = var("h"); - EXPECT(parse("h**2") == h * h); - EXPECT(parse("h**3") == h * h * h); - EXPECT(parse("h**1") == h); - EXPECT(parse("h**0") == lit(1)); - EXPECT(parse("2*h**2 + 1") == 2 * h * h + 1); - EXPECT(parse("(2*h)**3 + 5") == 8 * h * h * h + 5); -} - TEST_CASE(parse_empty_string) { EXPECT(parse("").empty()); @@ -777,16 +755,6 @@ TEST_CASE(parse_error_unexpected_end) EXPECT(test::throws([&] { parse("-"); })); } -TEST_CASE(parse_error_power_non_integer_exponent) -{ - EXPECT(test::throws([&] { parse("h**h"); })); -} - -TEST_CASE(parse_error_power_negative_exponent) -{ - EXPECT(test::throws([&] { parse("h**-1"); })); -} - TEST_CASE(parse_double_unary_minus) { auto h = var("h"); @@ -889,12 +857,12 @@ TEST_CASE(edge_empty_operations) TEST_CASE(edge_empty_with_nonempty) { se empty; - auto h = var("h"); - auto r1 = h + empty; - EXPECT(not r1.empty()); - - auto r2 = empty + h; - EXPECT(not r2.empty()); + auto h = var("h"); + EXPECT((h + empty).empty()); + EXPECT((empty + h).empty()); + EXPECT((h - empty).empty()); + EXPECT((empty * h).empty()); + EXPECT((h / empty).empty()); } TEST_CASE(edge_large_coefficients) From 293b5d87199a4ed2e21918759dc64133d0b09cf2 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 31 Mar 2026 12:45:18 -0700 Subject: [PATCH 27/60] merge and tidy --- src/shape.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shape.cpp b/src/shape.cpp index e92cc1d483c..8c6a25daf28 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -686,7 +686,7 @@ shape shape::to_dynamic() const dstrides.reserve(ndim()); for(auto s : strides()) dstrides.push_back(sym::lit(s)); - return shape(type(), std::move(dims), std::move(dstrides)); + return {type(), std::move(dims), std::move(dstrides)}; } shape shape::to_static(std::size_t x) const From cc521e4a1016551c8934bccf820bb925bd22965b Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 31 Mar 2026 12:46:05 -0700 Subject: [PATCH 28/60] license --- src/include/migraphx/functional.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/functional.hpp b/src/include/migraphx/functional.hpp index 50b8e36ec61..a8ed0a34b15 100644 --- a/src/include/migraphx/functional.hpp +++ b/src/include/migraphx/functional.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 47ec30af6dd410c5a44a85bad804cec815463b2a Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 1 Apr 2026 09:35:34 -0700 Subject: [PATCH 29/60] fix style --- src/shape.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 8c6a25daf28..90f2f59952f 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -680,12 +680,14 @@ shape shape::to_dynamic() const } std::vector dims; dims.reserve(ndim()); - for(auto len : lens()) - dims.push_back(dynamic_dimension{len, len}); + std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { + return dynamic_dimension{len, len}; + }); std::vector dstrides; dstrides.reserve(ndim()); - for(auto s : strides()) - dstrides.push_back(sym::lit(s)); + std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) { + return sym::lit(s); + }); return {type(), std::move(dims), std::move(dstrides)}; } @@ -1224,8 +1226,10 @@ void migraphx_from_value(const value& v, shape& s) auto v_ds = v.at("dyn_strides"); std::vector dstrides; dstrides.reserve(v_ds.size()); - for(const auto& x : v_ds) - dstrides.push_back(from_value(x)); + std::transform(v_ds.begin(), + v_ds.end(), + std::back_inserter(dstrides), + [](const auto& x) { return from_value(x); }); s = shape(shape::parse_type(t), std::move(dyn_dims), std::move(dstrides)); } else From 59a7ef419ed046dc8ff462a48d6c511e486a042a Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 1 Apr 2026 11:42:01 -0700 Subject: [PATCH 30/60] normalize fixed dynamic dim representation --- src/include/migraphx/shape.hpp | 12 ++++- src/shape.cpp | 94 ++++++++-------------------------- test/shape_test.cpp | 10 ++-- 3 files changed, 37 insertions(+), 79 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 216a0ad1973..19ef85e1da3 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -103,10 +103,14 @@ struct MIGRAPHX_EXPORT shape optional sym_expr; dynamic_dimension() = default; - dynamic_dimension(std::size_t min_v, std::size_t max_v) : min(min_v), max(max_v) {} + dynamic_dimension(std::size_t min_v, std::size_t max_v) : min(min_v), max(max_v) + { + normalize_sym(); + } dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt) : min(min_v), max(max_v), optimals(std::move(opt)) { + normalize_sym(); } dynamic_dimension(std::size_t min_v, std::size_t max_v, @@ -114,6 +118,7 @@ struct MIGRAPHX_EXPORT shape optional s) : min(min_v), max(max_v), optimals(std::move(opt)), sym_expr(std::move(s)) { + normalize_sym(); } template @@ -127,6 +132,11 @@ struct MIGRAPHX_EXPORT shape bool is_fixed() const; bool is_symbolic() const { return sym_expr.has_value(); } + void normalize_sym() + { + if(is_fixed() and not is_symbolic()) + sym_expr = sym::lit(min); + } bool has_optimal() const; /** diff --git a/src/shape.cpp b/src/shape.cpp index 90f2f59952f..4cc791ef909 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -79,8 +79,9 @@ struct shape_impl shape_impl(shape::type_t t, std::vector dims) : m_type(t), m_dyn_dims(std::move(dims)) { - if(std::any_of( - m_dyn_dims.begin(), m_dyn_dims.end(), [](const auto& d) { return d.is_symbolic(); })) + if(not m_dyn_dims.empty() and std::all_of(m_dyn_dims.begin(), + m_dyn_dims.end(), + [](const auto& d) { return d.is_symbolic(); })) calculate_dyn_strides(); } @@ -777,12 +778,7 @@ const std::vector& shape::dyn_dims() const return impl->m_dyn_dims; } -bool shape::symbolic() const -{ - return std::any_of(impl->m_dyn_dims.begin(), impl->m_dyn_dims.end(), [](const auto& d) { - return d.is_symbolic(); - }); -} +bool shape::symbolic() const { return not impl->m_dyn_strides.empty(); } const std::vector& shape::dyn_strides() const { return impl->m_dyn_strides; } @@ -804,67 +800,22 @@ bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty() shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) { - if(is_symbolic()) - sym_expr = *sym_expr + sym::lit(x); - this->min += x; - this->max += x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (opt + x); }); - this->optimals = new_optimals; - return *this; + return *this += dynamic_dimension{x, x}; } shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x) { - if(is_symbolic()) - sym_expr = *sym_expr - sym::lit(x); - assert(this->min >= x); - assert(this->max >= x); - this->min -= x; - this->max -= x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { - assert(opt >= x); - return (opt - x); - }); - this->optimals = new_optimals; - return *this; + return *this -= dynamic_dimension{x, x}; } shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t& x) { - if(is_symbolic()) - sym_expr = *sym_expr * sym::lit(x); - this->min *= x; - this->max *= x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (opt * x); }); - this->optimals = new_optimals; - return *this; + return *this *= dynamic_dimension{x, x}; } shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const std::size_t& x) { - if(is_symbolic()) - sym_expr = *sym_expr / sym::lit(x); - this->min = (x == 0) ? 0 : this->min / x; - this->max = (x == 0) ? std::numeric_limits::max() : this->max / x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (x == 0) ? std::size_t{0} : opt / x; }); - this->optimals = new_optimals; - return *this; + return *this /= dynamic_dimension{x, x}; } bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) @@ -935,19 +886,10 @@ shape::dynamic_dimension operator/(const shape::dynamic_dimension& x, const std: return dd /= y; } -static optional get_sym(const shape::dynamic_dimension& dd) -{ - if(dd.sym_expr) - return dd.sym_expr; - if(dd.is_fixed()) - return sym::lit(dd.min); - return nullopt; -} - shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) { - auto lhs_sym = get_sym(*this); - auto rhs_sym = get_sym(x); + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; min = min + x.min; max = (max > std::numeric_limits::max() - x.max) ? std::numeric_limits::max() @@ -966,13 +908,14 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dyna optimals.clear(); } sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym + *rhs_sym) : nullopt; + normalize_sym(); return *this; } shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dynamic_dimension& x) { - auto lhs_sym = get_sym(*this); - auto rhs_sym = get_sym(x); + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; min = (min > x.max) ? min - x.max : 0; max = (max > x.min) ? max - x.min : 0; if(x.is_fixed()) @@ -989,13 +932,14 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dyna optimals.clear(); } sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym - *rhs_sym) : nullopt; + normalize_sym(); return *this; } shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto lhs_sym = get_sym(*this); - auto rhs_sym = get_sym(x); + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; min = min * x.min; max = (max > std::numeric_limits::max() / (x.max == 0 ? 1 : x.max)) ? std::numeric_limits::max() @@ -1014,13 +958,14 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dyna optimals.clear(); } sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym * *rhs_sym) : nullopt; + normalize_sym(); return *this; } shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dynamic_dimension& x) { - auto lhs_sym = get_sym(*this); - auto rhs_sym = get_sym(x); + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; min = (x.max == 0) ? 0 : min / x.max; max = (x.min == 0) ? std::numeric_limits::max() : max / x.min; if(x.is_fixed()) @@ -1037,6 +982,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dyna optimals.clear(); } sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym / *rhs_sym) : nullopt; + normalize_sym(); return *this; } diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 98b155e8a84..782a7dccba7 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -467,8 +467,9 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{ - migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}}); + sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, + {{3, 3}, {4, 4}, {5, 5}}, + {lit(20), lit(5), lit(1)}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } @@ -1505,7 +1506,7 @@ TEST_CASE(test_dd_intersection_symbolic) EXPECT(*result->sym_expr == n); } -TEST_CASE(test_dd_intersection_fixed_drops_sym) +TEST_CASE(test_dd_intersection_fixed_gets_lit) { auto n = var("n"); migraphx::shape::dynamic_dimension a{1, 8, {}, n}; @@ -1514,7 +1515,8 @@ TEST_CASE(test_dd_intersection_fixed_drops_sym) EXPECT(result.has_value()); EXPECT(result->min == 4); EXPECT(result->max == 4); - EXPECT(not result->sym_expr.has_value()); + EXPECT(result->sym_expr.has_value()); + EXPECT(*result->sym_expr == lit(4)); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From 95894db1b2f20eced803f5467954edeac324f6a6 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 1 Apr 2026 11:51:12 -0700 Subject: [PATCH 31/60] fmt --- test/shape_test.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 782a7dccba7..be5983dc3f7 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -467,9 +467,8 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, - {{3, 3}, {4, 4}, {5, 5}}, - {lit(20), lit(5), lit(1)}}); + sub_shapes1.push_back(migraphx::shape{ + migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } From be87f713e1da6dfe33ab8f8bbf6132bb8e4d56a1 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 1 Apr 2026 17:08:34 -0700 Subject: [PATCH 32/60] fix serialization and normalization --- src/shape.cpp | 4 ++++ src/sym.cpp | 10 +++++----- test/serialize_program.cpp | 21 +++++++++++++++++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 4cc791ef909..1063b6b98fd 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -113,6 +113,10 @@ struct shape_impl m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]}); } } + if(not m_dyn_dims.empty() and std::all_of(m_dyn_dims.begin(), + m_dyn_dims.end(), + [](const auto& d) { return d.is_symbolic(); })) + calculate_dyn_strides(); } shape_impl(const std::vector& subs) : m_type(shape::tuple_type), m_shapes(subs) {} diff --git a/src/sym.cpp b/src/sym.cpp index 00d9480e810..fd552980727 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -1002,7 +1002,7 @@ static expr_ptr node_from_value(const value& v) const auto& type = v.at("type").get_string(); if(type == "int") { - return make_integer(v.at("value").get_int64()); + return make_integer(v.at("value").to()); } else if(type == "sym") { @@ -1010,24 +1010,24 @@ static expr_ptr node_from_value(const value& v) } else if(type == "add") { - auto constant = v.at("constant").get_int64(); + auto constant = v.at("constant").to(); term_map terms; for(const auto& t : v.at("terms")) { auto term = node_from_value(t.at("expr")); - auto coeff = t.at("coeff").get_int64(); + auto coeff = t.at("coeff").to(); terms[term] = coeff; } return build_add(constant, std::move(terms)); } else if(type == "mul") { - auto coefficient = v.at("coeff").get_int64(); + auto coefficient = v.at("coeff").to(); factor_map factors; for(const auto& f : v.at("factors")) { auto base = node_from_value(f.at("expr")); - auto exp = f.at("exp").get_int64(); + auto exp = f.at("exp").to(); factors[base] = exp; } return build_mul(coefficient, std::move(factors)); diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index 951cd40c790..d0a5311d417 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -26,6 +26,7 @@ #include #include "test.hpp" #include +#include #include @@ -138,4 +139,24 @@ TEST_CASE(program_with_module) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(symbolic_shape_msgpack_roundtrip) +{ + using migraphx::shape; + using dd = shape::dynamic_dimension; + auto n = migraphx::sym::var("n"); + + migraphx::program p; + auto* mm = p.get_main_module(); + shape s{shape::float_type, {dd{1, 8, {}, n}, {3, 3}, {4, 4}}}; + auto x = mm->add_parameter("x", s); + auto r = mm->add_instruction(migraphx::make_op("relu"), x); + mm->add_return({r}); + + migraphx::file_options options; + options.format = "msgpack"; + std::vector buffer = migraphx::save_buffer(p, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 78f06c715650bc59541e3d9f9f28dca77aca4b87 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 1 Apr 2026 17:14:57 -0700 Subject: [PATCH 33/60] license --- test/serialize_program.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index d0a5311d417..50071b028a7 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 1a68619254427490ee35b993e6398585d14d4055 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 7 Apr 2026 14:03:12 -0700 Subject: [PATCH 34/60] address review comments --- src/include/migraphx/shape.hpp | 19 ++++---- src/shape.cpp | 80 +++++++++++++++++++++------------- 2 files changed, 59 insertions(+), 40 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 19ef85e1da3..0500ca3786f 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -141,22 +141,21 @@ struct MIGRAPHX_EXPORT shape /** * Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if - * possible. Preserves the symbolic expression only when the result is still dynamic. + * possible. When both dimensions are symbolic, they are compatible only if they + * share the same symbolic expression. */ std::optional intersection(const dynamic_dimension& other) const { + if(this->is_symbolic() and other.is_symbolic()) + { + if(this->sym_expr == other.sym_expr) + return *this; + return nullopt; + } auto left = std::max(this->min, other.min); auto right = std::min(this->max, other.max); if(left <= right) - { - optional s; - if(left != right) - { - s = (this->sym_expr.has_value() and not this->is_fixed()) ? this->sym_expr - : other.sym_expr; - } - return dynamic_dimension{left, right, {}, s}; - } + return dynamic_dimension{left, right}; return nullopt; } diff --git a/src/shape.cpp b/src/shape.cpp index 1063b6b98fd..202fd18dcb7 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -90,6 +90,7 @@ struct shape_impl std::vector dstrides) : m_type(t), m_dyn_dims(std::move(dims)), m_dyn_strides(std::move(dstrides)) { + assert(m_dyn_strides.size() == m_dyn_dims.size()); } shape_impl(shape::type_t t, @@ -130,35 +131,42 @@ struct shape_impl std::vector m_dyn_dims = {}; std::vector m_dyn_strides = {}; - void calculate_dyn_strides() + std::vector sym_dim_exprs() const { - m_dyn_strides.clear(); - if(m_dyn_dims.empty()) - return; - m_dyn_strides.resize(m_dyn_dims.size()); - m_dyn_strides.back() = sym::lit(1); - std::transform(m_dyn_dims.rbegin(), - m_dyn_dims.rend() - 1, - m_dyn_strides.rbegin(), - m_dyn_strides.rbegin() + 1, - [](const auto& dd, const auto& stride) { - return dd.sym_expr.value_or(sym::lit(dd.min)) * stride; - }); + std::vector result(m_dyn_dims.size()); + std::transform(m_dyn_dims.begin(), m_dyn_dims.end(), result.begin(), [](const auto& dd) { + return dd.sym_expr.value_or(sym::expr{}); + }); + return result; } - void calculate_strides() + template + static T make_identity(int64_t n) { - m_strides.clear(); - m_strides.resize(m_lens.size(), 0); - if(m_strides.empty()) - return; - m_strides.back() = 1; - std::partial_sum(m_lens.rbegin(), - m_lens.rend() - 1, - m_strides.rbegin() + 1, - std::multiplies()); + if constexpr(std::is_same_v) + return sym::lit(n); + else + return T(n); } + template + static std::vector compute_strides(const std::vector& dims) + { + std::vector strides(dims.size()); + if(strides.empty()) + return strides; + strides.back() = make_identity(1); + std::partial_sum(dims.rbegin(), + dims.rend() - 1, + strides.rbegin() + 1, + [](const auto& a, const auto& b) { return b * a; }); + return strides; + } + + void calculate_dyn_strides() { m_dyn_strides = compute_strides(sym_dim_exprs()); } + + void calculate_strides() { m_strides = compute_strides(m_lens); } + std::size_t element_space() const { if(not m_dyn_dims.empty()) @@ -782,7 +790,13 @@ const std::vector& shape::dyn_dims() const return impl->m_dyn_dims; } -bool shape::symbolic() const { return not impl->m_dyn_strides.empty(); } +bool shape::symbolic() const +{ + return not impl->m_dyn_dims.empty() and + std::all_of(impl->m_dyn_dims.begin(), impl->m_dyn_dims.end(), [](const auto& dd) { + return dd.is_symbolic(); + }); +} const std::vector& shape::dyn_strides() const { return impl->m_dyn_strides; } @@ -819,6 +833,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const std::size_t& x) { + assert(x != 0); return *this /= dynamic_dimension{x, x}; } @@ -942,12 +957,17 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dyna shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - min = min * x.min; - max = (max > std::numeric_limits::max() / (x.max == 0 ? 1 : x.max)) - ? std::numeric_limits::max() - : max * x.max; + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + min = min * x.min; + auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { + if(b == 0) + return 0; + if(a > std::numeric_limits::max() / b) + return std::numeric_limits::max(); + return a * b; + }; + max = safe_mul(max, x.max); if(x.is_fixed()) { std::set new_optimals; From d6c8d49e73b25206735031e18fc068f12aed6f6d Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 7 Apr 2026 15:36:44 -0700 Subject: [PATCH 35/60] update tests for cleaned up intersection logic --- test/shape_test.cpp | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/test/shape_test.cpp b/test/shape_test.cpp index be5983dc3f7..abed97a30db 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1492,30 +1492,36 @@ TEST_CASE(test_symbolic_shape_print) EXPECT(to_str(s1) != to_str(s3)); } -TEST_CASE(test_dd_intersection_symbolic) +TEST_CASE(dd_intersection_symbolic_with_range) { auto n = var("n"); - migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension a{1, 32, {}, n}; migraphx::shape::dynamic_dimension b{2, 6}; auto result = a.intersection(b); EXPECT(result.has_value()); EXPECT(result->min == 2); EXPECT(result->max == 6); - EXPECT(result->sym_expr.has_value()); - EXPECT(*result->sym_expr == n); + EXPECT(not result->sym_expr.has_value()); } -TEST_CASE(test_dd_intersection_fixed_gets_lit) +TEST_CASE(dd_intersection_symbolic_same_symbol) { auto n = var("n"); - migraphx::shape::dynamic_dimension a{1, 8, {}, n}; - migraphx::shape::dynamic_dimension b{4, 4}; + migraphx::shape::dynamic_dimension a{1, 32, {}, n}; + migraphx::shape::dynamic_dimension b{1, 32, {}, n}; auto result = a.intersection(b); EXPECT(result.has_value()); - EXPECT(result->min == 4); - EXPECT(result->max == 4); - EXPECT(result->sym_expr.has_value()); - EXPECT(*result->sym_expr == lit(4)); + EXPECT(*result == a); +} + +TEST_CASE(dd_intersection_symbolic_different_symbol) +{ + auto n = var("n"); + auto m = var("m"); + migraphx::shape::dynamic_dimension a{1, 32, {}, n}; + migraphx::shape::dynamic_dimension b{1, 16, {}, m}; + auto result = a.intersection(b); + EXPECT(not result.has_value()); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From 378e3d7be73d15d1ffa3214d8dff35e0b1d1099d Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 9 Apr 2026 10:18:58 -0700 Subject: [PATCH 36/60] review feedback updates --- src/include/migraphx/shape.hpp | 48 ++---- src/shape.cpp | 293 +++++++++++++++------------------ 2 files changed, 149 insertions(+), 192 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 0500ca3786f..d2aa9b330c5 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -172,37 +172,23 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y); - // add, subtract, multiply, divide fixed std::size_t dimension - dynamic_dimension& operator+=(const std::size_t& x); - dynamic_dimension& operator-=(const std::size_t& x); - dynamic_dimension& operator*=(const std::size_t& x); - dynamic_dimension& operator/=(const std::size_t& x); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const std::size_t& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator/(const dynamic_dimension& x, - const std::size_t& y); - - // dd-to-dd arithmetic (defined in shape.cpp) - dynamic_dimension& operator+=(const dynamic_dimension& x); - dynamic_dimension& operator-=(const dynamic_dimension& x); - dynamic_dimension& operator*=(const dynamic_dimension& x); - dynamic_dimension& operator/=(const dynamic_dimension& x); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator/(const dynamic_dimension& x, - const dynamic_dimension& y); + // clang-format off +#define MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(binary_op, assign_op) \ + dynamic_dimension& operator assign_op(const dynamic_dimension& x); \ + dynamic_dimension& operator assign_op(const std::size_t& x); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const dynamic_dimension& x, const dynamic_dimension& y); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const dynamic_dimension& x, const std::size_t& y); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const std::size_t& x, const dynamic_dimension& y); + + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(+, +=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(-, -=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(*, *=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(/, /=) +#undef MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP + // clang-format on }; static std::string to_sizes_string(const std::vector& shapes); diff --git a/src/shape.cpp b/src/shape.cpp index 202fd18dcb7..6662867f6b0 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -131,8 +131,16 @@ struct shape_impl std::vector m_dyn_dims = {}; std::vector m_dyn_strides = {}; - std::vector sym_dim_exprs() const + std::vector sym_dims() const { + if(m_dyn_dims.empty()) + { + std::vector result(m_lens.size()); + std::transform(m_lens.begin(), m_lens.end(), result.begin(), [](auto len) { + return sym::lit(len); + }); + return result; + } std::vector result(m_dyn_dims.size()); std::transform(m_dyn_dims.begin(), m_dyn_dims.end(), result.begin(), [](const auto& dd) { return dd.sym_expr.value_or(sym::expr{}); @@ -143,7 +151,7 @@ struct shape_impl template static T make_identity(int64_t n) { - if constexpr(std::is_same_v) + if constexpr(std::is_same{}) return sym::lit(n); else return T(n); @@ -156,14 +164,11 @@ struct shape_impl if(strides.empty()) return strides; strides.back() = make_identity(1); - std::partial_sum(dims.rbegin(), - dims.rend() - 1, - strides.rbegin() + 1, - [](const auto& a, const auto& b) { return b * a; }); + std::partial_sum(dims.rbegin(), dims.rend() - 1, strides.rbegin() + 1, std::multiplies<>{}); return strides; } - void calculate_dyn_strides() { m_dyn_strides = compute_strides(sym_dim_exprs()); } + void calculate_dyn_strides() { m_dyn_strides = compute_strides(sym_dims()); } void calculate_strides() { m_strides = compute_strides(m_lens); } @@ -816,26 +821,32 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); } -shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) -{ - return *this += dynamic_dimension{x, x}; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x) -{ - return *this -= dynamic_dimension{x, x}; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t& x) -{ - return *this *= dynamic_dimension{x, x}; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const std::size_t& x) -{ - assert(x != 0); - return *this /= dynamic_dimension{x, x}; -} +// clang-format off +#define MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(binary_op, assign_op) \ + shape::dynamic_dimension& shape::dynamic_dimension::operator assign_op(const std::size_t& x) \ + { \ + return *this assign_op dynamic_dimension{x, x}; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const shape::dynamic_dimension& x, const std::size_t& y) \ + { \ + auto result = x; \ + result assign_op y; \ + return result; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const std::size_t& x, const shape::dynamic_dimension& y) \ + { \ + return shape::dynamic_dimension{x, x} binary_op y; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) \ + { \ + auto result = x; \ + result assign_op y; \ + return result; \ + } +// clang-format on bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { @@ -852,7 +863,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { if(x.is_symbolic()) - os << x.sym_expr->to_string(); + os << *x.sym_expr; if(x.is_fixed()) { if(not x.is_symbolic()) @@ -871,61 +882,62 @@ bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { retur bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); } bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); } -shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y) -{ - auto dd = x; - return dd += y; -} - -shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y) -{ - return y + x; -} - -shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y) -{ - auto dd = x; - return dd -= y; -} - -shape::dynamic_dimension operator*(const shape::dynamic_dimension& x, const std::size_t& y) -{ - auto dd = x; - return dd *= y; -} - -shape::dynamic_dimension operator*(const std::size_t& x, const shape::dynamic_dimension& y) -{ - return y * x; -} - -shape::dynamic_dimension operator/(const shape::dynamic_dimension& x, const std::size_t& y) -{ - auto dd = x; - return dd /= y; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) -{ - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - min = min + x.min; - max = (max > std::numeric_limits::max() - x.max) - ? std::numeric_limits::max() - : max + x.max; - if(x.is_fixed()) - { - std::set new_optimals; - std::transform(optimals.begin(), - optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&](auto o) { return o + x.min; }); - optimals = new_optimals; +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(+, +=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(-, -=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(*, *=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(/, /=) +#undef MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP + +// When one operand is fixed, shift the other's optimals by the fixed value. +// When neither is fixed, optimals are cleared. +template +static void merge_optimals(std::set& optimals, + bool lhs_fixed, + const std::set& rhs_optimals, + bool rhs_fixed, + F1 shift_lhs, + F2 shift_rhs) +{ + if(rhs_fixed) + { + std::set result; + std::transform( + optimals.begin(), optimals.end(), std::inserter(result, result.begin()), shift_lhs); + optimals = result; + } + else if(lhs_fixed) + { + std::set result; + std::transform(rhs_optimals.begin(), + rhs_optimals.end(), + std::inserter(result, result.begin()), + shift_rhs); + optimals = result; } else { optimals.clear(); } +} + +shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) +{ + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = min; + min = min + x.min; + max = (max > std::numeric_limits::max() - x.max) + ? std::numeric_limits::max() + : max + x.max; + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return o + x.min; }, + [&](auto o) { return o + lhs_min; }); sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym + *rhs_sym) : nullopt; normalize_sym(); return *this; @@ -933,23 +945,20 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dyna shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - min = (min > x.max) ? min - x.max : 0; - max = (max > x.min) ? max - x.min : 0; - if(x.is_fixed()) - { - std::set new_optimals; - std::transform(optimals.begin(), - optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&](auto o) { return (o > x.min) ? o - x.min : 0; }); - optimals = new_optimals; - } - else - { - optimals.clear(); - } + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = min; + min = (min > x.max) ? min - x.max : 0; + max = (max > x.min) ? max - x.min : 0; + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return (o > x.min) ? o - x.min : 0; }, + [&](auto o) { return (lhs_min > o) ? lhs_min - o : 0; }); sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym - *rhs_sym) : nullopt; normalize_sym(); return *this; @@ -957,10 +966,13 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dyna shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - min = min * x.min; - auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = min; + min = min * x.min; + auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { if(b == 0) return 0; if(a > std::numeric_limits::max() / b) @@ -968,19 +980,13 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dyna return a * b; }; max = safe_mul(max, x.max); - if(x.is_fixed()) - { - std::set new_optimals; - std::transform(optimals.begin(), - optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&](auto o) { return o * x.min; }); - optimals = new_optimals; - } - else - { - optimals.clear(); - } + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return o * x.min; }, + [&](auto o) { return o * lhs_min; }); sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym * *rhs_sym) : nullopt; normalize_sym(); return *this; @@ -988,60 +994,25 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dyna shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - min = (x.max == 0) ? 0 : min / x.max; - max = (x.min == 0) ? std::numeric_limits::max() : max / x.min; - if(x.is_fixed()) - { - std::set new_optimals; - std::transform(optimals.begin(), - optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&](auto o) { return (x.min == 0) ? std::size_t{0} : o / x.min; }); - optimals = new_optimals; - } - else - { - optimals.clear(); - } + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = min; + min = (x.max == 0) ? 0 : min / x.max; + max = (x.min == 0) ? std::numeric_limits::max() : max / x.min; + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return (x.min == 0) ? std::size_t{0} : o / x.min; }, + [&](auto o) { return (o == 0) ? std::size_t{0} : lhs_min / o; }); sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym / *rhs_sym) : nullopt; normalize_sym(); return *this; } -shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, - const shape::dynamic_dimension& y) -{ - auto result = x; - result += y; - return result; -} - -shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, - const shape::dynamic_dimension& y) -{ - auto result = x; - result -= y; - return result; -} - -shape::dynamic_dimension operator*(const shape::dynamic_dimension& x, - const shape::dynamic_dimension& y) -{ - auto result = x; - result *= y; - return result; -} - -shape::dynamic_dimension operator/(const shape::dynamic_dimension& x, - const shape::dynamic_dimension& y) -{ - auto result = x; - result /= y; - return result; -} - bool operator==(const shape& x, const shape& y) { if(x.dynamic() and y.dynamic()) From 0bf568036d33af7f3f00b58183eeac0f5385bd7e Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 10 Apr 2026 08:46:33 -0700 Subject: [PATCH 37/60] fix tidy --- src/include/migraphx/shape.hpp | 1 + src/shape.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index d2aa9b330c5..a179e3f9a6d 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -173,6 +173,7 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y); // clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) #define MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(binary_op, assign_op) \ dynamic_dimension& operator assign_op(const dynamic_dimension& x); \ dynamic_dimension& operator assign_op(const std::size_t& x); \ diff --git a/src/shape.cpp b/src/shape.cpp index 9827cbc8725..57651021cf0 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -823,6 +823,7 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); } // clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) #define MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(binary_op, assign_op) \ shape::dynamic_dimension& shape::dynamic_dimension::operator assign_op(const std::size_t& x) \ { \ From 2545a8e485b7fc764318b8d71f43dbb2fe021c4e Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 10 Apr 2026 09:44:24 -0700 Subject: [PATCH 38/60] fix callsite to remove disambiguity --- src/gemm.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gemm.cpp b/src/gemm.cpp index 2deef7fb673..49c21a22fb8 100644 --- a/src/gemm.cpp +++ b/src/gemm.cpp @@ -72,15 +72,18 @@ struct batch_slicer batch_slicer(const shape& mat_shape) { auto n_batch_dims = mat_shape.ndim() - 2; - inner_shape = shape{mat_shape.type(), - {mat_shape.lens().end() - 2, mat_shape.lens().end()}, - {mat_shape.strides().end() - 2, mat_shape.strides().end()}}; + inner_shape = shape{ + mat_shape.type(), + std::vector{mat_shape.lens().end() - 2, mat_shape.lens().end()}, + std::vector{mat_shape.strides().end() - 2, mat_shape.strides().end()}}; if(n_batch_dims > 0) { outer_shape = shape{mat_shape.type(), - {mat_shape.lens().begin(), mat_shape.lens().begin() + n_batch_dims}, - {mat_shape.strides().begin(), mat_shape.strides().begin() + n_batch_dims}}; + std::vector{mat_shape.lens().begin(), + mat_shape.lens().begin() + n_batch_dims}, + std::vector{mat_shape.strides().begin(), + mat_shape.strides().begin() + n_batch_dims}}; } } From c8b8df478afffa4304687fa36a73d0d9e342db5e Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 10 Apr 2026 21:55:49 -0700 Subject: [PATCH 39/60] fix merge --- test/serialize_program.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index fd43ae31b8a..783b4d4dbe6 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -159,6 +159,7 @@ TEST_CASE(symbolic_shape_msgpack_roundtrip) std::vector buffer = migraphx::save_buffer(p, options); migraphx::program p2 = migraphx::load_buffer(buffer, options); EXPECT(p.sort() == p2.sort()); +} static migraphx::program create_program_with_debug_symbols() { From 4e0da9c40dab5d05bab78b4ad20db2620c8febfd Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 13 Apr 2026 16:07:57 -0700 Subject: [PATCH 40/60] remove optional from sym_expr --- src/include/migraphx/shape.hpp | 6 +++--- src/shape.cpp | 16 ++++++++-------- test/shape_test.cpp | 26 +++++++++++++------------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index a179e3f9a6d..4076c509bd5 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -100,7 +100,7 @@ struct MIGRAPHX_EXPORT shape std::size_t min = 0; std::size_t max = 0; std::set optimals{}; - optional sym_expr; + sym::expr sym_expr; dynamic_dimension() = default; dynamic_dimension(std::size_t min_v, std::size_t max_v) : min(min_v), max(max_v) @@ -115,7 +115,7 @@ struct MIGRAPHX_EXPORT shape dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt, - optional s) + sym::expr s) : min(min_v), max(max_v), optimals(std::move(opt)), sym_expr(std::move(s)) { normalize_sym(); @@ -131,7 +131,7 @@ struct MIGRAPHX_EXPORT shape } bool is_fixed() const; - bool is_symbolic() const { return sym_expr.has_value(); } + bool is_symbolic() const { return not sym_expr.empty(); } void normalize_sym() { if(is_fixed() and not is_symbolic()) diff --git a/src/shape.cpp b/src/shape.cpp index f00e7573ca8..100234f3c4b 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -143,7 +143,7 @@ struct shape_impl } std::vector result(m_dyn_dims.size()); std::transform(m_dyn_dims.begin(), m_dyn_dims.end(), result.begin(), [](const auto& dd) { - return dd.sym_expr.value_or(sym::expr{}); + return dd.sym_expr; }); return result; } @@ -753,8 +753,8 @@ shape shape::to_static(const std::unordered_map& symbol_ [&](const auto& dd) -> std::size_t { if(dd.is_fixed()) return dd.min; - if(dd.sym_expr) - return dd.sym_expr->eval_uint(symbol_map); + if(not dd.sym_expr.empty()) + return dd.sym_expr.eval_uint(symbol_map); MIGRAPHX_THROW("to_static: non-fixed dimension has no symbolic expression"); }); const auto& ds = this->dyn_strides(); @@ -864,7 +864,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { if(x.is_symbolic()) - os << *x.sym_expr; + os << x.sym_expr; if(x.is_fixed()) { if(not x.is_symbolic()) @@ -939,7 +939,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dyna rhs_fixed, [&](auto o) { return o + x.min; }, [&](auto o) { return o + lhs_min; }); - sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym + *rhs_sym) : nullopt; + sym_expr = lhs_sym + rhs_sym; normalize_sym(); return *this; } @@ -960,7 +960,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dyna rhs_fixed, [&](auto o) { return (o > x.min) ? o - x.min : 0; }, [&](auto o) { return (lhs_min > o) ? lhs_min - o : 0; }); - sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym - *rhs_sym) : nullopt; + sym_expr = lhs_sym - rhs_sym; normalize_sym(); return *this; } @@ -988,7 +988,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dyna rhs_fixed, [&](auto o) { return o * x.min; }, [&](auto o) { return o * lhs_min; }); - sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym * *rhs_sym) : nullopt; + sym_expr = lhs_sym * rhs_sym; normalize_sym(); return *this; } @@ -1009,7 +1009,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dyna rhs_fixed, [&](auto o) { return (x.min == 0) ? std::size_t{0} : o / x.min; }, [&](auto o) { return (o == 0) ? std::size_t{0} : lhs_min / o; }); - sym_expr = (lhs_sym and rhs_sym) ? optional(*lhs_sym / *rhs_sym) : nullopt; + sym_expr = lhs_sym / rhs_sym; normalize_sym(); return *this; } diff --git a/test/shape_test.cpp b/test/shape_test.cpp index abed97a30db..6933ef85747 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1290,7 +1290,7 @@ TEST_CASE(test_dd_symbolic_add_size_t) dd += 2; EXPECT(dd.min == 3); EXPECT(dd.max == 10); - EXPECT(*dd.sym_expr == n + 2); + EXPECT(dd.sym_expr == n + 2); } TEST_CASE(test_dd_symbolic_sub_size_t) @@ -1300,7 +1300,7 @@ TEST_CASE(test_dd_symbolic_sub_size_t) dd -= 1; EXPECT(dd.min == 2); EXPECT(dd.max == 7); - EXPECT(*dd.sym_expr == n - 1); + EXPECT(dd.sym_expr == n - 1); } TEST_CASE(test_dd_symbolic_mul_size_t) @@ -1310,7 +1310,7 @@ TEST_CASE(test_dd_symbolic_mul_size_t) dd *= 3; EXPECT(dd.min == 3); EXPECT(dd.max == 24); - EXPECT(*dd.sym_expr == n * 3); + EXPECT(dd.sym_expr == n * 3); } TEST_CASE(test_dd_symbolic_div_size_t) @@ -1320,7 +1320,7 @@ TEST_CASE(test_dd_symbolic_div_size_t) dd /= 2; EXPECT(dd.min == 2); EXPECT(dd.max == 8); - EXPECT(*dd.sym_expr == n / 2); + EXPECT(dd.sym_expr == n / 2); } TEST_CASE(test_dd_symbolic_add_dd) @@ -1332,7 +1332,7 @@ TEST_CASE(test_dd_symbolic_add_dd) auto r = a + b; EXPECT(r.min == 3); EXPECT(r.max == 12); - EXPECT(*r.sym_expr == n + c); + EXPECT(r.sym_expr == n + c); } TEST_CASE(test_dd_symbolic_sub_dd) @@ -1344,7 +1344,7 @@ TEST_CASE(test_dd_symbolic_sub_dd) auto r = a - b; EXPECT(r.min == 0); EXPECT(r.max == 15); - EXPECT(*r.sym_expr == n - k); + EXPECT(r.sym_expr == n - k); } TEST_CASE(test_dd_symbolic_mul_dd) @@ -1356,7 +1356,7 @@ TEST_CASE(test_dd_symbolic_mul_dd) auto r = a * b; EXPECT(r.min == 2); EXPECT(r.max == 32); - EXPECT(*r.sym_expr == n * c); + EXPECT(r.sym_expr == n * c); } TEST_CASE(test_dd_symbolic_div_dd) @@ -1368,7 +1368,7 @@ TEST_CASE(test_dd_symbolic_div_dd) auto r = a / b; EXPECT(r.min == 1); EXPECT(r.max == 8); - EXPECT(*r.sym_expr == n / k); + EXPECT(r.sym_expr == n / k); } TEST_CASE(test_dd_symbolic_plus_fixed) @@ -1377,8 +1377,8 @@ TEST_CASE(test_dd_symbolic_plus_fixed) migraphx::shape::dynamic_dimension a{1, 8, {}, n}; migraphx::shape::dynamic_dimension b{3, 3}; auto r = a + b; - EXPECT(r.sym_expr.has_value()); - EXPECT(*r.sym_expr == n + 3); + EXPECT(not r.sym_expr.empty()); + EXPECT(r.sym_expr == n + 3); EXPECT(r.min == 4); EXPECT(r.max == 11); } @@ -1389,7 +1389,7 @@ TEST_CASE(test_dd_nonfixed_nonsymbolic_plus_symbolic_drops_sym) migraphx::shape::dynamic_dimension a{1, 8, {}}; migraphx::shape::dynamic_dimension b{2, 4, {}, c}; auto r = a + b; - EXPECT(not r.sym_expr.has_value()); + EXPECT(r.sym_expr.empty()); EXPECT(r.min == 3); EXPECT(r.max == 12); } @@ -1399,7 +1399,7 @@ TEST_CASE(test_dd_nonsymbolic_remains_nonsymbolic) migraphx::shape::dynamic_dimension a{1, 8, {}}; migraphx::shape::dynamic_dimension b{2, 4, {}}; auto r = a + b; - EXPECT(not r.sym_expr.has_value()); + EXPECT(r.sym_expr.empty()); } TEST_CASE(test_dd_equality_with_sym) @@ -1501,7 +1501,7 @@ TEST_CASE(dd_intersection_symbolic_with_range) EXPECT(result.has_value()); EXPECT(result->min == 2); EXPECT(result->max == 6); - EXPECT(not result->sym_expr.has_value()); + EXPECT(result->sym_expr.empty()); } TEST_CASE(dd_intersection_symbolic_same_symbol) From ca454b266431026d4abc346275362f60451f5a6a Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 14 Apr 2026 14:13:03 -0700 Subject: [PATCH 41/60] refactor how dyn dim intervals are stored and accessed --- src/include/migraphx/op/concat.hpp | 6 ++-- src/include/migraphx/op/convolution.hpp | 6 ++-- src/include/migraphx/op/gathernd.hpp | 4 +-- src/include/migraphx/op/nonmaxsuppression.hpp | 6 ++-- src/include/migraphx/op/pooling.hpp | 6 ++-- src/include/migraphx/op/reduce_op.hpp | 2 +- src/include/migraphx/op/reshape.hpp | 10 +++--- src/include/migraphx/op/reshape_lazy.hpp | 2 +- src/include/migraphx/op/resize.hpp | 12 +++---- src/include/migraphx/op/scatternd_op.hpp | 6 ++-- src/include/migraphx/op/slice.hpp | 14 ++++---- src/include/migraphx/shape.hpp | 36 ++++++++++++++++--- src/normalize_attributes.cpp | 4 +-- src/onnx/onnx_parser.cpp | 2 +- src/onnx/parse_depthtospace.cpp | 10 +++--- src/py/migraphx_py.cpp | 6 ++-- src/shape.cpp | 32 ++++++++--------- src/split_single_dyn_dim.cpp | 12 +++---- 18 files changed, 102 insertions(+), 74 deletions(-) diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 527ef55bc1c..89d7c40bdc3 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -129,8 +129,8 @@ struct concat for(const auto& input : inputs) { auto ddim = input.dyn_dims()[axis]; - new_min += ddim.min; - new_max += ddim.max; + new_min += ddim.min(); + new_max += ddim.max(); } auto new_dims = inputs[0].dyn_dims(); diff --git a/src/include/migraphx/op/convolution.hpp b/src/include/migraphx/op/convolution.hpp index c4af5801345..a9245e4ec39 100644 --- a/src/include/migraphx/op/convolution.hpp +++ b/src/include/migraphx/op/convolution.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -164,8 +164,8 @@ struct convolution x.optimals.end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); - output_dyn_dims.push_back( - shape::dynamic_dimension{ceil_div(x.min, s), ceil_div(x.max, s), optimals}); + output_dyn_dims.push_back(shape::dynamic_dimension{ + ceil_div(x.min(), s), ceil_div(x.max(), s), optimals}); } else { diff --git a/src/include/migraphx/op/gathernd.hpp b/src/include/migraphx/op/gathernd.hpp index 4c71c3cdd9e..91e9644b7a0 100644 --- a/src/include/migraphx/op/gathernd.hpp +++ b/src/include/migraphx/op/gathernd.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -63,7 +63,7 @@ struct gathernd MIGRAPHX_THROW( "GATHERND: last dimension of indices tensor must be fixed (min=max)"); } - k = i_shape.dyn_dims().back().min; + k = i_shape.dyn_dims().back().min(); } else k = i_shape.lens().back(); diff --git a/src/include/migraphx/op/nonmaxsuppression.hpp b/src/include/migraphx/op/nonmaxsuppression.hpp index a4ef6cdcb3a..442da4d5b0b 100644 --- a/src/include/migraphx/op/nonmaxsuppression.hpp +++ b/src/include/migraphx/op/nonmaxsuppression.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -156,12 +156,12 @@ struct nonmaxsuppression // check that it is only a dynamic number of classes const auto scores_dims = inputs.at(1).dyn_dims(); const auto boxes_lens = inputs.at(0).lens(); - if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max != boxes_lens.at(0)) + if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max() != boxes_lens.at(0)) { MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; num_batches not " "fixed or mismatched"); } - if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max != boxes_lens.at(1)) + if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max() != boxes_lens.at(1)) { MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; " "spatial_dimension not fixed or mismatches"); diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index 5854c7ad8b7..5d6cf08088b 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -218,8 +218,8 @@ struct pooling x.optimals.end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); - output_dyn_dims.push_back( - shape::dynamic_dimension{ceil_div(x.min, s), ceil_div(x.max, s), optimals}); + output_dyn_dims.push_back(shape::dynamic_dimension{ + ceil_div(x.min(), s), ceil_div(x.max(), s), optimals}); } return {input.type(), output_dyn_dims}; } diff --git a/src/include/migraphx/op/reduce_op.hpp b/src/include/migraphx/op/reduce_op.hpp index f8249afc7c1..ea2c0c55bbd 100644 --- a/src/include/migraphx/op/reduce_op.hpp +++ b/src/include/migraphx/op/reduce_op.hpp @@ -120,7 +120,7 @@ struct reduce_op : op_name if(axes.empty()) { std::transform(dims.begin(), dims.end(), dims.begin(), [](const auto& dim) { - return shape::dynamic_dimension{1, dim.max}; + return shape::dynamic_dimension{1, dim.max()}; }); } else diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index 139aceb93d8..b79919a69fc 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -104,16 +104,16 @@ struct reshape std::size_t max_cur_elements = 1; for(const auto& dd : output_dyn_dims) { - min_cur_elements = mul_sat(min_cur_elements, dd.min); - max_cur_elements = mul_sat(max_cur_elements, dd.max); + min_cur_elements = mul_sat(min_cur_elements, dd.min()); + max_cur_elements = mul_sat(max_cur_elements, dd.max()); } // accumulate the elements in the input dimensions std::size_t min_input_elements = 1; std::size_t max_input_elements = 1; for(const auto& dd : input_dyn_dims) { - min_input_elements = mul_sat(min_input_elements, dd.min); - max_input_elements = mul_sat(max_input_elements, dd.max); + min_input_elements = mul_sat(min_input_elements, dd.min()); + max_input_elements = mul_sat(max_input_elements, dd.max()); } // maximum dimensions should never accumulate to zero diff --git a/src/include/migraphx/op/reshape_lazy.hpp b/src/include/migraphx/op/reshape_lazy.hpp index 1e01fcd8230..e1cde2810bc 100644 --- a/src/include/migraphx/op/reshape_lazy.hpp +++ b/src/include/migraphx/op/reshape_lazy.hpp @@ -66,7 +66,7 @@ struct reshape_lazy if(dyn_dims[i].is_fixed()) { num_dims_ele *= dims[i]; - num_dd_ele *= dyn_dims[i].min; + num_dd_ele *= dyn_dims[i].min(); } else { diff --git a/src/include/migraphx/op/resize.hpp b/src/include/migraphx/op/resize.hpp index 05c8f3236c8..a13af624116 100644 --- a/src/include/migraphx/op/resize.hpp +++ b/src/include/migraphx/op/resize.hpp @@ -416,12 +416,12 @@ struct resize { for(std::size_t i = 0; i < scales.size(); i++) { - dyn_dims[i].min = static_cast(input.dyn_dims()[i].min * scales[i]); - if(input.dyn_dims()[i].max != max_val) - { - dyn_dims[i].max = - static_cast(input.dyn_dims()[i].max * scales[i]); - } + auto new_min = static_cast(input.dyn_dims()[i].min() * scales[i]); + auto new_max = + input.dyn_dims()[i].max() != max_val + ? static_cast(input.dyn_dims()[i].max() * scales[i]) + : max_val; + dyn_dims[i] = shape::dynamic_dimension{new_min, new_max}; } } return {input.type(), dyn_dims}; diff --git a/src/include/migraphx/op/scatternd_op.hpp b/src/include/migraphx/op/scatternd_op.hpp index d53bec905bd..27facdeb864 100644 --- a/src/include/migraphx/op/scatternd_op.hpp +++ b/src/include/migraphx/op/scatternd_op.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,7 +54,7 @@ struct scatternd_op : op_name shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this, true}.has(3); - auto data_shape = inputs.front(); + auto data_shape = inputs.front(); const auto& index_shape = inputs.at(1); const auto& upd_shape = inputs.back(); @@ -69,7 +69,7 @@ struct scatternd_op : op_name MIGRAPHX_THROW( "GATHERND: last dimension of indices tensor must be fixed (min=max)"); } - k = index_shape.dyn_dims().back().min; + k = index_shape.dyn_dims().back().min(); } else k = index_shape.lens().back(); diff --git a/src/include/migraphx/op/slice.hpp b/src/include/migraphx/op/slice.hpp index fa35d640849..e615b77e73d 100644 --- a/src/include/migraphx/op/slice.hpp +++ b/src/include/migraphx/op/slice.hpp @@ -166,7 +166,7 @@ struct slice MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch"); } std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { - dds.at(axis) = {0, dds.at(axis).max}; + dds.at(axis) = {0, dds.at(axis).max()}; }); } else if(set_attributes == starts_axes) @@ -177,7 +177,7 @@ struct slice MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch"); } std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { - dds.at(axis) = {0, dds.at(axis).max}; + dds.at(axis) = {0, dds.at(axis).max()}; }); } else if(set_attributes == starts_ends) @@ -188,7 +188,7 @@ struct slice MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch"); } std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max}; + return shape::dynamic_dimension{0, dd.max()}; }); } else @@ -206,7 +206,7 @@ struct slice MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch"); } std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { - dds.at(axis) = {0, dds.at(axis).max}; + dds.at(axis) = {0, dds.at(axis).max()}; }); } else if(set_attributes == ends_only) @@ -217,7 +217,7 @@ struct slice MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch"); } std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max}; + return shape::dynamic_dimension{0, dd.max()}; }); } else if(set_attributes == starts_only) @@ -229,7 +229,7 @@ struct slice MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch"); } std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max}; + return shape::dynamic_dimension{0, dd.max()}; }); } else @@ -241,7 +241,7 @@ struct slice { // all 4 inputs (data, inputs_starts, input_ends, input_axes) std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max}; + return shape::dynamic_dimension{0, dd.max()}; }); } return shape{input_shape.type(), dds}; diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index c9b343cd01c..3fd4fc0794f 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -96,16 +96,42 @@ struct MIGRAPHX_EXPORT shape struct MIGRAPHX_EXPORT dynamic_dimension { - std::size_t min = 0; - std::size_t max = 0; + struct interval + { + std::size_t min; + std::size_t max; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.min, "min"), f(self.max, "max")); + } + friend bool operator==(const interval& a, const interval& b) + { + return a.min == b.min and a.max == b.max; + } + friend bool operator!=(const interval& a, const interval& b) { return not(a == b); } + }; + + interval range = {0, 0}; std::set optimals{}; + dynamic_dimension() = default; + dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} {} + dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt) + : range{min_v, max_v}, optimals(std::move(opt)) + { + } + template static auto reflect(Self& self, F f) { - return pack(f(self.min, "min"), f(self.max, "max"), f(self.optimals, "optimals")); + return pack(f(self.range, "range"), f(self.optimals, "optimals")); } + std::size_t min() const { return range.min; } + std::size_t max() const { return range.max; } + interval get_interval() const { return range; } + bool is_fixed() const; bool has_optimal() const; @@ -115,8 +141,8 @@ struct MIGRAPHX_EXPORT shape */ std::optional intersection(const dynamic_dimension& other) const { - auto left = std::max(this->min, other.min); - auto right = std::min(this->max, other.max); + auto left = std::max(this->min(), other.min()); + auto right = std::min(this->max(), other.max()); if(left <= right) { return dynamic_dimension{left, right}; diff --git a/src/normalize_attributes.cpp b/src/normalize_attributes.cpp index a5e1aae3abd..d61c57740cd 100644 --- a/src/normalize_attributes.cpp +++ b/src/normalize_attributes.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -74,7 +74,7 @@ static auto tune_attribute(const std::vector& vec, return vec; } std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { - return input_shape.dyn_dims().at(i).max; + return input_shape.dyn_dims().at(i).max(); }); } else diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index c28f1d59fde..b0a74267ba1 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -54,7 +54,7 @@ static shape shape_from_dyn_dims(shape::type_t shape_type, std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), - [](const auto& d) { return d.max; }); + [](const auto& d) { return d.max(); }); return {shape_type, dims}; } return {shape_type, dyn_dims}; diff --git a/src/onnx/parse_depthtospace.cpp b/src/onnx/parse_depthtospace.cpp index fe603d87cfa..3b0a503b56b 100644 --- a/src/onnx/parse_depthtospace.cpp +++ b/src/onnx/parse_depthtospace.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -74,7 +74,7 @@ struct parse_depthtospace : op_parser { MIGRAPHX_THROW("DepthToSpace: dynamic channels are not supported"); } - int64_t c = dyn_dims1[1].max; + int64_t c = dyn_dims1[1].max(); auto h = info.add_instruction(make_op("dimensions_of", {{"start", 2}, {"end", 3}}), args[0]); auto w = @@ -82,7 +82,7 @@ struct parse_depthtospace : op_parser auto c_div = info.add_literal({c / divisor}); - dyn_dims2[1] = {dyn_dims2[1].min / divisor, dyn_dims2[1].max / divisor}; + dyn_dims2[1] = {dyn_dims2[1].min() / divisor, dyn_dims2[1].max() / divisor}; dyn_dims2[2] = dyn_dims2[2] * blocksize_unsigned; dyn_dims2[3] = dyn_dims2[3] * blocksize_unsigned; // push back h and w to expand the vector to 6d @@ -96,7 +96,7 @@ struct parse_depthtospace : op_parser if(mode == "DCR") { // expanded vector = n, blocksize, blocksize, c // (blocksize**2), h, w - dyn_dims1[3] = {dyn_dims1[1].min / divisor, dyn_dims1[1].max / divisor, {}}; + dyn_dims1[3] = {dyn_dims1[1].min() / divisor, dyn_dims1[1].max() / divisor, {}}; dyn_dims1[1] = {blocksize_unsigned, blocksize_unsigned, {}}; perm = {0, 3, 4, 1, 5, 2}; new_shape1 = info.add_instruction( @@ -121,7 +121,7 @@ struct parse_depthtospace : op_parser else if(mode == "CRD") { // expanded vector = b, c // (blocksize ** 2), blocksize, blocksize, h, w - dyn_dims1[1] = {dyn_dims1[1].min / divisor, dyn_dims1[1].max / divisor, {}}; + dyn_dims1[1] = {dyn_dims1[1].min() / divisor, dyn_dims1[1].max() / divisor, {}}; dyn_dims1[3] = {blocksize_unsigned, blocksize_unsigned, {}}; perm = {0, 1, 4, 2, 5, 3}; new_shape1 = info.add_instruction( diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 7bd5d03b3f0..1d3dcc214b8 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -397,8 +397,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) .def(py::init<>()) .def(py::init()) .def(py::init>()) - .def_readwrite("min", &migraphx::shape::dynamic_dimension::min) - .def_readwrite("max", &migraphx::shape::dynamic_dimension::max) + .def_property_readonly("min", + [](const migraphx::shape::dynamic_dimension& d) { return d.min(); }) + .def_property_readonly("max", + [](const migraphx::shape::dynamic_dimension& d) { return d.max(); }) .def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals) .def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed); diff --git a/src/shape.cpp b/src/shape.cpp index f977237240b..94ab5f23cf4 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -192,7 +192,7 @@ struct shape_impl std::transform(m_dyn_dims.cbegin(), m_dyn_dims.cend(), ret.begin(), - [](const shape::dynamic_dimension& x) { return x.min; }); + [](const shape::dynamic_dimension& x) { return x.min(); }); return ret; } @@ -202,7 +202,7 @@ struct shape_impl std::transform(m_dyn_dims.cbegin(), m_dyn_dims.cend(), ret.begin(), - [](const shape::dynamic_dimension& x) { return x.max; }); + [](const shape::dynamic_dimension& x) { return x.max(); }); return ret; } @@ -321,7 +321,7 @@ bool shape::is_compatible_lens(const shape& actual, const shape& expected) return std::equal(actual.lens().begin(), actual.lens().end(), expected.dyn_dims().begin(), - [&](auto a, const auto& e) { return a >= e.min and a <= e.max; }); + [&](auto a, const auto& e) { return a >= e.min() and a <= e.max(); }); } return actual.lens() == expected.lens(); } @@ -705,14 +705,14 @@ std::vector shape::max_lens() const std::vector> shape::opt_lens() const { return impl->opt_lens(); } -bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; } +bool shape::dynamic_dimension::is_fixed() const { return this->min() == this->max(); } bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); } shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) { - this->min += x; - this->max += x; + this->range.min += x; + this->range.max += x; std::set new_optimals; std::transform(this->optimals.begin(), this->optimals.end(), @@ -724,10 +724,10 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x) { - assert(this->min >= x); - assert(this->max >= x); - this->min -= x; - this->max -= x; + assert(this->range.min >= x); + assert(this->range.max >= x); + this->range.min -= x; + this->range.max -= x; std::set new_optimals; std::transform(this->optimals.begin(), this->optimals.end(), @@ -742,8 +742,8 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t& x) { - this->min *= x; - this->max *= x; + this->range.min *= x; + this->range.max *= x; std::set new_optimals; std::transform(this->optimals.begin(), this->optimals.end(), @@ -755,8 +755,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { - // don't check optimals if both are fixed - return (x.min == y.min and x.max == y.max and + return (x.min() == y.min() and x.max() == y.max() and ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals))); } @@ -766,13 +765,14 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio } std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { - os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]"; + os << "[ " << x.min() << ", " << x.max() << ", {" << migraphx::to_string_range(x.optimals) + << "} ]"; return os; } bool operator==(const shape::dynamic_dimension& x, const std::size_t& y) { - return x.min == y and x.max == y; + return x.min() == y and x.max() == y; } bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; } bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); } diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index 66974cd8e59..d4a538cc545 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -117,9 +117,9 @@ static bool any_sm_next(const_module_ref mm, const std::vectorget_parameter_names(); - auto param_shapes = mm->get_parameter_shapes(); + module_ref mm = &mpm.get_module(); + auto param_names = mm->get_parameter_names(); + auto param_shapes = mm->get_parameter_shapes(); optional> dd_check_vec = has_one_unique_dyn_dim(param_shapes); if(dd_check_vec.has_value() and not any_sm_next(mm, dd_check_vec.value())) @@ -128,7 +128,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const auto dyn_dim = dd_check_vec->at(0).dd; // create submodules for each dimension size std::vector submodules; - for(size_t dim_size : migraphx::range(dyn_dim.min, dyn_dim.max + 1)) + for(size_t dim_size : migraphx::range(dyn_dim.min(), dyn_dim.max() + 1)) { auto* submod = mpm.create_module("dim_" + std::to_string(dim_size)); // instruction map for new static shaped submodule parameters @@ -157,7 +157,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const migraphx::shape out_attr = migraphx::shape{output_shapes}; auto sm_ins = mm->add_instruction( migraphx::make_op("select_module", - {{"output_dyn_shapes", migraphx::to_value(out_attr)}}), + {{"output_dyn_shapes", migraphx::to_value(out_attr)}}), sm_inputs, submodules); std::vector outputs(output_shapes.size()); From 29ca4d519b05f6b82d34298a425d01f2bfd8934f Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 14 Apr 2026 14:30:05 -0700 Subject: [PATCH 42/60] add defaults --- src/include/migraphx/shape.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 3fd4fc0794f..9c5a8fa81aa 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -98,8 +98,8 @@ struct MIGRAPHX_EXPORT shape { struct interval { - std::size_t min; - std::size_t max; + std::size_t min = 0; + std::size_t max = 0; template static auto reflect(Self& self, F f) { From a0026ba1a33ee4df1f67c81e9d560a0a7056660f Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 14 Apr 2026 15:19:19 -0700 Subject: [PATCH 43/60] add getter for optimals and update callsites --- src/include/migraphx/op/convolution.hpp | 4 ++-- src/include/migraphx/op/pooling.hpp | 4 ++-- src/include/migraphx/shape.hpp | 1 + src/py/migraphx_py.cpp | 4 +++- src/shape.cpp | 6 +++--- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/op/convolution.hpp b/src/include/migraphx/op/convolution.hpp index a9245e4ec39..6357d8191dd 100644 --- a/src/include/migraphx/op/convolution.hpp +++ b/src/include/migraphx/op/convolution.hpp @@ -160,8 +160,8 @@ struct convolution { auto x = x_shape.dyn_dims()[i + 2]; std::set optimals{}; - std::transform(x.optimals.begin(), - x.optimals.end(), + std::transform(x.get_optimals().begin(), + x.get_optimals().end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); output_dyn_dims.push_back(shape::dynamic_dimension{ diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index 5d6cf08088b..d6ab4283509 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -214,8 +214,8 @@ struct pooling auto x = x_shape.dyn_dims()[i + 2]; std::set optimals{}; - std::transform(x.optimals.begin(), - x.optimals.end(), + std::transform(x.get_optimals().begin(), + x.get_optimals().end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); output_dyn_dims.push_back(shape::dynamic_dimension{ diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 9c5a8fa81aa..8460f4a8179 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -131,6 +131,7 @@ struct MIGRAPHX_EXPORT shape std::size_t min() const { return range.min; } std::size_t max() const { return range.max; } interval get_interval() const { return range; } + const std::set& get_optimals() const { return optimals; } bool is_fixed() const; bool has_optimal() const; diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 1d3dcc214b8..c95886d9143 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -401,7 +401,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) [](const migraphx::shape::dynamic_dimension& d) { return d.min(); }) .def_property_readonly("max", [](const migraphx::shape::dynamic_dimension& d) { return d.max(); }) - .def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals) + .def_property_readonly( + "optimals", + [](const migraphx::shape::dynamic_dimension& d) { return d.get_optimals(); }) .def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed); py::class_(m, "argument", py::buffer_protocol()) diff --git a/src/shape.cpp b/src/shape.cpp index 94ab5f23cf4..a4eabe56b8f 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -212,7 +212,7 @@ struct shape_impl std::transform(m_dyn_dims.cbegin(), m_dyn_dims.cend(), ret.begin(), - [](const shape::dynamic_dimension& x) { return x.optimals; }); + [](const shape::dynamic_dimension& x) { return x.get_optimals(); }); return ret; } @@ -756,7 +756,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { return (x.min() == y.min() and x.max() == y.max() and - ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals))); + ((x.is_fixed() and y.is_fixed()) or (x.get_optimals() == y.get_optimals()))); } bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) @@ -765,7 +765,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio } std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { - os << "[ " << x.min() << ", " << x.max() << ", {" << migraphx::to_string_range(x.optimals) + os << "[ " << x.min() << ", " << x.max() << ", {" << migraphx::to_string_range(x.get_optimals()) << "} ]"; return os; } From 1db31d63771fae9347c0f7aad2c7e5f5c9ef3bdd Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 14 Apr 2026 16:13:16 -0700 Subject: [PATCH 44/60] update to use get_interval() and remove min() and max() --- src/include/migraphx/op/concat.hpp | 7 ++--- src/include/migraphx/op/convolution.hpp | 3 ++- src/include/migraphx/op/gathernd.hpp | 2 +- src/include/migraphx/op/nonmaxsuppression.hpp | 6 +++-- src/include/migraphx/op/pooling.hpp | 3 ++- src/include/migraphx/op/reduce_op.hpp | 2 +- src/include/migraphx/op/reshape.hpp | 10 ++++--- src/include/migraphx/op/reshape_lazy.hpp | 2 +- src/include/migraphx/op/resize.hpp | 12 ++++----- src/include/migraphx/op/scatternd_op.hpp | 2 +- src/include/migraphx/op/slice.hpp | 14 +++++----- src/include/migraphx/shape.hpp | 8 +++--- src/normalize_attributes.cpp | 2 +- src/onnx/onnx_parser.cpp | 2 +- src/onnx/parse_depthtospace.cpp | 27 ++++++++++--------- src/py/migraphx_py.cpp | 8 +++--- src/shape.cpp | 25 +++++++++++------ src/split_single_dyn_dim.cpp | 3 ++- 18 files changed, 79 insertions(+), 59 deletions(-) diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 89d7c40bdc3..31288a547a2 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -128,9 +128,10 @@ struct concat std::size_t new_max = 0; for(const auto& input : inputs) { - auto ddim = input.dyn_dims()[axis]; - new_min += ddim.min(); - new_max += ddim.max(); + auto ddim = input.dyn_dims()[axis]; + auto dim_interval = ddim.get_interval(); + new_min += dim_interval.min; + new_max += dim_interval.max; } auto new_dims = inputs[0].dyn_dims(); diff --git a/src/include/migraphx/op/convolution.hpp b/src/include/migraphx/op/convolution.hpp index 6357d8191dd..41c5f1fc18b 100644 --- a/src/include/migraphx/op/convolution.hpp +++ b/src/include/migraphx/op/convolution.hpp @@ -164,8 +164,9 @@ struct convolution x.get_optimals().end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); + auto x_interval = x.get_interval(); output_dyn_dims.push_back(shape::dynamic_dimension{ - ceil_div(x.min(), s), ceil_div(x.max(), s), optimals}); + ceil_div(x_interval.min, s), ceil_div(x_interval.max, s), optimals}); } else { diff --git a/src/include/migraphx/op/gathernd.hpp b/src/include/migraphx/op/gathernd.hpp index 91e9644b7a0..5c9db07124a 100644 --- a/src/include/migraphx/op/gathernd.hpp +++ b/src/include/migraphx/op/gathernd.hpp @@ -63,7 +63,7 @@ struct gathernd MIGRAPHX_THROW( "GATHERND: last dimension of indices tensor must be fixed (min=max)"); } - k = i_shape.dyn_dims().back().min(); + k = i_shape.dyn_dims().back().get_interval().min; } else k = i_shape.lens().back(); diff --git a/src/include/migraphx/op/nonmaxsuppression.hpp b/src/include/migraphx/op/nonmaxsuppression.hpp index 442da4d5b0b..357e3acd562 100644 --- a/src/include/migraphx/op/nonmaxsuppression.hpp +++ b/src/include/migraphx/op/nonmaxsuppression.hpp @@ -156,12 +156,14 @@ struct nonmaxsuppression // check that it is only a dynamic number of classes const auto scores_dims = inputs.at(1).dyn_dims(); const auto boxes_lens = inputs.at(0).lens(); - if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max() != boxes_lens.at(0)) + if(not scores_dims.at(0).is_fixed() or + scores_dims.at(0).get_interval().max != boxes_lens.at(0)) { MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; num_batches not " "fixed or mismatched"); } - if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max() != boxes_lens.at(1)) + if(not scores_dims.at(2).is_fixed() or + scores_dims.at(2).get_interval().max != boxes_lens.at(1)) { MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; " "spatial_dimension not fixed or mismatches"); diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index d6ab4283509..429a9eeba04 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -218,8 +218,9 @@ struct pooling x.get_optimals().end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); + auto x_interval = x.get_interval(); output_dyn_dims.push_back(shape::dynamic_dimension{ - ceil_div(x.min(), s), ceil_div(x.max(), s), optimals}); + ceil_div(x_interval.min, s), ceil_div(x_interval.max, s), optimals}); } return {input.type(), output_dyn_dims}; } diff --git a/src/include/migraphx/op/reduce_op.hpp b/src/include/migraphx/op/reduce_op.hpp index ea2c0c55bbd..84a5573a1e6 100644 --- a/src/include/migraphx/op/reduce_op.hpp +++ b/src/include/migraphx/op/reduce_op.hpp @@ -120,7 +120,7 @@ struct reduce_op : op_name if(axes.empty()) { std::transform(dims.begin(), dims.end(), dims.begin(), [](const auto& dim) { - return shape::dynamic_dimension{1, dim.max()}; + return shape::dynamic_dimension{1, dim.get_interval().max}; }); } else diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index b79919a69fc..51c68d4c924 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -104,16 +104,18 @@ struct reshape std::size_t max_cur_elements = 1; for(const auto& dd : output_dyn_dims) { - min_cur_elements = mul_sat(min_cur_elements, dd.min()); - max_cur_elements = mul_sat(max_cur_elements, dd.max()); + auto dd_interval = dd.get_interval(); + min_cur_elements = mul_sat(min_cur_elements, dd_interval.min); + max_cur_elements = mul_sat(max_cur_elements, dd_interval.max); } // accumulate the elements in the input dimensions std::size_t min_input_elements = 1; std::size_t max_input_elements = 1; for(const auto& dd : input_dyn_dims) { - min_input_elements = mul_sat(min_input_elements, dd.min()); - max_input_elements = mul_sat(max_input_elements, dd.max()); + auto dd_interval = dd.get_interval(); + min_input_elements = mul_sat(min_input_elements, dd_interval.min); + max_input_elements = mul_sat(max_input_elements, dd_interval.max); } // maximum dimensions should never accumulate to zero diff --git a/src/include/migraphx/op/reshape_lazy.hpp b/src/include/migraphx/op/reshape_lazy.hpp index e1cde2810bc..845fb5cd43e 100644 --- a/src/include/migraphx/op/reshape_lazy.hpp +++ b/src/include/migraphx/op/reshape_lazy.hpp @@ -66,7 +66,7 @@ struct reshape_lazy if(dyn_dims[i].is_fixed()) { num_dims_ele *= dims[i]; - num_dd_ele *= dyn_dims[i].min(); + num_dd_ele *= dyn_dims[i].get_interval().min; } else { diff --git a/src/include/migraphx/op/resize.hpp b/src/include/migraphx/op/resize.hpp index a13af624116..e3646612653 100644 --- a/src/include/migraphx/op/resize.hpp +++ b/src/include/migraphx/op/resize.hpp @@ -416,12 +416,12 @@ struct resize { for(std::size_t i = 0; i < scales.size(); i++) { - auto new_min = static_cast(input.dyn_dims()[i].min() * scales[i]); - auto new_max = - input.dyn_dims()[i].max() != max_val - ? static_cast(input.dyn_dims()[i].max() * scales[i]) - : max_val; - dyn_dims[i] = shape::dynamic_dimension{new_min, new_max}; + auto input_interval = input.dyn_dims()[i].get_interval(); + auto new_min = static_cast(input_interval.min * scales[i]); + auto new_max = input_interval.max != max_val + ? static_cast(input_interval.max * scales[i]) + : max_val; + dyn_dims[i] = shape::dynamic_dimension{new_min, new_max}; } } return {input.type(), dyn_dims}; diff --git a/src/include/migraphx/op/scatternd_op.hpp b/src/include/migraphx/op/scatternd_op.hpp index 27facdeb864..9223a52fe0d 100644 --- a/src/include/migraphx/op/scatternd_op.hpp +++ b/src/include/migraphx/op/scatternd_op.hpp @@ -69,7 +69,7 @@ struct scatternd_op : op_name MIGRAPHX_THROW( "GATHERND: last dimension of indices tensor must be fixed (min=max)"); } - k = index_shape.dyn_dims().back().min(); + k = index_shape.dyn_dims().back().get_interval().min; } else k = index_shape.lens().back(); diff --git a/src/include/migraphx/op/slice.hpp b/src/include/migraphx/op/slice.hpp index e615b77e73d..47294a70358 100644 --- a/src/include/migraphx/op/slice.hpp +++ b/src/include/migraphx/op/slice.hpp @@ -166,7 +166,7 @@ struct slice MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch"); } std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { - dds.at(axis) = {0, dds.at(axis).max()}; + dds.at(axis) = {0, dds.at(axis).get_interval().max}; }); } else if(set_attributes == starts_axes) @@ -177,7 +177,7 @@ struct slice MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch"); } std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { - dds.at(axis) = {0, dds.at(axis).max()}; + dds.at(axis) = {0, dds.at(axis).get_interval().max}; }); } else if(set_attributes == starts_ends) @@ -188,7 +188,7 @@ struct slice MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch"); } std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max()}; + return shape::dynamic_dimension{0, dd.get_interval().max}; }); } else @@ -206,7 +206,7 @@ struct slice MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch"); } std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { - dds.at(axis) = {0, dds.at(axis).max()}; + dds.at(axis) = {0, dds.at(axis).get_interval().max}; }); } else if(set_attributes == ends_only) @@ -217,7 +217,7 @@ struct slice MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch"); } std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max()}; + return shape::dynamic_dimension{0, dd.get_interval().max}; }); } else if(set_attributes == starts_only) @@ -229,7 +229,7 @@ struct slice MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch"); } std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max()}; + return shape::dynamic_dimension{0, dd.get_interval().max}; }); } else @@ -241,7 +241,7 @@ struct slice { // all 4 inputs (data, inputs_starts, input_ends, input_axes) std::transform(dds.begin(), dds.end(), dds.begin(), [](const auto& dd) { - return shape::dynamic_dimension{0, dd.max()}; + return shape::dynamic_dimension{0, dd.get_interval().max}; }); } return shape{input_shape.type(), dds}; diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 8460f4a8179..390f1bcdc81 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -128,8 +128,6 @@ struct MIGRAPHX_EXPORT shape return pack(f(self.range, "range"), f(self.optimals, "optimals")); } - std::size_t min() const { return range.min; } - std::size_t max() const { return range.max; } interval get_interval() const { return range; } const std::set& get_optimals() const { return optimals; } @@ -142,8 +140,10 @@ struct MIGRAPHX_EXPORT shape */ std::optional intersection(const dynamic_dimension& other) const { - auto left = std::max(this->min(), other.min()); - auto right = std::min(this->max(), other.max()); + auto this_interval = this->get_interval(); + auto other_interval = other.get_interval(); + auto left = std::max(this_interval.min, other_interval.min); + auto right = std::min(this_interval.max, other_interval.max); if(left <= right) { return dynamic_dimension{left, right}; diff --git a/src/normalize_attributes.cpp b/src/normalize_attributes.cpp index d61c57740cd..48804c9034f 100644 --- a/src/normalize_attributes.cpp +++ b/src/normalize_attributes.cpp @@ -74,7 +74,7 @@ static auto tune_attribute(const std::vector& vec, return vec; } std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { - return input_shape.dyn_dims().at(i).max(); + return input_shape.dyn_dims().at(i).get_interval().max; }); } else diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index b0a74267ba1..95a58661a6a 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -54,7 +54,7 @@ static shape shape_from_dyn_dims(shape::type_t shape_type, std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), - [](const auto& d) { return d.max(); }); + [](const auto& d) { return d.get_interval().max; }); return {shape_type, dims}; } return {shape_type, dyn_dims}; diff --git a/src/onnx/parse_depthtospace.cpp b/src/onnx/parse_depthtospace.cpp index 3b0a503b56b..2c180b49b1f 100644 --- a/src/onnx/parse_depthtospace.cpp +++ b/src/onnx/parse_depthtospace.cpp @@ -74,7 +74,7 @@ struct parse_depthtospace : op_parser { MIGRAPHX_THROW("DepthToSpace: dynamic channels are not supported"); } - int64_t c = dyn_dims1[1].max(); + int64_t c = dyn_dims1[1].get_interval().max; auto h = info.add_instruction(make_op("dimensions_of", {{"start", 2}, {"end", 3}}), args[0]); auto w = @@ -82,9 +82,10 @@ struct parse_depthtospace : op_parser auto c_div = info.add_literal({c / divisor}); - dyn_dims2[1] = {dyn_dims2[1].min() / divisor, dyn_dims2[1].max() / divisor}; - dyn_dims2[2] = dyn_dims2[2] * blocksize_unsigned; - dyn_dims2[3] = dyn_dims2[3] * blocksize_unsigned; + auto chan_interval = dyn_dims2[1].get_interval(); + dyn_dims2[1] = {chan_interval.min / divisor, chan_interval.max / divisor}; + dyn_dims2[2] = dyn_dims2[2] * blocksize_unsigned; + dyn_dims2[3] = dyn_dims2[3] * blocksize_unsigned; // push back h and w to expand the vector to 6d dyn_dims1.push_back(dyn_dims1[2]); dyn_dims1.push_back(dyn_dims1[3]); @@ -96,10 +97,11 @@ struct parse_depthtospace : op_parser if(mode == "DCR") { // expanded vector = n, blocksize, blocksize, c // (blocksize**2), h, w - dyn_dims1[3] = {dyn_dims1[1].min() / divisor, dyn_dims1[1].max() / divisor, {}}; - dyn_dims1[1] = {blocksize_unsigned, blocksize_unsigned, {}}; - perm = {0, 3, 4, 1, 5, 2}; - new_shape1 = info.add_instruction( + auto dcr_chan = dyn_dims1[1].get_interval(); + dyn_dims1[3] = {dcr_chan.min / divisor, dcr_chan.max / divisor, {}}; + dyn_dims1[1] = {blocksize_unsigned, blocksize_unsigned, {}}; + perm = {0, 3, 4, 1, 5, 2}; + new_shape1 = info.add_instruction( make_op("concat"), n, blocksize_literal, blocksize_literal, c_div, h, w); new_shape_alloc1 = info.add_instruction( make_op("allocate", @@ -121,10 +123,11 @@ struct parse_depthtospace : op_parser else if(mode == "CRD") { // expanded vector = b, c // (blocksize ** 2), blocksize, blocksize, h, w - dyn_dims1[1] = {dyn_dims1[1].min() / divisor, dyn_dims1[1].max() / divisor, {}}; - dyn_dims1[3] = {blocksize_unsigned, blocksize_unsigned, {}}; - perm = {0, 1, 4, 2, 5, 3}; - new_shape1 = info.add_instruction( + auto crd_chan = dyn_dims1[1].get_interval(); + dyn_dims1[1] = {crd_chan.min / divisor, crd_chan.max / divisor, {}}; + dyn_dims1[3] = {blocksize_unsigned, blocksize_unsigned, {}}; + perm = {0, 1, 4, 2, 5, 3}; + new_shape1 = info.add_instruction( make_op("concat"), n, c_div, blocksize_literal, blocksize_literal, h, w); new_shape_alloc1 = info.add_instruction( make_op("allocate", diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index c95886d9143..cebec8cd83d 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -397,10 +397,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) .def(py::init<>()) .def(py::init()) .def(py::init>()) - .def_property_readonly("min", - [](const migraphx::shape::dynamic_dimension& d) { return d.min(); }) - .def_property_readonly("max", - [](const migraphx::shape::dynamic_dimension& d) { return d.max(); }) + .def_property_readonly( + "min", [](const migraphx::shape::dynamic_dimension& d) { return d.get_interval().min; }) + .def_property_readonly( + "max", [](const migraphx::shape::dynamic_dimension& d) { return d.get_interval().max; }) .def_property_readonly( "optimals", [](const migraphx::shape::dynamic_dimension& d) { return d.get_optimals(); }) diff --git a/src/shape.cpp b/src/shape.cpp index a4eabe56b8f..3c26fe0f897 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -192,7 +192,7 @@ struct shape_impl std::transform(m_dyn_dims.cbegin(), m_dyn_dims.cend(), ret.begin(), - [](const shape::dynamic_dimension& x) { return x.min(); }); + [](const shape::dynamic_dimension& x) { return x.get_interval().min; }); return ret; } @@ -202,7 +202,7 @@ struct shape_impl std::transform(m_dyn_dims.cbegin(), m_dyn_dims.cend(), ret.begin(), - [](const shape::dynamic_dimension& x) { return x.max(); }); + [](const shape::dynamic_dimension& x) { return x.get_interval().max; }); return ret; } @@ -321,7 +321,10 @@ bool shape::is_compatible_lens(const shape& actual, const shape& expected) return std::equal(actual.lens().begin(), actual.lens().end(), expected.dyn_dims().begin(), - [&](auto a, const auto& e) { return a >= e.min() and a <= e.max(); }); + [&](auto a, const auto& e) { + auto expected_interval = e.get_interval(); + return a >= expected_interval.min and a <= expected_interval.max; + }); } return actual.lens() == expected.lens(); } @@ -705,7 +708,11 @@ std::vector shape::max_lens() const std::vector> shape::opt_lens() const { return impl->opt_lens(); } -bool shape::dynamic_dimension::is_fixed() const { return this->min() == this->max(); } +bool shape::dynamic_dimension::is_fixed() const +{ + auto i = this->get_interval(); + return i.min == i.max; +} bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); } @@ -755,7 +762,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { - return (x.min() == y.min() and x.max() == y.max() and + return (x.get_interval() == y.get_interval() and ((x.is_fixed() and y.is_fixed()) or (x.get_optimals() == y.get_optimals()))); } @@ -765,14 +772,16 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio } std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { - os << "[ " << x.min() << ", " << x.max() << ", {" << migraphx::to_string_range(x.get_optimals()) - << "} ]"; + auto x_interval = x.get_interval(); + os << "[ " << x_interval.min << ", " << x_interval.max << ", {" + << migraphx::to_string_range(x.get_optimals()) << "} ]"; return os; } bool operator==(const shape::dynamic_dimension& x, const std::size_t& y) { - return x.min() == y and x.max() == y; + auto x_interval = x.get_interval(); + return x_interval.min == y and x_interval.max == y; } bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; } bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); } diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index d4a538cc545..072fa5aefe3 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -128,7 +128,8 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const auto dyn_dim = dd_check_vec->at(0).dd; // create submodules for each dimension size std::vector submodules; - for(size_t dim_size : migraphx::range(dyn_dim.min(), dyn_dim.max() + 1)) + auto dim_interval = dyn_dim.get_interval(); + for(size_t dim_size : migraphx::range(dim_interval.min, dim_interval.max + 1)) { auto* submod = mpm.create_module("dim_" + std::to_string(dim_size)); // instruction map for new static shaped submodule parameters From c128adc30b2a06f92a4a21e6aaf3d745c3ecf0d7 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 08:57:06 -0700 Subject: [PATCH 45/60] fix cppcheck --- src/include/migraphx/op/resize.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/op/resize.hpp b/src/include/migraphx/op/resize.hpp index e3646612653..1e582961203 100644 --- a/src/include/migraphx/op/resize.hpp +++ b/src/include/migraphx/op/resize.hpp @@ -417,11 +417,12 @@ struct resize for(std::size_t i = 0; i < scales.size(); i++) { auto input_interval = input.dyn_dims()[i].get_interval(); - auto new_min = static_cast(input_interval.min * scales[i]); - auto new_max = input_interval.max != max_val - ? static_cast(input_interval.max * scales[i]) - : max_val; - dyn_dims[i] = shape::dynamic_dimension{new_min, new_max}; + std::size_t new_min = input_interval.min * scales[i]; + std::size_t new_max = + input_interval.max == max_val + ? max_val + : static_cast(input_interval.max * scales[i]); + dyn_dims[i] = shape::dynamic_dimension{new_min, new_max}; } } return {input.type(), dyn_dims}; From 3b2c259df51dad3e76f5f71dff6252090a0bf697 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 09:01:12 -0700 Subject: [PATCH 46/60] return optimals by value --- src/include/migraphx/op/convolution.hpp | 7 ++++--- src/include/migraphx/op/pooling.hpp | 7 ++++--- src/include/migraphx/shape.hpp | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/include/migraphx/op/convolution.hpp b/src/include/migraphx/op/convolution.hpp index 41c5f1fc18b..e3b3058f087 100644 --- a/src/include/migraphx/op/convolution.hpp +++ b/src/include/migraphx/op/convolution.hpp @@ -158,10 +158,11 @@ struct convolution auto s = stride[i]; if(x_shape.dynamic()) { - auto x = x_shape.dyn_dims()[i + 2]; + auto x = x_shape.dyn_dims()[i + 2]; + auto x_opts = x.get_optimals(); std::set optimals{}; - std::transform(x.get_optimals().begin(), - x.get_optimals().end(), + std::transform(x_opts.begin(), + x_opts.end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); auto x_interval = x.get_interval(); diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index 429a9eeba04..1ba528ab0c7 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -212,10 +212,11 @@ struct pooling auto ceil_div = [](std::size_t x, std::size_t y) { return (x + y - 1) / y; }; auto s = stride[i]; - auto x = x_shape.dyn_dims()[i + 2]; + auto x = x_shape.dyn_dims()[i + 2]; + auto x_opts = x.get_optimals(); std::set optimals{}; - std::transform(x.get_optimals().begin(), - x.get_optimals().end(), + std::transform(x_opts.begin(), + x_opts.end(), std::inserter(optimals, optimals.begin()), [&](auto o) { return ceil_div(o, s); }); auto x_interval = x.get_interval(); diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 390f1bcdc81..d8f017c974a 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -129,7 +129,7 @@ struct MIGRAPHX_EXPORT shape } interval get_interval() const { return range; } - const std::set& get_optimals() const { return optimals; } + std::set get_optimals() const { return optimals; } bool is_fixed() const; bool has_optimal() const; From cc74c4a45680d8cde4a88723923a594079ef5163 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 09:29:23 -0700 Subject: [PATCH 47/60] update has_optimal --- src/shape.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shape.cpp b/src/shape.cpp index 3c26fe0f897..7483f5cb91a 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -714,7 +714,7 @@ bool shape::dynamic_dimension::is_fixed() const return i.min == i.max; } -bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); } +bool shape::dynamic_dimension::has_optimal() const { return not this->get_optimals().empty(); } shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) { From 44cd1753f18171d0c99a865fa7f23dbe18ab84df Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 09:34:25 -0700 Subject: [PATCH 48/60] symbolic dimension integration (squashed) --- src/gemm.cpp | 13 +- src/include/migraphx/shape.hpp | 75 ++++-- src/shape.cpp | 396 ++++++++++++++++++++++++------ src/sym.cpp | 10 +- src/targets/gpu/gemm_impl.cpp | 6 +- src/targets/gpu/hip_gemm_impl.cpp | 6 +- test/serialize_program.cpp | 21 ++ test/shape_test.cpp | 256 ++++++++++++++++++- 8 files changed, 670 insertions(+), 113 deletions(-) diff --git a/src/gemm.cpp b/src/gemm.cpp index 2deef7fb673..49c21a22fb8 100644 --- a/src/gemm.cpp +++ b/src/gemm.cpp @@ -72,15 +72,18 @@ struct batch_slicer batch_slicer(const shape& mat_shape) { auto n_batch_dims = mat_shape.ndim() - 2; - inner_shape = shape{mat_shape.type(), - {mat_shape.lens().end() - 2, mat_shape.lens().end()}, - {mat_shape.strides().end() - 2, mat_shape.strides().end()}}; + inner_shape = shape{ + mat_shape.type(), + std::vector{mat_shape.lens().end() - 2, mat_shape.lens().end()}, + std::vector{mat_shape.strides().end() - 2, mat_shape.strides().end()}}; if(n_batch_dims > 0) { outer_shape = shape{mat_shape.type(), - {mat_shape.lens().begin(), mat_shape.lens().begin() + n_batch_dims}, - {mat_shape.strides().begin(), mat_shape.strides().begin() + n_batch_dims}}; + std::vector{mat_shape.lens().begin(), + mat_shape.lens().begin() + n_batch_dims}, + std::vector{mat_shape.strides().begin(), + mat_shape.strides().begin() + n_batch_dims}}; } } diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index d8f017c974a..a58b36e36fe 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -39,6 +39,7 @@ #include #include #include +#include #include namespace migraphx { @@ -114,40 +115,64 @@ struct MIGRAPHX_EXPORT shape interval range = {0, 0}; std::set optimals{}; + sym::expr sym_expr; dynamic_dimension() = default; - dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} {} + dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} + { + normalize_sym(); + } dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt) : range{min_v, max_v}, optimals(std::move(opt)) + { + normalize_sym(); + } + dynamic_dimension(sym::expr s); + dynamic_dimension(std::size_t min_v, + std::size_t max_v, + std::set opt, + sym::expr s) + : range{min_v, max_v}, optimals(std::move(opt)), sym_expr(std::move(s)) { } template static auto reflect(Self& self, F f) { - return pack(f(self.range, "range"), f(self.optimals, "optimals")); + return pack(f(self.range, "range"), f(self.optimals, "optimals"), f(self.sym_expr, "sym")); } interval get_interval() const { return range; } std::set get_optimals() const { return optimals; } bool is_fixed() const; + bool is_symbolic() const { return not sym_expr.empty(); } + void normalize_sym() + { + if(is_fixed() and not is_symbolic()) + sym_expr = sym::lit(range.min); + } bool has_optimal() const; /** * Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if - * possible. + * possible. When both dimensions are symbolic, they are compatible only if they + * share the same symbolic expression. */ std::optional intersection(const dynamic_dimension& other) const { + if(this->is_symbolic() and other.is_symbolic()) + { + if(this->sym_expr == other.sym_expr) + return *this; + return nullopt; + } auto this_interval = this->get_interval(); auto other_interval = other.get_interval(); auto left = std::max(this_interval.min, other_interval.min); auto right = std::min(this_interval.max, other_interval.max); if(left <= right) - { return dynamic_dimension{left, right}; - } return nullopt; } @@ -164,20 +189,24 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y); - // add, subtract, multiply fixed std::size_t dimension - dynamic_dimension& operator+=(const std::size_t& x); - dynamic_dimension& operator-=(const std::size_t& x); - dynamic_dimension& operator*=(const std::size_t& x); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const std::size_t& x, - const dynamic_dimension& y); + // clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(binary_op, assign_op) \ + dynamic_dimension& operator assign_op(const dynamic_dimension& x); \ + dynamic_dimension& operator assign_op(const std::size_t& x); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const dynamic_dimension& x, const dynamic_dimension& y); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const dynamic_dimension& x, const std::size_t& y); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const std::size_t& x, const dynamic_dimension& y); + + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(+, +=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(-, -=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(*, *=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(/, /=) +#undef MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP + // clang-format on }; static std::string to_sizes_string(const std::vector& shapes); @@ -202,8 +231,10 @@ struct MIGRAPHX_EXPORT shape // Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to // shape(type_t, std::vector l) shape(type_t t, std::initializer_list d); + shape(type_t t, std::initializer_list l, std::initializer_list s); shape(type_t t, std::vector dims); + shape(type_t t, std::vector dims, std::vector dstrides); // Construct a dynamic shape from vectors of mins, maxes, and optimals. // optimals_list is a vector of optimals that corresponds to each min and max. @@ -272,6 +303,9 @@ struct MIGRAPHX_EXPORT shape const std::vector& dyn_dims() const; + bool symbolic() const; + const std::vector& dyn_strides() const; + /*! * Minimum lengths for dynamic shape. * lens() for static shape. @@ -391,11 +425,12 @@ struct MIGRAPHX_EXPORT shape shape with_type(type_t t) const; - // convert the shape to an equivalent dynamic shape with empty optimals + // convert the shape to an equivalent dynamic shape with constant symbolic strides shape to_dynamic() const; // convert the shape to a static one setting any non-fixed dynamic_dimensions to x shape to_static(std::size_t x) const; + shape to_static(const std::unordered_map& symbol_map) const; MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); diff --git a/src/shape.cpp b/src/shape.cpp index 7483f5cb91a..510d0ab78e1 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -78,6 +79,18 @@ struct shape_impl shape_impl(shape::type_t t, std::vector dims) : m_type(t), m_dyn_dims(std::move(dims)) { + if(not m_dyn_dims.empty() and std::all_of(m_dyn_dims.begin(), + m_dyn_dims.end(), + [](const auto& d) { return d.is_symbolic(); })) + calculate_dyn_strides(); + } + + shape_impl(shape::type_t t, + std::vector dims, + std::vector dstrides) + : m_type(t), m_dyn_dims(std::move(dims)), m_dyn_strides(std::move(dstrides)) + { + assert(m_dyn_strides.size() == m_dyn_dims.size()); } shape_impl(shape::type_t t, @@ -101,6 +114,10 @@ struct shape_impl m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]}); } } + if(not m_dyn_dims.empty() and std::all_of(m_dyn_dims.begin(), + m_dyn_dims.end(), + [](const auto& d) { return d.is_symbolic(); })) + calculate_dyn_strides(); } shape_impl(const std::vector& subs) : m_type(shape::tuple_type), m_shapes(subs) {} @@ -112,20 +129,49 @@ struct shape_impl bool m_standard = false; std::vector m_dyn_dims = {}; + std::vector m_dyn_strides = {}; + + std::vector sym_dims() const + { + if(m_dyn_dims.empty()) + { + std::vector result(m_lens.size()); + std::transform(m_lens.begin(), m_lens.end(), result.begin(), [](auto len) { + return sym::lit(len); + }); + return result; + } + std::vector result(m_dyn_dims.size()); + std::transform(m_dyn_dims.begin(), m_dyn_dims.end(), result.begin(), [](const auto& dd) { + return dd.sym_expr; + }); + return result; + } + + template + static T make_identity(int64_t n) + { + if constexpr(std::is_same{}) + return sym::lit(n); + else + return T(n); + } - void calculate_strides() + template + static std::vector compute_strides(const std::vector& dims) { - m_strides.clear(); - m_strides.resize(m_lens.size(), 0); - if(m_strides.empty()) - return; - m_strides.back() = 1; - std::partial_sum(m_lens.rbegin(), - m_lens.rend() - 1, - m_strides.rbegin() + 1, - std::multiplies()); + std::vector strides(dims.size()); + if(strides.empty()) + return strides; + strides.back() = make_identity(1); + std::partial_sum(dims.rbegin(), dims.rend() - 1, strides.rbegin() + 1, std::multiplies<>{}); + return strides; } + void calculate_dyn_strides() { m_dyn_strides = compute_strides(sym_dims()); } + + void calculate_strides() { m_strides = compute_strides(m_lens); } + std::size_t element_space() const { if(not m_dyn_dims.empty()) @@ -357,11 +403,23 @@ shape::shape(type_t t, std::initializer_list d) { } +shape::shape(type_t t, std::initializer_list l, std::initializer_list s) + : shape::shape(t, + std::vector{l.begin(), l.end()}, + std::vector{s.begin(), s.end()}) +{ +} + shape::shape(type_t t, std::vector dims) : impl(std::make_shared(t, std::move(dims))) { } +shape::shape(type_t t, std::vector dims, std::vector dstrides) + : impl(std::make_shared(t, std::move(dims), std::move(dstrides))) +{ +} + shape::shape(type_t t, std::vector mins, std::vector maxes, @@ -641,7 +699,17 @@ shape shape::to_dynamic() const { return *this; } - return {type(), lens(), lens(), {}}; + std::vector dims; + dims.reserve(ndim()); + std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { + return dynamic_dimension{len, len}; + }); + std::vector dstrides; + dstrides.reserve(ndim()); + std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) { + return sym::lit(s); + }); + return {type(), std::move(dims), std::move(dstrides)}; } shape shape::to_static(std::size_t x) const @@ -668,6 +736,40 @@ shape shape::to_static(std::size_t x) const return {type(), static_lens}; } +shape shape::to_static(const std::unordered_map& symbol_map) const +{ + if(not sub_shapes().empty()) + { + std::vector subs; + std::transform(sub_shapes().cbegin(), + sub_shapes().cend(), + std::back_inserter(subs), + [&](auto s) { return s.to_static(symbol_map); }); + return shape(subs); + } + if(not this->dynamic()) + return *this; + std::vector static_lens(this->ndim()); + std::transform(this->dyn_dims().cbegin(), + this->dyn_dims().cend(), + static_lens.begin(), + [&](const auto& dd) -> std::size_t { + if(dd.is_fixed()) + return dd.get_interval().min; + if(not dd.sym_expr.empty()) + return dd.sym_expr.eval_uint(symbol_map); + MIGRAPHX_THROW("to_static: non-fixed dimension has no symbolic expression"); + }); + const auto& ds = this->dyn_strides(); + if(ds.empty()) + return {type(), static_lens}; + std::vector static_strides(ds.size()); + std::transform(ds.cbegin(), ds.cend(), static_strides.begin(), [&](const auto& s) { + return s.eval_uint(symbol_map); + }); + return {type(), static_lens, static_strides}; +} + std::size_t shape::element_space() const { return impl->element_space(); } std::string shape::type_string() const { return name(this->type()); } @@ -696,6 +798,16 @@ const std::vector& shape::dyn_dims() const return impl->m_dyn_dims; } +bool shape::symbolic() const +{ + return not impl->m_dyn_dims.empty() and + std::all_of(impl->m_dyn_dims.begin(), impl->m_dyn_dims.end(), [](const auto& dd) { + return dd.is_symbolic(); + }); +} + +const std::vector& shape::dyn_strides() const { return impl->m_dyn_strides; } + std::vector shape::min_lens() const { return this->dynamic() ? impl->min_lens() : this->lens(); @@ -716,52 +828,38 @@ bool shape::dynamic_dimension::is_fixed() const bool shape::dynamic_dimension::has_optimal() const { return not this->get_optimals().empty(); } -shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) -{ - this->range.min += x; - this->range.max += x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (opt + x); }); - this->optimals = new_optimals; - return *this; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x) -{ - assert(this->range.min >= x); - assert(this->range.max >= x); - this->range.min -= x; - this->range.max -= x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { - assert(opt >= x); - return (opt - x); - }); - this->optimals = new_optimals; - return *this; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t& x) -{ - this->range.min *= x; - this->range.max *= x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (opt * x); }); - this->optimals = new_optimals; - return *this; -} +// clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(binary_op, assign_op) \ + shape::dynamic_dimension& shape::dynamic_dimension::operator assign_op(const std::size_t& x) \ + { \ + return *this assign_op dynamic_dimension{x, x}; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const shape::dynamic_dimension& x, const std::size_t& y) \ + { \ + auto result = x; \ + result assign_op y; \ + return result; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const std::size_t& x, const shape::dynamic_dimension& y) \ + { \ + return shape::dynamic_dimension{x, x} binary_op y; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) \ + { \ + auto result = x; \ + result assign_op y; \ + return result; \ + } +// clang-format on bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { + if(not(x.sym_expr == y.sym_expr)) + return false; return (x.get_interval() == y.get_interval() and ((x.is_fixed() and y.is_fixed()) or (x.get_optimals() == y.get_optimals()))); } @@ -773,8 +871,15 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { auto x_interval = x.get_interval(); - os << "[ " << x_interval.min << ", " << x_interval.max << ", {" - << migraphx::to_string_range(x.get_optimals()) << "} ]"; + if(x.is_symbolic()) + os << x.sym_expr; + if(x.is_fixed()) + { + if(not x.is_symbolic()) + os << x_interval.min; + return os; + } + os << "[" << x_interval.min << ".." << x_interval.max << "]"; return os; } @@ -787,40 +892,144 @@ bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { retur bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); } bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); } -shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(+, +=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(-, -=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(*, *=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(/, /=) +#undef MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP + +// When one operand is fixed, shift the other's optimals by the fixed value. +// When neither is fixed, optimals are cleared. +template +static void merge_optimals(std::set& optimals, + bool lhs_fixed, + const std::set& rhs_optimals, + bool rhs_fixed, + F1 shift_lhs, + F2 shift_rhs) { - auto dd = x; - return dd += y; + if(rhs_fixed) + { + std::set result; + std::transform( + optimals.begin(), optimals.end(), std::inserter(result, result.begin()), shift_lhs); + optimals = result; + } + else if(lhs_fixed) + { + std::set result; + std::transform(rhs_optimals.begin(), + rhs_optimals.end(), + std::inserter(result, result.begin()), + shift_rhs); + optimals = result; + } + else + { + optimals.clear(); + } } -shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) { - return y + x; + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = range.min; + range.min = range.min + x.range.min; + range.max = (range.max > std::numeric_limits::max() - x.range.max) + ? std::numeric_limits::max() + : range.max + x.range.max; + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return o + x.range.min; }, + [&](auto o) { return o + lhs_min; }); + sym_expr = lhs_sym + rhs_sym; + normalize_sym(); + return *this; } -shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dynamic_dimension& x) { - auto dd = x; - return dd -= y; + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = range.min; + range.min = (range.min > x.range.max) ? range.min - x.range.max : 0; + range.max = (range.max > x.range.min) ? range.max - x.range.min : 0; + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return (o > x.range.min) ? o - x.range.min : 0; }, + [&](auto o) { return (lhs_min > o) ? lhs_min - o : 0; }); + sym_expr = lhs_sym - rhs_sym; + normalize_sym(); + return *this; } -shape::dynamic_dimension operator*(const shape::dynamic_dimension& x, const std::size_t& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto dd = x; - return dd *= y; + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = range.min; + range.min = range.min * x.range.min; + auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { + if(b == 0) + return 0; + if(a > std::numeric_limits::max() / b) + return std::numeric_limits::max(); + return a * b; + }; + range.max = safe_mul(range.max, x.range.max); + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return o * x.range.min; }, + [&](auto o) { return o * lhs_min; }); + sym_expr = lhs_sym * rhs_sym; + normalize_sym(); + return *this; } -shape::dynamic_dimension operator*(const std::size_t& x, const shape::dynamic_dimension& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dynamic_dimension& x) { - return y * x; + auto lhs_sym = sym_expr; + auto rhs_sym = x.sym_expr; + auto lhs_fixed = this->is_fixed(); + auto rhs_fixed = x.is_fixed(); + auto lhs_min = range.min; + range.min = (x.range.max == 0) ? 0 : range.min / x.range.max; + range.max = (x.range.min == 0) ? std::numeric_limits::max() : range.max / x.range.min; + merge_optimals( + optimals, + lhs_fixed, + x.optimals, + rhs_fixed, + [&](auto o) { return (x.range.min == 0) ? std::size_t{0} : o / x.range.min; }, + [&](auto o) { return (o == 0) ? std::size_t{0} : lhs_min / o; }); + sym_expr = lhs_sym / rhs_sym; + normalize_sym(); + return *this; } bool operator==(const shape& x, const shape& y) { if(x.dynamic() and y.dynamic()) { - return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and - x.sub_shapes() == y.sub_shapes()); + return x.impl == y.impl or + (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and + x.dyn_strides() == y.dyn_strides() and x.sub_shapes() == y.sub_shapes()); } return x.impl == y.impl or (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and @@ -833,7 +1042,23 @@ std::ostream& operator<<(std::ostream& os, const shape& x) { if(x.sub_shapes().empty()) { - if(x.dynamic()) + if(x.symbolic()) + { + os << x.type_string() << ", {"; + const auto& dd = x.dyn_dims(); + for(std::size_t i = 0; i < dd.size(); ++i) + { + if(i > 0) + os << ", "; + if(dd[i].is_symbolic()) + os << dd[i]; + else + os << dd[i].get_interval().min; + } + os << "}, "; + os << "{" << to_string_range(x.dyn_strides()) << "}"; + } + else if(x.dynamic()) { os << "dynamic, "; os << x.type_string() << ", "; @@ -900,7 +1125,6 @@ void migraphx_to_value(value& v, const shape& s) value result; result["type"] = migraphx::to_value(s.type_string()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); - // avoid calling functions that will throw if(s.dynamic()) { result["lens"] = {}; @@ -913,6 +1137,14 @@ void migraphx_to_value(value& v, const shape& s) result["strides"] = migraphx::to_value(s.strides()); result["dynamic_dimensions"] = {}; } + if(s.symbolic()) + { + result["dyn_strides"] = migraphx::to_value(s.dyn_strides()); + } + else + { + result["dyn_strides"] = {}; + } v = result; } @@ -934,13 +1166,27 @@ void migraphx_from_value(const value& v, shape& s) else { auto v_dd = v.at("dynamic_dimensions"); - std::vector dyn_dims(v.at("dynamic_dimensions").size()); + std::vector dyn_dims(v_dd.size()); std::transform( v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) { return from_value(x); }); - s = shape{shape::parse_type(t), dyn_dims}; + if(v.contains("dyn_strides") and not v.at("dyn_strides").empty()) + { + auto v_ds = v.at("dyn_strides"); + std::vector dstrides; + dstrides.reserve(v_ds.size()); + std::transform(v_ds.begin(), + v_ds.end(), + std::back_inserter(dstrides), + [](const auto& x) { return from_value(x); }); + s = shape(shape::parse_type(t), std::move(dyn_dims), std::move(dstrides)); + } + else + { + s = shape{shape::parse_type(t), dyn_dims}; + } } } } diff --git a/src/sym.cpp b/src/sym.cpp index 00d9480e810..fd552980727 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -1002,7 +1002,7 @@ static expr_ptr node_from_value(const value& v) const auto& type = v.at("type").get_string(); if(type == "int") { - return make_integer(v.at("value").get_int64()); + return make_integer(v.at("value").to()); } else if(type == "sym") { @@ -1010,24 +1010,24 @@ static expr_ptr node_from_value(const value& v) } else if(type == "add") { - auto constant = v.at("constant").get_int64(); + auto constant = v.at("constant").to(); term_map terms; for(const auto& t : v.at("terms")) { auto term = node_from_value(t.at("expr")); - auto coeff = t.at("coeff").get_int64(); + auto coeff = t.at("coeff").to(); terms[term] = coeff; } return build_add(constant, std::move(terms)); } else if(type == "mul") { - auto coefficient = v.at("coeff").get_int64(); + auto coefficient = v.at("coeff").to(); factor_map factors; for(const auto& f : v.at("factors")) { auto base = node_from_value(f.at("expr")); - auto exp = f.at("exp").get_int64(); + auto exp = f.at("exp").to(); factors[base] = exp; } return build_mul(coefficient, std::move(factors)); diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 381e68dfabe..a2d39c8a656 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -94,9 +94,9 @@ void blas_shape(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; - shape batch_shape{s.type(), - {s.lens().begin(), s.lens().end() - 2}, - {s.strides().begin(), s.strides().end() - 2}}; + shape batch_shape(s.type(), + std::vector(s.lens().begin(), s.lens().end() - 2), + std::vector(s.strides().begin(), s.strides().end() - 2)); auto batch_shapes = reduce_dims({batch_shape}); if(batch_shapes.front().lens().size() != 1) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index c3766e1cdf5..dfff40dadab 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -90,9 +90,9 @@ void blas_shape_hip(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; - shape batch_shape{s.type(), - {s.lens().begin(), s.lens().end() - 2}, - {s.strides().begin(), s.strides().end() - 2}}; + shape batch_shape(s.type(), + std::vector(s.lens().begin(), s.lens().end() - 2), + std::vector(s.strides().begin(), s.strides().end() - 2)); auto batch_shapes = reduce_dims({batch_shape}); if(batch_shapes.front().lens().size() != 1) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index 8ba5d216f93..783b4d4dbe6 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -26,6 +26,7 @@ #include #include "test.hpp" #include +#include #include #include @@ -140,6 +141,26 @@ TEST_CASE(program_with_module) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(symbolic_shape_msgpack_roundtrip) +{ + using migraphx::shape; + using dd = shape::dynamic_dimension; + auto n = migraphx::sym::var("n"); + + migraphx::program p; + auto* mm = p.get_main_module(); + shape s{shape::float_type, {dd{1, 8, {}, n}, {3, 3}, {4, 4}}}; + auto x = mm->add_parameter("x", s); + auto r = mm->add_instruction(migraphx::make_op("relu"), x); + mm->add_return({r}); + + migraphx::file_options options; + options.format = "msgpack"; + std::vector buffer = migraphx::save_buffer(p, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p.sort() == p2.sort()); +} + static migraphx::program create_program_with_debug_symbols() { migraphx::program p; diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 386c9058aeb..6933ef85747 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -33,6 +34,9 @@ #include #include "test.hpp" +using migraphx::sym::lit; +using migraphx::sym::var; + TEST_CASE(test_shape_default) { migraphx::shape s{}; @@ -441,7 +445,9 @@ TEST_CASE(test_shape_static_to_dynamic) { migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s1 = s0.to_dynamic(); - migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}}; + migraphx::shape s2{migraphx::shape::float_type, + {{1, 1}, {2, 2}, {4, 4}, {4, 4}}, + {lit(32), lit(16), lit(4), lit(1)}}; EXPECT(s1 == s2); } @@ -461,7 +467,8 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}}); + sub_shapes1.push_back(migraphx::shape{ + migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } @@ -1272,4 +1279,249 @@ TEST_CASE(shape_same_lens_static_dynamic) EXPECT(not migraphx::shape::same_lens(s1, s3)); } +// =================================================================== +// Symbolic dynamic_dimension tests +// =================================================================== + +TEST_CASE(test_dd_symbolic_add_size_t) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{1, 8, {4}, n}; + dd += 2; + EXPECT(dd.min == 3); + EXPECT(dd.max == 10); + EXPECT(dd.sym_expr == n + 2); +} + +TEST_CASE(test_dd_symbolic_sub_size_t) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{3, 8, {4}, n}; + dd -= 1; + EXPECT(dd.min == 2); + EXPECT(dd.max == 7); + EXPECT(dd.sym_expr == n - 1); +} + +TEST_CASE(test_dd_symbolic_mul_size_t) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{1, 8, {4}, n}; + dd *= 3; + EXPECT(dd.min == 3); + EXPECT(dd.max == 24); + EXPECT(dd.sym_expr == n * 3); +} + +TEST_CASE(test_dd_symbolic_div_size_t) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension dd{4, 16, {8}, n}; + dd /= 2; + EXPECT(dd.min == 2); + EXPECT(dd.max == 8); + EXPECT(dd.sym_expr == n / 2); +} + +TEST_CASE(test_dd_symbolic_add_dd) +{ + auto n = var("n"); + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto r = a + b; + EXPECT(r.min == 3); + EXPECT(r.max == 12); + EXPECT(r.sym_expr == n + c); +} + +TEST_CASE(test_dd_symbolic_sub_dd) +{ + auto n = var("n"); + auto k = var("k"); + migraphx::shape::dynamic_dimension a{4, 16, {}, n}; + migraphx::shape::dynamic_dimension b{1, 4, {}, k}; + auto r = a - b; + EXPECT(r.min == 0); + EXPECT(r.max == 15); + EXPECT(r.sym_expr == n - k); +} + +TEST_CASE(test_dd_symbolic_mul_dd) +{ + auto n = var("n"); + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto r = a * b; + EXPECT(r.min == 2); + EXPECT(r.max == 32); + EXPECT(r.sym_expr == n * c); +} + +TEST_CASE(test_dd_symbolic_div_dd) +{ + auto n = var("n"); + auto k = var("k"); + migraphx::shape::dynamic_dimension a{4, 16, {}, n}; + migraphx::shape::dynamic_dimension b{2, 4, {}, k}; + auto r = a / b; + EXPECT(r.min == 1); + EXPECT(r.max == 8); + EXPECT(r.sym_expr == n / k); +} + +TEST_CASE(test_dd_symbolic_plus_fixed) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{3, 3}; + auto r = a + b; + EXPECT(not r.sym_expr.empty()); + EXPECT(r.sym_expr == n + 3); + EXPECT(r.min == 4); + EXPECT(r.max == 11); +} + +TEST_CASE(test_dd_nonfixed_nonsymbolic_plus_symbolic_drops_sym) +{ + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}}; + migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); + EXPECT(r.min == 3); + EXPECT(r.max == 12); +} + +TEST_CASE(test_dd_nonsymbolic_remains_nonsymbolic) +{ + migraphx::shape::dynamic_dimension a{1, 8, {}}; + migraphx::shape::dynamic_dimension b{2, 4, {}}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(test_dd_equality_with_sym) +{ + auto n = var("n"); + auto c = var("c"); + migraphx::shape::dynamic_dimension a{1, 8, {}, n}; + migraphx::shape::dynamic_dimension b{1, 8, {}, n}; + migraphx::shape::dynamic_dimension d2{1, 8, {}, c}; + migraphx::shape::dynamic_dimension d{1, 8, {}}; + EXPECT(a == b); + EXPECT(a != d2); + EXPECT(a != d); +} + +TEST_CASE(test_symbolic_shape_construction) +{ + auto n = var("n"); + migraphx::shape sh{migraphx::shape::float_type, + {{1, 8, {}, n}, {3, 3}, {224, 224}}, + {n * lit(3) * lit(224), lit(224), lit(1)}}; + EXPECT(sh.dynamic()); + EXPECT(sh.symbolic()); + EXPECT(sh.dyn_dims().size() == 3); + EXPECT(sh.dyn_strides().size() == 3); +} + +TEST_CASE(test_symbolic_stride_auto_compute) +{ + auto n = var("n"); + auto s = var("s"); + migraphx::shape sh{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + EXPECT(sh.symbolic()); + EXPECT(sh.dyn_strides().size() == 3); + EXPECT(sh.dyn_strides()[2] == lit(1)); + EXPECT(sh.dyn_strides()[1] == lit(4)); + EXPECT(sh.dyn_strides()[0] == s * 4); +} + +TEST_CASE(test_symbolic_to_static) +{ + auto n = var("n"); + auto s = var("s"); + migraphx::shape sh{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + std::unordered_map symbol_map = {{n, 2}, {s, 8}}; + auto s_static = sh.to_static(symbol_map); + EXPECT(not s_static.dynamic()); + EXPECT(s_static.lens() == std::vector{2, 8, 4}); + EXPECT(s_static.strides() == std::vector{32, 4, 1}); +} + +TEST_CASE(test_symbolic_shape_serialize) +{ + auto n = var("n"); + auto s = var("s"); + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + auto v = migraphx::to_value(s1); + auto s2 = migraphx::from_value(v); + EXPECT(s1 == s2); + EXPECT(s2.symbolic()); + EXPECT(s2.dyn_strides().size() == 3); + EXPECT(s2.dyn_strides()[0] == s * 4); + EXPECT(s2.dyn_strides()[2] == lit(1)); +} + +TEST_CASE(test_symbolic_shape_equality) +{ + auto n = var("n"); + auto c = var("c"); + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}}}; + migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, c}, {3, 3}}}; + EXPECT(s1 == s2); + EXPECT(s1 != s3); +} + +TEST_CASE(test_symbolic_shape_print) +{ + auto n = var("n"); + auto c = var("c"); + auto to_str = [](const migraphx::shape& sh) { + std::stringstream ss; + ss << sh; + return ss.str(); + }; + migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}, {4, 4}}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}, {4, 4}}}; + migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, c}, {3, 3}, {4, 4}}}; + EXPECT(to_str(s1) == to_str(s2)); + EXPECT(to_str(s1) != to_str(s3)); +} + +TEST_CASE(dd_intersection_symbolic_with_range) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension a{1, 32, {}, n}; + migraphx::shape::dynamic_dimension b{2, 6}; + auto result = a.intersection(b); + EXPECT(result.has_value()); + EXPECT(result->min == 2); + EXPECT(result->max == 6); + EXPECT(result->sym_expr.empty()); +} + +TEST_CASE(dd_intersection_symbolic_same_symbol) +{ + auto n = var("n"); + migraphx::shape::dynamic_dimension a{1, 32, {}, n}; + migraphx::shape::dynamic_dimension b{1, 32, {}, n}; + auto result = a.intersection(b); + EXPECT(result.has_value()); + EXPECT(*result == a); +} + +TEST_CASE(dd_intersection_symbolic_different_symbol) +{ + auto n = var("n"); + auto m = var("m"); + migraphx::shape::dynamic_dimension a{1, 32, {}, n}; + migraphx::shape::dynamic_dimension b{1, 16, {}, m}; + auto result = a.intersection(b); + EXPECT(not result.has_value()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From ae322fc04ad357eb1f6d04d03d835d7e1a52e619 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 14:59:24 -0700 Subject: [PATCH 49/60] update implementation to work on top of inverval refactor --- src/include/migraphx/shape.hpp | 56 ++- src/include/migraphx/sym.hpp | 26 +- src/permutation.cpp | 29 +- src/shape.cpp | 483 ++++++++++++++------- src/sym.cpp | 274 ++++++++++-- test/shape_test.cpp | 750 ++++++++++++++++++++++++++++----- test/sym_test.cpp | 412 +++++++++++++++++- 7 files changed, 1711 insertions(+), 319 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index a58b36e36fe..56379f5274b 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -113,45 +113,56 @@ struct MIGRAPHX_EXPORT shape friend bool operator!=(const interval& a, const interval& b) { return not(a == b); } }; - interval range = {0, 0}; - std::set optimals{}; + std::optional range; + std::optional> optimals; sym::expr sym_expr; dynamic_dimension() = default; - dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} + dynamic_dimension(std::size_t min_v, std::size_t max_v) + : range{interval{min_v, max_v}}, optimals{std::set{}} { - normalize_sym(); } dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt) - : range{min_v, max_v}, optimals(std::move(opt)) + : range{interval{min_v, max_v}}, + optimals(min_v == max_v ? std::set{} : std::move(opt)) { - normalize_sym(); } - dynamic_dimension(sym::expr s); - dynamic_dimension(std::size_t min_v, - std::size_t max_v, - std::set opt, - sym::expr s) - : range{min_v, max_v}, optimals(std::move(opt)), sym_expr(std::move(s)) + dynamic_dimension(sym::expr s) : sym_expr(std::move(s)) { + if(sym_expr.empty()) + MIGRAPHX_THROW( + "dynamic_dimension: cannot construct from an empty symbolic expression"); } template static auto reflect(Self& self, F f) { - return pack(f(self.range, "range"), f(self.optimals, "optimals"), f(self.sym_expr, "sym")); + return pack( + f(self.range, "range"), f(self.optimals, "optimals"), f(self.sym_expr, "sym")); } - interval get_interval() const { return range; } - std::set get_optimals() const { return optimals; } + interval get_interval() const + { + if(is_symbolic()) + { + auto ival = sym_expr.eval_interval(); + if(ival.min < 0 or ival.max < 0) + MIGRAPHX_THROW("dynamic_dimension: symbolic expression has negative bounds"); + return {static_cast(ival.min), static_cast(ival.max)}; + } + return *range; + } + std::set get_optimals() const + { + if(is_symbolic()) + return sym_expr.eval_optimals(); + if(optimals.has_value()) + return *optimals; + return {}; + } bool is_fixed() const; bool is_symbolic() const { return not sym_expr.empty(); } - void normalize_sym() - { - if(is_fixed() and not is_symbolic()) - sym_expr = sym::lit(range.min); - } bool has_optimal() const; /** @@ -273,6 +284,9 @@ struct MIGRAPHX_EXPORT shape */ static shape from_permutation(type_t t, const std::vector& l, const std::vector& perm); + static shape from_permutation(type_t t, + const std::vector& dds, + const std::vector& perm); type_t type() const; const std::vector& lens() const; @@ -422,6 +436,8 @@ struct MIGRAPHX_EXPORT shape shape with_lens(type_t t, const std::vector& l) const; shape with_lens(const std::vector& l) const; + shape with_lens(type_t t, const std::vector& dds) const; + shape with_lens(const std::vector& dds) const; shape with_type(type_t t) const; diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 59d3f842884..6b28e3016da 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -40,8 +41,21 @@ struct value; namespace sym { +struct interval +{ + int64_t min = 0; + int64_t max = 0; + friend bool operator==(const interval& a, const interval& b) + { + return a.min == b.min and a.max == b.max; + } + friend bool operator!=(const interval& a, const interval& b) { return not(a == b); } +}; + struct expr; -MIGRAPHX_EXPORT expr var(const std::string& name); +MIGRAPHX_EXPORT expr var(const std::string& name, + interval bounds = {1, 1}, + std::set optimals = {}); MIGRAPHX_EXPORT expr lit(int64_t n); MIGRAPHX_EXPORT expr parse(const std::string& s); @@ -50,11 +64,14 @@ struct MIGRAPHX_EXPORT expr expr(); bool empty() const; + bool is_literal() const; std::size_t hash() const; std::string to_string() const; value to_value() const; void from_value(const value& v); std::size_t eval_uint(const std::unordered_map& symbol_map) const; + interval eval_interval() const; + std::set eval_optimals() const; expr subs(const std::unordered_map& symbol_map) const; MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); @@ -63,6 +80,10 @@ struct MIGRAPHX_EXPORT expr MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b); MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b); MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend bool operator<(const expr& a, const expr& b); + friend bool operator>(const expr& a, const expr& b) { return b < a; } + friend bool operator<=(const expr& a, const expr& b) { return not(b < a); } + friend bool operator>=(const expr& a, const expr& b) { return not(a < b); } MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e); friend expr operator+(const expr& a, int64_t b) { return a + lit(b); } @@ -76,7 +97,8 @@ struct MIGRAPHX_EXPORT expr struct impl; - MIGRAPHX_EXPORT friend expr var(const std::string& name); + MIGRAPHX_EXPORT friend expr + var(const std::string& name, interval bounds, std::set optimals); MIGRAPHX_EXPORT friend expr lit(int64_t n); MIGRAPHX_EXPORT friend expr parse(const std::string& s); diff --git a/src/permutation.cpp b/src/permutation.cpp index f152e2c5a26..65096f37dce 100644 --- a/src/permutation.cpp +++ b/src/permutation.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS { shape reorder_shape(const shape& s, const std::vector& permutation) { + if(s.symbolic()) + return {s.type(), + reorder_dims(s.dyn_dims(), permutation), + reorder_dims(s.dyn_strides(), permutation)}; return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)}; } @@ -43,11 +48,25 @@ std::vector invert_permutation(const std::vector& permutation) std::vector find_permutation(const shape& s) { - std::vector result(s.lens().size()); + if(s.dynamic() and not s.symbolic()) + MIGRAPHX_THROW("FIND_PERMUTATION: non-symbolic dynamic shapes not supported"); + std::vector result(s.ndim()); std::iota(result.begin(), result.end(), 0); - std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { - return std::make_tuple(s.strides()[x], s.lens()[x]); - })); + if(s.symbolic()) + { + const auto& strides = s.dyn_strides(); + const auto& dds = s.dyn_dims(); + std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { + return std::make_tuple(strides[x].eval_interval().max, + dds[x].sym_expr.eval_interval().max); + })); + } + else + { + std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { + return std::make_tuple(s.strides()[x], s.lens()[x]); + })); + } return result; } @@ -64,7 +83,7 @@ std::vector find_permutation(const std::vector& shapes) } if(count.empty()) { - std::vector r(shapes.front().lens().size()); + std::vector r(shapes.front().ndim()); std::iota(r.begin(), r.end(), 0); return r; } diff --git a/src/shape.cpp b/src/shape.cpp index 510d0ab78e1..92a6fe7c1c6 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -79,10 +79,11 @@ struct shape_impl shape_impl(shape::type_t t, std::vector dims) : m_type(t), m_dyn_dims(std::move(dims)) { - if(not m_dyn_dims.empty() and std::all_of(m_dyn_dims.begin(), - m_dyn_dims.end(), - [](const auto& d) { return d.is_symbolic(); })) + if(all_dims_symbolic()) + { calculate_dyn_strides(); + m_standard = true; + } } shape_impl(shape::type_t t, @@ -91,6 +92,13 @@ struct shape_impl : m_type(t), m_dyn_dims(std::move(dims)), m_dyn_strides(std::move(dstrides)) { assert(m_dyn_strides.size() == m_dyn_dims.size()); + auto dim_exprs = sym_dims(); + std::vector filtered_strides; + for(std::size_t i = 0; i < m_dyn_strides.size(); i++) + if(m_dyn_dims[i] != 1) + filtered_strides.push_back(m_dyn_strides[i]); + m_standard = compute_packed(dim_exprs, m_dyn_strides) and + is_sorted_strides(filtered_strides); } shape_impl(shape::type_t t, @@ -114,10 +122,11 @@ struct shape_impl m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]}); } } - if(not m_dyn_dims.empty() and std::all_of(m_dyn_dims.begin(), - m_dyn_dims.end(), - [](const auto& d) { return d.is_symbolic(); })) + if(all_dims_symbolic()) + { calculate_dyn_strides(); + m_standard = true; + } } shape_impl(const std::vector& subs) : m_type(shape::tuple_type), m_shapes(subs) {} @@ -131,6 +140,14 @@ struct shape_impl std::vector m_dyn_dims = {}; std::vector m_dyn_strides = {}; + bool all_dims_symbolic() const + { + return not m_dyn_dims.empty() and + std::all_of(m_dyn_dims.begin(), m_dyn_dims.end(), [](const auto& d) { + return d.is_symbolic(); + }); + } + std::vector sym_dims() const { if(m_dyn_dims.empty()) @@ -172,6 +189,138 @@ struct shape_impl void calculate_strides() { m_strides = compute_strides(m_lens); } + template + static T compute_elements(const std::vector& dims) + { + if(dims.empty()) + return make_identity(0); + return std::accumulate(dims.begin(), dims.end(), make_identity(1), std::multiplies<>{}); + } + + template + static T compute_element_space(const std::vector& dims, const std::vector& strides) + { + if(dims.empty()) + return make_identity(0); + auto one = make_identity(1); + return std::inner_product(dims.begin(), + dims.end(), + strides.begin(), + make_identity(0), + std::plus<>{}, + [&](const T& l, const T& s) { return (l - one) * s; }) + + one; + } + + template + static bool compute_skips(const std::vector& dims, const std::vector& strides) + { + if(compute_elements(dims) == make_identity(1)) + return false; + auto one = make_identity(1); + return std::none_of( + strides.begin(), strides.end(), [&](const auto& x) { return x == one; }); + } + + template + static bool compute_packed(const std::vector& dims, const std::vector& strides) + { + return not compute_skips(dims, strides) and + compute_elements(dims) == compute_element_space(dims, strides); + } + + template + static bool compute_broadcasted(const std::vector& strides) + { + auto zero = make_identity(0); + return std::any_of( + strides.begin(), strides.end(), [&](const auto& x) { return x == zero; }); + } + + template + static bool compute_scalar(const std::vector& strides) + { + auto zero = make_identity(0); + return std::accumulate(strides.begin(), strides.end(), zero) == zero; + } + + template + static bool is_sorted_strides(const std::vector& strides) + { + if constexpr(std::is_same{}) + { + std::vector concrete(strides.size()); + std::transform(strides.begin(), strides.end(), concrete.begin(), [](const auto& s) { + return static_cast(s.eval_interval().max); + }); + return std::is_sorted(concrete.rbegin(), concrete.rend()); + } + else + { + return std::is_sorted(strides.rbegin(), strides.rend()); + } + } + + template + static bool compute_transposed(const std::vector& strides) + { + if(compute_broadcasted(strides)) + { + std::vector s; + s.reserve(strides.size()); + auto zero = make_identity(0); + std::copy_if(strides.begin(), strides.end(), std::back_inserter(s), [&](const auto& x) { + return x != zero; + }); + return not is_sorted_strides(s); + } + return not is_sorted_strides(strides); + } + + bool is_packed() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_packed(sym_dims(), m_dyn_strides); + } + return compute_packed(m_lens, m_strides); + } + + bool is_broadcasted() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_broadcasted(m_dyn_strides); + } + return compute_broadcasted(m_strides); + } + + bool is_transposed() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_transposed(m_dyn_strides); + } + return compute_transposed(m_strides); + } + + bool is_scalar() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_scalar(m_dyn_strides); + } + return compute_scalar(m_strides); + } + std::size_t element_space() const { if(not m_dyn_dims.empty()) @@ -181,7 +330,6 @@ struct shape_impl return std::accumulate( maxes.begin(), maxes.end(), std::size_t{1}, [&](std::size_t x, std::size_t y) { - // overflow check and clip if(x != 0 and y > max_val / x) { return max_val; @@ -191,15 +339,7 @@ struct shape_impl } assert(m_lens.size() == m_strides.size()); - if(m_lens.empty()) - return 0; - return std::inner_product(m_lens.begin(), - m_lens.end(), - m_strides.begin(), - std::size_t{0}, - std::plus{}, - [](std::size_t l, std::size_t s) { return (l - 1) * s; }) + - 1; + return compute_element_space(m_lens, m_strides); } std::size_t elements() const @@ -210,10 +350,7 @@ struct shape_impl } assert(m_lens.size() == m_strides.size()); - if(m_lens.empty()) - return 0; - return std::accumulate( - m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies()); + return compute_elements(m_lens); } std::size_t get_index(size_t i) const @@ -262,13 +399,10 @@ struct shape_impl return ret; } - // Does the shape skip over elements? bool skips() const { assert(m_lens.size() == m_strides.size()); - if(elements() == 1) - return false; - return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; }); + return compute_skips(m_lens, m_strides); } std::shared_ptr copy() const { return std::make_shared(*this); } @@ -433,16 +567,34 @@ shape::shape(const std::vector& subs) : impl(std::make_shared shape::shape(std::shared_ptr pimpl) : impl(std::move(pimpl)) {} +template +static shape +from_permutation_impl(shape::type_t t, const Dims& dims, const std::vector& perm) +{ + auto reordered = reorder_dims(dims, perm); + return reorder_shape({t, reordered}, invert_permutation(perm)); +} + shape shape::from_permutation(type_t t, const std::vector& l, const std::vector& perm) { - auto new_lens = reorder_dims(l, perm); - shape result = reorder_shape({t, new_lens}, invert_permutation(perm)); + shape result = from_permutation_impl(t, l, perm); assert(result.lens() == l); return result; } +shape shape::from_permutation(type_t t, + const std::vector& dds, + const std::vector& perm) +{ + if(std::any_of(dds.begin(), dds.end(), [](const auto& dd) { return not dd.is_symbolic(); })) + MIGRAPHX_THROW("FROM_PERMUTATION: non-symbolic dynamic dimensions not supported"); + shape result = from_permutation_impl(t, dds, perm); + assert(result.dyn_dims() == dds); + return result; +} + shape::type_t shape::type() const { return impl->m_type; } const std::vector& shape::lens() const @@ -585,76 +737,50 @@ std::size_t shape::single(const std::vector& idx) const bool shape::packed() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - return this->sub_shapes().empty() and not impl->skips() and - this->elements() == this->element_space(); + return this->sub_shapes().empty() and impl->is_packed(); } bool shape::transposed() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - if(this->broadcasted()) - { - // TODO: Use a filter_iterator instead - std::vector s; - s.reserve(this->strides().size()); - std::copy_if(this->strides().begin(), - this->strides().end(), - std::back_inserter(s), - [](std::size_t x) { return x != 0; }); - return not std::is_sorted(s.rbegin(), s.rend()); - } - else - { - return not std::is_sorted(this->strides().rbegin(), this->strides().rend()); - } + return impl->is_transposed(); } bool shape::broadcasted() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - assert(this->lens().size() == this->strides().size()); - return std::any_of( - this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; }); + return impl->is_broadcasted(); } bool shape::scalar() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - assert(this->lens().size() == this->strides().size()); - // if any stride > 0, then accumulate will return false - return this->sub_shapes().empty() and - std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; + return this->sub_shapes().empty() and impl->is_scalar(); } bool shape::standard() const { return impl->m_standard; } shape shape::normalize_standard() const { - if(this->standard()) - return {this->type(), this->lens()}; - else + if(not this->standard()) return *this; + if(this->symbolic()) + return {this->type(), this->dyn_dims()}; + return {this->type(), this->lens()}; } shape shape::as_standard() const { + if(this->symbolic()) + return {this->type(), this->dyn_dims()}; if(not this->dynamic()) return {this->type(), this->lens()}; - else - return *this; + return *this; } shape shape::with_lens(type_t t, const std::vector& l) const @@ -677,6 +803,20 @@ shape shape::with_lens(const std::vector& l) const return this->with_lens(this->type(), l); } +shape shape::with_lens(type_t t, const std::vector& dds) const +{ + if(this->dynamic() and not this->symbolic()) + MIGRAPHX_THROW("SHAPE: with_lens() called on non-symbolic dynamic shape"); + assert(dds.size() == this->ndim()); + auto perm = find_permutation(*this); + return shape::from_permutation(t, dds, perm); +} + +shape shape::with_lens(const std::vector& dds) const +{ + return this->with_lens(this->type(), dds); +} + shape shape::with_type(type_t t) const { auto c = impl->copy(); @@ -798,13 +938,7 @@ const std::vector& shape::dyn_dims() const return impl->m_dyn_dims; } -bool shape::symbolic() const -{ - return not impl->m_dyn_dims.empty() and - std::all_of(impl->m_dyn_dims.begin(), impl->m_dyn_dims.end(), [](const auto& dd) { - return dd.is_symbolic(); - }); -} +bool shape::symbolic() const { return impl->all_dims_symbolic(); } const std::vector& shape::dyn_strides() const { return impl->m_dyn_strides; } @@ -822,6 +956,8 @@ std::vector> shape::opt_lens() const { return impl->opt_le bool shape::dynamic_dimension::is_fixed() const { + if(sym_expr.is_literal()) + return true; auto i = this->get_interval(); return i.min == i.max; } @@ -833,7 +969,7 @@ bool shape::dynamic_dimension::has_optimal() const { return not this->get_optima #define MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(binary_op, assign_op) \ shape::dynamic_dimension& shape::dynamic_dimension::operator assign_op(const std::size_t& x) \ { \ - return *this assign_op dynamic_dimension{x, x}; \ + return *this assign_op dynamic_dimension{sym::lit(x)}; \ } \ shape::dynamic_dimension operator binary_op( \ const shape::dynamic_dimension& x, const std::size_t& y) \ @@ -845,7 +981,7 @@ bool shape::dynamic_dimension::has_optimal() const { return not this->get_optima shape::dynamic_dimension operator binary_op( \ const std::size_t& x, const shape::dynamic_dimension& y) \ { \ - return shape::dynamic_dimension{x, x} binary_op y; \ + return shape::dynamic_dimension{sym::lit(x)} binary_op y; \ } \ shape::dynamic_dimension operator binary_op( \ const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) \ @@ -930,97 +1066,138 @@ static void merge_optimals(std::set& optimals, } } +// Arithmetic semantics: symbolic + symbolic = symbolic, +// range + range = range, range + symbolic = range. +template +static shape::dynamic_dimension& apply_op(shape::dynamic_dimension& lhs, + const shape::dynamic_dimension& rhs, + SymOp sym_op, + RangeOp range_op) +{ + auto lhs_sym = lhs.sym_expr; + auto rhs_sym = rhs.sym_expr; + auto result_sym = sym_op(lhs_sym, rhs_sym); + if(not result_sym.empty()) + { + lhs.sym_expr = result_sym; + lhs.range = std::nullopt; + lhs.optimals = std::nullopt; + } + else + { + // Materialize symbolic operands as range-based shapes so that + // arithmetic between symbolic and range-based dimensions works. + auto to_range = [](const shape::dynamic_dimension& d) { + auto iv = d.get_interval(); + return shape::dynamic_dimension{iv.min, iv.max, d.get_optimals()}; + }; + auto lhs_range = lhs.is_symbolic() ? to_range(lhs) : lhs; + auto rhs_range = rhs.is_symbolic() ? to_range(rhs) : rhs; + range_op(lhs_range, rhs_range); + lhs = lhs_range; + } + return lhs; +} + shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - auto lhs_fixed = this->is_fixed(); - auto rhs_fixed = x.is_fixed(); - auto lhs_min = range.min; - range.min = range.min + x.range.min; - range.max = (range.max > std::numeric_limits::max() - x.range.max) - ? std::numeric_limits::max() - : range.max + x.range.max; - merge_optimals( - optimals, - lhs_fixed, - x.optimals, - rhs_fixed, - [&](auto o) { return o + x.range.min; }, - [&](auto o) { return o + lhs_min; }); - sym_expr = lhs_sym + rhs_sym; - normalize_sym(); - return *this; + return apply_op( + *this, + x, + [](auto& a, auto& b) { return a + b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + lhs.range->min += rhs.range->min; + lhs.range->max = + (lhs.range->max > std::numeric_limits::max() - rhs.range->max) + ? std::numeric_limits::max() + : lhs.range->max + rhs.range->max; + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return o + rhs.range->min; }, + [&](auto o) { return o + lhs_min; }); + }); } shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - auto lhs_fixed = this->is_fixed(); - auto rhs_fixed = x.is_fixed(); - auto lhs_min = range.min; - range.min = (range.min > x.range.max) ? range.min - x.range.max : 0; - range.max = (range.max > x.range.min) ? range.max - x.range.min : 0; - merge_optimals( - optimals, - lhs_fixed, - x.optimals, - rhs_fixed, - [&](auto o) { return (o > x.range.min) ? o - x.range.min : 0; }, - [&](auto o) { return (lhs_min > o) ? lhs_min - o : 0; }); - sym_expr = lhs_sym - rhs_sym; - normalize_sym(); - return *this; + return apply_op( + *this, + x, + [](auto& a, auto& b) { return a - b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + lhs.range->min = + (lhs.range->min > rhs.range->max) ? lhs.range->min - rhs.range->max : 0; + lhs.range->max = + (lhs.range->max > rhs.range->min) ? lhs.range->max - rhs.range->min : 0; + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return (o > rhs.range->min) ? o - rhs.range->min : std::size_t{0}; }, + [&](auto o) { return (lhs_min > o) ? lhs_min - o : std::size_t{0}; }); + }); } shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - auto lhs_fixed = this->is_fixed(); - auto rhs_fixed = x.is_fixed(); - auto lhs_min = range.min; - range.min = range.min * x.range.min; - auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { - if(b == 0) - return 0; - if(a > std::numeric_limits::max() / b) - return std::numeric_limits::max(); - return a * b; - }; - range.max = safe_mul(range.max, x.range.max); - merge_optimals( - optimals, - lhs_fixed, - x.optimals, - rhs_fixed, - [&](auto o) { return o * x.range.min; }, - [&](auto o) { return o * lhs_min; }); - sym_expr = lhs_sym * rhs_sym; - normalize_sym(); - return *this; + return apply_op( + *this, + x, + [](auto& a, auto& b) { return a * b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { + if(b == 0) + return 0; + if(a > std::numeric_limits::max() / b) + return std::numeric_limits::max(); + return a * b; + }; + lhs.range->min = lhs.range->min * rhs.range->min; + lhs.range->max = safe_mul(lhs.range->max, rhs.range->max); + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return o * rhs.range->min; }, + [&](auto o) { return o * lhs_min; }); + }); } shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dynamic_dimension& x) { - auto lhs_sym = sym_expr; - auto rhs_sym = x.sym_expr; - auto lhs_fixed = this->is_fixed(); - auto rhs_fixed = x.is_fixed(); - auto lhs_min = range.min; - range.min = (x.range.max == 0) ? 0 : range.min / x.range.max; - range.max = (x.range.min == 0) ? std::numeric_limits::max() : range.max / x.range.min; - merge_optimals( - optimals, - lhs_fixed, - x.optimals, - rhs_fixed, - [&](auto o) { return (x.range.min == 0) ? std::size_t{0} : o / x.range.min; }, - [&](auto o) { return (o == 0) ? std::size_t{0} : lhs_min / o; }); - sym_expr = lhs_sym / rhs_sym; - normalize_sym(); - return *this; + return apply_op( + *this, + x, + [](auto& a, auto& b) { return a / b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + lhs.range->min = (rhs.range->max == 0) ? 0 : lhs.range->min / rhs.range->max; + lhs.range->max = (rhs.range->min == 0) ? std::numeric_limits::max() + : lhs.range->max / rhs.range->min; + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return (rhs.range->min == 0) ? std::size_t{0} : o / rhs.range->min; }, + [&](auto o) { return (o == 0) ? std::size_t{0} : lhs_min / o; }); + }); } bool operator==(const shape& x, const shape& y) diff --git a/src/sym.cpp b/src/sym.cpp index fd552980727..6a376dded23 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -65,6 +65,9 @@ struct integer_data struct symbol_data { std::string name; + int64_t min; + int64_t max; + std::set optimals; }; struct add_data { @@ -127,7 +130,11 @@ static std::size_t compute_hash(const expr_data& d) return std::visit( overloaded{ [&](const integer_data& p) { return hash_combine(h, std::hash{}(p.value)); }, - [&](const symbol_data& p) { return hash_combine(h, std::hash{}(p.name)); }, + [&](const symbol_data& p) { + auto h2 = hash_combine(h, std::hash{}(p.name)); + h2 = hash_combine(h2, std::hash{}(p.min)); + return hash_combine(h2, std::hash{}(p.max)); + }, [&](const add_data& p) { return hash_combine(hash_combine(h, std::hash{}(p.constant)), hash_ordered_map(p.terms)); @@ -183,7 +190,14 @@ static int compare_expr(const expr_ptr& a, const expr_ptr& b) }, [&](const symbol_data& da) { const auto& db = std::get(b->data); - return da.name.compare(db.name); + int c = da.name.compare(db.name); + if(c != 0) + return c; + if(da.min != db.min) + return da.min < db.min ? -1 : 1; + if(da.max != db.max) + return da.max < db.max ? -1 : 1; + return 0; }, [&](const add_data& da) { const auto& db = std::get(b->data); @@ -264,7 +278,11 @@ static expr_ptr make_integer(int64_t n) return make_node(integer_data{n}); } -static expr_ptr make_symbol(const std::string& name) { return make_node(symbol_data{name}); } +static expr_ptr +make_symbol(const std::string& name, int64_t min, int64_t max, std::set optimals = {}) +{ + return make_node(symbol_data{name, min, max, std::move(optimals)}); +} static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b); static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b); @@ -384,20 +402,15 @@ static expr_ptr build_mul(int64_t coefficient, factor_map factors) static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) { - if(holds(a) and holds(b)) + if(holds(b)) { - int64_t n = get_integer(a); - if(n == 0) - return make_integer(0); - if(n == 1) - return b; - const auto& d = get_add(b); - term_map scaled; + const auto& d = get_add(b); + expr_ptr result = make_mul(a, make_integer(d.constant)); for(const auto& [term, coeff] : d.terms) - scaled[term] = coeff * n; - return build_add(d.constant * n, std::move(scaled)); + result = make_add(result, make_mul(a, make_mul(make_integer(coeff), term))); + return result; } - if(holds(b) and holds(a)) + if(holds(a)) return make_mul(b, a); auto pa = extract_mul(a); @@ -566,41 +579,94 @@ static expr_ptr substitute(const expr_ptr& e, const subs_map& bindings) e->data); } -static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) +template +static int64_t eval_impl(const expr_ptr& e, const SymbolResolver& resolve_sym) { return std::visit(overloaded{[](const integer_data& d) -> int64_t { return d.value; }, - [&](const symbol_data& d) -> int64_t { - auto it = bindings.find(e); - if(it != bindings.end()) - return it->second; - MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + - d.name + "'"); - }, + [&](const symbol_data& d) -> int64_t { return resolve_sym(e, d); }, [&](const add_data& d) -> int64_t { int64_t sum = d.constant; for(const auto& [term, coeff] : d.terms) - sum += coeff * eval_direct(term, bindings); + sum += coeff * eval_impl(term, resolve_sym); return sum; }, [&](const mul_data& d) -> int64_t { int64_t prod = d.coefficient; for(const auto& [base, exp] : d.factors) { - int64_t val = eval_direct(base, bindings); + int64_t val = eval_impl(base, resolve_sym); for(int64_t i = 0; i < exp; ++i) prod *= val; } return prod; }, [&](const tdiv_data& d) -> int64_t { - auto denom = eval_direct(d.denominator, bindings); + auto denom = eval_impl(d.denominator, resolve_sym); if(denom == 0) - MIGRAPHX_THROW("sym::expr::eval_uint: division by zero"); - return eval_direct(d.numerator, bindings) / denom; + MIGRAPHX_THROW("sym::expr: division by zero during eval"); + return eval_impl(d.numerator, resolve_sym) / denom; }}, e->data); } +static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) +{ + return eval_impl(e, [&](const expr_ptr& node, const symbol_data& d) -> int64_t { + auto it = bindings.find(node); + if(it != bindings.end()) + return it->second; + MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + d.name + "'"); + }); +} + +// Walk the expression tree and collect each unique symbol node and its data. +static void collect_symbols(const expr_ptr& e, + std::vector>& result, + std::set& seen) +{ + std::visit(overloaded{[](const integer_data&) {}, + [&](const symbol_data& d) { + if(seen.insert(d.name).second) + result.push_back({e, d}); + }, + [&](const add_data& d) { + for(const auto& [term, coeff] : d.terms) + collect_symbols(term, result, seen); + }, + [&](const mul_data& d) { + for(const auto& [base, exp] : d.factors) + collect_symbols(base, result, seen); + }, + [&](const tdiv_data& d) { + collect_symbols(d.numerator, result, seen); + collect_symbols(d.denominator, result, seen); + }}, + e->data); +} + +// Recursively enumerate all 2^k combinations of symbol {min, max} values, +// evaluating the expression at each and tracking the global min and max. +static void eval_bounds_impl(const expr_ptr& node, + const std::vector>& syms, + std::size_t idx, + binding_map& bindings, + int64_t& lo, + int64_t& hi) +{ + if(idx == syms.size()) + { + auto v = eval_direct(node, bindings); + lo = std::min(lo, v); + hi = std::max(hi, v); + return; + } + const auto& [sym_node, sd] = syms[idx]; + bindings[sym_node] = sd.min; + eval_bounds_impl(node, syms, idx + 1, bindings, lo, hi); + bindings[sym_node] = sd.max; + eval_bounds_impl(node, syms, idx + 1, bindings, lo, hi); +} + // =================================================================== // Section 7: Pretty-printer // =================================================================== @@ -734,7 +800,7 @@ static expr_ptr parse_primary(const char*& p) name += *p; ++p; } - return make_symbol(name); + return make_symbol(name, 1, 1); } if(*p == '(') { @@ -834,6 +900,8 @@ expr::expr(std::shared_ptr pi) : p(std::move(pi)) {} bool expr::empty() const { return p == nullptr; } +bool expr::is_literal() const { return p != nullptr and holds(p->node); } + std::size_t expr::hash() const { if(empty()) @@ -865,6 +933,87 @@ std::size_t expr::eval_uint(const std::unordered_map& symbol_ return v; } +// Compute both the minimum and maximum value of an expression by +// evaluating at all 2^k vertices of the symbol bound ranges. +// +// Assumptions: +// 1. Expressions are monotonic in each variable independently, so the +// global extrema always occur at vertices of the variable ranges. +// 2. Expressions represent dimension sizes or strides: sums, products, +// and integer divisions of positive-valued symbols. Non-monotonic +// expressions (e.g. polynomials with interior extrema) are not +// expected. +// 3. The number of unique symbols per expression is small (typically +// 1-3), making the 2^k evaluation cost negligible. +interval expr::eval_interval() const +{ + if(empty()) + MIGRAPHX_THROW("sym::expr::eval_interval: empty expression"); + std::vector> syms; + std::set seen; + collect_symbols(p->node, syms, seen); + if(syms.empty()) + { + auto v = eval_direct(p->node, {}); + return {v, v}; + } + int64_t lo = INT64_MAX; + int64_t hi = INT64_MIN; + binding_map bindings; + eval_bounds_impl(p->node, syms, 0, bindings, lo, hi); + return {lo, hi}; +} + +// Recursively enumerate the Cartesian product of symbol optimals, +// evaluating the expression at each combination without materializing +// intermediate binding maps. +static void eval_optimals_impl(const expr_ptr& node, + const std::vector>& syms, + std::size_t idx, + binding_map& bindings, + std::set& result) +{ + if(idx == syms.size()) + { + result.insert(eval_direct(node, bindings)); + return; + } + const auto& [sym_node, sd] = syms[idx]; + for(auto oval : sd.optimals) + { + bindings[sym_node] = oval; + eval_optimals_impl(node, syms, idx + 1, bindings, result); + } +} + +// Compute the set of optimal values for the expression by evaluating it +// at every combination of each symbol's optimal values (Cartesian product). +// +// For a single variable: var("n", {1, 8}, {2, 4}) => optimals = {2, 4} +// For a compound expr: 2*n + 1 where n has optimals {2, 4} => {5, 9} +// For multiple variables: n + m where n={2,4}, m={3,6} => {5, 8, 7, 10} +// +// Returns empty if any symbol in the expression has no optimals. +std::set expr::eval_optimals() const +{ + if(empty()) + return {}; + std::vector> syms; + std::set seen; + collect_symbols(p->node, syms, seen); + auto has_optimals = std::all_of( + syms.begin(), syms.end(), [](const auto& s) { return not s.second.optimals.empty(); }); + if(syms.empty() or not has_optimals) + return {}; + + std::set signed_result; + binding_map bindings; + eval_optimals_impl(p->node, syms, 0, bindings, signed_result); + if(std::any_of(signed_result.begin(), signed_result.end(), [](int64_t v) { return v < 0; })) + MIGRAPHX_THROW("sym::expr::eval_optimals: negative optimal value"); + return {signed_result.begin(), signed_result.end()}; +} + expr expr::subs(const std::unordered_map& symbol_map) const { if(empty()) @@ -920,6 +1069,45 @@ bool operator==(const expr& a, const expr& b) bool operator!=(const expr& a, const expr& b) { return not(a == b); } +// Semantic strict less-than for symbolic expressions using interval arithmetic. +// +// Assumptions: +// - All symbols have positive intervals [min, max] where 1 <= min <= max. +// - Expressions are monotonically non-decreasing in each variable, which +// holds for dimension/stride arithmetic (sums and products of positive +// terms). This lets us bound the range of (b - a) by evaluating at the +// interval endpoints. +// +// Algorithm: +// Compute diff = b - a, then evaluate diff at the lower and upper bounds +// of every symbol to obtain [lo, hi]. If the entire interval is strictly +// positive (lo > 0) then a < b for all possible symbol values. If the +// interval is non-positive (hi <= 0) then a >= b. Otherwise the comparison +// is undetermined and we throw. +// +// Examples (all symbols default to [1, 1]): +// n < 2*n => diff = n, lo = 1, hi = 1 => true (strictly positive) +// 2*n < n => diff = -n, lo = -1, hi = -1 => false (non-positive) +// k < m*k => diff = k(m-1), lo = 0, hi = 0 => false (not strictly positive) +// +// With explicit bounds, e.g. n in [2, 10]: +// n < 3 => diff = 3 - n, lo = -7, hi = 1 => undetermined (throws) +// n < 11 => diff = 11 - n, lo = 1, hi = 9 => true +bool operator<(const expr& a, const expr& b) +{ + if(a.empty() and b.empty()) + return false; + if(a.empty() or b.empty()) + MIGRAPHX_THROW("sym::expr: cannot compare empty expression"); + auto ival = (b - a).eval_interval(); + if(ival.min > 0) + return true; + if(ival.max <= 0) + return false; + MIGRAPHX_THROW("sym::expr: comparison undetermined for: " + print_expr(a.p->node) + " < " + + print_expr(b.p->node)); +} + std::ostream& operator<<(std::ostream& os, const expr& e) { if(not e.empty()) @@ -927,11 +1115,16 @@ std::ostream& operator<<(std::ostream& os, const expr& e) return os; } -expr var(const std::string& name) +expr var(const std::string& name, interval bounds, std::set optimals) { if(name.empty()) MIGRAPHX_THROW("sym::var: variable name must not be empty"); - return {std::make_shared(make_symbol(name))}; + if(bounds.min > bounds.max) + MIGRAPHX_THROW("sym::var: variable interval must satisfy min <= max"); + if(bounds.min < 1) + MIGRAPHX_THROW("sym::var: variable interval must satisfy min >= 1"); + return {std::make_shared( + make_symbol(name, bounds.min, bounds.max, std::move(optimals)))}; } expr lit(int64_t n) { return {std::make_shared(make_integer(n))}; } @@ -955,6 +1148,17 @@ static value node_to_value(const expr_ptr& e) value r; r["type"] = "sym"; r["name"] = d.name; + r["min"] = d.min; + r["max"] = d.max; + if(not d.optimals.empty()) + { + value opts = value::array{}; + std::transform(d.optimals.begin(), + d.optimals.end(), + std::back_inserter(opts), + [](auto o) -> value { return o; }); + r["optimals"] = opts; + } return r; }, [](const add_data& d) -> value { @@ -1006,7 +1210,15 @@ static expr_ptr node_from_value(const value& v) } else if(type == "sym") { - return make_symbol(v.at("name").get_string()); + auto sym_min = v.contains("min") ? v.at("min").to() : int64_t{1}; + auto sym_max = v.contains("max") ? v.at("max").to() : int64_t{1}; + std::set sym_opts; + if(v.contains("optimals")) + std::transform(v.at("optimals").begin(), + v.at("optimals").end(), + std::inserter(sym_opts, sym_opts.end()), + [](const auto& o) { return o.template to(); }); + return make_symbol(v.at("name").get_string(), sym_min, sym_max, std::move(sym_opts)); } else if(type == "add") { diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 6933ef85747..d408872b880 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -34,6 +34,7 @@ #include #include "test.hpp" +using dd = migraphx::shape::dynamic_dimension; using migraphx::sym::lit; using migraphx::sym::var; @@ -1285,131 +1286,122 @@ TEST_CASE(shape_same_lens_static_dynamic) TEST_CASE(test_dd_symbolic_add_size_t) { - auto n = var("n"); - migraphx::shape::dynamic_dimension dd{1, 8, {4}, n}; - dd += 2; - EXPECT(dd.min == 3); - EXPECT(dd.max == 10); - EXPECT(dd.sym_expr == n + 2); + auto n = var("n", {1, 8}); + dd d{n}; + d += 2; + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 10); + EXPECT(d.sym_expr == n + 2); } TEST_CASE(test_dd_symbolic_sub_size_t) { - auto n = var("n"); - migraphx::shape::dynamic_dimension dd{3, 8, {4}, n}; - dd -= 1; - EXPECT(dd.min == 2); - EXPECT(dd.max == 7); - EXPECT(dd.sym_expr == n - 1); + auto n = var("n", {3, 8}); + dd d{n}; + d -= 1; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 7); + EXPECT(d.sym_expr == n - 1); } TEST_CASE(test_dd_symbolic_mul_size_t) { - auto n = var("n"); - migraphx::shape::dynamic_dimension dd{1, 8, {4}, n}; - dd *= 3; - EXPECT(dd.min == 3); - EXPECT(dd.max == 24); - EXPECT(dd.sym_expr == n * 3); + auto n = var("n", {1, 8}); + dd d{n}; + d *= 3; + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 24); + EXPECT(d.sym_expr == n * 3); } TEST_CASE(test_dd_symbolic_div_size_t) { - auto n = var("n"); - migraphx::shape::dynamic_dimension dd{4, 16, {8}, n}; - dd /= 2; - EXPECT(dd.min == 2); - EXPECT(dd.max == 8); - EXPECT(dd.sym_expr == n / 2); + auto n = var("n", {4, 16}); + dd d{n}; + d /= 2; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 8); + EXPECT(d.sym_expr == n / 2); } TEST_CASE(test_dd_symbolic_add_dd) { - auto n = var("n"); - auto c = var("c"); - migraphx::shape::dynamic_dimension a{1, 8, {}, n}; - migraphx::shape::dynamic_dimension b{2, 4, {}, c}; - auto r = a + b; - EXPECT(r.min == 3); - EXPECT(r.max == 12); + auto n = var("n", {1, 8}); + auto c = var("c", {2, 4}); + auto r = dd{n} + dd{c}; + EXPECT(r.get_interval().min == 3); + EXPECT(r.get_interval().max == 12); EXPECT(r.sym_expr == n + c); } TEST_CASE(test_dd_symbolic_sub_dd) { - auto n = var("n"); - auto k = var("k"); - migraphx::shape::dynamic_dimension a{4, 16, {}, n}; - migraphx::shape::dynamic_dimension b{1, 4, {}, k}; - auto r = a - b; - EXPECT(r.min == 0); - EXPECT(r.max == 15); + auto n = var("n", {4, 16}); + auto k = var("k", {1, 4}); + auto r = dd{n} - dd{k}; + EXPECT(r.get_interval().min == 0); + EXPECT(r.get_interval().max == 15); EXPECT(r.sym_expr == n - k); } TEST_CASE(test_dd_symbolic_mul_dd) { - auto n = var("n"); - auto c = var("c"); - migraphx::shape::dynamic_dimension a{1, 8, {}, n}; - migraphx::shape::dynamic_dimension b{2, 4, {}, c}; - auto r = a * b; - EXPECT(r.min == 2); - EXPECT(r.max == 32); + auto n = var("n", {1, 8}); + auto c = var("c", {2, 4}); + auto r = dd{n} * dd{c}; + EXPECT(r.get_interval().min == 2); + EXPECT(r.get_interval().max == 32); EXPECT(r.sym_expr == n * c); } TEST_CASE(test_dd_symbolic_div_dd) { - auto n = var("n"); - auto k = var("k"); - migraphx::shape::dynamic_dimension a{4, 16, {}, n}; - migraphx::shape::dynamic_dimension b{2, 4, {}, k}; - auto r = a / b; - EXPECT(r.min == 1); - EXPECT(r.max == 8); + auto n = var("n", {4, 16}); + auto k = var("k", {2, 4}); + auto r = dd{n} / dd{k}; + EXPECT(r.get_interval().min == 1); + EXPECT(r.get_interval().max == 8); EXPECT(r.sym_expr == n / k); } -TEST_CASE(test_dd_symbolic_plus_fixed) +TEST_CASE(test_dd_symbolic_plus_range_fixed) { - auto n = var("n"); - migraphx::shape::dynamic_dimension a{1, 8, {}, n}; - migraphx::shape::dynamic_dimension b{3, 3}; + auto n = var("n", {1, 8}); + dd a{n}; + dd b{3, 3}; auto r = a + b; - EXPECT(not r.sym_expr.empty()); - EXPECT(r.sym_expr == n + 3); - EXPECT(r.min == 4); - EXPECT(r.max == 11); + EXPECT(r.sym_expr.empty()); + EXPECT(r.get_interval().min == 4); + EXPECT(r.get_interval().max == 11); } TEST_CASE(test_dd_nonfixed_nonsymbolic_plus_symbolic_drops_sym) { - auto c = var("c"); - migraphx::shape::dynamic_dimension a{1, 8, {}}; - migraphx::shape::dynamic_dimension b{2, 4, {}, c}; + auto c = var("c", {2, 4}); + dd a{1, 8}; + dd b{c}; auto r = a + b; EXPECT(r.sym_expr.empty()); - EXPECT(r.min == 3); - EXPECT(r.max == 12); + EXPECT(r.get_interval().min == 3); + EXPECT(r.get_interval().max == 12); } TEST_CASE(test_dd_nonsymbolic_remains_nonsymbolic) { - migraphx::shape::dynamic_dimension a{1, 8, {}}; - migraphx::shape::dynamic_dimension b{2, 4, {}}; + dd a{1, 8}; + dd b{2, 4}; auto r = a + b; EXPECT(r.sym_expr.empty()); } TEST_CASE(test_dd_equality_with_sym) { - auto n = var("n"); - auto c = var("c"); - migraphx::shape::dynamic_dimension a{1, 8, {}, n}; - migraphx::shape::dynamic_dimension b{1, 8, {}, n}; - migraphx::shape::dynamic_dimension d2{1, 8, {}, c}; - migraphx::shape::dynamic_dimension d{1, 8, {}}; + auto n = var("n", {1, 8}); + auto c = var("c", {1, 8}); + dd a{n}; + dd b{n}; + dd d2{c}; + dd d{1, 8}; EXPECT(a == b); EXPECT(a != d2); EXPECT(a != d); @@ -1417,10 +1409,10 @@ TEST_CASE(test_dd_equality_with_sym) TEST_CASE(test_symbolic_shape_construction) { - auto n = var("n"); + auto n = var("n", {1, 8}); migraphx::shape sh{migraphx::shape::float_type, - {{1, 8, {}, n}, {3, 3}, {224, 224}}, - {n * lit(3) * lit(224), lit(224), lit(1)}}; + {dd{n}, dd{lit(3)}, dd{lit(224)}}, + {n * 3 * 224, lit(224), lit(1)}}; EXPECT(sh.dynamic()); EXPECT(sh.symbolic()); EXPECT(sh.dyn_dims().size() == 3); @@ -1429,9 +1421,9 @@ TEST_CASE(test_symbolic_shape_construction) TEST_CASE(test_symbolic_stride_auto_compute) { - auto n = var("n"); - auto s = var("s"); - migraphx::shape sh{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + auto n = var("n", {1, 8}); + auto s = var("s", {1, 16}); + migraphx::shape sh{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(4)}}}; EXPECT(sh.symbolic()); EXPECT(sh.dyn_strides().size() == 3); EXPECT(sh.dyn_strides()[2] == lit(1)); @@ -1441,9 +1433,9 @@ TEST_CASE(test_symbolic_stride_auto_compute) TEST_CASE(test_symbolic_to_static) { - auto n = var("n"); - auto s = var("s"); - migraphx::shape sh{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + auto n = var("n", {1, 8}); + auto s = var("s", {1, 16}); + migraphx::shape sh{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(4)}}}; std::unordered_map symbol_map = {{n, 2}, {s, 8}}; auto s_static = sh.to_static(symbol_map); EXPECT(not s_static.dynamic()); @@ -1453,9 +1445,9 @@ TEST_CASE(test_symbolic_to_static) TEST_CASE(test_symbolic_shape_serialize) { - auto n = var("n"); - auto s = var("s"); - migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {1, 16, {}, s}, {4, 4}}}; + auto n = var("n", {1, 8}); + auto s = var("s", {1, 16}); + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(4)}}}; auto v = migraphx::to_value(s1); auto s2 = migraphx::from_value(v); EXPECT(s1 == s2); @@ -1467,48 +1459,48 @@ TEST_CASE(test_symbolic_shape_serialize) TEST_CASE(test_symbolic_shape_equality) { - auto n = var("n"); - auto c = var("c"); - migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}}}; - migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}}}; - migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, c}, {3, 3}}}; + auto n = var("n", {1, 8}); + auto c = var("c", {1, 8}); + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + migraphx::shape s2{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + migraphx::shape s3{migraphx::shape::float_type, {dd{c}, dd{lit(3)}}}; EXPECT(s1 == s2); EXPECT(s1 != s3); } TEST_CASE(test_symbolic_shape_print) { - auto n = var("n"); - auto c = var("c"); + auto n = var("n", {1, 8}); + auto c = var("c", {1, 8}); auto to_str = [](const migraphx::shape& sh) { std::stringstream ss; ss << sh; return ss.str(); }; - migraphx::shape s1{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}, {4, 4}}}; - migraphx::shape s2{migraphx::shape::float_type, {{1, 8, {}, n}, {3, 3}, {4, 4}}}; - migraphx::shape s3{migraphx::shape::float_type, {{1, 8, {}, c}, {3, 3}, {4, 4}}}; + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape s2{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape s3{migraphx::shape::float_type, {dd{c}, dd{lit(3)}, dd{lit(4)}}}; EXPECT(to_str(s1) == to_str(s2)); EXPECT(to_str(s1) != to_str(s3)); } TEST_CASE(dd_intersection_symbolic_with_range) { - auto n = var("n"); - migraphx::shape::dynamic_dimension a{1, 32, {}, n}; - migraphx::shape::dynamic_dimension b{2, 6}; + auto n = var("n", {1, 32}); + dd a{n}; + dd b{2, 6}; auto result = a.intersection(b); EXPECT(result.has_value()); - EXPECT(result->min == 2); - EXPECT(result->max == 6); + EXPECT(result->get_interval().min == 2); + EXPECT(result->get_interval().max == 6); EXPECT(result->sym_expr.empty()); } TEST_CASE(dd_intersection_symbolic_same_symbol) { - auto n = var("n"); - migraphx::shape::dynamic_dimension a{1, 32, {}, n}; - migraphx::shape::dynamic_dimension b{1, 32, {}, n}; + auto n = var("n", {1, 32}); + dd a{n}; + dd b{n}; auto result = a.intersection(b); EXPECT(result.has_value()); EXPECT(*result == a); @@ -1516,12 +1508,556 @@ TEST_CASE(dd_intersection_symbolic_same_symbol) TEST_CASE(dd_intersection_symbolic_different_symbol) { - auto n = var("n"); - auto m = var("m"); - migraphx::shape::dynamic_dimension a{1, 32, {}, n}; - migraphx::shape::dynamic_dimension b{1, 16, {}, m}; + auto n = var("n", {1, 32}); + auto m = var("m", {1, 16}); + dd a{n}; + dd b{m}; auto result = a.intersection(b); EXPECT(not result.has_value()); } +// ------------------------------------------------------------------- +// Symbolic shapes: packed/standard/transposed/broadcasted/scalar +// ------------------------------------------------------------------- + +TEST_CASE(test_symbolic_packed_default) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_standard) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_standard_singleton_dim) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(8)}}, {lit(8), lit(4), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_shape_ndim_symbolic) +{ + auto n = var("n", {1, 8}); + migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + EXPECT(s0.ndim() == 2); + + auto c = var("c", {1, 16}); + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}, dd{lit(4)}}}; + EXPECT(s1.ndim() == 4); +} + +TEST_CASE(test_symbolic_transposed) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {lit(1), n, n * c}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(s.packed()); + EXPECT(s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_to_dynamic_identity) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + auto s2 = s.to_dynamic(); + EXPECT(s == s2); +} + +TEST_CASE(test_symbolic_overlap) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}, {lit(6), lit(3), lit(2)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_scalar) +{ + migraphx::shape s{migraphx::shape::float_type, {dd{lit(1)}}, {lit(0)}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_scalar_broadcast) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}, {lit(0), lit(0), lit(0)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(1), lit(0)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted2) +{ + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{lit(1)}, dd{c}}, {lit(0), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_normalize_standard) +{ + auto n = var("n", {1, 4}); + auto c = var("c", {1, 64}); + migraphx::shape s{migraphx::shape::float_type, + {dd{n}, dd{c}, dd{lit(35)}, dd{lit(35)}}, + {c * 1225, lit(1225), lit(35), lit(1)}}; + EXPECT(s.standard()); + auto ns = s.normalize_standard(); + EXPECT(ns.standard()); + EXPECT(ns.symbolic()); + EXPECT(ns.dyn_dims() == s.dyn_dims()); + EXPECT(ns.type() == s.type()); +} + +TEST_CASE(test_symbolic_normalize_standard_transposed) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {lit(1), lit(4), c * 4}}; + EXPECT(not s.standard()); + EXPECT(s.transposed()); + auto ns = s.normalize_standard(); + EXPECT(ns == s); +} + +// ------------------------------------------------------------------- +// Symbolic with_lens / from_permutation / find_permutation +// ------------------------------------------------------------------- + +TEST_CASE(test_symbolic_with_lens_standard) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + auto b = var("b", {1, 16}); + std::vector
new_dims = {dd{b}, dd{lit(4)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.standard()); + EXPECT(not s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_transposed) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(1), n}}; + EXPECT(s.transposed()); + auto b = var("b", {1, 16}); + std::vector
new_dims = {dd{b}, dd{lit(4)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_4d) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}, dd{lit(4)}}}; + auto b = var("b", {1, 32}); + auto ch = var("ch", {1, 64}); + std::vector
new_dims = {dd{b}, dd{ch}, dd{lit(8)}, dd{lit(8)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.standard()); + EXPECT(not s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(find_permutation_symbolic_2d_standard) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + std::vector permutation = {0, 1}; + EXPECT(migraphx::find_permutation(s) == permutation); +} + +TEST_CASE(find_permutation_symbolic_2d_transpose) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(1), n}}; + std::vector permutation = {1, 0}; + EXPECT(migraphx::find_permutation(s) == permutation); +} + +TEST_CASE(from_symbolic_2d_permutation) +{ + auto n = var("n", {1, 8}); + std::vector
out_dims = {dd{n}, dd{lit(3)}}; + std::vector permutation = {1, 0}; + migraphx::shape out_shape = + migraphx::shape::from_permutation(migraphx::shape::float_type, out_dims, permutation); + EXPECT(out_shape.dyn_dims() == out_dims); + EXPECT(migraphx::find_permutation(out_shape) == permutation); +} + +TEST_CASE(from_symbolic_3d_permutation) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + std::vector
out_dims = {dd{n}, dd{c}, dd{lit(4)}}; + std::vector permutation = {1, 2, 0}; + migraphx::shape out_shape = + migraphx::shape::from_permutation(migraphx::shape::float_type, out_dims, permutation); + EXPECT(out_shape.dyn_dims() == out_dims); + EXPECT(migraphx::find_permutation(out_shape) == permutation); +} + +TEST_CASE(from_symbolic_4d_permutation) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 64}); + auto h = var("h", {2, 32}); + auto w = var("w", {2, 32}); + std::vector
out_dims = {dd{n}, dd{c}, dd{h}, dd{w}}; + std::vector permutation = {3, 2, 0, 1}; + migraphx::shape out_shape = + migraphx::shape::from_permutation(migraphx::shape::float_type, out_dims, permutation); + EXPECT(out_shape.dyn_dims() == out_dims); + EXPECT(migraphx::find_permutation(out_shape) == permutation); +} + +TEST_CASE(reorder_shape_symbolic) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + std::vector perm = {2, 0, 1}; + auto reordered = migraphx::reorder_shape(s, perm); + EXPECT(reordered.symbolic()); + EXPECT(reordered.dyn_dims().size() == s.dyn_dims().size()); +} + +TEST_CASE(test_symbolic_elements_via_to_static) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + std::unordered_map symbol_map = {{n, 2}, {c, 8}}; + auto ss = s.to_static(symbol_map); + EXPECT(ss.elements() == 2 * 8 * 4); + EXPECT(ss.strides() == std::vector{32, 4, 1}); +} + +// ------------------------------------------------------------------- +// Dynamic dimension: div, add/sub/mul/div with two dd's +// ------------------------------------------------------------------- + +TEST_CASE(dynamic_dimension_div_fixed) +{ + dd a{10, 30, {12, 24}}; + a /= 3; + EXPECT(a.get_interval().min == 3); + EXPECT(a.get_interval().max == 10); + EXPECT(a.get_optimals() == std::set{4, 8}); +} + +TEST_CASE(dynamic_dimension_add_dd) +{ + dd a{2, 8, {4, 6}}; + dd b{3, 5, {3, 5}}; + auto r = a + b; + EXPECT(r.get_interval().min == 5); + EXPECT(r.get_interval().max == 13); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_sub_dd) +{ + dd a{10, 30, {15, 25}}; + dd b{2, 5, {3}}; + auto r = a - b; + EXPECT(r.get_interval().min == 5); + EXPECT(r.get_interval().max == 28); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_mul_dd) +{ + dd a{2, 8, {4}}; + dd b{3, 5, {3, 5}}; + auto r = a * b; + EXPECT(r.get_interval().min == 6); + EXPECT(r.get_interval().max == 40); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_div_dd) +{ + dd a{10, 40, {20, 30}}; + dd b{2, 5, {2, 4}}; + auto r = a / b; + EXPECT(r.get_interval().min == 2); + EXPECT(r.get_interval().max == 20); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_sub_clamp_zero) +{ + dd a{2, 5}; + dd b{4, 8}; + auto r = a - b; + EXPECT(r.get_interval().min == 0); + EXPECT(r.get_interval().max == 1); +} + +TEST_CASE(dynamic_dimension_add_one_fixed) +{ + dd a{4, 4, {4}}; + dd b{2, 8, {3, 6}}; + auto r = a + b; + EXPECT(r.get_interval().min == 6); + EXPECT(r.get_interval().max == 12); + EXPECT(r.get_optimals() == std::set({7, 10})); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_mul_one_fixed) +{ + dd a{3, 3}; + dd b{2, 8, {4, 6}}; + auto r = a * b; + EXPECT(r.get_interval().min == 6); + EXPECT(r.get_interval().max == 24); + EXPECT(r.get_optimals() == std::set({12, 18})); + EXPECT(r.sym_expr.empty()); +} + +// ------------------------------------------------------------------- +// Dynamic dimension: symbolic construction and arithmetic +// ------------------------------------------------------------------- + +TEST_CASE(test_dd_from_empty_expr_throws) +{ + migraphx::sym::expr empty_expr; + EXPECT(test::throws([&] { dd{empty_expr}; })); +} + +TEST_CASE(test_dd_accessors_range_based) +{ + dd a{3, 10, {4, 7}}; + auto iv = a.get_interval(); + EXPECT(iv.min == 3); + EXPECT(iv.max == 10); + EXPECT(a.get_optimals() == std::set({4, 7})); +} + +TEST_CASE(test_dd_accessors_symbolic) +{ + auto n = var("n", {2, 16}, {4, 8}); + dd d{n}; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 16); + EXPECT(d.get_optimals() == std::set({4, 8})); +} + +TEST_CASE(test_dd_symbolic_no_optimals) +{ + auto n = var("n", {3, 12}); + dd d{n}; + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 12); + EXPECT(d.get_optimals().empty()); +} + +TEST_CASE(test_dd_symbolic_add_dd_optimals) +{ + auto h = var("h", {5, 20}, {10, 15}); + auto w = var("w", {5, 20}, {10, 15}); + auto r = dd{h} + dd{w}; + EXPECT(r.sym_expr == h + w); + EXPECT(r.get_interval().min == 10); + EXPECT(r.get_interval().max == 40); + EXPECT(r.get_optimals() == std::set({20, 25, 30})); +} + +TEST_CASE(test_dd_symbolic_sub_dd_optimals) +{ + auto n = var("n", {10, 50}, {20, 30}); + auto k = var("k", {1, 5}, {2, 4}); + auto r = dd{n} - dd{k}; + EXPECT(r.sym_expr == n - k); + EXPECT(r.get_interval().min == 5); + EXPECT(r.get_interval().max == 49); + EXPECT(r.get_optimals() == std::set({16, 18, 26, 28})); +} + +TEST_CASE(test_dd_symbolic_mul_dd_optimals) +{ + auto n = var("n", {1, 8}, {2, 4}); + auto c = var("c", {1, 4}, {2, 3}); + auto r = dd{n} * dd{c}; + EXPECT(r.sym_expr == n * c); + EXPECT(r.get_interval().min == 1); + EXPECT(r.get_interval().max == 32); + EXPECT(r.get_optimals() == std::set({4, 6, 8, 12})); +} + +TEST_CASE(test_dd_symbolic_div_dd_optimals) +{ + auto n = var("n", {10, 50}, {20, 40}); + auto k = var("k", {2, 5}, {2, 5}); + auto r = dd{n} / dd{k}; + EXPECT(r.sym_expr == n / k); + EXPECT(r.get_interval().min == 2); + EXPECT(r.get_interval().max == 25); + EXPECT(r.get_optimals() == std::set({4, 8, 10, 20})); +} + +TEST_CASE(test_dd_symbolic_add_size_t_optimals) +{ + auto n = var("n", {1, 8}, {4, 6}); + dd d{n}; + d += 2; + EXPECT(d.sym_expr == n + 2); + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 10); + EXPECT(d.get_optimals() == std::set({6, 8})); +} + +TEST_CASE(test_dd_symbolic_mul_size_t_optimals) +{ + auto n = var("n", {1, 8}, {2, 4}); + dd d{n}; + d *= 3; + EXPECT(d.sym_expr == n * 3); + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 24); + EXPECT(d.get_optimals() == std::set({6, 12})); +} + +TEST_CASE(test_dd_symbolic_chained_arithmetic_optimals) +{ + auto h = var("h", {10, 50}, {20, 30}); + dd d{h}; + d -= 3; + d /= 2; + d += 1; + EXPECT(d.sym_expr == (h - 3) / 2 + 1); + EXPECT(d.get_interval().min == 4); + EXPECT(d.get_interval().max == 24); + EXPECT(d.get_optimals() == std::set({9, 14})); +} + +TEST_CASE(test_dd_symbolic_arithmetic_invalidates_cache) +{ + auto n = var("n", {2, 8}, {4}); + dd d{n}; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 8); + d += 1; + EXPECT(d.sym_expr == n + 1); + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 9); + EXPECT(d.get_optimals() == std::set({5})); +} + +TEST_CASE(test_dd_range_arithmetic_keeps_cache) +{ + dd a{2, 8, {4}}; + dd b{1, 3}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); + EXPECT(r.get_interval().min == 3); + EXPECT(r.get_interval().max == 11); +} + +TEST_CASE(test_dd_serialize_range_based) +{ + dd a{3, 10, {5, 7}}; + auto v = migraphx::to_value(a); + auto b = migraphx::from_value
(v); + EXPECT(a == b); + EXPECT(b.get_interval().min == 3); + EXPECT(b.get_interval().max == 10); + EXPECT(b.get_optimals() == std::set({5, 7})); +} + +TEST_CASE(test_dd_serialize_symbolic) +{ + auto n = var("n", {2, 16}, {4, 8}); + dd d{n}; + auto v = migraphx::to_value(d); + auto d2 = migraphx::from_value
(v); + EXPECT(d == d2); + EXPECT(d2.get_interval().min == 2); + EXPECT(d2.get_interval().max == 16); + EXPECT(d2.get_optimals() == std::set({4, 8})); +} + +// ------------------------------------------------------------------- +// is_compatible / is_compatible_lens for symbolic shapes +// ------------------------------------------------------------------- + +TEST_CASE(shape_is_compatible_symbolic_same) +{ + auto n = var("n", {1, 8}); + migraphx::shape actual{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + EXPECT(migraphx::shape::is_compatible(actual, expected)); +} + +TEST_CASE(shape_is_compatible_lens_symbolic_same) +{ + auto n = var("n", {1, 8}); + migraphx::shape s1{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape s2{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}}; + EXPECT(migraphx::shape::is_compatible_lens(s1, s2)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/sym_test.cpp b/test/sym_test.cpp index f822b0e773f..73d5ff23684 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -26,7 +26,8 @@ #include #include "test.hpp" -using se = migraphx::sym::expr; +using se = migraphx::sym::expr; +using interval = migraphx::sym::interval; using migraphx::sym::lit; using migraphx::sym::parse; using migraphx::sym::var; @@ -979,4 +980,413 @@ TEST_CASE(serialize_compound) EXPECT(round_trip(e) == e); } +// ------------------------------------------------------------------- +// Bounded vars: constructor / eq / hash +// ------------------------------------------------------------------- + +TEST_CASE(construct_var_min_greater_than_max_throws) +{ + EXPECT(test::throws([&] { var("n", {10, 5}); })); +} + +TEST_CASE(construct_var_min_less_than_one_throws) +{ + EXPECT(test::throws([&] { var("n", {0, 5}); })); + EXPECT(test::throws([&] { var("n", {-1, 5}); })); +} + +TEST_CASE(eq_same_name_different_intervals) +{ + auto h1 = var("h", {1, 128}); + auto h2 = var("h", {1, 256}); + auto h3 = var("h", {2, 128}); + auto h4 = var("h", {1, 128}); + EXPECT(h1 != h2); + EXPECT(h1 != h3); + EXPECT(h1 == h4); +} + +TEST_CASE(hash_same_name_different_intervals) +{ + auto h1 = var("h", {1, 128}); + auto h2 = var("h", {1, 256}); + auto h3 = var("h", {1, 128}); + EXPECT(h1.hash() != h2.hash()); + EXPECT(h1.hash() == h3.hash()); +} + +// ------------------------------------------------------------------- +// Bounds: eval_interval() +// ------------------------------------------------------------------- + +TEST_CASE(eval_interval_single_var) +{ + auto n = var("n", {2, 16}); + EXPECT(n.eval_interval() == interval{2, 16}); +} + +TEST_CASE(eval_interval_literal) { EXPECT(lit(42).eval_interval() == interval{42, 42}); } + +TEST_CASE(eval_interval_compound) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + auto e = n * c * 4; + EXPECT(e.eval_interval() == interval{4, 512}); +} + +TEST_CASE(eval_interval_stride_diff) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + auto diff = n * c - n; + EXPECT(diff.eval_interval() == interval{0, 120}); +} + +TEST_CASE(eval_interval_division) +{ + auto n = var("n", {2, 10}); + auto d = var("d", {1, 5}); + auto e = n / d; + EXPECT(e.eval_interval() == interval{0, 10}); +} + +TEST_CASE(eval_interval_div_literal_denom) +{ + auto n = var("n", {4, 16}); + auto e = n / lit(4); + EXPECT(e.eval_interval() == interval{1, 4}); +} + +TEST_CASE(eval_interval_subtraction_independent) +{ + auto a = var("a", {1, 10}); + auto b = var("b", {1, 5}); + auto e = a - b; + EXPECT(e.eval_interval() == interval{-4, 9}); +} + +TEST_CASE(eval_interval_empty_throws) +{ + se empty; + EXPECT(test::throws([&] { (void)empty.eval_interval(); })); +} + +TEST_CASE(eval_interval_uint) +{ + auto n = var("n", {2, 16}); + auto e = 3 * n + 1; + EXPECT(e.eval_interval() == interval{7, 49}); +} + +// ------------------------------------------------------------------- +// Comparison operators +// ------------------------------------------------------------------- + +TEST_CASE(cmp_lit_constants) +{ + EXPECT(lit(1) < lit(2)); + EXPECT(not(lit(2) < lit(1))); + EXPECT(not(lit(3) < lit(3))); + EXPECT(lit(2) > lit(1)); + EXPECT(lit(3) <= lit(3)); + EXPECT(lit(3) >= lit(3)); + EXPECT(lit(1) <= lit(2)); + EXPECT(lit(2) >= lit(1)); +} + +TEST_CASE(cmp_equal_expr_not_less) +{ + auto n = var("n"); + EXPECT(not(n < n)); + EXPECT(not(n > n)); + EXPECT(n <= n); + EXPECT(n >= n); +} + +TEST_CASE(cmp_empty_not_less) +{ + se a; + se b; + EXPECT(not(a < b)); +} + +TEST_CASE(cmp_empty_with_nonempty_throws) +{ + EXPECT(test::throws([&]() -> bool { return se{} < var("n"); })); + EXPECT(test::throws([&]() -> bool { return var("n") < se{}; })); +} + +TEST_CASE(cmp_stride_ordering_4d) +{ + auto c = var("c", {1, 512}); + auto h = var("h", {1, 256}); + auto w = var("w", {1, 256}); + auto s0 = c * h * w; + auto s1 = h * w; + auto s2 = w; + auto s3 = lit(1); + EXPECT(s1 <= s0); + EXPECT(s2 <= s1); + EXPECT(s3 <= s2); + EXPECT(s3 <= s0); +} + +TEST_CASE(cmp_scaled_symbol) +{ + auto n = var("n"); + EXPECT(n < 2 * n); + EXPECT(n < 3 * n); + EXPECT(not(2 * n < n)); +} + +TEST_CASE(cmp_product_explicit_bounds) +{ + auto k = var("k", {1, 8}); + auto m = var("m", {2, 4}); + EXPECT(k < m * k); +} + +TEST_CASE(cmp_conv_output_smaller_than_input) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + EXPECT(out < h); + EXPECT(not(h < out)); +} + +TEST_CASE(cmp_repeated_pooling) +{ + auto h = var("h", {7, 256}); + auto out1 = (h - 3) / 2 + 1; + auto out2 = (out1 - 3) / 2 + 1; + EXPECT(out1 < h); + EXPECT(out2 < out1); + EXPECT(out2 < h); +} + +TEST_CASE(cmp_strides_after_conv) +{ + auto h = var("h", {7, 128}); + auto w = var("w", {2, 128}); + auto new_h = (h - 3) / 2 + 1; + auto s0 = new_h * w; + auto s1 = w; + auto s2 = lit(1); + EXPECT(s1 < s0); + EXPECT(s2 < s1); +} + +TEST_CASE(cmp_broadcast_stride_zero) +{ + auto w = var("w"); + EXPECT(lit(0) < w); + EXPECT(not(w < lit(0))); +} + +TEST_CASE(cmp_offset_expressions) +{ + auto h = var("h", {2, 256}); + EXPECT(h - 1 < h); + EXPECT(h < h + 1); + EXPECT(not(h + 1 < h)); +} + +TEST_CASE(cmp_undetermined_throws) +{ + auto n = var("n", {2, 10}); + EXPECT(test::throws([&]() -> bool { return n < lit(5); })); +} + +TEST_CASE(cmp_element_count_slice) +{ + auto n = var("n", {1, 32}); + auto c = var("c", {1, 512}); + auto h = var("h", {1, 256}); + auto w = var("w", {2, 256}); + EXPECT(n * c * h < n * c * h * w); +} + +TEST_CASE(cmp_deep_pooling_chain) +{ + auto h = var("h", {31, 512}); + se stage = h; + se prev; + for(int i = 0; i < 5; ++i) + { + prev = stage; + stage = (stage - 1) / 2; + } + EXPECT(stage < prev); + EXPECT(stage < h); +} + +TEST_CASE(cmp_commuted_product) +{ + auto a = var("a"); + auto b = var("b"); + EXPECT(not(a * b < b * a)); + EXPECT(a * b <= b * a); + EXPECT(a * b >= b * a); +} + +TEST_CASE(cmp_negative_literals) +{ + EXPECT(lit(-5) < lit(-1)); + EXPECT(lit(-1) < lit(0)); + EXPECT(lit(-10) < lit(10)); + EXPECT(not(lit(0) < lit(-1))); +} + +TEST_CASE(cmp_symmetry_lt_gt) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + EXPECT(out < h); + EXPECT(h > out); + EXPECT(not(h < out)); + EXPECT(not(out > h)); +} + +TEST_CASE(cmp_transitivity_strides) +{ + auto c = var("c", {2, 512}); + auto h = var("h", {2, 256}); + auto w = var("w", {2, 256}); + auto s0 = c * h * w; + auto s1 = h * w; + auto s2 = w; + auto s3 = lit(1); + EXPECT(s1 < s0); + EXPECT(s2 < s1); + EXPECT(s3 < s2); + EXPECT(s3 < s0); + EXPECT(s2 < s0); + EXPECT(s3 < s1); +} + +TEST_CASE(cmp_division_ordering) +{ + auto h = var("h", {5, 256}); + auto pool2 = (h - 1) / 2; + auto pool4 = (h - 1) / 4; + EXPECT(pool4 < pool2); + EXPECT(pool2 < h); + EXPECT(pool4 < h); +} + +TEST_CASE(cmp_sum_less_than_product) +{ + auto n = var("n", {2, 32}); + auto c = var("c", {3, 512}); + EXPECT(n + c < n * c); +} + +TEST_CASE(cmp_algebraically_equal_expressions) +{ + auto h = var("h"); + auto a = h + h; + auto b = 2 * h; + EXPECT(a == b); + EXPECT(not(a < b)); + EXPECT(not(b < a)); + EXPECT(a <= b); + EXPECT(a >= b); +} + +TEST_CASE(cmp_zero_stride_less_than_symbolic_stride) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(lit(0) < h); + EXPECT(lit(0) < h * w); + EXPECT(lit(0) < h + w); +} + +// ------------------------------------------------------------------- +// Optimals: eval_optimals() +// ------------------------------------------------------------------- + +TEST_CASE(eval_optimals_single_var) +{ + auto n = var("n", {1, 8}, {2, 4}); + EXPECT(n.eval_optimals() == std::set{2, 4}); +} + +TEST_CASE(eval_optimals_compound_expr) +{ + auto n = var("n", {1, 8}, {2, 4}); + auto e = 2 * n + 1; + EXPECT(e.eval_optimals() == std::set{5, 9}); +} + +TEST_CASE(eval_optimals_multi_var) +{ + auto n = var("n", {1, 8}, {2, 4}); + auto m = var("m", {1, 8}, {3, 6}); + auto e = n + m; + EXPECT(e.eval_optimals() == std::set{5, 7, 8, 10}); +} + +TEST_CASE(eval_optimals_negative_throws) +{ + auto n = var("n", {1, 4}, {2}); + auto m = var("m", {1, 8}, {5}); + auto e = n - m; + EXPECT(test::throws([&] { (void)e.eval_optimals(); })); +} + +TEST_CASE(eval_optimals_no_optimals) +{ + auto n = var("n", {1, 8}); + EXPECT(n.eval_optimals().empty()); +} + +TEST_CASE(eval_optimals_empty_expr) +{ + se e; + EXPECT(e.eval_optimals().empty()); +} + +// ------------------------------------------------------------------- +// Serialization: bounded vars +// ------------------------------------------------------------------- + +TEST_CASE(serialize_bounded_var) +{ + auto h = var("h", {1, 128}); + auto r = round_trip(h); + EXPECT(r == h); + EXPECT(r != var("h", {1, 256})); + EXPECT(r != var("h")); +} + +TEST_CASE(serialize_bounded_var_in_expr) +{ + auto h = var("h", {1, 128}); + auto w = var("w", {1, 256}); + auto e = 2 * h + w - 3; + auto r = round_trip(e); + EXPECT(r == e); + EXPECT(r.eval_uint({{h, 64}, {w, 32}}) == 157); +} + +TEST_CASE(serialize_conv_output_with_bounds) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + auto r = round_trip(out); + EXPECT(r == out); + EXPECT(r.eval_uint({{h, 255}}) == 127); +} + +TEST_CASE(serialize_comparison_survives_round_trip) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + auto h2 = round_trip(h); + auto out2 = round_trip(out); + EXPECT(out2 < h2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 5be3d69cc99b9a33ec300159a04a0770d2643afc Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 15:48:08 -0700 Subject: [PATCH 50/60] fix old constructor --- test/serialize_program.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index 783b4d4dbe6..6e6121e8697 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -145,11 +145,12 @@ TEST_CASE(symbolic_shape_msgpack_roundtrip) { using migraphx::shape; using dd = shape::dynamic_dimension; - auto n = migraphx::sym::var("n"); + using migraphx::sym::lit; + auto n = migraphx::sym::var("n", {1, 8}); migraphx::program p; auto* mm = p.get_main_module(); - shape s{shape::float_type, {dd{1, 8, {}, n}, {3, 3}, {4, 4}}}; + shape s{shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; auto x = mm->add_parameter("x", s); auto r = mm->add_instruction(migraphx::make_op("relu"), x); mm->add_return({r}); From 093964c2cf8dee87f11f8e885f313f37fc7e70fe Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 15 Apr 2026 17:11:33 -0700 Subject: [PATCH 51/60] fix ambiguous call --- test/eliminate_concat_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/eliminate_concat_test.cpp b/test/eliminate_concat_test.cpp index ccc6aa2fe5d..773c82619f9 100644 --- a/test/eliminate_concat_test.cpp +++ b/test/eliminate_concat_test.cpp @@ -203,7 +203,7 @@ static migraphx::shape create_shape(Ts... xs) return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}}; else return migraphx::shape::from_permutation( - migraphx::shape::float_type, {std::size_t(xs)...}, {Is...}); + migraphx::shape::float_type, std::vector{std::size_t(xs)...}, {Is...}); } template From 72f00eaea0fb229a8ab1afdaf11a0a0318b1db30 Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 16 Apr 2026 10:10:14 -0700 Subject: [PATCH 52/60] fix cppcheck --- src/shape.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 92a6fe7c1c6..799c2472b85 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -1104,7 +1104,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dyna return apply_op( *this, x, - [](auto& a, auto& b) { return a + b; }, + [](const auto& a, const auto& b) { return a + b; }, [](auto& lhs, const auto& rhs) { auto lhs_fixed = lhs.is_fixed(); auto rhs_fixed = rhs.is_fixed(); @@ -1129,7 +1129,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dyna return apply_op( *this, x, - [](auto& a, auto& b) { return a - b; }, + [](const auto& a, const auto& b) { return a - b; }, [](auto& lhs, const auto& rhs) { auto lhs_fixed = lhs.is_fixed(); auto rhs_fixed = rhs.is_fixed(); @@ -1153,7 +1153,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dyna return apply_op( *this, x, - [](auto& a, auto& b) { return a * b; }, + [](const auto& a, const auto& b) { return a * b; }, [](auto& lhs, const auto& rhs) { auto lhs_fixed = lhs.is_fixed(); auto rhs_fixed = rhs.is_fixed(); @@ -1182,7 +1182,7 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dyna return apply_op( *this, x, - [](auto& a, auto& b) { return a / b; }, + [](const auto& a, const auto& b) { return a / b; }, [](auto& lhs, const auto& rhs) { auto lhs_fixed = lhs.is_fixed(); auto rhs_fixed = rhs.is_fixed(); From 812a1b1dfe43ac3ce8363fe77db20ee79a577629 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 17 Apr 2026 12:46:07 -0700 Subject: [PATCH 53/60] add missing comment blocks --- src/permutation.cpp | 5 +++++ src/shape.cpp | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/permutation.cpp b/src/permutation.cpp index 65096f37dce..0b653974174 100644 --- a/src/permutation.cpp +++ b/src/permutation.cpp @@ -54,6 +54,11 @@ std::vector find_permutation(const shape& s) std::iota(result.begin(), result.end(), 0); if(s.symbolic()) { + // Evaluate symbolic strides/dims at their interval max to get concrete + // values for sorting. We use max rather than min because when min=1, + // stride products collapse (e.g. n*c with n=1,c=1 gives 1) making + // distinct strides appear equal. Max values preserve the structural + // ordering. See is_sorted_strides comment in shape.cpp. const auto& strides = s.dyn_strides(); const auto& dds = s.dyn_dims(); std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { diff --git a/src/shape.cpp b/src/shape.cpp index 799c2472b85..7d0b2be0e05 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -244,6 +244,24 @@ struct shape_impl return std::accumulate(strides.begin(), strides.end(), zero) == zero; } + // Check if strides are in descending order, which is a + // requirement for a shape to be considered standard layout. + // + // For symbolic strides (sym::expr), std::is_sorted cannot be used directly + // because sym::expr::operator< performs interval-based comparison: it + // evaluates the difference at the min and max of all variables' ranges. When + // a symbolic dimension variable has min=1, it can take unit value, collapsing + // stride products and making the comparison undetermined or wrong. + // + // Example: strides {1, n, n*c} with n in [1,8], c in [1,16]. + // Comparing n vs n*c: diff = n*c - n = n*(c-1). + // At all-min (n=1,c=1): diff=0. At all-max (n=8,c=16): diff=120. + // operator< sees range [0,120] -- neither strictly positive nor + // non-positive -- and throws "undetermined". But the shape is clearly + // transposed for all non-degenerate dimension values. + // + // To avoid these issues, symbolic strides are evaluated at their max variable + // values where no dimension is degenerate (unit), then sorted concretely. template static bool is_sorted_strides(const std::vector& strides) { From 8573cfb4f6a2cea2a07ff72b8b6f4ab970f7b992 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 21 Apr 2026 15:36:07 -0700 Subject: [PATCH 54/60] clearly state assumptions used when dealing with stride permutations and explicitly throw when they are violated --- src/permutation.cpp | 26 ++++++++++++++++++++------ src/shape.cpp | 31 ++++++++++++++++--------------- test/shape_test.cpp | 11 +++++++++++ 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/permutation.cpp b/src/permutation.cpp index 0b653974174..8b0d93af9f6 100644 --- a/src/permutation.cpp +++ b/src/permutation.cpp @@ -54,17 +54,31 @@ std::vector find_permutation(const shape& s) std::iota(result.begin(), result.end(), 0); if(s.symbolic()) { - // Evaluate symbolic strides/dims at their interval max to get concrete - // values for sorting. We use max rather than min because when min=1, - // stride products collapse (e.g. n*c with n=1,c=1 gives 1) making - // distinct strides appear equal. Max values preserve the structural - // ordering. See is_sorted_strides comment in shape.cpp. + // Sort symbolic strides by evaluating at max variable values. + // Assumptions (see is_sorted_strides in shape.cpp for details): + // 1. Strides are products of dim variables * constant factors (no symbolic divisors) + // 2. Strides come from compute_strides() or permutations thereof + // 3. Max-eval ordering is consistent with all non-degenerate runtime orderings const auto& strides = s.dyn_strides(); const auto& dds = s.dyn_dims(); + std::vector stride_intervals(strides.size()); + std::transform(strides.begin(), strides.end(), stride_intervals.begin(), [](const auto& e) { + return e.eval_interval(); + }); std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { - return std::make_tuple(strides[x].eval_interval().max, + return std::make_tuple(stride_intervals[x].max, dds[x].sym_expr.eval_interval().max); })); + // Assumption 3 guard: when max-eval gives a strict ordering between two + // adjacent strides, min-eval must not reverse it. Collapse to equality at + // min is expected (e.g. when a dim has min=1), but a sign flip indicates + // a symbolic divisor violating assumption 1. + if(std::adjacent_find(result.begin(), result.end(), [&](auto a, auto b) { + return stride_intervals[a].max > stride_intervals[b].max and + stride_intervals[a].min < stride_intervals[b].min; + }) != result.end()) + MIGRAPHX_THROW("FIND_PERMUTATION: symbolic stride ordering reversal between " + "max-eval and min-eval. Violation of symbolic stride assumptions."); } else { diff --git a/src/shape.cpp b/src/shape.cpp index 7d0b2be0e05..b2d27857ebd 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -92,6 +92,9 @@ struct shape_impl : m_type(t), m_dyn_dims(std::move(dims)), m_dyn_strides(std::move(dstrides)) { assert(m_dyn_strides.size() == m_dyn_dims.size()); + assert(std::all_of(m_dyn_strides.begin(), m_dyn_strides.end(), [](const auto& s) { + return s.eval_interval().min >= 0; + })); auto dim_exprs = sym_dims(); std::vector filtered_strides; for(std::size_t i = 0; i < m_dyn_strides.size(); i++) @@ -244,24 +247,22 @@ struct shape_impl return std::accumulate(strides.begin(), strides.end(), zero) == zero; } - // Check if strides are in descending order, which is a - // requirement for a shape to be considered standard layout. + // Check if strides are in descending order (standard layout). // - // For symbolic strides (sym::expr), std::is_sorted cannot be used directly - // because sym::expr::operator< performs interval-based comparison: it - // evaluates the difference at the min and max of all variables' ranges. When - // a symbolic dimension variable has min=1, it can take unit value, collapsing - // stride products and making the comparison undetermined or wrong. + // For symbolic strides we evaluate at max variable values rather than using + // sym::expr::operator<. This relies on three assumptions: // - // Example: strides {1, n, n*c} with n in [1,8], c in [1,16]. - // Comparing n vs n*c: diff = n*c - n = n*(c-1). - // At all-min (n=1,c=1): diff=0. At all-max (n=8,c=16): diff=120. - // operator< sees range [0,120] -- neither strictly positive nor - // non-positive -- and throws "undetermined". But the shape is clearly - // transposed for all non-degenerate dimension values. + // 1. Symbolic strides are products of dimension variables times constant + // factors — no symbolic divisors. All stride-producing paths (compute_strides, + // step, reshape_lazy) enforce this. + // 2. Strides originate from compute_strides() or permutations thereof + // (reorder_shape / from_permutation), not arbitrary user construction. + // 3. Because strides are products of dims (all >= 1), the ordering at max + // evaluation is consistent with all non-degenerate runtime evaluations. // - // To avoid these issues, symbolic strides are evaluated at their max variable - // values where no dimension is degenerate (unit), then sorted concretely. + // Strict symbolic comparison (operator<) is insufficient: when any dim has + // min=1 (e.g. seq_len in LLM decoding), stride products collapse and the + // comparison throws "undetermined". template static bool is_sorted_strides(const std::vector& strides) { diff --git a/test/shape_test.cpp b/test/shape_test.cpp index d408872b880..dab3146330e 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1779,6 +1779,17 @@ TEST_CASE(reorder_shape_symbolic) EXPECT(reordered.dyn_dims().size() == s.dyn_dims().size()); } +TEST_CASE(find_permutation_symbolic_stride_ordering_reversal) +{ + auto a = var("a", {1, 16}); + auto b = var("b", {1, 4}); + auto c = var("c", {1, 8}); + // a/b has interval [0, 16], c has interval [1, 8]. + // At max: 16 > 8 (a/b sorted first), at min: 0 < 1 (reversal). + migraphx::shape s{migraphx::shape::float_type, {dd{a}, dd{c}}, {a / b, c}}; + EXPECT(test::throws([&] { migraphx::find_permutation(s); })); +} + TEST_CASE(test_symbolic_elements_via_to_static) { auto n = var("n", {1, 8}); From d94f3671e6382fe352f2fd6f5b5d5da2ce88f32e Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 23 Apr 2026 11:22:45 -0700 Subject: [PATCH 55/60] add missing tests --- test/shape_test.cpp | 121 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/test/shape_test.cpp b/test/shape_test.cpp index dab3146330e..fd67fe8aae4 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1643,6 +1643,54 @@ TEST_CASE(test_symbolic_broadcasted2) EXPECT(s.broadcasted()); } +TEST_CASE(test_symbolic_broadcasted3) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(0), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted4) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {c * lit(4), lit(0), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted5) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {lit(1), lit(0), n * c}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_step_broadcasted) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(0), n}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + TEST_CASE(test_symbolic_normalize_standard) { auto n = var("n", {1, 4}); @@ -1715,6 +1763,59 @@ TEST_CASE(test_symbolic_with_lens_4d) EXPECT(s2.dyn_dims() == new_dims); } +TEST_CASE(test_symbolic_with_lens_ambiguous_singleton_nchw) +{ + auto n = var("n", {1, 64}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(24)}, dd{lit(24)}}}; + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(not s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_singleton_nhwc) +{ + auto n = var("n", {1, 64}); + auto s1 = migraphx::reorder_shape( + migraphx::shape{migraphx::shape::float_type, {dd{n}, dd{lit(24)}, dd{lit(24)}, dd{lit(1)}}}, + {0, 3, 1, 2}); + EXPECT(s1.transposed()); + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s1.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_all_singleton) +{ + auto n = var("n", {1, 64}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(1)}, dd{lit(1)}}}; + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.standard()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_nhwc_all_singleton) +{ + auto n = var("n", {1, 64}); + auto s1 = migraphx::reorder_shape( + migraphx::shape{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(1)}, dd{lit(3)}}}, + {0, 3, 1, 2}); + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s1.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + TEST_CASE(find_permutation_symbolic_2d_standard) { auto n = var("n", {1, 8}); @@ -1731,6 +1832,16 @@ TEST_CASE(find_permutation_symbolic_2d_transpose) EXPECT(migraphx::find_permutation(s) == permutation); } +TEST_CASE(find_permutation_symbolic_3d) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + auto h = var("h", {2, 32}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{h}}, {lit(1), c * h, n}}; + std::vector permutation = {1, 2, 0}; + EXPECT(migraphx::find_permutation(s) == permutation); +} + TEST_CASE(from_symbolic_2d_permutation) { auto n = var("n", {1, 8}); @@ -2071,4 +2182,14 @@ TEST_CASE(shape_is_compatible_lens_symbolic_same) EXPECT(migraphx::shape::is_compatible_lens(s1, s2)); } +TEST_CASE(shape_is_compatible_lens_static_vs_symbolic) +{ + auto n = var("n", {2, 8}); + migraphx::shape actual1{migraphx::shape::float_type, {1, 4, 3}}; + migraphx::shape actual2{migraphx::shape::float_type, {1, 16, 3}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}}}; + EXPECT(migraphx::shape::is_compatible_lens(actual1, expected)); + EXPECT(not migraphx::shape::is_compatible_lens(actual2, expected)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 3e0e2c214c3e9640c124cda505172a79fb75d788 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 24 Apr 2026 11:55:20 -0700 Subject: [PATCH 56/60] make var bounds non-optional and add deprecation TODO to clarify the duplication of the interval struct in sym and dynamic_dimension --- src/include/migraphx/shape.hpp | 7 +++++++ src/include/migraphx/sym.hpp | 4 +--- test/sym_test.cpp | 9 ++++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 56379f5274b..b2826e408f7 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -95,6 +95,13 @@ struct MIGRAPHX_EXPORT shape { }; + // TODO: Deprecate the pure range-based form of dynamic_dimension in favor + // of the symbolic form (sym_expr). The current design carries two parallel + // notions of bounds -- dynamic_dimension::interval (std::size_t min/max, + // here) and sym::interval (int64_t min/max, attached to each sym::var) -- + // which is a source of confusion. Once all shape-producing paths go through + // symbolic expressions, `range`/`optimals` and this nested `interval` can + // be removed and bounds will live solely on sym::var. struct MIGRAPHX_EXPORT dynamic_dimension { struct interval diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 6b28e3016da..151a3038840 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -53,9 +53,7 @@ struct interval }; struct expr; -MIGRAPHX_EXPORT expr var(const std::string& name, - interval bounds = {1, 1}, - std::set optimals = {}); +MIGRAPHX_EXPORT expr var(const std::string& name, interval bounds, std::set optimals = {}); MIGRAPHX_EXPORT expr lit(int64_t n); MIGRAPHX_EXPORT expr parse(const std::string& s); diff --git a/test/sym_test.cpp b/test/sym_test.cpp index 73d5ff23684..a777caf129c 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -30,7 +30,14 @@ using se = migraphx::sym::expr; using interval = migraphx::sym::interval; using migraphx::sym::lit; using migraphx::sym::parse; -using migraphx::sym::var; + +// Local wrappers so sym-library arithmetic/canonicalization tests don't have +// to spell out bounds they don't care about +static se var(const std::string& name) { return migraphx::sym::var(name, {1, 1}); } +static se var(const std::string& name, interval bounds, std::set optimals = {}) +{ + return migraphx::sym::var(name, bounds, optimals); +} // =================================================================== // Tier 1: Expression construction and canonicalization From 94c7941471e35fe955cded12aee8c3dd211c323d Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 24 Apr 2026 15:23:28 -0700 Subject: [PATCH 57/60] add scalar variant to ease merging 4782 --- src/include/migraphx/shape.hpp | 5 ++--- src/include/migraphx/sym.hpp | 13 +++++++++++-- src/permutation.cpp | 14 ++++++++++---- src/shape.cpp | 4 ++-- src/sym.cpp | 13 +++++++------ 5 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index b2826e408f7..e59f8f2c684 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -153,9 +153,8 @@ struct MIGRAPHX_EXPORT shape if(is_symbolic()) { auto ival = sym_expr.eval_interval(); - if(ival.min < 0 or ival.max < 0) - MIGRAPHX_THROW("dynamic_dimension: symbolic expression has negative bounds"); - return {static_cast(ival.min), static_cast(ival.max)}; + assert(sym::to(ival.min) >= 0 and sym::to(ival.max) >= 0); + return {sym::to(ival.min), sym::to(ival.max)}; } return *range; } diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 151a3038840..63dde97e37b 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include @@ -41,10 +42,18 @@ struct value; namespace sym { +using scalar = std::variant; + +template +To to(const scalar& v) +{ + return std::visit([](auto x) -> To { return x; }, v); +} + struct interval { - int64_t min = 0; - int64_t max = 0; + scalar min = int64_t{0}; + scalar max = int64_t{0}; friend bool operator==(const interval& a, const interval& b) { return a.min == b.min and a.max == b.max; diff --git a/src/permutation.cpp b/src/permutation.cpp index 8b0d93af9f6..39bf647dc0d 100644 --- a/src/permutation.cpp +++ b/src/permutation.cpp @@ -65,17 +65,23 @@ std::vector find_permutation(const shape& s) std::transform(strides.begin(), strides.end(), stride_intervals.begin(), [](const auto& e) { return e.eval_interval(); }); + std::vector dim_max(dds.size()); + std::transform(dds.begin(), dds.end(), dim_max.begin(), [](const auto& dd) { + return sym::to(dd.sym_expr.eval_interval().max); + }); std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { - return std::make_tuple(stride_intervals[x].max, - dds[x].sym_expr.eval_interval().max); + return std::make_tuple(sym::to(stride_intervals[x].max), + dim_max[x]); })); // Assumption 3 guard: when max-eval gives a strict ordering between two // adjacent strides, min-eval must not reverse it. Collapse to equality at // min is expected (e.g. when a dim has min=1), but a sign flip indicates // a symbolic divisor violating assumption 1. if(std::adjacent_find(result.begin(), result.end(), [&](auto a, auto b) { - return stride_intervals[a].max > stride_intervals[b].max and - stride_intervals[a].min < stride_intervals[b].min; + return sym::to(stride_intervals[a].max) > + sym::to(stride_intervals[b].max) and + sym::to(stride_intervals[a].min) < + sym::to(stride_intervals[b].min); }) != result.end()) MIGRAPHX_THROW("FIND_PERMUTATION: symbolic stride ordering reversal between " "max-eval and min-eval. Violation of symbolic stride assumptions."); diff --git a/src/shape.cpp b/src/shape.cpp index b2d27857ebd..4aa67e157b3 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -93,7 +93,7 @@ struct shape_impl { assert(m_dyn_strides.size() == m_dyn_dims.size()); assert(std::all_of(m_dyn_strides.begin(), m_dyn_strides.end(), [](const auto& s) { - return s.eval_interval().min >= 0; + return sym::to(s.eval_interval().min) >= 0; })); auto dim_exprs = sym_dims(); std::vector filtered_strides; @@ -270,7 +270,7 @@ struct shape_impl { std::vector concrete(strides.size()); std::transform(strides.begin(), strides.end(), concrete.begin(), [](const auto& s) { - return static_cast(s.eval_interval().max); + return sym::to(s.eval_interval().max); }); return std::is_sorted(concrete.rbegin(), concrete.rend()); } diff --git a/src/sym.cpp b/src/sym.cpp index 6a376dded23..4aa68370812 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -1100,9 +1100,9 @@ bool operator<(const expr& a, const expr& b) if(a.empty() or b.empty()) MIGRAPHX_THROW("sym::expr: cannot compare empty expression"); auto ival = (b - a).eval_interval(); - if(ival.min > 0) + if(to(ival.min) > 0) return true; - if(ival.max <= 0) + if(to(ival.max) <= 0) return false; MIGRAPHX_THROW("sym::expr: comparison undetermined for: " + print_expr(a.p->node) + " < " + print_expr(b.p->node)); @@ -1119,12 +1119,13 @@ expr var(const std::string& name, interval bounds, std::set optimals) { if(name.empty()) MIGRAPHX_THROW("sym::var: variable name must not be empty"); - if(bounds.min > bounds.max) + auto bmin = to(bounds.min); + auto bmax = to(bounds.max); + if(bmin > bmax) MIGRAPHX_THROW("sym::var: variable interval must satisfy min <= max"); - if(bounds.min < 1) + if(bmin < 1) MIGRAPHX_THROW("sym::var: variable interval must satisfy min >= 1"); - return {std::make_shared( - make_symbol(name, bounds.min, bounds.max, std::move(optimals)))}; + return {std::make_shared(make_symbol(name, bmin, bmax, std::move(optimals)))}; } expr lit(int64_t n) { return {std::make_shared(make_integer(n))}; } From a64a0d1d42b6590387dea1625adbd570095e12e0 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 24 Apr 2026 16:29:26 -0700 Subject: [PATCH 58/60] fix brace-init ambiguity for sles --- src/include/migraphx/sym.hpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 63dde97e37b..453f063e408 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -54,6 +55,15 @@ struct interval { scalar min = int64_t{0}; scalar max = int64_t{0}; + + interval() = default; + interval(scalar mn, scalar mx) : min{std::move(mn)}, max{std::move(mx)} {} + // Convenience overload so brace-init with bare integer literals (e.g. + // `interval{1, 8}` or `var("n", {1, 8})`) resolves unambiguously rather + // than triggering the variant converting-ctor's int-vs-double tie that + // some libstdc++ versions (notably on SLES) reject. + interval(int64_t mn, int64_t mx) : min{mn}, max{mx} {} + friend bool operator==(const interval& a, const interval& b) { return a.min == b.min and a.max == b.max; From c4a33e1fe13a8aefbf2b02f145b8f8dddbd6c0a2 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 24 Apr 2026 20:08:29 -0700 Subject: [PATCH 59/60] wrap scalar variant to handle ambuigity --- src/include/migraphx/sym.hpp | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 453f063e408..45b31ff0e78 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -30,8 +30,8 @@ #include #include #include +#include #include -#include #include #include @@ -43,27 +43,39 @@ struct value; namespace sym { -using scalar = std::variant; +// Scalar value held by literal expressions and interval bounds. Wraps a +// variant so that integer-literal initialization is unambiguous on stricter +// libstdc++ versions. +struct scalar +{ + std::variant value; + + constexpr scalar() = default; + + template , int> = 0> + constexpr scalar(T v) : value{int64_t{v}} // NOLINT(google-explicit-constructor) + { + } + + template , int> = 0> + constexpr scalar(T v) : value{double{v}} // NOLINT(google-explicit-constructor) + { + } + + friend bool operator==(const scalar& a, const scalar& b) { return a.value == b.value; } + friend bool operator!=(const scalar& a, const scalar& b) { return not(a == b); } +}; template To to(const scalar& v) { - return std::visit([](auto x) -> To { return x; }, v); + return std::visit([](auto x) -> To { return x; }, v.value); } struct interval { scalar min = int64_t{0}; scalar max = int64_t{0}; - - interval() = default; - interval(scalar mn, scalar mx) : min{std::move(mn)}, max{std::move(mx)} {} - // Convenience overload so brace-init with bare integer literals (e.g. - // `interval{1, 8}` or `var("n", {1, 8})`) resolves unambiguously rather - // than triggering the variant converting-ctor's int-vs-double tie that - // some libstdc++ versions (notably on SLES) reject. - interval(int64_t mn, int64_t mx) : min{mn}, max{mx} {} - friend bool operator==(const interval& a, const interval& b) { return a.min == b.min and a.max == b.max; From 3887e3fed9e87e2ae4d2bb3637820aa234bf78af Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 24 Apr 2026 22:34:58 -0700 Subject: [PATCH 60/60] use migraphx_requires and fix tidy --- src/include/migraphx/sym.hpp | 6 +++--- test/sym_test.cpp | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 45b31ff0e78..f5201cdea76 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -30,11 +30,11 @@ #include #include #include -#include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -52,12 +52,12 @@ struct scalar constexpr scalar() = default; - template , int> = 0> + template {})> constexpr scalar(T v) : value{int64_t{v}} // NOLINT(google-explicit-constructor) { } - template , int> = 0> + template {})> constexpr scalar(T v) : value{double{v}} // NOLINT(google-explicit-constructor) { } diff --git a/test/sym_test.cpp b/test/sym_test.cpp index a777caf129c..ff5a5e77043 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -24,6 +24,7 @@ #include #include +#include #include "test.hpp" using se = migraphx::sym::expr; @@ -36,7 +37,7 @@ using migraphx::sym::parse; static se var(const std::string& name) { return migraphx::sym::var(name, {1, 1}); } static se var(const std::string& name, interval bounds, std::set optimals = {}) { - return migraphx::sym::var(name, bounds, optimals); + return migraphx::sym::var(name, bounds, std::move(optimals)); } // ===================================================================