-
Notifications
You must be signed in to change notification settings - Fork 77
[Continuous Batching] Packed prefill with cu_seq_lens for multiple requests #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5dcb9b0
d8b1e2d
59aed34
ad31726
1274c63
074c865
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SCAFFOLDING: remove when varlen kernel is ready. | ||
| # | ||
| # Dense causal mask and per-request RoPE helpers for packed prefill. | ||
| # These are temporary — the varlen kernel will handle masking and | ||
| # position encoding natively, making this module unnecessary. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import mlx.core as mx | ||
|
|
||
|
|
||
| def build_packed_causal_mask( | ||
| cu_seqlens: list[int], | ||
| total_len: int, | ||
| dtype: mx.Dtype = mx.float32, | ||
| ) -> mx.array: | ||
| """Build a block-diagonal causal mask for packed prefill. | ||
|
|
||
| Each request only attends to its own tokens (causally). Returns an | ||
| additive mask of shape ``(1, 1, total_len, total_len)`` with 0 for | ||
| allowed positions and ``-inf`` for blocked positions, suitable for | ||
| ``mx.fast.scaled_dot_product_attention``. | ||
|
|
||
| Args: | ||
| dtype: Construct the mask directly in this dtype to avoid a | ||
| transient float32 allocation followed by a cast. | ||
|
|
||
| SCAFFOLDING: remove when varlen kernel is ready. | ||
| """ | ||
| neg_inf = mx.array(-mx.inf, dtype=dtype) | ||
| # Start with all-blocked, then open causal windows per request | ||
| mask = mx.full((total_len, total_len), neg_inf) | ||
| for i in range(len(cu_seqlens) - 1): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand this is temporary scaffolding. Since it is active in production path now, please add a packed-length safety cap (or split batches) to avoid O(N²) mask blowups until varlen kernel replaces it. (Packing currently happens unconditionally in model_runner complete-prefill path.)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| start = cu_seqlens[i] | ||
| end = cu_seqlens[i + 1] | ||
| seq_len = end - start | ||
| # Causal mask for this request's tokens | ||
| causal = mx.triu(mx.full((seq_len, seq_len), neg_inf), k=1) | ||
| mask[start:end, start:end] = causal | ||
| return mask.reshape(1, 1, total_len, total_len) | ||
|
|
||
|
|
||
| def apply_packed_rope( | ||
| attn_module: object, | ||
| queries: mx.array, | ||
| keys: mx.array, | ||
| cu_seqlens: list[int], | ||
| ) -> tuple[mx.array, mx.array]: | ||
| """Apply per-request RoPE with position reset for packed prefill. | ||
|
|
||
| SCAFFOLDING: remove when varlen kernel is ready. | ||
| """ | ||
| q_parts = [] | ||
| k_parts = [] | ||
| for i in range(len(cu_seqlens) - 1): | ||
| start = cu_seqlens[i] | ||
| end = cu_seqlens[i + 1] | ||
| q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0)) | ||
| k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0)) | ||
| return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add one focused unit test for apply_packed_rope as well? We already validate packed mask isolation here; a small RoPE-reset test would lock in the other key packed-prefill contract.