Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,24 @@ struct MIGRAPHX_EXPORT shape

shape with_type(type_t t) const;

// convert the shape to an equivalent dynamic shape with constant symbolic strides
// convert the shape to an equivalent range-based dynamic shape: each static len becomes
// dd{len, len} (strides are not carried); a symbolic shape is demoted by evaluating
// each dim's interval/optimals (symbolic strides are dropped). Idempotent on a shape
// that is already range-based dynamic.
shape to_dynamic() const;

// Align a list of shapes to a single representation. If any input contains a
// range-based dynamic shape (at any nesting level), every shape is converted via
// to_dynamic() (symbolic inputs are demoted). Otherwise every shape is converted
// via to_symbolic() (static inputs are promoted to symbolic literals). Recurses
// into tuple sub-shapes.
static std::vector<shape> to_dynamic(const std::vector<shape>& shapes);

// convert the shape to an equivalent symbolic dynamic shape: each static len becomes
// dd{sym::lit(len)} and each static stride becomes sym::lit(stride). Idempotent on a
// shape that is already symbolic. Throws on a range-based dynamic shape.
shape to_symbolic() const;

// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
shape to_static(const std::unordered_map<sym::expr, std::size_t>& symbol_map) const;
Expand Down Expand Up @@ -572,6 +587,26 @@ struct MIGRAPHX_EXPORT shape

void debug_print() const;

/// Whether a dim-like value has a single, known static integer value.
static bool is_fixed_dim(std::size_t) { return true; }
static bool is_fixed_dim(const dynamic_dimension& d) { return d.is_fixed(); }

/// Extract the static integer value from a fixed dim-like value. Caller is
/// responsible for ensuring `is_fixed_dim(x)` first.
static std::size_t static_dim_value(std::size_t x) { return x; }
static std::size_t static_dim_value(const dynamic_dimension& d)
{
if(not d.is_fixed())
MIGRAPHX_THROW("shape::static_dim_value: dimension is not fixed");
return d.get_interval().max;
}

/// Whether all dims of this shape have a single, known static integer value.
/// True for static shapes, range-based shapes with all-fixed dims, symbolic
/// shapes whose dims are all literals (or vars with collapsed bounds), and
/// tuple shapes whose sub-shapes are all fixed.
bool is_fixed() const;

private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
Expand Down
85 changes: 76 additions & 9 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,11 @@ shape shape::with_type(type_t t) const
return {c};
}

// Convert to an equivalent range-based dynamic shape:
// - static : each len becomes dd{len, len} (strides are not carried)
// - range-based dynamic : identity
// - symbolic : each dim is demoted via get_interval()/get_optimals(); symbolic
// strides are dropped (range-based shapes don't carry them)
shape shape::to_dynamic() const
{
if(not sub_shapes().empty())
Expand All @@ -854,20 +859,66 @@ shape shape::to_dynamic() const
[](auto s) { return s.to_dynamic(); });
return shape(subs);
}
if(this->symbolic())
{
std::vector<dynamic_dimension> dims(ndim());
std::transform(dyn_dims().begin(), dyn_dims().end(), dims.begin(), [](const auto& d) {
auto iv = d.get_interval();
return dynamic_dimension{iv.min, iv.max, d.get_optimals()};
});
return {type(), std::move(dims)};
}
if(this->dynamic())
{
return *this;
}
std::vector<dynamic_dimension> dims;
dims.reserve(ndim());
std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) {
return dynamic_dimension{len, len};
return {type(), lens(), lens(), {}};
Comment thread
shivadbhavsar marked this conversation as resolved.
}

static bool any_non_sym_dynamic(const shape& s)
{
if(not s.sub_shapes().empty())
return std::any_of(s.sub_shapes().begin(), s.sub_shapes().end(), &any_non_sym_dynamic);
return s.dynamic() and not s.symbolic();
}

std::vector<shape> shape::to_dynamic(const std::vector<shape>& shapes)
{
const bool any_non_sym = std::any_of(shapes.begin(), shapes.end(), &any_non_sym_dynamic);
std::vector<shape> result(shapes.size());
std::transform(shapes.begin(), shapes.end(), result.begin(), [&](const auto& s) {
return any_non_sym ? s.to_dynamic() : s.to_symbolic();
});
std::vector<sym::expr> dstrides;
dstrides.reserve(ndim());
std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) {
return sym::lit(s);
return result;
}

shape shape::to_symbolic() const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_symbolic(); });
return shape(subs);
}
if(this->symbolic())
{
return *this;
}
if(this->dynamic())
{
// Range-based dynamic shapes have no clean symbolic representation
MIGRAPHX_THROW("SHAPE: to_symbolic() called on a range-based dynamic shape");
}
std::vector<dynamic_dimension> dims(ndim());
std::transform(lens().begin(), lens().end(), dims.begin(), [](auto len) {
return dynamic_dimension{sym::lit(len)};
});
std::vector<sym::expr> dstrides(ndim());
std::transform(
strides().begin(), strides().end(), dstrides.begin(), [](auto s) { return sym::lit(s); });
return {type(), std::move(dims), std::move(dstrides)};
}

