Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 278 additions & 34 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def make_preshuffle_b_layout(

if elem_bytes not in (1, 2):
raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}")
c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True)
c_k_bytes = c_k * fx.Index(int(elem_bytes))
c_k0 = c_k_bytes // c64
n0 = c_n // c16

c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True))
c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // fx.Index(int(elem_bytes)))

stride_nlane = c_kpack_elems
stride_klane = c16 * stride_nlane
Expand All @@ -98,6 +98,30 @@ def make_preshuffle_b_layout(
return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes)


def _unpack_int4_to_int8_pair(packed32):
"""Split packed int4 dword into two int8 dwords (even/odd nibbles).

7-op bit manipulation shared by all int4 unpack paths (W4A8, W4A16, W4A_FP8).
"""
c_08 = fx.Int32(0x08080808)
c_0f = fx.Int32(0x0F0F0F0F)
c_1e = fx.Int32(0x1E)
c_4 = fx.Int32(4)
s0 = (packed32 & c_08) * c_1e
even = (packed32 & c_0f) | s0
t = packed32 >> c_4
s1 = (t & c_08) * c_1e
odd = (t & c_0f) | s1
return even, odd


def _pack_i32_pair_to_i64(lo, hi, vector):
"""Pack two i32 values into one i64 via vector bitcast."""
v2 = vector.from_elements(T.vec(2, T.i32), [lo, hi])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])


def _i8x4_in_i32_to_bf16x4_i64(val_i32, arith, vector, scale_val=None):
"""Convert one i32 (4 signed int8 bytes) to 4 bf16 packed as i64.

Expand Down Expand Up @@ -195,18 +219,7 @@ def unpack_b_w4a16(packed32, arith, vector, scale_val=None):
Takes raw packed32 from load_b_raw_w4a16 and produces (b0, b1) --
two i64 values each containing 4 bf16 for one MFMA.
"""
c_08080808 = fx.Int32(0x08080808)
c_0f0f0f0f = fx.Int32(0x0F0F0F0F)
c_1e = fx.Int32(0x1E)
c_4_i32 = fx.Int32(4)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

even, odd = _unpack_int4_to_int8_pair(packed32)
b0 = _i8x4_in_i32_to_bf16x4_i64(even, arith, vector, scale_val=scale_val)
b1 = _i8x4_in_i32_to_bf16x4_i64(odd, arith, vector, scale_val=scale_val)
return (b0, b1)
Expand Down Expand Up @@ -242,12 +255,12 @@ def load_b_pack_k32(
raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}")

c64 = fx.Index(64)
base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True)
base_k_bytes = base_k * fx.Index(int(elem_bytes))
k0_base = base_k_bytes // c64
k0 = k0_base + arith.constant(ki_step // 2, index=True)
k0 = k0_base + fx.Index(ki_step // 2)
k1 = lane_div_16
half_bytes = kpack_bytes // 2
k2_base = arith.constant((ki_step % 2) * half_bytes, index=True)
k2_base = fx.Index((ki_step % 2) * half_bytes)

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
Expand All @@ -263,22 +276,8 @@ def load_b_pack_k32(
static_position=[0],
dynamic_position=[],
)

c_08080808 = fx.Int32(0x08080808)
c_0f0f0f0f = fx.Int32(0x0F0F0F0F)
c_1e = fx.Int32(0x1E)
c_4_i32 = fx.Int32(4)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

v2 = vector.from_elements(T.vec(2, T.i32), [even, odd])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])
even, odd = _unpack_int4_to_int8_pair(packed32)
return _pack_i32_pair_to_i64(even, odd, vector)

vec_elems = kpack_bytes // int(elem_bytes)
b16 = _buffer_load_vec(
Expand Down Expand Up @@ -314,7 +313,7 @@ def tile_chunk_coord_i32(
"""Map (thread, chunk_id) -> (row_local, col_local_i32) for X/A loads."""
if chunk_i32 not in (1, 2, 4):
raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}")
chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True)
chunk_off_i32 = fx.Index(i * total_threads * chunk_i32)
tile_idx_i32 = tx_i32_base + chunk_off_i32
coord_local = idx2crd(tile_idx_i32, layout_tile_div4)
row_local = layout_get(coord_local, 0)
Expand Down Expand Up @@ -466,6 +465,251 @@ def lds_load_pack_k32(
"lds_store_16b_xor16",
"make_preshuffle_b_layout",
"load_b_pack_k32",
"load_b_raw_w4a16",
"unpack_b_w4a16",
"load_b_raw_w4a16_groupwise",
"unpack_b_w4a16_groupwise",
"load_b_raw_w4a8_k64",
"load_b_raw_w4a8_groupwise_k64",
"unpack_b_w4a8",
"unpack_b_w4a_fp8",
"swizzle_xor16",
"tile_chunk_coord_i32",
]


