[CK] Add new FMHA batch prefill kernel with FP8 per-tensor and per-block KV quantization on gfx950#6054
Open
[CK] Add new FMHA batch prefill kernel with FP8 per-tensor and per-block KV quantization on gfx950#6054
Conversation
Add FP8BF16 per-tensor quantization path to the FMHA forward V3 pipeline on gfx950. This includes: - FP8 32x32x32 warp gemm with C-transposed distribution - FP8 warp gemm dispatcher entries - V3 kernel support for per-tensor descale (q/k/v descale pointers) - V3 pipeline FP8 data path with asm volatile for P conversion - FP8 instruction scheduling optimization in CoreLoopScheduler - Codegen: FP8BF16 V3 tile size (256x64x128) and pipeline variants - Codegen: V3 dispatch condition extended for fp8bf16+pertensor - LLVM scheduler TRANS mask for scheduling control - Fix mask_info default initialization for no_mask case Note: V3 dispatch is disabled by default pending further validation.
Remove debug macros (ENABLE_DEBUG_STMTS, DEBUG_STMTS, WARP_ID, LANE_ID), debug lambdas (print_dist_tensor, print_lds, print_lds_1d), unused LDS windows (s/p/o/m_lds_window), their helper methods (MakeSimpleLdsDesc, MakeSimpleLdsDesc1D), and unused KPack variables in the policy file. Assembly verified identical (sched_diff=0), 176/176 fp8 tests pass.
The P matrix (attention weights) lives entirely in registers (VGPRs) via sp_compute_type, not in LDS. Remove the P buffer terms from GetSmemSize() so it reports only the actual KV buffer usage. Assembly-verified: before/after diff shows identical GPU and host code.
- Separate smem_k[2]/smem_v[2] pointers for explicit buffer control - Use async_load_tile_raw / init_raw for raw async copies - Remove dead P buffer from LDS size calculation - Reformat operator() signatures for readability
Revert changes from debug commit that swapped NumWarps and LaneGroups in MakeVDramTileDistribution(), MakeVLdsStoreBlockDescriptor(), and MakeVLdsLoadBlockDescriptor(). These were unrelated to the 4-buffer LDS architecture refactor. Restores the original dimension ordering: - MakeVDramTileDistribution: N1=LaneGroups, N2=NumWarps - MakeVLdsStoreBlockDescriptor: shape (NumIssues, LaneGroups, NumWarps, ...) - MakeVLdsLoadBlockDescriptor: merge sequence<0, 2, 1> for correct reorder Testing: 176/176 FP8 MHA tests pass
Remove fine-grained can_dispatch_v3 runtime guard. Try V3 first when enabled; unsupported configs return -1 and fall back to V2.
…ution Remove duplicate plain using definitions of WarpGemmMfma_f32_32x32x32_fp8_fp8, WarpGemmMfma_f32_32x32x32_bf8_bf8 that conflicted with the templated #if gfx950/#else versions, and deduplicate the corresponding dispatcher entry.
…rning Drop the epilogue shared-memory buffer and smem_ptr parameter that were left over after prior refactoring, and silence the -Wunreachable-code diagnostic in the V3/V2 dispatch fallback.
The bare `False` statement was a no-op, causing bias/dropout kernels to be generated for (192, 128) hdim configurations instead of being filtered out.
Add kernel_attr_for<ArchTag, Attrs...> to kernel_launch.hpp that composes an architecture tag with kernel attributes. When no attributes are provided, kernel_attr_for<ArchTag> is an identity alias for ArchTag itself (is_same_v is true). With attributes, it creates a unique type that inherits both the arch tag and attribute mixins. The existing kattr_no_packed_fp32_ops_v SFINAE detection works transparently through the inheritance chain. Usage: kernel_attr_for<gfx950_t> -> gfx950_t kernel_attr_for<gfx950_t, kernel_attr<true>> -> unique type
Refactor fmha_batch_prefill.py to match fmha_fwd.py patterns, preparing
for V3 pipeline integration:
- Add ArchTrait to FmhaFwdApiTrait and FmhaFwdKernel with arch
preprocessor guards
- Refactor FmhaFwdApiPool to hierarchical OrderedDict[arch][dtype][hdim]
with render() method
- Split API template into HEADER/FUNC_TEMPLATE/PER_ARCH/FOOTER
- Add ProblemContext, KernelContext, CompatibilityRule, is_compatible(),
create_kernel() abstractions
- Extract inline filtering into CompatibilityRuleFactory and Product
- Add factory hierarchy with get_factories_for_targets()
- Add extensible _get_cpp_kernel_class_name(),
_get_cpp_kargs_creator_func_name(),
_get_cpp_pipeline_problem_name() methods to FmhaFwdKernel
- Filename includes arch suffix: {name}_{arch}.cpp
Tested: 14848 passed, 0 failed across 4 combinations
(stock/custom compiler x before/after refactoring).
Add batch prefill V3 pipeline and kernel with scatter-gather paged KV support, simplified dispatch that relies on trait matching for fallback. - V3 pipeline: 4-phase double warp group, async buffer loads, 4-buffer LDS - V3 kernel: LINEAR layout only, SGLang + vLLM page tables - Codegen: 80 V3 kernels (bf16/fp8, no/causal mask, page sizes 1/16/1024) - Dispatch: try V3 first when enabled, fall back to V2 via trait matching - Static asserts enforce V3 constraints (LINEAR, no bias/dropout/kv_blockscale)
…odegen V3 batch prefill is only needed for fp8bf16. Remove the bf16/fp16 V3 tile (256x64, 8 warps) and pipeline (qr_async_trload_v3) entries from KernelComponentFactoryGfx950. bf16/fp16 continue to use V2 (qr_async).
Skip page table lookup in K_mem_load/V_mem_load when the next sequence position exceeds seqlen_k_end. Without this guard, the auto-advance reads page indices from the padding region of kv_page_indices and computes scatter-gather offsets that produce buffer_load addresses mapping to unmapped GPU pages, causing XNACK faults on gfx950. The contiguous V3 fwd pipeline doesn't have this issue because its move_tile_window is simple pointer arithmetic protected by the buffer descriptor's NUM_RECORDS field. The scatter-gather path computes physical offsets from the page table, bypassing NUM_RECORDS protection.
Separate K/V page offset updates from K/V_mem_load into dedicated K_page_advance/V_page_advance lambdas, called at the very end of each phase after Scheduler::schedule. This keeps async_load_tile + ds_read in one uninterrupted basic block, preventing the XNACK guard branch from fragmenting the load+ds_read scheduling. Recovers 6-13% of the 14-26% guard overhead (measured on FP8 batch_prefill sweep). The guard branch still exists but only fragments the tail of each phase, not the critical load/ds_read interleaving.
Replace the branched XNACK guard (if/s_cbranch) with branchless min(page_id, max_page_table_idx) inside load_physical_pages(). The max_page_table_idx is computed as (seqlen_k - 1) / kPageBlockSize in the kernel and threaded through to all load_physical_pages() call sites. This eliminates the 14-26% guard overhead that was caused by: - s_cbranch fragmenting sched_group_barrier-scheduled basic blocks - serialized global_load_dword + s_waitcnt at conditional join points - +14 VGPRs from extended live ranges across branch boundaries V3 FP8 batch_prefill is now 8-23% faster than V2 (was 4-30% slower). Paged KV overhead reduced from 20-47% to 7-17% vs contiguous varlen.
…overlap Split K_page_advance/V_page_advance into issue/consume pairs so the global_load_dword (page table lookup) is issued BEFORE the buffer_loads from cl_load, placing it oldest in the vmcnt FIFO. At consume time, s_waitcnt(N) drains only the oldest global_load while keeping the N buffer_loads in flight. sched_barrier(0) brackets prevent the compiler from reordering the global_load_dword across the buffer_loads, which would undo the FIFO ordering. Applied to all 4 core loop load phases (WG0 phases 1/3, WG1 phases 0/2), the pre-stage, and WG0 pre-loop setup. Correctness: 16480 passed, 5376 skipped, 0 failed (matches baseline) Performance (avg of 5 runs, FP8 batch_prefill, paged KV): s=4k-8k: +8-9%, s=16k: +6%, s=32k+: +2-5%, MHA h=40/40: +5%
Add per-page FP8 K/V dequantization scale support to the V3 batch prefill pipeline, matching the existing V2 implementation. Kernel: add FmhaFwdKVBlockScaleKargs with nblock/nhead strides, update Kargs type selection and MakeKargs for both batch/group mode. scale_s uses q_descale only (k_descale deferred to pipeline). Pipeline: FP8 shift trick in fmha_alu0 (subtract 8.0/7.0 from row max to implicitly scale P), k_descale applied after GEMM0, v_descale rescale trick around GEMM1 (divide before, multiply after). Double- buffered saved_k/v_descale indexed by LDS buffer slot, saved before each K_page_issue. Codegen: add "kv_blockscale" to V3 pipeline generation. Existing check_page_size filter enforces page_size >= kN0. Performance tuning not yet done (extra element-wise passes for v_descale rescale not merged into fmha_alu_D_upd).
…asses Reduce KV_BLOCKSCALE overhead from 28% to 11% vs pertensor by: 1. Merge v_descale into o_acc_scale: maintain o_acc in v_descale-scaled space, folding v_descale_prev/v_descale_cur ratio into the existing softmax rescale factor. Eliminates 2 full-tile VALU passes (divide before GEMM1 + multiply after GEMM1). Final v_descale applied in epilogue normalization. 2. Replace scalar k_descale multiply with v_pk_mul_f32: halves instruction count for the remaining k_descale pass by operating on float2 pairs. Both changes are guarded by if constexpr(KV_BLOCKSCALE) — pertensor assembly is structurally identical before/after.
…elop - Add block_fmha_batch_prefill_v3_pipeline_default_policy.hpp with FP8-specific LDS descriptors (separate from shared V3 fwd policy) - Add SRD rebasing (rebase_k/v_window) to V3 pipeline for kPageBlockSize >= kN0, enabling page_size=1024 V3 dispatch - Fix load_tile_transpose_with_offset to use develop's void API - Remove codegen page_size >= kN0 restriction for V3 pipeline
Add BatchPrefillCoreLoopScheduler with FP8-tuned VALU budgets (6/6 per
MFMA half) matching the feature branch's CoreLoopScheduler tuning.
The develop branch had VALU:4/3 which undershoots actual VALU work
because v_pk_mul_f32 asm volatile is invisible to the compiler.
Relax fma_impl_vsv from asm volatile("v_fma_f32") to plain C++ FMA.
The asm volatile anchor prevented compiler reordering across
sched_barrier(0) boundaries, causing s_nop 7+3 stalls (22 extra NOP
cycles per phase2). Plain C++ gives the compiler freedom to fill MFMA
latency gaps.
Use kernel_attr_for<gfx950_t, kernel_attr<true>> for V3 batch prefill
kernels to apply target("no-packed-fp32-ops") to the kernel entry point.
This prevents the compiler from generating v_pk_mul_f32 for non-asm
FP32 operations, allowing separated v_mul_f32 to co-execute with MFMA.
V2 kernels use the plain arch tag (unchanged behavior).
Replaces the -mllvm --amdgpu-disable-packed-fp32=1 flag which was
silently skipped by the stock ROCm 7.1.1 compiler. The target attribute
is honored by both stock and custom compilers.
Measured +2-5% over stock compiler without the attribute.
…aders Extract CoreLoopSchedulingParams, block_gemm_mfma_count_v, detail:: VALU helpers (fma_impl_vsv, add_impl_vv, mul_impl_vv, cvt_pk_*, pk_mul_f32), and macros (CK_TILE_FMHA_V3_ASM_MARKER, CK_TILE_FMHA_V3_ADD_SBARRIER_FOR_PHASE0) into a shared header block_fmha_fwd_v3_detail.hpp. Both fmha_fwd V3 and batch_prefill V3 pipelines now include this header instead of batch_prefill including the full fwd V3 pipeline header. Also: - Remove __gfx950__ guards inside V3 pipelines (keep only permlane32_swap path; entire pipeline is gfx950-only) - Add top-level gfx950 guard: operator() returns empty output on non-gfx950 device - Remove CK_TILE_DISABLE_PACKED_FP32 macro (always 0; V3 uses kernel_attr_for instead of -mllvm flag) - Add CK_TILE_ prefix to all custom macros - Add design comments for quant/LSC-unaware scheduler Assembly output verified identical (only __hip_cuid differs).
V3 batch prefill generates only LINEAR KV layout kernels. VECTORIZED layout requires sub-dword async loads that violate V3's buffer addressing constraints, and KV layout optimization has not been done for V3. VECTORIZED requests fall back to V2 via trait matching.
The attribute conflicts with explicit v_pk_mul_f32 inline asm in the KV_BLOCKSCALE descale path. Benchmarks (6 sweeps, avg runs 2-6) show removing it costs 2-4% on pertensor but is neutral-to-positive on blockscale, making the trade-off acceptable vs the asm conflict.
Contributor
Author
|
still be blocked by unrelated compilation errors |
Preserve extraction refactoring (v3_detail.hpp shared header) while adapting post-review fixes from PR 6051 (fmha_fwd V3) into both the fwd and batch_prefill V3 pipelines: 1. Move s_waitcnt/s_barrier outside `if(2 < num_total_loop)` guard in pre-stage to ensure K1+V0 async loads are drained before core_loop reads K1 from LDS (bug fix for num_total_loop==2 case). 2. Replace manual __builtin_amdgcn_permlane32_swap intrinsic calls with block_tile_reduce/block_tile_reduce_sync in fmha_alu0 (rowmax) and fmha_alu1 (rowsum), preserving kFoldKDescale logic in batch_prefill. 3. Split fmha_alu_D_upd into unpack/pack with interleaved scheduling. 4. Add CK_TILE_DISABLE_PACKED_FP32 guard on schedule_gemm1_compute(). 5. Add fmha_alu_D_reg_cnt % 2 == 0 assertion.
Lines 174-175 duplicated the specializations already at lines 125-126, causing redefinition errors. The duplicates were introduced by the merge (both branches added the same entries independently).
The merge conflict resolution dropped the Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> specialization, causing compilation errors when cshuffle_epilogue instantiates WarpGemmDispatcher with fp8 types and isCTransposed=false.
Contributor
Author
|
the CI always failed to pull CK image |
Resolve merge conflicts from StreamLLM sink token support (#6479) landing on develop after the V3 batch prefill refactor. Integrate F_sink parameter into the refactored multi-arch factory codegen.
The refactoring to CompatibilityRuleFactory dropped the original `if mode != "group": continue` guard, allowing batch-mode kernel instantiations that hit the static_assert in pipeline problem.
51085d9 to
515cb8f
Compare
…ispatch Add #pragma clang diagnostic push/pop around the v3 dispatch block in the API footer template, matching the pattern already used in fmha_fwd.py. Without this, gfx942 builds (which have no v3 kernels) emit if(false) and fail with -Werror,-Wunreachable-code.
Return None from get_factory for targets without a factory instead of raising an exception. Filter unsupported targets before dispatching to get_factories_for_targets so that builds targeting e.g. gfx1101 produce an empty kernel pool instead of crashing at configure time.
# Conflicts: # projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Fix -Werror,-Wunused-variable compilation error.
Contributor
|
To fix the OOB issue surfaced by AICK-1171, I based the V2 changes on your PR and opened a separate PR — #6932 — for the V2
Verified on MI-308X with the AICK-1171 reproducer and the full FMHA batch prefill suite on gfx942/gfx950. Splitting it out lets the OOB fix land on its own timeline; the rest of #6054 stays as-is. Credit to your original approach — same |
7 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
V3 FMHA pipeline for batch prefill with paged KV cache on gfx950 (FP8 only). Uses the same 8-warp, 256x64 tile, double-buffered LDS architecture as the contiguous
fmha_fwdV3 kernel, extended with scatter-gather page table support.Supported configurations:
Performance vs V2 (FP8 batch prefill, paged KV, stock ROCm 7.1.1, avg runs 2–6, MI355X):
Per-tensor (ps=1, sglang):
KV block-scale (ps=1024, sglang):
Key design decisions
Branchless page_id clamping — The scatter-gather path bypasses buffer descriptor
NUM_RECORDSprotection, so auto-advancing pastseqlen_kcauses XNACK faults. Solved withmin(page_id, max_page_table_idx)instead of branch guards, which would fragmentsched_group_barrier-scheduled basic blocks.Page advance issue/consume split — Page table lookups (
global_load_dword) are issued BEFOREbuffer_loads so they sit oldest in the vmcnt FIFO. At consume,s_waitcnt(N)drains only the page lookup while keeping N buffer_loads in flight.sched_barrier(0)prevents reordering.KV block-scale k_descale fold — k_descale is folded into the scalar
row_maxvia the FP8 shift trick, eliminating a full-tile VALU pass. v_descale is merged into the softmax rescale factor (o_acc_scale), with the final correction applied in the epilogue.No
no-packed-fp32-opskernel attribute — V3 does NOT usekernel_attr<true>(which addstarget("no-packed-fp32-ops")) despite it giving a +2–4% benefit on per-tensor workloads. The attribute conflicts with explicitv_pk_mul_f32inline asm used in the KV block-scale descale path (pk_mul_f32helper): when the asm is inlined into the attributed kernel entry, the assembler rejects the instruction. Benchmarks show the attribute is neutral-to-positive for block-scale (removing it actually gives +1–3%), so we accept the 2–4% per-tensor trade-off to avoid the asm conflict. A future per-variant solution (attribute on for per-tensor, off for block-scale) could recover this, but the codegen currently uses a singleF_kernel_attrper kernel.New files
fmha_batch_prefill_v3_kernel.hppblock_fmha_batch_prefill_v3_pipeline.hppblock_fmha_batch_prefill_v3_pipeline_default_policy.hppblock_fmha_fwd_v3_detail.hppCoreLoopSchedulingParams, VALU intrinsics, macrosCodegen changes (
fmha_batch_prefill.py)Refactored to support per-architecture kernel generation (matching
fmha_fwd.pypatterns):ArchTrait/ factory hierarchy:gfx9->gfx950with arch-specific tiles, pipelines, and compatibility rulesfmha_batch_prefill_v2()/fmha_batch_prefill_v3()dispatch functions (V3 tried first, falls back to V2)Bug fix in
fmha_fwd.pyFixed missing
returnincheck_hdim— bareFalsewas a no-op, causing unnecessary bias/dropout kernels for hdim=(192,128).Dependencies
Based on #4437 (
users/poyenc/fa-v3-fp8-pertensor). That PR must be merged first.Test plan
test_batch_prefill.pysuite