Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
8f20464
Rename the parameters of load_interleaved_pk_type and load_and_conver…
SamiAario-AMD Jan 12, 2026
f2ebac4
Add load_tile_transpose_convert for mixed precision transpose loading
SamiAario-AMD Jan 26, 2026
e4bc421
Add and use load_with_type_convert
SamiAario-AMD Nov 12, 2025
952d945
Introduce DetermineWarpPrecType for determining warp GEMM precision t…
SamiAario-AMD Oct 9, 2025
47427e4
Add functionality and tests for bf16 x fp8 and fp8 x bf16
SamiAario-AMD Oct 9, 2025
1831bc0
Add functionality and tests for fp16 x fp8 and fp8 x fp16
SamiAario-AMD Nov 12, 2025
9769281
Make some V3 pipeline functions device-only
SamiAario-AMD Apr 29, 2026
5224025
Make some C shuffle epilogue functions device-only
SamiAario-AMD Apr 30, 2026
359db8e
Add AttrNumAccessV_ as a template parameter to WarpGemmAttributeMfma:…
SamiAario-AMD Mar 2, 2026
24df65e
Refactor type conversions out of MakeBLdsBlockDescriptor
SamiAario-AMD Dec 18, 2025
4ad7eb7
Add a changelog entry
SamiAario-AMD Jan 28, 2026
4d2f508
PackNumAccess WIP
SamiAario-AMD Apr 29, 2026
2b33aa8
Add support and tests for mixed precision fp16-x-fp8 with transposed …
SamiAario-AMD Apr 30, 2026
5d0ac84
Add support and tests for mixed precision fp8-x-fp16 with transposed …
SamiAario-AMD Apr 30, 2026
8e73da5
Add support and tests for mixed precision bf16-x-fp8 with transposed …
SamiAario-AMD May 4, 2026
a41d430
Add support and tests for mixed precision fp8-x-bf16 with transposed …
SamiAario-AMD May 4, 2026
c872c6b
Remod changes
SamiAario-AMD May 5, 2026
753b287
Add missing warp GEMM instances requested by test_tile_gemm_quant_bqu…
SamiAario-AMD May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/composablekernel/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.2.0

### Added
* Added support for fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"

namespace ck_tile {

Expand Down Expand Up @@ -529,4 +530,118 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
return out_tensor;
}

