Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/rewrite_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,127 @@
#include <migraphx/common.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_convert.hpp>
#include <migraphx/instruction_traversal.hpp>
#include <migraphx/unfold.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <unordered_set>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP32_SOFTMAX);

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace {

// Walk forward through single-consumer ops looking for an instruction with
// the given name. Returns start itself if it already matches.
static std::optional<instruction_ref> find_downstream_named(instruction_ref start,

Check warning on line 49 in src/rewrite_reduce.cpp

View workflow job for this annotation

GitHub Actions / tidy

'find_downstream_named' is a static definition in anonymous namespace; static is redundant here [readability-static-definition-in-anonymous-namespace,-warnings-as-errors]
const std::string& target)
{
auto path = get_output_path(start);
auto it = std::find_if(
path.begin(), path.end(), [&](instruction_ref ins) { return ins->name() == target; });
if(it == path.end())
return std::nullopt;
return *it;
}

// Walk backward through the data-flow chain looking for an instruction with
// the given name. Returns start itself if it already matches. Single-input ops
// are followed directly; multi-input ops follow the first non-constant,
// non-bool input.
static std::optional<instruction_ref> find_upstream_named(instruction_ref start,

Check warning on line 64 in src/rewrite_reduce.cpp

View workflow job for this annotation

GitHub Actions / tidy

'find_upstream_named' is a static definition in anonymous namespace; static is redundant here [readability-static-definition-in-anonymous-namespace,-warnings-as-errors]
const std::string& target)
{
auto path = unfold(start, [](instruction_ref current) -> std::optional<instruction_ref> {
const auto& inputs = current->inputs();
if(inputs.empty())
return std::nullopt;
if(inputs.size() == 1)
return inputs.front();
auto it = std::find_if(inputs.begin(), inputs.end(), [](instruction_ref i) {
return not i->can_eval() and i->get_shape().type() != shape::bool_type;
});
if(it == inputs.end())
return std::nullopt;
return *it;
});
auto it = std::find_if(
path.begin(), path.end(), [&](instruction_ref ins) { return ins->name() == target; });
if(it == path.end())
return std::nullopt;
return *it;
}

// Scan the module for attention dots by matching the decomposed softmax
// pattern (match::softmax matches the final div). A softmax whose input
// reaches a dot upstream and whose output reaches another dot downstream
// identifies the Q*K^T and softmax*V dots of attention; both are marked so
// find_dot leaves them alone.
static std::unordered_set<instruction_ref> collect_attention_dots(module& m)

Check warning on line 92 in src/rewrite_reduce.cpp

View workflow job for this annotation

GitHub Actions / tidy

'collect_attention_dots' is a static definition in anonymous namespace; static is redundant here [readability-static-definition-in-anonymous-namespace,-warnings-as-errors]
{
std::unordered_set<instruction_ref> result;
for(auto ins : iterator_for(m))
{
auto r = match::match_instruction(m, ins, match::softmax());
if(r.result == m.end())
continue;
auto x = r.instructions["x"];
auto q_dot = find_upstream_named(x, "dot");
auto v_dot = find_downstream_named(ins, "dot");
if(q_dot.has_value() and v_dot.has_value())
{
result.insert(*q_dot);
result.insert(*v_dot);
}
}
return result;
}

struct find_dot
{
std::unordered_set<instruction_ref> attention_dots;

auto matcher() const { return match::name("dot"); }

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
if(attention_dots.count(ins) != 0)
return;
auto a_mat = ins->inputs().front();
auto b_mat = ins->inputs().back();
auto a_shape = a_mat->get_shape();
auto b_shape = b_mat->get_shape();
auto ndim = a_shape.ndim();
auto rows = a_shape.lens().at(ndim - 2);
if(rows > 2)
return;

std::vector<int64_t> permutation(ndim);
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation.at(ndim - 2));

// If the b matrix is const foldable then make sure its a transposed layout unless its
// broadcasting
if(b_mat->can_eval() and not b_shape.transposed())
{
b_mat =
m.insert_instruction(ins, make_op("layout", {{"permutation", permutation}}), b_mat);
}
Comment on lines +136 to +142
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layout insertion for const-foldable b_mat is guarded only by !b_shape.transposed(), but the preceding comment says it should be skipped when b_mat is broadcasting. If b_shape.broadcasted() is true (e.g., a multibroadcasted literal), inserting layout can materialize the broadcasted tensor and significantly increase compile-time/memory usage. Consider adding and not b_shape.broadcasted() (or an equivalent check) to the condition so broadcasted constants keep their cheap broadcasted representation.

Copilot uses AI. Check for mistakes.

auto a_unsqueeze =
m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {ndim - 1}}}), a_mat);
auto b_transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), b_mat);
auto b_unsqueeze =
m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {ndim - 2}}}), b_transpose);
auto mul = insert_common_op(m, ins, make_op("mul"), {a_unsqueeze, b_unsqueeze});
auto reduce = m.insert_instruction(ins, make_op("reduce_sum", {{"axes", {ndim}}}), mul);
m.replace_instruction(ins, make_op("squeeze", {{"axes", {ndim}}}), reduce);
}
};

struct find_logsoftmax
{
auto matcher() const { return match::name("logsoftmax"); }
Expand Down Expand Up @@ -300,6 +412,9 @@
{
match::find_matches(m, find_logsoftmax{});
match::find_matches(m, find_softmax{}, find_reduce_mean_variance{});
// Match the decomposed softmax pattern to identify dots participating in
// attention (Q*K^T and softmax*V) so find_dot can skip them.
match::find_matches(m, find_dot{collect_attention_dots(m)});

if(not enabled(MIGRAPHX_DISABLE_FP32_SOFTMAX{}))
{
Expand Down
Loading
Loading