Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,277 @@ def lds_load_pack_k32(
"lds_store_16b_xor16",
"make_preshuffle_b_layout",
"load_b_pack_k32",
"load_b_raw_w4a16",
"unpack_b_w4a16",
"load_b_raw_w4a16_groupwise",
"unpack_b_w4a16_groupwise",
"load_b_raw_w4a8_k64",
"load_b_raw_w4a8_groupwise_k64",
"unpack_b_w4a8",
"unpack_b_w4a_fp8",
"swizzle_xor16",
"tile_chunk_coord_i32",
]


# ---------------------------------------------------------------------------
# W4A16 groupwise load / unpack helpers
# ---------------------------------------------------------------------------

def load_b_raw_w4a16_groupwise(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A16 groupwise B load: buffer_loads for weight + scale.

Reuses :func:`load_b_raw_w4a16` for the weight load, then issues an
additional ``buffer_load_dword`` for the per-group scale.

Returns ``(packed32, scale_val)``.
"""
packed32 = load_b_raw_w4a16(
buffer_ops, arith, vector,
arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b,
base_k=base_k, ku=ku,
n_blk=n_blk, n_intra=n_intra,
lane_div_16=lane_div_16, elem_type=elem_type,
kpack_bytes=kpack_bytes,
)

i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()

c_ku_elems = arith.constant(ku * 32, index=True)
k_pos_base = base_k + c_ku_elems
c_group_size = arith.constant(group_size, index=True)
group_idx = k_pos_base / c_group_size

c16 = arith.constant(16, index=True)
n_global = n_blk * c16 + n_intra
c_gm1 = arith.constant(num_groups - 1, index=True)
c_npe = arith.constant(n_per_expert, index=True)
# n_global is the GLOBAL N index (includes expert offset), so use (G-1)
# to compensate: expert_offset*(G-1) + (expert_offset + n_within) = expert_offset*G + n_within
scale_idx = expert_offset * c_gm1 + n_global + group_idx * c_npe
scale_idx_i32 = arith.index_cast(i32, scale_idx)
scale_val = buffer_ops.buffer_load(scale_rsrc, scale_idx_i32, vec_width=1, dtype=f32)

return (packed32, scale_val)


def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector):
"""Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16."""
return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val)


# ---------------------------------------------------------------------------
# W4A8 load / unpack helpers (8B K64 loads)
# ---------------------------------------------------------------------------

def load_b_raw_w4a8_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 per-row B load: 8-byte buffer_load_dwordx2 for one K64 step.

Loads both K32 halves in a single VMEM instruction (``buffer_load_dwordx2``).
Returns ``(packed32_half0, packed32_half1)`` for :func:`unpack_b_w4a8`.
"""
if kpack_bytes != 8:
raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}")

c64 = arith.constant(64, index=True)
k0_base = base_k / c64
k0 = k0_base + arith.constant(ku, index=True)
k1 = lane_div_16

coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)

b8 = _buffer_load_vec(
buffer_ops, vector, b_rsrc, idx_pack,
elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True,
)
b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8)
half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[])
half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[])
return (half0, half1)


def load_b_raw_w4a8_groupwise_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 groupwise B load: 8B weight + two scale loads per K64.

One ``buffer_load_dwordx2`` for both K32 weight halves, plus two
``buffer_load_dword`` for the per-group scales (each K32 half may
belong to a different group).

Returns ``(half0, half1, scale0, scale1)``.
"""
if kpack_bytes != 8:
raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}")

i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()

c64 = arith.constant(64, index=True)
k0_base = base_k / c64
k0 = k0_base + arith.constant(ku, index=True)
k1 = lane_div_16

coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)

b8 = _buffer_load_vec(
buffer_ops, vector, b_rsrc, idx_pack,
elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True,
)
b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8)
half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[])
half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[])

c16 = arith.constant(16, index=True)
n_global = n_blk * c16 + n_intra
c_group_size = arith.constant(group_size, index=True)
c_gm1 = arith.constant(num_groups - 1, index=True)
c_npe = arith.constant(n_per_expert, index=True)
# n_global includes expert offset, so use (G-1) to compensate
base_scale = expert_offset * c_gm1 + n_global

k_pos0 = base_k + arith.constant(ku * 2 * 32, index=True)
group_idx0 = k_pos0 / c_group_size
s0_i32 = arith.index_cast(i32, base_scale + group_idx0 * c_npe)
scale0 = buffer_ops.buffer_load(scale_rsrc, s0_i32, vec_width=1, dtype=f32)

k_pos1 = base_k + arith.constant((ku * 2 + 1) * 32, index=True)
group_idx1 = k_pos1 / c_group_size
s1_i32 = arith.index_cast(i32, base_scale + group_idx1 * c_npe)
scale1 = buffer_ops.buffer_load(scale_rsrc, s1_i32, vec_width=1, dtype=f32)

return (half0, half1, scale0, scale1)


def unpack_b_w4a8(packed32, arith, vector):
"""Phase 2 of W4A8 B load: 7-op unpack from packed int4 to int8 i64.

Takes a raw ``packed32`` (one dword of packed int4) and produces one i64
value containing 8 signed int8 bytes for one MFMA K32 step.
"""
c_08080808 = arith.constant(0x08080808, type=T.i32)
c_0f0f0f0f = arith.constant(0x0F0F0F0F, type=T.i32)
c_1e = arith.constant(0x1E, type=T.i32)
c_4_i32 = arith.constant(4, type=T.i32)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

v2 = vector.from_elements(T.vec(2, T.i32), [even, odd])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])


def unpack_b_w4a_fp8(packed32, arith, vector, rocdl):
"""Unpack packed int4 (i32) to fp8 i64 for mfma_f32_16x16x32_fp8_fp8.

Pipeline: int4 -> int8 (7-op unpack) -> f32 (byte extract + sitofp)
-> fp8 (cvt_pk_fp8_f32) -> i64.
"""
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()

c_08080808 = arith.constant(0x08080808, type=i32)
c_0f0f0f0f = arith.constant(0x0F0F0F0F, type=i32)
c_1e = arith.constant(0x1E, type=i32)
c_4_i32 = arith.constant(4, type=i32)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

c_8 = arith.constant(8, type=i32)
c_16 = arith.constant(16, type=i32)
c_24 = arith.constant(24, type=i32)

from flydsl._mlir.dialects._arith_ops_gen import ShRSIOp as _ShRSIOp
_uw = arith._to_raw
_av = arith.ArithValue

def _i32_int8x4_to_fp8x4(val):
"""Convert i32 containing 4 signed int8 bytes -> i32 containing 4 fp8 bytes."""
def _sext_byte(src, shl_amount, shr_amount):
shifted = src << shl_amount
shrsi_result = _ShRSIOp(_uw(shifted), _uw(shr_amount)).result
return _uw(arith.sitofp(f32, _av(shrsi_result)))

f0 = _sext_byte(val, c_24, c_24)
f1 = _sext_byte(val, c_16, c_24)
f2 = _sext_byte(val, c_8, c_24)
b3 = _ShRSIOp(_uw(val), _uw(c_24)).result
f3 = _uw(arith.sitofp(f32, _av(b3)))

zero = _uw(arith.constant(0, type=i32))
pk = rocdl.cvt_pk_fp8_f32(src_a=f0, src_b=f1, old=zero, word_sel=0, res=i32)
pk = rocdl.cvt_pk_fp8_f32(src_a=f2, src_b=f3, old=_uw(pk), word_sel=1, res=i32)
return pk

even_fp8 = _i32_int8x4_to_fp8x4(even)
odd_fp8 = _i32_int8x4_to_fp8x4(odd)

v2 = vector.from_elements(T.vec(2, T.i32), [even_fp8, odd_fp8])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])
Loading
Loading