Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions tests/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_context,
prepare_decode,
prepare_prefill,
prepare_prefill_packed,
)


Expand Down Expand Up @@ -47,6 +48,30 @@ def test_prepare_prefill_slot_mapping(self):
assert ctx.is_prefill
assert ctx.slot_mapping == [40, 41, 42, 43, 44]

def test_prepare_prefill_packed_slot_mapping(self):
# Two requests: 3 tokens in block 10, 2 tokens in block 20
requests = [([10], 3), ([20], 2)]
prepare_prefill_packed(requests, block_size=4)
ctx = get_context()

assert ctx is not None
assert ctx.is_prefill
# Request 0: block 10, slots 40,41,42
# Request 1: block 20, slots 80,81
assert ctx.slot_mapping == [40, 41, 42, 80, 81]
assert ctx.cu_seqlens == [0, 3, 5]

def test_prepare_prefill_packed_single_request(self):
# Single request should still produce valid cu_seqlens
requests = [([5, 6], 5)]
prepare_prefill_packed(requests, block_size=4)
ctx = get_context()

assert ctx is not None
assert ctx.cu_seqlens == [0, 5]
# block 5: slots 20,21,22,23; block 6: slot 24
assert ctx.slot_mapping == [20, 21, 22, 23, 24]

def test_prepare_decode(self):
# Arrange
requests = [([5, 6], 7)]
Expand All @@ -61,3 +86,137 @@ def test_prepare_decode(self):
assert ctx.slot_mapping == [27]
assert ctx.context_lens == [8]
assert ctx.offsets == [7]


class TestPackedCausalMask:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add one focused unit test for apply_packed_rope as well? We already validate packed mask isolation here; a small RoPE-reset test would lock in the other key packed-prefill contract.

"""Tests for the block-diagonal causal mask used in packed prefill."""

def test_single_sequence(self):
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
build_packed_causal_mask,
)

mask = build_packed_causal_mask([0, 3], total_len=3)
# Standard causal: lower-triangular (0) with upper-triangular (-inf)
assert mask.shape == (1, 1, 3, 3)
m = mask[0, 0]
# Diagonal and below should be 0
assert m[0, 0].item() == 0.0
assert m[1, 0].item() == 0.0
assert m[1, 1].item() == 0.0
# Above diagonal should be -inf
assert m[0, 1].item() == float("-inf")
assert m[0, 2].item() == float("-inf")

def test_two_sequences_isolation(self):
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
build_packed_causal_mask,
)

# Two sequences: [0,2) and [2,5)
mask = build_packed_causal_mask([0, 2, 5], total_len=5)
m = mask[0, 0]
# Seq 0 tokens should not attend to seq 1 tokens
assert m[0, 2].item() == float("-inf")
assert m[0, 3].item() == float("-inf")
assert m[1, 2].item() == float("-inf")
# Seq 1 tokens should not attend to seq 0 tokens
assert m[2, 0].item() == float("-inf")
assert m[2, 1].item() == float("-inf")
assert m[3, 0].item() == float("-inf")
# Within seq 1: causal
assert m[2, 2].item() == 0.0
assert m[3, 2].item() == 0.0
assert m[3, 3].item() == 0.0
assert m[2, 3].item() == float("-inf")

def test_mask_dtype_matches_request(self):
import mlx.core as mx

from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
build_packed_causal_mask,
)

mask = build_packed_causal_mask([0, 3], total_len=3, dtype=mx.bfloat16)
assert mask.dtype == mx.bfloat16


class TestPackedRoPE:
"""Tests for per-request RoPE position reset in packed prefill."""

def test_positions_reset_per_request(self):
"""Each packed request's RoPE should start from position 0."""
import mlx.core as mx

from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
apply_packed_rope,
)

# Minimal RoPE stub: returns input + offset so we can verify offsets
class FakeRoPE:
def rope(self, x, offset=0):
return x + offset

module = FakeRoPE()
# Two requests packed: 3 tokens + 2 tokens
# Shape: (1, heads=1, total_len=5, head_dim=2)
q = mx.zeros((1, 1, 5, 2))
k = mx.zeros((1, 1, 5, 2))
cu_seqlens = [0, 3, 5]

