Skip to content

[AIMIGRAPHX-835] integrate symbolic expression in dynamic_dimension#4702

Merged
causten merged 88 commits intodevelopfrom
sym_dim_integration
Apr 25, 2026
Merged

[AIMIGRAPHX-835] integrate symbolic expression in dynamic_dimension#4702
causten merged 88 commits intodevelopfrom
sym_dim_integration

Conversation

@shivadbhavsar
Copy link
Copy Markdown
Contributor

@shivadbhavsar shivadbhavsar commented Mar 25, 2026

Motivation

MIGraphX's current dynamic shape system is range-based: each dimension is a dynamic_dimension with {min, max, optimals}. This loses dimension identity -- if two inputs share a batch dimension N, the system just sees [1..8] independently for each. Consequences:

Fusion passes can't prove two dynamic dims are the same, preventing valid fusions
Stride computation is impossible for dynamic shapes (strides depend on actual dimension values)
Multibroadcast must carry all inputs at runtime to determine output shapes
Matchers can't reason about shape relationships symbolically

This PR uses the symbolic library previously added to better define dynamic shapes to solve these issues. It also aims to preserve the current functionality with the pure range based implementation.

Technical Details

dynamic_dimension gains a sym_expr field

A new sym::expr sym_expr member sits alongside the existing optional<interval> range and optional<set<size_t>> optimals. A 1-arg constructor dd{expr} creates symbolic dimensions, and lazy accessors get_interval() / get_optimals() evaluate bounds from sym_expr when symbolic or read from range/optimals otherwise. The is_symbolic() predicate (not sym_expr.empty()) distinguishes the two kinds. Range-based and symbolic dimensions are kept strictly separate: dd{3,3} is range-based, dd{lit(3)} is symbolic, with no auto-promotion between them.

Symbolic arithmetic propagation

The apply_op helper in shape.cpp dispatches arithmetic based on operand types. When both operands are symbolic, the result is symbolic with expressions composed algebraically. When both are range-based, existing range logic applies unchanged. Mixed symbolic/range falls back to range via interval evaluation. Plain size_t operands are promoted to dd{sym::lit(x)}, so dd{n} + 2 stays symbolic while dd{n} + dd{3,3} falls back to range.

Intersection and other dynamic_dimension methods

Updated to handle the symbolic case. Two symbolic dimensions intersect only if they share the same expression; mixed symbolic/range falls back to interval comparison.

Serialization

The reflect() method includes sym_expr so symbolic dimensions round-trip through serialization.

Tests

Shape tests cover symbolic dimension arithmetic, mixed symbolic/range interactions, shape property computation with symbolic dimensions, and operator shape propagation (convolution, pooling, dot) with symbolic inputs.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@shivadbhavsar shivadbhavsar self-assigned this Mar 25, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 26, 2026

Codecov Report

❌ Patch coverage is 95.33469% with 23 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/shape.cpp 93.54% 19 Missing ⚠️
src/sym.cpp 98.43% 2 Missing ⚠️
src/include/migraphx/shape.hpp 96.15% 1 Missing ⚠️
src/permutation.cpp 96.55% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4702      +/-   ##
===========================================
+ Coverage    92.46%   92.52%   +0.05%     
===========================================
  Files          583      583              
  Lines        29564    29967     +403     
===========================================
+ Hits         27336    27724     +388     
- Misses        2228     2243      +15     
Files with missing lines Coverage Δ
src/gemm.cpp 100.00% <100.00%> (ø)
src/include/migraphx/sym.hpp 100.00% <100.00%> (ø)
src/include/migraphx/shape.hpp 91.36% <96.15%> (+1.53%) ⬆️
src/permutation.cpp 83.05% <96.55%> (+12.46%) ⬆️
src/sym.cpp 97.10% <98.43%> (+1.28%) ⬆️
src/shape.cpp 91.78% <93.54%> (-0.19%) ⬇️

... and 16 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@shivadbhavsar
Copy link
Copy Markdown
Contributor Author