Expand Down Expand Up @@ -1274,9 +1325,25 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os;
}

bool shape::is_fixed() const
{
if(not sub_shapes().empty())
return std::all_of(
sub_shapes().begin(), sub_shapes().end(), [](const auto& s) { return s.is_fixed(); });
if(this->dynamic())
return std::all_of(
dyn_dims().begin(), dyn_dims().end(), [](const auto& d) { return d.is_fixed(); });
return true;
}

// Fixed shapes compare by resolved static lens; otherwise compare dyn_dims directly.
bool shape::same_lens(const shape& x, const shape& y)
{
return x.to_dynamic().dyn_dims() == y.to_dynamic().dyn_dims();
if(x.is_fixed() != y.is_fixed())
return false;
if(x.is_fixed())
return x.to_static({}).lens() == y.to_static({}).lens();
Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
return x.dyn_dims() == y.dyn_dims();
}

shape::type_t shape::parse_type(const std::string& s)
Expand Down
183 changes: 176 additions & 7 deletions test/shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,7 @@ TEST_CASE(test_shape_static_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type,
{{1, 1}, {2, 2}, {4, 4}, {4, 4}},
{lit(32), lit(16), lit(4), lit(1)}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}};
EXPECT(s1 == s2);
}

Expand All @@ -468,12 +466,136 @@ TEST_CASE(test_shape_subshapes_to_dynamic)
migraphx::shape s1 = s0.to_dynamic();
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes1.push_back(migraphx::shape{
migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}

TEST_CASE(test_shape_static_to_symbolic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_symbolic();
migraphx::shape s2{migraphx::shape::float_type,
{dd{lit(1)}, dd{lit(2)}, dd{lit(4)}, dd{lit(4)}},
{lit(32), lit(16), lit(4), lit(1)}};
EXPECT(s1 == s2);
EXPECT(s1.symbolic());
}

TEST_CASE(test_shape_symbolic_to_symbolic)
{
auto n = var("n", {1, 8});
auto c = var("c", {1, 16});
migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}};
auto s1 = s0.to_symbolic();
EXPECT(s0 == s1);
}

TEST_CASE(test_shape_dyn_to_symbolic_throws)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
EXPECT(test::throws([&] { s0.to_symbolic(); }));
}

TEST_CASE(test_shape_subshapes_to_symbolic)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {2, 3}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_symbolic();
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(
migraphx::shape{migraphx::shape::float_type, {dd{lit(2)}, dd{lit(3)}}, {lit(3), lit(1)}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type,
{dd{lit(3)}, dd{lit(4)}, dd{lit(5)}},
{lit(20), lit(5), lit(1)}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}

TEST_CASE(test_shapes_to_dynamic_empty)
{
auto out = migraphx::shape::to_dynamic({});
EXPECT(out.empty());
}

TEST_CASE(test_shapes_to_dynamic_all_static)
{
migraphx::shape a{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3, 4}};
auto out = migraphx::shape::to_dynamic({a, b});
EXPECT(out.size() == 2);
EXPECT(out[0] == a.to_symbolic());
EXPECT(out[1] == b.to_symbolic());
EXPECT(out[0].symbolic());
EXPECT(out[1].symbolic());
}

TEST_CASE(test_shapes_to_dynamic_all_symbolic)
{
auto n = var("n", {1, 8});
migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}};
migraphx::shape b{migraphx::shape::float_type, {dd{lit(2)}, dd{n}}};
auto out = migraphx::shape::to_dynamic({a, b});
EXPECT(out[0] == a);
EXPECT(out[1] == b);
}

TEST_CASE(test_shapes_to_dynamic_all_range)
{
migraphx::shape a{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
migraphx::shape b{migraphx::shape::float_type, {{2, 2}, {3, 8}}};
auto out = migraphx::shape::to_dynamic({a, b});
EXPECT(out[0] == a);
EXPECT(out[1] == b);
}

TEST_CASE(test_shapes_to_dynamic_sym_and_static)
{
auto n = var("n", {1, 8});
migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}};
migraphx::shape b{migraphx::shape::float_type, {2, 4}};
auto out = migraphx::shape::to_dynamic({a, b});
EXPECT(out[0] == a);
EXPECT(out[1] == b.to_symbolic());
EXPECT(out[1].symbolic());
}

TEST_CASE(test_shapes_to_dynamic_range_and_static)
{
migraphx::shape a{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
migraphx::shape b{migraphx::shape::float_type, {2, 4}};
auto out = migraphx::shape::to_dynamic({a, b});
EXPECT(out[0] == a);
EXPECT(out[1] == b.to_dynamic());
EXPECT(not out[1].symbolic());
EXPECT(out[1].dynamic());
}

TEST_CASE(test_shapes_to_dynamic_sym_and_range_demotes)
{
auto n = var("n", {1, 8});
migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}};
migraphx::shape b{migraphx::shape::float_type, {{2, 2}, {3, 8}}};
auto out = migraphx::shape::to_dynamic({a, b});
EXPECT(not out[0].symbolic());
EXPECT(out[0].dynamic());
EXPECT(out[0] == a.to_dynamic());
EXPECT(out[1] == b);
}

