Skip to content

[varlen Kernel] Extend paged attention v2 to varlen [4/n]#166

Merged
LxYuan0420 merged 6 commits intovllm-project:mainfrom
WindChimeRan:varlen_kernel_final
Mar 17, 2026
Merged

[varlen Kernel] Extend paged attention v2 to varlen [4/n]#166
LxYuan0420 merged 6 commits intovllm-project:mainfrom
WindChimeRan:varlen_kernel_final

Conversation

@WindChimeRan
Copy link
Collaborator

@WindChimeRan WindChimeRan commented Mar 16, 2026

Summary

  • Add find_seq_idx binary search to the v2 Metal kernel so each threadgroup discovers its sequence from a flat cu_seqlens_q array, enabling variable-length queries (prefill + decode in one launch)
  • This PR does not take actual effect in production. The current production still use mx.sdpa for prefilling, and use this PR v2 for decoding. But the kernels_v2 is identical to previous v1, by freezing some parameters.
  • Pass all triangle tests. safe to move forward to stage 3 continuous batching.

Notes:

  • vendored feature from upstream vllm: adding sliding window support, and soft capping to the v2 kernel
  • Update production decode path to match the new function signature (default params, no behavior change)
  • These features are NOT TESTED IN END-to-END production usage, they are expected to be binded with specific models such as early version of mistral models.

Triangle Test Status

        ref (pure-MLX naive)
       /         \
  edge 1        edge 3
     /             \
   v1  ── edge 2 ── v2
  • Edge 1 (v1 == ref): 6 pass (unchanged)
  • Edge 2 (v2 == v1): 6 pass (unchanged)
  • Edge 3 (v2 == ref): 24 pass (was 3 pass + 21 xfail)
    • varlen (q_len > 1): now passing
    • sliding window (128): now passing
    • soft capping (50.0): now passing

Before: 15 passed + 21 xfail → After: 36 passed + 0 xfail

What's NOT in this PR

The kernel now supports unified prefill+decode, but production still uses the split path (MLX SDPA for prefill, v2 kernel for decode). Wiring metal_unified_attention() into model_runner.py is a follow-up.

Numeric Stability

python -m pytest tests/test_paged_deterministic.py -v -s
  • Before this PR: 5/6 match mlx_lm path
  • After this PR: 6/6 match mlx_lm path

However, I don't want to change the test for now. The test result will flip on and off later by the following PRs.

Benchmark

run same benchmark script as #136

This PR:
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  107.24
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.93
Output token throughput (tok/s):         205.71
Peak output token throughput (tok/s):    319.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          422.60
---------------Time to First Token----------------
Mean TTFT (ms):                          593.33
Median TTFT (ms):                        386.44
P99 TTFT (ms):                           2147.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          134.30
Median TPOT (ms):                        127.79
P99 TPOT (ms):                           477.35
---------------Inter-token Latency----------------
Mean ITL (ms):                           117.58
Median ITL (ms):                         104.05
P99 ITL (ms):                            473.33
==================================================
before this PR:
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  106.42
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              0.94
Output token throughput (tok/s):         207.30
Peak output token throughput (tok/s):    320.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          425.87
---------------Time to First Token----------------
Mean TTFT (ms):                          982.74
Median TTFT (ms):                        452.35
P99 TTFT (ms):                           3030.71
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          132.63
Median TPOT (ms):                        124.33
P99 TPOT (ms):                           442.78
---------------Inter-token Latency----------------
Mean ITL (ms):                           115.19
Median ITL (ms):                         101.69
P99 ITL (ms):                            440.37
==================================================

This PR has no effects on the performance. It paves the way for continuous batching.

Possible Limitation

  • binary search is translated from the triton kernel. But it may not be neccecary. Triton uses it to avoid CPU-GPU data copy, but we are on a unifed memory. Maybe we can prebuild the reverse map. But from the data range, O(log(n)) are the same with O(1) but takes less space.
  • didn't check the partition on or off.

Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan changed the title varlen prototype [varlen Kernel] Extend paged attention to varlen [4/n] Mar 16, 2026
@WindChimeRan WindChimeRan changed the title [varlen Kernel] Extend paged attention to varlen [4/n] [varlen Kernel] Extend paged attention v2 to varlen [4/n] Mar 16, 2026
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan marked this pull request as ready for review March 17, 2026 05:22
@LxYuan0420 LxYuan0420 merged commit ed2cefa into vllm-project:main Mar 17, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants