Skip to content

[CK TILE] Unification Work – Add MFMA specialisations for tf32_t#6768

Draft
yungshengtu wants to merge 9 commits intousers/wj-laskowski/ck/add_mfma_specializationsfrom
users/yungshengtu/ck/unification/add_mfma_tf32_specialisations
Draft

[CK TILE] Unification Work – Add MFMA specialisations for tf32_t#6768
yungshengtu wants to merge 9 commits intousers/wj-laskowski/ck/add_mfma_specializationsfrom
users/yungshengtu/ck/unification/add_mfma_tf32_specialisations

Conversation

@yungshengtu
Copy link
Copy Markdown
Contributor

@yungshengtu yungshengtu commented Apr 24, 2026

Motivation

This PR adds two specialisations related to tf32_t.

Technical Details

This change treats tf32_t as a concrete type rather than an empty struct. It also adds two new specialisations for MFMA dense builtins and resolves existing circular include issues.

Test Plan

All the new wrappers were added to the test suite in test_amdgcn_mma_layout.inc.

Test Result

Test should pass.

Submission Checklist

@yungshengtu yungshengtu self-assigned this Apr 24, 2026
@yungshengtu yungshengtu force-pushed the users/yungshengtu/ck/unification/add_mfma_tf32_specialisations branch 15 times, most recently from aea639e to b7497d0 Compare April 25, 2026 20:04
@yungshengtu yungshengtu changed the base branch from develop to users/wj-laskowski/ck/add_mfma_specializations April 28, 2026 08:40
@yungshengtu yungshengtu force-pushed the users/yungshengtu/ck/unification/add_mfma_tf32_specialisations branch 2 times, most recently from 6852902 to 95c7bad Compare April 28, 2026 10:28
@yungshengtu yungshengtu force-pushed the users/yungshengtu/ck/unification/add_mfma_tf32_specialisations branch 2 times, most recently from 039e593 to ffd8dd5 Compare April 28, 2026 13:40
@yungshengtu yungshengtu requested a review from krithalith May 4, 2026 11:17
@wj-laskowski wj-laskowski force-pushed the users/wj-laskowski/ck/add_mfma_specializations branch from 45d739c to 55dcd32 Compare May 5, 2026 13:44
@yungshengtu yungshengtu force-pushed the users/yungshengtu/ck/unification/add_mfma_tf32_specialisations branch 3 times, most recently from 0e4c508 to 16c5543 Compare May 7, 2026 11:48
Copy link
Copy Markdown
Contributor

@chris-tsiaousis-hpc chris-tsiaousis-hpc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I've added some comments, you can consider them as NITs!

Comment thread projects/composablekernel/include/ck_tile/core/numeric/tfloat32.hpp
static constexpr int bias = 127;
static constexpr std::uint32_t nan_mask = 0x7F800000u;
static constexpr std::uint32_t head_mask = 0xFF800000u;
static constexpr std::uint32_t mant_mask = 0x7FFFFFu;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 23 bits mantissa, shouldn't it be 10-bit (0x7FE000u)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this value should never be used anyway.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.


CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2)
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
#undef CK_TILE_TYPE_CONVERT
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, CK_TILE_TYPE_CONVERT was #undef'd after use in this file. Now the macro is used across multiple files but never #undef'd, so it leaks to all downstream includers. Consider undefining it in each file after use, or in a central location.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.

CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);

static constexpr uint32_t float32_exponent_mask = 0x7f800000u;
[[deprecated]] static constexpr uint32_t float32_exponent_mask = 0x7f800000u;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you add those [[deprecated]] tags it might be useful to add a comment explaining WHAT to use instead?
I think our compiler also supports direct messages like [[deprecated("Use ck_tile::tf32_t directly")]] but I'm not sure...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.

Comment thread projects/composablekernel/include/ck_tile/core/numeric/numeric.hpp
Copy link
Copy Markdown
Contributor

@wj-laskowski wj-laskowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Few comments/questions from my side

return float_to_tf32<rounding>(x);
}

CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those type converts moved to data type headers?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because data type headers need type_convert, but type_convert.hpp also needs data type headers, resulting in circular inclusion.

[[deprecated]] static constexpr uint32_t float32_exponent_mask = 0x7f800000u;

enum class tf32_rounding_mode
enum class [[deprecated]] tf32_rounding_mode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good job marking it. I think it might be useful to have an issue to clean this up in another PR

Comment thread projects/composablekernel/include/ck_tile/core/numeric/numeric.hpp
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr tf32_t min() { return bit_cast<tf32_t>(0x00800000u); }

// minumum finite value
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minimum

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.

amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x64_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_i32_32x32x32_i8
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these (line 320 and 321) appear to be identical

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from #6014, but I'll change it anyway.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.

Comment on lines +1261 to +1263
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.

