Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ vLLM Metal is a plugin that enables vLLM to run on Apple Silicon Macs using MLX
- **MLX-accelerated inference**: faster than PyTorch MPS on Apple Silicon
- **Unified memory**: True zero-copy operations leveraging Apple Silicon's unified memory architecture
- **vLLM compatibility**: Full integration with vLLM's engine, scheduler, and OpenAI-compatible API
- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1` (requires `pip install 'vllm-metal[paged]'`); default path uses MLX-managed KV cache
- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1`; default path uses MLX-managed KV cache. When enabled, expect significantly better serving performance (~82x TTFT, ~3.75x throughput in early benchmarks on Qwen3-0.6B). Other models may have rough edges.
- **GQA support**: Grouped-Query Attention for efficient inference

## Requirements
Expand Down Expand Up @@ -95,14 +95,13 @@ Environment variables for customization:
| `VLLM_METAL_USE_MLX` | `1` | Use MLX for compute (1=yes, 0=no) |
| `VLLM_MLX_DEVICE` | `gpu` | MLX device (`gpu` or `cpu`) |
| `VLLM_METAL_BLOCK_SIZE` | `16` | KV cache block size |
| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache (requires `pip install 'vllm-metal[paged]'`) |
| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache |
| `VLLM_METAL_DEBUG` | `0` | Enable debug logging |
| `VLLM_USE_MODELSCOPE` | `False` | Set True to change model registry to <https://www.modelscope.cn/> |
| `VLLM_METAL_MODELSCOPE_CACHE` | None | Specify the absolute path of the local model |
| `VLLM_METAL_PREFIX_CACHE` | (unset) | Set to enable prefix caching for shared prompt reuse |
| `VLLM_METAL_PREFIX_CACHE_FRACTION` | `0.05` | Fraction of MLX working set for prefix cache (0, 1] |


## Paged KV vs MLX KV memory settings

- MLX path (`VLLM_METAL_USE_PAGED_ATTENTION=0`): `VLLM_METAL_MEMORY_FRACTION` must be `auto`.
Expand All @@ -115,3 +114,7 @@ Environment variables for customization:
`auto` | `1` | Yes | Paged KV path; defaults to 0.9 internally
`0.7` | `1` | Yes | Paged KV path with explicit memory budget
`0.7` | `0` | No | Explicit fraction without paged KV is invalid

## Acknowledgements

