diff --git a/requirements.txt b/requirements.txt index 47cf56c7dd4..694b30b6e8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@0f4a7090de7ed8a6235d943acd2872c925669548 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@6671a1ee31d44e1f2f7743a26a71ffc4ce9169bf -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7624c97d3af..ef85880352c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -77,6 +77,7 @@ add_library(migraphx fp8_ocp_to_fnuz.cpp fuse_attention.cpp fuse_concat.cpp + fuse_horizontal.cpp fuse_pointwise.cpp fuse_pointwise_reduce.cpp fuse_reduce.cpp diff --git a/src/adjust_allocation.cpp b/src/adjust_allocation.cpp index 3dad3b7fade..eaa3689537c 100644 --- a/src/adjust_allocation.cpp +++ b/src/adjust_allocation.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -43,7 +43,10 @@ void adjust_allocation::apply(module& m) const if(ins->get_operator().is_context_free()) continue; - auto alias_ins = instruction::get_output_alias(ins, true); + auto aliases = instruction::get_output_alias(ins, true); + if(aliases.size() != 1) + continue; + auto alias_ins = aliases.front(); if(alias_ins->name() != model.name() and alias_ins->name() != "@param") continue; // shape allocated is different from actual shape diff --git a/src/api/api.cpp b/src/api/api.cpp index 5c37462641d..901222482d1 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -377,19 +377,9 @@ struct custom_operation return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); } - std::ptrdiff_t output_alias(std::vector inputs) const + std::vector output_alias(std::vector inputs) const { - auto alias_vec = op.output_alias(std::move(inputs)); - // TODO: For now, only support one output alias - if(alias_vec.empty()) - { - return -1; - } - if(alias_vec.size() > 1) - { - MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias"); - } - return alias_vec.front(); + return op.output_alias(std::move(inputs)); } bool runs_on_offload_target() const { return op.runs_on_offload_target(); } diff --git a/src/common_dims.cpp b/src/common_dims.cpp index 1afe92087fd..ee38d21e9cd 100644 --- a/src/common_dims.cpp +++ b/src/common_dims.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -97,25 +97,66 @@ static bool compute_common_dim(std::vector& cd_dims, assert(state1.get() < state2.get()); auto d2 = state2.get(); auto dims = state1.dims_for(d2); - auto n = elements(dims); auto naxes = distance(dims); + if(naxes == 0) return false; + + // Check if state1 has a remainder from previous split + bool has_remainder = (state1.rem != 1); + + // Compute the product of dimensions, adjusting for remainder if needed + auto n = elements(dims); + if(has_remainder and naxes > 0) + { + n = n / *dims.begin() * (*dims.begin() / state1.rem); + } + // If not divisible then we can't compute a common dim if((d2 % n) != 0) return false; + auto rem = d2 / n; - state1.add_multi_axes(naxes, cd_dims.size()); - state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size()); + auto start_pos = cd_dims.size(); + // Add axes mappings + if(has_remainder) + { + // state1: dimension was split, keep axes together + state1.add_axes(naxes, start_pos); + // state2: axes should include the previous remainder dimension + state2.add_axes(rem == 1 ? naxes : naxes + 1, start_pos - 1); + } + else + { + // state1: separate axes for each dimension + state1.add_multi_axes(naxes, start_pos); + // state2: normal axes mapping + state2.add_axes(rem == 1 ? naxes : naxes + 1, start_pos); + } + + // Add dimensions to cd_dims + if(has_remainder and naxes > 0) + { + // Adjust the first dimension by dividing by the remainder + cd_dims.push_back(*dims.begin() / state1.rem); + cd_dims.insert(cd_dims.end(), std::next(dims.begin()), dims.end()); + } + else + { + cd_dims.insert(cd_dims.end(), dims.begin(), dims.end()); + } + + // Add remainder dimension if needed + if(rem != 1) + cd_dims.push_back(rem); + + // Update states state1.rem = rem; state2.rem = 1; - - cd_dims.insert(cd_dims.end(), dims.begin(), dims.end()); - if(state1.rem != 1) - cd_dims.push_back(state1.rem); - state1.next(distance(dims)); + state1.next(naxes); state2.next(); + return true; } @@ -152,6 +193,22 @@ common_dims common_dims::compute(const std::vector& dims1, return {}; } } + + // Handle case where one state has a remainder that equals the next dimension + // In this case, the dimension was already added as a remainder, we just need the axes mapping + auto handle_remaining_dimension = [&cd](common_dim_state& state) { + if(not state.is_end() and state.rem != 1 and state.get() == 1) + { + // The remainder already added to cd_dims matches this dimension + // Add a single axes mapping + state.axes_map->push_back({cd.dims.size() - 1}); + state.next(); + } + }; + + handle_remaining_dimension(state1); + handle_remaining_dimension(state2); + assert(elements(dims1) == elements(cd.dims)); return cd; } diff --git a/src/dead_code_elimination.cpp b/src/dead_code_elimination.cpp index a33db6ae608..4d5aa279218 100644 --- a/src/dead_code_elimination.cpp +++ b/src/dead_code_elimination.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,13 +48,11 @@ void dead_code_elimination::apply(module& m) const // Skip the last instruction if(i == last) break; - // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, + // Skip instruction with empty shape as output unless its [builtin, undefined, // identity, allocate, or tuple_type] - if((not i->get_shape().dynamic() and - (i->get_shape().elements() == 0 and - i->get_shape().type() != migraphx::shape::tuple_type)) and - not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and - not i->is_undefined()) + if(i->get_shape().ndim() == 0 and not i->is_undefined() and + i->get_shape().type() != migraphx::shape::tuple_type and i->name().front() != '@' and + not contains({"identity", "allocate"}, i->name())) continue; assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); std::unordered_set visited; diff --git a/src/driver/perf.cpp b/src/driver/perf.cpp index c5f4f33c1fb..fdd20331163 100644 --- a/src/driver/perf.cpp +++ b/src/driver/perf.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -118,19 +118,23 @@ bool is_offload_copy_set(const program& p) { if(i.name() == "hip::copy_to_gpu") { - auto copy_arg = instruction::get_output_alias(i.inputs().front(), true); - param_ins.erase(copy_arg); + auto copy_args = instruction::get_output_alias(i.inputs().front(), true); + for(auto copy_arg : copy_args) + param_ins.erase(copy_arg); } else if(i.name() == "@return") { auto return_args = i.inputs(); - for(const auto& j : return_args) - { - auto alias_ins = instruction::get_output_alias(j, true); - if((alias_ins->name() == "@param" and param_ins.erase(alias_ins) == 0) or - (alias_ins->name() != "hip::copy_from_gpu")) + return std::all_of(return_args.begin(), return_args.end(), [&](const auto& j) { + auto aliases = instruction::get_output_alias(j, true); + return std::all_of(aliases.begin(), aliases.end(), [&](instruction_ref alias_ins) { + if(alias_ins->name() == "hip::copy_from_gpu") + return true; + if(alias_ins->name() == "@param") + return not contains(param_ins, alias_ins); return false; - } + }); + }); } } return param_ins.empty(); diff --git a/src/driver/trim.cpp b/src/driver/trim.cpp index 8158a04f846..4286e5a44ff 100644 --- a/src/driver/trim.cpp +++ b/src/driver/trim.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS { static instruction_ref capture_arg(std::unordered_set& s, instruction_ref ins) { - auto alias = instruction::get_output_alias(ins, true); - if(alias != ins) + auto aliases = instruction::get_output_alias(ins, true); + if(aliases.size() == 1 and aliases.front() != ins) { s.insert(ins); - return capture_arg(s, alias); + return capture_arg(s, aliases.front()); } if(contains({"reshape", "contiguous"}, ins->name())) { diff --git a/src/eliminate_concat.cpp b/src/eliminate_concat.cpp index 47a095659cc..7f9399777b3 100644 --- a/src/eliminate_concat.cpp +++ b/src/eliminate_concat.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -70,11 +70,14 @@ void eliminate_concat::apply(module& m) const // Where are the allocations for the tensors to be concatenated? std::vector allocations; - std::transform( - ins->inputs().begin(), - std::prev(ins->inputs().end()), - std::back_inserter(allocations), - [&](instruction_ref x) { return instruction::get_output_alias(x, true); }); + std::transform(ins->inputs().begin(), + std::prev(ins->inputs().end()), + std::back_inserter(allocations), + [&](instruction_ref x) { + auto aliases = instruction::get_output_alias(x, true); + // cppcheck-suppress returnDanglingLifetime + return aliases.front(); + }); if(std::any_of(allocations.begin(), allocations.end(), [&](auto x) { return x->name() != concat_opt.allocate(); diff --git a/src/fuse_horizontal.cpp b/src/fuse_horizontal.cpp new file mode 100644 index 00000000000..4f5b934f6b3 --- /dev/null +++ b/src/fuse_horizontal.cpp @@ -0,0 +1,291 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// --------------------------------------------------------------------------- +// Horizontal fusion framework +// +// To add a new horizontal fusion, define a plain struct that implements: +// +// std::size_t min_group_size() const — minimum group size for fusion +// bool is_candidate(instruction_ref) const — does this instruction qualify? +// auto group_key(instruction_ref) const — grouping key (equality-comparable) +// std::vector +// fuse(module&, const std::vector&, +// instruction_ref insert_pt) const +// — fuse a group, return one replacement instruction per original op +// +// Then pass an instance to fuse_horizontal_ops(). +// The framework handles scanning, grouping independent instructions by key, +// filtering inter-dependent instructions, dispatching to fuse(), and replacing +// originals with results. +// --------------------------------------------------------------------------- + +template +static void apply_horizontal_finder(module& m, const Finder& finder) +{ + // Collect all candidate instructions and build position map + std::vector candidates; + copy_if(iterator_for(m), std::back_inserter(candidates), [&](auto ins) { + return finder.is_candidate(ins); + }); + std::unordered_map pos; + std::size_t p = 0; + for(auto ins : iterator_for(m)) + { + pos[ins] = p++; + } + + auto pred = [&](instruction_ref x, instruction_ref y) { + if(x == y) + return true; + if(finder.group_key(x) != finder.group_key(y)) + return false; + if(pos.at(x) < pos.at(y)) + return not reaches(x, y); + return not reaches(y, x); + }; + + auto each = [&](auto start, auto last) { + auto n = std::distance(start, last); + if(n < finder.min_group_size()) + return; + + std::vector group(start, last); + // Sort by position for consistent ordering + std::sort( + group.begin(), group.end(), [&](auto a, auto b) { return pos.at(a) < pos.at(b); }); + + auto insert_pt = std::next(group.back()); + auto replacements = finder.fuse(m, group, insert_pt); + if(replacements.empty()) + return; + + assert(replacements.size() == group.size()); + + // Move outputs of the original instructions to after the new instructions + // so that replace_instruction's validity assertions hold. + std::for_each(group.begin(), group.end(), [&](auto g) { + m.move_output_instructions_after(g, replacements.back()); + }); + + migraphx::for_each(group.begin(), group.end(), replacements.begin(), [&](auto g, auto r) { + m.replace_instruction(g, r); + }); + }; + + group_by(candidates.begin(), candidates.end(), each, pred); +} + +template +static void fuse_horizontal_ops(module& m, Finders&&... finders) +{ + each_args([&](auto&& finder) { apply_horizontal_finder(m, finder); }, finders...); +} + +// --------------------------------------------------------------------------- +// Cross-embedding gather horizontal fusion +// +// Candidates: gather(axis=0) with 2D constant embedding table, static shapes, +// non-scalar index +// Grouping: by (embedding dimension, index type, index trailing dims) +// Fusion: concatenate embedding tables, adjust indices with offsets, +// single batched gather, slice results back +// --------------------------------------------------------------------------- + +struct gather_horizontal_fusion +{ + std::size_t min_group_size() const { return 4; } + + bool is_candidate(instruction_ref ins) const + { + if(ins->name() != "gather") + return false; + + if(ins->get_operator().to_value()["axis"].to() != 0) + return false; + + auto data = ins->inputs().at(0); + auto idx = ins->inputs().at(1); + + // Embedding must be 2D: {num_rows, embedding_dim} + if(data->get_shape().lens().size() != 2) + return false; + + // Embedding must be constant (evaluable) + if(not data->can_eval()) + return false; + + // Index must not be scalar + if(idx->get_shape().scalar() or idx->get_shape().lens().empty()) + return false; + + return true; + } + + auto group_key(instruction_ref ins) const + { + auto emb_dim = ins->inputs().at(0)->get_shape().lens().back(); + auto idx = ins->inputs().at(1); + auto idx_type = idx->get_shape().type(); + const auto& lens = idx->get_shape().lens(); + // Trailing index dims (all except first) — must match for concat on axis 0 + std::vector trailing(lens.begin() + 1, lens.end()); + return std::make_tuple(emb_dim, idx_type, std::move(trailing)); + } + + std::vector + fuse(module& m, const std::vector& gathers, instruction_ref insert_pt) const + { + auto idx_type = gathers.front()->inputs().at(1)->get_shape().type(); + + // Concatenate all embedding tables + std::vector emb_inputs(gathers.size()); + std::transform(gathers.begin(), gathers.end(), emb_inputs.begin(), [](auto g) { + return g->inputs().at(0); + }); + auto concat_emb = + m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), emb_inputs); + + // Compute cumulative embedding offsets using transform_partial_sum. + // Inclusive partial sum gives end offsets; shift right and prepend 0 + // to get start (exclusive) offsets. + std::vector cum_sizes(gathers.size()); + transform_partial_sum( + gathers.begin(), gathers.end(), cum_sizes.begin(), std::plus<>{}, [](auto g) { + return g->inputs().at(0)->get_shape().lens().front(); + }); + + // Exclusive offsets: [0, cum_sizes[0], cum_sizes[1], ...] + std::vector emb_offsets(gathers.size()); + emb_offsets[0] = 0; + std::copy(cum_sizes.begin(), std::prev(cum_sizes.end()), emb_offsets.begin() + 1); + + // Build adjusted indices (add offset to shift into concatenated table) + std::vector adjusted_idx_inputs; + adjusted_idx_inputs.reserve(gathers.size()); + + migraphx::for_each( + gathers.begin(), gathers.end(), emb_offsets.begin(), [&](auto g, auto offset) { + auto idx = g->inputs().at(1); + if(offset == 0) + { + adjusted_idx_inputs.push_back(idx); + } + else + { + auto offset_scalar = m.add_literal(literal{shape{idx_type}, {offset}}); + auto offset_broadcast = m.insert_instruction( + insert_pt, + make_op("multibroadcast", {{"out_lens", idx->get_shape().lens()}}), + offset_scalar); + auto adjusted_idx = + m.insert_instruction(insert_pt, make_op("add"), idx, offset_broadcast); + adjusted_idx_inputs.push_back(adjusted_idx); + } + }); + + // Concatenate adjusted indices + auto concat_idx = + m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), adjusted_idx_inputs); + + // Single batched gather + auto batched_gather = m.insert_instruction( + insert_pt, make_op("gather", {{"axis", 0}}), concat_emb, concat_idx); + + // Compute slice boundaries using partial_sum of index sizes + std::vector idx_sizes(gathers.size()); + std::transform(gathers.begin(), gathers.end(), idx_sizes.begin(), [](auto g) { + return g->inputs().at(1)->get_shape().lens().front(); + }); + + std::vector slice_ends(gathers.size()); + std::partial_sum(idx_sizes.begin(), idx_sizes.end(), slice_ends.begin()); + + // slice_starts = [0, slice_ends[0], slice_ends[1], ...] + std::vector slice_starts(gathers.size()); + slice_starts[0] = 0; + std::copy(slice_ends.begin(), std::prev(slice_ends.end()), slice_starts.begin() + 1); + + // Slice results back — one per original gather + std::vector results; + results.reserve(gathers.size()); + + migraphx::for_each( + slice_starts.begin(), + slice_starts.end(), + slice_ends.begin(), + [&](auto start, auto end) { + results.push_back(m.insert_instruction( + insert_pt, + make_op("slice", + {{"axes", std::vector{0}}, + {"starts", std::vector{static_cast(start)}}, + {"ends", std::vector{static_cast(end)}}}), + batched_gather)); + }); + + return results; + } +}; + +// --------------------------------------------------------------------------- +// Future: add more horizontal fusion finders here, e.g. +// +// struct pointwise_horizontal_fusion +// { +// std::size_t min_group_size() const { return 2; } +// bool is_candidate(instruction_ref ins) const { ... } +// std::string group_key(instruction_ref ins) const { ... } +// std::vector +// fuse(module& m, const std::vector& ops, +// instruction_ref insert_pt) const { ... } +// }; +// --------------------------------------------------------------------------- + +void fuse_horizontal::apply(module_pass_manager& mpm) const +{ + auto& m = mpm.get_module(); + + fuse_horizontal_ops(m, gather_horizontal_fusion{}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 42c95514555..89f9a16ddfa 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -162,31 +162,6 @@ static module::with_inputs append_pointwise_module(instruction_ref ins, instruct return {std::move(pm), inputs}; } -static void move_output_instructions_after(module& m, instruction_ref src, instruction_ref dst) -{ - auto d = std::distance(src, dst); - std::vector> instructions; - fix([&](auto self, instruction_ref ins) { - for(auto output : ins->outputs()) - { - assert(m.has_instruction(output)); - if(any_of(instructions, [&](const auto& p) { return p.second == output; })) - continue; - auto i = std::distance(src, output); - if(i >= d) - continue; - instructions.emplace_back(i, output); - self(output); - } - })(src); - std::sort(instructions.begin(), instructions.end(), by(std::less<>{}, [](auto&& p) { - return p.first; - })); - auto loc = std::next(dst); - for(auto [i, ins] : instructions) - m.move_instruction(ins, loc); -} - static void replace_with_tuple(module& m, instruction_ref ins, instruction_ref rep, bool first) { @@ -232,7 +207,7 @@ merge_instruction(module_pass_manager& mpm, instruction_ref input, instruction_r mpm.get_module().insert_instruction(output, input->get_operator(), fused.inputs, {new_pm}); if(fins->get_shape().tuple_size() != output->get_shape().tuple_size()) { - move_output_instructions_after(mpm.get_module(), input, fins); + mpm.get_module().move_output_instructions_after(input, fins); replace_with_tuple(mpm.get_module(), input, fins, false); } replace_with_tuple(mpm.get_module(), output, fins, true); @@ -282,11 +257,7 @@ find_output_pointwise(const module& m, instruction_ref ins, bool multi_out) return false; if(is_dead(output)) return false; - // TODO: move_output_instructions_after doesnt handle outputs from different - // modules so only fuse from the same module - return std::all_of(output->outputs().begin(), - output->outputs().end(), - [&](auto out) { return m.has_instruction(out); }); + return true; }); if(outputs.size() < 2) return result; diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8d1a7ff39d6..2ef2645a035 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -174,16 +174,19 @@ static auto any_input(Ms... ms) return match::any_of[match::inputs()](match::any(ms...).bind("input")); } -static bool is_valid_broadcast(const instruction_ref b, const std::vector& reduce_axes) +static bool is_valid_broadcast(const instruction_ref b, std::vector reduce_axes) { - std::vector broadcast_axes; - auto bstrides = b->get_shape().strides(); + const auto& blens = b->get_shape().lens(); + const auto& bstrides = b->get_shape().strides(); + reduce_axes.erase(std::remove_if(reduce_axes.begin(), + reduce_axes.end(), + [&](size_t axis) { return blens.at(axis) == 1; }), + reduce_axes.end()); - for(size_t i = 0; i < bstrides.size(); ++i) - { - if(bstrides.at(i) == 0) - broadcast_axes.push_back(i); - } + std::vector broadcast_axes; + copy_if(range(bstrides.size()), std::back_inserter(broadcast_axes), [&](size_t i) { + return bstrides.at(i) == 0 and blens.at(i) != 1; + }); return broadcast_axes == reduce_axes; } diff --git a/src/graphviz.cpp b/src/graphviz.cpp index e2d50ac720e..3726846a994 100644 --- a/src/graphviz.cpp +++ b/src/graphviz.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -114,7 +114,7 @@ std::string get_graph_color(const instruction_ref& ins) const auto& attr = op.attributes(); bool context_free = is_context_free(op); - bool alias = op.output_alias(to_shapes(ins->inputs())) >= 0; + bool alias = not op.output_alias(to_shapes(ins->inputs())).empty(); if(ins->can_eval()) { diff --git a/src/include/migraphx/fuse_horizontal.hpp b/src/include/migraphx/fuse_horizontal.hpp new file mode 100644 index 00000000000..fcac7b3151b --- /dev/null +++ b/src/include/migraphx/fuse_horizontal.hpp @@ -0,0 +1,52 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_HORIZONTAL_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_HORIZONTAL_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +/** + * @brief Horizontal fusion pass for independent gather operations. + * + * Currently supports the following fusion: + * Identifies groups of independent gather(axis=0) ops that share the same + * embedding dimension and index layout, then fuses them into a single gather + * over a concatenated embedding table with offset-adjusted indices. + * The batched result is sliced back to produce the original outputs. + */ +struct MIGRAPHX_EXPORT fuse_horizontal +{ + std::string name() const { return "fuse_horizontal"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_HORIZONTAL_HPP diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 0369bb6ab64..a12f6508c81 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -140,7 +140,7 @@ struct MIGRAPHX_EXPORT instruction void finalize(context& ctx); - static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false); + static std::vector get_output_alias(instruction_ref ins, bool shallow = false); void set_normalized(bool value = true); bool is_normalized() const; diff --git a/src/include/migraphx/liveness.hpp b/src/include/migraphx/liveness.hpp index 6d9715a8a10..9a3418c8307 100644 --- a/src/include/migraphx/liveness.hpp +++ b/src/include/migraphx/liveness.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -53,11 +53,12 @@ void liveness(const module& m, F f) auto add_live_variables = [&](const auto& inputs) { for(auto input : inputs) { - auto i = instruction::get_output_alias(input); + auto aliases = instruction::get_output_alias(input); // Skip if variable comes from parent - if(not m.has_instruction(i)) - continue; - live_set.insert(i); + std::copy_if(aliases.begin(), + aliases.end(), + std::inserter(live_set, live_set.end()), + [&](auto i) { return m.has_instruction(i); }); } }; add_live_variables(ins->inputs()); diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 2f0c95da4df..36ad2fd9052 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -835,6 +835,12 @@ inline auto ndim(std::size_t n) [=](instruction_ref ins) { return ins->get_shape().ndim() == n; }); } +inline auto nelements(std::size_t n) +{ + return make_basic_pred_matcher( + [=](instruction_ref ins) { return ins->get_shape().elements() == n; }); +} + MIGRAPHX_PRED_MATCHER(not_tuple, instruction_ref ins) { return ins->get_shape().type() != shape::tuple_type; diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index d68b2683e65..470772c3fae 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -128,6 +128,8 @@ struct MIGRAPHX_EXPORT module instruction_ref move_instruction(instruction_ref src, instruction_ref dst); instruction_ref move_instructions(instruction_ref src, instruction_ref dst); + void move_output_instructions_after(instruction_ref src, instruction_ref dst); + std::vector add_instructions(const std::vector& instructions, std::unordered_map* map_ins = nullptr, diff --git a/src/include/migraphx/op/as_shape.hpp b/src/include/migraphx/op/as_shape.hpp index 162526de16a..7618451a317 100644 --- a/src/include/migraphx/op/as_shape.hpp +++ b/src/include/migraphx/op/as_shape.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -53,7 +53,7 @@ struct as_shape { return args.front().reshape(output_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/broadcast.hpp b/src/include/migraphx/op/broadcast.hpp index 9587fedd6d2..f6da43940b0 100644 --- a/src/include/migraphx/op/broadcast.hpp +++ b/src/include/migraphx/op/broadcast.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -148,7 +148,7 @@ struct broadcast { return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } value attributes() const { return {{"fillcolor", "#9ACD32" /* yellowgreen */}}; } }; diff --git a/src/include/migraphx/op/broadcast_for_dot.hpp b/src/include/migraphx/op/broadcast_for_dot.hpp index e3432ac0dac..96897b991a5 100644 --- a/src/include/migraphx/op/broadcast_for_dot.hpp +++ b/src/include/migraphx/op/broadcast_for_dot.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -83,7 +83,7 @@ struct broadcast_for_dot return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } value attributes() const { return {{"fillcolor", "#9ACD32" /* yellowgreen */}}; } }; diff --git a/src/include/migraphx/op/capture.hpp b/src/include/migraphx/op/capture.hpp index efb71c5c349..fe036b304af 100644 --- a/src/include/migraphx/op/capture.hpp +++ b/src/include/migraphx/op/capture.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -66,7 +66,7 @@ struct capture return args.front(); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/fill.hpp b/src/include/migraphx/op/fill.hpp index 38d608b9183..ae3d8e784b4 100644 --- a/src/include/migraphx/op/fill.hpp +++ b/src/include/migraphx/op/fill.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -60,7 +60,7 @@ struct fill return args[1]; } - std::ptrdiff_t output_alias(const std::vector&) const { return 1; } + std::vector output_alias(const std::vector&) const { return {1}; } }; } // namespace op diff --git a/src/include/migraphx/op/gather.hpp b/src/include/migraphx/op/gather.hpp index 739dc06be84..4dfe051e762 100644 --- a/src/include/migraphx/op/gather.hpp +++ b/src/include/migraphx/op/gather.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -71,12 +71,14 @@ struct gather { data = data.to_dynamic(); } + const bool scalar_indices = + indices.ndim() == 1 and indices.scalar() and indices.elements() == 1; if(data.dynamic()) { auto dims = data.dyn_dims(); dims.erase(dims.begin() + axis); - if(not indices.scalar()) + if(not scalar_indices) { auto index_dims = indices.to_dynamic().dyn_dims(); dims.insert(dims.begin() + axis, index_dims.begin(), index_dims.end()); @@ -89,7 +91,7 @@ struct gather auto lens = data.lens(); lens.erase(lens.begin() + axis); - if(not indices.scalar()) + if(not scalar_indices) { auto ind_lens = indices.lens(); lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end()); diff --git a/src/include/migraphx/op/get_tuple_elem.hpp b/src/include/migraphx/op/get_tuple_elem.hpp index e8e24d11c53..30d5da2b7d7 100644 --- a/src/include/migraphx/op/get_tuple_elem.hpp +++ b/src/include/migraphx/op/get_tuple_elem.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -67,7 +67,7 @@ struct get_tuple_elem return vec_args.at(index); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/identity.hpp b/src/include/migraphx/op/identity.hpp index 9d80c5bfc9c..02a53f1d308 100644 --- a/src/include/migraphx/op/identity.hpp +++ b/src/include/migraphx/op/identity.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ struct identity value attributes() const { return {{"pointwise", true}, {"point_op", "${0}"}}; } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/load.hpp b/src/include/migraphx/op/load.hpp index 0050e715089..a37e9448c9a 100644 --- a/src/include/migraphx/op/load.hpp +++ b/src/include/migraphx/op/load.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -58,7 +58,7 @@ struct load return argument{s, args[0].data() + offset}; } lifetime get_lifetime() const { return lifetime::borrow; } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } friend std::ostream& operator<<(std::ostream& os, const load& op) { diff --git a/src/include/migraphx/op/multibroadcast.hpp b/src/include/migraphx/op/multibroadcast.hpp index 8a9d15c7d99..735a99d4b54 100644 --- a/src/include/migraphx/op/multibroadcast.hpp +++ b/src/include/migraphx/op/multibroadcast.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -115,7 +115,7 @@ struct multibroadcast { return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/random_uniform.hpp b/src/include/migraphx/op/random_uniform.hpp index f873ae9d313..73ae1232b65 100644 --- a/src/include/migraphx/op/random_uniform.hpp +++ b/src/include/migraphx/op/random_uniform.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -107,7 +107,7 @@ struct random_uniform return result; } - std::ptrdiff_t output_alias(const std::vector&) const { return 1; } + std::vector output_alias(const std::vector&) const { return {1}; } }; } // namespace op diff --git a/src/include/migraphx/op/reshape_lazy.hpp b/src/include/migraphx/op/reshape_lazy.hpp index 4746015053e..773f6c345b5 100644 --- a/src/include/migraphx/op/reshape_lazy.hpp +++ b/src/include/migraphx/op/reshape_lazy.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -316,7 +316,7 @@ struct reshape_lazy return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/scalar.hpp b/src/include/migraphx/op/scalar.hpp index 9102e476d49..005712bbe69 100644 --- a/src/include/migraphx/op/scalar.hpp +++ b/src/include/migraphx/op/scalar.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -56,7 +56,7 @@ struct scalar { return args[0].reshape(output_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/select_module.hpp b/src/include/migraphx/op/select_module.hpp index a8064e80cfa..a87fb089941 100644 --- a/src/include/migraphx/op/select_module.hpp +++ b/src/include/migraphx/op/select_module.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -138,9 +138,9 @@ struct select_module return argument{results}; } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/include/migraphx/op/slice.hpp b/src/include/migraphx/op/slice.hpp index 88c05269af9..fa35d640849 100644 --- a/src/include/migraphx/op/slice.hpp +++ b/src/include/migraphx/op/slice.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -502,7 +502,7 @@ struct slice } } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/squeeze.hpp b/src/include/migraphx/op/squeeze.hpp index 6802e352feb..9ea891c2b96 100644 --- a/src/include/migraphx/op/squeeze.hpp +++ b/src/include/migraphx/op/squeeze.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -141,7 +141,7 @@ struct squeeze { return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/step.hpp b/src/include/migraphx/op/step.hpp index d547ac9cffe..660c0b2744f 100644 --- a/src/include/migraphx/op/step.hpp +++ b/src/include/migraphx/op/step.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -90,7 +90,7 @@ struct step return args[0].reshape(output_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/transpose.hpp b/src/include/migraphx/op/transpose.hpp index 039bea6e4c2..bb508ecec76 100644 --- a/src/include/migraphx/op/transpose.hpp +++ b/src/include/migraphx/op/transpose.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -92,7 +92,7 @@ struct transpose return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/op/unsqueeze.hpp b/src/include/migraphx/op/unsqueeze.hpp index d62c3250b0b..06446b23ffb 100644 --- a/src/include/migraphx/op/unsqueeze.hpp +++ b/src/include/migraphx/op/unsqueeze.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -148,7 +148,7 @@ struct unsqueeze { return args[0].reshape(dyn_out.computed_shape); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; } // namespace op diff --git a/src/include/migraphx/operation.hpp b/src/include/migraphx/operation.hpp index 86fbe6ecd98..82cb9f9cdd1 100644 --- a/src/include/migraphx/operation.hpp +++ b/src/include/migraphx/operation.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -79,9 +79,9 @@ struct operation * the same the `output` shape. */ argument compute(context& ctx, const shape& output, const std::vector& input) const; - /// An optional method to return which argument the output will alias. If - /// there is no aliased output then -1 can be returned. - std::ptrdiff_t output_alias(const std::vector& input) const; + /// An optional method to return which arguments the output will alias. If + /// there is no aliased output then an empty vector can be returned. + std::vector output_alias(const std::vector& input) const; /// An optional stream operator to print the operation. When this is not /// implemented, it will just print the operation's name. friend std::ostream& operator<<(std::ostream& os, const operation& op); @@ -407,9 +407,9 @@ auto need_normalization_op(const T& x) } template -std::ptrdiff_t output_alias_op(const T&, const std::vector&) +std::vector output_alias_op(const T&, const std::vector&) { - return -1; + return {}; } template @@ -517,7 +517,7 @@ struct MIGRAPHX_EXPORT operation // (optional) lifetime get_lifetime() const; // (optional) - std::ptrdiff_t output_alias(const std::vector& input) const; + std::vector output_alias(const std::vector& input) const; // (optional) value compile(context& ctx, const shape& output, const std::vector& input); // (optional) @@ -623,9 +623,8 @@ struct operation } template - static std::ptrdiff_t private_detail_te_default_output_alias(float, - T&& private_detail_te_self, - const std::vector& input) + static std::vector private_detail_te_default_output_alias( + float, T&& private_detail_te_self, const std::vector& input) { return detail::output_alias_op(private_detail_te_self, input); } @@ -1030,7 +1029,7 @@ struct operation return (*this).private_detail_te_get_handle().get_lifetime(); } - std::ptrdiff_t output_alias(const std::vector& input) const + std::vector output_alias(const std::vector& input) const { assert((*this).private_detail_te_handle_mem_var); return (*this).private_detail_te_get_handle().output_alias(input); @@ -1139,12 +1138,12 @@ struct operation virtual std::shared_ptr clone() const = 0; virtual const std::type_info& type() const = 0; - virtual std::string name() const = 0; - virtual bool is_context_free() const = 0; - virtual bool need_normalization() const = 0; - virtual bool has_finalize() const = 0; - virtual lifetime get_lifetime() const = 0; - virtual std::ptrdiff_t output_alias(const std::vector& input) const = 0; + virtual std::string name() const = 0; + virtual bool is_context_free() const = 0; + virtual bool need_normalization() const = 0; + virtual bool has_finalize() const = 0; + virtual lifetime get_lifetime() const = 0; + virtual std::vector output_alias(const std::vector& input) const = 0; virtual value compile(context& ctx, const shape& output, const std::vector& input) = 0; virtual void @@ -1229,7 +1228,7 @@ struct operation return private_detail_te_default_get_lifetime(char(0), private_detail_te_value); } - std::ptrdiff_t output_alias(const std::vector& input) const override + std::vector output_alias(const std::vector& input) const override { return private_detail_te_default_output_alias(char(0), private_detail_te_value, input); diff --git a/src/include/migraphx/ranges.hpp b/src/include/migraphx/ranges.hpp index 00b2d92dc69..c99e800ef31 100644 --- a/src/include/migraphx/ranges.hpp +++ b/src/include/migraphx/ranges.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -234,6 +234,18 @@ bool equal(R1&& r1, R2&& r2, Predicate... pred) return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...); } +template +bool equal(const std::initializer_list& r1, R2&& r2, Predicate... pred) +{ + return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...); +} + +template +bool equal(R1&& r1, const std::initializer_list& r2, Predicate... pred) +{ + return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...); +} + template auto distance(Range&& r) { diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index f506addc3a7..858e496c8a8 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -163,6 +163,9 @@ struct rewrite_reshapes if(desc.empty()) return; + if(desc.elements() != elements(dims2)) + return; + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, const auto& gdesc) { return [&](auto input) { diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index c8d42119b98..f19bd9d94ef 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -89,7 +90,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor optional axis = nullopt); void simplify(); std::size_t elements() const; - std::vector generate(const std::vector& input_dims = {}) const; + std::vector generate(const std::vector& input_dims = {}, + bool no_broadcast = false) const; std::set find_broadcasted_axes() const; bool has_broadcast() const; @@ -106,6 +108,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor std::vector> common_axes_map_from_src() const; std::vector> common_axes_map_from_dst() const; + std::vector get_dst_axes_from_src(std::size_t axis) const; + bool empty() const; std::vector lens() const; @@ -158,6 +162,10 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor MIGRAPHX_EXPORT std::vector optimize_shape_transforms(const std::vector& dims, const std::vector& ops); +// Generate the shape transforms for strided view +MIGRAPHX_EXPORT optional> +generate_shape_transforms_for(shape s, const std::vector& idims, std::int64_t offset); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_SHAPE_TRANSFORM_DESCRIPTOR_HPP diff --git a/src/include/migraphx/simplify_reshapes.hpp b/src/include/migraphx/simplify_reshapes.hpp index 9c02dc9c00d..3ecfada3f36 100644 --- a/src/include/migraphx/simplify_reshapes.hpp +++ b/src/include/migraphx/simplify_reshapes.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,6 +40,7 @@ struct MIGRAPHX_EXPORT simplify_reshapes { size_t depth = 4; bool enable_op_shape_transform_op = false; + bool enable_gather_rewrite = false; std::string name() const { return "simplify_reshapes"; } void apply(module& m) const; }; diff --git a/src/instruction.cpp b/src/instruction.cpp index b8c68a42ce9..07124db4d99 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -331,7 +331,8 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) bool instruction::is_undefined() const { - if(op.name() == "undefined" or (op.name() == "@literal" and this->get_literal().empty())) + if(op.name() == "undefined" or + (op.name() == "@literal" and this->get_literal().get_shape().elements() == 0)) { return true; } @@ -342,7 +343,8 @@ bool instruction::is_undefined() const else { return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) { - return arg->is_undefined(); + return all_of(instruction::get_output_alias(arg), + [](auto alias) { return alias->is_undefined(); }); }); } } @@ -470,14 +472,25 @@ void instruction::debug_print() const std::cout << " -> " << this->get_shape() << std::endl; } -instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) +std::vector instruction::get_output_alias(instruction_ref ins, bool shallow) { - auto i = ins->get_operator().output_alias(to_shapes(ins->inputs())); - if(i < 0) - return ins; - if(shallow) - return ins->inputs().at(i); - return get_output_alias(ins->inputs().at(i)); + auto alias_indices = ins->get_operator().output_alias(to_shapes(ins->inputs())); + if(alias_indices.empty()) + return {ins}; + std::vector result; + for(auto i : alias_indices) + { + if(shallow) + { + result.push_back(ins->inputs().at(i)); + } + else + { + auto nested = get_output_alias(ins->inputs().at(i)); + result.insert(result.end(), nested.begin(), nested.end()); + } + } + return result; } void instruction::set_normalized(bool value) { normalized = value; } diff --git a/src/module.cpp b/src/module.cpp index 4838d241904..b2a61c0c86d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -434,6 +434,61 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d return src; } +void module::move_output_instructions_after(instruction_ref src, instruction_ref dst) +{ + auto d = std::distance(src, dst); + std::vector> instructions; + // When an output is in a submodule (cross-module reference), resolve it to + // the instruction in this module that owns the submodule containing the output. + // The map is lazily built on the first cross-module output encountered. + std::optional> mod_owner_map; + auto resolve_output = [&](instruction_ref output) -> instruction_ref { + if(this->has_instruction(output)) + return output; + if(not mod_owner_map) + { + mod_owner_map.emplace(); + auto r = range(std::next(src), dst); + for(auto ins : iterator_for(r)) + { + for(auto* mod : ins->module_inputs()) + { + (*mod_owner_map)[mod] = ins; + for(auto* smod : mod->get_sub_modules()) + (*mod_owner_map)[smod] = ins; + } + } + } + auto it = std::find_if(mod_owner_map->begin(), mod_owner_map->end(), [&](const auto& p) { + return p.first->has_instruction(output); + }); + if(it != mod_owner_map->end()) + return it->second; + return this->end(); + }; + fix([&](auto self, instruction_ref ins) { + for(auto output : ins->outputs()) + { + output = resolve_output(output); + if(is_end(output, this->end())) + continue; + if(any_of(instructions, [&](const auto& p) { return p.second == output; })) + continue; + auto i = std::distance(src, output); + if(i >= d) + continue; + instructions.emplace_back(i, output); + self(output); + } + })(src); + std::sort(instructions.begin(), instructions.end(), by(std::less<>{}, [](auto&& p) { + return p.first; + })); + auto loc = std::next(dst); + for(auto [i, ins] : instructions) + this->move_instruction(ins, loc); +} + std::vector module::add_instructions(const std::vector& instructions, std::unordered_map* map_ins, @@ -779,19 +834,24 @@ instruction_ref module::validate() const static bool is_borrowed(instruction_ref ins) { - auto alias = instruction::get_output_alias(ins, true); - if(alias == ins) + auto aliases = instruction::get_output_alias(ins, true); + if(aliases.size() == 1 and aliases.front() == ins) return false; - lifetime l = alias->get_operator().get_lifetime(); - if(l == lifetime::borrow) - return true; - return is_borrowed(alias); + return std::any_of(aliases.begin(), aliases.end(), [](instruction_ref alias) { + lifetime l = alias->get_operator().get_lifetime(); + if(l == lifetime::borrow) + return true; + return is_borrowed(alias); + }); } static bool is_global(instruction_ref ins) { - const auto& op = instruction::get_output_alias(ins)->get_operator(); - return op.name() == "@param" or op.get_lifetime() == lifetime::global; + auto aliases = instruction::get_output_alias(ins); + return std::any_of(aliases.begin(), aliases.end(), [](instruction_ref alias) { + const auto& op = alias->get_operator(); + return op.name() == "@param" or op.get_lifetime() == lifetime::global; + }); } static bool is_dangling(instruction_ref ins) { return not is_global(ins) and is_borrowed(ins); } @@ -1320,7 +1380,7 @@ module::print_py(std::ostream& os, if(ins->name() == "@literal") { os << mname << ".add_literal("; - if(ins->get_shape().elements() < 10) + if(ins->get_shape().elements() < 1024) { os << "migraphx.create_argument("; print_py_shape(os, ins->get_shape()); diff --git a/src/onnx/parse_constant_of_shape.cpp b/src/onnx/parse_constant_of_shape.cpp index 1038a11ae0d..2013f9ce859 100644 --- a/src/onnx/parse_constant_of_shape.cpp +++ b/src/onnx/parse_constant_of_shape.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -80,6 +80,10 @@ struct parse_constant_of_shape : op_parser input.visit([&](auto ia) { dims.assign(ia.begin(), ia.end()); }); s = migraphx::shape{type, dims}; } + if(s.elements() == 0) + { + return info.add_instruction(make_op("undefined")); + } literal l_out{}; l_val.visit([&](auto val) { using val_type = std::remove_cv_t; diff --git a/src/onnx/parse_generic_op.cpp b/src/onnx/parse_generic_op.cpp index 023a3a41f3d..1b98caad37f 100644 --- a/src/onnx/parse_generic_op.cpp +++ b/src/onnx/parse_generic_op.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,8 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -83,13 +85,33 @@ struct parse_generic_op : op_parser std::vector args) const { auto op = parser.load(opd.op_name, info); + if(needs_contiguous(opd.op_name)) { std::transform(args.begin(), args.end(), args.begin(), [&](auto arg) { return info.make_contiguous(arg); }); } - return info.add_instruction(op, args); + + if(any_of(args, [&](const auto& arg) { return arg->get_shape().dynamic(); })) + { + return info.add_instruction(op, args); + } + + // Filter out args that have 0 elements + std::vector new_args{}; + std::copy_if(args.begin(), + args.end(), + std::back_inserter(new_args), + [&](const instruction_ref& arg) { return arg->get_shape().elements() > 0; }); + + // If all args have 0 elements, return an undefined instruction + if(new_args.empty()) + { + return info.add_instruction(make_op("undefined")); + } + + return info.add_instruction(op, new_args); } }; diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index 6bbeb5e73db..02ecba772d5 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -182,7 +182,7 @@ struct module_pm : module_pass_manager catch(const std::exception& e) { std::cerr << "Error " << p.name() << ": " << e.what() << std::endl; - auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); + auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); fs::path dirname = fs::temp_directory_path() / "migraphx"; fs::create_directories(dirname); std::string base = p.name() + std::to_string(clk) + ".mxr"; diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index 8b44064fa2f..a09f4e44945 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,9 +45,9 @@ static bool skip_propagate(instruction_ref ins) auto&& s = ins->get_shape(); if(s.broadcasted() and s.element_space() < s.elements()) return true; - auto alias = instruction::get_output_alias(ins, true); - if(alias != ins) - return skip_propagate(alias); + auto aliases = instruction::get_output_alias(ins, true); + if(aliases.size() == 1 and aliases.front() != ins) + return skip_propagate(aliases.front()); if(ins->is_undefined()) return true; return false; diff --git a/src/replace_allocate.cpp b/src/replace_allocate.cpp index e2f01e1dc91..e4d454ac09b 100644 --- a/src/replace_allocate.cpp +++ b/src/replace_allocate.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -41,32 +42,39 @@ std::unordered_map create_output_names(const modul std::unordered_map mod_output_names; auto returns = mod.get_returns(); - std::vector outputs_alias(returns.size()); - - std::transform(returns.begin(), returns.end(), outputs_alias.begin(), [](const auto& i) { - return instruction::get_output_alias(i); - }); + // Collect all allocation aliases from each return value + std::vector alloc_aliases; + // Use a join but perhaps a tuple output parameter might be better? + std::transform(returns.begin(), + returns.end(), + join_back_inserter(alloc_aliases), + [](const auto& i) { return instruction::get_output_alias(i); }); std::size_t index = 0; - if(outputs_alias.size() == 1 and mod.name().empty()) + if(mod.name().empty()) { - mod_output_names[outputs_alias.front()] = "output"; - } - // Preserve main module output buffer naming across migraphx versions - else if(mod.name() == "main") - { - for(auto ins : outputs_alias) + // Single return with empty module name: all aliases get "output" or "output_N" + if(alloc_aliases.size() == 1) { - mod_output_names[ins] = mod.name() + ":#output_" + std::to_string(index++); + mod_output_names[alloc_aliases.front()] = "output"; + } + else + { + for(auto ins : alloc_aliases) + { + mod_output_names[ins] = "output_" + std::to_string(index++); + } } } + // Preserve main module output buffer naming across migraphx versions else { - for(auto ins : outputs_alias) + for(auto ins : alloc_aliases) { mod_output_names[ins] = param_name(index++, mod.name() + ":#output_"); } } + return mod_output_names; } @@ -78,8 +86,10 @@ void insert_copy(module& m, const allocation_model& model) { if(ins->get_shape().any_of_dynamic()) continue; - auto alias = instruction::get_output_alias(ins); - if(alias->get_shape() == ins->get_shape()) + auto aliases = instruction::get_output_alias(ins); + if(std::any_of(aliases.begin(), aliases.end(), [&](instruction_ref alias) { + return alias->get_shape() == ins->get_shape(); + })) continue; auto insert_ins = std::next(ins); auto alloc = m.insert_instruction( diff --git a/src/shape.cpp b/src/shape.cpp index 678bf5b53f0..df4e0419f2f 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -390,9 +390,9 @@ std::size_t shape::ndim() const { if(this->dynamic()) { - return dyn_dims().size(); + return impl->m_dyn_dims.size(); } - return lens().size(); + return impl->m_lens.size(); } std::size_t shape::elements() const { return impl->elements(); } @@ -670,7 +670,7 @@ bool shape::computable() const { return is_computable(this->type()); } const std::vector& shape::dyn_dims() const { - if(not this->dynamic()) + if(ndim() > 0 and not this->dynamic()) { MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape"); } diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index aac7170cb4d..13dc8c86cb1 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -330,7 +330,6 @@ struct rebase_ambiguity_resolver // Returns the axes mapping that can be used for rebase auto resolve() { - std::vector> subs_to_insert; { axes_map_t axes_map = group_axes(desc->dimensions); @@ -340,6 +339,9 @@ struct rebase_ambiguity_resolver if(shortage_axes.empty()) return axes_map; + if(try_trivial_direct_mapping()) + return regroup_axes(); + process_axis_groups(axes_map, subs_to_insert); if(shortage_axes.size() == initial_shortage_count) @@ -352,10 +354,7 @@ struct rebase_ambiguity_resolver sort_hidden_axes_groups(); sort_moved_axes_groups(); - axes_map_t regroup_axes = group_axes(desc->dimensions); - renumber_axes(regroup_axes); - - return regroup_axes; + return regroup_axes(); } private: @@ -369,6 +368,13 @@ struct rebase_ambiguity_resolver return x / y; } + axes_map_t regroup_axes() + { + axes_map_t result = group_axes(desc->dimensions); + renumber_axes(result); + return result; + } + // Identifies axes where the target dimension is larger than current subdimensions // These are "shortage" axes that need subdimensions due to ambiguous axis assignment void find_shortage_axes(const axes_map_t& axes_map) @@ -385,6 +391,72 @@ struct rebase_ambiguity_resolver initial_shortage_count = shortage_axes.size(); } + bool try_trivial_direct_mapping() + { + if(desc->lens() != *dims) + return false; + if(not std::all_of( + desc->dimensions.begin(), desc->dimensions.end(), [&](const dimension& d) { + if(d.subdimensions.empty()) + return false; + if(d.len() == 1) + return true; + if(std::any_of(d.subdimensions.begin(), + d.subdimensions.end(), + [&](const dimension::sub& s) { + if(s.origin_axis().empty()) + return false; + if(s.origin_axis().size() != 1) + return true; + if(s.len == 1) + return false; + if(s.has_hidden_axis()) + return false; + return ((*dims)[s.origin_axis().front()] != s.len); + })) + return false; + if(d.subdimensions.size() == 1) + return true; + auto n1dims = std::count_if(d.subdimensions.begin(), + d.subdimensions.end(), + [](const dimension::sub& s) { return s.len == 1; }); + return n1dims + 1 == d.subdimensions.size(); + })) + return false; + std::vector axes; + for_each_subdimension(desc->dimensions, [&](auto& s) { + if(s.origin_axis().empty()) + return; + axes.push_back(s.origin_axis().front()); + }); + // TODO: Handle permutations + if(not std::is_sorted(axes.begin(), axes.end())) + return false; + for(std::size_t i : range(desc->dimensions.size())) + { + auto& dim = desc->dimensions[i]; + if(dim.subdimensions.empty()) + continue; + auto sub = std::find_if(dim.subdimensions.begin(), + dim.subdimensions.end(), + [&](const dimension::sub& s) { return s.len != 1; }); + if(sub == dim.subdimensions.end()) + sub = dim.subdimensions.begin(); + sub->expose(); + sub->axis = {i}; + + auto remove_axis = [](dimension::sub& s) { + s.axis.clear(); + s.hidden_axis.clear(); + s.len = 1; + }; + std::for_each(dim.subdimensions.begin(), sub, remove_axis); + std::for_each(std::next(sub), dim.subdimensions.end(), remove_axis); + } + shortage_axes.clear(); + return true; + } + // Processes each axis group to resolve ambiguous axis assignments // This is the core logic that fixes mismatches from reshape ambiguity // @@ -1538,48 +1610,56 @@ static std::vector find_permutation(const std::vector& // are generated from the subdimensions and steps 4-5 are generated with the // dimensions. std::vector -shape_transform_descriptor::generate(const std::vector& input_dims) const +shape_transform_descriptor::generate(const std::vector& input_dims, + bool no_broadcast) const { operation_list result; std::vector new_dims = input_dims.empty() ? dimensions : this->rebase(input_dims).dimensions; assert(input_dims.empty() or not new_dims.empty()); - // Need broadcast - if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) + if(no_broadcast) { - std::vector out_lens; - std::transform(new_dims.begin(), - new_dims.end(), - std::back_inserter(out_lens), - [](const dimension& d) { return d.len(); }); - auto startb = std::find_if_not(new_dims.begin(), new_dims.end(), &has_no_axes); - auto trailb = std::find_if_not(startb, new_dims.end(), &has_axes); - auto axis = std::distance(new_dims.begin(), startb); - auto extra_dims = axis + std::distance(trailb, new_dims.end()); - // Use broadcast instead of multibroadcast - if(std::all_of(trailb, new_dims.end(), &has_no_axes) and extra_dims > 0 and - axis < new_dims.size()) + for_each_subdimension(new_dims, &flatten_broadcasted_dim); + } + else + { + // Need broadcast + if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) { - result.push_back(make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}})); - new_dims.erase(trailb, new_dims.end()); - new_dims.erase(new_dims.begin(), new_dims.begin() + axis); + std::vector out_lens; + std::transform(new_dims.begin(), + new_dims.end(), + std::back_inserter(out_lens), + [](const dimension& d) { return d.len(); }); + auto startb = std::find_if_not(new_dims.begin(), new_dims.end(), &has_no_axes); + auto trailb = std::find_if_not(startb, new_dims.end(), &has_axes); + auto axis = std::distance(new_dims.begin(), startb); + auto extra_dims = axis + std::distance(trailb, new_dims.end()); + // Use broadcast instead of multibroadcast + if(std::all_of(trailb, new_dims.end(), &has_no_axes) and extra_dims > 0 and + axis < new_dims.size()) + { + result.push_back(make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}})); + new_dims.erase(trailb, new_dims.end()); + new_dims.erase(new_dims.begin(), new_dims.begin() + axis); + } + else + { + result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + } } - else + // If all the dimensions have no axes then there isnt anthing else to do + // so just clear the new_dims + if(std::all_of(new_dims.begin(), new_dims.end(), &has_no_axes)) + new_dims.clear(); + // Flatten broadcasted dimensions + for(auto& d : new_dims) { - result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + if(d.subdimensions.size() != 1) + continue; + flatten_broadcasted_dim(d.subdimensions.front()); } } - // If all the dimensions have no axes then there isnt anthing else to do - // so just clear the new_dims - if(std::all_of(new_dims.begin(), new_dims.end(), &has_no_axes)) - new_dims.clear(); - // Flatten broadcasted dimensions - for(auto& d : new_dims) - { - if(d.subdimensions.size() != 1) - continue; - flatten_broadcasted_dim(d.subdimensions.front()); - } // Need squeeze reshape if(std::any_of(new_dims.begin(), new_dims.end(), [](const dimension& d) { if(d.subdimensions.size() != 1) @@ -1798,6 +1878,28 @@ std::vector> shape_transform_descriptor::common_axes_ma return result; } +std::vector shape_transform_descriptor::get_dst_axes_from_src(std::size_t axis) const +{ + std::vector result; + for(auto i : range(dimensions.size())) + { + const auto& d = dimensions[i]; + auto it = std::find_if(d.subdimensions.begin(), d.subdimensions.end(), [&](auto& s) { + if(s.axis.empty()) + return false; + return s.axis.front() == axis; + }); + if(it == d.subdimensions.end()) + continue; + // If it maps to a subdimension then exit as there isn't a clear mapping + if(d.len() != it->len) + return {}; + result.push_back(i); + } + // TODO: Put it in the correct order if there is multiple axes + return result; +} + bool shape_transform_descriptor::empty() const { return dimensions.empty(); } std::vector shape_transform_descriptor::lens() const @@ -1939,5 +2041,199 @@ std::vector optimize_shape_transforms(const std::vector& return sd.generate(); } +// Replace broadcasted dimensions with size 1, and set the stride to the previous stride +static shape unbroadcast(const shape& s) +{ + std::vector lens = s.lens(); + std::vector strides = s.strides(); + auto stride_it = std::find_if( + s.strides().begin(), s.strides().end(), [](auto stride) { return stride != 0; }); + std::size_t prev_stride = stride_it == s.strides().end() ? 1 : *stride_it; + for(std::size_t i = 0; i < lens.size(); ++i) + { + if(strides[i] == 0) + { + lens[i] = 1; + strides[i] = prev_stride; + } + else + { + prev_stride = strides[i]; + } + } + return {s.type(), lens, strides}; +} + +static std::size_t adjust_strided_shape(shape& s, std::size_t n) +{ + auto lens = s.lens(); + auto strides = s.strides(); + + // Insert a dim of 1 so it can be used to handle steps + if(std::none_of(strides.begin(), strides.end(), [](auto stride) { return stride == 1; }) and + std::any_of(strides.begin(), strides.end(), [](auto stride) { return stride != 0; })) + { + lens.push_back(1); + strides.push_back(1); + } + + auto last_axis = std::max_element(strides.begin(), strides.end()) - strides.begin(); + auto total_elements = std::max(1, strides[last_axis] * lens[last_axis]); + // Add a dim of 1 to the front so it can handle extra elements + auto extra = n / total_elements; + if(extra > 1) + { + strides.insert(strides.begin(), total_elements); + lens.insert(lens.begin(), 1); + } + s = shape(s.type(), lens, strides); + return std::max(1, extra); +} + +template +static std::vector select_mask(const std::vector& slice_mask, + const Range& r) +{ + std::vector result; + std::transform(slice_mask.begin(), + slice_mask.end(), + r.begin(), + join_back_inserter(result), + [](std::size_t mask, std::size_t n) -> std::vector { + if(mask == 0) + return {}; + return {n}; + }); + return result; +} + +// Generate the shape transforms for strided view +optional> +generate_shape_transforms_for(shape s, const std::vector& idims, std::int64_t offset) +{ + std::vector result; + if(s.lens().empty()) + return std::nullopt; + std::size_t ielements = + std::accumulate(idims.begin(), idims.end(), std::size_t(1), std::multiplies<>()); + auto extra = adjust_strided_shape(s, ielements); + // TODO: Improve handling of multiple dimensions, for now just reshape to 1 dimension + if(idims.size() != 1) + { + result.push_back(make_op("reshape", {{"dims", {ielements}}})); + auto ops = generate_shape_transforms_for(s, {ielements}, offset); + if(not ops) + return std::nullopt; + result.insert(result.end(), ops->begin(), ops->end()); + return result; + } + auto pre_broadcast = unbroadcast(s); + auto perm = find_permutation(pre_broadcast); + auto iperm = invert_permutation(perm); + auto pre_transpose = reorder_shape(pre_broadcast, perm); + + std::vector start_lens; + std::adjacent_difference(pre_transpose.strides().begin(), + pre_transpose.strides().end(), + std::back_inserter(start_lens), + [](auto y, auto x) -> std::size_t { + assert(x >= y); + assert(y != 0); + if((x % y) != 0) + return 0; + return x / y; + }); + if(std::any_of(start_lens.begin(), start_lens.end(), [](auto len) { return len == 0; })) + return std::nullopt; + start_lens.front() = extra > 1 ? extra : pre_transpose.lens().front(); + + std::size_t nelements = + std::accumulate(start_lens.begin(), start_lens.end(), std::size_t(1), std::multiplies<>()); + + if(nelements < pre_transpose.elements() * extra) + return std::nullopt; + + std::vector start_mask(start_lens.size(), 0); + if(offset != 0) + { + shape start_shape{shape::float_type, start_lens}; + auto idx = start_shape.multi(offset); + + std::vector overhead; + std::transform(start_lens.begin(), + start_lens.end(), + pre_transpose.lens().begin(), + std::back_inserter(overhead), + [](auto start_len, auto len) { return start_len - len; }); + if(std::equal( + idx.begin(), idx.end(), overhead.begin(), overhead.end(), [](auto i, auto over) { + return i <= over; + })) + { + start_mask = reorder_dims(idx, iperm); + offset = 0; + } + } + + std::vector pre_slice_mask; + std::transform(start_lens.begin(), + start_lens.end(), + pre_transpose.lens().begin(), + std::back_inserter(pre_slice_mask), + [](auto start_len, auto len) -> std::size_t { + if(start_len == len) + return 0; + return len; + }); + auto slice_mask = reorder_dims(pre_slice_mask, iperm); + + std::vector blens = reorder_dims(start_lens, iperm); + std::transform(s.lens().begin(), + s.lens().end(), + blens.begin(), + blens.begin(), + [](auto len, auto blen) -> std::size_t { + if(blen == 1) + return len; + return blen; + }); + + std::vector ops; + ops.push_back(make_op("multibroadcast", {{"out_lens", blens}})); + ops.push_back(make_op("transpose", {{"permutation", invert_permutation(perm)}})); + ops.push_back(make_op("reshape", {{"dims", start_lens}})); + std::reverse(ops.begin(), ops.end()); + + auto desc = shape_transform_descriptor::create({nelements}, ops); + + auto end = offset + nelements; + if(offset != 0 or nelements != ielements) + { + + // If the end is out of bounds broadcast it to pad it + if(end > ielements) + { + result.push_back(make_op("broadcast", {{"axis", 1}, {"out_lens", {2, ielements}}})); + result.push_back(make_op("reshape", {{"dims", {2 * ielements}}})); + } + result.push_back(make_op("slice", {{"axes", {0}}, {"starts", {offset}}, {"ends", {end}}})); + } + + auto opt_ops = desc.generate(); + result.insert(result.end(), opt_ops.begin(), opt_ops.end()); + + std::vector axes = select_mask(slice_mask, range(slice_mask.size())); + + if(not axes.empty()) + { + std::vector starts = select_mask(slice_mask, start_mask); + std::vector ends = select_mask(slice_mask, s.lens()); + std::transform(ends.begin(), ends.end(), starts.begin(), ends.begin(), std::plus<>{}); + + result.push_back(make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}})); + } + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 5199b850f4f..3bdd80e362b 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -68,12 +68,14 @@ static auto from_int4() { return match::make_predicate_matcher([](instruction_ref start) { return fix([&](auto self, instruction_ref ins) { - auto alias = instruction::get_output_alias(ins); - if(contains({"reshape", "dequantizelinear"}, alias->name())) - return self(alias->inputs().front()); - if(alias->name() == "concat") - return all_of(alias->inputs(), self); - return alias->name() == "unpack_int4"; + auto aliases = instruction::get_output_alias(ins); + return std::any_of(aliases.begin(), aliases.end(), [&](instruction_ref alias) { + if(contains({"reshape", "dequantizelinear"}, alias->name())) + return self(alias->inputs().front()); + if(alias->name() == "concat") + return all_of(alias->inputs(), self); + return alias->name() == "unpack_int4"; + }); })(start); }); } @@ -948,11 +950,12 @@ struct find_concat_op { auto concat_lens = ins.front()->get_shape().lens(); concat_lens.erase(concat_lens.begin() + axis); + auto front_type = ins.front()->get_shape().type(); return std::all_of(ins.begin(), ins.end(), [&](auto i) { auto lens = i->get_shape().lens(); lens.erase(lens.begin() + axis); - return lens == concat_lens; + return lens == concat_lens and i->get_shape().type() == front_type; }); } @@ -1031,8 +1034,8 @@ struct find_concat_op }; auto pred = [](auto i, auto j) { return i->get_operator() == j->get_operator() and - i->inputs().size() == i->inputs().size() and - i->outputs().size() == i->outputs().size(); + i->inputs().size() == j->inputs().size() and + i->outputs().size() == j->outputs().size(); }; group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred); if(args.size() == 1) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 51b5e199db5..2c233214ffd 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -21,15 +21,20 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include +#include #include #include #include #include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -40,25 +45,112 @@ #include #include #include +#include +#include +#include #include +#include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { -const auto& reshaper_names() + +template +instruction_ref +insert_auto_reshape(module& m, instruction_ref ins, const Dims& dims, instruction_ref input) { - // clang-format off - static const std::unordered_set names = { - "flatten", - "reshape", - "contiguous", - "squeeze", - "unsqueeze" - }; - // clang-format on - return names; + assert(std::all_of(dims.begin(), dims.end(), [](auto i) { return i > 0; })); + if(std::equal(dims.begin(), + dims.end(), + input->get_shape().lens().begin(), + input->get_shape().lens().end())) + { + return input; + } + + auto curr_lens = input->get_shape().lens(); + // Check if we can use squeeze (removing dimensions of size 1) + if(curr_lens.size() > dims.size()) + { + // Potential squeeze - check if we're just removing 1s + std::vector axes_to_squeeze; + std::size_t target_idx = 0; + for(std::size_t curr_idx = 0; curr_idx < curr_lens.size(); ++curr_idx) + { + if(curr_lens[curr_idx] == 1) + { + axes_to_squeeze.push_back(curr_idx); + } + else + { + if(target_idx >= dims.size() or curr_lens[curr_idx] != dims[target_idx]) + { + axes_to_squeeze.clear(); + break; + } + ++target_idx; + } + } + if(not axes_to_squeeze.empty() and target_idx == dims.size()) + { + return m.insert_instruction( + ins, make_op("squeeze", {{"axes", axes_to_squeeze}}), input); + } + } + // Check if we can use unsqueeze (adding dimensions of size 1) + else if(curr_lens.size() < dims.size()) + { + // Potential unsqueeze - check if we're just adding 1s + std::vector axes_to_unsqueeze; + std::size_t curr_idx = 0; + for(std::size_t target_idx = 0; target_idx < dims.size(); ++target_idx) + { + if(dims[target_idx] == 1) + { + axes_to_unsqueeze.push_back(target_idx); + } + else + { + if(curr_idx >= curr_lens.size() or dims[target_idx] != curr_lens[curr_idx]) + { + axes_to_unsqueeze.clear(); + break; + } + ++curr_idx; + } + } + if(not axes_to_unsqueeze.empty() and curr_idx == curr_lens.size()) + { + return m.insert_instruction( + ins, make_op("unsqueeze", {{"axes", axes_to_unsqueeze}}), input); + } + } + + return m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), input); +} + +template +instruction_ref insert_auto_reshape(module& m, + instruction_ref ins, + const std::initializer_list& dims, + instruction_ref input) +{ + return insert_auto_reshape(m, ins, std::vector(dims), input); +} + +instruction_ref +insert_ops(module& m, instruction_ref ins, const std::vector& ops, instruction_ref input) +{ + return std::accumulate( + ops.begin(), ops.end(), input, [&](instruction_ref x, const operation& op) { + return m.insert_instruction(ins, op, x); + }); } struct find_nested_shape_transforms @@ -114,9 +206,7 @@ struct find_nested_shape_transforms auto opt_ops = optimize_shape_transforms(x->get_shape().lens(), ops); if(ops == opt_ops) return; - auto y = x; - for(const auto& op : opt_ops) - y = m.insert_instruction(ins, op, y); + auto y = insert_ops(m, ins, opt_ops, x); m.replace_instruction(ins, y); } } @@ -159,10 +249,9 @@ struct find_op_shape_transform_op auto matcher() const { - auto reshapes = match::name(shape_transform_ops()); - auto match_op = match::any_of(match::reduce(), match::pointwise()); - auto x_op = - match_op(match::none_of(fusable_split())); + auto reshapes = match::name(shape_transform_ops()); + auto match_op = match::any_of(match::reduce(), match::pointwise()); + auto x_op = match_op(match::none_of(fusable_split())); auto reshapes_x_op = reshapes(match::arg(0)(match::skip(reshapes())(x_op.bind("x")))); return match_op(match::any_of[match::inputs()](reshapes_x_op.bind("input"))); } @@ -243,16 +332,19 @@ struct find_op_shape_transform_op return desc.elements() == ins->get_shape().elements(); } - static std::vector generate(const shape_transform_descriptor& desc, - const shape& input_shape) + static std::vector + generate(const shape_transform_descriptor& desc, const shape& input_shape, bool no_broadcast) { if(input_shape.scalar() and input_shape.elements() == 1 and input_shape.ndim() == 1) { - return {make_op("multibroadcast", {{"out_lens", desc.lens()}})}; + auto out_lens = desc.lens(); + if(no_broadcast) + std::fill(out_lens.begin(), out_lens.end(), 1); + return {make_op("multibroadcast", {{"out_lens", out_lens}})}; } else { - return desc.generate(input_shape.lens()); + return desc.generate(input_shape.lens(), no_broadcast); } } @@ -340,9 +432,11 @@ struct find_op_shape_transform_op return; } - auto reshape_input = [&](const auto& ins_to_insert, const auto& gdesc) { - return [&](auto input) { - auto gops = generate(gdesc, input->get_shape()); + auto reshape_input = [&](const auto& ins_to_insert, + const auto& gdesc, + bool no_broadcast = false) { + return [&, no_broadcast](auto input) { + auto gops = generate(gdesc, input->get_shape(), no_broadcast); return std::accumulate( gops.begin(), gops.end(), input, [&](auto start, const auto& op) { return m.insert_instruction(ins_to_insert, op, start); @@ -366,7 +460,7 @@ struct find_op_shape_transform_op std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(input == input_ins) return new_input_ins; - return reshape_input(ins, desc.to_common_from_dst())(input); + return reshape_input(ins, desc.to_common_from_dst(), true)(input); }); // Replace old x_ins just in case it is used more than once assert(x_ins->get_shape().lens() == new_x_ins->get_shape().lens()); @@ -379,26 +473,116 @@ struct find_op_shape_transform_op } }; +struct find_slice_shape_transforms +{ + static const auto& shape_transform_ops() + { + static const std::unordered_set names = { + "reshape", + "squeeze", + "unsqueeze", + "flatten", + "transpose", + "contiguous", + "multibroadcast", + "broadcast", + }; + return names; + } + + auto matcher() const + { + auto reshapes = match::name(shape_transform_ops()); + auto slice_op = match::name("slice")(match::arg(0)(match::used_once()), + match::none_of(match::is_constant())); + return reshapes(reshapes(match::none_of[match::outputs()](reshapes())), + match::arg(0)(match::skip(reshapes())(slice_op.bind("slice")))); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto slice = mr.instructions["slice"]; + auto slice_op = slice->get_operator().to_value(); + auto axes = slice_op.at("axes").to_vector(); + + if(ins->get_shape().scalar()) + return; + + std::vector ops; + auto x = ins; + while(contains(shape_transform_ops(), x->get_operator().name())) + { + ops.push_back(x->get_operator()); + x = x->inputs().front(); + } + if(x != slice) + return; + x = x->inputs().front(); + std::reverse(ops.begin(), ops.end()); + auto desc = shape_transform_descriptor::create(slice->get_shape().lens(), ops); + + std::vector new_axes; + std::transform(axes.begin(), + axes.end(), + join_back_inserter(new_axes), + [&](auto axis) -> std::vector { + auto result = desc.get_dst_axes_from_src(axis); + if(result.size() != 1) + return {}; + return result; + }); + + // Optimizes shape transforms if the slice cant be optimized + if(axes.size() != new_axes.size()) + { + auto opt_ops = desc.generate(); + auto y = insert_ops(m, ins, opt_ops, slice); + m.replace_instruction(ins, y); + return; + } + slice_op["axes"] = new_axes; + + auto new_desc = desc.rebase(slice->inputs().front()->get_shape().lens()); + if(new_desc.empty()) + return; + new_desc.simplify(); + + auto opt_ops = new_desc.generate(); + auto y = insert_ops(m, ins, opt_ops, x); + y = m.insert_instruction(ins, make_op("slice", slice_op), y); + m.replace_instruction(ins, y); + } +}; + struct find_nop_reshapes { auto matcher() const { - auto reshapes = reshaper_names(); - reshapes.insert("as_shape"); - reshapes.insert("broadcast"); - reshapes.insert("concat"); - reshapes.insert("convert"); - reshapes.insert("multibroadcast"); - reshapes.insert("pad"); - reshapes.insert("slice"); - reshapes.insert("step"); - reshapes.insert("transpose"); - reshapes.insert("reduce_mean"); - reshapes.insert("reduce_max"); - reshapes.insert("reduce_min"); - reshapes.insert("reduce_sum"); - reshapes.insert("reduce_prod"); - return match::name(reshapes)(match::same_shape(match::arg(0))); + // clang-format off + static const std::unordered_set names = { + "flatten", + "reshape", + "contiguous", + "squeeze", + "unsqueeze", + "as_shape", + "broadcast", + "concat", + "convert", + "multibroadcast", + "pad", + "slice", + "step", + "transpose", + "reduce_mean", + "reduce_max", + "reduce_min", + "reduce_sum", + "reduce_prod", + }; + + return match::name(names)(match::same_shape(match::arg(0))); } void apply(module& m, const match::matcher_result& mr) const @@ -790,165 +974,344 @@ struct find_nested_concat } }; -struct find_resize +struct find_gather { - auto matcher() const - { - return match::name("gather")( - match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); - } - - void apply(module& m, const match::matcher_result& r) const + struct arithmetic_segment { - auto ins = r.result; - auto ins_rsp = r.instructions["data"]; - auto ins_ind = r.instructions["ind"]; + int64_t base = 0; + int64_t stride = 0; + std::size_t count = 0; - // resize input shape - if(ins_rsp->get_shape().lens().size() != 1) + template + static std::vector from_ints(Iterator begin, Iterator end) { - return; + std::vector result(std::distance(begin, end)); + par_transform( + begin, end, result.begin(), [](auto x) { return arithmetic_segment{x, 1, 1}; }); + return result; } - // resize output shape - const auto& in_shape = ins_rsp->inputs().front()->get_shape(); - const auto& out_shape = ins->get_shape(); - // check if output shape is multiple of input shape - const auto& in_lens = in_shape.lens(); - const auto& out_lens = out_shape.lens(); - if(in_lens.size() != out_lens.size()) + template + static Iterator find_largest(Iterator start, Iterator last, OutputIterator out) { - return; + for(auto it = start; it != last;) + { + auto [seg, next_it] = find(it, last); + it = next_it; + *out = seg; + out++; + } + return last; } - // output shape must be multiple of input shape - std::vector is_multi(in_lens.size()); - std::transform( - in_lens.begin(), in_lens.end(), out_lens.begin(), is_multi.begin(), [](auto x, auto y) { - return (y % x == 0); - }); - if(not std::all_of(is_multi.begin(), is_multi.end(), [](auto b) { return b; })) + template + static Iterator find_n(Iterator start, Iterator last, std::size_t n, OutputIterator out) { - return; + for(auto it = start; it != last;) + { + if(std::distance(it, last) < n) + return it; + auto [seg, next_it] = find(it, it + n); + if(next_it != it + n) + return next_it; + it = next_it; + *out = seg; + out++; + } + return last; } - // output must be multiple of inputs - std::vector scales(in_lens.size()); - std::transform( - in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) { - return y / x; - }); + static std::vector + make_segments(const std::vector& segments, bool uniform = true) + { + std::vector result; + auto [first_seg, first_it] = find(segments.begin(), segments.end()); + result.push_back(first_seg); + // Try to find segments that are the same size + auto it = find_n(first_it, segments.end(), first_seg.count, std::back_inserter(result)); + if(it != segments.end()) + { + if(uniform) + return {}; + result.resize(1); + find_largest(first_it, segments.end(), std::back_inserter(result)); + } + return result; + } - // if ind is not constant, cannot optimize - std::vector vec_ind; - auto arg_ind = ins_ind->eval(); - if(arg_ind.empty()) + static std::vector shift(std::vector segments, + std::int64_t shift) { - return; + par_transform( + segments.begin(), segments.end(), segments.begin(), [&](arithmetic_segment x) { + x.base += shift; + return x; + }); + return segments; } - arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); - if(not all_of(range(out_shape.elements()), [&](auto i) { - auto out_idx = out_shape.multi(i); - auto in_idx = out_idx; - std::transform(out_idx.begin(), - out_idx.end(), - scales.begin(), - in_idx.begin(), - [&](auto io, auto scale) { return io - (io % scale); }); - return vec_ind[i] == vec_ind[out_shape.index(in_idx)]; - })) + + /// Detect arithmetic segment pattern + template + static std::pair find(Iterator begin, Iterator end) { - return; + std::size_t length = std::distance(begin, end); + if(length == 0) + return std::make_pair(arithmetic_segment{}, begin); + if(length == 1) + return std::make_pair(*begin, std::next(begin)); + auto start = *begin; + auto lstride = std::next(begin)->base - start.base; + if(lstride < 0) + return std::make_pair(*begin, std::next(begin)); + auto diff = + std::adjacent_find(begin, end, [&](arithmetic_segment x, arithmetic_segment y) { + return y.base - x.base != lstride; + }); + if(diff != end) + diff++; + return std::make_pair( + arithmetic_segment{start.base, lstride, std::size_t(std::distance(begin, diff))}, + diff); } - // wrap up shapes for multibroadcast - std::vector> dim_scales; - std::transform(in_lens.begin(), - in_lens.end(), - out_lens.begin(), - std::back_inserter(dim_scales), - [](auto x, auto y) { return std::make_pair(x, y / x); }); - - std::vector in_dims; - std::vector out_dims; - for(auto& isp : dim_scales) + static shape make_strided_view(std::vector segments) { - in_dims.push_back(isp.first); - out_dims.push_back(isp.first * isp.second); - if(isp.first == 1 or isp.second == 1) + std::vector lens; + std::vector strides; + + do { - continue; + segments = make_segments(segments); + if(segments.empty()) + return {}; + auto seg = segments.front(); + if(seg.stride < 0) + return {}; + if(std::any_of(segments.begin(), segments.end(), [](const arithmetic_segment& seg) { + return seg.base < 0; + })) + return {}; + if(not std::all_of( + segments.begin(), segments.end(), [&](const arithmetic_segment& seg) { + return seg.stride == segments.front().stride and + seg.count == segments.front().count; + })) + return {}; + lens.push_back(seg.count); + strides.push_back(seg.stride); + } while(segments.size() > 1); + + std::reverse(lens.begin(), lens.end()); + std::reverse(strides.begin(), strides.end()); + + if(std::none_of( + strides.begin(), strides.end(), [](auto pstride) { return pstride == 1; })) + { + lens.push_back(1); + strides.push_back(1); } - out_dims.back() = isp.first; - in_dims.push_back(1); - out_dims.push_back(isp.second); + return {shape::float_type, lens, strides}; } - auto in_rsp = ins_rsp->inputs().front(); - auto rsp_data = m.insert_instruction( - ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); - auto mb_rsp = m.insert_instruction( - ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); - std::vector rsp_dims(out_lens.begin(), out_lens.end()); - m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp); - } -}; + template + static std::optional + transform_indices(const Indices& indices, module& m, instruction_ref start) + { + auto isegments = from_ints(indices.begin(), indices.end()); + std::int64_t offset = isegments.front().base; + auto s = make_strided_view(shift(std::move(isegments), -offset)); + if(s.lens().empty() or s.elements() != indices.size()) + return std::nullopt; + auto ops = generate_shape_transforms_for(s, {start->get_shape().elements()}, offset); + if(not ops.has_value()) + return std::nullopt; + return insert_ops(m, std::next(start), *ops, start); + } + }; -struct find_where_op -{ + static std::vector build_flat_gather_indices(instruction_ref gather_ins, + const argument& indices_arg, + std::size_t axis_index) + { + auto data_ins = gather_ins->inputs()[0]; + auto output_dims = gather_ins->get_shape().lens(); + const auto r_in = data_ins->get_shape().lens().size(); + const auto r_idx = indices_arg.get_shape().lens().size(); + auto data_shape = data_ins->get_shape().as_standard(); + auto indices_shape = indices_arg.get_shape().as_standard(); + assert(axis_index < r_in); + + shape output_s{shape::float_type, output_dims}; // element type doesn't matter here + const auto out_n = output_s.elements(); + std::vector flat(out_n); + std::iota(flat.begin(), flat.end(), 0); + + auto indices = indices_arg.to_vector(); + + transform(flat, flat.begin(), [&](std::size_t out_lin) -> std::int64_t { + // 1) output linear -> output multi-index + auto out_multi = output_s.multi(out_lin); + + // 2) isolate the "indices" coordinates from the output coords (inserted at `axis`) + std::vector idx_multi(r_idx); + std::copy(out_multi.begin() + axis_index, + out_multi.begin() + axis_index + r_idx, + idx_multi.begin()); + + // 3) look up the actual index value (may be negative) + const std::int64_t idx_lin = indices_shape.index(idx_multi); + const std::int64_t axis_len = data_shape.lens().at(axis_index); + auto idx_val = indices.at(idx_lin); + + // Normalize negative indices into [0, axis_len) + if(idx_val < 0) + idx_val += axis_len; + + assert(idx_val >= 0 and idx_val < axis_len); + + // 4) construct corresponding INPUT multi-index + std::vector in_multi(r_in); + + // copy dims before axis + std::copy(out_multi.begin(), out_multi.begin() + axis_index, in_multi.begin()); + + // axis coordinate from indices + in_multi.at(axis_index) = idx_val; + + // copy dims after axis; they are shifted by r_idx in output + std::copy(out_multi.begin() + axis_index + r_idx, + out_multi.end(), + in_multi.begin() + axis_index + 1); + + // 5) map input multi-index -> flat offset in contiguous buffer + const auto in_lin = data_shape.index(in_multi); + return in_lin; + }); + + return flat; + } auto matcher() const { return match::name("gather")( - match::args(match::name("reshape")(match::arg(0)(match::name("concat").bind("data"))), - match::is_constant().bind("ind"))); + match::none_of(match::is_constant()), + match::args(match::any().bind("data"), match::is_constant().bind("indices"))); } void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; - auto concat = r.instructions["data"]; - auto ins_ind = r.instructions["ind"]; - std::vector vec_ind; - auto arg_ind = ins_ind->eval(); - arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); - // ind has to be the same value - auto val = vec_ind.front(); - if(not std::all_of(vec_ind.begin(), vec_ind.end(), [&](auto v) { return (v == val); })) - { + auto ins = r.result; + auto indices_ins = r.instructions["indices"]; + auto data_ins = r.instructions["data"]; + auto gather_op = any_cast(ins->get_operator()); + const auto& dlens = data_ins->get_shape().lens(); + if(dlens.empty()) return; - } - // concat axis must be 0 - auto op = any_cast(concat->get_operator()); - if(op.axis != 0) - { + const auto axis_index = static_cast( + tune_axis(static_cast(dlens.size()), gather_op.axis, gather_op.name())); + const auto axis_len = dlens.at(axis_index); + if(axis_len == 0) return; - } - // check concat inputs, it has to be 2 and have the same shape - const auto& inputs = concat->inputs(); - if(inputs.size() != 2) - { + auto arg_ind = indices_ins->eval(); + if(arg_ind.empty()) return; - } - if(inputs.at(0)->get_shape() != inputs.at(1)->get_shape()) - { + + std::vector indices_values; + arg_ind.visit([&](auto v) { + indices_values.resize(v.size()); + std::transform(v.begin(), v.end(), indices_values.begin(), [](auto x) { + return static_cast(x); + }); + }); + if(indices_values.empty()) return; - } - if(inputs.at(0)->get_shape().lens() != ins_ind->get_shape().lens()) - { + + const auto& indices_shape = indices_ins->get_shape(); + if(indices_shape.elements() != indices_values.size()) return; - } - if(val) + // Scalar indices should be rewritten to a normal gather + assert(not indices_shape.scalar() or indices_shape.ndim() != 1 or + indices_shape.elements() != 1); + + // Normalize negative indices using transform + std::transform(indices_values.begin(), + indices_values.end(), + indices_values.begin(), + [axis_len](auto idx) { + if(idx < 0) + idx += static_cast(axis_len); + return idx; + }); + + // Validate all indices are in bounds + bool all_valid = + std::all_of(indices_values.begin(), indices_values.end(), [axis_len](auto idx) { + return idx >= 0 and idx < static_cast(axis_len); + }); + if(not all_valid) + return; + + // Create indices argument with normalized values + shape normalized_indices_shape{shape::int64_type, indices_shape.lens()}; + literal indices_lit(normalized_indices_shape, indices_values.begin(), indices_values.end()); + argument indices_arg = indices_lit.get_argument(); + + // Sanity check: ensure the argument shape matches + assert(indices_arg.get_shape().lens() == indices_shape.lens()); + assert(indices_arg.get_shape().elements() == indices_values.size()); + + std::optional new_ins = std::nullopt; + + if(data_ins->get_shape().ndim() == 1 and indices_ins->get_shape().ndim() == 1) { - m.replace_instruction(ins, inputs.at(0)); + new_ins = arithmetic_segment::transform_indices(indices_values, m, data_ins); } else { - m.replace_instruction(ins, inputs.at(1)); + auto data_1d = + insert_auto_reshape(m, ins, {data_ins->get_shape().elements()}, data_ins); + auto new_indices = build_flat_gather_indices(ins, indices_arg, axis_index); + new_ins = arithmetic_segment::transform_indices(new_indices, m, data_1d); + } + + if(not new_ins.has_value()) + return; + + auto reshaped = insert_auto_reshape(m, ins, ins->get_shape().lens(), *new_ins); + + m.replace_instruction(ins, reshaped); + } +}; + +struct find_gather_scalar +{ + auto matcher() const + { + auto scalar_indices = + match::all_of(match::scalar_shape(), match::ndim(1), match::nelements(1)); + return match::name("gather")( + match::none_of(match::is_constant()), + match::args(match::any().bind("data"), scalar_indices.bind("indices"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto indices_ins = r.instructions["indices"]; + auto data_ins = r.instructions["data"]; + + auto new_indices = + m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0}}}), indices_ins); + auto new_gather = m.insert_instruction(ins, ins->get_operator(), data_ins, new_indices); + auto reshaped = insert_auto_reshape(m, ins, ins->get_shape().lens(), new_gather); + if(ins->get_shape().scalar() and ins->get_shape().ndim() == 1) + { + reshaped = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), reshaped); } + m.replace_instruction(ins, reshaped); } }; @@ -965,11 +1328,11 @@ struct find_reshape_cont void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; + auto ins = r.result; auto cont_input = r.instructions["input"]; - auto in_ins = r.instructions["rsp"]; + auto in_ins = r.instructions["rsp"]; - auto lens = cont_input->get_shape().lens(); + auto lens = cont_input->get_shape().lens(); std::vector dims(lens.begin(), lens.end()); if(in_ins->get_shape() != ins->get_shape()) @@ -1420,13 +1783,16 @@ struct find_flatten void simplify_reshapes::apply(module& m) const { + match::find_matches(m, find_gather_scalar{}); + dead_code_elimination{}.apply(m); + if(enable_gather_rewrite) + match::find_matches(m, find_gather{}); m.repeat_while_changes(depth, [&] { match::find_matches(m, - find_where_op{}, - find_resize{}, find_nop_reshapes{}, find_flatten{}, find_reshape_cont{}, + find_slice_shape_transforms{}, find_nested_shape_transforms{}, find_concat_slice{}, find_concat_transpose{}, diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 3188b00563f..d1d8930b956 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -119,8 +119,10 @@ struct splitter if(rins == rm->begin()) return; // We want to know what instructions are live after the split instruction - auto ins = instruction::get_output_alias(std::prev(rins)); - if(not contains(splits, ins)) + auto aliases = instruction::get_output_alias(std::prev(rins)); + if(not std::any_of(aliases.begin(), aliases.end(), [&](instruction_ref ins) { + return contains(splits, ins); + })) return; std::copy_if(live_set.begin(), live_set.end(), diff --git a/src/targets/cpu/copy.cpp b/src/targets/cpu/copy.cpp index 4c4af2b7164..9d917ca532c 100644 --- a/src/targets/cpu/copy.cpp +++ b/src/targets/cpu/copy.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,9 +54,9 @@ struct cpu_copy : reduce_dims_base, auto_register_op return result.reshape(output_shape); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/cpu/gather.cpp b/src/targets/cpu/gather.cpp index 40bc556b961..877874ba934 100644 --- a/src/targets/cpu/gather.cpp +++ b/src/targets/cpu/gather.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -77,9 +77,9 @@ struct cpu_gather : auto_register_op return args.back(); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp index b05cad85246..e650e5084a1 100644 --- a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp +++ b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -296,9 +296,9 @@ struct dnnl_op : auto_register_op return execute(ctx, args); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } value compile(context&, const shape& output_shape, std::vector inputs) { diff --git a/src/targets/cpu/include/migraphx/cpu/pointwise.hpp b/src/targets/cpu/include/migraphx/cpu/pointwise.hpp index ece5498c839..e0ca9c9771f 100644 --- a/src/targets/cpu/include/migraphx/cpu/pointwise.hpp +++ b/src/targets/cpu/include/migraphx/cpu/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -360,9 +360,9 @@ struct cpu_unary : reduce_dims_base, auto_register_op> return result.reshape(output_shape); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; @@ -401,9 +401,9 @@ struct cpu_binary : reduce_dims_base, auto_register_op> return result.reshape(output_shape); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/compile_miopen.cpp b/src/targets/gpu/compile_miopen.cpp index 583601bdda1..c726f8be686 100644 --- a/src/targets/gpu/compile_miopen.cpp +++ b/src/targets/gpu/compile_miopen.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,9 +52,9 @@ struct miopen_op return op.compute_shape(inputs); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; MIGRAPHX_REGISTER_OP(miopen_op); diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 50c252ac08f..7882b5f14f7 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -72,9 +74,9 @@ struct precompile_op return op.compute_shape(inputs, mods); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; @@ -255,12 +257,22 @@ struct compile_plan }); auto bench_ins = bench_mm->add_instruction( cr->ins->get_operator(), bench_ins_inputs, cr->ins->module_inputs()); + bench_mm->add_return({bench_ins}); cr->replace.replace(*bench_mm, bench_ins); // do dead code elimination - run_passes(*bench_mm, {dead_code_elimination{}}); - // by default, measure runtime with bundle of 1 benchmark config, - // repeat 20 times - auto t = time_program(*ctx, bench_prog, cr->replace.fill_map, 1, 20); + run_passes(*bench_mm, + { + eliminate_identity{}, + dead_code_elimination{}, + memory_coloring{"hip::allocate"}, + }); + if(trace_level > 2) + std::cout << bench_prog << std::endl; + auto t = time_program(*ctx, + bench_prog, + cr->replace.fill_map, + /* bundle */ 10, + /* nrun */ 20); if(trace_level > 1) std::cout << t << "ms" << std::endl; return t; diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index 4d04394ce69..ac925269fbc 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -251,9 +251,9 @@ struct miopen_fusion return pack(f(self.ops, "ops")); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } value compile(context& ctx, const shape&, std::vector inputs) @@ -383,9 +383,9 @@ struct miopen_conv_bias } shape get_workspace(context& ctx) { return fp.get_workspace(ctx); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; MIGRAPHX_REGISTER_OP(miopen_conv_bias) @@ -431,9 +431,9 @@ struct miopen_conv_bias_relu } shape get_workspace(context& ctx) { return fp.get_workspace(ctx); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; MIGRAPHX_REGISTER_OP(miopen_conv_bias_relu) diff --git a/src/targets/gpu/include/migraphx/gpu/abs.hpp b/src/targets/gpu/include/migraphx/gpu/abs.hpp index 1a9f4b87800..edeedaa6641 100644 --- a/src/targets/gpu/include/migraphx/gpu/abs.hpp +++ b/src/targets/gpu/include/migraphx/gpu/abs.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,9 +52,9 @@ struct miopen_abs argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; void finalize(context&, const shape&, const std::vector&); - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; #endif diff --git a/src/targets/gpu/include/migraphx/gpu/argmax.hpp b/src/targets/gpu/include/migraphx/gpu/argmax.hpp index e05678fa873..6870832f76d 100644 --- a/src/targets/gpu/include/migraphx/gpu/argmax.hpp +++ b/src/targets/gpu/include/migraphx/gpu/argmax.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,9 +48,9 @@ struct hip_argmax std::string name() const { return "gpu::argmax"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape&, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/argmin.hpp b/src/targets/gpu/include/migraphx/gpu/argmin.hpp index 071eb525e7f..dcb9f62b4b0 100644 --- a/src/targets/gpu/include/migraphx/gpu/argmin.hpp +++ b/src/targets/gpu/include/migraphx/gpu/argmin.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,9 +48,9 @@ struct hip_argmin std::string name() const { return "gpu::argmin"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape&, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp index 8186767289f..11d86dee5ff 100644 --- a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -72,9 +72,9 @@ struct code_object_op { return output_arg < 0 ? n + output_arg : output_arg; } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return get_output_arg(shapes.size()); + return {static_cast(get_output_arg(shapes.size()))}; } friend std::ostream& operator<<(std::ostream& os, const code_object_op& op) diff --git a/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp b/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp index 380fafa44ba..05bb19f9102 100644 --- a/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -57,9 +57,9 @@ struct hipblaslt_op return op.compute_shape(inputs); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; MIGRAPHX_REGISTER_OP(hipblaslt_op); diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index db5bb7373f0..a816d1d3b38 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -89,6 +89,8 @@ struct hip_device hipStream_t get() { + if(external_stream != nullptr) + return external_stream; if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{})) { setup(); @@ -143,8 +145,33 @@ struct hip_device } #endif + void set_external_stream(hipStream_t ext_stream) + { + if(external_stream == ext_stream) + return; + external_stream = ext_stream; +#if MIGRAPHX_USE_MIOPEN + if(mihandle != nullptr) + miopenSetStream(mihandle.get(), ext_stream); +#endif +#if MIGRAPHX_USE_ROCBLAS + if(rbhandle != nullptr) + rocblas_set_stream(rbhandle.get(), ext_stream); +#endif + } + + bool has_external_stream() const { return external_stream != nullptr; } + void wait() const { + if(external_stream != nullptr) + { + setup(); + auto status = hipStreamSynchronize(external_stream); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to wait: " + hip_error(status)); + return; + } if(s == nullptr) return; setup(); @@ -172,6 +199,7 @@ struct hip_device private: std::size_t id = 0; shared s = nullptr; + hipStream_t external_stream = nullptr; #if MIGRAPHX_USE_MIOPEN shared mihandle = nullptr; #endif @@ -333,17 +361,34 @@ struct context void wait_for(any_ptr queue) { - auto status = hipEventRecord(begin_event.get(), queue.get()); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to record: " + hip_error(status)); - - get_stream().wait(begin_event.get()); + if(get_stream().has_external_stream()) + return; + if(queue.unsafe_get() == nullptr) + return; + auto* ext = queue.get(); + if(ext == nullptr) + { + auto status = hipEventRecord(begin_event.get(), ext); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to record: " + hip_error(status)); + get_stream().wait(begin_event.get()); + } + else + { + get_stream().set_external_stream(ext); + } } void finish_on(any_ptr queue) { + if(get_stream().has_external_stream()) + { + get_stream().set_external_stream(nullptr); + return; + } + if(queue.unsafe_get() == nullptr) + return; get_stream().record(finish_event.get()); - auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); if(status != hipSuccess) MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); diff --git a/src/targets/gpu/include/migraphx/gpu/convolution.hpp b/src/targets/gpu/include/migraphx/gpu/convolution.hpp index 1bca5573f71..a25d6052812 100644 --- a/src/targets/gpu/include/migraphx/gpu/convolution.hpp +++ b/src/targets/gpu/include/migraphx/gpu/convolution.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -339,9 +339,9 @@ struct miopen_convolution #endif } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; #endif diff --git a/src/targets/gpu/include/migraphx/gpu/fixed_pad.hpp b/src/targets/gpu/include/migraphx/gpu/fixed_pad.hpp index ad78787c062..70e1b89c4b8 100644 --- a/src/targets/gpu/include/migraphx/gpu/fixed_pad.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fixed_pad.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,9 +48,9 @@ struct hip_fixed_pad std::string name() const { return "gpu::fixed_pad"; } shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape&, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm.hpp index 332de8bf090..7ada7dd05bc 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -121,9 +121,9 @@ struct rocblas_gemm return args.back(); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } void finalize(context& ctx, const shape& output_shape, const std::vector& input_shapes) diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index 055594252c2..d04e81e218c 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -113,7 +113,7 @@ struct hip_fill gpu_fill(ctx, args.front(), value); return args.front(); } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct hip_sync_stream @@ -135,11 +135,11 @@ struct hip_sync_stream return args.front(); } - std::ptrdiff_t output_alias(const std::vector& args) const + std::vector output_alias(const std::vector& args) const { if(args.empty()) - return -1; - return 0; + return {}; + return {0}; } }; @@ -165,11 +165,11 @@ struct hip_copy_to_gpu // Associate the input since it was registered with hip return {result.get_shape(), [input, result]() mutable { return result.data(); }}; } - std::ptrdiff_t output_alias(const std::vector& args) const + std::vector output_alias(const std::vector& args) const { if(args.size() == 1) - return -1; - return 1; + return {}; + return {1}; } }; @@ -198,11 +198,11 @@ struct hip_copy_from_gpu copy_from_gpu(ctx, input, args[1]); return args[1]; } - std::ptrdiff_t output_alias(const std::vector& args) const + std::vector output_alias(const std::vector& args) const { if(args.size() == 1) - return -1; - return 1; + return {}; + return {1}; } }; @@ -224,7 +224,7 @@ struct hip_copy gpu_copy(ctx, args[0], result); return args[1]; } - std::ptrdiff_t output_alias(const std::vector&) const { return 1; } + std::vector output_alias(const std::vector&) const { return {1}; } }; MIGRAPHX_GPU_EXPORT void diff --git a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp index e87ce9edcee..cfe5275d7f6 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -113,9 +113,9 @@ struct MIGRAPHX_GPU_EXPORT hip_gemm return args.back(); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } void finalize(context& ctx, const shape& output_shape, const std::vector& input_shapes) diff --git a/src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp b/src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp index 5ea23ee2724..919d3afec34 100644 --- a/src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp +++ b/src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -47,9 +47,9 @@ struct hip_logsoftmax shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/loop.hpp b/src/targets/gpu/include/migraphx/gpu/loop.hpp index 792c84b74f8..ea6eb75e769 100644 --- a/src/targets/gpu/include/migraphx/gpu/loop.hpp +++ b/src/targets/gpu/include/migraphx/gpu/loop.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -53,9 +53,9 @@ struct hip_loop const std::vector& mods, const std::function( module_ref&, const std::unordered_map&)>& run) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/lrn.hpp b/src/targets/gpu/include/migraphx/gpu/lrn.hpp index 8ccda7bba6a..e075fdfceb5 100644 --- a/src/targets/gpu/include/migraphx/gpu/lrn.hpp +++ b/src/targets/gpu/include/migraphx/gpu/lrn.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -50,9 +50,9 @@ struct miopen_lrn argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; void finalize(context&, const shape&, const std::vector&); - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; #endif diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index d1f19c1e8ef..aeb2e71f0ca 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,6 +54,7 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object }; MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& ins); +MIGRAPHX_GPU_EXPORT void adjust_param_shapes(module& m, const std::vector& inputs); MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx, module m, diff --git a/src/targets/gpu/include/migraphx/gpu/multinomial.hpp b/src/targets/gpu/include/migraphx/gpu/multinomial.hpp index c44d4808233..0946ffb075b 100644 --- a/src/targets/gpu/include/migraphx/gpu/multinomial.hpp +++ b/src/targets/gpu/include/migraphx/gpu/multinomial.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,9 +46,9 @@ struct hip_multinomial shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/nonzero.hpp b/src/targets/gpu/include/migraphx/gpu/nonzero.hpp index cfc7e78dbe3..a47eb138c75 100644 --- a/src/targets/gpu/include/migraphx/gpu/nonzero.hpp +++ b/src/targets/gpu/include/migraphx/gpu/nonzero.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -49,9 +49,9 @@ struct hip_nonzero shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/oper.hpp b/src/targets/gpu/include/migraphx/gpu/oper.hpp index 13ac11a3d48..6ad559e4c0f 100644 --- a/src/targets/gpu/include/migraphx/gpu/oper.hpp +++ b/src/targets/gpu/include/migraphx/gpu/oper.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -72,9 +72,9 @@ struct device_base : oper return {s0.type(), s0.lens()}; } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/pooling.hpp b/src/targets/gpu/include/migraphx/gpu/pooling.hpp index 7f6722b1130..0ea49b44b28 100644 --- a/src/targets/gpu/include/migraphx/gpu/pooling.hpp +++ b/src/targets/gpu/include/migraphx/gpu/pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -50,9 +50,9 @@ struct miopen_pooling void finalize(context&, const shape&, const std::vector&); argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; #endif diff --git a/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp b/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp index cca8efd6057..12c7c2e1957 100644 --- a/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp +++ b/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -67,9 +67,9 @@ struct hip_prefix_scan_sum : oper return args[1]; } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp b/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp index 10f3dcf8445..835cbe03da8 100644 --- a/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -65,9 +65,9 @@ struct reduce_op : oper return args[1]; } - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } reduce_op() {} diff --git a/src/targets/gpu/include/migraphx/gpu/reverse.hpp b/src/targets/gpu/include/migraphx/gpu/reverse.hpp index 8ef82523526..7c5e3aab5f5 100644 --- a/src/targets/gpu/include/migraphx/gpu/reverse.hpp +++ b/src/targets/gpu/include/migraphx/gpu/reverse.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -49,9 +49,9 @@ struct hip_reverse shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp b/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp index 7d811192da0..e29859287c3 100644 --- a/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp +++ b/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -49,9 +49,9 @@ struct hip_rnn_var_sl_shift_sequence shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; @@ -69,9 +69,9 @@ struct hip_rnn_var_sl_shift_output shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; @@ -88,9 +88,9 @@ struct hip_rnn_var_sl_last_output std::string name() const { return "gpu::" + op.name(); } shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape&, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/topk.hpp b/src/targets/gpu/include/migraphx/gpu/topk.hpp index f1df9d469e5..e07a7c4c24e 100644 --- a/src/targets/gpu/include/migraphx/gpu/topk.hpp +++ b/src/targets/gpu/include/migraphx/gpu/topk.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -49,9 +49,9 @@ struct hip_topk shape compute_shape(std::vector inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + std::vector output_alias(const std::vector& shapes) const { - return shapes.size() - 1; + return {shapes.size() - 1}; } }; diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 7fc49f1ca50..69d38e6c58e 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -51,39 +52,108 @@ static module create_pointwise_module(module_ref in_mod) pw_mod.add_parameter(any_cast(param->get_operator()).parameter, shape{param->get_shape().type()}); } - auto return_args = pw_mod.add_instructions( - in_mod, - &map_ins, - [](module& m, - instruction_ref ins, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) -> instruction_ref { - if(op.name() == "multibroadcast" and inputs.front()->name() == "@literal") - return inputs.front(); - else - return m.insert_instruction(ins, op, inputs, mod_args); - }); + auto return_args = + pw_mod.add_instructions(in_mod, + &map_ins, + [](module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) -> instruction_ref { + auto out_aliases = op.output_alias(to_shapes(inputs)); + if(out_aliases.size() == 1) + return inputs.at(out_aliases[0]); + else + return m.insert_instruction(ins, op, inputs, mod_args); + }); pw_mod.add_return(return_args); return pw_mod; } +static code_object_op +compile_pointwise_module(context& ctx, const std::vector& inputs, module_ref mod) +{ + operation cop; + auto pw_mod = create_pointwise_module(mod); + if(any_of(mod->get_parameters(), [&](instruction_ref param) { + if(param->outputs().size() != 1) + return false; + return equal(instruction::get_output_alias(param->outputs().front(), /* shallow */ true), + {param}); + })) + { + auto mod2 = *mod; + adjust_param_shapes(mod2, inputs); + auto names = mod2.get_parameter_names(); + std::sort(names.begin(), names.end()); + std::vector new_shapes; + std::transform(names.begin(), + names.end(), + std::back_inserter(new_shapes), + [&](const std::string& name) { + auto param = mod2.get_parameter(name); + auto output_path = get_output_path(param); + auto it = std::adjacent_find( + output_path.begin(), + output_path.end(), + [&](instruction_ref, instruction_ref output) { + return not equal(instruction::get_output_alias(output), {param}); + }); + return (*it)->get_shape(); + }); + std::copy(inputs.begin() + new_shapes.size(), inputs.end(), std::back_inserter(new_shapes)); + cop = compile_pointwise(ctx, new_shapes, &pw_mod); + } + else + { + cop = compile_pointwise(ctx, inputs, &pw_mod); + } + auto co = any_cast(cop); + co.expected_inputs = inputs; + return co; +} + +static instruction_ref find_final_split(instruction_ref split_ins) +{ + auto output_path = get_output_path(split_ins); + auto it = std::adjacent_find( + output_path.begin(), output_path.end(), [&](instruction_ref input, instruction_ref output) { + if(contains({"reshape", "squeeze", "unsqueeze", "transpose"}, output->name())) + return false; + if(contains({"add", "mul"}, output->name())) + { + auto aux = *std::find_if(output->inputs().begin(), + output->inputs().end(), + [&](instruction_ref i) { return i != input; }); + if(aux->can_eval()) + return false; + auto aliases = instruction::get_output_alias(aux); + return aliases.size() == 1 and aliases[0]->name() != "@param"; + } + return true; + }); + return *it; +} + struct mlir_compiler : compiler { std::vector names() const { return {"gpu::mlir_op"}; } operation compile_op(context&, const std::vector&, const value&) const { return {}; } - std::optional input_is_param(const instruction_ref& ins) const { - auto cur = instruction::get_output_alias(ins); - while(contains({"reshape", "contiguous"}, cur->name())) + auto aliases = instruction::get_output_alias(ins); + for(auto cur : aliases) { - cur = instruction::get_output_alias(cur->inputs().at(0)); - } - if(cur->name() == "@param") - { - return cur; + while(contains({"reshape", "contiguous"}, cur->name())) + { + auto nested_aliases = instruction::get_output_alias(cur->inputs().at(0)); + cur = nested_aliases.front(); + } + if(cur->name() == "@param") + { + return cur; + } } return nullopt; } @@ -149,9 +219,8 @@ struct mlir_compiler : compiler auto input_args = ins->inputs(); // remove alloc buffer input_args.pop_back(); - auto split_ins = std::prev(pointwise_ins); - std::array mod_splits; - mod_splits = smod->split(input_args, {split_ins}); + auto split_ins = find_final_split(gemm_like_ins); + std::array mod_splits = smod->split(input_args, {split_ins}); auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); // add alloc for the gemm output dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); @@ -166,10 +235,8 @@ struct mlir_compiler : compiler pw_shapes.push_back(shape{mod_splits[1].mod.get_output_shapes()}); } assert(pw_shapes.back() == ins->get_shape()); - auto pw_mod = create_pointwise_module(&mod_splits[1].mod); - auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); - std::vector cops = {cop1, - mlir_code_object{any_cast(cop2)}}; + auto cop2 = compile_pointwise_module(ctx, pw_shapes, &mod_splits[1].mod); + std::vector cops = {cop1, mlir_code_object{cop2}}; return insert(cops, mod_splits, ins, split_ins); } auto cr = insert(compile_mlir(ctx, *smod, to_shapes(ins->inputs()), solution)); diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 0e24c3cc936..9b660beb407 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -197,11 +197,14 @@ struct miopen_apply void insert_fill(instruction_ref ins, value v) const { - instruction_ref alloc = instruction::get_output_alias(ins, true); - if(alloc == ins) + auto aliases = instruction::get_output_alias(ins, true); + if(aliases.size() == 1 and aliases.front() == ins) return; - auto fill = mod->insert_instruction(ins, make_op("hip::fill", {{"value", v}}), alloc); - instruction::replace_argument(ins, alloc, fill); + for(instruction_ref alloc : aliases) + { + auto fill = mod->insert_instruction(ins, make_op("hip::fill", {{"value", v}}), alloc); + instruction::replace_argument(ins, alloc, fill); + } } instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 9ca5c35330c..16df2e56ebe 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -699,16 +699,18 @@ struct mlir_program static bool input_is_unpack_fp4(instruction_ref ins) { - ins = instruction::get_output_alias(ins); - if(ins->name() == "reshape") - { - return input_is_unpack_fp4(ins->inputs().front()); - } - if(ins->name() == "unpack_fp4") - { - return true; - } - return false; + auto aliases = instruction::get_output_alias(ins); + return std::any_of(aliases.begin(), aliases.end(), [](instruction_ref alias) { + if(alias->name() == "reshape") + { + return input_is_unpack_fp4(alias->inputs().front()); + } + if(alias->name() == "unpack_fp4") + { + return true; + } + return false; + }); } static shape make_fp4_unpacked_shape(shape s) @@ -1129,7 +1131,7 @@ bool is_module_fusible(const module& m, const context& migraphx_ctx, const value return mlirIsModuleFusible(mp.mmodule.get(), make_mlir_string_ref(*solution.if_string())); } -static void adjust_param_shapes(module& m, const std::vector& inputs) +void adjust_param_shapes(module& m, const std::vector& inputs) { auto names = m.get_parameter_names(); std::sort(names.begin(), names.end()); diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index de2da57ebc4..2cbcb1e4db8 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -198,7 +199,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, eliminate_data_type{unsupported_types, shape::type_t::float_type}, - simplify_reshapes{}, + simplify_reshapes{.enable_gather_rewrite = true}, eliminate_identity{}, eliminate_pad{}, dead_code_elimination{}, @@ -211,6 +212,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti optimize_module{}, layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})}, dead_code_elimination{}, + fuse_horizontal{}, + dead_code_elimination{}, prefuse_ops{}, dead_code_elimination{}, eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::fp8e5m2fnuz_type}, shape::float_type, unsupported_fp8fnuz_ops}, diff --git a/test/algorithm.cpp b/test/algorithm.cpp index 47d671b4763..d7ec1924136 100644 --- a/test/algorithm.cpp +++ b/test/algorithm.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/api/test_custom_op.cpp b/test/api/test_custom_op.cpp index b4d4c0da0c3..e6224ae6ca2 100644 --- a/test/api/test_custom_op.cpp +++ b/test/api/test_custom_op.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -133,10 +133,10 @@ struct identity_custom_op final : migraphx::experimental_custom_op_base return inputs.back(); } - virtual std::vector output_alias(migraphx::shapes) const override { return {0, 1}; } + virtual std::vector output_alias(migraphx::shapes) const override { return {0}; } }; -TEST_CASE(run_custom_op_with_invalid_output_alias) +TEST_CASE(run_custom_op_with_output_alias) { identity_custom_op i_op; migraphx::register_experimental_custom_op(i_op); @@ -147,11 +147,8 @@ TEST_CASE(run_custom_op_with_invalid_output_alias) migraphx::shape s{migraphx_shape_float_type, {12}}; migraphx::module m = p.get_main_module(); auto x = m.add_parameter("x", s); - auto i_ins = m.add_instruction(migraphx::operation("identity_custom_op"), {x}); - migraphx_test_private_disable_exception_catch(true); - EXPECT(test::throws( - [&] { p.compile(migraphx::target("ref")); }, - "Currently, CustomOps in MIGraphX only supports one output_alias")); + m.add_instruction(migraphx::operation("identity_custom_op"), {x}); + p.compile(migraphx::target("ref")); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/common_dims.cpp b/test/common_dims.cpp index 416e7495255..2436db340bc 100644 --- a/test/common_dims.cpp +++ b/test/common_dims.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -83,6 +83,22 @@ TEST_CASE(common4) EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3, 4}}); } +TEST_CASE(common5) +{ + auto cd = migraphx::common_dims::compute({3, 8, 5}, {12, 10}); + EXPECT(cd.dims == std::vector{3, 4, 2, 5}); + EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}}); + EXPECT(cd.axes_map2 == axes_map{{0, 1}, {1, 2}}); +} + +TEST_CASE(common6) +{ + auto cd = migraphx::common_dims::compute({12, 10}, {3, 8, 5}); + EXPECT(cd.dims == std::vector{3, 4, 2, 5}); + EXPECT(cd.axes_map1 == axes_map{{0, 1}, {1, 2}}); + EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}}); +} + TEST_CASE(common_same_dims) { auto cd = migraphx::common_dims::compute({{2, 32, 4}}, {64, 2, 2}); diff --git a/test/dead_code_elimination_test.cpp b/test/dead_code_elimination_test.cpp index 23073aae47d..501045f90fc 100644 --- a/test/dead_code_elimination_test.cpp +++ b/test/dead_code_elimination_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -298,4 +298,17 @@ TEST_CASE(empty_literal) EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); } +TEST_CASE(zero_element_shape_eliminated) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape zero_elem_shape{migraphx::shape::float_type, {2, 0, 4}}; + auto x = mm->add_parameter("x", zero_elem_shape); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); + mm->add_instruction(migraphx::make_op("neg"), x); + mm->add_return({y}); + run_pass(p); + EXPECT(std::none_of(mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "neg"; })); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_concat_test.cpp b/test/eliminate_concat_test.cpp index 4fe59c0bac7..c7fcf7859da 100644 --- a/test/eliminate_concat_test.cpp +++ b/test/eliminate_concat_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -123,7 +123,7 @@ struct simple_op { return args.at(0); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; template diff --git a/test/eval_test.cpp b/test/eval_test.cpp index 0435cf61634..018e601ef9a 100644 --- a/test/eval_test.cpp +++ b/test/eval_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -65,7 +65,7 @@ struct id_ctx_op return {}; return inputs.front(); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct id_ctx_final_op @@ -88,7 +88,7 @@ struct id_ctx_final_op return {}; return inputs.front(); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct reverse_pass diff --git a/test/fuse_horizontal_test.cpp b/test/fuse_horizontal_test.cpp new file mode 100644 index 00000000000..fb2f59ac1cf --- /dev/null +++ b/test/fuse_horizontal_test.cpp @@ -0,0 +1,516 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void run_pass(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::fuse_horizontal{}, migraphx::dead_code_elimination{}}); +} + +// 4 gathers with same embedding dim → should fuse into 1 batched gather +TEST_CASE(gather_horiz_fusion_basic) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {3}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {1}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx4); + + // Combine all outputs so every gather stays live through DCE + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3, g4}); + } + run_pass(m1); + + migraphx::module m2; + { + // Embedding literals (added first → pushed to front → end up at the back of no-dep list) + auto emb1 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + // Parameters (added second → in middle of no-dep list) + auto idx1 = m2.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m2.add_parameter("idx2", {migraphx::shape::int32_type, {3}}); + auto idx3 = m2.add_parameter("idx3", {migraphx::shape::int32_type, {1}}); + auto idx4 = m2.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + // Offset literals (added last → pushed to very front of no-dep list, + // matching order of add_literal calls inside the pass's fuse loop) + auto offset2 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(3)}}); + auto offset3 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(7)}}); + auto offset4 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(9)}}); + + // Concatenated embedding table: [3+4+2+5, 2] = [14, 2] + auto concat_emb = + m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{emb1, emb2, emb3, emb4}); + + // Adjust indices with cumulative offsets + auto bc2 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), offset2); + auto adj_idx2 = m2.add_instruction(migraphx::make_op("add"), idx2, bc2); + + auto bc3 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1}}}), offset3); + auto adj_idx3 = m2.add_instruction(migraphx::make_op("add"), idx3, bc3); + + auto bc4 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset4); + auto adj_idx4 = m2.add_instruction(migraphx::make_op("add"), idx4, bc4); + + // Concatenated adjusted indices: [2+3+1+2] = [8] + auto concat_idx = m2.add_instruction( + migraphx::make_op("concat", {{"axis", 0}}), + std::vector{idx1, adj_idx2, adj_idx3, adj_idx4}); + + // Single batched gather + auto bg = + m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), concat_emb, concat_idx); + + // Slice results back + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), bg); + auto s2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {5}}}), bg); + auto s3 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {5}}, {"ends", {6}}}), bg); + auto s4 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {8}}}), bg); + + // Same concat combiner as m1 (now referencing slices instead of gathers) + m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{s1, s2, s3, s4}); + } + EXPECT(m1 == m2); +} + +// Only 3 gathers (below min_batch_size=4) → no fusion +TEST_CASE(gather_horiz_no_fusion_below_threshold) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {3}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {1}}); + + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +// Embeddings are parameters (not constants) → no fusion +TEST_CASE(gather_horiz_no_fusion_non_constant_embedding) +{ + migraphx::module m1; + { + auto emb1 = m1.add_parameter("emb1", {migraphx::shape::float_type, {3, 2}}); + auto emb2 = m1.add_parameter("emb2", {migraphx::shape::float_type, {4, 2}}); + auto emb3 = m1.add_parameter("emb3", {migraphx::shape::float_type, {2, 2}}); + auto emb4 = m1.add_parameter("emb4", {migraphx::shape::float_type, {5, 2}}); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {3}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {1}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx4); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3, g4}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +// Gather axis=1 instead of axis=0 → no fusion +TEST_CASE(gather_horiz_no_fusion_wrong_axis) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 4}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 5}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 6}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 7}}, 3)); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {2}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {2}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + // axis=1 gathers → all outputs are [3, 2] + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), emb1, idx1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), emb4, idx4); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3, g4}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +// Each embedding has a different embedding dim → separate groups of 1, no fusion +TEST_CASE(gather_horiz_no_fusion_different_emb_dims) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 4}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 16}}, 3)); + + // All indices same size so outputs are compatible for concat on axis=1 + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {2}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {2}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + // outputs: [2,2], [2,4], [2,8], [2,16] + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx4); + + // concat on axis=1 since first dims match (2) but second dims differ + m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), + std::vector{g1, g2, g3, g4}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +// 3D embedding tables (not 2D) → no fusion +TEST_CASE(gather_horiz_no_fusion_3d_embedding) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}, 3)); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {3}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {1}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + // outputs: [2,3,4], [3,3,4], [1,3,4], [2,3,4] → concat axis=0 → [8,3,4] + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx4); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3, g4}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +// First gather's output is used before the second gather — consumers are interleaved +// The pass should still fuse and move_output_instructions_after handles reordering +TEST_CASE(gather_horiz_fusion_interleaved_consumers) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m1.add_parameter("idx2", {migraphx::shape::int32_type, {2}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {2}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + + // g1's output is consumed here — between g1 and g2 + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), g1); + + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx2); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx4); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{relu1, g2, g3, g4}); + } + run_pass(m1); + + migraphx::module m2; + { + auto emb1 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + auto idx1 = m2.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx2 = m2.add_parameter("idx2", {migraphx::shape::int32_type, {2}}); + auto idx3 = m2.add_parameter("idx3", {migraphx::shape::int32_type, {2}}); + auto idx4 = m2.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + auto offset2 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(3)}}); + auto offset3 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(7)}}); + auto offset4 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(9)}}); + + auto concat_emb = + m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{emb1, emb2, emb3, emb4}); + + auto bc2 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset2); + auto adj_idx2 = m2.add_instruction(migraphx::make_op("add"), idx2, bc2); + + auto bc3 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset3); + auto adj_idx3 = m2.add_instruction(migraphx::make_op("add"), idx3, bc3); + + auto bc4 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset4); + auto adj_idx4 = m2.add_instruction(migraphx::make_op("add"), idx4, bc4); + + auto concat_idx = m2.add_instruction( + migraphx::make_op("concat", {{"axis", 0}}), + std::vector{idx1, adj_idx2, adj_idx3, adj_idx4}); + + auto bg = + m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), concat_emb, concat_idx); + + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), bg); + auto s2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), bg); + auto s3 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {6}}}), bg); + auto s4 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {8}}}), bg); + + // relu was on g1, now on s1 — moved after slices + auto relu1 = m2.add_instruction(migraphx::make_op("relu"), s1); + + m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{relu1, s2, s3, s4}); + } + EXPECT(m1 == m2); +} + +// Shared index: all 4 gathers use the same index parameter +TEST_CASE(gather_horiz_fusion_shared_index) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + auto idx = m1.add_parameter("idx", {migraphx::shape::int32_type, {2}}); + + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, idx); + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3, g4}); + } + run_pass(m1); + + migraphx::module m2; + { + auto emb1 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + auto idx = m2.add_parameter("idx", {migraphx::shape::int32_type, {2}}); + + auto offset2 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(3)}}); + auto offset3 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(7)}}); + auto offset4 = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {std::size_t(9)}}); + + auto concat_emb = + m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{emb1, emb2, emb3, emb4}); + + auto bc2 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset2); + auto adj_idx2 = m2.add_instruction(migraphx::make_op("add"), idx, bc2); + + auto bc3 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset3); + auto adj_idx3 = m2.add_instruction(migraphx::make_op("add"), idx, bc3); + + auto bc4 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), offset4); + auto adj_idx4 = m2.add_instruction(migraphx::make_op("add"), idx, bc4); + + auto concat_idx = m2.add_instruction( + migraphx::make_op("concat", {{"axis", 0}}), + std::vector{idx, adj_idx2, adj_idx3, adj_idx4}); + + auto bg = + m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), concat_emb, concat_idx); + + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), bg); + auto s2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), bg); + auto s3 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {6}}}), bg); + auto s4 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {8}}}), bg); + + m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{s1, s2, s3, s4}); + } + EXPECT(m1 == m2); +} + +// Dependent gathers: g2 depends on g1's output → only independent ones fuse +// Since g1→g2 dependency exists, group_by won't group them together. +// With only 3 remaining independent gathers, below min_group_size=4, no fusion. +TEST_CASE(gather_horiz_no_fusion_dependent) +{ + migraphx::module m1; + { + auto emb1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {3, 2}}, 0)); + auto emb2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 2}}, 1)); + auto emb3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2}}, 2)); + auto emb4 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {5, 2}}, 3)); + + auto idx1 = m1.add_parameter("idx1", {migraphx::shape::int32_type, {2}}); + auto idx3 = m1.add_parameter("idx3", {migraphx::shape::int32_type, {2}}); + auto idx4 = m1.add_parameter("idx4", {migraphx::shape::int32_type, {2}}); + + auto g1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb1, idx1); + + // g2 uses g1's output shape to derive its index (dependency) + auto reshape_g1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), g1); + auto g2 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb2, reshape_g1); + + auto g3 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb3, idx3); + auto g4 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb4, idx4); + + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + std::vector{g1, g2, g3, g4}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 44e08e63c65..0bf9446d4af 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1165,24 +1165,34 @@ TEST_CASE(if_cross_module_multi_out_find_output) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto cond = mm->add_parameter("cond", migraphx::shape{migraphx::shape::bool_type, {1}}); - auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto fused = add_pointwise( + p2, + "main:pointwise0", + {x, y}, + [=](auto* pm, const auto& inputs) -> std::vector { + auto sqrt_r = pm->add_instruction(migraphx::make_op("sqrt"), inputs[0]); + auto add_r = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return {sqrt_r, add_r}; + }); + auto sqrt_out = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto add_out = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); auto* then_mod = p2.create_module("If_then_1"); { auto relu = add_pointwise( - p2, then_mod, "If_then_1:pointwise0", {add1}, single_pointwise("relu")); + p2, then_mod, "If_then_1:pointwise0", {add_out}, single_pointwise("relu")); then_mod->add_return({relu}); } auto* else_mod = p2.create_module("If_else_1"); { - else_mod->add_return({add1}); + else_mod->add_return({add_out}); } auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); - - auto sqrt = add_pointwise(p2, "main:pointwise1", {x}, single_pointwise("sqrt")); - mm->add_return({sqrt, if_ins}); + mm->add_return({sqrt_out, if_ins}); } EXPECT(p1.sort() == p2.sort()); } diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 8bf01bc7789..c75b141c321 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -939,6 +939,49 @@ TEST_CASE(reduce_contiguous_reshape_pointwise) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto rsum = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x); + auto squeeze = mm->add_instruction( + migraphx::make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), rsum); + auto unsqueeze = mm->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), squeeze); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), unsqueeze); + auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto add = add_reduce( + p2, + "main:reduce_sum0_reshape:main:pointwise0", + {x, y}, + {7, 8, 9, 10, 11}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); + return add_pointwise( + p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); + }); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + TEST_CASE(reduce_reshape_reduce) { migraphx::shape s1{migraphx::shape::float_type, {2, 32, 4096}}; diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp new file mode 100644 index 00000000000..8c9f1087dc6 --- /dev/null +++ b/test/gpu/external_stream.cpp @@ -0,0 +1,407 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +using hip_stream_ptr = MIGRAPHX_MANAGE_PTR(hipStream_t, hipStreamDestroy); + +static hip_stream_ptr create_external_stream() +{ + hipStream_t stream; + auto status = hipStreamCreate(&stream); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to create stream"); + return hip_stream_ptr{stream}; +} + +static void verify_data(const migraphx::argument& result, const migraphx::shape& s, float expected) +{ + std::vector expected_data(s.elements(), expected); + auto expected_arg = migraphx::argument{s, expected_data.data()}; + EXPECT(result == expected_arg); +} + +TEST_CASE(test_stream_override_get) +{ + migraphx::gpu::context ctx{}; + auto& stream = ctx.get_stream(); + + hipStream_t internal = stream.get(); + EXPECT(internal != nullptr); + + auto ext = create_external_stream(); + stream.set_external_stream(ext.get()); + + EXPECT(stream.get() == ext.get()); + EXPECT(stream.get() != internal); + EXPECT(stream.has_external_stream()); + + stream.set_external_stream(nullptr); + + EXPECT(stream.get() == internal); + EXPECT(not stream.has_external_stream()); +} + +TEST_CASE(test_stream_override_get_queue) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + hipStream_t original_queue = ctx.get_queue().get(); + EXPECT(original_queue != nullptr); + + ctx.get_stream().set_external_stream(ext.get()); + EXPECT(ctx.get_queue().get() == ext.get()); + + ctx.get_stream().set_external_stream(nullptr); + + EXPECT(ctx.get_queue().get() == original_queue); +} + +TEST_CASE(test_context_wait_for_sets_external_stream) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + migraphx::any_ptr queue(ext.get()); + + hipStream_t before = ctx.get_queue().get(); + ctx.wait_for(queue); + EXPECT(ctx.get_queue().get() == ext.get()); + EXPECT(ctx.get_queue().get() != before); + + ctx.finish_on(queue); + EXPECT(ctx.get_queue().get() == before); +} + +TEST_CASE(test_external_stream_eval_uses_caller_stream) +{ + const unsigned int m = 64; + const unsigned int k = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {m, k}}); + auto y = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {k, m}})); + mm->add_instruction(migraphx::make_op("dot"), x, y); + + p.compile(migraphx::make_target("gpu")); + + migraphx::shape input_shape{migraphx::shape::float_type, {m, k}}; + migraphx::shape output_shape{migraphx::shape::float_type, {m, m}}; + auto input = migraphx::fill_argument(input_shape, 1); + auto ginput = migraphx::gpu::to_gpu(input); + + auto output = migraphx::fill_argument(output_shape, 0); + auto goutput = migraphx::gpu::to_gpu(output); + + auto ext = create_external_stream(); + + auto results = p.eval({{"x", ginput}, {"main:#output_0", goutput}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_output = migraphx::gpu::from_gpu(goutput); + EXPECT(host_output != output); +} + +TEST_CASE(test_external_stream_serialized_on_caller_stream) +{ + const unsigned int n = 256; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); +} + +TEST_CASE(test_multiple_async_evals_same_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + for(int iter = 0; iter < 5; ++iter) + { + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + } + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); +} + +TEST_CASE(test_external_stream_cleared_after_eval) +{ + const unsigned int n = 64; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + hipStream_t internal_stream = gpu_ctx->get_queue().get(); + + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(gpu_ctx->get_queue().get() == internal_stream); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); +} + +TEST_CASE(test_wait_for_null_stream_uses_event_fallback) +{ + migraphx::gpu::context ctx{}; + + migraphx::any_ptr queue{}; + + hipStream_t internal_before = ctx.get_queue().get(); + + ctx.wait_for(queue); + + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() == internal_before); + + ctx.finish_on(queue); + + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() == internal_before); +} + +TEST_CASE(test_fallback_event_path_produces_correct_results) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 5.0f); + std::vector ydata(n, 7.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto results = + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {migraphx::any_ptr{}, true}); + + EXPECT(not results.empty()); + + EXPECT(hipDeviceSynchronize() == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 12.0f); +} + +TEST_CASE(test_non_async_eval_uses_internal_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 4.0f); + std::vector ydata(n, 6.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}); + + EXPECT(not results.empty()); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + p.finish(); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 10.0f); +} + +TEST_CASE(test_mixed_async_and_sync_evals) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto ext = create_external_stream(); + + // Async eval with external stream + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); + + // Sync eval with internal stream + auto gout2 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout2}}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + p.finish(); + + auto host_result2 = migraphx::gpu::from_gpu(gout2); + verify_data(host_result2, out_shape, 3.0f); + + // Async eval again to confirm no stale state + auto gout3 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout3}}, {ext.get(), true}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + auto host_result3 = migraphx::gpu::from_gpu(gout3); + verify_data(host_result3, out_shape, 3.0f); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/include/basic_ops.hpp b/test/include/basic_ops.hpp index fa57ec8b1ad..4f634247d7b 100644 --- a/test/include/basic_ops.hpp +++ b/test/include/basic_ops.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -131,7 +131,10 @@ struct pass_op return {}; return inputs.front(); } - int output_alias(const std::vector& s) const { return s.empty() ? -1 : 0; } + std::vector output_alias(const std::vector& s) const + { + return s.empty() ? std::vector{} : std::vector{0}; + } }; struct non_const_pass_op @@ -151,7 +154,10 @@ struct non_const_pass_op return {}; return inputs.front(); } - int output_alias(const std::vector& s) const { return s.empty() ? -1 : 0; } + std::vector output_alias(const std::vector& s) const + { + return s.empty() ? std::vector{} : std::vector{0}; + } }; struct mod_pass_op @@ -176,7 +182,7 @@ struct mod_pass_op return {}; } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct unary_pass_op @@ -196,7 +202,7 @@ struct unary_pass_op MIGRAPHX_THROW("Wrong inputs"); return inputs.front(); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct pass_standard_op @@ -221,7 +227,7 @@ struct pass_standard_op return {}; return inputs.front(); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct nop @@ -250,6 +256,24 @@ struct tuple_op } }; +// Operation that aliases multiple inputs (all inputs) +struct multi_alias_op +{ + std::string name() const { return "multi_alias"; } + migraphx::shape compute_shape(std::vector inputs) const + { + if(inputs.empty()) + MIGRAPHX_THROW("Need at least 1 input"); + return inputs.front(); + } + std::vector output_alias(const std::vector& s) const + { + std::vector result(s.size()); + std::iota(result.begin(), result.end(), 0); + return result; + } +}; + inline migraphx::literal get_2x2(int base = 0) { return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, diff --git a/test/include/test.hpp b/test/include/test.hpp index ace4aae4f42..77d95fbd6f6 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -190,10 +190,24 @@ Stream& print_stream_impl(rank<4>, Stream& s, std::nullptr_t) return s; } +template +auto print_stream_impl(rank<5>, Stream& s, const Optional& x) + -> decltype(bool(Optional{*x}), x.has_value(), x.value(), void()) +{ + if(x.has_value()) + { + print_stream(s, x.value()); + } + else + { + s << "nullopt"; + } +} + template void print_stream(Stream& s, const T& x) { - print_stream_impl(rank<5>{}, s, x); + print_stream_impl(rank<6>{}, s, x); } template diff --git a/test/liveness_test.cpp b/test/liveness_test.cpp new file mode 100644 index 00000000000..da206d239dc --- /dev/null +++ b/test/liveness_test.cpp @@ -0,0 +1,132 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +TEST_CASE(liveness_single_alias) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {1.0f, 2.0f}}); + auto p1 = mm->add_instruction(pass_op{}, x); + auto p2 = mm->add_instruction(pass_op{}, p1); + mm->add_return({p2}); + + std::vector consumed; + migraphx::liveness(*mm, [&](auto ins, const auto&) { consumed.push_back(ins); }); + + // With single alias ops, pass_ops alias to x, so liveness tracks x + // The callback is called when x is "consumed" (last usage) + EXPECT(migraphx::contains(consumed, x)); +} + +TEST_CASE(liveness_multi_alias) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {1.0f, 2.0f}}); + auto y = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {3.0f, 4.0f}}); + // multi_alias_op aliases both x and y + auto ma = mm->add_instruction(multi_alias_op{}, x, y); + auto p1 = mm->add_instruction(pass_op{}, ma); + mm->add_return({p1}); + + std::vector consumed; + migraphx::liveness(*mm, [&](auto ins, const auto&) { consumed.push_back(ins); }); + + // Both x and y should be tracked and consumed + // because multi_alias_op aliases both inputs + EXPECT(migraphx::contains(consumed, x)); + EXPECT(migraphx::contains(consumed, y)); +} + +TEST_CASE(liveness_multi_alias_cascade) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {1.0f, 2.0f}}); + auto y = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {3.0f, 4.0f}}); + auto z = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {5.0f, 6.0f}}); + // First multi_alias aliases x and y + auto ma1 = mm->add_instruction(multi_alias_op{}, x, y); + // Second multi_alias aliases ma1 (which aliases x,y) and z + auto ma2 = mm->add_instruction(multi_alias_op{}, ma1, z); + mm->add_return({ma2}); + + std::vector consumed; + migraphx::liveness(*mm, [&](auto ins, const auto&) { consumed.push_back(ins); }); + + // All three literals should be tracked and consumed + // ma2 transitively aliases x, y, z + EXPECT(migraphx::contains(consumed, x)); + EXPECT(migraphx::contains(consumed, y)); + EXPECT(migraphx::contains(consumed, z)); +} + +TEST_CASE(liveness_multi_alias_both_tracked) +{ + // This test verifies that when multi_alias_op aliases multiple inputs, + // ALL aliased instructions are properly tracked in liveness analysis. + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {1.0f, 2.0f}}); + auto y = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {3.0f, 4.0f}}); + // multi_alias_op returns output_alias {0, 1} - it aliases both inputs + auto ma = mm->add_instruction(multi_alias_op{}, x, y); + mm->add_return({ma}); + + // Count how many times each literal appears in any live_set across all callbacks + std::size_t x_live_count = 0; + std::size_t y_live_count = 0; + migraphx::liveness(*mm, [&](auto, const auto& live_set) { + if(migraphx::contains(live_set, x)) + x_live_count++; + if(migraphx::contains(live_set, y)) + y_live_count++; + }); + + // Both x and y should appear as live at some point during liveness analysis + // (because multi_alias_op properly exposes both as aliases) + // Note: they might have count 0 if they're only processed when the live_set is already emptied + // The key test is that BOTH get consumed (callback called for both) + std::vector consumed; + migraphx::liveness(*mm, [&](auto ins, const auto&) { consumed.push_back(ins); }); + + EXPECT(migraphx::contains(consumed, x)); + EXPECT(migraphx::contains(consumed, y)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/module_test.cpp b/test/module_test.cpp index 87ab9019e13..d14a5a5a312 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1579,4 +1579,433 @@ TEST_CASE(hoist_external_inputs_adjacent_instructions) EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), dot2)); } +TEST_CASE(move_output_instructions_after_single_output) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto out1 = m1.add_instruction(migraphx::make_op("neg"), src); + auto dst = m1.add_instruction(migraphx::make_op("relu"), x); + m1.add_return({out1, dst}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto dst = m2.add_instruction(migraphx::make_op("relu"), x); + auto out1 = m2.add_instruction(migraphx::make_op("neg"), src); + m2.add_return({out1, dst}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_transitive_outputs) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto a = m1.add_instruction(migraphx::make_op("neg"), src); + auto b = m1.add_instruction(migraphx::make_op("relu"), a); + auto dst = m1.add_instruction(migraphx::make_op("sqrt"), x); + m1.add_return({b, dst}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto dst = m2.add_instruction(migraphx::make_op("sqrt"), x); + auto a = m2.add_instruction(migraphx::make_op("neg"), src); + auto b = m2.add_instruction(migraphx::make_op("relu"), a); + m2.add_return({b, dst}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_multiple_direct_outputs) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto a = m1.add_instruction(migraphx::make_op("neg"), src); + auto b = m1.add_instruction(migraphx::make_op("relu"), src); + auto dst = m1.add_instruction(migraphx::make_op("sqrt"), x); + m1.add_return({a, b, dst}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto dst = m2.add_instruction(migraphx::make_op("sqrt"), x); + auto a = m2.add_instruction(migraphx::make_op("neg"), src); + auto b = m2.add_instruction(migraphx::make_op("relu"), src); + m2.add_return({a, b, dst}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_no_outputs_between) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto dst = m1.add_instruction(migraphx::make_op("sqrt"), x); + auto out1 = m1.add_instruction(migraphx::make_op("neg"), src); + m1.add_return({out1, dst}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto dst = m2.add_instruction(migraphx::make_op("sqrt"), x); + auto out1 = m2.add_instruction(migraphx::make_op("neg"), src); + m2.add_return({out1, dst}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_diamond) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto a = m1.add_instruction(migraphx::make_op("neg"), src); + auto b = m1.add_instruction(migraphx::make_op("relu"), src); + auto c = m1.add_instruction(migraphx::make_op("add"), a, b); + auto dst = m1.add_instruction(migraphx::make_op("sqrt"), x); + m1.add_return({c, dst}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto dst = m2.add_instruction(migraphx::make_op("sqrt"), x); + auto a = m2.add_instruction(migraphx::make_op("neg"), src); + auto b = m2.add_instruction(migraphx::make_op("relu"), src); + auto c = m2.add_instruction(migraphx::make_op("add"), a, b); + m2.add_return({c, dst}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_mixed) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto mid = m1.add_instruction(migraphx::make_op("neg"), y); + auto out1 = m1.add_instruction(migraphx::make_op("relu"), src); + auto dst = m1.add_instruction(migraphx::make_op("sqrt"), y); + auto out2 = m1.add_instruction(migraphx::make_op("tanh"), src); + m1.add_return({out1, out2, dst, mid}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto mid = m2.add_instruction(migraphx::make_op("neg"), y); + auto dst = m2.add_instruction(migraphx::make_op("sqrt"), y); + auto out1 = m2.add_instruction(migraphx::make_op("relu"), src); + auto out2 = m2.add_instruction(migraphx::make_op("tanh"), src); + m2.add_return({out1, out2, dst, mid}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_dst_depends_on_src) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto src = m1.add_instruction(migraphx::make_op("abs"), x); + auto out1 = m1.add_instruction(migraphx::make_op("neg"), src); + auto dst = m1.add_instruction(migraphx::make_op("relu"), src); + m1.add_return({out1, dst}); + m1.move_output_instructions_after(src, dst); + } + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto src = m2.add_instruction(migraphx::make_op("abs"), x); + auto dst = m2.add_instruction(migraphx::make_op("relu"), src); + auto out1 = m2.add_instruction(migraphx::make_op("neg"), src); + m2.add_return({out1, dst}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(move_output_instructions_after_module_output) +{ + // When src is referenced by instructions inside a submodule, src->outputs() + // includes those cross-module instructions. The function resolves them to + // the instruction in the current module that owns the submodule and moves + // that instruction instead. + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape cond_s{migraphx::shape::bool_type, {1}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond = mm->add_parameter("cond", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod = p1.create_module("then_mod"); + auto sub_neg = then_mod->add_instruction(migraphx::make_op("neg"), src); + then_mod->add_return({sub_neg}); + + auto* else_mod = p1.create_module("else_mod"); + auto sub_relu = else_mod->add_instruction(migraphx::make_op("relu"), src); + else_mod->add_return({sub_relu}); + + auto out1 = mm->add_instruction(migraphx::make_op("tanh"), src); + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + mm->add_return({out1, if_ins, dst}); + mm->move_output_instructions_after(src, dst); + } + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond = mm->add_parameter("cond", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod = p2.create_module("then_mod"); + auto sub_neg = then_mod->add_instruction(migraphx::make_op("neg"), src); + then_mod->add_return({sub_neg}); + + auto* else_mod = p2.create_module("else_mod"); + auto sub_relu = else_mod->add_instruction(migraphx::make_op("relu"), src); + else_mod->add_return({sub_relu}); + + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto out1 = mm->add_instruction(migraphx::make_op("tanh"), src); + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({out1, if_ins, dst}); + } + + EXPECT(p1 == p2); +} + +TEST_CASE(move_output_instructions_after_only_cross_module_output) +{ + // src has no direct outputs in the current module between src and dst, + // only cross-module outputs. The owning instruction (if) is between src + // and dst, so it gets moved after dst. + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape cond_s{migraphx::shape::bool_type, {1}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond = mm->add_parameter("cond", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod = p1.create_module("then_mod"); + auto sub_neg = then_mod->add_instruction(migraphx::make_op("neg"), src); + then_mod->add_return({sub_neg}); + + auto* else_mod = p1.create_module("else_mod"); + auto sub_relu = else_mod->add_instruction(migraphx::make_op("relu"), src); + else_mod->add_return({sub_relu}); + + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + mm->add_return({if_ins, dst}); + mm->move_output_instructions_after(src, dst); + } + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond = mm->add_parameter("cond", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod = p2.create_module("then_mod"); + auto sub_neg = then_mod->add_instruction(migraphx::make_op("neg"), src); + then_mod->add_return({sub_neg}); + + auto* else_mod = p2.create_module("else_mod"); + auto sub_relu = else_mod->add_instruction(migraphx::make_op("relu"), src); + else_mod->add_return({sub_relu}); + + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({if_ins, dst}); + } + + EXPECT(p1 == p2); +} + +TEST_CASE(move_output_instructions_after_cross_module_not_between) +{ + // src has cross-module outputs, but the owning instruction (if) is already + // after dst. Nothing should move. + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape cond_s{migraphx::shape::bool_type, {1}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond = mm->add_parameter("cond", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod = p1.create_module("then_mod"); + auto sub_neg = then_mod->add_instruction(migraphx::make_op("neg"), src); + then_mod->add_return({sub_neg}); + + auto* else_mod = p1.create_module("else_mod"); + auto sub_relu = else_mod->add_instruction(migraphx::make_op("relu"), src); + else_mod->add_return({sub_relu}); + + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({if_ins, dst}); + mm->move_output_instructions_after(src, dst); + } + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond = mm->add_parameter("cond", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod = p2.create_module("then_mod"); + auto sub_neg = then_mod->add_instruction(migraphx::make_op("neg"), src); + then_mod->add_return({sub_neg}); + + auto* else_mod = p2.create_module("else_mod"); + auto sub_relu = else_mod->add_instruction(migraphx::make_op("relu"), src); + else_mod->add_return({sub_relu}); + + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({if_ins, dst}); + } + + EXPECT(p1 == p2); +} + +TEST_CASE(move_output_instructions_after_cross_module_mixed) +{ + // src has cross-module outputs to submodules of two different instructions: + // one between src and dst (should move), one after dst (should NOT move). + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape cond_s{migraphx::shape::bool_type, {1}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond1 = mm->add_parameter("cond1", cond_s); + auto cond2 = mm->add_parameter("cond2", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod1 = p1.create_module("then_mod1"); + auto sub1 = then_mod1->add_instruction(migraphx::make_op("neg"), src); + then_mod1->add_return({sub1}); + + auto* else_mod1 = p1.create_module("else_mod1"); + auto sub2 = else_mod1->add_instruction(migraphx::make_op("relu"), src); + else_mod1->add_return({sub2}); + + // if1 is between src and dst — should be moved + auto if1 = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1}); + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + + auto* then_mod2 = p1.create_module("then_mod2"); + auto sub3 = then_mod2->add_instruction(migraphx::make_op("tanh"), src); + then_mod2->add_return({sub3}); + + auto* else_mod2 = p1.create_module("else_mod2"); + auto sub4 = else_mod2->add_instruction(migraphx::make_op("sin"), src); + else_mod2->add_return({sub4}); + + // if2 is after dst — should NOT be moved + auto if2 = mm->add_instruction(migraphx::make_op("if"), {cond2}, {then_mod2, else_mod2}); + mm->add_return({if1, if2, dst}); + mm->move_output_instructions_after(src, dst); + } + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto cond1 = mm->add_parameter("cond1", cond_s); + auto cond2 = mm->add_parameter("cond2", cond_s); + auto src = mm->add_instruction(migraphx::make_op("abs"), x); + + auto* then_mod1 = p2.create_module("then_mod1"); + auto sub1 = then_mod1->add_instruction(migraphx::make_op("neg"), src); + then_mod1->add_return({sub1}); + + auto* else_mod1 = p2.create_module("else_mod1"); + auto sub2 = else_mod1->add_instruction(migraphx::make_op("relu"), src); + else_mod1->add_return({sub2}); + + auto* then_mod2 = p2.create_module("then_mod2"); + auto sub3 = then_mod2->add_instruction(migraphx::make_op("tanh"), src); + then_mod2->add_return({sub3}); + + auto* else_mod2 = p2.create_module("else_mod2"); + auto sub4 = else_mod2->add_instruction(migraphx::make_op("sin"), src); + else_mod2->add_return({sub4}); + + // Expected: if1 moved after dst, if2 stays after dst + auto dst = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto if1 = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1}); + auto if2 = mm->add_instruction(migraphx::make_op("if"), {cond2}, {then_mod2, else_mod2}); + mm->add_return({if1, if2, dst}); + } + + EXPECT(p1 == p2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/onnx/const_of_shape_zero_dim_test.onnx b/test/onnx/const_of_shape_zero_dim_test.onnx new file mode 100644 index 00000000000..6a568eaf874 Binary files /dev/null and b/test/onnx/const_of_shape_zero_dim_test.onnx differ diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 348a46b3258..1cca354df7f 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -1748,6 +1748,34 @@ def const_of_shape_dyn_int64_test(): return ([node], [output_dims], [y]) +@onnx_test() +def const_of_shape_zero_dim_test(): + tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.INT64, [1], + [10]) + # Shape with a zero dimension - results in 0 elements output + shape_val = np.array([2, 0, 4]).astype(np.int64) + shape_ts = helper.make_tensor(name='shape_tensor', + data_type=TensorProto.INT64, + dims=shape_val.shape, + vals=shape_val.flatten().astype(np.int64)) + shape_const = helper.make_node( + 'Constant', + inputs=[], + outputs=['shape'], + value=shape_ts, + ) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [2, 0, 4]) + + node = onnx.helper.make_node( + 'ConstantOfShape', + inputs=['shape'], + outputs=['y'], + value=tensor_val, + ) + + return ([shape_const, node], [], [y]) + + @onnx_test() def conv_1d_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) diff --git a/test/onnx/parse/const_of_shape_zero_dim_test.cpp b/test/onnx/parse/const_of_shape_zero_dim_test.cpp new file mode 100644 index 00000000000..94d0cf4037c --- /dev/null +++ b/test/onnx/parse/const_of_shape_zero_dim_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(const_of_shape_zero_dim_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ss(migraphx::shape::int64_type, {3}); + mm->add_literal(migraphx::literal(ss, {2, 0, 4})); + auto ret = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_return({ret}); + + auto prog = read_onnx("const_of_shape_zero_dim_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/undefined_test.cpp b/test/onnx/parse/undefined_test.cpp index 09c383a4dfa..97cd1e5dab9 100644 --- a/test/onnx/parse/undefined_test.cpp +++ b/test/onnx/parse/undefined_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,8 +29,8 @@ TEST_CASE(undefined_test) migraphx::program p; auto* mm = p.get_main_module(); mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = mm->add_instruction(migraphx::make_op("undefined")); - auto l2 = mm->add_instruction(migraphx::make_op("identity"), l1); + mm->add_instruction(migraphx::make_op("undefined")); + auto l2 = mm->add_instruction(migraphx::make_op("undefined")); mm->add_return({l2}); auto prog = read_onnx("undefined_test.onnx"); diff --git a/test/operation.cpp b/test/operation.cpp index 28b4f31f32c..ce79e3e4ac7 100644 --- a/test/operation.cpp +++ b/test/operation.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -87,7 +87,7 @@ struct compilable_op return inputs.front(); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } migraphx::value compile(migraphx::context&, const migraphx::shape&, const std::vector&) diff --git a/test/output_alias.cpp b/test/output_alias.cpp index be6b1eac102..a52927d93db 100644 --- a/test/output_alias.cpp +++ b/test/output_alias.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,17 +23,20 @@ */ #include #include +#include #include #include +using instruction_refs = std::vector; + TEST_CASE(simple_alias) { migraphx::program p; auto* mm = p.get_main_module(); auto l = mm->add_literal(1); auto p1 = mm->add_instruction(pass_op{}, l); - EXPECT(migraphx::instruction::get_output_alias(l) == l); - EXPECT(migraphx::instruction::get_output_alias(p1) == l); + EXPECT(migraphx::instruction::get_output_alias(l) == instruction_refs{l}); + EXPECT(migraphx::instruction::get_output_alias(p1) == instruction_refs{l}); } TEST_CASE(cascade_alias) @@ -44,10 +47,10 @@ TEST_CASE(cascade_alias) auto p1 = mm->add_instruction(pass_op{}, l); auto p2 = mm->add_instruction(pass_op{}, p1); auto p3 = mm->add_instruction(pass_op{}, p2); - EXPECT(migraphx::instruction::get_output_alias(l) == l); - EXPECT(migraphx::instruction::get_output_alias(p1) == l); - EXPECT(migraphx::instruction::get_output_alias(p2) == l); - EXPECT(migraphx::instruction::get_output_alias(p3) == l); + EXPECT(migraphx::instruction::get_output_alias(l) == instruction_refs{l}); + EXPECT(migraphx::instruction::get_output_alias(p1) == instruction_refs{l}); + EXPECT(migraphx::instruction::get_output_alias(p2) == instruction_refs{l}); + EXPECT(migraphx::instruction::get_output_alias(p3) == instruction_refs{l}); } TEST_CASE(no_alias) @@ -57,7 +60,80 @@ TEST_CASE(no_alias) auto x = mm->add_literal(1); auto y = mm->add_literal(2); auto sum = mm->add_instruction(sum_op{}, x, y); - EXPECT(migraphx::instruction::get_output_alias(sum) == sum); + EXPECT(migraphx::instruction::get_output_alias(sum) == instruction_refs{sum}); +} + +TEST_CASE(multiple_aliases) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(1); + auto y = mm->add_literal(2); + auto ma = mm->add_instruction(multi_alias_op{}, x, y); + auto aliases = migraphx::instruction::get_output_alias(ma); + // multi_alias_op aliases both inputs, so we should get both literals back + EXPECT(aliases.size() == 2); + EXPECT(migraphx::contains(aliases, x)); + EXPECT(migraphx::contains(aliases, y)); +} + +TEST_CASE(multiple_aliases_shallow) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(1); + auto y = mm->add_literal(2); + auto p1 = mm->add_instruction(pass_op{}, x); + auto p2 = mm->add_instruction(pass_op{}, y); + auto ma = mm->add_instruction(multi_alias_op{}, p1, p2); + // shallow=true returns immediate inputs (p1, p2), not root aliases + auto shallow_aliases = migraphx::instruction::get_output_alias(ma, true); + EXPECT(shallow_aliases.size() == 2); + EXPECT(migraphx::contains(shallow_aliases, p1)); + EXPECT(migraphx::contains(shallow_aliases, p2)); + // shallow=false (default) returns root aliases (x, y) + auto deep_aliases = migraphx::instruction::get_output_alias(ma); + EXPECT(deep_aliases.size() == 2); + EXPECT(migraphx::contains(deep_aliases, x)); + EXPECT(migraphx::contains(deep_aliases, y)); +} + +TEST_CASE(multiple_aliases_cascade) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(1); + auto y = mm->add_literal(2); + auto z = mm->add_literal(3); + // First multi_alias aliases x and y + auto ma1 = mm->add_instruction(multi_alias_op{}, x, y); + // Second multi_alias aliases ma1 and z + auto ma2 = mm->add_instruction(multi_alias_op{}, ma1, z); + // Should recursively expand to get x, y, z + auto aliases = migraphx::instruction::get_output_alias(ma2); + EXPECT(aliases.size() == 3); + EXPECT(migraphx::contains(aliases, x)); + EXPECT(migraphx::contains(aliases, y)); + EXPECT(migraphx::contains(aliases, z)); +} + +TEST_CASE(alias_vector_size) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(1); + // No alias - returns vector with self + auto aliases_self = migraphx::instruction::get_output_alias(l); + EXPECT(aliases_self.size() == 1); + // Single alias - returns vector with one element + auto p1 = mm->add_instruction(pass_op{}, l); + auto aliases_single = migraphx::instruction::get_output_alias(p1); + EXPECT(aliases_single.size() == 1); + // Multiple aliases - returns vector with multiple elements + auto x = mm->add_literal(2); + auto ma = mm->add_instruction(multi_alias_op{}, l, x); + auto aliases_multi = migraphx::instruction::get_output_alias(ma); + EXPECT(aliases_multi.size() == 2); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/propagate_constant_test.cpp b/test/propagate_constant_test.cpp index 72011fd8d83..01c37b80b29 100644 --- a/test/propagate_constant_test.cpp +++ b/test/propagate_constant_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -562,4 +562,34 @@ TEST_CASE(pack_unpack_fp4) EXPECT(m1 == m2); } +// Test that propagate_constant correctly handles multi-alias operations +// when one of the aliases should be skipped (e.g., broadcasted) +TEST_CASE(skip_propagate_multi_alias) +{ + // When an instruction has multiple aliases and one should be skipped, + // propagation should be skipped for the entire instruction + migraphx::module m1; + { + // Create a broadcasted literal (should skip propagation) + auto broadcasted = m1.add_literal( + migraphx::literal{{migraphx::shape::float_type, {2, 1}, {1, 0}}, {1.0f, 2.0f}}); + // Create a normal literal + auto normal = + m1.add_literal(migraphx::literal{{migraphx::shape::float_type, {2}}, {3.0f, 4.0f}}); + // multi_alias_op aliases both inputs + auto ma = m1.add_instruction(multi_alias_op{}, broadcasted, normal); + // Add an operation that uses ma + auto neg = m1.add_instruction(migraphx::make_op("neg"), ma); + m1.add_return({neg}); + } + + migraphx::module m2 = m1; + run_pass(m1); + + // Since one alias (broadcasted) should skip propagation, + // the multi_alias instruction should not be propagated + // The modules should be equivalent (no propagation happened) + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/replace_allocate.cpp b/test/replace_allocate.cpp index 937920b9a5e..1c7b5ca4882 100644 --- a/test/replace_allocate.cpp +++ b/test/replace_allocate.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,7 +52,7 @@ struct test_copy : migraphx::auto_register_op return inputs.back(); } - std::ptrdiff_t output_alias(const std::vector&) const { return 1; } + std::vector output_alias(const std::vector&) const { return {1}; } }; struct allocate_no_out : migraphx::auto_register_op @@ -279,4 +279,257 @@ TEST_CASE(allocate_copy_with_no_out) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(allocate_out_multi_return_partial_alloc) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto alloc = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p1 = m1.add_instruction(pass_op{}, alloc); + m1.add_return({x, p1}); + } + run_pass(m1, allocation_with_out_model{}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto output = m2.add_parameter("output_1", s); + auto p1 = m2.add_instruction(pass_op{}, output); + m2.add_return({x, p1}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Test that replace_allocate handles multi-alias operations correctly +// when checking for shape matches (insert_copy code path) +TEST_CASE(multi_alias_shape_check) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + // multi_alias_op aliases both x and y (both have same shape) + auto ma = m1.add_instruction(multi_alias_op{}, x, y); + m1.add_return({ma}); + } + + // After pass, since the multi_alias aliases inputs with matching shapes, + // no copy should be inserted + migraphx::module m2 = m1; + run_pass(m1, allocation_no_out_model{}); + + // The module should remain unchanged since aliases have matching shapes + EXPECT(m1.sort() == m2.sort()); +} + +// Test multi-alias with allocation replaced by output parameter +// When first alias is an allocation, it gets replaced with output parameter +TEST_CASE(multi_alias_alloc_out_param) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + // Create allocation that is first in multi_alias - this will be used for output naming + auto alloc = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p1 = m1.add_instruction(pass_op{}, alloc); + // Put allocation-based alias first so it gets used for output naming + auto ma = m1.add_instruction(multi_alias_op{}, p1, x); + m1.add_return({ma}); + } + run_pass(m1, allocation_with_out_model{}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + // Only the allocation becomes an output parameter (named "output" since single alloc) + auto output = m2.add_parameter("output_0", s); + auto p1 = m2.add_instruction(pass_op{}, output); + auto ma = m2.add_instruction(multi_alias_op{}, p1, x); + m2.add_return({ma}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Test multi-alias where both inputs are allocations - both become output parameters +TEST_CASE(multi_alias_two_allocs_out_param) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + // First allocation will be replaced with output parameter + auto alloc1 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p1 = m1.add_instruction(pass_op{}, alloc1); + // Second allocation will also be replaced with output parameter + auto alloc2 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p2 = m1.add_instruction(pass_op{}, alloc2); + auto ma = m1.add_instruction(multi_alias_op{}, p1, p2); + m1.add_return({ma}); + } + run_pass(m1, allocation_with_out_model{}); + + migraphx::module m2; + { + // Both aliases become output parameters + auto output0 = m2.add_parameter("output_0", s); + auto p1 = m2.add_instruction(pass_op{}, output0); + auto output1 = m2.add_parameter("output_1", s); + auto p2 = m2.add_instruction(pass_op{}, output1); + auto ma = m2.add_instruction(multi_alias_op{}, p1, p2); + m2.add_return({ma}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Test multi-alias with matching shapes - no copy needed when any alias matches +TEST_CASE(multi_alias_no_copy_when_any_matches) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::shape s2{migraphx::shape::float_type, {10}}; + migraphx::module m1; + { + // x has shape {5}, y has shape {10} + // multi_alias output has shape {5} (from first input) + // Since x's shape matches output shape, no copy is needed + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s2); + auto ma = m1.add_instruction(multi_alias_op{}, x, y); + m1.add_return({ma}); + } + + // Module should be unchanged - no copy inserted because first alias shape matches + migraphx::module m2 = m1; + run_pass(m1, allocation_no_out_model{}); + EXPECT(m1.sort() == m2.sort()); +} + +// Test multiple return values where each has a multi-alias with allocation as first alias +// Both allocations should be replaced with separate output parameters +TEST_CASE(multi_alias_multiple_outputs_out_params) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + // First allocation -> pass -> multi_alias with x + auto alloc1 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p1 = m1.add_instruction(pass_op{}, alloc1); + auto ma1 = m1.add_instruction(multi_alias_op{}, p1, x); + // Second allocation -> pass -> multi_alias with y + auto alloc2 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p2 = m1.add_instruction(pass_op{}, alloc2); + auto ma2 = m1.add_instruction(multi_alias_op{}, p2, y); + // Return both multi_alias results - each should get its own output parameter + m1.add_return({ma1, ma2}); + } + run_pass(m1, allocation_with_out_model{}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + // First output parameter replaces first allocation (named output_0) + auto output0 = m2.add_parameter("output_0", s); + auto p1 = m2.add_instruction(pass_op{}, output0); + auto ma1 = m2.add_instruction(multi_alias_op{}, p1, x); + // Second output parameter replaces second allocation (named output_1) + auto output1 = m2.add_parameter("output_2", s); + auto p2 = m2.add_instruction(pass_op{}, output1); + auto ma2 = m2.add_instruction(multi_alias_op{}, p2, y); + m2.add_return({ma1, ma2}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Test multi-alias with 3 allocations - all become output parameters +TEST_CASE(multi_alias_three_allocs) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + // Three allocations wrapped in pass_op + auto alloc1 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p1 = m1.add_instruction(pass_op{}, alloc1); + auto alloc2 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p2 = m1.add_instruction(pass_op{}, alloc2); + auto alloc3 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p3 = m1.add_instruction(pass_op{}, alloc3); + // multi_alias aliases all three allocations + auto ma = m1.add_instruction(multi_alias_op{}, p1, p2, p3); + m1.add_return({ma}); + } + run_pass(m1, allocation_with_out_model{}); + + migraphx::module m2; + { + // All three aliases become output parameters + auto output0 = m2.add_parameter("output_0", s); + auto p1 = m2.add_instruction(pass_op{}, output0); + auto output1 = m2.add_parameter("output_1", s); + auto p2 = m2.add_instruction(pass_op{}, output1); + auto output2 = m2.add_parameter("output_2", s); + auto p3 = m2.add_instruction(pass_op{}, output2); + auto ma = m2.add_instruction(multi_alias_op{}, p1, p2, p3); + m2.add_return({ma}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Test where multiple multi_alias ops each contribute an allocation to multiple returns +// Each allocation becomes a separate output parameter +TEST_CASE(multi_alias_chain_multiple_out_params) +{ + migraphx::shape s{migraphx::shape::float_type, {5}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + // Three allocations + auto alloc1 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p1 = m1.add_instruction(pass_op{}, alloc1); + auto alloc2 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p2 = m1.add_instruction(pass_op{}, alloc2); + auto alloc3 = + m1.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto p3 = m1.add_instruction(pass_op{}, alloc3); + // Each multi_alias puts an allocation first + auto ma1 = m1.add_instruction(multi_alias_op{}, p1, x); + auto ma2 = m1.add_instruction(multi_alias_op{}, p2, x); + auto ma3 = m1.add_instruction(multi_alias_op{}, p3, x); + // Return all three - each should get its own output parameter + m1.add_return({ma1, ma2, ma3}); + } + run_pass(m1, allocation_with_out_model{}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + // Three output parameters - one for each return + auto output0 = m2.add_parameter("output_0", s); + auto p1 = m2.add_instruction(pass_op{}, output0); + auto output1 = m2.add_parameter("output_2", s); + auto p2 = m2.add_instruction(pass_op{}, output1); + auto output2 = m2.add_parameter("output_4", s); + auto p3 = m2.add_instruction(pass_op{}, output2); + auto ma1 = m2.add_instruction(multi_alias_op{}, p1, x); + auto ma2 = m2.add_instruction(multi_alias_op{}, p2, x); + auto ma3 = m2.add_instruction(multi_alias_op{}, p3, x); + m2.add_return({ma1, ma2, ma3}); + } + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/run_loop_test.cpp b/test/run_loop_test.cpp index c371ef59c63..9b94c45effa 100644 --- a/test/run_loop_test.cpp +++ b/test/run_loop_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -57,7 +57,7 @@ struct copy_op return args[1]; } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct test_loop_op diff --git a/test/schedule_test.cpp b/test/schedule_test.cpp index 252d0e1c3cc..addcabb80f3 100644 --- a/test/schedule_test.cpp +++ b/test/schedule_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -50,7 +50,7 @@ struct unary_op return {}; return inputs.front(); } - int output_alias(const std::vector&) const { return 0; } + std::vector output_alias(const std::vector&) const { return {0}; } }; struct nary_op diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 1f028b63798..3c4002d3d60 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -51,6 +51,17 @@ TEST_CASE(test_dyn_4arg_constructor) EXPECT(s1.dyn_dims() == expected_dyn_dims); } +TEST_CASE(test_dyn_4arg_constructor_empty) +{ + std::vector mins; + std::vector maxes; + std::vector> opts; + migraphx::shape empty_dims{migraphx::shape::int32_type, mins, maxes, opts}; + + std::vector expected_dyn_dims = {}; + EXPECT(empty_dims.dyn_dims() == expected_dyn_dims); +} + TEST_CASE(test_shape_assign) { migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}}; diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 1afc2f84d01..4b1a78f8324 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -97,6 +97,15 @@ static std::vector run_shape_transforms(const std::vector& return result.to_vector(); } +static std::vector run_strided_view(const migraphx::shape& s, std::int64_t offset) +{ + auto n = s.element_space(); + std::vector data(n); + std::iota(data.begin(), data.end(), offset); + migraphx::literal l(migraphx::shape{migraphx::shape::int64_type, {n}}, data); + return l.get_argument().reshape(s).to_vector(); +} + static std::vector check_optimize_shape_transforms(const std::vector& dims, const std::vector& ops) @@ -125,6 +134,21 @@ static shape_transform_descriptor make_simple_descriptor(const std::vector> static generate_for( + const std::vector& dims, + const std::vector& strides, + const std::vector& idims, + std::int64_t offset = 0) +{ + migraphx::shape s{migraphx::shape::int64_type, dims, strides}; + auto result = migraphx::generate_shape_transforms_for(s, idims, offset); + if(result) + { + CHECK(run_strided_view(s, offset) == run_shape_transforms(idims, result.value())); + } + return result; +} + TEST_CASE(dimension_len) { dimension dim; @@ -948,6 +972,15 @@ TEST_CASE(rebase_unsqueeze_broadcast) make_op("reshape", {{"dims", {1, 3, 256, 2, 256, 2}}}), }); } + + { + auto desc = base_desc.rebase({1, 16, 512, 512}); + EXPECT(get_final_lens(desc) == final_lens{1, 16, 256, 2, 256, 2}); + EXPECT(get_all_lens(desc) == all_lens{{1}, {16}, {256}, {2}, {256}, {2}}); + EXPECT(desc.generate() == ops{ + make_op("reshape", {{"dims", {1, 16, 256, 2, 256, 2}}}), + }); + } } TEST_CASE(rebase_unsqueeze_broadcast_transpose) @@ -1223,4 +1256,139 @@ TEST_CASE(rebase_adjust_axes_many_moved_groups) } } +TEST_CASE(rebase_adjust_squeeze_unsqueeze_broadcast) +{ + auto base_desc = make_simple_descriptor( + {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, + make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), + make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), + make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}})); + + { + auto desc = base_desc.rebase({1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); + EXPECT(not desc.empty()); + EXPECT(get_final_lens(desc) == final_lens{1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); + EXPECT(get_all_lens(desc) == + all_lens{{1}, {1}, {1}, {1}, {1}, {1}, {32}, {10}, {16}, {1}, {90}, {160}}); + EXPECT(desc.generate() == ops{}); + } +} + +TEST_CASE(generate_shape_transforms_for) +{ + EXPECT(generate_for({3}, {1}, {3}) == ops{}); + EXPECT(generate_for({3}, {0}, {1}) == ops{make_op("multibroadcast", {{"out_lens", {3}}})}); + EXPECT(generate_for({3}, {3}, {9}) == + ops{ + make_op("reshape", {{"dims", {3, 3}}}), + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), + }); + + EXPECT(generate_for({3, 4, 5, 2}, {2, 0, 0, 1}, {6}) == + ops{ + make_op("reshape", {{"dims", {3, 1, 1, 2}}}), + make_op("multibroadcast", {{"out_lens", {3, 4, 5, 2}}}), + }); + EXPECT(generate_for({3, 2}, {3, 0}, {9}) == + ops{ + make_op("reshape", {{"dims", {3, 1, 3}}}), + make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), + make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), + }); + + EXPECT(generate_for({3, 2}, {2, 1}, {6}) == ops{ + make_op("reshape", {{"dims", {3, 2}}}), + }); + + EXPECT(generate_for({3, 2}, {1, 3}, {6}) == ops{ + make_op("reshape", {{"dims", {2, 3}}}), + make_op("transpose", {{"permutation", {1, 0}}}), + }); + + EXPECT(generate_for({2, 2, 2, 2, 3}, {0, 2, 0, 1, 0}, {4}) == + ops{ + make_op("reshape", {{"dims", {1, 2, 1, 2, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2, 3}}}), + }); + + EXPECT(generate_for({2, 2, 3}, {4, 1, 0}, {8}) == + ops{ + make_op("reshape", {{"dims", {2, 4}}}), + make_op("broadcast", {{"axis", 0}, {"out_lens", {2, 4, 3}}}), + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}), + }); + + EXPECT(generate_for({2, 3, 4, 1}, {4, 16, 1, 1}, {48}) == + ops{ + make_op("reshape", {{"dims", {3, 4, 4, 1}}}), + make_op("transpose", {{"permutation", {1, 0, 2, 3}}}), + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), + }); +} + +TEST_CASE(generate_shape_transforms_for_overlap) +{ + // TODO: Overlaping strides not supported yet, need to support something like torch.unfold. + + // Case 1: {2, 3} with strides {1, 1} - overlapping rows + // Row 0 accesses [0, 1, 2], Row 1 accesses [1, 2, 3] + // Total elements needed: 4 (exactly matches input size) + EXPECT(generate_for({2, 3}, {1, 1}, {4}) == std::nullopt); + // EXPECT(generate_for({2, 3}, {1, 1}, {4}) == + // ops{ + // make_op("broadcast", {{"axis", 0}, {"out_lens", {2, 4}}}), + // make_op("reshape", {{"dims", {8}}}), + // make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), + // make_op("reshape", {{"dims", {4}}}), + // make_op("reshape", {{"dims", {2, 2}}}), + // make_op("multibroadcast", {{"out_lens", {2, 3}}}), + // make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {3}}}), + // }); + + // Case 2: {3, 2, 1} with strides {3, 2, 1} + // Element at (i,j,k) is at index i*3 + j*2 + k*1 + // Max index is (2,1,0) = 2*3 + 1*2 + 0*1 = 8 + // So we need 9 elements total (indices 0-8) + EXPECT(generate_for({3, 2, 1}, {3, 2, 1}, {9}) == std::nullopt); + // EXPECT(generate_for({3, 2, 1}, {3, 2, 1}, {9}) == + // ops{ + // make_op("reshape", {{"dims", {9}}}), + // // Extract the specific pattern of elements based on strides + // make_op("reshape", {{"dims", {3, 3}}}), + // make_op("transpose", {{"permutation", {1, 0}}}), + // make_op("reshape", {{"dims", {3, 3}}}), + // make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}), + // make_op("reshape", {{"dims", {3, 2}}}), + // make_op("multibroadcast", {{"out_lens", {3, 2, 1}}}), + // }); +} + +TEST_CASE(generate_shape_transforms_for_offset) +{ + EXPECT(generate_for({3, 1}, {4, 1}, {30}, 1) == + ops{ + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {24}}}), + make_op("reshape", {{"dims", {2, 3, 4}}}), + make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 2}}}), + }); + + EXPECT(generate_for({3, 1}, {5, 1}, {30}, 1) == + ops{ + make_op("reshape", {{"dims", {2, 3, 5}}}), + make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 2}}}), + }); + + EXPECT(generate_for({3, 2}, {10, 1}, {60}, 1) == + ops{ + make_op("reshape", {{"dims", {2, 3, 10}}}), + make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 3}}}), + }); + + EXPECT(generate_for({4, 3, 2}, {24, 4, 1}, {96}, 5) == + ops{ + make_op("reshape", {{"dims", {4, 6, 4}}}), + make_op("slice", {{"axes", {1, 2}}, {"starts", {1, 1}}, {"ends", {4, 3}}}), + }); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index c8722193156..9c6d8d86973 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1336,6 +1336,26 @@ TEST_CASE(concat_convert_fusion) EXPECT(m1 == m2); } +TEST_CASE(concat_convert_mismatched_input_types) +{ + auto sx = migraphx::shape{migraphx::shape::float_type, {1, 128}}; + auto sy = migraphx::shape{migraphx::shape::int32_type, {1, 64}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", sx); + auto y = m1.add_parameter("y", sy); + auto xc = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bf16_type}}), x); + auto yc = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bf16_type}}), y); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), xc, yc); + m1.add_instruction(pass_op{}, concat); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + TEST_CASE(simplify_div_const) { migraphx::module m1; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index ce8f1f50eec..acaba04f8a0 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,7 +38,8 @@ static void run_pass(migraphx::module& m) { migraphx::run_passes(m, { - migraphx::simplify_reshapes{.enable_op_shape_transform_op = true}, + migraphx::simplify_reshapes{.enable_op_shape_transform_op = true, + .enable_gather_rewrite = true}, migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}, }); @@ -88,8 +89,8 @@ TEST_CASE(broadcast_transpose) run_pass(m1); migraphx::module m2; { - auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); - auto b = m2.add_instruction( + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto b = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {5, 2, 3}}}), l); m2.add_return({b}); } @@ -112,8 +113,8 @@ TEST_CASE(broadcast_transpose_opt) run_pass(m1); migraphx::module m2; { - auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); - auto b = m2.add_instruction( + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto b = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {3, 2, 5}}}), l); m2.add_return({b}); } @@ -908,7 +909,7 @@ TEST_CASE(concat_multibroadcasts1) auto new_concat = std::find_if(m.begin(), m.end(), [](const auto& ins) { return ins.name() == "concat"; }); EXPECT(new_concat != m.end()); - auto cd = std::distance(m.begin(), new_concat); + auto cd = std::distance(m.begin(), new_concat); auto new_mb = std::find_if( m.begin(), m.end(), [](const auto& ins) { return ins.name() == "multibroadcast"; }); auto md = std::distance(m.begin(), new_mb); @@ -931,7 +932,7 @@ TEST_CASE(concat_multibroadcasts2) auto new_concat = std::find_if(m.begin(), m.end(), [](const auto& ins) { return ins.name() == "concat"; }); EXPECT(new_concat != m.end()); - auto cd = std::distance(m.begin(), new_concat); + auto cd = std::distance(m.begin(), new_concat); auto new_mb = std::find_if( m.begin(), m.end(), [](const auto& ins) { return ins.name() == "multibroadcast"; }); auto md = std::distance(m.begin(), new_mb); @@ -954,7 +955,7 @@ TEST_CASE(concat_multibroadcasts3) auto new_concat = std::find_if(m.begin(), m.end(), [](const auto& ins) { return ins.name() == "concat"; }); EXPECT(new_concat != m.end()); - auto cd = std::distance(m.begin(), new_concat); + auto cd = std::distance(m.begin(), new_concat); auto new_mb = std::find_if( m.begin(), m.end(), [](const auto& ins) { return ins.name() == "multibroadcast"; }); auto md = std::distance(m.begin(), new_mb); @@ -1516,8 +1517,7 @@ TEST_CASE(optimize_resize) auto create_optimized_module = [&] { migraphx::module m; - auto inx = m.add_parameter("X", sx); - std::vector dims = {1, 1, 2, 1, 2, 1}; + auto inx = m.add_parameter("X", sx); auto rspx = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3, 5}}}), inx); auto mbx = m.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 2, 2, 2, 3}}}), rspx); @@ -1534,6 +1534,47 @@ TEST_CASE(optimize_resize) EXPECT(m1 == create_optimized_module()); } +TEST_CASE(optimize_resize_flatten) +{ + migraphx::shape sx{migraphx::shape::float_type, {4}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {48}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, + 3, 3, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), inx, li); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), gr); + m.add_return({r}); + + return m; + }; + + auto m1 = create_resize_module(); + run_pass(m1); + + auto create_optimized_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + auto rspx = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 1, 2, 1}}}), inx); + auto mbx = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2, 3}}}), rspx); + + auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), mbx); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), rmb); + m.add_return({r}); + + return m; + }; + + EXPECT(m1 == create_optimized_module()); +} + TEST_CASE(optimize_resize_ind_not_apply) { migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; @@ -1588,54 +1629,86 @@ TEST_CASE(optimize_resize_ndims_unequal) { migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 3, 2}}; - auto create_resize_module = [&] { - migraphx::module m; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 3, 2}}; std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; - auto li = m.add_literal(migraphx::literal(si, ind)); + auto li = m1.add_literal(migraphx::literal(si, ind)); - auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); - auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); - auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); - m.add_return({r}); + auto lrsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto gr = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m1.add_instruction(migraphx::make_op("sub"), iny, gr); + m1.add_return({r}); + } + run_pass(m1); - return m; - }; + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto rsp_y = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 3}}}), iny); + auto trans_x = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1}}}), inx); + auto mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 3}}}), trans_x); + auto sub = m2.add_instruction(migraphx::make_op("sub"), rsp_y, mb); + auto rsp_out = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 4, 3, 2}}}), sub); + m2.add_return({rsp_out}); + } - auto m = create_resize_module(); - run_pass(m); - EXPECT(m == create_resize_module()); + EXPECT(m1.sort() == m2.sort()); } TEST_CASE(optimize_resize_ind_non_brcst) { migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}}; - auto create_resize_module = [&] { - migraphx::module m; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; - auto li = m.add_literal(migraphx::literal(si, ind)); + auto li = m1.add_literal(migraphx::literal(si, ind)); - auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); - auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); - auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); - m.add_return({r}); + auto lrsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); + auto gr = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m1.add_instruction(migraphx::make_op("sub"), iny, gr); + m1.add_return({r}); + } + run_pass(m1); - return m; - }; + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), rsp1); + auto rsp2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 1, 2, 1}}}), slc); + auto mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 2, 2, 2, 3}}}), rsp2); + auto rsp_y = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 2, 2, 3}}}), iny); + auto sub = m2.add_instruction(migraphx::make_op("sub"), rsp_y, mb); + auto rsp3 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 4, 6}}}), sub); + m2.add_return({rsp3}); + } - auto m = create_resize_module(); - run_pass(m); - EXPECT(m == create_resize_module()); + EXPECT(m1.sort() == m2.sort()); } TEST_CASE(optimize_resize_ind_non_const) @@ -1680,21 +1753,32 @@ TEST_CASE(optimize_where_true) return m; }; - auto return_xy = [&](bool cond) { + auto create_expected = [&](bool cond) { migraphx::module m; - auto x = m.add_parameter("X", s); - auto y = m.add_parameter("Y", s); - cond ? m.add_return({x}) : m.add_return({y}); + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto bc = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 1, 3, 2}}}), rsp); + int64_t start = cond ? 1 : 0; + int64_t end = cond ? 2 : 1; + auto slc = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {start}}, {"ends", {end}}}), bc); + m.add_return({slc}); return m; }; auto m = create_where_module(true); run_pass(m); - EXPECT(m == return_xy(true)); + auto expected = create_expected(true); + EXPECT(m.sort() == expected.sort()); auto m1 = create_where_module(false); run_pass(m1); - EXPECT(m1 == return_xy(false)); + auto expected1 = create_expected(false); + EXPECT(m1.sort() == expected1.sort()); } TEST_CASE(where_different_cond_values) @@ -1722,96 +1806,1083 @@ TEST_CASE(where_different_cond_values) TEST_CASE(where_axis_nonzero) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", s); - auto iny = m.add_parameter("Y", s); + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", s); + auto iny = m1.add_parameter("Y", s); migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + migraphx::module m2; + { + auto inx = m2.add_parameter("X", s); + auto iny = m2.add_parameter("Y", s); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); + auto tr = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0}}}), data); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 1}}}), tr); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 1, 3, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); } TEST_CASE(where_three_concat_inputs) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", s); - auto iny = m.add_parameter("Y", s); + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", s); + auto iny = m1.add_parameter("Y", s); migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", s); + auto iny = m2.add_parameter("Y", s); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {18, 1, 3, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(where_three_inputs_diff_shapes) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {18, 1, 3, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(where_three_lens_diff) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 6}}; + std::vector idata(6, 1); + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 1, 6}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_1d_nd_indices) +{ + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {6}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {0, 1, 2, 3, 4, 5}; + auto li = m.add_literal(migraphx::literal(si, indices)); + auto g = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, li); + m.add_return({g}); + + run_pass(m); + + migraphx::module expected; + auto xe = expected.add_parameter("x", {migraphx::shape::float_type, {6}}); + auto reshaped = expected.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3}}}), xe); + expected.add_return({reshaped}); + + EXPECT(m == expected); +} + +TEST_CASE(gather_axis_slice_broadcast) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {1, 1, 1, 2, 2, 2}; + auto li = m1.add_literal(migraphx::literal(si, indices)); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 4}}); + auto br = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {2, 4, 3}}}), x); + auto sliced = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {3}}}), br); + m2.add_return({sliced}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_single_index) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {1}}; + auto indices = m1.add_literal(migraphx::literal{si, {2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m2.add_parameter("data", s); + auto reshaped = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 20}}}), data); + auto sliced = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {10}}, {"ends", {15}}}), + reshaped); + auto unsqueezed = + m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), sliced); + m2.add_return({unsqueezed}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(gather_multi_axis_stride) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto flatten = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3, 1, 4}}; + std::vector indices = {0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, + 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39}; + auto li = m1.add_literal(migraphx::literal{indices_shape, indices}); + auto gather = m1.add_instruction(migraphx::make_op("gather"), flatten, li); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), x); + auto tr = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3, 4}}}), unsq); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr); + auto sliced = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), sq); + m2.add_return({sliced}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_multi_axis_stride) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {48}}); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {24}}; + std::vector indices = {0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, + 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39}; + auto li = m1.add_literal(migraphx::literal{indices_shape, indices}); + auto gather = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {48}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 4}}}), x); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {24}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_same_indices) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 1, 1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m2.add_parameter("data", s); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), data); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 20}}}), unsq); + auto mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 3, 20}}}), rsp1); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), mb); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 3, 4, 5}}}), slc); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), rsp2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_same_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 1, 1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + auto unsqueeze = + m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), gather); + m1.add_return({unsqueeze}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m2.add_parameter("data", s); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 3}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_sequential_indices) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {5, 6}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 2, 3}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {5, 6}}; + auto data = m2.add_parameter("data", s); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {30}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {24}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 6}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_sequential_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 2, 3}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), data); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 30}}}), unsq); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {60}}}), mb); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {31}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {10, 3}}}), slc1); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), rsp2); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slc2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 5, 9}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {24}}}), data); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4}}}), slc1); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 2}}}), + rsp); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slc2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_divisible_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 5, 10}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 5}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 0}}, {"ends", {1, 1}}}), + rsp); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slc); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_divisible_indices_window_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {5, 10, 15}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), data); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 30}}}), unsq); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {60}}}), mb); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {5}}, {"ends", {35}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 5}}}), slc1); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 0}}, {"ends", {1, 1}}}), + rsp2); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slc2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_divisible_both_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {15}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 5, 10}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {15}}; + auto data = m2.add_parameter("data", s); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 5}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rsp); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), slc); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_sequential_stride_rtr_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {8}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {8}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 4, 1, 5, 2, 6, 3, 7}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {8}}; + auto data = m2.add_parameter("data", s); + auto reshape1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4}}}), data); + auto transpose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), reshape1); + auto reshape2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), transpose); + m2.add_return({reshape2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_sequential_stride_rtr_window_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {8}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 4, 7, 10, 2, 5, 8, 11}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m2.add_parameter("data", s); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3}}}), data); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), rsp1); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis0_half_split_concat) +{ + // This pattern is not optimized - gather remains + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {4, 3}}); + migraphx::shape si{migraphx::shape::int32_type, {4}}; + std::vector indices = {2, 3, 0, 1}; + auto li = m1.add_literal(migraphx::literal(si, indices)); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, li); + m1.add_return({g}); + } + auto m2 = m1; + run_pass(m1); + + // Verify output shape is correct: {4, 3} + auto result = + std::find_if(m1.begin(), m1.end(), [](const auto& ins) { return ins.name() == "@return"; }); + EXPECT(result != m1.end()); + EXPECT(result->inputs().front()->get_shape().lens() == std::vector{4, 3}); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis1_same_stride_diff_base) +{ + // This pattern is not optimized - gather remains + migraphx::module m1; + { + migraphx::shape si{migraphx::shape::int32_type, {2, 2}}; + std::vector indices = {1, 1, 0, 2}; + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + auto tx = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x); + auto ind = m1.add_literal(migraphx::literal{si, indices}); + auto tind = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), tx, tind); + m1.add_return({g}); + } + auto m2 = m1; + // Verify there is no hang + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis1_transposed) +{ + migraphx::module m1; + { + migraphx::shape si{migraphx::shape::int32_type, {1}}; + std::vector indices = {1}; + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto tx = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x); + auto ind = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), tx, ind); + m1.add_return({g}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), x); + auto transpose = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transpose); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); + m2.add_return({squeeze}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_stride_slice) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {8}}); + migraphx::shape si{migraphx::shape::int32_type, {4}}; + std::vector indices = {1, 5, 2, 6}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {8}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4}}}), x); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_stride_first) +{ + migraphx::module m; + auto x = m.add_parameter("X", {migraphx::shape::float_type, {8}}); + migraphx::shape si{migraphx::shape::int32_type, {4}}; + std::vector indices = {0, 2, 4, 6}; + auto li = m.add_literal(migraphx::literal{si, indices}); + auto g = m.add_instruction(migraphx::make_op("gather"), x, li); + m.add_return({g}); + + run_pass(m); + + migraphx::module expected; + auto xe = expected.add_parameter("X", {migraphx::shape::float_type, {8}}); + auto reshape_block = + expected.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 2}}}), xe); + auto slice = expected.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), reshape_block); + auto result = expected.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), slice); + expected.add_return({result}); + + EXPECT(m == expected); +} + +TEST_CASE(gather_flatten_stride_offset) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {16}}); + migraphx::shape si{migraphx::shape::int32_type, {1, 4}}; + std::vector indices = {1, 5, 9, 13}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {16}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 4}}}), x); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rsp); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), slc); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), unsq); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_stride_grid) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {768}}); + migraphx::shape si{migraphx::shape::int32_type, {48}}; + std::vector indices = {17, 21, 25, 29, 81, 85, 89, 93, 145, 149, 153, 157, + 209, 213, 217, 221, 273, 277, 281, 285, 337, 341, 345, 349, + 401, 405, 409, 413, 465, 469, 473, 477, 529, 533, 537, 541, + 593, 597, 601, 605, 657, 661, 665, 669, 721, 725, 729, 733}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {768}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 16, 4}}}), x); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {4, 1}}, {"ends", {8, 2}}}), + rsp); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_permutation) +{ + migraphx::module m; + auto x = m.add_parameter("X", {migraphx::shape::float_type, {16}}); + migraphx::shape si{migraphx::shape::int32_type, {16}}; + std::vector indices = {0, 2, 8, 10, 4, 6, 12, 14, 1, 3, 9, 11, 5, 7, 13, 15}; + auto li = m.add_literal(migraphx::literal{si, indices}); + auto g = m.add_instruction(migraphx::make_op("gather"), x, li); + m.add_return({g}); + + run_pass(m); + + migraphx::module expected; + auto xe = expected.add_parameter("X", {migraphx::shape::float_type, {16}}); + auto reshape_perm = + expected.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2}}}), xe); + auto transpose = expected.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 1, 0, 2}}}), reshape_perm); + auto reshape_out = + expected.add_instruction(migraphx::make_op("reshape", {{"dims", {16}}}), transpose); + expected.add_return({reshape_out}); + + EXPECT(m == expected); +} + +TEST_CASE(gather_flatten_channel_patch) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {48}}); + migraphx::shape si{migraphx::shape::int32_type, {12}}; + std::vector indices = {5, 21, 37, 9, 25, 41, 6, 22, 38, 10, 26, 42}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {48}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 4}}}), x); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {3, 3}}}), + tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_channel_parity_permutation) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {48}}); + migraphx::shape si{migraphx::shape::int32_type, {48}}; + std::vector indices = {0, 2, 8, 10, 16, 18, 24, 26, 32, 34, 40, 42, + 4, 6, 12, 14, 20, 22, 28, 30, 36, 38, 44, 46, + 1, 3, 9, 11, 17, 19, 25, 27, 33, 35, 41, 43, + 5, 7, 13, 15, 21, 23, 29, 31, 37, 39, 45, 47}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {48}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {6, 2, 2, 2}}}), x); + auto tr = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 1, 0, 2}}}), rsp); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), tr); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis1_factorized_grid_const) +{ + migraphx::module m1; + { + auto data = m1.add_parameter("data", {migraphx::shape::float_type, {3, 8, 5}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 2, 1}}; + std::vector indices = {1, 3, 5, 7}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto data = m2.add_parameter("data", {migraphx::shape::float_type, {3, 8, 5}}); + auto rsp1 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{3, 4, 2, 5}}}), data); + auto rsp2 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{12, 10}}}), rsp1); + auto slc = m2.add_instruction(migraphx::make_op("slice", + {{"axes", std::vector{1}}, + {"starts", std::vector{5}}, + {"ends", std::vector{10}}}), + rsp2); + auto rsp3 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{3, 2, 2, 1, 5}}}), slc); + m2.add_return({rsp3}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis1_factorized_grid_multi_const) +{ + migraphx::module m1; + { + auto data = m1.add_parameter("data", {migraphx::shape::float_type, {2, 27, 4}}); + migraphx::shape si{migraphx::shape::int32_type, {3, 1}}; + std::vector indices = {5, 14, 23}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto data = m2.add_parameter("data", {migraphx::shape::float_type, {2, 27, 4}}); + auto rsp1 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{2, 3, 9, 4}}}), data); + auto rsp2 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{6, 36}}}), rsp1); + auto slc = m2.add_instruction(migraphx::make_op("slice", + {{"axes", std::vector{1}}, + {"starts", std::vector{20}}, + {"ends", std::vector{24}}}), + rsp2); + auto rsp3 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{2, 3, 1, 4}}}), slc); + m2.add_return({rsp3}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_scalar_index) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type}; + auto indices = m1.add_literal(migraphx::literal{si, {2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m2.add_parameter("data", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), data); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); + m2.add_return({squeeze}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_scalar_index_axis2) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {8, 32, 19}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type}; + auto indices = m1.add_literal(migraphx::literal{si, {0}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {8, 32, 19}}; + auto data = m2.add_parameter("data", s); + auto reshape1 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {256, 19}}}), data); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), reshape1); + auto reshape2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {8, 32, 1}}}), slice); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), reshape2); + m2.add_return({squeeze}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_scalar_index_single_dim) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {4}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type}; + auto indices = m1.add_literal(migraphx::literal{si, {2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {4}}; + auto data = m2.add_parameter("data", s); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), data); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), unsqueeze); + auto reshape = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1}}}), slice); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reshape); + m2.add_return({squeeze}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_negative_index) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {1}}; + auto indices = m1.add_literal(migraphx::literal{si, {-1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m2.add_parameter("data", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), data); + m2.add_return({slice}); + } + + EXPECT(m1.sort() == m2.sort()); +} - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); +TEST_CASE(gather_non_constant_indices) +{ + // Should not be transformed + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto si = migraphx::shape{migraphx::shape::int32_type, {2}}; + auto data = m1.add_parameter("data", s); + auto indices = m1.add_parameter("indices", si); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); } -TEST_CASE(where_three_inputs_diff_shapes) +TEST_CASE(gather_axis_1) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; - migraphx::shape sy{migraphx::shape::float_type, {2, 1, 3, 2}}; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 5, 3}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {2}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; - std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 5, 3}}; + auto data = m2.add_parameter("data", s); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 15}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 3}}}), slc); + m2.add_return({rsp2}); + } - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(where_three_lens_diff) +TEST_CASE(gather_onnx_axis_one_ex) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; - migraphx::shape sy{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); - - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 6}}; - std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 3}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {2, 1}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, indices); + m1.add_return({gather}); + } + migraphx::module m2 = m1; + run_pass(m1); - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + EXPECT(m1.sort() == m2.sort()); } TEST_CASE(reshape_cont) @@ -2080,8 +3151,8 @@ TEST_CASE(literal_reshape_unary_transpose_pointwise) run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", s2); - auto one = m2.add_literal(migraphx::generate_literal(s1)); + auto x = m2.add_parameter("x", s2); + auto one = m2.add_literal(migraphx::generate_literal(s1)); auto reshape_ins = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), one); auto transpose = m2.add_instruction( @@ -2561,6 +3632,67 @@ TEST_CASE(reduce_squeeze_pointwise3) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reduce_squeeze_pointwise4) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 3}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto reduce_sum = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto squeeze = + m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_sum); + auto add = m1.add_instruction(migraphx::make_op("add"), squeeze, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto reduce_sum = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), y); + auto add = m2.add_instruction(migraphx::make_op("add"), reduce_sum, unsqueeze); + auto relu = m2.add_instruction(migraphx::make_op("relu"), add); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), relu); + m2.add_return({squeeze}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reduce_squeeze_pointwise5) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 3}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {1}, {0}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto reduce_sum = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto squeeze = + m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_sum); + auto add = m1.add_instruction(migraphx::make_op("add"), squeeze, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto reduce_sum = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto unsqueeze = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1}}}), y); + auto add = m2.add_instruction(migraphx::make_op("add"), reduce_sum, unsqueeze); + auto relu = m2.add_instruction(migraphx::make_op("relu"), add); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), relu); + m2.add_return({squeeze}); + } + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(reduce_squeeze_broadcast_pointwise) { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 32, 10, 64, 64}}; @@ -2930,13 +4062,13 @@ TEST_CASE(transpose_unsqueeze_concat) std::vector unsqueezed_args; int64_t axis = 3; - std::transform( - args.begin(), - args.end(), - std::back_inserter(unsqueezed_args), - [&](migraphx::instruction_ref arg) { - return m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); - }); + std::transform(args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&](migraphx::instruction_ref arg) { + return m1.add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); + }); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args); m1.add_return({concat}); @@ -2951,13 +4083,13 @@ TEST_CASE(transpose_unsqueeze_concat) std::vector unsqueezed_args; int64_t axis = 1; - std::transform( - args.begin(), - args.end(), - std::back_inserter(unsqueezed_args), - [&](migraphx::instruction_ref arg) { - return m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); - }); + std::transform(args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&](migraphx::instruction_ref arg) { + return m2.add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); + }); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args); auto transpose = m2.add_instruction( @@ -3474,6 +4606,63 @@ TEST_CASE(add_transpose) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(slice_squeeze_unsqueeze) +{ + migraphx::shape s{migraphx::shape::float_type, {12, 6, 1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto slice = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), x); + auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slice); + auto unsqueeze = + m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), squeeze); + m1.add_return({unsqueeze}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), x); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), unsqueeze); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), squeeze); + m2.add_return({slice}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_reshape_reshape) +{ + migraphx::shape s{migraphx::shape::float_type, {3, 3, 20}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto slice = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), x); + auto reshape1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {60}}}), slice); + auto reshape2 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 5}}}), reshape1); + m1.add_return({reshape2}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), x); + auto reshape = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 3, 4, 5}}}), slice); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reshape); + m2.add_return({squeeze}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(flatten) { migraphx::shape s{migraphx::shape::float_type, {4608, 8, 2}}; @@ -3661,4 +4850,195 @@ TEST_CASE(conv_add_layernorm_conv) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(argmin_reshape_unsqueeze) +{ + // Test that argmin followed by reshape to add dimension gets simplified to unsqueeze + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 12}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {2, 3, 4}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmin = m1.add_instruction(migraphx::make_op("argmin", {{"axis", 1}}), x); + auto reshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 1}}}), argmin); + auto bc = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), + reshape); + auto add = m1.add_instruction(migraphx::make_op("add"), bc, y); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto argmin = m2.add_instruction(migraphx::make_op("argmin", {{"axis", 1}}), x); + auto unsqueeze = + m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), argmin); + auto bc = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), + unsqueeze); + auto add = m2.add_instruction(migraphx::make_op("add"), bc, y); + m2.add_return({add}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(argmax_reshape_unsqueeze) +{ + // Test that argmax followed by reshape to add dimension gets simplified to unsqueeze + auto s1 = migraphx::shape{migraphx::shape::float_type, {3, 8}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {3, 2, 4}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmax = m1.add_instruction(migraphx::make_op("argmax", {{"axis", 1}}), x); + auto reshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 1}}}), argmax); + auto bc = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), + reshape); + auto add = m1.add_instruction(migraphx::make_op("add"), bc, y); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto argmax = m2.add_instruction(migraphx::make_op("argmax", {{"axis", 1}}), x); + auto unsqueeze = + m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), argmax); + auto bc = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), + unsqueeze); + auto add = m2.add_instruction(migraphx::make_op("add"), bc, y); + m2.add_return({add}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(argmin_negative_axis_reshape) +{ + // Test that argmin with negative axis followed by reshape works correctly + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {2, 3, 2, 2}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmin = m1.add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), x); + auto reshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 1, 1}}}), argmin); + auto bc = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), + reshape); + auto add = m1.add_instruction(migraphx::make_op("add"), bc, y); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto argmin = m2.add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), x); + auto unsqueeze = + m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), argmin); + auto bc = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), + unsqueeze); + auto add = m2.add_instruction(migraphx::make_op("add"), bc, y); + m2.add_return({add}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(argmin_squeeze_pointwise) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 8, 1024, 1280}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {1, 1024, 1280}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmin = m1.add_instruction(migraphx::make_op("argmin", {{"axis", 1}}), x); + auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), argmin); + auto add = m1.add_instruction(migraphx::make_op("add"), squeeze, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + +TEST_CASE(argmax_squeeze_pointwise) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 8, 512, 640}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {1, 512, 640}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmax = m1.add_instruction(migraphx::make_op("argmax", {{"axis", 1}}), x); + auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), argmax); + auto add = m1.add_instruction(migraphx::make_op("add"), squeeze, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + +TEST_CASE(argmin_negative_axis_squeeze_pointwise) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 8, 256, 320}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {1, 256, 320}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmin = m1.add_instruction(migraphx::make_op("argmin", {{"axis", -3}}), x); + auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), argmin); + auto add = m1.add_instruction(migraphx::make_op("add"), squeeze, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + +TEST_CASE(argmax_negative_axis_squeeze_pointwise) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 4, 128, 160}}; + auto s2 = migraphx::shape{migraphx::shape::int64_type, {1, 128, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto argmax = m1.add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), x); + auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), argmax); + auto add = m1.add_instruction(migraphx::make_op("add"), squeeze, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + +TEST_CASE(gather_strided_view_elements_mismatch) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {12}}); + migraphx::shape si{migraphx::shape::int32_type, {6}}; + std::vector indices = {0, 2, 5, 7, 9, 11}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, li); + m1.add_return({g}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 4679f265753..cabdffb3ff3 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -294,4 +295,112 @@ TEST_CASE(double_split_live) EXPECT(p1.sort() == p2.sort()); } +// Test multi-alias in parallel reduce scenario - both reduce outputs are aliased by multi_alias_op +// The pass should split both reduces and extract the multi_alias to the main module +TEST_CASE(parallel_reduce_multi_alias) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p1, "fuse_reduce0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + auto xx = add_pointwise(p1, rm, "main:pointwise0", {inputs[0]}, squared()); + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); + // multi_alias_op aliases both reduce outputs + return rm->add_instruction(multi_alias_op{}, rsum1, rsum2); + }); + mm->add_return({rsum}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + // The pointwise (squared) is extracted to main module + auto xx = add_pointwise(p2, mm, "main:pointwise0", {x}, squared()); + // Split module takes both xx and x as inputs + auto rsum = + add_reduce(p2, + "fuse_reduce0_split", + {xx, x}, + {2}, + "assign_add", + [&](auto* rm, + const auto& inputs, + const auto& axes) -> std::vector { + // inputs[0] is xx (squared), inputs[1] is x + // The pass returns (rsum1, rsum2) order based on the original fused + // module order + auto rsum1 = rm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[1]); + auto rsum2 = rm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + return {rsum1, rsum2}; + }); + auto rsum1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsum); + auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsum); + // multi_alias_op is moved to main module after split + auto ma = mm->add_instruction(multi_alias_op{}, rsum1, rsum2); + mm->add_return({ma}); + } + EXPECT(p1.sort() == p2.sort()); +} + +// Test that find_alive correctly identifies live instructions through multi-alias chain +// sqrt is computed before reduce, used after reduce through multi_alias - should be split out +TEST_CASE(split_with_multi_alias_alive) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p1, "fuse_reduce0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + // Create a computation before the reduce + auto sqrt = + add_pointwise(p1, rm, "main:pointwise0", {inputs[0]}, single_pointwise("sqrt")); + auto rsum1 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), sqrt); + // multi_alias aliases sqrt and rsum1 - sqrt should be identified as alive + auto ma = rm->add_instruction(multi_alias_op{}, sqrt, rsum1); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), ma); + return add_pointwise( + p1, rm, "main:pointwise1", {rsumb, sqrt}, single_pointwise("mul")); + }); + mm->add_return({rsum}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + // sqrt is computed first, then passed to split module + auto sqrt = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); + auto rsums = + add_reduce(p2, + "fuse_reduce0_split", + {sqrt}, + {2}, + "assign_add", + [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + }); + // After split: multi_alias(sqrt, rsums) - shape is {2,3,327680} from sqrt + // multibroadcast is eliminated since multi_alias already has the right shape + auto ma = mm->add_instruction(multi_alias_op{}, sqrt, rsums); + // multiply multi_alias result with sqrt + auto result = add_pointwise(p2, mm, "main:pointwise1", {ma, sqrt}, single_pointwise("mul")); + mm->add_return({result}); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/tf/tests/stridedslice_test.cpp b/test/tf/tests/stridedslice_test.cpp index 514ea53fb7a..f05cc49b5b9 100644 --- a/test/tf/tests/stridedslice_test.cpp +++ b/test/tf/tests/stridedslice_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,7 +37,7 @@ TEST_CASE(stridedslice_test) migraphx::make_op( "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}), l1); - auto shrink_axis = 1; + auto shrink_axis = 2; mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); auto prog = optimize_tf("stridedslice_test.pb", true); diff --git a/test/verify/test_gather_axis0_half_split_concat.cpp b/test/verify/test_gather_axis0_half_split_concat.cpp new file mode 100644 index 00000000000..b16cfed672e --- /dev/null +++ b/test/verify/test_gather_axis0_half_split_concat.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis0_half_split_concat : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {4, 3}}); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4}}; + std::vector indices = {2, 3, 0, 1}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction( + migraphx::make_op("gather", {{"axis", int64_t{0}}}), data, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_axis0_slice_broadcast.cpp b/test/verify/test_gather_axis0_slice_broadcast.cpp new file mode 100644 index 00000000000..f5598035cd0 --- /dev/null +++ b/test/verify/test_gather_axis0_slice_broadcast.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis0_slice_broadcast : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 4}}); + auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 8}}; + std::vector indices = {0, 0, 0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3, 3, 3}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), reshape, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_axis1_factorized_grid_const.cpp b/test/verify/test_gather_axis1_factorized_grid_const.cpp new file mode 100644 index 00000000000..de69e38bc0b --- /dev/null +++ b/test/verify/test_gather_axis1_factorized_grid_const.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis1_factorized_grid_const + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("data", {migraphx::shape::float_type, {3, 8, 5}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 2, 1}}; + std::vector indices = {1, 3, 5, 7}; + auto li = mm->add_literal(migraphx::literal{si, indices}); + auto g = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + mm->add_return({g}); + + return p; + } +}; diff --git a/test/verify/test_gather_axis1_factorized_grid_multi_const.cpp b/test/verify/test_gather_axis1_factorized_grid_multi_const.cpp new file mode 100644 index 00000000000..6c49952d49c --- /dev/null +++ b/test/verify/test_gather_axis1_factorized_grid_multi_const.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis1_factorized_grid_multi_const + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("data", {migraphx::shape::float_type, {2, 27, 4}}); + migraphx::shape si{migraphx::shape::int32_type, {3, 1}}; + std::vector indices = {5, 14, 23}; + auto li = mm->add_literal(migraphx::literal{si, indices}); + auto g = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + mm->add_return({g}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_channel_parity.cpp b/test/verify/test_gather_flatten_channel_parity.cpp new file mode 100644 index 00000000000..074ba43369f --- /dev/null +++ b/test/verify/test_gather_flatten_channel_parity.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_channel_parity : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4, 3, 2, 2}}; + std::vector indices = {0, 2, 8, 10, 16, 18, 24, 26, 32, 34, 40, 42, + 4, 6, 12, 14, 20, 22, 28, 30, 36, 38, 44, 46, + 1, 3, 9, 11, 17, 19, 25, 27, 33, 35, 41, 43, + 5, 7, 13, 15, 21, 23, 29, 31, 37, 39, 45, 47}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_channel_patch.cpp b/test/verify/test_gather_flatten_channel_patch.cpp new file mode 100644 index 00000000000..97d0b409e5f --- /dev/null +++ b/test/verify/test_gather_flatten_channel_patch.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_channel_patch : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4, 3, 1, 1}}; + std::vector indices = {5, 21, 37, 9, 25, 41, 6, 22, 38, 10, 26, 42}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_multi_axis_stride.cpp b/test/verify/test_gather_flatten_multi_axis_stride.cpp new file mode 100644 index 00000000000..c6d3e4af253 --- /dev/null +++ b/test/verify/test_gather_flatten_multi_axis_stride.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_multi_axis_stride : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3, 1, 4}}; + std::vector indices = {0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, + 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_permutation.cpp b/test/verify/test_gather_flatten_permutation.cpp new file mode 100644 index 00000000000..0091d82ac65 --- /dev/null +++ b/test/verify/test_gather_flatten_permutation.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_permutation : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 1, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {16}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4, 1, 2, 2}}; + std::vector indices = {0, 2, 8, 10, 4, 6, 12, 14, 1, 3, 9, 11, 5, 7, 13, 15}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_rectangular_three_axes.cpp b/test/verify/test_gather_flatten_rectangular_three_axes.cpp new file mode 100644 index 00000000000..1ebc18d0c5b --- /dev/null +++ b/test/verify/test_gather_flatten_rectangular_three_axes.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_rectangular_three_axes + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {2, 24, 5}}); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 2, 3}}; + std::vector indices = {4, 5, 6, 8, 9, 10, 16, 17, 18, 20, 21, 22}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction( + migraphx::make_op("gather", {{"axis", int64_t{1}}}), data, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_rectangular_two_axes.cpp b/test/verify/test_gather_flatten_rectangular_two_axes.cpp new file mode 100644 index 00000000000..ce63af0cd28 --- /dev/null +++ b/test/verify/test_gather_flatten_rectangular_two_axes.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_rectangular_two_axes + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {1, 12}}); + auto flatten = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {4, 5, 6, 8, 9, 10}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), flatten, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_first.cpp b/test/verify/test_gather_flatten_stride_first.cpp new file mode 100644 index 00000000000..0956bfdb0f3 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_first.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_first : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 8}}); + auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {1, 4}}; + std::vector indices = {0, 2, 4, 6}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), reshape, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_grid.cpp b/test/verify/test_gather_flatten_stride_grid.cpp new file mode 100644 index 00000000000..cadbfda9768 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_grid.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_grid : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 16, 16}}); + auto flatten = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {768}}}), data); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {1, 3, 4, 4}}; + std::vector indices = {17, 21, 25, 29, 81, 85, 89, 93, 145, 149, 153, 157, + 209, 213, 217, 221, 273, 277, 281, 285, 337, 341, 345, 349, + 401, 405, 409, 413, 465, 469, 473, 477, 529, 533, 537, 541, + 593, 597, 601, 605, 657, 661, 665, 669, 721, 725, 729, 733}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), flatten, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_offset.cpp b/test/verify/test_gather_flatten_stride_offset.cpp new file mode 100644 index 00000000000..21fa55383d2 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_offset.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_offset : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {1, 16}}); + auto flatten = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {16}}}), data); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {1, 4}}; + std::vector indices = {1, 5, 9, 13}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), flatten, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_slice.cpp b/test/verify/test_gather_flatten_stride_slice.cpp new file mode 100644 index 00000000000..3f45a0fc8e1 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_slice.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_slice : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 8}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 2}}; + std::vector indices = {1, 5, 2, 6}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_nhwc.cpp b/test/verify/test_gather_nhwc.cpp new file mode 100644 index 00000000000..df9dc3a0285 --- /dev/null +++ b/test/verify/test_gather_nhwc.cpp @@ -0,0 +1,48 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_nhwc : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape data_shape{migraphx::shape::float_type, {1, 2, 2, 3}}; + migraphx::shape indices_shape{migraphx::shape::int32_type, {1}}; + std::vector indices = {1}; + auto data = mm->add_parameter("data", data_shape); + auto idx_lit = mm->add_literal(migraphx::literal{indices_shape, indices}); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data); + auto gather = + mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), transpose, idx_lit); + mm->add_return({gather}); + return p; + } +}; diff --git a/test/verify/test_gather_simplify.cpp b/test/verify/test_gather_simplify.cpp new file mode 100644 index 00000000000..85115af442e --- /dev/null +++ b/test/verify/test_gather_simplify.cpp @@ -0,0 +1,46 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_simplify : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape data_shape{migraphx::shape::float_type, {2, 4}}; + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {1, 1, 1, 2, 2, 2}; + auto data = mm->add_parameter("data", data_shape); + auto idx_lit = mm->add_literal(migraphx::literal{indices_shape, indices}); + auto gather = + mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, idx_lit); + mm->add_return({gather}); + return p; + } +}; diff --git a/tools/api/api.cpp b/tools/api/api.cpp index 0526175ddf3..88763e882f8 100644 --- a/tools/api/api.cpp +++ b/tools/api/api.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -377,19 +377,9 @@ struct custom_operation return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); } - std::ptrdiff_t output_alias(std::vector inputs) const + std::vector output_alias(std::vector inputs) const { - auto alias_vec = op.output_alias(std::move(inputs)); - // TODO: For now, only support one output alias - if(alias_vec.empty()) - { - return -1; - } - if(alias_vec.size() > 1) - { - MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias"); - } - return alias_vec.front(); + return op.output_alias(std::move(inputs)); } bool runs_on_offload_target() const { return op.runs_on_offload_target(); } diff --git a/tools/include/operation.hpp b/tools/include/operation.hpp index 769cc8f9be7..a2e73c79940 100644 --- a/tools/include/operation.hpp +++ b/tools/include/operation.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -79,9 +79,9 @@ struct operation * the same the `output` shape. */ argument compute(context& ctx, const shape& output, const std::vector& input) const; - /// An optional method to return which argument the output will alias. If - /// there is no aliased output then -1 can be returned. - std::ptrdiff_t output_alias(const std::vector& input) const; + /// An optional method to return which arguments the output will alias. If + /// there is no aliased output then an empty vector can be returned. + std::vector output_alias(const std::vector& input) const; /// An optional stream operator to print the operation. When this is not /// implemented, it will just print the operation's name. friend std::ostream& operator<<(std::ostream& os, const operation& op); @@ -410,9 +410,9 @@ auto need_normalization_op(const T& x) } template -std::ptrdiff_t output_alias_op(const T&, const std::vector&) +std::vector output_alias_op(const T&, const std::vector&) { - return -1; + return {}; } template @@ -505,92 +505,98 @@ lifetime get_lifetime_op(const T&) } // namespace detail <% - interface( - 'operation', - virtual('name', returns = 'std::string', const = True), - virtual( - 'is_context_free', returns = 'bool', const = True, default = 'detail::is_context_free_op'), - virtual('need_normalization', - returns = 'bool', - const = True, - default = 'detail::need_normalization_op'), - virtual('has_finalize', returns = 'bool', const = True, default = 'detail::has_finalize_op'), - virtual( - 'get_lifetime', returns = 'lifetime', const = True, default = 'detail::get_lifetime_op'), - virtual('output_alias', - returns = 'std::ptrdiff_t', - input = 'const std::vector&', - const = True, - default = 'detail::output_alias_op'), - virtual('compile', - returns = 'value', - ctx = 'context&', - output = 'const shape&', - input = 'const std::vector&', - default = 'detail::compile_op'), - virtual('finalize', - ctx = 'context&', - output = 'const shape&', - input = 'const std::vector&', - default = 'detail::finalize_op'), - virtual('compute_shape', - returns = 'shape', - input = 'const std::vector&', - const = True, - default = 'detail::compute_shape_op'), - virtual('compute_shape', - returns = 'shape', - inputs = 'const std::vector&', - mod_args = 'const std::vector&', - const = True, - default = 'detail::mod_compute_shape_op'), - virtual('compute', - returns = 'argument', - ctx = 'context&', - output = 'const shape&', - input = 'const std::vector&', - const = True, - default = 'detail::compute_op'), - virtual('compute', - returns = 'argument', - output = 'const shape&', - input = 'const std::vector&', - const = True, - default = 'detail::compute_op'), - virtual( - 'compute', - returns = 'argument', - output = 'const shape&', - input = 'const std::vector&', - module_args = 'const std::vector&', - run = - 'std::function(module_ref&, const std::unordered_map&)>', - const = True, - default = 'detail::compute_op'), - virtual( - 'compute', - returns = 'argument', - ctx = 'context&', - output = 'const shape&', - input = 'const std::vector&', - module_args = 'const std::vector&', - run = - 'std::function(module_ref&, const std::unordered_map&)>', - const = True, - default = 'detail::compute_op'), - virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'), - virtual('from_value', v = 'const value&', default = 'detail::from_value_op'), - virtual('attributes', returns = 'value', const = True, default = 'detail::attributes_op'), - friend('operator<<', - returns = 'std::ostream &', - os = 'std::ostream &', - op = 'const operation &', - using = 'migraphx::detail::operation_operators::operator<<'), - friend('operator==', - returns = 'bool', - x = 'const operation &', - y = 'const operation &', - using = 'migraphx::detail::operation_operators::operator==')) %> + interface( + 'operation', + virtual('name', returns = 'std::string', const = True), + virtual('is_context_free', + returns = 'bool', + const = True, + default = 'detail::is_context_free_op'), + virtual('need_normalization', + returns = 'bool', + const = True, + default = 'detail::need_normalization_op'), + virtual( + 'has_finalize', returns = 'bool', const = True, default = 'detail::has_finalize_op'), + virtual('get_lifetime', + returns = 'lifetime', + const = True, + default = 'detail::get_lifetime_op'), + virtual('output_alias', + returns = 'std::vector', + input = 'const std::vector&', + const = True, + default = 'detail::output_alias_op'), + virtual('compile', + returns = 'value', + ctx = 'context&', + output = 'const shape&', + input = 'const std::vector&', + default = 'detail::compile_op'), + virtual('finalize', + ctx = 'context&', + output = 'const shape&', + input = 'const std::vector&', + default = 'detail::finalize_op'), + virtual('compute_shape', + returns = 'shape', + input = 'const std::vector&', + const = True, + default = 'detail::compute_shape_op'), + virtual('compute_shape', + returns = 'shape', + inputs = 'const std::vector&', + mod_args = 'const std::vector&', + const = True, + default = 'detail::mod_compute_shape_op'), + virtual('compute', + returns = 'argument', + ctx = 'context&', + output = 'const shape&', + input = 'const std::vector&', + const = True, + default = 'detail::compute_op'), + virtual('compute', + returns = 'argument', + output = 'const shape&', + input = 'const std::vector&', + const = True, + default = 'detail::compute_op'), + virtual( + 'compute', + returns = 'argument', + output = 'const shape&', + input = 'const std::vector&', + module_args = 'const std::vector&', + run = + 'std::function(module_ref&, const std::unordered_map&)>', + const = True, + default = 'detail::compute_op'), + virtual( + 'compute', + returns = 'argument', + ctx = 'context&', + output = 'const shape&', + input = 'const std::vector&', + module_args = 'const std::vector&', + run = + 'std::function(module_ref&, const std::unordered_map&)>', + const = True, + default = 'detail::compute_op'), + virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'), + virtual('from_value', v = 'const value&', default = 'detail::from_value_op'), + virtual('attributes', returns = 'value', const = True, default = 'detail::attributes_op'), + friend('operator<<', + returns = 'std::ostream &', + os = 'std::ostream &', + op = 'const operation &', + using = 'migraphx::detail::operation_operators::operator<<'), + friend('operator==', + returns = 'bool', + x = 'const operation &', + y = 'const operation &', + using = 'migraphx::detail::operation_operators::operator==')) +%> inline bool operator!=(const operation& x, const operation& y) {