Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ static std::vector<argument> run_target(program p,
auto arg = inputs.count(x.first) == 0 ? generate_argument(x.second) : inputs.at(x.first);
m[x.first] = options.offload_copy ? arg : t.copy_to(arg);
}

p.eval(m); // run once for hip graph capture

Comment on lines +141 to +142
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warm-up p.eval(m) runs unconditionally and will execute the model twice even when hip-graph capture is disabled, which significantly slows verification and can skew perf numbers. Gate this extra run behind enabled(MIGRAPHX_ENABLE_HIP_GRAPH{}) (and/or only when the compiled target is GPU).

Copilot uses AI. Check for mistakes.
auto gpu_out = p.eval(m);
std::vector<argument> output(gpu_out.size());
std::cout << p << std::endl;
Expand Down
121 changes: 115 additions & 6 deletions src/include/migraphx/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/any_ptr.hpp>

namespace migraphx {
Expand Down Expand Up @@ -77,6 +78,17 @@ void finish_on_context(T&, any_ptr)
{
}

template <class T, class... Ts>
void nop_context(T&, Ts&&...)
{
}

template <class T>
std::function<std::vector<argument>()> get_capture_context(T&)
{
return nullptr;
}

#ifdef TYPE_ERASED_DECLARATION

// Type-erased interface for:
Expand All @@ -92,6 +104,12 @@ struct MIGRAPHX_EXPORT context
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
// (optional)
void start_capture();
// (optional)
void end_capture(const std::vector<argument>& args);
// (optional)
std::function<std::vector<argument>()> get_capture();
//
void finish() const;
};
Expand Down Expand Up @@ -169,6 +187,50 @@ struct context
finish_on_context(private_detail_te_self, queue);
}

template <class T>
static auto private_detail_te_default_start_capture(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.start_capture())
{
private_detail_te_self.start_capture();
}

template <class T>
static void private_detail_te_default_start_capture(float, T&& private_detail_te_self)
{
nop_context(private_detail_te_self);
}

template <class T>
static auto private_detail_te_default_end_capture(char,
T&& private_detail_te_self,
const std::vector<argument>& args)
-> decltype(private_detail_te_self.end_capture(args))
{
private_detail_te_self.end_capture(args);
}

template <class T>
static void private_detail_te_default_end_capture(float,
T&& private_detail_te_self,
const std::vector<argument>& args)
{
nop_context(private_detail_te_self, args);
}

template <class T>
static auto private_detail_te_default_get_capture(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_capture())
{
return private_detail_te_self.get_capture();
}

template <class T>
static std::function<std::vector<argument>()>
private_detail_te_default_get_capture(float, T&& private_detail_te_self)
{
return get_capture_context(private_detail_te_self);
}

template <class PrivateDetailTypeErasedT>
struct private_te_unwrap_reference
{
Expand Down Expand Up @@ -196,6 +258,14 @@ struct context
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
private_detail_te_default_finish_on(
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
private_detail_te_default_start_capture(char(0),
std::declval<PrivateDetailTypeErasedT>()),
private_detail_te_default_end_capture(
char(0),
std::declval<PrivateDetailTypeErasedT>(),
std::declval<const std::vector<argument>&>()),
private_detail_te_default_get_capture(char(0),
std::declval<PrivateDetailTypeErasedT>()),
std::declval<PrivateDetailTypeErasedT>().finish(),
void());

Expand Down Expand Up @@ -301,6 +371,24 @@ struct context
(*this).private_detail_te_get_handle().finish_on(queue);
}

void start_capture()
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().start_capture();
}

void end_capture(const std::vector<argument>& args)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().end_capture(args);
}

std::function<std::vector<argument>()> get_capture()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_capture();
}

void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
Expand All @@ -320,12 +408,15 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;

virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void start_capture() = 0;
virtual void end_capture(const std::vector<argument>& args) = 0;
virtual std::function<std::vector<argument>()> get_capture() = 0;
virtual void finish() const = 0;
};

template <typename PrivateDetailTypeErasedT>
Expand Down Expand Up @@ -386,6 +477,24 @@ struct context
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
}

void start_capture() override
{

private_detail_te_default_start_capture(char(0), private_detail_te_value);
}

void end_capture(const std::vector<argument>& args) override
{

private_detail_te_default_end_capture(char(0), private_detail_te_value, args);
}

std::function<std::vector<argument>()> get_capture() override
{

return private_detail_te_default_get_capture(char(0), private_detail_te_value);
}

void finish() const override { private_detail_te_value.finish(); }

PrivateDetailTypeErasedT private_detail_te_value;
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ inline namespace MIGRAPHX_INLINE_NS {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIP_GRAPH)

struct program_impl;

Expand Down
4 changes: 1 addition & 3 deletions src/layout_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ void remove_layout(module& m)
{
if(ins->name() != "layout")
continue;
auto perm = ins->get_operator().to_value()["permutation"].to_vector<std::int64_t>();
auto iperm = find_permutation(ins->inputs().front()->get_shape());
if(perm != iperm)
if(find_permutation(ins->get_shape()) != find_permutation(ins->inputs().front()->get_shape()))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
Expand Down
15 changes: 15 additions & 0 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,15 @@ std::vector<argument> program::eval(const parameter_map& params,
{
auto& contexts = this->impl->contexts;

if(contexts.size() == 1 and enabled(MIGRAPHX_ENABLE_HIP_GRAPH{}))
{
auto& ctx = contexts.front();
auto run = ctx.get_capture();
if(run != nullptr)
return run();
ctx.start_capture();
}
Comment on lines +593 to +600
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a capture already exists, program::eval returns run() before handling exec_env.async (wait_for/finish_on) and before applying any MIGRAPHX_TRACE_EVAL behavior. This changes observable semantics for callers using async execution environments or tracing. Consider disabling the hip-graph fast path when exec_env.async or tracing is enabled, or make the cached path honor the same synchronization/tracing behavior.

Copilot uses AI. Check for mistakes.

auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret;

Expand Down Expand Up @@ -664,6 +673,12 @@ std::vector<argument> program::eval(const parameter_map& params,
contexts.front().finish_on(exec_env.queue);
}

if(contexts.size() == 1 and enabled(MIGRAPHX_ENABLE_HIP_GRAPH{}))
{
auto& ctx = contexts.front();
ctx.end_capture(ret);
}
Comment on lines +676 to +680
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ctx.start_capture() is called and generic_eval throws (e.g., missing/shape-mismatched parameters), the stream capture is left open and end_capture is never reached, which can break subsequent executions on that stream. This needs RAII/try-catch cleanup to end/abort capture on exceptions (and to only call end_capture when a capture is actually active).

Copilot uses AI. Check for mistakes.

return ret;
}

Expand Down
73 changes: 73 additions & 0 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ static bool concat_const_foldable(Iterator start, Iterator last, std::size_t iax
});
}

MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins)
{
if(ins->name() != "convolution")
return false;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() != 1)
return false;
if(not all_of(v.at("stride"), [](const value& x) { return x.to<std::size_t>() == 1; }))
return false;
if(not all_of(v.at("padding"), [](const value& x) { return x.to<std::size_t>() == 0; }))
return false;
if(not all_of(v.at("dilation"), [](const value& x) { return x.to<std::size_t>() == 1; }))
return false;
auto w = ins->inputs().at(1)->get_shape();
return std::all_of(w.lens().begin() + 2, w.lens().end(), [](std::size_t i) { return i == 1; });
}

// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv
{
Expand Down Expand Up @@ -1093,6 +1110,61 @@ struct find_concat_conv
}
};

// (x * w1) * w2 => x * (w1 * w2)
struct find_conv_conv_1x1
{
auto matcher() const
{
return conv_1x1(
match::arg(0)(match::used_once(), match::name("convolution").bind("input")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto input = r.instructions["input"];
auto x_ins = input->inputs().front();
auto wnxn = input->inputs()[1];
auto w1x1 = ins->inputs()[1];

auto out_channels = w1x1->get_shape().lens()[0];
auto mid_channels = w1x1->get_shape().lens()[1];
auto in_channels_per_group = wnxn->get_shape().lens()[1];
auto groups = x_ins->get_shape().lens()[1] / in_channels_per_group;
auto w_size = std::accumulate(wnxn->get_shape().lens().begin() + 2,
wnxn->get_shape().lens().end(),
std::size_t{1},
std::multiplies<>{});

auto mw_dims = wnxn->get_shape().lens();
mw_dims[1] *= groups;

auto w1x1_reshaped = m.insert_instruction(
ins,
make_op("reshape", {{"dims", {out_channels, groups, mid_channels / groups}}}),
w1x1);
auto w1x1_grouped = m.insert_instruction(
ins, make_op("transpose", {{"permutation", {1, 0, 2}}}), w1x1_reshaped);

auto wnxn_reshaped = m.insert_instruction(
ins,
make_op("reshape",
{{"dims", {groups, mid_channels / groups, in_channels_per_group * w_size}}}),
wnxn);

auto mw = m.insert_instruction(ins, make_op("dot"), w1x1_grouped, wnxn_reshaped);
auto mw_transposed =
m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0, 2}}}), mw);
auto mw_reshaped =
m.insert_instruction(ins, make_op("reshape", {{"dims", mw_dims}}), mw_transposed);

auto op = input->get_operator();
op.from_value({{"group", 1}});
auto conv = m.insert_instruction(ins, op, x_ins, mw_reshaped);
Comment on lines +1139 to +1163
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find_conv_conv_1x1 builds mw_dims from the first convolution’s weights (wnxn) and only updates mw_dims[1]. This leaves mw_dims[0] as the inner conv’s output channels, but the fused weights produced by the dot have leading dimension out_channels (from the 1x1 conv). If out_channels != mid_channels (common), the final reshape will be invalid (element-count mismatch) and/or the replacement convolution will produce the wrong output shape. Update the fused weight shape so dim-0 matches out_channels (and add shape/divisibility checks such as mid_channels % groups == 0) before inserting reshapes/dot.

Copilot uses AI. Check for mistakes.
m.replace_instruction(ins, conv);
}
Comment on lines +1113 to +1165
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new convolution+1x1-convolution fusion introduces new behavior in simplify_algebra, but there’s no corresponding test coverage validating the rewrite (especially for the common case where the 1x1 conv changes channel count and/or where the first conv is grouped/depthwise). Please add a unit test in test/simplify_algebra_test.cpp that exercises the rewrite and checks the optimized module is equivalent and preserves output shapes.

Copilot uses AI. Check for mistakes.
};

static void
move_instructions_back(module& m, instruction_ref pos, std::vector<instruction_ref> inss)
{
Expand Down Expand Up @@ -2395,6 +2467,7 @@ void simplify_algebra::apply(module& m) const
find_zero_ops{},
find_dot_add{},
find_conv_add{},
find_conv_conv_1x1{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/compile_hip_code_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/program.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -151,7 +152,7 @@ compute_global_for(const context& ctx, std::size_t n, std::size_t over)
ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) {
std::size_t num_elements = n;
if(not hip_accept_non_uniform_wg())
if(enabled(MIGRAPHX_ENABLE_HIP_GRAPH{}) or not hip_accept_non_uniform_wg())
{
num_elements = (1 + (n - 1) / local) * local;
}
Expand Down
17 changes: 9 additions & 8 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,15 @@ auto is_mlir_conv(mlir_mode mode)
if(mode == mlir_mode::all)
return true;
// No winograd for group convolution
if(group > 1)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
return group == 1;
// if(group > 1)
// return true;
// auto w = ins->inputs().at(1)->get_shape();
// if(w.lens().size() != 4)
// return true;
// if(w.lens()[2] != w.lens()[3])
// return true;
// return (w.lens()[3] % 3) != 0;
Comment on lines 374 to +383
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_mlir_conv now unconditionally returns group == 1 in fast mode and leaves the previous heuristic commented out below a return, making the intent hard to interpret (the nearby “No winograd for group convolution” comment is also ambiguous with the new logic). Please either remove the dead commented code or replace it with a clear, active heuristic (and update the comment) so future changes don’t accidentally reintroduce the old behavior.

Copilot uses AI. Check for mistakes.
});
}

Expand Down
Loading
Loading