- The Metal paged attention kernels are currently adapted from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license), via [HuggingFace kernels-community](https://github.com/huggingface/kernels-community). We plan to develop custom kernels in the future.
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ dependencies = [
"transformers>=4.40.0",
"accelerate>=0.26.0",
"safetensors>=0.4.0",
# Native Metal extension JIT build
"nanobind>=2.0.0; platform_system == 'Darwin' and platform_machine == 'arm64'",
# Core utilities
"numpy>=1.24.0",
"psutil>=5.9.0",
]

[project.optional-dependencies]
paged = [
# Paged attention Metal kernel (opt-in, experimental)
"kernels>=0.4.5; platform_system == 'Darwin' and platform_machine == 'arm64'",
]
vllm = ["vllm>=0.14.0"]
stt = [
# Speech-to-text audio processing (Whisper models)
Expand All @@ -58,7 +56,7 @@ dev = [
"mypy>=1.19.1",
]
all = [
"vllm-metal[vllm,paged,stt,dev]",
"vllm-metal[vllm,stt,dev]",
]

[project.urls]
Expand Down
128 changes: 0 additions & 128 deletions tests/test_kernel_loader.py

This file was deleted.

42 changes: 16 additions & 26 deletions tests/test_metal_kernel_paged.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for Metal kernel paged attention — verifies output matches non-paged path.

Requires ``kernels`` package with ``kernels-community/paged-attention`` support.

Run with:
python -m pytest tests/test_metal_kernel_paged.py -v -s
"""
Expand All @@ -16,20 +14,19 @@

try:
import mlx.core as mx
import torch
from mlx_lm import load as mlx_lm_load
from mlx_lm.models.cache import make_prompt_cache

from vllm_metal.kv_cache_dtype import infer_kv_cache_dtype_from_model
except ImportError as exc:
pytest.skip(
f"Metal kernel paged attention tests require mlx/torch/mlx_lm: {exc}",
f"Metal kernel paged attention tests require mlx/mlx_lm: {exc}",
allow_module_level=True,
)

try:
from vllm_metal.metal_kernel_backend.cache import MPSPagedKVCache
from vllm_metal.metal_kernel_backend.kernel_loader import get_paged_attention_ops
from vllm_metal.metal import get_ops
from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
from vllm_metal.metal_kernel_backend.paged_attention import (
MetalKernelPagedAttentionWrapper,
patch_model_attention_metal_kernel,
Expand All @@ -42,31 +39,24 @@
)
except ImportError as exc:
pytest.skip(
"Metal kernel paged attention tests require the vllm-metal paged backend: "
f"{exc}. Install with: pip install 'vllm-metal[paged]'",
f"Metal kernel paged attention tests require vllm-metal paged backend: {exc}",
allow_module_level=True,
)


@pytest.fixture(scope="module", autouse=True)
def _paged_attention_ops_available() -> None:
"""Skip this module if the paged-attention ops cannot be loaded."""

try:
get_paged_attention_ops()
except ImportError as exc:
pytest.skip(str(exc))
except Exception as exc:
pytest.skip(f"kernels-community/paged-attention not available: {exc}")
"""Fail early if the native paged-attention ops cannot be loaded."""
get_ops()


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _test_infer_paged_kv_dtype(model) -> torch.dtype:
"""Test-only helper: choose a float dtype for MPSPagedKVCache.
def _test_infer_paged_kv_dtype(model) -> mx.Dtype:
"""Test-only helper: choose a float dtype for MetalPagedKVCache.

This is deliberately local to this test module. Production code uses
`vllm_metal.kv_cache_dtype.infer_kv_cache_dtype_from_model()`.
Expand Down Expand Up @@ -115,7 +105,7 @@ def _greedy_generate_metal_kernel(
total_tokens = len(token_ids) + max_new + BLOCK_SIZE
num_blocks = (total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE + 4

mps_cache = MPSPagedKVCache(
metal_cache = MetalPagedKVCache(
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
Expand All @@ -124,7 +114,7 @@ def _greedy_generate_metal_kernel(
dtype=_test_infer_paged_kv_dtype(model),
)

n_patched = patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
n_patched = patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)
assert n_patched == num_layers

# Assign block IDs for this sequence (manual allocation)
Expand Down Expand Up @@ -228,15 +218,15 @@ def test_batched_decode_matches(self, qwen3_model):
)
num_blocks = ((total_max + BLOCK_SIZE - 1) // BLOCK_SIZE) * len(prompts) + 8

mps_cache = MPSPagedKVCache(
metal_cache = MetalPagedKVCache(
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_blocks=num_blocks,
block_size=BLOCK_SIZE,
dtype=_test_infer_paged_kv_dtype(model),
)
patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)

# Prefill each prompt
all_token_ids = []
Expand Down Expand Up @@ -305,15 +295,15 @@ def test_patch_replaces_self_attn(self, qwen3_model):
model, _ = qwen3_model
args = model.args

mps_cache = MPSPagedKVCache(
metal_cache = MetalPagedKVCache(
num_layers=args.num_hidden_layers,
num_kv_heads=args.num_key_value_heads,
head_dim=args.head_dim,
num_blocks=32,
block_size=BLOCK_SIZE,
dtype=_test_infer_paged_kv_dtype(model),
)
patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)

layers = model.model.layers
for i, layer in enumerate(layers):
Expand All @@ -328,15 +318,15 @@ def test_fallback_when_no_context(self, qwen3_model):
model, _ = qwen3_model
args = model.args

mps_cache = MPSPagedKVCache(
metal_cache = MetalPagedKVCache(
num_layers=args.num_hidden_layers,
num_kv_heads=args.num_key_value_heads,
head_dim=args.head_dim,
num_blocks=32,
block_size=BLOCK_SIZE,
dtype=_test_infer_paged_kv_dtype(model),
)
patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)

# Run forward without setting context → should use fallback
cache = make_prompt_cache(model)
Expand Down
Loading