diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b989326f65f..4bd220bf587 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -218,7 +218,6 @@ register_migraphx_ops( gather gathernd get_tuple_elem - gqa_rotary_embedding greater group gru diff --git a/src/include/migraphx/op/gqa_rotary_embedding.hpp b/src/include/migraphx/op/gqa_rotary_embedding.hpp deleted file mode 100644 index c24463a19b6..00000000000 --- a/src/include/migraphx/op/gqa_rotary_embedding.hpp +++ /dev/null @@ -1,221 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - */ -#ifndef MIGRAPHX_GUARD_OPERATORS_GQA_GQA_ROTARY_EMBEDDING_HPP -#define MIGRAPHX_GUARD_OPERATORS_GQA_GQA_ROTARY_EMBEDDING_HPP - -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace op { - -struct gqa_rotary_embedding -{ - size_t num_heads = 1; - size_t kv_num_heads = 1; - bool interleaved = false; - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.num_heads, "num_heads"), - f(self.kv_num_heads, "kv_num_heads"), - f(self.interleaved, "interleaved")); - } - - std::string name() const { return "gqa_rotary_embedding"; } - - shape compute_shape(std::vector inputs) const { return inputs.front(); } - - struct rotary_parameters - { - size_t batch_size = 0; - size_t sequence_length = 0; - size_t head_size = 0; - size_t num_heads = 0; - size_t rotary_embedding_dim = 0; - size_t max_sequence_length = 0; // Sequence length used by cos/sin cache - size_t head_stride = 0; - size_t seq_stride = 0; - size_t batch_stride = 0; - bool position_ids_use_batch = false; - }; - - template - void run_rotary_embedding(T input, - T cos_cache, - T sin_cache, - T output, - const size_t* pos_ids, - rotary_parameters params) const - { - const size_t half_rotary_emb_dim = params.rotary_embedding_dim / 2; - - const size_t loop_len = params.batch_size * params.sequence_length * params.num_heads; - par_for(loop_len, [&](auto idx) { - const size_t b = (idx / params.num_heads) / params.sequence_length; - const size_t s = (idx / params.num_heads) % params.sequence_length; - const size_t n = idx % params.num_heads; - const size_t block_offset = - b * params.batch_stride + s * params.seq_stride + n * params.head_stride; - auto input_data = input + block_offset; - auto output_data = output + block_offset; - - const size_t position_id = params.position_ids_use_batch - ? pos_ids[b * params.sequence_length + s] - : pos_ids[0] + s; - - const size_t cache_offset = position_id * half_rotary_emb_dim; - auto cos_data = cos_cache + cache_offset; - auto sin_data = sin_cache + cache_offset; - - size_t cache_idx = 0; - float sign = 0.0; - size_t j = 0; - for(size_t i = 0; i < params.rotary_embedding_dim; i++) - { - if(interleaved) - { - cache_idx = (i / 2) % half_rotary_emb_dim; - sign = (i % 2 == 0) ? -1.0 : 1.0; - j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign - } - else - { - cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? -1.0 : 1.0; - j = (i + half_rotary_emb_dim) % params.rotary_embedding_dim; - } - output_data[i] = input_data[i] * cos_data[cache_idx] + - sign * input_data[j] * sin_data[cache_idx]; - } - std::copy(input_data + params.rotary_embedding_dim, - input_data + params.head_size, - output_data + params.rotary_embedding_dim); - }); - } - - template - void pack_v_into_rotary_qkv(rotary_parameters params, const T input, T output) const - { - const size_t loop_len = params.batch_size * params.sequence_length * kv_num_heads; - par_for(loop_len, [&](const auto idx) { - const size_t b = (idx / kv_num_heads) / params.sequence_length; - const size_t s = (idx / kv_num_heads) % params.sequence_length; - const size_t n = idx % kv_num_heads; - const size_t block_offset = - b * params.batch_stride + s * params.seq_stride + n * params.head_stride; - const T input_data = input + block_offset; - T output_data = output + block_offset; - for(size_t i = 0; i < params.head_size; i++) - { - output_data[i] = input_data[i]; - } - }); - } - - // Args: - // 0 - packed QKV (batch_size, num_heads + 2 * kv_num_heads, sequence_length, head_size) - // 1 - seqlens_k (batch_size) - // 2 - cos cache (max_rotary_sequence_length, head_size / 2) - // 3 - sin cache (max_rotary_sequence_length, head_size / 2) - argument compute(const shape& output_shape, std::vector args) const - { - rotary_parameters params; - - const auto& qkv_lens = args[0].get_shape().lens(); - params.batch_size = qkv_lens[0]; - params.sequence_length = qkv_lens[2]; - params.head_size = qkv_lens[3]; - const auto& cache_lens = args[2].get_shape().lens(); - params.max_sequence_length = cache_lens[0]; - params.rotary_embedding_dim = cache_lens[1] * 2; - params.seq_stride = params.head_size; - params.head_stride = params.sequence_length * params.seq_stride; - params.batch_stride = - (num_heads + 2 * kv_num_heads) * params.sequence_length * params.head_size; - params.position_ids_use_batch = params.sequence_length == 1; - - argument result{output_shape}; - - visit_all(result, args[0], args[2], args[3])( - [&](auto output, auto qkv, auto cos_cache, auto sin_cache) { - visit_all(args[1])([&](auto seqlens_k) { - std::vector pos_ids(params.position_ids_use_batch ? params.batch_size - : 1); - if(params.position_ids_use_batch) - { - std::transform(seqlens_k.begin(), - seqlens_k.end(), - pos_ids.begin(), - [](auto len) { return len; }); - } - else - { - pos_ids[0] = 0; - } - - auto q_input = qkv.begin(); - auto k_input = q_input + num_heads * params.head_stride; - auto q_rotary = output.begin(); - auto k_rotary = q_rotary + num_heads * params.head_stride; - - params.num_heads = num_heads; - run_rotary_embedding(q_input, - cos_cache.begin(), - sin_cache.begin(), - q_rotary, - pos_ids.data(), - params); - - params.num_heads = kv_num_heads; - run_rotary_embedding(k_input, - cos_cache.begin(), - sin_cache.begin(), - k_rotary, - pos_ids.data(), - params); - - auto v_input = k_input + kv_num_heads * params.head_stride; - auto v_rotary = k_rotary + kv_num_heads * params.head_stride; - params.num_heads = num_heads; - - pack_v_into_rotary_qkv(params, v_input, v_rotary); - }); - }); - - return result; - } -}; - -} // namespace op -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index acc60f91496..2a8cf1deef0 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -64,7 +64,6 @@ #include #include #include -#include #include #include #include diff --git a/src/targets/gpu/jit/gqa_rotary_embedding.cpp b/src/targets/gpu/jit/gqa_rotary_embedding.cpp deleted file mode 100644 index 34061635280..00000000000 --- a/src/targets/gpu/jit/gqa_rotary_embedding.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -namespace gpu { - -using namespace migraphx::gpu::gen; // NOLINT - -// NOLINTNEXTLINE -static const char* const gqa_rotary_embedding_kernel = R"__migraphx__( -#include -#include -#include -#include -#include - -namespace migraphx { - - - -extern "C" { - - -MIGRAPHX_GLOBAL void ${kernel}(${params}) -{ - transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { - - gqa_rotary_embedding(xs..., make_gqa_parameters(${gqa_params})); - }); -} - - -} - -} // namespace migraphx - -)__migraphx__"; - -struct gqa_rotary_embedding_compiler : compiler -{ - std::vector names() const - { - return {"gqa_rotary_embedding", "gpu::gqa_rotary_embedding"}; - } - - operation compile_op(context& ctx, const std::vector& inputs, const value& v) const - { - auto params = init_params(inputs, v); - auto gqa_params_str = params.make_init_str(); - - hip_compile_options options; - options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements())); - options.inputs = inputs; - options.output = inputs.back(); - options.kernel_name = v.get("kernel", "gqa_rotary_embedding_kernel"); - - auto src = interpolate_string(gqa_rotary_embedding_kernel, - {{"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"gqa_params", gqa_params_str}, - {"kernel", options.kernel_name}}); - return compile_hip_code_object(ctx, src, options); - } - - compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const - { - auto shapes = to_shapes(ins->inputs()); - auto v = op.to_value(); - return compile_op(ctx, shapes, v); - } -}; - -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/gqa_rotary_embedding.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/gqa_rotary_embedding.hpp deleted file mode 100644 index 7951b7b2270..00000000000 --- a/src/targets/gpu/kernels/include/migraphx/kernels/gqa_rotary_embedding.hpp +++ /dev/null @@ -1,178 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_KERNELS_ROTARY_EMBEDDING_HPP -#define MIGRAPHX_GUARD_KERNELS_ROTARY_EMBEDDING_HPP - -#include -#include -#include - -namespace migraphx { - -template -__device__ void run_rotary_embedding(Input input, - CosCache cos_cache, - SinCache sin_cache, - Output output, - PosIDs pos_ids, - Params params, - index_int idx, - bool is_query = false) -{ - const index_int batch_size = params.batch_size; - const index_int sequence_length = params.sequence_length; - const index_int n_heads = is_query ? params.num_heads : params.kv_num_heads; - const index_int head_size = params.head_size; - const index_int head_stride = params.head_stride; - const index_int seq_stride = params.seq_stride; - const index_int batch_stride = params.batch_stride; - const int position_ids_format = params.position_ids_format; - const index_int rotary_emb_dim = params.rotary_embedding_dim; - const index_int half_rotary_emb_dim = rotary_emb_dim / 2; - - const index_int loop_len = batch_size * sequence_length * n_heads; - const index_int i = idx / head_size; - const index_int ii = idx % head_size; - if(i < loop_len) - { - const index_int b = (i / n_heads) / sequence_length; - const index_int s = (i / n_heads) % sequence_length; - const index_int n = i % n_heads; - const index_int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - auto input_data = input + block_offset; - auto output_data = output + block_offset; - - // Cache is (M, H/2) or (M, rotary_embedding_dim/2) - int position_id = (position_ids_format == 0) - ? static_cast(pos_ids[0]) + s - : static_cast(pos_ids[b * sequence_length + s]); - position_id = (sequence_length == 1) ? position_id : s; - - const index_int cache_offset = position_id * half_rotary_emb_dim; - auto cos_data = cos_cache + cache_offset; - auto sin_data = sin_cache + cache_offset; - - int cache_idx = 0; - double sign = 0.0; - int j = 0; - if(ii < rotary_emb_dim) - { - if(params.interleaved) - { - cache_idx = (ii / 2) % half_rotary_emb_dim; - sign = (ii % 2 == 0) ? -1.0 : 1.0; - j = (ii % 2 == 0) ? ii + 1 : ii - 1; // i - sign - } - else - { - cache_idx = ii % half_rotary_emb_dim; - sign = (ii < half_rotary_emb_dim) ? -1.0 : 1.0; - j = (ii + half_rotary_emb_dim) % rotary_emb_dim; - } - double out_data = - static_cast(input_data[ii]) * static_cast(cos_data[cache_idx]) + - sign * static_cast(input_data[j]) * - static_cast(sin_data[cache_idx]); - output_data[ii] = out_data; - } - else if(ii < head_size) - { - output_data[ii] = input_data[ii]; - } - } -} - -template -__device__ void -pack_v_into_rotary_qkv(Params params, const Input input, Output output, index_int idx) -{ - const index_int loop_len = params.batch_size * params.sequence_length * params.kv_num_heads; - auto i = idx / params.head_size; - auto ii = idx % params.head_size; - if(i < loop_len) - { - const index_int b = (i / params.kv_num_heads) / params.sequence_length; - const index_int s = (i / params.kv_num_heads) % params.sequence_length; - const index_int n = i % params.kv_num_heads; - const index_int block_offset = - b * params.batch_stride + s * params.seq_stride + n * params.head_stride; - const Input input_data = input + block_offset; - Output output_data = output + block_offset; - if(ii < params.head_size) - { - output_data[ii] = input_data[ii]; - } - } -} - -template -__device__ void gqa_rotary_embedding(Output output, - Query query, - SeqLensK seqlens_k, - CosCache cos_cache, - SinCache sin_cache, - Params params) -{ - auto ind = make_index(); - ind.global_stride(output.get_shape().elements(), [&](auto idx) { - auto q_input = query.begin(); - auto q_rotary = output.begin(); - auto k_input = q_input + params.num_heads * params.sequence_length * params.head_size; - auto k_rotary = q_rotary + params.num_heads * params.sequence_length * params.head_size; - auto v_input = k_input + params.kv_num_heads * params.sequence_length * params.head_size; - auto v_rotary = k_rotary + params.kv_num_heads * params.sequence_length * params.head_size; - auto q_chunk_size = - params.batch_size * params.num_heads * params.sequence_length * params.head_size; - auto kv_chunk_size = - params.batch_size * params.kv_num_heads * params.sequence_length * params.head_size; - if(idx < q_chunk_size) - { - run_rotary_embedding(q_input, - cos_cache.begin(), - sin_cache.begin(), - q_rotary, - seqlens_k.begin(), - params, - idx, - true); - } - else if(idx < q_chunk_size + kv_chunk_size) - { - run_rotary_embedding(k_input, - cos_cache.begin(), - sin_cache.begin(), - k_rotary, - seqlens_k.begin(), - params, - idx - q_chunk_size); - } - else if(idx < output.get_shape().elements()) - { - pack_v_into_rotary_qkv(params, v_input, v_rotary, idx - (q_chunk_size + kv_chunk_size)); - } - }); -} - -} // namespace migraphx -#endif diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 4d40a4d0d17..245daf6f55e 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -1324,13 +1325,7 @@ TEST_CASE(kv_cache_attention) mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 6, 2}}}), query); auto tsp_q = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), rsp_q); - auto rope = mm->add_instruction( - migraphx::make_op("gqa_rotary_embedding", - {{"num_heads", 2}, {"kv_num_heads", 2}, {"interleaved", 0}}), - tsp_q, - slk, - cos_cache, - sin_cache); + auto rope = migraphx::op::builder::add("rotary_embedding", *mm, {tsp_q, slk, cos_cache, sin_cache}, {{"interleaved", false}}).at(0); auto slc_k = mm->add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {4}}}), rope); auto slc_v = mm->add_instruction( @@ -1397,13 +1392,7 @@ TEST_CASE(kv_cache_attention) mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 6, 2}}}), query); auto tsp_q = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), rsp_q); - auto rope = mm->add_instruction( - migraphx::make_op("gqa_rotary_embedding", - {{"num_heads", 2}, {"kv_num_heads", 2}, {"interleaved", 0}}), - tsp_q, - slk, - cos_cache, - sin_cache); + auto rope = migraphx::op::builder::add("rotary_embedding", *mm, {tsp_q, slk, cos_cache, sin_cache}, {{"interleaved", false}}).at(0); auto slc_k = mm->add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {4}}}), rope); auto slc_v = mm->add_instruction( diff --git a/test/ref/gqa_rotary_embedding.cpp b/test/ref/gqa_rotary_embedding.cpp deleted file mode 100644 index 5252c79b921..00000000000 --- a/test/ref/gqa_rotary_embedding.cpp +++ /dev/null @@ -1,500 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - */ -#include -#include -#include -#include -#include -#include - -#include - -TEST_CASE(gqa_rotary_embedding_test) -{ - migraphx::program p; - auto* mm = p.get_main_module(); - - const size_t batch_size = 1; - const size_t sequence_length = 8; - const size_t num_heads = 4; - const size_t kv_num_heads = 2; - const size_t head_size = 16; - const size_t max_cache_sequence_length = 8; - const size_t total_sequence_length = 8; - const size_t max_rotary_seq_length = max_cache_sequence_length; - const size_t rotary_dim = head_size / 2; - const bool interleaved = false; - - migraphx::shape qkv_shape{ - migraphx::shape::float_type, - {batch_size, num_heads + 2 * kv_num_heads, sequence_length, head_size}}; - migraphx::shape key_total_sequence_lens_shape(migraphx::shape::int32_type, {batch_size}); - migraphx::shape cos_cache_shape(migraphx::shape::float_type, - {max_rotary_seq_length, rotary_dim}); - migraphx::shape sin_cache_shape(migraphx::shape::float_type, - {max_rotary_seq_length, rotary_dim}); - - auto qkv = mm->add_parameter("qkv", qkv_shape); - auto ktsl = mm->add_parameter("ktsl", key_total_sequence_lens_shape); - auto cos_cache = mm->add_parameter("cos_cache", cos_cache_shape); - auto sin_cache = mm->add_parameter("sin_cache", sin_cache_shape); - auto rotary = mm->add_instruction(migraphx::make_op("gqa_rotary_embedding", - {{"num_heads", num_heads}, - {"kv_num_heads", kv_num_heads}, - {"interleaved", interleaved}}), - qkv, - ktsl, - cos_cache, - sin_cache); - mm->add_return({rotary}); - - std::vector qkv_val{ - 0.41749f, -0.69577f, -1.70273f, -0.79187f, 0.07310f, -0.27880f, -0.75174f, -0.72621f, - -2.06164f, -1.44138f, -0.01891f, -0.05486f, -0.50047f, 0.35353f, -0.76615f, -1.74248f, - 0.59954f, -1.00662f, 1.02043f, -0.07006f, -1.96865f, -0.42948f, 1.38265f, 1.39979f, - 1.61774f, 0.76564f, -0.06511f, 0.51657f, -0.78820f, -0.53020f, 1.00847f, 0.75522f, - 1.10552f, 0.33774f, -1.25448f, -0.74513f, 0.32448f, 0.18892f, -0.80532f, -0.47895f, - -0.15562f, 0.23953f, 1.14514f, -0.72504f, 0.60569f, -0.01937f, 1.17494f, 0.44646f, - 0.17420f, -1.29149f, 0.09795f, 1.30044f, 0.65743f, -0.43525f, 1.21967f, -0.26364f, - -1.17287f, 1.08942f, -0.20404f, -0.67642f, -1.04814f, -0.30387f, -0.34929f, -0.67358f, - -0.45570f, 1.01702f, -0.20984f, -0.66244f, -1.03716f, 1.67139f, 0.22007f, 0.94488f, - -0.15520f, 0.00094f, 0.93557f, -1.05383f, 0.36045f, 0.05115f, -0.19947f, 1.57586f, - 0.00395f, 0.20170f, -1.59494f, -2.09666f, 1.28563f, 0.10925f, -1.95444f, -0.10990f, - -1.39617f, 0.96968f, 0.43793f, 1.49594f, 0.69040f, -0.45398f, -0.43307f, 0.13872f, - -1.23702f, 0.27749f, -1.54765f, -0.33909f, 0.95985f, -0.80969f, -0.06333f, -0.35975f, - -1.03807f, 0.95710f, -0.72967f, 0.73937f, -0.01144f, 0.52790f, 0.81002f, -0.27235f, - 1.68369f, -1.10408f, -1.12985f, 0.08155f, -1.80029f, -1.07238f, -0.50818f, 0.48642f, - 1.65785f, 0.34942f, -0.29293f, -0.02263f, -0.28906f, 0.21972f, -0.31756f, -0.61890f, - -0.97612f, 0.70052f, 1.17989f, 1.03988f, 0.28753f, 0.44653f, 0.58795f, -0.75481f, - -0.86908f, -0.86280f, 0.46151f, -0.40664f, 0.93974f, -0.59152f, 0.62067f, -1.16862f, - 0.19509f, -2.08621f, -1.25910f, -0.69394f, -0.46405f, 1.65313f, -0.62303f, 0.57686f, - 0.12457f, -0.50076f, 0.52905f, -0.21097f, 0.13185f, -0.06517f, -0.81022f, -0.15991f, - -0.85608f, -0.66973f, 0.82188f, 0.84750f, 1.11381f, -0.27893f, 0.96663f, 0.17090f, - -0.58623f, 0.08810f, -0.64246f, 1.13702f, 0.78309f, 0.92382f, -0.37715f, 0.37516f, - 0.56550f, -0.14689f, -0.78602f, 1.58871f, -1.83865f, -0.05113f, -0.00113f, -0.20537f, - -0.25702f, -0.48554f, -0.42965f, -0.80981f, -0.77238f, -0.23455f, -1.04290f, 1.03438f, - -2.45761f, 0.44027f, 1.38838f, -1.69399f, -0.26094f, 1.95926f, 0.15914f, 0.56232f, - 1.48083f, 0.08590f, -0.19955f, 0.75460f, -0.52121f, -1.55349f, -0.34823f, -0.26756f, - -1.55488f, 0.92425f, -0.71535f, 0.61587f, -0.45286f, -0.92762f, -0.14223f, -1.34620f, - 1.03421f, 0.22707f, -0.90337f, 0.42504f, 1.18391f, 0.13978f, -0.16003f, -0.27895f, - -0.62004f, -0.68513f, 0.15930f, 0.81287f, 0.70400f, 0.62487f, 0.92692f, -1.02183f, - 0.67300f, 0.54589f, 0.81747f, -1.05813f, -0.44010f, -3.08890f, 0.98540f, 0.35170f, - -0.17988f, -0.37412f, 0.25646f, -0.60839f, 0.13135f, -0.06886f, 0.55302f, -0.49637f, - -0.41079f, 0.71376f, -0.92411f, -0.04047f, -0.49496f, -0.16855f, -0.23395f, 1.22807f, - 0.98338f, -0.72929f, -1.31395f, 1.29489f, 0.32306f, -0.69334f, 0.08734f, -0.15889f, - 1.89157f, 0.64903f, 0.08721f, 1.88299f, 0.90821f, 0.57134f, 0.29974f, -0.02099f, - -1.47984f, -1.19396f, -0.86907f, -0.46785f, -1.51203f, -0.48414f, -0.28719f, -0.35233f, - -1.87752f, 1.27135f, 0.65182f, -0.97655f, 0.51972f, -0.56604f, -0.78326f, -0.22718f, - 0.03281f, 0.60940f, -0.27093f, 1.08492f, -2.64958f, 0.47716f, -1.31534f, -1.52301f, - 1.54800f, -0.89364f, -1.19203f, -1.02952f, -0.89654f, 0.19871f, 0.22578f, -0.24013f, - 0.74516f, -0.76311f, -0.70077f, -0.05940f, 2.51484f, 0.45830f, -0.58173f, 0.88659f, - -0.75417f, -0.97556f, 1.05278f, 0.12707f, 0.86963f, -0.41560f, -1.37358f, 0.73604f, - 0.34192f, 0.80609f, 1.53944f, 0.40806f, 0.14858f, 0.28071f, 0.10681f, -0.88577f, - 1.25632f, -0.91592f, -1.26772f, 1.14000f, 2.31817f, 1.38137f, -0.50686f, -0.02952f, - -0.39148f, 0.32838f, -0.36859f, -1.35995f, 1.32267f, -1.54589f, -1.57510f, -0.66400f, - -0.19672f, -1.07867f, 0.70195f, -2.03537f, -0.45462f, -0.64457f, 0.31647f, 1.20802f, - -0.42087f, 0.68382f, 2.00249f, 1.45079f, -0.95763f, -1.11372f, 0.87328f, -0.39358f, - 1.41147f, -1.37817f, -0.47008f, -0.61774f, -0.05532f, 0.95466f, -1.10584f, 0.17766f, - -0.73075f, -0.69812f, -0.68570f, -0.07292f, 0.63213f, 1.15353f, -0.40322f, -0.05030f, - -0.83515f, -0.06614f, 1.02871f, 0.31555f, 0.13493f, -1.07473f, -1.43802f, 0.74787f, - -0.17787f, 2.72647f, 1.12792f, -0.87049f, 2.23661f, -0.72028f, -1.94251f, -0.82372f, - -0.39990f, -0.83586f, -0.15177f, -1.16006f, -1.00450f, 0.46886f, 1.12177f, 2.36713f, - -0.85803f, 1.14115f, -0.54052f, -0.33579f, 0.80441f, -0.03780f, 0.10341f, 0.08724f, - 0.32319f, 0.20989f, -0.11073f, 0.25509f, -0.29188f, -0.68462f, 0.27836f, -0.16048f, - -0.58644f, -0.75119f, -0.71220f, 0.36131f, -2.72436f, -0.05491f, 0.15442f, -0.42460f, - -0.45834f, -0.73080f, 0.64778f, 0.76119f, -2.56361f, 0.04359f, 0.71025f, 2.14647f, - 1.51538f, 0.01322f, 2.17842f, -0.20094f, 2.44536f, 0.01128f, -0.94017f, 2.85214f, - 0.61998f, 1.25180f, -0.71116f, -0.37180f, -1.22676f, 0.45919f, 0.23963f, 0.86249f, - -0.58306f, 0.27836f, -0.20477f, -1.56387f, -0.45388f, 0.40267f, 1.77693f, 0.41270f, - -2.19566f, 0.46689f, -1.00889f, 1.18128f, -1.69895f, 0.94235f, -0.48070f, 0.88864f, - 1.25445f, -0.85912f, 0.63483f, 2.48495f, 1.16718f, 0.94282f, 1.01831f, 0.97040f, - -0.48674f, -0.98615f, -0.53975f, -0.33098f, 0.80657f, 1.15605f, 0.09327f, -1.08562f, - -1.64618f, -0.29966f, 0.21517f, -2.44469f, -1.25757f, 0.51321f, 0.61697f, 0.31420f, - -0.84412f, -0.40086f, -0.28057f, -0.24301f, -1.28083f, -0.43024f, -0.65606f, -0.14499f, - -0.70121f, 0.70340f, 1.35571f, -0.25670f, 0.44737f, -1.85772f, -1.01866f, 0.57725f, - -2.49752f, 0.42235f, -0.23843f, -1.39177f, -0.76879f, 0.41183f, 2.10108f, 0.28865f, - 0.06077f, -1.14342f, 0.34651f, 0.67033f, -0.08636f, -0.84150f, 0.55234f, -0.52409f, - -0.84508f, 1.20556f, -0.98627f, 0.02287f, 2.32880f, -1.82350f, -0.65655f, 1.44936f, - 0.10410f, 2.05022f, 1.08761f, 1.57381f, -0.01312f, -0.66371f, -1.02632f, 0.24346f, - -1.47868f, 1.05582f, 0.38794f, -0.88618f, -0.48330f, -1.93097f, 0.38556f, 0.45613f, - -1.53924f, -1.21630f, 0.97219f, 0.68307f, 0.36223f, -1.26150f, -0.27909f, -1.29269f, - 0.81657f, -0.05342f, -1.31958f, -0.60319f, -1.04079f, 0.39469f, -0.75666f, 1.50323f, - 0.25399f, -0.10152f, -0.76702f, 1.83365f, 0.57462f, -0.77335f, -1.35593f, -0.46823f, - -0.04678f, -0.66350f, -0.24358f, -1.57921f, -0.69181f, -0.36496f, -0.01369f, -0.70503f, - -1.71507f, 1.58063f, -0.58568f, -1.55290f, -0.58359f, 1.35833f, 0.54550f, -0.70589f, - -0.24310f, -1.23408f, 0.42297f, -0.51576f, 1.30628f, -0.81409f, -0.35845f, 0.89670f, - 1.01546f, 0.39702f, -1.26967f, -0.36245f, 1.29258f, -0.49741f, -0.25744f, 0.92564f, - 0.59873f, 0.85128f, -0.86577f, 1.25455f, 0.36780f, 0.70003f, 1.60151f, -0.67805f, - 0.25288f, -1.44379f, 0.59452f, -0.74957f, 0.92938f, 0.93066f, 1.91045f, 0.58974f, - -0.33494f, 0.52111f, 0.73124f, -0.27959f, 0.81104f, 0.25312f, -1.38699f, 0.20307f, - 0.76345f, 0.14203f, 0.90887f, -0.83386f, -0.21649f, -0.13751f, 2.50241f, 0.38624f, - -0.67405f, -2.28033f, 1.00245f, 2.74601f, 0.78501f, 2.08665f, 0.72543f, 0.57949f, - -2.17947f, -0.52484f, -0.23829f, -1.05610f, 0.67673f, -0.81046f, 1.36806f, 0.27357f, - -0.68868f, 0.17381f, -1.02706f, 1.17405f, 1.05738f, -0.74259f, 0.57364f, 0.67129f, - 1.16308f, -0.78270f, -2.66280f, 0.29430f, -1.55534f, 0.02705f, -0.49464f, 1.54097f, - 0.58723f, 0.08179f, 0.39286f, -0.74650f, -0.38050f, -0.98030f, -0.29283f, -0.65102f, - 0.32762f, 0.84588f, -0.42554f, -0.15737f, -0.85467f, 0.05529f, 0.69491f, -1.04209f, - -0.85984f, -1.32812f, -0.29532f, 0.66120f, 0.11061f, 0.84870f, 0.86388f, -1.53822f, - 0.44723f, 0.13878f, -0.17476f, -0.41955f, -1.01415f, -1.36983f, 0.41508f, -1.88003f, - -1.13967f, -1.00110f, -0.31072f, -0.78613f, 0.69637f, -0.52052f, 1.20858f, -0.24648f, - 1.31078f, -1.65722f, 0.52834f, 1.24568f, 0.21253f, -0.40617f, -0.10599f, -0.49272f, - -0.12628f, 0.37853f, -0.03980f, 0.58942f, -1.50853f, 0.32210f, 0.40937f, -0.49893f, - 0.37507f, 0.19290f, -1.85422f, -0.03451f, 1.00867f, 0.74706f, 1.14846f, 0.29571f, - 0.77227f, -0.32998f, 0.14887f, 0.55427f, -0.71247f, 1.00684f, -1.00142f, 0.19765f, - -1.54164f, 0.75906f, -3.29353f, 0.62726f, -0.96910f, 1.45244f, 0.37286f, 0.70223f, - 0.13735f, 1.50773f, 0.91042f, 0.19532f, -0.16148f, -1.98474f, -0.99344f, 0.45995f, - -2.43026f, -0.08176f, -1.10964f, 0.18464f, -1.20368f, -1.53090f, -0.47588f, 0.36687f, - -0.30770f, 2.05622f, 0.80959f, 0.24895f, -0.55335f, 2.20205f, -0.59102f, 0.61142f, - 1.54316f, -0.70976f, 0.40827f, 0.09174f, -0.81059f, -0.94260f, 0.83612f, 1.54785f, - 0.52762f, 0.29894f, 0.51916f, 0.18861f, -0.48849f, 0.37812f, -0.55108f, 1.16715f, - 0.18136f, -2.05249f, -0.52255f, -3.17783f, -1.23747f, 0.96977f, 0.94546f, 0.53913f, - -0.87105f, -0.23456f, 0.94457f, 0.33633f, 1.67793f, -0.64343f, 1.64205f, -0.27973f, - -0.28263f, 0.59008f, 0.46478f, -0.67101f, -0.97408f, 1.09617f, -1.49934f, 1.65995f, - 0.48849f, 2.09345f, 1.12164f, 0.00400f, 0.41253f, -1.82862f, 0.02542f, 0.15923f, - -0.02319f, -1.10920f, 0.33385f, 0.44133f, 0.52444f, 0.17839f, 0.35989f, 1.74398f, - 3.12485f, 0.32066f, 1.83530f, 0.05643f, -1.48623f, -2.86640f, -0.98590f, 2.43114f, - -0.35453f, 1.19489f, 1.21466f, 0.43728f, 1.17409f, 1.10532f, -0.54684f, -0.27415f, - 2.10486f, 0.42502f, 0.31442f, 1.09029f, -1.39990f, 0.34544f, -0.04665f, 1.93631f, - -1.80739f, -0.72702f, 0.26711f, -0.49996f, -0.08280f, 1.45220f, -1.87458f, 2.56975f, - -0.35613f, -0.82097f, -0.73138f, 0.64083f, -0.82572f, 0.29276f, -1.20876f, 0.17214f, - -0.98109f, -0.93785f, 0.23218f, 0.74644f, -0.55791f, 0.76360f, -2.14018f, 0.41112f, - -1.00239f, 0.71319f, 0.23396f, 1.44287f, 0.08250f, 2.22372f, -0.53487f, 2.49612f, - -1.27615f, -0.19306f, -0.08782f, 1.84034f, 0.32768f, 0.56849f, 0.04441f, -0.65827f, - -0.23335f, -0.23385f, -0.99781f, 0.02744f, 0.27255f, -0.16298f, 0.16596f, -0.43781f, - 0.45376f, 1.61600f, 0.43510f, -0.15598f, 0.02663f, 0.20136f, 0.16456f, 2.32006f, - -0.17321f, -0.29193f, 0.91427f, 1.58508f, -0.46040f, -1.24638f, -0.67820f, -1.15898f, - 0.17969f, 1.22766f, 2.75006f, 0.42565f, -0.70687f, 0.32029f, 1.39965f, 0.56489f, - 0.71663f, 0.96956f, 0.11987f, 0.13721f, -1.87023f, -0.23010f, 0.06482f, 0.04463f, - 1.68528f, 0.76126f, 1.49722f, -0.38899f, 0.55481f, 1.01654f, 1.39907f, 0.52457f, - -1.95718f, 0.11925f, -0.93854f, -0.97164f, 0.36083f, 0.15714f, -1.08132f, 0.18311f, - 0.04688f, 0.84368f, -0.93179f, -0.51981f, -0.10838f, -0.21285f, -1.50430f, 0.11392f, - 1.62670f, -1.49010f, 0.52821f, -0.38828f, 0.48409f, -0.39124f, 0.59059f, -0.09530f, - -0.25620f, -1.30062f, 0.56251f, -0.36103f, 0.99661f, 1.14148f, 1.92076f, -0.42200f, - -0.98999f, 2.53498f, -0.86633f, 0.25785f, -1.04939f, 1.24230f, 0.19481f, -0.23041f, - 0.23570f, -1.02525f, 1.16260f, -1.04002f, 0.22364f, 0.28499f, -0.34476f, -0.67831f, - -0.55750f, -0.91398f, 1.37583f, 0.64503f, -2.02422f, -1.52848f, -0.27042f, 0.41021f, - -0.43892f, 0.75682f, 0.18781f, 0.92758f, 0.50460f, 0.73314f, 0.03367f, -0.27875f, - -0.63667f, 0.18394f, 1.42434f, 2.00770f, -0.88286f, -0.55983f, -1.12401f, 0.34193f, - -2.51687f, -1.04707f, -0.63970f, -0.70438f, 0.59782f, 0.74183f, 0.31749f, -0.28442f, - -1.95803f, -1.79381f, 0.46461f, -0.17142f, 0.41181f, 0.27836f, -0.02363f, 0.93865f}; - - std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length - 1); - - std::vector cos_cache_val{ - 0.60305f, 0.94544f, 0.59646f, -0.94253f, -0.92642f, 0.19489f, -0.97555f, -0.99972f, - -0.35706f, -0.77041f, -0.87666f, 0.63573f, -0.13506f, 0.51438f, 0.44535f, -0.99741f, - 0.80078f, 0.27214f, 0.75584f, 0.82948f, 0.45954f, -0.68932f, -0.76928f, 0.81090f, - 0.75784f, 0.88164f, -0.93231f, 0.04034f, 0.57829f, 0.90856f, 0.55415f, -0.96171f, - 0.14195f, 0.61549f, 0.45128f, 0.91793f, 0.17327f, 0.99181f, 0.99467f, -0.99735f, - 0.99103f, 0.67559f, 0.98632f, -0.58485f, 0.21572f, -0.53796f, -0.78902f, 0.17645f, - 0.97012f, -0.39310f, 0.33858f, 0.82521f, -0.35662f, 0.39955f, 0.24059f, 0.32539f, - 0.82187f, 0.89318f, -0.54623f, 0.05817f, 0.80386f, -0.63893f, -0.19885f, -0.96778f}; - - std::vector sin_cache_val{ - -0.79771f, -0.32581f, -0.80264f, 0.33412f, 0.37650f, -0.98082f, -0.21977f, 0.02385f, - -0.93408f, 0.63755f, -0.48110f, -0.77191f, -0.99084f, -0.85756f, 0.89536f, 0.07192f, - 0.59896f, -0.96226f, 0.65475f, -0.55853f, 0.88816f, 0.72445f, -0.63891f, -0.58519f, - 0.65244f, 0.47192f, -0.36167f, -0.99919f, -0.81583f, 0.41776f, 0.83242f, 0.27406f, - -0.98987f, -0.78815f, -0.89238f, -0.39675f, 0.98487f, -0.12775f, 0.10315f, 0.07274f, - -0.13363f, 0.73728f, -0.16483f, 0.81114f, 0.97646f, -0.84297f, 0.61436f, -0.98431f, - -0.24262f, -0.91950f, -0.94094f, 0.56482f, 0.93425f, 0.91671f, -0.97063f, 0.94558f, - -0.56968f, -0.44970f, 0.83764f, 0.99831f, 0.59482f, 0.76926f, 0.98003f, 0.25181f}; - - p.compile(migraphx::make_target("ref")); - migraphx::parameter_map pm; - pm["qkv"] = migraphx::argument(qkv_shape, qkv_val.data()); - pm["ktsl"] = migraphx::argument(key_total_sequence_lens_shape, ktsl_val.data()); - pm["cos_cache"] = migraphx::argument(cos_cache_shape, cos_cache_val.data()); - pm["sin_cache"] = migraphx::argument(sin_cache_shape, sin_cache_val.data()); - auto qkv_rotary = p.eval(pm).front(); - std::vector qkv_rotary_vals(qkv_shape.elements()); - qkv_rotary.visit([&](auto output) { qkv_rotary_vals.assign(output.begin(), output.end()); }); - - std::vector qkv_rotary_gold{ - -1.39282, -1.12742, -1.03079, 0.764691, 0.120706, 0.292414, 0.564983, - 0.767565, -1.57631, -1.13605, 1.3554, -0.212872, 0.491168, 0.342352, - 0.912628, 1.72467, 1.29703, 0.287376, -0.925895, 0.354206, -0.515094, - -0.675594, -0.287181, -1.45048, -1.13765, -1.23163, -0.43385, 0.382479, - 2.05707, 0.0955806, 1.68709, -0.652591, 0.978488, 0.322403, -1.69797, - -1.02303, -0.388838, -0.116194, 1.3702, -0.127117, 0.537545, -0.259808, - 0.0441719, -0.185229, 0.566529, 0.150215, -0.389331, 0.642311, 0.897243, - -1.65275, -0.165115, -0.623412, -0.474919, -0.268506, 0.966636, 0.438147, - -0.775193, 0.350996, 0.154803, -1.32667, -1.14248, -0.457914, 0.821719, - 0.575535, -0.218314, 0.626707, 0.740187, -1.02618, -0.534705, 1.66424, - 0.239472, -1.057, 0.429053, -0.800986, 0.609461, -0.704519, -0.959013, - -0.162789, -0.175707, -1.50295, -0.182656, -0.578659, -1.50094, 0.0128149, - -0.396812, -0.441464, 1.80815, 0.117152, -1.38417, 0.803815, 0.694833, - -2.57559, 1.4043, 0.152129, -0.859029, 0.132653, -1.45191, 0.770972, - -1.21058, -0.697431, -0.331614, -0.807443, 0.770993, 0.14047, -0.706927, - -0.631388, 1.20919, 0.418611, 0.90082, -0.531328, 0.256353, -0.428792, - 2.32822, -0.829008, 0.862528, 0.0273355, -1.27524, 0.516154, 0.41227, - -0.314902, 0.403373, 0.8086, -0.7864, 0.0800958, -1.30321, -0.965325, - -0.434885, 0.721444, -1.28192, 0.381191, 1.07418, -0.844252, -0.620186, - -0.49315, -0.43717, 0.78247, 0.254562, -1.04396, -0.671755, 0.730715, - -0.762339, -0.553247, -0.734708, 1.15029, 0.0466995, 1.9265, 1.35833, - -0.604008, 0.193317, 0.79445, 0.447972, -0.563865, -0.226709, -0.944273, - 0.141956, 0.401539, 0.441992, -1.45118, -0.918668, 0.200984, -0.334403, - -0.0974852, 1.04186, 1.33804, -0.183669, -0.476989, -0.984574, 0.358123, - -0.982199, 0.66843, 0.052529, 0.469781, 1.3491, -0.838878, -0.327456, - 0.204208, 0.596249, 0.099632, 0.577423, -0.745065, -1.6934, 0.0515309, - 0.867505, -0.0859758, 0.174175, -0.497392, 0.684847, -1.62009, 1.05337, - -0.234463, -0.578864, -1.05106, 1.11697, 0.338684, 0.448474, -1.25558, - 0.468111, 1.74476, 0.194212, -0.541368, 2.64292, -0.294128, -1.32902, - 1.36476, -0.347302, -1.79106, -0.329959, 0.307754, -1.40273, 0.457, - -0.854466, -0.704959, -1.25373, 0.616853, 0.210538, -0.51211, 1.23271, - 0.834837, -0.773101, 0.250972, -0.186807, 0.70676, 0.0388865, 1.27586, - -0.43823, 0.77127, 0.823126, 1.26844, 0.160103, 3.08129, 1.17947, - -0.665054, 0.803325, 0.415388, 0.126887, -0.414054, 0.81466, -0.661345, - -0.662619, -0.851782, -0.381857, -0.0131786, 0.633985, 0.00501156, 0.399999, - 0.173655, 0.11931, 0.171137, -0.235142, 0.805758, 0.719598, -0.609716, - -0.319749, 0.0547204, 0.588497, -1.31349, 2.10195, -0.478039, -0.71372, - -1.84962, -0.64123, 0.425257, -0.0193307, 0.159346, 0.356259, 0.851229, - 1.10665, -1.34213, -0.719752, 0.79139, -0.311606, 0.0171946, -1.22536, - 0.109289, 1.07547, -1.05123, 0.719174, -0.734445, 0.5734, 0.367756, - 2.05268, -1.74067, -0.153315, -0.259684, 1.42799, 0.124019, -0.605963, - 0.201252, -0.900916, -0.694072, 0.575702, 0.324902, -0.421317, -0.472871, - 1.15612, -1.37553, 1.25926, -0.829596, -1.07838, -1.45993, -2.76525, - 0.208704, 0.666696, 0.696529, 1.05676, -0.212402, 1.03409, 0.124571, - 2.16378, 0.590014, 0.82103, -1.05436, -0.085368, -1.22022, -0.72807, - 0.0644779, -1.54878, -0.186138, -1.24541, -0.464878, 1.29213, -0.225742, - -0.436569, 0.826866, -2.25735, 0.454881, 0.158523, 0.88557, -0.160122, - -1.19906, -1.94586, 0.884542, 0.548001, 1.3342, -0.493141, -0.0349891, - -0.414256, 1.01713, -0.247845, 2.44634, 0.729245, 0.288274, 1.04836, - 1.0719, -0.142642, -0.486631, 0.753102, 0.0872765, 1.19346, 1.64989, - -1.21738, 0.866737, -0.0658435, -1.53604, 0.235686, 1.54612, 0.393193, - -1.32013, -0.863259, -0.296059, 1.47141, -0.0870139, -2.04338, 0.30967, - -0.874938, -0.639524, -1.11369, -0.314353, -1.07635, -0.65329, -0.487139, - -0.319258, 0.427885, 0.0897219, 1.48948, -0.139642, -0.270091, 0.25487, - -1.13628, -0.0544412, 0.484468, 1.57404, -0.109217, -0.73644, -0.426269, - 2.30538, 0.550942, 1.20806, -1.69385, 0.319492, 2.14155, 0.767033, - -0.099271, -1.67857, -0.995838, 0.802543, 1.77267, 0.797841, -0.667437, - -2.38611, 0.608253, -1.01297, 0.42058, -0.0165653, -0.39785, -0.606546, - -0.203179, -0.0754723, 0.68607, 0.565839, 0.357117, 0.421368, -0.75762, - -0.319739, 0.216557, 0.166339, -0.195082, -0.907648, -0.962443, 0.724847, - 1.02494, 0.00627179, 0.334994, 0.911785, -0.718284, 0.52396, 0.0233051, - 0.429589, -3.59775, -0.069827, -0.645042, 1.98904, 0.743916, -0.579094, - -2.28817, -0.379605, 0.413299, -0.181583, -0.720468, -2.97931, 1.45854, - 1.10988, -0.124848, 0.185779, -2.70442, 0.421914, -0.649825, -0.0478077, - -2.25618, 0.539307, -0.992722, -0.96685, 1.5946, 0.519757, 1.81704, - -0.476246, 0.26548, 0.0679767, -0.272559, 1.7048, -0.74139, 0.883191, - -0.294848, -0.856265, 1.17815, 0.146656, 0.537179, -1.18485, -0.535799, - 0.467316, -0.860768, -0.897359, -0.650006, -1.29965, -0.637005, 2.20922, - 1.3137, -1.41668, 0.552017, -1.14673, -1.80179, -0.250794, -0.191147, - -1.88013, 1.64509, 0.599458, -0.488355, 0.239337, -0.419502, 0.433115, - -0.297457, -1.58134, -0.718115, 0.298562, -0.756691, 0.249923, -1.99909, - 0.818194, -0.540811, 1.37449, 0.816915, 0.870149, -1.85656, -0.631336, - -1.65317, 0.0609156, 1.26583, -0.337225, -0.351895, -1.6922, -1.41612, - -0.133992, -0.637481, -0.688252, -0.58494, -0.639447, -0.796788, -1.95253, - -0.683125, 0.489376, -0.558102, 1.51232, -0.866393, 0.202415, -2.18996, - 0.469978, 0.51911, -1.46145, -1.41838, -2.25265, -0.766826, 0.316467, - -0.477101, -1.99732, -0.802287, -0.275634, 0.43074, 0.493704, -0.863341, - -1.77821, 0.0782743, -0.424081, -0.747217, -0.437439, -1.72169, -0.382408, - 1.59882, 0.229693, 1.09085, 0.583644, -0.268739, -0.168567, -0.26805, - 1.15586, -0.36085, -0.881849, -0.156566, -1.18596, 0.760397, 1.97544, - 0.223005, 0.223615, 0.627005, -1.50396, -0.232102, -0.550169, -0.739993, - 0.643522, 0.130261, -0.632877, 0.5045, -1.89587, -0.868859, -0.654663, - -1.13629, 0.549711, -0.484092, 0.000221789, 0.113144, -1.63008, -1.38763, - 1.24321, 0.579567, 0.638793, 1.66319, -2.00534, 0.713527, 0.142681, - -0.348421, -0.980949, -0.300271, -0.94567, 1.08636, -0.359409, -1.39501, - -0.805637, -0.0803066, 0.857691, -0.780778, -0.504082, 0.457663, 0.867831, - -0.644647, -1.02772, 1.34149, 0.0427136, -1.42178, -1.03076, 0.164061, - 1.04671, 0.889346, -0.460635, -1.08915, 0.139808, -0.886619, -0.000123426, - -0.386286, 1.12272, -0.311824, -0.654093, 0.57904, 0.954279, -2.18804, - 0.623723, 0.243464, -0.898606, -1.33614, -2.78988, -0.640967, -1.51732, - -1.20855, -0.519717, -0.988904, -2.10062, 0.213738, -0.672715, 0.502266, - -1.439, 2.30819, -0.46356, -1.8637, -0.439576, -0.96649, 0.603132, - -1.02504, -0.886298, -1.20854, -0.289504, 1.32328, 0.335325, -0.421339, - -1.45944, -0.724789, 0.650192, -0.860273, -0.664577, 0.133231, 0.550855, - 2.52338, -0.389135, -0.16695, -0.826752, 0.0419003, -1.49016, -1.29609, - -0.562022, 0.936668, -0.701746, 1.59248, -0.527444, -0.573293, 0.76016, - 0.777361, -1.0478, -0.128279, 0.238765, -0.490994, -0.652953, 0.0173612, - -1.74518, -0.492311, -1.17539, -0.501837, 0.636348, -0.708254, -0.544971, - -1.10855, -0.637522, 1.0825, 0.594793, 0.0505524, -0.802418, -0.0183533, - -1.02712, -0.77603, 1.87559, -0.571897, -0.817117, 0.352893, 0.387498, - 1.23008, -1.04518, 1.01526, -0.278199, 0.0610645, -0.721664, 0.202913, - 1.3773, 1.52253, -0.361695, -0.147652, 0.527706, -1.31543, 1.53912, - -0.489441, 0.0468228, -0.0520686, 0.37135, 0.396255, 0.461767, 0.474904, - 0.373609, -1.80432, -0.429407, 0.913289, 0.446848, -0.290926, 0.246727, - 0.715222, -0.0807099, 0.452465, -0.352157, 0.831232, -1.17139, 1.49571, - -0.256195, -1.46225, 1.08797, -0.258473, 0.407301, 0.496463, 2.39975, - -0.874556, -0.206421, 0.507279, -1.29064, 3.40726, 0.515469, -0.847795, - 0.538463, -0.600921, 0.813678, -2.17265, 0.851656, -0.0720263, -0.237789, - -0.638447, -0.715811, 0.673846, -0.509011, 1.13158, 1.87334, -1.3717, - 0.198809, -1.16079, -2.58462, -0.348852, -0.499339, 1.54316, -0.70976, - 0.40827, 0.09174, -0.81059, -0.9426, 0.83612, 1.54785, 0.52762, - 0.29894, 0.51916, 0.18861, -0.48849, 0.37812, -0.55108, 1.16715, - 0.18136, -2.05249, -0.52255, -3.17783, -1.23747, 0.96977, 0.94546, - 0.53913, -0.87105, -0.23456, 0.94457, 0.33633, 1.67793, -0.64343, - 1.64205, -0.27973, -0.28263, 0.59008, 0.46478, -0.67101, -0.97408, - 1.09617, -1.49934, 1.65995, 0.48849, 2.09345, 1.12164, 0.004, - 0.41253, -1.82862, 0.02542, 0.15923, -0.02319, -1.1092, 0.33385, - 0.44133, 0.52444, 0.17839, 0.35989, 1.74398, 3.12485, 0.32066, - 1.8353, 0.05643, -1.48623, -2.8664, -0.9859, 2.43114, -0.35453, - 1.19489, 1.21466, 0.43728, 1.17409, 1.10532, -0.54684, -0.27415, - 2.10486, 0.42502, 0.31442, 1.09029, -1.3999, 0.34544, -0.04665, - 1.93631, -1.80739, -0.72702, 0.26711, -0.49996, -0.0828, 1.4522, - -1.87458, 2.56975, -0.35613, -0.82097, -0.73138, 0.64083, -0.82572, - 0.29276, -1.20876, 0.17214, -0.98109, -0.93785, 0.23218, 0.74644, - -0.55791, 0.7636, -2.14018, 0.41112, -1.00239, 0.71319, 0.23396, - 1.44287, 0.0825, 2.22372, -0.53487, 2.49612, -1.27615, -0.19306, - -0.08782, 1.84034, 0.32768, 0.56849, 0.04441, -0.65827, -0.23335, - -0.23385, -0.99781, 0.02744, 0.27255, -0.16298, 0.16596, -0.43781, - 0.45376, 1.616, 0.4351, -0.15598, 0.02663, 0.20136, 0.16456, - 2.32006, -0.17321, -0.29193, 0.91427, 1.58508, -0.4604, -1.24638, - -0.6782, -1.15898, 0.17969, 1.22766, 2.75006, 0.42565, -0.70687, - 0.32029, 1.39965, 0.56489, 0.71663, 0.96956, 0.11987, 0.13721, - -1.87023, -0.2301, 0.06482, 0.04463, 1.68528, 0.76126, 1.49722, - -0.38899, 0.55481, 1.01654, 1.39907, 0.52457, -1.95718, 0.11925, - -0.93854, -0.97164, 0.36083, 0.15714, -1.08132, 0.18311, 0.04688, - 0.84368, -0.93179, -0.51981, -0.10838, -0.21285, -1.5043, 0.11392, - 1.6267, -1.4901, 0.52821, -0.38828, 0.48409, -0.39124, 0.59059, - -0.0953, -0.2562, -1.30062, 0.56251, -0.36103, 0.99661, 1.14148, - 1.92076, -0.422, -0.98999, 2.53498, -0.86633, 0.25785, -1.04939, - 1.2423, 0.19481, -0.23041, 0.2357, -1.02525, 1.1626, -1.04002, - 0.22364, 0.28499, -0.34476, -0.67831, -0.5575, -0.91398, 1.37583, - 0.64503, -2.02422, -1.52848, -0.27042, 0.41021, -0.43892, 0.75682, - 0.18781, 0.92758, 0.5046, 0.73314, 0.03367, -0.27875, -0.63667, - 0.18394, 1.42434, 2.0077, -0.88286, -0.55983, -1.12401, 0.34193, - -2.51687, -1.04707, -0.6397, -0.70438, 0.59782, 0.74183, 0.31749, - -0.28442, -1.95803, -1.79381, 0.46461, -0.17142, 0.41181, 0.27836, - -0.02363, 0.93865}; - - EXPECT(migraphx::verify::verify_rms_range(qkv_rotary_vals, qkv_rotary_gold)); -} - -TEST_CASE(gqa_rotary_embedding_interleaved_test) -{ - migraphx::program p; - auto* mm = p.get_main_module(); - - const size_t batch_size = 2; - const size_t sequence_length = 1; - const size_t num_heads = 2; - const size_t kv_num_heads = 1; - const size_t head_size = 16; - const size_t max_cache_sequence_length = 8; - const size_t total_sequence_length = 8; - const size_t max_rotary_seq_length = max_cache_sequence_length; - const size_t rotary_dim = head_size / 2; - const bool interleaved = true; - - migraphx::shape qkv_shape{ - migraphx::shape::float_type, - {batch_size, num_heads + 2 * kv_num_heads, sequence_length, head_size}}; - migraphx::shape key_total_sequence_lens_shape(migraphx::shape::int32_type, {batch_size}); - migraphx::shape cos_cache_shape(migraphx::shape::float_type, - {max_rotary_seq_length, rotary_dim}); - migraphx::shape sin_cache_shape(migraphx::shape::float_type, - {max_rotary_seq_length, rotary_dim}); - - auto qkv = mm->add_parameter("qkv", qkv_shape); - auto ktsl = mm->add_parameter("ktsl", key_total_sequence_lens_shape); - auto cos_cache = mm->add_parameter("cos_cache", cos_cache_shape); - auto sin_cache = mm->add_parameter("sin_cache", sin_cache_shape); - - auto rotary = mm->add_instruction(migraphx::make_op("gqa_rotary_embedding", - {{"num_heads", num_heads}, - {"kv_num_heads", kv_num_heads}, - {"interleaved", interleaved}}), - qkv, - ktsl, - cos_cache, - sin_cache); - mm->add_return({rotary}); - - std::vector qkv_val{ - -0.65048f, -0.73475f, -1.16252f, 1.23505f, 0.30815f, 1.41725f, -0.99702f, 1.83288f, - 0.17508f, -0.44192f, -0.60220f, 0.57942f, -1.13502f, -0.21030f, -0.21183f, -0.59764f, - 0.03369f, -2.07573f, 0.26817f, 0.79531f, 0.82783f, 1.75045f, -0.13390f, 0.55881f, - -0.68510f, -0.22383f, -1.07129f, -0.37183f, 0.59560f, -0.24106f, 0.72188f, 0.96579f, - 1.10218f, 0.71842f, -0.13842f, -1.18598f, -0.97063f, 0.34577f, -0.09583f, -0.12853f, - -0.86645f, -0.41244f, 0.26598f, -0.01910f, -0.14762f, -0.01239f, -0.42813f, -0.25926f, - 0.51649f, -1.49558f, -0.64219f, -1.09694f, 0.17579f, -0.52930f, 0.99243f, -0.48142f, - 2.87901f, -0.34344f, 1.67369f, -1.01097f, 1.18906f, 0.79148f, 0.03848f, -0.08710f, - 1.06663f, 0.53670f, 0.18055f, 0.06149f, -1.27977f, 1.04653f, -0.09765f, 0.66428f, - 1.37472f, -0.90719f, 2.09439f, -0.54025f, 0.67836f, -0.12357f, -1.05392f, 1.01185f, - 1.20504f, -0.34832f, 1.38105f, -0.43522f, -0.08815f, -0.12122f, 0.66614f, -0.24025f, - 0.78835f, -1.20457f, -1.46959f, 1.04148f, -0.15917f, -0.11583f, 0.32542f, -0.82666f, - -1.77097f, -0.02825f, 0.03732f, 0.36346f, -1.30739f, 0.38991f, -0.11768f, 0.23001f, - -1.09870f, -0.80748f, 1.09745f, 0.33018f, -0.80205f, 0.10119f, -1.22517f, -0.54121f, - 0.50709f, -0.39303f, -0.94137f, 0.54072f, -0.17975f, 0.04328f, 0.37207f, 2.18807f, - -0.53601f, -0.44769f, 2.41322f, -1.96112f, -0.13698f, 0.57829f, -1.85719f, 0.77514f}; - - std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length - 1); - - std::vector cos_cache_val{ - 0.15911f, -0.05395f, -0.59862f, 0.98028f, -0.55443f, -0.81122f, -0.47045f, 0.91331f, - -0.84324f, 0.98846f, -0.71105f, -0.79251f, -0.57709f, -0.82186f, -0.07956f, 0.78094f, - -0.65349f, 0.99776f, 0.91546f, -0.05603f, -0.37510f, 0.05328f, -0.90493f, -0.50526f, - -0.67425f, 0.99716f, -0.51956f, 0.52490f, -0.45744f, 0.85886f, 0.87088f, 0.20695f, - 0.10653f, -0.79970f, 0.00293f, -0.81702f, 0.38169f, 0.86232f, 0.99098f, -0.75832f, - 0.14130f, -0.97297f, 0.44135f, 0.97386f, -0.08836f, -0.99652f, -0.29374f, 0.67556f, - -0.24402f, 0.51681f, 0.01236f, -0.04189f, 0.89800f, -0.76051f, -0.50112f, -0.96679f, - -0.76057f, -0.74721f, -0.56633f, 0.05373f, 0.51381f, 0.76501f, 0.19223f, -0.36373f}; - - std::vector sin_cache_val{ - -0.98726f, 0.99854f, 0.80103f, 0.19763f, 0.83223f, -0.58474f, -0.88243f, 0.40727f, - -0.53754f, -0.15150f, 0.70314f, -0.60985f, 0.81668f, 0.56969f, -0.99683f, 0.62461f, - -0.75693f, -0.06687f, 0.40241f, 0.99843f, -0.92699f, -0.99858f, 0.42557f, -0.86297f, - -0.73850f, -0.07530f, -0.85444f, 0.85116f, 0.88924f, 0.51221f, -0.49150f, -0.97835f, - -0.99431f, -0.60041f, -1.00000f, 0.57661f, 0.92429f, 0.50637f, 0.13403f, 0.65189f, - -0.98997f, -0.23095f, 0.89734f, -0.22716f, 0.99609f, 0.08341f, 0.95588f, -0.73730f, - -0.96977f, 0.85610f, -0.99992f, -0.99912f, 0.44000f, -0.64933f, -0.86538f, 0.25557f, - 0.64926f, 0.66459f, 0.82418f, -0.99856f, 0.85790f, -0.64402f, -0.98135f, 0.93151f}; - - p.compile(migraphx::make_target("ref")); - migraphx::parameter_map pm; - pm["qkv"] = migraphx::argument(qkv_shape, qkv_val.data()); - pm["ktsl"] = migraphx::argument(key_total_sequence_lens_shape, ktsl_val.data()); - pm["cos_cache"] = migraphx::argument(cos_cache_shape, cos_cache_val.data()); - pm["sin_cache"] = migraphx::argument(sin_cache_shape, sin_cache_val.data()); - - auto qkv_rotary = p.eval(pm).front(); - std::vector qkv_rotary_vals(qkv_shape.elements()); - qkv_rotary.visit([&](auto output) { qkv_rotary_vals.assign(output.begin(), output.end()); }); - - std::vector qkv_rotary_gold{ - 0.971779, 0.136498, 0.0478448, -1.69544, -1.34258, -0.54866, 1.77667, - 1.09406, 0.469081, -0.0768618, -0.0875309, 0.831091, -0.424563, 1.07343, - 0.633757, 0.0200578, 1.32206, 1.60061, -0.728934, -0.416041, -1.91151, - -0.309051, 0.550811, 0.163732, -0.159987, -0.702753, -1.05901, 0.405479, - -0.122072, -0.630831, -1.16221, 0.321152, -1.30473, 0.169193, 0.891619, - 0.794184, 0.26472, -0.995794, -0.133494, 0.0887861, -0.0913584, -0.955243, - 0.191177, -0.185908, -0.0405359, 0.142485, 0.397227, -0.304507, 0.51649, - -1.49558, -0.64219, -1.09694, 0.17579, -0.5293, 0.99243, -0.48142, - 2.87901, -0.34344, 1.67369, -1.01097, 1.18906, 0.79148, 0.03848, - -0.0871, -1.1597, 0.284322, -0.175774, 0.0740458, -0.137757, -1.64744, - 0.658077, 0.133201, 1.48462, 0.713249, 1.2543, -1.76213, 0.00913571, - -0.689462, -0.559206, -1.34978, -0.690367, 1.04731, -0.742692, 1.24303, - 0.149829, -0.00400095, -0.204112, -0.678089, 1.43846, 0.0574054, -0.453517, - 1.74319, -0.144267, 0.133935, 0.651677, 0.603813, 1.36529, -1.12833, - -0.269438, -0.246778, 0.419058, -1.29834, 0.223356, 0.129869, 0.128214, - -1.35747, 1.0522, -0.454189, -0.0548753, 0.806544, 0.949774, -0.944404, - 0.50709, -0.39303, -0.94137, 0.54072, -0.17975, 0.04328, 0.37207, - 2.18807, -0.53601, -0.44769, 2.41322, -1.96112, -0.13698, 0.57829, - -1.85719, 0.77514}; - - EXPECT(migraphx::verify::verify_rms_range(qkv_rotary_vals, qkv_rotary_gold)); -} diff --git a/test/verify/test_group_query_attention.cpp b/test/verify/test_group_query_attention.cpp index 8324ae6a448..f38fa38cc1a 100644 --- a/test/verify/test_group_query_attention.cpp +++ b/test/verify/test_group_query_attention.cpp @@ -27,6 +27,22 @@ #include #include #include +#include +#include + +static migraphx::instruction_ref insert_rotary(migraphx::module& m, + bool interleaved, + std::size_t sequence_length, + std::vector args) +{ + // GQA position semantics: prefill starts from 0, decode uses seqlens_k + auto& pos_ids = args.at(1); + if(sequence_length > 1) + { + pos_ids = m.add_literal(migraphx::literal{migraphx::shape{pos_ids->get_shape().type(), {1}}, {0}}); + } + return migraphx::op::builder::add("rotary_embedding", m, args, {{"interleaved", interleaved}}).at(0); +} // NOLINTNEXTLINE(readability-function-size) static migraphx::program create_gqa_program(const size_t batch_size, @@ -79,36 +95,24 @@ static migraphx::program create_gqa_program(const size_t batch_size, transposed_qkv = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), transposed_qkv); - auto rotary_qkv = transposed_qkv; + auto qk = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {num_heads + kv_num_heads}}}), transposed_qkv); + auto cur_v = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {num_heads + kv_num_heads}}, {"ends", {num_heads + (2 * kv_num_heads)}}}), transposed_qkv); + if(do_rotary) { - std::vector rotary_inputs{ - transposed_qkv, slk, cos_cache, sin_cache}; - rotary_qkv = mm->add_instruction( - migraphx::make_op( - "gqa_rotary_embedding", - {{"kv_num_heads", kv_num_heads}, {"num_heads", num_heads}, {"interleaved", false}}), - rotary_inputs); + qk = insert_rotary(*mm, false, sequence_length, {qk, slk, cos_cache, sin_cache}); + if(test_rotary) { - mm->add_return({rotary_qkv}); + mm->add_return({qk}); return p; } } - auto rotary_k = mm->add_instruction( - migraphx::make_op( - "slice", - {{"axes", {1}}, {"starts", {num_heads}}, {"ends", {num_heads + kv_num_heads}}}), - rotary_qkv); - auto rotary_v = - mm->add_instruction(migraphx::make_op("slice", - {{"axes", {1}}, - {"starts", {num_heads + kv_num_heads}}, - {"ends", {num_heads + (2 * kv_num_heads)}}}), - rotary_qkv); - std::vector concat_k_inputs{rotary_k, slk, k}; - std::vector concat_v_inputs{rotary_v, slk, v}; + auto q = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {num_heads}}}), qk); + auto cur_k = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {num_heads}}, {"ends", {num_heads + kv_num_heads}}}), qk); + std::vector concat_k_inputs{cur_k, slk, k}; + std::vector concat_v_inputs{cur_v, slk, v}; k = mm->add_instruction( migraphx::make_op("concat_past_present", {{"kv_num_heads", kv_num_heads}}), @@ -131,10 +135,6 @@ static migraphx::program create_gqa_program(const size_t batch_size, auto past_sl = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {batch_size, num_heads}}}), slk); - auto q = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {num_heads}}}), - rotary_qkv); - if(kv_num_heads_factor != 1) { auto kv_new_lens = kv_lens;