Skip to content
Open
Show file tree
Hide file tree
Changes from 92 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
d2b684e
custom symbolic expression lib
shivadbhavsar Mar 24, 2026
aa55785
format
shivadbhavsar Mar 24, 2026
314f7cf
use visit
shivadbhavsar Mar 24, 2026
dcfe825
format
shivadbhavsar Mar 24, 2026
2ec0969
integrate symbolic expression in dynamic_dimension
shivadbhavsar Mar 25, 2026
18caf6b
tidy
shivadbhavsar Mar 25, 2026
483932b
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 25, 2026
200f3c4
Merge branch 'develop' into custom_sym_lib
shivadbhavsar Mar 25, 2026
7ff2045
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 25, 2026
6af3621
fix constructor ambiguity
shivadbhavsar Mar 25, 2026
b7d7c23
fix ambiguity
shivadbhavsar Mar 25, 2026
3486135
update namespace and interface design
shivadbhavsar Mar 26, 2026
edbce87
Merge branch 'develop' into custom_sym_lib
shivadbhavsar Mar 26, 2026
964f934
use int64 for literals
shivadbhavsar Mar 26, 2026
2719ae6
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
364bd23
fix merge
shivadbhavsar Mar 26, 2026
33614e0
change eval func name
shivadbhavsar Mar 26, 2026
2ba3b74
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
830594f
use int64 for internal eval
shivadbhavsar Mar 26, 2026
bd70d84
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
def3038
fix eval call
shivadbhavsar Mar 26, 2026
359070d
copilot comments
shivadbhavsar Mar 26, 2026
9ad996f
copilot review fix
shivadbhavsar Mar 26, 2026
b61b6d8
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
bda9f91
format and tidy
shivadbhavsar Mar 26, 2026
5027376
Merge branch 'custom_sym_lib' into sym_dim_integration
shivadbhavsar Mar 26, 2026
1209717
Merge branch 'sym_dim_integration' of https://github.com/ROCm/AMDMIGr…
shivadbhavsar Mar 26, 2026
003c9d3
tidy fix
shivadbhavsar Mar 30, 2026
3759299
tidy
shivadbhavsar Mar 30, 2026
50944fb
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
fe649ff
update the only call sites using the braced-init-list that cannot be …
shivadbhavsar Mar 30, 2026
7be6e7f
Merge remote-tracking branch 'origin/develop' into custom_sym_lib
shivadbhavsar Mar 30, 2026
1274d3a
address review comments
shivadbhavsar Mar 30, 2026
5b28774
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
83da044
license
shivadbhavsar Mar 30, 2026
d680d55
reduce complexity
shivadbhavsar Mar 30, 2026
8d5629b
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 30, 2026
e7ca1d6
update calls to eval_uint
shivadbhavsar Mar 30, 2026
54debb5
clean up test file
shivadbhavsar Mar 31, 2026
c7f698c
review comments
shivadbhavsar Mar 31, 2026
149b661
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 31, 2026
293b5d8
merge and tidy
shivadbhavsar Mar 31, 2026
cc521e4
license
shivadbhavsar Mar 31, 2026
3b7077a
Merge remote-tracking branch 'origin/custom_sym_lib' into sym_dim_int…
shivadbhavsar Mar 31, 2026
47ec30a
fix style
shivadbhavsar Apr 1, 2026
59a7ef4
normalize fixed dynamic dim representation
shivadbhavsar Apr 1, 2026
95894db
fmt
shivadbhavsar Apr 1, 2026
be87f71
fix serialization and normalization
shivadbhavsar Apr 2, 2026
78f06c7
license
shivadbhavsar Apr 2, 2026
71d7ae7
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 6, 2026
1a68619
address review comments
shivadbhavsar Apr 7, 2026
2044073
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 7, 2026
d6c8d49
update tests for cleaned up intersection logic
shivadbhavsar Apr 7, 2026
4556366
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 8, 2026
378e3d7
review feedback updates
shivadbhavsar Apr 9, 2026
5412e60
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 9, 2026
0bf5680
fix tidy
shivadbhavsar Apr 10, 2026
4fa771a
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 10, 2026
2545a8e
fix callsite to remove disambiguity
shivadbhavsar Apr 10, 2026
e904332
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 11, 2026
c8b8df4
fix merge
shivadbhavsar Apr 11, 2026
4e0da9c
remove optional from sym_expr
shivadbhavsar Apr 13, 2026
ca454b2
refactor how dyn dim intervals are stored and accessed
shivadbhavsar Apr 14, 2026
29ca4d5
add defaults
shivadbhavsar Apr 14, 2026
a0026ba
add getter for optimals and update callsites
shivadbhavsar Apr 14, 2026
1db31d6
update to use get_interval() and remove min() and max()
shivadbhavsar Apr 14, 2026
c128adc
fix cppcheck
shivadbhavsar Apr 15, 2026
3b2c259
return optimals by value
shivadbhavsar Apr 15, 2026
49d7aa6
Merge remote-tracking branch 'origin/develop' into dyn_interval_refactor
shivadbhavsar Apr 15, 2026
cc74c4a
update has_optimal
shivadbhavsar Apr 15, 2026
44cd175
symbolic dimension integration (squashed)
shivadbhavsar Apr 15, 2026
ae322fc
update implementation to work on top of inverval refactor
shivadbhavsar Apr 15, 2026
7b5484e
Merge branch 'sym_dim_integration' of https://github.com/ROCm/AMDMIGr…
shivadbhavsar Apr 15, 2026
39e0442
Merge remote-tracking branch 'origin/develop' into sym_dim_integration
shivadbhavsar Apr 15, 2026
5be3d69
fix old constructor
shivadbhavsar Apr 15, 2026
093964c
fix ambiguous call
shivadbhavsar Apr 16, 2026
72f00ea
fix cppcheck
shivadbhavsar Apr 16, 2026
812a1b1
add missing comment blocks
shivadbhavsar Apr 17, 2026
8573cfb
clearly state assumptions used when dealing with stride permutations …
shivadbhavsar Apr 21, 2026
718e599
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 21, 2026
d94f367
add missing tests
shivadbhavsar Apr 23, 2026
1747639
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 23, 2026
3e0e2c2
make var bounds non-optional and add deprecation TODO to clarify the …
shivadbhavsar Apr 24, 2026
4bbb2e6
Merge branch 'develop' into sym_dim_integration
shivadbhavsar Apr 24, 2026
bb4ae35
support transpose, contiguous, as_shape
shivadbhavsar Apr 24, 2026
94c7941
add scalar variant to ease merging 4782
shivadbhavsar Apr 24, 2026
4305db3
Merge branch 'sym_dim_integration' into sym_layout_ops
shivadbhavsar Apr 24, 2026
a64a0d1
fix brace-init ambiguity for sles
shivadbhavsar Apr 24, 2026
305a7cc
Merge branch 'sym_dim_integration' into sym_layout_ops
shivadbhavsar Apr 24, 2026
c4a33e1
wrap scalar variant to handle ambuigity
shivadbhavsar Apr 25, 2026
4da9542
Merge branch 'sym_dim_integration' into sym_layout_ops
shivadbhavsar Apr 25, 2026
89703e9
Merge branch 'develop' into sym_layout_ops
shivadbhavsar Apr 26, 2026
60a2388
Merge remote-tracking branch 'origin/to_symbolic_helper' into sym_lay…
shivadbhavsar Apr 29, 2026
a49d69a
ci and reviews
shivadbhavsar Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/include/migraphx/op/as_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dyn_output.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -45,13 +46,17 @@ struct as_shape
std::string name() const { return "as_shape"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
assert(inputs.front().elements() >= s.elements());
if(inputs.front().dynamic() and not inputs.front().symbolic())
MIGRAPHX_THROW("AS_SHAPE: input must be static or symbolic");
if(s.dynamic() and not s.symbolic())
MIGRAPHX_THROW("AS_SHAPE: target shape must be static or symbolic");
check_shapes{inputs, *this, true}.has(1).standard();
Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
assert(inputs.front().sym_elements() >= s.sym_elements());
Comment thread
shivadbhavsar marked this conversation as resolved.
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args.front().reshape(output_shape);
return args.front().reshape(dyn_out.computed_shape);
}
std::vector<std::size_t> output_alias(const std::vector<shape>&) const { return {0}; }
};
Expand Down
14 changes: 4 additions & 10 deletions src/include/migraphx/op/contiguous.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -47,17 +47,11 @@
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front();

Check warning on line 50 in src/include/migraphx/op/contiguous.hpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 's0' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]

Check warning on line 50 in src/include/migraphx/op/contiguous.hpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 's0' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]

Check warning on line 50 in src/include/migraphx/op/contiguous.hpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 's0' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]

Check warning on line 50 in src/include/migraphx/op/contiguous.hpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 's0' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]

Check warning on line 50 in src/include/migraphx/op/contiguous.hpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 's0' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]

if(s0.dynamic())
{
return s0;
}
else
{
const auto& lens = s0.lens();
auto t = s0.type();
return {t, lens};
}
return {s0.type(), s0.dyn_dims()};
return {s0.type(), s0.lens()};
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
24 changes: 11 additions & 13 deletions src/include/migraphx/op/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,26 @@ struct transpose
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
}

if(input.dynamic())
// Range-only dynamic shapes do not carry strides; permute dims only.
if(input.dynamic() and not input.symbolic())
{
std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
const auto& input_lens = input.lens();
const auto& input_strides = input.strides();

std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
auto permute = [&](const auto& src) {
std::vector<typename std::decay_t<decltype(src)>::value_type> out(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {input.type(), output_lens, output_strides};
}
out[i] = src[dims[i]];
return out;
};

if(input.symbolic())
return {input.type(), permute(input.dyn_dims()), permute(input.dyn_strides())};
return {input.type(), permute(input.lens()), permute(input.strides())};
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
8 changes: 8 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ struct MIGRAPHX_EXPORT shape
*/
std::size_t elements() const;

/*!
* Return the number of elements as a symbolic expression. Works for any
* shape kind: for static shapes returns a literal; for symbolic shapes
* returns the product of the symbolic dimension expressions. Throws for
* range-only dynamic shapes.
*/
sym::expr sym_elements() const;

/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/sym.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <ostream>
#include <set>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <variant>

Expand Down
9 changes: 9 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,13 @@ struct shape_impl
return compute_elements<std::size_t>(m_lens);
}

sym::expr sym_elements() const
{
if(not m_dyn_dims.empty() and not all_dims_symbolic())
MIGRAPHX_THROW("SHAPE: sym_elements() called on a range-only dynamic shape");
return compute_elements<sym::expr>(sym_dims());
}

std::size_t get_index(size_t i) const
{
std::size_t result = 0;
Expand Down Expand Up @@ -645,6 +652,8 @@ std::size_t shape::ndim() const

std::size_t shape::elements() const { return impl->elements(); }

sym::expr shape::sym_elements() const { return impl->sym_elements(); }

std::size_t shape::bytes() const
{
if(this->sub_shapes().empty())
Expand Down
80 changes: 80 additions & 0 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@
#include <migraphx/instruction.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/sym.hpp>
#include <sstream>
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

#include "test.hpp"

using dd = migraphx::shape::dynamic_dimension;
using migraphx::sym::lit;
using migraphx::sym::var;

template <class... Ts>
static void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
{
Expand Down Expand Up @@ -519,13 +524,49 @@ TEST_CASE(contiguous_dyn_shape)
expect_shape(s0, migraphx::make_op("contiguous"), s0);
}

TEST_CASE(contiguous_sym_standard)
{
auto n = var("N", {1, 64});
auto h = var("H", {1, 128});
migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{h}}};
expect_shape(s0, migraphx::make_op("contiguous"), s0);
}

TEST_CASE(contiguous_sym_transposed)
{
auto n = var("N", {1, 64});
auto h = var("H", {1, 128});
// Transposed symbolic input: standard {N, H} -> transposed {H, N}.
auto input =
migraphx::shape::from_permutation(migraphx::shape::float_type, {dd{h}, dd{n}}, {1, 0});
migraphx::shape output{migraphx::shape::float_type, {dd{h}, dd{n}}};
expect_shape(output, migraphx::make_op("contiguous"), input);
}

TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type, {1}};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}

TEST_CASE(as_shape_static)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
migraphx::shape target{migraphx::shape::float_type, {2, 8}};
expect_shape(
target, migraphx::make_op("as_shape", {{"shape", migraphx::to_value(target)}}), input);
}

TEST_CASE(as_shape_sym)
{
auto n = var("N", {1, 64});
migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{lit(16)}}};
migraphx::shape target{migraphx::shape::float_type, {dd{n}, dd{lit(4)}, dd{lit(4)}}};
expect_shape(
target, migraphx::make_op("as_shape", {{"shape", migraphx::to_value(target)}}), input);
}

Comment thread
shivadbhavsar marked this conversation as resolved.
TEST_CASE(contiguous_shape_singleton_dim)
{
migraphx::shape output{migraphx::shape::float_type, {5, 1, 8}, {8, 8, 1}};
Expand Down Expand Up @@ -5241,6 +5282,45 @@ TEST_CASE(transpose_dyn_shape1)
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}

TEST_CASE(transpose_sym_identity)
{
auto n = var("N", {1, 64});
auto h = var("H", {1, 128});
migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{h}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
}

TEST_CASE(transpose_sym_swap)
{
auto n = var("N", {1, 64});
auto h = var("H", {1, 128});
migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{h}}};
auto output =
migraphx::shape::from_permutation(migraphx::shape::float_type, {dd{h}, dd{n}}, {1, 0});
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
}

TEST_CASE(transpose_sym_3d)
{
auto n = var("N", {1, 64});
auto c = var("C", {1, 32});
auto h = var("H", {1, 128});
migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{c}, dd{h}}};
auto output = migraphx::shape::from_permutation(
migraphx::shape::float_type, {dd{h}, dd{c}, dd{n}}, {2, 1, 0});
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}

TEST_CASE(transpose_sym_mixed_literal)
{
auto n = var("N", {1, 64});
migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{lit(8)}, dd{lit(4)}}};
auto output = migraphx::shape::from_permutation(
migraphx::shape::float_type, {dd{lit(4)}, dd{lit(8)}, dd{n}}, {2, 1, 0});
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}

TEST_CASE(transpose_axes_error)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
Expand Down
Loading