diff --git a/projects/composablekernel/CHANGELOG.md b/projects/composablekernel/CHANGELOG.md index f6812a8520f1..11ab9ca11669 100644 --- a/projects/composablekernel/CHANGELOG.md +++ b/projects/composablekernel/CHANGELOG.md @@ -31,6 +31,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added support for fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline * Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. diff --git a/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp b/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp index 5f73d4934a1f..3636447f82f3 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -14,6 +14,7 @@ #include "ck_tile/core/container/statically_indexed_array.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { @@ -529,4 +530,118 @@ load_tile_transpose(const tile_window_with_static_distribution, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window, + const index_t offset, + number = {}) +{ + using SrcDataType = typename BottomTensorView_::DataType; + using DstDataType = typename DistributedTensor_::DataType; + + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); + constexpr auto input_distr = TileDistribution_{}; + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; + + constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); + constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); + + constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths()); + constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths()); + + constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size(); + constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size(); + + // For mixed precision: element space size must be the same (total bytes match) + static_assert(y_in_element_space_size == y_out_element_space_size, + "For mixed precision transpose, input and output element space size must match!"); + + // Ensure total element counts are consistent and divisible by the input vector length. + constexpr index_t total_elems_in = + reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}); + constexpr index_t total_elems_out = + reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{}); + static_assert(total_elems_in == total_elems_out, + "For mixed precision transpose, input/output element counts must match!"); + static_assert(total_elems_in % number{} == 0, + "Input vector length must evenly divide total elements."); + + constexpr index_t num_of_access = total_elems_in / number{}; + + // Read as input type, convert to output type + using SrcDataVec = ext_vector_t{}>; + using DstDataVec = ext_vector_t{}>; + static_for<0, num_of_access, 1>{}([&](auto i) { + static_assert(number{} == 8, "Only PassThroughPack8 is supported for now."); + const element_wise::PassThroughPack8 elementwise_op{}; + + elementwise_op(out_tensor.get_thread_buffer().template get_as()(i), + trans_tensor.get_thread_buffer().template get_as()[i]); + }); +} + +/** + * @brief Mixed-precision transpose load with zero offset. + * + * Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0. + */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord_, + index_t UnaryOpSize_, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window, + number = {}) +{ + load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0, number{}); +} + } // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index aa0f632c2169..a62bbe981cca 100644 --- a/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/projects/composablekernel/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp b/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp index 9c90db67eddc..71919b61873d 100644 --- a/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp +++ b/projects/composablekernel/include/ck_tile/ops/batched_contraction.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp b/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp index 9cac035c4457..924db5fb60e4 100644 --- a/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp +++ b/projects/composablekernel/include/ck_tile/ops/batched_transpose.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/common.hpp b/projects/composablekernel/include/ck_tile/ops/common.hpp index ad7da5c18339..0113d8c9a280 100644 --- a/projects/composablekernel/include/ck_tile/ops/common.hpp +++ b/projects/composablekernel/include/ck_tile/ops/common.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp new file mode 100644 index 000000000000..acdc32008d60 --- /dev/null +++ b/projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +// DetermineWarpPrecType is a set of pattern-matching rules to determine the right precision types +// to use for the warp GEMM, given the precision types defined in the problem, and the compute data +// type. This gives rise to a type conversion: type conversions are sometimes needed to obtain +// a pair of types that are compatible with the hardware matrix operations available. A typical +// use case is mixed precision GEMMs. + +namespace ck_tile { +// For the most general case, default to no conversion. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = BPrecType; +}; + +// For pk_fp4_t x pk_fp4_t, keep pk_fp4_t +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::pk_fp4_t; + using b_prec_type = ck_tile::pk_fp4_t; +}; + +// For pk_int4_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_int4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For pk_fp4_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For pk_fp4_raw_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_raw_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For fp8 x bf16, use fp8 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; +}; + +// For bf16 x fp8, use bf16 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf16_t; + using b_prec_type = ck_tile::bf16_t; +}; + +// For bf8 x bf16, use bf8 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf8_t; + using b_prec_type = ck_tile::bf8_t; +}; + +// For bf16 x bf8, use bf16 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf16_t; + using b_prec_type = ck_tile::bf16_t; +}; + +// For fp8 x fp16, use fp8 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; +}; + +// For fp16 x fp8, use fp16 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::half_t; + using b_prec_type = ck_tile::half_t; +}; + +// For fp16 x bf8, use fp16 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::half_t; + using b_prec_type = ck_tile::half_t; +}; +}; // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp b/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp index 0748c5fb49e8..6dc2ff35c0a1 100644 --- a/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -10,15 +10,19 @@ namespace ck_tile { -template +template struct ConverterLoader { template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src_window) { + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto tmp = load_tile(src); + const auto src = load_tile(src_window); // NOTE: we rely on types packing neatly here using RawSrcType = typename SrcDataType::type; @@ -30,30 +34,56 @@ struct ConverterLoader const element_wise::PassThroughPack8 elementwise_op{}; elementwise_op(dst.get_thread_buffer().template get_as()(i), - tmp.get_thread_buffer().template get_as()[i]); + src.get_thread_buffer().template get_as()[i]); }); } + + template + CK_TILE_DEVICE static void load_with_type_convert(WarpTile& dst, const WarpWindow& src_window) + { + if constexpr(LoadTranspose) + { + if constexpr(std::is_same_v) + { + load_tile_transpose(dst, src_window); + } + else + { + load_tile_transpose_convert(dst, src_window, number{}); + } + } + else + { + if constexpr(std::is_same_v) + { + load_tile(dst, src_window); + } + else + { + auto tmp = load_tile(src_window); + sweep_tile([&](auto i) { + dst(i) = type_convert(type_convert(tmp(i))); + }); + } + } + } }; template -CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src_window) { using SrcDataType = typename WarpWindow::Base::DataType; using DstDataType = typename WarpTile::DataType; if constexpr(is_packed_type_v && !is_packed_type_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); - ConverterLoader::load_interleaved_pk_type(dst, src); - } - else if constexpr(LoadTranspose) - { - load_tile_transpose(dst, src); + ConverterLoader:: + load_interleaved_pk_type(dst, src_window); } else { - load_tile(dst, src); + ConverterLoader:: + load_with_type_convert(dst, src_window); } } - } // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/elementwise.hpp b/projects/composablekernel/include/ck_tile/ops/elementwise.hpp index bc72f3b0ba1f..2c0ae4ad093f 100644 --- a/projects/composablekernel/include/ck_tile/ops/elementwise.hpp +++ b/projects/composablekernel/include/ck_tile/ops/elementwise.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 4e971649d00c..1d1d1813dc93 100644 --- a/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/projects/composablekernel/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -447,6 +447,15 @@ CK_TILE_HOST_DEVICE bf16x8_t fp8x8_to_bf16x8_scale(const fp8x8_t& src, const flo return y; } +CK_TILE_HOST_DEVICE fp8x8_t bf16x8_to_fp8x8_scale(const bf16x8_t& src, const float& scale) +{ + fp8x8_t y; + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); + return y; +} + CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const float& scale) { fp16x8_t y; @@ -491,6 +500,15 @@ CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const flo return y; } +CK_TILE_HOST_DEVICE fp8x8_t fp16x8_to_fp8x8_scale(const fp16x8_t& src, const float& scale) +{ + fp8x8_t y; + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); + return y; +} + CK_TILE_HOST_DEVICE fp16x8_t bf8x8_to_fp16x8_scale(const bf8x8_t& src, const float& scale) { fp16x8_t y; @@ -620,12 +638,32 @@ struct PassThroughPack8 template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const fp8x8_t& x) const + { + y = fp8x8_to_fp16x8_scale(x, 1.0f); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const fp16x8_t& x) const + { + y = fp16x8_to_fp8x8_scale(x, 1.0f); + } + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const { y.lo = i4_to_half4(bit_cast(x)); y.hi = i4_to_half4(bit_cast(x) >> 8); } + CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const fp8x8_t& x) const + { + y = fp8x8_to_bf16x8_scale(x, 1.0f); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const bf16x8_t& x) const + { + y = bf16x8_to_fp8x8_scale(x, 1.0f); + } + CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const { y.lo = i4_to_bhalf4(bit_cast(x)); diff --git a/projects/composablekernel/include/ck_tile/ops/epilogue.hpp b/projects/composablekernel/include/ck_tile/ops/epilogue.hpp index b7a119d756cf..b97390f9e422 100644 --- a/projects/composablekernel/include/ck_tile/ops/epilogue.hpp +++ b/projects/composablekernel/include/ck_tile/ops/epilogue.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/permuten_epilogue.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp b/projects/composablekernel/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp index e8bd8c0c7de2..3f1e4e47b48c 100644 --- a/projects/composablekernel/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp +++ b/projects/composablekernel/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp @@ -239,7 +239,7 @@ struct CShuffleEpilogueChainBaseOp * * @return The vector store size for C tensor. */ - CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() + CK_TILE_DEVICE static constexpr index_t GetVectorSizeC() { if constexpr(FixedVectorSize) { @@ -268,7 +268,7 @@ struct CShuffleEpilogueChainBaseOp * @return The vector store size for Di tensor. */ template - CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) + CK_TILE_DEVICE static constexpr index_t GetVectorSizeD(number index) { constexpr index_t max_vector_size = 16; using DiDataType = remove_cvref_t>; @@ -354,7 +354,7 @@ struct CShuffleEpilogueChainBaseOp sequence>; template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsBlockDescriptor() { // N is contiguous dimension if constexpr(std::is_same_v) @@ -409,7 +409,7 @@ struct CShuffleEpilogueChainBaseOp return block_dstr_encoding; } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } diff --git a/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index b0e55d239f46..0e02a3efa0b1 100644 --- a/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/projects/composablekernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" @@ -100,19 +101,13 @@ struct CShuffleEpilogue using ATypeToUse = std::conditional_t, tf32_t, - std::conditional_t || - std::is_same_v, - BDataTypeBuf, - ADataTypeBuf>>; + typename DetermineWarpPrecType::a_prec_type>; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, tf32_t, - std::conditional_t || - std::is_same_v || - sizeof(BDataTypeBuf) < sizeof(ADataTypeBuf), - ADataTypeBuf, - BDataTypeBuf>>; + typename DetermineWarpPrecType::b_prec_type>; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/flatmm.hpp b/projects/composablekernel/include/ck_tile/ops/flatmm.hpp index e08fac48c7e9..2e71957ac774 100644 --- a/projects/composablekernel/include/ck_tile/ops/flatmm.hpp +++ b/projects/composablekernel/include/ck_tile/ops/flatmm.hpp @@ -21,6 +21,7 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha.hpp b/projects/composablekernel/include/ck_tile/ops/fmha.hpp index cf651312d911..a8a4ae86fdd5 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha.hpp @@ -56,13 +56,14 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp b/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp index 60f5bd1c4e35..2eb4abd64117 100644 --- a/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fused_moe.hpp @@ -14,6 +14,7 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/gemm.hpp b/projects/composablekernel/include/ck_tile/ops/gemm.hpp index 7c087e9186db..41f993bcbfaf 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm.hpp @@ -17,7 +17,10 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" @@ -84,6 +87,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 2b64f6e340c4..42e1f3764ee0 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -94,13 +94,8 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; - using BTypeToUse = std::conditional_t || - std::is_same_v || - sizeof(BDataType) < sizeof(ADataType), - ADataType, - BDataType>; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using WarpGemm = remove_cvref_t; @@ -140,6 +135,7 @@ struct BlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + template CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -159,12 +155,28 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(convert) + { + using Attr = typename WarpGemm::WarpGemmAttribute; + + constexpr auto NumAccessA = + Attr::AttrNumAccessAV * sizeof(ADataType) / sizeof(ComputeDataType); + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_awarp_dstr_encoding()); + + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - return a_block_dstr_encode; + return a_block_dstr_encode; + } } + template CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -184,10 +196,24 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(convert) + { + using Attr = typename WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessB = + Attr::AttrNumAccessBV * sizeof(BDataType) / sizeof(ComputeDataType); + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_bwarp_dstr_encoding()); + + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } } template diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 83f7f8082410..f69d1aa2ba62 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -441,10 +441,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index b4a8e9e8cb48..ab72a9b56077 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -321,15 +321,12 @@ struct UniversalGemmBasePolicy * @tparam Problem Gemm pipeline problem. * @return B tensor LDS block descriptor. */ - template + template > CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; - using BDataType = std::conditional_t; - + using BLayout = remove_cvref_t; + using BDataType = OverrideBDataType; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -709,7 +706,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; @@ -746,7 +743,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeBDramTileDistribution() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; @@ -786,7 +783,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeShuffledARegTileDistribution() { using ALayout = remove_cvref_t< std::tuple_element_t{}, remove_cvref_t>>; @@ -807,7 +804,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() { using BLayout = remove_cvref_t< std::tuple_element_t{}, remove_cvref_t>>; @@ -828,7 +825,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA() + CK_TILE_DEVICE static constexpr index_t GetSmemPackA() { using A = remove_cvref_t; using BlockGemm = remove_cvref_t())>; @@ -840,7 +837,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB() + CK_TILE_DEVICE static constexpr index_t GetSmemPackB() { using B = remove_cvref_t; using BlockGemm = remove_cvref_t())>; @@ -852,7 +849,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { using ADataType = remove_cvref_t; constexpr auto APackedSize = numeric_traits::PackedSize; @@ -863,7 +860,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; using BDataType = std::conditional_t - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { constexpr index_t smem_size_a = GetSmemSizeA(); constexpr index_t smem_size_b = GetSmemSizeB(); @@ -891,31 +888,31 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy { template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + CK_TILE_DEVICE static constexpr auto GetBlockGemm() { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr index_t vector_size = - DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr bool is_load_tr = is_a_load_tr || is_b_load_tr; + constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(typename Problem::ADataType); constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); constexpr auto wg_attr_num_access = - !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single - : vector_size == thread_elements ? WGAttrNumAccessEnum::Single - : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double - : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad - : WGAttrNumAccessEnum::Invalid; + !(is_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - using ATypeToUse = if_select_t; - using BTypeToUse = std::conditional_t || - std::is_same_v || - sizeof(BDataType) < sizeof(ADataType), - ADataType, - BDataType>; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; + + // Use the packed data access when the data types have different sizes to make their data + // access patterns compatible when using transposed load. + constexpr bool use_pack_num_access = sizeof(ADataType) != sizeof(BDataType) && is_load_tr; using WarpGemm = WarpGemmDispatcher, @@ -927,7 +924,9 @@ struct UniversalGemmPipelineAgBgCrPolicy Problem::TransposeC, false, Problem::UseStructuredSparsity, - wg_attr_num_access>; + wg_attr_num_access, + wg_attr_num_access, + use_pack_num_access>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy, AttrNumAccess>>; -template +template using WarpGemmMfmaTf32Tf32F32M16N16K32 = WarpGemmImpl, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; #endif // fp16 @@ -79,16 +83,24 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAttributeMfma>>; #if defined(__gfx950__) -template +template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< WarpGemmAttributeMfma, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; #else -template +template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; #endif #if defined(__gfx950__) @@ -200,48 +212,65 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAttributeMfma>>; #if defined(__gfx950__) -template +template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< WarpGemmAttributeMfma, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; #else -template +template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; #endif #if defined(__gfx950__) template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool PackNumAccess = AttrNumAccessA != AttrNumAccessB> using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccessA, - AttrNumAccessB>>; + AttrNumAccessB, + PackNumAccess>>; template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool PackNumAccess = AttrNumAccessA != AttrNumAccessB> using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, 2, AttrNumAccessA, - AttrNumAccessB>>; + AttrNumAccessB, + PackNumAccess>>; #else template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool PackNumAccess = AttrNumAccessA != AttrNumAccessB> using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, - AttrNumAccessA>>; + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool PackNumAccess = AttrNumAccessA != AttrNumAccessB> using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, 4, AttrNumAccessA, - AttrNumAccessB>>; + AttrNumAccessB, + PackNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl>; #endif +template using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< - WarpGemmAttributeMfma>>; + WarpGemmAttributeMfma, + AttrNumAccessA, + AttrNumAccessB, + PackNumAccess>>; using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index f79741ea9606..323222fff4d8 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -48,7 +48,8 @@ struct get_wgattr_num_access template + WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_, + bool PackNumAccess_ = AttrNumAccessA_ != AttrNumAccessB_> struct WarpGemmAttributeMfma { using Impl = remove_cvref_t; @@ -57,7 +58,7 @@ struct WarpGemmAttributeMfma static constexpr auto AttrNumAccessB = AttrNumAccessB_; static constexpr auto AttrNumAccessBV = get_wgattr_num_access::value; - static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB; + static constexpr bool UsePackNumAccess = PackNumAccess_; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -132,8 +133,23 @@ struct WarpGemmAttributeMfma } } } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + template + static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -187,7 +203,8 @@ struct WarpGemmAttributeMfma template + WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_, + bool PackNumAccess_ = AttrNumAccessA_ != AttrNumAccessB_> struct WarpGemmAttributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); @@ -198,7 +215,7 @@ struct WarpGemmAttributeMfmaIterateK static constexpr auto AttrNumAccessB = AttrNumAccessB_; static constexpr auto AttrNumAccessBV = get_wgattr_num_access::value; - static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB; + static constexpr bool UsePackNumAccess = PackNumAccess_; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -297,6 +314,24 @@ struct WarpGemmAttributeMfmaIterateK } } + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) @@ -335,14 +370,12 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -421,10 +454,25 @@ struct WarpGemmAttributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = - typename WarpGemmAttributeMfma::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfma::AWarpDstrEncoding; + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_bwarp_dstr_encoding< + AttrNumAccessV_>(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_awarp_dstr_encoding< + AttrNumAccessV_>(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -491,6 +539,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -499,6 +548,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution { - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // swap A and B using ADataType = typename Impl::BDataType; @@ -641,10 +693,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution } } - using AWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::AWarpDstrEncoding; + template + using AWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template BWarpDstrEncoding; + template + using BWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template AWarpDstrEncoding; using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -723,6 +777,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -731,6 +786,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence<1>>; + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ef31d06c9c27..4a275848b2b6 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -67,6 +67,10 @@ template struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // AttrNumAccessV is required for compatibility with the block GEMM, and is currently ignored + // within WarpGemmAttributeWmma + static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // When kTransC is true and A/B types differ, we need an impl with swapped types using TransposedImpl = @@ -99,8 +103,22 @@ struct WarpGemmAttributeWmma // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2 // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4 - using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; - using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; + template + static constexpr auto get_awarp_dstr_encoding() + { + return typename AWarpDstrEncodingTrait::type{}; + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return typename BWarpDstrEncodingTrait::type{}; + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 using CWarpDstrEncoding = diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index f59bd61db76e..b4be74a7c848 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -27,7 +27,8 @@ template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool PackNumAccess = AttrNumAccessA != AttrNumAccessB> struct Dispatcher; // clang-format off @@ -51,6 +52,7 @@ template<> struct Dispatcher struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32; }; #endif // Note: For gfx11/gfx12 and other architectures that don't support tf32, @@ -63,6 +65,7 @@ template<> struct Dispatcher { using template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; @@ -95,14 +98,17 @@ template<> struct Dispatcher { using template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16;}; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; @@ -121,7 +127,8 @@ template<> struct Dispatcher { u // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; @@ -208,7 +215,8 @@ template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool PackNumAccess_ = AttrNumAccessA != AttrNumAccessB> using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // AType, BType, @@ -220,6 +228,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // SwizzleA, UseStructuredSparsity, AttrNumAccessA, - AttrNumAccessB>::Type; + AttrNumAccessB, + PackNumAccess_>::Type; } // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index ca7c32b6af55..5ff0660f49c4 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -25,8 +25,8 @@ struct WarpGemmImpl using BDataType = typename WarpGemmAttribute::BDataType; using CDataType = typename WarpGemmAttribute::CDataType; - using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; - using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using AWarpDstrEncoding = typename WarpGemmAttribute::template AWarpDstrEncoding<>; + using BWarpDstrEncoding = typename WarpGemmAttribute::template BWarpDstrEncoding<>; using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; using AWarpDstr = remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp b/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp index 29fccf8057b9..fd04c1d0234a 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm_mx.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp b/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp index 5b2ce7ff1915..907aabc0f1e7 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm_quant.hpp @@ -33,6 +33,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp index 5bc4f0c6a042..1a115f9489e6 100644 --- a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp +++ b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp @@ -12,6 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp b/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp index 07d99890869e..faa165a8b001 100644 --- a/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp +++ b/projects/composablekernel/include/ck_tile/ops/image_to_column.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp b/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp index 8f9ab205ac45..2266c138729a 100644 --- a/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp +++ b/projects/composablekernel/include/ck_tile/ops/layernorm2d.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp b/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp index eae0ea14a337..9f572ff5cbb2 100644 --- a/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp +++ b/projects/composablekernel/include/ck_tile/ops/norm_reduce.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/permute.hpp b/projects/composablekernel/include/ck_tile/ops/permute.hpp index 4d37f4fbc12a..c7747a67e70c 100644 --- a/projects/composablekernel/include/ck_tile/ops/permute.hpp +++ b/projects/composablekernel/include/ck_tile/ops/permute.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/pooling.hpp b/projects/composablekernel/include/ck_tile/ops/pooling.hpp index faa77d53273e..43b24c7f8caf 100644 --- a/projects/composablekernel/include/ck_tile/ops/pooling.hpp +++ b/projects/composablekernel/include/ck_tile/ops/pooling.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp" #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/reduce.hpp b/projects/composablekernel/include/ck_tile/ops/reduce.hpp index b5e53283e485..e680d2574538 100644 --- a/projects/composablekernel/include/ck_tile/ops/reduce.hpp +++ b/projects/composablekernel/include/ck_tile/ops/reduce.hpp @@ -13,6 +13,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp b/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp index f271be50068c..7ee67334d6b9 100644 --- a/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp +++ b/projects/composablekernel/include/ck_tile/ops/rmsnorm2d.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp b/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp index 4c2fe9bee434..ad984c033f02 100644 --- a/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp +++ b/projects/composablekernel/include/ck_tile/ops/smoothquant.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/softmax.hpp b/projects/composablekernel/include/ck_tile/ops/softmax.hpp index c79ba06abfea..b810a57dda0d 100644 --- a/projects/composablekernel/include/ck_tile/ops/softmax.hpp +++ b/projects/composablekernel/include/ck_tile/ops/softmax.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp b/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp index c7c4171874aa..56e074794ddf 100644 --- a/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp +++ b/projects/composablekernel/include/ck_tile/ops/sparse_attn.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp" #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/topk.hpp b/projects/composablekernel/include/ck_tile/ops/topk.hpp index 474ba932270c..13d818174e62 100644 --- a/projects/composablekernel/include/ck_tile/ops/topk.hpp +++ b/projects/composablekernel/include/ck_tile/ops/topk.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp b/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp index 066fbf5feea2..b7219511faa3 100644 --- a/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp +++ b/projects/composablekernel/include/ck_tile/ops/topk_softmax.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 47a0267020e7..02c5792834d9 100644 --- a/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -83,41 +83,60 @@ using KernelTypesMemWmma = ::testing::Types< using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,