diff --git a/src/include/migraphx/op/as_shape.hpp b/src/include/migraphx/op/as_shape.hpp index 7618451a317..9058827e6ca 100644 --- a/src/include/migraphx/op/as_shape.hpp +++ b/src/include/migraphx/op/as_shape.hpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -45,13 +46,17 @@ struct as_shape std::string name() const { return "as_shape"; } shape compute_shape(const std::vector& inputs) const { - check_shapes{inputs, *this}.has(1).standard(); - assert(inputs.front().elements() >= s.elements()); + check_shapes{inputs, *this, true}.has(1).standard(); + if(inputs.front().dynamic() and not inputs.front().symbolic()) + MIGRAPHX_THROW("AS_SHAPE: input must be static or symbolic"); + if(s.dynamic() and not s.symbolic()) + MIGRAPHX_THROW("AS_SHAPE: target shape must be static or symbolic"); + assert(inputs.front().sym_elements() >= s.sym_elements()); return s; } - argument compute(shape output_shape, std::vector args) const + argument compute(const dyn_output& dyn_out, std::vector args) const { - return args.front().reshape(output_shape); + return args.front().reshape(dyn_out.computed_shape); } std::vector output_alias(const std::vector&) const { return {0}; } }; diff --git a/src/include/migraphx/op/contiguous.hpp b/src/include/migraphx/op/contiguous.hpp index 695869aa577..b182ec77a24 100644 --- a/src/include/migraphx/op/contiguous.hpp +++ b/src/include/migraphx/op/contiguous.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -47,17 +47,11 @@ struct contiguous shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this, true}.has(1); - auto s0 = inputs.front(); + const auto& s0 = inputs.front(); + if(s0.dynamic()) - { - return s0; - } - else - { - const auto& lens = s0.lens(); - auto t = s0.type(); - return {t, lens}; - } + return {s0.type(), s0.dyn_dims()}; + return {s0.type(), s0.lens()}; } argument compute(const dyn_output& dyn_out, std::vector args) const diff --git a/src/include/migraphx/op/transpose.hpp b/src/include/migraphx/op/transpose.hpp index bb508ecec76..2f2d62e069c 100644 --- a/src/include/migraphx/op/transpose.hpp +++ b/src/include/migraphx/op/transpose.hpp @@ -63,7 +63,8 @@ struct transpose MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); } - if(input.dynamic()) + // Range-only dynamic shapes do not carry strides; permute dims only. + if(input.dynamic() and not input.symbolic()) { std::vector output_dyn_dims(input.ndim()); std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) { @@ -71,20 +72,17 @@ struct transpose }); return {input.type(), output_dyn_dims}; } - else - { - const auto& input_lens = input.lens(); - const auto& input_strides = input.strides(); - std::vector output_lens(input.ndim()); - std::vector output_strides(input.ndim()); + auto permute = [&](const auto& src) { + std::vector::value_type> out(input.ndim()); for(std::size_t i = 0; i < input.ndim(); i++) - { - output_lens[i] = input_lens[dims[i]]; - output_strides[i] = input_strides[dims[i]]; - } - return {input.type(), output_lens, output_strides}; - } + out[i] = src[dims[i]]; + return out; + }; + + if(input.symbolic()) + return {input.type(), permute(input.dyn_dims()), permute(input.dyn_strides())}; + return {input.type(), permute(input.lens()), permute(input.strides())}; } argument compute(const dyn_output& dyn_out, std::vector args) const diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 1a8c1f9d53e..4715bba4135 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -309,6 +309,14 @@ struct MIGRAPHX_EXPORT shape */ std::size_t elements() const; + /*! + * Return the number of elements as a symbolic expression. Works for any + * shape kind: for static shapes returns a literal; for symbolic shapes + * returns the product of the symbolic dimension expressions. Throws for + * range-only dynamic shapes. + */ + sym::expr sym_elements() const; + /*! * Return the number of total bytes used for storage of the tensor data; includes subshapes. * For dynamic shape, returns the maximum number of bytes presuming a packed shape. diff --git a/src/include/migraphx/sym.hpp b/src/include/migraphx/sym.hpp index f5201cdea76..df8a87c4c7b 100644 --- a/src/include/migraphx/sym.hpp +++ b/src/include/migraphx/sym.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include diff --git a/src/shape.cpp b/src/shape.cpp index 0cf5470e5a6..a06c84164f9 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -372,6 +372,13 @@ struct shape_impl return compute_elements(m_lens); } + sym::expr sym_elements() const + { + if(not m_dyn_dims.empty() and not all_dims_symbolic()) + MIGRAPHX_THROW("SHAPE: sym_elements() called on a range-only dynamic shape"); + return compute_elements(sym_dims()); + } + std::size_t get_index(size_t i) const { std::size_t result = 0; @@ -645,6 +652,8 @@ std::size_t shape::ndim() const std::size_t shape::elements() const { return impl->elements(); } +sym::expr shape::sym_elements() const { return impl->sym_elements(); } + std::size_t shape::bytes() const { if(this->sub_shapes().empty()) diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 2a902ef2935..7fd355da8a9 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,10 @@ #include "test.hpp" +using dd = migraphx::shape::dynamic_dimension; +using migraphx::sym::lit; +using migraphx::sym::var; + template static void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs) { @@ -519,6 +524,25 @@ TEST_CASE(contiguous_dyn_shape) expect_shape(s0, migraphx::make_op("contiguous"), s0); } +TEST_CASE(contiguous_sym_standard) +{ + auto n = var("N", {1, 64}); + auto h = var("H", {1, 128}); + migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{h}}}; + expect_shape(s0, migraphx::make_op("contiguous"), s0); +} + +TEST_CASE(contiguous_sym_transposed) +{ + auto n = var("N", {1, 64}); + auto h = var("H", {1, 128}); + // Transposed symbolic input: standard {N, H} -> transposed {H, N}. + auto input = + migraphx::shape::from_permutation(migraphx::shape::float_type, {dd{h}, dd{n}}, {1, 0}); + migraphx::shape output{migraphx::shape::float_type, {dd{h}, dd{n}}}; + expect_shape(output, migraphx::make_op("contiguous"), input); +} + TEST_CASE(contiguous_shape_scalar) { migraphx::shape output{migraphx::shape::float_type, {1}}; @@ -526,6 +550,23 @@ TEST_CASE(contiguous_shape_scalar) expect_shape(output, migraphx::make_op("contiguous"), input); } +TEST_CASE(as_shape_static) +{ + migraphx::shape input{migraphx::shape::float_type, {4, 4}}; + migraphx::shape target{migraphx::shape::float_type, {2, 8}}; + expect_shape( + target, migraphx::make_op("as_shape", {{"shape", migraphx::to_value(target)}}), input); +} + +TEST_CASE(as_shape_sym) +{ + auto n = var("N", {1, 64}); + migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{lit(16)}}}; + migraphx::shape target{migraphx::shape::float_type, {dd{n}, dd{lit(4)}, dd{lit(4)}}}; + expect_shape( + target, migraphx::make_op("as_shape", {{"shape", migraphx::to_value(target)}}), input); +} + TEST_CASE(contiguous_shape_singleton_dim) { migraphx::shape output{migraphx::shape::float_type, {5, 1, 8}, {8, 8, 1}}; @@ -5241,6 +5282,45 @@ TEST_CASE(transpose_dyn_shape1) expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input); } +TEST_CASE(transpose_sym_identity) +{ + auto n = var("N", {1, 64}); + auto h = var("H", {1, 128}); + migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{h}}}; + expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input); +} + +TEST_CASE(transpose_sym_swap) +{ + auto n = var("N", {1, 64}); + auto h = var("H", {1, 128}); + migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{h}}}; + auto output = + migraphx::shape::from_permutation(migraphx::shape::float_type, {dd{h}, dd{n}}, {1, 0}); + expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input); +} + +TEST_CASE(transpose_sym_3d) +{ + auto n = var("N", {1, 64}); + auto c = var("C", {1, 32}); + auto h = var("H", {1, 128}); + migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{c}, dd{h}}}; + auto output = migraphx::shape::from_permutation( + migraphx::shape::float_type, {dd{h}, dd{c}, dd{n}}, {2, 1, 0}); + expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), input); + expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input); +} + +TEST_CASE(transpose_sym_mixed_literal) +{ + auto n = var("N", {1, 64}); + migraphx::shape input{migraphx::shape::float_type, {dd{n}, dd{lit(8)}, dd{lit(4)}}}; + auto output = migraphx::shape::from_permutation( + migraphx::shape::float_type, {dd{lit(4)}, dd{lit(8)}, dd{n}}, {2, 1, 0}); + expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input); +} + TEST_CASE(transpose_axes_error) { migraphx::shape input{migraphx::shape::float_type, {2, 2}};