Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ struct BasicInvoker
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection)
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeCompute = ADataType_;
using BDataTypeCompute = BDataType_;
using ADataTypeBuf = ck_tile::if_select_t<ADataType_, ck_tile::tf32_t, float, ADataType_>;
using BDataTypeBuf = ck_tile::if_select_t<BDataType_, ck_tile::tf32_t, float, BDataType_>;
using ADataTypeBuf = ADataType_;
using BDataTypeBuf = BDataType_;

if constexpr(std::is_same_v<ADataTypeCompute, ck_tile::tf32_t>)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@

#pragma once

#include <string>
#include <variant>

#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"

#include <string>
#include <variant>

struct GemmConfigBase
{
static constexpr bool kPadM = false;
Expand All @@ -35,10 +34,6 @@ struct GemmConfigBase
static constexpr bool TiledMMAPermuteN = false;
};

// Type trait for tf32 storage type (tf32 uses float for memory layout calculations)
template <typename T>
using prec_storage_type = ck_tile::if_select_t<T, ck_tile::tf32_t, float, T>;

template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
Expand Down Expand Up @@ -85,7 +80,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
Expand Down Expand Up @@ -125,7 +120,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
Expand Down Expand Up @@ -297,7 +292,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
Expand All @@ -306,7 +301,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<prec_storage_type<PrecType>, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand All @@ -331,8 +326,8 @@ struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
{
using ADataType = float;
using BDataType = float;
using ADataType = ck_tile::tf32_t;
using BDataType = ck_tile::tf32_t;
using AccDataType = float;
using CDataType = float;
};
Expand Down
1 change: 1 addition & 0 deletions projects/composablekernel/include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_fp6.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/tfloat32.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

#pragma once

#include "ck_tile/core/numeric/integer.hpp"

#include <type_traits>
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER

namespace ck_tile::core::arch::mma {

/**
Expand Down Expand Up @@ -50,13 +57,12 @@ static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
*/
struct DefaultMfmaCtrlFlags
{
static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0
static constexpr uint32_t Abid = 0; // ABID flag, default 0
static constexpr uint32_t Blgp = 0; // BLGP flag, default 0
static constexpr int32_t Cbsz = 0; // CBSZ flag, default 0
static constexpr int32_t Abid = 0; // ABID flag, default 0
static constexpr int32_t Blgp = 0; // BLGP flag, default 0
};

#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>

/**
* @concept CtrlFlagsGfx9I
Expand All @@ -65,9 +71,9 @@ struct DefaultMfmaCtrlFlags
template <typename CtrlFlags>
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
// Flag members for Gfx9 MFMA instructions
{ CtrlFlags::Cbsz } -> std::convertible_to<int>;
{ CtrlFlags::Abid } -> std::convertible_to<int>;
{ CtrlFlags::Blgp } -> std::convertible_to<int>;
{ CtrlFlags::Cbsz } -> std::convertible_to<int32_t>;
{ CtrlFlags::Abid } -> std::convertible_to<int32_t>;
{ CtrlFlags::Blgp } -> std::convertible_to<int32_t>;
};

#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/ext_vector_base.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/ext_vector_base.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#if CK_TILE_USE_LLVM_BUILTIN_BF16
#include <hip/hip_bfloat16.h>
#endif
#include <stdint.h>

#pragma once

namespace ck_tile {

enum class bf16_rounding_mode
Expand Down Expand Up @@ -365,7 +366,7 @@ struct numeric<bfloat16_t>
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
}

// minumum finite value
// minimum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
Expand Down Expand Up @@ -535,4 +536,24 @@ CK_TILE_DEVICE void convert_float_to_bf16_pairs(const ext_vector_t<float, VecSiz
#endif
}

#if !CK_TILE_USE_CUSTOM_DATA_TYPE
template <>
CK_TILE_HOST_DEVICE constexpr float type_convert<float, bf16_t>(bf16_t x)
{
return bf16_to_float(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr bf16_t type_convert<bf16_t, float>(float x)
{
return float_to_bf16(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr bf16x2_t type_convert<bf16x2_t, fp32x2_t>(fp32x2_t x)
{
return fp32x2_to_bf16x2(x);
}
#endif

} // namespace ck_tile
48 changes: 38 additions & 10 deletions projects/composablekernel/include/ck_tile/core/numeric/float8.hpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"

#include <stdint.h>
#include <type_traits>

#pragma once

#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
Expand Down Expand Up @@ -803,7 +804,7 @@ struct numeric<fp8_t>
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
}

// minumum finite value
// minimum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
Expand Down Expand Up @@ -862,7 +863,7 @@ struct numeric<bf8_t>
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
}

// minumum finite value
// minimum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
Expand Down Expand Up @@ -926,7 +927,7 @@ struct numeric<fp8_t>
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
}

// minumum finite value
// minimum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
Expand Down Expand Up @@ -993,7 +994,7 @@ struct numeric<bf8_t>
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
}

// minumum finite value
// minimum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
Expand Down Expand Up @@ -1115,6 +1116,33 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }

CK_TILE_DEVICE
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };

#else

template <>
CK_TILE_HOST_DEVICE constexpr float type_convert<float, fp8_t>(fp8_t x)
{
return fp8_to_float(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr float type_convert<float, bf8_t>(bf8_t x)
{
return bf8_to_float(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr fp8_t type_convert<fp8_t, float>(float x)
{
return float_to_fp8(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr bf8_t type_convert<bf8_t, float>(float x)
{
return float_to_bf8(x);
}

#endif

} // namespace ck_tile
30 changes: 27 additions & 3 deletions projects/composablekernel/include/ck_tile/core/numeric/half.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"

#include <hip/hip_fp16.h>

#pragma once
#include <cstdint>

namespace ck_tile {

Expand Down Expand Up @@ -165,7 +169,7 @@ struct numeric<half_t>
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
}

// minumum finite value
// minimum finite value
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
Expand Down Expand Up @@ -320,4 +324,24 @@ constexpr fp16x2_t fp32x2_to_fp16x2(const fp32x2_t& x)
{
return fp16x2_t{float_to_fp16(x.x), float_to_fp16(x.y)};
}

#if !CK_TILE_USE_CUSTOM_DATA_TYPE
template <>
CK_TILE_HOST_DEVICE constexpr float type_convert<float, fp16_t>(fp16_t x)
{
return fp16_to_float(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr fp16_t type_convert<fp16_t, float>(float x)
{
return float_to_fp16(x);
}

template <>
CK_TILE_HOST_DEVICE constexpr fp16x2_t type_convert<fp16x2_t, fp32x2_t>(fp32x2_t x)
{
return fp32x2_to_fp16x2(x);
}
#endif
} // namespace ck_tile
Loading
Loading