Skip to content

[Continuous Batching] Packed prefill with cu_seq_lens for multiple requests#151

Merged
LxYuan0420 merged 6 commits intovllm-project:mainfrom
ricky-chaoju:feat/packed-prefill-cu-seqlen
Mar 11, 2026
Merged

[Continuous Batching] Packed prefill with cu_seq_lens for multiple requests#151
LxYuan0420 merged 6 commits intovllm-project:mainfrom
ricky-chaoju:feat/packed-prefill-cu-seqlen

Conversation

@ricky-chaoju
Copy link
Contributor

@ricky-chaoju ricky-chaoju commented Mar 9, 2026

Summary

截圖 2026-03-10 上午9 15 51
  • Pack multiple complete prefill requests into a single forward pass using cumulative sequence lengths (cu_seq_lens) as separators
  • Build block-diagonal causal mask so packed requests don't cross-attend
  • Apply per-request RoPE position reset within packed sequences
  • Fall back to single-request path when only 1 prefill is scheduled

This is the first step of Stage 3 (Chunked Prefilling & Continuous Batching) in the roadmap (#148).

Changes

  • paged_attention_common.py: Add cu_seq_lens field to PagedAttentionContext; add prepare_prefill_packed()
  • paged_attention.py: Add _build_packed_causal_mask(); update _metal_kernel_prefill_attention() for packed mode
  • model_runner.py: Add _prefill_packed_paged(); restructure Phase 1 to collect complete prefills and batch them
  • test_paged_attention.py: Add tests for packed slot_mapping, cu_seq_lens, and block-diagonal causal mask isolation

Test

  • Unit tests: 9/9 pass (pytest tests/test_paged_attention.py)
  • E2E: Qwen3-0.6B with VLLM_METAL_USE_PAGED_ATTENTION=1, 3 concurrent requests all return correct responses
  • Verified packed path triggered via debug log: scheduler batches 2 complete requests into packed prefill

@ricky-chaoju ricky-chaoju force-pushed the feat/packed-prefill-cu-seqlen branch 2 times, most recently from a564827 to 8af0edf Compare March 9, 2026 15:14
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Copy link
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice first step toward Stage 3. A few principles I'd like us to follow for this PR:

Backward plumbing, not forward optimization. This PR's real value is the cu_seq_lens plumbing that will wire into the varlen kernel later. The dense mask approach is correct but temporary. Let's optimize for code clarity, not runtime performance at this stage.

Separate scaffolding from durable plumbing. _build_packed_causal_mask and the mask branching in the attention wrapper are throwaway code that gets deleted when the varlen kernel lands. Please move them out of paged_attention.py (e.g. into packed_prefill_compat.py) and add clear # SCAFFOLDING: remove when varlen kernel is ready markers. The durable pieces should be easy to distinguish at a glance.

Unified path, no fallback. The elif len(paged_complete) == 1 branch adds ~20 lines of duplicated state management for a negligible gain. The packed path already handles 1 request correctly. One code path = less to review, less to break, less to swap out later. Saw similar pattern on other place as well

Align with upstream vLLM interfaces. Since this plumbing will eventually connect to a real kernel, let's match upstream naming and conventions where possible (e.g. cu_seqlens without underscore, matching flash_attn_varlen_func). Less friction when integrating varlen kernel or doing future upstream alignment.

Benchmark performance gain or loss see how in #136

if self._rust_state_manager is not None:
self._rust_state_manager.add_request(rid, list(tids) + [nt])
elif len(paged_complete) == 1:
# Single request: no packing overhead
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we always go through the packed path and drop this elif branch? don't want to overcomplicate the codebase. In the future, once we have a varlen kernel, one request vs multiple request will be the same.

…ify path

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
@ricky-chaoju
Copy link
Contributor Author

Thanks for the review! All four points addressed:

  1. cu_seq_lens → cu_seqlens — aligned with upstream naming
  2. Scaffolding separated — build_packed_causal_mask + apply_packed_rope moved to packed_prefill_compat.py with # SCAFFOLDING: remove when varlen kernel is ready
  3. Unified path — dropped elif len(paged_complete) == 1, always goes through packed path
  4. Benchmark — no regression (TTFT -8.8%, throughput +0.7%, ITL -1.4%). Real gains expected when varlen kernel consumes cu_seqlens natively.

Also fixed a dtype bug: packed mask was float32 but model runs bfloat16.
截圖 2026-03-10 上午9 15 51

@WindChimeRan
Copy link
Collaborator

@ricky-chaoju failed on the preemption test

python tools/repro_block_exhaustion.py

could you please take a look?

[rank0]:   File "/Users/ran/workspace/vllm-metal/vllm_metal/metal_kernel_backend/paged_attention.py", line 338, in __call__
[rank0]:     return _metal_kernel_prefill_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/Users/ran/workspace/vllm-metal/vllm_metal/metal_kernel_backend/paged_attention.py", line 153, in _metal_kernel_prefill_attention
[rank0]:     output = mx.fast.scaled_dot_product_attention(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: ValueError: [scaled_dot_product_attention] Mask type must promote to output type bfloat16.

@ricky-chaoju
Copy link
Contributor Author

Fixed in latest commit (59aed34). The packed causal mask was float32 but the model runs bfloat16 — added .astype(queries.dtype).

repro_block_exhaustion.py passes locally but KV cache usage only reached 31% (my machine has 192GB RAM), so preemption likely didn't trigger. What fraction triggers exhaustion on your setup? I can rerun with that value.
截圖 2026-03-10 下午1 46 48

@WindChimeRan
Copy link
Collaborator

@ricky-chaoju my bad. I was not using the latest code. It works now.

Copy link
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, do you want to take another look? @LxYuan0420

"""
# Start with all-blocked, then open causal windows per request
mask = mx.full((total_len, total_len), -mx.inf)
for i in range(len(cu_seqlens) - 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is temporary scaffolding. Since it is active in production path now, please add a packed-length safety cap (or split batches) to avoid O(N²) mask blowups until varlen kernel replaces it. (Packing currently happens unconditionally in model_runner complete-prefill path.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
# Causal SDPA
# SCAFFOLDING: dense mask — remove when varlen kernel is ready.
if ctx.cu_seqlens is not None and len(ctx.cu_seqlens) > 2:
attn_mask = build_packed_causal_mask(ctx.cu_seqlens, L).astype(queries.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small follow-up suggestion: could we pass dtype into build_packed_causal_mask and construct the mask directly in that dtype? Right now we allocate default float32 and then cast to queries.dtype, which can increase transient peak memory in the packed path.

assert ctx.offsets == [7]


class TestPackedCausalMask:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add one focused unit test for apply_packed_rope as well? We already validate packed mask isolation here; a small RoPE-reset test would lock in the other key packed-prefill contract.

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Comment on lines +1953 to +1998
# SCAFFOLDING: cap packed length to avoid O(N²) mask blowup.
# Remove when varlen kernel is ready.
max_packed_tokens = 4096
if paged_complete:
pack_input = [
(rid, tids, sp, bids, gen, None)
for _, rid, tids, sp, bids, gen in paged_complete
]
next_tokens = self._prefill_packed_paged(pack_input)
for i, (idx, rid, tids, sp, bids, gen) in enumerate(paged_complete):
nt = next_tokens[i]
sampled_tokens[idx] = [nt]
self._request_states[rid] = RequestState(
token_ids=list(tids) + [nt],
prompt_len=len(tids),
cache=[],
sampling_params=sp,
generator=gen,
generated_tokens=1,
block_ids=bids,
)
if self._rust_state_manager is not None:
self._rust_state_manager.add_request(rid, list(tids) + [nt])
# Split into batches that fit within the packed-length cap.
batches: list[
list[
tuple[
int,
str,
list[int],
SamplingParams,
list[int],
torch.Generator | None,
]
]
] = [[]]
batch_tokens = 0
for entry in paged_complete:
entry_tokens = len(entry[2]) # token_ids
if batch_tokens + entry_tokens > max_packed_tokens and batches[-1]:
batches.append([])
batch_tokens = 0
batches[-1].append(entry)
batch_tokens += entry_tokens

for batch in batches:
pack_input = [
(rid, tids, sp, bids, gen, None)
for _, rid, tids, sp, bids, gen in batch
]
next_tokens = self._prefill_packed_paged(pack_input)
for i, (idx, rid, tids, sp, bids, gen) in enumerate(batch):
nt = next_tokens[i]
sampled_tokens[idx] = [nt]
self._request_states[rid] = RequestState(
token_ids=list(tids) + [nt],
prompt_len=len(tids),
cache=[],
sampling_params=sp,
generator=gen,
generated_tokens=1,
block_ids=bids,
)
if self._rust_state_manager is not None:
self._rust_state_manager.add_request(rid, list(tids) + [nt])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit makes this path harder to reason about.
Could you share your reasoning/trade-offs for this long inlined block so we can align before more changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent is to split packed requests into batches capped at 4096 tokens, bounding the O(N²) dense mask allocation until the varlen kernel replaces it. I'll extract the batching + state-update loop into a _run_packed_prefill helper to keep the main path readable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, that is the right intent

Please keep the follow-up scoped to:

  • extract only batching + dispatch + state writeback into _run_packed_prefill (no behavior change)
  • make 4096 a named constant with a short rationale comment
  • add one focused test for split-batch mapping correctness
  • keep this as scaffolding-only and avoid adding more policy logic here.

Copy link
Contributor Author

@ricky-chaoju ricky-chaoju Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 074c865

  1. Extracted _run_packed_prefill helper — batching + dispatch + state writeback
  2. MAX_PACKED_PREFILL_TOKENS = 4096 as module-level named constant with rationale
  3. Added TestBatchSplitting with 4 focused tests (single batch, split, oversized request, entry preservation)
  4. No behavior change, scaffolding only

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
@ricky-chaoju ricky-chaoju requested a review from LxYuan0420 March 11, 2026 01:04
Copy link
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One follow-up: please validate cu_seqlens in packed mask builder (starts at 0, monotonic, and last == total_len) so invalid context fails fast

@LxYuan0420 LxYuan0420 merged commit fac064f into vllm-project:main Mar 11, 2026
5 checks passed
@ricky-chaoju ricky-chaoju deleted the feat/packed-prefill-cu-seqlen branch March 11, 2026 02:49
renqHIT added a commit to renqHIT/vllm-metal that referenced this pull request Mar 13, 2026
The xfail on `test_batched_decode_matches` was added for issue vllm-project#119
(B=2 batched GEMM producing different floats than B=1). The test now
passes consistently on main after recent paged kernel fixes (vllm-project#146, vllm-project#151).
This follows PR vllm-project#149 which removed the same stale xfail for the
greedy single-request test.

Signed-off-by: Qiang <qren@integralads.com>
@WindChimeRan WindChimeRan changed the title [Paged KV] Packed prefill with cu_seq_lens for multiple requests [Continuous Batching] Packed prefill with cu_seq_lens for multiple requests Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants