diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 3f74a588..ac7aceb0 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -82,11 +82,11 @@ def make_preshuffle_b_layout( if elem_bytes not in (1, 2): raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}") - c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True) + c_k_bytes = c_k * fx.Index(int(elem_bytes)) c_k0 = c_k_bytes // c64 n0 = c_n // c16 - c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True)) + c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // fx.Index(int(elem_bytes))) stride_nlane = c_kpack_elems stride_klane = c16 * stride_nlane @@ -108,6 +108,30 @@ def make_preshuffle_b_layout( return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes) +def _unpack_int4_to_int8_pair(packed32): + """Split packed int4 dword into two int8 dwords (even/odd nibbles). + + 7-op bit manipulation shared by all int4 unpack paths (W4A8, W4A16, W4A_FP8). + """ + c_08 = fx.Int32(0x08080808) + c_0f = fx.Int32(0x0F0F0F0F) + c_1e = fx.Int32(0x1E) + c_4 = fx.Int32(4) + s0 = (packed32 & c_08) * c_1e + even = (packed32 & c_0f) | s0 + t = packed32 >> c_4 + s1 = (t & c_08) * c_1e + odd = (t & c_0f) | s1 + return even, odd + + +def _pack_i32_pair_to_i64(lo, hi, vector): + """Pack two i32 values into one i64 via vector bitcast.""" + v2 = vector.from_elements(T.vec(2, T.i32), [lo, hi]) + v64 = vector.bitcast(T.vec(1, T.i64), v2) + return vector.extract(v64, static_position=[0], dynamic_position=[]) + + def _i8x4_in_i32_to_bf16x4_i64(val_i32, arith, vector, scale_val=None): """Convert one i32 (4 signed int8 bytes) to 4 bf16 packed as i64. @@ -205,18 +229,7 @@ def unpack_b_w4a16(packed32, arith, vector, scale_val=None): Takes raw packed32 from load_b_raw_w4a16 and produces (b0, b1) -- two i64 values each containing 4 bf16 for one MFMA. """ - c_08080808 = fx.Int32(0x08080808) - c_0f0f0f0f = fx.Int32(0x0F0F0F0F) - c_1e = fx.Int32(0x1E) - c_4_i32 = fx.Int32(4) - - s0 = (packed32 & c_08080808) * c_1e - even = (packed32 & c_0f0f0f0f) | s0 - - t = packed32 >> c_4_i32 - s1 = (t & c_08080808) * c_1e - odd = (t & c_0f0f0f0f) | s1 - + even, odd = _unpack_int4_to_int8_pair(packed32) b0 = _i8x4_in_i32_to_bf16x4_i64(even, arith, vector, scale_val=scale_val) b1 = _i8x4_in_i32_to_bf16x4_i64(odd, arith, vector, scale_val=scale_val) return (b0, b1) @@ -252,12 +265,12 @@ def load_b_pack_k32( raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}") c64 = fx.Index(64) - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) + base_k_bytes = base_k * fx.Index(int(elem_bytes)) k0_base = base_k_bytes // c64 - k0 = k0_base + arith.constant(ki_step // 2, index=True) + k0 = k0_base + fx.Index(ki_step // 2) k1 = lane_div_16 half_bytes = kpack_bytes // 2 - k2_base = arith.constant((ki_step % 2) * half_bytes, index=True) + k2_base = fx.Index((ki_step % 2) * half_bytes) coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0)) idx_pack = crd2idx(coord_pack, layout_b) @@ -273,22 +286,8 @@ def load_b_pack_k32( static_position=[0], dynamic_position=[], ) - - c_08080808 = fx.Int32(0x08080808) - c_0f0f0f0f = fx.Int32(0x0F0F0F0F) - c_1e = fx.Int32(0x1E) - c_4_i32 = fx.Int32(4) - - s0 = (packed32 & c_08080808) * c_1e - even = (packed32 & c_0f0f0f0f) | s0 - - t = packed32 >> c_4_i32 - s1 = (t & c_08080808) * c_1e - odd = (t & c_0f0f0f0f) | s1 - - v2 = vector.from_elements(T.vec(2, T.i32), [even, odd]) - v64 = vector.bitcast(T.vec(1, T.i64), v2) - return vector.extract(v64, static_position=[0], dynamic_position=[]) + even, odd = _unpack_int4_to_int8_pair(packed32) + return _pack_i32_pair_to_i64(even, odd, vector) vec_elems = kpack_bytes // int(elem_bytes) b16 = _buffer_load_vec( @@ -324,7 +323,7 @@ def tile_chunk_coord_i32( """Map (thread, chunk_id) -> (row_local, col_local_i32) for X/A loads.""" if chunk_i32 not in (1, 2, 4): raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}") - chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True) + chunk_off_i32 = fx.Index(i * total_threads * chunk_i32) tile_idx_i32 = tx_i32_base + chunk_off_i32 coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4) row_local = fx.get(coord_local, 0) @@ -476,6 +475,251 @@ def lds_load_pack_k32( "lds_store_16b_xor16", "make_preshuffle_b_layout", "load_b_pack_k32", + "load_b_raw_w4a16", + "unpack_b_w4a16", + "load_b_raw_w4a16_groupwise", + "unpack_b_w4a16_groupwise", + "load_b_raw_w4a8_k64", + "load_b_raw_w4a8_groupwise_k64", + "unpack_b_w4a8", + "unpack_b_w4a_fp8", "swizzle_xor16", "tile_chunk_coord_i32", ] + + +# --------------------------------------------------------------------------- +# Groupwise scale load helper (shared by W4A16 and W4A8 groupwise paths) +# --------------------------------------------------------------------------- + +def _load_groupwise_scale( + buffer_ops, + arith, + *, + scale_rsrc, + expert_offset, + n_blk, + n_intra, + k_pos, + num_groups: int, + group_size: int, + n_per_expert: int, +): + """Load one per-group scale value from the scale buffer. + + Computes the linear index into the scale tensor from expert offset, + N position, and group index derived from ``k_pos``. + """ + c16 = fx.Index(16) + n_global = n_blk * c16 + n_intra + c_group_size = fx.Index(group_size) + c_gm1 = fx.Index(num_groups - 1) + c_npe = fx.Index(n_per_expert) + # n_global is the GLOBAL N index (includes expert offset), so use (G-1) + # to compensate: expert_offset*(G-1) + (expert_offset + n_within) = expert_offset*G + n_within + base_scale = expert_offset * c_gm1 + n_global + group_idx = k_pos // c_group_size + scale_idx_i32 = arith.index_cast(T.i32, base_scale + group_idx * c_npe) + return buffer_ops.buffer_load(scale_rsrc, scale_idx_i32, vec_width=1, dtype=T.f32) + + +# --------------------------------------------------------------------------- +# W4A16 groupwise load / unpack helpers +# --------------------------------------------------------------------------- + +def load_b_raw_w4a16_groupwise( + buffer_ops, + arith, + vector, + *, + arg_b, + b_rsrc, + layout_b, + base_k, + ku: int, + n_blk, + n_intra, + lane_div_16, + elem_type, + scale_rsrc, + expert_offset, + num_groups: int, + group_size: int, + n_per_expert: int, + kpack_bytes: int = 8, +): + """Phase 1 of W4A16 groupwise B load: buffer_loads for weight + scale. + + Reuses :func:`load_b_raw_w4a16` for the weight load, then issues an + additional ``buffer_load_dword`` for the per-group scale. + + Returns ``(packed32, scale_val)``. + """ + packed32 = load_b_raw_w4a16( + buffer_ops, arith, vector, + arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk, n_intra=n_intra, + lane_div_16=lane_div_16, elem_type=elem_type, + kpack_bytes=kpack_bytes, + ) + k_pos = base_k + fx.Index(ku * 32) + scale_val = _load_groupwise_scale( + buffer_ops, arith, + scale_rsrc=scale_rsrc, expert_offset=expert_offset, + n_blk=n_blk, n_intra=n_intra, k_pos=k_pos, + num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert, + ) + return (packed32, scale_val) + + +def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector): + """Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16.""" + return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val) + + +# --------------------------------------------------------------------------- +# W4A8 load / unpack helpers (8B K64 loads) +# --------------------------------------------------------------------------- + +def load_b_raw_w4a8_k64( + buffer_ops, + arith, + vector, + *, + arg_b, + b_rsrc, + layout_b, + base_k, + ku: int, + n_blk, + n_intra, + lane_div_16, + elem_type, + kpack_bytes: int = 8, +): + """Phase 1 of W4A8 per-row B load: 8-byte buffer_load_dwordx2 for one K64 step. + + Loads both K32 halves in a single VMEM instruction (``buffer_load_dwordx2``). + Returns ``(packed32_half0, packed32_half1)`` for :func:`unpack_b_w4a8`. + """ + if kpack_bytes != 8: + raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}") + + c64 = fx.Index(64) + k0_base = base_k // c64 + k0 = k0_base + fx.Index(ku) + k1 = lane_div_16 + + coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0)) + idx_pack = crd2idx(coord_pack, layout_b) + + b8 = _buffer_load_vec( + buffer_ops, vector, b_rsrc, idx_pack, + elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True, + ) + b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8) + half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[]) + half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[]) + return (half0, half1) + + +def load_b_raw_w4a8_groupwise_k64( + buffer_ops, + arith, + vector, + *, + arg_b, + b_rsrc, + layout_b, + base_k, + ku: int, + n_blk, + n_intra, + lane_div_16, + elem_type, + scale_rsrc, + expert_offset, + num_groups: int, + group_size: int, + n_per_expert: int, + kpack_bytes: int = 8, +): + """Phase 1 of W4A8 groupwise B load: 8B weight + two scale loads per K64. + + Reuses :func:`load_b_raw_w4a8_k64` for the weight load, then issues two + ``buffer_load_dword`` for per-group scales (each K32 half may belong to a + different group). + + Returns ``(half0, half1, scale0, scale1)``. + """ + half0, half1 = load_b_raw_w4a8_k64( + buffer_ops, arith, vector, + arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk, n_intra=n_intra, + lane_div_16=lane_div_16, elem_type=elem_type, + kpack_bytes=kpack_bytes, + ) + + scale_kw = dict( + scale_rsrc=scale_rsrc, expert_offset=expert_offset, + n_blk=n_blk, n_intra=n_intra, + num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert, + ) + scale0 = _load_groupwise_scale( + buffer_ops, arith, k_pos=base_k + fx.Index(ku * 2 * 32), **scale_kw, + ) + scale1 = _load_groupwise_scale( + buffer_ops, arith, k_pos=base_k + fx.Index((ku * 2 + 1) * 32), **scale_kw, + ) + return (half0, half1, scale0, scale1) + + +def unpack_b_w4a8(packed32, arith, vector): + """Phase 2 of W4A8 B load: 7-op unpack from packed int4 to int8 i64. + + Takes a raw ``packed32`` (one dword of packed int4) and produces one i64 + value containing 8 signed int8 bytes for one MFMA K32 step. + """ + even, odd = _unpack_int4_to_int8_pair(packed32) + return _pack_i32_pair_to_i64(even, odd, vector) + + +def unpack_b_w4a_fp8(packed32, arith, vector, rocdl): + """Unpack packed int4 (i32) to fp8 i64 for mfma_f32_16x16x32_fp8_fp8. + + Pipeline: int4 -> int8 (7-op unpack) -> f32 (byte extract + sitofp) + -> fp8 (cvt_pk_fp8_f32) -> i64. + """ + even, odd = _unpack_int4_to_int8_pair(packed32) + + c_8 = fx.Int32(8) + c_16 = fx.Int32(16) + c_24 = fx.Int32(24) + + from flydsl._mlir.dialects._arith_ops_gen import ShRSIOp as _ShRSIOp + _uw = arith._to_raw + _av = arith.ArithValue + + def _i32_int8x4_to_fp8x4(val): + """Convert i32 containing 4 signed int8 bytes -> i32 containing 4 fp8 bytes.""" + def _sext_byte(src, shl_amount, shr_amount): + shifted = src << shl_amount + shrsi_result = _ShRSIOp(_uw(shifted), _uw(shr_amount)).result + return _uw(arith.sitofp(T.f32, _av(shrsi_result))) + + f0 = _sext_byte(val, c_24, c_24) + f1 = _sext_byte(val, c_16, c_24) + f2 = _sext_byte(val, c_8, c_24) + b3 = _ShRSIOp(_uw(val), _uw(c_24)).result + f3 = _uw(arith.sitofp(T.f32, _av(b3))) + + zero = _uw(fx.Int32(0)) + pk = rocdl.cvt_pk_fp8_f32(src_a=f0, src_b=f1, old=zero, word_sel=0, res=T.i32) + pk = rocdl.cvt_pk_fp8_f32(src_a=f2, src_b=f3, old=_uw(pk), word_sel=1, res=T.i32) + return pk + + even_fp8 = _i32_int8x4_to_fp8x4(even) + odd_fp8 = _i32_int8x4_to_fp8x4(odd) + return _pack_i32_pair_to_i64(even_fp8, odd_fp8, vector) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index e20850be..411c222c 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -26,7 +26,9 @@ from flydsl._mlir import ir from flydsl._mlir.dialects import llvm, scf, memref +from flydsl._mlir.dialects import fly as _fly_dialect from flydsl._mlir.dialects import math as math_dialect +from kernels.kernels_common import _create_llvm_ptr from flydsl.expr.typing import T from flydsl.expr.arith import ArithValue @@ -2247,7 +2249,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): # stable path here.) out_base_idx = None if out_is_bf16: - out_base_idx = memref.extract_aligned_pointer_as_index(arg_out) + _out_raw = arg_out.value if hasattr(arg_out, "value") else arg_out; _out_ptr = _fly_dialect.extract_aligned_pointer_as_index(ir.Type.parse("!llvm.ptr"), _out_raw); out_base_idx = arith.index_cast(T.index, llvm.PtrToIntOp(T.i64, _out_ptr).result) def write_row_to_lds( *, @@ -2311,7 +2313,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): byte_off = idx_elem_even * c2_i32 byte_off_idx = arith.index_cast(T.index, byte_off) ptr_addr_idx = out_base_idx + byte_off_idx - out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1) + out_ptr = _create_llvm_ptr(ptr_addr_idx, address_space=1) out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr frag_v = frag._value if hasattr(frag, "_value") else frag llvm.AtomicRMWOp( diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 9eefa4ed..7486dc05 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -41,6 +41,8 @@ def bf16_global_atomics_arch_description() -> str: from flydsl._mlir import ir from flydsl._mlir.dialects import llvm, scf, memref +from flydsl._mlir.dialects import fly as _fly_dialect +from kernels.kernels_common import _create_llvm_ptr from flydsl.expr.typing import T @@ -55,6 +57,12 @@ def bf16_global_atomics_arch_description() -> str: load_b_pack_k32, load_b_raw_w4a16, unpack_b_w4a16, + load_b_raw_w4a16_groupwise, + unpack_b_w4a16_groupwise, + load_b_raw_w4a8_k64, + load_b_raw_w4a8_groupwise_k64, + unpack_b_w4a8, + unpack_b_w4a_fp8, tile_chunk_coord_i32, swizzle_xor16, ) @@ -118,20 +126,25 @@ def compile_moe_gemm1( have a distinct input scaling before quantization). - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - "int4_bf16": W4A16 path: X is bf16, W is packed int4 unpacked to bf16 in-kernel + - "int4_fp8": W4A_FP8 path: X is fp8, W is packed int4 unpacked to fp8 in-kernel """ gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) _state = {} # legacy; kept until stage2/reduction are migrated - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"): + _valid_dtypes = ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "int4_fp8") + if in_dtype not in _valid_dtypes: raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16'), got {in_dtype!r}" + f"in_dtype must be one of {_valid_dtypes}, got {in_dtype!r}" ) - is_int4_bf16 = in_dtype == "int4_bf16" + is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights + is_int4_fp8 = in_dtype == "int4_fp8" # W4A_FP8: fp8 activations, packed int4 weights is_f16 = in_dtype == "fp16" is_bf16 = is_int4_bf16 or in_dtype == "bf16" is_f16_or_bf16 = is_f16 or is_bf16 + # fp8/int8/int4/int4_fp8: both scales needed + needs_scale_x = not is_f16_or_bf16 # True for fp8/int8/int4/int4_fp8 needs_scale_w = (not is_f16_or_bf16) or is_int4_bf16 elem_bytes = 2 if is_f16_or_bf16 else 1 if out_dtype not in ("f16", "bf16"): @@ -150,12 +163,32 @@ def compile_moe_gemm1( f"(tile_k={tile_k}, elem_bytes={elem_bytes})" ) is_int4 = in_dtype == "int4" + # w_is_int4: True for any variant where weights are packed int4. + w_is_int4 = is_int4 or is_int4_bf16 or is_int4_fp8 # INT4 here means W4A8: X is int8, W is packed int4 and unpacked to int8 in-kernel. is_int8 = (in_dtype == "int8") or is_int4 x_is_token_slot = in_dtype == "int8smooth" # "int8smooth" still uses int8 MFMA, but X/scale_x are provided per (token,slot). is_int8 = is_int8 or x_is_token_slot + # Group-wise scale support for W4A16, W4A8, and W4A_FP8 + # NOTE: Only group_size=32 is supported due to int4 preshuffle layout constraints. + use_groupwise_scale = w_is_int4 and group_size > 0 + if use_groupwise_scale and group_size != 32: + raise ValueError( + f"FlyDSL groupwise scale only supports group_size=32, got {group_size}. " + f"This is due to int4 preshuffle layout constraints. " + f"Please use Triton kernel for other group sizes." + ) + is_int4_groupwise = is_int4 and use_groupwise_scale + is_int4_fp8_groupwise = is_int4_fp8 and use_groupwise_scale + is_int4_bf16_groupwise = is_int4_bf16 and use_groupwise_scale + num_groups = model_dim // group_size if use_groupwise_scale else 1 + scale_w_size_stage1 = experts * (2 * inter_dim) * num_groups + # For groupwise scale, weight scale is applied per-group in the K loop, + # so epilogue can skip weight scale multiplication (uses 1.0 for sw). + # For W4A8/W4A_FP8 groupwise: still need scale_x in epilogue. + mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( @@ -181,8 +214,8 @@ def compile_moe_gemm1( DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN - # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * (2 * inter_dim) * model_dim) // 2 if (is_int4 or is_int4_bf16) else (experts * (2 * inter_dim) * model_dim) + # W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte. + size_w = (experts * (2 * inter_dim) * model_dim) // 2 if w_is_int4 else (experts * (2 * inter_dim) * model_dim) size_sorted = DYN size_expert_ids = DYN @@ -213,10 +246,12 @@ def compile_moe_gemm1( epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" # IMPORTANT: module name participates in FlyDSL's compile cache key. # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. + _gs_tag = f"_g{group_size}" if use_groupwise_scale else "" module_name = ( f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults + f"{_gs_tag}" + f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults ).replace("-", "_") # ── LDS sizing (pure Python; no MLIR Context needed) ───────────────────── @@ -260,8 +295,8 @@ def moe_gemm1( tokens_i32_v = i32_tokens_in k_i32_v = i32_k_in x_elem = T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8)) - # For int4/int4_bf16, weights are stored as packed bytes (i8) and unpacked in-kernel. - w_elem = T.i8 if (is_int4 or is_int4_bf16) else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) + # For int4/int4_bf16/int4_fp8, weights are stored as packed bytes (i8) and unpacked in-kernel. + w_elem = T.i8 if w_is_int4 else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) vec16_elems = 16 if elem_bytes == 1 else 8 vec8_elems = 8 if elem_bytes == 1 else 4 vec4_elems = 4 if elem_bytes == 1 else 2 @@ -284,18 +319,20 @@ def silu(x): acc_init = ( arith.constant_vector(0, T.i32x4) - if is_int8 + if (is_int8 and not is_int4_groupwise) else arith.constant_vector(0.0, T.f32x4) ) + zero_i32_acc = arith.constant_vector(0, T.i32x4) if is_int4_groupwise else None + zero_f32_acc = arith.constant_vector(0.0, T.f32x4) if is_int4_fp8_groupwise else None # Layouts (use i32 values; fly.make_shape requires i32/i64, not index) layout_x = fx.make_layout((tokens_i32_v, k_i32_v), stride=(k_i32_v, 1)) # B preshuffle layout: match GEMM test helper exactly. c_n_total = arith.index(experts * (2 * inter_dim)) - # For packed int4 (W4A8/W4A16), kpack_bytes=8. - kpack_bytes = 8 if (is_int4 or is_int4_bf16) else 16 - w_elem_bytes = 1 if (is_int4 or is_int4_bf16) else elem_bytes + # For packed int4 (W4A8/W4A16/W4A_FP8), kpack_bytes=8. + kpack_bytes = 8 if w_is_int4 else 16 + w_elem_bytes = 1 if w_is_int4 else elem_bytes b_layout = make_preshuffle_b_layout( arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) @@ -603,10 +640,42 @@ def load_b_tile(base_k, blk_list, intra_list): Returns a list of length `k_unroll`, where each entry is a tuple: (packs_half0[ni], packs_half1[ni]) for the K64 micro-step. + For groupwise variants, each entry also includes per-group scales: + (packs0[ni], packs1[ni], scales0[ni], scales1[ni]) """ - if is_int4_bf16: - # W4A16: 2-phase load+unpack for VMEM latency hiding - # Phase 1: Issue ALL buffer_loads first. + if is_int4_bf16_groupwise: + # W4A16 groupwise: 2-phase load+unpack with per-group scale + raw_data = [] + for ku in range_constexpr(k_unroll): + raw_ku = [] + for ni in range_constexpr(num_acc_n): + packed32, scale_val = load_b_raw_w4a16_groupwise( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=blk_list[ni], n_intra=intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=2*inter_dim, + kpack_bytes=kpack_bytes, + ) + raw_ku.append((packed32, scale_val)) + raw_data.append(raw_ku) + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + for ni in range_constexpr(num_acc_n): + packed32, scale_val = raw_data[ku][ni] + b0, b1 = unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + elif is_int4_bf16: + # W4A16 per-row: 2-phase load+unpack for VMEM latency hiding raw_data = [] for ku in range_constexpr(k_unroll): raw_ku = [] @@ -621,7 +690,6 @@ def load_b_tile(base_k, blk_list, intra_list): ) raw_ku.append(raw) raw_data.append(raw_ku) - # Phase 2: Unpack ALL (by now early loads have completed). b_tile = [] for ku in range_constexpr(k_unroll): packs0 = [] @@ -632,19 +700,109 @@ def load_b_tile(base_k, blk_list, intra_list): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - ki0 = (ku * 2) + 0 - ki1 = (ku * 2) + 1 - b0 = load_b_pack(base_k, ki0, ni, blk_list, intra_list) - b1 = load_b_pack(base_k, ki1, ni, blk_list, intra_list) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile + elif is_int4_groupwise: + # W4A8 groupwise: 8B K64 weight load + 2 scale loads + unpack + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + scales0, scales1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1, sc0, sc1 = load_b_raw_w4a8_groupwise_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=blk_list[ni], n_intra=intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=2*inter_dim, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a8(h0, arith, vector)) + packs1.append(unpack_b_w4a8(h1, arith, vector)) + scales0.append(sc0) + scales1.append(sc1) + b_tile.append((packs0, packs1, scales0, scales1)) + return b_tile + elif is_int4_fp8_groupwise: + # W4A_FP8 groupwise: 8B K64 weight load + 2 scale loads + unpack to fp8 + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + scales0, scales1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1, sc0, sc1 = load_b_raw_w4a8_groupwise_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=blk_list[ni], n_intra=intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=2*inter_dim, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a_fp8(h0, arith, vector, rocdl)) + packs1.append(unpack_b_w4a_fp8(h1, arith, vector, rocdl)) + scales0.append(sc0) + scales1.append(sc1) + b_tile.append((packs0, packs1, scales0, scales1)) + return b_tile + elif is_int4_fp8: + # W4A_FP8 per-row: 8B K64 loads + unpack to fp8 + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1 = load_b_raw_w4a8_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=blk_list[ni], n_intra=intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a_fp8(h0, arith, vector, rocdl)) + packs1.append(unpack_b_w4a_fp8(h1, arith, vector, rocdl)) + b_tile.append((packs0, packs1)) + return b_tile + elif is_int4: + # W4A8 per-row: 8B K64 loads + unpack + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1 = load_b_raw_w4a8_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=blk_list[ni], n_intra=intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a8(h0, arith, vector)) + packs1.append(unpack_b_w4a8(h1, arith, vector)) + b_tile.append((packs0, packs1)) + return b_tile + else: + # fp8/int8/bf16/fp16: original code path + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni, blk_list, intra_list) + b1 = load_b_pack(base_k, ki1, ni, blk_list, intra_list) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile acc_gate = [acc_init] * (num_acc_n * m_repeat) acc_up = [acc_init] * (num_acc_n * m_repeat) @@ -740,7 +898,7 @@ def compute_tile( # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. epilogue_pf = None - if prefetch_epilogue: + if prefetch_epilogue and not use_groupwise_scale: expert_off_pf = expert_off_idx sw_gate_pf = [] sw_up_pf = [] @@ -783,40 +941,112 @@ def mfma_k64(acc_in, a0, a1, b0, b1): b1v = _i64_to_v4i16(b1) acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) + # fp8/int8: both A and B are raw i64. acc_mid = mfma_fn(mfma_res_ty, [a0, b0, acc_in, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1, b1, acc_mid, 0, 0, 0]) + + def _acc_scaled_i32_to_f32(f32_acc_vec, i32_partial_vec, scale_val): + """i32 MFMA partial -> sitofp -> scale -> add to f32 accumulator.""" + new_vals = [] + for ii in range_constexpr(4): + vi = vector.extract(i32_partial_vec, static_position=[ii], dynamic_position=[]) + vf = arith.sitofp(T.f32, vi) * scale_val + old = vector.extract(f32_acc_vec, static_position=[ii], dynamic_position=[]) + new_vals.append(old + vf) + return vector.from_elements(T.f32x4, new_vals) + + def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): + """FP8 MFMA f32 partial -> scale -> add to f32 accumulator.""" + new_vals = [] + for ii in range_constexpr(4): + vi = vector.extract(f32_partial_vec, static_position=[ii], dynamic_position=[]) + vf = vi * scale_val + old = vector.extract(f32_acc_vec, static_position=[ii], dynamic_position=[]) + new_vals.append(old + vf) + return vector.from_elements(T.f32x4, new_vals) + + if is_int4_fp8_groupwise: + # W4A_FP8 groupwise: per-K32 FP8 MFMA with fresh f32 acc -> scale -> f32 running acc. + for ku in range_constexpr(k_unroll): + b_gate_packs0, b_gate_packs1, b_gate_sc0, b_gate_sc1 = b_gate_tile_in[ku] + b_up_packs0, b_up_packs1, b_up_sc0, b_up_sc1 = b_up_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + # Half 0: fresh acc -> MFMA -> scale -> add to running acc + g0 = rocdl.mfma_f32_16x16x32_fp8_fp8(T.f32x4, [a0, b_gate_packs0[ni], zero_f32_acc, 0, 0, 0]) + gate_list[acc_idx] = _acc_scaled_f32(gate_list[acc_idx], g0, b_gate_sc0[ni]) + u0 = rocdl.mfma_f32_16x16x32_fp8_fp8(T.f32x4, [a0, b_up_packs0[ni], zero_f32_acc, 0, 0, 0]) + up_list[acc_idx] = _acc_scaled_f32(up_list[acc_idx], u0, b_up_sc0[ni]) + # Half 1 + g1 = rocdl.mfma_f32_16x16x32_fp8_fp8(T.f32x4, [a1, b_gate_packs1[ni], zero_f32_acc, 0, 0, 0]) + gate_list[acc_idx] = _acc_scaled_f32(gate_list[acc_idx], g1, b_gate_sc1[ni]) + u1 = rocdl.mfma_f32_16x16x32_fp8_fp8(T.f32x4, [a1, b_up_packs1[ni], zero_f32_acc, 0, 0, 0]) + up_list[acc_idx] = _acc_scaled_f32(up_list[acc_idx], u1, b_up_sc1[ni]) + elif is_int4_groupwise: + # W4A8 groupwise: per-K32 i32 MFMA with fresh i32 acc -> sitofp -> scale -> f32 running acc. + for ku in range_constexpr(k_unroll): + b_gate_packs0, b_gate_packs1, b_gate_sc0, b_gate_sc1 = b_gate_tile_in[ku] + b_up_packs0, b_up_packs1, b_up_sc0, b_up_sc1 = b_up_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + g0 = mfma_i32_k32(T.i32x4, [a0, b_gate_packs0[ni], zero_i32_acc, 0, 0, 0]) + gate_list[acc_idx] = _acc_scaled_i32_to_f32(gate_list[acc_idx], g0, b_gate_sc0[ni]) + u0 = mfma_i32_k32(T.i32x4, [a0, b_up_packs0[ni], zero_i32_acc, 0, 0, 0]) + up_list[acc_idx] = _acc_scaled_i32_to_f32(up_list[acc_idx], u0, b_up_sc0[ni]) + g1 = mfma_i32_k32(T.i32x4, [a1, b_gate_packs1[ni], zero_i32_acc, 0, 0, 0]) + gate_list[acc_idx] = _acc_scaled_i32_to_f32(gate_list[acc_idx], g1, b_gate_sc1[ni]) + u1 = mfma_i32_k32(T.i32x4, [a1, b_up_packs1[ni], zero_i32_acc, 0, 0, 0]) + up_list[acc_idx] = _acc_scaled_i32_to_f32(up_list[acc_idx], u1, b_up_sc1[ni]) + else: + for ku in range_constexpr(k_unroll): + b_gate_packs0, b_gate_packs1 = b_gate_tile_in[ku] + b_up_packs0, b_up_packs1 = b_up_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 - for ku in range_constexpr(k_unroll): - b_gate_packs0, b_gate_packs1 = b_gate_tile_in[ku] - b_up_packs0, b_up_packs1 = b_up_tile_in[ku] - ki64 = arith.index(ku * 64) - col_base = col_offset_base_bytes + ki64 - - for mi in range_constexpr(m_repeat): - mi_val = arith.index(mi * 16) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - gate_list[acc_idx] = mfma_k64( - gate_list[acc_idx], - a0, - a1, - b_gate_packs0[ni], - b_gate_packs1[ni], - ) - up_list[acc_idx] = mfma_k64( - up_list[acc_idx], - a0, - a1, - b_up_packs0[ni], - b_up_packs1[ni], - ) + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + gate_list[acc_idx] = mfma_k64( + gate_list[acc_idx], + a0, + a1, + b_gate_packs0[ni], + b_gate_packs1[ni], + ) + up_list[acc_idx] = mfma_k64( + up_list[acc_idx], + a0, + a1, + b_up_packs0[ni], + b_up_packs1[ni], + ) return gate_list, up_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- @@ -972,10 +1202,13 @@ def hot_loop_scheduler(): bx_m0 = bx_m tokens_i32_v = tokens_i32 topk_i32_v = topk_i32 - inter_i32_v = fx.Int32(inter_dim) - mask24_i32 = fx.Int32(0xFFFFFF) - - if epilogue_pf is not None: + inter_i32_v = arith.constant(inter_dim, type=T.i32) + mask24_i32 = arith.constant(0xFFFFFF, type=T.i32) + + if use_groupwise_scale: + sw_gate_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + sw_up_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + elif epilogue_pf is not None: sw_gate_vals, sw_up_vals = epilogue_pf else: sw_gate_vals = [] @@ -1068,7 +1301,7 @@ def write_row_to_lds( acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if is_int8 and not is_int4_groupwise: vg = arith.sitofp(T.f32, vg) vu = arith.sitofp(T.f32, vu) vg = vg * sx * sw_gate @@ -1186,7 +1419,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if is_int8 and not is_int4_groupwise: vg = arith.sitofp(T.f32, vg) vu = arith.sitofp(T.f32, vu) vg = vg * sx * sw_gate @@ -1289,6 +1522,7 @@ def compile_moe_gemm2( - "int8": A2/W are int8 - "int4": W4A8 path: A2 is int8, W is packed int4 unpacked to int8 in-kernel - "int4_bf16": W4A16 path: A2 is bf16, W is packed int4 unpacked to bf16 in-kernel + - "int4_fp8": W4A_FP8 path: A2 is fp8, W is packed int4 unpacked to fp8 in-kernel Stage2 output supports: - out_dtype="f16": fp16 half2 atomics (fast, can overflow to +/-inf for bf16 workloads) @@ -1301,14 +1535,17 @@ def compile_moe_gemm2( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"): + _valid_dtypes = ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "int4_fp8") + if in_dtype not in _valid_dtypes: raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16'), got {in_dtype!r}" + f"in_dtype must be one of {_valid_dtypes}, got {in_dtype!r}" ) - is_int4_bf16 = in_dtype == "int4_bf16" + is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights + is_int4_fp8 = in_dtype == "int4_fp8" # W4A_FP8: fp8 activations, packed int4 weights is_f16 = in_dtype == "fp16" is_bf16 = is_int4_bf16 or in_dtype == "bf16" is_f16_or_bf16 = is_f16 or is_bf16 + needs_scale_x = not is_f16_or_bf16 # True for fp8/int8/int4/int4_fp8 needs_scale_w = (not is_f16_or_bf16) or is_int4_bf16 elem_bytes = 2 if is_f16_or_bf16 else 1 out_s = str(out_dtype).strip().lower() @@ -1319,9 +1556,26 @@ def compile_moe_gemm2( if (not bool(accumulate)) and out_is_f32: raise ValueError("compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}") is_int4 = in_dtype == "int4" + # w_is_int4: True for any variant where weights are packed int4. + w_is_int4 = is_int4 or is_int4_bf16 or is_int4_fp8 # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. is_int8 = (in_dtype in ("int8", "int8smooth")) or is_int4 + # Group-wise scale support for W4A16, W4A8, and W4A_FP8 + use_groupwise_scale = w_is_int4 and group_size > 0 + if use_groupwise_scale and group_size != 32: + raise ValueError( + f"FlyDSL groupwise scale only supports group_size=32, got {group_size}. " + f"This is due to int4 preshuffle layout constraints. " + f"Please use Triton kernel for other group sizes." + ) + is_int4_groupwise = is_int4 and use_groupwise_scale + is_int4_fp8_groupwise = is_int4_fp8 and use_groupwise_scale + is_int4_bf16_groupwise = is_int4_bf16 and use_groupwise_scale + # Stage2 K dimension is inter_dim (weight shape: [E, model_dim, inter_dim]) + num_groups = inter_dim // group_size if use_groupwise_scale else 1 + scale_w_size_stage2 = experts * model_dim * num_groups + mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( @@ -1350,8 +1604,8 @@ def compile_moe_gemm2( size_sorted = DYN size_expert_ids_shape = DYN size_scale_x = DYN - # W is packed int4 for W4A8/W4A16: 2 values per byte. - size_w = (experts * model_dim * inter_dim) // 2 if (is_int4 or is_int4_bf16) else (experts * model_dim * inter_dim) + # W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte. + size_w = (experts * model_dim * inter_dim) // 2 if w_is_int4 else (experts * model_dim * inter_dim) total_threads = 256 tile_k_bytes = int(tile_k) * int(elem_bytes) @@ -1409,9 +1663,11 @@ def out_elem(): # IMPORTANT: module name participates in FlyDSL's compile cache key. # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. + _gs_tag = f"_g{group_size}" if use_groupwise_scale else "" module_name = ( f"mfma_moe2_{in_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" + f"{_gs_tag}" f"_abi2" # mask sentinel token ids on loads/stores to avoid illegal address faults ).replace("-", "_") @@ -1466,8 +1722,8 @@ def moe_gemm2( tokens_i32_v = i32_tokens_in k_i32_v = i32_k_in x_elem = T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8)) - # For int4/int4_bf16, weights are stored as packed bytes (i8) and unpacked in-kernel. - w_elem = T.i8 if (is_int4 or is_int4_bf16) else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) + # For int4/int4_bf16/int4_fp8, weights are stored as packed bytes (i8) and unpacked in-kernel. + w_elem = T.i8 if w_is_int4 else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) vec16_elems = 16 if elem_bytes == 1 else 8 vec8_elems = 8 if elem_bytes == 1 else 4 vec4_elems = 4 if elem_bytes == 1 else 2 @@ -1476,9 +1732,11 @@ def moe_gemm2( acc_init = ( arith.constant_vector(0, T.i32x4) - if is_int8 + if (is_int8 and not is_int4_groupwise) else arith.constant_vector(0.0, T.f32x4) ) + zero_i32_acc = arith.constant_vector(0, T.i32x4) if is_int4_groupwise else None + zero_f32_acc = arith.constant_vector(0.0, T.f32x4) if is_int4_fp8_groupwise else None # A2 layout (flatten token-slot -> M; use i32 for fly.make_shape). topk_idx = fx.Index(topk) @@ -1488,8 +1746,9 @@ def moe_gemm2( # B preshuffle layout: [experts*model_dim, inter_dim] c_n_total = arith.index(experts * model_dim) - kpack_bytes = 8 if (is_int4 or is_int4_bf16) else 16 - w_elem_bytes = 1 if (is_int4 or is_int4_bf16) else elem_bytes + # For packed int4 (W4A8/W4A16/W4A_FP8), kpack_bytes=8. + kpack_bytes = 8 if w_is_int4 else 16 + w_elem_bytes = 1 if w_is_int4 else elem_bytes b_layout = make_preshuffle_b_layout( arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) @@ -1790,9 +2049,42 @@ def load_b_tile(base_k): Returns a list of length `k_unroll`, where each entry is a tuple: (packs_half0[ni], packs_half1[ni]) for the K64 micro-step. + For groupwise variants, each entry also includes per-group scales: + (packs0[ni], packs1[ni], scales0[ni], scales1[ni]) """ - if is_int4_bf16: - # W4A16: 2-phase load+unpack for VMEM latency hiding + if is_int4_bf16_groupwise: + # W4A16 groupwise: 2-phase load+unpack with per-group scale + raw_data = [] + for ku in range_constexpr(k_unroll): + raw_ku = [] + for ni in range_constexpr(num_acc_n): + packed32, scale_val = load_b_raw_w4a16_groupwise( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=model_dim, + kpack_bytes=kpack_bytes, + ) + raw_ku.append((packed32, scale_val)) + raw_data.append(raw_ku) + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + for ni in range_constexpr(num_acc_n): + packed32, scale_val = raw_data[ku][ni] + b0, b1 = unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + elif is_int4_bf16: + # W4A16 per-row: 2-phase load+unpack for VMEM latency hiding raw_data = [] for ku in range_constexpr(k_unroll): raw_ku = [] @@ -1817,19 +2109,109 @@ def load_b_tile(base_k): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - ki0 = (ku * 2) + 0 - ki1 = (ku * 2) + 1 - b0 = load_b_pack(base_k, ki0, ni) - b1 = load_b_pack(base_k, ki1, ni) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile + elif is_int4_groupwise: + # W4A8 groupwise: 8B K64 weight load + 2 scale loads + unpack + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + scales0, scales1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1, sc0, sc1 = load_b_raw_w4a8_groupwise_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=model_dim, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a8(h0, arith, vector)) + packs1.append(unpack_b_w4a8(h1, arith, vector)) + scales0.append(sc0) + scales1.append(sc1) + b_tile.append((packs0, packs1, scales0, scales1)) + return b_tile + elif is_int4_fp8_groupwise: + # W4A_FP8 groupwise: 8B K64 weight load + 2 scale loads + unpack to fp8 + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + scales0, scales1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1, sc0, sc1 = load_b_raw_w4a8_groupwise_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=model_dim, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a_fp8(h0, arith, vector, rocdl)) + packs1.append(unpack_b_w4a_fp8(h1, arith, vector, rocdl)) + scales0.append(sc0) + scales1.append(sc1) + b_tile.append((packs0, packs1, scales0, scales1)) + return b_tile + elif is_int4_fp8: + # W4A_FP8 per-row: 8B K64 loads + unpack to fp8 + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1 = load_b_raw_w4a8_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a_fp8(h0, arith, vector, rocdl)) + packs1.append(unpack_b_w4a_fp8(h1, arith, vector, rocdl)) + b_tile.append((packs0, packs1)) + return b_tile + elif is_int4: + # W4A8 per-row: 8B K64 loads + unpack + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0, packs1 = [], [] + for ni in range_constexpr(num_acc_n): + h0, h1 = load_b_raw_w4a8_k64( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + kpack_bytes=kpack_bytes, + ) + packs0.append(unpack_b_w4a8(h0, arith, vector)) + packs1.append(unpack_b_w4a8(h1, arith, vector)) + b_tile.append((packs0, packs1)) + return b_tile + else: + # fp8/int8/bf16/fp16: original code path + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni) + b1 = load_b_pack(base_k, ki1, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base): @@ -1910,7 +2292,7 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False ) epilogue_pf = None - if prefetch_epilogue: + if prefetch_epilogue and not use_groupwise_scale: expert_off_pf = expert_off_idx sw_pf = [] for ni in range_constexpr(num_acc_n): @@ -1965,30 +2347,91 @@ def mfma_k64(acc0, a0, a1, b0, b1): return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) acc1 = mfma_fn(mfma_res_ty, [a0, b0, acc0, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1, b1, acc1, 0, 0, 0]) + + def _acc_scaled_i32_to_f32(f32_acc_vec, i32_partial_vec, scale_val): + """i32 MFMA partial -> sitofp -> scale -> add to f32 accumulator.""" + new_vals = [] + for ii in range_constexpr(4): + vi = vector.extract(i32_partial_vec, static_position=[ii], dynamic_position=[]) + vf = arith.sitofp(T.f32, vi) * scale_val + old = vector.extract(f32_acc_vec, static_position=[ii], dynamic_position=[]) + new_vals.append(old + vf) + return vector.from_elements(T.f32x4, new_vals) + + def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): + """FP8 MFMA f32 partial -> scale -> add to f32 accumulator.""" + new_vals = [] + for ii in range_constexpr(4): + vi = vector.extract(f32_partial_vec, static_position=[ii], dynamic_position=[]) + vf = vi * scale_val + old = vector.extract(f32_acc_vec, static_position=[ii], dynamic_position=[]) + new_vals.append(old + vf) + return vector.from_elements(T.f32x4, new_vals) + + if is_int4_fp8_groupwise: + # W4A_FP8 groupwise: per-K32 FP8 MFMA with fresh f32 acc -> scale -> f32 running acc. + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1, b_sc0, b_sc1 = b_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + # Half 0: fresh acc -> MFMA -> scale -> add to running acc + p0 = rocdl.mfma_f32_16x16x32_fp8_fp8(T.f32x4, [a0, b_packs0[ni], zero_f32_acc, 0, 0, 0]) + acc_list[acc_idx] = _acc_scaled_f32(acc_list[acc_idx], p0, b_sc0[ni]) + # Half 1 + p1 = rocdl.mfma_f32_16x16x32_fp8_fp8(T.f32x4, [a1, b_packs1[ni], zero_f32_acc, 0, 0, 0]) + acc_list[acc_idx] = _acc_scaled_f32(acc_list[acc_idx], p1, b_sc1[ni]) + elif is_int4_groupwise: + # W4A8 groupwise: per-K32 i32 MFMA with fresh i32 acc -> sitofp -> scale -> f32 running acc. + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1, b_sc0, b_sc1 = b_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + p0 = mfma_i32_k32(T.i32x4, [a0, b_packs0[ni], zero_i32_acc, 0, 0, 0]) + acc_list[acc_idx] = _acc_scaled_i32_to_f32(acc_list[acc_idx], p0, b_sc0[ni]) + p1 = mfma_i32_k32(T.i32x4, [a1, b_packs1[ni], zero_i32_acc, 0, 0, 0]) + acc_list[acc_idx] = _acc_scaled_i32_to_f32(acc_list[acc_idx], p1, b_sc1[ni]) + else: + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1 = b_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 - for ku in range_constexpr(k_unroll): - b_packs0, b_packs1 = b_tile_in[ku] - ki64 = arith.index(ku * 64) - col_base = col_offset_base_bytes + ki64 - - for mi in range_constexpr(m_repeat): - mi_val = arith.index(mi * 16) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - acc_list[acc_idx] = mfma_k64( - acc_list[acc_idx], - a0, - a1, - b_packs0[ni], - b_packs1[ni], - ) + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + acc_list[acc_idx] = mfma_k64( + acc_list[acc_idx], + a0, + a1, + b_packs0[ni], + b_packs1[ni], + ) return acc_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- @@ -2190,7 +2633,10 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): sw_pf, tw_pf = epilogue_pf # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). - if sw_pf is not None: + if use_groupwise_scale: + # Groupwise: weight scale already applied per-group in K-loop. + sw_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + elif sw_pf is not None: sw_vals = sw_pf else: sw_vals = [] @@ -2257,7 +2703,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int8 and not is_int4_groupwise: v = arith.sitofp(T.f32, v) v = v * sx * sw if doweight_stage2: @@ -2286,7 +2732,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): # stable path here.) out_base_idx = None if out_is_bf16: - out_base_idx = memref.extract_aligned_pointer_as_index(arg_out) + _out_raw = arg_out.value if hasattr(arg_out, "value") else arg_out; _out_ptr = _fly_dialect.extract_aligned_pointer_as_index(ir.Type.parse("!llvm.ptr"), _out_raw); out_base_idx = arith.index_cast(T.index, llvm.PtrToIntOp(T.i64, _out_ptr).result) def write_row_to_lds( *, @@ -2333,7 +2779,7 @@ def write_row_to_lds( sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int8 and not is_int4_groupwise: v = arith.sitofp(T.f32, v) v = v * sx * sw if doweight_stage2: @@ -2377,7 +2823,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): byte_off = idx_elem_even * c2_i32 byte_off_idx = arith.index_cast(T.index, byte_off) ptr_addr_idx = out_base_idx + byte_off_idx - out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1) + out_ptr = _create_llvm_ptr(ptr_addr_idx, address_space=1) out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr frag_v = frag._value if hasattr(frag, "_value") else frag llvm.AtomicRMWOp( diff --git a/python/flydsl/expr/vector.py b/python/flydsl/expr/vector.py index c4af4c0e..00646757 100644 --- a/python/flydsl/expr/vector.py +++ b/python/flydsl/expr/vector.py @@ -93,10 +93,17 @@ def load_op(result_type, memref, indices, *, loc=None, ip=None): def bitcast(result_type, source, *, loc=None, ip=None): """Wrapper around `vector.BitCastOp(...).result`.""" from . import arith as _arith_ext + from .._mlir import ir as _ir + + unwrapped = _arith_ext.unwrap(source, loc=loc) + # MLIR 23+ requires vector input for vector.bitcast; auto-wrap scalars. + if not isinstance(unwrapped.type, _ir.VectorType): + vec1_ty = _ir.VectorType.get([1], unwrapped.type) + unwrapped = _vector.FromElementsOp(vec1_ty, [unwrapped], loc=loc, ip=ip).result return _vector.BitCastOp( result_type, - _arith_ext.unwrap(source, loc=loc), + unwrapped, loc=loc, ip=ip, ).result diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index f7edd1fb..26d0647c 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -383,21 +383,29 @@ def run_moe_stage1( blocks, ) = routing - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"): + _valid = ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "int4_fp8") + if in_dtype not in _valid: raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16'), got {in_dtype!r}" + f"in_dtype must be one of {_valid}, got {in_dtype!r}" ) is_int4 = in_dtype == "int4" is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights + is_int4_fp8 = in_dtype == "int4_fp8" + w_is_int4 = is_int4 or is_int4_bf16 or is_int4_fp8 is_int8 = in_dtype in ("int8", "int8smooth", "int4") is_int8smooth = in_dtype == "int8smooth" # Quantize inputs / weights. - if in_dtype == "fp8": - x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) # [tokens,K], [tokens,1] - w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) # [E,2*inter,K], [E,2*inter,1] + if in_dtype in ("fp8", "int4_fp8"): + if is_int4_fp8: + x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) + w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) + w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + else: + x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) # [tokens,K], [tokens,1] + w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) # [E,2*inter,K], [E,2*inter,1] # w2 is not used by our kernel, but required by aiter stage1 API - w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) + w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) elif in_dtype == "fp16": x_q = x_fp32.to(torch.float16) w1_q = w1_fp32.to(torch.float16) @@ -448,8 +456,8 @@ def run_moe_stage1( w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) - # --- Groupwise scale for W4A16 --- - use_groupwise_scale = is_int4_bf16 and group_size > 0 + # --- Groupwise scale for int4 variants --- + use_groupwise_scale = (is_int4_bf16 or is_int4 or is_int4_fp8) and group_size > 0 scale_w1_groups = None # [E, K//group_size, 2*inter_dim] for kernel (Opt 0 layout) if use_groupwise_scale: N_total = 2 * inter_dim @@ -482,9 +490,9 @@ def run_moe_stage1( if is_int8smooth else x_q.contiguous().view(tokens, model_dim) ) - # Pack weights for int4 variants (W4A8 and W4A16). - # Both use the same interleaved packing: [ (v4<<4)|v0, (v5<<4)|v1, (v6<<4)|v2, (v7<<4)|v3 ] - use_packed_int4 = is_int4 or is_int4_bf16 + # Pack weights for int4 variants (W4A8, W4A16, W4A_FP8). + # All use the same interleaved packing: [ (v4<<4)|v0, (v5<<4)|v1, (v6<<4)|v2, (v7<<4)|v3 ] + use_packed_int4 = w_is_int4 w_kernel = ( _pack_shuffled_int8_to_packed_int4_no_perm(w1_shuffled_flat) if use_packed_int4 else w1_shuffled_flat ).contiguous() @@ -579,8 +587,8 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): scale_w1_groups=scale_w1_groups, ) - rtol = 0.5 if (is_int4 or is_int4_bf16) else 0.25 - atol = 0.5 if (is_int4 or is_int4_bf16) else 0.25 + rtol = 0.5 if w_is_int4 else 0.25 + atol = 0.5 if w_is_int4 else 0.25 assert verify_output(out.to(torch.float32), ref, rtol=rtol, atol=atol) # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. @@ -819,20 +827,28 @@ def run_moe_stage2( # NOTE: routing uses `moe_sorting` output directly (no host trim/pad). Extra launched blocks # are gated by `num_valid_ids` inside the kernels. - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"): + _valid = ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "int4_fp8") + if in_dtype not in _valid: raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16'), got {in_dtype!r}" + f"in_dtype must be one of {_valid}, got {in_dtype!r}" ) is_int4 = in_dtype == "int4" is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights + is_int4_fp8 = in_dtype == "int4_fp8" + w_is_int4 = is_int4 or is_int4_bf16 or is_int4_fp8 is_int8 = in_dtype in ("int8", "int8smooth", "int4") is_int8smooth = in_dtype == "int8smooth" # Quantize inputs / weights. - if in_dtype == "fp8": - x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) - w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) - w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) + if in_dtype in ("fp8", "int4_fp8"): + if is_int4_fp8: + x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) + w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) + w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + else: + x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) + w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) + w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) elif in_dtype == "fp16": x_q = x_fp32.to(torch.float16) w1_q = w1_fp32.to(torch.float16) @@ -869,8 +885,8 @@ def run_moe_stage2( w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) - # --- Groupwise scale for W4A16 (stage 2) --- - use_groupwise_scale = is_int4_bf16 and group_size > 0 + # --- Groupwise scale for int4 variants (stage 2) --- + use_groupwise_scale = (is_int4_bf16 or is_int4 or is_int4_fp8) and group_size > 0 scale_w2_groups = None # [E, inter_dim//group_size, model_dim] Opt 0 layout if use_groupwise_scale: num_groups_s2 = inter_dim // group_size @@ -914,7 +930,7 @@ def run_moe_stage2( inter_dim=inter_dim, doweight_stage1=bool(doweight_stage1), ) # [tokens, topk, inter] fp32 - if in_dtype == "fp8": + if in_dtype in ("fp8", "int4_fp8"): a2_q, a2_scale = pertoken_quant(out1_ref, quant_dtype=DTYPE_FP8) elif in_dtype == "fp16": a2_q = out1_ref.to(torch.float16) @@ -937,9 +953,9 @@ def run_moe_stage2( w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim) scale_w2_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, 1) - # For W4A8 and W4A16, pack preshuffled int8 weights into packed int4 bytes. - # Both use the same interleaved packing: [ (v4<<4)|v0, (v5<<4)|v1, (v6<<4)|v2, (v7<<4)|v3 ] - use_packed_int4 = is_int4 or is_int4_bf16 + # For int4 variants (W4A8, W4A16, W4A_FP8), pack preshuffled int8 weights into packed int4 bytes. + # All use the same interleaved packing: [ (v4<<4)|v0, (v5<<4)|v1, (v6<<4)|v2, (v7<<4)|v3 ] + use_packed_int4 = w_is_int4 w2_kernel = w2_shuffled_flat if use_packed_int4: w2_kernel = _pack_shuffled_int8_to_packed_int4_no_perm(w2_shuffled_flat) @@ -965,10 +981,12 @@ def run_moe_stage2( out_s = str(out_dtype).strip().lower() if out_s in ("f16", "fp16", "half"): out_torch_dtype = torch.float16 + elif out_s in ("bf16", "bfloat16"): + out_torch_dtype = torch.bfloat16 elif out_s in ("f32", "fp32", "float"): out_torch_dtype = torch.float32 else: - raise ValueError(f"out_dtype must be 'f16' or 'f32', got {out_dtype!r}") + raise ValueError(f"out_dtype must be 'f16', 'bf16', or 'f32', got {out_dtype!r}") out = torch.zeros((tokens, model_dim), device=device, dtype=out_torch_dtype) out_perf = torch.zeros_like(out) @@ -990,6 +1008,10 @@ def run_moe_stage2( is_reduce_exe = (getattr(exe, "mode", None) == MoeGemm2Mode.REDUCE) or bool(use_reduce) def launch(o, x, w, sx, sw, st, eids, sw_sorted): + # Atomic accumulation mode adds to output; must zero between launches + # to avoid multi-run overflow (especially for f16 atomics). + if not is_reduce_exe: + o.zero_() stream = torch.cuda.current_stream() valid_mask = None if is_reduce_exe and bool(use_valid_mask): @@ -1088,7 +1110,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): a2_elem_bytes = 2 if in_dtype in ("int4_bf16", "bf16", "fp16") else 1 # bf16/fp16 activations bytes_moved += tokens * topk * inter_dim * a2_elem_bytes # a2 (logical) bytes_moved += (experts * model_dim * inter_dim) // (2 if is_int4 else 1) # w2 (packed for int4) - bytes_moved += tokens * model_dim * (2 if out_torch_dtype == torch.float16 else 4) # out + bytes_moved += tokens * model_dim * (2 if out_torch_dtype in (torch.float16, torch.bfloat16) else 4) # out bytes_moved += tokens * topk * 4 # a2_scale f32 (logical) bytes_moved += experts * model_dim * 4 # w2_scale f32 (1D) bytes_moved += int(sorted_weights.numel()) * 4 @@ -1203,8 +1225,8 @@ def launch_ck(o, a2_, w1_, w2_, sorted_ids_, sorted_eids_, num_valid_, w2_scale_ pytest.param(333, 4096, 2048, 17, 9, 64, 128, 128, 256, 128, False, id="L", marks=pytest.mark.large_shape), ], ) -@pytest.mark.parametrize("in_dtype", ["fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"]) -@pytest.mark.parametrize("out_dtype", ["f16", "f32"], ids=["out_f16", "out_f32"]) +@pytest.mark.parametrize("in_dtype", ["fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "int4_fp8"]) +@pytest.mark.parametrize("out_dtype", ["f16", "bf16", "f32"], ids=["out_f16", "out_bf16", "out_f32"]) @pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) @pytest.mark.parametrize("use_valid_mask", [False, True], ids=["nomask", "mask"]) @pytest.mark.parametrize("test_graph", [ @@ -1249,8 +1271,8 @@ def test_moe_gemm_2stage( out_s = str(out_dtype).strip().lower() if bool(use_reduce) and out_s in ("f32", "fp32", "float"): pytest.skip("reduce mode does not support out_dtype='f32' (compile_moe_gemm2(accumulate=False) forbids it).") - if group_size > 0 and in_dtype != "int4_bf16": - pytest.skip("groupwise scale only applies to int4_bf16") + if group_size > 0 and in_dtype not in ("int4_bf16", "int4", "int4_fp8"): + pytest.skip("groupwise scale only applies to int4_bf16, int4, int4_fp8") device = torch.device("cuda") # torch.manual_seed(int(seed)) @@ -1323,7 +1345,7 @@ def test_moe_gemm_2stage( # a2_q = torch.ones_like(out1_fp16, dtype=torch.float32) / 5 # w2_fp32 = torch.ones_like(w2_fp32, dtype=torch.float32) / 10 a2_scale = None - elif in_dtype == "fp8": + elif in_dtype in ("fp8", "int4_fp8"): out1_fp32 = out1_fp16.to(torch.float32) a2_q, a2_scale = pertoken_quant(out1_fp32, quant_dtype=DTYPE_FP8) elif in_dtype == "fp16": diff --git a/tests/test_common.py b/tests/test_common.py index 2f27592e..66632bf5 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -179,7 +179,6 @@ def run_perftest( needTrace=False, **kwargs, ): - @perftest( num_iters=num_iters, num_warmup=num_warmup,