diff --git a/src/rewrite_reduce.cpp b/src/rewrite_reduce.cpp index 24b6e03ee78..8aa2b9d7d95 100644 --- a/src/rewrite_reduce.cpp +++ b/src/rewrite_reduce.cpp @@ -32,8 +32,10 @@ #include #include #include +#include #include #include +#include MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP32_SOFTMAX); @@ -41,6 +43,116 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { + +// Walk forward through single-consumer ops looking for an instruction with +// the given name. Returns start itself if it already matches. +static std::optional find_downstream_named(instruction_ref start, + const std::string& target) +{ + auto path = get_output_path(start); + auto it = std::find_if( + path.begin(), path.end(), [&](instruction_ref ins) { return ins->name() == target; }); + if(it == path.end()) + return std::nullopt; + return *it; +} + +// Walk backward through the data-flow chain looking for an instruction with +// the given name. Returns start itself if it already matches. Single-input ops +// are followed directly; multi-input ops follow the first non-constant, +// non-bool input. +static std::optional find_upstream_named(instruction_ref start, + const std::string& target) +{ + auto path = unfold(start, [](instruction_ref current) -> std::optional { + const auto& inputs = current->inputs(); + if(inputs.empty()) + return std::nullopt; + if(inputs.size() == 1) + return inputs.front(); + auto it = std::find_if(inputs.begin(), inputs.end(), [](instruction_ref i) { + return not i->can_eval() and i->get_shape().type() != shape::bool_type; + }); + if(it == inputs.end()) + return std::nullopt; + return *it; + }); + auto it = std::find_if( + path.begin(), path.end(), [&](instruction_ref ins) { return ins->name() == target; }); + if(it == path.end()) + return std::nullopt; + return *it; +} + +// Scan the module for attention dots by matching the decomposed softmax +// pattern (match::softmax matches the final div). A softmax whose input +// reaches a dot upstream and whose output reaches another dot downstream +// identifies the Q*K^T and softmax*V dots of attention; both are marked so +// find_dot leaves them alone. +static std::unordered_set collect_attention_dots(module& m) +{ + std::unordered_set result; + for(auto ins : iterator_for(m)) + { + auto r = match::match_instruction(m, ins, match::softmax()); + if(r.result == m.end()) + continue; + auto x = r.instructions["x"]; + auto q_dot = find_upstream_named(x, "dot"); + auto v_dot = find_downstream_named(ins, "dot"); + if(q_dot.has_value() and v_dot.has_value()) + { + result.insert(*q_dot); + result.insert(*v_dot); + } + } + return result; +} + +struct find_dot +{ + std::unordered_set attention_dots; + + auto matcher() const { return match::name("dot"); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + if(attention_dots.count(ins) != 0) + return; + auto a_mat = ins->inputs().front(); + auto b_mat = ins->inputs().back(); + auto a_shape = a_mat->get_shape(); + auto b_shape = b_mat->get_shape(); + auto ndim = a_shape.ndim(); + auto rows = a_shape.lens().at(ndim - 2); + if(rows > 2) + return; + + std::vector permutation(ndim); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation.back(), permutation.at(ndim - 2)); + + // If the b matrix is const foldable then make sure its a transposed layout unless its + // broadcasting + if(b_mat->can_eval() and not b_shape.transposed()) + { + b_mat = + m.insert_instruction(ins, make_op("layout", {{"permutation", permutation}}), b_mat); + } + + auto a_unsqueeze = + m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {ndim - 1}}}), a_mat); + auto b_transpose = + m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), b_mat); + auto b_unsqueeze = + m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {ndim - 2}}}), b_transpose); + auto mul = insert_common_op(m, ins, make_op("mul"), {a_unsqueeze, b_unsqueeze}); + auto reduce = m.insert_instruction(ins, make_op("reduce_sum", {{"axes", {ndim}}}), mul); + m.replace_instruction(ins, make_op("squeeze", {{"axes", {ndim}}}), reduce); + } +}; + struct find_logsoftmax { auto matcher() const { return match::name("logsoftmax"); } @@ -300,6 +412,9 @@ void rewrite_reduce::apply(module& m) const { match::find_matches(m, find_logsoftmax{}); match::find_matches(m, find_softmax{}, find_reduce_mean_variance{}); + // Match the decomposed softmax pattern to identify dots participating in + // attention (Q*K^T and softmax*V) so find_dot can skip them. + match::find_matches(m, find_dot{collect_attention_dots(m)}); if(not enabled(MIGRAPHX_DISABLE_FP32_SOFTMAX{})) { diff --git a/test/rewrite_reduce.cpp b/test/rewrite_reduce.cpp index 9790d346464..5ddfad50c78 100644 --- a/test/rewrite_reduce.cpp +++ b/test/rewrite_reduce.cpp @@ -78,10 +78,261 @@ TEST_CASE(softmax_upcast) })); } +// Skinny dot [M=1, K] @ [K, N] gets rewritten to mul + reduce_sum. +TEST_CASE(dot_skinny_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {1, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {128, 4}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + m1.add_return({dot}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto a_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), a); + auto b_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), b); + auto b_trans = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), b_unsq); + auto a_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 128}}}), a_unsq); + auto mul = m2.add_instruction(migraphx::make_op("mul"), a_bc, b_trans); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), mul); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), red); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Skinny dot with M=2 also gets rewritten. +TEST_CASE(dot_skinny_m2_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {2, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {128, 4}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + m1.add_return({dot}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto a_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), a); + auto a_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 128}}}), a_unsq); + auto b_trans = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + auto b_bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 4, 128}}}), b_trans); + auto mul = m2.add_instruction(migraphx::make_op("mul"), a_bc, b_bc); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), mul); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), red); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// rows > 2 exceeds the skinny threshold so the dot is left alone. +TEST_CASE(dot_wide_no_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {3, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {128, 4}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + m1.add_return({dot}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a, b); + m2.add_return({dot}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Batched skinny dot [B, M=1, K] @ [B, K, N] gets rewritten. +TEST_CASE(dot_batched_skinny_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {4, 1, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {4, 128, 8}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + m1.add_return({dot}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto a_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), a); + auto b_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), b); + auto b_trans = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), b_unsq); + auto a_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 8, 128}}}), a_unsq); + auto mul = m2.add_instruction(migraphx::make_op("mul"), a_bc, b_trans); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), mul); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), red); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Batched dot with M=2 gets rewritten too. +TEST_CASE(dot_batched_m2_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {1, 12, 2, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {1, 12, 128, 64}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + m1.add_return({dot}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto a_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), a); + auto a_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 2, 64, 128}}}), a_unsq); + auto b_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), b); + auto b_trans = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 4, 3}}}), b_unsq); + auto b_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 2, 64, 128}}}), b_trans); + auto mul = m2.add_instruction(migraphx::make_op("mul"), a_bc, b_bc); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {4}}}), mul); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {4}}}), red); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// Batched dot feeding a softmax that returns (no downstream dot) is not +// attention; find_dot rewrites the dot and find_softmax decomposes the softmax. +// Using float_type avoids the fp16->fp32 upcast wrapping. +TEST_CASE(dot_softmax_return_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {1, 12, 1, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {1, 12, 128, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), dot); + m1.add_return({softmax}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto a_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), a); + auto b_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), b); + auto b_trans = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), b_unsq); + auto a_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128, 128}}}), a_unsq); + auto mul = m2.add_instruction(migraphx::make_op("mul"), a_bc, b_trans); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {4}}}), mul); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {4}}}), red); + auto rmax = m2.add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), sq); + auto rmax_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128}}}), rmax); + auto sub = m2.add_instruction(migraphx::make_op("sub"), sq, rmax_bc); + auto exp = m2.add_instruction(migraphx::make_op("exp"), sub); + auto rsum = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128}}}), rsum); + auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); + m2.add_return({div}); + } + EXPECT(m1.sort() == m2.sort()); +} + +// dot -> mul -> softmax -> return: still not attention (no dot after softmax). +TEST_CASE(dot_mul_softmax_return_rewrite) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {1, 12, 1, 128}}; + migraphx::shape b_shape{migraphx::shape::float_type, {1, 12, 128, 128}}; + migraphx::shape scale_shape{migraphx::shape::float_type, {1}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", a_shape); + auto b = m1.add_parameter("b", b_shape); + auto scale = m1.add_parameter("scale", scale_shape); + auto scale_bc = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128}}}), scale); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, scale_bc); + auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), mul); + m1.add_return({softmax}); + } + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", a_shape); + auto b = m2.add_parameter("b", b_shape); + auto scale = m2.add_parameter("scale", scale_shape); + auto scale_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128}}}), scale); + auto a_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), a); + auto b_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), b); + auto b_trans = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), b_unsq); + auto a_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128, 128}}}), a_unsq); + auto mul_ab = m2.add_instruction(migraphx::make_op("mul"), a_bc, b_trans); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {4}}}), mul_ab); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {4}}}), red); + auto mul_scale = m2.add_instruction(migraphx::make_op("mul"), sq, scale_bc); + auto rmax = m2.add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), mul_scale); + auto rmax_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128}}}), rmax); + auto sub = m2.add_instruction(migraphx::make_op("sub"), mul_scale, rmax_bc); + auto exp = m2.add_instruction(migraphx::make_op("exp"), sub); + auto rsum = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 12, 1, 128}}}), rsum); + auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); + m2.add_return({div}); + } + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(softmax_dot_scale_where_fp32_convert_after) { migraphx::shape dot_shape{migraphx::shape::half_type, {1, 12, 1, 128}}; migraphx::shape k_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; + migraphx::shape v_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; migraphx::shape scale_shape{migraphx::shape::half_type, {1}}; migraphx::shape mask_shape{migraphx::shape::bool_type, {1, 12, 1, 128}}; migraphx::shape f32_dot_shape{migraphx::shape::float_type, dot_shape.lens()}; @@ -102,23 +353,26 @@ TEST_CASE(softmax_dot_scale_where_fp32_convert_after) return std::make_tuple(dot, scale_bc, mask, ninf_bc); }; - // Input module: dot -> mul -> where -> softmax + // Input module: dot -> mul -> where -> softmax -> dot(V) migraphx::module m1; { auto [dot, scale_bc, mask, ninf_bc] = make_dot(m1, dot_shape, k_shape, scale_shape, mask_shape); + auto v = m1.add_parameter("v", v_shape); auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, scale_bc); auto where = m1.add_instruction(migraphx::make_op("where"), mask, ninf_bc, mul); auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), where); - m1.add_return({softmax}); + auto dot_v = m1.add_instruction(migraphx::make_op("dot"), softmax, v); + m1.add_return({dot_v}); } // Expected module: dot(f16) -> convert(f32) -> mul(f32) -> where(f32) -> - // softmax_decomposed(f32) -> convert(f16) + // softmax_decomposed(f32) -> convert(f16) -> dot(V) migraphx::module m2; { auto [dot, scale_bc, mask, ninf_bc] = make_dot(m2, dot_shape, k_shape, scale_shape, mask_shape); + auto v = m2.add_parameter("v", v_shape); auto cvt_dot = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), dot); auto cvt_scale = m2.add_instruction( @@ -138,7 +392,8 @@ TEST_CASE(softmax_dot_scale_where_fp32_convert_after) auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); auto cvt_out = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), div); - m2.add_return({cvt_out}); + auto dot_v = m2.add_instruction(migraphx::make_op("dot"), cvt_out, v); + m2.add_return({dot_v}); } run_pass(m1); @@ -149,26 +404,30 @@ TEST_CASE(softmax_dot_scale_fp32_convert_after) { migraphx::shape dot_shape{migraphx::shape::half_type, {1, 12, 1, 128}}; migraphx::shape k_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; + migraphx::shape v_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; migraphx::shape scale_shape{migraphx::shape::half_type, {1}}; migraphx::shape f32_dot_shape{migraphx::shape::float_type, dot_shape.lens()}; - // Input module: dot -> mul -> softmax + // Input module: dot -> mul -> softmax -> dot(V) migraphx::module m1; auto q1 = m1.add_parameter("q", dot_shape); auto k1 = m1.add_parameter("k", k_shape); + auto v1 = m1.add_parameter("v", v_shape); auto scale1 = m1.add_parameter("scale", scale_shape); auto scale_bc1 = m1.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", dot_shape.lens()}}), scale1); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), q1, k1); auto mul1 = m1.add_instruction(migraphx::make_op("mul"), dot1, scale_bc1); auto softmax1 = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), mul1); - m1.add_return({softmax1}); + auto dot_v1 = m1.add_instruction(migraphx::make_op("dot"), softmax1, v1); + m1.add_return({dot_v1}); // Expected module: dot(f16) -> convert(f32) -> mul(f32) -> softmax_decomposed(f32) -> - // convert(f16) + // convert(f16) -> dot(V) migraphx::module m2; auto q2 = m2.add_parameter("q", dot_shape); auto k2 = m2.add_parameter("k", k_shape); + auto v2 = m2.add_parameter("v", v_shape); auto scale2 = m2.add_parameter("scale", scale_shape); auto scale_bc2 = m2.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", dot_shape.lens()}}), scale2); @@ -189,7 +448,8 @@ TEST_CASE(softmax_dot_scale_fp32_convert_after) auto div2 = m2.add_instruction(migraphx::make_op("div"), exp2, rsum_bc2); auto cvt_out2 = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), div2); - m2.add_return({cvt_out2}); + auto dot_v2 = m2.add_instruction(migraphx::make_op("dot"), cvt_out2, v2); + m2.add_return({dot_v2}); run_pass(m1); EXPECT(m1 == m2); @@ -200,18 +460,22 @@ TEST_CASE(softmax_dot_only) { migraphx::shape dot_shape{migraphx::shape::half_type, {1, 12, 1, 128}}; migraphx::shape k_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; + migraphx::shape v_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; migraphx::shape f32_dot_shape{migraphx::shape::float_type, dot_shape.lens()}; migraphx::module m1; auto q1 = m1.add_parameter("q", dot_shape); auto k1 = m1.add_parameter("k", k_shape); + auto v1 = m1.add_parameter("v", v_shape); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), q1, k1); auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), dot1); - m1.add_return({softmax}); + auto dot_v1 = m1.add_instruction(migraphx::make_op("dot"), softmax, v1); + m1.add_return({dot_v1}); migraphx::module m2; auto q2 = m2.add_parameter("q", dot_shape); auto k2 = m2.add_parameter("k", k_shape); + auto v2 = m2.add_parameter("v", v_shape); auto dot2 = m2.add_instruction(migraphx::make_op("dot"), q2, k2); auto cvt_dot = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), dot2); @@ -226,7 +490,8 @@ TEST_CASE(softmax_dot_only) auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); auto cvt_out = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), div); - m2.add_return({cvt_out}); + auto dot_v2 = m2.add_instruction(migraphx::make_op("dot"), cvt_out, v2); + m2.add_return({dot_v2}); run_pass(m1); EXPECT(m1 == m2); @@ -313,6 +578,7 @@ TEST_CASE(softmax_dot_scale_double_where) { migraphx::shape dot_shape{migraphx::shape::half_type, {1, 12, 1, 128}}; migraphx::shape k_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; + migraphx::shape v_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; migraphx::shape scale_shape{migraphx::shape::half_type, {1}}; migraphx::shape mask_shape{migraphx::shape::bool_type, {1, 12, 1, 128}}; migraphx::shape f32_dot_shape{migraphx::shape::float_type, dot_shape.lens()}; @@ -339,17 +605,20 @@ TEST_CASE(softmax_dot_scale_double_where) { auto [dot, scale_bc, mask1, mask2, ninf_bc] = make_inputs(m1, dot_shape, k_shape, scale_shape, mask_shape); + auto v = m1.add_parameter("v", v_shape); auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, scale_bc); auto where1 = m1.add_instruction(migraphx::make_op("where"), mask1, ninf_bc, mul); auto where2 = m1.add_instruction(migraphx::make_op("where"), mask2, ninf_bc, where1); auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), where2); - m1.add_return({softmax}); + auto dot_v = m1.add_instruction(migraphx::make_op("dot"), softmax, v); + m1.add_return({dot_v}); } migraphx::module m2; { auto [dot, scale_bc, mask1, mask2, ninf_bc] = make_inputs(m2, dot_shape, k_shape, scale_shape, mask_shape); + auto v = m2.add_parameter("v", v_shape); auto cvt_dot = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), dot); auto cvt_scale = m2.add_instruction( @@ -370,7 +639,8 @@ TEST_CASE(softmax_dot_scale_double_where) auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); auto cvt_out = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), div); - m2.add_return({cvt_out}); + auto dot_v = m2.add_instruction(migraphx::make_op("dot"), cvt_out, v); + m2.add_return({dot_v}); } run_pass(m1); @@ -383,20 +653,24 @@ TEST_CASE(softmax_dot_relu_upcast) { migraphx::shape dot_shape{migraphx::shape::half_type, {1, 12, 1, 128}}; migraphx::shape k_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; + migraphx::shape v_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; migraphx::shape f32_dot_shape{migraphx::shape::float_type, dot_shape.lens()}; migraphx::module m1; auto q1 = m1.add_parameter("q", dot_shape); auto k1 = m1.add_parameter("k", k_shape); + auto v1 = m1.add_parameter("v", v_shape); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), q1, k1); auto relu1 = m1.add_instruction(migraphx::make_op("relu"), dot1); auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), relu1); - m1.add_return({softmax}); + auto dot_v1 = m1.add_instruction(migraphx::make_op("dot"), softmax, v1); + m1.add_return({dot_v1}); // Expected: dot stays f16, convert(f16->f32) after dot, relu upcasted to f32 migraphx::module m2; auto q2 = m2.add_parameter("q", dot_shape); auto k2 = m2.add_parameter("k", k_shape); + auto v2 = m2.add_parameter("v", v_shape); auto dot2 = m2.add_instruction(migraphx::make_op("dot"), q2, k2); auto cvt_dot = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), dot2); @@ -412,7 +686,8 @@ TEST_CASE(softmax_dot_relu_upcast) auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); auto cvt_out = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), div); - m2.add_return({cvt_out}); + auto dot_v2 = m2.add_instruction(migraphx::make_op("dot"), cvt_out, v2); + m2.add_return({dot_v2}); run_pass(m1); EXPECT(m1 == m2); @@ -425,6 +700,7 @@ TEST_CASE(softmax_dot_scale_left) { migraphx::shape dot_shape{migraphx::shape::half_type, {1, 12, 1, 128}}; migraphx::shape k_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; + migraphx::shape v_shape{migraphx::shape::half_type, {1, 12, 128, 128}}; migraphx::shape scale_shape{migraphx::shape::half_type, {1}}; migraphx::shape f32_dot_shape{migraphx::shape::float_type, dot_shape.lens()}; @@ -441,14 +717,17 @@ TEST_CASE(softmax_dot_scale_left) migraphx::module m1; { auto [dot, scale_bc] = make_graph(m1, dot_shape, k_shape, scale_shape); + auto v = m1.add_parameter("v", v_shape); auto mul = m1.add_instruction(migraphx::make_op("mul"), scale_bc, dot); auto softmax = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), mul); - m1.add_return({softmax}); + auto dot_v = m1.add_instruction(migraphx::make_op("dot"), softmax, v); + m1.add_return({dot_v}); } migraphx::module m2; { auto [dot, scale_bc] = make_graph(m2, dot_shape, k_shape, scale_shape); + auto v = m2.add_parameter("v", v_shape); auto cvt_scale = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), scale_bc); auto cvt_dot = m2.add_instruction( @@ -465,7 +744,8 @@ TEST_CASE(softmax_dot_scale_left) auto div = m2.add_instruction(migraphx::make_op("div"), exp, rsum_bc); auto cvt_out = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), div); - m2.add_return({cvt_out}); + auto dot_v = m2.add_instruction(migraphx::make_op("dot"), cvt_out, v); + m2.add_return({dot_v}); } run_pass(m1);