Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2044d1d
Rename the parameters of load_interleaved_pk_type and load_and_conver…
SamiAario-AMD Jan 12, 2026
f8b027b
Add load_tile_transpose_convert for mixed precision transpose loading
SamiAario-AMD Jan 26, 2026
89c358f
Add and use load_with_type_convert
SamiAario-AMD Nov 12, 2025
ab97b13
Add FillIdentity for host tensors
SamiAario-AMD Mar 17, 2026
1eb1a03
Add test_load_and_convert_tile
SamiAario-AMD Jan 13, 2026
a6309fa
In LoadAndConvertKernel, modify the input tensor view instead of the …
SamiAario-AMD Mar 19, 2026
cf86511
Only run test-load-tile-transpose on gfx950
SamiAario-AMD Mar 20, 2026
1918d45
Remove an unnecessary seed in a call to FillUniformDistributionIntege…
SamiAario-AMD Mar 23, 2026
5c2c2ae
Replace the matrix type macros with constants from an enum struct
SamiAario-AMD Mar 24, 2026
f0e5e57
Add width and precision parameters to print_matrix
SamiAario-AMD Mar 24, 2026
8c0d3fe
Move print_matrix to a separate header file under include/ck_tile/host
SamiAario-AMD Mar 24, 2026
cd42fd7
Move TestLoadAndConvert to a separate utility file
SamiAario-AMD Mar 24, 2026
9ff4970
Split the tests into three separate suites, so that only a subset of …
SamiAario-AMD Mar 24, 2026
daa36ad
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 24, 2026
f73169f
fixup! Split the tests into three separate suites, so that only a sub…
SamiAario-AMD Mar 24, 2026
97a4bab
All transposed loads are a gfx950 feature, so they should all be run …
SamiAario-AMD Mar 24, 2026
a185043
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 24, 2026
e29da0f
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 25, 2026
587955e
Adjust compile options
SamiAario-AMD Mar 26, 2026
a807103
Limit tests to gfx9 because of linker problems detected on gfx1201 an…
SamiAario-AMD Mar 26, 2026
3f45f79
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 27, 2026
4cc024f
Modify the M, N, and K parameters to span multiple tiles
SamiAario-AMD Mar 27, 2026
4ced776
No need to send C tensor to device
SamiAario-AMD Mar 27, 2026
f1b3bc1
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 27, 2026
1b4e66c
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 30, 2026
bff31f4
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 30, 2026
6a4fecb
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Mar 31, 2026
b9ddf2a
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Apr 2, 2026
6989169
Merge branch 'develop' into users/samaario/ck/test-load-tile-transpose
SamiAario-AMD Apr 7, 2026
fbae349
Fix test to work with multiple warps per block, and fix tile distribu…
SamiAario-AMD Apr 7, 2026
fa2c74d
Introduce DetermineWarpPrecType for determining warp GEMM precision t…
SamiAario-AMD Oct 9, 2025
b6a16fa
Add functionality and tests for bf16 x fp8 and fp8 x bf16
SamiAario-AMD Oct 9, 2025
26ce440
Add functionality and tests for fp16 x fp8 and fp8 x fp16
SamiAario-AMD Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"

namespace ck_tile {

Expand Down Expand Up @@ -529,4 +530,118 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
return out_tensor;
}

/**
* @brief Mixed-precision transpose load: converts input data type to output data type while
* transposing.
*
* This function enables transposing from one data type (e.g., fp8) to another (e.g., fp16) in a
* single operation. The input tile distribution encoding must be valid for the input data type,
* and the output distribution will be generated based on the output data type.
*
* @tparam DistributedTensor_ The output tensor type with desired output data type.
* @tparam BottomTensorView_ The input tensor view (may have different data type than output).
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution for input.
* @tparam NumCoord_ The number of coordinates (dimensions).
* @tparam Policy The transpose policy (should validate against input type).
*
* @note
* - Input and output must have compatible element space sizes (total byte count per Y-space).
* - Type conversion is performed element-by-element during the copy.
* - The validation uses the input data type for quad pattern checking.
* - The output distribution is generated based on the output data type.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord_,
index_t UnaryOpSize_,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord_>& __restrict__ tile_window,
const index_t offset,
number<UnaryOpSize_> = {})
{
using SrcDataType = typename BottomTensorView_::DataType;
using DstDataType = typename DistributedTensor_::DataType;

auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};

constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();

constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());

constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();

// For mixed precision: element space size must be the same (total bytes match)
static_assert(y_in_element_space_size == y_out_element_space_size,
"For mixed precision transpose, input and output element space size must match!");

// Ensure total element counts are consistent and divisible by the input vector length.
constexpr index_t total_elems_in =
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
constexpr index_t total_elems_out =
reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
static_assert(total_elems_in == total_elems_out,
"For mixed precision transpose, input/output element counts must match!");
static_assert(total_elems_in % number<UnaryOpSize_>{} == 0,
"Input vector length must evenly divide total elements.");

constexpr index_t num_of_access = total_elems_in / number<UnaryOpSize_>{};

// Read as input type, convert to output type
using SrcDataVec = ext_vector_t<SrcDataType, number<UnaryOpSize_>{}>;
using DstDataVec = ext_vector_t<DstDataType, number<UnaryOpSize_>{}>;
static_for<0, num_of_access, 1>{}([&](auto i) {
static_assert(number<UnaryOpSize_>{} == 8, "Only PassThroughPack8 is supported for now.");
const element_wise::PassThroughPack8 elementwise_op{};

elementwise_op(out_tensor.get_thread_buffer().template get_as<DstDataVec>()(i),
trans_tensor.get_thread_buffer().template get_as<SrcDataVec>()[i]);
});
}

/**
* @brief Mixed-precision transpose load with zero offset.
*
* Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord_,
index_t UnaryOpSize_,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void load_tile_transpose_convert(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord_>& __restrict__ tile_window,
number<UnaryOpSize_> = {})
{
load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0, number<UnaryOpSize_>{});
}

} // namespace ck_tile
1 change: 1 addition & 0 deletions projects/composablekernel/include/ck_tile/host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/print_matrix.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_contraction.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
Expand Down
38 changes: 38 additions & 0 deletions projects/composablekernel/include/ck_tile/host/fill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,44 @@ struct FillConstant
}
};

template <typename T>
struct FillIdentity
{
std::size_t rows_{0};
std::size_t cols_{0};
T zero_{type_convert<T>(0)};
T one_{type_convert<T>(1)};

template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
if(rows_ == 0 || cols_ == 0 || first == last)
return;

const auto total = static_cast<std::size_t>(std::distance(first, last));
if(total < rows_ * cols_)
{
throw std::runtime_error("FillIdentity requires range size >= rows_ * cols_.");
}

std::fill(first, first + rows_ * cols_, zero_);

const auto min_dim = std::min(rows_, cols_);
for(std::size_t i = 0; i < min_dim; ++i)
*(first + i * cols_ + i) = one_;
}

template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillIdentity&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};

//----------------------------------------------------------------------------------------------
/// @brief Transforms given input to fit 2:4 structured sparsity pattern so
/// every subgroup of 4 elements contain at most 2 non-zero elements
Expand Down
33 changes: 33 additions & 0 deletions projects/composablekernel/include/ck_tile/host/print_matrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once

// Helper to print matrix (for debugging)
template <typename T>
void print_matrix(const ck_tile::HostTensor<T>& mat,
const std::string& name = "Matrix",
const int width = 3,
const int precision = 3)
{
const auto lens = mat.get_lengths();
assert(len(lens) == 2);
const ck_tile::index_t rows = lens[0];
const ck_tile::index_t cols = lens[1];
const ck_tile::index_t limit = 32;

std::cout << name << " (" << rows << "×" << cols << "):\n";
for(ck_tile::index_t i = 0; i < std::min(rows, ck_tile::index_t(limit)); ++i)
{
for(ck_tile::index_t j = 0; j < std::min(cols, ck_tile::index_t(limit)); ++j)
{
std::cout << std::setw(width) << std::setprecision(precision)
<< ck_tile::type_convert<float>(mat(i, j)) << " ";
}
if(cols > limit)
std::cout << "...";
std::cout << "\n";
}
if(rows > limit)
std::cout << "...\n";
std::cout << "\n";
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
1 change: 1 addition & 0 deletions projects/composablekernel/include/ck_tile/ops/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#pragma once

#include "ck_tile/ops/common/determine_warp_prec_type.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include "ck_tile/core.hpp"

// DetermineWarpPrecType is a set of pattern-matching rules to determine the right precision types
// to use for the warp GEMM, given the precision types defined in the problem, and the compute data
// type. This gives rise to a type conversion: type conversions are sometimes needed to obtain
// a pair of types that are compatible with the hardware matrix operations available. A typical
// use case is mixed precision GEMMs.

namespace ck_tile {
// For the most general case, default to no conversion.
template <typename APrecType, typename BPrecType, typename ComputePrecType>
struct DetermineWarpPrecType
{
using a_prec_type = APrecType;
using b_prec_type = BPrecType;
};

// Use tf32_t if compute type if tf32_t
template <typename APrecType, typename BPrecType>
struct DetermineWarpPrecType<APrecType, BPrecType, ck_tile::tf32_t>
{
using a_prec_type = ck_tile::tf32_t;
using b_prec_type = ck_tile::tf32_t;
};

// For pk_fp4_t x pk_fp4_t, keep pk_fp4_t
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ComputePrecType>
{
using a_prec_type = ck_tile::pk_fp4_t;
using b_prec_type = ck_tile::pk_fp4_t;
};

// For pk_int4_t x B, use the B type.
template <typename BPrecType, typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::pk_int4_t, BPrecType, ComputePrecType>
{
using a_prec_type = BPrecType;
using b_prec_type = BPrecType;
};

// For A x pk_int4_t, use the A type.
template <typename APrecType, typename ComputePrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_int4_t, ComputePrecType>
{
using a_prec_type = APrecType;
using b_prec_type = APrecType;
};

// For pk_fp4_t x B, use the B type.
template <typename BPrecType, typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::pk_fp4_t, BPrecType, ComputePrecType>
{
using a_prec_type = BPrecType;
using b_prec_type = BPrecType;
};

// For A x pk_fp4_t, use the A type.
template <typename APrecType, typename ComputePrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_fp4_t, ComputePrecType>
{
using a_prec_type = APrecType;
using b_prec_type = APrecType;
};

// For pk_fp4_raw_t x B, use the B type.
template <typename BPrecType, typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::pk_fp4_raw_t, BPrecType, ComputePrecType>
{
using a_prec_type = BPrecType;
using b_prec_type = BPrecType;
};

// For A x pk_fp4_raw_t, use the A type.
template <typename APrecType, typename ComputePrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_fp4_raw_t, ComputePrecType>
{
using a_prec_type = APrecType;
using b_prec_type = APrecType;
};

// For fp8 x bf16, use fp8
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::fp8_t, ck_tile::bf16_t, ComputePrecType>
{
using a_prec_type = ck_tile::fp8_t;
using b_prec_type = ck_tile::fp8_t;
};

// For bf16 x fp8, use bf16
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::bf16_t, ck_tile::fp8_t, ComputePrecType>
{
using a_prec_type = ck_tile::bf16_t;
using b_prec_type = ck_tile::bf16_t;
};

// For bf8 x bf16, use bf8
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::bf8_t, ck_tile::bf16_t, ComputePrecType>
{
using a_prec_type = ck_tile::bf8_t;
using b_prec_type = ck_tile::bf8_t;
};

// For bf16 x bf8, use bf16
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::bf16_t, ck_tile::bf8_t, ComputePrecType>
{
using a_prec_type = ck_tile::bf16_t;
using b_prec_type = ck_tile::bf16_t;
};

// For fp8 x fp16, use fp8
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::fp8_t, ck_tile::half_t, ComputePrecType>
{
using a_prec_type = ck_tile::fp8_t;
using b_prec_type = ck_tile::fp8_t;
};

// For fp16 x fp8, use fp16
template <typename ComputePrecType>
struct DetermineWarpPrecType<ck_tile::half_t, ck_tile::fp8_t, ComputePrecType>
{
using a_prec_type = ck_tile::half_t;
using b_prec_type = ck_tile::half_t;
};
}; // namespace ck_tile
Loading