diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index 3da0c62e..bee3ec76 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -13,6 +13,7 @@ get_context, prepare_decode, prepare_prefill, + prepare_prefill_packed, ) @@ -47,6 +48,30 @@ def test_prepare_prefill_slot_mapping(self): assert ctx.is_prefill assert ctx.slot_mapping == [40, 41, 42, 43, 44] + def test_prepare_prefill_packed_slot_mapping(self): + # Two requests: 3 tokens in block 10, 2 tokens in block 20 + requests = [([10], 3), ([20], 2)] + prepare_prefill_packed(requests, block_size=4) + ctx = get_context() + + assert ctx is not None + assert ctx.is_prefill + # Request 0: block 10, slots 40,41,42 + # Request 1: block 20, slots 80,81 + assert ctx.slot_mapping == [40, 41, 42, 80, 81] + assert ctx.cu_seqlens == [0, 3, 5] + + def test_prepare_prefill_packed_single_request(self): + # Single request should still produce valid cu_seqlens + requests = [([5, 6], 5)] + prepare_prefill_packed(requests, block_size=4) + ctx = get_context() + + assert ctx is not None + assert ctx.cu_seqlens == [0, 5] + # block 5: slots 20,21,22,23; block 6: slot 24 + assert ctx.slot_mapping == [20, 21, 22, 23, 24] + def test_prepare_decode(self): # Arrange requests = [([5, 6], 7)] @@ -61,3 +86,137 @@ def test_prepare_decode(self): assert ctx.slot_mapping == [27] assert ctx.context_lens == [8] assert ctx.offsets == [7] + + +class TestPackedCausalMask: + """Tests for the block-diagonal causal mask used in packed prefill.""" + + def test_single_sequence(self): + from vllm_metal.metal_kernel_backend.packed_prefill_compat import ( + build_packed_causal_mask, + ) + + mask = build_packed_causal_mask([0, 3], total_len=3) + # Standard causal: lower-triangular (0) with upper-triangular (-inf) + assert mask.shape == (1, 1, 3, 3) + m = mask[0, 0] + # Diagonal and below should be 0 + assert m[0, 0].item() == 0.0 + assert m[1, 0].item() == 0.0 + assert m[1, 1].item() == 0.0 + # Above diagonal should be -inf + assert m[0, 1].item() == float("-inf") + assert m[0, 2].item() == float("-inf") + + def test_two_sequences_isolation(self): + from vllm_metal.metal_kernel_backend.packed_prefill_compat import ( + build_packed_causal_mask, + ) + + # Two sequences: [0,2) and [2,5) + mask = build_packed_causal_mask([0, 2, 5], total_len=5) + m = mask[0, 0] + # Seq 0 tokens should not attend to seq 1 tokens + assert m[0, 2].item() == float("-inf") + assert m[0, 3].item() == float("-inf") + assert m[1, 2].item() == float("-inf") + # Seq 1 tokens should not attend to seq 0 tokens + assert m[2, 0].item() == float("-inf") + assert m[2, 1].item() == float("-inf") + assert m[3, 0].item() == float("-inf") + # Within seq 1: causal + assert m[2, 2].item() == 0.0 + assert m[3, 2].item() == 0.0 + assert m[3, 3].item() == 0.0 + assert m[2, 3].item() == float("-inf") + + def test_mask_dtype_matches_request(self): + import mlx.core as mx + + from vllm_metal.metal_kernel_backend.packed_prefill_compat import ( + build_packed_causal_mask, + ) + + mask = build_packed_causal_mask([0, 3], total_len=3, dtype=mx.bfloat16) + assert mask.dtype == mx.bfloat16 + + +class TestPackedRoPE: + """Tests for per-request RoPE position reset in packed prefill.""" + + def test_positions_reset_per_request(self): + """Each packed request's RoPE should start from position 0.""" + import mlx.core as mx + + from vllm_metal.metal_kernel_backend.packed_prefill_compat import ( + apply_packed_rope, + ) + + # Minimal RoPE stub: returns input + offset so we can verify offsets + class FakeRoPE: + def rope(self, x, offset=0): + return x + offset + + module = FakeRoPE() + # Two requests packed: 3 tokens + 2 tokens + # Shape: (1, heads=1, total_len=5, head_dim=2) + q = mx.zeros((1, 1, 5, 2)) + k = mx.zeros((1, 1, 5, 2)) + cu_seqlens = [0, 3, 5] + + q_out, k_out = apply_packed_rope(module, q, k, cu_seqlens) + + # All values should be 0 (offset=0 for every request) + assert q_out.shape == (1, 1, 5, 2) + assert mx.allclose(q_out, mx.zeros_like(q_out)).item() + assert mx.allclose(k_out, mx.zeros_like(k_out)).item() + + +class TestBatchSplitting: + """Tests for the packed-prefill batch splitting logic.""" + + @staticmethod + def _split_batches( + entries: list[tuple[int, int]], + max_tokens: int, + ) -> list[list[tuple[int, int]]]: + """Reproduce the batch splitting algorithm from _run_packed_prefill. + + entries: list of (index, num_tokens) for simplicity. + """ + batches: list[list[tuple[int, int]]] = [[]] + batch_tokens = 0 + for entry in entries: + entry_tokens = entry[1] + if batch_tokens + entry_tokens > max_tokens and batches[-1]: + batches.append([]) + batch_tokens = 0 + batches[-1].append(entry) + batch_tokens += entry_tokens + return batches + + def test_all_fit_single_batch(self): + entries = [(0, 100), (1, 200), (2, 300)] + batches = self._split_batches(entries, max_tokens=4096) + assert len(batches) == 1 + assert batches[0] == entries + + def test_split_into_two_batches(self): + entries = [(0, 3000), (1, 2000)] + batches = self._split_batches(entries, max_tokens=4096) + assert len(batches) == 2 + assert batches[0] == [(0, 3000)] + assert batches[1] == [(1, 2000)] + + def test_single_large_request_not_dropped(self): + # A request exceeding the cap should still go into its own batch + entries = [(0, 5000)] + batches = self._split_batches(entries, max_tokens=4096) + assert len(batches) == 1 + assert batches[0] == [(0, 5000)] + + def test_preserves_all_entries(self): + entries = [(i, 1000) for i in range(10)] + batches = self._split_batches(entries, max_tokens=4096) + flat = [e for batch in batches for e in batch] + assert flat == entries diff --git a/vllm_metal/metal_kernel_backend/packed_prefill_compat.py b/vllm_metal/metal_kernel_backend/packed_prefill_compat.py new file mode 100644 index 00000000..abd92816 --- /dev/null +++ b/vllm_metal/metal_kernel_backend/packed_prefill_compat.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SCAFFOLDING: remove when varlen kernel is ready. +# +# Dense causal mask and per-request RoPE helpers for packed prefill. +# These are temporary — the varlen kernel will handle masking and +# position encoding natively, making this module unnecessary. + +from __future__ import annotations + +import mlx.core as mx + + +def build_packed_causal_mask( + cu_seqlens: list[int], + total_len: int, + dtype: mx.Dtype = mx.float32, +) -> mx.array: + """Build a block-diagonal causal mask for packed prefill. + + Each request only attends to its own tokens (causally). Returns an + additive mask of shape ``(1, 1, total_len, total_len)`` with 0 for + allowed positions and ``-inf`` for blocked positions, suitable for + ``mx.fast.scaled_dot_product_attention``. + + Args: + dtype: Construct the mask directly in this dtype to avoid a + transient float32 allocation followed by a cast. + + SCAFFOLDING: remove when varlen kernel is ready. + """ + neg_inf = mx.array(-mx.inf, dtype=dtype) + # Start with all-blocked, then open causal windows per request + mask = mx.full((total_len, total_len), neg_inf) + for i in range(len(cu_seqlens) - 1): + start = cu_seqlens[i] + end = cu_seqlens[i + 1] + seq_len = end - start + # Causal mask for this request's tokens + causal = mx.triu(mx.full((seq_len, seq_len), neg_inf), k=1) + mask[start:end, start:end] = causal + return mask.reshape(1, 1, total_len, total_len) + + +def apply_packed_rope( + attn_module: object, + queries: mx.array, + keys: mx.array, + cu_seqlens: list[int], +) -> tuple[mx.array, mx.array]: + """Apply per-request RoPE with position reset for packed prefill. + + SCAFFOLDING: remove when varlen kernel is ready. + """ + q_parts = [] + k_parts = [] + for i in range(len(cu_seqlens) - 1): + start = cu_seqlens[i] + end = cu_seqlens[i + 1] + q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0)) + k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0)) + return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2) diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index 4806bb60..43b6a58f 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -68,6 +68,10 @@ from vllm_metal.metal import get_ops from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache +from vllm_metal.metal_kernel_backend.packed_prefill_compat import ( + apply_packed_rope, + build_packed_causal_mask, +) from vllm_metal.paged_attention_common import ( PagedAttentionContext, find_layers_and_attr, @@ -89,25 +93,36 @@ def _metal_kernel_prefill_attention( ctx: PagedAttentionContext, offset_cache: Any, ) -> mx.array: - """Prefill: B=1, L=prompt_len. + """Prefill: B=1, L=prompt_len (single) or L=total_tokens (packed). Inline causal SDPA in MLX, then write K/V to paged cache via - ``reshape_and_cache``. + ``reshape_and_cache``. When ``ctx.cu_seqlens`` is set, builds a + block-diagonal causal mask so packed requests don't cross-attend. """ B, _, L, _ = queries.shape # noqa: N806 - # RoPE + # RoPE — per-request position reset for packed prefill if not hasattr(attn_module, "rope"): raise NotImplementedError( f"Attention module {type(attn_module).__name__} does not have a 'rope' " "attribute. Only RoPE-based models are supported by paged attention." ) - offset = offset_cache.offset if offset_cache is not None else 0 - queries = attn_module.rope(queries, offset=offset) - keys = attn_module.rope(keys, offset=offset) - # Causal SDPA (inline — K/V already in hand) - attn_mask = "causal" if L > 1 else None + # SCAFFOLDING: packed RoPE + mask — remove when varlen kernel is ready. + if ctx.cu_seqlens is not None: + queries, keys = apply_packed_rope(attn_module, queries, keys, ctx.cu_seqlens) + else: + offset = offset_cache.offset if offset_cache is not None else 0 + queries = attn_module.rope(queries, offset=offset) + keys = attn_module.rope(keys, offset=offset) + + # Causal SDPA + # SCAFFOLDING: dense mask — remove when varlen kernel is ready. + if ctx.cu_seqlens is not None and len(ctx.cu_seqlens) > 2: + attn_mask = build_packed_causal_mask(ctx.cu_seqlens, L, dtype=queries.dtype) + else: + attn_mask = "causal" if L > 1 else None + output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=attn_module.scale, mask=attn_mask ) diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index b8f2d484..615b7a2b 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -42,6 +42,10 @@ class PagedAttentionContext: block_tables: list[list[int]] = field(default_factory=list) context_lens: list[int] = field(default_factory=list) offsets: list[int] = field(default_factory=list) + # packed prefill fields — set when multiple requests are packed into + # a single forward pass. cu_seqlens is a cumulative sequence length + # array: [0, len0, len0+len1, ...] (length = num_requests + 1). + cu_seqlens: list[int] | None = None def set_context(ctx: PagedAttentionContext) -> None: @@ -164,6 +168,39 @@ def prepare_prefill( ) +def prepare_prefill_packed( + requests: list[tuple[list[int], int]], + block_size: int, +) -> None: + """Compute slot_mapping and cu_seqlens for packed prefill. + + Packs multiple prefill requests into a single forward pass. The + attention wrapper uses ``cu_seqlens`` to build a block-diagonal + causal mask so that each request only attends to its own tokens. + + Args: + requests: list of (block_ids, num_tokens) per request. + block_size: tokens per block. + """ + slot_mapping: list[int] = [] + cu_seqlens: list[int] = [0] + + for block_ids, num_tokens in requests: + for pos in range(num_tokens): + block_idx = block_ids[pos // block_size] + slot = block_idx * block_size + (pos % block_size) + slot_mapping.append(slot) + cu_seqlens.append(cu_seqlens[-1] + num_tokens) + + set_context( + PagedAttentionContext( + is_prefill=True, + slot_mapping=slot_mapping, + cu_seqlens=cu_seqlens, + ) + ) + + def prepare_decode( requests: list[tuple[list[int], int]], block_size: int, diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index d35bd31f..77449a93 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -54,6 +54,7 @@ clear_context, prepare_decode, prepare_prefill, + prepare_prefill_packed, ) from vllm_metal.pytorch_backend.tensor_bridge import mlx_to_torch from vllm_metal.stt.config import ( @@ -703,6 +704,12 @@ def decode( return tokens +# SCAFFOLDING: remove when varlen kernel is ready. +# Cap total packed-prefill tokens to bound the O(N²) dense causal mask. +# Batches exceeding this limit are split into multiple forward passes. +MAX_PACKED_PREFILL_TOKENS = 4096 + + class MetalModelRunner: """Model runner for MLX-based inference on Metal. @@ -1619,6 +1626,158 @@ def _prefill_single_request_paged( return next_token + def _prefill_packed_paged( + self, + pack_reqs: list[ + tuple[ + str, + list[int], + SamplingParams, + list[int], + torch.Generator | None, + int | None, + ] + ], + ) -> list[int]: + """Packed paged-attention prefill for multiple requests. + + Concatenates token_ids from all requests into a single forward + pass using ``cu_seqlens`` to build a block-diagonal causal mask. + This avoids the overhead of N separate forward passes. + + Args: + pack_reqs: list of + (req_id, token_ids, sampling_params, block_ids, + generator, prompt_len) tuples. + + Returns: + List of sampled next tokens, one per request. + """ + # Build packed input + all_token_ids: list[int] = [] + block_requests: list[tuple[list[int], int]] = [] + for _, token_ids, _, block_ids, _, _ in pack_reqs: + all_token_ids.extend(token_ids) + block_requests.append((block_ids, len(token_ids))) + + # Stash packed context (slot_mapping + cu_seqlens) + prepare_prefill_packed(block_requests, self._paged_block_size) + + offset_caches = [OffsetCache(0) for _ in range(self.num_layers)] + input_ids = mx.array([all_token_ids], dtype=mx.int32) + try: + model_output = self.model(input_ids, cache=offset_caches) + logits = self._extract_logits(model_output) + finally: + clear_context() + + # Extract per-request last-token logits and sample + cu_seqlens = [0] + for _, token_ids, _, _, _, _ in pack_reqs: + cu_seqlens.append(cu_seqlens[-1] + len(token_ids)) + + next_tokens: list[int] = [] + for i, ( + req_id, + token_ids, + sampling_params, + _, + generator, + prompt_len, + ) in enumerate(pack_reqs): + last_idx = cu_seqlens[i + 1] - 1 + last_logits = logits[:, last_idx : last_idx + 1, :] + + if prompt_len is None: + prompt_len = len(token_ids) + + is_greedy = sampling_params.temperature < 1e-5 + needs_advanced = ( + sampling_params.top_k > 0 + or sampling_params.top_p < 1.0 + or sampling_params.frequency_penalty != 0 + or sampling_params.presence_penalty != 0 + or sampling_params.repetition_penalty != 1.0 + ) + + if is_greedy and not needs_advanced: + next_token_mlx = _mlx_greedy_sample(last_logits[0]) + mx.eval(next_token_mlx) + next_token = int(next_token_mlx.item()) + else: + mx.eval(last_logits) + logits_torch = mlx_to_torch( + last_logits[0].astype(mx.float32), device=self.device + ) + generators = {} if generator is None else {0: generator} + metadata = self._make_sampling_metadata( + [sampling_params], + [token_ids[:prompt_len]], + [token_ids[prompt_len:]], + generators=generators, + ) + output = self._sampler.forward(logits_torch, metadata) + next_token = int(output.sampled_token_ids[0, 0].item()) + + self._paged_request_seq_lens[req_id] = len(token_ids) + next_tokens.append(next_token) + + return next_tokens + + def _run_packed_prefill( + self, + paged_complete: list[ + tuple[ + int, + str, + list[int], + SamplingParams, + list[int], + torch.Generator | None, + ] + ], + sampled_tokens: list[list[int]], + ) -> None: + """Batch, dispatch, and write back state for packed paged prefill. + + Splits *paged_complete* into batches that fit within + ``MAX_PACKED_PREFILL_TOKENS``, runs each batch through + ``_prefill_packed_paged``, and fills *sampled_tokens* in-place. + + SCAFFOLDING: batching removed when varlen kernel is ready. + """ + # Split into batches that fit within the packed-length cap. + batches: list[list[tuple]] = [[]] + batch_tokens = 0 + for entry in paged_complete: + entry_tokens = len(entry[2]) # token_ids + if batch_tokens + entry_tokens > MAX_PACKED_PREFILL_TOKENS and batches[-1]: + batches.append([]) + batch_tokens = 0 + batches[-1].append(entry) + batch_tokens += entry_tokens + + for batch in batches: + pack_input = [ + (rid, tids, sp, bids, gen, None) + for _, rid, tids, sp, bids, gen in batch + ] + next_tokens = self._prefill_packed_paged(pack_input) + for i, (idx, rid, tids, sp, bids, gen) in enumerate(batch): + nt = next_tokens[i] + sampled_tokens[idx] = [nt] + self._request_states[rid] = RequestState( + token_ids=list(tids) + [nt], + prompt_len=len(tids), + cache=[], + sampling_params=sp, + generator=gen, + generated_tokens=1, + block_ids=bids, + ) + if self._rust_state_manager is not None: + self._rust_state_manager.add_request(rid, list(tids) + [nt]) + def _batched_decode_paged( self, decode_reqs: list[tuple[str, RequestState]] ) -> list[int]: @@ -1757,71 +1916,84 @@ def execute_model( # === PHASE 1: Process new requests (prefill phase) === new_reqs = scheduler_output.scheduled_new_reqs + # First pass: handle intermediate chunks immediately, collect + # complete paged prefill requests for potential packing. + paged_complete: list[ + tuple[ + int, str, list[int], SamplingParams, list[int], torch.Generator | None + ] + ] = [] + for new_req in new_reqs: req_id = new_req.req_id token_ids = new_req.prompt_token_ids or [] sampling_params = new_req.sampling_params or SamplingParams() req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 + output_idx = len(req_ids) - 1 + req_id_to_index[req_id] = output_idx - if token_ids: - generator = _create_request_generator(self.device, sampling_params) + if not token_ids: + sampled_tokens.append([0]) # Fallback + continue - if self._paged_kv_cache is not None: - # Paged attention path (Metal kernel) - sched_block_ids = list(new_req.block_ids[0]) - scheduled_tokens = scheduler_output.num_scheduled_tokens.get( - req_id, 0 - ) - computed_tokens = new_req.num_computed_tokens - prompt_len = len(token_ids) - if computed_tokens + scheduled_tokens < prompt_len: - # Intermediate chunk: sample then drop (async scheduler - # allocates no placeholder for intermediate chunks). - cur_len = computed_tokens + scheduled_tokens - _discarded = self._prefill_single_request_paged( - req_id, - token_ids[:cur_len], - sampling_params, - block_ids=sched_block_ids, - generator=generator, - ) - cache: list = [] - sampled_tokens.append([]) - self._request_states[req_id] = RequestState( - token_ids=list(token_ids), - prompt_len=prompt_len, - cache=cache, - sampling_params=sampling_params, - generator=generator, - generated_tokens=0, - block_ids=sched_block_ids, - ) - if self._rust_state_manager is not None: - self._rust_state_manager.add_request( - req_id, list(token_ids[:cur_len]) - ) - continue - # Prompt complete: generate first output token. - next_token = self._prefill_single_request_paged( + generator = _create_request_generator(self.device, sampling_params) + + if self._paged_kv_cache is not None: + # Paged attention path (Metal kernel) + sched_block_ids = list(new_req.block_ids[0]) + scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) + computed_tokens = new_req.num_computed_tokens + prompt_len = len(token_ids) + + if computed_tokens + scheduled_tokens < prompt_len: + # Intermediate chunk: sample then drop (async scheduler + # allocates no placeholder for intermediate chunks). + cur_len = computed_tokens + scheduled_tokens + _discarded = self._prefill_single_request_paged( req_id, - token_ids, + token_ids[:cur_len], sampling_params, block_ids=sched_block_ids, generator=generator, ) - cache = [] # No per-request KV cache needed - else: - next_token, cache = self._prefill_single( + cache: list = [] + sampled_tokens.append([]) + self._request_states[req_id] = RequestState( + token_ids=list(token_ids), + prompt_len=prompt_len, + cache=cache, + sampling_params=sampling_params, + generator=generator, + generated_tokens=0, + block_ids=sched_block_ids, + ) + if self._rust_state_manager is not None: + self._rust_state_manager.add_request( + req_id, list(token_ids[:cur_len]) + ) + continue + + # Complete prefill — collect for packed processing + sampled_tokens.append([]) # placeholder, filled below + paged_complete.append( + ( + output_idx, req_id, token_ids, sampling_params, - generator=generator, + sched_block_ids, + generator, ) + ) + else: + next_token, cache = self._prefill_single( + req_id, + token_ids, + sampling_params, + generator=generator, + ) sampled_tokens.append([next_token]) - - # Store request state with cache for future decoding self._request_states[req_id] = RequestState( token_ids=list(token_ids) + [next_token], prompt_len=len(token_ids), @@ -1829,18 +2001,17 @@ def execute_model( sampling_params=sampling_params, generator=generator, generated_tokens=1, - block_ids=sched_block_ids - if self._paged_kv_cache is not None - else [], + block_ids=[], ) - - # Register with Rust state manager if available if self._rust_state_manager is not None: self._rust_state_manager.add_request( req_id, list(token_ids) + [next_token] ) - else: - sampled_tokens.append([0]) # Fallback + + # Process collected complete paged prefill requests via unified + # packed path (handles 1 or more requests). + if paged_complete: + self._run_packed_prefill(paged_complete, sampled_tokens) # === PHASE 2: Process cached requests (TRUE batched decode) === cached_reqs = scheduler_output.scheduled_cached_reqs