# ---------------------------------------------------------------------------
# Groupwise scale load helper (shared by W4A16 and W4A8 groupwise paths)
# ---------------------------------------------------------------------------

def _load_groupwise_scale(
buffer_ops,
arith,
*,
scale_rsrc,
expert_offset,
n_blk,
n_intra,
k_pos,
num_groups: int,
group_size: int,
n_per_expert: int,
):
"""Load one per-group scale value from the scale buffer.

Computes the linear index into the scale tensor from expert offset,
N position, and group index derived from ``k_pos``.
"""
c16 = fx.Index(16)
n_global = n_blk * c16 + n_intra
c_group_size = fx.Index(group_size)
c_gm1 = fx.Index(num_groups - 1)
c_npe = fx.Index(n_per_expert)
# n_global is the GLOBAL N index (includes expert offset), so use (G-1)
# to compensate: expert_offset*(G-1) + (expert_offset + n_within) = expert_offset*G + n_within
base_scale = expert_offset * c_gm1 + n_global
group_idx = k_pos // c_group_size
scale_idx_i32 = arith.index_cast(T.i32, base_scale + group_idx * c_npe)
return buffer_ops.buffer_load(scale_rsrc, scale_idx_i32, vec_width=1, dtype=T.f32)


# ---------------------------------------------------------------------------
# W4A16 groupwise load / unpack helpers
# ---------------------------------------------------------------------------

def load_b_raw_w4a16_groupwise(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A16 groupwise B load: buffer_loads for weight + scale.

Reuses :func:`load_b_raw_w4a16` for the weight load, then issues an
additional ``buffer_load_dword`` for the per-group scale.

Returns ``(packed32, scale_val)``.
"""
packed32 = load_b_raw_w4a16(
buffer_ops, arith, vector,
arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b,
base_k=base_k, ku=ku,
n_blk=n_blk, n_intra=n_intra,
lane_div_16=lane_div_16, elem_type=elem_type,
kpack_bytes=kpack_bytes,
)
k_pos = base_k + fx.Index(ku * 32)
scale_val = _load_groupwise_scale(
buffer_ops, arith,
scale_rsrc=scale_rsrc, expert_offset=expert_offset,
n_blk=n_blk, n_intra=n_intra, k_pos=k_pos,
num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert,
)
return (packed32, scale_val)


def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector):
"""Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16."""
return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val)


# ---------------------------------------------------------------------------
# W4A8 load / unpack helpers (8B K64 loads)
# ---------------------------------------------------------------------------

def load_b_raw_w4a8_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 per-row B load: 8-byte buffer_load_dwordx2 for one K64 step.

Loads both K32 halves in a single VMEM instruction (``buffer_load_dwordx2``).
Returns ``(packed32_half0, packed32_half1)`` for :func:`unpack_b_w4a8`.
"""
if kpack_bytes != 8:
raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}")

c64 = fx.Index(64)
k0_base = base_k // c64
k0 = k0_base + fx.Index(ku)
k1 = lane_div_16

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)

b8 = _buffer_load_vec(
buffer_ops, vector, b_rsrc, idx_pack,
elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True,
)
b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8)
half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[])
half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[])
return (half0, half1)


def load_b_raw_w4a8_groupwise_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 groupwise B load: 8B weight + two scale loads per K64.

Reuses :func:`load_b_raw_w4a8_k64` for the weight load, then issues two
``buffer_load_dword`` for per-group scales (each K32 half may belong to a
different group).

