Skip to content
Merged
Show file tree
Hide file tree
Changes from 77 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d2b684e
custom symbolic expression lib
shivadbhavsar Mar 24, 2026
aa55785
format
shivadbhavsar Mar 24, 2026
314f7cf
use visit
shivadbhavsar Mar 24, 2026
dcfe825
format
shivadbhavsar Mar 24, 2026
2ec0969
integrate symbolic expression in dynamic_dimension
shivadbhavsar Mar 25, 2026
18caf6b
tidy
shivadbhavsar Mar 25, 2026
483932b
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 25, 2026
200f3c4
Merge branch 'develop' into custom_sym_lib
shivadbhavsar Mar 25, 2026
7ff2045
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 25, 2026
6af3621
fix constructor ambiguity
shivadbhavsar Mar 25, 2026
b7d7c23
fix ambiguity
shivadbhavsar Mar 25, 2026
3486135
update namespace and interface design
shivadbhavsar Mar 26, 2026
edbce87
Merge branch 'develop' into custom_sym_lib
shivadbhavsar Mar 26, 2026
964f934
use int64 for literals
shivadbhavsar Mar 26, 2026
2719ae6
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
364bd23
fix merge
shivadbhavsar Mar 26, 2026
33614e0
change eval func name
shivadbhavsar Mar 26, 2026
2ba3b74
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
830594f
use int64 for internal eval
shivadbhavsar Mar 26, 2026
bd70d84
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
def3038
fix eval call
shivadbhavsar Mar 26, 2026
359070d
copilot comments
shivadbhavsar Mar 26, 2026
9ad996f
copilot review fix
shivadbhavsar Mar 26, 2026
b61b6d8
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
bda9f91
format and tidy
shivadbhavsar Mar 26, 2026
5027376
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
1209717
Merge branch 'sym_dim_integration' of https://github.com/ROCm/AMDMIGr…
shivadbhavsar Mar 26, 2026
003c9d3
tidy fix
shivadbhavsar Mar 30, 2026
3759299
tidy
shivadbhavsar Mar 30, 2026
50944fb
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
fe649ff
update the only call sites using the braced-init-list that cannot be …
shivadbhavsar Mar 30, 2026
7be6e7f
Merge remote-tracking branch 'origin/develop' into custom_sym_lib
shivadbhavsar Mar 30, 2026
1274d3a
address review comments
shivadbhavsar Mar 30, 2026
5b28774
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
83da044
license
shivadbhavsar Mar 30, 2026
d680d55
reduce complexity
shivadbhavsar Mar 30, 2026
8d5629b
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
e7ca1d6
update calls to eval_uint
shivadbhavsar Mar 30, 2026
54debb5
clean up test file
shivadbhavsar Mar 31, 2026
c7f698c
review comments
shivadbhavsar Mar 31, 2026
149b661
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 31, 2026
293b5d8
merge and tidy
shivadbhavsar Mar 31, 2026
cc521e4
license
shivadbhavsar Mar 31, 2026
3b7077a
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 31, 2026
47ec30a
fix style
shivadbhavsar Apr 1, 2026
59a7ef4
normalize fixed dynamic dim representation
shivadbhavsar Apr 1, 2026
95894db
fmt
shivadbhavsar Apr 1, 2026
be87f71
fix serialization and normalization
shivadbhavsar Apr 2, 2026
78f06c7
license
shivadbhavsar Apr 2, 2026
71d7ae7
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 6, 2026
1a68619
address review comments
shivadbhavsar Apr 7, 2026
2044073
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 7, 2026
d6c8d49
update tests for cleaned up intersection logic
shivadbhavsar Apr 7, 2026
4556366
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 8, 2026
378e3d7
review feedback updates
shivadbhavsar Apr 9, 2026
5412e60
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 9, 2026
0bf5680
fix tidy
shivadbhavsar Apr 10, 2026
4fa771a
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 10, 2026
2545a8e
fix callsite to remove disambiguity
shivadbhavsar Apr 10, 2026
e904332
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 11, 2026
c8b8df4
fix merge
shivadbhavsar Apr 11, 2026
4e0da9c
remove optional from sym_expr
shivadbhavsar Apr 13, 2026
ca454b2
refactor how dyn dim intervals are stored and accessed
shivadbhavsar Apr 14, 2026
29ca4d5
add defaults
shivadbhavsar Apr 14, 2026
a0026ba
add getter for optimals and update callsites
shivadbhavsar Apr 14, 2026
1db31d6
update to use get_interval() and remove min() and max()
shivadbhavsar Apr 14, 2026
c128adc
fix cppcheck
shivadbhavsar Apr 15, 2026
3b2c259
return optimals by value
shivadbhavsar Apr 15, 2026
49d7aa6
Merge remote-tracking branch 'origin/develop' into dyn_interval_refactor
shivadbhavsar Apr 15, 2026
cc74c4a
update has_optimal
shivadbhavsar Apr 15, 2026
44cd175
symbolic dimension integration (squashed)
shivadbhavsar Apr 15, 2026
ae322fc
update implementation to work on top of inverval refactor
shivadbhavsar Apr 15, 2026
7b5484e
Merge branch 'sym_dim_integration' of https://github.com/ROCm/AMDMIGr…
shivadbhavsar Apr 15, 2026
39e0442
Merge remote-tracking branch 'origin/develop' into sym_dim_integration
shivadbhavsar Apr 15, 2026
5be3d69
fix old constructor
shivadbhavsar Apr 15, 2026
093964c
fix ambiguous call
shivadbhavsar Apr 16, 2026
72f00ea
fix cppcheck
shivadbhavsar Apr 16, 2026
812a1b1
add missing comment blocks
shivadbhavsar Apr 17, 2026
8573cfb
clearly state assumptions used when dealing with stride permutations …
shivadbhavsar Apr 21, 2026
718e599
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 21, 2026
d94f367
add missing tests
shivadbhavsar Apr 23, 2026
1747639
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 23, 2026
3e0e2c2
make var bounds non-optional and add deprecation TODO to clarify the …
shivadbhavsar Apr 24, 2026
4bbb2e6
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 24, 2026
94c7941
add scalar variant to ease merging 4782
shivadbhavsar Apr 24, 2026
a64a0d1
fix brace-init ambiguity for sles
shivadbhavsar Apr 24, 2026
c4a33e1
wrap scalar variant to handle ambuigity
shivadbhavsar Apr 25, 2026
3887e3f
use migraphx_requires and fix tidy
shivadbhavsar Apr 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,18 @@ struct batch_slicer
batch_slicer(const shape& mat_shape)
{
auto n_batch_dims = mat_shape.ndim() - 2;
inner_shape = shape{mat_shape.type(),
{mat_shape.lens().end() - 2, mat_shape.lens().end()},
{mat_shape.strides().end() - 2, mat_shape.strides().end()}};
inner_shape = shape{
mat_shape.type(),
std::vector<std::size_t>{mat_shape.lens().end() - 2, mat_shape.lens().end()},
std::vector<std::size_t>{mat_shape.strides().end() - 2, mat_shape.strides().end()}};
if(n_batch_dims > 0)
{
outer_shape =
shape{mat_shape.type(),
{mat_shape.lens().begin(), mat_shape.lens().begin() + n_batch_dims},
{mat_shape.strides().begin(), mat_shape.strides().begin() + n_batch_dims}};
std::vector<std::size_t>{mat_shape.lens().begin(),
mat_shape.lens().begin() + n_batch_dims},
std::vector<std::size_t>{mat_shape.strides().begin(),
mat_shape.strides().begin() + n_batch_dims}};
}
}