/**
* @brief Mixed-precision transpose load: converts input data type to output data type while
* transposing.
*
* This function enables transposing from one data type (e.g., fp8) to another (e.g., fp16) in a
* single operation. The input tile distribution encoding must be valid for the input data type,
* and the output distribution will be generated based on the output data type.
*
* @tparam DistributedTensor_ The output tensor type with desired output data type.
* @tparam BottomTensorView_ The input tensor view (may have different data type than output).
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution for input.
* @tparam NumCoord_ The number of coordinates (dimensions).
* @tparam Policy The transpose policy (should validate against input type).
*
* @note
* - Input and output must have compatible element space sizes (total byte count per Y-space).
* - Type conversion is performed element-by-element during the copy.
* - The validation uses the input data type for quad pattern checking.
* - The output distribution is generated based on the output data type.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord_,
index_t UnaryOpSize_,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord_>& __restrict__ tile_window,
const index_t offset,
number<UnaryOpSize_> = {})
{
using SrcDataType = typename BottomTensorView_::DataType;
using DstDataType = typename DistributedTensor_::DataType;

auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};

constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();

constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());

constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();

// For mixed precision: element space size must be the same (total bytes match)
static_assert(y_in_element_space_size == y_out_element_space_size,
"For mixed precision transpose, input and output element space size must match!");

// Ensure total element counts are consistent and divisible by the input vector length.
constexpr index_t total_elems_in =
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
constexpr index_t total_elems_out =
reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
static_assert(total_elems_in == total_elems_out,
"For mixed precision transpose, input/output element counts must match!");
static_assert(total_elems_in % number<UnaryOpSize_>{} == 0,
"Input vector length must evenly divide total elements.");

constexpr index_t num_of_access = total_elems_in / number<UnaryOpSize_>{};

// Read as input type, convert to output type
using SrcDataVec = ext_vector_t<SrcDataType, number<UnaryOpSize_>{}>;
using DstDataVec = ext_vector_t<DstDataType, number<UnaryOpSize_>{}>;
static_for<0, num_of_access, 1>{}([&](auto i) {
static_assert(number<UnaryOpSize_>{} == 8, "Only PassThroughPack8 is supported for now.");
const element_wise::PassThroughPack8 elementwise_op{};

elementwise_op(out_tensor.get_thread_buffer().template get_as<DstDataVec>()(i),
trans_tensor.get_thread_buffer().template get_as<SrcDataVec>()[i]);
});
}

/**
* @brief Mixed-precision transpose load with zero offset.
*
* Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord_,
index_t UnaryOpSize_,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void load_tile_transpose_convert(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord_>& __restrict__ tile_window,
number<UnaryOpSize_> = {})
{
load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0, number<UnaryOpSize_>{});
}

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
1 change: 1 addition & 0 deletions projects/composablekernel/include/ck_tile/ops/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#pragma once

#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include "ck_tile/core.hpp"

// DetermineWarpPrecType is a set of pattern-matching rules to determine the right precision types
// to use for the warp GEMM, given the precision types defined in the problem, and the compute data
// type. This gives rise to a type conversion: type conversions are sometimes needed to obtain
// a pair of types that are compatible with the hardware matrix operations available. A typical
// use case is mixed precision GEMMs.

namespace ck_tile {
// For the most general case, default to no conversion.
template <typename APrecType, typename BPrecType>
struct DetermineWarpPrecType
{
using a_prec_type = APrecType;
using b_prec_type = BPrecType;
};

// For pk_fp4_t x pk_fp4_t, keep pk_fp4_t
template <>
struct DetermineWarpPrecType<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t>
{
using a_prec_type = ck_tile::pk_fp4_t;
using b_prec_type = ck_tile::pk_fp4_t;
};

// For pk_int4_t x B, use the B type.
template <typename BPrecType>
struct DetermineWarpPrecType<ck_tile::pk_int4_t, BPrecType>
{
using a_prec_type = BPrecType;
using b_prec_type = BPrecType;
};

// For A x pk_int4_t, use the A type.
template <typename APrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_int4_t>
{
using a_prec_type = APrecType;
using b_prec_type = APrecType;
};

// For pk_fp4_t x B, use the B type.
template <typename BPrecType>
struct DetermineWarpPrecType<ck_tile::pk_fp4_t, BPrecType>
{
using a_prec_type = BPrecType;
using b_prec_type = BPrecType;
};

// For A x pk_fp4_t, use the A type.
template <typename APrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_fp4_t>
{
using a_prec_type = APrecType;
using b_prec_type = APrecType;
};

// For pk_fp4_raw_t x B, use the B type.
template <typename BPrecType>
struct DetermineWarpPrecType<ck_tile::pk_fp4_raw_t, BPrecType>
{
using a_prec_type = BPrecType;
using b_prec_type = BPrecType;
};

// For A x pk_fp4_raw_t, use the A type.
template <typename APrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_fp4_raw_t>
{
using a_prec_type = APrecType;
using b_prec_type = APrecType;
};

// For fp8 x bf16, use fp8
template <>
struct DetermineWarpPrecType<ck_tile::fp8_t, ck_tile::bf16_t>
{
using a_prec_type = ck_tile::fp8_t;
using b_prec_type = ck_tile::fp8_t;
};

// For bf16 x fp8, use bf16
template <>
struct DetermineWarpPrecType<ck_tile::bf16_t, ck_tile::fp8_t>
{
using a_prec_type = ck_tile::bf16_t;
using b_prec_type = ck_tile::bf16_t;
};

// For bf8 x bf16, use bf8
template <>
struct DetermineWarpPrecType<ck_tile::bf8_t, ck_tile::bf16_t>
{
using a_prec_type = ck_tile::bf8_t;
using b_prec_type = ck_tile::bf8_t;
};

// For bf16 x bf8, use bf16
template <>
struct DetermineWarpPrecType<ck_tile::bf16_t, ck_tile::bf8_t>
{
using a_prec_type = ck_tile::bf16_t;
using b_prec_type = ck_tile::bf16_t;
};

// For fp8 x fp16, use fp8
template <>
struct DetermineWarpPrecType<ck_tile::fp8_t, ck_tile::half_t>
{
using a_prec_type = ck_tile::fp8_t;
using b_prec_type = ck_tile::fp8_t;
};

// For fp16 x fp8, use fp16
template <>
struct DetermineWarpPrecType<ck_tile::half_t, ck_tile::fp8_t>
{
using a_prec_type = ck_tile::half_t;
using b_prec_type = ck_tile::half_t;
};

// For fp16 x bf8, use fp16
template <>
struct DetermineWarpPrecType<ck_tile::half_t, ck_tile::bf8_t>
{
using a_prec_type = ck_tile::half_t;
using b_prec_type = ck_tile::half_t;
};
}; // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@

namespace ck_tile {

template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false>
struct ConverterLoader
{
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src)
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src_window)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto tmp = load_tile(src);
const auto src = load_tile(src_window);

// NOTE: we rely on types packing neatly here
using RawSrcType = typename SrcDataType::type;
Expand All @@ -30,30 +34,56 @@ struct ConverterLoader
const element_wise::PassThroughPack8 elementwise_op{};

elementwise_op(dst.get_thread_buffer().template get_as<DstVectorType>()(i),
tmp.get_thread_buffer().template get_as<SrcVectorType>()[i]);
src.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
}

template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_with_type_convert(WarpTile& dst, const WarpWindow& src_window)
{
if constexpr(LoadTranspose)
{
if constexpr(std::is_same_v<SrcDataType, DstDataType>)
{
load_tile_transpose(dst, src_window);
}
else
{
load_tile_transpose_convert(dst, src_window, number<UnaryOpSize>{});
}
}
else
{
if constexpr(std::is_same_v<SrcDataType, DstDataType>)
{
load_tile(dst, src_window);
}
else
{
auto tmp = load_tile(src_window);
sweep_tile<WarpTile>([&](auto i) {
dst(i) = type_convert<DstDataType>(type_convert<float>(tmp(i)));
});
}
}
}
};

template <index_t UnaryOpSize, bool LoadTranspose = false, typename WarpTile, typename WarpWindow>
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src_window)
{
using SrcDataType = typename WarpWindow::Base::DataType;
using DstDataType = typename WarpTile::DataType;

if constexpr(is_packed_type_v<SrcDataType> && !is_packed_type_v<DstDataType>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
ConverterLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else if constexpr(LoadTranspose)
{
load_tile_transpose(dst, src);
ConverterLoader<SrcDataType, DstDataType, UnaryOpSize, LoadTranspose>::
load_interleaved_pk_type(dst, src_window);
}
else
{
load_tile(dst, src);
ConverterLoader<SrcDataType, DstDataType, UnaryOpSize, LoadTranspose>::
load_with_type_convert(dst, src_window);
}
}

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Loading
Loading