Skip to content

[CK] Add new FMHA batch prefill kernel with FP8 per-tensor and per-block KV quantization on gfx950#6054

Open
poyenc wants to merge 63 commits intodevelopfrom
users/poyenc/ck/batch-prefill-v3
Open

[CK] Add new FMHA batch prefill kernel with FP8 per-tensor and per-block KV quantization on gfx950#6054
poyenc wants to merge 63 commits intodevelopfrom
users/poyenc/ck/batch-prefill-v3

Conversation

@poyenc
Copy link
Copy Markdown
Contributor

@poyenc poyenc commented Mar 31, 2026

Summary

V3 FMHA pipeline for batch prefill with paged KV cache on gfx950 (FP8 only). Uses the same 8-warp, 256x64 tile, double-buffered LDS architecture as the contiguous fmha_fwd V3 kernel, extended with scatter-gather page table support.

Supported configurations:

  • FP8 per-tensor and KV block-scale quantization
  • SGLang (1D) and vLLM (2D) page tables
  • LINEAR KV layout only (VECTORIZED falls back to V2 via trait matching)
  • No mask / causal mask
  • Page sizes: 1, 16, 1024

Performance vs V2 (FP8 batch prefill, paged KV, stock ROCm 7.1.1, avg runs 2–6, MI355X):

Per-tensor (ps=1, sglang):

Problem V2 TFlops V3 TFlops V3/V2
h=6/1 sq=1k c 43.1 47.7 1.11x
h=6/1 sq=2k c 95.4 120.1 1.26x
h=6/1 sq=4k c 197.5 259.5 1.31x
h=6/1 sq=8k c 396.3 549.4 1.39x
h=6/1 sq=16k c 778.7 1066.1 1.37x
h=6/1 sq=32k c 1000.2 1282.3 1.28x
h=6/1 sq=65k c 1045.0 1375.5 1.32x
h=6/1 sq=131k c 1072.9 1390.2 1.30x
h=16/1 sq=65k c 1072.3 1381.8 1.29x
h=40/40 sq=37k nc 930.1 1304.3 1.40x

KV block-scale (ps=1024, sglang):

Problem V2 TFlops V3 TFlops V3/V2
h=6/1 sq=1k c 42.6 45.2 1.06x
h=6/1 sq=2k c 95.2 111.6 1.17x
h=6/1 sq=4k c 197.8 241.1 1.22x
h=6/1 sq=8k c 389.9 509.7 1.31x
h=6/1 sq=16k c 745.6 1013.1 1.36x
h=6/1 sq=32k c 1016.1 1205.4 1.19x
h=6/1 sq=65k c 1057.0 1313.5 1.24x
h=6/1 sq=131k c 1069.6 1330.3 1.24x
h=16/1 sq=65k c 1084.4 1324.2 1.22x
h=40/40 sq=37k nc 945.0 1251.6 1.32x

Key design decisions

Branchless page_id clamping — The scatter-gather path bypasses buffer descriptor NUM_RECORDS protection, so auto-advancing past seqlen_k causes XNACK faults. Solved with min(page_id, max_page_table_idx) instead of branch guards, which would fragment sched_group_barrier-scheduled basic blocks.

Page advance issue/consume split — Page table lookups (global_load_dword) are issued BEFORE buffer_loads so they sit oldest in the vmcnt FIFO. At consume, s_waitcnt(N) drains only the page lookup while keeping N buffer_loads in flight. sched_barrier(0) prevents reordering.

KV block-scale k_descale fold — k_descale is folded into the scalar row_max via the FP8 shift trick, eliminating a full-tile VALU pass. v_descale is merged into the softmax rescale factor (o_acc_scale), with the final correction applied in the epilogue.

No no-packed-fp32-ops kernel attribute — V3 does NOT use kernel_attr<true> (which adds target("no-packed-fp32-ops")) despite it giving a +2–4% benefit on per-tensor workloads. The attribute conflicts with explicit v_pk_mul_f32 inline asm used in the KV block-scale descale path (pk_mul_f32 helper): when the asm is inlined into the attributed kernel entry, the assembler rejects the instruction. Benchmarks show the attribute is neutral-to-positive for block-scale (removing it actually gives +1–3%), so we accept the 2–4% per-tensor trade-off to avoid the asm conflict. A future per-variant solution (attribute on for per-tensor, off for block-scale) could recover this, but the codegen currently uses a single F_kernel_attr per kernel.

New files

File Description
fmha_batch_prefill_v3_kernel.hpp V3 kernel with paged KV DRAM views, page table Kargs, batch/group mode
block_fmha_batch_prefill_v3_pipeline.hpp V3 pipeline with scatter-gather loads, KV block-scale support
block_fmha_batch_prefill_v3_pipeline_default_policy.hpp FP8 LDS descriptors, SRD rebasing for large page sizes
block_fmha_fwd_v3_detail.hpp Shared V3 helpers: CoreLoopSchedulingParams, VALU intrinsics, macros

Codegen changes (fmha_batch_prefill.py)

Refactored to support per-architecture kernel generation (matching fmha_fwd.py patterns):

  • ArchTrait / factory hierarchy: gfx9 -> gfx950 with arch-specific tiles, pipelines, and compatibility rules
  • Separate fmha_batch_prefill_v2() / fmha_batch_prefill_v3() dispatch functions (V3 tried first, falls back to V2)
  • 80 V3 kernels generated for gfx950

Bug fix in fmha_fwd.py

Fixed missing return in check_hdim — bare False was a no-op, causing unnecessary bias/dropout kernels for hdim=(192,128).

Dependencies

Based on #4437 (users/poyenc/fa-v3-fp8-pertensor). That PR must be merged first.

Test plan

  • Full test_batch_prefill.py suite

poyenc added 30 commits March 13, 2026 02:35
Add FP8BF16 per-tensor quantization path to the FMHA forward V3
pipeline on gfx950. This includes:

- FP8 32x32x32 warp gemm with C-transposed distribution
- FP8 warp gemm dispatcher entries
- V3 kernel support for per-tensor descale (q/k/v descale pointers)
- V3 pipeline FP8 data path with asm volatile for P conversion
- FP8 instruction scheduling optimization in CoreLoopScheduler
- Codegen: FP8BF16 V3 tile size (256x64x128) and pipeline variants
- Codegen: V3 dispatch condition extended for fp8bf16+pertensor
- LLVM scheduler TRANS mask for scheduling control
- Fix mask_info default initialization for no_mask case

Note: V3 dispatch is disabled by default pending further validation.
Remove debug macros (ENABLE_DEBUG_STMTS, DEBUG_STMTS, WARP_ID, LANE_ID),
debug lambdas (print_dist_tensor, print_lds, print_lds_1d), unused LDS
windows (s/p/o/m_lds_window), their helper methods (MakeSimpleLdsDesc,
MakeSimpleLdsDesc1D), and unused KPack variables in the policy file.

Assembly verified identical (sched_diff=0), 176/176 fp8 tests pass.
The P matrix (attention weights) lives entirely in registers (VGPRs)
via sp_compute_type, not in LDS. Remove the P buffer terms from
GetSmemSize() so it reports only the actual KV buffer usage.

Assembly-verified: before/after diff shows identical GPU and host code.
- Separate smem_k[2]/smem_v[2] pointers for explicit buffer control
- Use async_load_tile_raw / init_raw for raw async copies
- Remove dead P buffer from LDS size calculation
- Reformat operator() signatures for readability
Revert changes from debug commit that swapped NumWarps and LaneGroups
in MakeVDramTileDistribution(), MakeVLdsStoreBlockDescriptor(), and
MakeVLdsLoadBlockDescriptor(). These were unrelated to the 4-buffer
LDS architecture refactor.

Restores the original dimension ordering:
- MakeVDramTileDistribution: N1=LaneGroups, N2=NumWarps
- MakeVLdsStoreBlockDescriptor: shape (NumIssues, LaneGroups, NumWarps, ...)
- MakeVLdsLoadBlockDescriptor: merge sequence<0, 2, 1> for correct reorder

Testing: 176/176 FP8 MHA tests pass
Remove fine-grained can_dispatch_v3 runtime guard. Try V3 first when
enabled; unsupported configs return -1 and fall back to V2.
…ution

Remove duplicate plain using definitions of WarpGemmMfma_f32_32x32x32_fp8_fp8,
WarpGemmMfma_f32_32x32x32_bf8_bf8 that conflicted with the templated
#if gfx950/#else versions, and deduplicate the corresponding dispatcher entry.
…rning

Drop the epilogue shared-memory buffer and smem_ptr parameter that were
left over after prior refactoring, and silence the -Wunreachable-code
diagnostic in the V3/V2 dispatch fallback.
The bare `False` statement was a no-op, causing bias/dropout kernels
to be generated for (192, 128) hdim configurations instead of being
filtered out.
Add kernel_attr_for<ArchTag, Attrs...> to kernel_launch.hpp that
composes an architecture tag with kernel attributes. When no attributes
are provided, kernel_attr_for<ArchTag> is an identity alias for ArchTag
itself (is_same_v is true). With attributes, it creates a unique type
that inherits both the arch tag and attribute mixins.

The existing kattr_no_packed_fp32_ops_v SFINAE detection works
transparently through the inheritance chain.

Usage:
  kernel_attr_for<gfx950_t>                       -> gfx950_t
  kernel_attr_for<gfx950_t, kernel_attr<true>>    -> unique type
Refactor fmha_batch_prefill.py to match fmha_fwd.py patterns, preparing
for V3 pipeline integration:

- Add ArchTrait to FmhaFwdApiTrait and FmhaFwdKernel with arch
  preprocessor guards
- Refactor FmhaFwdApiPool to hierarchical OrderedDict[arch][dtype][hdim]
  with render() method
- Split API template into HEADER/FUNC_TEMPLATE/PER_ARCH/FOOTER
- Add ProblemContext, KernelContext, CompatibilityRule, is_compatible(),
  create_kernel() abstractions
- Extract inline filtering into CompatibilityRuleFactory and Product
- Add factory hierarchy with get_factories_for_targets()
- Add extensible _get_cpp_kernel_class_name(),
  _get_cpp_kargs_creator_func_name(),
  _get_cpp_pipeline_problem_name() methods to FmhaFwdKernel
- Filename includes arch suffix: {name}_{arch}.cpp

Tested: 14848 passed, 0 failed across 4 combinations
(stock/custom compiler x before/after refactoring).
Add batch prefill V3 pipeline and kernel with scatter-gather paged KV
support, simplified dispatch that relies on trait matching for fallback.

- V3 pipeline: 4-phase double warp group, async buffer loads, 4-buffer LDS
- V3 kernel: LINEAR layout only, SGLang + vLLM page tables
- Codegen: 80 V3 kernels (bf16/fp8, no/causal mask, page sizes 1/16/1024)
- Dispatch: try V3 first when enabled, fall back to V2 via trait matching
- Static asserts enforce V3 constraints (LINEAR, no bias/dropout/kv_blockscale)
…odegen

V3 batch prefill is only needed for fp8bf16. Remove the bf16/fp16 V3
tile (256x64, 8 warps) and pipeline (qr_async_trload_v3) entries from
KernelComponentFactoryGfx950. bf16/fp16 continue to use V2 (qr_async).
Skip page table lookup in K_mem_load/V_mem_load when the next sequence
position exceeds seqlen_k_end. Without this guard, the auto-advance
reads page indices from the padding region of kv_page_indices and
computes scatter-gather offsets that produce buffer_load addresses
mapping to unmapped GPU pages, causing XNACK faults on gfx950.

The contiguous V3 fwd pipeline doesn't have this issue because its
move_tile_window is simple pointer arithmetic protected by the buffer
descriptor's NUM_RECORDS field. The scatter-gather path computes
physical offsets from the page table, bypassing NUM_RECORDS protection.
Separate K/V page offset updates from K/V_mem_load into dedicated
K_page_advance/V_page_advance lambdas, called at the very end of each
phase after Scheduler::schedule. This keeps async_load_tile + ds_read
in one uninterrupted basic block, preventing the XNACK guard branch
from fragmenting the load+ds_read scheduling.

Recovers 6-13% of the 14-26% guard overhead (measured on FP8
batch_prefill sweep). The guard branch still exists but only fragments
the tail of each phase, not the critical load/ds_read interleaving.
Replace the branched XNACK guard (if/s_cbranch) with branchless
min(page_id, max_page_table_idx) inside load_physical_pages(). The
max_page_table_idx is computed as (seqlen_k - 1) / kPageBlockSize in
the kernel and threaded through to all load_physical_pages() call sites.

This eliminates the 14-26% guard overhead that was caused by:
- s_cbranch fragmenting sched_group_barrier-scheduled basic blocks
- serialized global_load_dword + s_waitcnt at conditional join points
- +14 VGPRs from extended live ranges across branch boundaries

V3 FP8 batch_prefill is now 8-23% faster than V2 (was 4-30% slower).
Paged KV overhead reduced from 20-47% to 7-17% vs contiguous varlen.
…overlap

Split K_page_advance/V_page_advance into issue/consume pairs so the
global_load_dword (page table lookup) is issued BEFORE the buffer_loads
from cl_load, placing it oldest in the vmcnt FIFO. At consume time,
s_waitcnt(N) drains only the oldest global_load while keeping the N
buffer_loads in flight.

sched_barrier(0) brackets prevent the compiler from reordering the
global_load_dword across the buffer_loads, which would undo the FIFO
ordering.

Applied to all 4 core loop load phases (WG0 phases 1/3, WG1 phases
0/2), the pre-stage, and WG0 pre-loop setup.

Correctness: 16480 passed, 5376 skipped, 0 failed (matches baseline)
Performance (avg of 5 runs, FP8 batch_prefill, paged KV):
  s=4k-8k: +8-9%, s=16k: +6%, s=32k+: +2-5%, MHA h=40/40: +5%
Add per-page FP8 K/V dequantization scale support to the V3 batch
prefill pipeline, matching the existing V2 implementation.

Kernel: add FmhaFwdKVBlockScaleKargs with nblock/nhead strides,
update Kargs type selection and MakeKargs for both batch/group mode.
scale_s uses q_descale only (k_descale deferred to pipeline).

Pipeline: FP8 shift trick in fmha_alu0 (subtract 8.0/7.0 from row
max to implicitly scale P), k_descale applied after GEMM0, v_descale
rescale trick around GEMM1 (divide before, multiply after). Double-
buffered saved_k/v_descale indexed by LDS buffer slot, saved before
each K_page_issue.

Codegen: add "kv_blockscale" to V3 pipeline generation. Existing
check_page_size filter enforces page_size >= kN0.

Performance tuning not yet done (extra element-wise passes for
v_descale rescale not merged into fmha_alu_D_upd).
…asses

Reduce KV_BLOCKSCALE overhead from 28% to 11% vs pertensor by:

1. Merge v_descale into o_acc_scale: maintain o_acc in v_descale-scaled
   space, folding v_descale_prev/v_descale_cur ratio into the existing
   softmax rescale factor. Eliminates 2 full-tile VALU passes (divide
   before GEMM1 + multiply after GEMM1). Final v_descale applied in
   epilogue normalization.

2. Replace scalar k_descale multiply with v_pk_mul_f32: halves
   instruction count for the remaining k_descale pass by operating on
   float2 pairs.

Both changes are guarded by if constexpr(KV_BLOCKSCALE) — pertensor
assembly is structurally identical before/after.
…elop

- Add block_fmha_batch_prefill_v3_pipeline_default_policy.hpp with
  FP8-specific LDS descriptors (separate from shared V3 fwd policy)
- Add SRD rebasing (rebase_k/v_window) to V3 pipeline for
  kPageBlockSize >= kN0, enabling page_size=1024 V3 dispatch
- Fix load_tile_transpose_with_offset to use develop's void API
- Remove codegen page_size >= kN0 restriction for V3 pipeline
Add BatchPrefillCoreLoopScheduler with FP8-tuned VALU budgets (6/6 per
MFMA half) matching the feature branch's CoreLoopScheduler tuning.
The develop branch had VALU:4/3 which undershoots actual VALU work
because v_pk_mul_f32 asm volatile is invisible to the compiler.

Relax fma_impl_vsv from asm volatile("v_fma_f32") to plain C++ FMA.
The asm volatile anchor prevented compiler reordering across
sched_barrier(0) boundaries, causing s_nop 7+3 stalls (22 extra NOP
cycles per phase2). Plain C++ gives the compiler freedom to fill MFMA
latency gaps.
Use kernel_attr_for<gfx950_t, kernel_attr<true>> for V3 batch prefill
kernels to apply target("no-packed-fp32-ops") to the kernel entry point.
This prevents the compiler from generating v_pk_mul_f32 for non-asm
FP32 operations, allowing separated v_mul_f32 to co-execute with MFMA.

V2 kernels use the plain arch tag (unchanged behavior).

Replaces the -mllvm --amdgpu-disable-packed-fp32=1 flag which was
silently skipped by the stock ROCm 7.1.1 compiler. The target attribute
is honored by both stock and custom compilers.

Measured +2-5% over stock compiler without the attribute.
…aders

Extract CoreLoopSchedulingParams, block_gemm_mfma_count_v, detail::
VALU helpers (fma_impl_vsv, add_impl_vv, mul_impl_vv, cvt_pk_*,
pk_mul_f32), and macros (CK_TILE_FMHA_V3_ASM_MARKER,
CK_TILE_FMHA_V3_ADD_SBARRIER_FOR_PHASE0) into a shared header
block_fmha_fwd_v3_detail.hpp. Both fmha_fwd V3 and batch_prefill V3
pipelines now include this header instead of batch_prefill including
the full fwd V3 pipeline header.

Also:
- Remove __gfx950__ guards inside V3 pipelines (keep only
  permlane32_swap path; entire pipeline is gfx950-only)
- Add top-level gfx950 guard: operator() returns empty output on
  non-gfx950 device
- Remove CK_TILE_DISABLE_PACKED_FP32 macro (always 0; V3 uses
  kernel_attr_for instead of -mllvm flag)
- Add CK_TILE_ prefix to all custom macros
- Add design comments for quant/LSC-unaware scheduler

Assembly output verified identical (only __hip_cuid differs).
V3 batch prefill generates only LINEAR KV layout kernels. VECTORIZED
layout requires sub-dword async loads that violate V3's buffer
addressing constraints, and KV layout optimization has not been done
for V3. VECTORIZED requests fall back to V2 via trait matching.
The attribute conflicts with explicit v_pk_mul_f32 inline asm in the
KV_BLOCKSCALE descale path. Benchmarks (6 sweeps, avg runs 2-6) show
removing it costs 2-4% on pertensor but is neutral-to-positive on
blockscale, making the trade-off acceptable vs the asm conflict.
@poyenc
Copy link
Copy Markdown
Contributor Author

poyenc commented Apr 3, 2026

still be blocked by unrelated compilation errors

poyenc added 12 commits April 3, 2026 09:58
Preserve extraction refactoring (v3_detail.hpp shared header) while
adapting post-review fixes from PR 6051 (fmha_fwd V3) into both the
fwd and batch_prefill V3 pipelines:

1. Move s_waitcnt/s_barrier outside `if(2 < num_total_loop)` guard in
   pre-stage to ensure K1+V0 async loads are drained before core_loop
   reads K1 from LDS (bug fix for num_total_loop==2 case).

2. Replace manual __builtin_amdgcn_permlane32_swap intrinsic calls with
   block_tile_reduce/block_tile_reduce_sync in fmha_alu0 (rowmax) and
   fmha_alu1 (rowsum), preserving kFoldKDescale logic in batch_prefill.

3. Split fmha_alu_D_upd into unpack/pack with interleaved scheduling.

4. Add CK_TILE_DISABLE_PACKED_FP32 guard on schedule_gemm1_compute().

5. Add fmha_alu_D_reg_cnt % 2 == 0 assertion.
Lines 174-175 duplicated the specializations already at lines 125-126,
causing redefinition errors. The duplicates were introduced by the merge
(both branches added the same entries independently).
The merge conflict resolution dropped the Dispatcher<fp8_t, fp8_t, float,
32, 32, 32, false> specialization, causing compilation errors when
cshuffle_epilogue instantiates WarpGemmDispatcher with fp8 types and
isCTransposed=false.
@poyenc
Copy link
Copy Markdown
Contributor Author

poyenc commented Apr 18, 2026

the CI always failed to pull CK image

poyenc added 3 commits April 20, 2026 13:47
Resolve merge conflicts from StreamLLM sink token support (#6479)
landing on develop after the V3 batch prefill refactor. Integrate
F_sink parameter into the refactored multi-arch factory codegen.
The refactoring to CompatibilityRuleFactory dropped the original
`if mode != "group": continue` guard, allowing batch-mode kernel
instantiations that hit the static_assert in pipeline problem.
@poyenc poyenc force-pushed the users/poyenc/ck/batch-prefill-v3 branch from 51085d9 to 515cb8f Compare April 23, 2026 03:46
poyenc added 6 commits April 23, 2026 11:47
…ispatch

Add #pragma clang diagnostic push/pop around the v3 dispatch block in the
API footer template, matching the pattern already used in fmha_fwd.py.
Without this, gfx942 builds (which have no v3 kernels) emit if(false)
and fail with -Werror,-Wunreachable-code.
Return None from get_factory for targets without a factory instead of
raising an exception. Filter unsupported targets before dispatching to
get_factories_for_targets so that builds targeting e.g. gfx1101 produce
an empty kernel pool instead of crashing at configure time.
# Conflicts:
#	projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Fix -Werror,-Wunused-variable compilation error.
@poyenc poyenc requested a review from Jeff-Huang April 29, 2026 06:11
@Jeff-Huang
Copy link
Copy Markdown
Contributor

Jeff-Huang commented Apr 29, 2026

To fix the OOB issue surfaced by AICK-1171, I based the V2 changes on your PR and opened a separate PR — #6932 — for the V2 load_physical_pages part only. Two small tweaks while extracting:

  • max_page_table_idx is mandatory (no INT32_MAX default), so every callsite has to pass the bound explicitly. With the optional default, any unupdated callsite silently no-ops the clamp.
  • Clamp applied to all branches in load_physical_pages (K prefetch + V LINEAR + V crosses-pages + V single-page lane0 broadcast) in V2.

Verified on MI-308X with the AICK-1171 reproducer and the full FMHA batch prefill suite on gfx942/gfx950. Splitting it out lets the OOB fix land on its own timeline; the rest of #6054 stays as-is.

Credit to your original approach — same min(page_id, max_bound) idea, just hardened.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants