From da5f60c7946f44991c1f3e7952b959182d9cd0fb Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 30 Apr 2026 04:52:53 -0500 Subject: [PATCH] [composablekernel] Support biased SwiGLU in MXFP4 MoE Port ROCm/composable_kernel#3735 into rocm-libraries so MXFP4 MoE handles bias-aware SwiGLU and avoids stale split-k accumulation when k_batch is one. Made-with: Cursor --- .../device/impl/device_moe_mx_gemm_bns.hpp | 6 +- .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 6 +- .../gridwise_gemm_xdl_cshuffle_common.hpp | 9 ++ .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 72 +++++++++- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 136 ++++++++++++++++-- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 27 +++- 6 files changed, 237 insertions(+), 19 deletions(-) diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp index d81e73630b26..6865cc6ab01e 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp @@ -74,7 +74,8 @@ template + typename ComputeTypeB = BDataType, + typename BiasDataType = CDataType> struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle p_ds, void* p_c, index_t NumTokens, @@ -444,6 +446,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle(p_a_scale), static_cast(p_b), static_cast(p_b_scale), + static_cast(p_bias), p_ds, static_cast(p_c), NumTokens, @@ -493,6 +496,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle(p_a_scale), static_cast(p_b), static_cast(p_b_scale), + nullptr, p_ds, static_cast(p_c), M, // randoms set, no use diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp index e64970145bd7..62610b94752f 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp @@ -74,7 +74,8 @@ template + typename ComputeTypeB = BDataType, + typename BiasDataType = CDataType> struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle p_ds, void* p_c, index_t NumTokens, @@ -471,6 +473,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle(p_a_scale), static_cast(p_b), static_cast(p_b_scale), + static_cast(p_bias), p_ds, static_cast(p_c), NumTokens, @@ -520,6 +523,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle(p_a_scale), static_cast(p_b), static_cast(p_b_scale), + nullptr, p_ds, static_cast(p_c), M, // randoms set, no use diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 2f9a9cd21b18..a9661cb4fbe6 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -33,6 +33,15 @@ enum Activation swiglustep_and_mul = 2 }; +__host__ __device__ inline float compute_swiglu_and_mul(float gate, float linear) +{ + constexpr float alpha = 1.702f; + constexpr float limit = 7.0f; + gate = gate < limit ? gate : limit; + linear = linear < limit ? (linear > -limit ? linear : -limit) : limit; + return gate * (1.0f / (1.0f + math::exp(-alpha * gate))) * (linear + 1.0f); +} + template + typename ComputeTypeB = BDataType, + typename BiasDataType = CDataType> struct GridwiseMoeGemmMXBNS : public GridwiseGemm_xdl_cshuffle_base< ALayout, @@ -703,6 +705,7 @@ struct GridwiseMoeGemmMXBNS const AScaleDataType* p_a_scale_grid_, const BDataType* p_b_grid_, const BScaleDataType* p_b_scale_grid_, + const BiasDataType* p_bias_grid_, std::array p_ds_grid_, CDataType* p_c_grid_, index_t NumTokens_, @@ -739,6 +742,7 @@ struct GridwiseMoeGemmMXBNS p_a_scale_grid{p_a_scale_grid_}, p_b_grid{p_b_grid_}, p_b_scale_grid{p_b_scale_grid_}, + p_bias_grid{p_bias_grid_}, p_ds_grid{}, p_c_grid{p_c_grid_}, a_element_op{a_element_op_}, @@ -762,6 +766,7 @@ struct GridwiseMoeGemmMXBNS const AScaleDataType* p_a_scale_grid; const BDataType* p_b_grid; const BScaleDataType* p_b_scale_grid; + const BiasDataType* p_bias_grid; DsGridPointer p_ds_grid; CDataType* p_c_grid; @@ -1099,6 +1104,7 @@ struct GridwiseMoeGemmMXBNS const AScaleDataType* p_a_scale_grid, const BDataType* p_b_grid, const BScaleDataType* p_b_scale_grid, + const BiasDataType* p_bias_grid, DsGridPointer& p_ds_grid, CDataType* p_c_grid, void* p_shared, @@ -1517,10 +1523,34 @@ struct GridwiseMoeGemmMXBNS static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + const BiasDataType* p_bias_col = nullptr; + const BiasDataType* p_bias_col_up = nullptr; + if(p_bias_grid != nullptr) + { + const long_index_t expert_bias_stride = + static_cast(problem.N) * (IsInputGemm ? 2 : 1); + const long_index_t base_n = static_cast(block_n_id) * NPerBlock + + static_cast(waveId_n) * NXdlPack * + NPerXdl + + threadIdx.x % NPerXdl; + p_bias_col = + p_bias_grid + static_cast(expert_id) * expert_bias_stride + base_n; + if constexpr(IsInputGemm) + { + p_bias_col_up = p_bias_col + problem.N; + } + } vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + constexpr index_t n_offset = + n0 * NWave * NXdlPack * NPerXdl + inxdl * NPerXdl; + const float bias = + p_bias_col != nullptr ? type_convert(p_bias_col[n_offset]) : 0.0f; + const float bias_up = p_bias_col_up != nullptr + ? type_convert(p_bias_col_up[n_offset]) + : 0.0f; static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk @@ -1553,9 +1583,32 @@ struct GridwiseMoeGemmMXBNS gate = gate * topk_weights.AsType()[m5]; up = up * topk_weights.AsType()[m5]; } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } + c_thread_buf_fp32(cidx) = + compute_swiglu_and_mul(gate, up); + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -1565,6 +1618,11 @@ struct GridwiseMoeGemmMXBNS gate = gate * topk_weights.AsType()[m5]; up = up * topk_weights.AsType()[m5]; } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; @@ -1581,13 +1639,16 @@ struct GridwiseMoeGemmMXBNS } else { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + float out_val = c_thread_buf[cidx]; + if(p_bias_col != nullptr) + { + out_val += bias; + } if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m5] * - c_thread_buf_fp32[cidx]; + out_val = topk_weights.AsType()[m5] * out_val; } + c_thread_buf_fp32(cidx) = out_val; } }); }); @@ -1630,6 +1691,7 @@ struct GridwiseMoeGemmMXBNS const AScaleDataType* p_a_scale_grid, const BDataType* p_b_grid, const BScaleDataType* p_b_scale_grid, + const BiasDataType* p_bias_grid, DsGridPointer& p_ds_grid, CDataType* p_c_grid, void* p_shared, diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 2e5f10e7915e..7fa2ba6d3940 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -59,6 +59,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_bias_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, @@ -101,6 +102,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_bias_grid, karg.p_ds_grid, karg.p_c_grid, p_shared_0, @@ -170,7 +172,8 @@ template + typename ComputeTypeB = BDataType, + typename BiasDataType = CDataType> struct GridwiseMoeGemmMX_BPreshuffle : public GridwiseGemm_xdl_cshuffle_base< ALayout, @@ -845,6 +848,7 @@ struct GridwiseMoeGemmMX_BPreshuffle const AScaleDataType* p_a_scale_grid_, const BDataType* p_b_grid_, const BScaleDataType* p_b_scale_grid_, + const BiasDataType* p_bias_grid_, std::array p_ds_grid_, CDataType* p_c_grid_, index_t NumTokens_, @@ -881,6 +885,7 @@ struct GridwiseMoeGemmMX_BPreshuffle p_a_scale_grid{p_a_scale_grid_}, p_b_grid{p_b_grid_}, p_b_scale_grid{p_b_scale_grid_}, + p_bias_grid{p_bias_grid_}, p_ds_grid{}, p_c_grid{p_c_grid_}, a_element_op{a_element_op_}, @@ -904,6 +909,7 @@ struct GridwiseMoeGemmMX_BPreshuffle const AScaleDataType* p_a_scale_grid; const BDataType* p_b_grid; const BScaleDataType* p_b_scale_grid; + const BiasDataType* p_bias_grid; DsGridPointer p_ds_grid; CDataType* p_c_grid; @@ -1222,6 +1228,7 @@ struct GridwiseMoeGemmMX_BPreshuffle const AScaleDataType* p_a_scale_grid, const BDataType* p_b_grid, const BScaleDataType* p_b_scale_grid, + const BiasDataType* p_bias_grid, DsGridPointer& p_ds_grid, CDataType* p_c_grid, void* p_shared, @@ -1613,10 +1620,34 @@ struct GridwiseMoeGemmMX_BPreshuffle static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + const BiasDataType* p_bias_col = nullptr; + const BiasDataType* p_bias_col_up = nullptr; + if(p_bias_grid != nullptr) + { + const long_index_t expert_bias_stride = + static_cast(problem.N) * (IsInputGemm ? 2 : 1); + const long_index_t base_n = static_cast(block_n_id) * NPerBlock + + static_cast(waveId_n) * NXdlPack * + NPerXdl + + threadIdx.x % NPerXdl; + p_bias_col = + p_bias_grid + static_cast(expert_id) * expert_bias_stride + base_n; + if constexpr(IsInputGemm) + { + p_bias_col_up = p_bias_col + problem.N; + } + } vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + constexpr index_t n_offset = + n0 * NWave * NXdlPack * NPerXdl + inxdl * NPerXdl; + const float bias = + p_bias_col != nullptr ? type_convert(p_bias_col[n_offset]) : 0.0f; + const float bias_up = p_bias_col_up != nullptr + ? type_convert(p_bias_col_up[n_offset]) + : 0.0f; static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk @@ -1648,9 +1679,32 @@ struct GridwiseMoeGemmMX_BPreshuffle gate = gate * topk_weights.AsType()[m5]; up = up * topk_weights.AsType()[m5]; } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } + c_thread_buf_fp32(cidx) = + compute_swiglu_and_mul(gate, up); + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -1660,19 +1714,27 @@ struct GridwiseMoeGemmMX_BPreshuffle gate = gate * topk_weights.AsType()[m5]; up = up * topk_weights.AsType()[m5]; } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } } else { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + float out_val = c_thread_buf[cidx]; + if(p_bias_col != nullptr) + { + out_val += bias; + } if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m5] * - c_thread_buf_fp32[cidx]; + out_val = topk_weights.AsType()[m5] * out_val; } + c_thread_buf_fp32(cidx) = out_val; } }); }); @@ -1714,6 +1776,7 @@ struct GridwiseMoeGemmMX_BPreshuffle const AScaleDataType* p_a_scale_grid, const BDataType* p_b_grid, const BScaleDataType* p_b_scale_grid, + const BiasDataType* p_bias_grid, DsGridPointer& p_ds_grid, CDataType* p_c_grid, void* p_shared_0, @@ -2110,10 +2173,34 @@ struct GridwiseMoeGemmMX_BPreshuffle static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + const BiasDataType* p_bias_col = nullptr; + const BiasDataType* p_bias_col_up = nullptr; + if(p_bias_grid != nullptr) + { + const long_index_t expert_bias_stride = + static_cast(problem.N) * (IsInputGemm ? 2 : 1); + const long_index_t base_n = static_cast(block_n_id) * NPerBlock + + static_cast(waveId_n) * NXdlPack * + NPerXdl + + threadIdx.x % NPerXdl; + p_bias_col = + p_bias_grid + static_cast(expert_id) * expert_bias_stride + base_n; + if constexpr(IsInputGemm) + { + p_bias_col_up = p_bias_col + problem.N; + } + } vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + constexpr index_t n_offset = + n0 * NWave * NXdlPack * NPerXdl + inxdl * NPerXdl; + const float bias = + p_bias_col != nullptr ? type_convert(p_bias_col[n_offset]) : 0.0f; + const float bias_up = p_bias_col_up != nullptr + ? type_convert(p_bias_col_up[n_offset]) + : 0.0f; static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk @@ -2145,9 +2232,32 @@ struct GridwiseMoeGemmMX_BPreshuffle gate = gate * topk_weights.AsType()[m5]; up = up * topk_weights.AsType()[m5]; } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } + c_thread_buf_fp32(cidx) = + compute_swiglu_and_mul(gate, up); + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -2157,19 +2267,27 @@ struct GridwiseMoeGemmMX_BPreshuffle gate = gate * topk_weights.AsType()[m5]; up = up * topk_weights.AsType()[m5]; } + if(p_bias_col != nullptr) + { + gate += bias; + up += bias_up; + } tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } } else { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + float out_val = c_thread_buf[cidx]; + if(p_bias_col != nullptr) + { + out_val += bias; + } if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m5] * - c_thread_buf_fp32[cidx]; + out_val = topk_weights.AsType()[m5] * out_val; } + c_thread_buf_fp32(cidx) = out_val; } }); }); diff --git a/projects/composablekernel/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index bffd02c5f6c9..55552b70075f 100644 --- a/projects/composablekernel/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1476,12 +1476,33 @@ struct MoeFlatmmKernel c_scatter_offsets[mIter], c_scatter_valids[mIter]); - if constexpr(!IsInputGemm || - decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp == - memory_operation_enum::atomic_add) + if constexpr(!IsInputGemm) + { c_scatter_tile_window.update(c_out_tensor); + } + else if constexpr(decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::atomic_add) + { + if constexpr(IsGemm1SplitK) + { + if(kargs.k_batch == 1) + { + c_scatter_tile_window.store(c_out_tensor); + } + else + { + c_scatter_tile_window.update(c_out_tensor); + } + } + else + { + c_scatter_tile_window.update(c_out_tensor); + } + } else + { c_scatter_tile_window.store(c_out_tensor); + } if constexpr(iAccess != num_access - 1) {