diff --git a/src/common.cpp b/src/common.cpp index 10bc5d0707f..ba2f5aebd9d 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -91,10 +91,8 @@ compute_broadcasted_dyn_dims(std::vector dds0, std::vector compute_broadcasted_dyn_dims(shape s0, shape s1) { - // change both shapes to dynamic_dimension representation - s0 = s0.to_dynamic(); - s1 = s1.to_dynamic(); - return compute_broadcasted_dyn_dims(s0.dyn_dims(), s1.dyn_dims()); + auto aligned = shape::to_dynamic({s0, s1}); + return compute_broadcasted_dyn_dims(aligned[0].dyn_dims(), aligned[1].dyn_dims()); } std::vector compute_common_dyn_dims(const std::vector& shapes) @@ -235,19 +233,43 @@ instruction_ref add_common_op(module& m, return insert_common_op(m, m.end(), op, std::move(inputs), options); } -shape make_bcast_shape(const shape& input_shape, const std::vector& bcast_lens) +// Unified broadcast-shape builder: propagates input strides on matching axes and leaves +// `zero` on broadcast axes. Works for both static (Dim = std::size_t, Stride = std::size_t) +// and symbolic (Dim = shape::dynamic_dimension, Stride = sym::expr) representations. +template +static shape make_bcast_shape_impl(shape::type_t type, + const std::vector& input_dims, + const std::vector& input_strides, + const std::vector& bcast_dims, + const Stride& zero) { - assert(not input_shape.dynamic()); - auto offset = bcast_lens.size() - input_shape.ndim(); - std::vector bcast_strides(bcast_lens.size(), 0); - for(std::ptrdiff_t i : reverse(range(input_shape.ndim()))) + assert(bcast_dims.size() >= input_dims.size()); + auto offset = bcast_dims.size() - input_dims.size(); + std::vector bcast_strides(bcast_dims.size(), zero); + for(std::size_t i = 0; i < input_dims.size(); ++i) { - if(bcast_lens.at(i + offset) == input_shape.lens()[i]) - { - bcast_strides.at(i + offset) = input_shape.strides()[i]; - } + if(bcast_dims[i + offset] == input_dims[i]) + bcast_strides[i + offset] = input_strides[i]; } - return shape{input_shape.type(), bcast_lens, bcast_strides}; + return {type, bcast_dims, bcast_strides}; +} + +shape make_bcast_shape(const shape& input_shape, const std::vector& bcast_lens) +{ + assert(not input_shape.dynamic()); + return make_bcast_shape_impl( + input_shape.type(), input_shape.lens(), input_shape.strides(), bcast_lens, std::size_t{0}); +} + +shape make_bcast_shape(const shape& input_shape, + const std::vector& bcast_dyn_dims) +{ + assert(input_shape.symbolic()); + return make_bcast_shape_impl(input_shape.type(), + input_shape.dyn_dims(), + input_shape.dyn_strides(), + bcast_dyn_dims, + sym::lit(0)); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/common.hpp b/src/include/migraphx/common.hpp index 44338d66987..b405c3abf77 100644 --- a/src/include/migraphx/common.hpp +++ b/src/include/migraphx/common.hpp @@ -153,6 +153,21 @@ instruction_ref add_common_op(module& m, MIGRAPHX_EXPORT shape make_bcast_shape(const shape& input_shape, const std::vector& bcast_lens); +/** + * Calculates the broadcasted shape from a symbolic broadcast target. The input shape MUST + * already be symbolic (`input_shape.symbolic()`); callers bridging from a static shape + * should promote via `shape::to_symbolic()` first. Mirrors the single-modality contract of + * the static overload above. Broadcast axes receive `sym::lit(0)`; matching axes propagate + * the input's symbolic stride. + * + * @param input_shape symbolic dynamic shape to broadcast + * @param bcast_dyn_dims symbolic dynamic dimensions to broadcast to + * @return broadcasted shape with symbolic dyn_strides + */ +MIGRAPHX_EXPORT +shape make_bcast_shape(const shape& input_shape, + const std::vector& bcast_dyn_dims); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 55721d6551e..161285baba0 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -817,6 +817,7 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins) } MIGRAPHX_PRED_MATCHER(dynamic_shape, instruction_ref ins) { return ins->get_shape().dynamic(); } MIGRAPHX_PRED_MATCHER(static_shape, instruction_ref ins) { return not ins->get_shape().dynamic(); } +MIGRAPHX_PRED_MATCHER(symbolic_shape, instruction_ref ins) { return ins->get_shape().symbolic(); } MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) { return ins->get_shape().broadcasted(); diff --git a/src/include/migraphx/op/binary.hpp b/src/include/migraphx/op/binary.hpp index d5a3bf98f7b..0478c11559a 100644 --- a/src/include/migraphx/op/binary.hpp +++ b/src/include/migraphx/op/binary.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -70,7 +70,8 @@ struct binary : op_name .same_dims(); auto s0 = inputs.at(0); auto s1 = inputs.at(1); - if(s0.dynamic() or s1.dynamic()) + // Range-based dynamic (or mixed dynamic/static) inputs only support strict equality. + if((s0.dynamic() or s1.dynamic()) and not(s0.symbolic() and s1.symbolic())) { if(s0 == s1) return s0; @@ -86,10 +87,15 @@ struct binary : op_name } else if(s0.broadcasted() != s1.broadcasted()) { + if(s0.symbolic()) + return s0.broadcasted() ? s1.with_lens(s0.dyn_dims()) : s0.with_lens(s0.dyn_dims()); return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens()); } else { + if(s0.symbolic()) + return shape::from_permutation( + s0.type(), s0.dyn_dims(), find_permutation({s0, s1})); return shape::from_permutation(s0.type(), s0.lens(), find_permutation({s0, s1})); } } diff --git a/src/include/migraphx/op/broadcast.hpp b/src/include/migraphx/op/broadcast.hpp index f6da43940b0..ee7fe188fd7 100644 --- a/src/include/migraphx/op/broadcast.hpp +++ b/src/include/migraphx/op/broadcast.hpp @@ -45,6 +45,11 @@ namespace op { * (left-most to rightwards element-wise comparison) * ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0 * + * Symbolic 1 input version: opt-in via a fully-symbolic `output_dyn_dims` attribute. Input may be + * static (promoted via shape::to_symbolic()) or already symbolic. Range-based dynamic input is + * not allowed (per the "no mixing symbolic and range-based" design rule). `broadcast_lens` is + * not used in this mode. + * * 2 input version: * Broadcast the first input 1D shape into the second input shape based on the axis parameter. * Handles broadcasting a 1D static shape into a higher rank dynamic shape. @@ -52,96 +57,132 @@ namespace op { */ struct broadcast { - uint64_t axis = 0; - std::vector broadcast_lens = {}; + uint64_t axis = 0; + std::vector broadcast_lens = {}; + std::vector output_dyn_dims = {}; template static auto reflect(Self& self, F f) { - return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens")); + return pack(f(self.axis, "axis"), + f(self.broadcast_lens, "out_lens"), + f(self.output_dyn_dims, "out_dyn_dims")); } std::string name() const { return "broadcast"; } - shape compute_shape(std::vector inputs) const + + // Validate axis/dims and place `in_strides` at `axis`, filling broadcast positions with + // `zero`. Templated for the static (size_t) and symbolic (sym::expr) paths. + template + shape build_output(shape::type_t t, + const Target& target, + const Dims& in_dims, + const Strides& in_strides, + const Zero& zero) const { - check_shapes{inputs, *this, true}.has(1, 2); - auto s0 = inputs.at(0); - auto t = s0.type(); - if(inputs.size() == 1) + if(axis >= target.size()) + MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) + " is out of range"); + if(target.size() - axis < in_dims.size()) + MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims"); + for(std::size_t i = 0; i < in_dims.size(); ++i) { - // the ONNX broadcast op is deprecated now, so not handling the negative - // value of axis anymore - if(s0.dynamic()) - MIGRAPHX_THROW( - "BROADCAST: Single dynamic input shape not supported. Use two inputs."); - if(axis >= broadcast_lens.size()) - { - MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) + - " is out of range"); - } - if(broadcast_lens.size() - axis < s0.lens().size()) - { - MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims"); - } - if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis)) - { + if(target[axis + i] != in_dims[i]) MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); - } + } + std::vector bcast_strides(target.size(), zero); + std::copy(in_strides.begin(), in_strides.end(), bcast_strides.begin() + axis); + return shape{t, target, std::move(bcast_strides)}; + } - std::vector bcast_strides(broadcast_lens.size(), 0); - std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); - shape output{t, broadcast_lens, std::move(bcast_strides)}; - if(output.elements() < s0.elements()) - { - // don't think this can occur? - MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size"); - } - return output; + shape compute_shape_1in(shape s0) const + { + // the ONNX broadcast op is deprecated now, so not handling the negative + // value of axis anymore + const bool symbolic_target = not output_dyn_dims.empty() and + std::all_of(output_dyn_dims.begin(), + output_dyn_dims.end(), + [](const auto& d) { return d.is_symbolic(); }); + if(not output_dyn_dims.empty() and not symbolic_target) + MIGRAPHX_THROW("BROADCAST: output_dyn_dims must be fully symbolic"); + + if(s0.dynamic() and not(symbolic_target and s0.symbolic())) + MIGRAPHX_THROW("BROADCAST: Single dynamic input shape not supported. Use two inputs."); + + if(symbolic_target) + { + auto s0_sym = s0.to_symbolic(); + return build_output( + s0.type(), output_dyn_dims, s0_sym.dyn_dims(), s0_sym.dyn_strides(), sym::lit(0)); } - else + + auto output = + build_output(s0.type(), broadcast_lens, s0.lens(), s0.strides(), std::size_t{0}); + if(output.elements() < s0.elements()) { - // two inputs - auto s1 = inputs.at(1); - if(s0.dynamic()) - { - MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting " - "a dynamic shape"); - } - if(s0.ndim() != 1) - { - MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) + - ", only handle ndim = 1"); - } - if(axis >= s1.ndim()) - { - MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) + - " is out of range"); - } - if(s1.dynamic()) - { - s0 = s0.to_dynamic(); - if(s0.dyn_dims()[0] != s1.dyn_dims()[axis]) - { - MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis " - "dimension length (" + - migraphx::to_string(s0.dyn_dims()[0]) + - " != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")"); - } - return s1; - } + // don't think this can occur? + MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size"); + } + return output; + } + + shape compute_shape_2in(shape s0, shape s1) const + { + if(s0.ndim() != 1) + { + MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) + + ", only handle ndim = 1"); + } + + if(s0.symbolic() or s1.symbolic()) + { + auto s0_sym = s0.to_symbolic(); + auto s1_sym = s1.to_symbolic(); + return build_output( + s0.type(), s1_sym.dyn_dims(), s0_sym.dyn_dims(), s0_sym.dyn_strides(), sym::lit(0)); + } + + if(axis >= s1.ndim()) + { + MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) + " is out of range"); + } - if(s0.lens()[0] != s1.lens()[axis]) + // Range-based dynamic s0 alone is not supported. + if(s0.dynamic()) + { + MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting " + "a dynamic shape"); + } + if(s1.dynamic()) + { + s0 = s0.to_dynamic(); + if(s0.dyn_dims()[0] != s1.dyn_dims()[axis]) { - MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis " + MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis " "dimension length (" + - migraphx::to_string(s0.lens()[0]) + - " != " + migraphx::to_string(s1.lens()[axis]) + ")"); + migraphx::to_string(s0.dyn_dims()[0]) + + " != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")"); } - std::vector bcast_strides(s1.ndim(), 0); - std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); - shape output{t, s1.lens(), std::move(bcast_strides)}; - return output; + return s1; + } + + if(s0.lens()[0] != s1.lens()[axis]) + { + MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis " + "dimension length (" + + migraphx::to_string(s0.lens()[0]) + + " != " + migraphx::to_string(s1.lens()[axis]) + ")"); } + std::vector bcast_strides(s1.ndim(), 0); + std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); + return shape{s0.type(), s1.lens(), std::move(bcast_strides)}; + } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this, true}.has(1, 2); + if(inputs.size() == 1) + return compute_shape_1in(inputs.at(0)); + return compute_shape_2in(inputs.at(0), inputs.at(1)); } argument compute(const dyn_output& dyn_out, std::vector args) const diff --git a/src/include/migraphx/op/multibroadcast.hpp b/src/include/migraphx/op/multibroadcast.hpp index 735a99d4b54..205502a41f2 100644 --- a/src/include/migraphx/op/multibroadcast.hpp +++ b/src/include/migraphx/op/multibroadcast.hpp @@ -37,7 +37,8 @@ namespace op { /** * Broadcast multiple dimensions between two tensors. * Two versions of this operator: 1 input and 2+ inputs. - * One input version uses output_lens attribute and broadcasts to it. + * One input version uses output_lens (static target) or output_dyn_dims (symbolic target); + * see compute_shape for the symbolic single-input contract. * 2+ inputs version broadcasts first input to the common shape at evaluation time. */ struct multibroadcast @@ -59,7 +60,7 @@ struct multibroadcast { check_shapes{inputs, *this, true}.has_at_least(1); - auto t = inputs.at(0).type(); + auto t = inputs.at(0).type(); const auto& s0 = inputs.at(0); if(s0.ndim() < 1) @@ -69,25 +70,42 @@ struct multibroadcast if(inputs.size() == 1) { - if(s0.dynamic()) + // Symbolic 1-input mode: opt-in via a fully-symbolic output_dyn_dims attribute. + // Input may be static (bridged via to_symbolic()) or already symbolic. + // Range-based dynamic input is not allowed. + const bool symbolic_target = not output_dyn_dims.empty() and + std::all_of(output_dyn_dims.begin(), + output_dyn_dims.end(), + [](const auto& d) { return d.is_symbolic(); }); + if(not output_dyn_dims.empty() and not symbolic_target) + MIGRAPHX_THROW("MULTIBROADCAST: output_dyn_dims must be fully symbolic"); + + if(s0.dynamic() and not(symbolic_target and s0.symbolic())) MIGRAPHX_THROW( "MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs."); - if(s0.ndim() > output_lens.size()) - { - MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size"); - } - auto offset = output_lens.size() - s0.ndim(); - for(std::ptrdiff_t i = s0.ndim() - 1; i >= 0; i--) - { - if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1) + // Shared validation: input dims must align with target dims, with axis-1 broadcast. + auto validate = [](const auto& in_dims, const auto& out_dims) { + if(in_dims.size() > out_dims.size()) + MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size"); + auto offset = out_dims.size() - in_dims.size(); + for(std::ptrdiff_t i = in_dims.size() - 1; i >= 0; --i) { - MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) + - "} cannot be broadcasted to {" + to_string_range(output_lens) + - "}!"); + if(out_dims[i + offset] != in_dims[i] and in_dims[i] != 1) + MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(in_dims) + + "} cannot be broadcasted to {" + to_string_range(out_dims) + + "}!"); } + }; + + if(symbolic_target) + { + auto s0_sym = s0.to_symbolic(); + validate(s0_sym.dyn_dims(), output_dyn_dims); + return make_bcast_shape(s0_sym, output_dyn_dims); } + validate(s0.lens(), output_lens); return make_bcast_shape(s0, output_lens); } else @@ -105,7 +123,7 @@ struct multibroadcast else { // output_lens will not be set for 2+ input version - auto bcast_lens = compute_common_lens(inputs); + auto bcast_lens = compute_common_lens(inputs); return make_bcast_shape(s0, bcast_lens); } } diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index 7d459d14b6f..d140b940806 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -154,6 +154,31 @@ struct find_static_2in_broadcasts : match::supports_dynamic_shapes } }; +/** + * Convert 2 input symbolic shape broadcast/multibroadcast into 1 input version. Mirrors the + * static analog above; fires only when the resulting shape is fully symbolic so that the + * target dimensions are known at compile time and can be carried as the 1-input op's + * `out_dyn_dims` attribute. Range-based dynamic broadcasts keep the 2-arg form for runtime + * resolution. + * From: + * broadcast_op(argument_with_symbolic_shape, argument_with_symbolic_shape) + * To: + * broadcast_op(argument); broadcast_op.out_dyn_dims = symbolic_output_dims + */ +struct find_symbolic_2in_broadcasts : match::supports_dynamic_shapes +{ + auto matcher() const { return match::broadcast(match::nargs(2), match::symbolic_shape()); } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto out_dyn_dims = ins->get_shape().dyn_dims(); + auto broadcast_op = ins->get_operator(); + broadcast_op.from_value({{"out_dyn_dims", to_value(out_dyn_dims)}}); + m.replace_instruction(ins, broadcast_op, ins->inputs().at(0)); + } +}; + /** * Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant. * From: @@ -412,7 +437,7 @@ struct find_const_alloc_fill : match::supports_dynamic_shapes * From: * broadcast_for_dot(static_shape_arg, static_shape_arg) * To: - * multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape + * multibroadcast(static_shape_arg); output_lens = static_broadcast_for_dotted_shape */ struct find_static_broadcast_for_dot : match::supports_dynamic_shapes { @@ -440,6 +465,29 @@ struct find_static_broadcast_for_dot : match::supports_dynamic_shapes } }; +/** + * Simplify broadcast_for_dot instructions that produce a fully symbolic shape. Mirrors the + * static analog above: the broadcast_for_dot is rewritten to a 1-input multibroadcast carrying + * the symbolic output dims as the `out_dyn_dims` attribute. + * From: + * broadcast_for_dot(symbolic_shape_arg, symbolic_shape_arg) + * To: + * multibroadcast(symbolic_shape_arg); out_dyn_dims = symbolic_broadcast_for_dotted_shape + */ +struct find_symbolic_broadcast_for_dot : match::supports_dynamic_shapes +{ + auto matcher() const { return match::name("broadcast_for_dot")(match::symbolic_shape()); } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto out_dyn_dims = ins->get_shape().dyn_dims(); + m.replace_instruction(ins, + make_op("multibroadcast", {{"out_dyn_dims", to_value(out_dyn_dims)}}), + ins->inputs().at(0)); + } +}; + /** * Simplify onehot instructions with static shape `indices` input and * a compile-time constant `depth` attribute or input. @@ -669,11 +717,13 @@ void simplify_dyn_ops::apply(module& m) const find_static_dimensions_of{}, find_const_alloc_reshapes{}, find_static_2in_broadcasts{}, + find_symbolic_2in_broadcasts{}, find_const_2in_slice{}, find_const_3in_slice{}, find_const_4in_slice{}, find_const_alloc_fill{}, find_static_broadcast_for_dot{}, + find_symbolic_broadcast_for_dot{}, find_static_onehot{}); match::find_matches(m, simplify_select_module_output_shape{}); } diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index e3a4f38d64b..e5491cc4296 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -207,6 +207,39 @@ TEST_CASE(binary_dyn_static_error) throws_shape(migraphx::make_op("add"), a_shape, b_shape); } +TEST_CASE(binary_sym_same_packed) +{ + auto n = var("n", {2, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{lit(2)}, dd{n}, dd{lit(4)}}}; + expect_shape(s, migraphx::make_op("add"), s, s); +} + +TEST_CASE(binary_sym_packed_vs_broadcasted) +{ + auto n = var("n", {2, 8}); + std::vector
dims{dd{lit(2)}, dd{n}, dd{lit(4)}}; + migraphx::shape sx{migraphx::shape::float_type, dims}; + migraphx::shape sy{migraphx::shape::float_type, dims, {lit(0), lit(4), lit(1)}}; + expect_shape(sx, migraphx::make_op("add"), sx, sy); +} + +TEST_CASE(binary_sym_nonpacked_permutation) +{ + std::vector
dims{dd{lit(4)}, dd{lit(3)}}; + migraphx::shape sx{migraphx::shape::float_type, dims, {lit(1), lit(8)}}; + migraphx::shape sy{migraphx::shape::float_type, dims, {lit(1), lit(16)}}; + auto sout = migraphx::shape::from_permutation(migraphx::shape::float_type, dims, {1, 0}); + expect_shape(sout, migraphx::make_op("mul"), sx, sy); +} + +TEST_CASE(binary_sym_with_range_dyn_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape sx{migraphx::shape::float_type, {dd{lit(2)}, dd{n}, dd{lit(4)}}}; + migraphx::shape sy{migraphx::shape::float_type, {dd{2, 2}, dd{2, 8}, dd{4, 4}}}; + throws_shape(migraphx::make_op("add"), sx, sy); +} + TEST_CASE(bit_cast_typesize_mismatch) { migraphx::shape a_shape{migraphx::shape::int8_type, {1, 4, 4}}; @@ -281,6 +314,83 @@ TEST_CASE(broadcast_1in_dyn_error) throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input); } +TEST_CASE(broadcast_1in_sym_match) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(3)}}}; + std::vector
out{dd{n}, dd{lit(3)}, dd{lit(4)}}; + migraphx::shape expected{migraphx::shape::float_type, out, {lit(0), lit(1), lit(0)}}; + expect_shape( + expected, + migraphx::make_op("broadcast", {{"axis", 1}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(broadcast_1in_sym_static_to_sym) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {3}}; + std::vector
out{dd{n}, dd{lit(3)}, dd{lit(4)}}; + migraphx::shape expected{migraphx::shape::float_type, out, {lit(0), lit(1), lit(0)}}; + expect_shape( + expected, + migraphx::make_op("broadcast", {{"axis", 1}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(broadcast_1in_sym_higher_rank) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(4)}, dd{lit(3)}}}; + std::vector
out{dd{lit(3)}, dd{n}, dd{lit(4)}, dd{lit(3)}}; + migraphx::shape expected{migraphx::shape::float_type, out, {lit(0), lit(0), lit(3), lit(1)}}; + expect_shape( + expected, + migraphx::make_op("broadcast", {{"axis", 2}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(broadcast_1in_sym_axis_out_of_range_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(3)}}}; + std::vector
out{dd{n}, dd{lit(3)}, dd{lit(4)}}; + throws_shape( + migraphx::make_op("broadcast", {{"axis", 4}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(broadcast_1in_sym_size_mismatch_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(3)}, dd{lit(4)}}}; + std::vector
out{dd{n}, dd{lit(3)}}; + throws_shape( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(broadcast_1in_sym_dim_mismatch_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(5)}}}; + std::vector
out{dd{n}, dd{lit(3)}, dd{lit(4)}}; + throws_shape( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(broadcast_1in_range_dyn_with_sym_target_error) +{ + auto n = var("n", {2, 8}); + std::vector
in_dims{dd{3, 3}}; + migraphx::shape input{migraphx::shape::float_type, in_dims}; + std::vector
out{dd{n}, dd{lit(3)}, dd{lit(4)}}; + throws_shape( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + TEST_CASE(broadcast_2in_static_static) { migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; @@ -337,6 +447,39 @@ TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error) throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input); } +TEST_CASE(broadcast_2in_sym_match) +{ + auto n = var("n", {2, 8}); + migraphx::shape a_input{migraphx::shape::float_type, {dd{lit(4)}}}; + migraphx::shape b_input{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}, {lit(0), lit(1)}}; + expect_shape(expected, migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input); +} + +TEST_CASE(broadcast_2in_sym_static_to_sym) +{ + auto n = var("n", {2, 8}); + migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; + migraphx::shape b_input{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}, {lit(0), lit(1)}}; + expect_shape(expected, migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input); +} + +TEST_CASE(broadcast_2in_sym_dim_mismatch_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape a_input{migraphx::shape::float_type, {dd{lit(4)}}}; + migraphx::shape b_input{migraphx::shape::float_type, {dd{n}, dd{lit(8)}}}; + throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input); +} + +TEST_CASE(broadcast_2in_sym_with_range_dyn_error) +{ + migraphx::shape a_input{migraphx::shape::float_type, {dd{lit(4)}}}; + migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input); +} + TEST_CASE(conv_2d_0) { migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}}; @@ -2058,6 +2201,77 @@ TEST_CASE(multibroadcast_1in_dyn_error_0) throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); } +TEST_CASE(multibroadcast_1in_sym_match) +{ + auto n = var("n", {2, 8}); + std::vector
out{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}; + migraphx::shape input{migraphx::shape::float_type, out}; + expect_shape(input, + migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(multibroadcast_1in_sym_broadcast_axis) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, + {dd{lit(1)}, dd{lit(1)}, dd{lit(3)}, dd{lit(4)}}}; + std::vector
out{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}; + migraphx::shape expected{migraphx::shape::float_type, out, {lit(12), lit(0), lit(4), lit(1)}}; + expect_shape(expected, + migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(multibroadcast_1in_sym_rank_extend) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(3)}, dd{lit(4)}}}; + std::vector
out{dd{lit(2)}, dd{n}, dd{lit(3)}, dd{lit(4)}}; + migraphx::shape expected{migraphx::shape::float_type, out, {lit(0), lit(0), lit(4), lit(1)}}; + expect_shape(expected, + migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(multibroadcast_1in_static_to_sym) +{ + // Static input + symbolic out_dyn_dims: op bridges via to_symbolic() internally. + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {1, 1, 3, 4}}; + std::vector
out{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}; + migraphx::shape expected{migraphx::shape::float_type, out, {lit(12), lit(0), lit(4), lit(1)}}; + expect_shape(expected, + migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(multibroadcast_1in_sym_mismatch_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(2)}, dd{lit(3)}}}; + std::vector
out{dd{lit(1)}, dd{n}, dd{lit(3)}}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(multibroadcast_1in_range_dyn_with_sym_target_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}}}; + std::vector
out{dd{lit(1)}, dd{n}, dd{lit(3)}}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(out)}}), + input); +} + +TEST_CASE(multibroadcast_1in_sym_input_with_static_target_error) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(1)}, dd{n}}}; + std::vector lens{2, 8}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); +} + TEST_CASE(multibroadcast_2in_static_dyn0) { migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}}; diff --git a/test/shape_test.cpp b/test/shape_test.cpp index fa286d9490c..dcd3b292f34 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -2382,4 +2383,93 @@ TEST_CASE(shape_is_compatible_lens_static_vs_symbolic) EXPECT(not migraphx::shape::is_compatible_lens(actual2, expected)); } +TEST_CASE(make_bcast_shape_static) +{ + migraphx::shape input{migraphx::shape::float_type, {1, 1, 3, 4}}; + auto out = migraphx::make_bcast_shape(input, {2, 5, 3, 4}); + EXPECT(not out.dynamic()); + EXPECT(out.lens() == std::vector{2, 5, 3, 4}); + EXPECT(out.strides() == std::vector{0, 0, 4, 1}); +} + +TEST_CASE(make_bcast_shape_symbolic_same_rank_match) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}}; + auto out = migraphx::make_bcast_shape( + input, std::vector
{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}); + EXPECT(out.symbolic()); + EXPECT(out.dyn_dims() == input.dyn_dims()); + EXPECT(out.dyn_strides() == input.dyn_strides()); +} + +TEST_CASE(make_bcast_shape_symbolic_broadcast_axis) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, + {dd{lit(1)}, dd{lit(1)}, dd{lit(3)}, dd{lit(4)}}}; + auto out = migraphx::make_bcast_shape( + input, std::vector
{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}); + EXPECT(out.symbolic()); + EXPECT(out.dyn_dims() == (std::vector
{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}})); + EXPECT(out.dyn_strides() == + (std::vector{lit(12), lit(0), lit(4), lit(1)})); +} + +TEST_CASE(make_bcast_shape_symbolic_rank_extend) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {dd{lit(3)}, dd{lit(4)}}}; + auto out = migraphx::make_bcast_shape( + input, std::vector
{dd{lit(2)}, dd{n}, dd{lit(3)}, dd{lit(4)}}); + EXPECT(out.symbolic()); + EXPECT(out.dyn_dims() == (std::vector
{dd{lit(2)}, dd{n}, dd{lit(3)}, dd{lit(4)}})); + EXPECT(out.dyn_strides() == (std::vector{lit(0), lit(0), lit(4), lit(1)})); +} + +TEST_CASE(make_bcast_shape_static_input_via_to_symbolic) +{ + auto n = var("n", {2, 8}); + migraphx::shape input{migraphx::shape::float_type, {1, 1, 3, 4}}; + auto out = migraphx::make_bcast_shape( + input.to_symbolic(), std::vector
{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}}); + EXPECT(out.symbolic()); + EXPECT(out.dyn_dims() == (std::vector
{dd{lit(1)}, dd{n}, dd{lit(3)}, dd{lit(4)}})); + EXPECT(out.dyn_strides() == + (std::vector{lit(12), lit(0), lit(4), lit(1)})); +} + +TEST_CASE(to_symbolic_static) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}}; + auto sym = s.to_symbolic(); + EXPECT(sym.symbolic()); + EXPECT(sym.type() == s.type()); + EXPECT(sym.dyn_dims() == (std::vector
{dd{lit(2)}, dd{lit(3)}, dd{lit(4)}})); + EXPECT(sym.dyn_strides() == (std::vector{lit(12), lit(4), lit(1)})); +} + +TEST_CASE(to_symbolic_static_with_strides) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}, {0, 4, 1}}; + auto sym = s.to_symbolic(); + EXPECT(sym.symbolic()); + EXPECT(sym.dyn_strides() == (std::vector{lit(0), lit(4), lit(1)})); +} + +TEST_CASE(to_symbolic_already_symbolic_is_identity) +{ + auto n = var("n", {2, 8}); + migraphx::shape s{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(3)}}}; + auto sym = s.to_symbolic(); + EXPECT(sym.symbolic()); + EXPECT(sym == s); +} + +TEST_CASE(to_symbolic_range_based_throws) +{ + migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}}; + EXPECT(test::throws([&] { s.to_symbolic(); })); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_dyn_ops_test.cpp b/test/simplify_dyn_ops_test.cpp index b6a770df194..d5332a80d13 100644 --- a/test/simplify_dyn_ops_test.cpp +++ b/test/simplify_dyn_ops_test.cpp @@ -734,6 +734,109 @@ TEST_CASE(static_broadcast_for_dot) EXPECT(m0 == m1); } +TEST_CASE(symbolic_broadcast) +{ + using dd = migraphx::shape::dynamic_dimension; + using migraphx::sym::lit; + using migraphx::sym::var; + auto n = var("n", {2, 8}); + migraphx::shape s0_s{migraphx::shape::float_type, {dd{lit(4)}}}; + migraphx::shape s1_s{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + + migraphx::module m0; + { + auto k = m0.add_parameter("k", s0_s); + auto data = m0.add_parameter("data", s1_s); + auto bcast = m0.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), k, data); + auto add_ins = m0.add_instruction(migraphx::make_op("add"), bcast, data); + m0.add_return({add_ins}); + } + run_pass(m0); + + migraphx::module m1; + { + auto k = m1.add_parameter("k", s0_s); + auto data = m1.add_parameter("data", s1_s); + std::vector
out_dyn_dims{dd{n}, dd{lit(4)}}; + auto bcast = m1.add_instruction( + migraphx::make_op("broadcast", + {{"axis", 1}, {"out_dyn_dims", migraphx::to_value(out_dyn_dims)}}), + k); + auto add_ins = m1.add_instruction(migraphx::make_op("add"), bcast, data); + m1.add_return({add_ins}); + } + EXPECT(m0 == m1); +} + +TEST_CASE(symbolic_multibroadcast) +{ + using dd = migraphx::shape::dynamic_dimension; + using migraphx::sym::lit; + using migraphx::sym::var; + auto n = var("n", {2, 8}); + migraphx::shape s0_s{migraphx::shape::float_type, {dd{n}, dd{lit(1)}, dd{lit(4)}}}; + migraphx::shape s1_s{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(4)}}}; + + migraphx::module m0; + { + auto a = m0.add_parameter("a", s0_s); + auto b = m0.add_parameter("b", s1_s); + auto bcast = m0.add_instruction(migraphx::make_op("multibroadcast"), a, b); + auto add_ins = m0.add_instruction(migraphx::make_op("add"), bcast, b); + m0.add_return({add_ins}); + } + run_pass(m0); + + migraphx::module m1; + { + auto a = m1.add_parameter("a", s0_s); + auto b = m1.add_parameter("b", s1_s); + std::vector
out_dyn_dims{dd{n}, dd{lit(3)}, dd{lit(4)}}; + auto bcast = m1.add_instruction( + migraphx::make_op("multibroadcast", + {{"out_dyn_dims", migraphx::to_value(out_dyn_dims)}}), + a); + auto add_ins = m1.add_instruction(migraphx::make_op("add"), bcast, b); + m1.add_return({add_ins}); + } + EXPECT(m0 == m1); +} + +TEST_CASE(symbolic_broadcast_for_dot) +{ + using dd = migraphx::shape::dynamic_dimension; + using migraphx::sym::lit; + using migraphx::sym::var; + auto n = var("n", {2, 8}); + migraphx::shape s0_s{migraphx::shape::float_type, + {dd{lit(1)}, dd{lit(4)}, dd{lit(6)}, dd{lit(8)}}}; + migraphx::shape s1_s{migraphx::shape::float_type, {dd{n}, dd{lit(4)}, dd{lit(8)}, dd{lit(10)}}}; + + migraphx::module m0; + { + auto a = m0.add_parameter("a", s0_s); + auto b = m0.add_parameter("b", s1_s); + auto bcast = m0.add_instruction(migraphx::make_op("broadcast_for_dot"), a, b); + auto dot = m0.add_instruction(migraphx::make_op("dot"), bcast, b); + m0.add_return({dot}); + } + run_pass(m0); + + migraphx::module m1; + { + auto a = m1.add_parameter("a", s0_s); + auto b = m1.add_parameter("b", s1_s); + std::vector
out_dyn_dims{dd{n}, dd{lit(4)}, dd{lit(6)}, dd{lit(8)}}; + auto bcast = m1.add_instruction( + migraphx::make_op("multibroadcast", + {{"out_dyn_dims", migraphx::to_value(out_dyn_dims)}}), + a); + auto dot = m1.add_instruction(migraphx::make_op("dot"), bcast, b); + m1.add_return({dot}); + } + EXPECT(m0 == m1); +} + TEST_CASE(static_onehot) { // depth as a literal