Returns ``(half0, half1, scale0, scale1)``.
"""
half0, half1 = load_b_raw_w4a8_k64(
buffer_ops, arith, vector,
arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b,
base_k=base_k, ku=ku,
n_blk=n_blk, n_intra=n_intra,
lane_div_16=lane_div_16, elem_type=elem_type,
kpack_bytes=kpack_bytes,
)

scale_kw = dict(
scale_rsrc=scale_rsrc, expert_offset=expert_offset,
n_blk=n_blk, n_intra=n_intra,
num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert,
)
scale0 = _load_groupwise_scale(
buffer_ops, arith, k_pos=base_k + fx.Index(ku * 2 * 32), **scale_kw,
)
scale1 = _load_groupwise_scale(
buffer_ops, arith, k_pos=base_k + fx.Index((ku * 2 + 1) * 32), **scale_kw,
)
return (half0, half1, scale0, scale1)


def unpack_b_w4a8(packed32, arith, vector):
"""Phase 2 of W4A8 B load: 7-op unpack from packed int4 to int8 i64.

Takes a raw ``packed32`` (one dword of packed int4) and produces one i64
value containing 8 signed int8 bytes for one MFMA K32 step.
"""
even, odd = _unpack_int4_to_int8_pair(packed32)
return _pack_i32_pair_to_i64(even, odd, vector)


def unpack_b_w4a_fp8(packed32, arith, vector, rocdl):
"""Unpack packed int4 (i32) to fp8 i64 for mfma_f32_16x16x32_fp8_fp8.

Pipeline: int4 -> int8 (7-op unpack) -> f32 (byte extract + sitofp)
-> fp8 (cvt_pk_fp8_f32) -> i64.
"""
even, odd = _unpack_int4_to_int8_pair(packed32)

c_8 = fx.Int32(8)
c_16 = fx.Int32(16)
c_24 = fx.Int32(24)

from flydsl._mlir.dialects._arith_ops_gen import ShRSIOp as _ShRSIOp
_uw = arith._to_raw
_av = arith.ArithValue

def _i32_int8x4_to_fp8x4(val):
"""Convert i32 containing 4 signed int8 bytes -> i32 containing 4 fp8 bytes."""
def _sext_byte(src, shl_amount, shr_amount):
shifted = src << shl_amount
shrsi_result = _ShRSIOp(_uw(shifted), _uw(shr_amount)).result
return _uw(arith.sitofp(T.f32, _av(shrsi_result)))

f0 = _sext_byte(val, c_24, c_24)
f1 = _sext_byte(val, c_16, c_24)
f2 = _sext_byte(val, c_8, c_24)
b3 = _ShRSIOp(_uw(val), _uw(c_24)).result
f3 = _uw(arith.sitofp(T.f32, _av(b3)))

zero = _uw(fx.Int32(0))
pk = rocdl.cvt_pk_fp8_f32(src_a=f0, src_b=f1, old=zero, word_sel=0, res=T.i32)
pk = rocdl.cvt_pk_fp8_f32(src_a=f2, src_b=f3, old=_uw(pk), word_sel=1, res=T.i32)
return pk

even_fp8 = _i32_int8x4_to_fp8x4(even)
odd_fp8 = _i32_int8x4_to_fp8x4(odd)
return _pack_i32_pair_to_i64(even_fp8, odd_fp8, vector)
6 changes: 4 additions & 2 deletions kernels/moe_blockscale_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from flydsl._mlir import ir
from flydsl._mlir.dialects import llvm, scf, memref
from flydsl._mlir.dialects import fly as _fly_dialect
from kernels.kernels_common import _create_llvm_ptr
from flydsl.expr.typing import T

from kernels.mfma_preshuffle_pipeline import (
Expand Down Expand Up @@ -2115,7 +2117,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row):
# stable path here.)
out_base_idx = None
if out_is_bf16:
out_base_idx = memref.extract_aligned_pointer_as_index(arg_out)
_out_raw = arg_out.value if hasattr(arg_out, "value") else arg_out; _out_ptr = _fly_dialect.extract_aligned_pointer_as_index(ir.Type.parse("!llvm.ptr"), _out_raw); out_base_idx = arith.index_cast(T.index, llvm.PtrToIntOp(T.i64, _out_ptr).result)

def write_row_to_lds(
*,
Expand Down Expand Up @@ -2206,7 +2208,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag):
byte_off = idx_elem_even * c2_i32
byte_off_idx = arith.index_cast(T.index, byte_off)
ptr_addr_idx = out_base_idx + byte_off_idx
out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr = _create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr
frag_v = frag._value if hasattr(frag, "_value") else frag
llvm.AtomicRMWOp(
Expand Down
Loading
Loading