diff --git a/tests/test_metal_unified_attention.py b/tests/test_metal_unified_attention.py index 4a3187c..f40fcfe 100644 --- a/tests/test_metal_unified_attention.py +++ b/tests/test_metal_unified_attention.py @@ -357,12 +357,6 @@ def test_metal_unified_attn( query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] - # xfail cases that need features not yet in the v2 kernel: - # varlen (q_len > 1), sliding window, or soft capping. - # Decode-only cases with no extras already work and should pass. - max_query_len_val = max(query_lens) - if max_query_len_val > 1 or sliding_window is not None or soft_cap is not None: - pytest.xfail("v2 varlen/sliding-window/soft-cap not yet implemented") num_query_heads = num_heads[0] num_kv_heads = num_heads[1] assert num_query_heads % num_kv_heads == 0 diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py index 92fdf22..29f4c18 100644 --- a/vllm_metal/metal/__init__.py +++ b/vllm_metal/metal/__init__.py @@ -87,35 +87,33 @@ def metal_unified_attention( ) -> None: """Unified varlen paged attention for Metal. - Currently supports decode-only (max_seqlen_q=1). Sliding window and - soft capping are not yet supported. These will be enabled when the v2 - kernel is extended to handle variable-length queries (prefill + decode). + Supports variable-length queries (prefill + decode) with online softmax, + paged KV cache, causal masking, sliding window, and soft capping. + + Grid: one threadgroup per (head, query_token). Each threadgroup uses + binary search on cu_seqlens_q to find its sequence and computes causal + attention against the paged KV cache. """ + assert causal, "Only causal attention is supported" import mlx.core as mx - if max_seqlen_q != 1: - raise NotImplementedError( - f"metal_unified_attention only supports decode (max_seqlen_q=1), " - f"got {max_seqlen_q}" - ) - if window_size != (-1, -1): - raise NotImplementedError( - f"Sliding window not yet supported, got window_size={window_size}" - ) - if softcap != 0: - raise NotImplementedError( - f"Soft capping not yet supported, got softcap={softcap}" - ) - # Extract dimensions from cache shape # k shape: [num_blocks, block_size, num_kv_heads, head_size] num_kv_heads = k.shape[2] block_size = k.shape[1] + # Convert window_size tuple to a single sliding_window int. + # window_size = (left, right) where left = sw-1, right = 0 for causal. + # sliding_window = left + 1 = total window size. -1 = disabled. + if window_size == (-1, -1): + sliding_window = -1 + else: + sliding_window = window_size[0] + 1 + ops = get_ops() # Ensure all inputs are evaluated before raw Metal dispatch - mx.eval(out, q, k, v, block_table, seqused_k) + mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q) ops.paged_attention_v2_online( out, @@ -124,10 +122,13 @@ def metal_unified_attention( v, num_kv_heads, softmax_scale, + softcap, block_table, seqused_k, + cu_seqlens_q, block_size, max_seqlen_k, + sliding_window, ) mx.synchronize() diff --git a/vllm_metal/metal/kernels_v2/pagedattention.metal b/vllm_metal/metal/kernels_v2/pagedattention.metal index 48eea53..b0d7620 100644 --- a/vllm_metal/metal/kernels_v2/pagedattention.metal +++ b/vllm_metal/metal/kernels_v2/pagedattention.metal @@ -756,6 +756,34 @@ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +// Binary search to find which sequence a global query token belongs to. +// +// In varlen (ragged-batch) attention, queries from multiple sequences are +// packed contiguously into a flat array: +// q[0..q_len_0-1] → seq 0, q[q_len_0..q_len_0+q_len_1-1] → seq 1, ... +// The kernel launches one threadgroup per (head, query_token) in a flat grid. +// Each threadgroup needs to discover which sequence it belongs to so it can +// look up the correct block_table row, kv_len, and causal mask boundary. +// +// This is the same approach used by the upstream vLLM unified Triton kernel +// (triton_unified_attention.py:find_seq_idx) and FlashAttention's varlen API. +// +// cu_seqlens_q is sorted ascending: [0, q_len_0, q_len_0+q_len_1, ...]. +// Returns seq_idx such that cu_seqlens_q[seq_idx] <= q_token_idx < cu_seqlens_q[seq_idx+1]. +inline int find_seq_idx(const device int32_t *cu_seqlens_q, + int q_token_idx, int num_seqs) { + int lo = 0, hi = num_seqs; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (cu_seqlens_q[mid] <= q_token_idx) { + lo = mid; + } else { + hi = mid - 1; + } + } + return lo; +} + constant bool use_partitioning [[function_constant(10)]]; constant bool use_alibi [[function_constant(20)]]; constant bool use_fp8_scales [[function_constant(30)]]; @@ -795,24 +823,41 @@ template 0; - const uint32_t context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const uint32_t context_len = context_lens[seq_idx]; // total KV length for this seq + + // Causal: this query token can attend to KV positions [0, effective_context_len). + const int effective_context_len = (int)context_len - q_len + q_pos_in_seq + 1; + if (effective_context_len <= 0) { + // No KV tokens to attend to. Caller guarantees out is zero-initialized. + return; + } + + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= effective_context_len) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_context_blocks = DIVIDE_ROUND_UP(effective_context_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; @@ -867,7 +912,7 @@ template ::dot( q_vecs[thread_group_offset], k_vecs); - if (softcapping != 1.0) { + if (softcapping > 0.0f) { qk = tanh(qk / softcapping) * softcapping; } qk += - (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + (alibi_slope != 0) ? alibi_slope * (token_idx - effective_context_len + 1) : 0; if (thread_group_offset == 0) { - const bool mask = token_idx >= context_len; + // Causal mask: only attend to KV positions < effective_context_len. + bool mask = token_idx >= effective_context_len; + // Sliding window mask: skip positions too far in the past. + if (sliding_window >= 0) { + mask = mask || (token_idx < effective_context_len - sliding_window); + } warp_scores[physical_block_offset] = mask ? -FLT_MAX : qk; } } @@ -981,7 +1031,7 @@ template (shared_mem); const device float *max_logits_ptr = - max_logits + seq_idx * num_heads * max_num_partitions + + max_logits + q_token_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = thread_position_in_threadgroup.x; i < num_partitions; @@ -1242,7 +1301,7 @@ template ( shared_mem + sizeof(float) * num_partitions); const device float *exp_sums_ptr = exp_sums + - seq_idx * num_heads * max_num_partitions + + q_token_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = thread_position_in_threadgroup.x; i < num_partitions; @@ -1265,10 +1324,10 @@ template (out_h); auto& query = *nb::inst_ptr(query_h); @@ -276,14 +279,17 @@ void paged_attention_v2_online_impl( auto& value_cache = *nb::inst_ptr(value_cache_h); auto& block_tables = *nb::inst_ptr(block_tables_h); auto& seq_lens = *nb::inst_ptr(seq_lens_h); + auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); auto s = default_stream(Device::gpu); auto& d = metal::device(Device::gpu); - int num_seqs = static_cast(query.shape(0)); + // Varlen: query shape is [total_q_tokens, num_heads, head_size] + int total_q_tokens = static_cast(query.shape(0)); int num_heads = static_cast(query.shape(1)); int head_size = static_cast(query.shape(2)); int max_blocks = static_cast(block_tables.shape(1)); + int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; // Same kernel name format as v1 — the template instantiation is identical. auto dt = dtype_to_metal(query.dtype()); @@ -325,7 +331,7 @@ void paged_attention_v2_online_impl( enc.set_compute_pipeline_state(kernel); enc.set_threadgroup_memory_length(shmem, 0); - // Buffer bindings — identical to v1. + // Buffer bindings enc.set_output_array(out, 2); enc.set_input_array(query, 3); enc.set_input_array(key_cache, 4); @@ -334,7 +340,8 @@ void paged_attention_v2_online_impl( int32_t nkv = static_cast(num_kv_heads); enc.set_bytes(nkv, 8); enc.set_bytes(scale, 9); - float softcapping = 1.0f; + // softcap: 0.0 = disabled, >0 = enabled. Passed through to kernel as-is. + float softcapping = softcap; enc.set_bytes(softcapping, 10); enc.set_input_array(block_tables, 11); @@ -350,8 +357,16 @@ void paged_attention_v2_online_impl( enc.set_bytes(kv_block_stride, 16); enc.set_bytes(kv_head_stride, 17); + // Varlen buffers (new in v2) + enc.set_input_array(cu_seqlens_q, 19); + int32_t num_seqs_i = static_cast(num_seqs); + enc.set_bytes(num_seqs_i, 20); + int32_t sliding_window_i = static_cast(sliding_window); + enc.set_bytes(sliding_window_i, 21); + + // Grid: one threadgroup per (head, query_token) enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, num_seqs, 1), + MTL::Size::Make(num_heads, total_q_tokens, 1), MTL::Size::Make(NUM_THREADS, 1, 1)); d.add_temporary(out, s.index); @@ -360,6 +375,7 @@ void paged_attention_v2_online_impl( d.add_temporary(value_cache, s.index); d.add_temporary(block_tables, s.index); d.add_temporary(seq_lens, s.index); + d.add_temporary(cu_seqlens_q, s.index); } // --------------------------------------------------------------------------- @@ -393,7 +409,10 @@ NB_MODULE(_paged_ops, m) { nb::arg("out"), nb::arg("query"), nb::arg("key_cache"), nb::arg("value_cache"), nb::arg("num_kv_heads"), nb::arg("scale"), + nb::arg("softcap"), nb::arg("block_tables"), nb::arg("seq_lens"), + nb::arg("cu_seqlens_q"), nb::arg("block_size"), nb::arg("max_seq_len"), - "Online-softmax paged attention (v2, decode-only)."); + nb::arg("sliding_window"), + "Online-softmax varlen paged attention (v2, unified prefill+decode)."); } diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index a590d49..fe0f6d0 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -226,7 +226,11 @@ def _metal_kernel_decode_attention( max_seq_len = max(ctx.context_lens) scale = attn_module.scale - # Zero-copy paged attention (v2, online softmax) + # Build cu_seqlens_q for varlen dispatch: decode has q_len=1 per sequence. + cu_seqlens_q = mx.arange(B + 1, dtype=mx.int32) + mx.eval(cu_seqlens_q) + + # Zero-copy paged attention (v2, online softmax, varlen-capable) ops.paged_attention_v2_online( out, q_3d, @@ -234,10 +238,13 @@ def _metal_kernel_decode_attention( cache.value_caches[layer_idx], cache.num_kv_heads, scale, + 0.0, # softcap (0 = disabled) block_tables, seq_lens, + cu_seqlens_q, cache.block_size, max_seq_len, + -1, # sliding_window (-1 = disabled) ) # Synchronize GPU: paged_attention_v2_online wrote to out's buffer via a raw