diff --git a/src/gemm.cpp b/src/gemm.cpp index 2deef7fb673..49c21a22fb8 100644 --- a/src/gemm.cpp +++ b/src/gemm.cpp @@ -72,15 +72,18 @@ struct batch_slicer batch_slicer(const shape& mat_shape) { auto n_batch_dims = mat_shape.ndim() - 2; - inner_shape = shape{mat_shape.type(), - {mat_shape.lens().end() - 2, mat_shape.lens().end()}, - {mat_shape.strides().end() - 2, mat_shape.strides().end()}}; + inner_shape = shape{ + mat_shape.type(), + std::vector{mat_shape.lens().end() - 2, mat_shape.lens().end()}, + std::vector{mat_shape.strides().end() - 2, mat_shape.strides().end()}}; if(n_batch_dims > 0) { outer_shape = shape{mat_shape.type(), - {mat_shape.lens().begin(), mat_shape.lens().begin() + n_batch_dims}, - {mat_shape.strides().begin(), mat_shape.strides().begin() + n_batch_dims}}; + std::vector{mat_shape.lens().begin(), + mat_shape.lens().begin() + n_batch_dims}, + std::vector{mat_shape.strides().begin(), + mat_shape.strides().begin() + n_batch_dims}}; } } diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index d8f017c974a..e59f8f2c684 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -39,6 +39,7 @@ #include #include #include +#include #include namespace migraphx { @@ -94,6 +95,13 @@ struct MIGRAPHX_EXPORT shape { }; + // TODO: Deprecate the pure range-based form of dynamic_dimension in favor + // of the symbolic form (sym_expr). The current design carries two parallel + // notions of bounds -- dynamic_dimension::interval (std::size_t min/max, + // here) and sym::interval (int64_t min/max, attached to each sym::var) -- + // which is a source of confusion. Once all shape-producing paths go through + // symbolic expressions, `range`/`optimals` and this nested `interval` can + // be removed and bounds will live solely on sym::var. struct MIGRAPHX_EXPORT dynamic_dimension { struct interval @@ -112,42 +120,76 @@ struct MIGRAPHX_EXPORT shape friend bool operator!=(const interval& a, const interval& b) { return not(a == b); } }; - interval range = {0, 0}; - std::set optimals{}; + std::optional range; + std::optional> optimals; + sym::expr sym_expr; dynamic_dimension() = default; - dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} {} + dynamic_dimension(std::size_t min_v, std::size_t max_v) + : range{interval{min_v, max_v}}, optimals{std::set{}} + { + } dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set opt) - : range{min_v, max_v}, optimals(std::move(opt)) + : range{interval{min_v, max_v}}, + optimals(min_v == max_v ? std::set{} : std::move(opt)) + { + } + dynamic_dimension(sym::expr s) : sym_expr(std::move(s)) { + if(sym_expr.empty()) + MIGRAPHX_THROW( + "dynamic_dimension: cannot construct from an empty symbolic expression"); } template static auto reflect(Self& self, F f) { - return pack(f(self.range, "range"), f(self.optimals, "optimals")); + return pack( + f(self.range, "range"), f(self.optimals, "optimals"), f(self.sym_expr, "sym")); } - interval get_interval() const { return range; } - std::set get_optimals() const { return optimals; } + interval get_interval() const + { + if(is_symbolic()) + { + auto ival = sym_expr.eval_interval(); + assert(sym::to(ival.min) >= 0 and sym::to(ival.max) >= 0); + return {sym::to(ival.min), sym::to(ival.max)}; + } + return *range; + } + std::set get_optimals() const + { + if(is_symbolic()) + return sym_expr.eval_optimals(); + if(optimals.has_value()) + return *optimals; + return {}; + } bool is_fixed() const; + bool is_symbolic() const { return not sym_expr.empty(); } bool has_optimal() const; /** * Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if - * possible. + * possible. When both dimensions are symbolic, they are compatible only if they + * share the same symbolic expression. */ std::optional intersection(const dynamic_dimension& other) const { + if(this->is_symbolic() and other.is_symbolic()) + { + if(this->sym_expr == other.sym_expr) + return *this; + return nullopt; + } auto this_interval = this->get_interval(); auto other_interval = other.get_interval(); auto left = std::max(this_interval.min, other_interval.min); auto right = std::min(this_interval.max, other_interval.max); if(left <= right) - { return dynamic_dimension{left, right}; - } return nullopt; } @@ -164,20 +206,24 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y); - // add, subtract, multiply fixed std::size_t dimension - dynamic_dimension& operator+=(const std::size_t& x); - dynamic_dimension& operator-=(const std::size_t& x); - dynamic_dimension& operator*=(const std::size_t& x); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x, - const dynamic_dimension& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x, - const std::size_t& y); - MIGRAPHX_EXPORT friend dynamic_dimension operator*(const std::size_t& x, - const dynamic_dimension& y); + // clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(binary_op, assign_op) \ + dynamic_dimension& operator assign_op(const dynamic_dimension& x); \ + dynamic_dimension& operator assign_op(const std::size_t& x); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const dynamic_dimension& x, const dynamic_dimension& y); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const dynamic_dimension& x, const std::size_t& y); \ + MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \ + const std::size_t& x, const dynamic_dimension& y); + + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(+, +=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(-, -=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(*, *=) + MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(/, /=) +#undef MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP + // clang-format on }; static std::string to_sizes_string(const std::vector& shapes); @@ -202,8 +248,10 @@ struct MIGRAPHX_EXPORT shape // Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to // shape(type_t, std::vector l) shape(type_t t, std::initializer_list d); + shape(type_t t, std::initializer_list l, std::initializer_list s); shape(type_t t, std::vector dims); + shape(type_t t, std::vector dims, std::vector dstrides); // Construct a dynamic shape from vectors of mins, maxes, and optimals. // optimals_list is a vector of optimals that corresponds to each min and max. @@ -242,6 +290,9 @@ struct MIGRAPHX_EXPORT shape */ static shape from_permutation(type_t t, const std::vector& l, const std::vector& perm); + static shape from_permutation(type_t t, + const std::vector& dds, + const std::vector& perm); type_t type() const; const std::vector& lens() const; @@ -272,6 +323,9 @@ struct MIGRAPHX_EXPORT shape const std::vector& dyn_dims() const; + bool symbolic() const; + const std::vector& dyn_strides() const; + /*! * Minimum lengths for dynamic shape. * lens() for static shape. @@ -388,14 +442,17 @@ struct MIGRAPHX_EXPORT shape shape with_lens(type_t t, const std::vector& l) const; shape with_lens(const std::vector& l) const; + shape with_lens(type_t t, const std::vector& dds) const; + shape with_lens(const std::vector& dds) const; shape with_type(type_t t) const; - // convert the shape to an equivalent dynamic shape with empty optimals + // convert the shape to an equivalent dynamic shape with constant symbolic strides shape to_dynamic() const; // convert the shape to a static one setting any non-fixed dynamic_dimensions to x shape to_static(std::size_t x) const; + shape to_static(const std::unordered_map& symbol_map) const; MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index 59d3f842884..f5201cdea76 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -28,10 +28,13 @@ #include #include #include +#include #include #include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -40,8 +43,48 @@ struct value; namespace sym { +// Scalar value held by literal expressions and interval bounds. Wraps a +// variant so that integer-literal initialization is unambiguous on stricter +// libstdc++ versions. +struct scalar +{ + std::variant value; + + constexpr scalar() = default; + + template {})> + constexpr scalar(T v) : value{int64_t{v}} // NOLINT(google-explicit-constructor) + { + } + + template {})> + constexpr scalar(T v) : value{double{v}} // NOLINT(google-explicit-constructor) + { + } + + friend bool operator==(const scalar& a, const scalar& b) { return a.value == b.value; } + friend bool operator!=(const scalar& a, const scalar& b) { return not(a == b); } +}; + +template +To to(const scalar& v) +{ + return std::visit([](auto x) -> To { return x; }, v.value); +} + +struct interval +{ + scalar min = int64_t{0}; + scalar max = int64_t{0}; + friend bool operator==(const interval& a, const interval& b) + { + return a.min == b.min and a.max == b.max; + } + friend bool operator!=(const interval& a, const interval& b) { return not(a == b); } +}; + struct expr; -MIGRAPHX_EXPORT expr var(const std::string& name); +MIGRAPHX_EXPORT expr var(const std::string& name, interval bounds, std::set optimals = {}); MIGRAPHX_EXPORT expr lit(int64_t n); MIGRAPHX_EXPORT expr parse(const std::string& s); @@ -50,11 +93,14 @@ struct MIGRAPHX_EXPORT expr expr(); bool empty() const; + bool is_literal() const; std::size_t hash() const; std::string to_string() const; value to_value() const; void from_value(const value& v); std::size_t eval_uint(const std::unordered_map& symbol_map) const; + interval eval_interval() const; + std::set eval_optimals() const; expr subs(const std::unordered_map& symbol_map) const; MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); @@ -63,6 +109,10 @@ struct MIGRAPHX_EXPORT expr MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b); MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b); MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b); + MIGRAPHX_EXPORT friend bool operator<(const expr& a, const expr& b); + friend bool operator>(const expr& a, const expr& b) { return b < a; } + friend bool operator<=(const expr& a, const expr& b) { return not(b < a); } + friend bool operator>=(const expr& a, const expr& b) { return not(a < b); } MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e); friend expr operator+(const expr& a, int64_t b) { return a + lit(b); } @@ -76,7 +126,8 @@ struct MIGRAPHX_EXPORT expr struct impl; - MIGRAPHX_EXPORT friend expr var(const std::string& name); + MIGRAPHX_EXPORT friend expr + var(const std::string& name, interval bounds, std::set optimals); MIGRAPHX_EXPORT friend expr lit(int64_t n); MIGRAPHX_EXPORT friend expr parse(const std::string& s); diff --git a/src/permutation.cpp b/src/permutation.cpp index f152e2c5a26..39bf647dc0d 100644 --- a/src/permutation.cpp +++ b/src/permutation.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS { shape reorder_shape(const shape& s, const std::vector& permutation) { + if(s.symbolic()) + return {s.type(), + reorder_dims(s.dyn_dims(), permutation), + reorder_dims(s.dyn_strides(), permutation)}; return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)}; } @@ -43,11 +48,50 @@ std::vector invert_permutation(const std::vector& permutation) std::vector find_permutation(const shape& s) { - std::vector result(s.lens().size()); + if(s.dynamic() and not s.symbolic()) + MIGRAPHX_THROW("FIND_PERMUTATION: non-symbolic dynamic shapes not supported"); + std::vector result(s.ndim()); std::iota(result.begin(), result.end(), 0); - std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { - return std::make_tuple(s.strides()[x], s.lens()[x]); - })); + if(s.symbolic()) + { + // Sort symbolic strides by evaluating at max variable values. + // Assumptions (see is_sorted_strides in shape.cpp for details): + // 1. Strides are products of dim variables * constant factors (no symbolic divisors) + // 2. Strides come from compute_strides() or permutations thereof + // 3. Max-eval ordering is consistent with all non-degenerate runtime orderings + const auto& strides = s.dyn_strides(); + const auto& dds = s.dyn_dims(); + std::vector stride_intervals(strides.size()); + std::transform(strides.begin(), strides.end(), stride_intervals.begin(), [](const auto& e) { + return e.eval_interval(); + }); + std::vector dim_max(dds.size()); + std::transform(dds.begin(), dds.end(), dim_max.begin(), [](const auto& dd) { + return sym::to(dd.sym_expr.eval_interval().max); + }); + std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { + return std::make_tuple(sym::to(stride_intervals[x].max), + dim_max[x]); + })); + // Assumption 3 guard: when max-eval gives a strict ordering between two + // adjacent strides, min-eval must not reverse it. Collapse to equality at + // min is expected (e.g. when a dim has min=1), but a sign flip indicates + // a symbolic divisor violating assumption 1. + if(std::adjacent_find(result.begin(), result.end(), [&](auto a, auto b) { + return sym::to(stride_intervals[a].max) > + sym::to(stride_intervals[b].max) and + sym::to(stride_intervals[a].min) < + sym::to(stride_intervals[b].min); + }) != result.end()) + MIGRAPHX_THROW("FIND_PERMUTATION: symbolic stride ordering reversal between " + "max-eval and min-eval. Violation of symbolic stride assumptions."); + } + else + { + std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { + return std::make_tuple(s.strides()[x], s.lens()[x]); + })); + } return result; } @@ -64,7 +108,7 @@ std::vector find_permutation(const std::vector& shapes) } if(count.empty()) { - std::vector r(shapes.front().lens().size()); + std::vector r(shapes.front().ndim()); std::iota(r.begin(), r.end(), 0); return r; } diff --git a/src/shape.cpp b/src/shape.cpp index 7483f5cb91a..4aa67e157b3 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -78,6 +79,29 @@ struct shape_impl shape_impl(shape::type_t t, std::vector dims) : m_type(t), m_dyn_dims(std::move(dims)) { + if(all_dims_symbolic()) + { + calculate_dyn_strides(); + m_standard = true; + } + } + + shape_impl(shape::type_t t, + std::vector dims, + std::vector dstrides) + : m_type(t), m_dyn_dims(std::move(dims)), m_dyn_strides(std::move(dstrides)) + { + assert(m_dyn_strides.size() == m_dyn_dims.size()); + assert(std::all_of(m_dyn_strides.begin(), m_dyn_strides.end(), [](const auto& s) { + return sym::to(s.eval_interval().min) >= 0; + })); + auto dim_exprs = sym_dims(); + std::vector filtered_strides; + for(std::size_t i = 0; i < m_dyn_strides.size(); i++) + if(m_dyn_dims[i] != 1) + filtered_strides.push_back(m_dyn_strides[i]); + m_standard = compute_packed(dim_exprs, m_dyn_strides) and + is_sorted_strides(filtered_strides); } shape_impl(shape::type_t t, @@ -101,6 +125,11 @@ struct shape_impl m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]}); } } + if(all_dims_symbolic()) + { + calculate_dyn_strides(); + m_standard = true; + } } shape_impl(const std::vector& subs) : m_type(shape::tuple_type), m_shapes(subs) {} @@ -112,18 +141,203 @@ struct shape_impl bool m_standard = false; std::vector m_dyn_dims = {}; + std::vector m_dyn_strides = {}; + + bool all_dims_symbolic() const + { + return not m_dyn_dims.empty() and + std::all_of(m_dyn_dims.begin(), m_dyn_dims.end(), [](const auto& d) { + return d.is_symbolic(); + }); + } + + std::vector sym_dims() const + { + if(m_dyn_dims.empty()) + { + std::vector result(m_lens.size()); + std::transform(m_lens.begin(), m_lens.end(), result.begin(), [](auto len) { + return sym::lit(len); + }); + return result; + } + std::vector result(m_dyn_dims.size()); + std::transform(m_dyn_dims.begin(), m_dyn_dims.end(), result.begin(), [](const auto& dd) { + return dd.sym_expr; + }); + return result; + } + + template + static T make_identity(int64_t n) + { + if constexpr(std::is_same{}) + return sym::lit(n); + else + return T(n); + } + + template + static std::vector compute_strides(const std::vector& dims) + { + std::vector strides(dims.size()); + if(strides.empty()) + return strides; + strides.back() = make_identity(1); + std::partial_sum(dims.rbegin(), dims.rend() - 1, strides.rbegin() + 1, std::multiplies<>{}); + return strides; + } + + void calculate_dyn_strides() { m_dyn_strides = compute_strides(sym_dims()); } + + void calculate_strides() { m_strides = compute_strides(m_lens); } + + template + static T compute_elements(const std::vector& dims) + { + if(dims.empty()) + return make_identity(0); + return std::accumulate(dims.begin(), dims.end(), make_identity(1), std::multiplies<>{}); + } + + template + static T compute_element_space(const std::vector& dims, const std::vector& strides) + { + if(dims.empty()) + return make_identity(0); + auto one = make_identity(1); + return std::inner_product(dims.begin(), + dims.end(), + strides.begin(), + make_identity(0), + std::plus<>{}, + [&](const T& l, const T& s) { return (l - one) * s; }) + + one; + } + + template + static bool compute_skips(const std::vector& dims, const std::vector& strides) + { + if(compute_elements(dims) == make_identity(1)) + return false; + auto one = make_identity(1); + return std::none_of( + strides.begin(), strides.end(), [&](const auto& x) { return x == one; }); + } + + template + static bool compute_packed(const std::vector& dims, const std::vector& strides) + { + return not compute_skips(dims, strides) and + compute_elements(dims) == compute_element_space(dims, strides); + } + + template + static bool compute_broadcasted(const std::vector& strides) + { + auto zero = make_identity(0); + return std::any_of( + strides.begin(), strides.end(), [&](const auto& x) { return x == zero; }); + } + + template + static bool compute_scalar(const std::vector& strides) + { + auto zero = make_identity(0); + return std::accumulate(strides.begin(), strides.end(), zero) == zero; + } + + // Check if strides are in descending order (standard layout). + // + // For symbolic strides we evaluate at max variable values rather than using + // sym::expr::operator<. This relies on three assumptions: + // + // 1. Symbolic strides are products of dimension variables times constant + // factors — no symbolic divisors. All stride-producing paths (compute_strides, + // step, reshape_lazy) enforce this. + // 2. Strides originate from compute_strides() or permutations thereof + // (reorder_shape / from_permutation), not arbitrary user construction. + // 3. Because strides are products of dims (all >= 1), the ordering at max + // evaluation is consistent with all non-degenerate runtime evaluations. + // + // Strict symbolic comparison (operator<) is insufficient: when any dim has + // min=1 (e.g. seq_len in LLM decoding), stride products collapse and the + // comparison throws "undetermined". + template + static bool is_sorted_strides(const std::vector& strides) + { + if constexpr(std::is_same{}) + { + std::vector concrete(strides.size()); + std::transform(strides.begin(), strides.end(), concrete.begin(), [](const auto& s) { + return sym::to(s.eval_interval().max); + }); + return std::is_sorted(concrete.rbegin(), concrete.rend()); + } + else + { + return std::is_sorted(strides.rbegin(), strides.rend()); + } + } + + template + static bool compute_transposed(const std::vector& strides) + { + if(compute_broadcasted(strides)) + { + std::vector s; + s.reserve(strides.size()); + auto zero = make_identity(0); + std::copy_if(strides.begin(), strides.end(), std::back_inserter(s), [&](const auto& x) { + return x != zero; + }); + return not is_sorted_strides(s); + } + return not is_sorted_strides(strides); + } + + bool is_packed() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_packed(sym_dims(), m_dyn_strides); + } + return compute_packed(m_lens, m_strides); + } + + bool is_broadcasted() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_broadcasted(m_dyn_strides); + } + return compute_broadcasted(m_strides); + } + + bool is_transposed() const + { + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_transposed(m_dyn_strides); + } + return compute_transposed(m_strides); + } - void calculate_strides() + bool is_scalar() const { - m_strides.clear(); - m_strides.resize(m_lens.size(), 0); - if(m_strides.empty()) - return; - m_strides.back() = 1; - std::partial_sum(m_lens.rbegin(), - m_lens.rend() - 1, - m_strides.rbegin() + 1, - std::multiplies()); + if(not m_dyn_dims.empty()) + { + if(m_dyn_strides.empty()) + return false; + return compute_scalar(m_dyn_strides); + } + return compute_scalar(m_strides); } std::size_t element_space() const @@ -135,7 +349,6 @@ struct shape_impl return std::accumulate( maxes.begin(), maxes.end(), std::size_t{1}, [&](std::size_t x, std::size_t y) { - // overflow check and clip if(x != 0 and y > max_val / x) { return max_val; @@ -145,15 +358,7 @@ struct shape_impl } assert(m_lens.size() == m_strides.size()); - if(m_lens.empty()) - return 0; - return std::inner_product(m_lens.begin(), - m_lens.end(), - m_strides.begin(), - std::size_t{0}, - std::plus{}, - [](std::size_t l, std::size_t s) { return (l - 1) * s; }) + - 1; + return compute_element_space(m_lens, m_strides); } std::size_t elements() const @@ -164,10 +369,7 @@ struct shape_impl } assert(m_lens.size() == m_strides.size()); - if(m_lens.empty()) - return 0; - return std::accumulate( - m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies()); + return compute_elements(m_lens); } std::size_t get_index(size_t i) const @@ -216,13 +418,10 @@ struct shape_impl return ret; } - // Does the shape skip over elements? bool skips() const { assert(m_lens.size() == m_strides.size()); - if(elements() == 1) - return false; - return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; }); + return compute_skips(m_lens, m_strides); } std::shared_ptr copy() const { return std::make_shared(*this); } @@ -357,11 +556,23 @@ shape::shape(type_t t, std::initializer_list d) { } +shape::shape(type_t t, std::initializer_list l, std::initializer_list s) + : shape::shape(t, + std::vector{l.begin(), l.end()}, + std::vector{s.begin(), s.end()}) +{ +} + shape::shape(type_t t, std::vector dims) : impl(std::make_shared(t, std::move(dims))) { } +shape::shape(type_t t, std::vector dims, std::vector dstrides) + : impl(std::make_shared(t, std::move(dims), std::move(dstrides))) +{ +} + shape::shape(type_t t, std::vector mins, std::vector maxes, @@ -375,16 +586,34 @@ shape::shape(const std::vector& subs) : impl(std::make_shared shape::shape(std::shared_ptr pimpl) : impl(std::move(pimpl)) {} +template +static shape +from_permutation_impl(shape::type_t t, const Dims& dims, const std::vector& perm) +{ + auto reordered = reorder_dims(dims, perm); + return reorder_shape({t, reordered}, invert_permutation(perm)); +} + shape shape::from_permutation(type_t t, const std::vector& l, const std::vector& perm) { - auto new_lens = reorder_dims(l, perm); - shape result = reorder_shape({t, new_lens}, invert_permutation(perm)); + shape result = from_permutation_impl(t, l, perm); assert(result.lens() == l); return result; } +shape shape::from_permutation(type_t t, + const std::vector& dds, + const std::vector& perm) +{ + if(std::any_of(dds.begin(), dds.end(), [](const auto& dd) { return not dd.is_symbolic(); })) + MIGRAPHX_THROW("FROM_PERMUTATION: non-symbolic dynamic dimensions not supported"); + shape result = from_permutation_impl(t, dds, perm); + assert(result.dyn_dims() == dds); + return result; +} + shape::type_t shape::type() const { return impl->m_type; } const std::vector& shape::lens() const @@ -527,76 +756,50 @@ std::size_t shape::single(const std::vector& idx) const bool shape::packed() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - return this->sub_shapes().empty() and not impl->skips() and - this->elements() == this->element_space(); + return this->sub_shapes().empty() and impl->is_packed(); } bool shape::transposed() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - if(this->broadcasted()) - { - // TODO: Use a filter_iterator instead - std::vector s; - s.reserve(this->strides().size()); - std::copy_if(this->strides().begin(), - this->strides().end(), - std::back_inserter(s), - [](std::size_t x) { return x != 0; }); - return not std::is_sorted(s.rbegin(), s.rend()); - } - else - { - return not std::is_sorted(this->strides().rbegin(), this->strides().rend()); - } + return impl->is_transposed(); } bool shape::broadcasted() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - assert(this->lens().size() == this->strides().size()); - return std::any_of( - this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; }); + return impl->is_broadcasted(); } bool shape::scalar() const { - if(this->dynamic()) - { + if(this->dynamic() and not this->symbolic()) return false; - } - assert(this->lens().size() == this->strides().size()); - // if any stride > 0, then accumulate will return false - return this->sub_shapes().empty() and - std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; + return this->sub_shapes().empty() and impl->is_scalar(); } bool shape::standard() const { return impl->m_standard; } shape shape::normalize_standard() const { - if(this->standard()) - return {this->type(), this->lens()}; - else + if(not this->standard()) return *this; + if(this->symbolic()) + return {this->type(), this->dyn_dims()}; + return {this->type(), this->lens()}; } shape shape::as_standard() const { + if(this->symbolic()) + return {this->type(), this->dyn_dims()}; if(not this->dynamic()) return {this->type(), this->lens()}; - else - return *this; + return *this; } shape shape::with_lens(type_t t, const std::vector& l) const @@ -619,6 +822,20 @@ shape shape::with_lens(const std::vector& l) const return this->with_lens(this->type(), l); } +shape shape::with_lens(type_t t, const std::vector& dds) const +{ + if(this->dynamic() and not this->symbolic()) + MIGRAPHX_THROW("SHAPE: with_lens() called on non-symbolic dynamic shape"); + assert(dds.size() == this->ndim()); + auto perm = find_permutation(*this); + return shape::from_permutation(t, dds, perm); +} + +shape shape::with_lens(const std::vector& dds) const +{ + return this->with_lens(this->type(), dds); +} + shape shape::with_type(type_t t) const { auto c = impl->copy(); @@ -641,7 +858,17 @@ shape shape::to_dynamic() const { return *this; } - return {type(), lens(), lens(), {}}; + std::vector dims; + dims.reserve(ndim()); + std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { + return dynamic_dimension{len, len}; + }); + std::vector dstrides; + dstrides.reserve(ndim()); + std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) { + return sym::lit(s); + }); + return {type(), std::move(dims), std::move(dstrides)}; } shape shape::to_static(std::size_t x) const @@ -668,6 +895,40 @@ shape shape::to_static(std::size_t x) const return {type(), static_lens}; } +shape shape::to_static(const std::unordered_map& symbol_map) const +{ + if(not sub_shapes().empty()) + { + std::vector subs; + std::transform(sub_shapes().cbegin(), + sub_shapes().cend(), + std::back_inserter(subs), + [&](auto s) { return s.to_static(symbol_map); }); + return shape(subs); + } + if(not this->dynamic()) + return *this; + std::vector static_lens(this->ndim()); + std::transform(this->dyn_dims().cbegin(), + this->dyn_dims().cend(), + static_lens.begin(), + [&](const auto& dd) -> std::size_t { + if(dd.is_fixed()) + return dd.get_interval().min; + if(not dd.sym_expr.empty()) + return dd.sym_expr.eval_uint(symbol_map); + MIGRAPHX_THROW("to_static: non-fixed dimension has no symbolic expression"); + }); + const auto& ds = this->dyn_strides(); + if(ds.empty()) + return {type(), static_lens}; + std::vector static_strides(ds.size()); + std::transform(ds.cbegin(), ds.cend(), static_strides.begin(), [&](const auto& s) { + return s.eval_uint(symbol_map); + }); + return {type(), static_lens, static_strides}; +} + std::size_t shape::element_space() const { return impl->element_space(); } std::string shape::type_string() const { return name(this->type()); } @@ -696,6 +957,10 @@ const std::vector& shape::dyn_dims() const return impl->m_dyn_dims; } +bool shape::symbolic() const { return impl->all_dims_symbolic(); } + +const std::vector& shape::dyn_strides() const { return impl->m_dyn_strides; } + std::vector shape::min_lens() const { return this->dynamic() ? impl->min_lens() : this->lens(); @@ -710,58 +975,46 @@ std::vector> shape::opt_lens() const { return impl->opt_le bool shape::dynamic_dimension::is_fixed() const { + if(sym_expr.is_literal()) + return true; auto i = this->get_interval(); return i.min == i.max; } bool shape::dynamic_dimension::has_optimal() const { return not this->get_optimals().empty(); } -shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) -{ - this->range.min += x; - this->range.max += x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (opt + x); }); - this->optimals = new_optimals; - return *this; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x) -{ - assert(this->range.min >= x); - assert(this->range.max >= x); - this->range.min -= x; - this->range.max -= x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { - assert(opt >= x); - return (opt - x); - }); - this->optimals = new_optimals; - return *this; -} - -shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const std::size_t& x) -{ - this->range.min *= x; - this->range.max *= x; - std::set new_optimals; - std::transform(this->optimals.begin(), - this->optimals.end(), - std::inserter(new_optimals, new_optimals.begin()), - [&x](const auto& opt) { return (opt * x); }); - this->optimals = new_optimals; - return *this; -} +// clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(binary_op, assign_op) \ + shape::dynamic_dimension& shape::dynamic_dimension::operator assign_op(const std::size_t& x) \ + { \ + return *this assign_op dynamic_dimension{sym::lit(x)}; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const shape::dynamic_dimension& x, const std::size_t& y) \ + { \ + auto result = x; \ + result assign_op y; \ + return result; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const std::size_t& x, const shape::dynamic_dimension& y) \ + { \ + return shape::dynamic_dimension{sym::lit(x)} binary_op y; \ + } \ + shape::dynamic_dimension operator binary_op( \ + const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) \ + { \ + auto result = x; \ + result assign_op y; \ + return result; \ + } +// clang-format on bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) { + if(not(x.sym_expr == y.sym_expr)) + return false; return (x.get_interval() == y.get_interval() and ((x.is_fixed() and y.is_fixed()) or (x.get_optimals() == y.get_optimals()))); } @@ -773,8 +1026,15 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) { auto x_interval = x.get_interval(); - os << "[ " << x_interval.min << ", " << x_interval.max << ", {" - << migraphx::to_string_range(x.get_optimals()) << "} ]"; + if(x.is_symbolic()) + os << x.sym_expr; + if(x.is_fixed()) + { + if(not x.is_symbolic()) + os << x_interval.min; + return os; + } + os << "[" << x_interval.min << ".." << x_interval.max << "]"; return os; } @@ -787,40 +1047,185 @@ bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { retur bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); } bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); } -shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(+, +=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(-, -=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(*, *=) +MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP(/, /=) +#undef MIGRAPHX_SHAPE_DYN_DIM_IMPLEMENT_OP + +// When one operand is fixed, shift the other's optimals by the fixed value. +// When neither is fixed, optimals are cleared. +template +static void merge_optimals(std::set& optimals, + bool lhs_fixed, + const std::set& rhs_optimals, + bool rhs_fixed, + F1 shift_lhs, + F2 shift_rhs) +{ + if(rhs_fixed) + { + std::set result; + std::transform( + optimals.begin(), optimals.end(), std::inserter(result, result.begin()), shift_lhs); + optimals = result; + } + else if(lhs_fixed) + { + std::set result; + std::transform(rhs_optimals.begin(), + rhs_optimals.end(), + std::inserter(result, result.begin()), + shift_rhs); + optimals = result; + } + else + { + optimals.clear(); + } +} + +// Arithmetic semantics: symbolic + symbolic = symbolic, +// range + range = range, range + symbolic = range. +template +static shape::dynamic_dimension& apply_op(shape::dynamic_dimension& lhs, + const shape::dynamic_dimension& rhs, + SymOp sym_op, + RangeOp range_op) { - auto dd = x; - return dd += y; + auto lhs_sym = lhs.sym_expr; + auto rhs_sym = rhs.sym_expr; + auto result_sym = sym_op(lhs_sym, rhs_sym); + if(not result_sym.empty()) + { + lhs.sym_expr = result_sym; + lhs.range = std::nullopt; + lhs.optimals = std::nullopt; + } + else + { + // Materialize symbolic operands as range-based shapes so that + // arithmetic between symbolic and range-based dimensions works. + auto to_range = [](const shape::dynamic_dimension& d) { + auto iv = d.get_interval(); + return shape::dynamic_dimension{iv.min, iv.max, d.get_optimals()}; + }; + auto lhs_range = lhs.is_symbolic() ? to_range(lhs) : lhs; + auto rhs_range = rhs.is_symbolic() ? to_range(rhs) : rhs; + range_op(lhs_range, rhs_range); + lhs = lhs_range; + } + return lhs; } -shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const shape::dynamic_dimension& x) { - return y + x; + return apply_op( + *this, + x, + [](const auto& a, const auto& b) { return a + b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + lhs.range->min += rhs.range->min; + lhs.range->max = + (lhs.range->max > std::numeric_limits::max() - rhs.range->max) + ? std::numeric_limits::max() + : lhs.range->max + rhs.range->max; + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return o + rhs.range->min; }, + [&](auto o) { return o + lhs_min; }); + }); } -shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const shape::dynamic_dimension& x) { - auto dd = x; - return dd -= y; + return apply_op( + *this, + x, + [](const auto& a, const auto& b) { return a - b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + lhs.range->min = + (lhs.range->min > rhs.range->max) ? lhs.range->min - rhs.range->max : 0; + lhs.range->max = + (lhs.range->max > rhs.range->min) ? lhs.range->max - rhs.range->min : 0; + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return (o > rhs.range->min) ? o - rhs.range->min : std::size_t{0}; }, + [&](auto o) { return (lhs_min > o) ? lhs_min - o : std::size_t{0}; }); + }); } -shape::dynamic_dimension operator*(const shape::dynamic_dimension& x, const std::size_t& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator*=(const shape::dynamic_dimension& x) { - auto dd = x; - return dd *= y; + return apply_op( + *this, + x, + [](const auto& a, const auto& b) { return a * b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + auto safe_mul = [](std::size_t a, std::size_t b) -> std::size_t { + if(b == 0) + return 0; + if(a > std::numeric_limits::max() / b) + return std::numeric_limits::max(); + return a * b; + }; + lhs.range->min = lhs.range->min * rhs.range->min; + lhs.range->max = safe_mul(lhs.range->max, rhs.range->max); + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return o * rhs.range->min; }, + [&](auto o) { return o * lhs_min; }); + }); } -shape::dynamic_dimension operator*(const std::size_t& x, const shape::dynamic_dimension& y) +shape::dynamic_dimension& shape::dynamic_dimension::operator/=(const shape::dynamic_dimension& x) { - return y * x; + return apply_op( + *this, + x, + [](const auto& a, const auto& b) { return a / b; }, + [](auto& lhs, const auto& rhs) { + auto lhs_fixed = lhs.is_fixed(); + auto rhs_fixed = rhs.is_fixed(); + auto lhs_min = lhs.range->min; + lhs.range->min = (rhs.range->max == 0) ? 0 : lhs.range->min / rhs.range->max; + lhs.range->max = (rhs.range->min == 0) ? std::numeric_limits::max() + : lhs.range->max / rhs.range->min; + merge_optimals( + *lhs.optimals, + lhs_fixed, + *rhs.optimals, + rhs_fixed, + [&](auto o) { return (rhs.range->min == 0) ? std::size_t{0} : o / rhs.range->min; }, + [&](auto o) { return (o == 0) ? std::size_t{0} : lhs_min / o; }); + }); } bool operator==(const shape& x, const shape& y) { if(x.dynamic() and y.dynamic()) { - return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and - x.sub_shapes() == y.sub_shapes()); + return x.impl == y.impl or + (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and + x.dyn_strides() == y.dyn_strides() and x.sub_shapes() == y.sub_shapes()); } return x.impl == y.impl or (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and @@ -833,7 +1238,23 @@ std::ostream& operator<<(std::ostream& os, const shape& x) { if(x.sub_shapes().empty()) { - if(x.dynamic()) + if(x.symbolic()) + { + os << x.type_string() << ", {"; + const auto& dd = x.dyn_dims(); + for(std::size_t i = 0; i < dd.size(); ++i) + { + if(i > 0) + os << ", "; + if(dd[i].is_symbolic()) + os << dd[i]; + else + os << dd[i].get_interval().min; + } + os << "}, "; + os << "{" << to_string_range(x.dyn_strides()) << "}"; + } + else if(x.dynamic()) { os << "dynamic, "; os << x.type_string() << ", "; @@ -900,7 +1321,6 @@ void migraphx_to_value(value& v, const shape& s) value result; result["type"] = migraphx::to_value(s.type_string()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); - // avoid calling functions that will throw if(s.dynamic()) { result["lens"] = {}; @@ -913,6 +1333,14 @@ void migraphx_to_value(value& v, const shape& s) result["strides"] = migraphx::to_value(s.strides()); result["dynamic_dimensions"] = {}; } + if(s.symbolic()) + { + result["dyn_strides"] = migraphx::to_value(s.dyn_strides()); + } + else + { + result["dyn_strides"] = {}; + } v = result; } @@ -934,13 +1362,27 @@ void migraphx_from_value(const value& v, shape& s) else { auto v_dd = v.at("dynamic_dimensions"); - std::vector dyn_dims(v.at("dynamic_dimensions").size()); + std::vector dyn_dims(v_dd.size()); std::transform( v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) { return from_value(x); }); - s = shape{shape::parse_type(t), dyn_dims}; + if(v.contains("dyn_strides") and not v.at("dyn_strides").empty()) + { + auto v_ds = v.at("dyn_strides"); + std::vector dstrides; + dstrides.reserve(v_ds.size()); + std::transform(v_ds.begin(), + v_ds.end(), + std::back_inserter(dstrides), + [](const auto& x) { return from_value(x); }); + s = shape(shape::parse_type(t), std::move(dyn_dims), std::move(dstrides)); + } + else + { + s = shape{shape::parse_type(t), dyn_dims}; + } } } } diff --git a/src/sym.cpp b/src/sym.cpp index 00d9480e810..4aa68370812 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -65,6 +65,9 @@ struct integer_data struct symbol_data { std::string name; + int64_t min; + int64_t max; + std::set optimals; }; struct add_data { @@ -127,7 +130,11 @@ static std::size_t compute_hash(const expr_data& d) return std::visit( overloaded{ [&](const integer_data& p) { return hash_combine(h, std::hash{}(p.value)); }, - [&](const symbol_data& p) { return hash_combine(h, std::hash{}(p.name)); }, + [&](const symbol_data& p) { + auto h2 = hash_combine(h, std::hash{}(p.name)); + h2 = hash_combine(h2, std::hash{}(p.min)); + return hash_combine(h2, std::hash{}(p.max)); + }, [&](const add_data& p) { return hash_combine(hash_combine(h, std::hash{}(p.constant)), hash_ordered_map(p.terms)); @@ -183,7 +190,14 @@ static int compare_expr(const expr_ptr& a, const expr_ptr& b) }, [&](const symbol_data& da) { const auto& db = std::get(b->data); - return da.name.compare(db.name); + int c = da.name.compare(db.name); + if(c != 0) + return c; + if(da.min != db.min) + return da.min < db.min ? -1 : 1; + if(da.max != db.max) + return da.max < db.max ? -1 : 1; + return 0; }, [&](const add_data& da) { const auto& db = std::get(b->data); @@ -264,7 +278,11 @@ static expr_ptr make_integer(int64_t n) return make_node(integer_data{n}); } -static expr_ptr make_symbol(const std::string& name) { return make_node(symbol_data{name}); } +static expr_ptr +make_symbol(const std::string& name, int64_t min, int64_t max, std::set optimals = {}) +{ + return make_node(symbol_data{name, min, max, std::move(optimals)}); +} static expr_ptr make_add(const expr_ptr& a, const expr_ptr& b); static expr_ptr make_sub(const expr_ptr& a, const expr_ptr& b); @@ -384,20 +402,15 @@ static expr_ptr build_mul(int64_t coefficient, factor_map factors) static expr_ptr make_mul(const expr_ptr& a, const expr_ptr& b) { - if(holds(a) and holds(b)) + if(holds(b)) { - int64_t n = get_integer(a); - if(n == 0) - return make_integer(0); - if(n == 1) - return b; - const auto& d = get_add(b); - term_map scaled; + const auto& d = get_add(b); + expr_ptr result = make_mul(a, make_integer(d.constant)); for(const auto& [term, coeff] : d.terms) - scaled[term] = coeff * n; - return build_add(d.constant * n, std::move(scaled)); + result = make_add(result, make_mul(a, make_mul(make_integer(coeff), term))); + return result; } - if(holds(b) and holds(a)) + if(holds(a)) return make_mul(b, a); auto pa = extract_mul(a); @@ -566,41 +579,94 @@ static expr_ptr substitute(const expr_ptr& e, const subs_map& bindings) e->data); } -static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) +template +static int64_t eval_impl(const expr_ptr& e, const SymbolResolver& resolve_sym) { return std::visit(overloaded{[](const integer_data& d) -> int64_t { return d.value; }, - [&](const symbol_data& d) -> int64_t { - auto it = bindings.find(e); - if(it != bindings.end()) - return it->second; - MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + - d.name + "'"); - }, + [&](const symbol_data& d) -> int64_t { return resolve_sym(e, d); }, [&](const add_data& d) -> int64_t { int64_t sum = d.constant; for(const auto& [term, coeff] : d.terms) - sum += coeff * eval_direct(term, bindings); + sum += coeff * eval_impl(term, resolve_sym); return sum; }, [&](const mul_data& d) -> int64_t { int64_t prod = d.coefficient; for(const auto& [base, exp] : d.factors) { - int64_t val = eval_direct(base, bindings); + int64_t val = eval_impl(base, resolve_sym); for(int64_t i = 0; i < exp; ++i) prod *= val; } return prod; }, [&](const tdiv_data& d) -> int64_t { - auto denom = eval_direct(d.denominator, bindings); + auto denom = eval_impl(d.denominator, resolve_sym); if(denom == 0) - MIGRAPHX_THROW("sym::expr::eval_uint: division by zero"); - return eval_direct(d.numerator, bindings) / denom; + MIGRAPHX_THROW("sym::expr: division by zero during eval"); + return eval_impl(d.numerator, resolve_sym) / denom; }}, e->data); } +static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) +{ + return eval_impl(e, [&](const expr_ptr& node, const symbol_data& d) -> int64_t { + auto it = bindings.find(node); + if(it != bindings.end()) + return it->second; + MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + d.name + "'"); + }); +} + +// Walk the expression tree and collect each unique symbol node and its data. +static void collect_symbols(const expr_ptr& e, + std::vector>& result, + std::set& seen) +{ + std::visit(overloaded{[](const integer_data&) {}, + [&](const symbol_data& d) { + if(seen.insert(d.name).second) + result.push_back({e, d}); + }, + [&](const add_data& d) { + for(const auto& [term, coeff] : d.terms) + collect_symbols(term, result, seen); + }, + [&](const mul_data& d) { + for(const auto& [base, exp] : d.factors) + collect_symbols(base, result, seen); + }, + [&](const tdiv_data& d) { + collect_symbols(d.numerator, result, seen); + collect_symbols(d.denominator, result, seen); + }}, + e->data); +} + +// Recursively enumerate all 2^k combinations of symbol {min, max} values, +// evaluating the expression at each and tracking the global min and max. +static void eval_bounds_impl(const expr_ptr& node, + const std::vector>& syms, + std::size_t idx, + binding_map& bindings, + int64_t& lo, + int64_t& hi) +{ + if(idx == syms.size()) + { + auto v = eval_direct(node, bindings); + lo = std::min(lo, v); + hi = std::max(hi, v); + return; + } + const auto& [sym_node, sd] = syms[idx]; + bindings[sym_node] = sd.min; + eval_bounds_impl(node, syms, idx + 1, bindings, lo, hi); + bindings[sym_node] = sd.max; + eval_bounds_impl(node, syms, idx + 1, bindings, lo, hi); +} + // =================================================================== // Section 7: Pretty-printer // =================================================================== @@ -734,7 +800,7 @@ static expr_ptr parse_primary(const char*& p) name += *p; ++p; } - return make_symbol(name); + return make_symbol(name, 1, 1); } if(*p == '(') { @@ -834,6 +900,8 @@ expr::expr(std::shared_ptr pi) : p(std::move(pi)) {} bool expr::empty() const { return p == nullptr; } +bool expr::is_literal() const { return p != nullptr and holds(p->node); } + std::size_t expr::hash() const { if(empty()) @@ -865,6 +933,87 @@ std::size_t expr::eval_uint(const std::unordered_map& symbol_ return v; } +// Compute both the minimum and maximum value of an expression by +// evaluating at all 2^k vertices of the symbol bound ranges. +// +// Assumptions: +// 1. Expressions are monotonic in each variable independently, so the +// global extrema always occur at vertices of the variable ranges. +// 2. Expressions represent dimension sizes or strides: sums, products, +// and integer divisions of positive-valued symbols. Non-monotonic +// expressions (e.g. polynomials with interior extrema) are not +// expected. +// 3. The number of unique symbols per expression is small (typically +// 1-3), making the 2^k evaluation cost negligible. +interval expr::eval_interval() const +{ + if(empty()) + MIGRAPHX_THROW("sym::expr::eval_interval: empty expression"); + std::vector> syms; + std::set seen; + collect_symbols(p->node, syms, seen); + if(syms.empty()) + { + auto v = eval_direct(p->node, {}); + return {v, v}; + } + int64_t lo = INT64_MAX; + int64_t hi = INT64_MIN; + binding_map bindings; + eval_bounds_impl(p->node, syms, 0, bindings, lo, hi); + return {lo, hi}; +} + +// Recursively enumerate the Cartesian product of symbol optimals, +// evaluating the expression at each combination without materializing +// intermediate binding maps. +static void eval_optimals_impl(const expr_ptr& node, + const std::vector>& syms, + std::size_t idx, + binding_map& bindings, + std::set& result) +{ + if(idx == syms.size()) + { + result.insert(eval_direct(node, bindings)); + return; + } + const auto& [sym_node, sd] = syms[idx]; + for(auto oval : sd.optimals) + { + bindings[sym_node] = oval; + eval_optimals_impl(node, syms, idx + 1, bindings, result); + } +} + +// Compute the set of optimal values for the expression by evaluating it +// at every combination of each symbol's optimal values (Cartesian product). +// +// For a single variable: var("n", {1, 8}, {2, 4}) => optimals = {2, 4} +// For a compound expr: 2*n + 1 where n has optimals {2, 4} => {5, 9} +// For multiple variables: n + m where n={2,4}, m={3,6} => {5, 8, 7, 10} +// +// Returns empty if any symbol in the expression has no optimals. +std::set expr::eval_optimals() const +{ + if(empty()) + return {}; + std::vector> syms; + std::set seen; + collect_symbols(p->node, syms, seen); + auto has_optimals = std::all_of( + syms.begin(), syms.end(), [](const auto& s) { return not s.second.optimals.empty(); }); + if(syms.empty() or not has_optimals) + return {}; + + std::set signed_result; + binding_map bindings; + eval_optimals_impl(p->node, syms, 0, bindings, signed_result); + if(std::any_of(signed_result.begin(), signed_result.end(), [](int64_t v) { return v < 0; })) + MIGRAPHX_THROW("sym::expr::eval_optimals: negative optimal value"); + return {signed_result.begin(), signed_result.end()}; +} + expr expr::subs(const std::unordered_map& symbol_map) const { if(empty()) @@ -920,6 +1069,45 @@ bool operator==(const expr& a, const expr& b) bool operator!=(const expr& a, const expr& b) { return not(a == b); } +// Semantic strict less-than for symbolic expressions using interval arithmetic. +// +// Assumptions: +// - All symbols have positive intervals [min, max] where 1 <= min <= max. +// - Expressions are monotonically non-decreasing in each variable, which +// holds for dimension/stride arithmetic (sums and products of positive +// terms). This lets us bound the range of (b - a) by evaluating at the +// interval endpoints. +// +// Algorithm: +// Compute diff = b - a, then evaluate diff at the lower and upper bounds +// of every symbol to obtain [lo, hi]. If the entire interval is strictly +// positive (lo > 0) then a < b for all possible symbol values. If the +// interval is non-positive (hi <= 0) then a >= b. Otherwise the comparison +// is undetermined and we throw. +// +// Examples (all symbols default to [1, 1]): +// n < 2*n => diff = n, lo = 1, hi = 1 => true (strictly positive) +// 2*n < n => diff = -n, lo = -1, hi = -1 => false (non-positive) +// k < m*k => diff = k(m-1), lo = 0, hi = 0 => false (not strictly positive) +// +// With explicit bounds, e.g. n in [2, 10]: +// n < 3 => diff = 3 - n, lo = -7, hi = 1 => undetermined (throws) +// n < 11 => diff = 11 - n, lo = 1, hi = 9 => true +bool operator<(const expr& a, const expr& b) +{ + if(a.empty() and b.empty()) + return false; + if(a.empty() or b.empty()) + MIGRAPHX_THROW("sym::expr: cannot compare empty expression"); + auto ival = (b - a).eval_interval(); + if(to(ival.min) > 0) + return true; + if(to(ival.max) <= 0) + return false; + MIGRAPHX_THROW("sym::expr: comparison undetermined for: " + print_expr(a.p->node) + " < " + + print_expr(b.p->node)); +} + std::ostream& operator<<(std::ostream& os, const expr& e) { if(not e.empty()) @@ -927,11 +1115,17 @@ std::ostream& operator<<(std::ostream& os, const expr& e) return os; } -expr var(const std::string& name) +expr var(const std::string& name, interval bounds, std::set optimals) { if(name.empty()) MIGRAPHX_THROW("sym::var: variable name must not be empty"); - return {std::make_shared(make_symbol(name))}; + auto bmin = to(bounds.min); + auto bmax = to(bounds.max); + if(bmin > bmax) + MIGRAPHX_THROW("sym::var: variable interval must satisfy min <= max"); + if(bmin < 1) + MIGRAPHX_THROW("sym::var: variable interval must satisfy min >= 1"); + return {std::make_shared(make_symbol(name, bmin, bmax, std::move(optimals)))}; } expr lit(int64_t n) { return {std::make_shared(make_integer(n))}; } @@ -955,6 +1149,17 @@ static value node_to_value(const expr_ptr& e) value r; r["type"] = "sym"; r["name"] = d.name; + r["min"] = d.min; + r["max"] = d.max; + if(not d.optimals.empty()) + { + value opts = value::array{}; + std::transform(d.optimals.begin(), + d.optimals.end(), + std::back_inserter(opts), + [](auto o) -> value { return o; }); + r["optimals"] = opts; + } return r; }, [](const add_data& d) -> value { @@ -1002,32 +1207,40 @@ static expr_ptr node_from_value(const value& v) const auto& type = v.at("type").get_string(); if(type == "int") { - return make_integer(v.at("value").get_int64()); + return make_integer(v.at("value").to()); } else if(type == "sym") { - return make_symbol(v.at("name").get_string()); + auto sym_min = v.contains("min") ? v.at("min").to() : int64_t{1}; + auto sym_max = v.contains("max") ? v.at("max").to() : int64_t{1}; + std::set sym_opts; + if(v.contains("optimals")) + std::transform(v.at("optimals").begin(), + v.at("optimals").end(), + std::inserter(sym_opts, sym_opts.end()), + [](const auto& o) { return o.template to(); }); + return make_symbol(v.at("name").get_string(), sym_min, sym_max, std::move(sym_opts)); } else if(type == "add") { - auto constant = v.at("constant").get_int64(); + auto constant = v.at("constant").to(); term_map terms; for(const auto& t : v.at("terms")) { auto term = node_from_value(t.at("expr")); - auto coeff = t.at("coeff").get_int64(); + auto coeff = t.at("coeff").to(); terms[term] = coeff; } return build_add(constant, std::move(terms)); } else if(type == "mul") { - auto coefficient = v.at("coeff").get_int64(); + auto coefficient = v.at("coeff").to(); factor_map factors; for(const auto& f : v.at("factors")) { auto base = node_from_value(f.at("expr")); - auto exp = f.at("exp").get_int64(); + auto exp = f.at("exp").to(); factors[base] = exp; } return build_mul(coefficient, std::move(factors)); diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 381e68dfabe..a2d39c8a656 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -94,9 +94,9 @@ void blas_shape(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; - shape batch_shape{s.type(), - {s.lens().begin(), s.lens().end() - 2}, - {s.strides().begin(), s.strides().end() - 2}}; + shape batch_shape(s.type(), + std::vector(s.lens().begin(), s.lens().end() - 2), + std::vector(s.strides().begin(), s.strides().end() - 2)); auto batch_shapes = reduce_dims({batch_shape}); if(batch_shapes.front().lens().size() != 1) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index c3766e1cdf5..dfff40dadab 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -90,9 +90,9 @@ void blas_shape_hip(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; - shape batch_shape{s.type(), - {s.lens().begin(), s.lens().end() - 2}, - {s.strides().begin(), s.strides().end() - 2}}; + shape batch_shape(s.type(), + std::vector(s.lens().begin(), s.lens().end() - 2), + std::vector(s.strides().begin(), s.strides().end() - 2)); auto batch_shapes = reduce_dims({batch_shape}); if(batch_shapes.front().lens().size() != 1) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); diff --git a/test/eliminate_concat_test.cpp b/test/eliminate_concat_test.cpp index ccc6aa2fe5d..773c82619f9 100644 --- a/test/eliminate_concat_test.cpp +++ b/test/eliminate_concat_test.cpp @@ -203,7 +203,7 @@ static migraphx::shape create_shape(Ts... xs) return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}}; else return migraphx::shape::from_permutation( - migraphx::shape::float_type, {std::size_t(xs)...}, {Is...}); + migraphx::shape::float_type, std::vector{std::size_t(xs)...}, {Is...}); } template diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index 8ba5d216f93..6e6121e8697 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -26,6 +26,7 @@ #include #include "test.hpp" #include +#include #include #include @@ -140,6 +141,27 @@ TEST_CASE(program_with_module) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(symbolic_shape_msgpack_roundtrip) +{ + using migraphx::shape; + using dd = shape::dynamic_dimension; + using migraphx::sym::lit; + auto n = migraphx::sym::var("n", {1, 8}); + + migraphx::program p; + auto* mm = p.get_main_module(); + shape s{shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + auto x = mm->add_parameter("x", s); + auto r = mm->add_instruction(migraphx::make_op("relu"), x); + mm->add_return({r}); + + migraphx::file_options options; + options.format = "msgpack"; + std::vector buffer = migraphx::save_buffer(p, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p.sort() == p2.sort()); +} + static migraphx::program create_program_with_debug_symbols() { migraphx::program p; diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 386c9058aeb..fd67fe8aae4 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -33,6 +34,10 @@ #include #include "test.hpp" +using dd = migraphx::shape::dynamic_dimension; +using migraphx::sym::lit; +using migraphx::sym::var; + TEST_CASE(test_shape_default) { migraphx::shape s{}; @@ -441,7 +446,9 @@ TEST_CASE(test_shape_static_to_dynamic) { migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s1 = s0.to_dynamic(); - migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}}; + migraphx::shape s2{migraphx::shape::float_type, + {{1, 1}, {2, 2}, {4, 4}, {4, 4}}, + {lit(32), lit(16), lit(4), lit(1)}}; EXPECT(s1 == s2); } @@ -461,7 +468,8 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}}); + sub_shapes1.push_back(migraphx::shape{ + migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } @@ -1272,4 +1280,916 @@ TEST_CASE(shape_same_lens_static_dynamic) EXPECT(not migraphx::shape::same_lens(s1, s3)); } +// =================================================================== +// Symbolic dynamic_dimension tests +// =================================================================== + +TEST_CASE(test_dd_symbolic_add_size_t) +{ + auto n = var("n", {1, 8}); + dd d{n}; + d += 2; + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 10); + EXPECT(d.sym_expr == n + 2); +} + +TEST_CASE(test_dd_symbolic_sub_size_t) +{ + auto n = var("n", {3, 8}); + dd d{n}; + d -= 1; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 7); + EXPECT(d.sym_expr == n - 1); +} + +TEST_CASE(test_dd_symbolic_mul_size_t) +{ + auto n = var("n", {1, 8}); + dd d{n}; + d *= 3; + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 24); + EXPECT(d.sym_expr == n * 3); +} + +TEST_CASE(test_dd_symbolic_div_size_t) +{ + auto n = var("n", {4, 16}); + dd d{n}; + d /= 2; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 8); + EXPECT(d.sym_expr == n / 2); +} + +TEST_CASE(test_dd_symbolic_add_dd) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {2, 4}); + auto r = dd{n} + dd{c}; + EXPECT(r.get_interval().min == 3); + EXPECT(r.get_interval().max == 12); + EXPECT(r.sym_expr == n + c); +} + +TEST_CASE(test_dd_symbolic_sub_dd) +{ + auto n = var("n", {4, 16}); + auto k = var("k", {1, 4}); + auto r = dd{n} - dd{k}; + EXPECT(r.get_interval().min == 0); + EXPECT(r.get_interval().max == 15); + EXPECT(r.sym_expr == n - k); +} + +TEST_CASE(test_dd_symbolic_mul_dd) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {2, 4}); + auto r = dd{n} * dd{c}; + EXPECT(r.get_interval().min == 2); + EXPECT(r.get_interval().max == 32); + EXPECT(r.sym_expr == n * c); +} + +TEST_CASE(test_dd_symbolic_div_dd) +{ + auto n = var("n", {4, 16}); + auto k = var("k", {2, 4}); + auto r = dd{n} / dd{k}; + EXPECT(r.get_interval().min == 1); + EXPECT(r.get_interval().max == 8); + EXPECT(r.sym_expr == n / k); +} + +TEST_CASE(test_dd_symbolic_plus_range_fixed) +{ + auto n = var("n", {1, 8}); + dd a{n}; + dd b{3, 3}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); + EXPECT(r.get_interval().min == 4); + EXPECT(r.get_interval().max == 11); +} + +TEST_CASE(test_dd_nonfixed_nonsymbolic_plus_symbolic_drops_sym) +{ + auto c = var("c", {2, 4}); + dd a{1, 8}; + dd b{c}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); + EXPECT(r.get_interval().min == 3); + EXPECT(r.get_interval().max == 12); +} + +TEST_CASE(test_dd_nonsymbolic_remains_nonsymbolic) +{ + dd a{1, 8}; + dd b{2, 4}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(test_dd_equality_with_sym) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 8}); + dd a{n}; + dd b{n}; + dd d2{c}; + dd d{1, 8}; + EXPECT(a == b); + EXPECT(a != d2); + EXPECT(a != d); +} + +TEST_CASE(test_symbolic_shape_construction) +{ + auto n = var("n", {1, 8}); + migraphx::shape sh{migraphx::shape::float_type, + {dd{n}, dd{lit(3)}, dd{lit(224)}}, + {n * 3 * 224, lit(224), lit(1)}}; + EXPECT(sh.dynamic()); + EXPECT(sh.symbolic()); + EXPECT(sh.dyn_dims().size() == 3); + EXPECT(sh.dyn_strides().size() == 3); +} + +TEST_CASE(test_symbolic_stride_auto_compute) +{ + auto n = var("n", {1, 8}); + auto s = var("s", {1, 16}); + migraphx::shape sh{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(4)}}}; + EXPECT(sh.symbolic()); + EXPECT(sh.dyn_strides().size() == 3); + EXPECT(sh.dyn_strides()[2] == lit(1)); + EXPECT(sh.dyn_strides()[1] == lit(4)); + EXPECT(sh.dyn_strides()[0] == s * 4); +} + +TEST_CASE(test_symbolic_to_static) +{ + auto n = var("n", {1, 8}); + auto s = var("s", {1, 16}); + migraphx::shape sh{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(4)}}}; + std::unordered_map symbol_map = {{n, 2}, {s, 8}}; + auto s_static = sh.to_static(symbol_map); + EXPECT(not s_static.dynamic()); + EXPECT(s_static.lens() == std::vector{2, 8, 4}); + EXPECT(s_static.strides() == std::vector{32, 4, 1}); +} + +TEST_CASE(test_symbolic_shape_serialize) +{ + auto n = var("n", {1, 8}); + auto s = var("s", {1, 16}); + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(4)}}}; + auto v = migraphx::to_value(s1); + auto s2 = migraphx::from_value(v); + EXPECT(s1 == s2); + EXPECT(s2.symbolic()); + EXPECT(s2.dyn_strides().size() == 3); + EXPECT(s2.dyn_strides()[0] == s * 4); + EXPECT(s2.dyn_strides()[2] == lit(1)); +} + +TEST_CASE(test_symbolic_shape_equality) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 8}); + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + migraphx::shape s2{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + migraphx::shape s3{migraphx::shape::float_type, {dd{c}, dd{lit(3)}}}; + EXPECT(s1 == s2); + EXPECT(s1 != s3); +} + +TEST_CASE(test_symbolic_shape_print) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 8}); + auto to_str = [](const migraphx::shape& sh) { + std::stringstream ss; + ss << sh; + return ss.str(); + }; + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape s2{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape s3{migraphx::shape::float_type, {dd{c}, dd{lit(3)}, dd{lit(4)}}}; + EXPECT(to_str(s1) == to_str(s2)); + EXPECT(to_str(s1) != to_str(s3)); +} + +TEST_CASE(dd_intersection_symbolic_with_range) +{ + auto n = var("n", {1, 32}); + dd a{n}; + dd b{2, 6}; + auto result = a.intersection(b); + EXPECT(result.has_value()); + EXPECT(result->get_interval().min == 2); + EXPECT(result->get_interval().max == 6); + EXPECT(result->sym_expr.empty()); +} + +TEST_CASE(dd_intersection_symbolic_same_symbol) +{ + auto n = var("n", {1, 32}); + dd a{n}; + dd b{n}; + auto result = a.intersection(b); + EXPECT(result.has_value()); + EXPECT(*result == a); +} + +TEST_CASE(dd_intersection_symbolic_different_symbol) +{ + auto n = var("n", {1, 32}); + auto m = var("m", {1, 16}); + dd a{n}; + dd b{m}; + auto result = a.intersection(b); + EXPECT(not result.has_value()); +} + +// ------------------------------------------------------------------- +// Symbolic shapes: packed/standard/transposed/broadcasted/scalar +// ------------------------------------------------------------------- + +TEST_CASE(test_symbolic_packed_default) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_standard) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_standard_singleton_dim) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(8)}}, {lit(8), lit(4), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_shape_ndim_symbolic) +{ + auto n = var("n", {1, 8}); + migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + EXPECT(s0.ndim() == 2); + + auto c = var("c", {1, 16}); + migraphx::shape s1{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}, dd{lit(4)}}}; + EXPECT(s1.ndim() == 4); +} + +TEST_CASE(test_symbolic_transposed) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {lit(1), n, n * c}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(s.packed()); + EXPECT(s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_to_dynamic_identity) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + auto s2 = s.to_dynamic(); + EXPECT(s == s2); +} + +TEST_CASE(test_symbolic_overlap) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}, {lit(6), lit(3), lit(2)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + +TEST_CASE(test_symbolic_scalar) +{ + migraphx::shape s{migraphx::shape::float_type, {dd{lit(1)}}, {lit(0)}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_scalar_broadcast) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}, {lit(0), lit(0), lit(0)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(1), lit(0)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted2) +{ + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{lit(1)}, dd{c}}, {lit(0), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted3) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(0), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted4) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {c * lit(4), lit(0), lit(1)}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_broadcasted5) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {lit(1), lit(0), n * c}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_step_broadcasted) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(0), n}}; + EXPECT(s.symbolic()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_symbolic_normalize_standard) +{ + auto n = var("n", {1, 4}); + auto c = var("c", {1, 64}); + migraphx::shape s{migraphx::shape::float_type, + {dd{n}, dd{c}, dd{lit(35)}, dd{lit(35)}}, + {c * 1225, lit(1225), lit(35), lit(1)}}; + EXPECT(s.standard()); + auto ns = s.normalize_standard(); + EXPECT(ns.standard()); + EXPECT(ns.symbolic()); + EXPECT(ns.dyn_dims() == s.dyn_dims()); + EXPECT(ns.type() == s.type()); +} + +TEST_CASE(test_symbolic_normalize_standard_transposed) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{ + migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}, {lit(1), lit(4), c * 4}}; + EXPECT(not s.standard()); + EXPECT(s.transposed()); + auto ns = s.normalize_standard(); + EXPECT(ns == s); +} + +// ------------------------------------------------------------------- +// Symbolic with_lens / from_permutation / find_permutation +// ------------------------------------------------------------------- + +TEST_CASE(test_symbolic_with_lens_standard) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + auto b = var("b", {1, 16}); + std::vector
new_dims = {dd{b}, dd{lit(4)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.standard()); + EXPECT(not s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_transposed) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(1), n}}; + EXPECT(s.transposed()); + auto b = var("b", {1, 16}); + std::vector
new_dims = {dd{b}, dd{lit(4)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_4d) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}, dd{lit(4)}}}; + auto b = var("b", {1, 32}); + auto ch = var("ch", {1, 64}); + std::vector
new_dims = {dd{b}, dd{ch}, dd{lit(8)}, dd{lit(8)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.standard()); + EXPECT(not s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_singleton_nchw) +{ + auto n = var("n", {1, 64}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(24)}, dd{lit(24)}}}; + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(not s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_singleton_nhwc) +{ + auto n = var("n", {1, 64}); + auto s1 = migraphx::reorder_shape( + migraphx::shape{migraphx::shape::float_type, {dd{n}, dd{lit(24)}, dd{lit(24)}, dd{lit(1)}}}, + {0, 3, 1, 2}); + EXPECT(s1.transposed()); + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s1.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_all_singleton) +{ + auto n = var("n", {1, 64}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(1)}, dd{lit(1)}}}; + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.standard()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(test_symbolic_with_lens_ambiguous_nhwc_all_singleton) +{ + auto n = var("n", {1, 64}); + auto s1 = migraphx::reorder_shape( + migraphx::shape{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(1)}, dd{lit(3)}}}, + {0, 3, 1, 2}); + auto c = var("c", {1, 16}); + std::vector
new_dims = {dd{n}, dd{c}, dd{lit(24)}, dd{lit(24)}}; + auto s2 = s1.with_lens(new_dims); + EXPECT(s2.symbolic()); + EXPECT(s2.transposed()); + EXPECT(s2.dyn_dims() == new_dims); +} + +TEST_CASE(find_permutation_symbolic_2d_standard) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}}; + std::vector permutation = {0, 1}; + EXPECT(migraphx::find_permutation(s) == permutation); +} + +TEST_CASE(find_permutation_symbolic_2d_transpose) +{ + auto n = var("n", {1, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}}, {lit(1), n}}; + std::vector permutation = {1, 0}; + EXPECT(migraphx::find_permutation(s) == permutation); +} + +TEST_CASE(find_permutation_symbolic_3d) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + auto h = var("h", {2, 32}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{h}}, {lit(1), c * h, n}}; + std::vector permutation = {1, 2, 0}; + EXPECT(migraphx::find_permutation(s) == permutation); +} + +TEST_CASE(from_symbolic_2d_permutation) +{ + auto n = var("n", {1, 8}); + std::vector
out_dims = {dd{n}, dd{lit(3)}}; + std::vector permutation = {1, 0}; + migraphx::shape out_shape = + migraphx::shape::from_permutation(migraphx::shape::float_type, out_dims, permutation); + EXPECT(out_shape.dyn_dims() == out_dims); + EXPECT(migraphx::find_permutation(out_shape) == permutation); +} + +TEST_CASE(from_symbolic_3d_permutation) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + std::vector
out_dims = {dd{n}, dd{c}, dd{lit(4)}}; + std::vector permutation = {1, 2, 0}; + migraphx::shape out_shape = + migraphx::shape::from_permutation(migraphx::shape::float_type, out_dims, permutation); + EXPECT(out_shape.dyn_dims() == out_dims); + EXPECT(migraphx::find_permutation(out_shape) == permutation); +} + +TEST_CASE(from_symbolic_4d_permutation) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 64}); + auto h = var("h", {2, 32}); + auto w = var("w", {2, 32}); + std::vector
out_dims = {dd{n}, dd{c}, dd{h}, dd{w}}; + std::vector permutation = {3, 2, 0, 1}; + migraphx::shape out_shape = + migraphx::shape::from_permutation(migraphx::shape::float_type, out_dims, permutation); + EXPECT(out_shape.dyn_dims() == out_dims); + EXPECT(migraphx::find_permutation(out_shape) == permutation); +} + +TEST_CASE(reorder_shape_symbolic) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + std::vector perm = {2, 0, 1}; + auto reordered = migraphx::reorder_shape(s, perm); + EXPECT(reordered.symbolic()); + EXPECT(reordered.dyn_dims().size() == s.dyn_dims().size()); +} + +TEST_CASE(find_permutation_symbolic_stride_ordering_reversal) +{ + auto a = var("a", {1, 16}); + auto b = var("b", {1, 4}); + auto c = var("c", {1, 8}); + // a/b has interval [0, 16], c has interval [1, 8]. + // At max: 16 > 8 (a/b sorted first), at min: 0 < 1 (reversal). + migraphx::shape s{migraphx::shape::float_type, {dd{a}, dd{c}}, {a / b, c}}; + EXPECT(test::throws([&] { migraphx::find_permutation(s); })); +} + +TEST_CASE(test_symbolic_elements_via_to_static) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + std::unordered_map symbol_map = {{n, 2}, {c, 8}}; + auto ss = s.to_static(symbol_map); + EXPECT(ss.elements() == 2 * 8 * 4); + EXPECT(ss.strides() == std::vector{32, 4, 1}); +} + +// ------------------------------------------------------------------- +// Dynamic dimension: div, add/sub/mul/div with two dd's +// ------------------------------------------------------------------- + +TEST_CASE(dynamic_dimension_div_fixed) +{ + dd a{10, 30, {12, 24}}; + a /= 3; + EXPECT(a.get_interval().min == 3); + EXPECT(a.get_interval().max == 10); + EXPECT(a.get_optimals() == std::set{4, 8}); +} + +TEST_CASE(dynamic_dimension_add_dd) +{ + dd a{2, 8, {4, 6}}; + dd b{3, 5, {3, 5}}; + auto r = a + b; + EXPECT(r.get_interval().min == 5); + EXPECT(r.get_interval().max == 13); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_sub_dd) +{ + dd a{10, 30, {15, 25}}; + dd b{2, 5, {3}}; + auto r = a - b; + EXPECT(r.get_interval().min == 5); + EXPECT(r.get_interval().max == 28); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_mul_dd) +{ + dd a{2, 8, {4}}; + dd b{3, 5, {3, 5}}; + auto r = a * b; + EXPECT(r.get_interval().min == 6); + EXPECT(r.get_interval().max == 40); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_div_dd) +{ + dd a{10, 40, {20, 30}}; + dd b{2, 5, {2, 4}}; + auto r = a / b; + EXPECT(r.get_interval().min == 2); + EXPECT(r.get_interval().max == 20); + EXPECT(r.get_optimals().empty()); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_sub_clamp_zero) +{ + dd a{2, 5}; + dd b{4, 8}; + auto r = a - b; + EXPECT(r.get_interval().min == 0); + EXPECT(r.get_interval().max == 1); +} + +TEST_CASE(dynamic_dimension_add_one_fixed) +{ + dd a{4, 4, {4}}; + dd b{2, 8, {3, 6}}; + auto r = a + b; + EXPECT(r.get_interval().min == 6); + EXPECT(r.get_interval().max == 12); + EXPECT(r.get_optimals() == std::set({7, 10})); + EXPECT(r.sym_expr.empty()); +} + +TEST_CASE(dynamic_dimension_mul_one_fixed) +{ + dd a{3, 3}; + dd b{2, 8, {4, 6}}; + auto r = a * b; + EXPECT(r.get_interval().min == 6); + EXPECT(r.get_interval().max == 24); + EXPECT(r.get_optimals() == std::set({12, 18})); + EXPECT(r.sym_expr.empty()); +} + +// ------------------------------------------------------------------- +// Dynamic dimension: symbolic construction and arithmetic +// ------------------------------------------------------------------- + +TEST_CASE(test_dd_from_empty_expr_throws) +{ + migraphx::sym::expr empty_expr; + EXPECT(test::throws([&] { dd{empty_expr}; })); +} + +TEST_CASE(test_dd_accessors_range_based) +{ + dd a{3, 10, {4, 7}}; + auto iv = a.get_interval(); + EXPECT(iv.min == 3); + EXPECT(iv.max == 10); + EXPECT(a.get_optimals() == std::set({4, 7})); +} + +TEST_CASE(test_dd_accessors_symbolic) +{ + auto n = var("n", {2, 16}, {4, 8}); + dd d{n}; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 16); + EXPECT(d.get_optimals() == std::set({4, 8})); +} + +TEST_CASE(test_dd_symbolic_no_optimals) +{ + auto n = var("n", {3, 12}); + dd d{n}; + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 12); + EXPECT(d.get_optimals().empty()); +} + +TEST_CASE(test_dd_symbolic_add_dd_optimals) +{ + auto h = var("h", {5, 20}, {10, 15}); + auto w = var("w", {5, 20}, {10, 15}); + auto r = dd{h} + dd{w}; + EXPECT(r.sym_expr == h + w); + EXPECT(r.get_interval().min == 10); + EXPECT(r.get_interval().max == 40); + EXPECT(r.get_optimals() == std::set({20, 25, 30})); +} + +TEST_CASE(test_dd_symbolic_sub_dd_optimals) +{ + auto n = var("n", {10, 50}, {20, 30}); + auto k = var("k", {1, 5}, {2, 4}); + auto r = dd{n} - dd{k}; + EXPECT(r.sym_expr == n - k); + EXPECT(r.get_interval().min == 5); + EXPECT(r.get_interval().max == 49); + EXPECT(r.get_optimals() == std::set({16, 18, 26, 28})); +} + +TEST_CASE(test_dd_symbolic_mul_dd_optimals) +{ + auto n = var("n", {1, 8}, {2, 4}); + auto c = var("c", {1, 4}, {2, 3}); + auto r = dd{n} * dd{c}; + EXPECT(r.sym_expr == n * c); + EXPECT(r.get_interval().min == 1); + EXPECT(r.get_interval().max == 32); + EXPECT(r.get_optimals() == std::set({4, 6, 8, 12})); +} + +TEST_CASE(test_dd_symbolic_div_dd_optimals) +{ + auto n = var("n", {10, 50}, {20, 40}); + auto k = var("k", {2, 5}, {2, 5}); + auto r = dd{n} / dd{k}; + EXPECT(r.sym_expr == n / k); + EXPECT(r.get_interval().min == 2); + EXPECT(r.get_interval().max == 25); + EXPECT(r.get_optimals() == std::set({4, 8, 10, 20})); +} + +TEST_CASE(test_dd_symbolic_add_size_t_optimals) +{ + auto n = var("n", {1, 8}, {4, 6}); + dd d{n}; + d += 2; + EXPECT(d.sym_expr == n + 2); + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 10); + EXPECT(d.get_optimals() == std::set({6, 8})); +} + +TEST_CASE(test_dd_symbolic_mul_size_t_optimals) +{ + auto n = var("n", {1, 8}, {2, 4}); + dd d{n}; + d *= 3; + EXPECT(d.sym_expr == n * 3); + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 24); + EXPECT(d.get_optimals() == std::set({6, 12})); +} + +TEST_CASE(test_dd_symbolic_chained_arithmetic_optimals) +{ + auto h = var("h", {10, 50}, {20, 30}); + dd d{h}; + d -= 3; + d /= 2; + d += 1; + EXPECT(d.sym_expr == (h - 3) / 2 + 1); + EXPECT(d.get_interval().min == 4); + EXPECT(d.get_interval().max == 24); + EXPECT(d.get_optimals() == std::set({9, 14})); +} + +TEST_CASE(test_dd_symbolic_arithmetic_invalidates_cache) +{ + auto n = var("n", {2, 8}, {4}); + dd d{n}; + EXPECT(d.get_interval().min == 2); + EXPECT(d.get_interval().max == 8); + d += 1; + EXPECT(d.sym_expr == n + 1); + EXPECT(d.get_interval().min == 3); + EXPECT(d.get_interval().max == 9); + EXPECT(d.get_optimals() == std::set({5})); +} + +TEST_CASE(test_dd_range_arithmetic_keeps_cache) +{ + dd a{2, 8, {4}}; + dd b{1, 3}; + auto r = a + b; + EXPECT(r.sym_expr.empty()); + EXPECT(r.get_interval().min == 3); + EXPECT(r.get_interval().max == 11); +} + +TEST_CASE(test_dd_serialize_range_based) +{ + dd a{3, 10, {5, 7}}; + auto v = migraphx::to_value(a); + auto b = migraphx::from_value
(v); + EXPECT(a == b); + EXPECT(b.get_interval().min == 3); + EXPECT(b.get_interval().max == 10); + EXPECT(b.get_optimals() == std::set({5, 7})); +} + +TEST_CASE(test_dd_serialize_symbolic) +{ + auto n = var("n", {2, 16}, {4, 8}); + dd d{n}; + auto v = migraphx::to_value(d); + auto d2 = migraphx::from_value
(v); + EXPECT(d == d2); + EXPECT(d2.get_interval().min == 2); + EXPECT(d2.get_interval().max == 16); + EXPECT(d2.get_optimals() == std::set({4, 8})); +} + +// ------------------------------------------------------------------- +// is_compatible / is_compatible_lens for symbolic shapes +// ------------------------------------------------------------------- + +TEST_CASE(shape_is_compatible_symbolic_same) +{ + auto n = var("n", {1, 8}); + migraphx::shape actual{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + EXPECT(migraphx::shape::is_compatible(actual, expected)); +} + +TEST_CASE(shape_is_compatible_lens_symbolic_same) +{ + auto n = var("n", {1, 8}); + migraphx::shape s1{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}}; + migraphx::shape s2{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}}; + EXPECT(migraphx::shape::is_compatible_lens(s1, s2)); +} + +TEST_CASE(shape_is_compatible_lens_static_vs_symbolic) +{ + auto n = var("n", {2, 8}); + migraphx::shape actual1{migraphx::shape::float_type, {1, 4, 3}}; + migraphx::shape actual2{migraphx::shape::float_type, {1, 16, 3}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}}}; + EXPECT(migraphx::shape::is_compatible_lens(actual1, expected)); + EXPECT(not migraphx::shape::is_compatible_lens(actual2, expected)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/sym_test.cpp b/test/sym_test.cpp index f822b0e773f..ff5a5e77043 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -24,12 +24,21 @@ #include #include +#include #include "test.hpp" -using se = migraphx::sym::expr; +using se = migraphx::sym::expr; +using interval = migraphx::sym::interval; using migraphx::sym::lit; using migraphx::sym::parse; -using migraphx::sym::var; + +// Local wrappers so sym-library arithmetic/canonicalization tests don't have +// to spell out bounds they don't care about +static se var(const std::string& name) { return migraphx::sym::var(name, {1, 1}); } +static se var(const std::string& name, interval bounds, std::set optimals = {}) +{ + return migraphx::sym::var(name, bounds, std::move(optimals)); +} // =================================================================== // Tier 1: Expression construction and canonicalization @@ -979,4 +988,413 @@ TEST_CASE(serialize_compound) EXPECT(round_trip(e) == e); } +// ------------------------------------------------------------------- +// Bounded vars: constructor / eq / hash +// ------------------------------------------------------------------- + +TEST_CASE(construct_var_min_greater_than_max_throws) +{ + EXPECT(test::throws([&] { var("n", {10, 5}); })); +} + +TEST_CASE(construct_var_min_less_than_one_throws) +{ + EXPECT(test::throws([&] { var("n", {0, 5}); })); + EXPECT(test::throws([&] { var("n", {-1, 5}); })); +} + +TEST_CASE(eq_same_name_different_intervals) +{ + auto h1 = var("h", {1, 128}); + auto h2 = var("h", {1, 256}); + auto h3 = var("h", {2, 128}); + auto h4 = var("h", {1, 128}); + EXPECT(h1 != h2); + EXPECT(h1 != h3); + EXPECT(h1 == h4); +} + +TEST_CASE(hash_same_name_different_intervals) +{ + auto h1 = var("h", {1, 128}); + auto h2 = var("h", {1, 256}); + auto h3 = var("h", {1, 128}); + EXPECT(h1.hash() != h2.hash()); + EXPECT(h1.hash() == h3.hash()); +} + +// ------------------------------------------------------------------- +// Bounds: eval_interval() +// ------------------------------------------------------------------- + +TEST_CASE(eval_interval_single_var) +{ + auto n = var("n", {2, 16}); + EXPECT(n.eval_interval() == interval{2, 16}); +} + +TEST_CASE(eval_interval_literal) { EXPECT(lit(42).eval_interval() == interval{42, 42}); } + +TEST_CASE(eval_interval_compound) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + auto e = n * c * 4; + EXPECT(e.eval_interval() == interval{4, 512}); +} + +TEST_CASE(eval_interval_stride_diff) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + auto diff = n * c - n; + EXPECT(diff.eval_interval() == interval{0, 120}); +} + +TEST_CASE(eval_interval_division) +{ + auto n = var("n", {2, 10}); + auto d = var("d", {1, 5}); + auto e = n / d; + EXPECT(e.eval_interval() == interval{0, 10}); +} + +TEST_CASE(eval_interval_div_literal_denom) +{ + auto n = var("n", {4, 16}); + auto e = n / lit(4); + EXPECT(e.eval_interval() == interval{1, 4}); +} + +TEST_CASE(eval_interval_subtraction_independent) +{ + auto a = var("a", {1, 10}); + auto b = var("b", {1, 5}); + auto e = a - b; + EXPECT(e.eval_interval() == interval{-4, 9}); +} + +TEST_CASE(eval_interval_empty_throws) +{ + se empty; + EXPECT(test::throws([&] { (void)empty.eval_interval(); })); +} + +TEST_CASE(eval_interval_uint) +{ + auto n = var("n", {2, 16}); + auto e = 3 * n + 1; + EXPECT(e.eval_interval() == interval{7, 49}); +} + +// ------------------------------------------------------------------- +// Comparison operators +// ------------------------------------------------------------------- + +TEST_CASE(cmp_lit_constants) +{ + EXPECT(lit(1) < lit(2)); + EXPECT(not(lit(2) < lit(1))); + EXPECT(not(lit(3) < lit(3))); + EXPECT(lit(2) > lit(1)); + EXPECT(lit(3) <= lit(3)); + EXPECT(lit(3) >= lit(3)); + EXPECT(lit(1) <= lit(2)); + EXPECT(lit(2) >= lit(1)); +} + +TEST_CASE(cmp_equal_expr_not_less) +{ + auto n = var("n"); + EXPECT(not(n < n)); + EXPECT(not(n > n)); + EXPECT(n <= n); + EXPECT(n >= n); +} + +TEST_CASE(cmp_empty_not_less) +{ + se a; + se b; + EXPECT(not(a < b)); +} + +TEST_CASE(cmp_empty_with_nonempty_throws) +{ + EXPECT(test::throws([&]() -> bool { return se{} < var("n"); })); + EXPECT(test::throws([&]() -> bool { return var("n") < se{}; })); +} + +TEST_CASE(cmp_stride_ordering_4d) +{ + auto c = var("c", {1, 512}); + auto h = var("h", {1, 256}); + auto w = var("w", {1, 256}); + auto s0 = c * h * w; + auto s1 = h * w; + auto s2 = w; + auto s3 = lit(1); + EXPECT(s1 <= s0); + EXPECT(s2 <= s1); + EXPECT(s3 <= s2); + EXPECT(s3 <= s0); +} + +TEST_CASE(cmp_scaled_symbol) +{ + auto n = var("n"); + EXPECT(n < 2 * n); + EXPECT(n < 3 * n); + EXPECT(not(2 * n < n)); +} + +TEST_CASE(cmp_product_explicit_bounds) +{ + auto k = var("k", {1, 8}); + auto m = var("m", {2, 4}); + EXPECT(k < m * k); +} + +TEST_CASE(cmp_conv_output_smaller_than_input) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + EXPECT(out < h); + EXPECT(not(h < out)); +} + +TEST_CASE(cmp_repeated_pooling) +{ + auto h = var("h", {7, 256}); + auto out1 = (h - 3) / 2 + 1; + auto out2 = (out1 - 3) / 2 + 1; + EXPECT(out1 < h); + EXPECT(out2 < out1); + EXPECT(out2 < h); +} + +TEST_CASE(cmp_strides_after_conv) +{ + auto h = var("h", {7, 128}); + auto w = var("w", {2, 128}); + auto new_h = (h - 3) / 2 + 1; + auto s0 = new_h * w; + auto s1 = w; + auto s2 = lit(1); + EXPECT(s1 < s0); + EXPECT(s2 < s1); +} + +TEST_CASE(cmp_broadcast_stride_zero) +{ + auto w = var("w"); + EXPECT(lit(0) < w); + EXPECT(not(w < lit(0))); +} + +TEST_CASE(cmp_offset_expressions) +{ + auto h = var("h", {2, 256}); + EXPECT(h - 1 < h); + EXPECT(h < h + 1); + EXPECT(not(h + 1 < h)); +} + +TEST_CASE(cmp_undetermined_throws) +{ + auto n = var("n", {2, 10}); + EXPECT(test::throws([&]() -> bool { return n < lit(5); })); +} + +TEST_CASE(cmp_element_count_slice) +{ + auto n = var("n", {1, 32}); + auto c = var("c", {1, 512}); + auto h = var("h", {1, 256}); + auto w = var("w", {2, 256}); + EXPECT(n * c * h < n * c * h * w); +} + +TEST_CASE(cmp_deep_pooling_chain) +{ + auto h = var("h", {31, 512}); + se stage = h; + se prev; + for(int i = 0; i < 5; ++i) + { + prev = stage; + stage = (stage - 1) / 2; + } + EXPECT(stage < prev); + EXPECT(stage < h); +} + +TEST_CASE(cmp_commuted_product) +{ + auto a = var("a"); + auto b = var("b"); + EXPECT(not(a * b < b * a)); + EXPECT(a * b <= b * a); + EXPECT(a * b >= b * a); +} + +TEST_CASE(cmp_negative_literals) +{ + EXPECT(lit(-5) < lit(-1)); + EXPECT(lit(-1) < lit(0)); + EXPECT(lit(-10) < lit(10)); + EXPECT(not(lit(0) < lit(-1))); +} + +TEST_CASE(cmp_symmetry_lt_gt) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + EXPECT(out < h); + EXPECT(h > out); + EXPECT(not(h < out)); + EXPECT(not(out > h)); +} + +TEST_CASE(cmp_transitivity_strides) +{ + auto c = var("c", {2, 512}); + auto h = var("h", {2, 256}); + auto w = var("w", {2, 256}); + auto s0 = c * h * w; + auto s1 = h * w; + auto s2 = w; + auto s3 = lit(1); + EXPECT(s1 < s0); + EXPECT(s2 < s1); + EXPECT(s3 < s2); + EXPECT(s3 < s0); + EXPECT(s2 < s0); + EXPECT(s3 < s1); +} + +TEST_CASE(cmp_division_ordering) +{ + auto h = var("h", {5, 256}); + auto pool2 = (h - 1) / 2; + auto pool4 = (h - 1) / 4; + EXPECT(pool4 < pool2); + EXPECT(pool2 < h); + EXPECT(pool4 < h); +} + +TEST_CASE(cmp_sum_less_than_product) +{ + auto n = var("n", {2, 32}); + auto c = var("c", {3, 512}); + EXPECT(n + c < n * c); +} + +TEST_CASE(cmp_algebraically_equal_expressions) +{ + auto h = var("h"); + auto a = h + h; + auto b = 2 * h; + EXPECT(a == b); + EXPECT(not(a < b)); + EXPECT(not(b < a)); + EXPECT(a <= b); + EXPECT(a >= b); +} + +TEST_CASE(cmp_zero_stride_less_than_symbolic_stride) +{ + auto h = var("h"); + auto w = var("w"); + EXPECT(lit(0) < h); + EXPECT(lit(0) < h * w); + EXPECT(lit(0) < h + w); +} + +// ------------------------------------------------------------------- +// Optimals: eval_optimals() +// ------------------------------------------------------------------- + +TEST_CASE(eval_optimals_single_var) +{ + auto n = var("n", {1, 8}, {2, 4}); + EXPECT(n.eval_optimals() == std::set{2, 4}); +} + +TEST_CASE(eval_optimals_compound_expr) +{ + auto n = var("n", {1, 8}, {2, 4}); + auto e = 2 * n + 1; + EXPECT(e.eval_optimals() == std::set{5, 9}); +} + +TEST_CASE(eval_optimals_multi_var) +{ + auto n = var("n", {1, 8}, {2, 4}); + auto m = var("m", {1, 8}, {3, 6}); + auto e = n + m; + EXPECT(e.eval_optimals() == std::set{5, 7, 8, 10}); +} + +TEST_CASE(eval_optimals_negative_throws) +{ + auto n = var("n", {1, 4}, {2}); + auto m = var("m", {1, 8}, {5}); + auto e = n - m; + EXPECT(test::throws([&] { (void)e.eval_optimals(); })); +} + +TEST_CASE(eval_optimals_no_optimals) +{ + auto n = var("n", {1, 8}); + EXPECT(n.eval_optimals().empty()); +} + +TEST_CASE(eval_optimals_empty_expr) +{ + se e; + EXPECT(e.eval_optimals().empty()); +} + +// ------------------------------------------------------------------- +// Serialization: bounded vars +// ------------------------------------------------------------------- + +TEST_CASE(serialize_bounded_var) +{ + auto h = var("h", {1, 128}); + auto r = round_trip(h); + EXPECT(r == h); + EXPECT(r != var("h", {1, 256})); + EXPECT(r != var("h")); +} + +TEST_CASE(serialize_bounded_var_in_expr) +{ + auto h = var("h", {1, 128}); + auto w = var("w", {1, 256}); + auto e = 2 * h + w - 3; + auto r = round_trip(e); + EXPECT(r == e); + EXPECT(r.eval_uint({{h, 64}, {w, 32}}) == 157); +} + +TEST_CASE(serialize_conv_output_with_bounds) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + auto r = round_trip(out); + EXPECT(r == out); + EXPECT(r.eval_uint({{h, 255}}) == 127); +} + +TEST_CASE(serialize_comparison_survives_round_trip) +{ + auto h = var("h", {3, 256}); + auto out = (h - 3) / 2 + 1; + auto h2 = round_trip(h); + auto out2 = round_trip(out); + EXPECT(out2 < h2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }