Skip to content

[Paged KV] Inline metal kernel. deprecate hf pytorch kernel#136

Merged
LxYuan0420 merged 15 commits intovllm-project:mainfrom
WindChimeRan:page_kernel
Mar 9, 2026
Merged

[Paged KV] Inline metal kernel. deprecate hf pytorch kernel#136
LxYuan0420 merged 15 commits intovllm-project:mainfrom
WindChimeRan:page_kernel

Conversation

@WindChimeRan
Copy link
Collaborator

@WindChimeRan WindChimeRan commented Mar 5, 2026

bench_comparison

Overview

This PR eliminates the external kernels-community/paged-attention HuggingFace dependency by vendoring the Metal shader files directly and replacing the PyTorch MPS dispatch layer with a nanobind C++ extension that dispatches through MLX's own Metal command encoder.

This removes the MLX↔PyTorch MPS fence synchronization bubble.

Key architectural change:

  • original: mlx.array → PyTorch MPS bridge → HF kernel → PyTorch MPS bridge → mlx.array
  • this PR: mlx.array → native Metal dispatch → mlx.array

fully zero-copy. single Metal command queue.

FAQ:

  • Q: Why mlx.core.fast.metal_kernel Can't Be Used Here? This is much more simpler.

  • A: mlx.core.fast.metal_kernel always allocates fresh output arrays and provides inputs as read-only. Paged KV
    cache requires in-place, scatter-write mutation of a persistent buffer (via reshape_and_cache) and zero-copy reads from that same buffer (via paged_attention_v1). If we copied the entire cache in and out on every token, it would negate the whole purpose of paging. The C++ nanobind bridge is the minimum complexity needed to get direct Metal buffer access through MLX's internal command encoder.

  • Q: What's the difference between kernel and kernel_v1?

  • A: kernel is a drop-in replacement for the HuggingFace kernel, originally vendored from an older version of mistral.rs. This PR's end-to-end deterministic tests confirm no regression from the replacement. kernel_v1 is the latest Metal kernel from the mistral.rs repo — more mature, with preliminary scaffolding for variable-length kernels and gpt-oss sink attention support. That said, neither will persist beyond this PR. Both are slated for deprecation once we introduce first-class variable-length kernel support, which is a prerequisite for continuous batching, chunked prefill, and MQA Scorer speculative decoding.

Test

python -m pytest tests/test_paged_deterministic.py -v -s

5 PASSED. The model output is identical to the hf kernel version before.

Benchmark (M1 Pro 32G RAM, macos 26.1)

Usage

Page KV Cache On

VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 vllm serve Qwen/Qwen3-0.6B --max-model-len 2048
vllm bench serve --backend vllm --model Qwen/Qwen3-0.6B \
    --endpoint /v1/completions \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --num-prompts 100 \
    --request-rate 10 \
    --max-concurrency 32

baseline: default, Page KV Cache Off, mlx_lm

vllm serve Qwen/Qwen3-0.6B --max-model-len 2048
mlx_lm baseline
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  345.33
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.29
Output token throughput (tok/s):         63.88
Peak output token throughput (tok/s):    96.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          131.24
---------------Time to First Token----------------
Mean TTFT (ms):                          69638.28
Median TTFT (ms):                        80145.51
P99 TTFT (ms):                           111926.69
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          156.87
Median TPOT (ms):                        92.12
P99 TPOT (ms):                           1007.68
---------------Inter-token Latency----------------
Mean ITL (ms):                           92.78
Median ITL (ms):                         68.17
P99 ITL (ms):                            298.20
==================================================
hf kernel main
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  336.21
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.30
Output token throughput (tok/s):         65.62
Peak output token throughput (tok/s):    128.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          134.80
---------------Time to First Token----------------
Mean TTFT (ms):                          1861.20
Median TTFT (ms):                        1019.14
P99 TTFT (ms):                           6470.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          365.51
Median TPOT (ms):                        349.28
P99 TPOT (ms):                           844.43
---------------Inter-token Latency----------------
Mean ITL (ms):                           340.26
Median ITL (ms):                         312.74
P99 ITL (ms):                            811.23
==================================================
This PR (Inline paged_kernel)
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  92.14
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              1.09
Output token throughput (tok/s):         239.43
Peak output token throughput (tok/s):    358.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          491.87
---------------Time to First Token----------------
Mean TTFT (ms):                          844.34
Median TTFT (ms):                        434.30
P99 TTFT (ms):                           2657.34
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          120.13
Median TPOT (ms):                        107.88
P99 TPOT (ms):                           522.26
---------------Inter-token Latency----------------
Mean ITL (ms):                           99.98
Median ITL (ms):                         86.99
P99 ITL (ms):                            471.89
==================================================

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan marked this pull request as ready for review March 6, 2026 04:32
@WindChimeRan
Copy link
Collaborator Author

@LxYuan0420 request for review.

  • I'm not very familiar with the cpp + JIT stuff. I hope it works on your machine out-of-the-box.
  • Ideally, in the future, maybe we can release the compiled metal kernel binary: install everything in one line together with the wheel Bump vllm to 0.16.0 and install from wheel #134

LxYuan0420 added a commit that referenced this pull request Mar 6, 2026
## Summary
- Add `test_paged_deterministic.py`: 5-prompt smoke test using vLLM
offline inference (temp=0, greedy) against hardcoded golden token IDs
from Qwen3-0.6B
- Golden values generated on `main` from both MLX inline cache and HF
paged KV cache paths
- Add `tools/gen_golden.py` helper to regenerate golden values

## Motivation
Prerequisite for the native Metal kernel PR (#136). After inlining the
vendored Metal shaders, paged attention output must remain identical to
the current HF kernel baseline. This test anchors that.

## Test
- `python -m pytest tests/test_paged_deterministic.py -v -s` (paged path
by default)
- Passes on `main` with HF kernel: 5/5

## Relevant Issue & PR

* Issue #119 
* PR #136 : This inline metal kernel need to either pass this test, or
explain the possible non-deterministics from the kernel.

upstream batch invariant feature
* blog:
https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
* main feature: vllm-project/vllm#27433
* vllm upstream batch invariant feature is only compatible with H / B
series NVIDIA GPU. A100 not working. See my exp results
https://github.com/WindChimeRan/spec_deterministic
* community work: vllm-project/vllm#30018

Batch invariant is hardware & kernel dependent. Supporting this feature
is non-trivial on metal.


output example: 
<img width="1061" height="721" alt="image"
src="https://github.com/user-attachments/assets/bf423b90-c567-408b-8682-e2c36050fb8f"
/>

---------

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Co-authored-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: ran <hzz5361@psu.edu>
@ricky-chaoju
Copy link
Contributor

It works on my machine, but I had to manually install nanobind first. Until pre-compiled binaries are bundled in the wheel, should nanobind be added to the pyproject.toml dependencies?
截圖 2026-03-06 晚上9 38 47

@LxYuan0420
Copy link
Collaborator

@LxYuan0420 request for review.

  • I'm not very familiar with the cpp + JIT stuff. I hope it works on your machine out-of-the-box.
  • Ideally, in the future, maybe we can release the compiled metal kernel binary: install everything in one line together with the wheel Bump vllm to 0.16.0 and install from wheel #134

Thanks for the context. Totally understandable.

Given this PR introduces a new native C++/JIT execution path, I would like a bit more time to complete review and local validation before final approval.

Could you please share before/after perf numbers (main vs this PR, same machine/config)?

The 5-pass deterministic test is helpful for correctness smoke testing, but I would also like performance evidence for this backend swap.

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Resolve README.md conflict: keep both upstream memory settings docs
and our acknowledgements section.

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan
Copy link
Collaborator Author

@ricky-chaoju Thanks for testing! pyproject.toml updated.

@WindChimeRan
Copy link
Collaborator Author

@LxYuan0420

Benchmarks are now in the PR description: TL;DR ~3.6x throughput improvement over the HF kernel path.

Worth noting that the scope here is fairly contained: it's a backend swap behind the existing VLLM_METAL_USE_PAGED_ATTENTION=1 toggle (off by default), and the deterministic tests confirm identical output.

The kernels themselves are temporary. They'll be replaced once we have first-class variable-length support (might be my next PR).

Happy to address any concerns, but wanted to flag that this is low-risk to land and iterate on.

@LxYuan0420
Copy link
Collaborator

@LxYuan0420

Benchmarks are now in the PR description: TL;DR ~3.6x throughput improvement over the HF kernel path.

Worth noting that the scope here is fairly contained: it's a backend swap behind the existing VLLM_METAL_USE_PAGED_ATTENTION=1 toggle (off by default), and the deterministic tests confirm identical output.

The kernels themselves are temporary. They'll be replaced once we have first-class variable-length support (might be my next PR).

Happy to address any concerns, but wanted to flag that this is low-risk to land and iterate on.

This is great.

Copy link
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for driving this forward. A few minor changes are needed before we merge this

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan requested a review from LxYuan0420 March 8, 2026 14:35
@WindChimeRan
Copy link
Collaborator Author

@LxYuan0420 Fixed. Will cleanup the kernel files and try var len flashatt in the next PR.

Copy link
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Thanks for the quick update.

@LxYuan0420 LxYuan0420 merged commit 8f60f3e into vllm-project:main Mar 9, 2026
5 checks passed
LxYuan0420 pushed a commit that referenced this pull request Mar 9, 2026
## Simplify paged KV cache to unified 4D layout 

**The goal is to align with vllm varlen kernel's KV cache layout**

Replace the x-tiled key cache layout (`[num_blocks, kv_heads,
head_dim//x, block_size, x]`) and separate value layout (`[num_blocks,
kv_heads, head_dim, block_size]`) with a single unified layout for both
K and V:

```
[num_blocks, block_size, num_kv_heads, head_size]
```

### Changes

- `cache.py` — Both K and V use the same simple 4D layout
- `reshape_and_cache.metal` — Simplified from 5D/4D indexing to
row-major 4D
- `pagedattention.metal` — K access simplified (no x-tiling), V
accumulation rewritten for contiguous head_size
- `gather_kv_cache.metal` — Updated to unified layout
- `paged_ops.cpp` — Updated stride calculations

### Testing

- `test_paged_deterministic.py` — **All 5 prompts pass** (golden token
match, `max_num_seqs=1`)
- `test_metal_kernel_paged.py::test_greedy_output_matches` — **Now
passes** (was xfail before)
- `test_metal_kernel_paged.py::test_batched_decode_matches` — Flipped to
xfail. Root cause is B=1 vs B=2 GEMM floating-point sensitivity in MLX
(not a kernel bug). The deterministic test is the more robust
correctness check and it passes cleanly.

No regression in algorithm correctness.


### Benchmark

run same benchmark script as #136 

results
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  103.34
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.97
Output token throughput (tok/s):         213.49
Peak output token throughput (tok/s):    325.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          438.58
---------------Time to First Token----------------
Mean TTFT (ms):                          2001.54
Median TTFT (ms):                        422.31
P99 TTFT (ms):                           8293.96
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          130.97
Median TPOT (ms):                        116.35
P99 TPOT (ms):                           590.88
---------------Inter-token Latency----------------
Mean ITL (ms):                           107.96
Median ITL (ms):                         96.17
P99 ITL (ms):                            420.17
==================================================
```



**Benchmark impact:** This refactoring adds ~8ms to mean ITL (108ms →
100ms baseline), ~1.2s to mean TTFT (2002ms → 844ms baseline), and
reduces output throughput by ~26 tok/s (213 → 239 baseline), which is
the expected cost of aligning with upstream vllm's architecture as the
foundation for continuous batching.

---------

Signed-off-by: ran <hzz5361@psu.edu>
LxYuan0420 added a commit that referenced this pull request Mar 11, 2026
This PR is:
- To remove a stale `xfail` on `test_greedy_output_matches` that was
originally added for issue #119.
- To align test expectation with current `main` behavior after
paged-path fixes already merged.
- To keep parity tracking accurate while leaving batched behavior to its
own tracking path.

## Context

Issue #119 reported token mismatch parity failures between:
- standard MLX KV cache path, and
- Metal paged-attention path.

Since then, two key fixes landed:
- #125 corrected paged KV cache dtype inference/fallback behavior and KV
cache size accounting used by paged memory/block calculations.
- #136 replaced the HF/PyTorch kernel-bridge path with native MLX +
inline Metal JIT dispatch (`get_ops`/nanobind), removing cross-framework
bridge behavior from paged execution.

With those changes, the old greedy mismatch from #119 no longer
reproduces on `main`, so the greedy `xfail` is stale.

## Verification

```bash
pytest -q tests/test_metal_kernel_paged.py::TestMetalKernelPagedVsStandard::test_greedy_output_matches -s
pytest -m slow -q tests/test_metal_kernel_paged.py
```

Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
LxYuan0420 pushed a commit that referenced this pull request Mar 17, 2026
## Summary

- Add `find_seq_idx` binary search to the v2 Metal kernel so each
threadgroup discovers its sequence from a flat `cu_seqlens_q` array,
enabling variable-length queries (prefill + decode in one launch)
- **This PR does not take actual effect in production**. The current
production still use `mx.sdpa` for prefilling, and use this PR v2 for
decoding. But the kernels_v2 is identical to previous v1, by freezing
some parameters.
- Pass all triangle tests. safe to move forward to stage 3 continuous
batching.

Notes:
- vendored feature from upstream vllm: adding sliding window support,
and soft capping to the v2 kernel
- Update production decode path to match the new function signature
(default params, no behavior change)
- **These features are NOT TESTED IN END-to-END production usage, they
are expected to be binded with specific models such as early version of
mistral models.**

## Triangle Test Status

```
        ref (pure-MLX naive)
       /         \
  edge 1        edge 3
     /             \
   v1  ── edge 2 ── v2
```

- **Edge 1** (v1 == ref): 6 pass (unchanged)
- **Edge 2** (v2 == v1): 6 pass (unchanged)
- **Edge 3** (v2 == ref): **24 pass** (was 3 pass + 21 xfail)
  - varlen (q_len > 1): now passing
  - sliding window (128): now passing
  - soft capping (50.0): now passing

**Before:** 15 passed + 21 xfail → **After:** 36 passed + 0 xfail


## What's NOT in this PR

The kernel now supports unified prefill+decode, but production still
uses the split path (MLX SDPA for prefill, v2 kernel for decode). Wiring
`metal_unified_attention()` into `model_runner.py` is a follow-up.

## Numeric Stability


```
python -m pytest tests/test_paged_deterministic.py -v -s
```

* Before this PR: 5/6 match mlx_lm path
* After this PR: 6/6 match mlx_lm path

However, I don't want to change the test for now. The test result will
flip on and off later by the following PRs.


## Benchmark

run same benchmark script as
#136

<details>

<summary>This PR: </summary>


```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  107.24
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.93
Output token throughput (tok/s):         205.71
Peak output token throughput (tok/s):    319.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          422.60
---------------Time to First Token----------------
Mean TTFT (ms):                          593.33
Median TTFT (ms):                        386.44
P99 TTFT (ms):                           2147.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          134.30
Median TPOT (ms):                        127.79
P99 TPOT (ms):                           477.35
---------------Inter-token Latency----------------
Mean ITL (ms):                           117.58
Median ITL (ms):                         104.05
P99 ITL (ms):                            473.33
==================================================
```

</details>

<details>

<summary>before this PR:</summary>

```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  106.42
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.94
Output token throughput (tok/s):         207.30
Peak output token throughput (tok/s):    320.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          425.87
---------------Time to First Token----------------
Mean TTFT (ms):                          982.74
Median TTFT (ms):                        452.35
P99 TTFT (ms):                           3030.71
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          132.63
Median TPOT (ms):                        124.33
P99 TPOT (ms):                           442.78
---------------Inter-token Latency----------------
Mean ITL (ms):                           115.19
Median ITL (ms):                         101.69
P99 ITL (ms):                            440.37
==================================================
```

</details>

This PR has no effects on the performance. It paves the way for
continuous batching.

## Possible Limitation

* binary search is translated from the triton kernel. But it may not be
neccecary. Triton uses it to avoid CPU-GPU data copy, but we are on a
unifed memory. Maybe we can prebuild the reverse map. But from the data
range, O(log(n)) are the same with O(1) but takes less space.
* didn't check the partition on or off.

---------

Signed-off-by: ran <hzz5361@psu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants