[Continuous Batching] Packed prefill with cu_seq_lens for multiple requests#151
Conversation
a564827 to
8af0edf
Compare
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
8af0edf to
5dcb9b0
Compare
There was a problem hiding this comment.
Nice first step toward Stage 3. A few principles I'd like us to follow for this PR:
Backward plumbing, not forward optimization. This PR's real value is the cu_seq_lens plumbing that will wire into the varlen kernel later. The dense mask approach is correct but temporary. Let's optimize for code clarity, not runtime performance at this stage.
Separate scaffolding from durable plumbing. _build_packed_causal_mask and the mask branching in the attention wrapper are throwaway code that gets deleted when the varlen kernel lands. Please move them out of paged_attention.py (e.g. into packed_prefill_compat.py) and add clear # SCAFFOLDING: remove when varlen kernel is ready markers. The durable pieces should be easy to distinguish at a glance.
Unified path, no fallback. The elif len(paged_complete) == 1 branch adds ~20 lines of duplicated state management for a negligible gain. The packed path already handles 1 request correctly. One code path = less to review, less to break, less to swap out later. Saw similar pattern on other place as well
Align with upstream vLLM interfaces. Since this plumbing will eventually connect to a real kernel, let's match upstream naming and conventions where possible (e.g. cu_seqlens without underscore, matching flash_attn_varlen_func). Less friction when integrating varlen kernel or doing future upstream alignment.
Benchmark performance gain or loss see how in #136
vllm_metal/v1/model_runner.py
Outdated
| if self._rust_state_manager is not None: | ||
| self._rust_state_manager.add_request(rid, list(tids) + [nt]) | ||
| elif len(paged_complete) == 1: | ||
| # Single request: no packing overhead |
There was a problem hiding this comment.
Could we always go through the packed path and drop this elif branch? don't want to overcomplicate the codebase. In the future, once we have a varlen kernel, one request vs multiple request will be the same.
…ify path Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
|
@ricky-chaoju failed on the preemption test python tools/repro_block_exhaustion.pycould you please take a look? |
|
Fixed in latest commit (59aed34). The packed causal mask was float32 but the model runs bfloat16 — added .astype(queries.dtype). repro_block_exhaustion.py passes locally but KV cache usage only reached 31% (my machine has 192GB RAM), so preemption likely didn't trigger. What fraction triggers exhaustion on your setup? I can rerun with that value. |
|
@ricky-chaoju my bad. I was not using the latest code. It works now. |
There was a problem hiding this comment.
LGTM, do you want to take another look? @LxYuan0420
| """ | ||
| # Start with all-blocked, then open causal windows per request | ||
| mask = mx.full((total_len, total_len), -mx.inf) | ||
| for i in range(len(cu_seqlens) - 1): |
There was a problem hiding this comment.
I understand this is temporary scaffolding. Since it is active in production path now, please add a packed-length safety cap (or split batches) to avoid O(N²) mask blowups until varlen kernel replaces it. (Packing currently happens unconditionally in model_runner complete-prefill path.)
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
| # Causal SDPA | ||
| # SCAFFOLDING: dense mask — remove when varlen kernel is ready. | ||
| if ctx.cu_seqlens is not None and len(ctx.cu_seqlens) > 2: | ||
| attn_mask = build_packed_causal_mask(ctx.cu_seqlens, L).astype(queries.dtype) |
There was a problem hiding this comment.
Small follow-up suggestion: could we pass dtype into build_packed_causal_mask and construct the mask directly in that dtype? Right now we allocate default float32 and then cast to queries.dtype, which can increase transient peak memory in the packed path.
| assert ctx.offsets == [7] | ||
|
|
||
|
|
||
| class TestPackedCausalMask: |
There was a problem hiding this comment.
Could we add one focused unit test for apply_packed_rope as well? We already validate packed mask isolation here; a small RoPE-reset test would lock in the other key packed-prefill contract.
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
vllm_metal/v1/model_runner.py
Outdated
| # SCAFFOLDING: cap packed length to avoid O(N²) mask blowup. | ||
| # Remove when varlen kernel is ready. | ||
| max_packed_tokens = 4096 | ||
| if paged_complete: | ||
| pack_input = [ | ||
| (rid, tids, sp, bids, gen, None) | ||
| for _, rid, tids, sp, bids, gen in paged_complete | ||
| ] | ||
| next_tokens = self._prefill_packed_paged(pack_input) | ||
| for i, (idx, rid, tids, sp, bids, gen) in enumerate(paged_complete): | ||
| nt = next_tokens[i] | ||
| sampled_tokens[idx] = [nt] | ||
| self._request_states[rid] = RequestState( | ||
| token_ids=list(tids) + [nt], | ||
| prompt_len=len(tids), | ||
| cache=[], | ||
| sampling_params=sp, | ||
| generator=gen, | ||
| generated_tokens=1, | ||
| block_ids=bids, | ||
| ) | ||
| if self._rust_state_manager is not None: | ||
| self._rust_state_manager.add_request(rid, list(tids) + [nt]) | ||
| # Split into batches that fit within the packed-length cap. | ||
| batches: list[ | ||
| list[ | ||
| tuple[ | ||
| int, | ||
| str, | ||
| list[int], | ||
| SamplingParams, | ||
| list[int], | ||
| torch.Generator | None, | ||
| ] | ||
| ] | ||
| ] = [[]] | ||
| batch_tokens = 0 | ||
| for entry in paged_complete: | ||
| entry_tokens = len(entry[2]) # token_ids | ||
| if batch_tokens + entry_tokens > max_packed_tokens and batches[-1]: | ||
| batches.append([]) | ||
| batch_tokens = 0 | ||
| batches[-1].append(entry) | ||
| batch_tokens += entry_tokens | ||
|
|
||
| for batch in batches: | ||
| pack_input = [ | ||
| (rid, tids, sp, bids, gen, None) | ||
| for _, rid, tids, sp, bids, gen in batch | ||
| ] | ||
| next_tokens = self._prefill_packed_paged(pack_input) | ||
| for i, (idx, rid, tids, sp, bids, gen) in enumerate(batch): | ||
| nt = next_tokens[i] | ||
| sampled_tokens[idx] = [nt] | ||
| self._request_states[rid] = RequestState( | ||
| token_ids=list(tids) + [nt], | ||
| prompt_len=len(tids), | ||
| cache=[], | ||
| sampling_params=sp, | ||
| generator=gen, | ||
| generated_tokens=1, | ||
| block_ids=bids, | ||
| ) | ||
| if self._rust_state_manager is not None: | ||
| self._rust_state_manager.add_request(rid, list(tids) + [nt]) |
There was a problem hiding this comment.
This commit makes this path harder to reason about.
Could you share your reasoning/trade-offs for this long inlined block so we can align before more changes?
There was a problem hiding this comment.
The intent is to split packed requests into batches capped at 4096 tokens, bounding the O(N²) dense mask allocation until the varlen kernel replaces it. I'll extract the batching + state-update loop into a _run_packed_prefill helper to keep the main path readable.
There was a problem hiding this comment.
Good, that is the right intent
Please keep the follow-up scoped to:
- extract only batching + dispatch + state writeback into _run_packed_prefill (no behavior change)
- make 4096 a named constant with a short rationale comment
- add one focused test for split-batch mapping correctness
- keep this as scaffolding-only and avoid adding more policy logic here.
There was a problem hiding this comment.
Done in 074c865
- Extracted
_run_packed_prefillhelper — batching + dispatch + state writeback MAX_PACKED_PREFILL_TOKENS = 4096as module-level named constant with rationale- Added
TestBatchSplittingwith 4 focused tests (single batch, split, oversized request, entry preservation) - No behavior change, scaffolding only
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
The xfail on `test_batched_decode_matches` was added for issue vllm-project#119 (B=2 batched GEMM producing different floats than B=1). The test now passes consistently on main after recent paged kernel fixes (vllm-project#146, vllm-project#151). This follows PR vllm-project#149 which removed the same stale xfail for the greedy single-request test. Signed-off-by: Qiang <qren@integralads.com>


Summary
cu_seq_lens) as separatorsThis is the first step of Stage 3 (Chunked Prefilling & Continuous Batching) in the roadmap (#148).
Changes
paged_attention_common.py: Addcu_seq_lensfield toPagedAttentionContext; addprepare_prefill_packed()paged_attention.py: Add_build_packed_causal_mask(); update_metal_kernel_prefill_attention()for packed modemodel_runner.py: Add_prefill_packed_paged(); restructure Phase 1 to collect complete prefills and batch themtest_paged_attention.py: Add tests for packed slot_mapping, cu_seq_lens, and block-diagonal causal mask isolationTest
pytest tests/test_paged_attention.py)VLLM_METAL_USE_PAGED_ATTENTION=1, 3 concurrent requests all return correct responses