Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,277 @@ def lds_load_pack_k32(
"lds_store_16b_xor16",
"make_preshuffle_b_layout",
"load_b_pack_k32",
"load_b_raw_w4a16",
"unpack_b_w4a16",
"load_b_raw_w4a16_groupwise",
"unpack_b_w4a16_groupwise",
"load_b_raw_w4a8_k64",
"load_b_raw_w4a8_groupwise_k64",
"unpack_b_w4a8",
"unpack_b_w4a_fp8",
"swizzle_xor16",
"tile_chunk_coord_i32",
]


# ---------------------------------------------------------------------------
# W4A16 groupwise load / unpack helpers
# ---------------------------------------------------------------------------

def load_b_raw_w4a16_groupwise(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A16 groupwise B load: buffer_loads for weight + scale.

Reuses :func:`load_b_raw_w4a16` for the weight load, then issues an
additional ``buffer_load_dword`` for the per-group scale.

Returns ``(packed32, scale_val)``.
"""
packed32 = load_b_raw_w4a16(
buffer_ops, arith, vector,
arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b,
base_k=base_k, ku=ku,
n_blk=n_blk, n_intra=n_intra,
lane_div_16=lane_div_16, elem_type=elem_type,
kpack_bytes=kpack_bytes,
)

i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()

c_ku_elems = arith.constant(ku * 32, index=True)
k_pos_base = base_k + c_ku_elems
c_group_size = arith.constant(group_size, index=True)
group_idx = k_pos_base / c_group_size

c16 = arith.constant(16, index=True)
n_global = n_blk * c16 + n_intra
c_gm1 = arith.constant(num_groups - 1, index=True)
c_npe = arith.constant(n_per_expert, index=True)
# n_global is the GLOBAL N index (includes expert offset), so use (G-1)
# to compensate: expert_offset*(G-1) + (expert_offset + n_within) = expert_offset*G + n_within
scale_idx = expert_offset * c_gm1 + n_global + group_idx * c_npe
scale_idx_i32 = arith.index_cast(i32, scale_idx)
scale_val = buffer_ops.buffer_load(scale_rsrc, scale_idx_i32, vec_width=1, dtype=f32)

return (packed32, scale_val)


def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector):
"""Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16."""
return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val)


# ---------------------------------------------------------------------------
# W4A8 load / unpack helpers (8B K64 loads)
# ---------------------------------------------------------------------------

def load_b_raw_w4a8_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 per-row B load: 8-byte buffer_load_dwordx2 for one K64 step.

Loads both K32 halves in a single VMEM instruction (``buffer_load_dwordx2``).
Returns ``(packed32_half0, packed32_half1)`` for :func:`unpack_b_w4a8`.
"""
if kpack_bytes != 8:
raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}")

c64 = arith.constant(64, index=True)
k0_base = base_k / c64
k0 = k0_base + arith.constant(ku, index=True)
k1 = lane_div_16

coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)

b8 = _buffer_load_vec(
buffer_ops, vector, b_rsrc, idx_pack,
elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True,
)
b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8)
half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[])
half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[])
return (half0, half1)


def load_b_raw_w4a8_groupwise_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 groupwise B load: 8B weight + two scale loads per K64.

One ``buffer_load_dwordx2`` for both K32 weight halves, plus two
``buffer_load_dword`` for the per-group scales (each K32 half may
belong to a different group).

Returns ``(half0, half1, scale0, scale1)``.
"""
if kpack_bytes != 8:
raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}")

i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()

c64 = arith.constant(64, index=True)
k0_base = base_k / c64
k0 = k0_base + arith.constant(ku, index=True)
k1 = lane_div_16

coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)

b8 = _buffer_load_vec(
buffer_ops, vector, b_rsrc, idx_pack,
elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True,
)
b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8)
half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[])
half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[])

c16 = arith.constant(16, index=True)
n_global = n_blk * c16 + n_intra
c_group_size = arith.constant(group_size, index=True)
c_gm1 = arith.constant(num_groups - 1, index=True)
c_npe = arith.constant(n_per_expert, index=True)
# n_global includes expert offset, so use (G-1) to compensate
base_scale = expert_offset * c_gm1 + n_global

k_pos0 = base_k + arith.constant(ku * 2 * 32, index=True)
group_idx0 = k_pos0 / c_group_size
s0_i32 = arith.index_cast(i32, base_scale + group_idx0 * c_npe)
scale0 = buffer_ops.buffer_load(scale_rsrc, s0_i32, vec_width=1, dtype=f32)

k_pos1 = base_k + arith.constant((ku * 2 + 1) * 32, index=True)
group_idx1 = k_pos1 / c_group_size
s1_i32 = arith.index_cast(i32, base_scale + group_idx1 * c_npe)
scale1 = buffer_ops.buffer_load(scale_rsrc, s1_i32, vec_width=1, dtype=f32)

return (half0, half1, scale0, scale1)


def unpack_b_w4a8(packed32, arith, vector):
"""Phase 2 of W4A8 B load: 7-op unpack from packed int4 to int8 i64.

Takes a raw ``packed32`` (one dword of packed int4) and produces one i64
value containing 8 signed int8 bytes for one MFMA K32 step.
"""
c_08080808 = arith.constant(0x08080808, type=T.i32)
c_0f0f0f0f = arith.constant(0x0F0F0F0F, type=T.i32)
c_1e = arith.constant(0x1E, type=T.i32)
c_4_i32 = arith.constant(4, type=T.i32)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

v2 = vector.from_elements(T.vec(2, T.i32), [even, odd])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])


def unpack_b_w4a_fp8(packed32, arith, vector, rocdl):
"""Unpack packed int4 (i32) to fp8 i64 for mfma_f32_16x16x32_fp8_fp8.

Pipeline: int4 -> int8 (7-op unpack) -> f32 (byte extract + sitofp)
-> fp8 (cvt_pk_fp8_f32) -> i64.
"""
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()

c_08080808 = arith.constant(0x08080808, type=i32)
c_0f0f0f0f = arith.constant(0x0F0F0F0F, type=i32)
c_1e = arith.constant(0x1E, type=i32)
c_4_i32 = arith.constant(4, type=i32)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

c_8 = arith.constant(8, type=i32)
c_16 = arith.constant(16, type=i32)
c_24 = arith.constant(24, type=i32)

from flydsl._mlir.dialects._arith_ops_gen import ShRSIOp as _ShRSIOp
_uw = arith._to_raw
_av = arith.ArithValue

def _i32_int8x4_to_fp8x4(val):
"""Convert i32 containing 4 signed int8 bytes -> i32 containing 4 fp8 bytes."""
def _sext_byte(src, shl_amount, shr_amount):
shifted = src << shl_amount
shrsi_result = _ShRSIOp(_uw(shifted), _uw(shr_amount)).result
return _uw(arith.sitofp(f32, _av(shrsi_result)))

f0 = _sext_byte(val, c_24, c_24)
f1 = _sext_byte(val, c_16, c_24)
f2 = _sext_byte(val, c_8, c_24)
b3 = _ShRSIOp(_uw(val), _uw(c_24)).result
f3 = _uw(arith.sitofp(f32, _av(b3)))

zero = _uw(arith.constant(0, type=i32))
pk = rocdl.cvt_pk_fp8_f32(src_a=f0, src_b=f1, old=zero, word_sel=0, res=i32)
pk = rocdl.cvt_pk_fp8_f32(src_a=f2, src_b=f3, old=_uw(pk), word_sel=1, res=i32)
return pk

even_fp8 = _i32_int8x4_to_fp8x4(even)
odd_fp8 = _i32_int8x4_to_fp8x4(odd)

v2 = vector.from_elements(T.vec(2, T.i32), [even_fp8, odd_fp8])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])
6 changes: 4 additions & 2 deletions kernels/moe_blockscale_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from flydsl._mlir import ir
from flydsl._mlir.dialects import llvm, scf, memref
from flydsl._mlir.dialects import fly as _fly_dialect
from kernels.kernels_common import _create_llvm_ptr
from flydsl.expr.typing import T

from kernels.mfma_preshuffle_pipeline import (
Expand Down Expand Up @@ -2115,7 +2117,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row):
# stable path here.)
out_base_idx = None
if out_is_bf16:
out_base_idx = memref.extract_aligned_pointer_as_index(arg_out)
_out_raw = arg_out.value if hasattr(arg_out, "value") else arg_out; _out_ptr = _fly_dialect.extract_aligned_pointer_as_index(ir.Type.parse("!llvm.ptr"), _out_raw); out_base_idx = arith.index_cast(T.index, llvm.PtrToIntOp(T.i64, _out_ptr).result)

def write_row_to_lds(
*,
Expand Down Expand Up @@ -2206,7 +2208,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag):
byte_off = idx_elem_even * c2_i32
byte_off_idx = arith.index_cast(T.index, byte_off)
ptr_addr_idx = out_base_idx + byte_off_idx
out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr = _create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr
frag_v = frag._value if hasattr(frag, "_value") else frag
llvm.AtomicRMWOp(
Expand Down
Loading
Loading