[CK TILE] Unification Work – Add MFMA specialisations for tf32_t#6768
Conversation
aea639e to
b7497d0
Compare
6852902 to
95c7bad
Compare
039e593 to
ffd8dd5
Compare
45d739c to
55dcd32
Compare
0e4c508 to
16c5543
Compare
chris-tsiaousis-hpc
left a comment
There was a problem hiding this comment.
Great work! I've added some comments, you can consider them as NITs!
| static constexpr int bias = 127; | ||
| static constexpr std::uint32_t nan_mask = 0x7F800000u; | ||
| static constexpr std::uint32_t head_mask = 0xFF800000u; | ||
| static constexpr std::uint32_t mant_mask = 0x7FFFFFu; |
There was a problem hiding this comment.
Why 23 bits mantissa, shouldn't it be 10-bit (0x7FE000u)?
There was a problem hiding this comment.
I think this value should never be used anyway.
|
|
||
| CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2) | ||
| CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) | ||
| #undef CK_TILE_TYPE_CONVERT |
There was a problem hiding this comment.
Previously, CK_TILE_TYPE_CONVERT was #undef'd after use in this file. Now the macro is used across multiple files but never #undef'd, so it leaks to all downstream includers. Consider undefining it in each file after use, or in a central location.
| CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale); | ||
|
|
||
| static constexpr uint32_t float32_exponent_mask = 0x7f800000u; | ||
| [[deprecated]] static constexpr uint32_t float32_exponent_mask = 0x7f800000u; |
There was a problem hiding this comment.
Since you add those [[deprecated]] tags it might be useful to add a comment explaining WHAT to use instead?
I think our compiler also supports direct messages like [[deprecated("Use ck_tile::tf32_t directly")]] but I'm not sure...
wj-laskowski
left a comment
There was a problem hiding this comment.
Nice work! Few comments/questions from my side
| return float_to_tf32<rounding>(x); | ||
| } | ||
|
|
||
| CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float) |
There was a problem hiding this comment.
Why are those type converts moved to data type headers?
There was a problem hiding this comment.
It's because data type headers need type_convert, but type_convert.hpp also needs data type headers, resulting in circular inclusion.
| [[deprecated]] static constexpr uint32_t float32_exponent_mask = 0x7f800000u; | ||
|
|
||
| enum class tf32_rounding_mode | ||
| enum class [[deprecated]] tf32_rounding_mode |
There was a problem hiding this comment.
good job marking it. I think it might be useful to have an issue to clean this up in another PR
| // minimum finite value, or minimum positive normalized value for float | ||
| CK_TILE_HOST_DEVICE static constexpr tf32_t min() { return bit_cast<tf32_t>(0x00800000u); } | ||
|
|
||
| // minumum finite value |
| amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x64_i8 | ||
| amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_i32_32x32x32_i8 | ||
| amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16 | ||
| amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16 |
| static_cast<int>(CtrlFlags::Cbsz), | ||
| static_cast<int>(CtrlFlags::Abid), | ||
| static_cast<int>(CtrlFlags::Blgp))}; |
| static_cast<int>(CtrlFlags::Cbsz), | ||
| static_cast<int>(CtrlFlags::Abid), | ||
| static_cast<int>(CtrlFlags::Blgp))}; |
| #include "ck_tile/core/arch/mma/mma_op_family.hpp" | ||
| #include "ck_tile/core/config.hpp" | ||
| #include "ck_tile/core/numeric/bfloat16.hpp" | ||
| #include "ck_tile/core/numeric/float8.hpp" | ||
| #include "ck_tile/core/numeric/half.hpp" | ||
| #include "ck_tile/core/numeric/int8.hpp" | ||
| #include "ck_tile/core/numeric/integer.hpp" | ||
| #include "ck_tile/core/numeric/tfloat32.hpp" | ||
| #include "ck_tile/core/numeric/vector_type.hpp" | ||
| #include "ck_tile/core/utility/bit_cast.hpp" | ||
|
|
||
| #include <cstdint> |
There was a problem hiding this comment.
Did the pre-commit scripts do this? I don't see why your additions warrant so many new includes.
There was a problem hiding this comment.
The reason I'd like to add them is because when I was fixing the circular inclusion, headers are sort of indirectly depending on other headers, which is a bit hard to debug. I found that directly including what they are using is a good practice in headers.
| // RTNE rounding. | ||
| if((i & f32_exp_mask) != f32_exp_mask) | ||
| { | ||
| // Add rounding bias for round-to-nearest-even (RTNE) before truncating: | ||
| // - 0xfff is the rounding bias corresponding to the 13 fraction bits that | ||
| // will be discarded. | ||
| // - (i >> 13) & 1 extracts the least significant of those discarded bits and | ||
| // adding it implements "ties to even" (round half-way cases to even). | ||
| i += 0xFFF + ((i >> 13) & 1); | ||
| } |
There was a problem hiding this comment.
This is cool but I have some questions about this:
- Will this break for subnormal numbers?
- I think in some existing code there was a choice between truncation and RNTE, but here it is always the latter. Why? The ISA is actually inconsistent about what actually happens to the input, mentioning both rounding and simple truncation. In principle we could figure out which it is with some test code.
- What is fundamentally the point of explicitly rounding or truncating the mantissa like this? The intrinsic already does this, and for verification code you could just use float with a custom rtol and atol (which probably already gets calculated anyway). If there is no real reason to do this the type could be much simpler.
There was a problem hiding this comment.
- No. It allows values with an exponent field that is not all 1s, and subnormal numbers, which have all-0 exponent field, can also go into this block.
- I think the ISA explicitly indicates rounding behaviour in section 7.1.2 (see CDNA3 ISA).
Later, in section 12.10, it states that:
“XF32 is a FP32 operation with FP32 inputs and outputs but implemented at reduced intermediate precision where the mantissa is truncated to 10 bits (not including leading 1 for non-zero values), and results are accumulated into an FP32 value with a 23-bit mantissa.”
I interpret the “truncated to 10 bits” wording as describing the reduced precision representation, while the actual behaviour is consistent with the rounding rules defined earlier in section 7.1.2. In other words, the truncation here should be understood in the context of a rounding step rather than a literal bit-wise truncation.
From this perspective, floating-point operations are still expected to follow standard rounding behaviour to avoid introducing unnecessary bias or excessive error. - The main reason is to treat
tf32_tas a concrete, well-defined type. Although the intrinsics already handle rounding, defining it as a type ensures that all values stored astf32_tstrictly conform to its precision rules.
This makestf32_teasier to use in a consistent way for both initialisation and subsequent computation, without requiring users to reason about the underlying bit-level behaviour. In practice,tf32_tis always constrained to its 10-bit mantissa definition, which simplifies its usage model.
| // TF32 tag type: 1 sign bit, 8 exponent bits, 10 mantissa bits (see numeric_traits<tf32_t>) | ||
| struct tf32_t | ||
| { | ||
| }; | ||
|
|
There was a problem hiding this comment.
Yeah it's quite strange that they chose to go with an empty tag here, while all other custom types do contain data. This seems quite error prone and requires a bunch of code to have conditional extra conditionals just for tf32_t, which isn't necessary for the other custom types. Can we hunt down this code remove the now deprecated conditionals? Also I am wondering if it is possible that there is any code that relies on the fact that it is currently an empty tag...
There was a problem hiding this comment.
Removed in 9daa9a6. I checked while removing, and I didn't see anything that relies on the fact of an empty tag.
|
Nice work! More involved than I expected. It seems like they chose a strange implementation for tf32_t with it being an empty tag. This seems very error prone and doesn't align with how all other custom types are constructed. I do wonder though: What would happen if we didn't change anything about the current tf32_t implementation? Would our new intrinsics work or would something break? Another thing I was wondering is whether it would be safer to, for now, keep the old tf32_t implementation just in case it being an empty tag somehow affects performance or functionality somewhere, and then slowly removing it later while checking for perf regressions in specific tests / examples. This might be overkill though. It seems like overall, tf32 in ck tile is very new anyway. Might be worth discussing with the original authors of TF32 in CK tile (#4302). |
krithalith
left a comment
There was a problem hiding this comment.
"Request changes" might be a bit strong but there are some things to think about / check.
16c5543 to
993d89a
Compare
993d89a to
fde3885
Compare
Motivation
This PR adds two specialisations related to
tf32_t.Technical Details
This change treats
tf32_tas a concrete type rather than an emptystruct. It also adds two new specialisations for MFMA dense builtins and resolves existing circular include issues.Test Plan
All the new wrappers were added to the test suite in test_amdgcn_mma_layout.inc.
Test Result
Test should pass.
Submission Checklist