Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tests/test_metal_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,6 @@ def test_metal_unified_attn(
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]

# xfail cases that need features not yet in the v2 kernel:
# varlen (q_len > 1), sliding window, or soft capping.
# Decode-only cases with no extras already work and should pass.
max_query_len_val = max(query_lens)
if max_query_len_val > 1 or sliding_window is not None or soft_cap is not None:
pytest.xfail("v2 varlen/sliding-window/soft-cap not yet implemented")
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
Expand Down
37 changes: 19 additions & 18 deletions vllm_metal/metal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,33 @@ def metal_unified_attention(
) -> None:
"""Unified varlen paged attention for Metal.

Currently supports decode-only (max_seqlen_q=1). Sliding window and
soft capping are not yet supported. These will be enabled when the v2
kernel is extended to handle variable-length queries (prefill + decode).
Supports variable-length queries (prefill + decode) with online softmax,
paged KV cache, causal masking, sliding window, and soft capping.

Grid: one threadgroup per (head, query_token). Each threadgroup uses
binary search on cu_seqlens_q to find its sequence and computes causal
attention against the paged KV cache.
"""
assert causal, "Only causal attention is supported"
import mlx.core as mx

if max_seqlen_q != 1:
raise NotImplementedError(
f"metal_unified_attention only supports decode (max_seqlen_q=1), "
f"got {max_seqlen_q}"
)
if window_size != (-1, -1):
raise NotImplementedError(
f"Sliding window not yet supported, got window_size={window_size}"
)
if softcap != 0:
raise NotImplementedError(
f"Soft capping not yet supported, got softcap={softcap}"
)

# Extract dimensions from cache shape
# k shape: [num_blocks, block_size, num_kv_heads, head_size]
num_kv_heads = k.shape[2]
block_size = k.shape[1]

# Convert window_size tuple to a single sliding_window int.
# window_size = (left, right) where left = sw-1, right = 0 for causal.
# sliding_window = left + 1 = total window size. -1 = disabled.
if window_size == (-1, -1):
sliding_window = -1
else:
sliding_window = window_size[0] + 1

ops = get_ops()

# Ensure all inputs are evaluated before raw Metal dispatch
mx.eval(out, q, k, v, block_table, seqused_k)
mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q)

ops.paged_attention_v2_online(
out,
Expand All @@ -124,10 +122,13 @@ def metal_unified_attention(
v,
num_kv_heads,
softmax_scale,
softcap,
block_table,
seqused_k,
cu_seqlens_q,
block_size,
max_seqlen_k,
sliding_window,
)
mx.synchronize()

Expand Down
104 changes: 84 additions & 20 deletions vllm_metal/metal/kernels_v2/pagedattention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,34 @@ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

// Binary search to find which sequence a global query token belongs to.
//
// In varlen (ragged-batch) attention, queries from multiple sequences are
// packed contiguously into a flat array:
// q[0..q_len_0-1] → seq 0, q[q_len_0..q_len_0+q_len_1-1] → seq 1, ...
// The kernel launches one threadgroup per (head, query_token) in a flat grid.
// Each threadgroup needs to discover which sequence it belongs to so it can
// look up the correct block_table row, kv_len, and causal mask boundary.
//
// This is the same approach used by the upstream vLLM unified Triton kernel
// (triton_unified_attention.py:find_seq_idx) and FlashAttention's varlen API.
//
// cu_seqlens_q is sorted ascending: [0, q_len_0, q_len_0+q_len_1, ...].
// Returns seq_idx such that cu_seqlens_q[seq_idx] <= q_token_idx < cu_seqlens_q[seq_idx+1].
inline int find_seq_idx(const device int32_t *cu_seqlens_q,
int q_token_idx, int num_seqs) {
int lo = 0, hi = num_seqs;
while (lo < hi) {
int mid = (lo + hi + 1) / 2;
if (cu_seqlens_q[mid] <= q_token_idx) {
lo = mid;
} else {
hi = mid - 1;
}
}
return lo;
}

constant bool use_partitioning [[function_constant(10)]];
constant bool use_alibi [[function_constant(20)]];
constant bool use_fp8_scales [[function_constant(30)]];
Expand Down Expand Up @@ -795,24 +823,41 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
const constant int &kv_head_stride [[buffer(17)]],
const device float *sinks
[[buffer(18), function_constant(use_sinks)]], // [num_heads]
device const int32_t *cu_seqlens_q [[buffer(19)]], // [num_seqs + 1]
const constant int &num_seqs [[buffer(20)]],
const constant int &sliding_window [[buffer(21)]], // -1 = disabled
threadgroup char *shared_mem [[threadgroup(0)]],
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint3 threadgroups_per_grid [[threadgroups_per_grid]],
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
uint simd_tid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
const int seq_idx = threadgroup_position_in_grid.y;
// Varlen: each threadgroup handles one query token.
// Use binary search on cu_seqlens_q to find which sequence it belongs to.
const int q_token_idx = threadgroup_position_in_grid.y;
const int seq_idx = find_seq_idx(cu_seqlens_q, q_token_idx, num_seqs);
const int q_seq_start = cu_seqlens_q[seq_idx];
const int q_len = cu_seqlens_q[seq_idx + 1] - q_seq_start;
const int q_pos_in_seq = q_token_idx - q_seq_start;
const int partition_idx = threadgroup_position_in_grid.z;
const int max_num_partitions = threadgroups_per_grid.z;
const int thread_idx = thread_position_in_threadgroup.x;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const uint32_t context_len = context_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
const uint32_t context_len = context_lens[seq_idx]; // total KV length for this seq

// Causal: this query token can attend to KV positions [0, effective_context_len).
const int effective_context_len = (int)context_len - q_len + q_pos_in_seq + 1;
if (effective_context_len <= 0) {
// No KV tokens to attend to. Caller guarantees out is zero-initialized.
return;
}

if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= effective_context_len) {
// No work to do. Terminate the thread block.
return;
}

const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int num_context_blocks = DIVIDE_ROUND_UP(effective_context_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;

Expand Down Expand Up @@ -867,7 +912,7 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
// For example, if the thread group size is 4, then the first thread in the
// group has 0, 4, 8, ... th vectors of the query, and the second thread has
// 1, 5, 9, ... th vectors of the query, and so on.
const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const device T *q_ptr = q + q_token_idx * q_stride + head_idx * HEAD_SIZE;
threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
Expand Down Expand Up @@ -955,15 +1000,20 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
float qk = scale * Qk_dot<T, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);

if (softcapping != 1.0) {
if (softcapping > 0.0f) {
qk = tanh(qk / softcapping) * softcapping;
}

qk +=
(alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
(alibi_slope != 0) ? alibi_slope * (token_idx - effective_context_len + 1) : 0;

if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len;
// Causal mask: only attend to KV positions < effective_context_len.
bool mask = token_idx >= effective_context_len;
// Sliding window mask: skip positions too far in the past.
if (sliding_window >= 0) {
mask = mask || (token_idx < effective_context_len - sliding_window);
}
warp_scores[physical_block_offset] = mask ? -FLT_MAX : qk;
}
}
Expand All @@ -981,7 +1031,7 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
// Valid tokens in this block:
const int block_start_token = block_idx * BLOCK_SIZE;
const int block_valid_tokens =
MIN(BLOCK_SIZE, (int)context_len - block_start_token);
MIN(BLOCK_SIZE, effective_context_len - block_start_token);

// Find max score in this block (all lanes participate for speed).
float block_max = -FLT_MAX;
Expand Down Expand Up @@ -1058,13 +1108,14 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
// reached by all threads in the threadgroup).

// If partitioning is enabled, store the partial result for the reduce kernel.
// Indexed by q_token_idx (not seq_idx) for varlen compatibility.
if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
device float *max_logits_ptr =
max_logits + seq_idx * num_heads * max_num_partitions +
max_logits + q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = warp_m;
device float *exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = warp_l;
}
Expand Down Expand Up @@ -1143,7 +1194,7 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
const float inv_l = 1.f / (warp_l + 1e-6f);

device T *out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
out + q_token_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int j = 0; j < V_ELEMS_PER_THREAD; j++) {
Expand All @@ -1165,6 +1216,8 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
const constant int &max_num_partitions [[buffer(5)]],
const device float *sinks
[[buffer(6), function_constant(use_sinks)]], // [num_heads]
device const int32_t *cu_seqlens_q [[buffer(7)]], // [num_seqs + 1]
const constant int &num_seqs [[buffer(8)]],
threadgroup char *shared_mem [[threadgroup(0)]],
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint3 threadgroups_per_grid [[threadgroups_per_grid]],
Expand All @@ -1174,15 +1227,21 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
uint simd_lid [[thread_index_in_simdgroup]]) {
const int num_heads = threadgroups_per_grid.x;
const int head_idx = threadgroup_position_in_grid.x;
const int seq_idx = threadgroup_position_in_grid.y;
// Varlen: grid.y is q_token_idx (one per query token), not seq_idx.
const int q_token_idx = threadgroup_position_in_grid.y;
const int seq_idx = find_seq_idx(cu_seqlens_q, q_token_idx, num_seqs);
const int q_seq_start = cu_seqlens_q[seq_idx];
const int q_len = cu_seqlens_q[seq_idx + 1] - q_seq_start;
const int q_pos_in_seq = q_token_idx - q_seq_start;
const uint32_t context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
const int effective_context_len = (int)context_len - q_len + q_pos_in_seq + 1;
const int num_partitions = DIVIDE_ROUND_UP(effective_context_len, PARTITION_SIZE);
if (num_partitions == 1 && !use_sinks) {
// No need to reduce. Only copy tmp_out to out.
device T *out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
out + q_token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const device T *tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
tmp_out + q_token_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
i += threads_per_threadgroup.x) {
Expand All @@ -1203,7 +1262,7 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
threadgroup float *shared_max_logits =
reinterpret_cast<threadgroup float *>(shared_mem);
const device float *max_logits_ptr =
max_logits + seq_idx * num_heads * max_num_partitions +
max_logits + q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = thread_position_in_threadgroup.x; i < num_partitions;
Expand Down Expand Up @@ -1242,7 +1301,7 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
threadgroup float *shared_exp_sums = reinterpret_cast<threadgroup float *>(
shared_mem + sizeof(float) * num_partitions);
const device float *exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = thread_position_in_threadgroup.x; i < num_partitions;
Expand All @@ -1265,10 +1324,10 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,

// Aggregate tmp_out to out.
const device T *tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
tmp_out + q_token_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
device T *out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
out + q_token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
i += NUM_THREADS) {
Expand Down Expand Up @@ -1313,6 +1372,9 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
const constant int &kv_block_stride [[buffer(16)]], \
const constant int &kv_head_stride [[buffer(17)]], \
const device float *sinks [[buffer(18), function_constant(use_sinks)]], \
device const int32_t *cu_seqlens_q [[buffer(19)]], \
const constant int &num_seqs [[buffer(20)]], \
const constant int &sliding_window [[buffer(21)]], \
threadgroup char *shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
Expand All @@ -1334,6 +1396,8 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
device uint32_t *context_lens [[buffer(4)]], \
const constant int &max_num_partitions [[buffer(5)]], \
const device float *sinks [[buffer(6), function_constant(use_sinks)]], \
device const int32_t *cu_seqlens_q [[buffer(7)]], \
const constant int &num_seqs [[buffer(8)]], \
threadgroup char *shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
Expand Down
Loading
Loading