Expand Down
101 changes: 76 additions & 25 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <migraphx/bf16.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/sym.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
Expand Down Expand Up @@ -112,42 +113,77 @@ struct MIGRAPHX_EXPORT shape
friend bool operator!=(const interval& a, const interval& b) { return not(a == b); }
};

interval range = {0, 0};
std::set<std::size_t> optimals{};
std::optional<interval> range;
std::optional<std::set<std::size_t>> optimals;
sym::expr sym_expr;

dynamic_dimension() = default;
dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} {}
dynamic_dimension(std::size_t min_v, std::size_t max_v)
: range{interval{min_v, max_v}}, optimals{std::set<std::size_t>{}}
{
}
dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set<std::size_t> opt)
: range{min_v, max_v}, optimals(std::move(opt))
: range{interval{min_v, max_v}},
optimals(min_v == max_v ? std::set<std::size_t>{} : std::move(opt))
{
}
dynamic_dimension(sym::expr s) : sym_expr(std::move(s))
{
if(sym_expr.empty())
MIGRAPHX_THROW(
"dynamic_dimension: cannot construct from an empty symbolic expression");
}

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.range, "range"), f(self.optimals, "optimals"));
return pack(
f(self.range, "range"), f(self.optimals, "optimals"), f(self.sym_expr, "sym"));
}