q_out, k_out = apply_packed_rope(module, q, k, cu_seqlens)

# All values should be 0 (offset=0 for every request)
assert q_out.shape == (1, 1, 5, 2)
assert mx.allclose(q_out, mx.zeros_like(q_out)).item()
assert mx.allclose(k_out, mx.zeros_like(k_out)).item()


class TestBatchSplitting:
"""Tests for the packed-prefill batch splitting logic."""

@staticmethod
def _split_batches(
entries: list[tuple[int, int]],
max_tokens: int,
) -> list[list[tuple[int, int]]]:
"""Reproduce the batch splitting algorithm from _run_packed_prefill.

entries: list of (index, num_tokens) for simplicity.
"""
batches: list[list[tuple[int, int]]] = [[]]
batch_tokens = 0
for entry in entries:
entry_tokens = entry[1]
if batch_tokens + entry_tokens > max_tokens and batches[-1]:
batches.append([])
batch_tokens = 0
batches[-1].append(entry)
batch_tokens += entry_tokens
return batches

def test_all_fit_single_batch(self):
entries = [(0, 100), (1, 200), (2, 300)]
batches = self._split_batches(entries, max_tokens=4096)
assert len(batches) == 1
assert batches[0] == entries

def test_split_into_two_batches(self):
entries = [(0, 3000), (1, 2000)]
batches = self._split_batches(entries, max_tokens=4096)
assert len(batches) == 2
assert batches[0] == [(0, 3000)]
assert batches[1] == [(1, 2000)]

def test_single_large_request_not_dropped(self):
# A request exceeding the cap should still go into its own batch
entries = [(0, 5000)]
batches = self._split_batches(entries, max_tokens=4096)
assert len(batches) == 1
assert batches[0] == [(0, 5000)]

def test_preserves_all_entries(self):
entries = [(i, 1000) for i in range(10)]
batches = self._split_batches(entries, max_tokens=4096)
flat = [e for batch in batches for e in batch]
assert flat == entries
61 changes: 61 additions & 0 deletions vllm_metal/metal_kernel_backend/packed_prefill_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SCAFFOLDING: remove when varlen kernel is ready.
#
# Dense causal mask and per-request RoPE helpers for packed prefill.
# These are temporary — the varlen kernel will handle masking and
# position encoding natively, making this module unnecessary.

from __future__ import annotations

import mlx.core as mx


def build_packed_causal_mask(
cu_seqlens: list[int],
total_len: int,
dtype: mx.Dtype = mx.float32,
) -> mx.array:
"""Build a block-diagonal causal mask for packed prefill.

Each request only attends to its own tokens (causally). Returns an
additive mask of shape ``(1, 1, total_len, total_len)`` with 0 for
allowed positions and ``-inf`` for blocked positions, suitable for
``mx.fast.scaled_dot_product_attention``.

Args:
dtype: Construct the mask directly in this dtype to avoid a
transient float32 allocation followed by a cast.

SCAFFOLDING: remove when varlen kernel is ready.
"""
neg_inf = mx.array(-mx.inf, dtype=dtype)
# Start with all-blocked, then open causal windows per request
mask = mx.full((total_len, total_len), neg_inf)
for i in range(len(cu_seqlens) - 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is temporary scaffolding. Since it is active in production path now, please add a packed-length safety cap (or split batches) to avoid O(N²) mask blowups until varlen kernel replaces it. (Packing currently happens unconditionally in model_runner complete-prefill path.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

start = cu_seqlens[i]
end = cu_seqlens[i + 1]
seq_len = end - start
# Causal mask for this request's tokens
causal = mx.triu(mx.full((seq_len, seq_len), neg_inf), k=1)
mask[start:end, start:end] = causal
return mask.reshape(1, 1, total_len, total_len)


def apply_packed_rope(
attn_module: object,
queries: mx.array,
keys: mx.array,
cu_seqlens: list[int],
) -> tuple[mx.array, mx.array]:
"""Apply per-request RoPE with position reset for packed prefill.

SCAFFOLDING: remove when varlen kernel is ready.
"""
q_parts = []
k_parts = []
for i in range(len(cu_seqlens) - 1):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0))
k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0))
return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2)
31 changes: 23 additions & 8 deletions vllm_metal/metal_kernel_backend/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@

from vllm_metal.metal import get_ops
from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
apply_packed_rope,
build_packed_causal_mask,
)
from vllm_metal.paged_attention_common import (
PagedAttentionContext,
find_layers_and_attr,
Expand All @@ -89,25 +93,36 @@ def _metal_kernel_prefill_attention(
ctx: PagedAttentionContext,
offset_cache: Any,
) -> mx.array:
"""Prefill: B=1, L=prompt_len.
"""Prefill: B=1, L=prompt_len (single) or L=total_tokens (packed).

Inline causal SDPA in MLX, then write K/V to paged cache via
``reshape_and_cache``.
``reshape_and_cache``. When ``ctx.cu_seqlens`` is set, builds a
block-diagonal causal mask so packed requests don't cross-attend.
"""
B, _, L, _ = queries.shape # noqa: N806

# RoPE
# RoPE — per-request position reset for packed prefill
if not hasattr(attn_module, "rope"):
raise NotImplementedError(
f"Attention module {type(attn_module).__name__} does not have a 'rope' "
"attribute. Only RoPE-based models are supported by paged attention."
)
offset = offset_cache.offset if offset_cache is not None else 0
queries = attn_module.rope(queries, offset=offset)
keys = attn_module.rope(keys, offset=offset)

# Causal SDPA (inline — K/V already in hand)
attn_mask = "causal" if L > 1 else None
# SCAFFOLDING: packed RoPE + mask — remove when varlen kernel is ready.
if ctx.cu_seqlens is not None:
queries, keys = apply_packed_rope(attn_module, queries, keys, ctx.cu_seqlens)
else:
offset = offset_cache.offset if offset_cache is not None else 0
queries = attn_module.rope(queries, offset=offset)
keys = attn_module.rope(keys, offset=offset)

# Causal SDPA
# SCAFFOLDING: dense mask — remove when varlen kernel is ready.
if ctx.cu_seqlens is not None and len(ctx.cu_seqlens) > 2:
attn_mask = build_packed_causal_mask(ctx.cu_seqlens, L, dtype=queries.dtype)
else:
attn_mask = "causal" if L > 1 else None

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=attn_module.scale, mask=attn_mask
)
Expand Down
37 changes: 37 additions & 0 deletions vllm_metal/paged_attention_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class PagedAttentionContext:
block_tables: list[list[int]] = field(default_factory=list)
context_lens: list[int] = field(default_factory=list)
offsets: list[int] = field(default_factory=list)
# packed prefill fields — set when multiple requests are packed into
# a single forward pass. cu_seqlens is a cumulative sequence length
# array: [0, len0, len0+len1, ...] (length = num_requests + 1).
cu_seqlens: list[int] | None = None


def set_context(ctx: PagedAttentionContext) -> None:
Expand Down Expand Up @@ -164,6 +168,39 @@ def prepare_prefill(
)


def prepare_prefill_packed(
requests: list[tuple[list[int], int]],
block_size: int,
) -> None:
"""Compute slot_mapping and cu_seqlens for packed prefill.

Packs multiple prefill requests into a single forward pass. The
attention wrapper uses ``cu_seqlens`` to build a block-diagonal
causal mask so that each request only attends to its own tokens.

Args:
requests: list of (block_ids, num_tokens) per request.
block_size: tokens per block.
"""
slot_mapping: list[int] = []
cu_seqlens: list[int] = [0]

for block_ids, num_tokens in requests:
for pos in range(num_tokens):
block_idx = block_ids[pos // block_size]
slot = block_idx * block_size + (pos % block_size)
slot_mapping.append(slot)
cu_seqlens.append(cu_seqlens[-1] + num_tokens)

set_context(
PagedAttentionContext(
is_prefill=True,
slot_mapping=slot_mapping,
cu_seqlens=cu_seqlens,
)
)


def prepare_decode(
requests: list[tuple[list[int], int]],
block_size: int,
Expand Down
Loading