TEST_CASE(test_shapes_to_dynamic_subshapes_recurse)
{
migraphx::shape inner_static{migraphx::shape::float_type, {2, 3}};
migraphx::shape inner_range{migraphx::shape::float_type, {{1, 4}, {3, 3}}};
migraphx::shape tuple_with_range{std::vector<migraphx::shape>{inner_range, inner_static}};
migraphx::shape plain_static{migraphx::shape::float_type, {3, 4}};
auto out = migraphx::shape::to_dynamic({tuple_with_range, plain_static});
EXPECT(out[0] == tuple_with_range.to_dynamic());
EXPECT(out[1] == plain_static.to_dynamic());
}

TEST_CASE(test_shape_dyn_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 2}, {2, 10}, {2, 10}}};
Expand Down Expand Up @@ -1280,6 +1402,50 @@ TEST_CASE(shape_same_lens_static_dynamic)
EXPECT(not migraphx::shape::same_lens(s1, s3));
}

TEST_CASE(shape_same_lens_symbolic_fixed)
{
auto n = var("n", {4, 4});
migraphx::shape s_static{migraphx::shape::float_type, {1, 4, 8}};
migraphx::shape s_sym_lit{migraphx::shape::half_type, {dd{lit(1)}, dd{lit(4)}, dd{lit(8)}}};
migraphx::shape s_sym_fixed_var{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(8)}}};
migraphx::shape s_dyn_fixed{migraphx::shape::float_type, {{1, 1}, {4, 4}, {8, 8}}};
EXPECT(migraphx::shape::same_lens(s_static, s_sym_lit));
EXPECT(migraphx::shape::same_lens(s_static, s_sym_fixed_var));
EXPECT(migraphx::shape::same_lens(s_sym_lit, s_dyn_fixed));
EXPECT(migraphx::shape::same_lens(s_sym_fixed_var, s_dyn_fixed));
}

TEST_CASE(shape_same_lens_symbolic_nonfixed)
{
auto n = var("n", {1, 8});
auto m = var("m", {1, 8});
migraphx::shape s_n{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}};
migraphx::shape s_n_again{migraphx::shape::half_type, {dd{n}, dd{lit(4)}}};
migraphx::shape s_m{migraphx::shape::float_type, {dd{m}, dd{lit(4)}}};
migraphx::shape s_range{migraphx::shape::float_type, {{1, 8}, {4, 4}}};
EXPECT(migraphx::shape::same_lens(s_n, s_n_again));
EXPECT(not migraphx::shape::same_lens(s_n, s_m));
EXPECT(not migraphx::shape::same_lens(s_n, s_range));
}

TEST_CASE(shape_is_fixed)
{
migraphx::shape s_static{migraphx::shape::float_type, {1, 2, 8}};
migraphx::shape s_dyn_fixed{migraphx::shape::float_type, {{1, 1}, {2, 2}, {8, 8}}};
migraphx::shape s_dyn_range{migraphx::shape::float_type, {{1, 4}, {2, 2}, {8, 8}}};
migraphx::shape s_sym_lit{migraphx::shape::float_type, {dd{lit(1)}, dd{lit(2)}}};
migraphx::shape s_sym_var{migraphx::shape::float_type, {dd{var("n", {1, 8})}, dd{lit(2)}}};
EXPECT(s_static.is_fixed());
EXPECT(s_dyn_fixed.is_fixed());
EXPECT(not s_dyn_range.is_fixed());
EXPECT(s_sym_lit.is_fixed());
EXPECT(not s_sym_var.is_fixed());
migraphx::shape s_tuple_fixed{{s_static, s_sym_lit}};
migraphx::shape s_tuple_mixed{{s_static, s_sym_var}};
EXPECT(s_tuple_fixed.is_fixed());
EXPECT(not s_tuple_mixed.is_fixed());
}

// ===================================================================
// Symbolic dynamic_dimension tests
// ===================================================================
Expand Down Expand Up @@ -1578,13 +1744,16 @@ TEST_CASE(test_symbolic_transposed)
EXPECT(not s.broadcasted());
}

TEST_CASE(test_symbolic_to_dynamic_identity)
TEST_CASE(test_symbolic_to_dynamic_demotes)
{
auto n = var("n", {1, 8});
auto c = var("c", {1, 16});
migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}};
auto s2 = s.to_dynamic();
EXPECT(s == s2);
EXPECT(not s2.symbolic());
EXPECT(s2.dynamic());
migraphx::shape expected{migraphx::shape::float_type, {{1, 8}, {1, 16}, {4, 4}}};
EXPECT(s2 == expected);
}

TEST_CASE(test_symbolic_overlap)
Expand Down
Loading