diff --git a/tests/test_metal_kernel_paged.py b/tests/test_metal_kernel_paged.py deleted file mode 100644 index 438637c4..00000000 --- a/tests/test_metal_kernel_paged.py +++ /dev/null @@ -1,338 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for Metal kernel paged attention — verifies output matches non-paged path. - -Run with: - python -m pytest tests/test_metal_kernel_paged.py -v -s -""" - -from __future__ import annotations - -import pytest - -MODEL_NAME = "Qwen/Qwen3-0.6B" -BLOCK_SIZE = 16 - -try: - import mlx.core as mx - from mlx_lm import load as mlx_lm_load - from mlx_lm.models.cache import make_prompt_cache - - from vllm_metal.kv_cache_dtype import infer_kv_cache_dtype_from_model -except ImportError as exc: - pytest.skip( - f"Metal kernel paged attention tests require mlx/mlx_lm: {exc}", - allow_module_level=True, - ) - -try: - from vllm_metal.metal import get_ops - from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache - from vllm_metal.metal_kernel_backend.paged_attention import ( - MetalKernelPagedAttentionWrapper, - patch_model_attention_metal_kernel, - ) - from vllm_metal.paged_attention_common import ( - OffsetCache, - clear_context, - prepare_decode, - prepare_prefill, - ) -except ImportError as exc: - pytest.skip( - f"Metal kernel paged attention tests require vllm-metal paged backend: {exc}", - allow_module_level=True, - ) - - -@pytest.fixture(scope="module", autouse=True) -def _paged_attention_ops_available() -> None: - """Fail early if the native paged-attention ops cannot be loaded.""" - get_ops() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _test_infer_paged_kv_dtype(model) -> mx.Dtype: - """Test-only helper: choose a float dtype for MetalPagedKVCache. - - This is deliberately local to this test module. Production code uses - `vllm_metal.kv_cache_dtype.infer_kv_cache_dtype_from_model()`. - """ - result = infer_kv_cache_dtype_from_model(model) - if result.warning is not None: - raise AssertionError( - "KV cache dtype inference unexpectedly fell back during tests: " - f"{result.warning}" - ) - return result.dtype - - -def _greedy_generate_standard(model, token_ids: list[int], max_new: int) -> list[int]: - """Generate tokens using the standard mlx_lm KVCache path.""" - cache = make_prompt_cache(model) - - # Prefill - input_ids = mx.array([token_ids], dtype=mx.int32) - logits = model(input_ids, cache=cache) - next_tok = int(mx.argmax(logits[:, -1, :], axis=-1).item()) - mx.eval(mx.array(next_tok), *[c.state for c in cache]) - generated = [next_tok] - - # Decode - for _ in range(max_new - 1): - input_ids = mx.array([[generated[-1]]], dtype=mx.int32) - logits = model(input_ids, cache=cache) - next_tok = int(mx.argmax(logits[:, -1, :], axis=-1).item()) - mx.eval(mx.array(next_tok), *[c.state for c in cache]) - generated.append(next_tok) - - return generated - - -def _greedy_generate_metal_kernel( - model, token_ids: list[int], max_new: int -) -> list[int]: - """Generate tokens using the Metal kernel paged attention path.""" - args = model.args - num_layers = args.num_hidden_layers - num_kv_heads = args.num_key_value_heads - head_dim = args.head_dim - - # Allocate generous block pool - total_tokens = len(token_ids) + max_new + BLOCK_SIZE - num_blocks = (total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE + 4 - - metal_cache = MetalPagedKVCache( - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - num_blocks=num_blocks, - block_size=BLOCK_SIZE, - dtype=_test_infer_paged_kv_dtype(model), - ) - - n_patched = patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) - assert n_patched == num_layers - - # Assign block IDs for this sequence (manual allocation) - seq_blocks_needed = (len(token_ids) + max_new + BLOCK_SIZE - 1) // BLOCK_SIZE - block_ids = list(range(seq_blocks_needed)) - - # --- Prefill --- - prepare_prefill(block_ids, len(token_ids), BLOCK_SIZE) - offset_caches = [OffsetCache(0) for _ in range(num_layers)] - - input_ids = mx.array([token_ids], dtype=mx.int32) - logits = model(input_ids, cache=offset_caches) - next_tok = int(mx.argmax(logits[:, -1, :], axis=-1).item()) - mx.eval(mx.array(next_tok)) - clear_context() - generated = [next_tok] - - seq_len = len(token_ids) # tokens stored in cache so far - - # --- Decode --- - for _ in range(max_new - 1): - prepare_decode([(block_ids, seq_len)], BLOCK_SIZE) - offset_caches = [OffsetCache(seq_len) for _ in range(num_layers)] - - input_ids = mx.array([[generated[-1]]], dtype=mx.int32) - logits = model(input_ids, cache=offset_caches) - next_tok = int(mx.argmax(logits[:, -1, :], axis=-1).item()) - mx.eval(mx.array(next_tok)) - clear_context() - generated.append(next_tok) - seq_len += 1 - - return generated - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture(scope="module") -def qwen3_model(): - """Load Qwen3-0.6B once for all tests in this module.""" - model, tokenizer = mlx_lm_load(MODEL_NAME) - return model, tokenizer - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestMetalKernelPagedVsStandard: - @pytest.mark.slow - def test_greedy_output_matches(self, qwen3_model): - """Metal kernel paged attention greedy decode must match standard path.""" - model, tokenizer = qwen3_model - prompt = "The capital of France is" - token_ids = tokenizer.encode(prompt) - max_new = 20 - - # Standard path - ref_tokens = _greedy_generate_standard(model, token_ids, max_new) - - # Metal kernel path - mk_tokens = _greedy_generate_metal_kernel(model, token_ids, max_new) - - assert ref_tokens == mk_tokens, ( - f"Token mismatch!\n" - f" Standard: {ref_tokens}\n" - f" Metal kernel: {mk_tokens}" - ) - - @pytest.mark.slow - @pytest.mark.xfail( - reason="B=2 batched GEMM produces different floats than B=1, " - "causing token divergence after ~5 decode steps (not a kernel bug). " - "See https://github.com/vllm-project/vllm-metal/issues/119" - ) - def test_batched_decode_matches(self, qwen3_model): - """Batched Metal kernel paged decode must match per-request sequential.""" - model, tokenizer = qwen3_model - prompts = [ - "The capital of France is", - "Machine learning is", - ] - max_new = 10 - - # Generate reference tokens independently - ref_all = [] - for prompt in prompts: - token_ids = tokenizer.encode(prompt) - ref_all.append(_greedy_generate_standard(model, token_ids, max_new)) - - # Metal kernel path: prefill each, then batched decode - args = model.args - num_layers = args.num_hidden_layers - num_kv_heads = args.num_key_value_heads - head_dim = args.head_dim - - total_max = ( - max(len(tokenizer.encode(p)) for p in prompts) + max_new + BLOCK_SIZE - ) - num_blocks = ((total_max + BLOCK_SIZE - 1) // BLOCK_SIZE) * len(prompts) + 8 - - metal_cache = MetalPagedKVCache( - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - num_blocks=num_blocks, - block_size=BLOCK_SIZE, - dtype=_test_infer_paged_kv_dtype(model), - ) - patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) - - # Prefill each prompt - all_token_ids = [] - all_block_ids = [] - all_seq_lens = [] - all_generated: list[list[int]] = [] - - block_offset = 0 - for _i, prompt in enumerate(prompts): - tids = tokenizer.encode(prompt) - all_token_ids.append(tids) - needed = (len(tids) + max_new + BLOCK_SIZE - 1) // BLOCK_SIZE - bids = list(range(block_offset, block_offset + needed)) - block_offset += needed - all_block_ids.append(bids) - - prepare_prefill(bids, len(tids), BLOCK_SIZE) - offset_caches = [OffsetCache(0) for _ in range(num_layers)] - input_ids = mx.array([tids], dtype=mx.int32) - logits = model(input_ids, cache=offset_caches) - next_tok = int(mx.argmax(logits[:, -1, :], axis=-1).item()) - mx.eval(mx.array(next_tok)) - clear_context() - - all_generated.append([next_tok]) - all_seq_lens.append(len(tids)) - - # Batched decode steps - for _step in range(max_new - 1): - requests_info = [] - for i in range(len(prompts)): - requests_info.append((all_block_ids[i], all_seq_lens[i])) - - prepare_decode(requests_info, BLOCK_SIZE) - - max_offset = max(all_seq_lens) - offset_caches = [OffsetCache(max_offset) for _ in range(num_layers)] - - last_tokens = [gen[-1] for gen in all_generated] - batched_input = mx.array(last_tokens, dtype=mx.int32)[:, None] - logits = model(batched_input, cache=offset_caches) - next_toks = mx.argmax(logits[:, -1, :], axis=-1) - mx.eval(next_toks) - clear_context() - - for i in range(len(prompts)): - tok = int(next_toks[i].item()) - all_generated[i].append(tok) - all_seq_lens[i] += 1 - - # Compare - for i, prompt in enumerate(prompts): - assert all_generated[i] == ref_all[i], ( - f"Mismatch for prompt {i} ({prompt!r}):\n" - f" Standard: {ref_all[i]}\n" - f" Metal kernel: {all_generated[i]}" - ) - - -class TestMetalKernelPatchRouting: - """Verify that the wrapper routes to metal kernel vs fallback.""" - - @pytest.mark.slow - def test_patch_replaces_self_attn(self, qwen3_model): - """After patching, each layer's self_attn should be a wrapper.""" - model, _ = qwen3_model - args = model.args - - metal_cache = MetalPagedKVCache( - num_layers=args.num_hidden_layers, - num_kv_heads=args.num_key_value_heads, - head_dim=args.head_dim, - num_blocks=32, - block_size=BLOCK_SIZE, - dtype=_test_infer_paged_kv_dtype(model), - ) - patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) - - layers = model.model.layers - for i, layer in enumerate(layers): - assert isinstance(layer.self_attn, MetalKernelPagedAttentionWrapper), ( - f"Layer {i} self_attn is {type(layer.self_attn).__name__}, " - f"expected MetalKernelPagedAttentionWrapper" - ) - - @pytest.mark.slow - def test_fallback_when_no_context(self, qwen3_model): - """Without paged context, calls must fall back to original attention.""" - model, _ = qwen3_model - args = model.args - - metal_cache = MetalPagedKVCache( - num_layers=args.num_hidden_layers, - num_kv_heads=args.num_key_value_heads, - head_dim=args.head_dim, - num_blocks=32, - block_size=BLOCK_SIZE, - dtype=_test_infer_paged_kv_dtype(model), - ) - patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) - - # Run forward without setting context → should use fallback - cache = make_prompt_cache(model) - input_ids = mx.array([[1, 2, 3]], dtype=mx.int32) - logits = model(input_ids, cache=cache) - mx.eval(logits) - # If we got here without error, fallback worked diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py index 0b1d3420..50ab97b8 100644 --- a/tests/test_paged_deterministic.py +++ b/tests/test_paged_deterministic.py @@ -46,6 +46,7 @@ "One plus one equals", "The largest planet in our solar system is", "Water boils at a temperature of", + "Machine learning is", ] # fmt: off @@ -57,6 +58,7 @@ "One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11], "The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13], "Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315], + "Machine learning is": [264, 7988, 5392, 429, 702, 13791, 1506, 279, 2070, 315], } # Golden token IDs from paged KV cache (HF kernel on main branch), greedy decoding. @@ -68,6 +70,7 @@ "One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11], "The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13], "Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315], + "Machine learning is": [264, 7988, 5392, 429, 702, 13791, 1506, 279, 2070, 315], } # fmt: on diff --git a/tools/gen_golden_token_ids_for_deterministics.py b/tools/gen_golden_token_ids_for_deterministics.py index 1043ffe5..d29a59d9 100644 --- a/tools/gen_golden_token_ids_for_deterministics.py +++ b/tools/gen_golden_token_ids_for_deterministics.py @@ -32,6 +32,7 @@ "One plus one equals", "The largest planet in our solar system is", "Water boils at a temperature of", + "Machine learning is", ] if __name__ == "__main__":