[Paged KV] Inline metal kernel. deprecate hf pytorch kernel#136
[Paged KV] Inline metal kernel. deprecate hf pytorch kernel#136LxYuan0420 merged 15 commits intovllm-project:mainfrom
Conversation
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
a580e21 to
02c26c0
Compare
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
|
@LxYuan0420 request for review.
|
## Summary - Add `test_paged_deterministic.py`: 5-prompt smoke test using vLLM offline inference (temp=0, greedy) against hardcoded golden token IDs from Qwen3-0.6B - Golden values generated on `main` from both MLX inline cache and HF paged KV cache paths - Add `tools/gen_golden.py` helper to regenerate golden values ## Motivation Prerequisite for the native Metal kernel PR (#136). After inlining the vendored Metal shaders, paged attention output must remain identical to the current HF kernel baseline. This test anchors that. ## Test - `python -m pytest tests/test_paged_deterministic.py -v -s` (paged path by default) - Passes on `main` with HF kernel: 5/5 ## Relevant Issue & PR * Issue #119 * PR #136 : This inline metal kernel need to either pass this test, or explain the possible non-deterministics from the kernel. upstream batch invariant feature * blog: https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ * main feature: vllm-project/vllm#27433 * vllm upstream batch invariant feature is only compatible with H / B series NVIDIA GPU. A100 not working. See my exp results https://github.com/WindChimeRan/spec_deterministic * community work: vllm-project/vllm#30018 Batch invariant is hardware & kernel dependent. Supporting this feature is non-trivial on metal. output example: <img width="1061" height="721" alt="image" src="https://github.com/user-attachments/assets/bf423b90-c567-408b-8682-e2c36050fb8f" /> --------- Signed-off-by: ran <hzz5361@psu.edu> Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com> Co-authored-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Thanks for the context. Totally understandable. Given this PR introduces a new native C++/JIT execution path, I would like a bit more time to complete review and local validation before final approval. Could you please share before/after perf numbers (main vs this PR, same machine/config)? The 5-pass deterministic test is helpful for correctness smoke testing, but I would also like performance evidence for this backend swap. |
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Resolve README.md conflict: keep both upstream memory settings docs and our acknowledgements section. Signed-off-by: ran <hzz5361@psu.edu>
b31e77b to
69f917d
Compare
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
|
@ricky-chaoju Thanks for testing! pyproject.toml updated. |
|
Benchmarks are now in the PR description: TL;DR ~3.6x throughput improvement over the HF kernel path. Worth noting that the scope here is fairly contained: it's a backend swap behind the existing The kernels themselves are temporary. They'll be replaced once we have first-class variable-length support (might be my next PR). Happy to address any concerns, but wanted to flag that this is low-risk to land and iterate on. |
This is great. |
LxYuan0420
left a comment
There was a problem hiding this comment.
Thanks for driving this forward. A few minor changes are needed before we merge this
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
|
@LxYuan0420 Fixed. Will cleanup the kernel files and try var len flashatt in the next PR. |
LxYuan0420
left a comment
There was a problem hiding this comment.
LGTM.
Thanks for the quick update.
## Simplify paged KV cache to unified 4D layout **The goal is to align with vllm varlen kernel's KV cache layout** Replace the x-tiled key cache layout (`[num_blocks, kv_heads, head_dim//x, block_size, x]`) and separate value layout (`[num_blocks, kv_heads, head_dim, block_size]`) with a single unified layout for both K and V: ``` [num_blocks, block_size, num_kv_heads, head_size] ``` ### Changes - `cache.py` — Both K and V use the same simple 4D layout - `reshape_and_cache.metal` — Simplified from 5D/4D indexing to row-major 4D - `pagedattention.metal` — K access simplified (no x-tiling), V accumulation rewritten for contiguous head_size - `gather_kv_cache.metal` — Updated to unified layout - `paged_ops.cpp` — Updated stride calculations ### Testing - `test_paged_deterministic.py` — **All 5 prompts pass** (golden token match, `max_num_seqs=1`) - `test_metal_kernel_paged.py::test_greedy_output_matches` — **Now passes** (was xfail before) - `test_metal_kernel_paged.py::test_batched_decode_matches` — Flipped to xfail. Root cause is B=1 vs B=2 GEMM floating-point sensitivity in MLX (not a kernel bug). The deterministic test is the more robust correctness check and it passes cleanly. No regression in algorithm correctness. ### Benchmark run same benchmark script as #136 results ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 32 Request rate configured (RPS): 10.00 Benchmark duration (s): 103.34 Total input tokens: 23260 Total generated tokens: 22061 Request throughput (req/s): 0.97 Output token throughput (tok/s): 213.49 Peak output token throughput (tok/s): 325.00 Peak concurrent requests: 35.00 Total token throughput (tok/s): 438.58 ---------------Time to First Token---------------- Mean TTFT (ms): 2001.54 Median TTFT (ms): 422.31 P99 TTFT (ms): 8293.96 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 130.97 Median TPOT (ms): 116.35 P99 TPOT (ms): 590.88 ---------------Inter-token Latency---------------- Mean ITL (ms): 107.96 Median ITL (ms): 96.17 P99 ITL (ms): 420.17 ================================================== ``` **Benchmark impact:** This refactoring adds ~8ms to mean ITL (108ms → 100ms baseline), ~1.2s to mean TTFT (2002ms → 844ms baseline), and reduces output throughput by ~26 tok/s (213 → 239 baseline), which is the expected cost of aligning with upstream vllm's architecture as the foundation for continuous batching. --------- Signed-off-by: ran <hzz5361@psu.edu>
This PR is: - To remove a stale `xfail` on `test_greedy_output_matches` that was originally added for issue #119. - To align test expectation with current `main` behavior after paged-path fixes already merged. - To keep parity tracking accurate while leaving batched behavior to its own tracking path. ## Context Issue #119 reported token mismatch parity failures between: - standard MLX KV cache path, and - Metal paged-attention path. Since then, two key fixes landed: - #125 corrected paged KV cache dtype inference/fallback behavior and KV cache size accounting used by paged memory/block calculations. - #136 replaced the HF/PyTorch kernel-bridge path with native MLX + inline Metal JIT dispatch (`get_ops`/nanobind), removing cross-framework bridge behavior from paged execution. With those changes, the old greedy mismatch from #119 no longer reproduces on `main`, so the greedy `xfail` is stale. ## Verification ```bash pytest -q tests/test_metal_kernel_paged.py::TestMetalKernelPagedVsStandard::test_greedy_output_matches -s pytest -m slow -q tests/test_metal_kernel_paged.py ``` Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
## Summary
- Add `find_seq_idx` binary search to the v2 Metal kernel so each
threadgroup discovers its sequence from a flat `cu_seqlens_q` array,
enabling variable-length queries (prefill + decode in one launch)
- **This PR does not take actual effect in production**. The current
production still use `mx.sdpa` for prefilling, and use this PR v2 for
decoding. But the kernels_v2 is identical to previous v1, by freezing
some parameters.
- Pass all triangle tests. safe to move forward to stage 3 continuous
batching.
Notes:
- vendored feature from upstream vllm: adding sliding window support,
and soft capping to the v2 kernel
- Update production decode path to match the new function signature
(default params, no behavior change)
- **These features are NOT TESTED IN END-to-END production usage, they
are expected to be binded with specific models such as early version of
mistral models.**
## Triangle Test Status
```
ref (pure-MLX naive)
/ \
edge 1 edge 3
/ \
v1 ── edge 2 ── v2
```
- **Edge 1** (v1 == ref): 6 pass (unchanged)
- **Edge 2** (v2 == v1): 6 pass (unchanged)
- **Edge 3** (v2 == ref): **24 pass** (was 3 pass + 21 xfail)
- varlen (q_len > 1): now passing
- sliding window (128): now passing
- soft capping (50.0): now passing
**Before:** 15 passed + 21 xfail → **After:** 36 passed + 0 xfail
## What's NOT in this PR
The kernel now supports unified prefill+decode, but production still
uses the split path (MLX SDPA for prefill, v2 kernel for decode). Wiring
`metal_unified_attention()` into `model_runner.py` is a follow-up.
## Numeric Stability
```
python -m pytest tests/test_paged_deterministic.py -v -s
```
* Before this PR: 5/6 match mlx_lm path
* After this PR: 6/6 match mlx_lm path
However, I don't want to change the test for now. The test result will
flip on and off later by the following PRs.
## Benchmark
run same benchmark script as
#136
<details>
<summary>This PR: </summary>
```
============ Serving Benchmark Result ============
Successful requests: 100
Failed requests: 0
Maximum request concurrency: 32
Request rate configured (RPS): 10.00
Benchmark duration (s): 107.24
Total input tokens: 23260
Total generated tokens: 22061
Request throughput (req/s): 0.93
Output token throughput (tok/s): 205.71
Peak output token throughput (tok/s): 319.00
Peak concurrent requests: 35.00
Total token throughput (tok/s): 422.60
---------------Time to First Token----------------
Mean TTFT (ms): 593.33
Median TTFT (ms): 386.44
P99 TTFT (ms): 2147.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 134.30
Median TPOT (ms): 127.79
P99 TPOT (ms): 477.35
---------------Inter-token Latency----------------
Mean ITL (ms): 117.58
Median ITL (ms): 104.05
P99 ITL (ms): 473.33
==================================================
```
</details>
<details>
<summary>before this PR:</summary>
```
============ Serving Benchmark Result ============
Successful requests: 100
Failed requests: 0
Maximum request concurrency: 32
Request rate configured (RPS): 10.00
Benchmark duration (s): 106.42
Total input tokens: 23260
Total generated tokens: 22061
Request throughput (req/s): 0.94
Output token throughput (tok/s): 207.30
Peak output token throughput (tok/s): 320.00
Peak concurrent requests: 35.00
Total token throughput (tok/s): 425.87
---------------Time to First Token----------------
Mean TTFT (ms): 982.74
Median TTFT (ms): 452.35
P99 TTFT (ms): 3030.71
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 132.63
Median TPOT (ms): 124.33
P99 TPOT (ms): 442.78
---------------Inter-token Latency----------------
Mean ITL (ms): 115.19
Median ITL (ms): 101.69
P99 ITL (ms): 440.37
==================================================
```
</details>
This PR has no effects on the performance. It paves the way for
continuous batching.
## Possible Limitation
* binary search is translated from the triton kernel. But it may not be
neccecary. Triton uses it to avoid CPU-GPU data copy, but we are on a
unifed memory. Maybe we can prebuild the reverse map. But from the data
range, O(log(n)) are the same with O(1) but takes less space.
* didn't check the partition on or off.
---------
Signed-off-by: ran <hzz5361@psu.edu>

Overview
This PR eliminates the external
kernels-community/paged-attentionHuggingFace dependency by vendoring the Metal shader files directly and replacing the PyTorch MPS dispatch layer with a nanobind C++ extension that dispatches through MLX's own Metal command encoder.This removes the MLX↔PyTorch MPS fence synchronization bubble.
Key architectural change:
fully zero-copy. single Metal command queue.
FAQ:
Q: Why
mlx.core.fast.metal_kernelCan't Be Used Here? This is much more simpler.A:
mlx.core.fast.metal_kernelalways allocates fresh output arrays and provides inputs as read-only. Paged KVcache requires in-place, scatter-write mutation of a persistent buffer (via
reshape_and_cache) and zero-copy reads from that same buffer (viapaged_attention_v1). If we copied the entire cache in and out on every token, it would negate the whole purpose of paging. The C++ nanobind bridge is the minimum complexity needed to get direct Metal buffer access through MLX's internal command encoder.Q: What's the difference between
kernelandkernel_v1?A:
kernelis a drop-in replacement for the HuggingFace kernel, originally vendored from an older version of mistral.rs. This PR's end-to-end deterministic tests confirm no regression from the replacement.kernel_v1is the latest Metal kernel from the mistral.rs repo — more mature, with preliminary scaffolding for variable-length kernels and gpt-oss sink attention support. That said, neither will persist beyond this PR. Both are slated for deprecation once we introduce first-class variable-length kernel support, which is a prerequisite for continuous batching, chunked prefill, and MQA Scorer speculative decoding.Test
5 PASSED. The model output is identical to the hf kernel version before.
Benchmark (M1 Pro 32G RAM, macos 26.1)
Usage
Page KV Cache On
baseline: default, Page KV Cache Off, mlx_lm
mlx_lm baseline
hf kernel main
This PR (Inline paged_kernel)