Added features that were originally planned for a second PR (#4760) all in here since splitting it cleanly would be cumbersome. Ignore the approval and please re-review when possible @pfultz2 @CharlieL7 . Also updated the new sym APIs so that integrating with #4782 should be a bit cleaner

Comment on lines +97 to +99
shape batch_shape(s.type(),
std::vector<std::size_t>(s.lens().begin(), s.lens().end() - 2),
std::vector<std::size_t>(s.strides().begin(), s.strides().end() - 2));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did these need changes with the new initializer_list constructor?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old libstd in sles kept causing ci issues

Comment thread src/include/migraphx/sym.hpp Outdated
Comment on lines +56 to +58
MIGRAPHX_EXPORT expr var(const std::string& name,
interval bounds = {1, 1},
std::set<int64_t> optimals = {});
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the bounds and optimals in the symbolic variable feels wrong to me. The bounds should be supplied from the interval?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing the bigger picture now. This is expecting that an expression can have variables that each have different intervals. Making it different from the top-level interval in the dynamic_dimension. I'm thinking there should be a map between variables and their intervals stored in the program or module itself.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I had put a comment somewhere regarding this (maybe in the shape implementation). The big picture is that symbolic shapes will not even have a top level interval in dynamic_dimension. Its implemented as optional because it should only exist for range-based shapes for backward compatibility (and any other reason we might want a pure range-based dynamic shape). For symbolic shapes, when the interval is needed, it will be computed on the fly based on the symbols in the symbolic expression for that shape

Comment thread src/permutation.cpp Outdated
Comment on lines +60 to +61
return std::make_tuple(strides[x].eval_interval().max,
dds[x].sym_expr.eval_interval().max);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is finding the permutation by evaluating out the intervals and then ordering the symbolic expressions using those. Is this why the symbolic variable needs the interval? This feels incorrect to me and instead if we have the same assumptions of monotonic increasing functions, evaluating any random interval should work?

Copy link
Copy Markdown
Contributor Author

@shivadbhavsar shivadbhavsar Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had a large comment block on is_sorted_strides to address this very thing but somehow it got removed when i was consolidating changes after the interval refactor. It should be there now.
Root of the issue is that ideally we should be sorting the expressions (using intervals), but when you have variables with min=1, eg. n{1, 4}, c{1,8}, you cant technically say n < nc, because n == nc at their min values. That causes all sorts of problems when trying to sort these using intervals and so the simplest pragmatic approach i could find was to simply use the interval max for stride sorting.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the idea of permutations makes sense for symbolic shapes. Take this contrived example:
strides = {n , (n*c) / d}, n = {1, 2}, c = {1, 2}, d = {1,2}
for n = 1, c = 2, d = 1: strides = {1, 2} transposed
for n = 2, c = 1, d = 2: strides = {2, 1} standard

Maybe instead the order for the static shape must be determined at runtime unless something like this holds : stride_0 - stride_1 < 0 where one symbolic stride is strictly less than other other for positive, non-zero values.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would a stride expression have a division by a non constant symbol? Is there any op that would actually create such a stride?
The reason I think we want to support the permutation functions for sym shapes is because some optimizations and shape ops will not work without this (or will need to implement other logic in a different way to get it to work).

Regardless, I got claude to try and create breaking examples for this using stride manipulating shape ops (like transpose, slice, step, broadcasts, reshape), and the only way it was able to come up with an ambiguous case like this is if the input bounds are just really badly defined.

Eg.:

Step + transpose with degenerate bounds:

Input [a, b], b ∈ [2, 16], standard strides [b, 1].

After step(axis=1, step=8): dims [a, ceil(b/8)], strides [b, 8].

After transpose [1, 0]: strides [8, b].

At b=2: strides [8, 2] → perm [0, 1]
At b=16: strides [8, 16] → perm [1, 0]

Here the input bounds dont really make sense considering the step value, and it would even give an incorrect result for the static shape case if compiling the model for something like [2,2] input shape.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, updated to be more clear about assumptions being made, and explicitly throw when intervals cause contradictory results at boundaries

@shivadbhavsar shivadbhavsar requested a review from CharlieL7 April 20, 2026 15:36
Copy link
Copy Markdown
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing the

Comment thread src/permutation.cpp Outdated
Comment on lines +60 to +61
return std::make_tuple(strides[x].eval_interval().max,
dds[x].sym_expr.eval_interval().max);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the idea of permutations makes sense for symbolic shapes. Take this contrived example:
strides = {n , (n*c) / d}, n = {1, 2}, c = {1, 2}, d = {1,2}
for n = 1, c = 2, d = 1: strides = {1, 2} transposed
for n = 2, c = 1, d = 2: strides = {2, 1} standard

Maybe instead the order for the static shape must be determined at runtime unless something like this holds : stride_0 - stride_1 < 0 where one symbolic stride is strictly less than other other for positive, non-zero values.

@shivadbhavsar shivadbhavsar requested a review from CharlieL7 April 21, 2026 22:37
Copy link
Copy Markdown
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, needs some small additional changes based on discussion on Teams

Comment thread src/include/migraphx/shape.hpp Outdated
Comment thread src/include/migraphx/shape.hpp Outdated
Comment thread src/permutation.cpp Outdated
Comment thread src/permutation.cpp Outdated
Comment thread src/include/migraphx/sym.hpp Outdated
@shivadbhavsar shivadbhavsar requested a review from pfultz2 April 24, 2026 22:23
Comment thread src/include/migraphx/sym.hpp Outdated
// `interval{1, 8}` or `var("n", {1, 8})`) resolves unambiguously rather
// than triggering the variant converting-ctor's int-vs-double tie that
// some libstdc++ versions (notably on SLES) reject.
interval(int64_t mn, int64_t mx) : min{mn}, max{mx} {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? The problem is that we cant write var("n", {.min = 1, .max = 8}) if we add these instructions.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even without this i think var("n", {.min = 1, .max = 8}) breaks on sles for the same reason. I'd have to modify tons of calls in all the tests (this PRs and subsequent PRs that are already drafted). I'm not completely against doing that but its going to be more verbose and I dont think our sles CI will ever let us write designated inits anyway

Comment thread src/include/migraphx/sym.hpp Outdated

constexpr scalar() = default;

template <class T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use MIGRAPHX_REQUIRES and write the traits as std::is_integral<T>{} instead of using _v.

@causten causten merged commit 4861148 into develop Apr 25, 2026
34 checks passed
@causten causten deleted the sym_dim_integration branch April 25, 2026 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants