diff --git a/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp b/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp index 5f73d4934a1f..3636447f82f3 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -14,6 +14,7 @@ #include "ck_tile/core/container/statically_indexed_array.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { @@ -529,4 +530,118 @@ load_tile_transpose(const tile_window_with_static_distribution, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window, + const index_t offset, + number = {}) +{ + using SrcDataType = typename BottomTensorView_::DataType; + using DstDataType = typename DistributedTensor_::DataType; + + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); + constexpr auto input_distr = TileDistribution_{}; + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; + + constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); + constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); + + constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths()); + constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths()); + + constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size(); + constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size(); + + // For mixed precision: element space size must be the same (total bytes match) + static_assert(y_in_element_space_size == y_out_element_space_size, + "For mixed precision transpose, input and output element space size must match!"); + + // Ensure total element counts are consistent and divisible by the input vector length. + constexpr index_t total_elems_in = + reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}); + constexpr index_t total_elems_out = + reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{}); + static_assert(total_elems_in == total_elems_out, + "For mixed precision transpose, input/output element counts must match!"); + static_assert(total_elems_in % number{} == 0, + "Input vector length must evenly divide total elements."); + + constexpr index_t num_of_access = total_elems_in / number{}; + + // Read as input type, convert to output type + using SrcDataVec = ext_vector_t{}>; + using DstDataVec = ext_vector_t{}>; + static_for<0, num_of_access, 1>{}([&](auto i) { + static_assert(number{} == 8, "Only PassThroughPack8 is supported for now."); + const element_wise::PassThroughPack8 elementwise_op{}; + + elementwise_op(out_tensor.get_thread_buffer().template get_as()(i), + trans_tensor.get_thread_buffer().template get_as()[i]); + }); +} + +/** + * @brief Mixed-precision transpose load with zero offset. + * + * Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0. + */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord_, + index_t UnaryOpSize_, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window, + number = {}) +{ + load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0, number{}); +} + } // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/host.hpp b/projects/composablekernel/include/ck_tile/host.hpp index 995d8545364f..fbe083335dd8 100644 --- a/projects/composablekernel/include/ck_tile/host.hpp +++ b/projects/composablekernel/include/ck_tile/host.hpp @@ -17,6 +17,7 @@ #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/print_matrix.hpp" #include "ck_tile/host/ranges.hpp" #include "ck_tile/host/reference/reference_batched_contraction.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp" diff --git a/projects/composablekernel/include/ck_tile/host/fill.hpp b/projects/composablekernel/include/ck_tile/host/fill.hpp index bddc0ae2d2ca..847ab8e34823 100644 --- a/projects/composablekernel/include/ck_tile/host/fill.hpp +++ b/projects/composablekernel/include/ck_tile/host/fill.hpp @@ -461,6 +461,44 @@ struct FillConstant } }; +template +struct FillIdentity +{ + std::size_t rows_{0}; + std::size_t cols_{0}; + T zero_{type_convert(0)}; + T one_{type_convert(1)}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + if(rows_ == 0 || cols_ == 0 || first == last) + return; + + const auto total = static_cast(std::distance(first, last)); + if(total < rows_ * cols_) + { + throw std::runtime_error("FillIdentity requires range size >= rows_ * cols_."); + } + + std::fill(first, first + rows_ * cols_, zero_); + + const auto min_dim = std::min(rows_, cols_); + for(std::size_t i = 0; i < min_dim; ++i) + *(first + i * cols_ + i) = one_; + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + //---------------------------------------------------------------------------------------------- /// @brief Transforms given input to fit 2:4 structured sparsity pattern so /// every subgroup of 4 elements contain at most 2 non-zero elements diff --git a/projects/composablekernel/include/ck_tile/host/print_matrix.hpp b/projects/composablekernel/include/ck_tile/host/print_matrix.hpp new file mode 100644 index 000000000000..8cf5588e619a --- /dev/null +++ b/projects/composablekernel/include/ck_tile/host/print_matrix.hpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +// Helper to print matrix (for debugging) +template +void print_matrix(const ck_tile::HostTensor& mat, + const std::string& name = "Matrix", + const int width = 3, + const int precision = 3) +{ + const auto lens = mat.get_lengths(); + assert(len(lens) == 2); + const ck_tile::index_t rows = lens[0]; + const ck_tile::index_t cols = lens[1]; + const ck_tile::index_t limit = 32; + + std::cout << name << " (" << rows << "×" << cols << "):\n"; + for(ck_tile::index_t i = 0; i < std::min(rows, ck_tile::index_t(limit)); ++i) + { + for(ck_tile::index_t j = 0; j < std::min(cols, ck_tile::index_t(limit)); ++j) + { + std::cout << std::setw(width) << std::setprecision(precision) + << ck_tile::type_convert(mat(i, j)) << " "; + } + if(cols > limit) + std::cout << "..."; + std::cout << "\n"; + } + if(rows > limit) + std::cout << "...\n"; + std::cout << "\n"; +} diff --git a/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index aa0f632c2169..a62bbe981cca 100644 --- a/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp b/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp index 9c90db67eddc..71919b61873d 100644 --- a/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp +++ b/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp b/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp index 9cac035c4457..924db5fb60e4 100644 --- a/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp +++ b/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/common.hpp b/projects/composablekernel/include/ck_tile/ops/common.hpp index ad7da5c18339..0113d8c9a280 100644 --- a/projects/composablekernel/include/ck_tile/ops/common.hpp +++ b/projects/composablekernel/include/ck_tile/ops/common.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp new file mode 100644 index 000000000000..1fa105937f9d --- /dev/null +++ b/projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +// DetermineWarpPrecType is a set of pattern-matching rules to determine the right precision types +// to use for the warp GEMM, given the precision types defined in the problem, and the compute data +// type. This gives rise to a type conversion: type conversions are sometimes needed to obtain +// a pair of types that are compatible with the hardware matrix operations available. A typical +// use case is mixed precision GEMMs. + +namespace ck_tile { +// For the most general case, default to no conversion. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = BPrecType; +}; + +// Use tf32_t if compute type if tf32_t +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::tf32_t; + using b_prec_type = ck_tile::tf32_t; +}; + +// For pk_fp4_t x pk_fp4_t, keep pk_fp4_t +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::pk_fp4_t; + using b_prec_type = ck_tile::pk_fp4_t; +}; + +// For pk_int4_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_int4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For pk_fp4_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For pk_fp4_raw_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_raw_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For fp8 x bf16, use fp8 +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; +}; + +// For bf16 x fp8, use bf16 +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf16_t; + using b_prec_type = ck_tile::bf16_t; +}; + +// For bf8 x bf16, use bf8 +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf8_t; + using b_prec_type = ck_tile::bf8_t; +}; + +// For bf16 x bf8, use bf16 +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf16_t; + using b_prec_type = ck_tile::bf16_t; +}; + +// For fp8 x fp16, use fp8 +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; +}; + +// For fp16 x fp8, use fp16 +template +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::half_t; + using b_prec_type = ck_tile::half_t; +}; +}; // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp b/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp index 0748c5fb49e8..6dc2ff35c0a1 100644 --- a/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -10,15 +10,19 @@ namespace ck_tile { -template +template struct ConverterLoader { template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src_window) { + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto tmp = load_tile(src); + const auto src = load_tile(src_window); // NOTE: we rely on types packing neatly here using RawSrcType = typename SrcDataType::type; @@ -30,30 +34,56 @@ struct ConverterLoader const element_wise::PassThroughPack8 elementwise_op{}; elementwise_op(dst.get_thread_buffer().template get_as()(i), - tmp.get_thread_buffer().template get_as()[i]); + src.get_thread_buffer().template get_as()[i]); }); } + + template + CK_TILE_DEVICE static void load_with_type_convert(WarpTile& dst, const WarpWindow& src_window) + { + if constexpr(LoadTranspose) + { + if constexpr(std::is_same_v) + { + load_tile_transpose(dst, src_window); + } + else + { + load_tile_transpose_convert(dst, src_window, number{}); + } + } + else + { + if constexpr(std::is_same_v) + { + load_tile(dst, src_window); + } + else + { + auto tmp = load_tile(src_window); + sweep_tile([&](auto i) { + dst(i) = type_convert(type_convert(tmp(i))); + }); + } + } + } }; template -CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src_window) { using SrcDataType = typename WarpWindow::Base::DataType; using DstDataType = typename WarpTile::DataType; if constexpr(is_packed_type_v && !is_packed_type_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); - ConverterLoader::load_interleaved_pk_type(dst, src); - } - else if constexpr(LoadTranspose) - { - load_tile_transpose(dst, src); + ConverterLoader:: + load_interleaved_pk_type(dst, src_window); } else { - load_tile(dst, src); + ConverterLoader:: + load_with_type_convert(dst, src_window); } } - } // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/elementwise.hpp b/projects/composablekernel/include/ck_tile/ops/elementwise.hpp index bc72f3b0ba1f..2c0ae4ad093f 100644 --- a/projects/composablekernel/include/ck_tile/ops/elementwise.hpp +++ b/projects/composablekernel/include/ck_tile/ops/elementwise.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 4ad699629c02..eea1eb6acd62 100644 --- a/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -447,6 +447,15 @@ CK_TILE_HOST_DEVICE bf16x8_t fp8x8_to_bf16x8_scale(const fp8x8_t& src, const flo return y; } +CK_TILE_HOST_DEVICE fp8x8_t bf16x8_to_fp8x8_scale(const bf16x8_t& src, const float& scale) +{ + fp8x8_t y; + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); + return y; +} + CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const float& scale) { fp16x8_t y; @@ -491,6 +500,15 @@ CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const flo return y; } +CK_TILE_HOST_DEVICE fp8x8_t fp16x8_to_fp8x8_scale(const fp16x8_t& src, const float& scale) +{ + fp8x8_t y; + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); + return y; +} + CK_TILE_HOST_DEVICE fp16x8_t bf8x8_to_fp16x8_scale(const bf8x8_t& src, const float& scale) { fp16x8_t y; @@ -620,12 +638,32 @@ struct PassThroughPack8 template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const fp8x8_t& x) const + { + y = fp8x8_to_fp16x8_scale(x, 1.0f); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const fp16x8_t& x) const + { + y = fp16x8_to_fp8x8_scale(x, 1.0f); + } + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const { y.lo = i4_to_half4(bit_cast(x)); y.hi = i4_to_half4(bit_cast(x) >> 8); } + CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const fp8x8_t& x) const + { + y = fp8x8_to_bf16x8_scale(x, 1.0f); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const bf16x8_t& x) const + { + y = bf16x8_to_fp8x8_scale(x, 1.0f); + } + CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const { y.lo = i4_to_bhalf4(bit_cast(x)); diff --git a/projects/composablekernel/include/ck_tile/ops/epilogue.hpp b/projects/composablekernel/include/ck_tile/ops/epilogue.hpp index d1b38a8bca6f..0eb9e59e723c 100644 --- a/projects/composablekernel/include/ck_tile/ops/epilogue.hpp +++ b/projects/composablekernel/include/ck_tile/ops/epilogue.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index fba831e20534..d4fc127a3b76 100644 --- a/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" @@ -100,21 +101,10 @@ struct CShuffleEpilogue // For warp gemm selection: use tf32_t if compute type was tf32_t // For pk_int4/pk_fp4: use the other data type using ATypeToUse = - std::conditional_t, - tf32_t, - std::conditional_t || - std::is_same_v, - BDataTypeBuf, - ADataTypeBuf>>; + typename DetermineWarpPrecType::a_prec_type; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = - std::conditional_t, - tf32_t, - std::conditional_t || - std::is_same_v || - sizeof(BDataTypeBuf) < sizeof(ADataTypeBuf), - ADataTypeBuf, - BDataTypeBuf>>; + typename DetermineWarpPrecType::b_prec_type; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/flatmm.hpp b/projects/composablekernel/include/ck_tile/ops/flatmm.hpp index e08fac48c7e9..2e71957ac774 100644 --- a/projects/composablekernel/include/ck_tile/ops/flatmm.hpp +++ b/projects/composablekernel/include/ck_tile/ops/flatmm.hpp @@ -21,6 +21,7 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha.hpp b/projects/composablekernel/include/ck_tile/ops/fmha.hpp index 8a5d77bf462e..633a36b7ff68 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha.hpp @@ -61,6 +61,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp b/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp index 60f5bd1c4e35..2eb4abd64117 100644 --- a/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp @@ -14,6 +14,7 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/gemm.hpp b/projects/composablekernel/include/ck_tile/ops/gemm.hpp index 7c087e9186db..8f2125cb6748 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm.hpp @@ -84,6 +84,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f7f5cd33dbb9..c5ed8931e7ee 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -95,12 +95,9 @@ struct BlockUniversalGemmAsBsCr using CDataType = remove_cvref_t; using ATypeToUse = - std::conditional_t, BDataType, ADataType>; - using BTypeToUse = std::conditional_t || - std::is_same_v || - sizeof(BDataType) < sizeof(ADataType), - ADataType, - BDataType>; + typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = + typename DetermineWarpPrecType::b_prec_type; using WarpGemm = remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index b4a8e9e8cb48..0f5b12fc65dc 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -910,12 +910,10 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - using ATypeToUse = if_select_t; - using BTypeToUse = std::conditional_t || - std::is_same_v || - sizeof(BDataType) < sizeof(ADataType), - ADataType, - BDataType>; + using ATypeToUse = + typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = + typename DetermineWarpPrecType::b_prec_type; using WarpGemm = WarpGemmDispatcher, diff --git a/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp b/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp index 29fccf8057b9..fd04c1d0234a 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp b/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp index 5b2ce7ff1915..907aabc0f1e7 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp @@ -33,6 +33,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp index 5bc4f0c6a042..1a115f9489e6 100644 --- a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp +++ b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp @@ -12,6 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp b/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp index 07d99890869e..faa165a8b001 100644 --- a/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp +++ b/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp b/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp index 8f9ab205ac45..2266c138729a 100644 --- a/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp +++ b/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp b/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp index eae0ea14a337..9f572ff5cbb2 100644 --- a/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp +++ b/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/permute.hpp b/projects/composablekernel/include/ck_tile/ops/permute.hpp index 4d37f4fbc12a..c7747a67e70c 100644 --- a/projects/composablekernel/include/ck_tile/ops/permute.hpp +++ b/projects/composablekernel/include/ck_tile/ops/permute.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/pooling.hpp b/projects/composablekernel/include/ck_tile/ops/pooling.hpp index faa77d53273e..43b24c7f8caf 100644 --- a/projects/composablekernel/include/ck_tile/ops/pooling.hpp +++ b/projects/composablekernel/include/ck_tile/ops/pooling.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp" #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/reduce.hpp b/projects/composablekernel/include/ck_tile/ops/reduce.hpp index b5e53283e485..e680d2574538 100644 --- a/projects/composablekernel/include/ck_tile/ops/reduce.hpp +++ b/projects/composablekernel/include/ck_tile/ops/reduce.hpp @@ -13,6 +13,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp b/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp index f271be50068c..7ee67334d6b9 100644 --- a/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp +++ b/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp b/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp index 4c2fe9bee434..ad984c033f02 100644 --- a/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp +++ b/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/softmax.hpp b/projects/composablekernel/include/ck_tile/ops/softmax.hpp index c79ba06abfea..b810a57dda0d 100644 --- a/projects/composablekernel/include/ck_tile/ops/softmax.hpp +++ b/projects/composablekernel/include/ck_tile/ops/softmax.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp b/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp index c7c4171874aa..56e074794ddf 100644 --- a/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp +++ b/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp" #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/topk.hpp b/projects/composablekernel/include/ck_tile/ops/topk.hpp index 474ba932270c..13d818174e62 100644 --- a/projects/composablekernel/include/ck_tile/ops/topk.hpp +++ b/projects/composablekernel/include/ck_tile/ops/topk.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp b/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp index 066fbf5feea2..b7219511faa3 100644 --- a/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp +++ b/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/test/ck_tile/CMakeLists.txt b/projects/composablekernel/test/ck_tile/CMakeLists.txt index 63bf1746437c..f59646140a23 100644 --- a/projects/composablekernel/test/ck_tile/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/CMakeLists.txt @@ -69,4 +69,5 @@ add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) add_subdirectory(pooling) add_subdirectory(grouped_conv) +add_subdirectory(load_and_convert_tile) add_subdirectory(pooling_tile_engine) diff --git a/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 47a0267020e7..98d365f552ac 100644 --- a/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -93,11 +93,15 @@ using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, @@ -113,11 +117,15 @@ using KernelTypesCompV3 = ::testing::Types< std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, diff --git a/projects/composablekernel/test/ck_tile/load_and_convert_tile/CMakeLists.txt b/projects/composablekernel/test/ck_tile/load_and_convert_tile/CMakeLists.txt new file mode 100644 index 000000000000..83edc3248ea0 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/load_and_convert_tile/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9") + set(LOAD_TILE_TRANSPOSE_COMPILE_OPTIONS) + if(CK_USE_OCP_FP8) + list(APPEND LOAD_TILE_TRANSPOSE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + list(APPEND LOAD_TILE_TRANSPOSE_COMPILE_OPTIONS) + + add_gtest_executable(test_load_and_convert_tile_no_transpose test_load_and_convert_tile_no_transpose.cpp) + target_compile_options(test_load_and_convert_tile_no_transpose PRIVATE ${LOAD_TILE_TRANSPOSE_COMPILE_OPTIONS}) + + if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_load_and_convert_tile_transposed test_load_and_convert_tile_transposed.cpp) + target_compile_options(test_load_and_convert_tile_transposed PRIVATE ${LOAD_TILE_TRANSPOSE_COMPILE_OPTIONS}) + endif() +endif() diff --git a/projects/composablekernel/test/ck_tile/load_and_convert_tile/kernel.hpp b/projects/composablekernel/test/ck_tile/load_and_convert_tile/kernel.hpp new file mode 100644 index 000000000000..9689038a4d84 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/load_and_convert_tile/kernel.hpp @@ -0,0 +1,227 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +template +struct LoadAndConvertShape +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); + + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t BlockSize = + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{}); +}; + +template +struct LoadAndConvertProblem +{ + using XDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using LoadTranspose = remove_cvref_t; +}; + +template +struct LoadAndConvertPolicy +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using LoadTranspose = ck_tile::remove_cvref_t; + + template + CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() + { + using S = typename Problem::BlockShape; + + if constexpr(NumAccess == 1) + return tile_distribution_encoding< + sequence<1>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + else + return tile_distribution_encoding< + sequence<1>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<1, 2>>{}; + } + + template + CK_TILE_DEVICE static constexpr auto GetVectorSize() + { + return DS_READ_TR_SIZE() / sizeof(DataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t thread_elements = S::Warp_M * S::Warp_N / get_warp_size(); + constexpr index_t NumAccess = + LoadTranspose::value ? thread_elements / GetVectorSize() : 1; + + constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding< + sequence<>, + tuple, + sequence<>>, + tuple>, + tuple>, + sequence<1>, + sequence<0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encode, get_warp_dstr_encoding()); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + template + CK_TILE_DEVICE static constexpr auto MakeDRAMTransposedDistribution() + { + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename decltype(MakeDRAMDistribution())::DstrEncode, + DataType>::TransposedDstrEncode{}); + } +}; + +template +struct LoadAndConvertKernel +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + using LoadTranspose = ck_tile::remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + + CK_TILE_HOST static auto BlockSize() + { + if(ck_tile::is_wave32()) + { + return kBlockSize / 2; + } + else + { + return kBlockSize; + } + } + + CK_TILE_DEVICE void operator()(const XDataType* a, YDataType* c, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + constexpr auto block_dims = make_tuple(S::Block_M, S::Block_N); + constexpr auto block_strides = make_tuple(1, S::Block_M); + + const index_t m_block_base = get_block_id() * S::Block_M; + + // LDS buffer + __shared__ XDataType a_lds[S::Block_M * S::Block_N]; + + auto a_lds_view = make_naive_tensor_view( + a_lds, block_dims, block_strides, number<1>{}, number<1>{}); + + auto a_block_lds_write_window = make_tile_window(a_lds_view, block_dims, {0, 0}); + + auto a_block_lds_read_window = [&] { + if constexpr(LoadTranspose::value) + { + constexpr auto block_dims_t = make_tuple(S::Block_N, S::Block_M); + constexpr auto block_strides_t = make_tuple(S::Block_M, 1); + + auto view = make_naive_tensor_view( + a_lds, + block_dims_t, + block_strides_t, + number()>{}, + number<1>{}); + + return make_tile_window( + view, + block_dims_t, + {0, 0}, + Policy::template MakeDRAMTransposedDistribution()); + } + else + { + return make_tile_window( + a_lds_view, + block_dims, + {0, 0}, + Policy::template MakeDRAMDistribution()); + } + }(); + + // Input tensor + const auto a_tensor = make_naive_tensor_view( + a, make_tuple(M, N), make_tuple(1, M), number<1>{}, number<1>{}); + + auto a_block_window = + make_tile_window(a_tensor, + block_dims, + {m_block_base, 0}, + Policy::template MakeDRAMDistribution()); + + // Output tensor + const auto c_tensor = make_naive_tensor_view( + c, make_tuple(M, N), make_tuple(1, M), number<1>{}, number<1>{}); + + auto c_block_window = + make_tile_window(c_tensor, + block_dims, + {m_block_base, 0}, + Policy::template MakeDRAMDistribution()); + + const index_t num_n_loops = N / S::Block_N; + for(index_t n_iter = 0; n_iter < num_n_loops; ++n_iter) + { + auto dram_tile = load_tile(a_block_window); + store_tile(a_block_lds_write_window, dram_tile); + block_sync_lds(); + + decltype(load_tile(c_block_window)) c_tile; + load_and_convert_tile<8, LoadTranspose::value>(c_tile, a_block_lds_read_window); + store_tile(c_block_window, c_tile); + + if(n_iter < num_n_loops - 1) + { + move_tile_window(a_block_window, {0, S::Block_N}); + move_tile_window(c_block_window, {0, S::Block_N}); + } + } + } +}; + +} // namespace ck_tile diff --git a/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_no_transpose.cpp b/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_no_transpose.cpp new file mode 100644 index 000000000000..bffb830a7267 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_no_transpose.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_load_and_convert_tile_util.hpp" + +using TestTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestLoadAndConvert, TestTypes); + +TYPED_TEST(TestLoadAndConvert, TestNoTranspose) { this->RunTest(); } diff --git a/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_transposed.cpp b/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_transposed.cpp new file mode 100644 index 000000000000..07717902b1e3 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_transposed.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_load_and_convert_tile_util.hpp" + +using TestTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestLoadAndConvert, TestTypes); + +TYPED_TEST(TestLoadAndConvert, TestTransposed) { this->RunTest(); } diff --git a/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_util.hpp b/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_util.hpp new file mode 100644 index 000000000000..b2984ece2b2f --- /dev/null +++ b/projects/composablekernel/test/ck_tile/load_and_convert_tile/test_load_and_convert_tile_util.hpp @@ -0,0 +1,101 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/common.hpp" +#include "kernel.hpp" + +// Enum struct specifying what kind of test matrix to use +enum struct TestMatrixType +{ + MonotonicSequence = 0, + Identity = 1, + UniformDistribution = 2 +}; + +static constexpr auto matrix_type = TestMatrixType::UniformDistribution; + +#define PRINT_MATRICES 0 + +template +class TestLoadAndConvert : public ::testing::Test +{ + public: + using XDataType = std::tuple_element_t<0, Tuple>; + using YDataType = std::tuple_element_t<1, Tuple>; + using LoadTranspose = std::tuple_element_t<2, Tuple>; + + protected: + void RunTest() + { + constexpr ck_tile::index_t M = 256; + constexpr ck_tile::index_t N = 256; + + ck_tile::HostTensor h_a({M, N}); + ck_tile::HostTensor h_c({M, N}); + + if constexpr(matrix_type == TestMatrixType::MonotonicSequence) + { + ck_tile::HostTensor h_a_tmp({M, N}); + ck_tile::FillMonotonicSeq{0.0, 0.1}(h_a_tmp); + ck_tile::reference_unary_elementwise( + h_a_tmp, h_a, [](const auto& x) { return x; }); + } + else if constexpr(matrix_type == TestMatrixType::Identity) + { + ck_tile::FillIdentity{M, N}(h_a); + } + else + { + ck_tile::FillUniformDistributionIntegerValue{-5.0, 5.0}(h_a); + } + + ck_tile::DeviceMem d_a(h_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_c(h_c.get_element_space_size_in_bytes()); + + d_a.ToDevice(h_a.data()); + + using BlockWarps = ck_tile::sequence<4, 4>; + using BlockTile = ck_tile::sequence<512, 32>; + using WarpTile = ck_tile::sequence<64, 8>; + using Vector = ck_tile::sequence<1, 8>; + + using Shape = ck_tile::LoadAndConvertShape; + using Problem = ck_tile::LoadAndConvertProblem; + using Policy = ck_tile::LoadAndConvertPolicy; + using Kernel = ck_tile::LoadAndConvertKernel; + + const ck_tile::index_t block_size = Kernel::BlockSize(); + const ck_tile::index_t grid_size = ck_tile::integer_divide_ceil(M, Shape::Block_M) * + ck_tile::integer_divide_ceil(N, Shape::Block_N); + + launch_kernel(ck_tile::stream_config{nullptr, true}, + make_kernel<1>(Kernel{}, + dim3(grid_size), + dim3(block_size), + 0, + static_cast(d_a.GetDeviceBuffer()), + static_cast(d_c.GetDeviceBuffer()), + M, + N)); + + ck_tile::hip_check_error(hipDeviceSynchronize()); + d_c.FromDevice(h_c.data()); + ck_tile::HostTensor h_a_ref({M, N}); + ck_tile::reference_unary_elementwise( + h_a, h_a_ref, [](const auto& x) { return x; }); + bool pass = ck_tile::check_err(h_c, h_a_ref); + +#if PRINT_MATRICES + auto [width, precision] = matrix_type == TestMatrixType::MonotonicSequence + ? std::make_pair(3, 3) + : std::make_pair(2, 6); + print_matrix(h_a, "Matrix A", width, precision); + print_matrix(h_c, "Matrix C", width, precision); +#endif + + EXPECT_TRUE(pass); + } +};