From 2e4d205f48bb312846b858070d4da6549dfd04ba Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 13 Jan 2026 16:41:11 -0600 Subject: [PATCH 001/107] progress --- src/include/migraphx/instruction.hpp | 6 +++ src/instruction.cpp | 14 +++++++ .../include/migraphx/onnx/onnx_parser.hpp | 5 +++ src/onnx/onnx_parser.cpp | 42 +++++++++++++++---- 4 files changed, 59 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 0369bb6ab64..536072890aa 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -95,6 +96,10 @@ struct MIGRAPHX_EXPORT instruction /// Where this instruction is used as an input to another instruction const std::vector& outputs() const; + const std::set& get_debug_symbols() const; + + void add_debug_symbol(const std::string& symbol); + MIGRAPHX_EXPORT friend bool operator==(const instruction& x, const instruction& y); MIGRAPHX_EXPORT friend bool operator!=(const instruction& x, const instruction& y); @@ -188,6 +193,7 @@ struct MIGRAPHX_EXPORT instruction std::vector output; std::vector arguments; std::vector module_args; + std::set debug_symbols; literal lit; bool normalized = false; std::size_t target_id = 0; diff --git a/src/instruction.cpp b/src/instruction.cpp index b8c68a42ce9..69eeda6c3f1 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -28,12 +28,16 @@ #include #include #include +#include +#include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SHOW_DEBUG_SYMBOLS) + template static auto equal_to(const T& x) { @@ -188,6 +192,10 @@ const std::vector& instruction::module_inputs() const { return modul const std::vector& instruction::outputs() const { return output; } +const std::set& instruction::get_debug_symbols() const { return debug_symbols; } + +void instruction::add_debug_symbol(const std::string& symbol) { debug_symbols.insert(symbol); } + bool operator==(const instruction& x, const instruction& y) { if(not std::equal(x.arguments.begin(), @@ -437,6 +445,12 @@ void instruction::print(std::ostream& os, // print tid if(ins->target_id != 0) os << ", target_id=" << ins->target_id; + + // print debug symbols if enabled + if(enabled(MIGRAPHX_SHOW_DEBUG_SYMBOLS{}) and not ins->debug_symbols.empty()) + { + os << " /* " << join_strings(ins->debug_symbols, ", ") << " */"; + } } static void debug_name(std::ostream& os, const instruction& ins) diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index e0d8d498a6a..2ae868cfe24 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -53,6 +53,11 @@ struct onnx_parser std::size_t num_outputs = 1; std::string name = ""; module* mod = nullptr; + std::string onnx_node_name{}; + std::string onnx_op_type{}; + + std::string get_debug_symbol() const; + instruction_ref make_contiguous(instruction_ref ins) const; instruction_ref add_bias(const std::vector& args, instruction_ref curr_ins, diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 2668be70aed..37afd52457e 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -122,6 +122,13 @@ instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) con return ins; } +std::string onnx_parser::node_info::get_debug_symbol() const +{ + if(onnx_node_name.empty() and onnx_op_type.empty()) + return {}; + return "onnx::" + onnx_op_type + "::" + onnx_node_name; +} + instruction_ref onnx_parser::node_info::add_bias(const std::vector& args, instruction_ref curr_ins, uint64_t axis) const @@ -132,16 +139,15 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vectorget_shape().dynamic()) { - bias_bcast = - mod->add_instruction(make_op("broadcast", {{"axis", axis}}), args[2], curr_ins); + bias_bcast = add_instruction(make_op("broadcast", {{"axis", axis}}), args[2], curr_ins); } else { - bias_bcast = mod->add_instruction( + bias_bcast = add_instruction( make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}), args[2]); } - return mod->add_instruction(make_op("add"), curr_ins, bias_bcast); + return add_instruction(make_op("add"), curr_ins, bias_bcast); } return curr_ins; } @@ -175,21 +181,39 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const { - return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); + auto ins = migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); + auto debug_symbol = get_debug_symbol(); + if(not debug_symbol.empty()) + { + ins->add_debug_symbol(debug_symbol); + } + return ins; } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const { - return mod->add_instruction(op, args); + auto ins = mod->add_instruction(op, args); + auto debug_symbol = get_debug_symbol(); + if(not debug_symbol.empty()) + { + ins->add_debug_symbol(debug_symbol); + } + return ins; } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const { - return mod->add_instruction(op, args, mods); + auto ins = mod->add_instruction(op, args, mods); + auto debug_symbol = get_debug_symbol(); + if(not debug_symbol.empty()) + { + ins->add_debug_symbol(debug_symbol); + } + return ins; } instruction_ref onnx_parser::node_info::add_literal(literal l) const @@ -585,7 +609,9 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini { std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); result = ops[node.op_type()]( - *this, {get_attributes(node), output_num, node_name, mod}, args); + *this, + {get_attributes(node), output_num, node_name, mod, node.name(), node.op_type()}, + args); } output_num = std::min(output_num, result.size()); From bb622af3639b619324e5ae7f4ceb8dafdb70170c Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 12:44:36 -0600 Subject: [PATCH 002/107] dd --- src/onnx/include/migraphx/onnx/onnx_parser.hpp | 6 +----- src/onnx/onnx_parser.cpp | 11 ++--------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index 2ae868cfe24..dc8417609fc 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -53,10 +53,6 @@ struct onnx_parser std::size_t num_outputs = 1; std::string name = ""; module* mod = nullptr; - std::string onnx_node_name{}; - std::string onnx_op_type{}; - - std::string get_debug_symbol() const; instruction_ref make_contiguous(instruction_ref ins) const; instruction_ref add_bias(const std::vector& args, diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 37afd52457e..90f4cee06bd 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -122,13 +122,6 @@ instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) con return ins; } -std::string onnx_parser::node_info::get_debug_symbol() const -{ - if(onnx_node_name.empty() and onnx_op_type.empty()) - return {}; - return "onnx::" + onnx_op_type + "::" + onnx_node_name; -} - instruction_ref onnx_parser::node_info::add_bias(const std::vector& args, instruction_ref curr_ins, uint64_t axis) const @@ -178,8 +171,8 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s * operation. * */ -instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, - std::vector inputs) const +instruction_ref onnx_parser::nme ode_info::add_common_op(const std::string& op_name, + std::vector inputs) const { auto ins = migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); auto debug_symbol = get_debug_symbol(); From 432ed8307fd773f684499319a61c99c0cf3d2461 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 15:51:33 -0600 Subject: [PATCH 003/107] Add onnx_parser and module changes --- src/common.cpp | 60 ++++++++++++++++--- src/include/migraphx/common.hpp | 20 +++++++ src/include/migraphx/module.hpp | 36 +++++++++++ src/instruction.cpp | 8 ++- src/module.cpp | 57 ++++++++++++++++-- .../include/migraphx/onnx/onnx_parser.hpp | 3 + src/onnx/onnx_parser.cpp | 45 +++----------- 7 files changed, 179 insertions(+), 50 deletions(-) diff --git a/src/common.cpp b/src/common.cpp index 10bc5d0707f..72abcd3d173 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -151,10 +151,11 @@ shape common_shape(const std::vector& shapes) return {compute_common_types(shapes), compute_common_lens(shapes)}; } -std::vector insert_common_args(module& m, - instruction_ref ins, - std::vector inputs, - common_options options) +std::vector insert_common_args_impl(module& m, + instruction_ref ins, + const std::string& debug_symbol, + std::vector inputs, + common_options options) { if(std::any_of( inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) @@ -167,7 +168,10 @@ std::vector insert_common_args(module& m, auto s0 = inputs[0]->get_shape(); // always add both multibroadcast instructions for dynamic shapes inputs[0] = m.insert_instruction( - ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs); + ins, + make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), + debug_symbol, + inputs); std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) { // uses previous input to avoid recalculating the common shape from the // full set of input shapes at runtime @@ -175,6 +179,7 @@ std::vector insert_common_args(module& m, return m.insert_instruction( ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), + debug_symbol, input, inputs[0]); }); @@ -186,7 +191,7 @@ std::vector insert_common_args(module& m, if(input->get_shape().type() != c_type) { input = m.insert_instruction( - ins, make_op("convert", {{"target_type", c_type}}), input); + ins, make_op("convert", {{"target_type", c_type}}), debug_symbol, input); } return input; }); @@ -198,13 +203,16 @@ std::vector insert_common_args(module& m, std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(options.common_lens and input->get_shape().lens() != common.lens()) { - input = m.insert_instruction( - ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input); + input = + m.insert_instruction(ins, + make_op("multibroadcast", {{"out_lens", common.lens()}}), + debug_symbol, + input); } if(options.common_type and input->get_shape().type() != common.type()) { input = m.insert_instruction( - ins, make_op("convert", {{"target_type", common.type()}}), input); + ins, make_op("convert", {{"target_type", common.type()}}), debug_symbol, input); } return input; }); @@ -212,6 +220,19 @@ std::vector insert_common_args(module& m, return inputs; } +std::vector insert_common_args(module& m, + instruction_ref ins, + std::vector inputs, + common_options options) +{ return insert_common_args_impl(m, ins, {}, std::move(inputs), options); } + +std::vector insert_common_args(module& m, + instruction_ref ins, + const std::string& debug_symbol, + std::vector inputs, + common_options options) +{ return insert_common_args_impl(m, ins, debug_symbol, std::move(inputs), options); } + std::vector add_common_args(module& m, std::vector inputs, common_options options) { @@ -227,6 +248,20 @@ instruction_ref insert_common_op(module& m, return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs), options)); } +instruction_ref insert_common_op(module& m, + instruction_ref ins, + const operation& op, + const std::string& debug_symbol, + std::vector inputs, + common_options options) +{ + return m.insert_instruction( + ins, + op, + debug_symbol, + insert_common_args(m, ins, debug_symbol, std::move(inputs), options)); +} + instruction_ref add_common_op(module& m, const operation& op, std::vector inputs, @@ -235,6 +270,13 @@ instruction_ref add_common_op(module& m, return insert_common_op(m, m.end(), op, std::move(inputs), options); } +instruction_ref add_common_op(module& m, + const operation& op, + const std::string& debug_symbol, + std::vector inputs, + common_options options) +{ return insert_common_op(m, m.end(), op, debug_symbol, std::move(inputs), options); } + shape make_bcast_shape(const shape& input_shape, const std::vector& bcast_lens) { assert(not input_shape.dynamic()); diff --git a/src/include/migraphx/common.hpp b/src/include/migraphx/common.hpp index 63a2dc43541..9ac4dad1b9c 100644 --- a/src/include/migraphx/common.hpp +++ b/src/include/migraphx/common.hpp @@ -123,6 +123,12 @@ MIGRAPHX_EXPORT std::vector insert_common_args(module& m, std::vector inputs, common_options options = {}); +MIGRAPHX_EXPORT std::vector insert_common_args(module& m, + instruction_ref ins, + const std::string& debug_symbol, + std::vector inputs, + common_options options = {}); + MIGRAPHX_EXPORT std::vector add_common_args(module& m, std::vector inputs, common_options options = {}); @@ -134,6 +140,14 @@ instruction_ref insert_common_op(module& m, std::vector inputs, common_options options = {}); +MIGRAPHX_EXPORT +instruction_ref insert_common_op(module& m, + instruction_ref ins, + const operation& op, + const std::string& debug_symbol, + std::vector inputs, + common_options options = {}); + /** * @brief Wrapper for insert_common_args() which inserts operation at the end of the module. */ @@ -142,6 +156,12 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector inputs, common_options options = {}); +MIGRAPHX_EXPORT +instruction_ref add_common_op(module& m, + const operation& op, + const std::string& debug_symbol, + std::vector inputs, + common_options options = {}); /** * Calculates the broadcasted shape with the given input_shape and broadcasted dimensions. diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 470772c3fae..d4ac323f90e 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -93,11 +93,25 @@ struct MIGRAPHX_EXPORT module std::vector args, std::vector module_args); + template {}...)> + instruction_ref add_instruction(operation op, const std::string& debug_symbol, Ts... args) + { return add_instruction(op, debug_symbol, {args...}); } + + instruction_ref add_instruction(const operation& op, + const std::string& debug_symbol, + std::vector args); + + instruction_ref add_instruction(const operation& op, + const std::string& debug_symbol, + std::vector args, + std::vector module_args); + template {}...)> instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) { return insert_instruction(ins, op, {args...}); } + instruction_ref insert_instruction(instruction_ref ins, const operation& op, std::vector args); @@ -106,6 +120,24 @@ struct MIGRAPHX_EXPORT module std::vector args, std::vector module_args); + template {}...)> + instruction_ref insert_instruction(instruction_ref ins, + operation op, + const std::string& debug_symbol, + Ts... args) + { return insert_instruction(ins, op, debug_symbol, {args...}); } + + instruction_ref insert_instruction(instruction_ref ins, + const operation& op, + const std::string& debug_symbol, + std::vector args); + + instruction_ref insert_instruction(instruction_ref ins, + const operation& op, + const std::string& debug_symbol, + std::vector args, + std::vector module_args); + template {}...)> instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) { @@ -173,6 +205,8 @@ struct MIGRAPHX_EXPORT module instruction_ref add_literal(literal l); + instruction_ref add_literal(literal l, const std::string& debug_symbol); + instruction_ref add_outline(const shape& s); instruction_ref add_parameter(std::string name, shape s); @@ -183,6 +217,8 @@ struct MIGRAPHX_EXPORT module instruction_ref insert_literal(instruction_ref ins, literal l); + instruction_ref insert_literal(instruction_ref ins, literal l, const std::string& debug_symbol); + instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s); std::vector get_parameter_names() const; diff --git a/src/instruction.cpp b/src/instruction.cpp index 3e6640820a5..0364f83e72a 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -194,7 +194,13 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } -void instruction::add_debug_symbol(const std::string& symbol) { debug_symbols.insert(symbol); } +void instruction::add_debug_symbol(const std::string& symbol) +{ + if(not symbol.empty()) + { + debug_symbols.insert(symbol); + } +} bool operator==(const instruction& x, const instruction& y) { diff --git a/src/module.cpp b/src/module.cpp index 6e6270d4dbb..610ae0f9015 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -292,6 +292,26 @@ instruction_ref module::add_instruction(const operation& op, std::vectorinsert_end(), op, std::move(args)); } + +instruction_ref module::add_instruction(const operation& op, + const std::string& debug_symbol, + std::vector args) +{ return insert_instruction(this->insert_end(), op, debug_symbol, std::move(args)); } + +instruction_ref module::add_instruction(const operation& op, + std::vector args, + std::vector module_args) +{ return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); } + +instruction_ref module::add_instruction(const operation& op, + const std::string& debug_symbol, + std::vector args, + std::vector module_args) +{ + return insert_instruction( + this->insert_end(), op, debug_symbol, std::move(args), std::move(module_args)); +} + instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) @@ -305,11 +325,14 @@ instruction_ref module::insert_instruction(instruction_ref ins, return result; } -instruction_ref module::add_instruction(const operation& op, - std::vector args, - std::vector module_args) +instruction_ref module::insert_instruction(instruction_ref ins, + const operation& op, + const std::string& debug_symbol, + std::vector args) { - return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); + auto new_ins = insert_instruction(ins, op, args); + new_ins->add_debug_symbol(debug_symbol); + return ins; } instruction_ref module::insert_instruction(instruction_ref ins, @@ -326,6 +349,17 @@ instruction_ref module::insert_instruction(instruction_ref ins, return result; } +instruction_ref module::insert_instruction(instruction_ref ins, + const operation& op, + const std::string& debug_symbol, + std::vector args, + std::vector module_args) +{ + auto new_ins = insert_instruction(ins, op, args, module_args); + new_ins->add_debug_symbol(debug_symbol); + return ins; +} + instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST @@ -560,6 +594,13 @@ module::insert_instructions(instruction_ref ins, instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } +instruction_ref module::add_literal(literal l, const std::string& debug_symbol) +{ + auto lit_ins = add_literal(l); + lit_ins->add_debug_symbol(debug_symbol); + return lit_ins; +} + instruction_ref module::add_outline(const shape& s) { impl->push_front({builtin::outline{s}, s, {}}); @@ -587,6 +628,14 @@ instruction_ref module::insert_literal(instruction_ref ins, literal l) return std::prev(ins); } +instruction_ref +module::insert_literal(instruction_ref ins, literal l, const std::string& debug_symbol) +{ + auto lit_ins = insert_literal(ins, l); + lit_ins->add_debug_symbol(debug_symbol); + return lit_ins; +} + instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s) { assert(get_parameter_shape(name) == shape{}); diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index d0c550283b7..bfd9c9803a0 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -51,8 +51,11 @@ struct onnx_parser { attribute_map attributes{}; std::size_t num_outputs = 1; + // unique identifier for MIGX, not given ONNX node name std::string name = ""; module* mod = nullptr; + std::string onnx_node_name{}; + std::string onnx_op_type{}; instruction_ref make_contiguous(instruction_ref ins) const; instruction_ref add_bias(const std::vector& args, diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 251c1468ea4..55774603200 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -172,48 +172,22 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s * operation. * */ -instruction_ref onnx_parser::nme ode_info::add_common_op(const std::string& op_name, - std::vector inputs) const -{ - auto ins = migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); - auto debug_symbol = get_debug_symbol(); - if(not debug_symbol.empty()) - { - ins->add_debug_symbol(debug_symbol); - } - return ins; -} +instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, + std::vector inputs) const +{ return migraphx::add_common_op(*mod, make_op(op_name), onnx_node_name, std::move(inputs)); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ - auto ins = mod->add_instruction(op, args); - auto debug_symbol = get_debug_symbol(); - if(not debug_symbol.empty()) - { - ins->add_debug_symbol(debug_symbol); - } - return ins; -} +{ return mod->add_instruction(op, onnx_node_name, args); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ - auto ins = mod->add_instruction(op, args, mods); - auto debug_symbol = get_debug_symbol(); - if(not debug_symbol.empty()) - { - ins->add_debug_symbol(debug_symbol); - } - return ins; -} +{ return mod->add_instruction(op, onnx_node_name, args, mods); } instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ - return mod->add_literal(std::move(l)); -} +{ return mod->add_literal(std::move(l), onnx_node_name); } onnx_parser::onnx_parser() { @@ -618,10 +592,9 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini else { std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); - result = ops[node.op_type()]( - *this, - {get_attributes(node), output_num, node_name, mod, node.name(), node.op_type()}, - args); + node_info ninfo{ + get_attributes(node), output_num, node_name, mod, node.name(), node.op_type()}; + result = ops[node.op_type()](*this, ninfo, args); } output_num = std::min(output_num, result.size()); From c2381cbafdb2eabab7f908cadd40f0f29cfef3b6 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 16:57:41 -0600 Subject: [PATCH 004/107] Change to use a set of debug_symbols when inserting ins --- src/common.cpp | 26 +++++++-------- src/include/migraphx/common.hpp | 7 ++-- src/include/migraphx/instruction.hpp | 2 +- src/include/migraphx/module.hpp | 21 ++++++------ src/instruction.cpp | 9 ++++-- src/module.cpp | 48 ++++++++++++++-------------- src/onnx/onnx_parser.cpp | 8 ++--- 7 files changed, 63 insertions(+), 58 deletions(-) diff --git a/src/common.cpp b/src/common.cpp index 72abcd3d173..fbc7bb01a90 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -153,7 +153,7 @@ shape common_shape(const std::vector& shapes) std::vector insert_common_args_impl(module& m, instruction_ref ins, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options) { @@ -170,7 +170,7 @@ std::vector insert_common_args_impl(module& m, inputs[0] = m.insert_instruction( ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), - debug_symbol, + debug_symbols, inputs); std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) { // uses previous input to avoid recalculating the common shape from the @@ -179,7 +179,7 @@ std::vector insert_common_args_impl(module& m, return m.insert_instruction( ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), - debug_symbol, + debug_symbols, input, inputs[0]); }); @@ -191,7 +191,7 @@ std::vector insert_common_args_impl(module& m, if(input->get_shape().type() != c_type) { input = m.insert_instruction( - ins, make_op("convert", {{"target_type", c_type}}), debug_symbol, input); + ins, make_op("convert", {{"target_type", c_type}}), debug_symbols, input); } return input; }); @@ -206,13 +206,13 @@ std::vector insert_common_args_impl(module& m, input = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), - debug_symbol, + debug_symbols, input); } if(options.common_type and input->get_shape().type() != common.type()) { input = m.insert_instruction( - ins, make_op("convert", {{"target_type", common.type()}}), debug_symbol, input); + ins, make_op("convert", {{"target_type", common.type()}}), debug_symbols, input); } return input; }); @@ -228,10 +228,10 @@ std::vector insert_common_args(module& m, std::vector insert_common_args(module& m, instruction_ref ins, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options) -{ return insert_common_args_impl(m, ins, debug_symbol, std::move(inputs), options); } +{ return insert_common_args_impl(m, ins, debug_symbols, std::move(inputs), options); } std::vector add_common_args(module& m, std::vector inputs, common_options options) @@ -251,15 +251,15 @@ instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m, instruction_ref ins, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options) { return m.insert_instruction( ins, op, - debug_symbol, - insert_common_args(m, ins, debug_symbol, std::move(inputs), options)); + debug_symbols, + insert_common_args(m, ins, debug_symbols, std::move(inputs), options)); } instruction_ref add_common_op(module& m, @@ -272,10 +272,10 @@ instruction_ref add_common_op(module& m, instruction_ref add_common_op(module& m, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options) -{ return insert_common_op(m, m.end(), op, debug_symbol, std::move(inputs), options); } +{ return insert_common_op(m, m.end(), op, debug_symbols, std::move(inputs), options); } shape make_bcast_shape(const shape& input_shape, const std::vector& bcast_lens) { diff --git a/src/include/migraphx/common.hpp b/src/include/migraphx/common.hpp index 9ac4dad1b9c..88f139f1bc3 100644 --- a/src/include/migraphx/common.hpp +++ b/src/include/migraphx/common.hpp @@ -27,6 +27,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -125,7 +126,7 @@ MIGRAPHX_EXPORT std::vector insert_common_args(module& m, MIGRAPHX_EXPORT std::vector insert_common_args(module& m, instruction_ref ins, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options = {}); @@ -144,7 +145,7 @@ MIGRAPHX_EXPORT instruction_ref insert_common_op(module& m, instruction_ref ins, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options = {}); @@ -159,7 +160,7 @@ instruction_ref add_common_op(module& m, MIGRAPHX_EXPORT instruction_ref add_common_op(module& m, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector inputs, common_options options = {}); diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index f55dacf0e2e..1050775a660 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -98,7 +98,7 @@ struct MIGRAPHX_EXPORT instruction const std::set& get_debug_symbols() const; - void add_debug_symbol(const std::string& symbol); + void add_debug_symbols(const std::set& symbols); MIGRAPHX_EXPORT friend bool operator==(const instruction& x, const instruction& y); diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index d4ac323f90e..448e3c90505 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -25,6 +25,7 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP #include +#include #include #include #include @@ -94,15 +95,15 @@ struct MIGRAPHX_EXPORT module std::vector module_args); template {}...)> - instruction_ref add_instruction(operation op, const std::string& debug_symbol, Ts... args) - { return add_instruction(op, debug_symbol, {args...}); } + instruction_ref add_instruction(operation op, const std::set& debug_symbols, Ts... args) + { return add_instruction(op, debug_symbols, {args...}); } instruction_ref add_instruction(const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector args); instruction_ref add_instruction(const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector args, std::vector module_args); @@ -123,18 +124,18 @@ struct MIGRAPHX_EXPORT module template {}...)> instruction_ref insert_instruction(instruction_ref ins, operation op, - const std::string& debug_symbol, + const std::set& debug_symbols, Ts... args) - { return insert_instruction(ins, op, debug_symbol, {args...}); } + { return insert_instruction(ins, op, debug_symbols, {args...}); } instruction_ref insert_instruction(instruction_ref ins, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector args); instruction_ref insert_instruction(instruction_ref ins, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, std::vector args, std::vector module_args); @@ -205,7 +206,7 @@ struct MIGRAPHX_EXPORT module instruction_ref add_literal(literal l); - instruction_ref add_literal(literal l, const std::string& debug_symbol); + instruction_ref add_literal(literal l, const std::set& debug_symbols); instruction_ref add_outline(const shape& s); @@ -217,7 +218,7 @@ struct MIGRAPHX_EXPORT module instruction_ref insert_literal(instruction_ref ins, literal l); - instruction_ref insert_literal(instruction_ref ins, literal l, const std::string& debug_symbol); + instruction_ref insert_literal(instruction_ref ins, literal l, const std::set& debug_symbols); instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s); diff --git a/src/instruction.cpp b/src/instruction.cpp index 0364f83e72a..9d3d28f03e0 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -194,11 +194,14 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } -void instruction::add_debug_symbol(const std::string& symbol) +void instruction::add_debug_symbols(const std::set& symbols) { - if(not symbol.empty()) + for(const auto& symbol : symbols) { - debug_symbols.insert(symbol); + if(not symbol.empty()) + { + debug_symbols.insert(symbol); + } } } diff --git a/src/module.cpp b/src/module.cpp index 610ae0f9015..bdb3c01b2eb 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -293,23 +293,23 @@ instruction_ref module::add_instruction(const operation& op, std::vectorinsert_end(), op, std::move(args)); } -instruction_ref module::add_instruction(const operation& op, - const std::string& debug_symbol, - std::vector args) -{ return insert_instruction(this->insert_end(), op, debug_symbol, std::move(args)); } - instruction_ref module::add_instruction(const operation& op, std::vector args, std::vector module_args) { return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); } instruction_ref module::add_instruction(const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, + std::vector args) +{ return insert_instruction(this->insert_end(), op, debug_symbols, std::move(args)); } + +instruction_ref module::add_instruction(const operation& op, + const std::set& debug_symbols, std::vector args, std::vector module_args) { return insert_instruction( - this->insert_end(), op, debug_symbol, std::move(args), std::move(module_args)); + this->insert_end(), op, debug_symbols, std::move(args), std::move(module_args)); } instruction_ref module::insert_instruction(instruction_ref ins, @@ -325,16 +325,6 @@ instruction_ref module::insert_instruction(instruction_ref ins, return result; } -instruction_ref module::insert_instruction(instruction_ref ins, - const operation& op, - const std::string& debug_symbol, - std::vector args) -{ - auto new_ins = insert_instruction(ins, op, args); - new_ins->add_debug_symbol(debug_symbol); - return ins; -} - instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args, @@ -351,13 +341,23 @@ instruction_ref module::insert_instruction(instruction_ref ins, instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, - const std::string& debug_symbol, + const std::set& debug_symbols, + std::vector args) +{ + auto new_ins = insert_instruction(ins, op, args); + new_ins->add_debug_symbols(debug_symbols); + return new_ins; +} + +instruction_ref module::insert_instruction(instruction_ref ins, + const operation& op, + const std::set& debug_symbols, std::vector args, std::vector module_args) { auto new_ins = insert_instruction(ins, op, args, module_args); - new_ins->add_debug_symbol(debug_symbol); - return ins; + new_ins->add_debug_symbols(debug_symbols); + return new_ins; } instruction_ref module::replace_instruction(instruction_ref ins, @@ -594,10 +594,10 @@ module::insert_instructions(instruction_ref ins, instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } -instruction_ref module::add_literal(literal l, const std::string& debug_symbol) +instruction_ref module::add_literal(literal l, const std::set& debug_symbols) { auto lit_ins = add_literal(l); - lit_ins->add_debug_symbol(debug_symbol); + lit_ins->add_debug_symbols(debug_symbols); return lit_ins; } @@ -629,10 +629,10 @@ instruction_ref module::insert_literal(instruction_ref ins, literal l) } instruction_ref -module::insert_literal(instruction_ref ins, literal l, const std::string& debug_symbol) +module::insert_literal(instruction_ref ins, literal l, const std::set& debug_symbols) { auto lit_ins = insert_literal(ins, l); - lit_ins->add_debug_symbol(debug_symbol); + lit_ins->add_debug_symbols(debug_symbols); return lit_ins; } diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 55774603200..7b56e9c352e 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -174,20 +174,20 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s */ instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const -{ return migraphx::add_common_op(*mod, make_op(op_name), onnx_node_name, std::move(inputs)); } +{ return migraphx::add_common_op(*mod, make_op(op_name), {onnx_node_name}, std::move(inputs)); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ return mod->add_instruction(op, onnx_node_name, args); } +{ return mod->add_instruction(op, {onnx_node_name}, args); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ return mod->add_instruction(op, onnx_node_name, args, mods); } +{ return mod->add_instruction(op, {onnx_node_name}, args, mods); } instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ return mod->add_literal(std::move(l), onnx_node_name); } +{ return mod->add_literal(std::move(l), {onnx_node_name}); } onnx_parser::onnx_parser() { From b00e7e2790494ea62892fd16555c18c5d48964e0 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 22:01:07 -0600 Subject: [PATCH 005/107] Incomplete first parse design --- .gitignore | 7 --- src/common.cpp | 6 +- src/include/migraphx/common.hpp | 11 ++-- src/include/migraphx/module.hpp | 6 +- src/onnx/parse_batchnorm.cpp | 2 + src/onnx/parse_convolution.cpp | 3 + src/op/builder/batchnorm.cpp | 59 ++++++++++++------- src/op/builder/clip.cpp | 4 ++ src/op/builder/convolution.cpp | 7 ++- src/op/builder/gelu.cpp | 8 +++ .../include/migraphx/op/builder/insert.hpp | 8 +++ .../migraphx/op/builder/op_builder.hpp | 15 +++++ 12 files changed, 95 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index 20dead71ed3..fa78b2ef667 100644 --- a/.gitignore +++ b/.gitignore @@ -54,8 +54,6 @@ _toc.yml #==============================================================================# # Directories to ignore (do not add trailing '/'s, they skip symlinks). #==============================================================================# -# Nested build directory -/build* # Downloaded models test/onnx/models @@ -76,11 +74,6 @@ docs/_doxygen docs/html /_readthedocs -# JetBrains config directories (ignoring symlinks) -.idea/ -cmake-build*/ -build*/ - # Recommended location to install rbuild dependencies from README.md depend*/ diff --git a/src/common.cpp b/src/common.cpp index fbc7bb01a90..6413d5f8e95 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -211,8 +211,10 @@ std::vector insert_common_args_impl(module& m, } if(options.common_type and input->get_shape().type() != common.type()) { - input = m.insert_instruction( - ins, make_op("convert", {{"target_type", common.type()}}), debug_symbols, input); + input = m.insert_instruction(ins, + make_op("convert", {{"target_type", common.type()}}), + debug_symbols, + input); } return input; }); diff --git a/src/include/migraphx/common.hpp b/src/include/migraphx/common.hpp index 88f139f1bc3..ebb1e5d55d1 100644 --- a/src/include/migraphx/common.hpp +++ b/src/include/migraphx/common.hpp @@ -124,11 +124,12 @@ MIGRAPHX_EXPORT std::vector insert_common_args(module& m, std::vector inputs, common_options options = {}); -MIGRAPHX_EXPORT std::vector insert_common_args(module& m, - instruction_ref ins, - const std::set& debug_symbols, - std::vector inputs, - common_options options = {}); +MIGRAPHX_EXPORT std::vector +insert_common_args(module& m, + instruction_ref ins, + const std::set& debug_symbols, + std::vector inputs, + common_options options = {}); MIGRAPHX_EXPORT std::vector diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 448e3c90505..596579c473d 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -95,7 +95,8 @@ struct MIGRAPHX_EXPORT module std::vector module_args); template {}...)> - instruction_ref add_instruction(operation op, const std::set& debug_symbols, Ts... args) + instruction_ref + add_instruction(operation op, const std::set& debug_symbols, Ts... args) { return add_instruction(op, debug_symbols, {args...}); } instruction_ref add_instruction(const operation& op, @@ -218,7 +219,8 @@ struct MIGRAPHX_EXPORT module instruction_ref insert_literal(instruction_ref ins, literal l); - instruction_ref insert_literal(instruction_ref ins, literal l, const std::set& debug_symbols); + instruction_ref + insert_literal(instruction_ref ins, literal l, const std::set& debug_symbols); instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s); diff --git a/src/onnx/parse_batchnorm.cpp b/src/onnx/parse_batchnorm.cpp index ceef3c850bd..3c405f07bcc 100644 --- a/src/onnx/parse_batchnorm.cpp +++ b/src/onnx/parse_batchnorm.cpp @@ -39,6 +39,8 @@ struct parse_batchnorm : op_parser const std::vector& args) const { value options = {}; + options.insert({"debug_symbols", std::set({info.onnx_node_name})}); + if(contains(info.attributes, "epsilon")) { const float epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); diff --git a/src/onnx/parse_convolution.cpp b/src/onnx/parse_convolution.cpp index 1492607b098..e76be18a404 100644 --- a/src/onnx/parse_convolution.cpp +++ b/src/onnx/parse_convolution.cpp @@ -64,6 +64,9 @@ struct parse_convolution : op_parser check_padding_mode(info, opd.onnx_name); value options = {}; + std::set debug_s{info.onnx_node_name}; + options.insert({"debug_symbols", debug_s}); + if(contains(info.attributes, "strides")) { const auto& attr = info.attributes["strides"].ints(); diff --git a/src/op/builder/batchnorm.cpp b/src/op/builder/batchnorm.cpp index f1fa3479510..5f230c80657 100644 --- a/src/op/builder/batchnorm.cpp +++ b/src/op/builder/batchnorm.cpp @@ -59,34 +59,49 @@ struct batchnorm : op_builder auto x_rank = x_lens.size(); if(x_rank == 1 or x_rank == 2) { - auto eps = m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); - auto x_sub_mean = insert_common_op(m, ins, "sub", args[0], args[3]); - auto var_eps = insert_common_op(m, ins, "add", args[4], eps); - auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), var_eps); - auto mul0 = insert_common_op(m, ins, "mul", args[1], rsqrt); - auto r0 = insert_common_op(m, ins, "mul", x_sub_mean, mul0); - return {insert_common_op(m, ins, "add", r0, args[2])}; + auto eps = + m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}, debug_symbols); + auto x_sub_mean = insert_common_op(m, ins, "sub", debug_symbols, args[0], args[3]); + auto var_eps = insert_common_op(m, ins, "add", debug_symbols, args[4], eps); + auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), debug_symbols, var_eps); + auto mul0 = insert_common_op(m, ins, "mul", debug_symbols, args[1], rsqrt); + auto r0 = insert_common_op(m, ins, "mul", debug_symbols, x_sub_mean, mul0); + return {insert_common_op(m, ins, "add", debug_symbols, r0, args[2])}; } else if(x_rank > 2) { // unsqueeze tensors of shape (C) to broadcast correctly std::vector unsqueeze_axes(x_lens.size() - 2); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); - auto eps = m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); - auto scale_unsqueeze = m.insert_instruction( - ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); - auto bias_unsqueeze = m.insert_instruction( - ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[2]); - auto mean_unsqueeze = m.insert_instruction( - ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); - auto var_unsqueeze = m.insert_instruction( - ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); - auto x_sub_mean = insert_common_op(m, ins, "sub", args[0], mean_unsqueeze); - auto var_eps = insert_common_op(m, ins, "add", var_unsqueeze, eps); - auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), var_eps); - auto mul0 = insert_common_op(m, ins, "mul", scale_unsqueeze, rsqrt); - auto r0 = insert_common_op(m, ins, "mul", x_sub_mean, mul0); - return {insert_common_op(m, ins, "add", r0, bias_unsqueeze)}; + auto eps = + m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}, debug_symbols); + auto scale_unsqueeze = + m.insert_instruction(ins, + migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), + debug_symbols, + args[1]); + auto bias_unsqueeze = + m.insert_instruction(ins, + migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), + debug_symbols, + args[2]); + auto mean_unsqueeze = + m.insert_instruction(ins, + migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), + debug_symbols, + args[3]); + auto var_unsqueeze = + m.insert_instruction(ins, + migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), + debug_symbols, + args[4]); + auto x_sub_mean = + insert_common_op(m, ins, "sub", debug_symbols, args[0], mean_unsqueeze); + auto var_eps = insert_common_op(m, ins, "add", debug_symbols, var_unsqueeze, eps); + auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), debug_symbols, var_eps); + auto mul0 = insert_common_op(m, ins, "mul", debug_symbols, scale_unsqueeze, rsqrt); + auto r0 = insert_common_op(m, ins, "mul", debug_symbols, x_sub_mean, mul0); + return {insert_common_op(m, ins, "add", debug_symbols, r0, bias_unsqueeze)}; } else { diff --git a/src/op/builder/clip.cpp b/src/op/builder/clip.cpp index 7d5168e18a1..9d986423a9c 100644 --- a/src/op/builder/clip.cpp +++ b/src/op/builder/clip.cpp @@ -35,6 +35,10 @@ namespace builder { struct clip : op_builder { + template + static auto reflect(Self&, F) + { return pack(); } + std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { diff --git a/src/op/builder/convolution.cpp b/src/op/builder/convolution.cpp index 629d745c56e..88acbf9cbb4 100644 --- a/src/op/builder/convolution.cpp +++ b/src/op/builder/convolution.cpp @@ -147,8 +147,9 @@ struct convolution : convolution_base auto kdims = in_lens.size() - 2; validate_or_init_attributes(kdims, x, weights); - - auto conv = m.insert_instruction(ins, make_conv_op("convolution"), x, weights); + std::cout << "conv debug symbols empty?: " << debug_symbols.empty() << std::endl; + auto conv = + m.insert_instruction(ins, make_conv_op("convolution"), debug_symbols, x, weights); return {add_bias(m, ins, args, conv, 1)}; } @@ -201,7 +202,7 @@ struct quant_convolution : convolution_base auto x_zp = get_zero_point(m, x, 2, args); auto w_zp = get_zero_point(m, weights, 3, args); handle_quant_inputs(m, ins, x, weights, x_zp, w_zp); - auto conv = m.insert_instruction(ins, op, x, weights); + auto conv = m.insert_instruction(ins, op, debug_symbols, x, weights); return {handle_quant_bias(m, ins, op, conv, x, weights, x_zp, w_zp)}; } diff --git a/src/op/builder/gelu.cpp b/src/op/builder/gelu.cpp index f360bae016f..60258b3a905 100644 --- a/src/op/builder/gelu.cpp +++ b/src/op/builder/gelu.cpp @@ -57,6 +57,10 @@ struct gelu_quick : op_builder struct gelu_erf : op_builder { + template + static auto reflect(Self&, F) + { return pack(); } + std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { @@ -124,6 +128,10 @@ struct gelu_tanh : op_builder struct gelu_split : op_builder { + template + static auto reflect(Self&, F) + { return pack(); } + std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { diff --git a/src/op/builder/include/migraphx/op/builder/insert.hpp b/src/op/builder/include/migraphx/op/builder/insert.hpp index ff64a05580e..d9e987cf476 100644 --- a/src/op/builder/include/migraphx/op/builder/insert.hpp +++ b/src/op/builder/include/migraphx/op/builder/insert.hpp @@ -67,6 +67,14 @@ insert_common_op(module& m, instruction_ref ins, const std::string& op_name, Ins return insert_common_op(m, ins, make_op(op_name), {args...}); } +template +instruction_ref insert_common_op(module& m, + instruction_ref ins, + const std::string& op_name, + const std::set& debug_symbols, + Ins... args) +{ return insert_common_op(m, ins, make_op(op_name), debug_symbols, {args...}); } + } // namespace builder } // namespace op } // namespace MIGRAPHX_INLINE_NS diff --git a/src/op/builder/include/migraphx/op/builder/op_builder.hpp b/src/op/builder/include/migraphx/op/builder/op_builder.hpp index 015d375687e..d3446afa8f8 100644 --- a/src/op/builder/include/migraphx/op/builder/op_builder.hpp +++ b/src/op/builder/include/migraphx/op/builder/op_builder.hpp @@ -53,6 +53,16 @@ struct op_builder_if MIGRAPHX_EXPORT void register_builder(const std::string& name, op_builder_if opb_if); +template +void apply_debug_symbols(T& x, const value& options) +{ + if(options.contains("debug_symbols")) + { + x.debug_symbols = + from_value>(options.at("debug_symbols").without_key()); + } +} + template auto invoke_builder(const std::string& /*name*/, module& m, @@ -62,6 +72,7 @@ auto invoke_builder(const std::string& /*name*/, const value& options) -> decltype(T{}.insert(m, ins, args, module_args)) { auto x = from_value(options); + apply_debug_symbols(x, options); return x.insert(m, ins, args, module_args); } @@ -76,6 +87,7 @@ auto invoke_builder(const std::string& /*name*/, if(not module_args.empty()) MIGRAPHX_THROW("Module args should be empty"); auto x = from_value(options); + apply_debug_symbols(x, options); return x.insert(m, ins, args); } @@ -88,6 +100,7 @@ auto invoke_builder(const std::string& name, const value& options) -> decltype(T{}.insert(name, m, ins, args, module_args)) { auto x = from_value(options); + apply_debug_symbols(x, options); return x.insert(name, m, ins, args, module_args); } @@ -102,6 +115,7 @@ auto invoke_builder(const std::string& name, if(not module_args.empty()) MIGRAPHX_THROW("Module args should be empty"); auto x = from_value(options); + apply_debug_symbols(x, options); return x.insert(name, m, ins, args); } @@ -140,6 +154,7 @@ struct register_builder_action template struct op_builder : auto_register { + std::set debug_symbols; static std::vector names() { static const std::string& name = get_type_name(); From 151743888485a3ddd78da80a5bb906bee424e963 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 23:14:29 -0600 Subject: [PATCH 006/107] Version 2 using scoped debug symbols in module --- src/include/migraphx/module.hpp | 20 +++++-- src/module.cpp | 57 +++++++++++++----- src/onnx/onnx_parser.cpp | 17 ++++-- src/onnx/parse_batchnorm.cpp | 2 - src/onnx/parse_convolution.cpp | 3 - src/op/builder/batchnorm.cpp | 59 +++++++------------ src/op/builder/convolution.cpp | 7 +-- .../include/migraphx/op/builder/insert.hpp | 8 --- .../migraphx/op/builder/op_builder.hpp | 15 ----- 9 files changed, 94 insertions(+), 94 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 596579c473d..ca6e11866ad 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -53,6 +53,20 @@ using ins_dep_map = std::unordered_map symbols); + ~scoped_debug_symbols(); + scoped_debug_symbols(const scoped_debug_symbols&) = delete; + scoped_debug_symbols& operator=(const scoped_debug_symbols&) = delete; + scoped_debug_symbols(scoped_debug_symbols&& other) noexcept; + scoped_debug_symbols& operator=(scoped_debug_symbols&& other) noexcept; + + private: + module* mod; + std::shared_ptr> previous; +}; + /** * @brief Stores the instruction stream */ @@ -207,8 +221,6 @@ struct MIGRAPHX_EXPORT module instruction_ref add_literal(literal l); - instruction_ref add_literal(literal l, const std::set& debug_symbols); - instruction_ref add_outline(const shape& s); instruction_ref add_parameter(std::string name, shape s); @@ -219,9 +231,6 @@ struct MIGRAPHX_EXPORT module instruction_ref insert_literal(instruction_ref ins, literal l); - instruction_ref - insert_literal(instruction_ref ins, literal l, const std::set& debug_symbols); - instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s); std::vector get_parameter_names() const; @@ -387,6 +396,7 @@ struct MIGRAPHX_EXPORT module friend bool operator!=(const module& x, const module& y) { return not(x == y); } friend struct program; + friend struct scoped_debug_symbols; private: void set_name(const std::string& name); diff --git a/src/module.cpp b/src/module.cpp index bdb3c01b2eb..ab6189836e7 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -61,6 +61,7 @@ struct module_impl uint32_t nparams = 0; bool bypass = false; // used for skipping compiler passes bit_signal<64> changed{}; + std::shared_ptr> active_debug_symbols; bool contains(instruction_ref ins) const { @@ -129,6 +130,38 @@ struct module_impl const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } +scoped_debug_symbols::scoped_debug_symbols(module& m, std::set symbols) + : mod(&m), previous(m.impl->active_debug_symbols) +{ + mod->impl->active_debug_symbols = + std::make_shared>(std::move(symbols)); +} + +scoped_debug_symbols::~scoped_debug_symbols() +{ + if(mod != nullptr) + mod->impl->active_debug_symbols = previous; +} + +scoped_debug_symbols::scoped_debug_symbols(scoped_debug_symbols&& other) noexcept + : mod(other.mod), previous(std::move(other.previous)) +{ + other.mod = nullptr; +} + +scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& other) noexcept +{ + if(this != &other) + { + if(mod != nullptr) + mod->impl->active_debug_symbols = previous; + mod = other.mod; + previous = std::move(other.previous); + other.mod = nullptr; + } + return *this; +} + module::module(const std::string& name) : impl(std::make_unique()) { impl->name = name; @@ -322,6 +355,8 @@ instruction_ref module::insert_instruction(instruction_ref ins, auto result = impl->insert(ins, {op, r, std::move(args)}); instruction::backreference(result); assert(result->valid(begin())); + if(impl->active_debug_symbols != nullptr) + result->add_debug_symbols(*impl->active_debug_symbols); return result; } @@ -336,6 +371,8 @@ instruction_ref module::insert_instruction(instruction_ref ins, auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); instruction::backreference(result); assert(result->valid(begin())); + if(impl->active_debug_symbols != nullptr) + result->add_debug_symbols(*impl->active_debug_symbols); return result; } @@ -594,13 +631,6 @@ module::insert_instructions(instruction_ref ins, instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } -instruction_ref module::add_literal(literal l, const std::set& debug_symbols) -{ - auto lit_ins = add_literal(l); - lit_ins->add_debug_symbols(debug_symbols); - return lit_ins; -} - instruction_ref module::add_outline(const shape& s) { impl->push_front({builtin::outline{s}, s, {}}); @@ -625,15 +655,10 @@ instruction_ref module::add_return(std::vector args) instruction_ref module::insert_literal(instruction_ref ins, literal l) { impl->emplace(ins, std::move(l)); - return std::prev(ins); -} - -instruction_ref -module::insert_literal(instruction_ref ins, literal l, const std::set& debug_symbols) -{ - auto lit_ins = insert_literal(ins, l); - lit_ins->add_debug_symbols(debug_symbols); - return lit_ins; + auto result = std::prev(ins); + if(impl->active_debug_symbols != nullptr) + result->add_debug_symbols(*impl->active_debug_symbols); + return result; } instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s) diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 7b56e9c352e..58adf3110c2 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -174,20 +174,28 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s */ instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const -{ return migraphx::add_common_op(*mod, make_op(op_name), {onnx_node_name}, std::move(inputs)); } +{ + return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); +} instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ return mod->add_instruction(op, {onnx_node_name}, args); } +{ + return mod->add_instruction(op, args); +} instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ return mod->add_instruction(op, {onnx_node_name}, args, mods); } +{ + return mod->add_instruction(op, args, mods); +} instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ return mod->add_literal(std::move(l), {onnx_node_name}); } +{ + return mod->add_literal(std::move(l)); +} onnx_parser::onnx_parser() { @@ -594,6 +602,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); node_info ninfo{ get_attributes(node), output_num, node_name, mod, node.name(), node.op_type()}; + scoped_debug_symbols guard(*mod, {node.name()}); result = ops[node.op_type()](*this, ninfo, args); } diff --git a/src/onnx/parse_batchnorm.cpp b/src/onnx/parse_batchnorm.cpp index 3c405f07bcc..ceef3c850bd 100644 --- a/src/onnx/parse_batchnorm.cpp +++ b/src/onnx/parse_batchnorm.cpp @@ -39,8 +39,6 @@ struct parse_batchnorm : op_parser const std::vector& args) const { value options = {}; - options.insert({"debug_symbols", std::set({info.onnx_node_name})}); - if(contains(info.attributes, "epsilon")) { const float epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); diff --git a/src/onnx/parse_convolution.cpp b/src/onnx/parse_convolution.cpp index e76be18a404..1492607b098 100644 --- a/src/onnx/parse_convolution.cpp +++ b/src/onnx/parse_convolution.cpp @@ -64,9 +64,6 @@ struct parse_convolution : op_parser check_padding_mode(info, opd.onnx_name); value options = {}; - std::set debug_s{info.onnx_node_name}; - options.insert({"debug_symbols", debug_s}); - if(contains(info.attributes, "strides")) { const auto& attr = info.attributes["strides"].ints(); diff --git a/src/op/builder/batchnorm.cpp b/src/op/builder/batchnorm.cpp index 5f230c80657..f1fa3479510 100644 --- a/src/op/builder/batchnorm.cpp +++ b/src/op/builder/batchnorm.cpp @@ -59,49 +59,34 @@ struct batchnorm : op_builder auto x_rank = x_lens.size(); if(x_rank == 1 or x_rank == 2) { - auto eps = - m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}, debug_symbols); - auto x_sub_mean = insert_common_op(m, ins, "sub", debug_symbols, args[0], args[3]); - auto var_eps = insert_common_op(m, ins, "add", debug_symbols, args[4], eps); - auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), debug_symbols, var_eps); - auto mul0 = insert_common_op(m, ins, "mul", debug_symbols, args[1], rsqrt); - auto r0 = insert_common_op(m, ins, "mul", debug_symbols, x_sub_mean, mul0); - return {insert_common_op(m, ins, "add", debug_symbols, r0, args[2])}; + auto eps = m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); + auto x_sub_mean = insert_common_op(m, ins, "sub", args[0], args[3]); + auto var_eps = insert_common_op(m, ins, "add", args[4], eps); + auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), var_eps); + auto mul0 = insert_common_op(m, ins, "mul", args[1], rsqrt); + auto r0 = insert_common_op(m, ins, "mul", x_sub_mean, mul0); + return {insert_common_op(m, ins, "add", r0, args[2])}; } else if(x_rank > 2) { // unsqueeze tensors of shape (C) to broadcast correctly std::vector unsqueeze_axes(x_lens.size() - 2); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); - auto eps = - m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}, debug_symbols); - auto scale_unsqueeze = - m.insert_instruction(ins, - migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), - debug_symbols, - args[1]); - auto bias_unsqueeze = - m.insert_instruction(ins, - migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), - debug_symbols, - args[2]); - auto mean_unsqueeze = - m.insert_instruction(ins, - migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), - debug_symbols, - args[3]); - auto var_unsqueeze = - m.insert_instruction(ins, - migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), - debug_symbols, - args[4]); - auto x_sub_mean = - insert_common_op(m, ins, "sub", debug_symbols, args[0], mean_unsqueeze); - auto var_eps = insert_common_op(m, ins, "add", debug_symbols, var_unsqueeze, eps); - auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), debug_symbols, var_eps); - auto mul0 = insert_common_op(m, ins, "mul", debug_symbols, scale_unsqueeze, rsqrt); - auto r0 = insert_common_op(m, ins, "mul", debug_symbols, x_sub_mean, mul0); - return {insert_common_op(m, ins, "add", debug_symbols, r0, bias_unsqueeze)}; + auto eps = m.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); + auto scale_unsqueeze = m.insert_instruction( + ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); + auto bias_unsqueeze = m.insert_instruction( + ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[2]); + auto mean_unsqueeze = m.insert_instruction( + ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); + auto var_unsqueeze = m.insert_instruction( + ins, migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); + auto x_sub_mean = insert_common_op(m, ins, "sub", args[0], mean_unsqueeze); + auto var_eps = insert_common_op(m, ins, "add", var_unsqueeze, eps); + auto rsqrt = m.insert_instruction(ins, make_op("rsqrt"), var_eps); + auto mul0 = insert_common_op(m, ins, "mul", scale_unsqueeze, rsqrt); + auto r0 = insert_common_op(m, ins, "mul", x_sub_mean, mul0); + return {insert_common_op(m, ins, "add", r0, bias_unsqueeze)}; } else { diff --git a/src/op/builder/convolution.cpp b/src/op/builder/convolution.cpp index 88acbf9cbb4..629d745c56e 100644 --- a/src/op/builder/convolution.cpp +++ b/src/op/builder/convolution.cpp @@ -147,9 +147,8 @@ struct convolution : convolution_base auto kdims = in_lens.size() - 2; validate_or_init_attributes(kdims, x, weights); - std::cout << "conv debug symbols empty?: " << debug_symbols.empty() << std::endl; - auto conv = - m.insert_instruction(ins, make_conv_op("convolution"), debug_symbols, x, weights); + + auto conv = m.insert_instruction(ins, make_conv_op("convolution"), x, weights); return {add_bias(m, ins, args, conv, 1)}; } @@ -202,7 +201,7 @@ struct quant_convolution : convolution_base auto x_zp = get_zero_point(m, x, 2, args); auto w_zp = get_zero_point(m, weights, 3, args); handle_quant_inputs(m, ins, x, weights, x_zp, w_zp); - auto conv = m.insert_instruction(ins, op, debug_symbols, x, weights); + auto conv = m.insert_instruction(ins, op, x, weights); return {handle_quant_bias(m, ins, op, conv, x, weights, x_zp, w_zp)}; } diff --git a/src/op/builder/include/migraphx/op/builder/insert.hpp b/src/op/builder/include/migraphx/op/builder/insert.hpp index d9e987cf476..ff64a05580e 100644 --- a/src/op/builder/include/migraphx/op/builder/insert.hpp +++ b/src/op/builder/include/migraphx/op/builder/insert.hpp @@ -67,14 +67,6 @@ insert_common_op(module& m, instruction_ref ins, const std::string& op_name, Ins return insert_common_op(m, ins, make_op(op_name), {args...}); } -template -instruction_ref insert_common_op(module& m, - instruction_ref ins, - const std::string& op_name, - const std::set& debug_symbols, - Ins... args) -{ return insert_common_op(m, ins, make_op(op_name), debug_symbols, {args...}); } - } // namespace builder } // namespace op } // namespace MIGRAPHX_INLINE_NS diff --git a/src/op/builder/include/migraphx/op/builder/op_builder.hpp b/src/op/builder/include/migraphx/op/builder/op_builder.hpp index d3446afa8f8..015d375687e 100644 --- a/src/op/builder/include/migraphx/op/builder/op_builder.hpp +++ b/src/op/builder/include/migraphx/op/builder/op_builder.hpp @@ -53,16 +53,6 @@ struct op_builder_if MIGRAPHX_EXPORT void register_builder(const std::string& name, op_builder_if opb_if); -template -void apply_debug_symbols(T& x, const value& options) -{ - if(options.contains("debug_symbols")) - { - x.debug_symbols = - from_value>(options.at("debug_symbols").without_key()); - } -} - template auto invoke_builder(const std::string& /*name*/, module& m, @@ -72,7 +62,6 @@ auto invoke_builder(const std::string& /*name*/, const value& options) -> decltype(T{}.insert(m, ins, args, module_args)) { auto x = from_value(options); - apply_debug_symbols(x, options); return x.insert(m, ins, args, module_args); } @@ -87,7 +76,6 @@ auto invoke_builder(const std::string& /*name*/, if(not module_args.empty()) MIGRAPHX_THROW("Module args should be empty"); auto x = from_value(options); - apply_debug_symbols(x, options); return x.insert(m, ins, args); } @@ -100,7 +88,6 @@ auto invoke_builder(const std::string& name, const value& options) -> decltype(T{}.insert(name, m, ins, args, module_args)) { auto x = from_value(options); - apply_debug_symbols(x, options); return x.insert(name, m, ins, args, module_args); } @@ -115,7 +102,6 @@ auto invoke_builder(const std::string& name, if(not module_args.empty()) MIGRAPHX_THROW("Module args should be empty"); auto x = from_value(options); - apply_debug_symbols(x, options); return x.insert(name, m, ins, args); } @@ -154,7 +140,6 @@ struct register_builder_action template struct op_builder : auto_register { - std::set debug_symbols; static std::vector names() { static const std::string& name = get_type_name(); From b8e1cdb93058951af2defa9f696ca70e2209d13a Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 23:31:07 -0600 Subject: [PATCH 007/107] Cleanup --- src/common.cpp | 64 +++---------------- src/include/migraphx/common.hpp | 22 ------- src/include/migraphx/module.hpp | 32 ---------- src/module.cpp | 43 +------------ .../include/migraphx/onnx/onnx_parser.hpp | 2 - src/onnx/onnx_parser.cpp | 19 ++---- 6 files changed, 18 insertions(+), 164 deletions(-) diff --git a/src/common.cpp b/src/common.cpp index 6413d5f8e95..10bc5d0707f 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -151,11 +151,10 @@ shape common_shape(const std::vector& shapes) return {compute_common_types(shapes), compute_common_lens(shapes)}; } -std::vector insert_common_args_impl(module& m, - instruction_ref ins, - const std::set& debug_symbols, - std::vector inputs, - common_options options) +std::vector insert_common_args(module& m, + instruction_ref ins, + std::vector inputs, + common_options options) { if(std::any_of( inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) @@ -168,10 +167,7 @@ std::vector insert_common_args_impl(module& m, auto s0 = inputs[0]->get_shape(); // always add both multibroadcast instructions for dynamic shapes inputs[0] = m.insert_instruction( - ins, - make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), - debug_symbols, - inputs); + ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs); std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) { // uses previous input to avoid recalculating the common shape from the // full set of input shapes at runtime @@ -179,7 +175,6 @@ std::vector insert_common_args_impl(module& m, return m.insert_instruction( ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), - debug_symbols, input, inputs[0]); }); @@ -191,7 +186,7 @@ std::vector insert_common_args_impl(module& m, if(input->get_shape().type() != c_type) { input = m.insert_instruction( - ins, make_op("convert", {{"target_type", c_type}}), debug_symbols, input); + ins, make_op("convert", {{"target_type", c_type}}), input); } return input; }); @@ -203,18 +198,13 @@ std::vector insert_common_args_impl(module& m, std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(options.common_lens and input->get_shape().lens() != common.lens()) { - input = - m.insert_instruction(ins, - make_op("multibroadcast", {{"out_lens", common.lens()}}), - debug_symbols, - input); + input = m.insert_instruction( + ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input); } if(options.common_type and input->get_shape().type() != common.type()) { - input = m.insert_instruction(ins, - make_op("convert", {{"target_type", common.type()}}), - debug_symbols, - input); + input = m.insert_instruction( + ins, make_op("convert", {{"target_type", common.type()}}), input); } return input; }); @@ -222,19 +212,6 @@ std::vector insert_common_args_impl(module& m, return inputs; } -std::vector insert_common_args(module& m, - instruction_ref ins, - std::vector inputs, - common_options options) -{ return insert_common_args_impl(m, ins, {}, std::move(inputs), options); } - -std::vector insert_common_args(module& m, - instruction_ref ins, - const std::set& debug_symbols, - std::vector inputs, - common_options options) -{ return insert_common_args_impl(m, ins, debug_symbols, std::move(inputs), options); } - std::vector add_common_args(module& m, std::vector inputs, common_options options) { @@ -250,20 +227,6 @@ instruction_ref insert_common_op(module& m, return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs), options)); } -instruction_ref insert_common_op(module& m, - instruction_ref ins, - const operation& op, - const std::set& debug_symbols, - std::vector inputs, - common_options options) -{ - return m.insert_instruction( - ins, - op, - debug_symbols, - insert_common_args(m, ins, debug_symbols, std::move(inputs), options)); -} - instruction_ref add_common_op(module& m, const operation& op, std::vector inputs, @@ -272,13 +235,6 @@ instruction_ref add_common_op(module& m, return insert_common_op(m, m.end(), op, std::move(inputs), options); } -instruction_ref add_common_op(module& m, - const operation& op, - const std::set& debug_symbols, - std::vector inputs, - common_options options) -{ return insert_common_op(m, m.end(), op, debug_symbols, std::move(inputs), options); } - shape make_bcast_shape(const shape& input_shape, const std::vector& bcast_lens) { assert(not input_shape.dynamic()); diff --git a/src/include/migraphx/common.hpp b/src/include/migraphx/common.hpp index ebb1e5d55d1..63a2dc43541 100644 --- a/src/include/migraphx/common.hpp +++ b/src/include/migraphx/common.hpp @@ -27,7 +27,6 @@ #include #include #include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -124,13 +123,6 @@ MIGRAPHX_EXPORT std::vector insert_common_args(module& m, std::vector inputs, common_options options = {}); -MIGRAPHX_EXPORT std::vector -insert_common_args(module& m, - instruction_ref ins, - const std::set& debug_symbols, - std::vector inputs, - common_options options = {}); - MIGRAPHX_EXPORT std::vector add_common_args(module& m, std::vector inputs, common_options options = {}); @@ -142,14 +134,6 @@ instruction_ref insert_common_op(module& m, std::vector inputs, common_options options = {}); -MIGRAPHX_EXPORT -instruction_ref insert_common_op(module& m, - instruction_ref ins, - const operation& op, - const std::set& debug_symbols, - std::vector inputs, - common_options options = {}); - /** * @brief Wrapper for insert_common_args() which inserts operation at the end of the module. */ @@ -158,12 +142,6 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector inputs, common_options options = {}); -MIGRAPHX_EXPORT -instruction_ref add_common_op(module& m, - const operation& op, - const std::set& debug_symbols, - std::vector inputs, - common_options options = {}); /** * Calculates the broadcasted shape with the given input_shape and broadcasted dimensions. diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index ca6e11866ad..e8e4327b0f1 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -108,20 +108,6 @@ struct MIGRAPHX_EXPORT module std::vector args, std::vector module_args); - template {}...)> - instruction_ref - add_instruction(operation op, const std::set& debug_symbols, Ts... args) - { return add_instruction(op, debug_symbols, {args...}); } - - instruction_ref add_instruction(const operation& op, - const std::set& debug_symbols, - std::vector args); - - instruction_ref add_instruction(const operation& op, - const std::set& debug_symbols, - std::vector args, - std::vector module_args); - template {}...)> instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) { @@ -136,24 +122,6 @@ struct MIGRAPHX_EXPORT module std::vector args, std::vector module_args); - template {}...)> - instruction_ref insert_instruction(instruction_ref ins, - operation op, - const std::set& debug_symbols, - Ts... args) - { return insert_instruction(ins, op, debug_symbols, {args...}); } - - instruction_ref insert_instruction(instruction_ref ins, - const operation& op, - const std::set& debug_symbols, - std::vector args); - - instruction_ref insert_instruction(instruction_ref ins, - const operation& op, - const std::set& debug_symbols, - std::vector args, - std::vector module_args); - template {}...)> instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) { diff --git a/src/module.cpp b/src/module.cpp index ab6189836e7..b220b0e0a8a 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -145,9 +145,7 @@ scoped_debug_symbols::~scoped_debug_symbols() scoped_debug_symbols::scoped_debug_symbols(scoped_debug_symbols&& other) noexcept : mod(other.mod), previous(std::move(other.previous)) -{ - other.mod = nullptr; -} +{ other.mod = nullptr; } scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& other) noexcept { @@ -155,8 +153,8 @@ scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& oth { if(mod != nullptr) mod->impl->active_debug_symbols = previous; - mod = other.mod; - previous = std::move(other.previous); + mod = other.mod; + previous = std::move(other.previous); other.mod = nullptr; } return *this; @@ -331,20 +329,6 @@ instruction_ref module::add_instruction(const operation& op, std::vector module_args) { return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); } -instruction_ref module::add_instruction(const operation& op, - const std::set& debug_symbols, - std::vector args) -{ return insert_instruction(this->insert_end(), op, debug_symbols, std::move(args)); } - -instruction_ref module::add_instruction(const operation& op, - const std::set& debug_symbols, - std::vector args, - std::vector module_args) -{ - return insert_instruction( - this->insert_end(), op, debug_symbols, std::move(args), std::move(module_args)); -} - instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) @@ -376,27 +360,6 @@ instruction_ref module::insert_instruction(instruction_ref ins, return result; } -instruction_ref module::insert_instruction(instruction_ref ins, - const operation& op, - const std::set& debug_symbols, - std::vector args) -{ - auto new_ins = insert_instruction(ins, op, args); - new_ins->add_debug_symbols(debug_symbols); - return new_ins; -} - -instruction_ref module::insert_instruction(instruction_ref ins, - const operation& op, - const std::set& debug_symbols, - std::vector args, - std::vector module_args) -{ - auto new_ins = insert_instruction(ins, op, args, module_args); - new_ins->add_debug_symbols(debug_symbols); - return new_ins; -} - instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index bfd9c9803a0..f5348874f2f 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -54,8 +54,6 @@ struct onnx_parser // unique identifier for MIGX, not given ONNX node name std::string name = ""; module* mod = nullptr; - std::string onnx_node_name{}; - std::string onnx_op_type{}; instruction_ref make_contiguous(instruction_ref ins) const; instruction_ref add_bias(const std::vector& args, diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 58adf3110c2..29f693c66fe 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -174,28 +174,20 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s */ instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const -{ - return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); -} +{ return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ - return mod->add_instruction(op, args); -} +{ return mod->add_instruction(op, args); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ - return mod->add_instruction(op, args, mods); -} +{ return mod->add_instruction(op, args, mods); } instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ - return mod->add_literal(std::move(l)); -} +{ return mod->add_literal(std::move(l)); } onnx_parser::onnx_parser() { @@ -600,8 +592,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini else { std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); - node_info ninfo{ - get_attributes(node), output_num, node_name, mod, node.name(), node.op_type()}; + node_info ninfo{get_attributes(node), output_num, node_name, mod}; scoped_debug_symbols guard(*mod, {node.name()}); result = ops[node.op_type()](*this, ninfo, args); } From ec7487e3e5ad054a6ef6a4059dc69813af40383b Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 18 Feb 2026 23:38:38 -0600 Subject: [PATCH 008/107] cleanup 2 --- src/include/migraphx/module.hpp | 1 - src/module.cpp | 13 ++++++----- .../include/migraphx/onnx/onnx_parser.hpp | 1 - src/onnx/onnx_parser.cpp | 22 +++++++++++++------ 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index e8e4327b0f1..1e15eb94a08 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -113,7 +113,6 @@ struct MIGRAPHX_EXPORT module { return insert_instruction(ins, op, {args...}); } - instruction_ref insert_instruction(instruction_ref ins, const operation& op, std::vector args); diff --git a/src/module.cpp b/src/module.cpp index b220b0e0a8a..ede92984c3f 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -323,12 +323,6 @@ instruction_ref module::add_instruction(const operation& op, std::vectorinsert_end(), op, std::move(args)); } - -instruction_ref module::add_instruction(const operation& op, - std::vector args, - std::vector module_args) -{ return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); } - instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) @@ -344,6 +338,13 @@ instruction_ref module::insert_instruction(instruction_ref ins, return result; } +instruction_ref module::add_instruction(const operation& op, + std::vector args, + std::vector module_args) +{ + return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); +} + instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args, diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index f5348874f2f..e81544a1683 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -54,7 +54,6 @@ struct onnx_parser // unique identifier for MIGX, not given ONNX node name std::string name = ""; module* mod = nullptr; - instruction_ref make_contiguous(instruction_ref ins) const; instruction_ref add_bias(const std::vector& args, instruction_ref curr_ins, diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 29f693c66fe..90029b4b68c 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -174,20 +174,28 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s */ instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const -{ return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); } +{ + return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); +} instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ return mod->add_instruction(op, args); } +{ + return mod->add_instruction(op, args); +} instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ return mod->add_instruction(op, args, mods); } +{ + return mod->add_instruction(op, args, mods); +} instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ return mod->add_literal(std::move(l)); } +{ + return mod->add_literal(std::move(l)); +} onnx_parser::onnx_parser() { @@ -704,9 +712,9 @@ static shape parse_tensor_shape(const onnx::TensorProto& t) literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const { - auto tensor_shape = parse_tensor_shape(t); - const auto& dims = tensor_shape.lens(); - auto type = tensor_shape.type(); + auto tensor_shape = parse_tensor_shape(t); + const auto& dims = tensor_shape.lens(); + auto type = tensor_shape.type(); const auto& external_data = t.external_data(); if(not external_data.empty()) From f499489f5663ce8e703052d02282f37fbb87e569 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 19 Feb 2026 11:24:49 -0600 Subject: [PATCH 009/107] Formatting --- src/include/migraphx/module.hpp | 4 +--- src/module.cpp | 8 ++------ src/onnx/onnx_parser.cpp | 22 +++++++--------------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 1e15eb94a08..43de97a9fea 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -110,9 +110,7 @@ struct MIGRAPHX_EXPORT module template {}...)> instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) - { - return insert_instruction(ins, op, {args...}); - } + { return insert_instruction(ins, op, {args...}); } instruction_ref insert_instruction(instruction_ref ins, const operation& op, std::vector args); diff --git a/src/module.cpp b/src/module.cpp index ede92984c3f..5560d8cc79b 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -320,9 +320,7 @@ insert_generic_instructions(module& m, } instruction_ref module::add_instruction(const operation& op, std::vector args) -{ - return insert_instruction(this->insert_end(), op, std::move(args)); -} +{ return insert_instruction(this->insert_end(), op, std::move(args)); } instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) @@ -341,9 +339,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, instruction_ref module::add_instruction(const operation& op, std::vector args, std::vector module_args) -{ - return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); -} +{ return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); } instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 90029b4b68c..29f693c66fe 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -174,28 +174,20 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s */ instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const -{ - return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); -} +{ return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ - return mod->add_instruction(op, args); -} +{ return mod->add_instruction(op, args); } instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ - return mod->add_instruction(op, args, mods); -} +{ return mod->add_instruction(op, args, mods); } instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ - return mod->add_literal(std::move(l)); -} +{ return mod->add_literal(std::move(l)); } onnx_parser::onnx_parser() { @@ -712,9 +704,9 @@ static shape parse_tensor_shape(const onnx::TensorProto& t) literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const { - auto tensor_shape = parse_tensor_shape(t); - const auto& dims = tensor_shape.lens(); - auto type = tensor_shape.type(); + auto tensor_shape = parse_tensor_shape(t); + const auto& dims = tensor_shape.lens(); + auto type = tensor_shape.type(); const auto& external_data = t.external_data(); if(not external_data.empty()) From 8324d9bad75e0f9e5ab03752e6fb786badaec5b5 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 14:16:54 -0600 Subject: [PATCH 010/107] Add debug propagate with matcher context --- src/include/migraphx/matcher.hpp | 29 ++ src/include/migraphx/module.hpp | 8 +- src/instruction.cpp | 12 +- src/module.cpp | 114 ++++++- src/onnx/onnx_parser.cpp | 12 +- src/program.cpp | 2 +- src/simplify_algebra.cpp | 46 ++- test/debug_symbols_test.cpp | 496 +++++++++++++++++++++++++++++++ 8 files changed, 684 insertions(+), 35 deletions(-) create mode 100644 test/debug_symbols_test.cpp diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 36ad2fd9052..39bf5a53816 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -476,6 +476,21 @@ auto make_match_runner_with_trace(source_location location, Finder& f) } // If its already invalid dont validate it again bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); + + optional debug_guard; + + std::set symbols; + const auto& mr_result_symbols = r.result->get_debug_symbols(); + symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); + for(const auto& mr_ins : r.instructions) + { + const auto& ins_symbols = mr_ins.second->get_debug_symbols(); + symbols.insert(ins_symbols.begin(), ins_symbols.end()); + } + + if(not symbols.empty()) + debug_guard.emplace(get_module(mod), symbols); + if(trace_enabled) { if(trace > 1) @@ -512,6 +527,17 @@ auto make_match_runner(Finder& f) match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) return false; + std::set symbols; + const auto& mr_result_symbols = r.result->get_debug_symbols(); + symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); + for(const auto& mr_ins : r.instructions) + { + const auto& ins_symbols = mr_ins.second->get_debug_symbols(); + symbols.insert(ins_symbols.begin(), ins_symbols.end()); + } + optional debug_guard; + if(not symbols.empty()) + debug_guard.emplace(get_module(mod), symbols); f.apply(mod, r); return true; }; @@ -580,6 +606,9 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M } // If its already invalid dont validate it again bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); + optional debug_guard; + if(not r.result->get_debug_symbols().empty()) + debug_guard.emplace(get_module(mod), r.result->get_debug_symbols()); auto apply_time = time>([&] { m.apply(mod, r); }); if(time_matchers or trace_for) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 43de97a9fea..7405b0398e5 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -64,7 +64,7 @@ struct MIGRAPHX_EXPORT scoped_debug_symbols private: module* mod; - std::shared_ptr> previous; + std::set previous; }; /** @@ -78,7 +78,7 @@ struct MIGRAPHX_EXPORT module const std::vector& inputs, const std::vector& mod_args)>; - module(const std::string& name = ""); + module(const std::string& name = "", bool use_debug_symbols = false); // move constructor module(module&&) noexcept; @@ -96,6 +96,9 @@ struct MIGRAPHX_EXPORT module bool bypass() const; void set_bypass(bool b = true); + bool get_use_debug_symbols() const; + void set_use_debug_symbols(bool b = true); + template {}...)> instruction_ref add_instruction(operation op, Ts... args) { @@ -370,6 +373,7 @@ struct MIGRAPHX_EXPORT module const module& pmod, instruction_ref ins, ins_dep_map& deps) const; + void propagate_replace_debug_symbols(instruction_ref rep_ins, const std::set& debug_symbols); std::unique_ptr impl; }; diff --git a/src/instruction.cpp b/src/instruction.cpp index 9d3d28f03e0..248e5bd9989 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -36,8 +36,6 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SHOW_DEBUG_SYMBOLS) - template static auto equal_to(const T& x) { @@ -196,13 +194,7 @@ const std::set& instruction::get_debug_symbols() const { return deb void instruction::add_debug_symbols(const std::set& symbols) { - for(const auto& symbol : symbols) - { - if(not symbol.empty()) - { - debug_symbols.insert(symbol); - } - } + debug_symbols.insert(symbols.begin(), symbols.end()); } bool operator==(const instruction& x, const instruction& y) @@ -458,7 +450,7 @@ void instruction::print(std::ostream& os, os << ", target_id=" << ins->target_id; // print debug symbols if enabled - if(enabled(MIGRAPHX_SHOW_DEBUG_SYMBOLS{}) and not ins->debug_symbols.empty()) + if(not ins->debug_symbols.empty()) { os << " /* " << join_strings(ins->debug_symbols, ", ") << " */"; } diff --git a/src/module.cpp b/src/module.cpp index 5560d8cc79b..de7c1815905 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -51,6 +51,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_DEBUG_SYMBOLS) struct module_impl { @@ -61,7 +62,8 @@ struct module_impl uint32_t nparams = 0; bool bypass = false; // used for skipping compiler passes bit_signal<64> changed{}; - std::shared_ptr> active_debug_symbols; + bool use_debug_symbols = false; + std::set active_debug_symbols; bool contains(instruction_ref ins) const { @@ -131,16 +133,15 @@ struct module_impl const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } scoped_debug_symbols::scoped_debug_symbols(module& m, std::set symbols) - : mod(&m), previous(m.impl->active_debug_symbols) + : mod(&m), previous(std::move(m.impl->active_debug_symbols)) { - mod->impl->active_debug_symbols = - std::make_shared>(std::move(symbols)); + mod->impl->active_debug_symbols = std::move(symbols); } scoped_debug_symbols::~scoped_debug_symbols() { if(mod != nullptr) - mod->impl->active_debug_symbols = previous; + mod->impl->active_debug_symbols = std::move(previous); } scoped_debug_symbols::scoped_debug_symbols(scoped_debug_symbols&& other) noexcept @@ -152,7 +153,7 @@ scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& oth if(this != &other) { if(mod != nullptr) - mod->impl->active_debug_symbols = previous; + mod->impl->active_debug_symbols = std::move(previous); mod = other.mod; previous = std::move(other.previous); other.mod = nullptr; @@ -160,9 +161,17 @@ scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& oth return *this; } -module::module(const std::string& name) : impl(std::make_unique()) +module::module(const std::string& name, bool use_debug_symbols) : impl(std::make_unique()) { impl->name = name; + if(enabled(MIGRAPHX_ENABLE_DEBUG_SYMBOLS{})) + { + impl->use_debug_symbols = true; + } + else + { + impl->use_debug_symbols = use_debug_symbols; + } } module::module(module&&) noexcept = default; @@ -185,6 +194,9 @@ void module::set_name(const std::string& name) { impl->name = name; } bool module::bypass() const { return impl->bypass; } void module::set_bypass(bool b) { impl->bypass = b; } +bool module::get_use_debug_symbols() const { return impl->use_debug_symbols; } +void module::set_use_debug_symbols(bool b) { impl->use_debug_symbols = b; } + void module::assign(const module& m) { // copy the impl @@ -321,6 +333,7 @@ insert_generic_instructions(module& m, instruction_ref module::add_instruction(const operation& op, std::vector args) { return insert_instruction(this->insert_end(), op, std::move(args)); } + instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) @@ -331,8 +344,8 @@ instruction_ref module::insert_instruction(instruction_ref ins, auto result = impl->insert(ins, {op, r, std::move(args)}); instruction::backreference(result); assert(result->valid(begin())); - if(impl->active_debug_symbols != nullptr) - result->add_debug_symbols(*impl->active_debug_symbols); + if(not impl->active_debug_symbols.empty()) + result->add_debug_symbols(impl->active_debug_symbols); return result; } @@ -352,11 +365,57 @@ instruction_ref module::insert_instruction(instruction_ref ins, auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); instruction::backreference(result); assert(result->valid(begin())); - if(impl->active_debug_symbols != nullptr) - result->add_debug_symbols(*impl->active_debug_symbols); + if(not impl->active_debug_symbols.empty()) + result->add_debug_symbols(impl->active_debug_symbols); return result; } +/** + * old_ins : instruction that will be replaced + * Traverse inputs of old_ins and gather debug_symbols of instructions that will become dead code. + */ +std::set gather_replace_debug_symbols(instruction_ref old_ins) +{ + std::set debug_symbols; + if(starts_with(old_ins->name(), "@")) + return debug_symbols; + const auto& old_ins_debug = old_ins->get_debug_symbols(); + if(old_ins_debug.empty()) + return debug_symbols; + debug_symbols.insert(old_ins_debug.begin(), old_ins_debug.end()); + for(auto input : old_ins->inputs()) + { + // check if only output to old_ins + if(input->outputs().size() == 1 and input->outputs().at(0) == old_ins) + { + const auto& gdebug = gather_replace_debug_symbols(input); + debug_symbols.insert(gdebug.begin(), gdebug.end()); + } + } + return debug_symbols; +} + +/** + * Add gathered debug_symbols to rep_ins and traverse it's inputs to add the same debug_symbols + * to instructions with empty debug_symbols. + */ +void module::propagate_replace_debug_symbols(instruction_ref rep_ins, const std::set& debug_symbols) +{ + if(starts_with(rep_ins->name(), "@")) + return ; + if(debug_symbols.empty()) + return; + rep_ins->add_debug_symbols(debug_symbols); + for(auto input : rep_ins->inputs()) + { + auto input_ds = input->get_debug_symbols(); + if(input_ds.empty() or input_ds == impl->active_debug_symbols) + { + propagate_replace_debug_symbols(input, debug_symbols); + } + } +} + instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST @@ -366,7 +425,16 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(not starts_with(op.name(), "@")); shape r = compute_shape(op, args); - instruction::replace(ins, op, r, std::move(args)); + if(get_use_debug_symbols()) + { + auto debug_symbols = gather_replace_debug_symbols(ins); + instruction::replace(ins, op, r, std::move(args)); + propagate_replace_debug_symbols(ins, debug_symbols); + } + else + { + instruction::replace(ins, op, r, std::move(args)); + } assert(ins->valid(begin())); return ins; } @@ -380,7 +448,16 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(has_instruction(ins)); assert(not starts_with(op.name(), "@")); auto out_shape = compute_shape(op, args, module_args); - instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); + if(get_use_debug_symbols()) + { + auto debug_symbols = gather_replace_debug_symbols(ins); + instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); + propagate_replace_debug_symbols(ins, debug_symbols); + } + else + { + instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); + } assert(ins->valid(begin())); return ins; } @@ -403,6 +480,10 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref { return rep; } + + std::set debug_symbols; + if(get_use_debug_symbols()) + debug_symbols = gather_replace_debug_symbols(ins); // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); for(auto out : outputs) @@ -414,6 +495,9 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref } assert(out->valid(begin())); } + if(get_use_debug_symbols()) + propagate_replace_debug_symbols(rep, debug_symbols); + // Replacement should not be dead code unless its the last instruction assert(not rep->outputs().empty() or rep == std::prev(end())); // Output of the original instruction should only be the replacement or empty @@ -616,8 +700,8 @@ instruction_ref module::insert_literal(instruction_ref ins, literal l) { impl->emplace(ins, std::move(l)); auto result = std::prev(ins); - if(impl->active_debug_symbols != nullptr) - result->add_debug_symbols(*impl->active_debug_symbols); + if(not impl->active_debug_symbols.empty()) + result->add_debug_symbols(impl->active_debug_symbols); return result; } diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 29f693c66fe..e482a642f8d 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -44,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace onnx { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_DEBUG_SYMBOLS) static shape shape_from_dyn_dims(shape::type_t shape_type, const std::vector& dyn_dims) @@ -593,8 +594,15 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini { std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); node_info ninfo{get_attributes(node), output_num, node_name, mod}; - scoped_debug_symbols guard(*mod, {node.name()}); - result = ops[node.op_type()](*this, ninfo, args); + if(mod->get_use_debug_symbols()) + { + scoped_debug_symbols guard(*mod, {node.name()}); + result = ops[node.op_type()](*this, ninfo, args); + } + else + { + result = ops[node.op_type()](*this, ninfo, args); + } } output_num = std::min(output_num, result.size()); diff --git a/src/program.cpp b/src/program.cpp index b7b5ff11152..baa1459dc09 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1370,7 +1370,7 @@ program& program::sort() return *this; } -bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } +bool operator==(const program& x, const program& y) { return migraphx::to_string(x) == migraphx::to_string(y); } std::ostream& operator<<(std::ostream& os, const program& p) { diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index d77fc63bf88..33008111895 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1633,7 +1633,7 @@ struct find_add_convs } }; -MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) +MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ref ins) { // checking size to prevent matching block quantized quant_dot for now auto pred = [&](auto name) { @@ -1642,10 +1642,46 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) i->inputs().at(1)->can_eval() and i->inputs().size() == 2; }; }; - auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); - auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot")); - auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); - return (dots >= 2 or convs >= 2 or qdots >= 2); + + // adding matched instructions to matcher_context to have their debug_symbols propagate + auto add_instructions_to_ctx = [&ctx](std::string key_prefix, std::vector ins_vec){ + int count = 1; + for(instruction_ref d : ins_vec) + { + std::stringstream ss; + ss << key_prefix << "_" << count; + ctx.instructions[ss.str()] = d; + count++; + } + }; + bool found_horiz = false; + std::vector dots; + std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("dot")); + std::vector qdots; + std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("quant_dot")); + std::vector convs; + std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("convolution")); + if(dots.size() >= 2) + { + found_horiz = true; + add_instructions_to_ctx("dot", dots); + } + else if(qdots.size() >= 2) + { + found_horiz = true; + add_instructions_to_ctx("qdot", qdots); + } + else if(convs.size() >= 2) + { + found_horiz = true; + add_instructions_to_ctx("conv", convs); + } + + if(found_horiz) + { + return {ins}; + } + return nullopt; } struct find_conv_dot_horiz_fusion diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp new file mode 100644 index 00000000000..1b4a223a10a --- /dev/null +++ b/test/debug_symbols_test.cpp @@ -0,0 +1,496 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// Two adds fused into a single pointwise op via fuse_pointwise. +// Both symbols should appear on the fused pointwise instruction. +TEST_CASE(pw_double_add) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + mm->set_use_debug_symbols(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + migraphx::instruction_ref add1; + { + migraphx::scoped_debug_symbols guard0(*mm, {"onnx:add0"}); + add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + } + migraphx::instruction_ref add2; + { + migraphx::scoped_debug_symbols guard1(*mm, {"onnx:add1"}); + add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + } + mm->add_return({add2}); + } + migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + mm->set_use_debug_symbols(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + fadd->add_debug_symbols({"onnx:add0", "onnx:add1"}); + mm->add_return({fadd}); + } + // BUG straight equality is not working even though both call migraphx::to_string + //EXPECT(p1 == p2); + EXPECT(to_string(p1) == to_string(p2)); +} + + +// Diamond pattern: add1 feeds into both add2 and add3, which then feed +// into add4. All four are fused into one pointwise op. Verifies that +// symbols from every instruction in the diamond appear on the fused result. +TEST_CASE(pw_used_twice_fused) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + mm->set_use_debug_symbols(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + add1->add_debug_symbols({"onnx:add1"}); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x); + add2->add_debug_symbols({"onnx:add2"}); + auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y); + add3->add_debug_symbols({"onnx:add3"}); + auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3); + add4->add_debug_symbols({"onnx:add4"}); + mm->add_return({add4}); + } + migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + mm->set_use_debug_symbols(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]); + auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add2, add3); + }); + fadd->add_debug_symbols({"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); + mm->add_return({fadd}); + } + // BUG straight equality is not working even though both call migraphx::to_string + EXPECT(to_string(p1.sort()) == to_string(p2.sort())); +} + +// Horizontal fusion of two dot ops sharing the same input via +// simplify_algebra. The two dots are fused into concat + single dot + slices. +// Each new instruction inherits the symbols of the original dots it derives +// from (e.g. the concat and fused dot carry both "gemm1" and "gemm2"). +TEST_CASE(horiz_fusion_dot) +{ + auto type = migraphx::shape::int32_type; + auto s = migraphx::shape{type, {3, 2, 2}}; + migraphx::module m1; + { + m1.set_use_debug_symbols(); + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(s, 0)); + auto b = m1.add_literal(migraphx::generate_literal(s, 1)); + auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); + x->add_debug_symbols({"gemm1"}); + auto y = m1.add_instruction(migraphx::make_op("dot"), input, b); + y->add_debug_symbols({"gemm2"}); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + sum->add_debug_symbols({"sum"}); + m1.add_instruction(pass_op{}, sum); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + m2.set_use_debug_symbols(); + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(s, 0)); + auto b = m2.add_literal(migraphx::generate_literal(s, 1)); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); + concat->add_debug_symbols({"gemm1", "gemm2"}); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); + dot->add_debug_symbols({"gemm1", "gemm2"}); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); + x->add_debug_symbols({"gemm1"}); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); + y->add_debug_symbols({"gemm2"}); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + sum->add_debug_symbols({"sum"}); + m2.add_instruction(pass_op{}, sum); + } + // BUG straight equality is not working even though both call migraphx::to_string + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Tests symbol propagation through add reassociation in simplify_algebra +// (find_double_add_lit_broadcast). Input: add0(add1(x, 1), add2(y, 2)) +// is reassociated to add(add(x, y), add(1, 2)). gather_replace_debug_symbols +// collects symbols from add0 and its dead-code inputs (add1, add2), then +// propagate_replace_debug_symbols spreads them to all replacement instructions. +TEST_CASE(simplify_add_debug_symbols) +{ + migraphx::module m1; + { + m1.set_use_debug_symbols(); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); + sum1->add_debug_symbols({"onnx:add1"}); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); + sum2->add_debug_symbols({"onnx:add2"}); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + sum3->add_debug_symbols({"onnx:add0"}); + m1.add_instruction(pass_op{}, sum3); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + m2.set_use_debug_symbols(); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + sum1->add_debug_symbols({"onnx:add0", "onnx:add1", "onnx:add2"}); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + sum2->add_debug_symbols({"onnx:add0", "onnx:add1", "onnx:add2"}); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); + sum3->add_debug_symbols({"onnx:add0", "onnx:add1", "onnx:add2"}); + m2.add_instruction(pass_op{}, sum3); + } + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Tests the replace_instruction(ins, rep) overload that replaces an +// instruction with an existing instruction_ref. find_unit_ops simplifies +// add(relu(x), broadcast(0)) to relu(x). gather_replace_debug_symbols +// walks the dead-code chain (add -> relu) collecting both symbols, then +// propagates them onto the surviving relu instruction. +TEST_CASE(replace_with_insref_debug_symbols) +{ + migraphx::module m1; + { + m1.set_use_debug_symbols(); + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto zero = m1.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {0.0f}}); + auto bcast = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), zero); + auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); + relu_x->add_debug_symbols({"onnx:relu"}); + auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); + add_r->add_debug_symbols({"onnx:add"}); + m1.add_instruction(pass_op{}, add_r); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + m2.set_use_debug_symbols(); + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto relu_x = m2.add_instruction(migraphx::make_op("relu"), x); + relu_x->add_debug_symbols({"onnx:add", "onnx:relu"}); + m2.add_instruction(pass_op{}, relu_x); + } + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Directly tests that gather_replace_debug_symbols walks the dead-code +// input chain. relu(x) feeds exclusively into add(relu, y). When add is +// replaced by mul(x, y) via replace_instruction(ins, rep), the gather step +// sees relu's sole consumer is add, so it collects "onnx:relu" along with +// "onnx:add" and propagates both to the replacement mul instruction. +TEST_CASE(gather_replace_chain_debug_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + m1.set_use_debug_symbols(); + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); + relu_x->add_debug_symbols({"onnx:relu"}); + auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, y); + add_r->add_debug_symbols({"onnx:add"}); + m1.add_instruction(pass_op{}, add_r); + + auto mul_r = m1.insert_instruction(add_r, migraphx::make_op("mul"), x, y); + m1.replace_instruction(add_r, mul_r); + } + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + m2.set_use_debug_symbols(); + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, y); + mul_r->add_debug_symbols({"onnx:add", "onnx:relu"}); + m2.add_instruction(pass_op{}, mul_r); + } + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Tests the distributive law transform in simplify_algebra (find_mul_add): +// mul(add(3, x), 2) -> add(mul(2, x), mul(2, 3)). The matcher's scoped +// guard captures "onnx:mul" from r.result, and gather_replace_debug_symbols +// also collects "onnx:add" from the dead-code add (its sole consumer was +// the mul being replaced). Both symbols propagate to all three new instructions. +TEST_CASE(simplify_mul_add_debug_symbols) +{ + migraphx::module m1; + { + m1.set_use_debug_symbols(); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(3); + auto two = m1.add_literal(2); + auto sum = m1.add_instruction(migraphx::make_op("add"), one, x); + sum->add_debug_symbols({"onnx:add"}); + auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two); + mul->add_debug_symbols({"onnx:mul"}); + m1.add_instruction(pass_op{}, mul); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + m2.set_use_debug_symbols(); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(3); + auto two = m2.add_literal(2); + auto mul1 = m2.add_instruction(migraphx::make_op("mul"), two, x); + mul1->add_debug_symbols({"onnx:add", "onnx:mul"}); + auto mul2 = m2.add_instruction(migraphx::make_op("mul"), two, one); + mul2->add_debug_symbols({"onnx:add", "onnx:mul"}); + auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); + sum->add_debug_symbols({"onnx:add", "onnx:mul"}); + m2.add_instruction(pass_op{}, sum); + } + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Tests symbol propagation through find_div_const in simplify_algebra: +// div(x, c) -> mul(x, recip(c)). The matcher's scoped guard sets "onnx:div" +// as the active symbol, so recip gets it via insert_instruction. The +// replace_instruction propagation then confirms "onnx:div" on the mul as well. +TEST_CASE(simplify_div_const_debug_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + m1.set_use_debug_symbols(); + auto x = m1.add_parameter("x", s); + auto c = m1.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto div_r = m1.add_instruction(migraphx::make_op("div"), x, c); + div_r->add_debug_symbols({"onnx:div"}); + m1.add_instruction(pass_op{}, div_r); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + m2.set_use_debug_symbols(); + auto x = m2.add_parameter("x", s); + auto c = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto recip = m2.add_instruction(migraphx::make_op("recip"), c); + recip->add_debug_symbols({"onnx:div"}); + auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, recip); + mul_r->add_debug_symbols({"onnx:div"}); + m2.add_instruction(pass_op{}, mul_r); + } + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Unit test for the scoped_debug_symbols RAII guard's save/restore behavior. +// An outer guard sets "outer", a nested inner guard temporarily replaces it +// with "inner", and after the inner guard's destructor runs the outer symbols +// are restored. Instructions created in each scope should carry only the +// symbols active at the time they were added. +TEST_CASE(scoped_debug_symbols_nesting) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m("test", true); + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + + migraphx::instruction_ref add1; + migraphx::instruction_ref add2; + migraphx::instruction_ref add3; + { + migraphx::scoped_debug_symbols outer(m, {"outer"}); + add1 = m.add_instruction(migraphx::make_op("add"), x, y); + { + migraphx::scoped_debug_symbols inner(m, {"inner"}); + add2 = m.add_instruction(migraphx::make_op("add"), add1, x); + } + add3 = m.add_instruction(migraphx::make_op("add"), add2, y); + } + + EXPECT(add1->get_debug_symbols() == std::set{"outer"}); + EXPECT(add2->get_debug_symbols() == std::set{"inner"}); + EXPECT(add3->get_debug_symbols() == std::set{"outer"}); +} + +// Three sequential adds fused into a single pointwise op via fuse_pointwise: +// add(add(add(x, y), z), w). All three ONNX node symbols should appear on +// the fused pointwise instruction. Extends pw_double_add to a longer chain. +TEST_CASE(pw_triple_add_fused) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + mm->set_use_debug_symbols(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto w = mm->add_parameter("w", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + add1->add_debug_symbols({"onnx:add1"}); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + add2->add_debug_symbols({"onnx:add2"}); + auto add3 = mm->add_instruction(migraphx::make_op("add"), add2, w); + add3->add_debug_symbols({"onnx:add3"}); + mm->add_return({add3}); + } + migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + mm->set_use_debug_symbols(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto w = mm->add_parameter("w", s); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x, y, z, w}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + return pm->add_instruction(migraphx::make_op("add"), add2, inputs[3]); + }); + fadd->add_debug_symbols({"onnx:add1", "onnx:add2", "onnx:add3"}); + mm->add_return({fadd}); + } + EXPECT(to_string(p1) == to_string(p2)); +} + +// Same add-reassociation pattern as simplify_add_debug_symbols but with +// set_use_debug_symbols() NOT called. The matcher's scoped guard still tags +// new instructions with the matched result's symbols ("onnx:add0") via +// insert_instruction, but replace_instruction skips gather/propagate when +// the flag is off. Dead-code symbols "onnx:add1" and "onnx:add2" are lost. +TEST_CASE(no_propagation_without_flag) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); + sum1->add_debug_symbols({"onnx:add1"}); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); + sum2->add_debug_symbols({"onnx:add2"}); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + sum3->add_debug_symbols({"onnx:add0"}); + m1.add_instruction(pass_op{}, sum3); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + sum1->add_debug_symbols({"onnx:add0"}); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + sum2->add_debug_symbols({"onnx:add0"}); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); + sum3->add_debug_symbols({"onnx:add0"}); + m2.add_instruction(pass_op{}, sum3); + } + EXPECT(to_string(m1.sort()) == to_string(m2.sort())); +} + +// Verifies that debug symbols appear in the module's printed/serialized +// output using the expected "/* sym_a, sym_b */" comment format produced +// by instruction::print. +TEST_CASE(debug_symbols_in_print) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + m.set_use_debug_symbols(); + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + add->add_debug_symbols({"sym_a", "sym_b"}); + m.add_instruction(pass_op{}, add); + + auto str = migraphx::to_string(m); + EXPECT(str.find("/* sym_a, sym_b */") != std::string::npos); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From c9f567fa68e63bc70bb43a78bdec02ed37ba2d7d Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 14:17:21 -0600 Subject: [PATCH 011/107] formatting --- src/include/migraphx/module.hpp | 3 +- src/instruction.cpp | 4 +- src/module.cpp | 12 +++--- src/program.cpp | 3 +- src/simplify_algebra.cpp | 14 +++++-- test/debug_symbols_test.cpp | 67 ++++++++++++++++----------------- 6 files changed, 54 insertions(+), 49 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 7405b0398e5..1e3261d4b61 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -373,7 +373,8 @@ struct MIGRAPHX_EXPORT module const module& pmod, instruction_ref ins, ins_dep_map& deps) const; - void propagate_replace_debug_symbols(instruction_ref rep_ins, const std::set& debug_symbols); + void propagate_replace_debug_symbols(instruction_ref rep_ins, + const std::set& debug_symbols); std::unique_ptr impl; }; diff --git a/src/instruction.cpp b/src/instruction.cpp index 248e5bd9989..dbd662f41a4 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -193,9 +193,7 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } void instruction::add_debug_symbols(const std::set& symbols) -{ - debug_symbols.insert(symbols.begin(), symbols.end()); -} +{ debug_symbols.insert(symbols.begin(), symbols.end()); } bool operator==(const instruction& x, const instruction& y) { diff --git a/src/module.cpp b/src/module.cpp index de7c1815905..adff2518d92 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -134,9 +134,7 @@ const operation& get_operation(instruction_ref ins) { return ins->get_operator() scoped_debug_symbols::scoped_debug_symbols(module& m, std::set symbols) : mod(&m), previous(std::move(m.impl->active_debug_symbols)) -{ - mod->impl->active_debug_symbols = std::move(symbols); -} +{ mod->impl->active_debug_symbols = std::move(symbols); } scoped_debug_symbols::~scoped_debug_symbols() { @@ -161,7 +159,8 @@ scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& oth return *this; } -module::module(const std::string& name, bool use_debug_symbols) : impl(std::make_unique()) +module::module(const std::string& name, + bool use_debug_symbols) :impl(std::make_unique()) { impl->name = name; if(enabled(MIGRAPHX_ENABLE_DEBUG_SYMBOLS{})) @@ -399,10 +398,11 @@ std::set gather_replace_debug_symbols(instruction_ref old_ins) * Add gathered debug_symbols to rep_ins and traverse it's inputs to add the same debug_symbols * to instructions with empty debug_symbols. */ -void module::propagate_replace_debug_symbols(instruction_ref rep_ins, const std::set& debug_symbols) +void module::propagate_replace_debug_symbols(instruction_ref rep_ins, + const std::set& debug_symbols) { if(starts_with(rep_ins->name(), "@")) - return ; + return; if(debug_symbols.empty()) return; rep_ins->add_debug_symbols(debug_symbols); diff --git a/src/program.cpp b/src/program.cpp index baa1459dc09..20418f88bfa 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1370,7 +1370,8 @@ program& program::sort() return *this; } -bool operator==(const program& x, const program& y) { return migraphx::to_string(x) == migraphx::to_string(y); } +bool operator==(const program& x, const program& y) +{ return migraphx::to_string(x) == migraphx::to_string(y); } std::ostream& operator<<(std::ostream& os, const program& p) { diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 33008111895..bfe58be81ea 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1644,7 +1644,8 @@ MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ }; // adding matched instructions to matcher_context to have their debug_symbols propagate - auto add_instructions_to_ctx = [&ctx](std::string key_prefix, std::vector ins_vec){ + auto add_instructions_to_ctx = [&ctx](std::string key_prefix, + std::vector ins_vec) { int count = 1; for(instruction_ref d : ins_vec) { @@ -1656,11 +1657,16 @@ MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ }; bool found_horiz = false; std::vector dots; - std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("dot")); + std::copy_if( + ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("dot")); std::vector qdots; - std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("quant_dot")); + std::copy_if( + ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("quant_dot")); std::vector convs; - std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("convolution")); + std::copy_if(ins->outputs().begin(), + ins->outputs().end(), + std::back_inserter(dots), + pred("convolution")); if(dots.size() >= 2) { found_horiz = true; diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 1b4a223a10a..7dfce5bf1ad 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -44,11 +44,11 @@ TEST_CASE(pw_double_add) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); + auto* mm = p1.get_main_module(); mm->set_use_debug_symbols(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); migraphx::instruction_ref add1; { migraphx::scoped_debug_symbols guard0(*mm, {"onnx:add0"}); @@ -65,11 +65,11 @@ TEST_CASE(pw_double_add) migraphx::program p2; { - auto* mm = p2.get_main_module(); + auto* mm = p2.get_main_module(); mm->set_use_debug_symbols(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); auto fadd = add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); @@ -79,11 +79,10 @@ TEST_CASE(pw_double_add) mm->add_return({fadd}); } // BUG straight equality is not working even though both call migraphx::to_string - //EXPECT(p1 == p2); + // EXPECT(p1 == p2); EXPECT(to_string(p1) == to_string(p2)); } - // Diamond pattern: add1 feeds into both add2 and add3, which then feed // into add4. All four are fused into one pointwise op. Verifies that // symbols from every instruction in the diamond appear on the fused result. @@ -92,7 +91,7 @@ TEST_CASE(pw_used_twice_fused) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); + auto* mm = p1.get_main_module(); mm->set_use_debug_symbols(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); @@ -110,7 +109,7 @@ TEST_CASE(pw_used_twice_fused) migraphx::program p2; { - auto* mm = p2.get_main_module(); + auto* mm = p2.get_main_module(); mm->set_use_debug_symbols(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); @@ -134,7 +133,7 @@ TEST_CASE(pw_used_twice_fused) TEST_CASE(horiz_fusion_dot) { auto type = migraphx::shape::int32_type; - auto s = migraphx::shape{type, {3, 2, 2}}; + auto s = migraphx::shape{type, {3, 2, 2}}; migraphx::module m1; { m1.set_use_debug_symbols(); @@ -143,9 +142,9 @@ TEST_CASE(horiz_fusion_dot) auto b = m1.add_literal(migraphx::generate_literal(s, 1)); auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); x->add_debug_symbols({"gemm1"}); - auto y = m1.add_instruction(migraphx::make_op("dot"), input, b); + auto y = m1.add_instruction(migraphx::make_op("dot"), input, b); y->add_debug_symbols({"gemm2"}); - auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); sum->add_debug_symbols({"sum"}); m1.add_instruction(pass_op{}, sum); } @@ -159,9 +158,9 @@ TEST_CASE(horiz_fusion_dot) auto b = m2.add_literal(migraphx::generate_literal(s, 1)); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); concat->add_debug_symbols({"gemm1", "gemm2"}); - auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); dot->add_debug_symbols({"gemm1", "gemm2"}); - auto x = m2.add_instruction( + auto x = m2.add_instruction( migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); x->add_debug_symbols({"gemm1"}); auto y = m2.add_instruction( @@ -227,14 +226,14 @@ TEST_CASE(replace_with_insref_debug_symbols) migraphx::module m1; { m1.set_use_debug_symbols(); - auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); - auto zero = m1.add_literal( + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto zero = m1.add_literal( migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {0.0f}}); - auto bcast = m1.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), zero); + auto bcast = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), zero); auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); relu_x->add_debug_symbols({"onnx:relu"}); - auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); + auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); add_r->add_debug_symbols({"onnx:add"}); m1.add_instruction(pass_op{}, add_r); } @@ -266,7 +265,7 @@ TEST_CASE(gather_replace_chain_debug_symbols) auto y = m1.add_parameter("y", s); auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); relu_x->add_debug_symbols({"onnx:relu"}); - auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, y); + auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, y); add_r->add_debug_symbols({"onnx:add"}); m1.add_instruction(pass_op{}, add_r); @@ -318,7 +317,7 @@ TEST_CASE(simplify_mul_add_debug_symbols) mul1->add_debug_symbols({"onnx:add", "onnx:mul"}); auto mul2 = m2.add_instruction(migraphx::make_op("mul"), two, one); mul2->add_debug_symbols({"onnx:add", "onnx:mul"}); - auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); + auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); sum->add_debug_symbols({"onnx:add", "onnx:mul"}); m2.add_instruction(pass_op{}, sum); } @@ -336,9 +335,9 @@ TEST_CASE(simplify_div_const_debug_symbols) { m1.set_use_debug_symbols(); auto x = m1.add_parameter("x", s); - auto c = m1.add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto c = + m1.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); auto div_r = m1.add_instruction(migraphx::make_op("div"), x, c); div_r->add_debug_symbols({"onnx:div"}); m1.add_instruction(pass_op{}, div_r); @@ -349,9 +348,9 @@ TEST_CASE(simplify_div_const_debug_symbols) { m2.set_use_debug_symbols(); auto x = m2.add_parameter("x", s); - auto c = m2.add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto c = + m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); auto recip = m2.add_instruction(migraphx::make_op("recip"), c); recip->add_debug_symbols({"onnx:div"}); auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, recip); @@ -401,10 +400,10 @@ TEST_CASE(pw_triple_add_fused) { auto* mm = p1.get_main_module(); mm->set_use_debug_symbols(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto w = mm->add_parameter("w", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto w = mm->add_parameter("w", s); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); add1->add_debug_symbols({"onnx:add1"}); auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); From e66c848f31b25ebbe86523aecaad452b5a8ed459 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 15:39:20 -0600 Subject: [PATCH 012/107] tests and fixup --- src/include/migraphx/matcher.hpp | 35 +++--- test/debug_symbols_test.cpp | 189 +++++++++++++++++++++++++------ 2 files changed, 177 insertions(+), 47 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 39bf5a53816..e55557c7572 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -477,17 +477,20 @@ auto make_match_runner_with_trace(source_location location, Finder& f) // If its already invalid dont validate it again bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); - optional debug_guard; - + // Get debug symbols from matcher_context.instructions and matcher_result. std::set symbols; - const auto& mr_result_symbols = r.result->get_debug_symbols(); - symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); - for(const auto& mr_ins : r.instructions) + if(get_module(mod).get_use_debug_symbols()) { - const auto& ins_symbols = mr_ins.second->get_debug_symbols(); - symbols.insert(ins_symbols.begin(), ins_symbols.end()); + const auto& mr_result_symbols = r.result->get_debug_symbols(); + symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); + for(const auto& mr_ins : r.instructions) + { + const auto& ins_symbols = mr_ins.second->get_debug_symbols(); + symbols.insert(ins_symbols.begin(), ins_symbols.end()); + } } - + // `scoped_debug_symbols` for matcher apply. + optional debug_guard; if(not symbols.empty()) debug_guard.emplace(get_module(mod), symbols); @@ -527,14 +530,20 @@ auto make_match_runner(Finder& f) match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) return false; + + // Get debug symbols from matcher_context.instructions and matcher_result. std::set symbols; - const auto& mr_result_symbols = r.result->get_debug_symbols(); - symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); - for(const auto& mr_ins : r.instructions) + if(get_module(mod).get_use_debug_symbols()) { - const auto& ins_symbols = mr_ins.second->get_debug_symbols(); - symbols.insert(ins_symbols.begin(), ins_symbols.end()); + const auto& mr_result_symbols = r.result->get_debug_symbols(); + symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); + for(const auto& mr_ins : r.instructions) + { + const auto& ins_symbols = mr_ins.second->get_debug_symbols(); + symbols.insert(ins_symbols.begin(), ins_symbols.end()); + } } + // `scoped_debug_symbols` for matcher apply. optional debug_guard; if(not symbols.empty()) debug_guard.emplace(get_module(mod), symbols); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 7dfce5bf1ad..d0460420921 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -39,6 +39,18 @@ // Two adds fused into a single pointwise op via fuse_pointwise. // Both symbols should appear on the fused pointwise instruction. +// +// Before: After: +// +// x y x y z +// \ / \ | / +// add {add0} pointwise {add0, add1} +// | z | +// | / @return +// add {add1} +// | +// @return +// TEST_CASE(pw_double_add) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -86,6 +98,21 @@ TEST_CASE(pw_double_add) // Diamond pattern: add1 feeds into both add2 and add3, which then feed // into add4. All four are fused into one pointwise op. Verifies that // symbols from every instruction in the diamond appear on the fused result. +// +// Before: After: +// +// x y x y +// \ / \ / +// add1 {add1} pointwise {add1, add2, add3, add4} +// / \ | +// x y @return +// | | +// add2 add3 {add2} {add3} +// \ / +// add4 {add4} +// | +// @return +// TEST_CASE(pw_used_twice_fused) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -130,6 +157,22 @@ TEST_CASE(pw_used_twice_fused) // simplify_algebra. The two dots are fused into concat + single dot + slices. // Each new instruction inherits the symbols of the original dots it derives // from (e.g. the concat and fused dot carry both "gemm1" and "gemm2"). +// +// Before: After: +// +// input a input b a b +// \ / \ / \ / +// dot {g1} dot {g2} concat {g1, g2} +// \ / input | +// \ / \ | +// add {sum} dot {g1, g2} +// | / \ +// pass slice{g1} slice{g2} +// \ / +// add {sum} +// | +// pass +// TEST_CASE(horiz_fusion_dot) { auto type = migraphx::shape::int32_type; @@ -175,10 +218,18 @@ TEST_CASE(horiz_fusion_dot) } // Tests symbol propagation through add reassociation in simplify_algebra -// (find_double_add_lit_broadcast). Input: add0(add1(x, 1), add2(y, 2)) -// is reassociated to add(add(x, y), add(1, 2)). gather_replace_debug_symbols -// collects symbols from add0 and its dead-code inputs (add1, add2), then -// propagate_replace_debug_symbols spreads them to all replacement instructions. +// (find_double_add_lit_broadcast). Checks add(add(x,1), add(y,2)) -> (add(add(x,y), add(1,2)). +// +// Before: After: +// +// x 1 y 2 1 2 x y +// \ / \ / \ / \ / +// add1{a1} add2{a2} add{a0,a1,a2} add{a0,a1,a2} +// \ / \ / +// add0{a0} add{a0,a1,a2} +// | | +// pass pass +// TEST_CASE(simplify_add_debug_symbols) { migraphx::module m1; @@ -216,11 +267,20 @@ TEST_CASE(simplify_add_debug_symbols) EXPECT(to_string(m1.sort()) == to_string(m2.sort())); } -// Tests the replace_instruction(ins, rep) overload that replaces an -// instruction with an existing instruction_ref. find_unit_ops simplifies -// add(relu(x), broadcast(0)) to relu(x). gather_replace_debug_symbols -// walks the dead-code chain (add -> relu) collecting both symbols, then -// propagates them onto the surviving relu instruction. +// Tests the replace_instruction(ins, rep) overload via find_unit_ops which +// simplifies add(relu(x), broadcast(0)) to relu(x). +// +// Before: After: +// +// x 0 x +// | | | +// relu bcast relu {add, relu} +// {relu} (0.0) | +// \ / pass +// add {add} +// | +// pass +// TEST_CASE(replace_with_insref_debug_symbols) { migraphx::module m1; @@ -250,11 +310,19 @@ TEST_CASE(replace_with_insref_debug_symbols) EXPECT(to_string(m1.sort()) == to_string(m2.sort())); } -// Directly tests that gather_replace_debug_symbols walks the dead-code -// input chain. relu(x) feeds exclusively into add(relu, y). When add is -// replaced by mul(x, y) via replace_instruction(ins, rep), the gather step -// sees relu's sole consumer is add, so it collects "onnx:relu" along with -// "onnx:add" and propagates both to the replacement mul instruction. +// Tests that debug_symbols propagate through the dead-code chain. +// +// Before: After: +// +// x x y +// | \ / +// relu {relu} mul {add, relu} +// | y | +// | / pass +// add {add} +// | (relu becomes dead code, removed by DCE) +// pass +// TEST_CASE(gather_replace_chain_debug_symbols) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -287,10 +355,19 @@ TEST_CASE(gather_replace_chain_debug_symbols) } // Tests the distributive law transform in simplify_algebra (find_mul_add): -// mul(add(3, x), 2) -> add(mul(2, x), mul(2, 3)). The matcher's scoped -// guard captures "onnx:mul" from r.result, and gather_replace_debug_symbols -// also collects "onnx:add" from the dead-code add (its sole consumer was -// the mul being replaced). Both symbols propagate to all three new instructions. +// mul(add(3, x), 2) -> add(mul(2, x), mul(2, 3)). +// +// Before: After: +// +// 3 x 2 x 2 3 +// \ / \ / \ / +// add {add} mul{add,mul} mul{add,mul} +// | 2 \ / +// | / add{add,mul} +// mul {mul} | +// | pass +// pass +// TEST_CASE(simplify_mul_add_debug_symbols) { migraphx::module m1; @@ -325,9 +402,19 @@ TEST_CASE(simplify_mul_add_debug_symbols) } // Tests symbol propagation through find_div_const in simplify_algebra: -// div(x, c) -> mul(x, recip(c)). The matcher's scoped guard sets "onnx:div" -// as the active symbol, so recip gets it via insert_instruction. The -// replace_instruction propagation then confirms "onnx:div" on the mul as well. +// div(x, c) -> mul(x, recip(c)). +// +// Before: After: +// +// x c c +// \ / | +// div {div} recip {div} +// | x | +// pass \ | +// mul {div} +// | +// pass +// TEST_CASE(simplify_div_const_debug_symbols) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -363,8 +450,15 @@ TEST_CASE(simplify_div_const_debug_symbols) // Unit test for the scoped_debug_symbols RAII guard's save/restore behavior. // An outer guard sets "outer", a nested inner guard temporarily replaces it // with "inner", and after the inner guard's destructor runs the outer symbols -// are restored. Instructions created in each scope should carry only the -// symbols active at the time they were added. +// are restored. Instructions created in each scope carry only the symbols +// active at the time they were added. +// +// scope active symbols instruction gets symbol +// ------- ---------------- ----------- ----------- +// outer {"outer"} add1(x, y) {"outer"} +// inner {"inner"} add2(add1, x) {"inner"} +// outer {"outer"} add3(add2, y) {"outer"} +// TEST_CASE(scoped_debug_symbols_nesting) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -390,9 +484,24 @@ TEST_CASE(scoped_debug_symbols_nesting) EXPECT(add3->get_debug_symbols() == std::set{"outer"}); } -// Three sequential adds fused into a single pointwise op via fuse_pointwise: -// add(add(add(x, y), z), w). All three ONNX node symbols should appear on -// the fused pointwise instruction. Extends pw_double_add to a longer chain. +// Three sequential adds fused into a single pointwise op via fuse_pointwise. +// All three ONNX node symbols should appear on the fused pointwise instruction. +// Extends pw_double_add to a longer chain. +// +// Before: After: +// +// x y x y z w +// \ / \ | | / +// add {add1} pointwise {add1, add2, add3} +// | z | +// | / @return +// add {add2} +// | w +// | / +// add {add3} +// | +// @return +// TEST_CASE(pw_triple_add_fused) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -435,10 +544,21 @@ TEST_CASE(pw_triple_add_fused) } // Same add-reassociation pattern as simplify_add_debug_symbols but with -// set_use_debug_symbols() NOT called. The matcher's scoped guard still tags -// new instructions with the matched result's symbols ("onnx:add0") via -// insert_instruction, but replace_instruction skips gather/propagate when -// the flag is off. Dead-code symbols "onnx:add1" and "onnx:add2" are lost. +// set_use_debug_symbols() NOT called. +// +// Before: After (flag OFF): +// +// x 1 y 2 1 2 x y +// \ / \ / \ / \ / +// add1{a1} add2{a2} add{} add{} +// \ / \ / +// add0{a0} add{a0} +// | | +// pass pass +// +// (compare with simplify_add_debug_symbols where flag ON +// gives {a0, a1, a2} on every instruction) +// TEST_CASE(no_propagation_without_flag) { migraphx::module m1; @@ -464,9 +584,7 @@ TEST_CASE(no_propagation_without_flag) auto one = m2.add_literal(1); auto two = m2.add_literal(2); auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); - sum1->add_debug_symbols({"onnx:add0"}); auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); - sum2->add_debug_symbols({"onnx:add0"}); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); sum3->add_debug_symbols({"onnx:add0"}); m2.add_instruction(pass_op{}, sum3); @@ -475,8 +593,11 @@ TEST_CASE(no_propagation_without_flag) } // Verifies that debug symbols appear in the module's printed/serialized -// output using the expected "/* sym_a, sym_b */" comment format produced -// by instruction::print. +// output using the expected comment format produced by instruction::print. +// +// Printed output includes: +// @2 = add(@0,@1) -> float_type, {2, 3} /* sym_a, sym_b */ +// TEST_CASE(debug_symbols_in_print) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; From a769c1116db853c4ee2e954b65a95b13e33107ab Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 16:00:41 -0600 Subject: [PATCH 013/107] Add docs changes --- CHANGELOG.md | 1 + docs/reference/MIGraphX-dev-env-vars.rst | 8 ++++++++ src/instruction.cpp | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf8b9c00c67..55880989be1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Full documentation for MIGraphX is available at * Added a dedicated logger for MIGraphX. * [Linux] Use HSA API to query number of chiplets for architectures where this is applicable (ex. gfx90a). +* Added debug symbols for MIGraphX instructions such that parsed and compiled instructions can be tracked back to their ONNX origin node node with MIGRAPHX_ENABLE_DEBUG_SYMBOLS (#4626) ### Changed diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 87bcb91af93..7253505ede9 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -514,6 +514,14 @@ Compilation tracing | Default: Quantization parameters aren't printed. + * - | ``MIGRAPHX_ENABLE_DEBUG_SYMBOLS`` + | Adds parsing and propagating of debug symbols through compiler passes such that the origin of instructions + | can be more easily determined. For ONNX models, the debug symbols are the ONNX node names. + + - | ``1``: Enable parsing and propagating debug symbols. + + | Default: Debug symbols are not parsed nor propagated. + MLIR ************************** diff --git a/src/instruction.cpp b/src/instruction.cpp index dbd662f41a4..c63a2453e40 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -447,7 +447,7 @@ void instruction::print(std::ostream& os, if(ins->target_id != 0) os << ", target_id=" << ins->target_id; - // print debug symbols if enabled + // print debug symbols if they exist if(not ins->debug_symbols.empty()) { os << " /* " << join_strings(ins->debug_symbols, ", ") << " */"; From 7321509e002dd85d1a40b1ce32176c390d38ae49 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 16:05:22 -0600 Subject: [PATCH 014/107] Have gather_replace_debug_symbols traverse empty Traverse over ins with empty debug symbols also just in case. Probably will not occur however --- src/module.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index adff2518d92..fafa3a470ca 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -379,8 +379,6 @@ std::set gather_replace_debug_symbols(instruction_ref old_ins) if(starts_with(old_ins->name(), "@")) return debug_symbols; const auto& old_ins_debug = old_ins->get_debug_symbols(); - if(old_ins_debug.empty()) - return debug_symbols; debug_symbols.insert(old_ins_debug.begin(), old_ins_debug.end()); for(auto input : old_ins->inputs()) { From 2794a6695d9431b4a42c546a02447f5457e1d4ab Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 16:13:03 -0600 Subject: [PATCH 015/107] Copilot suggestions --- src/onnx/onnx_parser.cpp | 3 ++- src/simplify_algebra.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index e482a642f8d..d83783fbd67 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -596,7 +596,8 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini node_info ninfo{get_attributes(node), output_num, node_name, mod}; if(mod->get_use_debug_symbols()) { - scoped_debug_symbols guard(*mod, {node.name()}); + std::string debug_symbol = node.name().empty() ? std::string("migx_uid:") + node_name : node.name(); + scoped_debug_symbols guard(*mod, {debug_symbol}); result = ops[node.op_type()](*this, ninfo, args); } else diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index a1b7b0fe5b6..2332824a203 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1661,11 +1661,11 @@ MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("dot")); std::vector qdots; std::copy_if( - ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("quant_dot")); + ins->outputs().begin(), ins->outputs().end(), std::back_inserter(qdots), pred("quant_dot")); std::vector convs; std::copy_if(ins->outputs().begin(), ins->outputs().end(), - std::back_inserter(dots), + std::back_inserter(convs), pred("convolution")); if(dots.size() >= 2) { From 69c787a40e1e33d1be53660ff165d62e83b7888e Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 20 Feb 2026 16:15:30 -0600 Subject: [PATCH 016/107] more copilot --- src/simplify_algebra.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 2332824a203..a99986d1234 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1645,7 +1645,7 @@ MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ // adding matched instructions to matcher_context to have their debug_symbols propagate auto add_instructions_to_ctx = [&ctx](std::string key_prefix, - std::vector ins_vec) { + const std::vector& ins_vec) { int count = 1; for(instruction_ref d : ins_vec) { From d90382c0b1ee5e985343638021ceb845fa43c521 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Fri, 20 Feb 2026 17:29:53 -0500 Subject: [PATCH 017/107] Update CHANGELOG.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55880989be1..1694bf695e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ Full documentation for MIGraphX is available at * Added a dedicated logger for MIGraphX. * [Linux] Use HSA API to query number of chiplets for architectures where this is applicable (ex. gfx90a). -* Added debug symbols for MIGraphX instructions such that parsed and compiled instructions can be tracked back to their ONNX origin node node with MIGRAPHX_ENABLE_DEBUG_SYMBOLS (#4626) +* Added debug symbols for MIGraphX instructions such that parsed and compiled instructions can be tracked back to their ONNX origin node with MIGRAPHX_ENABLE_DEBUG_SYMBOLS (#4626) ### Changed From 107a7e78a7b93497ee90e35b4c05352e7c95e70f Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Fri, 20 Feb 2026 17:30:04 -0500 Subject: [PATCH 018/107] Update src/module.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/module.cpp b/src/module.cpp index fafa3a470ca..0c59b3c4330 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -393,7 +393,7 @@ std::set gather_replace_debug_symbols(instruction_ref old_ins) } /** - * Add gathered debug_symbols to rep_ins and traverse it's inputs to add the same debug_symbols + * Add gathered debug_symbols to rep_ins and traverse its inputs to add the same debug_symbols * to instructions with empty debug_symbols. */ void module::propagate_replace_debug_symbols(instruction_ref rep_ins, From f9bb2c8aab19dadc3c787f0389a2d1a64f1cee06 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Feb 2026 16:30:35 -0600 Subject: [PATCH 019/107] Revert matcher.hpp --- src/include/migraphx/matcher.hpp | 38 -------------------------------- 1 file changed, 38 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index e55557c7572..36ad2fd9052 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -476,24 +476,6 @@ auto make_match_runner_with_trace(source_location location, Finder& f) } // If its already invalid dont validate it again bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); - - // Get debug symbols from matcher_context.instructions and matcher_result. - std::set symbols; - if(get_module(mod).get_use_debug_symbols()) - { - const auto& mr_result_symbols = r.result->get_debug_symbols(); - symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); - for(const auto& mr_ins : r.instructions) - { - const auto& ins_symbols = mr_ins.second->get_debug_symbols(); - symbols.insert(ins_symbols.begin(), ins_symbols.end()); - } - } - // `scoped_debug_symbols` for matcher apply. - optional debug_guard; - if(not symbols.empty()) - debug_guard.emplace(get_module(mod), symbols); - if(trace_enabled) { if(trace > 1) @@ -530,23 +512,6 @@ auto make_match_runner(Finder& f) match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) return false; - - // Get debug symbols from matcher_context.instructions and matcher_result. - std::set symbols; - if(get_module(mod).get_use_debug_symbols()) - { - const auto& mr_result_symbols = r.result->get_debug_symbols(); - symbols.insert(mr_result_symbols.begin(), mr_result_symbols.end()); - for(const auto& mr_ins : r.instructions) - { - const auto& ins_symbols = mr_ins.second->get_debug_symbols(); - symbols.insert(ins_symbols.begin(), ins_symbols.end()); - } - } - // `scoped_debug_symbols` for matcher apply. - optional debug_guard; - if(not symbols.empty()) - debug_guard.emplace(get_module(mod), symbols); f.apply(mod, r); return true; }; @@ -615,9 +580,6 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M } // If its already invalid dont validate it again bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); - optional debug_guard; - if(not r.result->get_debug_symbols().empty()) - debug_guard.emplace(get_module(mod), r.result->get_debug_symbols()); auto apply_time = time>([&] { m.apply(mod, r); }); if(time_matchers or trace_for) From 3d3bc4272a5227cf5b298efe932a60631870e91c Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Feb 2026 16:37:00 -0600 Subject: [PATCH 020/107] Revert env_var doc --- docs/reference/MIGraphX-dev-env-vars.rst | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 7253505ede9..87bcb91af93 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -514,14 +514,6 @@ Compilation tracing | Default: Quantization parameters aren't printed. - * - | ``MIGRAPHX_ENABLE_DEBUG_SYMBOLS`` - | Adds parsing and propagating of debug symbols through compiler passes such that the origin of instructions - | can be more easily determined. For ONNX models, the debug symbols are the ONNX node names. - - - | ``1``: Enable parsing and propagating debug symbols. - - | Default: Debug symbols are not parsed nor propagated. - MLIR ************************** From 66b0a2f60ce2075f5c0f9193d2fdababaada3ff4 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Feb 2026 17:21:59 -0600 Subject: [PATCH 021/107] Review progress --- CHANGELOG.md | 2 +- src/include/migraphx/module.hpp | 21 +-- src/instruction.cpp | 3 - src/module.cpp | 153 ++++++------------ .../include/migraphx/onnx/onnx_parser.hpp | 1 + src/onnx/onnx_parser.cpp | 49 +++--- src/simplify_algebra.cpp | 124 +------------- 7 files changed, 86 insertions(+), 267 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55880989be1..8e795276836 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ Full documentation for MIGraphX is available at * Added a dedicated logger for MIGraphX. * [Linux] Use HSA API to query number of chiplets for architectures where this is applicable (ex. gfx90a). -* Added debug symbols for MIGraphX instructions such that parsed and compiled instructions can be tracked back to their ONNX origin node node with MIGRAPHX_ENABLE_DEBUG_SYMBOLS (#4626) +* Added debug symbols for MIGraphX instructions such that parsed and compiled instructions can be tracked back to their ONNX origin node (#4626) ### Changed diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 1e3261d4b61..96112b07eda 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -53,20 +53,6 @@ using ins_dep_map = std::unordered_map symbols); - ~scoped_debug_symbols(); - scoped_debug_symbols(const scoped_debug_symbols&) = delete; - scoped_debug_symbols& operator=(const scoped_debug_symbols&) = delete; - scoped_debug_symbols(scoped_debug_symbols&& other) noexcept; - scoped_debug_symbols& operator=(scoped_debug_symbols&& other) noexcept; - - private: - module* mod; - std::set previous; -}; - /** * @brief Stores the instruction stream */ @@ -78,7 +64,7 @@ struct MIGRAPHX_EXPORT module const std::vector& inputs, const std::vector& mod_args)>; - module(const std::string& name = "", bool use_debug_symbols = false); + module(const std::string& name = ""); // move constructor module(module&&) noexcept; @@ -96,8 +82,9 @@ struct MIGRAPHX_EXPORT module bool bypass() const; void set_bypass(bool b = true); - bool get_use_debug_symbols() const; - void set_use_debug_symbols(bool b = true); + bool has_debug_symbols() const; + void add_debug_symbols(instruction_ref ins, std::set symbols); + void clear_debug_symbols(instruction_ref ins); template {}...)> instruction_ref add_instruction(operation op, Ts... args) diff --git a/src/instruction.cpp b/src/instruction.cpp index c63a2453e40..406c678242b 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -192,9 +192,6 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } -void instruction::add_debug_symbols(const std::set& symbols) -{ debug_symbols.insert(symbols.begin(), symbols.end()); } - bool operator==(const instruction& x, const instruction& y) { if(not std::equal(x.arguments.begin(), diff --git a/src/module.cpp b/src/module.cpp index fafa3a470ca..0be35410025 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -63,7 +63,7 @@ struct module_impl bool bypass = false; // used for skipping compiler passes bit_signal<64> changed{}; bool use_debug_symbols = false; - std::set active_debug_symbols; + std::size_t num_debug_symbols = 0; bool contains(instruction_ref ins) const { @@ -132,46 +132,8 @@ struct module_impl const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } -scoped_debug_symbols::scoped_debug_symbols(module& m, std::set symbols) - : mod(&m), previous(std::move(m.impl->active_debug_symbols)) -{ mod->impl->active_debug_symbols = std::move(symbols); } - -scoped_debug_symbols::~scoped_debug_symbols() -{ - if(mod != nullptr) - mod->impl->active_debug_symbols = std::move(previous); -} - -scoped_debug_symbols::scoped_debug_symbols(scoped_debug_symbols&& other) noexcept - : mod(other.mod), previous(std::move(other.previous)) -{ other.mod = nullptr; } - -scoped_debug_symbols& scoped_debug_symbols::operator=(scoped_debug_symbols&& other) noexcept -{ - if(this != &other) - { - if(mod != nullptr) - mod->impl->active_debug_symbols = std::move(previous); - mod = other.mod; - previous = std::move(other.previous); - other.mod = nullptr; - } - return *this; -} - -module::module(const std::string& name, - bool use_debug_symbols) :impl(std::make_unique()) -{ - impl->name = name; - if(enabled(MIGRAPHX_ENABLE_DEBUG_SYMBOLS{})) - { - impl->use_debug_symbols = true; - } - else - { - impl->use_debug_symbols = use_debug_symbols; - } -} +module::module(const std::string& name) :impl(std::make_unique()) +{ impl->name = name; } module::module(module&&) noexcept = default; module::~module() noexcept = default; @@ -193,8 +155,7 @@ void module::set_name(const std::string& name) { impl->name = name; } bool module::bypass() const { return impl->bypass; } void module::set_bypass(bool b) { impl->bypass = b; } -bool module::get_use_debug_symbols() const { return impl->use_debug_symbols; } -void module::set_use_debug_symbols(bool b) { impl->use_debug_symbols = b; } +bool module::has_debug_symbols() const { return impl->has_debug_symbols; } void module::assign(const module& m) { @@ -343,8 +304,6 @@ instruction_ref module::insert_instruction(instruction_ref ins, auto result = impl->insert(ins, {op, r, std::move(args)}); instruction::backreference(result); assert(result->valid(begin())); - if(not impl->active_debug_symbols.empty()) - result->add_debug_symbols(impl->active_debug_symbols); return result; } @@ -364,52 +323,54 @@ instruction_ref module::insert_instruction(instruction_ref ins, auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); instruction::backreference(result); assert(result->valid(begin())); - if(not impl->active_debug_symbols.empty()) - result->add_debug_symbols(impl->active_debug_symbols); return result; } /** - * old_ins : instruction that will be replaced - * Traverse inputs of old_ins and gather debug_symbols of instructions that will become dead code. - */ -std::set gather_replace_debug_symbols(instruction_ref old_ins) -{ - std::set debug_symbols; - if(starts_with(old_ins->name(), "@")) - return debug_symbols; - const auto& old_ins_debug = old_ins->get_debug_symbols(); - debug_symbols.insert(old_ins_debug.begin(), old_ins_debug.end()); - for(auto input : old_ins->inputs()) - { - // check if only output to old_ins - if(input->outputs().size() == 1 and input->outputs().at(0) == old_ins) + * Traverse inputs of `ins` and gather instructions that output only to `ins` (would become deadcode + * without `ins`). + **/ +static std::unordered_set gather_splice(module_ref m, instruction_ref ins) +{ + std::unordered_set result = {ins}; + // TODO: add visited list + fix([&](auto self, const std::vector& inputs) { + for(auto input : inputs) { - const auto& gdebug = gather_replace_debug_symbols(input); - debug_symbols.insert(gdebug.begin(), gdebug.end()); + if(not m->has_instruction(input)) + continue; + if(contains(result, input)) + continue; + if(any_of(input->outputs(), + [&](instruction_ref output) { return not contains(result, output); })) + continue; + result.insert(input); + self(input->inputs()); } - } - return debug_symbols; + }); + return result; } -/** - * Add gathered debug_symbols to rep_ins and traverse it's inputs to add the same debug_symbols - * to instructions with empty debug_symbols. - */ -void module::propagate_replace_debug_symbols(instruction_ref rep_ins, - const std::set& debug_symbols) +void propagate_debug_symbols(module_ref m, + bool has_debug_symbols, + instruction_ref ins, + instruction_ref rep) { - if(starts_with(rep_ins->name(), "@")) - return; - if(debug_symbols.empty()) - return; - rep_ins->add_debug_symbols(debug_symbols); - for(auto input : rep_ins->inputs()) + if(has_debug_symbols) { - auto input_ds = input->get_debug_symbols(); - if(input_ds.empty() or input_ds == impl->active_debug_symbols) + auto old_splice = gather_splice(m, ins); + auto new_splice = gather_splice(m, rep); + std::unordered_set symbols; + for(auto i : old_splice) { - propagate_replace_debug_symbols(input, debug_symbols); + copy(ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + for(auto i : new_splice) + { + for(const auto& symbol : symbols) + { + m->add_debug_symbol(i, symbol); + } } } } @@ -423,16 +384,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(not starts_with(op.name(), "@")); shape r = compute_shape(op, args); - if(get_use_debug_symbols()) - { - auto debug_symbols = gather_replace_debug_symbols(ins); - instruction::replace(ins, op, r, std::move(args)); - propagate_replace_debug_symbols(ins, debug_symbols); - } - else - { - instruction::replace(ins, op, r, std::move(args)); - } + propagate_debug_symbols(m, m.has_debug_symbols(), ins, rep); + instruction::replace(ins, op, r, std::move(args)); assert(ins->valid(begin())); return ins; } @@ -446,16 +399,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(has_instruction(ins)); assert(not starts_with(op.name(), "@")); auto out_shape = compute_shape(op, args, module_args); - if(get_use_debug_symbols()) - { - auto debug_symbols = gather_replace_debug_symbols(ins); - instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); - propagate_replace_debug_symbols(ins, debug_symbols); - } - else - { - instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); - } + propagate_debug_symbols(m, m.has_debug_symbols(), ins, rep); + instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); assert(ins->valid(begin())); return ins; } @@ -479,9 +424,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref return rep; } - std::set debug_symbols; - if(get_use_debug_symbols()) - debug_symbols = gather_replace_debug_symbols(ins); + propagate_debug_symbols(m, m.has_debug_symbols(), ins, rep); // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); for(auto out : outputs) @@ -493,8 +436,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref } assert(out->valid(begin())); } - if(get_use_debug_symbols()) - propagate_replace_debug_symbols(rep, debug_symbols); // Replacement should not be dead code unless its the last instruction assert(not rep->outputs().empty() or rep == std::prev(end())); @@ -698,8 +639,6 @@ instruction_ref module::insert_literal(instruction_ref ins, literal l) { impl->emplace(ins, std::move(l)); auto result = std::prev(ins); - if(not impl->active_debug_symbols.empty()) - result->add_debug_symbols(impl->active_debug_symbols); return result; } diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index e81544a1683..4f58cd085a9 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -104,6 +104,7 @@ struct onnx_parser std::unordered_map> map_dyn_input_dims; bool use_dyn_output = false; bool skip_unknown_operators = false; + bool use_debug_symbols = false; int64_t max_loop_iterations = 10; int64_t limit_max_iterations = std::numeric_limits::max(); int64_t opset_version = 13; diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index d83783fbd67..f295f60559c 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -44,7 +44,6 @@ inline namespace MIGRAPHX_INLINE_NS { namespace onnx { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_DEBUG_SYMBOLS) static shape shape_from_dyn_dims(shape::type_t shape_type, const std::vector& dyn_dims) @@ -310,9 +309,13 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) return version; } -static void print_added_instructions(module* mod, - const std::vector& args, - const std::vector& result) +/** + * Get the instructions added by the parser not in `args`. + * Does a DFS through inputs of result up to the instructions `args`. + */ +static std::vector +get_added_instructions(const std::vector& args, + const std::vector& result) { // Print instructions added by the parser not in args std::vector added_instructions; @@ -327,7 +330,7 @@ static void print_added_instructions(module* mod, added_instructions.push_back(ins); } })(result); - mod->debug_print(added_instructions); + return added_instructions; } static bool is_type_packed_int4(const onnx::TensorProto& t) @@ -583,6 +586,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini std::vector result; std::size_t output_num = node.output().size(); + std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); if(ops.count(node.op_type()) == 0) { if(skip_unknown_operators) @@ -592,31 +596,36 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini } else { - std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); - node_info ninfo{get_attributes(node), output_num, node_name, mod}; - if(mod->get_use_debug_symbols()) - { - std::string debug_symbol = node.name().empty() ? std::string("migx_uid:") + node_name : node.name(); - scoped_debug_symbols guard(*mod, {debug_symbol}); - result = ops[node.op_type()](*this, ninfo, args); - } - else - { - result = ops[node.op_type()](*this, ninfo, args); - } + result = ops[node.op_type()]( + *this, {get_attributes(node), output_num, node_name, mod}, args); } - output_num = std::min(output_num, result.size()); std::transform(node.output().begin(), node.output().begin() + output_num, result.begin(), std::inserter(instructions, instructions.end()), [](auto&& x, auto&& y) { return std::make_pair(x, y); }); - + auto added_instructions = get_added_instructions(args, result); + if(this->use_debug_symbols) + { + std::string debug_symbol = + node.name().empty() ? std::string("migx_uid:") + node_name : node.name(); + for(auto ins : added_instructions) + { + mod->add_debug_symbol(ins, debug_symbol); + } + } if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) { - print_added_instructions(mod, args, result); + mod->debug_print(added_instructions); } + + output_num = std::min(output_num, result.size()); + std::transform(node.output().begin(), + node.output().begin() + output_num, + result.begin(), + std::inserter(instructions, instructions.end()), + [](auto&& x, auto&& y) { return std::make_pair(x, y); }); } // Find instructions corresponding to the output diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index a99986d1234..d77fc63bf88 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1633,7 +1633,7 @@ struct find_add_convs } }; -MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ref ins) +MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) { // checking size to prevent matching block quantized quant_dot for now auto pred = [&](auto name) { @@ -1642,52 +1642,10 @@ MIGRAPHX_BASIC_MATCHER(horiz_conv_dot, match::matcher_context& ctx, instruction_ i->inputs().at(1)->can_eval() and i->inputs().size() == 2; }; }; - - // adding matched instructions to matcher_context to have their debug_symbols propagate - auto add_instructions_to_ctx = [&ctx](std::string key_prefix, - const std::vector& ins_vec) { - int count = 1; - for(instruction_ref d : ins_vec) - { - std::stringstream ss; - ss << key_prefix << "_" << count; - ctx.instructions[ss.str()] = d; - count++; - } - }; - bool found_horiz = false; - std::vector dots; - std::copy_if( - ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), pred("dot")); - std::vector qdots; - std::copy_if( - ins->outputs().begin(), ins->outputs().end(), std::back_inserter(qdots), pred("quant_dot")); - std::vector convs; - std::copy_if(ins->outputs().begin(), - ins->outputs().end(), - std::back_inserter(convs), - pred("convolution")); - if(dots.size() >= 2) - { - found_horiz = true; - add_instructions_to_ctx("dot", dots); - } - else if(qdots.size() >= 2) - { - found_horiz = true; - add_instructions_to_ctx("qdot", qdots); - } - else if(convs.size() >= 2) - { - found_horiz = true; - add_instructions_to_ctx("conv", convs); - } - - if(found_horiz) - { - return {ins}; - } - return nullopt; + auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); + auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot")); + auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); + return (dots >= 2 or convs >= 2 or qdots >= 2); } struct find_conv_dot_horiz_fusion @@ -2160,83 +2118,11 @@ struct find_split_transpose } }; -// When a convolution's input is a spatially-broadcast constant (e.g. a bias -// vector broadcast to [N, IC, H, W] with stride-0 spatial dims), the full -// spatial convolution is redundant. Replace it with: -// W_reduced[oc,ic] = sum_{kh,kw} W[oc,ic,kh,kw] (reduce_sum) -// result = dot(input_2d, W_reduced^T) (tiny GEMM) -// multibroadcast result to the original output shape -struct find_conv_broadcast_input -{ - auto matcher() const - { - return match::name("convolution")(match::args( - match::name("broadcast", "multibroadcast")(match::args(match::any().bind("x"))) - .bind("bcast"), - match::is_constant().bind("w"))); - } - - void apply(module& m, const match::matcher_result& r) const - { - auto ins = r.result; - auto x_ins = r.instructions["x"]; - auto w_ins = r.instructions["w"]; - - if(ins->get_operator().to_value()["group"].to() != 1) - return; - - const auto& x_shape = x_ins->get_shape(); - const auto& w_shape = w_ins->get_shape(); - - const auto& x_lens = x_shape.lens(); - if(x_lens.size() > 2 and - std::any_of(x_lens.begin() + 2, x_lens.end(), [](auto l) { return l != 1; })) - return; - - auto oc = w_shape.lens()[0]; - auto ic = w_shape.lens()[1]; - - auto out_lens = ins->get_shape().lens(); - auto n = out_lens[0]; - - if(x_shape.elements() != n * ic) - return; - - auto ndim = w_shape.ndim(); - std::vector spatial_axes(ndim - 2); - std::iota(spatial_axes.begin(), spatial_axes.end(), 2); - - auto w_reduced = - m.insert_instruction(ins, make_op("reduce_sum", {{"axes", spatial_axes}}), w_ins); - auto w_2d = m.insert_instruction( - ins, make_op("reshape", {{"dims", std::vector{oc, ic}}}), w_reduced); - auto w_t = m.insert_instruction( - ins, make_op("transpose", {{"permutation", std::vector{1, 0}}}), w_2d); - - instruction_ref x_2d; - if(x_shape.ndim() == 1 and n == 1) - x_2d = m.insert_instruction( - ins, make_op("unsqueeze", {{"axes", std::vector{0}}}), x_ins); - else - x_2d = m.insert_instruction( - ins, make_op("reshape", {{"dims", std::vector{n, ic}}}), x_ins); - - auto dot_result = m.insert_instruction(ins, make_op("dot"), x_2d, w_t); - - auto dot_1d = m.insert_instruction( - ins, make_op("squeeze", {{"axes", std::vector{0}}}), dot_result); - - m.replace_instruction( - ins, make_op("broadcast", {{"axis", 1}, {"out_lens", out_lens}}), dot_1d); - } -}; - void simplify_algebra::apply(module& m) const { // Run simplifications multiple times m.repeat_while_changes(8, [&] { match::find_matches(m, - find_conv_broadcast_input{}, find_inner_broadcast{}, find_dot_broadcast{}, find_double_add_lit_broadcast{}, From b3c8a1cacaea1bcd4f4a9b2fb90b35a96d8cb828 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Feb 2026 13:53:42 -0600 Subject: [PATCH 022/107] Module tracking number of instructions with debug_symbols --- src/include/migraphx/instruction.hpp | 4 ++++ src/include/migraphx/module.hpp | 5 ++++- src/instruction.cpp | 7 ++++++- src/module.cpp | 17 ++++++++++++++--- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 1050775a660..855ded1e782 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -98,8 +98,12 @@ struct MIGRAPHX_EXPORT instruction const std::set& get_debug_symbols() const; + /// Avoid using directly because module will not track number of debug symbols void add_debug_symbols(const std::set& symbols); + /// Avoid using directly because module will not track number of debug symbols + void rm_debug_symbols(); + MIGRAPHX_EXPORT friend bool operator==(const instruction& x, const instruction& y); MIGRAPHX_EXPORT friend bool operator!=(const instruction& x, const instruction& y); diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 96112b07eda..6b2fe80e41a 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -82,9 +82,12 @@ struct MIGRAPHX_EXPORT module bool bypass() const; void set_bypass(bool b = true); + /// Number of instructions with debug symbols bool has_debug_symbols() const; + /// Merge given symbols with instruction's symbols void add_debug_symbols(instruction_ref ins, std::set symbols); - void clear_debug_symbols(instruction_ref ins); + /// Clear all debug symbols from instruction + void rm_debug_symbols(instruction_ref ins); template {}...)> instruction_ref add_instruction(operation op, Ts... args) diff --git a/src/instruction.cpp b/src/instruction.cpp index 406c678242b..c1ea60aa349 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -192,6 +192,11 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } +void instruction::add_debug_symbols(const std::set& symbols) +{ debug_symbols.insert(symbols.begin(), symbols.end()); } + +void instruction::rm_debug_symbols() { debug_symbols.clear(); } + bool operator==(const instruction& x, const instruction& y) { if(not std::equal(x.arguments.begin(), @@ -447,7 +452,7 @@ void instruction::print(std::ostream& os, // print debug symbols if they exist if(not ins->debug_symbols.empty()) { - os << " /* " << join_strings(ins->debug_symbols, ", ") << " */"; + os << " # " << join_strings(ins->debug_symbols, ", ") << " #"; } } diff --git a/src/module.cpp b/src/module.cpp index 0be35410025..790a9a8d248 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -62,8 +62,7 @@ struct module_impl uint32_t nparams = 0; bool bypass = false; // used for skipping compiler passes bit_signal<64> changed{}; - bool use_debug_symbols = false; - std::size_t num_debug_symbols = 0; + std::size_t num_debug_symbols = 0; // number of ins with debug symbols bool contains(instruction_ref ins) const { @@ -155,7 +154,19 @@ void module::set_name(const std::string& name) { impl->name = name; } bool module::bypass() const { return impl->bypass; } void module::set_bypass(bool b) { impl->bypass = b; } -bool module::has_debug_symbols() const { return impl->has_debug_symbols; } +bool module::has_debug_symbols() const { return impl->num_debug_symbols; } + +void module::add_debug_symbols(instruction_ref ins, std::set symbols) +{ + ins->add_debug_symbols(symbols); + impl->num_debug_symbols++; +} + +void module::rm_debug_symbols(instruction_ref) +{ + impl->num_debug_symbols--; + ins->rm_debug_symbols(symbols); +} void module::assign(const module& m) { From 5830120541039510088a5816c9ed8666f6ea7e12 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Feb 2026 16:40:53 -0600 Subject: [PATCH 023/107] Cleanup --- src/include/migraphx/instruction.hpp | 5 ++- src/include/migraphx/module.hpp | 3 +- src/instruction.cpp | 5 +-- src/module.cpp | 48 ++++++++++++++++++---------- 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 855ded1e782..82f60dc9184 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -32,7 +32,6 @@ #include #include #include -#include #include #include @@ -96,10 +95,10 @@ struct MIGRAPHX_EXPORT instruction /// Where this instruction is used as an input to another instruction const std::vector& outputs() const; - const std::set& get_debug_symbols() const; + const std::unordered_set& get_debug_symbols() const; /// Avoid using directly because module will not track number of debug symbols - void add_debug_symbols(const std::set& symbols); + void add_debug_symbols(const std::unordered_set& symbols); /// Avoid using directly because module will not track number of debug symbols void rm_debug_symbols(); diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 6b2fe80e41a..f50a6fa1b3a 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -25,7 +25,6 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP #include -#include #include #include #include @@ -85,7 +84,7 @@ struct MIGRAPHX_EXPORT module /// Number of instructions with debug symbols bool has_debug_symbols() const; /// Merge given symbols with instruction's symbols - void add_debug_symbols(instruction_ref ins, std::set symbols); + void add_debug_symbols(instruction_ref ins, std::unordered_set symbols); /// Clear all debug symbols from instruction void rm_debug_symbols(instruction_ref ins); diff --git a/src/instruction.cpp b/src/instruction.cpp index c1ea60aa349..f5d0acb0015 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -190,9 +190,10 @@ const std::vector& instruction::module_inputs() const { return modul const std::vector& instruction::outputs() const { return output; } -const std::set& instruction::get_debug_symbols() const { return debug_symbols; } +const std::unordered_set& instruction::get_debug_symbols() const +{ return debug_symbols; } -void instruction::add_debug_symbols(const std::set& symbols) +void instruction::add_debug_symbols(const std::unordered_set& symbols) { debug_symbols.insert(symbols.begin(), symbols.end()); } void instruction::rm_debug_symbols() { debug_symbols.clear(); } diff --git a/src/module.cpp b/src/module.cpp index 790a9a8d248..f700ba86b01 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -156,16 +156,16 @@ void module::set_bypass(bool b) { impl->bypass = b; } bool module::has_debug_symbols() const { return impl->num_debug_symbols; } -void module::add_debug_symbols(instruction_ref ins, std::set symbols) +void module::add_debug_symbols(instruction_ref ins, std::unordered_set symbols) { ins->add_debug_symbols(symbols); impl->num_debug_symbols++; } -void module::rm_debug_symbols(instruction_ref) +void module::rm_debug_symbols(instruction_ref ins) { + ins->rm_debug_symbols(); impl->num_debug_symbols--; - ins->rm_debug_symbols(symbols); } void module::assign(const module& m) @@ -362,26 +362,23 @@ static std::unordered_set gather_splice(module_ref m, instructi return result; } -void propagate_debug_symbols(module_ref m, - bool has_debug_symbols, - instruction_ref ins, - instruction_ref rep) +static void propagate_debug_symbols(module_ref m, + bool has_debug_symbols, + instruction_ref ins, + instruction_ref rep) { if(has_debug_symbols) { auto old_splice = gather_splice(m, ins); auto new_splice = gather_splice(m, rep); std::unordered_set symbols; - for(auto i : old_splice) + for(auto old_ins : old_splice) { - copy(ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); } - for(auto i : new_splice) + for(auto new_ins : new_splice) { - for(const auto& symbol : symbols) - { - m->add_debug_symbol(i, symbol); - } + m->add_debug_symbols(new_ins, symbols); } } } @@ -395,8 +392,26 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(not starts_with(op.name(), "@")); shape r = compute_shape(op, args); - propagate_debug_symbols(m, m.has_debug_symbols(), ins, rep); + std::unordered_set old_splice; + std::unordered_set new_splice; + if(has_debug_symbols()) + { + old_splice = gather_splice(this, ins); + } instruction::replace(ins, op, r, std::move(args)); + if(has_debug_symbols()) + { + new_splice = gather_splice(this, ins); + std::unordered_set symbols; + for(auto old_ins : old_splice) + { + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + for(auto new_ins : new_splice) + { + this->add_debug_symbols(new_ins, symbols); + } + } assert(ins->valid(begin())); return ins; } @@ -649,8 +664,7 @@ instruction_ref module::add_return(std::vector args) instruction_ref module::insert_literal(instruction_ref ins, literal l) { impl->emplace(ins, std::move(l)); - auto result = std::prev(ins); - return result; + return std::prev(ins); } instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s) From 593fc75a249488b3a4883af48074749166cf0093 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Feb 2026 16:41:45 -0600 Subject: [PATCH 024/107] Revert simplify_algebra --- src/simplify_algebra.cpp | 72 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index d77fc63bf88..1cf1735efbc 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -2118,11 +2118,83 @@ struct find_split_transpose } }; +// When a convolution's input is a spatially-broadcast constant (e.g. a bias +// vector broadcast to [N, IC, H, W] with stride-0 spatial dims), the full +// spatial convolution is redundant. Replace it with: +// W_reduced[oc,ic] = sum_{kh,kw} W[oc,ic,kh,kw] (reduce_sum) +// result = dot(input_2d, W_reduced^T) (tiny GEMM) +// multibroadcast result to the original output shape +struct find_conv_broadcast_input +{ + auto matcher() const + { + return match::name("convolution")(match::args( + match::name("broadcast", "multibroadcast")(match::args(match::any().bind("x"))) + .bind("bcast"), + match::is_constant().bind("w"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + auto w_ins = r.instructions["w"]; + + if(ins->get_operator().to_value()["group"].to() != 1) + return; + + const auto& x_shape = x_ins->get_shape(); + const auto& w_shape = w_ins->get_shape(); + + const auto& x_lens = x_shape.lens(); + if(x_lens.size() > 2 and + std::any_of(x_lens.begin() + 2, x_lens.end(), [](auto l) { return l != 1; })) + return; + + auto oc = w_shape.lens()[0]; + auto ic = w_shape.lens()[1]; + + auto out_lens = ins->get_shape().lens(); + auto n = out_lens[0]; + + if(x_shape.elements() != n * ic) + return; + + auto ndim = w_shape.ndim(); + std::vector spatial_axes(ndim - 2); + std::iota(spatial_axes.begin(), spatial_axes.end(), 2); + + auto w_reduced = + m.insert_instruction(ins, make_op("reduce_sum", {{"axes", spatial_axes}}), w_ins); + auto w_2d = m.insert_instruction( + ins, make_op("reshape", {{"dims", std::vector{oc, ic}}}), w_reduced); + auto w_t = m.insert_instruction( + ins, make_op("transpose", {{"permutation", std::vector{1, 0}}}), w_2d); + + instruction_ref x_2d; + if(x_shape.ndim() == 1 and n == 1) + x_2d = m.insert_instruction( + ins, make_op("unsqueeze", {{"axes", std::vector{0}}}), x_ins); + else + x_2d = m.insert_instruction( + ins, make_op("reshape", {{"dims", std::vector{n, ic}}}), x_ins); + + auto dot_result = m.insert_instruction(ins, make_op("dot"), x_2d, w_t); + + auto dot_1d = m.insert_instruction( + ins, make_op("squeeze", {{"axes", std::vector{0}}}), dot_result); + + m.replace_instruction( + ins, make_op("broadcast", {{"axis", 1}, {"out_lens", out_lens}}), dot_1d); + } +}; + void simplify_algebra::apply(module& m) const { // Run simplifications multiple times m.repeat_while_changes(8, [&] { match::find_matches(m, + find_conv_broadcast_input{}, find_inner_broadcast{}, find_dot_broadcast{}, find_double_add_lit_broadcast{}, From fbdd1e8fddb8b7b1a39bd275949afda482f3c9df Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 3 Mar 2026 14:37:25 -0600 Subject: [PATCH 025/107] Fix implementation and tests --- src/include/migraphx/instruction.hpp | 9 +- src/include/migraphx/module.hpp | 19 ++- src/instruction.cpp | 7 +- src/module.cpp | 129 ++++++++++----- src/onnx/onnx_parser.cpp | 2 +- src/simplify_algebra.cpp | 10 +- test/debug_symbols_test.cpp | 231 +++++++-------------------- 7 files changed, 180 insertions(+), 227 deletions(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 82f60dc9184..45cc5f9918c 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -31,8 +31,9 @@ #include #include #include -#include +#include #include +#include #include namespace migraphx { @@ -95,13 +96,13 @@ struct MIGRAPHX_EXPORT instruction /// Where this instruction is used as an input to another instruction const std::vector& outputs() const; - const std::unordered_set& get_debug_symbols() const; + const std::set& get_debug_symbols() const; /// Avoid using directly because module will not track number of debug symbols - void add_debug_symbols(const std::unordered_set& symbols); + void add_debug_symbols(const std::set& symbols); /// Avoid using directly because module will not track number of debug symbols - void rm_debug_symbols(); + void remove_debug_symbols(); MIGRAPHX_EXPORT friend bool operator==(const instruction& x, const instruction& y); diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index f50a6fa1b3a..bce7b4f75e9 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -81,12 +81,12 @@ struct MIGRAPHX_EXPORT module bool bypass() const; void set_bypass(bool b = true); - /// Number of instructions with debug symbols + /// If any instructions in this module have debug symbols bool has_debug_symbols() const; /// Merge given symbols with instruction's symbols - void add_debug_symbols(instruction_ref ins, std::unordered_set symbols); + void add_debug_symbols(instruction_ref ins, const std::set& symbols) const; /// Clear all debug symbols from instruction - void rm_debug_symbols(instruction_ref ins); + void remove_debug_symbols(instruction_ref ins) const; template {}...)> instruction_ref add_instruction(operation op, Ts... args) @@ -127,6 +127,19 @@ struct MIGRAPHX_EXPORT module instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); + struct instruction_replacer + { + instruction_ref ins; + operation op; + std::vector args; + std::vector module_args; + }; + + /// Replaces an array of instructions within the same function to propertly handle debug symbols + /// propagation Returns vector of instruction_ref to replaced instructions + std::vector + batch_replace_instruction(const std::vector& replacers); + instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instructions(instruction_ref first, instruction_ref last); diff --git a/src/instruction.cpp b/src/instruction.cpp index f5d0acb0015..547802a9783 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -190,13 +190,12 @@ const std::vector& instruction::module_inputs() const { return modul const std::vector& instruction::outputs() const { return output; } -const std::unordered_set& instruction::get_debug_symbols() const -{ return debug_symbols; } +const std::set& instruction::get_debug_symbols() const { return debug_symbols; } -void instruction::add_debug_symbols(const std::unordered_set& symbols) +void instruction::add_debug_symbols(const std::set& symbols) { debug_symbols.insert(symbols.begin(), symbols.end()); } -void instruction::rm_debug_symbols() { debug_symbols.clear(); } +void instruction::remove_debug_symbols() { debug_symbols.clear(); } bool operator==(const instruction& x, const instruction& y) { diff --git a/src/module.cpp b/src/module.cpp index f700ba86b01..4af989ab0ea 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -51,7 +51,6 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_DEBUG_SYMBOLS) struct module_impl { @@ -62,7 +61,7 @@ struct module_impl uint32_t nparams = 0; bool bypass = false; // used for skipping compiler passes bit_signal<64> changed{}; - std::size_t num_debug_symbols = 0; // number of ins with debug symbols + std::size_t num_ins_with_debug_symbols = 0; // number of ins with debug symbols bool contains(instruction_ref ins) const { @@ -154,18 +153,24 @@ void module::set_name(const std::string& name) { impl->name = name; } bool module::bypass() const { return impl->bypass; } void module::set_bypass(bool b) { impl->bypass = b; } -bool module::has_debug_symbols() const { return impl->num_debug_symbols; } +bool module::has_debug_symbols() const { return impl->num_ins_with_debug_symbols > 0; } -void module::add_debug_symbols(instruction_ref ins, std::unordered_set symbols) +void module::add_debug_symbols(instruction_ref ins, const std::set& symbols) const { + if(ins->get_debug_symbols().empty()) + { + impl->num_ins_with_debug_symbols++; + } ins->add_debug_symbols(symbols); - impl->num_debug_symbols++; } -void module::rm_debug_symbols(instruction_ref ins) +void module::remove_debug_symbols(instruction_ref ins) const { - ins->rm_debug_symbols(); - impl->num_debug_symbols--; + if(not ins->get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) + { + impl->num_ins_with_debug_symbols--; + } + ins->remove_debug_symbols(); } void module::assign(const module& m) @@ -341,15 +346,16 @@ instruction_ref module::insert_instruction(instruction_ref ins, * Traverse inputs of `ins` and gather instructions that output only to `ins` (would become deadcode * without `ins`). **/ -static std::unordered_set gather_splice(module_ref m, instruction_ref ins) +static std::unordered_set gather_splice(const_module_ref m, instruction_ref ins) { std::unordered_set result = {ins}; - // TODO: add visited list fix([&](auto self, const std::vector& inputs) { for(auto input : inputs) { if(not m->has_instruction(input)) continue; + if(starts_with(input->name(), "@")) + continue; if(contains(result, input)) continue; if(any_of(input->outputs(), @@ -358,31 +364,10 @@ static std::unordered_set gather_splice(module_ref m, instructi result.insert(input); self(input->inputs()); } - }); + })(ins->inputs()); return result; } -static void propagate_debug_symbols(module_ref m, - bool has_debug_symbols, - instruction_ref ins, - instruction_ref rep) -{ - if(has_debug_symbols) - { - auto old_splice = gather_splice(m, ins); - auto new_splice = gather_splice(m, rep); - std::unordered_set symbols; - for(auto old_ins : old_splice) - { - copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); - } - for(auto new_ins : new_splice) - { - m->add_debug_symbols(new_ins, symbols); - } - } -} - instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST @@ -393,7 +378,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, shape r = compute_shape(op, args); std::unordered_set old_splice; - std::unordered_set new_splice; if(has_debug_symbols()) { old_splice = gather_splice(this, ins); @@ -401,8 +385,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction::replace(ins, op, r, std::move(args)); if(has_debug_symbols()) { - new_splice = gather_splice(this, ins); - std::unordered_set symbols; + std::unordered_set new_splice = gather_splice(this, ins); + std::set symbols; for(auto old_ins : old_splice) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); @@ -425,12 +409,47 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(has_instruction(ins)); assert(not starts_with(op.name(), "@")); auto out_shape = compute_shape(op, args, module_args); - propagate_debug_symbols(m, m.has_debug_symbols(), ins, rep); + std::unordered_set old_splice; + if(has_debug_symbols()) + { + old_splice = gather_splice(this, ins); + } instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); + if(has_debug_symbols()) + { + std::unordered_set new_splice = gather_splice(this, ins); + std::set symbols; + for(auto old_ins : old_splice) + { + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + for(auto new_ins : new_splice) + { + this->add_debug_symbols(new_ins, symbols); + } + } assert(ins->valid(begin())); return ins; } +static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, instruction_ref rep) +{ + if(m->has_debug_symbols()) + { + auto old_splice = gather_splice(m, ins); + auto new_splice = gather_splice(m, rep); + std::set symbols; + for(auto old_ins : old_splice) + { + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + for(auto new_ins : new_splice) + { + m->add_debug_symbols(new_ins, symbols); + } + } +} + instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) { impl->changed.notify(); @@ -450,7 +469,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref return rep; } - propagate_debug_symbols(m, m.has_debug_symbols(), ins, rep); + propagate_debug_symbols(this, ins, rep); // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); for(auto out : outputs) @@ -474,6 +493,42 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref return rep; } +std::vector +module::batch_replace_instruction(const std::vector& replacers) +{ + std::vector ret; + std::set symbols; + if(has_debug_symbols()) + { + // gather all previous debug symbols from splices + for(const auto& replacer : replacers) + { + std::unordered_set old_splice = gather_splice(this, replacer.ins); + for(auto old_ins : old_splice) + { + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + } + } + + for(const auto& replacer : replacers) + { + auto out_shape = compute_shape(replacer.op, replacer.args, replacer.module_args); + instruction::replace( + replacer.ins, replacer.op, out_shape, replacer.args, replacer.module_args); + if(has_debug_symbols()) + { + std::unordered_set new_splice = gather_splice(this, replacer.ins); + for(auto new_ins : new_splice) + { + this->add_debug_symbols(new_ins, symbols); + } + } + ret.push_back(replacer.ins); + } + return ret; +} + instruction_ref module::remove_instruction(instruction_ref ins) { assert(has_instruction(ins)); diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index f295f60559c..8b805c466f9 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -612,7 +612,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini node.name().empty() ? std::string("migx_uid:") + node_name : node.name(); for(auto ins : added_instructions) { - mod->add_debug_symbol(ins, debug_symbol); + mod->add_debug_symbols(ins, {debug_symbol}); } } if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 1cf1735efbc..8d74c1b139e 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1704,19 +1704,19 @@ struct find_conv_dot_horiz_fusion auto concat = m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); auto fused = m.insert_instruction(std::next(input), op, input, concat); + std::vector replacers; int64_t offset = 0; for(auto arg : range(start, last)) { auto outputs = arg->outputs(); int64_t len = arg->get_shape().lens()[axis]; - m.replace_instruction( - arg, - make_op("slice", - {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), - fused); + auto slice_op = make_op( + "slice", {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}); + replacers.push_back(module::instruction_replacer{arg, slice_op, {fused}, {}}); offset += len; } + m.batch_replace_instruction(replacers); }; auto outputs = ins->outputs(); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index d0460420921..ca33c99ab66 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -44,10 +44,10 @@ // // x y x y z // \ / \ | / -// add {add0} pointwise {add0, add1} +// add {add1} pointwise {add1, add2} // | z | // | / @return -// add {add1} +// add {add2} // | // @return // @@ -56,21 +56,14 @@ TEST_CASE(pw_double_add) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - mm->set_use_debug_symbols(); + auto* mm = p1.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto z = mm->add_parameter("z", s); - migraphx::instruction_ref add1; - { - migraphx::scoped_debug_symbols guard0(*mm, {"onnx:add0"}); - add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - } - migraphx::instruction_ref add2; - { - migraphx::scoped_debug_symbols guard1(*mm, {"onnx:add1"}); - add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); - } + migraphx::instruction_ref add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(add1, {"add1"}); + migraphx::instruction_ref add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + mm->add_debug_symbols(add2, {"add2"}); mm->add_return({add2}); } migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); @@ -78,7 +71,6 @@ TEST_CASE(pw_double_add) migraphx::program p2; { auto* mm = p2.get_main_module(); - mm->set_use_debug_symbols(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto z = mm->add_parameter("z", s); @@ -87,7 +79,7 @@ TEST_CASE(pw_double_add) auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); }); - fadd->add_debug_symbols({"onnx:add0", "onnx:add1"}); + mm->add_debug_symbols(fadd, {"add1", "add2"}); mm->add_return({fadd}); } // BUG straight equality is not working even though both call migraphx::to_string @@ -118,26 +110,24 @@ TEST_CASE(pw_used_twice_fused) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - mm->set_use_debug_symbols(); + auto* mm = p1.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - add1->add_debug_symbols({"onnx:add1"}); + mm->add_debug_symbols(add1, {"onnx:add1"}); auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x); - add2->add_debug_symbols({"onnx:add2"}); + mm->add_debug_symbols(add2, {"onnx:add2"}); auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y); - add3->add_debug_symbols({"onnx:add3"}); + mm->add_debug_symbols(add3, {"onnx:add3"}); auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3); - add4->add_debug_symbols({"onnx:add4"}); + mm->add_debug_symbols(add4, {"onnx:add4"}); mm->add_return({add4}); } migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); migraphx::program p2; { - auto* mm = p2.get_main_module(); - mm->set_use_debug_symbols(); + auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { @@ -146,10 +136,9 @@ TEST_CASE(pw_used_twice_fused) auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]); return pm->add_instruction(migraphx::make_op("add"), add2, add3); }); - fadd->add_debug_symbols({"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); + mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); mm->add_return({fadd}); } - // BUG straight equality is not working even though both call migraphx::to_string EXPECT(to_string(p1.sort()) == to_string(p2.sort())); } @@ -160,18 +149,16 @@ TEST_CASE(pw_used_twice_fused) // // Before: After: // -// input a input b a b +// a input b a b // \ / \ / \ / -// dot {g1} dot {g2} concat {g1, g2} -// \ / input | +// dot {g1} dot {g2} concat {g1, g2} +// \ / input | // \ / \ | // add {sum} dot {g1, g2} -// | / \ -// pass slice{g1} slice{g2} +// / \ +// slice{g1, g2} slice{g1, g2} // \ / -// add {sum} -// | -// pass +// add {sum} // TEST_CASE(horiz_fusion_dot) { @@ -179,41 +166,38 @@ TEST_CASE(horiz_fusion_dot) auto s = migraphx::shape{type, {3, 2, 2}}; migraphx::module m1; { - m1.set_use_debug_symbols(); auto input = m1.add_parameter("input", s); auto a = m1.add_literal(migraphx::generate_literal(s, 0)); auto b = m1.add_literal(migraphx::generate_literal(s, 1)); auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); - x->add_debug_symbols({"gemm1"}); + m1.add_debug_symbols(x, {"gemm1"}); auto y = m1.add_instruction(migraphx::make_op("dot"), input, b); - y->add_debug_symbols({"gemm2"}); + m1.add_debug_symbols(y, {"gemm2"}); auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); - sum->add_debug_symbols({"sum"}); - m1.add_instruction(pass_op{}, sum); + m1.add_debug_symbols(sum, {"sum"}); + m1.add_return({sum}); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); migraphx::module m2; { - m2.set_use_debug_symbols(); auto input = m2.add_parameter("input", s); auto a = m2.add_literal(migraphx::generate_literal(s, 0)); auto b = m2.add_literal(migraphx::generate_literal(s, 1)); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); - concat->add_debug_symbols({"gemm1", "gemm2"}); + m2.add_debug_symbols(concat, {"gemm1", "gemm2"}); auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); - dot->add_debug_symbols({"gemm1", "gemm2"}); + m2.add_debug_symbols(dot, {"gemm1", "gemm2"}); auto x = m2.add_instruction( migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); - x->add_debug_symbols({"gemm1"}); + m2.add_debug_symbols(x, {"gemm1", "gemm2"}); auto y = m2.add_instruction( migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); - y->add_debug_symbols({"gemm2"}); + m2.add_debug_symbols(y, {"gemm1", "gemm2"}); auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); - sum->add_debug_symbols({"sum"}); - m2.add_instruction(pass_op{}, sum); + m2.add_debug_symbols(sum, {"sum"}); + m2.add_return({sum}); } - // BUG straight equality is not working even though both call migraphx::to_string EXPECT(to_string(m1.sort()) == to_string(m2.sort())); } @@ -234,34 +218,32 @@ TEST_CASE(simplify_add_debug_symbols) { migraphx::module m1; { - m1.set_use_debug_symbols(); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = m1.add_literal(1); auto two = m1.add_literal(2); auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); - sum1->add_debug_symbols({"onnx:add1"}); + m1.add_debug_symbols(sum1, {"onnx:add1"}); auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); - sum2->add_debug_symbols({"onnx:add2"}); + m1.add_debug_symbols(sum2, {"onnx:add2"}); auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); - sum3->add_debug_symbols({"onnx:add0"}); + m1.add_debug_symbols(sum3, {"onnx:add0"}); m1.add_instruction(pass_op{}, sum3); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); migraphx::module m2; { - m2.set_use_debug_symbols(); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = m2.add_literal(1); auto two = m2.add_literal(2); auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); - sum1->add_debug_symbols({"onnx:add0", "onnx:add1", "onnx:add2"}); + m2.add_debug_symbols(sum1, {"onnx:add0", "onnx:add1", "onnx:add2"}); auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); - sum2->add_debug_symbols({"onnx:add0", "onnx:add1", "onnx:add2"}); + m2.add_debug_symbols(sum2, {"onnx:add0", "onnx:add1", "onnx:add2"}); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); - sum3->add_debug_symbols({"onnx:add0", "onnx:add1", "onnx:add2"}); + m2.add_debug_symbols(sum3, {"onnx:add0", "onnx:add1", "onnx:add2"}); m2.add_instruction(pass_op{}, sum3); } EXPECT(to_string(m1.sort()) == to_string(m2.sort())); @@ -285,26 +267,24 @@ TEST_CASE(replace_with_insref_debug_symbols) { migraphx::module m1; { - m1.set_use_debug_symbols(); auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); auto zero = m1.add_literal( migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {0.0f}}); auto bcast = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), zero); auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); - relu_x->add_debug_symbols({"onnx:relu"}); + m1.add_debug_symbols(relu_x, {"onnx:relu"}); auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); - add_r->add_debug_symbols({"onnx:add"}); + m1.add_debug_symbols(add_r, {"onnx:add"}); m1.add_instruction(pass_op{}, add_r); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); migraphx::module m2; { - m2.set_use_debug_symbols(); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); auto relu_x = m2.add_instruction(migraphx::make_op("relu"), x); - relu_x->add_debug_symbols({"onnx:add", "onnx:relu"}); + m2.add_debug_symbols(relu_x, {"onnx:add", "onnx:relu"}); m2.add_instruction(pass_op{}, relu_x); } EXPECT(to_string(m1.sort()) == to_string(m2.sort())); @@ -328,13 +308,12 @@ TEST_CASE(gather_replace_chain_debug_symbols) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::module m1; { - m1.set_use_debug_symbols(); auto x = m1.add_parameter("x", s); auto y = m1.add_parameter("y", s); auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); - relu_x->add_debug_symbols({"onnx:relu"}); + m1.add_debug_symbols(relu_x, {"onnx:relu"}); auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, y); - add_r->add_debug_symbols({"onnx:add"}); + m1.add_debug_symbols(add_r, {"onnx:add"}); m1.add_instruction(pass_op{}, add_r); auto mul_r = m1.insert_instruction(add_r, migraphx::make_op("mul"), x, y); @@ -344,11 +323,10 @@ TEST_CASE(gather_replace_chain_debug_symbols) migraphx::module m2; { - m2.set_use_debug_symbols(); auto x = m2.add_parameter("x", s); auto y = m2.add_parameter("y", s); auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, y); - mul_r->add_debug_symbols({"onnx:add", "onnx:relu"}); + m2.add_debug_symbols(mul_r, {"onnx:add", "onnx:relu"}); m2.add_instruction(pass_op{}, mul_r); } EXPECT(to_string(m1.sort()) == to_string(m2.sort())); @@ -372,30 +350,28 @@ TEST_CASE(simplify_mul_add_debug_symbols) { migraphx::module m1; { - m1.set_use_debug_symbols(); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto one = m1.add_literal(3); auto two = m1.add_literal(2); auto sum = m1.add_instruction(migraphx::make_op("add"), one, x); - sum->add_debug_symbols({"onnx:add"}); + m1.add_debug_symbols(sum, {"onnx:add"}); auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two); - mul->add_debug_symbols({"onnx:mul"}); + m1.add_debug_symbols(mul, {"onnx:mul"}); m1.add_instruction(pass_op{}, mul); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); migraphx::module m2; { - m2.set_use_debug_symbols(); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto one = m2.add_literal(3); auto two = m2.add_literal(2); auto mul1 = m2.add_instruction(migraphx::make_op("mul"), two, x); - mul1->add_debug_symbols({"onnx:add", "onnx:mul"}); + m2.add_debug_symbols(mul1, {"onnx:add", "onnx:mul"}); auto mul2 = m2.add_instruction(migraphx::make_op("mul"), two, one); - mul2->add_debug_symbols({"onnx:add", "onnx:mul"}); + m2.add_debug_symbols(mul2, {"onnx:add", "onnx:mul"}); auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); - sum->add_debug_symbols({"onnx:add", "onnx:mul"}); + m2.add_debug_symbols(sum, {"onnx:add", "onnx:mul"}); m2.add_instruction(pass_op{}, sum); } EXPECT(to_string(m1.sort()) == to_string(m2.sort())); @@ -420,70 +396,31 @@ TEST_CASE(simplify_div_const_debug_symbols) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::module m1; { - m1.set_use_debug_symbols(); auto x = m1.add_parameter("x", s); auto c = m1.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); auto div_r = m1.add_instruction(migraphx::make_op("div"), x, c); - div_r->add_debug_symbols({"onnx:div"}); + m1.add_debug_symbols(div_r, {"onnx:div"}); m1.add_instruction(pass_op{}, div_r); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); migraphx::module m2; { - m2.set_use_debug_symbols(); auto x = m2.add_parameter("x", s); auto c = m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); auto recip = m2.add_instruction(migraphx::make_op("recip"), c); - recip->add_debug_symbols({"onnx:div"}); + m2.add_debug_symbols(recip, {"onnx:div"}); auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, recip); - mul_r->add_debug_symbols({"onnx:div"}); + m2.add_debug_symbols(mul_r, {"onnx:div"}); m2.add_instruction(pass_op{}, mul_r); } EXPECT(to_string(m1.sort()) == to_string(m2.sort())); } -// Unit test for the scoped_debug_symbols RAII guard's save/restore behavior. -// An outer guard sets "outer", a nested inner guard temporarily replaces it -// with "inner", and after the inner guard's destructor runs the outer symbols -// are restored. Instructions created in each scope carry only the symbols -// active at the time they were added. -// -// scope active symbols instruction gets symbol -// ------- ---------------- ----------- ----------- -// outer {"outer"} add1(x, y) {"outer"} -// inner {"inner"} add2(add1, x) {"inner"} -// outer {"outer"} add3(add2, y) {"outer"} -// -TEST_CASE(scoped_debug_symbols_nesting) -{ - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::module m("test", true); - auto x = m.add_parameter("x", s); - auto y = m.add_parameter("y", s); - - migraphx::instruction_ref add1; - migraphx::instruction_ref add2; - migraphx::instruction_ref add3; - { - migraphx::scoped_debug_symbols outer(m, {"outer"}); - add1 = m.add_instruction(migraphx::make_op("add"), x, y); - { - migraphx::scoped_debug_symbols inner(m, {"inner"}); - add2 = m.add_instruction(migraphx::make_op("add"), add1, x); - } - add3 = m.add_instruction(migraphx::make_op("add"), add2, y); - } - - EXPECT(add1->get_debug_symbols() == std::set{"outer"}); - EXPECT(add2->get_debug_symbols() == std::set{"inner"}); - EXPECT(add3->get_debug_symbols() == std::set{"outer"}); -} - // Three sequential adds fused into a single pointwise op via fuse_pointwise. // All three ONNX node symbols should appear on the fused pointwise instruction. // Extends pw_double_add to a longer chain. @@ -507,18 +444,17 @@ TEST_CASE(pw_triple_add_fused) migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - mm->set_use_debug_symbols(); + auto* mm = p1.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto z = mm->add_parameter("z", s); auto w = mm->add_parameter("w", s); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - add1->add_debug_symbols({"onnx:add1"}); + mm->add_debug_symbols(add1, {"onnx:add1"}); auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); - add2->add_debug_symbols({"onnx:add2"}); + mm->add_debug_symbols(add2, {"onnx:add2"}); auto add3 = mm->add_instruction(migraphx::make_op("add"), add2, w); - add3->add_debug_symbols({"onnx:add3"}); + mm->add_debug_symbols(add3, {"onnx:add3"}); mm->add_return({add3}); } migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); @@ -526,7 +462,6 @@ TEST_CASE(pw_triple_add_fused) migraphx::program p2; { auto* mm = p2.get_main_module(); - mm->set_use_debug_symbols(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto z = mm->add_parameter("z", s); @@ -537,61 +472,12 @@ TEST_CASE(pw_triple_add_fused) auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); return pm->add_instruction(migraphx::make_op("add"), add2, inputs[3]); }); - fadd->add_debug_symbols({"onnx:add1", "onnx:add2", "onnx:add3"}); + mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3"}); mm->add_return({fadd}); } EXPECT(to_string(p1) == to_string(p2)); } -// Same add-reassociation pattern as simplify_add_debug_symbols but with -// set_use_debug_symbols() NOT called. -// -// Before: After (flag OFF): -// -// x 1 y 2 1 2 x y -// \ / \ / \ / \ / -// add1{a1} add2{a2} add{} add{} -// \ / \ / -// add0{a0} add{a0} -// | | -// pass pass -// -// (compare with simplify_add_debug_symbols where flag ON -// gives {a0, a1, a2} on every instruction) -// -TEST_CASE(no_propagation_without_flag) -{ - migraphx::module m1; - { - auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = m1.add_literal(1); - auto two = m1.add_literal(2); - auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); - sum1->add_debug_symbols({"onnx:add1"}); - auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); - sum2->add_debug_symbols({"onnx:add2"}); - auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); - sum3->add_debug_symbols({"onnx:add0"}); - m1.add_instruction(pass_op{}, sum3); - } - migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); - - migraphx::module m2; - { - auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = m2.add_literal(1); - auto two = m2.add_literal(2); - auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); - auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); - auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); - sum3->add_debug_symbols({"onnx:add0"}); - m2.add_instruction(pass_op{}, sum3); - } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); -} - // Verifies that debug symbols appear in the module's printed/serialized // output using the expected comment format produced by instruction::print. // @@ -602,15 +488,14 @@ TEST_CASE(debug_symbols_in_print) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::module m; - m.set_use_debug_symbols(); auto x = m.add_parameter("x", s); auto y = m.add_parameter("y", s); auto add = m.add_instruction(migraphx::make_op("add"), x, y); - add->add_debug_symbols({"sym_a", "sym_b"}); + m.add_debug_symbols(add, {"sym_a", "sym_b"}); m.add_instruction(pass_op{}, add); auto str = migraphx::to_string(m); - EXPECT(str.find("/* sym_a, sym_b */") != std::string::npos); + EXPECT(str.find("# sym_a, sym_b #") != std::string::npos); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From 4e1cee9340e973489b47aebfb0db249d387c6272 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 3 Mar 2026 19:25:32 -0600 Subject: [PATCH 026/107] splice fix and add gemm->add->add test --- src/include/migraphx/onnx.hpp | 2 + src/module.cpp | 137 +++++++++++++++++++--------------- src/onnx/onnx.cpp | 1 + test/debug_symbols_test.cpp | 48 +++++++++++- 4 files changed, 125 insertions(+), 63 deletions(-) diff --git a/src/include/migraphx/onnx.hpp b/src/include/migraphx/onnx.hpp index a85fa131236..8247dc73835 100644 --- a/src/include/migraphx/onnx.hpp +++ b/src/include/migraphx/onnx.hpp @@ -58,6 +58,8 @@ struct onnx_options int64_t limit_max_iterations = std::numeric_limits::max(); /// Use dynamic output for operators when available bool use_dyn_output = false; + /// Parse in ONNX node names as debug symbols + bool use_debug_symbols = false; /// Path to use for the external data if it is stored at different location compared to onnx /// file std::string external_data_path = ""; diff --git a/src/module.cpp b/src/module.cpp index 4af989ab0ea..77299a165c0 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -343,10 +343,12 @@ instruction_ref module::insert_instruction(instruction_ref ins, } /** - * Traverse inputs of `ins` and gather instructions that output only to `ins` (would become deadcode - * without `ins`). + * Traverse inputs of `ins` and gather instructions that output only to `ins`. + * This splice is the total possibility of instructions that could be spliced by a + * replace_instruction. **/ -static std::unordered_set gather_splice(const_module_ref m, instruction_ref ins) +static std::unordered_set gather_max_splice(const_module_ref m, + instruction_ref ins) { std::unordered_set result = {ins}; fix([&](auto self, const std::vector& inputs) { @@ -368,6 +370,37 @@ static std::unordered_set gather_splice(const_module_ref m, ins return result; } +void propagate_debug_symbols(const_module_ref m, + instruction_ref ins, + const std::unordered_set old_max_splice) +{ + std::unordered_set new_max_splice = gather_max_splice(m, ins); + // Remove instructions from new_max_splice that are also in old_max_splice to get the actual + // new_splice set notation: {new_max_splice} - {old_max_splice} + std::unordered_set new_splice; + std::copy_if( + new_max_splice.cbegin(), + new_max_splice.cend(), + std::inserter(new_splice, new_splice.begin()), + [&old_max_splice](auto new_ins) { return (not contains(old_max_splice, new_ins)); }); + // Vice versa process as new_splice + std::unordered_set old_splice; + std::copy_if( + old_max_splice.cbegin(), + old_max_splice.cend(), + std::inserter(old_splice, old_splice.begin()), + [&new_max_splice](auto old_ins) { return (not contains(new_max_splice, old_ins)); }); + std::set symbols; + for(auto old_ins : old_splice) + { + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + for(auto new_ins : new_splice) + { + m->add_debug_symbols(new_ins, symbols); + } +} + instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST @@ -377,24 +410,15 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(not starts_with(op.name(), "@")); shape r = compute_shape(op, args); - std::unordered_set old_splice; + std::unordered_set old_max_splice; if(has_debug_symbols()) { - old_splice = gather_splice(this, ins); + old_max_splice = gather_max_splice(this, ins); } instruction::replace(ins, op, r, std::move(args)); if(has_debug_symbols()) { - std::unordered_set new_splice = gather_splice(this, ins); - std::set symbols; - for(auto old_ins : old_splice) - { - copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); - } - for(auto new_ins : new_splice) - { - this->add_debug_symbols(new_ins, symbols); - } + propagate_debug_symbols(this, ins, old_max_splice); } assert(ins->valid(begin())); return ins; @@ -409,47 +433,20 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(has_instruction(ins)); assert(not starts_with(op.name(), "@")); auto out_shape = compute_shape(op, args, module_args); - std::unordered_set old_splice; + std::unordered_set old_max_splice; if(has_debug_symbols()) { - old_splice = gather_splice(this, ins); + old_max_splice = gather_max_splice(this, ins); } instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); if(has_debug_symbols()) { - std::unordered_set new_splice = gather_splice(this, ins); - std::set symbols; - for(auto old_ins : old_splice) - { - copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); - } - for(auto new_ins : new_splice) - { - this->add_debug_symbols(new_ins, symbols); - } + propagate_debug_symbols(this, ins, old_max_splice); } assert(ins->valid(begin())); return ins; } -static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, instruction_ref rep) -{ - if(m->has_debug_symbols()) - { - auto old_splice = gather_splice(m, ins); - auto new_splice = gather_splice(m, rep); - std::set symbols; - for(auto old_ins : old_splice) - { - copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); - } - for(auto new_ins : new_splice) - { - m->add_debug_symbols(new_ins, symbols); - } - } -} - instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) { impl->changed.notify(); @@ -469,7 +466,11 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref return rep; } - propagate_debug_symbols(this, ins, rep); + if(has_debug_symbols()) + { + auto old_max_splice = gather_max_splice(this, ins); + propagate_debug_symbols(this, rep, old_max_splice); + } // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); for(auto out : outputs) @@ -497,34 +498,52 @@ std::vector module::batch_replace_instruction(const std::vector& replacers) { std::vector ret; - std::set symbols; + std::unordered_set old_max_splices; if(has_debug_symbols()) { - // gather all previous debug symbols from splices + // gather all previous debug symbols from max splices for(const auto& replacer : replacers) { - std::unordered_set old_splice = gather_splice(this, replacer.ins); - for(auto old_ins : old_splice) - { - copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); - } + old_max_splices.merge(gather_max_splice(this, replacer.ins)); } } + std::unordered_set new_max_splices; for(const auto& replacer : replacers) { auto out_shape = compute_shape(replacer.op, replacer.args, replacer.module_args); instruction::replace( replacer.ins, replacer.op, out_shape, replacer.args, replacer.module_args); + ret.push_back(replacer.ins); if(has_debug_symbols()) { - std::unordered_set new_splice = gather_splice(this, replacer.ins); - for(auto new_ins : new_splice) - { - this->add_debug_symbols(new_ins, symbols); - } + new_max_splices.merge(gather_max_splice(this, replacer.ins)); + } + } + if(has_debug_symbols()) + { + std::unordered_set new_splices; + std::copy_if( + new_max_splices.cbegin(), + new_max_splices.cend(), + std::inserter(new_splices, new_splices.begin()), + [&old_max_splices](auto new_ins) { return (not contains(old_max_splices, new_ins)); }); + // Vice versa process as new_splice for the symbols + std::unordered_set old_splices; + std::copy_if( + old_max_splices.cbegin(), + old_max_splices.cend(), + std::inserter(old_splices, old_splices.begin()), + [&new_max_splices](auto old_ins) { return (not contains(new_max_splices, old_ins)); }); + std::set symbols; + for(auto old_ins : old_splices) + { + copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + } + for(auto new_ins : new_splices) + { + add_debug_symbols(new_ins, symbols); } - ret.push_back(replacer.ins); } return ret; } diff --git a/src/onnx/onnx.cpp b/src/onnx/onnx.cpp index 0517f38a997..7ea1edfb350 100644 --- a/src/onnx/onnx.cpp +++ b/src/onnx/onnx.cpp @@ -45,6 +45,7 @@ static program parse_onnx_from(const onnx_options& options, Ts&&... xs) parser.map_input_dims = options.map_input_dims; parser.dim_params = options.dim_params; parser.map_dyn_input_dims = options.map_dyn_input_dims; + parser.use_debug_symbols = options.use_debug_symbols; auto dim_val = options.default_dim_value; if(dim_val != 0) { diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index ca33c99ab66..49bdf0b1aaa 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -87,10 +87,6 @@ TEST_CASE(pw_double_add) EXPECT(to_string(p1) == to_string(p2)); } -// Diamond pattern: add1 feeds into both add2 and add3, which then feed -// into add4. All four are fused into one pointwise op. Verifies that -// symbols from every instruction in the diamond appear on the fused result. -// // Before: After: // // x y x y @@ -142,6 +138,50 @@ TEST_CASE(pw_used_twice_fused) EXPECT(to_string(p1.sort()) == to_string(p2.sort())); } +// To check that the debug symbols don't propagate above the fusion +TEST_CASE(gemm_pw) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s1); + auto a = mm->add_literal(migraphx::generate_literal(s2, 0)); + auto gemm = mm->add_instruction(migraphx::make_op("dot"), x, a); + mm->add_debug_symbols(gemm, {"gemm1"}); + auto add1 = mm->add_instruction(migraphx::make_op("add"), gemm, y); + mm->add_debug_symbols(add1, {"add1"}); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + mm->add_debug_symbols(add2, {"add2"}); + mm->add_return({add2}); + } + migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s1); + auto a = mm->add_literal(migraphx::generate_literal(s2, 0)); + auto gemm = mm->add_instruction(migraphx::make_op("dot"), x, a); + mm->add_debug_symbols(gemm, {"gemm1"}); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + mm->add_debug_symbols(fadd, {"add1", "add2"}); + mm->add_return({fadd}); + } + // BUG straight equality is not working even though both call migraphx::to_string + // EXPECT(p1 == p2); + EXPECT(to_string(p1) == to_string(p2)); +} + // Horizontal fusion of two dot ops sharing the same input via // simplify_algebra. The two dots are fused into concat + single dot + slices. // Each new instruction inherits the symbols of the original dots it derives From dadc312e268a8f5289a4806f7b9a40d7f78a2b95 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 3 Mar 2026 20:55:55 -0600 Subject: [PATCH 027/107] Refine the set difference algo and tests --- src/instruction.cpp | 9 ++++- src/module.cpp | 79 +++++++++++++++++++------------------ test/debug_symbols_test.cpp | 17 ++++---- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 547802a9783..87995344a63 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -484,7 +484,14 @@ void instruction::debug_print() const } if(not this->inputs().empty()) std::cout << ")"; - std::cout << " -> " << this->get_shape() << std::endl; + std::cout << " -> " << this->get_shape(); + + // print debug symbols if they exist + if(not debug_symbols.empty()) + { + std::cout << " # " << join_strings(debug_symbols, ", ") << " #"; + } + std::cout << std::endl; } std::vector instruction::get_output_alias(instruction_ref ins, bool shallow) diff --git a/src/module.cpp b/src/module.cpp index 77299a165c0..a3998f363a0 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -344,11 +344,9 @@ instruction_ref module::insert_instruction(instruction_ref ins, /** * Traverse inputs of `ins` and gather instructions that output only to `ins`. - * This splice is the total possibility of instructions that could be spliced by a - * replace_instruction. + * This splice is the total possibility of instructions that could be spliced by a replace_instruction. **/ -static std::unordered_set gather_max_splice(const_module_ref m, - instruction_ref ins) +static std::unordered_set gather_max_splice(const_module_ref m, instruction_ref ins) { std::unordered_set result = {ins}; fix([&](auto self, const std::vector& inputs) { @@ -370,27 +368,25 @@ static std::unordered_set gather_max_splice(const_module_ref m, return result; } -void propagate_debug_symbols(const_module_ref m, - instruction_ref ins, - const std::unordered_set old_max_splice) +void propagate_debug_symbols(const_module_ref m, instruction_ref ins, std::unordered_set old_max_splice) { + // Remove ins from old_max_splice, if it is there. To prevent it being in both old_max_splice and new_max_slice. + old_max_splice.erase(ins); std::unordered_set new_max_splice = gather_max_splice(m, ins); - // Remove instructions from new_max_splice that are also in old_max_splice to get the actual - // new_splice set notation: {new_max_splice} - {old_max_splice} - std::unordered_set new_splice; - std::copy_if( - new_max_splice.cbegin(), - new_max_splice.cend(), - std::inserter(new_splice, new_splice.begin()), - [&old_max_splice](auto new_ins) { return (not contains(old_max_splice, new_ins)); }); - // Vice versa process as new_splice + // Remove instructions from old_max_splice that are also in new_max_splice to get the actual old_splice + // set notation: old_splice = {old_max_splice} - {new_max_splice} std::unordered_set old_splice; - std::copy_if( - old_max_splice.cbegin(), - old_max_splice.cend(), - std::inserter(old_splice, old_splice.begin()), - [&new_max_splice](auto old_ins) { return (not contains(new_max_splice, old_ins)); }); - std::set symbols; + std::copy_if(old_max_splice.cbegin(), + old_max_splice.cend(), + std::inserter(old_splice, old_splice.begin()), + [&new_max_splice](auto old_ins){ return (not contains(new_max_splice, old_ins)); }); + // Vice versa process + std::unordered_set new_splice; + std::copy_if(new_max_splice.cbegin(), + new_max_splice.cend(), + std::inserter(new_splice, new_splice.begin()), + [&old_max_splice](auto new_ins){ return (not contains(old_max_splice, new_ins)); }); + std::set symbols = ins->get_debug_symbols(); for(auto old_ins : old_splice) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); @@ -418,7 +414,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction::replace(ins, op, r, std::move(args)); if(has_debug_symbols()) { - propagate_debug_symbols(this, ins, old_max_splice); + propagate_debug_symbols(this, ins, std::move(old_max_splice)); } assert(ins->valid(begin())); return ins; @@ -441,7 +437,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); if(has_debug_symbols()) { - propagate_debug_symbols(this, ins, old_max_splice); + propagate_debug_symbols(this, ins, std::move(old_max_splice)); } assert(ins->valid(begin())); return ins; @@ -465,11 +461,13 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref { return rep; } - + if(has_debug_symbols()) { auto old_max_splice = gather_max_splice(this, ins); - propagate_debug_symbols(this, rep, old_max_splice); + // Remove rep incase it shows up in old_max_splice + old_max_splice.erase(rep); + propagate_debug_symbols(this, rep, std::move(old_max_splice)); } // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); @@ -504,10 +502,14 @@ module::batch_replace_instruction(const std::vector& repla // gather all previous debug symbols from max splices for(const auto& replacer : replacers) { - old_max_splices.merge(gather_max_splice(this, replacer.ins)); + auto ms = gather_max_splice(this, replacer.ins); + // Remove ins from old_max_splice to prevent it being in both old_max_splice and new_max_slice + ms.erase(replacer.ins); + old_max_splices.merge(ms); } } + std::set symbols; std::unordered_set new_max_splices; for(const auto& replacer : replacers) { @@ -517,25 +519,24 @@ module::batch_replace_instruction(const std::vector& repla ret.push_back(replacer.ins); if(has_debug_symbols()) { + auto ds = replacer.ins->get_debug_symbols(); + // add symbols from replacer.ins here because we removed replacer.ins from old_max_splice + symbols.insert(ds.begin(), ds.end()); new_max_splices.merge(gather_max_splice(this, replacer.ins)); } } if(has_debug_symbols()) { std::unordered_set new_splices; - std::copy_if( - new_max_splices.cbegin(), - new_max_splices.cend(), - std::inserter(new_splices, new_splices.begin()), - [&old_max_splices](auto new_ins) { return (not contains(old_max_splices, new_ins)); }); - // Vice versa process as new_splice for the symbols + std::copy_if(new_max_splices.cbegin(), + new_max_splices.cend(), + std::inserter(new_splices, new_splices.begin()), + [&old_max_splices](auto new_ins){ return (not contains(old_max_splices, new_ins)); }); std::unordered_set old_splices; - std::copy_if( - old_max_splices.cbegin(), - old_max_splices.cend(), - std::inserter(old_splices, old_splices.begin()), - [&new_max_splices](auto old_ins) { return (not contains(new_max_splices, old_ins)); }); - std::set symbols; + std::copy_if(old_max_splices.cbegin(), + old_max_splices.cend(), + std::inserter(old_splices, old_splices.begin()), + [&new_max_splices](auto old_ins){ return (not contains(new_max_splices, old_ins)); }); for(auto old_ins : old_splices) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 49bdf0b1aaa..7b9517078d0 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -139,7 +139,7 @@ TEST_CASE(pw_used_twice_fused) } // To check that the debug symbols don't propagate above the fusion -TEST_CASE(gemm_pw) +TEST_CASE(gemm_add_add) { migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; migraphx::shape s2{migraphx::shape::float_type, {3, 3}}; @@ -170,7 +170,7 @@ TEST_CASE(gemm_pw) auto gemm = mm->add_instruction(migraphx::make_op("dot"), x, a); mm->add_debug_symbols(gemm, {"gemm1"}); auto fadd = - add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + add_pointwise(p2, "main:pointwise0", {gemm, y, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); }); @@ -178,7 +178,6 @@ TEST_CASE(gemm_pw) mm->add_return({fadd}); } // BUG straight equality is not working even though both call migraphx::to_string - // EXPECT(p1 == p2); EXPECT(to_string(p1) == to_string(p2)); } @@ -297,11 +296,9 @@ TEST_CASE(simplify_add_debug_symbols) // x 0 x // | | | // relu bcast relu {add, relu} -// {relu} (0.0) | -// \ / pass +// {relu} (0.0) +// \ / // add {add} -// | -// pass // TEST_CASE(replace_with_insref_debug_symbols) { @@ -316,7 +313,7 @@ TEST_CASE(replace_with_insref_debug_symbols) m1.add_debug_symbols(relu_x, {"onnx:relu"}); auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); m1.add_debug_symbols(add_r, {"onnx:add"}); - m1.add_instruction(pass_op{}, add_r); + m1.add_return({add_r}); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); @@ -325,7 +322,7 @@ TEST_CASE(replace_with_insref_debug_symbols) auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); auto relu_x = m2.add_instruction(migraphx::make_op("relu"), x); m2.add_debug_symbols(relu_x, {"onnx:add", "onnx:relu"}); - m2.add_instruction(pass_op{}, relu_x); + m2.add_return({relu_x}); } EXPECT(to_string(m1.sort()) == to_string(m2.sort())); } @@ -538,4 +535,6 @@ TEST_CASE(debug_symbols_in_print) EXPECT(str.find("# sym_a, sym_b #") != std::string::npos); } +//TODO make tests that directly call module::replace_instruction + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 9a32dc45b7fb37eaecbb07a5d9a51a4a98badedd Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 3 Mar 2026 20:56:26 -0600 Subject: [PATCH 028/107] Formatting --- src/instruction.cpp | 2 +- src/module.cpp | 61 ++++++++++++++++++++++--------------- test/debug_symbols_test.cpp | 2 +- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 87995344a63..eaddccef380 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -484,7 +484,7 @@ void instruction::debug_print() const } if(not this->inputs().empty()) std::cout << ")"; - std::cout << " -> " << this->get_shape(); + std::cout << " -> " << this->get_shape(); // print debug symbols if they exist if(not debug_symbols.empty()) diff --git a/src/module.cpp b/src/module.cpp index a3998f363a0..8b32cc0422e 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -344,9 +344,11 @@ instruction_ref module::insert_instruction(instruction_ref ins, /** * Traverse inputs of `ins` and gather instructions that output only to `ins`. - * This splice is the total possibility of instructions that could be spliced by a replace_instruction. + * This splice is the total possibility of instructions that could be spliced by a + * replace_instruction. **/ -static std::unordered_set gather_max_splice(const_module_ref m, instruction_ref ins) +static std::unordered_set gather_max_splice(const_module_ref m, + instruction_ref ins) { std::unordered_set result = {ins}; fix([&](auto self, const std::vector& inputs) { @@ -368,24 +370,29 @@ static std::unordered_set gather_max_splice(const_module_ref m, return result; } -void propagate_debug_symbols(const_module_ref m, instruction_ref ins, std::unordered_set old_max_splice) +void propagate_debug_symbols(const_module_ref m, + instruction_ref ins, + std::unordered_set old_max_splice) { - // Remove ins from old_max_splice, if it is there. To prevent it being in both old_max_splice and new_max_slice. + // Remove ins from old_max_splice, if it is there. To prevent it being in both old_max_splice + // and new_max_slice. old_max_splice.erase(ins); std::unordered_set new_max_splice = gather_max_splice(m, ins); - // Remove instructions from old_max_splice that are also in new_max_splice to get the actual old_splice - // set notation: old_splice = {old_max_splice} - {new_max_splice} + // Remove instructions from old_max_splice that are also in new_max_splice to get the actual + // old_splice set notation: old_splice = {old_max_splice} - {new_max_splice} std::unordered_set old_splice; - std::copy_if(old_max_splice.cbegin(), - old_max_splice.cend(), - std::inserter(old_splice, old_splice.begin()), - [&new_max_splice](auto old_ins){ return (not contains(new_max_splice, old_ins)); }); + std::copy_if( + old_max_splice.cbegin(), + old_max_splice.cend(), + std::inserter(old_splice, old_splice.begin()), + [&new_max_splice](auto old_ins) { return (not contains(new_max_splice, old_ins)); }); // Vice versa process std::unordered_set new_splice; - std::copy_if(new_max_splice.cbegin(), - new_max_splice.cend(), - std::inserter(new_splice, new_splice.begin()), - [&old_max_splice](auto new_ins){ return (not contains(old_max_splice, new_ins)); }); + std::copy_if( + new_max_splice.cbegin(), + new_max_splice.cend(), + std::inserter(new_splice, new_splice.begin()), + [&old_max_splice](auto new_ins) { return (not contains(old_max_splice, new_ins)); }); std::set symbols = ins->get_debug_symbols(); for(auto old_ins : old_splice) { @@ -461,7 +468,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref { return rep; } - + if(has_debug_symbols()) { auto old_max_splice = gather_max_splice(this, ins); @@ -503,7 +510,8 @@ module::batch_replace_instruction(const std::vector& repla for(const auto& replacer : replacers) { auto ms = gather_max_splice(this, replacer.ins); - // Remove ins from old_max_splice to prevent it being in both old_max_splice and new_max_slice + // Remove ins from old_max_splice to prevent it being in both old_max_splice and + // new_max_slice ms.erase(replacer.ins); old_max_splices.merge(ms); } @@ -520,7 +528,8 @@ module::batch_replace_instruction(const std::vector& repla if(has_debug_symbols()) { auto ds = replacer.ins->get_debug_symbols(); - // add symbols from replacer.ins here because we removed replacer.ins from old_max_splice + // add symbols from replacer.ins here because we removed replacer.ins from + // old_max_splice symbols.insert(ds.begin(), ds.end()); new_max_splices.merge(gather_max_splice(this, replacer.ins)); } @@ -528,15 +537,17 @@ module::batch_replace_instruction(const std::vector& repla if(has_debug_symbols()) { std::unordered_set new_splices; - std::copy_if(new_max_splices.cbegin(), - new_max_splices.cend(), - std::inserter(new_splices, new_splices.begin()), - [&old_max_splices](auto new_ins){ return (not contains(old_max_splices, new_ins)); }); + std::copy_if( + new_max_splices.cbegin(), + new_max_splices.cend(), + std::inserter(new_splices, new_splices.begin()), + [&old_max_splices](auto new_ins) { return (not contains(old_max_splices, new_ins)); }); std::unordered_set old_splices; - std::copy_if(old_max_splices.cbegin(), - old_max_splices.cend(), - std::inserter(old_splices, old_splices.begin()), - [&new_max_splices](auto old_ins){ return (not contains(new_max_splices, old_ins)); }); + std::copy_if( + old_max_splices.cbegin(), + old_max_splices.cend(), + std::inserter(old_splices, old_splices.begin()), + [&new_max_splices](auto old_ins) { return (not contains(new_max_splices, old_ins)); }); for(auto old_ins : old_splices) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 7b9517078d0..bb12ae010f8 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -535,6 +535,6 @@ TEST_CASE(debug_symbols_in_print) EXPECT(str.find("# sym_a, sym_b #") != std::string::npos); } -//TODO make tests that directly call module::replace_instruction +// TODO make tests that directly call module::replace_instruction int main(int argc, const char* argv[]) { test::run(argc, argv); } From bd1997a3d8a6a3a4dbdbeaa490f73e29d74f2f59 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Mar 2026 09:36:21 -0600 Subject: [PATCH 029/107] Formatting --- src/split_reduce.cpp | 4 +--- src/targets/gpu/compile_ops.cpp | 12 +++--------- src/targets/gpu/prepare_reduce.cpp | 12 +++--------- test/gpu/prepare_reduce.cpp | 4 +--- 4 files changed, 8 insertions(+), 24 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index af13343ee0f..9fdf9bd8485 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -75,9 +75,7 @@ struct split_fused_reduce MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) -{ - return contains(ins.name(), "reduce") or ins.name() == "argmin" or ins.name() == "argmax"; -} +{ return contains(ins.name(), "reduce") or ins.name() == "argmin" or ins.name() == "argmax"; } namespace { struct splitter diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 8685469cda2..918e6ca8b5b 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -102,21 +102,15 @@ struct dynamic_code_object_op template static auto reflect(Self& self, F f) - { - return pack(f(self.pre_op, "pre_op")); - } + { return pack(f(self.pre_op, "pre_op")); } std::string name() const { return "gpu::dynamic_code_object_op"; } shape compute_shape(const std::vector& inputs, const std::vector& mods) const - { - return pre_op.compute_shape(inputs, mods); - } + { return pre_op.compute_shape(inputs, mods); } std::vector output_alias(const std::vector& shapes) const - { - return {shapes.size() - 1}; - } + { return {shapes.size() - 1}; } std::unordered_map build_param_map(const std::vector& args, const_module_ref mod) const { diff --git a/src/targets/gpu/prepare_reduce.cpp b/src/targets/gpu/prepare_reduce.cpp index 46c8d6fa7ad..d53aedcf501 100644 --- a/src/targets/gpu/prepare_reduce.cpp +++ b/src/targets/gpu/prepare_reduce.cpp @@ -64,9 +64,7 @@ struct arg_reduce template static auto reflect(Self& self, F f) - { - return pack(f(self.op, "op")); - } + { return pack(f(self.op, "op")); } std::string name() const { return "gpu::arg_reduce"; } @@ -85,16 +83,12 @@ struct make_indices template static auto reflect(Self& self, F f) - { - return pack(f(self.size, "size")); - } + { return pack(f(self.size, "size")); } std::string name() const { return "gpu::make_indices"; } shape compute_shape(const std::vector&) const - { - return shape{shape::uint32_type, {size}}; - } + { return shape{shape::uint32_type, {size}}; } }; MIGRAPHX_REGISTER_OP(make_indices); diff --git a/test/gpu/prepare_reduce.cpp b/test/gpu/prepare_reduce.cpp index e45f01c134f..39b97b3a847 100644 --- a/test/gpu/prepare_reduce.cpp +++ b/test/gpu/prepare_reduce.cpp @@ -31,9 +31,7 @@ #include static void run_pass(migraphx::module& m) -{ - migraphx::run_passes(m, {migraphx::gpu::prepare_reduce{}, migraphx::dead_code_elimination{}}); -} +{ migraphx::run_passes(m, {migraphx::gpu::prepare_reduce{}, migraphx::dead_code_elimination{}}); } // Helper to add the arg_reduce pattern: make_indices -> arg_reduce -> get_tuple_elem static migraphx::instruction_ref add_arg_reduce(migraphx::module& m, From 573ee7f9d83181da11745beabd094915db6870ae Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Mar 2026 16:54:37 -0600 Subject: [PATCH 030/107] Update program serialization and fix EXPECT dangling ref --- src/module.cpp | 5 +++-- src/program.cpp | 5 ++++- test/debug_symbols_test.cpp | 23 ++++++++++------------- test/include/test.hpp | 34 ++++++++++++++++++++++------------ 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 7f965f739de..85c24a56024 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -514,7 +514,7 @@ module::batch_replace_instruction(const std::vector& repla // Remove ins from old_max_splice to prevent it being in both old_max_splice and // new_max_slice ms.erase(replacer.ins); - old_max_splices.merge(ms); + old_max_splices.insert(ms.begin(), ms.end()); } } @@ -532,7 +532,8 @@ module::batch_replace_instruction(const std::vector& repla // add symbols from replacer.ins here because we removed replacer.ins from // old_max_splice symbols.insert(ds.begin(), ds.end()); - new_max_splices.merge(gather_max_splice(this, replacer.ins)); + auto ms = gather_max_splice(this, replacer.ins); + new_max_splices.insert(ms.begin(), ms.end()); } } if(has_debug_symbols()) diff --git a/src/program.cpp b/src/program.cpp index 20418f88bfa..fa433b48436 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -685,7 +685,7 @@ static std::string get_migraphx_version() program file version is for the data structure or format of the MXR file. Version should be bumped if any changes occur to the format of the MXR file. */ -const int program_file_version = 7; +const int program_file_version = 8; value program::to_value() const { @@ -711,6 +711,7 @@ value program::to_value() const if(ins->name() == "@literal") node["literal"] = migraphx::to_value(ins->get_literal()); node["operator"] = ins->get_operator().to_value(); + node["debug_symbols"] = ins->get_debug_symbols(); std::vector inputs; std::transform(ins->inputs().begin(), ins->inputs().end(), @@ -756,6 +757,7 @@ static void mod_from_val(module_ref mod, auto name = node.at("name").to(); auto fields = node.at("operator"); auto normalized = node.at("normalized").to(); + auto debug_symbols = node.at("debug_symbols").to>(); if(name == "@param") { @@ -808,6 +810,7 @@ static void mod_from_val(module_ref mod, output = mod->insert_instruction(mod->end(), op, inputs, module_inputs); } } + output->add_debug_symbols(debug_symbols); output->set_normalized(normalized); instructions[node.at("output").to()] = output; } diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index bb12ae010f8..9a0fe06d701 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -82,9 +82,7 @@ TEST_CASE(pw_double_add) mm->add_debug_symbols(fadd, {"add1", "add2"}); mm->add_return({fadd}); } - // BUG straight equality is not working even though both call migraphx::to_string - // EXPECT(p1 == p2); - EXPECT(to_string(p1) == to_string(p2)); + EXPECT(p1 == p2); } // Before: After: @@ -135,7 +133,7 @@ TEST_CASE(pw_used_twice_fused) mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); mm->add_return({fadd}); } - EXPECT(to_string(p1.sort()) == to_string(p2.sort())); + EXPECT(p1.sort() == p2.sort()); } // To check that the debug symbols don't propagate above the fusion @@ -177,8 +175,7 @@ TEST_CASE(gemm_add_add) mm->add_debug_symbols(fadd, {"add1", "add2"}); mm->add_return({fadd}); } - // BUG straight equality is not working even though both call migraphx::to_string - EXPECT(to_string(p1) == to_string(p2)); + EXPECT(p1 == p2); } // Horizontal fusion of two dot ops sharing the same input via @@ -237,7 +234,7 @@ TEST_CASE(horiz_fusion_dot) m2.add_debug_symbols(sum, {"sum"}); m2.add_return({sum}); } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); + EXPECT(m1.sort() == m2.sort()); } // Tests symbol propagation through add reassociation in simplify_algebra @@ -285,7 +282,7 @@ TEST_CASE(simplify_add_debug_symbols) m2.add_debug_symbols(sum3, {"onnx:add0", "onnx:add1", "onnx:add2"}); m2.add_instruction(pass_op{}, sum3); } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); + EXPECT(m1.sort() == m2.sort()); } // Tests the replace_instruction(ins, rep) overload via find_unit_ops which @@ -324,7 +321,7 @@ TEST_CASE(replace_with_insref_debug_symbols) m2.add_debug_symbols(relu_x, {"onnx:add", "onnx:relu"}); m2.add_return({relu_x}); } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); + EXPECT(m1.sort() == m2.sort()); } // Tests that debug_symbols propagate through the dead-code chain. @@ -366,7 +363,7 @@ TEST_CASE(gather_replace_chain_debug_symbols) m2.add_debug_symbols(mul_r, {"onnx:add", "onnx:relu"}); m2.add_instruction(pass_op{}, mul_r); } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); + EXPECT(m1.sort() == m2.sort()); } // Tests the distributive law transform in simplify_algebra (find_mul_add): @@ -411,7 +408,7 @@ TEST_CASE(simplify_mul_add_debug_symbols) m2.add_debug_symbols(sum, {"onnx:add", "onnx:mul"}); m2.add_instruction(pass_op{}, sum); } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); + EXPECT(m1.sort() == m2.sort()); } // Tests symbol propagation through find_div_const in simplify_algebra: @@ -455,7 +452,7 @@ TEST_CASE(simplify_div_const_debug_symbols) m2.add_debug_symbols(mul_r, {"onnx:div"}); m2.add_instruction(pass_op{}, mul_r); } - EXPECT(to_string(m1.sort()) == to_string(m2.sort())); + EXPECT(m1.sort() == m2.sort()); } // Three sequential adds fused into a single pointwise op via fuse_pointwise. @@ -512,7 +509,7 @@ TEST_CASE(pw_triple_add_fused) mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3"}); mm->add_return({fadd}); } - EXPECT(to_string(p1) == to_string(p2)); + EXPECT(p1 == p2); } // Verifies that debug symbols appear in the module's printed/serialized diff --git a/test/include/test.hpp b/test/include/test.hpp index 7ae503e2db4..82e3295c7ab 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -220,10 +220,10 @@ template struct lhs_expression; template -lhs_expression make_lhs_expression(T&& lhs); +decltype(auto) make_lhs_expression(T&& lhs); template -lhs_expression make_lhs_expression(T&& lhs, Operator); +decltype(auto) make_lhs_expression(T&& lhs, Operator); // NOLINTNEXTLINE #define TEST_EXPR_BINARY_OPERATOR(op, name) \ @@ -237,6 +237,9 @@ lhs_expression make_lhs_expression(T&& lhs, Operator); #define TEST_EXPR_UNARY_OPERATOR(op, name) \ auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } +//DEBUG +template struct TD; + template struct expression { @@ -245,38 +248,45 @@ struct expression friend std::ostream& operator<<(std::ostream& s, const expression& self) { + //DEBUG + //s << __PRETTY_FUNCTION__ << std::endl; print_stream(s, self.lhs); s << " " << Operator::as_string() << " "; print_stream(s, self.rhs); return s; } - friend decltype(auto) get_value(const expression& e) { return e.value(); } + friend decltype(auto) get_value(const expression& e) { + return e.value(); + } - decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); }; + decltype(auto) value() const { + return Operator::call(get_value(lhs), get_value(rhs)); + }; TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) }; -// TODO: Remove rvalue references template -expression make_expression(T&& lhs, U&& rhs, Operator) +decltype(auto) make_expression(T&& lhs, U&& rhs, Operator) { - return {std::forward(lhs), std::forward(rhs)}; + //rvalue references to pass by value + return expression, std::decay_t, Operator>{lhs, rhs}; } -// TODO: Remove rvalue reference template -lhs_expression make_lhs_expression(T&& lhs) +decltype(auto) make_lhs_expression(T&& lhs) { - return lhs_expression{std::forward(lhs)}; + //rvalue references to pass by value + return lhs_expression>{lhs}; } template -lhs_expression make_lhs_expression(T&& lhs, Operator) +decltype(auto) make_lhs_expression(T&& lhs, Operator) { - return lhs_expression{std::forward(lhs)}; + //rvalue references to pass by value + return lhs_expression, Operator>{lhs}; } template From 8e8993d25860aff7e4aa9e2f8023b5b3686d2cd3 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Mar 2026 16:55:06 -0600 Subject: [PATCH 031/107] Formatting --- test/include/test.hpp | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/test/include/test.hpp b/test/include/test.hpp index 82e3295c7ab..3c98e9b9a59 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -237,8 +237,9 @@ decltype(auto) make_lhs_expression(T&& lhs, Operator); #define TEST_EXPR_UNARY_OPERATOR(op, name) \ auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } -//DEBUG -template struct TD; +// DEBUG +template +struct TD; template struct expression @@ -248,21 +249,17 @@ struct expression friend std::ostream& operator<<(std::ostream& s, const expression& self) { - //DEBUG - //s << __PRETTY_FUNCTION__ << std::endl; + // DEBUG + // s << __PRETTY_FUNCTION__ << std::endl; print_stream(s, self.lhs); s << " " << Operator::as_string() << " "; print_stream(s, self.rhs); return s; } - friend decltype(auto) get_value(const expression& e) { - return e.value(); - } + friend decltype(auto) get_value(const expression& e) { return e.value(); } - decltype(auto) value() const { - return Operator::call(get_value(lhs), get_value(rhs)); - }; + decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); }; TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) @@ -271,21 +268,21 @@ struct expression template decltype(auto) make_expression(T&& lhs, U&& rhs, Operator) { - //rvalue references to pass by value + // rvalue references to pass by value return expression, std::decay_t, Operator>{lhs, rhs}; } template decltype(auto) make_lhs_expression(T&& lhs) { - //rvalue references to pass by value + // rvalue references to pass by value return lhs_expression>{lhs}; } template decltype(auto) make_lhs_expression(T&& lhs, Operator) { - //rvalue references to pass by value + // rvalue references to pass by value return lhs_expression, Operator>{lhs}; } From 0336f6d0f3a391ebb4f6b06aac967a36b37a9b91 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Mar 2026 17:09:37 -0600 Subject: [PATCH 032/107] Undo serialization changes for another PR --- src/program.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/program.cpp b/src/program.cpp index fa433b48436..20418f88bfa 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -685,7 +685,7 @@ static std::string get_migraphx_version() program file version is for the data structure or format of the MXR file. Version should be bumped if any changes occur to the format of the MXR file. */ -const int program_file_version = 8; +const int program_file_version = 7; value program::to_value() const { @@ -711,7 +711,6 @@ value program::to_value() const if(ins->name() == "@literal") node["literal"] = migraphx::to_value(ins->get_literal()); node["operator"] = ins->get_operator().to_value(); - node["debug_symbols"] = ins->get_debug_symbols(); std::vector inputs; std::transform(ins->inputs().begin(), ins->inputs().end(), @@ -757,7 +756,6 @@ static void mod_from_val(module_ref mod, auto name = node.at("name").to(); auto fields = node.at("operator"); auto normalized = node.at("normalized").to(); - auto debug_symbols = node.at("debug_symbols").to>(); if(name == "@param") { @@ -810,7 +808,6 @@ static void mod_from_val(module_ref mod, output = mod->insert_instruction(mod->end(), op, inputs, module_inputs); } } - output->add_debug_symbols(debug_symbols); output->set_normalized(normalized); instructions[node.at("output").to()] = output; } From 278819643c6f9f6a3245ace3018999b98e5b4dbb Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Mar 2026 17:44:55 -0600 Subject: [PATCH 033/107] Add more tests directly calling module::replace_ins --- test/debug_symbols_test.cpp | 253 +++++++++++++++++++++++++++++++++--- 1 file changed, 232 insertions(+), 21 deletions(-) diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 9a0fe06d701..f6500ba011e 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -247,8 +247,6 @@ TEST_CASE(horiz_fusion_dot) // add1{a1} add2{a2} add{a0,a1,a2} add{a0,a1,a2} // \ / \ / // add0{a0} add{a0,a1,a2} -// | | -// pass pass // TEST_CASE(simplify_add_debug_symbols) { @@ -264,7 +262,7 @@ TEST_CASE(simplify_add_debug_symbols) m1.add_debug_symbols(sum2, {"onnx:add2"}); auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); m1.add_debug_symbols(sum3, {"onnx:add0"}); - m1.add_instruction(pass_op{}, sum3); + m1.add_return({sum3}); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); @@ -280,7 +278,7 @@ TEST_CASE(simplify_add_debug_symbols) m2.add_debug_symbols(sum2, {"onnx:add0", "onnx:add1", "onnx:add2"}); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); m2.add_debug_symbols(sum3, {"onnx:add0", "onnx:add1", "onnx:add2"}); - m2.add_instruction(pass_op{}, sum3); + m2.add_return({sum3}); } EXPECT(m1.sort() == m2.sort()); } @@ -331,8 +329,8 @@ TEST_CASE(replace_with_insref_debug_symbols) // x x y // | \ / // relu {relu} mul {add, relu} -// | y | -// | / pass +// | y +// | / // add {add} // | (relu becomes dead code, removed by DCE) // pass @@ -348,7 +346,7 @@ TEST_CASE(gather_replace_chain_debug_symbols) m1.add_debug_symbols(relu_x, {"onnx:relu"}); auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, y); m1.add_debug_symbols(add_r, {"onnx:add"}); - m1.add_instruction(pass_op{}, add_r); + m1.add_return({add_r}); auto mul_r = m1.insert_instruction(add_r, migraphx::make_op("mul"), x, y); m1.replace_instruction(add_r, mul_r); @@ -361,7 +359,7 @@ TEST_CASE(gather_replace_chain_debug_symbols) auto y = m2.add_parameter("y", s); auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, y); m2.add_debug_symbols(mul_r, {"onnx:add", "onnx:relu"}); - m2.add_instruction(pass_op{}, mul_r); + m2.add_return({mul_r}); } EXPECT(m1.sort() == m2.sort()); } @@ -376,9 +374,7 @@ TEST_CASE(gather_replace_chain_debug_symbols) // add {add} mul{add,mul} mul{add,mul} // | 2 \ / // | / add{add,mul} -// mul {mul} | -// | pass -// pass +// mul {mul} // TEST_CASE(simplify_mul_add_debug_symbols) { @@ -391,7 +387,7 @@ TEST_CASE(simplify_mul_add_debug_symbols) m1.add_debug_symbols(sum, {"onnx:add"}); auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two); m1.add_debug_symbols(mul, {"onnx:mul"}); - m1.add_instruction(pass_op{}, mul); + m1.add_return({mul}); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); @@ -406,7 +402,7 @@ TEST_CASE(simplify_mul_add_debug_symbols) m2.add_debug_symbols(mul2, {"onnx:add", "onnx:mul"}); auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); m2.add_debug_symbols(sum, {"onnx:add", "onnx:mul"}); - m2.add_instruction(pass_op{}, sum); + m2.add_return({sum}); } EXPECT(m1.sort() == m2.sort()); } @@ -419,11 +415,9 @@ TEST_CASE(simplify_mul_add_debug_symbols) // x c c // \ / | // div {div} recip {div} -// | x | -// pass \ | +// x | +// \ | // mul {div} -// | -// pass // TEST_CASE(simplify_div_const_debug_symbols) { @@ -436,7 +430,7 @@ TEST_CASE(simplify_div_const_debug_symbols) {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); auto div_r = m1.add_instruction(migraphx::make_op("div"), x, c); m1.add_debug_symbols(div_r, {"onnx:div"}); - m1.add_instruction(pass_op{}, div_r); + m1.add_return({div_r}); } migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); @@ -450,7 +444,7 @@ TEST_CASE(simplify_div_const_debug_symbols) m2.add_debug_symbols(recip, {"onnx:div"}); auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, recip); m2.add_debug_symbols(mul_r, {"onnx:div"}); - m2.add_instruction(pass_op{}, mul_r); + m2.add_return({mul_r}); } EXPECT(m1.sort() == m2.sort()); } @@ -526,12 +520,229 @@ TEST_CASE(debug_symbols_in_print) auto y = m.add_parameter("y", s); auto add = m.add_instruction(migraphx::make_op("add"), x, y); m.add_debug_symbols(add, {"sym_a", "sym_b"}); - m.add_instruction(pass_op{}, add); + m.add_return({add}); auto str = migraphx::to_string(m); EXPECT(str.find("# sym_a, sym_b #") != std::string::npos); } -// TODO make tests that directly call module::replace_instruction +// ----------------------------------------------------------------------- +// Direct replace_instruction / batch_replace_instruction tests +// ----------------------------------------------------------------------- + +// replace_instruction(ins, op, args) -- simple in-place, no splice chain. +// The replaced instruction retains its own debug symbols. +// +// Before: After: +// +// x y x y +// \ / \ / +// add {sym_add} mul {sym_add} +// | | +// @return @return +// +TEST_CASE(replace_op_args_no_splice) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"sym_add"}); + m.add_return({add}); + + m.replace_instruction(add, migraphx::make_op("mul"), x, y); + + EXPECT(add->name() == "mul"); + EXPECT(add->get_debug_symbols() == std::set{"sym_add"}); +} + +// replace_instruction(ins, op, args) -- splice chain propagation. +// relu only outputs to neg, so relu is in neg's splice chain. When neg +// is replaced with abs(x), relu's symbols propagate to the replacement. +// +// Before: After: +// +// x x +// | | +// relu {relu_sym} abs {neg_sym, relu_sym} +// | | +// neg {neg_sym} @return +// | +// @return +// +TEST_CASE(replace_op_args_spliced_inputs) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto relu_x = m.add_instruction(migraphx::make_op("relu"), x); + m.add_debug_symbols(relu_x, {"relu_sym"}); + auto neg_r = m.add_instruction(migraphx::make_op("neg"), relu_x); + m.add_debug_symbols(neg_r, {"neg_sym"}); + m.add_return({neg_r}); + + m.replace_instruction(neg_r, migraphx::make_op("abs"), x); + + EXPECT(neg_r->name() == "abs"); + std::set expected{"neg_sym", "relu_sym"}; + EXPECT(neg_r->get_debug_symbols() == expected); +} + +// replace_instruction(ins, op, args, module_args) -- in-place replace with +// module arguments. The debug symbol logic is identical to the args-only +// overload; this test exercises the separate code path. +// +// Before: After: +// +// cond cond +// | | +// if {sym_if} if {sym_if} (different sub-modules) +// | | +// get[0] get[0] +// | | +// @return @return +// +TEST_CASE(replace_op_args_module_args) +{ + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape s{migraphx::shape::float_type, {5}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto cond = mm->add_parameter("cond", cond_s); + + auto* then_mod1 = p.create_module("then1"); + std::vector d1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod1->add_literal(migraphx::literal(s, d1)); + then_mod1->add_return({l1}); + + auto* else_mod1 = p.create_module("else1"); + std::vector d2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod1->add_literal(migraphx::literal(s, d2)); + else_mod1->add_return({l2}); + + auto if_ins = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1}); + mm->add_debug_symbols(if_ins, {"sym_if"}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), if_ins); + mm->add_return({r}); + + auto* then_mod2 = p.create_module("then2"); + std::vector d3 = {10, 20, 30, 40, 50}; + auto l3 = then_mod2->add_literal(migraphx::literal(s, d3)); + then_mod2->add_return({l3}); + + auto* else_mod2 = p.create_module("else2"); + std::vector d4 = {50, 40, 30, 20, 10}; + auto l4 = else_mod2->add_literal(migraphx::literal(s, d4)); + else_mod2->add_return({l4}); + + mm->replace_instruction(if_ins, + migraphx::make_op("if"), + std::vector{cond}, + std::vector{then_mod2, else_mod2}); + + EXPECT(if_ins->get_debug_symbols() == std::set{"sym_if"}); +} + +// replace_instruction(ins, rep) -- redirect outputs from ins to rep. +// rep inherits the debug symbols of ins. +// +// Before: After (add becomes dead): +// +// x y x y +// |\ /| \ / +// | X | mul {sym_add} +// |/ \| | +// mul add {sym_add} @return +// | +// @return +// +TEST_CASE(replace_with_insref_no_splice) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto mul = m.add_instruction(migraphx::make_op("mul"), x, y); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"sym_add"}); + m.add_return({add}); + + m.replace_instruction(add, mul); + + EXPECT(mul->get_debug_symbols() == std::set{"sym_add"}); +} + +// batch_replace_instruction -- single element with splice chain. +// Same topology as replace_op_args_spliced_inputs but via the batch API. +// +// Before: After: +// +// x x +// | | +// relu {relu_sym} abs {neg_sym, relu_sym} +// | | +// neg {neg_sym} @return +// | +// @return +// +TEST_CASE(batch_replace_single_with_splice) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto relu_x = m.add_instruction(migraphx::make_op("relu"), x); + m.add_debug_symbols(relu_x, {"relu_sym"}); + auto neg_r = m.add_instruction(migraphx::make_op("neg"), relu_x); + m.add_debug_symbols(neg_r, {"neg_sym"}); + m.add_return({neg_r}); + + auto results = m.batch_replace_instruction({{neg_r, migraphx::make_op("abs"), {x}, {}}}); + + EXPECT(results.size() == 1); + std::set expected{"neg_sym", "relu_sym"}; + EXPECT(results[0]->get_debug_symbols() == expected); +} + +// batch_replace_instruction -- two simultaneous replacements. The batch +// collects symbols from ALL replaced instructions and propagates the +// merged set to every new-splice instruction. With individual replaces +// mul would only get {add1_sym} and div only {add2_sym}; the batch +// merges them so both receive the combined set. +// +// Before: After: +// +// x y x y +// |\ /| |\ /| +// | X | | X | +// |/ \| |/ \| +// add1 add2 mul div {add1_sym, add2_sym} +// {add1_sym} {add2_sym} \ / +// \ / @return +// @return +// +TEST_CASE(batch_replace_multi_merges_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add1, {"add1_sym"}); + auto add2 = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add2, {"add2_sym"}); + m.add_return({add1, add2}); + + auto results = m.batch_replace_instruction({ + {add1, migraphx::make_op("mul"), {x, y}, {}}, + {add2, migraphx::make_op("div"), {x, y}, {}}, + }); + + EXPECT(results.size() == 2); + std::set expected{"add1_sym", "add2_sym"}; + EXPECT(results[0]->get_debug_symbols() == expected); + EXPECT(results[1]->get_debug_symbols() == expected); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } From 0c7506b7b82ab46febf349f3c91bc843f2fece0d Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 4 Mar 2026 17:45:41 -0600 Subject: [PATCH 034/107] Add onnx parse tests --- test/onnx/parse/debug_symbols_test.cpp | 68 ++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 test/onnx/parse/debug_symbols_test.cpp diff --git a/test/onnx/parse/debug_symbols_test.cpp b/test/onnx/parse/debug_symbols_test.cpp new file mode 100644 index 00000000000..40bc2876776 --- /dev/null +++ b/test/onnx/parse/debug_symbols_test.cpp @@ -0,0 +1,68 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +template +struct TD; + +TEST_CASE(debug_symbols_onnx_names) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto conv_t = mm->add_instruction(migraphx::make_op("convolution_backwards"), l0, l1); + mm->add_debug_symbols(conv_t, {"conv1"}); + mm->add_return({conv_t}); + + migraphx::onnx_options options; + options.use_debug_symbols = true; + + auto prog = read_onnx("conv_transpose_test.onnx", options); + EXPECT(p == prog); +} + +TEST_CASE(debug_symbols_migx_names) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}}); + uint64_t axis = 1; + auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1); + mm->add_debug_symbols(l3, {"migx_uid:Conv_3"}); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + mm->add_debug_symbols(l4, {"migx_uid:Conv_3"}); + auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); + mm->add_debug_symbols(l5, {"migx_uid:Conv_3"}); + mm->add_return({l5}); + + migraphx::onnx_options options; + options.use_debug_symbols = true; + auto prog = read_onnx("conv_bias_test.onnx", options); + EXPECT(p == prog); +} From 95a5d9e07ecefab90802d8353725c5d6e37df219 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 5 Mar 2026 15:43:22 -0600 Subject: [PATCH 035/107] Tidy fixes --- src/include/migraphx/module.hpp | 6 +++--- src/module.cpp | 6 +++--- test/include/test.hpp | 8 +------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 8cea294529e..a82312ae188 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -135,10 +135,10 @@ struct MIGRAPHX_EXPORT module std::vector module_args; }; - /// Replaces an array of instructions within the same function to propertly handle debug symbols + /// Replaces an array of instructions within the same function to properly handle debug symbols /// propagation Returns vector of instruction_ref to replaced instructions - std::vector - batch_replace_instruction(const std::vector& replacers); + std::vector batch_replace_instruction( + const std::vector& replacers) MIGRAPHX_TIDY_CONST; instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instructions(instruction_ref first, instruction_ref last); diff --git a/src/module.cpp b/src/module.cpp index 85c24a56024..5ee8cd1930e 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -371,9 +371,9 @@ static std::unordered_set gather_max_splice(const_module_ref m, return result; } -void propagate_debug_symbols(const_module_ref m, - instruction_ref ins, - std::unordered_set old_max_splice) +static void propagate_debug_symbols(const_module_ref m, + instruction_ref ins, + std::unordered_set old_max_splice) { // Remove ins from old_max_splice, if it is there. To prevent it being in both old_max_splice // and new_max_slice. diff --git a/test/include/test.hpp b/test/include/test.hpp index 3c98e9b9a59..c3ca931c6ae 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -237,10 +237,6 @@ decltype(auto) make_lhs_expression(T&& lhs, Operator); #define TEST_EXPR_UNARY_OPERATOR(op, name) \ auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } -// DEBUG -template -struct TD; - template struct expression { @@ -249,8 +245,6 @@ struct expression friend std::ostream& operator<<(std::ostream& s, const expression& self) { - // DEBUG - // s << __PRETTY_FUNCTION__ << std::endl; print_stream(s, self.lhs); s << " " << Operator::as_string() << " "; print_stream(s, self.rhs); @@ -290,7 +284,7 @@ template struct lhs_expression { T lhs; - explicit lhs_expression(T e) : lhs(std::move(e)) {} + explicit lhs_expression(const T& e) : lhs(e) {} friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) { From 80f3e470abd8bd3a9f64553d8d3f65f090b705bd Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 5 Mar 2026 15:45:00 -0600 Subject: [PATCH 036/107] Licensing --- src/include/migraphx/onnx.hpp | 2 +- src/onnx/onnx.cpp | 2 +- src/program.cpp | 2 +- test/onnx/parse/debug_symbols_test.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/onnx.hpp b/src/include/migraphx/onnx.hpp index 8247dc73835..13745994fd7 100644 --- a/src/include/migraphx/onnx.hpp +++ b/src/include/migraphx/onnx.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/onnx/onnx.cpp b/src/onnx/onnx.cpp index 7ea1edfb350..3b8be092d85 100644 --- a/src/onnx/onnx.cpp +++ b/src/onnx/onnx.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/program.cpp b/src/program.cpp index 20418f88bfa..1102231c2ed 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/onnx/parse/debug_symbols_test.cpp b/test/onnx/parse/debug_symbols_test.cpp index 40bc2876776..99e242004c9 100644 --- a/test/onnx/parse/debug_symbols_test.cpp +++ b/test/onnx/parse/debug_symbols_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 3f6e1a77545a5a487d6e2476bc62d75fbc33a606 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 5 Mar 2026 15:49:49 -0600 Subject: [PATCH 037/107] formatting --- src/include/migraphx/module.hpp | 4 +++- src/instruction.cpp | 4 +++- src/module.cpp | 16 +++++++++++----- src/onnx/onnx_parser.cpp | 16 ++++++++++++---- src/op/builder/clip.cpp | 4 +++- src/op/builder/gelu.cpp | 8 ++++++-- src/program.cpp | 12 +++++++----- src/split_reduce.cpp | 6 ++++-- src/targets/gpu/compile_ops.cpp | 12 +++++++++--- src/targets/gpu/prepare_reduce.cpp | 12 +++++++++--- test/debug_symbols_test.cpp | 20 ++++++++++---------- test/gpu/prepare_reduce.cpp | 4 +++- 12 files changed, 80 insertions(+), 38 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index a82312ae188..88a8aa83512 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -102,7 +102,9 @@ struct MIGRAPHX_EXPORT module template {}...)> instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) - { return insert_instruction(ins, op, {args...}); } + { + return insert_instruction(ins, op, {args...}); + } instruction_ref insert_instruction(instruction_ref ins, const operation& op, std::vector args); diff --git a/src/instruction.cpp b/src/instruction.cpp index e58ca74cd58..e2599c2610f 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -194,7 +194,9 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } void instruction::add_debug_symbols(const std::set& symbols) -{ debug_symbols.insert(symbols.begin(), symbols.end()); } +{ + debug_symbols.insert(symbols.begin(), symbols.end()); +} void instruction::remove_debug_symbols() { debug_symbols.clear(); } diff --git a/src/module.cpp b/src/module.cpp index 5ee8cd1930e..b3a6d3bdf4e 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -132,7 +132,9 @@ struct module_impl const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } module::module(const std::string& name) :impl(std::make_unique()) -{ impl->name = name; } +{ + impl->name = name; +} module::module(module&&) noexcept = default; module::~module() noexcept = default; @@ -202,7 +204,7 @@ void module::assign(const module& m) auto order = any_cast(ins->get_operator()).order; auto s = ins->get_shape(); copy_ins = impl->insert(impl->instructions.end(), - {builtin::param{name, order}, std::move(s), {}}); + {builtin::param{name, order}, std::move(s), {}}); impl->nparams++; } else if(ins->name() == "@outline") @@ -309,7 +311,9 @@ insert_generic_instructions(module& m, } instruction_ref module::add_instruction(const operation& op, std::vector args) -{ return insert_instruction(this->insert_end(), op, std::move(args)); } +{ + return insert_instruction(this->insert_end(), op, std::move(args)); +} instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, @@ -327,7 +331,9 @@ instruction_ref module::insert_instruction(instruction_ref ins, instruction_ref module::add_instruction(const operation& op, std::vector args, std::vector module_args) -{ return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); } +{ + return insert_instruction(this->insert_end(), op, std::move(args), std::move(module_args)); +} instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, @@ -1409,7 +1415,7 @@ std::unordered_map module::print( std::unordered_map names) const { const bool is_root = names.empty(); - int count = 0; + int count = 0; for(auto ins : iterator_for(*this)) { std::string var_name; diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 8b805c466f9..3c753617aab 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -174,20 +174,28 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s */ instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, std::vector inputs) const -{ return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); } +{ + return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); +} instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args) const -{ return mod->add_instruction(op, args); } +{ + return mod->add_instruction(op, args); +} instruction_ref onnx_parser::node_info::add_instruction(const operation& op, const std::vector& args, const std::vector& mods) const -{ return mod->add_instruction(op, args, mods); } +{ + return mod->add_instruction(op, args, mods); +} instruction_ref onnx_parser::node_info::add_literal(literal l) const -{ return mod->add_literal(std::move(l)); } +{ + return mod->add_literal(std::move(l)); +} onnx_parser::onnx_parser() { diff --git a/src/op/builder/clip.cpp b/src/op/builder/clip.cpp index 9d986423a9c..f3f5e147301 100644 --- a/src/op/builder/clip.cpp +++ b/src/op/builder/clip.cpp @@ -37,7 +37,9 @@ struct clip : op_builder { template static auto reflect(Self&, F) - { return pack(); } + { + return pack(); + } std::vector insert(module& m, instruction_ref ins, const std::vector& args) const diff --git a/src/op/builder/gelu.cpp b/src/op/builder/gelu.cpp index 60258b3a905..b80a554a600 100644 --- a/src/op/builder/gelu.cpp +++ b/src/op/builder/gelu.cpp @@ -59,7 +59,9 @@ struct gelu_erf : op_builder { template static auto reflect(Self&, F) - { return pack(); } + { + return pack(); + } std::vector insert(module& m, instruction_ref ins, const std::vector& args) const @@ -130,7 +132,9 @@ struct gelu_split : op_builder { template static auto reflect(Self&, F) - { return pack(); } + { + return pack(); + } std::vector insert(module& m, instruction_ref ins, const std::vector& args) const diff --git a/src/program.cpp b/src/program.cpp index 1102231c2ed..7b5cabbc0b1 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1043,10 +1043,10 @@ void program::perf_report( os << percentile_90_time << "ms, " << percentile_95_time << "ms, " << percentile_99_time << "ms)" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; - os << "Overhead time: " << overhead_time << "ms" - << ", " << calculate_overhead_time << "ms" << std::endl; - os << "Overhead: " << std::round(overhead_percent) << "%" - << ", " << std::round(calculate_overhead_percent) << "%" << std::endl; + os << "Overhead time: " << overhead_time << "ms" << ", " << calculate_overhead_time << "ms" + << std::endl; + os << "Overhead: " << std::round(overhead_percent) << "%" << ", " + << std::round(calculate_overhead_percent) << "%" << std::endl; } void program::debug_print() const { std::cout << *this << std::endl; } @@ -1371,7 +1371,9 @@ program& program::sort() } bool operator==(const program& x, const program& y) -{ return migraphx::to_string(x) == migraphx::to_string(y); } +{ + return migraphx::to_string(x) == migraphx::to_string(y); +} std::ostream& operator<<(std::ostream& os, const program& p) { diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 9fdf9bd8485..41065c4770b 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -60,7 +60,7 @@ struct split_fused_reduce if(mods.size() != 1) MIGRAPHX_THROW("should have one submodule."); const auto* sm = mods.front(); - auto names = sm->get_parameter_names(); + auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); auto result = @@ -75,7 +75,9 @@ struct split_fused_reduce MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) -{ return contains(ins.name(), "reduce") or ins.name() == "argmin" or ins.name() == "argmax"; } +{ + return contains(ins.name(), "reduce") or ins.name() == "argmin" or ins.name() == "argmax"; +} namespace { struct splitter diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 918e6ca8b5b..8685469cda2 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -102,15 +102,21 @@ struct dynamic_code_object_op template static auto reflect(Self& self, F f) - { return pack(f(self.pre_op, "pre_op")); } + { + return pack(f(self.pre_op, "pre_op")); + } std::string name() const { return "gpu::dynamic_code_object_op"; } shape compute_shape(const std::vector& inputs, const std::vector& mods) const - { return pre_op.compute_shape(inputs, mods); } + { + return pre_op.compute_shape(inputs, mods); + } std::vector output_alias(const std::vector& shapes) const - { return {shapes.size() - 1}; } + { + return {shapes.size() - 1}; + } std::unordered_map build_param_map(const std::vector& args, const_module_ref mod) const { diff --git a/src/targets/gpu/prepare_reduce.cpp b/src/targets/gpu/prepare_reduce.cpp index d53aedcf501..46c8d6fa7ad 100644 --- a/src/targets/gpu/prepare_reduce.cpp +++ b/src/targets/gpu/prepare_reduce.cpp @@ -64,7 +64,9 @@ struct arg_reduce template static auto reflect(Self& self, F f) - { return pack(f(self.op, "op")); } + { + return pack(f(self.op, "op")); + } std::string name() const { return "gpu::arg_reduce"; } @@ -83,12 +85,16 @@ struct make_indices template static auto reflect(Self& self, F f) - { return pack(f(self.size, "size")); } + { + return pack(f(self.size, "size")); + } std::string name() const { return "gpu::make_indices"; } shape compute_shape(const std::vector&) const - { return shape{shape::uint32_type, {size}}; } + { + return shape{shape::uint32_type, {size}}; + } }; MIGRAPHX_REGISTER_OP(make_indices); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index f6500ba011e..d5f275e4c57 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -57,9 +57,9 @@ TEST_CASE(pw_double_add) migraphx::program p1; { auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); migraphx::instruction_ref add1 = mm->add_instruction(migraphx::make_op("add"), x, y); mm->add_debug_symbols(add1, {"add1"}); migraphx::instruction_ref add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); @@ -71,9 +71,9 @@ TEST_CASE(pw_double_add) migraphx::program p2; { auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); auto fadd = add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); @@ -490,10 +490,10 @@ TEST_CASE(pw_triple_add_fused) migraphx::program p2; { auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto w = mm->add_parameter("w", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto w = mm->add_parameter("w", s); auto fadd = add_pointwise(p2, "main:pointwise0", {x, y, z, w}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); diff --git a/test/gpu/prepare_reduce.cpp b/test/gpu/prepare_reduce.cpp index 39b97b3a847..e45f01c134f 100644 --- a/test/gpu/prepare_reduce.cpp +++ b/test/gpu/prepare_reduce.cpp @@ -31,7 +31,9 @@ #include static void run_pass(migraphx::module& m) -{ migraphx::run_passes(m, {migraphx::gpu::prepare_reduce{}, migraphx::dead_code_elimination{}}); } +{ + migraphx::run_passes(m, {migraphx::gpu::prepare_reduce{}, migraphx::dead_code_elimination{}}); +} // Helper to add the arg_reduce pattern: make_indices -> arg_reduce -> get_tuple_elem static migraphx::instruction_ref add_arg_reduce(migraphx::module& m, From e2ea3aec9f52612dbd141d465918d934e00f9c87 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 5 Mar 2026 16:09:14 -0600 Subject: [PATCH 038/107] parser get_added_instructions only if used --- src/onnx/onnx_parser.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 3c753617aab..fbd32608e68 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -613,7 +613,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini result.begin(), std::inserter(instructions, instructions.end()), [](auto&& x, auto&& y) { return std::make_pair(x, y); }); - auto added_instructions = get_added_instructions(args, result); + std::vector added_instructions; + if(this->use_debug_symbols or enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) + { + added_instructions = get_added_instructions(args, result); + } if(this->use_debug_symbols) { std::string debug_symbol = From 6c0f1420a0ec538cb1a79b9171fece533d15ac23 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 5 Mar 2026 16:17:49 -0600 Subject: [PATCH 039/107] Fixes --- src/include/migraphx/module.hpp | 3 --- src/module.cpp | 4 +++- src/onnx/onnx_parser.cpp | 7 ------- src/simplify_algebra.cpp | 2 -- test/debug_symbols_test.cpp | 2 +- 5 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 88a8aa83512..14bf3cc4def 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -374,7 +374,6 @@ struct MIGRAPHX_EXPORT module friend bool operator!=(const module& x, const module& y) { return not(x == y); } friend struct program; - friend struct scoped_debug_symbols; private: void set_name(const std::string& name); @@ -383,8 +382,6 @@ struct MIGRAPHX_EXPORT module const module& pmod, instruction_ref ins, ins_dep_map& deps) const; - void propagate_replace_debug_symbols(instruction_ref rep_ins, - const std::set& debug_symbols); std::unique_ptr impl; }; diff --git a/src/module.cpp b/src/module.cpp index b3a6d3bdf4e..03a1ce7d210 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -160,6 +160,8 @@ bool module::has_debug_symbols() const { return impl->num_ins_with_debug_symbols void module::add_debug_symbols(instruction_ref ins, const std::set& symbols) const { + if(symbols.empty()) + return; if(ins->get_debug_symbols().empty()) { impl->num_ins_with_debug_symbols++; @@ -231,7 +233,7 @@ void module::assign(const module& m) copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args); } } - + copy_ins->add_debug_symbols(ins->get_debug_symbols()); ins_map[ins] = copy_ins; } } diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index fbd32608e68..4c221033b32 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -631,13 +631,6 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini { mod->debug_print(added_instructions); } - - output_num = std::min(output_num, result.size()); - std::transform(node.output().begin(), - node.output().begin() + output_num, - result.begin(), - std::inserter(instructions, instructions.end()), - [](auto&& x, auto&& y) { return std::make_pair(x, y); }); } // Find instructions corresponding to the output diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 8d74c1b139e..7076c9060a9 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1708,8 +1708,6 @@ struct find_conv_dot_horiz_fusion int64_t offset = 0; for(auto arg : range(start, last)) { - auto outputs = arg->outputs(); - int64_t len = arg->get_shape().lens()[axis]; auto slice_op = make_op( "slice", {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index d5f275e4c57..700cf9bb5c1 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -510,7 +510,7 @@ TEST_CASE(pw_triple_add_fused) // output using the expected comment format produced by instruction::print. // // Printed output includes: -// @2 = add(@0,@1) -> float_type, {2, 3} /* sym_a, sym_b */ +// @2 = add(@0,@1) -> float_type, {2, 3} # sym_a, sym_b # // TEST_CASE(debug_symbols_in_print) { From ac241183d0dff90cbd84ffb4fde6ebb2e832ba32 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 5 Mar 2026 16:23:46 -0600 Subject: [PATCH 040/107] more copilot fixes --- src/module.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 03a1ce7d210..fa1dbf3337e 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -95,7 +95,8 @@ struct module_impl changed.notify(); instructions.clear(); instruction_set.clear(); - nparams = 0; + nparams = 0; + num_ins_with_debug_symbols = 0; } void push_front(const instruction& ins) { insert(instructions.begin(), ins); } @@ -574,6 +575,10 @@ instruction_ref module::remove_instruction(instruction_ref ins) { assert(has_instruction(ins)); assert(ins->outputs().empty()); + if(not ins->get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) + { + impl->num_ins_with_debug_symbols--; + } ins->clear_arguments(); return impl->erase(ins); } @@ -584,7 +589,13 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r return first; // TODO: Check every element assert(has_instruction(first)); - std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); + std::for_each(first, last, [&](instruction& ins) { + if(not ins.get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) + { + impl->num_ins_with_debug_symbols--; + } + ins.clear_arguments(); + }); assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); })); return impl->erase(first, last); } From 821aab6b756a35159e785fa27816e906994a2107 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 6 Mar 2026 11:33:46 -0600 Subject: [PATCH 041/107] Cleanup env --- src/instruction.cpp | 1 - src/onnx/onnx_parser.cpp | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index e2599c2610f..b6b26a2222c 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 4c221033b32..c28f1d59fde 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -327,15 +327,18 @@ get_added_instructions(const std::vector& args, { // Print instructions added by the parser not in args std::vector added_instructions; + // Set for checking added_instructions faster + std::unordered_set visit_set; fix([&](auto self, const auto& r) { for(auto ins : r) { if(contains(args, ins)) continue; - if(contains(added_instructions, ins)) + if(contains(visit_set, ins)) continue; self(ins->inputs()); added_instructions.push_back(ins); + visit_set.insert(ins); } })(result); return added_instructions; From 3ba71eccca4600bb04977a731052fe19b0695a7c Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 6 Mar 2026 11:35:44 -0600 Subject: [PATCH 042/107] Add notify --- src/module.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/module.cpp b/src/module.cpp index fa1dbf3337e..89940348484 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -234,7 +234,7 @@ void module::assign(const module& m) copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args); } } - copy_ins->add_debug_symbols(ins->get_debug_symbols()); + add_debug_symbols(copy_ins, ins->get_debug_symbols()); ins_map[ins] = copy_ins; } } @@ -512,6 +512,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref std::vector module::batch_replace_instruction(const std::vector& replacers) { + impl->changed.notify(); std::vector ret; std::unordered_set old_max_splices; if(has_debug_symbols()) @@ -531,6 +532,8 @@ module::batch_replace_instruction(const std::vector& repla std::unordered_set new_max_splices; for(const auto& replacer : replacers) { + assert(has_instruction(replacer.ins)); + assert(not starts_with(replacer.op.name(), "@")); auto out_shape = compute_shape(replacer.op, replacer.args, replacer.module_args); instruction::replace( replacer.ins, replacer.op, out_shape, replacer.args, replacer.module_args); From cfe37d69050d7eea54959b49491ffc3c6030f3dc Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 6 Mar 2026 11:36:31 -0600 Subject: [PATCH 043/107] Module const func --- src/module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/module.cpp b/src/module.cpp index 89940348484..b48c7c38caa 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -510,7 +510,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref } std::vector -module::batch_replace_instruction(const std::vector& replacers) +module::batch_replace_instruction(const std::vector& replacers) MIGRAPHX_TIDY_CONST { impl->changed.notify(); std::vector ret; From a5c6706213e71f39775a4f6ffad9a7c7c4ab5a1f Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 6 Mar 2026 11:43:35 -0600 Subject: [PATCH 044/107] Update gitignore --- .gitignore | 17 +++++++++++++++++ src/op/builder/celu.cpp | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index fa78b2ef667..13d85644f5b 100644 --- a/.gitignore +++ b/.gitignore @@ -54,10 +54,27 @@ _toc.yml #==============================================================================# # Directories to ignore (do not add trailing '/'s, they skip symlinks). #==============================================================================# +# Nested build directory +build*/ +debug_docker_build/ +docker_build/ + +# Cursor files +.cursor/ + +# Ctags files +.ctagsignore # Downloaded models test/onnx/models +# Python virtual env +venv_rbuild/ + +# ONNX test generated files +test/onnx/load_save_arg.msgpack +test/onnx/migraphx_api_load_save_argument.msgpack + # VS2017 and VSCode config files. .vscode .vs diff --git a/src/op/builder/celu.cpp b/src/op/builder/celu.cpp index 0c08f79fad1..f849dfec4ef 100644 --- a/src/op/builder/celu.cpp +++ b/src/op/builder/celu.cpp @@ -1,6 +1,6 @@ /* The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From cb0a3e8280ca150c43a4eb4a3932258c8ae6d894 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 6 Mar 2026 11:44:57 -0600 Subject: [PATCH 045/107] formatting again --- src/module.cpp | 4 ++-- test/debug_symbols_test.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index b48c7c38caa..98433dbd35d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -509,8 +509,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref return rep; } -std::vector -module::batch_replace_instruction(const std::vector& replacers) MIGRAPHX_TIDY_CONST +std::vector module::batch_replace_instruction( + const std::vector& replacers) MIGRAPHX_TIDY_CONST { impl->changed.notify(); std::vector ret; diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 700cf9bb5c1..2c5310d9c11 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -510,7 +510,7 @@ TEST_CASE(pw_triple_add_fused) // output using the expected comment format produced by instruction::print. // // Printed output includes: -// @2 = add(@0,@1) -> float_type, {2, 3} # sym_a, sym_b # +// @2 = add(@0,@1) -> float_type, {2, 3} # sym_a, sym_b # // TEST_CASE(debug_symbols_in_print) { From 2a0365b5488672bf20147ac5ced56c260eee10c3 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 9 Mar 2026 14:25:46 -0500 Subject: [PATCH 046/107] join_strings to use const ref instead --- src/include/migraphx/stringutils.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/stringutils.hpp b/src/include/migraphx/stringutils.hpp index c4618036f5c..492c760ee34 100644 --- a/src/include/migraphx/stringutils.hpp +++ b/src/include/migraphx/stringutils.hpp @@ -71,16 +71,18 @@ inline bool ends_with(const std::string& value, const std::string& suffix) } template -inline std::string join_strings(Strings strings, const std::string& delim) +inline std::string join_strings(const typename std::remove_reference::type& strings, + const std::string& delim) { auto it = strings.begin(); if(it == strings.end()) return ""; auto nit = std::next(it); - return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) { - return std::move(x) + delim + std::move(y); - }); + return std::accumulate( + nit, strings.end(), *it, [&](const std::string& x, const std::string& y) { + return x + delim + y; + }); } inline std::vector split_string(const std::string& s, char delim) From 19f61659da4d4daf42ae0c2015339ca39ac86728 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 9 Mar 2026 14:28:36 -0500 Subject: [PATCH 047/107] Comment formatting --- src/include/migraphx/module.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 14bf3cc4def..7b0e5e11595 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -138,7 +138,7 @@ struct MIGRAPHX_EXPORT module }; /// Replaces an array of instructions within the same function to properly handle debug symbols - /// propagation Returns vector of instruction_ref to replaced instructions + /// propagation. Returns vector of instruction_ref to replaced instructions. std::vector batch_replace_instruction( const std::vector& replacers) MIGRAPHX_TIDY_CONST; From 7a4b3f43298f0e862f85837655be290b5264a09b Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 9 Mar 2026 14:29:09 -0500 Subject: [PATCH 048/107] Remove redunant test --- test/debug_symbols_test.cpp | 57 ------------------------------------- 1 file changed, 57 deletions(-) diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 2c5310d9c11..2f779a61d10 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -449,63 +449,6 @@ TEST_CASE(simplify_div_const_debug_symbols) EXPECT(m1.sort() == m2.sort()); } -// Three sequential adds fused into a single pointwise op via fuse_pointwise. -// All three ONNX node symbols should appear on the fused pointwise instruction. -// Extends pw_double_add to a longer chain. -// -// Before: After: -// -// x y x y z w -// \ / \ | | / -// add {add1} pointwise {add1, add2, add3} -// | z | -// | / @return -// add {add2} -// | w -// | / -// add {add3} -// | -// @return -// -TEST_CASE(pw_triple_add_fused) -{ - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::program p1; - { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto w = mm->add_parameter("w", s); - auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - mm->add_debug_symbols(add1, {"onnx:add1"}); - auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); - mm->add_debug_symbols(add2, {"onnx:add2"}); - auto add3 = mm->add_instruction(migraphx::make_op("add"), add2, w); - mm->add_debug_symbols(add3, {"onnx:add3"}); - mm->add_return({add3}); - } - migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); - - migraphx::program p2; - { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto w = mm->add_parameter("w", s); - auto fadd = - add_pointwise(p2, "main:pointwise0", {x, y, z, w}, [=](auto* pm, const auto& inputs) { - auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); - return pm->add_instruction(migraphx::make_op("add"), add2, inputs[3]); - }); - mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3"}); - mm->add_return({fadd}); - } - EXPECT(p1 == p2); -} - // Verifies that debug symbols appear in the module's printed/serialized // output using the expected comment format produced by instruction::print. // From 287b79b6309ad14c2e221612d8ee110054c2a178 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 9 Mar 2026 15:24:49 -0500 Subject: [PATCH 049/107] More tests --- src/include/migraphx/stringutils.hpp | 3 +- test/debug_symbols_test.cpp | 93 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/stringutils.hpp b/src/include/migraphx/stringutils.hpp index 492c760ee34..022e11571bf 100644 --- a/src/include/migraphx/stringutils.hpp +++ b/src/include/migraphx/stringutils.hpp @@ -71,8 +71,7 @@ inline bool ends_with(const std::string& value, const std::string& suffix) } template -inline std::string join_strings(const typename std::remove_reference::type& strings, - const std::string& delim) +inline std::string join_strings(const Strings& strings, const std::string& delim) { auto it = strings.begin(); if(it == strings.end()) diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 2f779a61d10..ff5e5297b9f 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -688,4 +688,97 @@ TEST_CASE(batch_replace_multi_merges_symbols) EXPECT(results[1]->get_debug_symbols() == expected); } +// ----------------------------------------------------------------------- +// module::remove_debug_symbols tests +// ----------------------------------------------------------------------- + +TEST_CASE(remove_single_symbol) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"sym_a", "sym_b"}); + m.add_return({add}); + + EXPECT(m.has_debug_symbols()); + m.remove_debug_symbols(add); + + EXPECT(add->get_debug_symbols().empty()); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(remove_noop_no_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_return({add}); + + EXPECT(not m.has_debug_symbols()); + m.remove_debug_symbols(add); + + EXPECT(add->get_debug_symbols().empty()); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(remove_one_of_two) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add1, {"sym_add1"}); + auto add2 = m.add_instruction(migraphx::make_op("add"), add1, y); + m.add_debug_symbols(add2, {"sym_add2"}); + m.add_return({add2}); + + EXPECT(m.has_debug_symbols()); + m.remove_debug_symbols(add1); + + EXPECT(add1->get_debug_symbols().empty()); + EXPECT(add2->get_debug_symbols() == std::set{"sym_add2"}); + EXPECT(m.has_debug_symbols()); +} + +TEST_CASE(remove_then_re_add) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"old_sym"}); + m.add_return({add}); + + m.remove_debug_symbols(add); + EXPECT(add->get_debug_symbols().empty()); + EXPECT(not m.has_debug_symbols()); + + m.add_debug_symbols(add, {"new_sym"}); + EXPECT(add->get_debug_symbols() == std::set{"new_sym"}); + EXPECT(m.has_debug_symbols()); +} + +TEST_CASE(remove_idempotent) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"sym"}); + m.add_return({add}); + + m.remove_debug_symbols(add); + m.remove_debug_symbols(add); + + EXPECT(add->get_debug_symbols().empty()); + EXPECT(not m.has_debug_symbols()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 3f1ec9d16933282a4c71c563e624e597a65f5a46 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 9 Mar 2026 19:18:18 -0500 Subject: [PATCH 050/107] Add tests for module::remove_instruction changes --- test/debug_symbols_test.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index ff5e5297b9f..34f85ed4591 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -136,7 +136,23 @@ TEST_CASE(pw_used_twice_fused) EXPECT(p1.sort() == p2.sort()); } -// To check that the debug symbols don't propagate above the fusion +// Debug symbols should not propagate above the fusion boundary. +// The gemm (dot) keeps its own symbol; only the fused adds merge. +// +// Before: After: +// +// x a x a +// \ / \ / +// dot {gemm1} dot {gemm1} +// | y / +// | / gemm y z +// add {add1} \ | / +// | z pointwise {add1, add2} +// | / | +// add {add2} @return +// | +// @return +// TEST_CASE(gemm_add_add) { migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; From 162c811360e5b52866c42347c1ffa2f171b85e57 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 9 Mar 2026 19:41:57 -0500 Subject: [PATCH 051/107] Add debug symbols to MXR --- src/program.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/program.cpp b/src/program.cpp index 7b5cabbc0b1..0a4221e0211 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -685,7 +685,7 @@ static std::string get_migraphx_version() program file version is for the data structure or format of the MXR file. Version should be bumped if any changes occur to the format of the MXR file. */ -const int program_file_version = 7; +const int program_file_version = 8; value program::to_value() const { @@ -730,12 +730,11 @@ value program::to_value() const [&](auto mod_ref) { return mod_ref->name(); }); node["module_inputs"] = module_inputs; } - + nodes["debug_symbols"] = migraphx::to_value(ins->get_debug_symbols()); nodes.push_back(node); }, names); mod_val["nodes"] = nodes; - module_vals[mod->name()] = mod_val; } @@ -809,6 +808,7 @@ static void mod_from_val(module_ref mod, } } output->set_normalized(normalized); + output->add_debug_symbols(node.at("debug_symbols").to>()); instructions[node.at("output").to()] = output; } } From acffe7e269dd73f0dd0d5e0d0fe17c354f26f979 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Mar 2026 12:42:26 -0500 Subject: [PATCH 052/107] Tidy fix --- src/targets/gpu/compile_gen.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 5a08677220b..7adf6c69230 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -303,9 +303,9 @@ std::size_t find_fast_axis(const std::vector& inputs) return it - permutation.begin(); } -std::string make_transformer_args(std::vector transformers) +std::string make_transformer_args(const std::vector& transformers) { - return join_strings(std::move(transformers), ", "); + return join_strings(transformers, ", "); } static void generate_pointwise(cpp_generator& gg, From 3e014fba885c488c9f1d354f2f77f2ac1c3f33d3 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Mar 2026 15:44:28 -0500 Subject: [PATCH 053/107] Revert stringutils change --- src/include/migraphx/stringutils.hpp | 20 +++------ src/targets/gpu/compile_gen.cpp | 63 ++++++---------------------- 2 files changed, 18 insertions(+), 65 deletions(-) diff --git a/src/include/migraphx/stringutils.hpp b/src/include/migraphx/stringutils.hpp index 022e11571bf..47c13fd3556 100644 --- a/src/include/migraphx/stringutils.hpp +++ b/src/include/migraphx/stringutils.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -71,17 +71,16 @@ inline bool ends_with(const std::string& value, const std::string& suffix) } template -inline std::string join_strings(const Strings& strings, const std::string& delim) +inline std::string join_strings(Strings strings, const std::string& delim) { auto it = strings.begin(); if(it == strings.end()) return ""; auto nit = std::next(it); - return std::accumulate( - nit, strings.end(), *it, [&](const std::string& x, const std::string& y) { - return x + delim + y; - }); + return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) { + return std::move(x) + delim + std::move(y); + }); } inline std::vector split_string(const std::string& s, char delim) @@ -225,15 +224,6 @@ inline auto to_string(const T& x) return ss.str(); } -template -inline auto to_hex_float(const T& x) - -> decltype((std::declval() << x), std::string{}) -{ - std::stringstream ss; - ss << std::hexfloat << x; - return ss.str(); -} - } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 7adf6c69230..ced8ab03bae 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -247,8 +247,8 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) auto tile_size = dim1 * dim2; result.ntiles = s.elements() / tile_size; - // equivalent to dim2 * (dim1 + 1) to avoid bank conflicts - auto tile_bytes = (tile_size + dim2) * s.type_size(); + // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts + auto tile_bytes = (tile_size + dim1) * s.type_size(); if(tile_bytes > 65536) return {}; @@ -303,9 +303,9 @@ std::size_t find_fast_axis(const std::vector& inputs) return it - permutation.begin(); } -std::string make_transformer_args(const std::vector& transformers) +std::string make_transformer_args(std::vector transformers) { - return join_strings(transformers, ", "); + return join_strings(std::move(transformers), ", "); } static void generate_pointwise(cpp_generator& gg, @@ -409,31 +409,6 @@ void reduce_op::set(instruction_ref ins, const operation& op) set(rop.name(), input, output); read = "compose(array_apply(" + read + "), MIGRAPHX_LIFT(make_array))"; } - else if(op.name() == "gpu::arg_reduce") - { - // extract the inner argmin/argmax operation - auto inner_op = from_value(op.to_value().at("op")); - auto inner_v = inner_op.to_value(); - bool select_last = inner_v.get("select_last_index", false); - std::string select_last_str = select_last ? "true" : "false"; - - if(inner_op.name() == "argmin") - { - reduction = "op::argmin<" + select_last_str + ">{}"; - init = "make_tuple(highest{}, index_int{0})"; - } - else if(inner_op.name() == "argmax") - { - reduction = "op::argmax<" + select_last_str + ">{}"; - init = "make_tuple(lowest{}, index_int{0})"; - } - else - { - MIGRAPHX_THROW("Unsupported arg operation"); - } - // read creates tuples from (value, index), cast index to index_int - read = "[](auto val, auto idx) { return make_tuple(val, static_cast(idx)); }"; - } else { set(op.name(), ins->inputs().front()->get_shape(), ins->get_shape()); @@ -476,17 +451,6 @@ static void preload_params(module& m) } } -static std::vector get_rlens(const module& m) -{ - auto reduce = std::find_if( - m.begin(), m.end(), [&](const auto& ins) { return contains(ins.name(), "reduce"); }); - if(reduce == m.end()) - MIGRAPHX_THROW("Missing reduce operator"); - if(reduce->get_shape().type() == shape::tuple_type) - return reduce->get_shape().sub_shapes().front().lens(); - return reduce->get_shape().lens(); -} - std::string generate_reduce(module m, const std::string& name) { preload_params(m); @@ -495,7 +459,11 @@ std::string generate_reduce(module m, const std::string& name) cpp_generator g; g.always_return_tuple(); auto param_shapes = m.get_parameter_shapes(); - auto rlens = get_rlens(m); + auto max_shape = + std::max_element(param_shapes.begin(), + param_shapes.end(), + by(std::less<>{}, [](const auto& p) { return p.second.elements(); })); + auto ilens = max_shape->second.lens(); std::size_t i = 0; auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) { if(contains(ins->name(), "reduce")) @@ -512,7 +480,7 @@ std::string generate_reduce(module m, const std::string& name) ins->inputs().end(), std::back_inserter(tensors), [&](auto input) { - return input->get_shape().lens() != rlens and + return input->get_shape().lens() == ilens and not input->get_shape().broadcasted(); }); auto inner_names = names; @@ -553,13 +521,8 @@ std::string generate_reduce(module m, const std::string& name) { const auto& x = names.at(ins->inputs().front()); auto index = ins->get_operator().to_value()["index"].to(); - return interpolate_string("${x}[_c<${index}>]", - {{"x", x}, {"index", std::to_string(index)}}); - } - if(ins->name() == "gpu::make_indices") - { - auto size = ins->get_operator().to_value()["size"].to(); - return "reduce::make_indices(_c<" + std::to_string(size) + ">)"; + return interpolate_string("${x}[${index}]", + {{"x", x}, {"index", std::to_string(index)}}); } if(ins->name() == "identity") { From 3393535055e5c2ecc43712e25cfd6c6cb53d0c16 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 10 Mar 2026 15:46:33 -0500 Subject: [PATCH 054/107] Revert properly --- src/include/migraphx/stringutils.hpp | 11 +++++- src/targets/gpu/compile_gen.cpp | 59 ++++++++++++++++++++++------ 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/src/include/migraphx/stringutils.hpp b/src/include/migraphx/stringutils.hpp index 47c13fd3556..c4618036f5c 100644 --- a/src/include/migraphx/stringutils.hpp +++ b/src/include/migraphx/stringutils.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -224,6 +224,15 @@ inline auto to_string(const T& x) return ss.str(); } +template +inline auto to_hex_float(const T& x) + -> decltype((std::declval() << x), std::string{}) +{ + std::stringstream ss; + ss << std::hexfloat << x; + return ss.str(); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index ced8ab03bae..5a08677220b 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -247,8 +247,8 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) auto tile_size = dim1 * dim2; result.ntiles = s.elements() / tile_size; - // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts - auto tile_bytes = (tile_size + dim1) * s.type_size(); + // equivalent to dim2 * (dim1 + 1) to avoid bank conflicts + auto tile_bytes = (tile_size + dim2) * s.type_size(); if(tile_bytes > 65536) return {}; @@ -409,6 +409,31 @@ void reduce_op::set(instruction_ref ins, const operation& op) set(rop.name(), input, output); read = "compose(array_apply(" + read + "), MIGRAPHX_LIFT(make_array))"; } + else if(op.name() == "gpu::arg_reduce") + { + // extract the inner argmin/argmax operation + auto inner_op = from_value(op.to_value().at("op")); + auto inner_v = inner_op.to_value(); + bool select_last = inner_v.get("select_last_index", false); + std::string select_last_str = select_last ? "true" : "false"; + + if(inner_op.name() == "argmin") + { + reduction = "op::argmin<" + select_last_str + ">{}"; + init = "make_tuple(highest{}, index_int{0})"; + } + else if(inner_op.name() == "argmax") + { + reduction = "op::argmax<" + select_last_str + ">{}"; + init = "make_tuple(lowest{}, index_int{0})"; + } + else + { + MIGRAPHX_THROW("Unsupported arg operation"); + } + // read creates tuples from (value, index), cast index to index_int + read = "[](auto val, auto idx) { return make_tuple(val, static_cast(idx)); }"; + } else { set(op.name(), ins->inputs().front()->get_shape(), ins->get_shape()); @@ -451,6 +476,17 @@ static void preload_params(module& m) } } +static std::vector get_rlens(const module& m) +{ + auto reduce = std::find_if( + m.begin(), m.end(), [&](const auto& ins) { return contains(ins.name(), "reduce"); }); + if(reduce == m.end()) + MIGRAPHX_THROW("Missing reduce operator"); + if(reduce->get_shape().type() == shape::tuple_type) + return reduce->get_shape().sub_shapes().front().lens(); + return reduce->get_shape().lens(); +} + std::string generate_reduce(module m, const std::string& name) { preload_params(m); @@ -459,11 +495,7 @@ std::string generate_reduce(module m, const std::string& name) cpp_generator g; g.always_return_tuple(); auto param_shapes = m.get_parameter_shapes(); - auto max_shape = - std::max_element(param_shapes.begin(), - param_shapes.end(), - by(std::less<>{}, [](const auto& p) { return p.second.elements(); })); - auto ilens = max_shape->second.lens(); + auto rlens = get_rlens(m); std::size_t i = 0; auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) { if(contains(ins->name(), "reduce")) @@ -480,7 +512,7 @@ std::string generate_reduce(module m, const std::string& name) ins->inputs().end(), std::back_inserter(tensors), [&](auto input) { - return input->get_shape().lens() == ilens and + return input->get_shape().lens() != rlens and not input->get_shape().broadcasted(); }); auto inner_names = names; @@ -521,8 +553,13 @@ std::string generate_reduce(module m, const std::string& name) { const auto& x = names.at(ins->inputs().front()); auto index = ins->get_operator().to_value()["index"].to(); - return interpolate_string("${x}[${index}]", - {{"x", x}, {"index", std::to_string(index)}}); + return interpolate_string("${x}[_c<${index}>]", + {{"x", x}, {"index", std::to_string(index)}}); + } + if(ins->name() == "gpu::make_indices") + { + auto size = ins->get_operator().to_value()["size"].to(); + return "reduce::make_indices(_c<" + std::to_string(size) + ">)"; } if(ins->name() == "identity") { From bb471fb06e13b34543c908d741041bd39c18a4e4 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 11 Mar 2026 17:12:12 -0500 Subject: [PATCH 055/107] Prototype version with commented out new reaches code --- src/driver/main.cpp | 6 + src/include/migraphx/instruction.hpp | 2 + src/instruction.cpp | 80 ++++++++++++ src/module.cpp | 178 +++++++++++++++++++++------ 4 files changed, 226 insertions(+), 40 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 99f79fd5570..37999628d26 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -189,6 +189,7 @@ struct loader bool skip_unknown_operators = false; bool brief = false; bool verbose = false; + bool use_debug_symbols = false; std::string output_type; std::string output; std::string default_dyn_dim; @@ -221,6 +222,10 @@ struct loader ap.help("Skip unknown operators when parsing and continue to parse."), ap.set_value(true)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); + ap(use_debug_symbols, + {"--debug-symbols"}, + ap.help("Parse ONNX node names into MIGX instructions and propagate them as debug symbols."), + ap.set_value(true)); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim_size, {"--trim-size", "-s"}, ap.help("Number of instructions in the trim model")); ap(param_dims, @@ -404,6 +409,7 @@ struct loader } options.skip_unknown_operators = skip_unknown_operators; options.print_program_on_error = true; + options.use_debug_symbols = use_debug_symbols; options.map_input_dims = map_input_dims; options.map_dyn_input_dims = map_dyn_input_dims; options.dim_params = map_dim_params; diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 45cc5f9918c..a73a2b065c4 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -51,6 +51,8 @@ MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end); MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end, const_module_ref m); +MIGRAPHX_EXPORT bool reaches(const std::unordered_set& starts, instruction_ref end, const_module_ref m); + MIGRAPHX_EXPORT bool is_interdependent(const std::vector& instructions, const_module_ref m, instruction_ref root); diff --git a/src/instruction.cpp b/src/instruction.cpp index b6b26a2222c..4b4ee11380b 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -599,6 +599,8 @@ static auto track_visits(instruction_ref start, instruction_ref end, F f) std::size_t n = std::distance(start, end); if(n < small) { + // Stop condition is ins distance to end > N or + // same instruction already visited. std::bitset visited; auto stop = [&](auto ins) { auto i = std::distance(ins, end); @@ -613,6 +615,9 @@ static auto track_visits(instruction_ref start, instruction_ref end, F f) } else { + // Make a hashmap of instructions between start and end. + // Stop condition is instruction not in the hashmap or + // same instruction already visited. auto instructions = range(start, std::next(end)); auto instruction_refs = iterator_for(instructions); std::unordered_set in_range(instruction_refs.begin(), @@ -622,6 +627,49 @@ static auto track_visits(instruction_ref start, instruction_ref end, F f) } } +// Version of track visits that works on an array of starting instructions +template +static auto track_visits(const T& starts, instruction_ref end, F f) +{ + const std::size_t small = 16; + instruction_ref first_start = *starts.begin(); + std::size_t max_distance = std::distance(first_start, end); + instruction_ref farthest_start = first_start; + // TODO optimize this loop + for(instruction_ref start : starts) + { + std::size_t dist = std::distance(start, end); + if(dist > max_distance) + { + max_distance = dist; + farthest_start = start; + } + } + if(max_distance < small) + { + std::bitset visited; + auto stop = [&](auto ins) { + auto i = std::distance(ins, end); + if(i > max_distance) + return true; + if(visited.test(i)) + return true; + visited.set(i); + return false; + }; + return f(stop); + } + else + { + auto instructions = range(farthest_start, std::next(end)); + auto instruction_refs = iterator_for(instructions); + std::unordered_set in_range(instruction_refs.begin(), + instruction_refs.end()); + auto stop = [&](auto ins) { return in_range.erase(ins) == 0; }; + return f(stop); + } +} + // DFS through inputs of `end` to find `start`. // `start` must be positioned before `end`. bool reaches(instruction_ref start, instruction_ref end) @@ -668,6 +716,38 @@ bool reaches(instruction_ref start, instruction_ref end, const_module_ref m) return reaches(start, end, m, [](auto) { return false; }); } +// Version of reaches with a unordered_map of starting instructions +template +static bool reaches(const T& starts, instruction_ref end, const_module_ref m, P predicate) +{ + if(contains(starts, end)) + return true; + if(not m->has_instruction(end)) + return false; + for(auto start : starts) + { + if(not m->has_instruction(start)) + return false; + } + + return track_visits(starts, end, [&](auto stop) { + return fix([&](auto self, auto ins) -> bool { + if(not m->has_instruction(ins)) + return false; + if(contains(starts, ins) or predicate(ins)) + return true; + if(stop(ins)) + return false; + return std::any_of(ins->inputs().begin(), ins->inputs().end(), self); + })(end); + }); +} + +bool reaches(const std::unordered_set& starts, instruction_ref end, const_module_ref m) +{ + return reaches(starts, end, m, [](auto) { return false; }); +} + bool is_interdependent(const std::vector& instructions, const_module_ref m, instruction_ref root) diff --git a/src/module.cpp b/src/module.cpp index 98433dbd35d..ba410a5da8d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -366,13 +366,15 @@ static std::unordered_set gather_max_splice(const_module_ref m, { if(not m->has_instruction(input)) continue; - if(starts_with(input->name(), "@")) - continue; if(contains(result, input)) continue; if(any_of(input->outputs(), [&](instruction_ref output) { return not contains(result, output); })) + { + // include first instruction that is not solely depenedent + result.insert(input); continue; + } result.insert(input); self(input->inputs()); } @@ -380,30 +382,76 @@ static std::unordered_set gather_max_splice(const_module_ref m, return result; } +static std::unordered_set deduce_min_splice(const_module_ref m, const std::unordered_set& max_splice, const std::unordered_set& splice_intersection) +{ + std::unordered_set min_splice; + if(splice_intersection.empty()) + return min_splice; + + // make min_splice by going through max_splice instructions that come from the splice_intersection + // TODO reduce DFS calls by making another version of reaches that checks set of starting instructions + for(auto ins : max_splice) + { + if(contains(splice_intersection, ins)) + { + continue; + } + + //// If ins before max_splice no need to check. + //// Check up through instruction linked_list from ins for splice_intersection instructions. + //auto start_to_ins = range(m->begin(), std::next(ins)); + //auto instruction_refs = iterator_for(start_to_ins); + //if(std::any_of(instruction_refs.begin(), + // instruction_refs.end(), + // [&](instruction_ref i){return contains(splice_intersection, i);})) + //{ + // continue; + //} + + //if(reaches(splice_intersection, ins, m)) + //{ + // min_splice.insert(ins); + // continue; + //} + + for(auto intersect : splice_intersection) + { + if(std::distance(m->begin(), ins) < std::distance(m->begin(), intersect)) + { + continue; + } + if(reaches(intersect, ins, m)) + { + min_splice.insert(ins); + continue; + } + } + } + return min_splice; +} + static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, std::unordered_set old_max_splice) { - // Remove ins from old_max_splice, if it is there. To prevent it being in both old_max_splice - // and new_max_slice. - old_max_splice.erase(ins); std::unordered_set new_max_splice = gather_max_splice(m, ins); - // Remove instructions from old_max_splice that are also in new_max_splice to get the actual - // old_splice set notation: old_splice = {old_max_splice} - {new_max_splice} - std::unordered_set old_splice; + + // Find the common instructions between old_max_splice and new_max_splice. + // {old_max_splice} ∩ {new_max_splice} + // This is like calculating the lowest common ancestor + std::unordered_set splice_intersection; std::copy_if( old_max_splice.cbegin(), old_max_splice.cend(), - std::inserter(old_splice, old_splice.begin()), - [&new_max_splice](auto old_ins) { return (not contains(new_max_splice, old_ins)); }); - // Vice versa process - std::unordered_set new_splice; - std::copy_if( - new_max_splice.cbegin(), - new_max_splice.cend(), - std::inserter(new_splice, new_splice.begin()), - [&old_max_splice](auto new_ins) { return (not contains(old_max_splice, new_ins)); }); - std::set symbols = ins->get_debug_symbols(); + std::inserter(splice_intersection, splice_intersection.begin()), + [&new_max_splice](auto old_ins) { return contains(new_max_splice, old_ins); }); + splice_intersection.erase(ins); + + // Remove instructions from old_max_splice that are the intersection or output to the intersection + std::unordered_set old_splice = deduce_min_splice(m, old_max_splice, splice_intersection); + std::unordered_set new_splice = deduce_min_splice(m, new_max_splice, splice_intersection); + + std::set symbols; for(auto old_ins : old_splice) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); @@ -482,9 +530,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref if(has_debug_symbols()) { auto old_max_splice = gather_max_splice(this, ins); - // Remove rep incase it shows up in old_max_splice - old_max_splice.erase(rep); - propagate_debug_symbols(this, rep, std::move(old_max_splice)); + propagate_debug_symbols(this, rep, old_max_splice); } // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); @@ -521,14 +567,10 @@ std::vector module::batch_replace_instruction( for(const auto& replacer : replacers) { auto ms = gather_max_splice(this, replacer.ins); - // Remove ins from old_max_splice to prevent it being in both old_max_splice and - // new_max_slice - ms.erase(replacer.ins); old_max_splices.insert(ms.begin(), ms.end()); } } - std::set symbols; std::unordered_set new_max_splices; for(const auto& replacer : replacers) { @@ -540,36 +582,92 @@ std::vector module::batch_replace_instruction( ret.push_back(replacer.ins); if(has_debug_symbols()) { - auto ds = replacer.ins->get_debug_symbols(); - // add symbols from replacer.ins here because we removed replacer.ins from - // old_max_splice - symbols.insert(ds.begin(), ds.end()); auto ms = gather_max_splice(this, replacer.ins); new_max_splices.insert(ms.begin(), ms.end()); } } if(has_debug_symbols()) { - std::unordered_set new_splices; - std::copy_if( - new_max_splices.cbegin(), - new_max_splices.cend(), - std::inserter(new_splices, new_splices.begin()), - [&old_max_splices](auto new_ins) { return (not contains(old_max_splices, new_ins)); }); - std::unordered_set old_splices; + std::unordered_set splice_intersection; std::copy_if( old_max_splices.cbegin(), old_max_splices.cend(), - std::inserter(old_splices, old_splices.begin()), - [&new_max_splices](auto old_ins) { return (not contains(new_max_splices, old_ins)); }); - for(auto old_ins : old_splices) + std::inserter(splice_intersection, splice_intersection.begin()), + [&new_max_splices](auto old_ins) { return contains(new_max_splices, old_ins); }); + for(const auto& replacer : replacers) + { + splice_intersection.erase(replacer.ins); + } + // Remove instructions from old_max_splice that are the intersection or output to the intersection + std::unordered_set old_splice = deduce_min_splice(this, old_max_splices, splice_intersection); + std::unordered_set new_splice = deduce_min_splice(this, new_max_splices, splice_intersection); + std::set symbols; + for(auto old_ins : old_splice) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); } - for(auto new_ins : new_splices) + for(auto new_ins : new_splice) { add_debug_symbols(new_ins, symbols); } + + //std::cout << "old_max_splice: {\n"; + //for(auto s : old_max_splices) + //{ + // s->debug_print(); + //} + //std::cout << "}" << std::endl; + //std::cout << "new_max_splice: {\n"; + //for(auto s : new_max_splices) + //{ + // s->debug_print(); + //} + //std::cout << "}" << std::endl; + //std::cout << "splice intersection: {\n"; + //for(auto s : splice_intersection) + //{ + // s->debug_print(); + //} + //std::cout << "}" << std::endl; + //std::cout << "old splice: {"; + //for(auto s : old_splice) + //{ + // s->debug_print(); + //} + //std::cout << "}" << std::endl; + //std::cout << "old symbols: {"; + //for(auto s : symbols) + //{ + // std::cout << s << ", "; + //} + //std::cout << "}" << std::endl; + //std::cout << "new_splice:\n"; + //for(auto s : new_splice) + //{ + // s->debug_print(); + //} + //std::cout << std::endl; + + //std::unordered_set new_splices; + //std::copy_if( + // new_max_splices.cbegin(), + // new_max_splices.cend(), + // std::inserter(new_splices, new_splices.begin()), + // [&old_max_splices](auto new_ins) { return (not contains(old_max_splices, new_ins)); }); + //std::unordered_set old_splices; + //std::copy_if( + // old_max_splices.cbegin(), + // old_max_splices.cend(), + // std::inserter(old_splices, old_splices.begin()), + // [&new_max_splices](auto old_ins) { return (not contains(new_max_splices, old_ins)); }); + //for(auto old_ins : old_splices) + //{ + // copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + //} + //for(auto new_ins : new_splices) + //{ + // add_debug_symbols(new_ins, symbols); + //} } return ret; } From 4a1469b5deaab766b01d86c50832c17a6b65d5c0 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 11 Mar 2026 21:18:50 -0500 Subject: [PATCH 056/107] Optimize the DFS calls --- src/instruction.cpp | 44 +++++++++++++++++++++-------- src/module.cpp | 54 ++++++++++++++++++------------------ test/debug_symbols_test.cpp | 55 +++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 39 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 4b4ee11380b..6a232b0aded 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -629,21 +629,43 @@ static auto track_visits(instruction_ref start, instruction_ref end, F f) // Version of track visits that works on an array of starting instructions template -static auto track_visits(const T& starts, instruction_ref end, F f) +static auto track_visits(const_module_ref m, const T& starts, instruction_ref end, F f) { const std::size_t small = 16; - instruction_ref first_start = *starts.begin(); - std::size_t max_distance = std::distance(first_start, end); - instruction_ref farthest_start = first_start; // TODO optimize this loop - for(instruction_ref start : starts) - { - std::size_t dist = std::distance(start, end); - if(dist > max_distance) + //instruction_ref first_start = *starts.begin(); + //std::size_t max_distance = std::distance(first_start, end); + //instruction_ref farthest_start = first_start; + + //for(instruction_ref start : starts) + //{ + // std::size_t dist = std::distance(start, end); + // if(dist > max_distance) + // { + // max_distance = dist; + // farthest_start = start; + // } + //} + + // Find starts instruction with maximum distance from end + auto to_visit = starts; + instruction_ref ins = end; + std::size_t dist = 0; + std::size_t max_distance = 0; + instruction_ref farthest_start; + bool cond = (ins != m->begin()); + while(cond) + { + if(to_visit.empty() or ins == m->begin()) + cond = false; + if(contains(to_visit, ins)) { max_distance = dist; - farthest_start = start; + farthest_start = ins; + to_visit.erase(ins); } + ins = std::prev(ins); + dist++; } if(max_distance < small) { @@ -716,7 +738,7 @@ bool reaches(instruction_ref start, instruction_ref end, const_module_ref m) return reaches(start, end, m, [](auto) { return false; }); } -// Version of reaches with a unordered_map of starting instructions +// Version of reaches with an array of starting instructions template static bool reaches(const T& starts, instruction_ref end, const_module_ref m, P predicate) { @@ -730,7 +752,7 @@ static bool reaches(const T& starts, instruction_ref end, const_module_ref m, P return false; } - return track_visits(starts, end, [&](auto stop) { + return track_visits(m, starts, end, [&](auto stop) { return fix([&](auto self, auto ins) -> bool { if(not m->has_instruction(ins)) return false; diff --git a/src/module.cpp b/src/module.cpp index ba410a5da8d..060be1f602b 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -389,7 +389,6 @@ static std::unordered_set deduce_min_splice(const_module_ref m, return min_splice; // make min_splice by going through max_splice instructions that come from the splice_intersection - // TODO reduce DFS calls by making another version of reaches that checks set of starting instructions for(auto ins : max_splice) { if(contains(splice_intersection, ins)) @@ -397,35 +396,34 @@ static std::unordered_set deduce_min_splice(const_module_ref m, continue; } - //// If ins before max_splice no need to check. - //// Check up through instruction linked_list from ins for splice_intersection instructions. - //auto start_to_ins = range(m->begin(), std::next(ins)); - //auto instruction_refs = iterator_for(start_to_ins); - //if(std::any_of(instruction_refs.begin(), - // instruction_refs.end(), - // [&](instruction_ref i){return contains(splice_intersection, i);})) - //{ - // continue; - //} + // If ins is before all instructions in splice_intersection no need to check. + auto start_to_ins = range(m->begin(), std::next(ins)); + auto instruction_refs = iterator_for(start_to_ins); + if(std::none_of(instruction_refs.begin(), + instruction_refs.end(), + [&](instruction_ref x){return contains(splice_intersection, x);})) + { + continue; + } - //if(reaches(splice_intersection, ins, m)) - //{ - // min_splice.insert(ins); - // continue; - //} + if(reaches(splice_intersection, ins, m)) + { + min_splice.insert(ins); + continue; + } - for(auto intersect : splice_intersection) - { - if(std::distance(m->begin(), ins) < std::distance(m->begin(), intersect)) - { - continue; - } - if(reaches(intersect, ins, m)) - { - min_splice.insert(ins); - continue; - } - } + //for(auto intersect : splice_intersection) + //{ + // if(std::distance(m->begin(), ins) < std::distance(m->begin(), intersect)) + // { + // continue; + // } + // if(reaches(intersect, ins, m)) + // { + // min_splice.insert(ins); + // continue; + // } + //} } return min_splice; } diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 34f85ed4591..c47cfa3c62e 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -33,9 +33,11 @@ #include #include #include +#include #include #include +#include // Two adds fused into a single pointwise op via fuse_pointwise. // Both symbols should appear on the fused pointwise instruction. @@ -253,6 +255,59 @@ TEST_CASE(horiz_fusion_dot) EXPECT(m1.sort() == m2.sort()); } +// Goes through the find_pointwise_reduce matcher. +// Making sure that the debug symbols are getting from the minimum splice of the old instructions. +TEST_CASE(pointwise_reduce_debug_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto curr = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(curr, {"add0"}); + curr = mm->add_instruction(migraphx::make_op("relu"), curr); + mm->add_debug_symbols(curr, {"relu0"}); + auto add = add_pointwise(p1, "main:pointwise0", {curr, z}, single_pointwise("add")); + mm->add_debug_symbols(add, {"pointwise"}); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add); + mm->add_debug_symbols(rsum, {"reduce_sum"}); + curr = mm->add_instruction(migraphx::make_op("relu"), rsum); + mm->add_debug_symbols(curr, {"relu1"}); + mm->add_return({curr}); + } + migraphx::run_passes(p1, {migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto curr = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(curr, {"add0"}); + curr = mm->add_instruction(migraphx::make_op("relu"), curr); + mm->add_debug_symbols(curr, {"relu0"}); + auto rsum = add_reduce( + p2, + "main:pointwise0:main:reduce_sum0", + {curr, z}, + {1}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto add = + add_pointwise(p2, rm, "main:pointwise0", inputs, single_pointwise("add")); + return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add); + }); + mm->add_debug_symbols(rsum, {"pointwise", "reduce_sum"}); + curr = mm->add_instruction(migraphx::make_op("relu"), rsum); + mm->add_debug_symbols(curr, {"relu1"}); + mm->add_return({curr}); + } + EXPECT(p1 == p2); +} + // Tests symbol propagation through add reassociation in simplify_algebra // (find_double_add_lit_broadcast). Checks add(add(x,1), add(y,2)) -> (add(add(x,y), add(1,2)). // From abfa3d259bb86dd73afdc3e95ad907c484ea6ddb Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 11 Mar 2026 21:30:03 -0500 Subject: [PATCH 057/107] Commented out code cleanup --- src/instruction.cpp | 15 ---------- src/module.cpp | 71 --------------------------------------------- 2 files changed, 86 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 6a232b0aded..a245b51adb5 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -632,21 +632,6 @@ template static auto track_visits(const_module_ref m, const T& starts, instruction_ref end, F f) { const std::size_t small = 16; - // TODO optimize this loop - //instruction_ref first_start = *starts.begin(); - //std::size_t max_distance = std::distance(first_start, end); - //instruction_ref farthest_start = first_start; - - //for(instruction_ref start : starts) - //{ - // std::size_t dist = std::distance(start, end); - // if(dist > max_distance) - // { - // max_distance = dist; - // farthest_start = start; - // } - //} - // Find starts instruction with maximum distance from end auto to_visit = starts; instruction_ref ins = end; diff --git a/src/module.cpp b/src/module.cpp index 060be1f602b..f9eccec002d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -411,19 +411,6 @@ static std::unordered_set deduce_min_splice(const_module_ref m, min_splice.insert(ins); continue; } - - //for(auto intersect : splice_intersection) - //{ - // if(std::distance(m->begin(), ins) < std::distance(m->begin(), intersect)) - // { - // continue; - // } - // if(reaches(intersect, ins, m)) - // { - // min_splice.insert(ins); - // continue; - // } - //} } return min_splice; } @@ -608,64 +595,6 @@ std::vector module::batch_replace_instruction( { add_debug_symbols(new_ins, symbols); } - - //std::cout << "old_max_splice: {\n"; - //for(auto s : old_max_splices) - //{ - // s->debug_print(); - //} - //std::cout << "}" << std::endl; - //std::cout << "new_max_splice: {\n"; - //for(auto s : new_max_splices) - //{ - // s->debug_print(); - //} - //std::cout << "}" << std::endl; - //std::cout << "splice intersection: {\n"; - //for(auto s : splice_intersection) - //{ - // s->debug_print(); - //} - //std::cout << "}" << std::endl; - //std::cout << "old splice: {"; - //for(auto s : old_splice) - //{ - // s->debug_print(); - //} - //std::cout << "}" << std::endl; - //std::cout << "old symbols: {"; - //for(auto s : symbols) - //{ - // std::cout << s << ", "; - //} - //std::cout << "}" << std::endl; - //std::cout << "new_splice:\n"; - //for(auto s : new_splice) - //{ - // s->debug_print(); - //} - //std::cout << std::endl; - - //std::unordered_set new_splices; - //std::copy_if( - // new_max_splices.cbegin(), - // new_max_splices.cend(), - // std::inserter(new_splices, new_splices.begin()), - // [&old_max_splices](auto new_ins) { return (not contains(old_max_splices, new_ins)); }); - //std::unordered_set old_splices; - //std::copy_if( - // old_max_splices.cbegin(), - // old_max_splices.cend(), - // std::inserter(old_splices, old_splices.begin()), - // [&new_max_splices](auto old_ins) { return (not contains(new_max_splices, old_ins)); }); - //for(auto old_ins : old_splices) - //{ - // copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); - //} - //for(auto new_ins : new_splices) - //{ - // add_debug_symbols(new_ins, symbols); - //} } return ret; } From 0abea08f8c9f115b82d08c3ac527eaab69c04c49 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 12 Mar 2026 14:57:09 -0500 Subject: [PATCH 058/107] Fix serialization --- src/program.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/program.cpp b/src/program.cpp index 0a4221e0211..75c3ec60777 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -730,7 +730,10 @@ value program::to_value() const [&](auto mod_ref) { return mod_ref->name(); }); node["module_inputs"] = module_inputs; } - nodes["debug_symbols"] = migraphx::to_value(ins->get_debug_symbols()); + if(not ins->get_debug_symbols().empty()) + { + nodes["debug_symbols"] = migraphx::to_value(ins->get_debug_symbols()); + } nodes.push_back(node); }, names); @@ -808,7 +811,10 @@ static void mod_from_val(module_ref mod, } } output->set_normalized(normalized); - output->add_debug_symbols(node.at("debug_symbols").to>()); + if(node.contains("debug_symbols")) + { + output->add_debug_symbols(node.at("debug_symbols").to>()); + } instructions[node.at("output").to()] = output; } } From 6b3cce7b24b62320fe2a2d304df4dfee52264270 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 12 Mar 2026 14:58:51 -0500 Subject: [PATCH 059/107] Formatting --- src/driver/main.cpp | 3 ++- src/instruction.cpp | 14 ++++++++------ src/module.cpp | 43 ++++++++++++++++++++++++++----------------- src/program.cpp | 2 +- 4 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 37999628d26..57ccda626d2 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -224,7 +224,8 @@ struct loader ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(use_debug_symbols, {"--debug-symbols"}, - ap.help("Parse ONNX node names into MIGX instructions and propagate them as debug symbols."), + ap.help( + "Parse ONNX node names into MIGX instructions and propagate them as debug symbols."), ap.set_value(true)); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim_size, {"--trim-size", "-s"}, ap.help("Number of instructions in the trim model")); diff --git a/src/instruction.cpp b/src/instruction.cpp index a245b51adb5..c0a4b5b7a86 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -616,7 +616,7 @@ static auto track_visits(instruction_ref start, instruction_ref end, F f) else { // Make a hashmap of instructions between start and end. - // Stop condition is instruction not in the hashmap or + // Stop condition is instruction not in the hashmap or // same instruction already visited. auto instructions = range(start, std::next(end)); auto instruction_refs = iterator_for(instructions); @@ -633,9 +633,9 @@ static auto track_visits(const_module_ref m, const T& starts, instruction_ref en { const std::size_t small = 16; // Find starts instruction with maximum distance from end - auto to_visit = starts; - instruction_ref ins = end; - std::size_t dist = 0; + auto to_visit = starts; + instruction_ref ins = end; + std::size_t dist = 0; std::size_t max_distance = 0; instruction_ref farthest_start; bool cond = (ins != m->begin()); @@ -645,7 +645,7 @@ static auto track_visits(const_module_ref m, const T& starts, instruction_ref en cond = false; if(contains(to_visit, ins)) { - max_distance = dist; + max_distance = dist; farthest_start = ins; to_visit.erase(ins); } @@ -750,7 +750,9 @@ static bool reaches(const T& starts, instruction_ref end, const_module_ref m, P }); } -bool reaches(const std::unordered_set& starts, instruction_ref end, const_module_ref m) +bool reaches(const std::unordered_set& starts, + instruction_ref end, + const_module_ref m) { return reaches(starts, end, m, [](auto) { return false; }); } diff --git a/src/module.cpp b/src/module.cpp index f9eccec002d..7bab2bca890 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -382,13 +382,17 @@ static std::unordered_set gather_max_splice(const_module_ref m, return result; } -static std::unordered_set deduce_min_splice(const_module_ref m, const std::unordered_set& max_splice, const std::unordered_set& splice_intersection) +static std::unordered_set +deduce_min_splice(const_module_ref m, + const std::unordered_set& max_splice, + const std::unordered_set& splice_intersection) { std::unordered_set min_splice; if(splice_intersection.empty()) return min_splice; - // make min_splice by going through max_splice instructions that come from the splice_intersection + // make min_splice by going through max_splice instructions that come from the + // splice_intersection for(auto ins : max_splice) { if(contains(splice_intersection, ins)) @@ -399,9 +403,9 @@ static std::unordered_set deduce_min_splice(const_module_ref m, // If ins is before all instructions in splice_intersection no need to check. auto start_to_ins = range(m->begin(), std::next(ins)); auto instruction_refs = iterator_for(start_to_ins); - if(std::none_of(instruction_refs.begin(), - instruction_refs.end(), - [&](instruction_ref x){return contains(splice_intersection, x);})) + if(std::none_of(instruction_refs.begin(), instruction_refs.end(), [&](instruction_ref x) { + return contains(splice_intersection, x); + })) { continue; } @@ -423,18 +427,20 @@ static void propagate_debug_symbols(const_module_ref m, // Find the common instructions between old_max_splice and new_max_splice. // {old_max_splice} ∩ {new_max_splice} - // This is like calculating the lowest common ancestor + // This is like calculating the lowest common ancestor std::unordered_set splice_intersection; - std::copy_if( - old_max_splice.cbegin(), - old_max_splice.cend(), - std::inserter(splice_intersection, splice_intersection.begin()), - [&new_max_splice](auto old_ins) { return contains(new_max_splice, old_ins); }); + std::copy_if(old_max_splice.cbegin(), + old_max_splice.cend(), + std::inserter(splice_intersection, splice_intersection.begin()), + [&new_max_splice](auto old_ins) { return contains(new_max_splice, old_ins); }); splice_intersection.erase(ins); - // Remove instructions from old_max_splice that are the intersection or output to the intersection - std::unordered_set old_splice = deduce_min_splice(m, old_max_splice, splice_intersection); - std::unordered_set new_splice = deduce_min_splice(m, new_max_splice, splice_intersection); + // Remove instructions from old_max_splice that are the intersection or output to the + // intersection + std::unordered_set old_splice = + deduce_min_splice(m, old_max_splice, splice_intersection); + std::unordered_set new_splice = + deduce_min_splice(m, new_max_splice, splice_intersection); std::set symbols; for(auto old_ins : old_splice) @@ -583,9 +589,12 @@ std::vector module::batch_replace_instruction( { splice_intersection.erase(replacer.ins); } - // Remove instructions from old_max_splice that are the intersection or output to the intersection - std::unordered_set old_splice = deduce_min_splice(this, old_max_splices, splice_intersection); - std::unordered_set new_splice = deduce_min_splice(this, new_max_splices, splice_intersection); + // Remove instructions from old_max_splice that are the intersection or output to the + // intersection + std::unordered_set old_splice = + deduce_min_splice(this, old_max_splices, splice_intersection); + std::unordered_set new_splice = + deduce_min_splice(this, new_max_splices, splice_intersection); std::set symbols; for(auto old_ins : old_splice) { diff --git a/src/program.cpp b/src/program.cpp index 75c3ec60777..104b28e94d7 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -737,7 +737,7 @@ value program::to_value() const nodes.push_back(node); }, names); - mod_val["nodes"] = nodes; + mod_val["nodes"] = nodes; module_vals[mod->name()] = mod_val; } From 5162ceb0f77795c47cd85652e647f24e29b61a06 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 12 Mar 2026 14:59:43 -0500 Subject: [PATCH 060/107] Tidy fix --- src/module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/module.cpp b/src/module.cpp index 7bab2bca890..1b58fade0dc 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -421,7 +421,7 @@ deduce_min_splice(const_module_ref m, static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, - std::unordered_set old_max_splice) + const std::unordered_set& old_max_splice) { std::unordered_set new_max_splice = gather_max_splice(m, ins); From a2a4bedac43a5d516681fb9ac6a8c2db75e4d818 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 16 Mar 2026 13:30:23 -0500 Subject: [PATCH 061/107] revert test.hpp changes --- test/include/test.hpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/test/include/test.hpp b/test/include/test.hpp index c3ca931c6ae..7ae503e2db4 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -220,10 +220,10 @@ template struct lhs_expression; template -decltype(auto) make_lhs_expression(T&& lhs); +lhs_expression make_lhs_expression(T&& lhs); template -decltype(auto) make_lhs_expression(T&& lhs, Operator); +lhs_expression make_lhs_expression(T&& lhs, Operator); // NOLINTNEXTLINE #define TEST_EXPR_BINARY_OPERATOR(op, name) \ @@ -259,32 +259,31 @@ struct expression TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) }; +// TODO: Remove rvalue references template -decltype(auto) make_expression(T&& lhs, U&& rhs, Operator) +expression make_expression(T&& lhs, U&& rhs, Operator) { - // rvalue references to pass by value - return expression, std::decay_t, Operator>{lhs, rhs}; + return {std::forward(lhs), std::forward(rhs)}; } +// TODO: Remove rvalue reference template -decltype(auto) make_lhs_expression(T&& lhs) +lhs_expression make_lhs_expression(T&& lhs) { - // rvalue references to pass by value - return lhs_expression>{lhs}; + return lhs_expression{std::forward(lhs)}; } template -decltype(auto) make_lhs_expression(T&& lhs, Operator) +lhs_expression make_lhs_expression(T&& lhs, Operator) { - // rvalue references to pass by value - return lhs_expression, Operator>{lhs}; + return lhs_expression{std::forward(lhs)}; } template struct lhs_expression { T lhs; - explicit lhs_expression(const T& e) : lhs(e) {} + explicit lhs_expression(T e) : lhs(std::move(e)) {} friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) { From 3805caba43ecfa8762c3dcb6e29c78664e38c34d Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 17 Mar 2026 19:03:49 -0500 Subject: [PATCH 062/107] Replace instruction refactor To go from common ancestors down and avoid having max splice be the whole graph if possible. --- src/module.cpp | 172 ++++++++++++++++++++++++++----------------------- 1 file changed, 92 insertions(+), 80 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 1b58fade0dc..e2f9c632c71 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -119,6 +119,8 @@ struct module_impl { changed.notify(); instruction_set.erase(std::addressof(*pos)); + if(num_ins_with_debug_symbols > 0) + --num_ins_with_debug_symbols; return instructions.erase(pos); } @@ -355,10 +357,12 @@ instruction_ref module::insert_instruction(instruction_ref ins, /** * Traverse inputs of `ins` and gather instructions that output only to `ins`. * This splice is the total possibility of instructions that could be spliced by a - * replace_instruction. + * replace_instruction and including the first not solely depenedent instruction. + * stops : optional set of instructions to stop search on if encoutered. **/ static std::unordered_set gather_max_splice(const_module_ref m, - instruction_ref ins) + instruction_ref ins, + std::unordered_set stops = {}) { std::unordered_set result = {ins}; fix([&](auto self, const std::vector& inputs) { @@ -368,88 +372,84 @@ static std::unordered_set gather_max_splice(const_module_ref m, continue; if(contains(result, input)) continue; - if(any_of(input->outputs(), + result.insert(input); + if(contains(stops, input) or any_of(input->outputs(), [&](instruction_ref output) { return not contains(result, output); })) { - // include first instruction that is not solely depenedent - result.insert(input); + // include first instruction that is not solely depenedent or in stops continue; } - result.insert(input); self(input->inputs()); } })(ins->inputs()); return result; } +/** + * end: instruction at end of splice + */ static std::unordered_set -deduce_min_splice(const_module_ref m, +deduce_min_splice(std::vector ends, const std::unordered_set& max_splice, - const std::unordered_set& splice_intersection) + const std::unordered_set& common_ancestors) { std::unordered_set min_splice; - if(splice_intersection.empty()) + min_splice.insert(ends.begin(), ends.end()); + if(common_ancestors.empty()) return min_splice; - - // make min_splice by going through max_splice instructions that come from the - // splice_intersection - for(auto ins : max_splice) + + // make min_splice by gathering outputs of common_ancestors within the max_splice + for(auto anc : common_ancestors) { - if(contains(splice_intersection, ins)) - { - continue; - } - - // If ins is before all instructions in splice_intersection no need to check. - auto start_to_ins = range(m->begin(), std::next(ins)); - auto instruction_refs = iterator_for(start_to_ins); - if(std::none_of(instruction_refs.begin(), instruction_refs.end(), [&](instruction_ref x) { - return contains(splice_intersection, x); - })) - { - continue; - } - - if(reaches(splice_intersection, ins, m)) - { - min_splice.insert(ins); - continue; - } + fix([&](auto self, const auto& outputs) { + for(auto output : outputs) + { + if(not contains(max_splice, output)) + continue; + if(contains(min_splice, output)) + continue; + if(contains(common_ancestors, output)) + continue; + min_splice.insert(output); + self(output->outputs()); + } + })(anc->outputs()); } return min_splice; } +/** + * ins: instruction that was/will be replaced + * rep: replacing instruction + */ static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, - const std::unordered_set& old_max_splice) + instruction_ref rep, + std::unordered_set new_max_splice, + std::unordered_set old_max_splice) { - std::unordered_set new_max_splice = gather_max_splice(m, ins); - + // TODO: can get common ancestors within gather_max_splice as an optimization // Find the common instructions between old_max_splice and new_max_splice. // {old_max_splice} ∩ {new_max_splice} - // This is like calculating the lowest common ancestor - std::unordered_set splice_intersection; + // This is like calculating the lowest common ancestors + std::unordered_set common_ancestors; std::copy_if(old_max_splice.cbegin(), old_max_splice.cend(), - std::inserter(splice_intersection, splice_intersection.begin()), + std::inserter(common_ancestors, common_ancestors.begin()), [&new_max_splice](auto old_ins) { return contains(new_max_splice, old_ins); }); - splice_intersection.erase(ins); - // Remove instructions from old_max_splice that are the intersection or output to the - // intersection - std::unordered_set old_splice = - deduce_min_splice(m, old_max_splice, splice_intersection); - std::unordered_set new_splice = - deduce_min_splice(m, new_max_splice, splice_intersection); + // Deduce the correct (minimum) splice from the max splice and the intersection + std::unordered_set old_splice = deduce_min_splice({ins}, old_max_splice, common_ancestors); + std::unordered_set new_splice = deduce_min_splice({rep}, new_max_splice, common_ancestors); std::set symbols; - for(auto old_ins : old_splice) + for(auto x : old_splice) { - copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); + copy(x->get_debug_symbols(), std::inserter(symbols, symbols.begin())); } - for(auto new_ins : new_splice) + for(auto x : new_splice) { - m->add_debug_symbols(new_ins, symbols); + m->add_debug_symbols(x, symbols); } } @@ -462,15 +462,19 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(not starts_with(op.name(), "@")); shape r = compute_shape(op, args); - std::unordered_set old_max_splice; + std::vector prev_args; if(has_debug_symbols()) { - old_max_splice = gather_max_splice(this, ins); + prev_args = ins->inputs(); } instruction::replace(ins, op, r, std::move(args)); if(has_debug_symbols()) { - propagate_debug_symbols(this, ins, std::move(old_max_splice)); + // placeholder identity instruction + auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); + std::unordered_set old_max_splice = gather_max_splice(this, id_ins); + std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); + propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); } assert(ins->valid(begin())); return ins; @@ -485,15 +489,18 @@ instruction_ref module::replace_instruction(instruction_ref ins, assert(has_instruction(ins)); assert(not starts_with(op.name(), "@")); auto out_shape = compute_shape(op, args, module_args); - std::unordered_set old_max_splice; + std::vector prev_args; if(has_debug_symbols()) { - old_max_splice = gather_max_splice(this, ins); + prev_args = ins->inputs(); } instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); if(has_debug_symbols()) { - propagate_debug_symbols(this, ins, std::move(old_max_splice)); + auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); + std::unordered_set old_max_splice = gather_max_splice(this, id_ins); + std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); + propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); } assert(ins->valid(begin())); return ins; @@ -520,9 +527,11 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref if(has_debug_symbols()) { - auto old_max_splice = gather_max_splice(this, ins); - propagate_debug_symbols(this, rep, old_max_splice); + std::unordered_set new_max_splice = gather_max_splice(this, rep); + std::unordered_set old_max_splice = gather_max_splice(this, ins, new_max_splice); + propagate_debug_symbols(this, ins, rep, new_max_splice, old_max_splice); } + // Make a copy of outputs which can be changed when calling replace_argument auto outputs = ins->outputs(); for(auto out : outputs) @@ -552,50 +561,53 @@ std::vector module::batch_replace_instruction( impl->changed.notify(); std::vector ret; std::unordered_set old_max_splices; - if(has_debug_symbols()) - { - // gather all previous debug symbols from max splices - for(const auto& replacer : replacers) - { - auto ms = gather_max_splice(this, replacer.ins); - old_max_splices.insert(ms.begin(), ms.end()); - } - } - std::unordered_set new_max_splices; for(const auto& replacer : replacers) { assert(has_instruction(replacer.ins)); assert(not starts_with(replacer.op.name(), "@")); + std::vector prev_args; + if(has_debug_symbols()) + { + prev_args = replacer.ins->inputs(); + } auto out_shape = compute_shape(replacer.op, replacer.args, replacer.module_args); instruction::replace( replacer.ins, replacer.op, out_shape, replacer.args, replacer.module_args); ret.push_back(replacer.ins); if(has_debug_symbols()) { - auto ms = gather_max_splice(this, replacer.ins); - new_max_splices.insert(ms.begin(), ms.end()); + auto id_ins = insert_instruction(replacer.ins, make_op("identity"), prev_args); + auto old_ms = gather_max_splice(this, id_ins); + auto new_ms = gather_max_splice(this, replacer.ins, old_ms); + old_max_splices.insert(old_ms.begin(), old_ms.end()); + new_max_splices.insert(new_ms.begin(), new_ms.end()); } } if(has_debug_symbols()) { - std::unordered_set splice_intersection; + std::unordered_set common_ancestors; std::copy_if( old_max_splices.cbegin(), old_max_splices.cend(), - std::inserter(splice_intersection, splice_intersection.begin()), + std::inserter(common_ancestors, common_ancestors.begin()), [&new_max_splices](auto old_ins) { return contains(new_max_splices, old_ins); }); - for(const auto& replacer : replacers) - { - splice_intersection.erase(replacer.ins); - } - // Remove instructions from old_max_splice that are the intersection or output to the - // intersection + + std::vector ends; + std::transform(replacers.begin(), replacers.end(), std::back_inserter(ends), [](const auto& rep){ return rep.ins; }); + std::unordered_set old_splice = - deduce_min_splice(this, old_max_splices, splice_intersection); + deduce_min_splice({ends}, old_max_splices, common_ancestors); std::unordered_set new_splice = - deduce_min_splice(this, new_max_splices, splice_intersection); + deduce_min_splice({ends}, new_max_splices, common_ancestors); + + // include in-place debug symbols if they're there std::set symbols; + for(const auto& replacer : replacers) + { + const auto& ds = replacer.ins->get_debug_symbols(); + symbols.insert(ds.begin(), ds.end()); + } for(auto old_ins : old_splice) { copy(old_ins->get_debug_symbols(), std::inserter(symbols, symbols.begin())); From b5e6816c1f6dea57c32d101fc0294b96dbb531cd Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 17 Mar 2026 19:08:01 -0500 Subject: [PATCH 063/107] revert instruction::reaches changes --- src/include/migraphx/instruction.hpp | 2 - src/instruction.cpp | 84 ---------------------------- 2 files changed, 86 deletions(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index a73a2b065c4..45cc5f9918c 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -51,8 +51,6 @@ MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end); MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end, const_module_ref m); -MIGRAPHX_EXPORT bool reaches(const std::unordered_set& starts, instruction_ref end, const_module_ref m); - MIGRAPHX_EXPORT bool is_interdependent(const std::vector& instructions, const_module_ref m, instruction_ref root); diff --git a/src/instruction.cpp b/src/instruction.cpp index c0a4b5b7a86..ad87d80899a 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -627,56 +627,6 @@ static auto track_visits(instruction_ref start, instruction_ref end, F f) } } -// Version of track visits that works on an array of starting instructions -template -static auto track_visits(const_module_ref m, const T& starts, instruction_ref end, F f) -{ - const std::size_t small = 16; - // Find starts instruction with maximum distance from end - auto to_visit = starts; - instruction_ref ins = end; - std::size_t dist = 0; - std::size_t max_distance = 0; - instruction_ref farthest_start; - bool cond = (ins != m->begin()); - while(cond) - { - if(to_visit.empty() or ins == m->begin()) - cond = false; - if(contains(to_visit, ins)) - { - max_distance = dist; - farthest_start = ins; - to_visit.erase(ins); - } - ins = std::prev(ins); - dist++; - } - if(max_distance < small) - { - std::bitset visited; - auto stop = [&](auto ins) { - auto i = std::distance(ins, end); - if(i > max_distance) - return true; - if(visited.test(i)) - return true; - visited.set(i); - return false; - }; - return f(stop); - } - else - { - auto instructions = range(farthest_start, std::next(end)); - auto instruction_refs = iterator_for(instructions); - std::unordered_set in_range(instruction_refs.begin(), - instruction_refs.end()); - auto stop = [&](auto ins) { return in_range.erase(ins) == 0; }; - return f(stop); - } -} - // DFS through inputs of `end` to find `start`. // `start` must be positioned before `end`. bool reaches(instruction_ref start, instruction_ref end) @@ -723,40 +673,6 @@ bool reaches(instruction_ref start, instruction_ref end, const_module_ref m) return reaches(start, end, m, [](auto) { return false; }); } -// Version of reaches with an array of starting instructions -template -static bool reaches(const T& starts, instruction_ref end, const_module_ref m, P predicate) -{ - if(contains(starts, end)) - return true; - if(not m->has_instruction(end)) - return false; - for(auto start : starts) - { - if(not m->has_instruction(start)) - return false; - } - - return track_visits(m, starts, end, [&](auto stop) { - return fix([&](auto self, auto ins) -> bool { - if(not m->has_instruction(ins)) - return false; - if(contains(starts, ins) or predicate(ins)) - return true; - if(stop(ins)) - return false; - return std::any_of(ins->inputs().begin(), ins->inputs().end(), self); - })(end); - }); -} - -bool reaches(const std::unordered_set& starts, - instruction_ref end, - const_module_ref m) -{ - return reaches(starts, end, m, [](auto) { return false; }); -} - bool is_interdependent(const std::vector& instructions, const_module_ref m, instruction_ref root) From 6671be0180acf2a2e32b26256237efa6a8475f9c Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 17 Mar 2026 19:09:52 -0500 Subject: [PATCH 064/107] Edit debug symbols print --- src/instruction.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index ad87d80899a..df9d521885d 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -454,7 +454,7 @@ void instruction::print(std::ostream& os, // print debug symbols if they exist if(not ins->debug_symbols.empty()) { - os << " # " << join_strings(ins->debug_symbols, ", ") << " #"; + os << " # " << join_strings(ins->debug_symbols, ", "); } } @@ -491,7 +491,7 @@ void instruction::debug_print() const // print debug symbols if they exist if(not debug_symbols.empty()) { - std::cout << " # " << join_strings(debug_symbols, ", ") << " #"; + std::cout << " # " << join_strings(debug_symbols, ", "); } std::cout << std::endl; } From 4b7c5ac6877a135525ad47457acdf3298e141ffb Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 17 Mar 2026 19:12:15 -0500 Subject: [PATCH 065/107] Revert op builder changes --- src/op/builder/celu.cpp | 6 ------ src/op/builder/clip.cpp | 6 ------ src/op/builder/gelu.cpp | 6 ------ 3 files changed, 18 deletions(-) diff --git a/src/op/builder/celu.cpp b/src/op/builder/celu.cpp index f849dfec4ef..5718f3aa754 100644 --- a/src/op/builder/celu.cpp +++ b/src/op/builder/celu.cpp @@ -37,12 +37,6 @@ struct celu : op_builder { float alpha = 1.0f; - template - static auto reflect(Self& self, F f) - { - return pack(f(self.alpha, "alpha")); - } - std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { diff --git a/src/op/builder/clip.cpp b/src/op/builder/clip.cpp index f3f5e147301..7d5168e18a1 100644 --- a/src/op/builder/clip.cpp +++ b/src/op/builder/clip.cpp @@ -35,12 +35,6 @@ namespace builder { struct clip : op_builder { - template - static auto reflect(Self&, F) - { - return pack(); - } - std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { diff --git a/src/op/builder/gelu.cpp b/src/op/builder/gelu.cpp index b80a554a600..f94e5db6824 100644 --- a/src/op/builder/gelu.cpp +++ b/src/op/builder/gelu.cpp @@ -37,12 +37,6 @@ struct gelu_quick : op_builder { float alpha = 1.0f; - template - static auto reflect(Self& self, F f) - { - return pack(f(self.alpha, "alpha")); - } - std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { From caa2a98fcba8c7abd9907dfb4214979bdc085a3f Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 17 Mar 2026 19:12:28 -0500 Subject: [PATCH 066/107] Edit .build*/ to .build/ --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 13d85644f5b..c5eb85500ac 100644 --- a/.gitignore +++ b/.gitignore @@ -55,7 +55,7 @@ _toc.yml # Directories to ignore (do not add trailing '/'s, they skip symlinks). #==============================================================================# # Nested build directory -build*/ +build/ debug_docker_build/ docker_build/ From 4fb59f04fd1e6cf3796df7bf33f6bd3f85385966 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 14:01:34 -0500 Subject: [PATCH 067/107] Keep only build* fix --- .gitignore | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.gitignore b/.gitignore index c5eb85500ac..fe25ef2c6a3 100644 --- a/.gitignore +++ b/.gitignore @@ -56,14 +56,6 @@ _toc.yml #==============================================================================# # Nested build directory build/ -debug_docker_build/ -docker_build/ - -# Cursor files -.cursor/ - -# Ctags files -.ctagsignore # Downloaded models test/onnx/models From b95aff8beaea4cd03f9be77ab69221ec8b55aeb7 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 14:49:09 -0500 Subject: [PATCH 068/107] Fix op_builders error --- src/op/builder/celu.cpp | 6 ++++++ src/op/builder/gelu.cpp | 18 ++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/op/builder/celu.cpp b/src/op/builder/celu.cpp index 5718f3aa754..f849dfec4ef 100644 --- a/src/op/builder/celu.cpp +++ b/src/op/builder/celu.cpp @@ -37,6 +37,12 @@ struct celu : op_builder { float alpha = 1.0f; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.alpha, "alpha")); + } + std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { diff --git a/src/op/builder/gelu.cpp b/src/op/builder/gelu.cpp index f94e5db6824..f360bae016f 100644 --- a/src/op/builder/gelu.cpp +++ b/src/op/builder/gelu.cpp @@ -37,6 +37,12 @@ struct gelu_quick : op_builder { float alpha = 1.0f; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.alpha, "alpha")); + } + std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { @@ -51,12 +57,6 @@ struct gelu_quick : op_builder struct gelu_erf : op_builder { - template - static auto reflect(Self&, F) - { - return pack(); - } - std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { @@ -124,12 +124,6 @@ struct gelu_tanh : op_builder struct gelu_split : op_builder { - template - static auto reflect(Self&, F) - { - return pack(); - } - std::vector insert(module& m, instruction_ref ins, const std::vector& args) const { From 8b6053e77d1628e0ed5be062b43dabc224696131 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 15:00:26 -0500 Subject: [PATCH 069/107] Cleanup --- src/include/migraphx/module.hpp | 4 ++-- src/module.cpp | 6 +++--- src/simplify_algebra.cpp | 4 ++-- test/debug_symbols_test.cpp | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 7b0e5e11595..d4f3d22a0fb 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -129,7 +129,7 @@ struct MIGRAPHX_EXPORT module instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); - struct instruction_replacer + struct instruction_replacement { instruction_ref ins; operation op; @@ -140,7 +140,7 @@ struct MIGRAPHX_EXPORT module /// Replaces an array of instructions within the same function to properly handle debug symbols /// propagation. Returns vector of instruction_ref to replaced instructions. std::vector batch_replace_instruction( - const std::vector& replacers) MIGRAPHX_TIDY_CONST; + const std::vector& replacers) MIGRAPHX_TIDY_CONST; instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instructions(instruction_ref first, instruction_ref last); diff --git a/src/module.cpp b/src/module.cpp index e2f9c632c71..a3615715052 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -425,8 +425,8 @@ deduce_min_splice(std::vector ends, static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, instruction_ref rep, - std::unordered_set new_max_splice, - std::unordered_set old_max_splice) + const std::unordered_set& new_max_splice, + const std::unordered_set& old_max_splice) { // TODO: can get common ancestors within gather_max_splice as an optimization // Find the common instructions between old_max_splice and new_max_splice. @@ -556,7 +556,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref } std::vector module::batch_replace_instruction( - const std::vector& replacers) MIGRAPHX_TIDY_CONST + const std::vector& replacers) MIGRAPHX_TIDY_CONST { impl->changed.notify(); std::vector ret; diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index ed9c2dcad86..1cb7b72183c 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1704,14 +1704,14 @@ struct find_conv_dot_horiz_fusion auto concat = m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); auto fused = m.insert_instruction(std::next(input), op, input, concat); - std::vector replacers; + std::vector replacers; int64_t offset = 0; for(auto arg : range(start, last)) { int64_t len = arg->get_shape().lens()[axis]; auto slice_op = make_op( "slice", {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}); - replacers.push_back(module::instruction_replacer{arg, slice_op, {fused}, {}}); + replacers.push_back(module::instruction_replacement{arg, slice_op, {fused}, {}}); offset += len; } m.batch_replace_instruction(replacers); diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index c47cfa3c62e..f2821a45789 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -537,7 +537,7 @@ TEST_CASE(debug_symbols_in_print) m.add_return({add}); auto str = migraphx::to_string(m); - EXPECT(str.find("# sym_a, sym_b #") != std::string::npos); + EXPECT(str.find("# sym_a, sym_b") != std::string::npos); } // ----------------------------------------------------------------------- From dfd5f696493f3817da837b24ba0ce32758fa4c66 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 18:45:57 -0500 Subject: [PATCH 070/107] Revert gitignore --- .gitignore | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index fe25ef2c6a3..20dead71ed3 100644 --- a/.gitignore +++ b/.gitignore @@ -55,18 +55,11 @@ _toc.yml # Directories to ignore (do not add trailing '/'s, they skip symlinks). #==============================================================================# # Nested build directory -build/ +/build* # Downloaded models test/onnx/models -# Python virtual env -venv_rbuild/ - -# ONNX test generated files -test/onnx/load_save_arg.msgpack -test/onnx/migraphx_api_load_save_argument.msgpack - # VS2017 and VSCode config files. .vscode .vs @@ -83,6 +76,11 @@ docs/_doxygen docs/html /_readthedocs +# JetBrains config directories (ignoring symlinks) +.idea/ +cmake-build*/ +build*/ + # Recommended location to install rbuild dependencies from README.md depend*/ From 716033c560fddc7a2f74fcdd9b9fd1aaae2658d4 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 18:52:10 -0500 Subject: [PATCH 071/107] add tests and tidy --- src/module.cpp | 20 ++-- src/program.cpp | 5 +- test/module_test.cpp | 186 +++++++++++++++++++++++++++++++++++++ test/serialize_program.cpp | 109 ++++++++++++++++++++++ 4 files changed, 305 insertions(+), 15 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index a3615715052..5ab18210a9c 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -119,7 +119,7 @@ struct module_impl { changed.notify(); instruction_set.erase(std::addressof(*pos)); - if(num_ins_with_debug_symbols > 0) + if(num_ins_with_debug_symbols > 0 and not pos->get_debug_symbols().empty()) --num_ins_with_debug_symbols; return instructions.erase(pos); } @@ -127,7 +127,11 @@ struct module_impl instruction_ref erase(instruction_ref start, instruction_ref last) { changed.notify(); - std::for_each(start, last, [&](auto& ins) { instruction_set.erase(std::addressof(ins)); }); + std::for_each(start, last, [&](auto& ins) { + instruction_set.erase(std::addressof(ins)); + if(num_ins_with_debug_symbols > 0 and not ins.get_debug_symbols().empty()) + --num_ins_with_debug_symbols; + }); return instructions.erase(start, last); } }; @@ -624,10 +628,6 @@ instruction_ref module::remove_instruction(instruction_ref ins) { assert(has_instruction(ins)); assert(ins->outputs().empty()); - if(not ins->get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) - { - impl->num_ins_with_debug_symbols--; - } ins->clear_arguments(); return impl->erase(ins); } @@ -638,13 +638,7 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r return first; // TODO: Check every element assert(has_instruction(first)); - std::for_each(first, last, [&](instruction& ins) { - if(not ins.get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) - { - impl->num_ins_with_debug_symbols--; - } - ins.clear_arguments(); - }); + std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); })); return impl->erase(first, last); } diff --git a/src/program.cpp b/src/program.cpp index 104b28e94d7..9cece98b2ee 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -732,7 +732,7 @@ value program::to_value() const } if(not ins->get_debug_symbols().empty()) { - nodes["debug_symbols"] = migraphx::to_value(ins->get_debug_symbols()); + node["debug_symbols"] = migraphx::to_value(ins->get_debug_symbols()); } nodes.push_back(node); }, @@ -813,7 +813,8 @@ static void mod_from_val(module_ref mod, output->set_normalized(normalized); if(node.contains("debug_symbols")) { - output->add_debug_symbols(node.at("debug_symbols").to>()); + output->add_debug_symbols( + from_value>(node.at("debug_symbols"))); } instructions[node.at("output").to()] = output; } diff --git a/test/module_test.cpp b/test/module_test.cpp index 25c518200ff..ecf431376c2 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -2031,4 +2031,190 @@ TEST_CASE(move_output_instructions_after_cross_module_mixed) EXPECT(p1 == p2); } +TEST_CASE(debug_symbols_add_and_remove) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto z = m.add_parameter("z", s); + migraphx::instruction_ref add1 = m.add_instruction(migraphx::make_op("add"), x, y); + migraphx::instruction_ref add2 = m.add_instruction(migraphx::make_op("add"), add1, z); + m.add_debug_symbols(add2, {"add2"}); + m.add_return({add1}); + EXPECT(m.has_debug_symbols()); + m.remove_instruction(add2); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(debug_symbols_no_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_return({add}); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(debug_symbols_multiple_instructions) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add1); + m.add_debug_symbols(add1, {"add_op"}); + m.add_debug_symbols(neg, {"neg_op"}); + m.add_return({neg}); + EXPECT(m.has_debug_symbols()); + + m.remove_debug_symbols(add1); + EXPECT(m.has_debug_symbols()); + + m.remove_debug_symbols(neg); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(debug_symbols_add_multiple_symbols_to_one_instruction) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"sym1"}); + m.add_debug_symbols(add, {"sym2", "sym3"}); + m.add_return({add}); + + EXPECT(m.has_debug_symbols()); + auto syms = add->get_debug_symbols(); + EXPECT(syms.count("sym1") == 1); + EXPECT(syms.count("sym2") == 1); + EXPECT(syms.count("sym3") == 1); +} + +TEST_CASE(debug_symbols_remove_instructions_range) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add1); + auto relu = m.add_instruction(migraphx::make_op("relu"), neg); + m.add_debug_symbols(neg, {"neg_op"}); + m.add_debug_symbols(relu, {"relu_op"}); + m.add_return({add1}); + + EXPECT(m.has_debug_symbols()); + m.remove_instructions(neg, m.end()); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(debug_symbols_copy_module) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_debug_symbols(add, {"add_sym"}); + m.add_return({add}); + + auto m2 = m; + EXPECT(m2.has_debug_symbols()); +} + +TEST_CASE(erase_single_with_debug_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add1); + m.add_debug_symbols(add1, {"add_sym"}); + m.add_debug_symbols(neg, {"neg_sym"}); + m.add_return({add1}); + + EXPECT(m.has_debug_symbols()); + m.remove_instruction(neg); + EXPECT(m.has_debug_symbols()); +} + +TEST_CASE(erase_single_without_debug_symbols_preserves_counter) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add1); + m.add_debug_symbols(add1, {"add_sym"}); + m.add_return({add1}); + + EXPECT(m.has_debug_symbols()); + m.remove_instruction(neg); + EXPECT(m.has_debug_symbols()); +} + +TEST_CASE(erase_single_last_debug_symbol) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add); + m.add_debug_symbols(neg, {"neg_sym"}); + m.add_return({add}); + + EXPECT(m.has_debug_symbols()); + m.remove_instruction(neg); + EXPECT(not m.has_debug_symbols()); +} + +TEST_CASE(erase_range_with_debug_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add1); + auto relu = m.add_instruction(migraphx::make_op("relu"), neg); + m.add_debug_symbols(add1, {"add_sym"}); + m.add_debug_symbols(neg, {"neg_sym"}); + m.add_debug_symbols(relu, {"relu_sym"}); + m.add_return({add1}); + + EXPECT(m.has_debug_symbols()); + auto ret_ins = std::prev(m.end()); + m.remove_instructions(neg, ret_ins); + EXPECT(m.has_debug_symbols()); + EXPECT(add1->get_debug_symbols().count("add_sym") == 1); +} + +TEST_CASE(erase_range_all_debug_symbols) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m.add_instruction(migraphx::make_op("neg"), add1); + auto relu = m.add_instruction(migraphx::make_op("relu"), neg); + m.add_debug_symbols(neg, {"neg_sym"}); + m.add_debug_symbols(relu, {"relu_sym"}); + m.add_return({add1}); + + EXPECT(m.has_debug_symbols()); + auto ret_ins = std::prev(m.end()); + m.remove_instructions(neg, ret_ins); + EXPECT(not m.has_debug_symbols()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index 951cd40c790..17318c75243 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -26,6 +26,8 @@ #include #include "test.hpp" #include +#include +#include #include @@ -138,4 +140,111 @@ TEST_CASE(program_with_module) EXPECT(p1.sort() == p2.sort()); } +static migraphx::program create_program_with_debug_symbols() +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 3}}); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(add, {"onnx:Add_0", "onnx:Add_1"}); + auto relu = mm->add_instruction(migraphx::make_op("relu"), add); + mm->add_debug_symbols(relu, {"onnx:Relu_0"}); + mm->add_return({relu}); + return p; +} + +using symbol_map = std::map>; + +static symbol_map collect_debug_symbols(const migraphx::module& m) +{ + symbol_map result; + for(auto ins : migraphx::iterator_for(m)) + { + auto syms = ins->get_debug_symbols(); + if(not syms.empty()) + result[ins->name()] = std::move(syms); + } + return result; +} + +TEST_CASE(debug_symbols_as_value) +{ + migraphx::program p1 = create_program_with_debug_symbols(); + migraphx::program p2; + p2.from_value(p1.to_value()); + EXPECT(p1.sort() == p2.sort()); + + auto syms1 = collect_debug_symbols(*p1.get_main_module()); + auto syms2 = collect_debug_symbols(*p2.get_main_module()); + EXPECT(syms1 == syms2); +} + +TEST_CASE(debug_symbols_as_msgpack) +{ + migraphx::file_options options; + options.format = "msgpack"; + migraphx::program p1 = create_program_with_debug_symbols(); + auto buffer = migraphx::save_buffer(p1, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p1.sort() == p2.sort()); + + auto syms1 = collect_debug_symbols(*p1.get_main_module()); + auto syms2 = collect_debug_symbols(*p2.get_main_module()); + EXPECT(syms1 == syms2); +} + +TEST_CASE(debug_symbols_as_json) +{ + migraphx::file_options options; + options.format = "json"; + migraphx::program p1 = create_program_with_debug_symbols(); + auto buffer = migraphx::save_buffer(p1, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p1.sort() == p2.sort()); + + auto syms1 = collect_debug_symbols(*p1.get_main_module()); + auto syms2 = collect_debug_symbols(*p2.get_main_module()); + EXPECT(syms1 == syms2); +} + +TEST_CASE(debug_symbols_as_file) +{ + std::string filename = "migraphx_program_debug_symbols.mxr"; + migraphx::program p1 = create_program_with_debug_symbols(); + migraphx::save(p1, filename); + migraphx::program p2 = migraphx::load(filename); + std::remove(filename.c_str()); + EXPECT(p1.sort() == p2.sort()); + + auto syms1 = collect_debug_symbols(*p1.get_main_module()); + auto syms2 = collect_debug_symbols(*p2.get_main_module()); + EXPECT(syms1 == syms2); +} + +TEST_CASE(debug_symbols_compiled) +{ + migraphx::program p1 = create_program_with_debug_symbols(); + p1.compile(migraphx::make_target("ref")); + auto buffer = migraphx::save_buffer(p1); + migraphx::program p2 = migraphx::load_buffer(buffer); + EXPECT(p1.sort() == p2.sort()); + + auto syms1 = collect_debug_symbols(*p1.get_main_module()); + auto syms2 = collect_debug_symbols(*p2.get_main_module()); + EXPECT(syms1 == syms2); +} + +TEST_CASE(no_debug_symbols_roundtrip) +{ + migraphx::program p1 = create_program(); + migraphx::program p2; + p2.from_value(p1.to_value()); + EXPECT(p1.sort() == p2.sort()); + + auto syms2 = collect_debug_symbols(*p2.get_main_module()); + EXPECT(syms2.empty()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 1e237d50deac6315298d7abeb7da9e0d7b7a7c0f Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 19:14:50 -0500 Subject: [PATCH 072/107] Comment update --- src/module.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 5ab18210a9c..4e13211d684 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -390,6 +390,7 @@ static std::unordered_set gather_max_splice(const_module_ref m, } /** + * Figure out the instructions actually being spliced (min splice). * end: instruction at end of splice */ static std::unordered_set @@ -422,10 +423,8 @@ deduce_min_splice(std::vector ends, return min_splice; } -/** - * ins: instruction that was/will be replaced - * rep: replacing instruction - */ +// ins: instruction that was/will be replaced +// rep: replacing instruction static void propagate_debug_symbols(const_module_ref m, instruction_ref ins, instruction_ref rep, @@ -559,6 +558,9 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref return rep; } +// For replacing multiple instructions within a single matcher. +// Handles debug symbol propagation by having all old splice debug symbols propagate +// to the the new splice instructions. std::vector module::batch_replace_instruction( const std::vector& replacers) MIGRAPHX_TIDY_CONST { From bfa594f617013c321c588ebace0da1a0f7cc78b0 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 23 Mar 2026 19:15:40 -0500 Subject: [PATCH 073/107] formatting --- src/instruction.cpp | 4 +--- src/module.cpp | 32 ++++++++++++++++++++------------ src/program.cpp | 7 ++----- test/debug_symbols_test.cpp | 2 +- test/module_test.cpp | 6 +++--- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index df9d521885d..7f1ce186a92 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -193,9 +193,7 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } void instruction::add_debug_symbols(const std::set& symbols) -{ - debug_symbols.insert(symbols.begin(), symbols.end()); -} +{ debug_symbols.insert(symbols.begin(), symbols.end()); } void instruction::remove_debug_symbols() { debug_symbols.clear(); } diff --git a/src/module.cpp b/src/module.cpp index 4e13211d684..c28cb683c1c 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -364,9 +364,8 @@ instruction_ref module::insert_instruction(instruction_ref ins, * replace_instruction and including the first not solely depenedent instruction. * stops : optional set of instructions to stop search on if encoutered. **/ -static std::unordered_set gather_max_splice(const_module_ref m, - instruction_ref ins, - std::unordered_set stops = {}) +static std::unordered_set gather_max_splice( + const_module_ref m, instruction_ref ins, std::unordered_set stops = {}) { std::unordered_set result = {ins}; fix([&](auto self, const std::vector& inputs) { @@ -377,8 +376,9 @@ static std::unordered_set gather_max_splice(const_module_ref m, if(contains(result, input)) continue; result.insert(input); - if(contains(stops, input) or any_of(input->outputs(), - [&](instruction_ref output) { return not contains(result, output); })) + if(contains(stops, input) or any_of(input->outputs(), [&](instruction_ref output) { + return not contains(result, output); + })) { // include first instruction that is not solely depenedent or in stops continue; @@ -402,7 +402,7 @@ deduce_min_splice(std::vector ends, min_splice.insert(ends.begin(), ends.end()); if(common_ancestors.empty()) return min_splice; - + // make min_splice by gathering outputs of common_ancestors within the max_splice for(auto anc : common_ancestors) { @@ -442,8 +442,10 @@ static void propagate_debug_symbols(const_module_ref m, [&new_max_splice](auto old_ins) { return contains(new_max_splice, old_ins); }); // Deduce the correct (minimum) splice from the max splice and the intersection - std::unordered_set old_splice = deduce_min_splice({ins}, old_max_splice, common_ancestors); - std::unordered_set new_splice = deduce_min_splice({rep}, new_max_splice, common_ancestors); + std::unordered_set old_splice = + deduce_min_splice({ins}, old_max_splice, common_ancestors); + std::unordered_set new_splice = + deduce_min_splice({rep}, new_max_splice, common_ancestors); std::set symbols; for(auto x : old_splice) @@ -476,7 +478,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, // placeholder identity instruction auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); std::unordered_set old_max_splice = gather_max_splice(this, id_ins); - std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); + std::unordered_set new_max_splice = + gather_max_splice(this, ins, old_max_splice); propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); } assert(ins->valid(begin())); @@ -502,7 +505,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, { auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); std::unordered_set old_max_splice = gather_max_splice(this, id_ins); - std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); + std::unordered_set new_max_splice = + gather_max_splice(this, ins, old_max_splice); propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); } assert(ins->valid(begin())); @@ -531,7 +535,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref if(has_debug_symbols()) { std::unordered_set new_max_splice = gather_max_splice(this, rep); - std::unordered_set old_max_splice = gather_max_splice(this, ins, new_max_splice); + std::unordered_set old_max_splice = + gather_max_splice(this, ins, new_max_splice); propagate_debug_symbols(this, ins, rep, new_max_splice, old_max_splice); } @@ -600,7 +605,10 @@ std::vector module::batch_replace_instruction( [&new_max_splices](auto old_ins) { return contains(new_max_splices, old_ins); }); std::vector ends; - std::transform(replacers.begin(), replacers.end(), std::back_inserter(ends), [](const auto& rep){ return rep.ins; }); + std::transform(replacers.begin(), + replacers.end(), + std::back_inserter(ends), + [](const auto& rep) { return rep.ins; }); std::unordered_set old_splice = deduce_min_splice({ends}, old_max_splices, common_ancestors); diff --git a/src/program.cpp b/src/program.cpp index 9cece98b2ee..75f3111bde0 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -813,8 +813,7 @@ static void mod_from_val(module_ref mod, output->set_normalized(normalized); if(node.contains("debug_symbols")) { - output->add_debug_symbols( - from_value>(node.at("debug_symbols"))); + output->add_debug_symbols(from_value>(node.at("debug_symbols"))); } instructions[node.at("output").to()] = output; } @@ -1378,9 +1377,7 @@ program& program::sort() } bool operator==(const program& x, const program& y) -{ - return migraphx::to_string(x) == migraphx::to_string(y); -} +{ return migraphx::to_string(x) == migraphx::to_string(y); } std::ostream& operator<<(std::ostream& os, const program& p) { diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index f2821a45789..da816047b94 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -270,7 +270,7 @@ TEST_CASE(pointwise_reduce_debug_symbols) mm->add_debug_symbols(curr, {"add0"}); curr = mm->add_instruction(migraphx::make_op("relu"), curr); mm->add_debug_symbols(curr, {"relu0"}); - auto add = add_pointwise(p1, "main:pointwise0", {curr, z}, single_pointwise("add")); + auto add = add_pointwise(p1, "main:pointwise0", {curr, z}, single_pointwise("add")); mm->add_debug_symbols(add, {"pointwise"}); auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add); mm->add_debug_symbols(rsum, {"reduce_sum"}); diff --git a/test/module_test.cpp b/test/module_test.cpp index ecf431376c2..12229da8e99 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -2035,9 +2035,9 @@ TEST_CASE(debug_symbols_add_and_remove) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::module m; - auto x = m.add_parameter("x", s); - auto y = m.add_parameter("y", s); - auto z = m.add_parameter("z", s); + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto z = m.add_parameter("z", s); migraphx::instruction_ref add1 = m.add_instruction(migraphx::make_op("add"), x, y); migraphx::instruction_ref add2 = m.add_instruction(migraphx::make_op("add"), add1, z); m.add_debug_symbols(add2, {"add2"}); From 60143f75d9f7ea7eb56765fcfde31d898364bb9b Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 12:55:31 -0500 Subject: [PATCH 074/107] copilot fixes --- src/module.cpp | 17 +++++++++++++---- test/onnx/parse/debug_symbols_test.cpp | 3 --- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index c28cb683c1c..4ee76c40098 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -361,8 +361,8 @@ instruction_ref module::insert_instruction(instruction_ref ins, /** * Traverse inputs of `ins` and gather instructions that output only to `ins`. * This splice is the total possibility of instructions that could be spliced by a - * replace_instruction and including the first not solely depenedent instruction. - * stops : optional set of instructions to stop search on if encoutered. + * replace_instruction and including the first not solely dependent instruction. + * stops : optional set of instructions to stop search on if encountered. **/ static std::unordered_set gather_max_splice( const_module_ref m, instruction_ref ins, std::unordered_set stops = {}) @@ -481,6 +481,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); + remove_instruction(id_ins); } assert(ins->valid(begin())); return ins; @@ -508,6 +509,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); + remove_instruction(id_ins); } assert(ins->valid(begin())); return ins; @@ -573,6 +575,7 @@ std::vector module::batch_replace_instruction( std::vector ret; std::unordered_set old_max_splices; std::unordered_set new_max_splices; + std::vector id_instructions; for(const auto& replacer : replacers) { assert(has_instruction(replacer.ins)); @@ -589,6 +592,7 @@ std::vector module::batch_replace_instruction( if(has_debug_symbols()) { auto id_ins = insert_instruction(replacer.ins, make_op("identity"), prev_args); + id_instructions.push_back(id_ins); auto old_ms = gather_max_splice(this, id_ins); auto new_ms = gather_max_splice(this, replacer.ins, old_ms); old_max_splices.insert(old_ms.begin(), old_ms.end()); @@ -611,9 +615,9 @@ std::vector module::batch_replace_instruction( [](const auto& rep) { return rep.ins; }); std::unordered_set old_splice = - deduce_min_splice({ends}, old_max_splices, common_ancestors); + deduce_min_splice(ends, old_max_splices, common_ancestors); std::unordered_set new_splice = - deduce_min_splice({ends}, new_max_splices, common_ancestors); + deduce_min_splice(ends, new_max_splices, common_ancestors); // include in-place debug symbols if they're there std::set symbols; @@ -630,6 +634,11 @@ std::vector module::batch_replace_instruction( { add_debug_symbols(new_ins, symbols); } + // clean up the identity placeholder instructions + for(auto id_ins : id_instructions) + { + remove_instruction(id_ins); + } } return ret; } diff --git a/test/onnx/parse/debug_symbols_test.cpp b/test/onnx/parse/debug_symbols_test.cpp index 99e242004c9..5088368b6cc 100644 --- a/test/onnx/parse/debug_symbols_test.cpp +++ b/test/onnx/parse/debug_symbols_test.cpp @@ -24,9 +24,6 @@ #include -template -struct TD; - TEST_CASE(debug_symbols_onnx_names) { migraphx::program p; From 6a3ae3c1b98f0d85ed935696f0d4599246306e9a Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 13:13:30 -0500 Subject: [PATCH 075/107] copilot review updates --- src/include/migraphx/module.hpp | 7 ++++--- src/module.cpp | 4 ++-- test/debug_symbols_test.cpp | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 1215790144b..33b4feab9cb 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -24,9 +24,6 @@ #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP -#include -#include -#include #include #include #include @@ -38,6 +35,10 @@ #include #include #include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/module.cpp b/src/module.cpp index 4ee76c40098..dc1b2c21f27 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -380,7 +380,7 @@ static std::unordered_set gather_max_splice( return not contains(result, output); })) { - // include first instruction that is not solely depenedent or in stops + // include first instruction that is not solely dependent or in stops continue; } self(input->inputs()); @@ -567,7 +567,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref // For replacing multiple instructions within a single matcher. // Handles debug symbol propagation by having all old splice debug symbols propagate -// to the the new splice instructions. +// to the new splice instructions. std::vector module::batch_replace_instruction( const std::vector& replacers) MIGRAPHX_TIDY_CONST { diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index da816047b94..3d0b43da548 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -524,7 +524,7 @@ TEST_CASE(simplify_div_const_debug_symbols) // output using the expected comment format produced by instruction::print. // // Printed output includes: -// @2 = add(@0,@1) -> float_type, {2, 3} # sym_a, sym_b # +// @2 = add(@0,@1) -> float_type, {2, 3} # sym_a, sym_b // TEST_CASE(debug_symbols_in_print) { From 17d3f1d97f3978bf6e0ff11025665cb273dd2aee Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 13:13:58 -0500 Subject: [PATCH 076/107] licensing --- test/serialize_program.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp index 17318c75243..8ba5d216f93 100644 --- a/test/serialize_program.cpp +++ b/test/serialize_program.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 56ce83acdae3fee84c069c7e10b11bf0e3b257b6 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 13:15:06 -0500 Subject: [PATCH 077/107] Formatting --- src/instruction.cpp | 26 +++++++------------------- src/program.cpp | 4 +--- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index ad39aee0235..26e9fd172ea 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -70,9 +70,7 @@ struct replace_shape_order std::size_t location(instruction_ref x) const { return std::distance(start, x); } bool operator()(instruction_ref x, instruction_ref y) const - { - return location(x) > location(y); - } + { return location(x) > location(y); } }; void instruction::replace(const shape& r) @@ -125,9 +123,7 @@ void instruction::clear_arguments() } bool operator==(const instruction& i, instruction_ref ref) -{ - return std::addressof(i) == std::addressof(*ref); -} +{ return std::addressof(i) == std::addressof(*ref); } bool instruction::valid(instruction_ref start, bool check_order) const { @@ -282,7 +278,7 @@ void instruction::replace(operation o, std::vector mdl_args) { lit = literal{}; - op = std::move(o); + op = std::move(o); replace(r); replace(std::move(args), std::move(mdl_args)); } @@ -522,9 +518,7 @@ void instruction::set_normalized(bool value) { normalized = value; } bool instruction::is_normalized() const { return normalized; } bool instruction::need_normalization() const -{ - return this->get_operator().need_normalization() and not normalized; -} +{ return this->get_operator().need_normalization() and not normalized; } operation instruction::normalized_operator() const { @@ -550,9 +544,7 @@ std::vector to_shapes(const std::vector& args) } shape compute_shape(const operation& op, const std::vector& args) -{ - return op.compute_shape(to_shapes(args)); -} +{ return op.compute_shape(to_shapes(args)); } shape compute_shape(const operation& op, const std::vector& args, @@ -583,14 +575,10 @@ std::vector try_compute_shape(const operation& op, const std::vector::iterator& ins) noexcept -{ - return iterator_address(ins); -} +{ return iterator_address(ins); } const migraphx::instruction* as_address(const std::list::const_iterator& ins) noexcept -{ - return iterator_address(ins); -} +{ return iterator_address(ins); } template static auto track_visits(instruction_ref start, instruction_ref end, F f) diff --git a/src/program.cpp b/src/program.cpp index 47f258467ac..5093c954f88 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -81,9 +81,7 @@ struct program_impl program::program() : impl(std::make_unique()) { this->create_module("main"); } program::program(module m) : impl(std::make_unique()) -{ - this->create_module("main", std::move(m)); -} +{ this->create_module("main", std::move(m)); } program::program(program&&) noexcept = default; program::~program() noexcept = default; From 6394aaef84f297e0f35e7b7d0a33f27e6db77837 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 16:07:47 -0500 Subject: [PATCH 078/107] Update to use protobuff to more resemble ONNX file --- codecov.yml | 1 - docs/driver/read.rst | 2 +- docs/migraphx-driver.rst | 2 +- src/CMakeLists.txt | 1 - src/driver/main.cpp | 4 +- src/include/migraphx/netron_output.hpp | 7 +- src/include/migraphx/program.hpp | 6 + src/netron_output.cpp | 283 ------------------------- src/onnx/netron_output.cpp | 245 +++++++++++++++++++++ src/program.cpp | 13 +- test/netron_output_test.cpp | 89 ++++++++ 11 files changed, 354 insertions(+), 299 deletions(-) delete mode 100644 src/netron_output.cpp create mode 100644 src/onnx/netron_output.cpp create mode 100644 test/netron_output_test.cpp diff --git a/codecov.yml b/codecov.yml index 9f2569b0669..03abe2daeb2 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,4 +2,3 @@ ignore: - "test/" - "src/driver" - "build/" - - "src/netron_output.cpp" diff --git a/docs/driver/read.rst b/docs/driver/read.rst index db32b11dda7..dfc0d17b0f0 100644 --- a/docs/driver/read.rst +++ b/docs/driver/read.rst @@ -80,7 +80,7 @@ Print out program as json. .. option:: --netron -Print out program as a Netron viewable json file. +Print out program as ONNX protobuf binary viewable in Netron. .. option:: --text diff --git a/docs/migraphx-driver.rst b/docs/migraphx-driver.rst index b84edb0f371..26422154c16 100644 --- a/docs/migraphx-driver.rst +++ b/docs/migraphx-driver.rst @@ -86,7 +86,7 @@ To learn which options can be used with which commands, see the :ref:`MIGraphX d * - --binary - Prints the program in binary format * - --netron - - Prints the program in Netron viewable JSON format + - Prints the program as ONNX protobuf binary viewable in Netron * - --output | -o - Writes output in a file * - --fill0 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bd0083dfe6f..7782c8bf81f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -96,7 +96,6 @@ add_library(migraphx memory_coloring.cpp module.cpp msgpack.cpp - netron_output.cpp normalize_attributes.cpp normalize_ops.cpp op_enums.cpp diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 7088def715d..e0f90a7f498 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -283,7 +283,7 @@ struct loader ap.set_value("binary")); ap(output_type, {"--netron"}, - ap.help("Print out program as Netron readable json."), + ap.help("Print out program as ONNX protobuf binary viewable in Netron."), ap.set_value("netron")); ap(output, {"--output", "-o"}, ap.help("Output to file.")); } @@ -540,7 +540,7 @@ struct loader else if(type == "binary") write(*os, save_buffer(p)); else if(type == "netron") - *os << make_netron_output(p) << std::endl; + write_netron_output(p, *os); } }; diff --git a/src/include/migraphx/netron_output.hpp b/src/include/migraphx/netron_output.hpp index fb355a2d9f5..de1f3884366 100644 --- a/src/include/migraphx/netron_output.hpp +++ b/src/include/migraphx/netron_output.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,14 +24,15 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_NETRON_OUTPUT_HPP #define MIGRAPHX_GUARD_RTGLIB_NETRON_OUTPUT_HPP -#include +#include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -MIGRAPHX_EXPORT std::string make_netron_output(const program& prog); +MIGRAPHX_ONNX_EXPORT void write_netron_output(const program& prog, std::ostream& os); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 8bc7310c2d2..625fc9d85ba 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -56,6 +56,7 @@ struct marker; */ struct MIGRAPHX_EXPORT program { + program(); explicit program(module m); @@ -79,6 +80,8 @@ struct MIGRAPHX_EXPORT program std::unordered_map get_parameter_shapes() const; + int get_program_file_version() const; + std::size_t total_instructions() const; std::vector eval(const parameter_map& params, @@ -165,6 +168,9 @@ struct MIGRAPHX_EXPORT program private: void assign(const program& p); std::unique_ptr impl; + // program file version is for the data structure or format of the MXR file. Version should be bumped + // if any changes occur to the format of the MXR file. + const int program_file_version = 7; }; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/netron_output.cpp b/src/netron_output.cpp deleted file mode 100644 index 8323556f169..00000000000 --- a/src/netron_output.cpp +++ /dev/null @@ -1,283 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace { - -// from https://onnx.ai/onnx/intro/concepts.html -int get_onnx_type(shape::type_t s_type) -{ - switch(s_type) - { - case shape::float_type: return 1; - case shape::uint8_type: return 2; - case shape::int8_type: return 3; - case shape::uint16_type: return 4; - case shape::int16_type: return 5; - case shape::int32_type: return 6; - case shape::int64_type: return 7; - case shape::bool_type: return 9; - case shape::half_type: return 10; - case shape::double_type: return 11; - case shape::uint32_type: return 12; - case shape::uint64_type: return 13; - case shape::bf16_type: return 16; - case shape::fp8e4m3fn_type: return 17; - case shape::fp8e4m3fnuz_type: return 18; - case shape::fp8e5m2_type: return 19; - case shape::fp8e5m2fnuz_type: return 20; - case shape::tuple_type: return 0; - case shape::fp4x2_type: return 21; // TODO update this when the type is added - } - MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported"); -} - -auto make_attribute(const migraphx::value& val) -{ - value attribute = value(std::unordered_map()); - attribute["name"] = val.get_key(); - auto val_string = val.to(); - std::string sub_str = val.get_key() + ":"; - auto find_key = val_string.find(sub_str); - if(find_key != std::string::npos) - { - val_string = val_string.substr(find_key + sub_str.length() + 1); - } - // TODO: doesn't work for some reason with Netron now - // attribute["s"] = base64_encode(val_string); - // attribute["type"] = "STRING"; - attribute["docString"] = val_string; - return attribute; -} - -/// Returns a value with the JSON structure needed for a node -auto make_onnx_json_node(instruction_ref ins, - std::unordered_map ins_uids) -{ - value node; - // TODO add support for module inputs - value input_arr = value({}); - for(instruction_ref input_ins : ins->inputs()) - { - auto name = input_ins->name(); - if(name == "@literal" or name == "@param") - { - input_arr.push_back(ins_uids.at(input_ins)); - } - // TODO make a better process for handling nodes to ignore - else if(name.find("hip::hip_allocate_memory") != std::string::npos) - { - continue; - } - else - { - input_arr.push_back(ins_uids.at(input_ins) + "->" + ins_uids.at(ins)); - } - } - value output_arr = value({}); - for(instruction_ref output_ins : ins->outputs()) - { - if(output_ins->name() == "@return") - { - output_arr.push_back(ins_uids.at(output_ins)); - } - else - { - output_arr.push_back(ins_uids.at(ins) + "->" + ins_uids.at(output_ins)); - } - } - node["input"] = input_arr; - node["output"] = output_arr; - node["name"] = ins_uids.at(ins); - node["opType"] = ins->name(); - value op_attribute_arr = value({}); - auto op_value = ins->get_operator().to_value(); - std::for_each(op_value.begin(), op_value.end(), [&](const auto& v) { - const std::string& attr_key = v.get_key(); - if(v.is_binary() or attr_key == "code_object") - { - return; - } - else if(attr_key == "symbol_name" or attr_key == "name") - { - node["opType"] = migraphx::from_value(v); - } - else - { - op_attribute_arr.push_back(make_attribute(v)); - } - }); - node["attribute"] = op_attribute_arr; - return node; -} - -// ONNX graph constant data called "initializer" -auto make_onnx_json_literal(instruction_ref ins, - std::unordered_map ins_uids) -{ - value lit; - lit["dims"] = ins->get_shape().lens(); - lit["dataType"] = get_onnx_type(ins->get_shape().type()); - lit["name"] = ins_uids.at(ins); - // ignoring literal data, setting to "NULL" in base64 - lit["rawData"] = "TlVMTA=="; - return lit; -} - -// TODO handle dynamic shapes -// TODO handle subshapes -auto make_onnx_json_shape(const shape& s) -{ - value ret; - value dim = value({}); - for(std::size_t len : s.lens()) - { - // cppcheck-suppress useStlAlgorithm - dim.push_back({{"dimValue", len}}); - } - ret["dim"] = dim; - return ret; -} - -// ONNX graph edges called "valueType" -auto make_onnx_json_edge(instruction_ref ins, - instruction_ref out_ins, - std::unordered_map ins_uids) -{ - value ret; - shape ins_shape = ins->get_shape(); - ret["name"] = ins_uids.at(ins) + "->" + ins_uids.at(out_ins); - value type = {{"tensorType", - {{"elemType", get_onnx_type(ins_shape.type())}, - {"shape", make_onnx_json_shape(ins_shape)}}}}; - ret["type"] = type; - return ret; -} - -auto make_onnx_json_in_out(instruction_ref ins, - std::unordered_map ins_uids) -{ - value ret; - shape ins_shape = ins->get_shape(); - ret["name"] = ins_uids.at(ins); - value type = {{"tensorType", - {{"elemType", get_onnx_type(ins_shape.type())}, - {"shape", make_onnx_json_shape(ins_shape)}}}}; - ret["type"] = type; - return ret; -} - -std::unordered_map make_ins_uids(const module& mod) -{ - std::unordered_map ret; - int count = 0; - for(auto ins : iterator_for(mod)) - { - std::string var_name; - var_name = mod.name() + ":"; - var_name.append(ins->name() + ":"); - if(ins->name() == "@param") - { - var_name.append(any_cast(ins->get_operator()).parameter + ":"); - } - var_name.append("@" + std::to_string(count)); - count++; - ret.emplace(ins, var_name); - } - return ret; -} - -value make_graph(const module* mod) -{ - value graph = {{"node", value({})}, - {"initializer", value({})}, - {"input", value({})}, - {"output", value({})}, - {"valueInfo", value({})}}; - auto ins_uids = make_ins_uids(*mod); - for(auto ins = mod->begin(); ins != mod->end(); ++ins) - { - const auto& name = ins->name(); - if(name == "@literal") - { - graph["initializer"].push_back(make_onnx_json_literal(ins, ins_uids)); - } - else if(name == "@param") - { - graph["input"].push_back(make_onnx_json_in_out(ins, ins_uids)); - } - else if(name == "@return") - { - graph["output"].push_back(make_onnx_json_in_out(ins, ins_uids)); - } - else if(name.find("hip::hip_allocate_memory") != std::string::npos) - { - continue; - } - else - { - graph["node"].push_back(make_onnx_json_node(ins, ins_uids)); - const auto& outputs = ins->outputs(); - for(auto out_ins : outputs) - { - if(out_ins->name() != "@return") - { - graph["valueInfo"].push_back(make_onnx_json_edge(ins, out_ins, ins_uids)); - } - } - } - } - return graph; -} - -} // namespace - -std::string make_netron_output(const program& prog) -{ - value output; - auto prog_value = prog.to_value(); - // ONNX IR version 6 - // TODO: investigate sure how this affects things - output["irVersion"] = 6; - output["producerName"] = "AMDMIGraphX"; - output["producerVersion"] = prog_value.at("migraphx_version").to(); - for(auto& mod : prog.get_modules()) - { - auto graph = make_graph(mod); - output["graph"] = graph; - } - return to_pretty_json_string(output, 4); -} - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/onnx/netron_output.cpp b/src/onnx/netron_output.cpp new file mode 100644 index 00000000000..b94743910ee --- /dev/null +++ b/src/onnx/netron_output.cpp @@ -0,0 +1,245 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace onnx = onnx_for_migraphx; + +namespace { + +int get_onnx_type(shape::type_t s_type) +{ + switch(s_type) + { + case shape::float_type: return onnx::TensorProto::FLOAT; + case shape::uint8_type: return onnx::TensorProto::UINT8; + case shape::int8_type: return onnx::TensorProto::INT8; + case shape::uint16_type: return onnx::TensorProto::UINT16; + case shape::int16_type: return onnx::TensorProto::INT16; + case shape::int32_type: return onnx::TensorProto::INT32; + case shape::int64_type: return onnx::TensorProto::INT64; + case shape::bool_type: return onnx::TensorProto::BOOL; + case shape::half_type: return onnx::TensorProto::FLOAT16; + case shape::double_type: return onnx::TensorProto::DOUBLE; + case shape::uint32_type: return onnx::TensorProto::UINT32; + case shape::uint64_type: return onnx::TensorProto::UINT64; + case shape::bf16_type: return onnx::TensorProto::BFLOAT16; + case shape::fp8e4m3fn_type: return onnx::TensorProto::FLOAT8E4M3FN; + case shape::fp8e4m3fnuz_type: return onnx::TensorProto::FLOAT8E4M3FNUZ; + case shape::fp8e5m2_type: return onnx::TensorProto::FLOAT8E5M2; + case shape::fp8e5m2fnuz_type: return onnx::TensorProto::FLOAT8E5M2FNUZ; + case shape::tuple_type: return onnx::TensorProto::UNDEFINED; + case shape::fp4x2_type: return onnx::TensorProto::UINT4; + } + MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported"); +} + +std::unordered_map make_ins_uids(const module& mod) +{ + std::unordered_map ret; + int count = 0; + for(auto ins : iterator_for(mod)) + { + std::string var_name; + var_name = mod.name() + ":"; + var_name.append(ins->name() + ":"); + if(ins->name() == "@param") + { + var_name.append(any_cast(ins->get_operator()).parameter + ":"); + } + var_name.append("@" + std::to_string(count)); + count++; + ret.emplace(ins, var_name); + } + return ret; +} + +void set_shape_proto(onnx::TensorShapeProto* shape_proto, const shape& s) +{ + for(std::size_t len : s.lens()) + { + shape_proto->add_dim()->set_dim_value(len); + } +} + +void set_value_info(onnx::ValueInfoProto* vi, const std::string& name, const shape& s) +{ + vi->set_name(name); + auto* type = vi->mutable_type(); + auto* tensor = type->mutable_tensor_type(); + tensor->set_elem_type(get_onnx_type(s.type())); + set_shape_proto(tensor->mutable_shape(), s); +} + +void add_initializer(onnx::GraphProto* graph, + instruction_ref ins, + const std::unordered_map& ins_uids) +{ + auto* init = graph->add_initializer(); + init->set_name(ins_uids.at(ins)); + init->set_data_type(get_onnx_type(ins->get_shape().type())); + for(std::size_t d : ins->get_shape().lens()) + { + init->add_dims(d); + } + init->set_raw_data("NULL"); +} + +void add_node(onnx::GraphProto* graph, + instruction_ref ins, + const std::unordered_map& ins_uids) +{ + auto* node = graph->add_node(); + + std::string op_type = ins->name(); + auto op_value = ins->get_operator().to_value(); + std::for_each(op_value.begin(), op_value.end(), [&](const auto& v) { + const std::string& attr_key = v.get_key(); + if(v.is_binary() or attr_key == "code_object") + { + return; + } + else if(attr_key == "symbol_name" or attr_key == "name") + { + op_type = migraphx::from_value(v); + } + else + { + auto* attr = node->add_attribute(); + attr->set_name(attr_key); + + auto val_string = v.template to(); + std::string sub_str = attr_key + ":"; + auto find_key = val_string.find(sub_str); + if(find_key != std::string::npos) + { + val_string = val_string.substr(find_key + sub_str.length() + 1); + } + attr->set_type(onnx::AttributeProto::STRING); + attr->set_s(val_string); + } + }); + + node->set_op_type(op_type); + node->set_name(ins_uids.at(ins)); + + for(instruction_ref input_ins : ins->inputs()) + { + auto name = input_ins->name(); + if(name == "@literal" or name == "@param") + { + node->add_input(ins_uids.at(input_ins)); + } + else if(name.find("hip::hip_allocate_memory") != std::string::npos) + { + continue; + } + else + { + node->add_input(ins_uids.at(input_ins) + "->" + ins_uids.at(ins)); + } + } + + for(instruction_ref output_ins : ins->outputs()) + { + if(output_ins->name() == "@return") + { + node->add_output(ins_uids.at(output_ins)); + } + else + { + node->add_output(ins_uids.at(ins) + "->" + ins_uids.at(output_ins)); + } + } +} + +void build_graph(onnx::GraphProto* graph, const module* mod) +{ + auto ins_uids = make_ins_uids(*mod); + for(auto ins = mod->begin(); ins != mod->end(); ++ins) + { + const auto& name = ins->name(); + if(name == "@literal") + { + add_initializer(graph, ins, ins_uids); + } + else if(name == "@param") + { + set_value_info(graph->add_input(), ins_uids.at(ins), ins->get_shape()); + } + else if(name == "@return") + { + set_value_info(graph->add_output(), ins_uids.at(ins), ins->get_shape()); + } + else if(name.find("hip::hip_allocate_memory") != std::string::npos) + { + continue; + } + else + { + add_node(graph, ins, ins_uids); + for(auto out_ins : ins->outputs()) + { + if(out_ins->name() != "@return") + { + set_value_info(graph->add_value_info(), + ins_uids.at(ins) + "->" + ins_uids.at(out_ins), + ins->get_shape()); + } + } + } + } +} + +} // namespace + +void write_netron_output(const program& prog, std::ostream& os) +{ + onnx::ModelProto model; + auto prog_value = prog.to_value(); + model.set_ir_version(6); + model.set_producer_name("AMDMIGraphX"); + model.set_producer_version(prog_value.at("migraphx_version").to()); + + for(auto* mod : prog.get_modules()) + { + build_graph(model.mutable_graph(), mod); + } + + model.SerializeToOstream(&os); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/program.cpp b/src/program.cpp index 4ca08842e74..31601b44db0 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -160,6 +160,11 @@ std::unordered_map program::get_parameter_shapes() const return mm->get_parameter_shapes(); } +int program::get_program_file_version() const +{ + return program_file_version; +} + std::size_t program::size() const { return impl->modules.size(); } std::vector program::get_output_shapes() const @@ -681,16 +686,10 @@ static std::string get_migraphx_version() return ss.str(); } -/* -program file version is for the data structure or format of the MXR file. Version should be bumped -if any changes occur to the format of the MXR file. -*/ -const int program_file_version = 7; - value program::to_value() const { value result; - result["version"] = program_file_version; + result["version"] = get_program_file_version(); result["migraphx_version"] = get_migraphx_version(); result["targets"] = migraphx::to_value(this->impl->targets); result["contexts"] = migraphx::to_value(this->impl->contexts); diff --git a/test/netron_output_test.cpp b/test/netron_output_test.cpp new file mode 100644 index 00000000000..a228fce6cf6 --- /dev/null +++ b/test/netron_output_test.cpp @@ -0,0 +1,89 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +TEST_CASE(netron_output_basic) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 3}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_return({sum}); + + std::ostringstream os; + migraphx::write_netron_output(p, os); + + std::string output = os.str(); + EXPECT(not output.empty()); + EXPECT(output.size() > 10); +} + +TEST_CASE(netron_output_with_literal) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto lit = mm->add_literal(migraphx::literal{{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, lit); + mm->add_return({sum}); + + std::ostringstream os; + migraphx::write_netron_output(p, os); + + std::string output = os.str(); + EXPECT(not output.empty()); + EXPECT(output.size() > 10); +} + +TEST_CASE(netron_output_roundtrip) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 3}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_return({sum}); + + std::ostringstream os; + migraphx::write_netron_output(p, os); + std::string output = os.str(); + + // The output should be parseable as a valid ONNX model + migraphx::onnx_options options; + options.skip_unknown_operators = true; + auto p2 = migraphx::parse_onnx_buffer(output.data(), output.size(), options); + EXPECT(p2.get_main_module()->size() > 0); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From ee4aa217fc58997b1e466fe3b96d530380fa2be3 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 16:32:45 -0500 Subject: [PATCH 079/107] cleanup --- src/onnx/netron_output.cpp | 2 +- test/netron_output_test.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/onnx/netron_output.cpp b/src/onnx/netron_output.cpp index b94743910ee..2a5f866467a 100644 --- a/src/onnx/netron_output.cpp +++ b/src/onnx/netron_output.cpp @@ -229,7 +229,7 @@ void write_netron_output(const program& prog, std::ostream& os) { onnx::ModelProto model; auto prog_value = prog.to_value(); - model.set_ir_version(6); + model.set_ir_version(prog.get_program_file_version()); model.set_producer_name("AMDMIGraphX"); model.set_producer_version(prog_value.at("migraphx_version").to()); diff --git a/test/netron_output_test.cpp b/test/netron_output_test.cpp index a228fce6cf6..26aa8b7f243 100644 --- a/test/netron_output_test.cpp +++ b/test/netron_output_test.cpp @@ -79,7 +79,8 @@ TEST_CASE(netron_output_roundtrip) migraphx::write_netron_output(p, os); std::string output = os.str(); - // The output should be parseable as a valid ONNX model + // The output should be parseable as a valid ONNX model. + // Most nodes will be unknown operators however. migraphx::onnx_options options; options.skip_unknown_operators = true; auto p2 = migraphx::parse_onnx_buffer(output.data(), output.size(), options); From 2874c28aa4d0c08c4d7dd1d1dd0f7a707441b0c4 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 17:10:27 -0500 Subject: [PATCH 080/107] Add netron output to APIs --- src/api/api.cpp | 19 ++++++ src/api/include/migraphx/migraphx.h | 3 + src/api/include/migraphx/migraphx.hpp | 7 ++- src/api/migraphx.py | 4 ++ src/py/migraphx_py.cpp | 9 +++ test/api/CMakeLists.txt | 1 + test/api/test_netron_output.cpp | 88 +++++++++++++++++++++++++++ test/netron_output_test.cpp | 2 +- test/py/CMakeLists.txt | 1 + test/py/test_netron_output.py | 59 ++++++++++++++++++ tools/api/api.cpp | 8 +++ 11 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 test/api/test_netron_output.cpp create mode 100644 test/py/test_netron_output.py diff --git a/src/api/api.cpp b/src/api/api.cpp index 46613777117..51662c08842 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -39,7 +39,9 @@ #include #include #include +#include #include +#include #include #include @@ -332,6 +334,12 @@ static std::vector get_output_shapes(program& p) { return p.get_output_sh static void print_program(const program& p) { std::cout << p << std::endl; } +static void write_netron_output_file(const program& p, const char* filename) +{ + std::ofstream os(filename, std::ios::binary); + write_netron_output(p, os); +} + static void print_module(const module& m) { std::cout << m << std::endl; } static migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s) @@ -1751,6 +1759,17 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr return api_error_result; } +extern "C" migraphx_status migraphx_program_write_netron_output(const_migraphx_program_t program, + const char* filename) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + migraphx::write_netron_output_file((program->object), (filename)); + }); + return api_error_result; +} + extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program) { auto api_error_result = migraphx::try_([&] { diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index c4a992abcf5..35d01097141 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -464,6 +464,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_output_shapes(migraphx_sh MIGRAPHX_C_EXPORT migraphx_status migraphx_program_print(const_migraphx_program_t program); +MIGRAPHX_C_EXPORT migraphx_status +migraphx_program_write_netron_output(const_migraphx_program_t program, const char* filename); + MIGRAPHX_C_EXPORT migraphx_status migraphx_program_sort(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run(migraphx_arguments_t* out, diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index d059f7a95c8..a27c3c2810d 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1232,6 +1232,11 @@ struct program : MIGRAPHX_HANDLE_BASE(program) void print() const { call(&migraphx_program_print, this->get_handle_ptr()); } + void write_netron_output(const char* filename) const + { + call(&migraphx_program_write_netron_output, this->get_handle_ptr(), filename); + } + program sort() { call(&migraphx_program_sort, this->get_handle_ptr()); diff --git a/src/api/migraphx.py b/src/api/migraphx.py index 3c20e926268..5b1d8b9cdd1 100644 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -284,6 +284,10 @@ def program(h): invoke='migraphx::get_output_shapes($@)', returns='std::vector') h.method('print', invoke='migraphx::print_program($@)', const=True) + h.method('write_netron_output', + api.params(filename='const char*'), + invoke='migraphx::write_netron_output_file($@)', + const=True) h.method('sort') h.method('run', api.params( diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index e591474d499..49bdc048714 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -38,7 +38,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -567,6 +569,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) return ss.str(); }) .def("sort", &migraphx::program::sort) + .def("write_netron_output", + [](const migraphx::program& p, const std::string& filename) { + std::ofstream os(filename, std::ios::binary); + migraphx::write_netron_output(p, os); + }, + "Write program as ONNX protobuf binary viewable in Netron", + py::arg("filename")) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("__eq__", std::equal_to{}) .def("__ne__", std::not_equal_to{}) diff --git a/test/api/CMakeLists.txt b/test/api/CMakeLists.txt index f4bd1d7140f..f0b22e6e967 100644 --- a/test/api/CMakeLists.txt +++ b/test/api/CMakeLists.txt @@ -60,6 +60,7 @@ add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(dynamic_shape test_dynamic_shape.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) +add_api_test(netron_output test_netron_output.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_c_api_test(c_op test_c_op_construct.c ${TEST_ONNX_DIR}) diff --git a/test/api/test_netron_output.cpp b/test/api/test_netron_output.cpp new file mode 100644 index 00000000000..1da537dff27 --- /dev/null +++ b/test/api/test_netron_output.cpp @@ -0,0 +1,88 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" + +TEST_CASE(netron_output_cpp_api) +{ + auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + std::string filename = "migraphx_api_netron_output_test.onnx"; + p.write_netron_output(filename.c_str()); + + std::ifstream ifs(filename, std::ios::binary | std::ios::ate); + EXPECT(ifs.good()); + auto size = ifs.tellg(); + EXPECT(size > 0); + + std::remove(filename.c_str()); +} + +TEST_CASE(netron_output_c_api) +{ + migraphx_program_t p; + migraphx_onnx_options_t onnx_options; + migraphx_onnx_options_create(&onnx_options); + auto status = migraphx_parse_onnx(&p, "conv_relu_maxpool_test.onnx", onnx_options); + EXPECT(status == migraphx_status_success); + + std::string filename = "migraphx_c_api_netron_output_test.onnx"; + status = migraphx_program_write_netron_output(p, filename.c_str()); + EXPECT(status == migraphx_status_success); + + std::ifstream ifs(filename, std::ios::binary | std::ios::ate); + EXPECT(ifs.good()); + auto size = ifs.tellg(); + EXPECT(size > 0); + + std::remove(filename.c_str()); + migraphx_program_destroy(p); + migraphx_onnx_options_destroy(onnx_options); +} + +TEST_CASE(netron_output_constructed_program) +{ + migraphx::program p; + migraphx::module m = p.get_main_module(); + migraphx::shape s{migraphx_shape_float_type, {2, 3}}; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto add_op = migraphx::operation("add"); + auto r = m.add_instruction(add_op, {x, y}); + m.add_return({r}); + + std::string filename = "migraphx_api_netron_constructed_test.onnx"; + p.write_netron_output(filename.c_str()); + + std::ifstream ifs(filename, std::ios::binary | std::ios::ate); + EXPECT(ifs.good()); + auto size = ifs.tellg(); + EXPECT(size > 0); + + std::remove(filename.c_str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/netron_output_test.cpp b/test/netron_output_test.cpp index 26aa8b7f243..83634c7d8a5 100644 --- a/test/netron_output_test.cpp +++ b/test/netron_output_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/py/CMakeLists.txt b/test/py/CMakeLists.txt index f2bfd2e8555..e3d5d05e622 100644 --- a/test/py/CMakeLists.txt +++ b/test/py/CMakeLists.txt @@ -99,6 +99,7 @@ endif() add_py_test(ref test_cpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(save_load test_save_load.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(netron_output test_netron_output.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(op test_op.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(shape test_shape.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(module_construct test_module_construct.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) diff --git a/test/py/test_netron_output.py b/test/py/test_netron_output.py new file mode 100644 index 00000000000..aec1ecb4f22 --- /dev/null +++ b/test/py/test_netron_output.py @@ -0,0 +1,59 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### +import migraphx, tempfile, os + + +def test_netron_output_parsed_model(): + p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as t: + filename = t.name + + p.write_netron_output(filename) + size = os.path.getsize(filename) + assert size > 0, "Netron output file is empty" + os.remove(filename) + + +def test_netron_output_constructed_program(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + mm.add_return([add_ins]) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as t: + filename = t.name + + p.write_netron_output(filename) + size = os.path.getsize(filename) + assert size > 0, "Netron output file is empty" + os.remove(filename) + + +if __name__ == "__main__": + test_netron_output_parsed_model() + test_netron_output_constructed_program() diff --git a/tools/api/api.cpp b/tools/api/api.cpp index cfefcf39f29..594923e840c 100644 --- a/tools/api/api.cpp +++ b/tools/api/api.cpp @@ -39,7 +39,9 @@ #include #include #include +#include #include +#include #include #include @@ -332,6 +334,12 @@ static std::vector get_output_shapes(program& p) { return p.get_output_sh static void print_program(const program& p) { std::cout << p << std::endl; } +static void write_netron_output_file(const program& p, const char* filename) +{ + std::ofstream os(filename, std::ios::binary); + write_netron_output(p, os); +} + static void print_module(const module& m) { std::cout << m << std::endl; } static migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s) From 115f1117798339df52b2e2d3591ae74d9d97d4a9 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 18:17:59 -0500 Subject: [PATCH 081/107] Fix bug with replace_instruction debug symbols that have no inputs Hit the bug with replace_allocate bug --- src/module.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index dc1b2c21f27..871501856c6 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -386,6 +386,10 @@ static std::unordered_set gather_max_splice( self(input->inputs()); } })(ins->inputs()); + if(result.size() > 50) + { + std::cout << "max splice larger than 50 instructions" << std::endl; + } return result; } @@ -473,11 +477,13 @@ instruction_ref module::replace_instruction(instruction_ref ins, prev_args = ins->inputs(); } instruction::replace(ins, op, r, std::move(args)); - if(has_debug_symbols()) + if(has_debug_symbols() and not prev_args.empty()) { // placeholder identity instruction auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); + // Get old_max_splice after replacement to get smallest dependent splice std::unordered_set old_max_splice = gather_max_splice(this, id_ins); + // TODO: if there are no common ancestors, this may traverse the majority of the graph std::unordered_set new_max_splice = gather_max_splice(this, ins, old_max_splice); propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); @@ -502,7 +508,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, prev_args = ins->inputs(); } instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); - if(has_debug_symbols()) + if(has_debug_symbols() and not prev_args.empty()) { auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); std::unordered_set old_max_splice = gather_max_splice(this, id_ins); @@ -589,7 +595,7 @@ std::vector module::batch_replace_instruction( instruction::replace( replacer.ins, replacer.op, out_shape, replacer.args, replacer.module_args); ret.push_back(replacer.ins); - if(has_debug_symbols()) + if(has_debug_symbols() and not prev_args.empty()) { auto id_ins = insert_instruction(replacer.ins, make_op("identity"), prev_args); id_instructions.push_back(id_ins); From 9d8f0cfdb30c1eccb2da96660bbfad295722d70a Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 24 Mar 2026 18:31:54 -0500 Subject: [PATCH 082/107] Remove print --- src/module.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 871501856c6..6c41e8fef8f 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -386,10 +386,6 @@ static std::unordered_set gather_max_splice( self(input->inputs()); } })(ins->inputs()); - if(result.size() > 50) - { - std::cout << "max splice larger than 50 instructions" << std::endl; - } return result; } From 3eea60cc019700bad510aa4728caba70e96737c7 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 25 Mar 2026 18:59:33 -0500 Subject: [PATCH 083/107] Move and change tests --- test/debug_symbols_test.cpp | 439 ++++++++++++++++++++------------- test/fuse_pointwise.cpp | 99 ++++++++ test/simplify_algebra_test.cpp | 230 +++++++++++++++++ 3 files changed, 591 insertions(+), 177 deletions(-) diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index 3d0b43da548..d2b56931110 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -22,8 +22,6 @@ * THE SOFTWARE. */ -#include -#include #include #include #include @@ -33,20 +31,18 @@ #include #include #include -#include #include -#include -#include -// Two adds fused into a single pointwise op via fuse_pointwise. -// Both symbols should appear on the fused pointwise instruction. +// Two adds replaced by a single pass_op via replace_instruction, +// emulating what fuse_pointwise does without running the pass. +// add1 feeds only into add2 (splice chain), so both symbols merge. // // Before: After: // // x y x y z // \ / \ | / -// add {add1} pointwise {add1, add2} +// add {add1} pass {add1, add2} // | z | // | / @return // add {add2} @@ -56,42 +52,44 @@ TEST_CASE(pw_double_add) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::program p1; + migraphx::module m1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - migraphx::instruction_ref add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - mm->add_debug_symbols(add1, {"add1"}); - migraphx::instruction_ref add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); - mm->add_debug_symbols(add2, {"add2"}); - mm->add_return({add2}); + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto z = m1.add_parameter("z", s); + auto add1 = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_debug_symbols(add1, {"add1"}); + auto add2 = m1.add_instruction(migraphx::make_op("add"), add1, z); + m1.add_debug_symbols(add2, {"add2"}); + m1.add_return({add2}); + + m1.replace_instruction(add2, pass_op{}, x, y, z); } - migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); - migraphx::program p2; + migraphx::module m2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto fadd = - add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { - auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); - }); - mm->add_debug_symbols(fadd, {"add1", "add2"}); - mm->add_return({fadd}); + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto z = m2.add_parameter("z", s); + auto fadd = m2.add_instruction(pass_op{}, x, y, z); + m2.add_debug_symbols(fadd, {"add1", "add2"}); + m2.add_return({fadd}); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } +// Diamond of four adds replaced by a single pass_op via replace_instruction, +// emulating what fuse_pointwise does without running the pass. +// add2 and add3 each feed only into add4 (splice chain), and add1 +// feeds only into add2 and add3 (both in the chain), so all four +// symbols merge onto the replacement. +// // Before: After: // // x y x y // \ / \ / -// add1 {add1} pointwise {add1, add2, add3, add4} +// add1 {add1} pass {add1, add2, add3, add4} // / \ | // x y @return // | | @@ -104,42 +102,39 @@ TEST_CASE(pw_double_add) TEST_CASE(pw_used_twice_fused) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::program p1; + migraphx::module m1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - mm->add_debug_symbols(add1, {"onnx:add1"}); - auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x); - mm->add_debug_symbols(add2, {"onnx:add2"}); - auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y); - mm->add_debug_symbols(add3, {"onnx:add3"}); - auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3); - mm->add_debug_symbols(add4, {"onnx:add4"}); - mm->add_return({add4}); + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto add1 = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_debug_symbols(add1, {"onnx:add1"}); + auto add2 = m1.add_instruction(migraphx::make_op("add"), add1, x); + m1.add_debug_symbols(add2, {"onnx:add2"}); + auto add3 = m1.add_instruction(migraphx::make_op("add"), add1, y); + m1.add_debug_symbols(add3, {"onnx:add3"}); + auto add4 = m1.add_instruction(migraphx::make_op("add"), add2, add3); + m1.add_debug_symbols(add4, {"onnx:add4"}); + m1.add_return({add4}); + + m1.replace_instruction(add4, pass_op{}, x, y); } - migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); - migraphx::program p2; + migraphx::module m2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { - auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]); - auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]); - return pm->add_instruction(migraphx::make_op("add"), add2, add3); - }); - mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); - mm->add_return({fadd}); + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto fadd = m2.add_instruction(pass_op{}, x, y); + m2.add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); + m2.add_return({fadd}); } - EXPECT(p1.sort() == p2.sort()); + EXPECT(m1 == m2); } // Debug symbols should not propagate above the fusion boundary. -// The gemm (dot) keeps its own symbol; only the fused adds merge. +// The dot is a common ancestor in both the old and new splice, so +// its symbol stays on the dot instruction. Only add1 and add2 (the +// splice chain between dot and the replacement) merge their symbols. // // Before: After: // @@ -147,9 +142,9 @@ TEST_CASE(pw_used_twice_fused) // \ / \ / // dot {gemm1} dot {gemm1} // | y / -// | / gemm y z +// | / / y z // add {add1} \ | / -// | z pointwise {add1, add2} +// | z pass {add1, add2} // | / | // add {add2} @return // | @@ -159,47 +154,45 @@ TEST_CASE(gemm_add_add) { migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; migraphx::shape s2{migraphx::shape::float_type, {3, 3}}; - migraphx::program p1; + migraphx::module m1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto z = mm->add_parameter("z", s1); - auto a = mm->add_literal(migraphx::generate_literal(s2, 0)); - auto gemm = mm->add_instruction(migraphx::make_op("dot"), x, a); - mm->add_debug_symbols(gemm, {"gemm1"}); - auto add1 = mm->add_instruction(migraphx::make_op("add"), gemm, y); - mm->add_debug_symbols(add1, {"add1"}); - auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); - mm->add_debug_symbols(add2, {"add2"}); - mm->add_return({add2}); + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s1); + auto z = m1.add_parameter("z", s1); + auto a = m1.add_literal(migraphx::generate_literal(s2, 0)); + auto gemm = m1.add_instruction(migraphx::make_op("dot"), x, a); + m1.add_debug_symbols(gemm, {"gemm1"}); + auto add1 = m1.add_instruction(migraphx::make_op("add"), gemm, y); + m1.add_debug_symbols(add1, {"add1"}); + auto add2 = m1.add_instruction(migraphx::make_op("add"), add1, z); + m1.add_debug_symbols(add2, {"add2"}); + m1.add_return({add2}); + + m1.replace_instruction(add2, pass_op{}, gemm, y, z); } - migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); - migraphx::program p2; + migraphx::module m2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto z = mm->add_parameter("z", s1); - auto a = mm->add_literal(migraphx::generate_literal(s2, 0)); - auto gemm = mm->add_instruction(migraphx::make_op("dot"), x, a); - mm->add_debug_symbols(gemm, {"gemm1"}); - auto fadd = - add_pointwise(p2, "main:pointwise0", {gemm, y, z}, [=](auto* pm, const auto& inputs) { - auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); - }); - mm->add_debug_symbols(fadd, {"add1", "add2"}); - mm->add_return({fadd}); + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s1); + auto z = m2.add_parameter("z", s1); + auto a = m2.add_literal(migraphx::generate_literal(s2, 0)); + auto gemm = m2.add_instruction(migraphx::make_op("dot"), x, a); + m2.add_debug_symbols(gemm, {"gemm1"}); + auto fadd = m2.add_instruction(pass_op{}, gemm, y, z); + m2.add_debug_symbols(fadd, {"add1", "add2"}); + m2.add_return({fadd}); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } -// Horizontal fusion of two dot ops sharing the same input via -// simplify_algebra. The two dots are fused into concat + single dot + slices. -// Each new instruction inherits the symbols of the original dots it derives -// from (e.g. the concat and fused dot carry both "gemm1" and "gemm2"). +// Horizontal fusion of two dot ops sharing the same input, emulating +// what simplify_algebra does via insert_instruction + batch_replace_instruction. +// The first replacement sees fused_dot with a single output (the second +// dot hasn't been replaced yet), so the new splice chain traverses +// through fused_dot and concat. All four new instructions receive the +// merged {gemm1, gemm2} symbols. // // Before: After: // @@ -210,7 +203,7 @@ TEST_CASE(gemm_add_add) // \ / \ | // add {sum} dot {g1, g2} // / \ -// slice{g1, g2} slice{g1, g2} +// slice {g1, g2} slice {g1, g2} // \ / // add {sum} // @@ -230,86 +223,105 @@ TEST_CASE(horiz_fusion_dot) auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); m1.add_debug_symbols(sum, {"sum"}); m1.add_return({sum}); + + auto concat = m1.insert_instruction(x, migraphx::make_op("concat", {{"axis", 2}}), a, b); + auto fused_dot = m1.insert_instruction(x, migraphx::make_op("dot"), input, concat); + m1.batch_replace_instruction({ + {x, migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), {fused_dot}, {}}, + {y, migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), {fused_dot}, {}}, + }); } - migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); migraphx::module m2; { - auto input = m2.add_parameter("input", s); - auto a = m2.add_literal(migraphx::generate_literal(s, 0)); - auto b = m2.add_literal(migraphx::generate_literal(s, 1)); + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(s, 0)); + auto b = m2.add_literal(migraphx::generate_literal(s, 1)); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); m2.add_debug_symbols(concat, {"gemm1", "gemm2"}); - auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); - m2.add_debug_symbols(dot, {"gemm1", "gemm2"}); - auto x = m2.add_instruction( - migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); - m2.add_debug_symbols(x, {"gemm1", "gemm2"}); - auto y = m2.add_instruction( - migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); - m2.add_debug_symbols(y, {"gemm1", "gemm2"}); - auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + auto fused_dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); + m2.add_debug_symbols(fused_dot, {"gemm1", "gemm2"}); + auto sx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), fused_dot); + m2.add_debug_symbols(sx, {"gemm1", "gemm2"}); + auto sy = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), fused_dot); + m2.add_debug_symbols(sy, {"gemm1", "gemm2"}); + auto sum = m2.add_instruction(migraphx::make_op("add"), sx, sy); m2.add_debug_symbols(sum, {"sum"}); m2.add_return({sum}); } - EXPECT(m1.sort() == m2.sort()); + EXPECT(m1 == m2); } -// Goes through the find_pointwise_reduce matcher. -// Making sure that the debug symbols are getting from the minimum splice of the old instructions. -TEST_CASE(pointwise_reduce_debug_symbols) +// Emulates the find_pointwise_reduce fusion using pass_op + replace_instruction. +// A pass_op standing in for the pointwise feeds only into a pass_op standing +// in for the reduce (splice chain), so replace_instruction merges both symbols. +// +// Before: After: +// +// x y x y +// \ / \ / +// add {add0} add {add0} +// | | +// relu {relu0} relu {relu0} +// | z | z +// | / | / +// pass {pointwise} pass {pointwise, reduce_sum} +// | | +// pass {reduce_sum} relu {relu1} +// | | +// relu {relu1} @return +// | +// @return +// +TEST_CASE(pointwise_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::program p1; + migraphx::module m1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto curr = mm->add_instruction(migraphx::make_op("add"), x, y); - mm->add_debug_symbols(curr, {"add0"}); - curr = mm->add_instruction(migraphx::make_op("relu"), curr); - mm->add_debug_symbols(curr, {"relu0"}); - auto add = add_pointwise(p1, "main:pointwise0", {curr, z}, single_pointwise("add")); - mm->add_debug_symbols(add, {"pointwise"}); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add); - mm->add_debug_symbols(rsum, {"reduce_sum"}); - curr = mm->add_instruction(migraphx::make_op("relu"), rsum); - mm->add_debug_symbols(curr, {"relu1"}); - mm->add_return({curr}); + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto z = m1.add_parameter("z", s); + auto curr = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_debug_symbols(curr, {"add0"}); + curr = m1.add_instruction(migraphx::make_op("relu"), curr); + m1.add_debug_symbols(curr, {"relu0"}); + auto pw = m1.add_instruction(pass_op{}, curr, z); + m1.add_debug_symbols(pw, {"pointwise"}); + auto rs = m1.add_instruction(pass_op{}, pw); + m1.add_debug_symbols(rs, {"reduce_sum"}); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), rs); + m1.add_debug_symbols(relu1, {"relu1"}); + m1.add_return({relu1}); + + m1.replace_instruction(rs, pass_op{}, curr, z); } - migraphx::run_passes(p1, {migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); - migraphx::program p2; + migraphx::module m2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); - auto curr = mm->add_instruction(migraphx::make_op("add"), x, y); - mm->add_debug_symbols(curr, {"add0"}); - curr = mm->add_instruction(migraphx::make_op("relu"), curr); - mm->add_debug_symbols(curr, {"relu0"}); - auto rsum = add_reduce( - p2, - "main:pointwise0:main:reduce_sum0", - {curr, z}, - {1}, - [&](auto* rm, const auto& inputs, const auto& axes) { - auto add = - add_pointwise(p2, rm, "main:pointwise0", inputs, single_pointwise("add")); - return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add); - }); - mm->add_debug_symbols(rsum, {"pointwise", "reduce_sum"}); - curr = mm->add_instruction(migraphx::make_op("relu"), rsum); - mm->add_debug_symbols(curr, {"relu1"}); - mm->add_return({curr}); + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto z = m2.add_parameter("z", s); + auto curr = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_debug_symbols(curr, {"add0"}); + curr = m2.add_instruction(migraphx::make_op("relu"), curr); + m2.add_debug_symbols(curr, {"relu0"}); + auto fused = m2.add_instruction(pass_op{}, curr, z); + m2.add_debug_symbols(fused, {"pointwise", "reduce_sum"}); + auto relu1 = m2.add_instruction(migraphx::make_op("relu"), fused); + m2.add_debug_symbols(relu1, {"relu1"}); + m2.add_return({relu1}); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } -// Tests symbol propagation through add reassociation in simplify_algebra -// (find_double_add_lit_broadcast). Checks add(add(x,1), add(y,2)) -> (add(add(x,y), add(1,2)). +// Tests symbol propagation through add reassociation, emulating +// find_double_add_lit_broadcast via insert_instruction + replace_instruction. +// add(add(x,1), add(y,2)) -> add(add(x,y), add(1,2)). +// sum1 and sum2 each feed only into sum3 (splice chain), so all three +// symbols merge onto every instruction in the new splice. // // Before: After: // @@ -319,7 +331,7 @@ TEST_CASE(pointwise_reduce_debug_symbols) // \ / \ / // add0{a0} add{a0,a1,a2} // -TEST_CASE(simplify_add_debug_symbols) +TEST_CASE(simplify_add) { migraphx::module m1; { @@ -334,8 +346,12 @@ TEST_CASE(simplify_add_debug_symbols) auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); m1.add_debug_symbols(sum3, {"onnx:add0"}); m1.add_return({sum3}); + + auto sumab = m1.insert_instruction(sum3, migraphx::make_op("add"), one, two); + auto sumxy = m1.insert_instruction(sum3, migraphx::make_op("add"), x, y); + m1.replace_instruction(sum3, migraphx::make_op("add"), sumxy, sumab); } - migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); migraphx::module m2; { @@ -354,8 +370,10 @@ TEST_CASE(simplify_add_debug_symbols) EXPECT(m1.sort() == m2.sort()); } -// Tests the replace_instruction(ins, rep) overload via find_unit_ops which -// simplifies add(relu(x), broadcast(0)) to relu(x). +// Tests the replace_instruction(ins, rep) overload directly, emulating +// what find_unit_ops does: add(relu(x), broadcast(0)) -> relu(x). +// replace_instruction(add_r, relu_x) redirects add's outputs to relu, +// and relu inherits add's {onnx:add} symbol. // // Before: After: // @@ -366,7 +384,7 @@ TEST_CASE(simplify_add_debug_symbols) // \ / // add {add} // -TEST_CASE(replace_with_insref_debug_symbols) +TEST_CASE(replace_with_insref) { migraphx::module m1; { @@ -380,8 +398,10 @@ TEST_CASE(replace_with_insref_debug_symbols) auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); m1.add_debug_symbols(add_r, {"onnx:add"}); m1.add_return({add_r}); + + m1.replace_instruction(add_r, relu_x); } - migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); migraphx::module m2; { @@ -406,7 +426,7 @@ TEST_CASE(replace_with_insref_debug_symbols) // | (relu becomes dead code, removed by DCE) // pass // -TEST_CASE(gather_replace_chain_debug_symbols) +TEST_CASE(gather_replace_chain) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::module m1; @@ -435,8 +455,10 @@ TEST_CASE(gather_replace_chain_debug_symbols) EXPECT(m1.sort() == m2.sort()); } -// Tests the distributive law transform in simplify_algebra (find_mul_add): -// mul(add(3, x), 2) -> add(mul(2, x), mul(2, 3)). +// Tests the distributive law transform via insert_instruction + replace_instruction, +// emulating find_mul_add: mul(add(3, x), 2) -> add(mul(2, x), mul(2, 3)). +// add feeds only into mul (splice chain), so both symbols merge onto +// every instruction in the new splice. // // Before: After: // @@ -447,7 +469,7 @@ TEST_CASE(gather_replace_chain_debug_symbols) // | / add{add,mul} // mul {mul} // -TEST_CASE(simplify_mul_add_debug_symbols) +TEST_CASE(simplify_mul_add) { migraphx::module m1; { @@ -459,8 +481,12 @@ TEST_CASE(simplify_mul_add_debug_symbols) auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two); m1.add_debug_symbols(mul, {"onnx:mul"}); m1.add_return({mul}); + + auto ax = m1.insert_instruction(mul, migraphx::make_op("mul"), two, x); + auto ab = m1.insert_instruction(mul, migraphx::make_op("mul"), two, one); + m1.replace_instruction(mul, migraphx::make_op("add"), ax, ab); } - migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); migraphx::module m2; { @@ -478,8 +504,10 @@ TEST_CASE(simplify_mul_add_debug_symbols) EXPECT(m1.sort() == m2.sort()); } -// Tests symbol propagation through find_div_const in simplify_algebra: -// div(x, c) -> mul(x, recip(c)). +// Tests symbol propagation through insert_instruction + replace_instruction, +// emulating find_div_const: div(x, c) -> mul(x, recip(c)). +// div is the only instruction in the splice chain, so its {onnx:div} symbol +// propagates to both recip and mul. // // Before: After: // @@ -490,7 +518,7 @@ TEST_CASE(simplify_mul_add_debug_symbols) // \ | // mul {div} // -TEST_CASE(simplify_div_const_debug_symbols) +TEST_CASE(simplify_div_const) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::module m1; @@ -502,8 +530,11 @@ TEST_CASE(simplify_div_const_debug_symbols) auto div_r = m1.add_instruction(migraphx::make_op("div"), x, c); m1.add_debug_symbols(div_r, {"onnx:div"}); m1.add_return({div_r}); + + auto recip = m1.insert_instruction(div_r, migraphx::make_op("recip"), c); + m1.replace_instruction(div_r, migraphx::make_op("mul"), x, recip); } - migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); migraphx::module m2; { @@ -540,10 +571,6 @@ TEST_CASE(debug_symbols_in_print) EXPECT(str.find("# sym_a, sym_b") != std::string::npos); } -// ----------------------------------------------------------------------- -// Direct replace_instruction / batch_replace_instruction tests -// ----------------------------------------------------------------------- - // replace_instruction(ins, op, args) -- simple in-place, no splice chain. // The replaced instruction retains its own debug symbols. // @@ -759,6 +786,64 @@ TEST_CASE(batch_replace_multi_merges_symbols) EXPECT(results[1]->get_debug_symbols() == expected); } +// Emulates rewrite_nearest_resize: resize(x) -> gather(reshape(x), indices). +// resize is the only symbolized instruction in the splice chain, so +// {onnx:resize} propagates to both reshape and gather. The literal +// (gather indices) does not receive symbols. +// +// Before: After: +// +// x x +// | | +// resize {resize} reshape {resize} indices +// \ / +// gather {resize} +// +TEST_CASE(rewrite_resize_debug_symbols) +{ + migraphx::shape in_s{migraphx::shape::float_type, {1, 1, 2, 2}}; + // clang-format off + std::vector indices = {0, 0, 1, 1, + 0, 0, 1, 1, + 2, 2, 3, 3, + 2, 2, 3, 3}; + // clang-format on + + migraphx::module m1; + { + auto x = m1.add_parameter("x", in_s); + auto resize = m1.add_instruction( + migraphx::make_op("resize", + {{"scales", {1.0f, 1.0f, 2.0f, 2.0f}}, + {"nearest_mode", "floor"}, + {"coordinate_transformation_mode", "asymmetric"}}), + x); + m1.add_debug_symbols(resize, {"onnx:resize"}); + m1.add_return({resize}); + + auto rsp = + m1.insert_instruction(resize, migraphx::make_op("reshape", {{"dims", {4}}}), x); + auto ins_ind = m1.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1, 1, 4, 4}}, indices}); + m1.replace_instruction(resize, migraphx::make_op("gather", {{"axis", 0}}), rsp, ins_ind); + } + migraphx::run_passes(m1, {migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", in_s); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); + m2.add_debug_symbols(rsp, {"onnx:resize"}); + auto ins_ind = m2.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1, 1, 4, 4}}, indices}); + auto gather = + m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, ins_ind); + m2.add_debug_symbols(gather, {"onnx:resize"}); + m2.add_return({gather}); + } + EXPECT(m1.sort() == m2.sort()); +} + // ----------------------------------------------------------------------- // module::remove_debug_symbols tests // ----------------------------------------------------------------------- diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 573c8ca91ab..b34e1493792 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1316,4 +1316,103 @@ TEST_CASE(if_cross_module_multi_out_find_input) EXPECT(p1.sort() == p2.sort()); } +// Two adds fused into a single pointwise op via fuse_pointwise. +// Both symbols should appear on the fused pointwise instruction. +// +// Before: After: +// +// x y x y z +// \ / \ | / +// add {add1} pointwise {add1, add2} +// | z | +// | / @return +// add {add2} +// | +// @return +// +TEST_CASE(debug_symbols_pw_double_add) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + migraphx::instruction_ref add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(add1, {"add1"}); + migraphx::instruction_ref add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + mm->add_debug_symbols(add2, {"add2"}); + mm->add_return({add2}); + } + migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + mm->add_debug_symbols(fadd, {"add1", "add2"}); + mm->add_return({fadd}); + } + EXPECT(p1 == p2); +} + +// Before: After: +// +// x y x y +// \ / \ / +// add1 {add1} pointwise {add1, add2, add3, add4} +// / \ | +// x y @return +// | | +// add2 add3 {add2} {add3} +// \ / +// add4 {add4} +// | +// @return +// +TEST_CASE(debug_symbols_pw_used_twice_fused) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(add1, {"onnx:add1"}); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x); + mm->add_debug_symbols(add2, {"onnx:add2"}); + auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y); + mm->add_debug_symbols(add3, {"onnx:add3"}); + auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3); + mm->add_debug_symbols(add4, {"onnx:add4"}); + mm->add_return({add4}); + } + migraphx::run_passes(p1, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]); + auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add2, add3); + }); + mm->add_debug_symbols(fadd, {"onnx:add1", "onnx:add2", "onnx:add3", "onnx:add4"}); + mm->add_return({fadd}); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 5e444558636..584f70a0427 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -5074,4 +5074,234 @@ TEST_CASE(pow3) EXPECT(m1.sort() == m2.sort()); } +// Horizontal fusion of two dot ops sharing the same input via +// simplify_algebra. The two dots are fused into concat + single dot + slices. +// Each new instruction inherits the symbols of the original dots it derives +// from (e.g. the concat and fused dot carry both "gemm1" and "gemm2"). +// +// Before: After: +// +// a input b a b +// \ / \ / \ / +// dot {g1} dot {g2} concat {g1, g2} +// \ / input | +// \ / \ | +// add {sum} dot {g1, g2} +// / \ +// slice{g1, g2} slice{g1, g2} +// \ / +// add {sum} +// +TEST_CASE(debug_symbols_horiz_fusion_dot) +{ + auto type = migraphx::shape::int32_type; + auto s = migraphx::shape{type, {3, 2, 2}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(s, 0)); + auto b = m1.add_literal(migraphx::generate_literal(s, 1)); + auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); + m1.add_debug_symbols(x, {"gemm1"}); + auto y = m1.add_instruction(migraphx::make_op("dot"), input, b); + m1.add_debug_symbols(y, {"gemm2"}); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_debug_symbols(sum, {"sum"}); + m1.add_return({sum}); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(s, 0)); + auto b = m2.add_literal(migraphx::generate_literal(s, 1)); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); + m2.add_debug_symbols(concat, {"gemm1", "gemm2"}); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); + m2.add_debug_symbols(dot, {"gemm1", "gemm2"}); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); + m2.add_debug_symbols(x, {"gemm1", "gemm2"}); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); + m2.add_debug_symbols(y, {"gemm1", "gemm2"}); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_debug_symbols(sum, {"sum"}); + m2.add_return({sum}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Tests symbol propagation through add reassociation in simplify_algebra +// (find_double_add_lit_broadcast). Checks add(add(x,1), add(y,2)) -> (add(add(x,y), add(1,2)). +// +// Before: After: +// +// x 1 y 2 1 2 x y +// \ / \ / \ / \ / +// add1{a1} add2{a2} add{a0,a1,a2} add{a0,a1,a2} +// \ / \ / +// add0{a0} add{a0,a1,a2} +// +TEST_CASE(debug_symbols_simplify_add) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); + m1.add_debug_symbols(sum1, {"onnx:add1"}); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); + m1.add_debug_symbols(sum2, {"onnx:add2"}); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_debug_symbols(sum3, {"onnx:add0"}); + m1.add_return({sum3}); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + m2.add_debug_symbols(sum1, {"onnx:add0", "onnx:add1", "onnx:add2"}); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_debug_symbols(sum2, {"onnx:add0", "onnx:add1", "onnx:add2"}); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); + m2.add_debug_symbols(sum3, {"onnx:add0", "onnx:add1", "onnx:add2"}); + m2.add_return({sum3}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Tests the replace_instruction(ins, rep) overload via find_unit_ops which +// simplifies add(relu(x), broadcast(0)) to relu(x). +// +// Before: After: +// +// x 0 x +// | | | +// relu bcast relu {add, relu} +// {relu} (0.0) +// \ / +// add {add} +// +TEST_CASE(debug_symbols_replace_with_insref) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto zero = m1.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {0.0f}}); + auto bcast = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), zero); + auto relu_x = m1.add_instruction(migraphx::make_op("relu"), x); + m1.add_debug_symbols(relu_x, {"onnx:relu"}); + auto add_r = m1.add_instruction(migraphx::make_op("add"), relu_x, bcast); + m1.add_debug_symbols(add_r, {"onnx:add"}); + m1.add_return({add_r}); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto relu_x = m2.add_instruction(migraphx::make_op("relu"), x); + m2.add_debug_symbols(relu_x, {"onnx:add", "onnx:relu"}); + m2.add_return({relu_x}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Tests the distributive law transform in simplify_algebra (find_mul_add): +// mul(add(3, x), 2) -> add(mul(2, x), mul(2, 3)). +// +// Before: After: +// +// 3 x 2 x 2 3 +// \ / \ / \ / +// add {add} mul{add,mul} mul{add,mul} +// | 2 \ / +// | / add{add,mul} +// mul {mul} +// +TEST_CASE(debug_symbols_simplify_mul_add) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(3); + auto two = m1.add_literal(2); + auto sum = m1.add_instruction(migraphx::make_op("add"), one, x); + m1.add_debug_symbols(sum, {"onnx:add"}); + auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two); + m1.add_debug_symbols(mul, {"onnx:mul"}); + m1.add_return({mul}); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(3); + auto two = m2.add_literal(2); + auto mul1 = m2.add_instruction(migraphx::make_op("mul"), two, x); + m2.add_debug_symbols(mul1, {"onnx:add", "onnx:mul"}); + auto mul2 = m2.add_instruction(migraphx::make_op("mul"), two, one); + m2.add_debug_symbols(mul2, {"onnx:add", "onnx:mul"}); + auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); + m2.add_debug_symbols(sum, {"onnx:add", "onnx:mul"}); + m2.add_return({sum}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Tests symbol propagation through find_div_const in simplify_algebra: +// div(x, c) -> mul(x, recip(c)). +// +// Before: After: +// +// x c c +// \ / | +// div {div} recip {div} +// x | +// \ | +// mul {div} +// +TEST_CASE(debug_symbols_simplify_div_const) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto c = + m1.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto div_r = m1.add_instruction(migraphx::make_op("div"), x, c); + m1.add_debug_symbols(div_r, {"onnx:div"}); + m1.add_return({div_r}); + } + migraphx::run_passes(m1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto c = + m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 3}}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + auto recip = m2.add_instruction(migraphx::make_op("recip"), c); + m2.add_debug_symbols(recip, {"onnx:div"}); + auto mul_r = m2.add_instruction(migraphx::make_op("mul"), x, recip); + m2.add_debug_symbols(mul_r, {"onnx:div"}); + m2.add_return({mul_r}); + } + EXPECT(m1.sort() == m2.sort()); +} + + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 9b2f7591898ff79b6e11d4de8e69fc5877a5cff9 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 25 Mar 2026 19:18:30 -0500 Subject: [PATCH 084/107] Update API to set onnx_options use_debug_symbols --- src/api/api.cpp | 16 ++++++++++++++++ src/api/include/migraphx/migraphx.h | 3 +++ src/api/include/migraphx/migraphx.hpp | 6 ++++++ src/api/migraphx.py | 5 +++++ src/py/migraphx_py.cpp | 14 ++++++++++---- tools/api/api.cpp | 5 +++++ 6 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/api/api.cpp b/src/api/api.cpp index 46613777117..f6fac27b50b 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -178,6 +178,11 @@ static void set_limit_loop_iterations(onnx_options& options, int64_t value) options.limit_max_iterations = value; } +static void set_use_debug_symbols(onnx_options& options, bool value) +{ + options.use_debug_symbols = value; +} + static void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; } static void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; } @@ -1991,6 +1996,17 @@ migraphx_onnx_options_set_external_data_path(migraphx_onnx_options_t onnx_option return api_error_result; } +extern "C" migraphx_status +migraphx_onnx_options_set_use_debug_symbols(migraphx_onnx_options_t onnx_options, bool value) +{ + auto api_error_result = migraphx::try_([&] { + if(onnx_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); + migraphx::set_use_debug_symbols((onnx_options->object), (value)); + }); + return api_error_result; +} + extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options) { auto api_error_result = migraphx::try_([&] { destroy((file_options)); }); diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index c4a992abcf5..891123f2d36 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -535,6 +535,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_limit_loop_iteration MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_external_data_path( migraphx_onnx_options_t onnx_options, const char* external_data_path); +MIGRAPHX_C_EXPORT migraphx_status +migraphx_onnx_options_set_use_debug_symbols(migraphx_onnx_options_t onnx_options, bool value); + MIGRAPHX_C_EXPORT migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index d059f7a95c8..518e9b6bc8c 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1366,6 +1366,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) this->get_handle_ptr(), external_data_path.c_str()); } + + /// Enable debug symbols from ONNX node names + void set_use_debug_symbols(bool value) + { + call(&migraphx_onnx_options_set_use_debug_symbols, this->get_handle_ptr(), value); + } }; /// Parse an onnx file into a migraphx program diff --git a/src/api/migraphx.py b/src/api/migraphx.py index 3c20e926268..e00bd782bda 100644 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -370,6 +370,11 @@ def onnx_options(h): api.params(external_data_path='const char*'), invoke='migraphx::set_external_data_path($@)', ) + h.method( + 'set_use_debug_symbols', + api.params(value='bool'), + invoke='migraphx::set_use_debug_symbols($@)', + ) @auto_handle() diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index e591474d499..43f20042233 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -634,7 +634,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) bool skip_unknown_operators, bool print_program_on_error, int64_t max_loop_iterations, - int64_t limit_max_iterations) { + int64_t limit_max_iterations, + bool use_debug_symbols) { migraphx::onnx_options options; options.default_dim_value = default_dim_value; options.default_dyn_dim_value = default_dyn_dim_value; @@ -645,6 +646,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) options.print_program_on_error = print_program_on_error; options.max_loop_iterations = max_loop_iterations; options.limit_max_iterations = limit_max_iterations; + options.use_debug_symbols = use_debug_symbols; return migraphx::parse_onnx(filename, options); }, "Parse onnx file", @@ -659,7 +661,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("skip_unknown_operators") = false, py::arg("print_program_on_error") = false, py::arg("max_loop_iterations") = 10, - py::arg("limit_max_iterations") = std::numeric_limits::max()); + py::arg("limit_max_iterations") = std::numeric_limits::max(), + py::arg("use_debug_symbols") = false); m.def( "parse_onnx_buffer", @@ -671,7 +674,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) map_dyn_input_dims, bool skip_unknown_operators, bool print_program_on_error, - const std::string& external_data_path) { + const std::string& external_data_path, + bool use_debug_symbols) { migraphx::onnx_options options; options.default_dim_value = default_dim_value; options.default_dyn_dim_value = default_dyn_dim_value; @@ -680,6 +684,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) options.skip_unknown_operators = skip_unknown_operators; options.print_program_on_error = print_program_on_error; options.external_data_path = external_data_path; + options.use_debug_symbols = use_debug_symbols; return migraphx::parse_onnx_buffer(onnx_buffer, options); }, "Parse onnx file", @@ -691,7 +696,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) std::unordered_map>(), py::arg("skip_unknown_operators") = false, py::arg("print_program_on_error") = false, - py::arg("external_data_path") = ""); + py::arg("external_data_path") = "", + py::arg("use_debug_symbols") = false); m.def( "load", diff --git a/tools/api/api.cpp b/tools/api/api.cpp index cfefcf39f29..05181653be4 100644 --- a/tools/api/api.cpp +++ b/tools/api/api.cpp @@ -178,6 +178,11 @@ static void set_limit_loop_iterations(onnx_options& options, int64_t value) options.limit_max_iterations = value; } +static void set_use_debug_symbols(onnx_options& options, bool value) +{ + options.use_debug_symbols = value; +} + static void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; } static void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; } From 58dfa497f6f614ae2c3e3d3ecb2cf2ea0013fab6 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 27 Mar 2026 16:53:39 -0500 Subject: [PATCH 085/107] Add debug symbols --- src/onnx/netron_output.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/onnx/netron_output.cpp b/src/onnx/netron_output.cpp index 2a5f866467a..8f95daa3dcb 100644 --- a/src/onnx/netron_output.cpp +++ b/src/onnx/netron_output.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -183,6 +184,14 @@ void add_node(onnx::GraphProto* graph, node->add_output(ins_uids.at(ins) + "->" + ins_uids.at(output_ins)); } } + + if(not ins->get_debug_symbols().empty()) + { + auto* attr = node->add_attribute(); + attr->set_name("debug symbols"); + attr->set_type(onnx::AttributeProto::STRING); + attr->set_s(join_strings(ins->get_debug_symbols(), ", ")); + } } void build_graph(onnx::GraphProto* graph, const module* mod) From 17d45951e97d10cbde2cd58cf77d733117132205 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 27 Mar 2026 17:21:51 -0500 Subject: [PATCH 086/107] Review comments --- src/api/include/migraphx/migraphx.hpp | 2 +- src/module.cpp | 9 +++--- test/module_test.cpp | 42 ++++++++++++++++++--------- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index 518e9b6bc8c..56adf089a74 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1368,7 +1368,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) } /// Enable debug symbols from ONNX node names - void set_use_debug_symbols(bool value) + void set_use_debug_symbols(bool value = true) { call(&migraphx_onnx_options_set_use_debug_symbols, this->get_handle_ptr(), value); } diff --git a/src/module.cpp b/src/module.cpp index 6c41e8fef8f..1322d7f3c91 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -62,7 +62,7 @@ struct module_impl uint32_t nparams = 0; bool bypass = false; // used for skipping compiler passes bit_signal<64> changed{}; - std::size_t num_ins_with_debug_symbols = 0; // number of ins with debug symbols + std::size_t num_ins_with_debug_symbols = 0; bool contains(instruction_ref ins) const { @@ -178,6 +178,7 @@ void module::add_debug_symbols(instruction_ref ins, const std::set& void module::remove_debug_symbols(instruction_ref ins) const { + assert(ins->get_debug_symbols().empty() or impl->num_ins_with_debug_symbols > 0); if(not ins->get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) { impl->num_ins_with_debug_symbols--; @@ -326,7 +327,7 @@ instruction_ref module::add_instruction(const operation& op, std::vector args) + std::vector args) MIGRAPHX_TIDY_CONST { assert(has_instruction(ins) or is_end(ins, this->end())); assert(not starts_with(op.name(), "@")); @@ -347,7 +348,7 @@ instruction_ref module::add_instruction(const operation& op, instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args, - std::vector module_args) + std::vector module_args) MIGRAPHX_TIDY_CONST { assert(has_instruction(ins) or is_end(ins, this->end())); assert(not starts_with(op.name(), "@")); @@ -517,7 +518,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, return ins; } -instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) +instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) MIGRAPHX_TIDY_CONST { impl->changed.notify(); assert(has_instruction(ins)); diff --git a/test/module_test.cpp b/test/module_test.cpp index 12229da8e99..845ac195dfd 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -2031,6 +2031,34 @@ TEST_CASE(move_output_instructions_after_cross_module_mixed) EXPECT(p1 == p2); } +TEST_CASE(debug_symbols_copy_module_verify) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto add = m1.add_instruction(migraphx::make_op("add"), x, y); + auto neg = m1.add_instruction(migraphx::make_op("neg"), add); + auto ret = m1.add_return({neg}); + + m1.add_debug_symbols(add, {"add_node"}); + m1.add_debug_symbols(neg, {"neg_node", "extra_sym"}); + m1.add_debug_symbols(ret, {"@output:0:result_tensor"}); + + auto m2 = m1; + + EXPECT(m2.has_debug_symbols()); + + auto it1 = m1.begin(); + auto it2 = m2.begin(); + for(; it1 != m1.end() and it2 != m2.end(); ++it1, ++it2) + { + EXPECT(it1->get_debug_symbols() == it2->get_debug_symbols()); + } + EXPECT(it1 == m1.end()); + EXPECT(it2 == m2.end()); +} + TEST_CASE(debug_symbols_add_and_remove) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -2114,20 +2142,6 @@ TEST_CASE(debug_symbols_remove_instructions_range) EXPECT(not m.has_debug_symbols()); } -TEST_CASE(debug_symbols_copy_module) -{ - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::module m; - auto x = m.add_parameter("x", s); - auto y = m.add_parameter("y", s); - auto add = m.add_instruction(migraphx::make_op("add"), x, y); - m.add_debug_symbols(add, {"add_sym"}); - m.add_return({add}); - - auto m2 = m; - EXPECT(m2.has_debug_symbols()); -} - TEST_CASE(erase_single_with_debug_symbols) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; From e5b2009bc9359d112b60791c6476cc9bf666389f Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 27 Mar 2026 17:28:05 -0500 Subject: [PATCH 087/107] Formatting --- src/instruction.cpp | 28 +++++++++++++++++++++------- src/program.cpp | 8 ++++++-- test/debug_symbols_test.cpp | 24 ++++++++++++++---------- test/simplify_algebra_test.cpp | 1 - 4 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 26e9fd172ea..9a3c1e29a0c 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -70,7 +70,9 @@ struct replace_shape_order std::size_t location(instruction_ref x) const { return std::distance(start, x); } bool operator()(instruction_ref x, instruction_ref y) const - { return location(x) > location(y); } + { + return location(x) > location(y); + } }; void instruction::replace(const shape& r) @@ -123,7 +125,9 @@ void instruction::clear_arguments() } bool operator==(const instruction& i, instruction_ref ref) -{ return std::addressof(i) == std::addressof(*ref); } +{ + return std::addressof(i) == std::addressof(*ref); +} bool instruction::valid(instruction_ref start, bool check_order) const { @@ -189,7 +193,9 @@ const std::vector& instruction::outputs() const { return output const std::set& instruction::get_debug_symbols() const { return debug_symbols; } void instruction::add_debug_symbols(const std::set& symbols) -{ debug_symbols.insert(symbols.begin(), symbols.end()); } +{ + debug_symbols.insert(symbols.begin(), symbols.end()); +} void instruction::remove_debug_symbols() { debug_symbols.clear(); } @@ -518,7 +524,9 @@ void instruction::set_normalized(bool value) { normalized = value; } bool instruction::is_normalized() const { return normalized; } bool instruction::need_normalization() const -{ return this->get_operator().need_normalization() and not normalized; } +{ + return this->get_operator().need_normalization() and not normalized; +} operation instruction::normalized_operator() const { @@ -544,7 +552,9 @@ std::vector to_shapes(const std::vector& args) } shape compute_shape(const operation& op, const std::vector& args) -{ return op.compute_shape(to_shapes(args)); } +{ + return op.compute_shape(to_shapes(args)); +} shape compute_shape(const operation& op, const std::vector& args, @@ -575,10 +585,14 @@ std::vector try_compute_shape(const operation& op, const std::vector::iterator& ins) noexcept -{ return iterator_address(ins); } +{ + return iterator_address(ins); +} const migraphx::instruction* as_address(const std::list::const_iterator& ins) noexcept -{ return iterator_address(ins); } +{ + return iterator_address(ins); +} template static auto track_visits(instruction_ref start, instruction_ref end, F f) diff --git a/src/program.cpp b/src/program.cpp index 5093c954f88..42efc1adaeb 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -81,7 +81,9 @@ struct program_impl program::program() : impl(std::make_unique()) { this->create_module("main"); } program::program(module m) : impl(std::make_unique()) -{ this->create_module("main", std::move(m)); } +{ + this->create_module("main", std::move(m)); +} program::program(program&&) noexcept = default; program::~program() noexcept = default; @@ -1375,7 +1377,9 @@ program& program::sort() } bool operator==(const program& x, const program& y) -{ return migraphx::to_string(x) == migraphx::to_string(y); } +{ + return migraphx::to_string(x) == migraphx::to_string(y); +} std::ostream& operator<<(std::ostream& os, const program& p) { diff --git a/test/debug_symbols_test.cpp b/test/debug_symbols_test.cpp index d2b56931110..82788ee2179 100644 --- a/test/debug_symbols_test.cpp +++ b/test/debug_symbols_test.cpp @@ -227,21 +227,27 @@ TEST_CASE(horiz_fusion_dot) auto concat = m1.insert_instruction(x, migraphx::make_op("concat", {{"axis", 2}}), a, b); auto fused_dot = m1.insert_instruction(x, migraphx::make_op("dot"), input, concat); m1.batch_replace_instruction({ - {x, migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), {fused_dot}, {}}, - {y, migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), {fused_dot}, {}}, + {x, + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), + {fused_dot}, + {}}, + {y, + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), + {fused_dot}, + {}}, }); } migraphx::module m2; { - auto input = m2.add_parameter("input", s); - auto a = m2.add_literal(migraphx::generate_literal(s, 0)); - auto b = m2.add_literal(migraphx::generate_literal(s, 1)); + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(s, 0)); + auto b = m2.add_literal(migraphx::generate_literal(s, 1)); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); m2.add_debug_symbols(concat, {"gemm1", "gemm2"}); auto fused_dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); m2.add_debug_symbols(fused_dot, {"gemm1", "gemm2"}); - auto sx = m2.add_instruction( + auto sx = m2.add_instruction( migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), fused_dot); m2.add_debug_symbols(sx, {"gemm1", "gemm2"}); auto sy = m2.add_instruction( @@ -821,8 +827,7 @@ TEST_CASE(rewrite_resize_debug_symbols) m1.add_debug_symbols(resize, {"onnx:resize"}); m1.add_return({resize}); - auto rsp = - m1.insert_instruction(resize, migraphx::make_op("reshape", {{"dims", {4}}}), x); + auto rsp = m1.insert_instruction(resize, migraphx::make_op("reshape", {{"dims", {4}}}), x); auto ins_ind = m1.add_literal( migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1, 1, 4, 4}}, indices}); m1.replace_instruction(resize, migraphx::make_op("gather", {{"axis", 0}}), rsp, ins_ind); @@ -836,8 +841,7 @@ TEST_CASE(rewrite_resize_debug_symbols) m2.add_debug_symbols(rsp, {"onnx:resize"}); auto ins_ind = m2.add_literal( migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1, 1, 4, 4}}, indices}); - auto gather = - m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, ins_ind); + auto gather = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, ins_ind); m2.add_debug_symbols(gather, {"onnx:resize"}); m2.add_return({gather}); } diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 584f70a0427..31f884d9b83 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -5303,5 +5303,4 @@ TEST_CASE(debug_symbols_simplify_div_const) EXPECT(m1.sort() == m2.sort()); } - int main(int argc, const char* argv[]) { test::run(argc, argv); } From e710cf6cd28b10139bbb4636289a52e45a48c506 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 27 Mar 2026 17:29:09 -0500 Subject: [PATCH 088/107] License --- src/api/include/migraphx/migraphx.hpp | 2 +- src/api/migraphx.py | 2 +- src/py/migraphx_py.cpp | 2 +- tools/api/migraphx.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index 56adf089a74..872fdcde1b9 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/api/migraphx.py b/src/api/migraphx.py index e00bd782bda..18bfb365719 100644 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 43f20042233..7bd5d03b3f0 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h index 90834fc7dc9..263dacd0160 100644 --- a/tools/api/migraphx.h +++ b/tools/api/migraphx.h @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From ce9d36cd781f5413d6b9791e3c0db9804e8e2cce Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 30 Mar 2026 14:46:30 -0500 Subject: [PATCH 089/107] Add tidy const to header --- src/include/migraphx/module.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 33b4feab9cb..190e964f323 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -107,12 +107,12 @@ struct MIGRAPHX_EXPORT module return insert_instruction(ins, op, {args...}); } instruction_ref - insert_instruction(instruction_ref ins, const operation& op, std::vector args); + insert_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST; instruction_ref insert_instruction(instruction_ref ins, const operation& op, std::vector args, - std::vector module_args); + std::vector module_args) MIGRAPHX_TIDY_CONST; template {}...)> instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) @@ -128,7 +128,7 @@ struct MIGRAPHX_EXPORT module std::vector args, std::vector module_args) MIGRAPHX_TIDY_CONST; - instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); + instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep) MIGRAPHX_TIDY_CONST; struct instruction_replacement { From a9519a5785ce101365396aa3613321fc9932d655 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 1 Apr 2026 11:44:28 -0500 Subject: [PATCH 090/107] Add debug symbols interface to api --- src/api/migraphx.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/api/migraphx.py b/src/api/migraphx.py index 18bfb365719..3168e906b00 100644 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -220,7 +220,10 @@ def shapes(h): @api.handle('migraphx_instruction', 'migraphx::instruction_ref') def instruction(h): - pass + h.method('get_debug_symbols', + fname='get_debug_symbols', + returns='const std::set&', + const=True) @api.handle('migraphx_instructions', 'std::vector') @@ -265,6 +268,16 @@ def module(h): api.params(s='const migraphx::shape&'), invoke='migraphx::add_allocation($@)', returns='migraphx::instruction_ref') + h.method('has_debug_symbols', + fname='has_debug_symbols', + returns='bool', + const=True) + h.method('add_debug_symbols', + api.params(ins='migraphx::instruction_ref', symbols='std::set'), + fname='add_debug_symbols') + h.method('remove_debug_symbols', + api.params(ins='migraphx::instruction_ref'), + fname='remove_debug_symbols') @auto_handle() From 14cb668257bc62edbcb327b28b5549bb3c4ba904 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 1 Apr 2026 11:45:17 -0500 Subject: [PATCH 091/107] Revert "Add debug symbols interface to api" This reverts commit a9519a5785ce101365396aa3613321fc9932d655. --- src/api/migraphx.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/api/migraphx.py b/src/api/migraphx.py index 3168e906b00..18bfb365719 100644 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -220,10 +220,7 @@ def shapes(h): @api.handle('migraphx_instruction', 'migraphx::instruction_ref') def instruction(h): - h.method('get_debug_symbols', - fname='get_debug_symbols', - returns='const std::set&', - const=True) + pass @api.handle('migraphx_instructions', 'std::vector') @@ -268,16 +265,6 @@ def module(h): api.params(s='const migraphx::shape&'), invoke='migraphx::add_allocation($@)', returns='migraphx::instruction_ref') - h.method('has_debug_symbols', - fname='has_debug_symbols', - returns='bool', - const=True) - h.method('add_debug_symbols', - api.params(ins='migraphx::instruction_ref', symbols='std::set'), - fname='add_debug_symbols') - h.method('remove_debug_symbols', - api.params(ins='migraphx::instruction_ref'), - fname='remove_debug_symbols') @auto_handle() From f0f5b78acaf5e8e999d0bb9c527fd83a8577e7b6 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 1 Apr 2026 11:47:15 -0500 Subject: [PATCH 092/107] clean up const-ness --- src/include/migraphx/module.hpp | 16 ++++++++-------- src/module.cpp | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 190e964f323..570491206bf 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -85,9 +85,9 @@ struct MIGRAPHX_EXPORT module /// If any instructions in this module have debug symbols bool has_debug_symbols() const; /// Merge given symbols with instruction's symbols - void add_debug_symbols(instruction_ref ins, const std::set& symbols) const; + void add_debug_symbols(instruction_ref ins, const std::set& symbols); /// Clear all debug symbols from instruction - void remove_debug_symbols(instruction_ref ins) const; + void remove_debug_symbols(instruction_ref ins); template {}...)> instruction_ref add_instruction(operation op, Ts... args) @@ -107,12 +107,12 @@ struct MIGRAPHX_EXPORT module return insert_instruction(ins, op, {args...}); } instruction_ref - insert_instruction(instruction_ref ins, const operation& op, std::vector args) MIGRAPHX_TIDY_CONST; + insert_instruction(instruction_ref ins, const operation& op, std::vector args); instruction_ref insert_instruction(instruction_ref ins, const operation& op, std::vector args, - std::vector module_args) MIGRAPHX_TIDY_CONST; + std::vector module_args); template {}...)> instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) @@ -121,14 +121,14 @@ struct MIGRAPHX_EXPORT module } instruction_ref replace_instruction(instruction_ref ins, const operation& op, - std::vector args) MIGRAPHX_TIDY_CONST; + std::vector args); instruction_ref replace_instruction(instruction_ref ins, const operation& op, std::vector args, - std::vector module_args) MIGRAPHX_TIDY_CONST; + std::vector module_args); - instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep) MIGRAPHX_TIDY_CONST; + instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); struct instruction_replacement { @@ -141,7 +141,7 @@ struct MIGRAPHX_EXPORT module /// Replaces an array of instructions within the same function to properly handle debug symbols /// propagation. Returns vector of instruction_ref to replaced instructions. std::vector batch_replace_instruction( - const std::vector& replacers) MIGRAPHX_TIDY_CONST; + const std::vector& replacers); instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instructions(instruction_ref first, instruction_ref last); diff --git a/src/module.cpp b/src/module.cpp index 1322d7f3c91..cb4e4bb2ffe 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -165,7 +165,7 @@ void module::set_bypass(bool b) { impl->bypass = b; } bool module::has_debug_symbols() const { return impl->num_ins_with_debug_symbols > 0; } -void module::add_debug_symbols(instruction_ref ins, const std::set& symbols) const +void module::add_debug_symbols(instruction_ref ins, const std::set& symbols) { if(symbols.empty()) return; @@ -176,7 +176,7 @@ void module::add_debug_symbols(instruction_ref ins, const std::set& ins->add_debug_symbols(symbols); } -void module::remove_debug_symbols(instruction_ref ins) const +void module::remove_debug_symbols(instruction_ref ins) { assert(ins->get_debug_symbols().empty() or impl->num_ins_with_debug_symbols > 0); if(not ins->get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) @@ -327,7 +327,7 @@ instruction_ref module::add_instruction(const operation& op, std::vector args) MIGRAPHX_TIDY_CONST + std::vector args) { assert(has_instruction(ins) or is_end(ins, this->end())); assert(not starts_with(op.name(), "@")); @@ -348,7 +348,7 @@ instruction_ref module::add_instruction(const operation& op, instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args, - std::vector module_args) MIGRAPHX_TIDY_CONST + std::vector module_args) { assert(has_instruction(ins) or is_end(ins, this->end())); assert(not starts_with(op.name(), "@")); @@ -461,7 +461,7 @@ static void propagate_debug_symbols(const_module_ref m, instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, - std::vector args) MIGRAPHX_TIDY_CONST + std::vector args) { impl->changed.notify(); assert(has_instruction(ins)); @@ -493,7 +493,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args, - std::vector module_args) MIGRAPHX_TIDY_CONST + std::vector module_args) { impl->changed.notify(); assert(has_instruction(ins)); @@ -518,7 +518,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, return ins; } -instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) MIGRAPHX_TIDY_CONST +instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) { impl->changed.notify(); assert(has_instruction(ins)); @@ -572,7 +572,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref // Handles debug symbol propagation by having all old splice debug symbols propagate // to the new splice instructions. std::vector module::batch_replace_instruction( - const std::vector& replacers) MIGRAPHX_TIDY_CONST + const std::vector& replacers) { impl->changed.notify(); std::vector ret; From b7583007ff5d91e2eaaffc2fbe9eb6d487be6779 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 1 Apr 2026 11:49:36 -0500 Subject: [PATCH 093/107] Add to api --- src/api/migraphx.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/api/migraphx.py b/src/api/migraphx.py index 18bfb365719..3168e906b00 100644 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -220,7 +220,10 @@ def shapes(h): @api.handle('migraphx_instruction', 'migraphx::instruction_ref') def instruction(h): - pass + h.method('get_debug_symbols', + fname='get_debug_symbols', + returns='const std::set&', + const=True) @api.handle('migraphx_instructions', 'std::vector') @@ -265,6 +268,16 @@ def module(h): api.params(s='const migraphx::shape&'), invoke='migraphx::add_allocation($@)', returns='migraphx::instruction_ref') + h.method('has_debug_symbols', + fname='has_debug_symbols', + returns='bool', + const=True) + h.method('add_debug_symbols', + api.params(ins='migraphx::instruction_ref', symbols='std::set'), + fname='add_debug_symbols') + h.method('remove_debug_symbols', + api.params(ins='migraphx::instruction_ref'), + fname='remove_debug_symbols') @auto_handle() From fe15f3249e4de3f8784fc2b933314bfda2f88053 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 13 Apr 2026 15:54:30 -0500 Subject: [PATCH 094/107] Fix merge conflicts leftover --- src/module.cpp | 65 ------------------------------------------------- src/program.cpp | 8 ++---- 2 files changed, 2 insertions(+), 71 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 6b08b022a68..ab4b4abea7d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -166,11 +166,7 @@ void module::set_bypass(bool b) { impl->bypass = b; } bool module::has_debug_symbols() const { return impl->num_ins_with_debug_symbols > 0; } -<<<<<<< HEAD -void module::add_debug_symbols(instruction_ref ins, const std::set& symbols) const -======= void module::add_debug_symbols(instruction_ref ins, const std::set& symbols) ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 { if(symbols.empty()) return; @@ -181,14 +177,9 @@ void module::add_debug_symbols(instruction_ref ins, const std::set& ins->add_debug_symbols(symbols); } -<<<<<<< HEAD -void module::remove_debug_symbols(instruction_ref ins) const -{ -======= void module::remove_debug_symbols(instruction_ref ins) { assert(ins->get_debug_symbols().empty() or impl->num_ins_with_debug_symbols > 0); ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 if(not ins->get_debug_symbols().empty() and impl->num_ins_with_debug_symbols > 0) { impl->num_ins_with_debug_symbols--; @@ -400,15 +391,8 @@ static std::unordered_set gather_max_splice( return result; } -<<<<<<< HEAD -/** - * Figure out the instructions actually being spliced (min splice). - * end: instruction at end of splice - */ -======= // Figure out the instructions actually being spliced (min splice). // end: instruction at end of splice ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 static std::unordered_set deduce_min_splice(std::vector ends, const std::unordered_set& max_splice, @@ -441,11 +425,7 @@ deduce_min_splice(std::vector ends, // ins: instruction that was/will be replaced // rep: replacing instruction -<<<<<<< HEAD -static void propagate_debug_symbols(const_module_ref m, -======= static void propagate_debug_symbols(module_ref m, ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 instruction_ref ins, instruction_ref rep, const std::unordered_set& new_max_splice, @@ -478,8 +458,6 @@ static void propagate_debug_symbols(module_ref m, } } -<<<<<<< HEAD -======= // Adds a placeholder identity instruction for when replacing an instruction in place. static void propagate_debug_symbols_with_placeholder(module_ref m, instruction_ref ins, @@ -494,40 +472,11 @@ static void propagate_debug_symbols_with_placeholder(module_ref m, m->remove_instruction(id_ins); } ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 instruction_ref module::replace_instruction(instruction_ref ins, const operation& op, std::vector args) { -<<<<<<< HEAD - impl->changed.notify(); - assert(has_instruction(ins)); - assert(not starts_with(op.name(), "@")); - - shape r = compute_shape(op, args); - std::vector prev_args; - if(has_debug_symbols()) - { - prev_args = ins->inputs(); - } - instruction::replace(ins, op, r, std::move(args)); - if(has_debug_symbols() and not prev_args.empty()) - { - // placeholder identity instruction - auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); - // Get old_max_splice after replacement to get smallest dependent splice - std::unordered_set old_max_splice = gather_max_splice(this, id_ins); - // TODO: if there are no common ancestors, this may traverse the majority of the graph - std::unordered_set new_max_splice = - gather_max_splice(this, ins, old_max_splice); - propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); - remove_instruction(id_ins); - } - assert(ins->valid(begin())); - return ins; -======= return replace_instruction(ins, op, std::move(args), {}); ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 } instruction_ref module::replace_instruction(instruction_ref ins, @@ -547,16 +496,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); if(has_debug_symbols() and not prev_args.empty()) { -<<<<<<< HEAD - auto id_ins = insert_instruction(ins, make_op("identity"), prev_args); - std::unordered_set old_max_splice = gather_max_splice(this, id_ins); - std::unordered_set new_max_splice = - gather_max_splice(this, ins, old_max_splice); - propagate_debug_symbols(this, ins, id_ins, new_max_splice, old_max_splice); - remove_instruction(id_ins); -======= propagate_debug_symbols_with_placeholder(this, ins, prev_args); ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 } assert(ins->valid(begin())); return ins; @@ -615,13 +555,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref // For replacing multiple instructions within a single matcher. // Handles debug symbol propagation by having all old splice debug symbols propagate // to the new splice instructions. -<<<<<<< HEAD -std::vector module::batch_replace_instruction( - const std::vector& replacers) MIGRAPHX_TIDY_CONST -======= std::vector module::batch_replace_instruction(const std::vector& replacers) ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 { impl->changed.notify(); std::vector ret; diff --git a/src/program.cpp b/src/program.cpp index d7be0258c77..25459961c38 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1195,7 +1195,7 @@ void program::annotate(std::ostream& os, const std::functionimpl->modules) { - os << pp.first << ":" << std::endl; + os << pp.first << ":" << std::endl;https://charliel7.github.io/AMDMIGraphX_test/ pp.second.annotate(os, a); } } @@ -1320,7 +1320,7 @@ void program::remove_module(const std::string& name) // if an instruction has an input out side of the current module, need to remove // the instruction from its input's outputs - auto& mod = impl->modules.at(name); + auto& mod = impl->modules.at(name);https://charliel7.github.io/AMDMIGraphX_test/ for(auto ins : iterator_for(mod)) { auto inputs = ins->inputs(); @@ -1375,13 +1375,9 @@ program& program::sort() } bool operator==(const program& x, const program& y) -<<<<<<< HEAD -{ return migraphx::to_string(x) == migraphx::to_string(y); } -======= { return migraphx::to_string(x) == migraphx::to_string(y); } ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 std::ostream& operator<<(std::ostream& os, const program& p) { From 57e649be252b85b578e02de8ba754c46c842e6dc Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 13:47:59 -0500 Subject: [PATCH 095/107] Resolve conflict markers --- src/include/migraphx/module.hpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index bd4ba952e81..cb0206c9e37 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -140,13 +140,8 @@ struct MIGRAPHX_EXPORT module /// Replaces an array of instructions within the same function to properly handle debug symbols /// propagation. Returns vector of instruction_ref to replaced instructions. -<<<<<<< HEAD - std::vector batch_replace_instruction( - const std::vector& replacers) MIGRAPHX_TIDY_CONST; -======= std::vector batch_replace_instruction(const std::vector& replacers); ->>>>>>> d9c440eb76eb5f4991cbf7ea17706b72a12af934 instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instructions(instruction_ref first, instruction_ref last); From 801bf36439ca01d740b07a0708aef8d7478e9169 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 13:48:57 -0500 Subject: [PATCH 096/107] Make program ir version static constexpr --- src/include/migraphx/program.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 0ce8cb9a1b5..63c300ba444 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -170,7 +170,7 @@ struct MIGRAPHX_EXPORT program std::unique_ptr impl; // program file version is for the data structure or format of the MXR file. Version should be bumped // if any changes occur to the format of the MXR file. - const int program_file_version = 8; + static constexpr int program_file_version = 8; }; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx From 31fd8c150124d57302c6c928fee46e93019e42b4 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 13:51:50 -0500 Subject: [PATCH 097/107] netron output only main module --- src/onnx/netron_output.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/onnx/netron_output.cpp b/src/onnx/netron_output.cpp index 8f95daa3dcb..ee34fca4f3f 100644 --- a/src/onnx/netron_output.cpp +++ b/src/onnx/netron_output.cpp @@ -241,11 +241,10 @@ void write_netron_output(const program& prog, std::ostream& os) model.set_ir_version(prog.get_program_file_version()); model.set_producer_name("AMDMIGraphX"); model.set_producer_version(prog_value.at("migraphx_version").to()); - - for(auto* mod : prog.get_modules()) - { - build_graph(model.mutable_graph(), mod); - } + + // only exporting the main module + // TODO handle submodules as ONNX subgraphs + build_graph(model.mutable_graph(), prog.get_main_module()); model.SerializeToOstream(&os); } From 40e8e99e3c35017b2fb8dd05d230ee606079f966 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 13:54:52 -0500 Subject: [PATCH 098/107] api check output stream --- src/api/api.cpp | 3 +++ tools/api/api.cpp | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/api/api.cpp b/src/api/api.cpp index 6ef8b298333..ff7a7e8590b 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -342,6 +342,9 @@ static void print_program(const program& p) { std::cout << p << std::endl; } static void write_netron_output_file(const program& p, const char* filename) { std::ofstream os(filename, std::ios::binary); + if(not os.is_open()) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Failed to open file for writing: " + std::string(filename)); write_netron_output(p, os); } diff --git a/tools/api/api.cpp b/tools/api/api.cpp index 54fed4ef8b1..a473705a016 100644 --- a/tools/api/api.cpp +++ b/tools/api/api.cpp @@ -342,6 +342,9 @@ static void print_program(const program& p) { std::cout << p << std::endl; } static void write_netron_output_file(const program& p, const char* filename) { std::ofstream os(filename, std::ios::binary); + if(not os.is_open()) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Failed to open file for writing: " + std::string(filename)); write_netron_output(p, os); } From 8c9599a802bf1e72e93d3fddcf69bbd9fd827f8a Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 13:57:04 -0500 Subject: [PATCH 099/107] Remove raw_data entry --- src/onnx/netron_output.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/onnx/netron_output.cpp b/src/onnx/netron_output.cpp index ee34fca4f3f..67a42def0b0 100644 --- a/src/onnx/netron_output.cpp +++ b/src/onnx/netron_output.cpp @@ -115,7 +115,6 @@ void add_initializer(onnx::GraphProto* graph, { init->add_dims(d); } - init->set_raw_data("NULL"); } void add_node(onnx::GraphProto* graph, From 78445f8b2e28389d93435f5a24ba5330a7b76890 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 14:04:24 -0500 Subject: [PATCH 100/107] Licensing --- src/include/migraphx/program.hpp | 2 +- src/onnx/netron_output.cpp | 2 +- test/api/CMakeLists.txt | 2 +- test/py/CMakeLists.txt | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 63c300ba444..f5d0d372795 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/onnx/netron_output.cpp b/src/onnx/netron_output.cpp index 67a42def0b0..73d9b3a3285 100644 --- a/src/onnx/netron_output.cpp +++ b/src/onnx/netron_output.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/api/CMakeLists.txt b/test/api/CMakeLists.txt index f0b22e6e967..0740c840031 100644 --- a/test/api/CMakeLists.txt +++ b/test/api/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/test/py/CMakeLists.txt b/test/py/CMakeLists.txt index e3d5d05e622..b591f6ecf6b 100644 --- a/test/py/CMakeLists.txt +++ b/test/py/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal From 836cf4d6718596014098b4397babd244b511a453 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 14:06:40 -0500 Subject: [PATCH 101/107] Add ostream checking for python api --- src/py/migraphx_py.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 94d8f9c617a..666edbc72de 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -572,6 +572,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) .def("write_netron_output", [](const migraphx::program& p, const std::string& filename) { std::ofstream os(filename, std::ios::binary); + if(not os.is_open()) + throw std::runtime_error("Failed to open file for writing: " + filename); migraphx::write_netron_output(p, os); }, "Write program as ONNX protobuf binary viewable in Netron", From ab03d6d3ca618422aded30b231e0fe55c67b48fe Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 14:10:47 -0500 Subject: [PATCH 102/107] remove typos --- src/program.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/program.cpp b/src/program.cpp index 25459961c38..d6460807e99 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1195,7 +1195,7 @@ void program::annotate(std::ostream& os, const std::functionimpl->modules) { - os << pp.first << ":" << std::endl;https://charliel7.github.io/AMDMIGraphX_test/ + os << pp.first << ":" << std::endl; pp.second.annotate(os, a); } } @@ -1320,7 +1320,7 @@ void program::remove_module(const std::string& name) // if an instruction has an input out side of the current module, need to remove // the instruction from its input's outputs - auto& mod = impl->modules.at(name);https://charliel7.github.io/AMDMIGraphX_test/ + auto& mod = impl->modules.at(name); for(auto ins : iterator_for(mod)) { auto inputs = ins->inputs(); From b6c35dcd749ebb17861d0b1e88c9bf6ce3a8dc2b Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Apr 2026 15:00:59 -0500 Subject: [PATCH 103/107] Update tests to have ONNX protobuff expected value --- test/CMakeLists.txt | 2 + test/netron_output_test.cpp | 193 ++++++++++++++++++++++++++++++++++-- 2 files changed, 187 insertions(+), 8 deletions(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1db4fb29d4f..a3fe705fc88 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -45,6 +45,8 @@ foreach(TEST ${TESTS}) rocm_clang_tidy_check(test_${BASE_NAME}) endforeach() +target_link_libraries(test_netron_output_test onnx-proto) + if(MIGRAPHX_ENABLE_GPU) # gpu tests file(GLOB GPU_TESTS CONFIGURE_DEPENDS gpu/*.cpp) diff --git a/test/netron_output_test.cpp b/test/netron_output_test.cpp index 83634c7d8a5..b69f17585c0 100644 --- a/test/netron_output_test.cpp +++ b/test/netron_output_test.cpp @@ -25,9 +25,69 @@ #include #include #include +#include #include #include +namespace onnx = onnx_for_migraphx; + +static void set_value_info(onnx::ValueInfoProto* vi, + const std::string& name, + int elem_type, + const std::vector& dims) +{ + vi->set_name(name); + auto* tensor = vi->mutable_type()->mutable_tensor_type(); + tensor->set_elem_type(elem_type); + for(auto d : dims) + tensor->mutable_shape()->add_dim()->set_dim_value(d); +} + +static void add_initializer(onnx::GraphProto* graph, + const std::string& name, + int data_type, + const std::vector& dims) +{ + auto* init = graph->add_initializer(); + init->set_name(name); + init->set_data_type(data_type); + for(auto d : dims) + init->add_dims(d); +} + +static onnx::NodeProto* add_node(onnx::GraphProto* graph, + const std::string& op_type, + const std::string& name, + const std::vector& inputs, + const std::vector& outputs) +{ + auto* node = graph->add_node(); + node->set_op_type(op_type); + node->set_name(name); + for(const auto& in : inputs) + node->add_input(in); + for(const auto& out : outputs) + node->add_output(out); + return node; +} + +static void add_string_attribute(onnx::NodeProto* node, + const std::string& name, + const std::string& value) +{ + auto* attr = node->add_attribute(); + attr->set_name(name); + attr->set_type(onnx::AttributeProto::STRING); + attr->set_s(value); +} + +static onnx::GraphProto parse_graph(const std::string& proto_binary) +{ + onnx::ModelProto model; + model.ParseFromString(proto_binary); + return model.graph(); +} + TEST_CASE(netron_output_basic) { migraphx::program p; @@ -41,9 +101,17 @@ TEST_CASE(netron_output_basic) std::ostringstream os; migraphx::write_netron_output(p, os); - std::string output = os.str(); - EXPECT(not output.empty()); - EXPECT(output.size() > 10); + onnx::GraphProto expected; + add_node(&expected, + "add", + "main:add:@2", + {"main:@param:x:@1", "main:@param:y:@0"}, + {"main:@return:@3"}); + set_value_info(expected.add_input(), "main:@param:y:@0", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_input(), "main:@param:x:@1", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_output(), "main:@return:@3", onnx::TensorProto::FLOAT, {2, 3}); + + EXPECT(parse_graph(os.str()).SerializeAsString() == expected.SerializeAsString()); } TEST_CASE(netron_output_with_literal) @@ -60,9 +128,120 @@ TEST_CASE(netron_output_with_literal) std::ostringstream os; migraphx::write_netron_output(p, os); - std::string output = os.str(); - EXPECT(not output.empty()); - EXPECT(output.size() > 10); + onnx::GraphProto expected; + add_node(&expected, + "add", + "main:add:@2", + {"main:@param:x:@1", "main:@literal:@0"}, + {"main:@return:@3"}); + add_initializer(&expected, "main:@literal:@0", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_input(), "main:@param:x:@1", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_output(), "main:@return:@3", onnx::TensorProto::FLOAT, {2, 3}); + + EXPECT(parse_graph(os.str()).SerializeAsString() == expected.SerializeAsString()); +} + +TEST_CASE(netron_output_multiple_types) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", {migraphx::shape::int32_type, {2, 3}}); + mm->add_parameter("b", {migraphx::shape::int64_type, {4}}); + mm->add_parameter("c", {migraphx::shape::half_type, {2, 3}}); + mm->add_parameter("d", {migraphx::shape::double_type, {2, 3}}); + mm->add_parameter("e", {migraphx::shape::bool_type, {2}}); + mm->add_parameter("f", {migraphx::shape::uint8_type, {2, 3}}); + mm->add_parameter("g", {migraphx::shape::bf16_type, {2, 3}}); + + auto lit = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 3}}, + {1, 2, 3, 4, 5, 6}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), a, lit); + mm->add_return({sum}); + + std::ostringstream os; + migraphx::write_netron_output(p, os); + + onnx::GraphProto expected; + add_node(&expected, + "add", + "main:add:@8", + {"main:@param:a:@7", "main:@literal:@0"}, + {"main:@return:@9"}); + add_initializer(&expected, "main:@literal:@0", onnx::TensorProto::INT32, {2, 3}); + set_value_info(expected.add_input(), "main:@param:g:@1", onnx::TensorProto::BFLOAT16, {2, 3}); + set_value_info(expected.add_input(), "main:@param:f:@2", onnx::TensorProto::UINT8, {2, 3}); + set_value_info(expected.add_input(), "main:@param:e:@3", onnx::TensorProto::BOOL, {2}); + set_value_info(expected.add_input(), "main:@param:d:@4", onnx::TensorProto::DOUBLE, {2, 3}); + set_value_info(expected.add_input(), "main:@param:c:@5", onnx::TensorProto::FLOAT16, {2, 3}); + set_value_info(expected.add_input(), "main:@param:b:@6", onnx::TensorProto::INT64, {4}); + set_value_info(expected.add_input(), "main:@param:a:@7", onnx::TensorProto::INT32, {2, 3}); + set_value_info(expected.add_output(), "main:@return:@9", onnx::TensorProto::INT32, {2, 3}); + + EXPECT(parse_graph(os.str()).SerializeAsString() == expected.SerializeAsString()); +} + +TEST_CASE(netron_output_op_attributes_and_chain) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 3}}); + auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), x); + auto sum = mm->add_instruction(migraphx::make_op("add"), sm, y); + mm->add_return({sum}); + + std::ostringstream os; + migraphx::write_netron_output(p, os); + + onnx::GraphProto expected; + auto* sm_node = add_node(&expected, + "softmax", + "main:softmax:@2", + {"main:@param:x:@1"}, + {"main:softmax:@2->main:add:@3"}); + add_string_attribute(sm_node, "axis", "1"); + add_node(&expected, + "add", + "main:add:@3", + {"main:softmax:@2->main:add:@3", "main:@param:y:@0"}, + {"main:@return:@4"}); + set_value_info(expected.add_input(), "main:@param:y:@0", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_input(), "main:@param:x:@1", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_output(), "main:@return:@4", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info( + expected.add_value_info(), "main:softmax:@2->main:add:@3", onnx::TensorProto::FLOAT, {2, 3}); + + EXPECT(parse_graph(os.str()).SerializeAsString() == expected.SerializeAsString()); +} + +TEST_CASE(netron_output_debug_symbols) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 3}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_debug_symbols(sum, {"test_file.onnx:42", "origin_op:Add"}); + mm->add_return({sum}); + + std::ostringstream os; + migraphx::write_netron_output(p, os); + + onnx::GraphProto expected; + auto* node = add_node(&expected, + "add", + "main:add:@2", + {"main:@param:x:@1", "main:@param:y:@0"}, + {"main:@return:@3"}); + add_string_attribute(node, "debug symbols", "origin_op:Add, test_file.onnx:42"); + set_value_info(expected.add_input(), "main:@param:y:@0", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_input(), "main:@param:x:@1", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_output(), "main:@return:@3", onnx::TensorProto::FLOAT, {2, 3}); + + EXPECT(parse_graph(os.str()).SerializeAsString() == expected.SerializeAsString()); } TEST_CASE(netron_output_roundtrip) @@ -79,8 +258,6 @@ TEST_CASE(netron_output_roundtrip) migraphx::write_netron_output(p, os); std::string output = os.str(); - // The output should be parseable as a valid ONNX model. - // Most nodes will be unknown operators however. migraphx::onnx_options options; options.skip_unknown_operators = true; auto p2 = migraphx::parse_onnx_buffer(output.data(), output.size(), options); From 930d4806fb4190e5ecc5ef32572226ec2dba2346 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 17 Apr 2026 15:31:51 -0500 Subject: [PATCH 104/107] license --- test/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a3fe705fc88..8df09eb4fbb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ # #################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal From 6681273036350d5daf9417c5b4dd5447b1d1eba1 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 17 Apr 2026 15:32:58 -0500 Subject: [PATCH 105/107] formatting --- src/instruction.cpp | 24 ++++++++++++++----- src/program.cpp | 9 ++++--- src/py/migraphx_py.cpp | 19 ++++++++------- test/netron_output_test.cpp | 47 +++++++++++++++++++------------------ 4 files changed, 56 insertions(+), 43 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 17111afd96c..9a3c1e29a0c 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -70,7 +70,9 @@ struct replace_shape_order std::size_t location(instruction_ref x) const { return std::distance(start, x); } bool operator()(instruction_ref x, instruction_ref y) const - { return location(x) > location(y); } + { + return location(x) > location(y); + } }; void instruction::replace(const shape& r) @@ -123,7 +125,9 @@ void instruction::clear_arguments() } bool operator==(const instruction& i, instruction_ref ref) -{ return std::addressof(i) == std::addressof(*ref); } +{ + return std::addressof(i) == std::addressof(*ref); +} bool instruction::valid(instruction_ref start, bool check_order) const { @@ -520,7 +524,9 @@ void instruction::set_normalized(bool value) { normalized = value; } bool instruction::is_normalized() const { return normalized; } bool instruction::need_normalization() const -{ return this->get_operator().need_normalization() and not normalized; } +{ + return this->get_operator().need_normalization() and not normalized; +} operation instruction::normalized_operator() const { @@ -546,7 +552,9 @@ std::vector to_shapes(const std::vector& args) } shape compute_shape(const operation& op, const std::vector& args) -{ return op.compute_shape(to_shapes(args)); } +{ + return op.compute_shape(to_shapes(args)); +} shape compute_shape(const operation& op, const std::vector& args, @@ -577,10 +585,14 @@ std::vector try_compute_shape(const operation& op, const std::vector::iterator& ins) noexcept -{ return iterator_address(ins); } +{ + return iterator_address(ins); +} const migraphx::instruction* as_address(const std::list::const_iterator& ins) noexcept -{ return iterator_address(ins); } +{ + return iterator_address(ins); +} template static auto track_visits(instruction_ref start, instruction_ref end, F f) diff --git a/src/program.cpp b/src/program.cpp index d6460807e99..a571c0fa3ac 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -82,7 +82,9 @@ struct program_impl program::program() : impl(std::make_unique()) { this->create_module("main"); } program::program(module m) : impl(std::make_unique()) -{ this->create_module("main", std::move(m)); } +{ + this->create_module("main", std::move(m)); +} program::program(program&&) noexcept = default; program::~program() noexcept = default; @@ -159,10 +161,7 @@ std::unordered_map program::get_parameter_shapes() const return mm->get_parameter_shapes(); } -int program::get_program_file_version() const -{ - return program_file_version; -} +int program::get_program_file_version() const { return program_file_version; } std::size_t program::size() const { return impl->modules.size(); } diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 666edbc72de..70bb7d48c37 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -569,15 +569,16 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) return ss.str(); }) .def("sort", &migraphx::program::sort) - .def("write_netron_output", - [](const migraphx::program& p, const std::string& filename) { - std::ofstream os(filename, std::ios::binary); - if(not os.is_open()) - throw std::runtime_error("Failed to open file for writing: " + filename); - migraphx::write_netron_output(p, os); - }, - "Write program as ONNX protobuf binary viewable in Netron", - py::arg("filename")) + .def( + "write_netron_output", + [](const migraphx::program& p, const std::string& filename) { + std::ofstream os(filename, std::ios::binary); + if(not os.is_open()) + throw std::runtime_error("Failed to open file for writing: " + filename); + migraphx::write_netron_output(p, os); + }, + "Write program as ONNX protobuf binary viewable in Netron", + py::arg("filename")) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("__eq__", std::equal_to{}) .def("__ne__", std::not_equal_to{}) diff --git a/test/netron_output_test.cpp b/test/netron_output_test.cpp index b69f17585c0..5b5bdffb120 100644 --- a/test/netron_output_test.cpp +++ b/test/netron_output_test.cpp @@ -44,9 +44,9 @@ static void set_value_info(onnx::ValueInfoProto* vi, } static void add_initializer(onnx::GraphProto* graph, - const std::string& name, - int data_type, - const std::vector& dims) + const std::string& name, + int data_type, + const std::vector& dims) { auto* init = graph->add_initializer(); init->set_name(name); @@ -56,10 +56,10 @@ static void add_initializer(onnx::GraphProto* graph, } static onnx::NodeProto* add_node(onnx::GraphProto* graph, - const std::string& op_type, - const std::string& name, - const std::vector& inputs, - const std::vector& outputs) + const std::string& op_type, + const std::string& name, + const std::vector& inputs, + const std::vector& outputs) { auto* node = graph->add_node(); node->set_op_type(op_type); @@ -71,9 +71,8 @@ static onnx::NodeProto* add_node(onnx::GraphProto* graph, return node; } -static void add_string_attribute(onnx::NodeProto* node, - const std::string& name, - const std::string& value) +static void +add_string_attribute(onnx::NodeProto* node, const std::string& name, const std::string& value) { auto* attr = node->add_attribute(); attr->set_name(name); @@ -121,7 +120,7 @@ TEST_CASE(netron_output_with_literal) auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3}}); auto lit = mm->add_literal(migraphx::literal{{migraphx::shape::float_type, {2, 3}}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); auto sum = mm->add_instruction(migraphx::make_op("add"), x, lit); mm->add_return({sum}); @@ -154,8 +153,8 @@ TEST_CASE(netron_output_multiple_types) mm->add_parameter("f", {migraphx::shape::uint8_type, {2, 3}}); mm->add_parameter("g", {migraphx::shape::bf16_type, {2, 3}}); - auto lit = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 3}}, - {1, 2, 3, 4, 5, 6}}); + auto lit = mm->add_literal( + migraphx::literal{{migraphx::shape::int32_type, {2, 3}}, {1, 2, 3, 4, 5, 6}}); auto sum = mm->add_instruction(migraphx::make_op("add"), a, lit); mm->add_return({sum}); @@ -197,10 +196,10 @@ TEST_CASE(netron_output_op_attributes_and_chain) onnx::GraphProto expected; auto* sm_node = add_node(&expected, - "softmax", - "main:softmax:@2", - {"main:@param:x:@1"}, - {"main:softmax:@2->main:add:@3"}); + "softmax", + "main:softmax:@2", + {"main:@param:x:@1"}, + {"main:softmax:@2->main:add:@3"}); add_string_attribute(sm_node, "axis", "1"); add_node(&expected, "add", @@ -210,8 +209,10 @@ TEST_CASE(netron_output_op_attributes_and_chain) set_value_info(expected.add_input(), "main:@param:y:@0", onnx::TensorProto::FLOAT, {2, 3}); set_value_info(expected.add_input(), "main:@param:x:@1", onnx::TensorProto::FLOAT, {2, 3}); set_value_info(expected.add_output(), "main:@return:@4", onnx::TensorProto::FLOAT, {2, 3}); - set_value_info( - expected.add_value_info(), "main:softmax:@2->main:add:@3", onnx::TensorProto::FLOAT, {2, 3}); + set_value_info(expected.add_value_info(), + "main:softmax:@2->main:add:@3", + onnx::TensorProto::FLOAT, + {2, 3}); EXPECT(parse_graph(os.str()).SerializeAsString() == expected.SerializeAsString()); } @@ -232,10 +233,10 @@ TEST_CASE(netron_output_debug_symbols) onnx::GraphProto expected; auto* node = add_node(&expected, - "add", - "main:add:@2", - {"main:@param:x:@1", "main:@param:y:@0"}, - {"main:@return:@3"}); + "add", + "main:add:@2", + {"main:@param:x:@1", "main:@param:y:@0"}, + {"main:@return:@3"}); add_string_attribute(node, "debug symbols", "origin_op:Add, test_file.onnx:42"); set_value_info(expected.add_input(), "main:@param:y:@0", onnx::TensorProto::FLOAT, {2, 3}); set_value_info(expected.add_input(), "main:@param:x:@1", onnx::TensorProto::FLOAT, {2, 3}); From a99ab17a82ba799c0b4e286037cfa8984edd8c69 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 17 Apr 2026 15:50:37 -0500 Subject: [PATCH 106/107] add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b58aa81760c..4a44af776bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Full documentation for MIGraphX is available at * Updated the ONNX clip operator to support opset 13 (#4518). * Updated `argmin` and `argmax` ops to be implemented as reduction ops, so they now have JIT support and can fuse (#4620). * Replaced usages of `std::cout` and `std::cerr` with the logger (#4732) +* Updated netron output to create an ONNX-like protobuff. Now also includes debug symbols if enabled. (#4701) ### Resolved issues From 7055913958d2b6b4a570433e8c6b44f1fab09c49 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 20 Apr 2026 13:42:44 -0500 Subject: [PATCH 107/107] Add tests and pybind funcs --- src/py/migraphx_py.cpp | 16 +++ test/py/CMakeLists.txt | 1 + test/py/test_debug_symbols.py | 217 ++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+) create mode 100644 test/py/test_debug_symbols.py diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 4a77cf15ecb..67869d8529d 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -459,6 +459,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) .def("name", [](migraphx::instruction_ref i) { return i->name(); }) .def("get_literal", [](migraphx::instruction_ref i) { return i->get_literal().get_argument(); }) + .def("get_debug_symbols", + [](migraphx::instruction_ref i) { return i->get_debug_symbols(); }) .def(py::hash(py::self)) .def(py::self == py::self) .def(py::self != py::self); @@ -534,6 +536,20 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("macro"), py::arg("args"), py::arg("mod_args") = std::vector{}) + .def("has_debug_symbols", &migraphx::module::has_debug_symbols) + .def( + "add_debug_symbols", + [](migraphx::module& mm, + migraphx::instruction_ref ins, + const std::set& symbols) { mm.add_debug_symbols(ins, symbols); }, + py::arg("ins"), + py::arg("symbols")) + .def( + "remove_debug_symbols", + [](migraphx::module& mm, migraphx::instruction_ref ins) { + mm.remove_debug_symbols(ins); + }, + py::arg("ins")) .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }) .def( "__iter__", diff --git a/test/py/CMakeLists.txt b/test/py/CMakeLists.txt index cfa56ff564f..6db68fa0f93 100644 --- a/test/py/CMakeLists.txt +++ b/test/py/CMakeLists.txt @@ -106,6 +106,7 @@ add_py_test(module_construct test_module_construct.py common ${VENV} WORKING_DIR add_py_test(macro test_macro.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(literal test_literal.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(autocast_fp8 test_autocast_fp8.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(debug_symbols test_debug_symbols.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) if(MIGRAPHX_ENABLE_GPU) add_py_test(gpu_offload test_gpu_offload.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) diff --git a/test/py/test_debug_symbols.py b/test/py/test_debug_symbols.py new file mode 100644 index 00000000000..253e80cde9d --- /dev/null +++ b/test/py/test_debug_symbols.py @@ -0,0 +1,217 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### +import migraphx + + +def test_module_has_no_debug_symbols(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + mm.add_instruction(migraphx.op("add"), [x, y]) + assert not mm.has_debug_symbols() + + +def test_module_add_debug_symbols(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + + assert not mm.has_debug_symbols() + mm.add_debug_symbols(add_ins, {"sym_a", "sym_b"}) + assert mm.has_debug_symbols() + + +def test_module_add_debug_symbols_multiple_instructions(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + relu_ins = mm.add_instruction(migraphx.op("relu"), [add_ins]) + + mm.add_debug_symbols(add_ins, {"onnx:add"}) + mm.add_debug_symbols(relu_ins, {"onnx:relu"}) + + assert mm.has_debug_symbols() + assert add_ins.get_debug_symbols() == {"onnx:add"} + assert relu_ins.get_debug_symbols() == {"onnx:relu"} + + +def test_module_add_debug_symbols_merge(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + + mm.add_debug_symbols(add_ins, {"sym_a"}) + mm.add_debug_symbols(add_ins, {"sym_b"}) + assert add_ins.get_debug_symbols() == {"sym_a", "sym_b"} + + +def test_module_remove_debug_symbols(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + + mm.add_debug_symbols(add_ins, {"sym_a", "sym_b"}) + assert mm.has_debug_symbols() + + mm.remove_debug_symbols(add_ins) + assert add_ins.get_debug_symbols() == set() + assert not mm.has_debug_symbols() + + +def test_module_remove_one_of_two(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + relu_ins = mm.add_instruction(migraphx.op("relu"), [add_ins]) + + mm.add_debug_symbols(add_ins, {"sym_add"}) + mm.add_debug_symbols(relu_ins, {"sym_relu"}) + + mm.remove_debug_symbols(add_ins) + assert add_ins.get_debug_symbols() == set() + assert relu_ins.get_debug_symbols() == {"sym_relu"} + assert mm.has_debug_symbols() + + +def test_module_remove_then_readd(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + + mm.add_debug_symbols(add_ins, {"old_sym"}) + mm.remove_debug_symbols(add_ins) + assert not mm.has_debug_symbols() + + mm.add_debug_symbols(add_ins, {"new_sym"}) + assert add_ins.get_debug_symbols() == {"new_sym"} + assert mm.has_debug_symbols() + + +def test_instruction_get_debug_symbols(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + + assert add_ins.get_debug_symbols() == set() + + mm.add_debug_symbols(add_ins, {"sym_a", "sym_b", "sym_c"}) + assert add_ins.get_debug_symbols() == {"sym_a", "sym_b", "sym_c"} + + +def test_iterate_instructions_debug_symbols(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + z = mm.add_parameter("z", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + mul_ins = mm.add_instruction(migraphx.op("mul"), [add_ins, z]) + relu_ins = mm.add_instruction(migraphx.op("relu"), [mul_ins]) + mm.add_return([relu_ins]) + + mm.add_debug_symbols(add_ins, {"onnx:add"}) + mm.add_debug_symbols(mul_ins, {"onnx:mul"}) + mm.add_debug_symbols(relu_ins, {"onnx:relu"}) + + all_symbols = set() + for ins in mm: + all_symbols.update(ins.get_debug_symbols()) + + assert all_symbols == {"onnx:add", "onnx:mul", "onnx:relu"} + + +def test_iterate_only_symbolized_instructions(): + p = migraphx.program() + mm = p.get_main_module() + s = migraphx.shape(lens=[2, 3], type="float") + x = mm.add_parameter("x", s) + y = mm.add_parameter("y", s) + add_ins = mm.add_instruction(migraphx.op("add"), [x, y]) + relu_ins = mm.add_instruction(migraphx.op("relu"), [add_ins]) + mm.add_return([relu_ins]) + + mm.add_debug_symbols(relu_ins, {"onnx:relu"}) + + symbolized = {ins.name(): ins.get_debug_symbols() + for ins in mm if ins.get_debug_symbols()} + assert len(symbolized) == 1 + assert symbolized["relu"] == {"onnx:relu"} + + +def test_parse_onnx_with_debug_symbols(): + p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx", + use_debug_symbols=True) + mm = p.get_main_module() + assert mm.has_debug_symbols() + + all_symbols = set() + for ins in mm: + all_symbols.update(ins.get_debug_symbols()) + assert len(all_symbols) > 0 + + +def test_parse_onnx_without_debug_symbols(): + p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx", + use_debug_symbols=False) + mm = p.get_main_module() + assert not mm.has_debug_symbols() + + +if __name__ == "__main__": + test_module_has_no_debug_symbols() + test_module_add_debug_symbols() + test_module_add_debug_symbols_multiple_instructions() + test_module_add_debug_symbols_merge() + test_module_remove_debug_symbols() + test_module_remove_one_of_two() + test_module_remove_then_readd() + test_instruction_get_debug_symbols() + test_iterate_instructions_debug_symbols() + test_iterate_only_symbolized_instructions() + test_parse_onnx_with_debug_symbols() + test_parse_onnx_without_debug_symbols()