diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index e59f8f2c684..1a8c1f9d53e 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -447,12 +447,29 @@ struct MIGRAPHX_EXPORT shape shape with_type(type_t t) const; - // convert the shape to an equivalent dynamic shape with constant symbolic strides + // convert the shape to an equivalent range-based dynamic shape: each static len becomes + // dd{len, len} (strides are not carried); a symbolic shape is demoted by evaluating + // each dim's interval/optimals (symbolic strides are dropped). Idempotent on a shape + // that is already range-based dynamic. shape to_dynamic() const; + // Align a list of shapes to a single representation. If any input contains a + // range-based dynamic shape (at any nesting level), every shape is converted via + // to_dynamic() (symbolic inputs are demoted). Otherwise every shape is converted + // via to_symbolic() (static inputs are promoted to symbolic literals). Recurses + // into tuple sub-shapes. + static std::vector to_dynamic(const std::vector& shapes); + + // convert the shape to an equivalent symbolic dynamic shape: each static len becomes + // dd{sym::lit(len)} and each static stride becomes sym::lit(stride). Idempotent on a + // shape that is already symbolic. Throws on a range-based dynamic shape. + shape to_symbolic() const; + // convert the shape to a static one setting any non-fixed dynamic_dimensions to x shape to_static(std::size_t x) const; shape to_static(const std::unordered_map& symbol_map) const; + // Collapse a fully-fixed shape to a static one; throws on non-fixed dimensions. + shape to_static() const; MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); @@ -572,6 +589,26 @@ struct MIGRAPHX_EXPORT shape void debug_print() const; + /// Whether a dim-like value has a single, known static integer value. + static bool is_fixed_dim(std::size_t) { return true; } + static bool is_fixed_dim(const dynamic_dimension& d) { return d.is_fixed(); } + + /// Extract the static integer value from a fixed dim-like value. Caller is + /// responsible for ensuring `is_fixed_dim(x)` first. + static std::size_t static_dim_value(std::size_t x) { return x; } + static std::size_t static_dim_value(const dynamic_dimension& d) + { + if(not d.is_fixed()) + MIGRAPHX_THROW("shape::static_dim_value: dimension is not fixed"); + return d.get_interval().max; + } + + /// Whether all dims of this shape have a single, known static integer value. + /// True for static shapes, range-based shapes with all-fixed dims, symbolic + /// shapes whose dims are all literals (or vars with collapsed bounds), and + /// tuple shapes whose sub-shapes are all fixed. + bool is_fixed() const; + private: shape(std::shared_ptr pimpl); std::shared_ptr impl; diff --git a/src/shape.cpp b/src/shape.cpp index 4aa67e157b3..0cf5470e5a6 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -843,6 +843,11 @@ shape shape::with_type(type_t t) const return {c}; } +// Convert to an equivalent range-based dynamic shape: +// - static : each len becomes dd{len, len} (strides are not carried) +// - range-based dynamic : identity +// - symbolic : each dim is demoted via get_interval()/get_optimals(); symbolic +// strides are dropped (range-based shapes don't carry them) shape shape::to_dynamic() const { if(not sub_shapes().empty()) @@ -854,20 +859,66 @@ shape shape::to_dynamic() const [](auto s) { return s.to_dynamic(); }); return shape(subs); } + if(this->symbolic()) + { + std::vector dims(ndim()); + std::transform(dyn_dims().begin(), dyn_dims().end(), dims.begin(), [](const auto& d) { + auto iv = d.get_interval(); + return dynamic_dimension{iv.min, iv.max, d.get_optimals()}; + }); + return {type(), std::move(dims)}; + } if(this->dynamic()) { return *this; } - std::vector dims; - dims.reserve(ndim()); - std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { - return dynamic_dimension{len, len}; + return {type(), lens(), lens(), {}}; +} + +static bool any_non_sym_dynamic(const shape& s) +{ + if(not s.sub_shapes().empty()) + return std::any_of(s.sub_shapes().begin(), s.sub_shapes().end(), &any_non_sym_dynamic); + return s.dynamic() and not s.symbolic(); +} + +std::vector shape::to_dynamic(const std::vector& shapes) +{ + const bool any_non_sym = std::any_of(shapes.begin(), shapes.end(), &any_non_sym_dynamic); + std::vector result(shapes.size()); + std::transform(shapes.begin(), shapes.end(), result.begin(), [&](const auto& s) { + return any_non_sym ? s.to_dynamic() : s.to_symbolic(); }); - std::vector dstrides; - dstrides.reserve(ndim()); - std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) { - return sym::lit(s); + return result; +} + +shape shape::to_symbolic() const +{ + if(not sub_shapes().empty()) + { + std::vector subs; + std::transform(sub_shapes().cbegin(), + sub_shapes().cend(), + std::back_inserter(subs), + [](auto s) { return s.to_symbolic(); }); + return shape(subs); + } + if(this->symbolic()) + { + return *this; + } + if(this->dynamic()) + { + // Range-based dynamic shapes have no clean symbolic representation + MIGRAPHX_THROW("SHAPE: to_symbolic() called on a range-based dynamic shape"); + } + std::vector dims(ndim()); + std::transform(lens().begin(), lens().end(), dims.begin(), [](auto len) { + return dynamic_dimension{sym::lit(len)}; }); + std::vector dstrides(ndim()); + std::transform( + strides().begin(), strides().end(), dstrides.begin(), [](auto s) { return sym::lit(s); }); return {type(), std::move(dims), std::move(dstrides)}; } @@ -929,6 +980,13 @@ shape shape::to_static(const std::unordered_map& symbol_ return {type(), static_lens, static_strides}; } +shape shape::to_static() const +{ + if(not this->is_fixed()) + MIGRAPHX_THROW("SHAPE: to_static() requires fully-fixed dimensions"); + return this->to_static(std::unordered_map{}); +} + std::size_t shape::element_space() const { return impl->element_space(); } std::string shape::type_string() const { return name(this->type()); } @@ -1274,9 +1332,25 @@ std::ostream& operator<<(std::ostream& os, const shape& x) return os; } +bool shape::is_fixed() const +{ + if(not sub_shapes().empty()) + return std::all_of( + sub_shapes().begin(), sub_shapes().end(), [](const auto& s) { return s.is_fixed(); }); + if(this->dynamic()) + return std::all_of( + dyn_dims().begin(), dyn_dims().end(), [](const auto& d) { return d.is_fixed(); }); + return true; +} + +// Fixed shapes compare by resolved static lens; otherwise compare dyn_dims directly. bool shape::same_lens(const shape& x, const shape& y) { - return x.to_dynamic().dyn_dims() == y.to_dynamic().dyn_dims(); + if(x.is_fixed() != y.is_fixed()) + return false; + if(x.is_fixed()) + return x.to_static().lens() == y.to_static().lens(); + return x.dyn_dims() == y.dyn_dims(); } shape::type_t shape::parse_type(const std::string& s) diff --git a/src/sym.cpp b/src/sym.cpp index 4aa68370812..c54b4ff532f 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -615,6 +615,9 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) auto it = bindings.find(node); if(it != bindings.end()) return it->second; + // Fall back to the symbol's own bounds when fixed (min == max). + if(d.min == d.max) + return d.min; MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + d.name + "'"); }); } diff --git a/test/shape_test.cpp b/test/shape_test.cpp index fd67fe8aae4..f897588d04b 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -446,9 +446,7 @@ TEST_CASE(test_shape_static_to_dynamic) { migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s1 = s0.to_dynamic(); - migraphx::shape s2{migraphx::shape::float_type, - {{1, 1}, {2, 2}, {4, 4}, {4, 4}}, - {lit(32), lit(16), lit(4), lit(1)}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}}; EXPECT(s1 == s2); } @@ -468,12 +466,136 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{ - migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}}); + sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } +TEST_CASE(test_shape_static_to_symbolic) +{ + migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; + migraphx::shape s1 = s0.to_symbolic(); + migraphx::shape s2{migraphx::shape::float_type, + {dd{lit(1)}, dd{lit(2)}, dd{lit(4)}, dd{lit(4)}}, + {lit(32), lit(16), lit(4), lit(1)}}; + EXPECT(s1 == s2); + EXPECT(s1.symbolic()); +} + +TEST_CASE(test_shape_symbolic_to_symbolic) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + auto s1 = s0.to_symbolic(); + EXPECT(s0 == s1); +} + +TEST_CASE(test_shape_dyn_to_symbolic_throws) +{ + migraphx::shape s0{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + EXPECT(test::throws([&] { s0.to_symbolic(); })); +} + +TEST_CASE(test_shape_subshapes_to_symbolic) +{ + std::vector sub_shapes0 = {}; + sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {2, 3}}); + sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}); + migraphx::shape s0{sub_shapes0}; + migraphx::shape s1 = s0.to_symbolic(); + std::vector sub_shapes1 = {}; + sub_shapes1.push_back( + migraphx::shape{migraphx::shape::float_type, {dd{lit(2)}, dd{lit(3)}}, {lit(3), lit(1)}}); + sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, + {dd{lit(3)}, dd{lit(4)}, dd{lit(5)}}, + {lit(20), lit(5), lit(1)}}); + migraphx::shape s2{sub_shapes1}; + EXPECT(s1 == s2); +} + +TEST_CASE(test_shapes_to_dynamic_empty) +{ + auto out = migraphx::shape::to_dynamic({}); + EXPECT(out.empty()); +} + +TEST_CASE(test_shapes_to_dynamic_all_static) +{ + migraphx::shape a{migraphx::shape::float_type, {2, 3}}; + migraphx::shape b{migraphx::shape::float_type, {3, 4}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out.size() == 2); + EXPECT(out[0] == a.to_symbolic()); + EXPECT(out[1] == b.to_symbolic()); + EXPECT(out[0].symbolic()); + EXPECT(out[1].symbolic()); +} + +TEST_CASE(test_shapes_to_dynamic_all_symbolic) +{ + auto n = var("n", {1, 8}); + migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape b{migraphx::shape::float_type, {dd{lit(2)}, dd{n}}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b); +} + +TEST_CASE(test_shapes_to_dynamic_all_range) +{ + migraphx::shape a{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + migraphx::shape b{migraphx::shape::float_type, {{2, 2}, {3, 8}}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b); +} + +TEST_CASE(test_shapes_to_dynamic_sym_and_static) +{ + auto n = var("n", {1, 8}); + migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape b{migraphx::shape::float_type, {2, 4}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b.to_symbolic()); + EXPECT(out[1].symbolic()); +} + +TEST_CASE(test_shapes_to_dynamic_range_and_static) +{ + migraphx::shape a{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + migraphx::shape b{migraphx::shape::float_type, {2, 4}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b.to_dynamic()); + EXPECT(not out[1].symbolic()); + EXPECT(out[1].dynamic()); +} + +TEST_CASE(test_shapes_to_dynamic_sym_and_range_demotes) +{ + auto n = var("n", {1, 8}); + migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape b{migraphx::shape::float_type, {{2, 2}, {3, 8}}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(not out[0].symbolic()); + EXPECT(out[0].dynamic()); + EXPECT(out[0] == a.to_dynamic()); + EXPECT(out[1] == b); +} + +TEST_CASE(test_shapes_to_dynamic_subshapes_recurse) +{ + migraphx::shape inner_static{migraphx::shape::float_type, {2, 3}}; + migraphx::shape inner_range{migraphx::shape::float_type, {{1, 4}, {3, 3}}}; + migraphx::shape tuple_with_range{std::vector{inner_range, inner_static}}; + migraphx::shape plain_static{migraphx::shape::float_type, {3, 4}}; + auto out = migraphx::shape::to_dynamic({tuple_with_range, plain_static}); + EXPECT(out[0] == tuple_with_range.to_dynamic()); + EXPECT(out[1] == plain_static.to_dynamic()); +} + TEST_CASE(test_shape_dyn_to_static) { migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 2}, {2, 10}, {2, 10}}}; @@ -1280,6 +1402,50 @@ TEST_CASE(shape_same_lens_static_dynamic) EXPECT(not migraphx::shape::same_lens(s1, s3)); } +TEST_CASE(shape_same_lens_symbolic_fixed) +{ + auto n = var("n", {4, 4}); + migraphx::shape s_static{migraphx::shape::float_type, {1, 4, 8}}; + migraphx::shape s_sym_lit{migraphx::shape::half_type, {dd{lit(1)}, dd{lit(4)}, dd{lit(8)}}}; + migraphx::shape s_sym_fixed_var{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(8)}}}; + migraphx::shape s_dyn_fixed{migraphx::shape::float_type, {{1, 1}, {4, 4}, {8, 8}}}; + EXPECT(migraphx::shape::same_lens(s_static, s_sym_lit)); + EXPECT(migraphx::shape::same_lens(s_static, s_sym_fixed_var)); + EXPECT(migraphx::shape::same_lens(s_sym_lit, s_dyn_fixed)); + EXPECT(migraphx::shape::same_lens(s_sym_fixed_var, s_dyn_fixed)); +} + +TEST_CASE(shape_same_lens_symbolic_nonfixed) +{ + auto n = var("n", {1, 8}); + auto m = var("m", {1, 8}); + migraphx::shape s_n{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape s_n_again{migraphx::shape::half_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape s_m{migraphx::shape::float_type, {dd{m}, dd{lit(4)}}}; + migraphx::shape s_range{migraphx::shape::float_type, {{1, 8}, {4, 4}}}; + EXPECT(migraphx::shape::same_lens(s_n, s_n_again)); + EXPECT(not migraphx::shape::same_lens(s_n, s_m)); + EXPECT(not migraphx::shape::same_lens(s_n, s_range)); +} + +TEST_CASE(shape_is_fixed) +{ + migraphx::shape s_static{migraphx::shape::float_type, {1, 2, 8}}; + migraphx::shape s_dyn_fixed{migraphx::shape::float_type, {{1, 1}, {2, 2}, {8, 8}}}; + migraphx::shape s_dyn_range{migraphx::shape::float_type, {{1, 4}, {2, 2}, {8, 8}}}; + migraphx::shape s_sym_lit{migraphx::shape::float_type, {dd{lit(1)}, dd{lit(2)}}}; + migraphx::shape s_sym_var{migraphx::shape::float_type, {dd{var("n", {1, 8})}, dd{lit(2)}}}; + EXPECT(s_static.is_fixed()); + EXPECT(s_dyn_fixed.is_fixed()); + EXPECT(not s_dyn_range.is_fixed()); + EXPECT(s_sym_lit.is_fixed()); + EXPECT(not s_sym_var.is_fixed()); + migraphx::shape s_tuple_fixed{{s_static, s_sym_lit}}; + migraphx::shape s_tuple_mixed{{s_static, s_sym_var}}; + EXPECT(s_tuple_fixed.is_fixed()); + EXPECT(not s_tuple_mixed.is_fixed()); +} + // =================================================================== // Symbolic dynamic_dimension tests // =================================================================== @@ -1578,13 +1744,16 @@ TEST_CASE(test_symbolic_transposed) EXPECT(not s.broadcasted()); } -TEST_CASE(test_symbolic_to_dynamic_identity) +TEST_CASE(test_symbolic_to_dynamic_demotes) { auto n = var("n", {1, 8}); auto c = var("c", {1, 16}); migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; auto s2 = s.to_dynamic(); - EXPECT(s == s2); + EXPECT(not s2.symbolic()); + EXPECT(s2.dynamic()); + migraphx::shape expected{migraphx::shape::float_type, {{1, 8}, {1, 16}, {4, 4}}}; + EXPECT(s2 == expected); } TEST_CASE(test_symbolic_overlap) diff --git a/test/sym_test.cpp b/test/sym_test.cpp index ff5a5e77043..3134ff82eba 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -499,12 +499,22 @@ TEST_CASE(eval_trunc_division) TEST_CASE(eval_unbound_throws) { - auto h = var("h"); - auto w = var("w"); + auto h = var("h", {1, 8}); + auto w = var("w", {1, 8}); EXPECT(test::throws([&] { h.eval_uint({}); })); EXPECT(test::throws([&] { (h + w).eval_uint({{h, 1}}); })); } +TEST_CASE(eval_uint_falls_back_to_fixed_bounds) +{ + // Fixed-bound vars (min == max) are resolved from their own bounds. + auto n = var("n", {4, 4}); + EXPECT(n.eval_uint({}) == 4); + EXPECT((n * 8).eval_uint({}) == 32); + auto h = var("h", {1, 8}); + EXPECT((h + n).eval_uint({{h, 2}}) == 6); +} + TEST_CASE(eval_division_by_zero_throws) { auto h = var("h");