Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,29 @@ struct MIGRAPHX_EXPORT shape

shape with_type(type_t t) const;

// convert the shape to an equivalent dynamic shape with constant symbolic strides
// convert the shape to an equivalent range-based dynamic shape: each static len becomes
// dd{len, len} (strides are not carried); a symbolic shape is demoted by evaluating
// each dim's interval/optimals (symbolic strides are dropped). Idempotent on a shape
// that is already range-based dynamic.
shape to_dynamic() const;

// Align a list of shapes to a single representation. If any input contains a
// range-based dynamic shape (at any nesting level), every shape is converted via
// to_dynamic() (symbolic inputs are demoted). Otherwise every shape is converted
// via to_symbolic() (static inputs are promoted to symbolic literals). Recurses
// into tuple sub-shapes.
static std::vector<shape> to_dynamic(const std::vector<shape>& shapes);

// convert the shape to an equivalent symbolic dynamic shape: each static len becomes
// dd{sym::lit(len)} and each static stride becomes sym::lit(stride). Idempotent on a
// shape that is already symbolic. Throws on a range-based dynamic shape.
shape to_symbolic() const;

// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
shape to_static(const std::unordered_map<sym::expr, std::size_t>& symbol_map) const;
// Collapse a fully-fixed shape to a static one; throws on non-fixed dimensions.
shape to_static() const;

MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y);
MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y);
Expand Down Expand Up @@ -572,6 +589,26 @@ struct MIGRAPHX_EXPORT shape

void debug_print() const;

/// Whether a dim-like value has a single, known static integer value.
static bool is_fixed_dim(std::size_t) { return true; }
static bool is_fixed_dim(const dynamic_dimension& d) { return d.is_fixed(); }

/// Extract the static integer value from a fixed dim-like value. Caller is
/// responsible for ensuring `is_fixed_dim(x)` first.
static std::size_t static_dim_value(std::size_t x) { return x; }
static std::size_t static_dim_value(const dynamic_dimension& d)
{
if(not d.is_fixed())
MIGRAPHX_THROW("shape::static_dim_value: dimension is not fixed");
return d.get_interval().max;
}

/// Whether all dims of this shape have a single, known static integer value.
/// True for static shapes, range-based shapes with all-fixed dims, symbolic
/// shapes whose dims are all literals (or vars with collapsed bounds), and
/// tuple shapes whose sub-shapes are all fixed.
bool is_fixed() const;

private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
Expand Down
92 changes: 83 additions & 9 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,11 @@ shape shape::with_type(type_t t) const
return {c};
}

// Convert to an equivalent range-based dynamic shape:
// - static : each len becomes dd{len, len} (strides are not carried)
// - range-based dynamic : identity
// - symbolic : each dim is demoted via get_interval()/get_optimals(); symbolic
// strides are dropped (range-based shapes don't carry them)
shape shape::to_dynamic() const
{
if(not sub_shapes().empty())
Expand All @@ -854,20 +859,66 @@ shape shape::to_dynamic() const
[](auto s) { return s.to_dynamic(); });
return shape(subs);
}
if(this->symbolic())
{
std::vector<dynamic_dimension> dims(ndim());
std::transform(dyn_dims().begin(), dyn_dims().end(), dims.begin(), [](const auto& d) {
auto iv = d.get_interval();
return dynamic_dimension{iv.min, iv.max, d.get_optimals()};
});
return {type(), std::move(dims)};
}
if(this->dynamic())
{
return *this;
}
std::vector<dynamic_dimension> dims;
dims.reserve(ndim());
std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) {
return dynamic_dimension{len, len};
return {type(), lens(), lens(), {}};
Comment thread
shivadbhavsar marked this conversation as resolved.
}

static bool any_non_sym_dynamic(const shape& s)
{
if(not s.sub_shapes().empty())
return std::any_of(s.sub_shapes().begin(), s.sub_shapes().end(), &any_non_sym_dynamic);
return s.dynamic() and not s.symbolic();
}

std::vector<shape> shape::to_dynamic(const std::vector<shape>& shapes)
{
const bool any_non_sym = std::any_of(shapes.begin(), shapes.end(), &any_non_sym_dynamic);
std::vector<shape> result(shapes.size());
std::transform(shapes.begin(), shapes.end(), result.begin(), [&](const auto& s) {
return any_non_sym ? s.to_dynamic() : s.to_symbolic();
});
std::vector<sym::expr> dstrides;
dstrides.reserve(ndim());
std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) {
return sym::lit(s);
return result;
}

shape shape::to_symbolic() const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_symbolic(); });
return shape(subs);
}
if(this->symbolic())
{
return *this;
}
if(this->dynamic())
{
// Range-based dynamic shapes have no clean symbolic representation
MIGRAPHX_THROW("SHAPE: to_symbolic() called on a range-based dynamic shape");
}
std::vector<dynamic_dimension> dims(ndim());
std::transform(lens().begin(), lens().end(), dims.begin(), [](auto len) {
return dynamic_dimension{sym::lit(len)};
});
std::vector<sym::expr> dstrides(ndim());
std::transform(
strides().begin(), strides().end(), dstrides.begin(), [](auto s) { return sym::lit(s); });
return {type(), std::move(dims), std::move(dstrides)};
}

Expand Down Expand Up @@ -929,6 +980,13 @@ shape shape::to_static(const std::unordered_map<sym::expr, std::size_t>& symbol_
return {type(), static_lens, static_strides};
}

shape shape::to_static() const
{
if(not this->is_fixed())
MIGRAPHX_THROW("SHAPE: to_static() requires fully-fixed dimensions");
return this->to_static(std::unordered_map<sym::expr, std::size_t>{});
}

std::size_t shape::element_space() const { return impl->element_space(); }

std::string shape::type_string() const { return name(this->type()); }
Expand Down Expand Up @@ -1274,9 +1332,25 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os;
}

bool shape::is_fixed() const
{
if(not sub_shapes().empty())
return std::all_of(
sub_shapes().begin(), sub_shapes().end(), [](const auto& s) { return s.is_fixed(); });
if(this->dynamic())
return std::all_of(
dyn_dims().begin(), dyn_dims().end(), [](const auto& d) { return d.is_fixed(); });
return true;
}

// Fixed shapes compare by resolved static lens; otherwise compare dyn_dims directly.
bool shape::same_lens(const shape& x, const shape& y)
{
return x.to_dynamic().dyn_dims() == y.to_dynamic().dyn_dims();
if(x.is_fixed() != y.is_fixed())
return false;
if(x.is_fixed())
return x.to_static().lens() == y.to_static().lens();
return x.dyn_dims() == y.dyn_dims();
}

shape::type_t shape::parse_type(const std::string& s)
Expand Down
3 changes: 3 additions & 0 deletions src/sym.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings)
auto it = bindings.find(node);
if(it != bindings.end())
return it->second;
// Fall back to the symbol's own bounds when fixed (min == max).
if(d.min == d.max)
return d.min;
MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + d.name + "'");
});
}
Expand Down
Loading
Loading