interval get_interval() const { return range; }
std::set<std::size_t> get_optimals() const { return optimals; }
interval get_interval() const
{
if(is_symbolic())
{
auto ival = sym_expr.eval_interval();
if(ival.min < 0 or ival.max < 0)
Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
MIGRAPHX_THROW("dynamic_dimension: symbolic expression has negative bounds");
return {static_cast<std::size_t>(ival.min), static_cast<std::size_t>(ival.max)};
Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
}
return *range;
}
std::set<std::size_t> get_optimals() const
{
if(is_symbolic())
return sym_expr.eval_optimals();
if(optimals.has_value())
return *optimals;
return {};
}

bool is_fixed() const;
bool is_symbolic() const { return not sym_expr.empty(); }
bool has_optimal() const;

/**
* Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if
* possible.
* possible. When both dimensions are symbolic, they are compatible only if they
* share the same symbolic expression.
*/
std::optional<dynamic_dimension> intersection(const dynamic_dimension& other) const
{
if(this->is_symbolic() and other.is_symbolic())
{
if(this->sym_expr == other.sym_expr)
return *this;
return nullopt;
}
auto this_interval = this->get_interval();
auto other_interval = other.get_interval();
auto left = std::max(this_interval.min, other_interval.min);
auto right = std::min(this_interval.max, other_interval.max);
if(left <= right)
{
return dynamic_dimension{left, right};
}
return nullopt;
}

Expand All @@ -164,20 +200,24 @@ struct MIGRAPHX_EXPORT shape
MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);

// add, subtract, multiply fixed std::size_t dimension
dynamic_dimension& operator+=(const std::size_t& x);
dynamic_dimension& operator-=(const std::size_t& x);
dynamic_dimension& operator*=(const std::size_t& x);
MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x,
const std::size_t& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x,
const dynamic_dimension& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x,
const std::size_t& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x,
const std::size_t& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator*(const std::size_t& x,
const dynamic_dimension& y);
// clang-format off
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(binary_op, assign_op) \
dynamic_dimension& operator assign_op(const dynamic_dimension& x); \
dynamic_dimension& operator assign_op(const std::size_t& x); \
MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \
const dynamic_dimension& x, const dynamic_dimension& y); \
MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \
const dynamic_dimension& x, const std::size_t& y); \
MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \
const std::size_t& x, const dynamic_dimension& y);

MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(+, +=)
MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(-, -=)
MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(*, *=)
MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(/, /=)
#undef MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP
// clang-format on
};

static std::string to_sizes_string(const std::vector<shape>& shapes);
Expand All @@ -202,8 +242,10 @@ struct MIGRAPHX_EXPORT shape
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
shape(type_t t, std::initializer_list<std::size_t> l, std::initializer_list<std::size_t> s);

shape(type_t t, std::vector<dynamic_dimension> dims);
shape(type_t t, std::vector<dynamic_dimension> dims, std::vector<sym::expr> dstrides);

// Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
Expand Down Expand Up @@ -242,6 +284,9 @@ struct MIGRAPHX_EXPORT shape
*/
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
static shape from_permutation(type_t t,
const std::vector<dynamic_dimension>& dds,
const std::vector<int64_t>& perm);

type_t type() const;
const std::vector<std::size_t>& lens() const;
Expand Down Expand Up @@ -272,6 +317,9 @@ struct MIGRAPHX_EXPORT shape

const std::vector<dynamic_dimension>& dyn_dims() const;

bool symbolic() const;
Comment thread
CharlieL7 marked this conversation as resolved.
const std::vector<sym::expr>& dyn_strides() const;

/*!
* Minimum lengths for dynamic shape.
* lens() for static shape.
Expand Down Expand Up @@ -388,14 +436,17 @@ struct MIGRAPHX_EXPORT shape

shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_lens(type_t t, const std::vector<dynamic_dimension>& dds) const;
shape with_lens(const std::vector<dynamic_dimension>& dds) const;

shape with_type(type_t t) const;

// convert the shape to an equivalent dynamic shape with empty optimals
// convert the shape to an equivalent dynamic shape with constant symbolic strides
shape to_dynamic() const;

// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
shape to_static(const std::unordered_map<sym::expr, std::size_t>& symbol_map) const;

MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y);
MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y);
Expand Down
26 changes: 24 additions & 2 deletions src/include/migraphx/sym.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cstdint>
#include <memory>
#include <ostream>
#include <set>
#include <string>
#include <unordered_map>

Expand All @@ -40,8 +41,21 @@ struct value;

namespace sym {

struct interval
{
int64_t min = 0;
int64_t max = 0;
Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
friend bool operator==(const interval& a, const interval& b)
{
return a.min == b.min and a.max == b.max;
}
friend bool operator!=(const interval& a, const interval& b) { return not(a == b); }
};

struct expr;
MIGRAPHX_EXPORT expr var(const std::string& name);
MIGRAPHX_EXPORT expr var(const std::string& name,
interval bounds = {1, 1},
std::set<int64_t> optimals = {});
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the bounds and optimals in the symbolic variable feels wrong to me. The bounds should be supplied from the interval?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing the bigger picture now. This is expecting that an expression can have variables that each have different intervals. Making it different from the top-level interval in the dynamic_dimension. I'm thinking there should be a map between variables and their intervals stored in the program or module itself.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I had put a comment somewhere regarding this (maybe in the shape implementation). The big picture is that symbolic shapes will not even have a top level interval in dynamic_dimension. Its implemented as optional because it should only exist for range-based shapes for backward compatibility (and any other reason we might want a pure range-based dynamic shape). For symbolic shapes, when the interval is needed, it will be computed on the fly based on the symbols in the symbolic expression for that shape

MIGRAPHX_EXPORT expr lit(int64_t n);
MIGRAPHX_EXPORT expr parse(const std::string& s);

Expand All @@ -50,11 +64,14 @@ struct MIGRAPHX_EXPORT expr
expr();

bool empty() const;
bool is_literal() const;
std::size_t hash() const;
std::string to_string() const;
value to_value() const;
void from_value(const value& v);
std::size_t eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const;
interval eval_interval() const;
std::set<std::size_t> eval_optimals() const;
expr subs(const std::unordered_map<expr, expr>& symbol_map) const;

MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b);
Expand All @@ -63,6 +80,10 @@ struct MIGRAPHX_EXPORT expr
MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b);
MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b);
MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b);
MIGRAPHX_EXPORT friend bool operator<(const expr& a, const expr& b);
friend bool operator>(const expr& a, const expr& b) { return b < a; }
friend bool operator<=(const expr& a, const expr& b) { return not(b < a); }
friend bool operator>=(const expr& a, const expr& b) { return not(a < b); }
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e);

friend expr operator+(const expr& a, int64_t b) { return a + lit(b); }
Expand All @@ -76,7 +97,8 @@ struct MIGRAPHX_EXPORT expr

struct impl;

MIGRAPHX_EXPORT friend expr var(const std::string& name);
MIGRAPHX_EXPORT friend expr
var(const std::string& name, interval bounds, std::set<int64_t> optimals);
MIGRAPHX_EXPORT friend expr lit(int64_t n);
MIGRAPHX_EXPORT friend expr parse(const std::string& s);

Expand Down
29 changes: 24 additions & 5 deletions src/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/sym.hpp>
#include <map>
#include <functional>

Expand All @@ -33,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS {

shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
if(s.symbolic())
return {s.type(),
reorder_dims(s.dyn_dims(), permutation),
reorder_dims(s.dyn_strides(), permutation)};
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}

Expand All @@ -43,11 +48,25 @@ std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)

std::vector<int64_t> find_permutation(const shape& s)
{
std::vector<std::int64_t> result(s.lens().size());
if(s.dynamic() and not s.symbolic())
MIGRAPHX_THROW("FIND_PERMUTATION: non-symbolic dynamic shapes not supported");
std::vector<std::int64_t> result(s.ndim());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(s.strides()[x], s.lens()[x]);
}));
if(s.symbolic())
{
const auto& strides = s.dyn_strides();
const auto& dds = s.dyn_dims();
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(strides[x].eval_interval().max,
dds[x].sym_expr.eval_interval().max);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is finding the permutation by evaluating out the intervals and then ordering the symbolic expressions using those. Is this why the symbolic variable needs the interval? This feels incorrect to me and instead if we have the same assumptions of monotonic increasing functions, evaluating any random interval should work?

Copy link
Copy Markdown
Contributor Author

@shivadbhavsar shivadbhavsar Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had a large comment block on is_sorted_strides to address this very thing but somehow it got removed when i was consolidating changes after the interval refactor. It should be there now.
Root of the issue is that ideally we should be sorting the expressions (using intervals), but when you have variables with min=1, eg. n{1, 4}, c{1,8}, you cant technically say n < nc, because n == nc at their min values. That causes all sorts of problems when trying to sort these using intervals and so the simplest pragmatic approach i could find was to simply use the interval max for stride sorting.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the idea of permutations makes sense for symbolic shapes. Take this contrived example:
strides = {n , (n*c) / d}, n = {1, 2}, c = {1, 2}, d = {1,2}
for n = 1, c = 2, d = 1: strides = {1, 2} transposed
for n = 2, c = 1, d = 2: strides = {2, 1} standard

Maybe instead the order for the static shape must be determined at runtime unless something like this holds : stride_0 - stride_1 < 0 where one symbolic stride is strictly less than other other for positive, non-zero values.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would a stride expression have a division by a non constant symbol? Is there any op that would actually create such a stride?
The reason I think we want to support the permutation functions for sym shapes is because some optimizations and shape ops will not work without this (or will need to implement other logic in a different way to get it to work).

Regardless, I got claude to try and create breaking examples for this using stride manipulating shape ops (like transpose, slice, step, broadcasts, reshape), and the only way it was able to come up with an ambiguous case like this is if the input bounds are just really badly defined.

Eg.:

Step + transpose with degenerate bounds:

Input [a, b], b ∈ [2, 16], standard strides [b, 1].

After step(axis=1, step=8): dims [a, ceil(b/8)], strides [b, 8].

After transpose [1, 0]: strides [8, b].

At b=2: strides [8, 2] → perm [0, 1]
At b=16: strides [8, 16] → perm [1, 0]

Here the input bounds dont really make sense considering the step value, and it would even give an incorrect result for the static shape case if compiling the model for something like [2,2] input shape.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, updated to be more clear about assumptions being made, and explicitly throw when intervals cause contradictory results at boundaries

Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
}));
}
else
{
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(s.strides()[x], s.lens()[x]);
}));
}
return result;
}

Expand All @@ -64,7 +83,7 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
}
if(count.empty())
{
std::vector<int64_t> r(shapes.front().lens().size());
std::vector<int64_t> r(shapes.front().ndim());
std::iota(r.begin(), r.end(), 0);
return r;
}
Expand Down
Loading
Loading