Comment on lines +1290 to +1292
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in 993d89a.

Comment on lines +10 to +21
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/tfloat32.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"

#include <cstdint>
Copy link
Copy Markdown
Contributor

@krithalith krithalith May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the pre-commit scripts do this? I don't see why your additions warrant so many new includes.

Copy link
Copy Markdown
Contributor Author

@yungshengtu yungshengtu May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I'd like to add them is because when I was fixing the circular inclusion, headers are sort of indirectly depending on other headers, which is a bit hard to debug. I found that directly including what they are using is a good practice in headers.

Comment on lines +65 to +74
// RTNE rounding.
if((i & f32_exp_mask) != f32_exp_mask)
{
// Add rounding bias for round-to-nearest-even (RTNE) before truncating:
// - 0xfff is the rounding bias corresponding to the 13 fraction bits that
// will be discarded.
// - (i >> 13) & 1 extracts the least significant of those discarded bits and
// adding it implements "ties to even" (round half-way cases to even).
i += 0xFFF + ((i >> 13) & 1);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool but I have some questions about this:

  1. Will this break for subnormal numbers?
  2. I think in some existing code there was a choice between truncation and RNTE, but here it is always the latter. Why? The ISA is actually inconsistent about what actually happens to the input, mentioning both rounding and simple truncation. In principle we could figure out which it is with some test code.
  3. What is fundamentally the point of explicitly rounding or truncating the mantissa like this? The intrinsic already does this, and for verification code you could just use float with a custom rtol and atol (which probably already gets calculated anyway). If there is no real reason to do this the type could be much simpler.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. No. It allows values with an exponent field that is not all 1s, and subnormal numbers, which have all-0 exponent field, can also go into this block.
  2. I think the ISA explicitly indicates rounding behaviour in section 7.1.2 (see CDNA3 ISA).
    Later, in section 12.10, it states that:
    “XF32 is a FP32 operation with FP32 inputs and outputs but implemented at reduced intermediate precision where the mantissa is truncated to 10 bits (not including leading 1 for non-zero values), and results are accumulated into an FP32 value with a 23-bit mantissa.”
    I interpret the “truncated to 10 bits” wording as describing the reduced precision representation, while the actual behaviour is consistent with the rounding rules defined earlier in section 7.1.2. In other words, the truncation here should be understood in the context of a rounding step rather than a literal bit-wise truncation.
    From this perspective, floating-point operations are still expected to follow standard rounding behaviour to avoid introducing unnecessary bias or excessive error.
  3. The main reason is to treat tf32_t as a concrete, well-defined type. Although the intrinsics already handle rounding, defining it as a type ensures that all values stored as tf32_t strictly conform to its precision rules.
    This makes tf32_t easier to use in a consistent way for both initialisation and subsequent computation, without requiring users to reason about the underlying bit-level behaviour. In practice, tf32_t is always constrained to its 10-bit mantissa definition, which simplifies its usage model.

Comment on lines -12 to -16
// TF32 tag type: 1 sign bit, 8 exponent bits, 10 mantissa bits (see numeric_traits<tf32_t>)
struct tf32_t
{
};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's quite strange that they chose to go with an empty tag here, while all other custom types do contain data. This seems quite error prone and requires a bunch of code to have conditional extra conditionals just for tf32_t, which isn't necessary for the other custom types. Can we hunt down this code remove the now deprecated conditionals? Also I am wondering if it is possible that there is any code that relies on the fact that it is currently an empty tag...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in 9daa9a6. I checked while removing, and I didn't see anything that relies on the fact of an empty tag.

@krithalith
Copy link
Copy Markdown
Contributor

Nice work! More involved than I expected. It seems like they chose a strange implementation for tf32_t with it being an empty tag. This seems very error prone and doesn't align with how all other custom types are constructed. I do wonder though: What would happen if we didn't change anything about the current tf32_t implementation? Would our new intrinsics work or would something break?

Another thing I was wondering is whether it would be safer to, for now, keep the old tf32_t implementation just in case it being an empty tag somehow affects performance or functionality somewhere, and then slowly removing it later while checking for perf regressions in specific tests / examples. This might be overkill though. It seems like overall, tf32 in ck tile is very new anyway. Might be worth discussing with the original authors of TF32 in CK tile (#4302).

Copy link
Copy Markdown
Contributor

@krithalith krithalith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Request changes" might be a bit strong but there are some things to think about / check.

@yungshengtu yungshengtu force-pushed the users/yungshengtu/ck/unification/add_mfma_tf32_specialisations branch from 16c5543 to 993d89a Compare May 8, 2026 12:15
@yungshengtu yungshengtu force-pushed the users/yungshengtu/ck/unification/add_mfma_tf32_specialisations branch from 993d89a to fde3885 Compare May 8, 2026 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants