Skip to content
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d2b684e
custom symbolic expression lib
shivadbhavsar Mar 24, 2026
aa55785
format
shivadbhavsar Mar 24, 2026
314f7cf
use visit
shivadbhavsar Mar 24, 2026
dcfe825
format
shivadbhavsar Mar 24, 2026
2ec0969
integrate symbolic expression in dynamic_dimension
shivadbhavsar Mar 25, 2026
18caf6b
tidy
shivadbhavsar Mar 25, 2026
483932b
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 25, 2026
200f3c4
Merge branch 'develop' into custom_sym_lib
shivadbhavsar Mar 25, 2026
7ff2045
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 25, 2026
6af3621
fix constructor ambiguity
shivadbhavsar Mar 25, 2026
b7d7c23
fix ambiguity
shivadbhavsar Mar 25, 2026
3486135
update namespace and interface design
shivadbhavsar Mar 26, 2026
edbce87
Merge branch 'develop' into custom_sym_lib
shivadbhavsar Mar 26, 2026
964f934
use int64 for literals
shivadbhavsar Mar 26, 2026
2719ae6
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
364bd23
fix merge
shivadbhavsar Mar 26, 2026
33614e0
change eval func name
shivadbhavsar Mar 26, 2026
2ba3b74
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
830594f
use int64 for internal eval
shivadbhavsar Mar 26, 2026
bd70d84
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
def3038
fix eval call
shivadbhavsar Mar 26, 2026
359070d
copilot comments
shivadbhavsar Mar 26, 2026
9ad996f
copilot review fix
shivadbhavsar Mar 26, 2026
b61b6d8
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
bda9f91
format and tidy
shivadbhavsar Mar 26, 2026
5027376
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
1209717
Merge branch 'sym_dim_integration' of https://github.com/ROCm/AMDMIGr…
shivadbhavsar Mar 26, 2026
003c9d3
tidy fix
shivadbhavsar Mar 30, 2026
3759299
tidy
shivadbhavsar Mar 30, 2026
50944fb
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
fe649ff
update the only call sites using the braced-init-list that cannot be …
shivadbhavsar Mar 30, 2026
7be6e7f
Merge remote-tracking branch 'origin/develop' into custom_sym_lib
shivadbhavsar Mar 30, 2026
1274d3a
address review comments
shivadbhavsar Mar 30, 2026
5b28774
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
83da044
license
shivadbhavsar Mar 30, 2026
d680d55
reduce complexity
shivadbhavsar Mar 30, 2026
8d5629b
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
e7ca1d6
update calls to eval_uint
shivadbhavsar Mar 30, 2026
54debb5
clean up test file
shivadbhavsar Mar 31, 2026
c7f698c
review comments
shivadbhavsar Mar 31, 2026
149b661
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 31, 2026
293b5d8
merge and tidy
shivadbhavsar Mar 31, 2026
cc521e4
license
shivadbhavsar Mar 31, 2026
3b7077a
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 31, 2026
47ec30a
fix style
shivadbhavsar Apr 1, 2026
59a7ef4
normalize fixed dynamic dim representation
shivadbhavsar Apr 1, 2026
95894db
fmt
shivadbhavsar Apr 1, 2026
be87f71
fix serialization and normalization
shivadbhavsar Apr 2, 2026
78f06c7
license
shivadbhavsar Apr 2, 2026
71d7ae7
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 6, 2026
1a68619
address review comments
shivadbhavsar Apr 7, 2026
2044073
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 7, 2026
d6c8d49
update tests for cleaned up intersection logic
shivadbhavsar Apr 7, 2026
4556366
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 8, 2026
378e3d7
review feedback updates
shivadbhavsar Apr 9, 2026
5412e60
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 9, 2026
0bf5680
fix tidy
shivadbhavsar Apr 10, 2026
4fa771a
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 10, 2026
2545a8e
fix callsite to remove disambiguity
shivadbhavsar Apr 10, 2026
e904332
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 11, 2026
c8b8df4
fix merge
shivadbhavsar Apr 11, 2026
4e0da9c
remove optional from sym_expr
shivadbhavsar Apr 13, 2026
ca454b2
refactor how dyn dim intervals are stored and accessed
shivadbhavsar Apr 14, 2026
29ca4d5
add defaults
shivadbhavsar Apr 14, 2026
a0026ba
add getter for optimals and update callsites
shivadbhavsar Apr 14, 2026
1db31d6
update to use get_interval() and remove min() and max()
shivadbhavsar Apr 14, 2026
c128adc
fix cppcheck
shivadbhavsar Apr 15, 2026
3b2c259
return optimals by value
shivadbhavsar Apr 15, 2026
49d7aa6
Merge remote-tracking branch 'origin/develop' into dyn_interval_refactor
shivadbhavsar Apr 15, 2026
cc74c4a
update has_optimal
shivadbhavsar Apr 15, 2026
44cd175
symbolic dimension integration (squashed)
shivadbhavsar Apr 15, 2026
ae322fc
update implementation to work on top of inverval refactor
shivadbhavsar Apr 15, 2026
7b5484e
Merge branch 'sym_dim_integration' of https://github.com/ROCm/AMDMIGr…
shivadbhavsar Apr 15, 2026
39e0442
Merge remote-tracking branch 'origin/develop' into sym_dim_integration
shivadbhavsar Apr 15, 2026
5be3d69
fix old constructor
shivadbhavsar Apr 15, 2026
093964c
fix ambiguous call
shivadbhavsar Apr 16, 2026
72f00ea
fix cppcheck
shivadbhavsar Apr 16, 2026
812a1b1
add missing comment blocks
shivadbhavsar Apr 17, 2026
8573cfb
clearly state assumptions used when dealing with stride permutations …
shivadbhavsar Apr 21, 2026
718e599
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 21, 2026
d94f367
add missing tests
shivadbhavsar Apr 23, 2026
1747639
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 23, 2026
3e0e2c2
make var bounds non-optional and add deprecation TODO to clarify the …
shivadbhavsar Apr 24, 2026
4bbb2e6
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 24, 2026
94c7941
add scalar variant to ease merging 4782
shivadbhavsar Apr 24, 2026
a64a0d1
fix brace-init ambiguity for sles
shivadbhavsar Apr 24, 2026
c4a33e1
wrap scalar variant to handle ambuigity
shivadbhavsar Apr 25, 2026
3887e3f
use migraphx_requires and fix tidy
shivadbhavsar Apr 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,18 @@ struct batch_slicer
batch_slicer(const shape& mat_shape)
{
auto n_batch_dims = mat_shape.ndim() - 2;
inner_shape = shape{mat_shape.type(),
{mat_shape.lens().end() - 2, mat_shape.lens().end()},
{mat_shape.strides().end() - 2, mat_shape.strides().end()}};
inner_shape = shape{
mat_shape.type(),
std::vector<std::size_t>{mat_shape.lens().end() - 2, mat_shape.lens().end()},
std::vector<std::size_t>{mat_shape.strides().end() - 2, mat_shape.strides().end()}};
if(n_batch_dims > 0)
{
outer_shape =
shape{mat_shape.type(),
{mat_shape.lens().begin(), mat_shape.lens().begin() + n_batch_dims},
{mat_shape.strides().begin(), mat_shape.strides().begin() + n_batch_dims}};
std::vector<std::size_t>{mat_shape.lens().begin(),
mat_shape.lens().begin() + n_batch_dims},
std::vector<std::size_t>{mat_shape.strides().begin(),
mat_shape.strides().begin() + n_batch_dims}};
}
}

Expand Down
107 changes: 82 additions & 25 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <migraphx/bf16.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/sym.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
Expand Down Expand Up @@ -94,6 +95,13 @@ struct MIGRAPHX_EXPORT shape
{
};

// TODO: Deprecate the pure range-based form of dynamic_dimension in favor
// of the symbolic form (sym_expr). The current design carries two parallel
// notions of bounds -- dynamic_dimension::interval (std::size_t min/max,
// here) and sym::interval (int64_t min/max, attached to each sym::var) --
// which is a source of confusion. Once all shape-producing paths go through
// symbolic expressions, `range`/`optimals` and this nested `interval` can
// be removed and bounds will live solely on sym::var.
struct MIGRAPHX_EXPORT dynamic_dimension
{
struct interval
Expand All @@ -112,42 +120,76 @@ struct MIGRAPHX_EXPORT shape
friend bool operator!=(const interval& a, const interval& b) { return not(a == b); }
};

interval range = {0, 0};
std::set<std::size_t> optimals{};
std::optional<interval> range;
std::optional<std::set<std::size_t>> optimals;
sym::expr sym_expr;

dynamic_dimension() = default;
dynamic_dimension(std::size_t min_v, std::size_t max_v) : range{min_v, max_v} {}
dynamic_dimension(std::size_t min_v, std::size_t max_v)
: range{interval{min_v, max_v}}, optimals{std::set<std::size_t>{}}
{
}
dynamic_dimension(std::size_t min_v, std::size_t max_v, std::set<std::size_t> opt)
: range{min_v, max_v}, optimals(std::move(opt))
: range{interval{min_v, max_v}},
optimals(min_v == max_v ? std::set<std::size_t>{} : std::move(opt))
{
}
dynamic_dimension(sym::expr s) : sym_expr(std::move(s))
{
if(sym_expr.empty())
MIGRAPHX_THROW(
"dynamic_dimension: cannot construct from an empty symbolic expression");
}

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.range, "range"), f(self.optimals, "optimals"));
return pack(
f(self.range, "range"), f(self.optimals, "optimals"), f(self.sym_expr, "sym"));
}

interval get_interval() const { return range; }
std::set<std::size_t> get_optimals() const { return optimals; }
interval get_interval() const
{
if(is_symbolic())
{
auto ival = sym_expr.eval_interval();
assert(sym::to<int64_t>(ival.min) >= 0 and sym::to<int64_t>(ival.max) >= 0);
return {sym::to<std::size_t>(ival.min), sym::to<std::size_t>(ival.max)};
}
return *range;
}
std::set<std::size_t> get_optimals() const
{
if(is_symbolic())
return sym_expr.eval_optimals();
if(optimals.has_value())
return *optimals;
return {};
}

bool is_fixed() const;
bool is_symbolic() const { return not sym_expr.empty(); }
bool has_optimal() const;

/**
* Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if
* possible.
* possible. When both dimensions are symbolic, they are compatible only if they
* share the same symbolic expression.
*/
std::optional<dynamic_dimension> intersection(const dynamic_dimension& other) const
{
if(this->is_symbolic() and other.is_symbolic())
{
if(this->sym_expr == other.sym_expr)
return *this;
return nullopt;
}
auto this_interval = this->get_interval();
auto other_interval = other.get_interval();
auto left = std::max(this_interval.min, other_interval.min);
auto right = std::min(this_interval.max, other_interval.max);
if(left <= right)
{
return dynamic_dimension{left, right};
}
return nullopt;
}

Expand All @@ -164,20 +206,24 @@ struct MIGRAPHX_EXPORT shape
MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);

// add, subtract, multiply fixed std::size_t dimension
dynamic_dimension& operator+=(const std::size_t& x);
dynamic_dimension& operator-=(const std::size_t& x);
dynamic_dimension& operator*=(const std::size_t& x);
MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x,
const std::size_t& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x,
const dynamic_dimension& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x,
const std::size_t& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator*(const dynamic_dimension& x,
const std::size_t& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator*(const std::size_t& x,
const dynamic_dimension& y);
// clang-format off
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(binary_op, assign_op) \
dynamic_dimension& operator assign_op(const dynamic_dimension& x); \
dynamic_dimension& operator assign_op(const std::size_t& x); \
MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \
const dynamic_dimension& x, const dynamic_dimension& y); \
MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \
const dynamic_dimension& x, const std::size_t& y); \
MIGRAPHX_EXPORT friend dynamic_dimension operator binary_op( \
const std::size_t& x, const dynamic_dimension& y);

MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(+, +=)
MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(-, -=)
MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(*, *=)
MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP(/, /=)
#undef MIGRAPHX_SHAPE_DYN_DIM_DEFINE_OP
// clang-format on
};

static std::string to_sizes_string(const std::vector<shape>& shapes);
Expand All @@ -202,8 +248,10 @@ struct MIGRAPHX_EXPORT shape
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
shape(type_t t, std::initializer_list<std::size_t> l, std::initializer_list<std::size_t> s);

shape(type_t t, std::vector<dynamic_dimension> dims);
shape(type_t t, std::vector<dynamic_dimension> dims, std::vector<sym::expr> dstrides);

// Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
Expand Down Expand Up @@ -242,6 +290,9 @@ struct MIGRAPHX_EXPORT shape
*/
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
static shape from_permutation(type_t t,
const std::vector<dynamic_dimension>& dds,
const std::vector<int64_t>& perm);

type_t type() const;
const std::vector<std::size_t>& lens() const;
Expand Down Expand Up @@ -272,6 +323,9 @@ struct MIGRAPHX_EXPORT shape

const std::vector<dynamic_dimension>& dyn_dims() const;

bool symbolic() const;
Comment thread
CharlieL7 marked this conversation as resolved.
const std::vector<sym::expr>& dyn_strides() const;

/*!
* Minimum lengths for dynamic shape.
* lens() for static shape.
Expand Down Expand Up @@ -388,14 +442,17 @@ struct MIGRAPHX_EXPORT shape

shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_lens(type_t t, const std::vector<dynamic_dimension>& dds) const;
shape with_lens(const std::vector<dynamic_dimension>& dds) const;

shape with_type(type_t t) const;

// convert the shape to an equivalent dynamic shape with empty optimals
// convert the shape to an equivalent dynamic shape with constant symbolic strides
shape to_dynamic() const;

// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
shape to_static(const std::unordered_map<sym::expr, std::size_t>& symbol_map) const;

MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y);
MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y);
Expand Down
55 changes: 53 additions & 2 deletions src/include/migraphx/sym.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
#include <cstdint>
#include <memory>
#include <ostream>
#include <set>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <variant>

#include <migraphx/config.hpp>

Expand All @@ -40,8 +43,48 @@ struct value;

namespace sym {

// Scalar value held by literal expressions and interval bounds. Wraps a
// variant so that integer-literal initialization is unambiguous on stricter
// libstdc++ versions.
struct scalar
{
std::variant<int64_t, double> value;

constexpr scalar() = default;

template <class T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use MIGRAPHX_REQUIRES and write the traits as std::is_integral<T>{} instead of using _v.

constexpr scalar(T v) : value{int64_t{v}} // NOLINT(google-explicit-constructor)
{
}

template <class T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
constexpr scalar(T v) : value{double{v}} // NOLINT(google-explicit-constructor)
{
}

friend bool operator==(const scalar& a, const scalar& b) { return a.value == b.value; }
friend bool operator!=(const scalar& a, const scalar& b) { return not(a == b); }
};

template <class To>
To to(const scalar& v)
{
return std::visit([](auto x) -> To { return x; }, v.value);
}

struct interval
{
scalar min = int64_t{0};
scalar max = int64_t{0};
friend bool operator==(const interval& a, const interval& b)
{
return a.min == b.min and a.max == b.max;
}
friend bool operator!=(const interval& a, const interval& b) { return not(a == b); }
};

struct expr;
MIGRAPHX_EXPORT expr var(const std::string& name);
MIGRAPHX_EXPORT expr var(const std::string& name, interval bounds, std::set<int64_t> optimals = {});
MIGRAPHX_EXPORT expr lit(int64_t n);
MIGRAPHX_EXPORT expr parse(const std::string& s);

Expand All @@ -50,11 +93,14 @@ struct MIGRAPHX_EXPORT expr
expr();

bool empty() const;
bool is_literal() const;
std::size_t hash() const;
std::string to_string() const;
value to_value() const;
void from_value(const value& v);
std::size_t eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const;
interval eval_interval() const;
std::set<std::size_t> eval_optimals() const;
expr subs(const std::unordered_map<expr, expr>& symbol_map) const;

MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b);
Expand All @@ -63,6 +109,10 @@ struct MIGRAPHX_EXPORT expr
MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b);
MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b);
MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b);
MIGRAPHX_EXPORT friend bool operator<(const expr& a, const expr& b);
friend bool operator>(const expr& a, const expr& b) { return b < a; }
friend bool operator<=(const expr& a, const expr& b) { return not(b < a); }
friend bool operator>=(const expr& a, const expr& b) { return not(a < b); }
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e);

friend expr operator+(const expr& a, int64_t b) { return a + lit(b); }
Expand All @@ -76,7 +126,8 @@ struct MIGRAPHX_EXPORT expr

struct impl;

MIGRAPHX_EXPORT friend expr var(const std::string& name);
MIGRAPHX_EXPORT friend expr
var(const std::string& name, interval bounds, std::set<int64_t> optimals);
MIGRAPHX_EXPORT friend expr lit(int64_t n);
MIGRAPHX_EXPORT friend expr parse(const std::string& s);

Expand Down
Loading
Loading