diff --git a/README.md b/README.md
index ee1c799..07d76b0 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@ vLLM Metal is a plugin that enables vLLM to run on Apple Silicon Macs using MLX
- **MLX-accelerated inference**: faster than PyTorch MPS on Apple Silicon
- **Unified memory**: True zero-copy operations leveraging Apple Silicon's unified memory architecture
- **vLLM compatibility**: Full integration with vLLM's engine, scheduler, and OpenAI-compatible API
-- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1` (requires `pip install 'vllm-metal[paged]'`); default path uses MLX-managed KV cache
+- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1`; default path uses MLX-managed KV cache. When enabled, expect significantly better serving performance (~82x TTFT, ~3.75x throughput in early benchmarks on Qwen3-0.6B). Other models may have rough edges.
- **GQA support**: Grouped-Query Attention for efficient inference
## Requirements
@@ -95,14 +95,13 @@ Environment variables for customization:
| `VLLM_METAL_USE_MLX` | `1` | Use MLX for compute (1=yes, 0=no) |
| `VLLM_MLX_DEVICE` | `gpu` | MLX device (`gpu` or `cpu`) |
| `VLLM_METAL_BLOCK_SIZE` | `16` | KV cache block size |
-| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache (requires `pip install 'vllm-metal[paged]'`) |
+| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache |
| `VLLM_METAL_DEBUG` | `0` | Enable debug logging |
| `VLLM_USE_MODELSCOPE` | `False` | Set True to change model registry to |
| `VLLM_METAL_MODELSCOPE_CACHE` | None | Specify the absolute path of the local model |
| `VLLM_METAL_PREFIX_CACHE` | (unset) | Set to enable prefix caching for shared prompt reuse |
| `VLLM_METAL_PREFIX_CACHE_FRACTION` | `0.05` | Fraction of MLX working set for prefix cache (0, 1] |
-
## Paged KV vs MLX KV memory settings
- MLX path (`VLLM_METAL_USE_PAGED_ATTENTION=0`): `VLLM_METAL_MEMORY_FRACTION` must be `auto`.
@@ -115,3 +114,7 @@ Environment variables for customization:
`auto` | `1` | Yes | Paged KV path; defaults to 0.9 internally
`0.7` | `1` | Yes | Paged KV path with explicit memory budget
`0.7` | `0` | No | Explicit fraction without paged KV is invalid
+
+## Acknowledgements
+
+- The Metal paged attention kernels are currently adapted from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license), via [HuggingFace kernels-community](https://github.com/huggingface/kernels-community). We plan to develop custom kernels in the future.
diff --git a/pyproject.toml b/pyproject.toml
index 763c53d..846e42f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,16 +35,14 @@ dependencies = [
"transformers>=4.40.0",
"accelerate>=0.26.0",
"safetensors>=0.4.0",
+ # Native Metal extension JIT build
+ "nanobind>=2.0.0; platform_system == 'Darwin' and platform_machine == 'arm64'",
# Core utilities
"numpy>=1.24.0",
"psutil>=5.9.0",
]
[project.optional-dependencies]
-paged = [
- # Paged attention Metal kernel (opt-in, experimental)
- "kernels>=0.4.5; platform_system == 'Darwin' and platform_machine == 'arm64'",
-]
vllm = ["vllm>=0.14.0"]
stt = [
# Speech-to-text audio processing (Whisper models)
@@ -58,7 +56,7 @@ dev = [
"mypy>=1.19.1",
]
all = [
- "vllm-metal[vllm,paged,stt,dev]",
+ "vllm-metal[vllm,stt,dev]",
]
[project.urls]
diff --git a/tests/test_kernel_loader.py b/tests/test_kernel_loader.py
deleted file mode 100644
index 0872c8a..0000000
--- a/tests/test_kernel_loader.py
+++ /dev/null
@@ -1,128 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for kernel_loader: OS-aware revision pinning for Metal compatibility.
-
-Verifies that:
-- macOS 16+ uses the latest HF kernel (default revision)
-- macOS 15 and earlier pins to the Nov 2025 compat revision (Metal 3.2)
-- Both revisions actually load and expose the expected ops
-
-Run with:
- python -m pytest tests/test_kernel_loader.py -v -s
-"""
-
-from __future__ import annotations
-
-from unittest import mock
-
-import pytest
-
-pytest.importorskip("kernels")
-
-# ---------------------------------------------------------------------------
-# Unit tests (no network, no GPU)
-# ---------------------------------------------------------------------------
-
-
-class TestNeedsCompatRevision:
- """Test _needs_compat_revision() with mocked macOS versions."""
-
- @pytest.mark.parametrize(
- "ver, expected",
- [
- ("15.7.4", True), # macOS 15 — needs compat
- ("14.5", True), # macOS 14 — needs compat
- ("26.3", False), # macOS 26 — modern
- ("", False), # empty — safe default
- ],
- )
- def test_version_check(self, ver, expected):
- from vllm_metal.metal_kernel_backend.kernel_loader import _needs_compat_revision
-
- with mock.patch("platform.mac_ver", return_value=(ver, ("", "", ""), "")):
- assert _needs_compat_revision() is expected
-
-
-class TestGetKernelRevisionSelection:
- """Test that get_paged_attention_ops passes the right revision to get_kernel."""
-
- def _reset_kernel_cache(self):
- import vllm_metal.metal_kernel_backend.kernel_loader as kl
-
- kl._kernel = None
-
- def test_macos_15_uses_compat_revision(self):
- self._reset_kernel_cache()
- with (
- mock.patch("platform.mac_ver", return_value=("15.7.4", ("", "", ""), "")),
- mock.patch("kernels.get_kernel", return_value=mock.MagicMock()) as mk,
- ):
- from vllm_metal.metal_kernel_backend.kernel_loader import (
- _MACOS15_COMPAT_REVISION,
- get_paged_attention_ops,
- )
-
- get_paged_attention_ops()
- mk.assert_called_once_with(
- "kernels-community/paged-attention",
- revision=_MACOS15_COMPAT_REVISION,
- )
- self._reset_kernel_cache()
-
- def test_macos_26_uses_latest(self):
- self._reset_kernel_cache()
- with (
- mock.patch("platform.mac_ver", return_value=("26.3", ("", "", ""), "")),
- mock.patch("kernels.get_kernel", return_value=mock.MagicMock()) as mk,
- ):
- from vllm_metal.metal_kernel_backend.kernel_loader import (
- get_paged_attention_ops,
- )
-
- get_paged_attention_ops()
- mk.assert_called_once_with(
- "kernels-community/paged-attention",
- revision=None,
- )
- self._reset_kernel_cache()
-
-
-# ---------------------------------------------------------------------------
-# Integration tests (require network + MPS)
-# ---------------------------------------------------------------------------
-
-
-def _mps_available() -> bool:
- try:
- import torch
-
- return torch.backends.mps.is_available()
- except Exception:
- return False
-
-
-@pytest.mark.skipif(not _mps_available(), reason="MPS not available")
-class TestKernelLoadsForReal:
- """Actually load the kernel from HuggingFace and verify ops exist."""
-
- _EXPECTED_OPS = {"reshape_and_cache", "paged_attention_v1"}
-
- def test_latest_revision_loads(self):
- from kernels import get_kernel
-
- kernel = get_kernel("kernels-community/paged-attention")
- ops = set(dir(kernel))
- assert self._EXPECTED_OPS <= ops, f"Missing ops: {self._EXPECTED_OPS - ops}"
-
- def test_compat_revision_loads(self):
- from kernels import get_kernel
-
- from vllm_metal.metal_kernel_backend.kernel_loader import (
- _MACOS15_COMPAT_REVISION,
- )
-
- kernel = get_kernel(
- "kernels-community/paged-attention",
- revision=_MACOS15_COMPAT_REVISION,
- )
- ops = set(dir(kernel))
- assert self._EXPECTED_OPS <= ops, f"Missing ops: {self._EXPECTED_OPS - ops}"
diff --git a/tests/test_metal_kernel_paged.py b/tests/test_metal_kernel_paged.py
index 9c0c839..a1d77b6 100644
--- a/tests/test_metal_kernel_paged.py
+++ b/tests/test_metal_kernel_paged.py
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for Metal kernel paged attention — verifies output matches non-paged path.
-Requires ``kernels`` package with ``kernels-community/paged-attention`` support.
-
Run with:
python -m pytest tests/test_metal_kernel_paged.py -v -s
"""
@@ -16,20 +14,19 @@
try:
import mlx.core as mx
- import torch
from mlx_lm import load as mlx_lm_load
from mlx_lm.models.cache import make_prompt_cache
from vllm_metal.kv_cache_dtype import infer_kv_cache_dtype_from_model
except ImportError as exc:
pytest.skip(
- f"Metal kernel paged attention tests require mlx/torch/mlx_lm: {exc}",
+ f"Metal kernel paged attention tests require mlx/mlx_lm: {exc}",
allow_module_level=True,
)
try:
- from vllm_metal.metal_kernel_backend.cache import MPSPagedKVCache
- from vllm_metal.metal_kernel_backend.kernel_loader import get_paged_attention_ops
+ from vllm_metal.metal import get_ops
+ from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
from vllm_metal.metal_kernel_backend.paged_attention import (
MetalKernelPagedAttentionWrapper,
patch_model_attention_metal_kernel,
@@ -42,22 +39,15 @@
)
except ImportError as exc:
pytest.skip(
- "Metal kernel paged attention tests require the vllm-metal paged backend: "
- f"{exc}. Install with: pip install 'vllm-metal[paged]'",
+ f"Metal kernel paged attention tests require vllm-metal paged backend: {exc}",
allow_module_level=True,
)
@pytest.fixture(scope="module", autouse=True)
def _paged_attention_ops_available() -> None:
- """Skip this module if the paged-attention ops cannot be loaded."""
-
- try:
- get_paged_attention_ops()
- except ImportError as exc:
- pytest.skip(str(exc))
- except Exception as exc:
- pytest.skip(f"kernels-community/paged-attention not available: {exc}")
+ """Fail early if the native paged-attention ops cannot be loaded."""
+ get_ops()
# ---------------------------------------------------------------------------
@@ -65,8 +55,8 @@ def _paged_attention_ops_available() -> None:
# ---------------------------------------------------------------------------
-def _test_infer_paged_kv_dtype(model) -> torch.dtype:
- """Test-only helper: choose a float dtype for MPSPagedKVCache.
+def _test_infer_paged_kv_dtype(model) -> mx.Dtype:
+ """Test-only helper: choose a float dtype for MetalPagedKVCache.
This is deliberately local to this test module. Production code uses
`vllm_metal.kv_cache_dtype.infer_kv_cache_dtype_from_model()`.
@@ -115,7 +105,7 @@ def _greedy_generate_metal_kernel(
total_tokens = len(token_ids) + max_new + BLOCK_SIZE
num_blocks = (total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE + 4
- mps_cache = MPSPagedKVCache(
+ metal_cache = MetalPagedKVCache(
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
@@ -124,7 +114,7 @@ def _greedy_generate_metal_kernel(
dtype=_test_infer_paged_kv_dtype(model),
)
- n_patched = patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
+ n_patched = patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)
assert n_patched == num_layers
# Assign block IDs for this sequence (manual allocation)
@@ -228,7 +218,7 @@ def test_batched_decode_matches(self, qwen3_model):
)
num_blocks = ((total_max + BLOCK_SIZE - 1) // BLOCK_SIZE) * len(prompts) + 8
- mps_cache = MPSPagedKVCache(
+ metal_cache = MetalPagedKVCache(
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
@@ -236,7 +226,7 @@ def test_batched_decode_matches(self, qwen3_model):
block_size=BLOCK_SIZE,
dtype=_test_infer_paged_kv_dtype(model),
)
- patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
+ patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)
# Prefill each prompt
all_token_ids = []
@@ -305,7 +295,7 @@ def test_patch_replaces_self_attn(self, qwen3_model):
model, _ = qwen3_model
args = model.args
- mps_cache = MPSPagedKVCache(
+ metal_cache = MetalPagedKVCache(
num_layers=args.num_hidden_layers,
num_kv_heads=args.num_key_value_heads,
head_dim=args.head_dim,
@@ -313,7 +303,7 @@ def test_patch_replaces_self_attn(self, qwen3_model):
block_size=BLOCK_SIZE,
dtype=_test_infer_paged_kv_dtype(model),
)
- patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
+ patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)
layers = model.model.layers
for i, layer in enumerate(layers):
@@ -328,7 +318,7 @@ def test_fallback_when_no_context(self, qwen3_model):
model, _ = qwen3_model
args = model.args
- mps_cache = MPSPagedKVCache(
+ metal_cache = MetalPagedKVCache(
num_layers=args.num_hidden_layers,
num_kv_heads=args.num_key_value_heads,
head_dim=args.head_dim,
@@ -336,7 +326,7 @@ def test_fallback_when_no_context(self, qwen3_model):
block_size=BLOCK_SIZE,
dtype=_test_infer_paged_kv_dtype(model),
)
- patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
+ patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE)
# Run forward without setting context → should use fallback
cache = make_prompt_cache(model)
diff --git a/vllm_metal/kv_cache_dtype.py b/vllm_metal/kv_cache_dtype.py
index 365b53f..b52501e 100644
--- a/vllm_metal/kv_cache_dtype.py
+++ b/vllm_metal/kv_cache_dtype.py
@@ -2,7 +2,7 @@
"""KV cache dtype inference and policy.
The Metal paged-attention backend stores *activation* K/V tensors in an
-MPS-backed cache. Those tensors must be floating point. Some models may have
+MLX-native cache. Those tensors must be floating point. Some models may have
quantized *weights* (e.g. int8), so we must not derive the KV cache dtype from
weights without enforcing a float-only policy.
"""
@@ -12,17 +12,16 @@
from dataclasses import dataclass
from typing import Any
-import torch
+import mlx.core as mx
from vllm_metal.paged_attention_common import find_layers_and_attr
-from vllm_metal.pytorch_backend.tensor_bridge import MLX_TO_TORCH_DTYPE
-DEFAULT_KV_CACHE_DTYPE = torch.float16
-ALLOWED_KV_CACHE_DTYPES: frozenset[torch.dtype] = frozenset(
+DEFAULT_KV_CACHE_DTYPE: mx.Dtype = mx.float16
+ALLOWED_KV_CACHE_DTYPES: frozenset[mx.Dtype] = frozenset(
{
- torch.float16,
- torch.bfloat16,
- torch.float32,
+ mx.float16,
+ mx.bfloat16,
+ mx.float32,
}
)
@@ -31,18 +30,17 @@
class KvCacheDtypeInference:
"""Result of inferring the KV cache dtype from a model."""
- dtype: torch.dtype
+ dtype: mx.Dtype
warning: str | None = None
def infer_kv_cache_dtype_from_model(
- model: Any, *, default: torch.dtype = DEFAULT_KV_CACHE_DTYPE
+ model: Any, *, default: mx.Dtype = DEFAULT_KV_CACHE_DTYPE
) -> KvCacheDtypeInference:
"""Infer a float KV-cache dtype from an MLX(-LM/-VLM) model.
Policy:
- - If we can map the model's attention weight dtype to torch and it's a
- supported float dtype, use it.
+ - If the model's attention weight dtype is a supported float dtype, use it.
- Otherwise, fall back to *default* and provide a warning string the caller
may log.
"""
@@ -62,20 +60,13 @@ def infer_kv_cache_dtype_from_model(
warning=f"Cannot infer KV cache dtype from model ({exc}); using {default}",
)
- torch_dtype = MLX_TO_TORCH_DTYPE.get(mlx_dtype)
- if torch_dtype is None:
- return KvCacheDtypeInference(
- dtype=default,
- warning=f"Unsupported MLX dtype for KV cache ({mlx_dtype!r}); using {default}",
- )
-
- if torch_dtype not in ALLOWED_KV_CACHE_DTYPES:
+ if mlx_dtype not in ALLOWED_KV_CACHE_DTYPES:
return KvCacheDtypeInference(
dtype=default,
warning=(
- f"Model weight dtype {mlx_dtype!r} maps to non-float torch dtype "
- f"{torch_dtype}; using {default} for KV cache instead"
+ f"Model weight dtype {mlx_dtype!r} is not a supported float dtype; "
+ f"using {default} for KV cache instead"
),
)
- return KvCacheDtypeInference(dtype=torch_dtype)
+ return KvCacheDtypeInference(dtype=mlx_dtype)
diff --git a/vllm_metal/metal/README.md b/vllm_metal/metal/README.md
new file mode 100644
index 0000000..1d65a78
--- /dev/null
+++ b/vllm_metal/metal/README.md
@@ -0,0 +1,43 @@
+# Metal Kernel Sources
+
+This directory contains two sets of Metal paged-attention shaders, both vendored from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license).
+
+## `kernels/` — active (current)
+
+Drop-in replacement for the HuggingFace kernels-community paged-attention shaders, originally vendored from an older version of mistral.rs. This is what `paged_ops.cpp` compiles and dispatches today via MLX.
+
+| File | Purpose |
+|------|---------|
+| `utils.metal` | bfloat16 polyfill, operator overloads |
+| `float8.metal` | FP8 E4M3/E5M2 encode/decode helpers |
+| `attention/paged_attention.metal` | paged attention v1/v2 kernels |
+| `cache/reshape_and_cache.metal` | write projected K/V into block cache |
+| `cache/copy_blocks.metal` | block-level cache copy kernel |
+| `convert_fp8.metal` | FP8 precision conversion kernel |
+
+### Reference only (not compiled, kept for context)
+
+- `paged_attention.mm` — PyTorch MPS dispatch (Obj-C++), replaced by `paged_ops.cpp`
+- `cache.mm` — PyTorch MPS cache ops (Obj-C++), replaced by `paged_ops.cpp`
+
+## `kernels_v1/` — next-generation (not yet wired up)
+
+Latest Metal kernels from the mistral.rs repo. More mature than `kernels/`, with preliminary scaffolding for variable-length sequences and gpt-oss sink attention support.
+
+| File | Purpose |
+|------|---------|
+| `utils.metal` | shared types and helpers |
+| `float8.metal` | FP8 encode/decode helpers |
+| `pagedattention.metal` | paged attention kernel (restructured) |
+| `reshape_and_cache.metal` | K/V cache reshape kernel |
+| `copy_blocks.metal` | block-level cache copy kernel |
+| `gather_kv_cache.metal` | *new* — gather KV from non-contiguous blocks |
+| `kv_scale_update.metal` | *new* — KV scale update for quantised caches |
+
+## Deprecation plan
+
+Neither kernel set will persist long-term. Both are slated for deprecation once we introduce first-class variable-length kernel support, which is a prerequisite for:
+
+- Continuous batching
+- Chunked prefill
+- MQA Scorer speculative decoding
diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py
new file mode 100644
index 0000000..2cbc142
--- /dev/null
+++ b/vllm_metal/metal/__init__.py
@@ -0,0 +1,96 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Native paged-attention Metal kernels dispatched through MLX.
+
+Usage::
+
+ from vllm_metal.metal import get_ops
+ ops = get_ops()
+ ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
+ ops.paged_attention_v1(out, query, key_cache, value_cache, ...)
+"""
+
+from __future__ import annotations
+
+import importlib
+import importlib.util
+import logging
+import re
+from pathlib import Path
+from types import ModuleType
+
+logger = logging.getLogger(__name__)
+
+_THIS_DIR = Path(__file__).resolve().parent
+_KERNELS_DIR = _THIS_DIR / "kernels"
+
+# Cached after first get_ops() call. The Metal shaders are JIT-compiled once
+# and held in MLX's library cache for the lifetime of the process. Editing
+# .metal source files requires restarting the Python interpreter to pick up
+# changes (the .cpp extension itself is rebuilt automatically by build.py when
+# paged_ops.cpp is newer than the .so).
+_ops_module: ModuleType | None = None
+
+
+def _read_metal_source(path: Path) -> str:
+ """Read a .metal file and strip local #include directives."""
+ text = path.read_text()
+ # Remove #include "..." for our vendored files (keep etc.)
+ text = re.sub(r'#include\s+"[^"]*"', "", text)
+ return text
+
+
+def _build_reshape_cache_source() -> str:
+ """Concatenate utils + float8 + reshape_and_cache into a single source."""
+ parts = [
+ _read_metal_source(_KERNELS_DIR / "utils.metal"),
+ _read_metal_source(_KERNELS_DIR / "float8.metal"),
+ _read_metal_source(_KERNELS_DIR / "cache" / "reshape_and_cache.metal"),
+ ]
+ return "\n".join(parts)
+
+
+def _build_paged_attention_source() -> str:
+ """Concatenate utils + float8 + paged_attention into a single source."""
+ parts = [
+ _read_metal_source(_KERNELS_DIR / "utils.metal"),
+ _read_metal_source(_KERNELS_DIR / "float8.metal"),
+ _read_metal_source(_KERNELS_DIR / "attention" / "paged_attention.metal"),
+ ]
+ return "\n".join(parts)
+
+
+def get_ops() -> ModuleType:
+ """JIT-build and import the native paged_ops extension.
+
+ The Metal shader sources are read, pre-processed (includes inlined),
+ and passed to the C++ extension which JIT-compiles them via
+ ``mlx::core::metal::Device::get_library()``.
+
+ Returns:
+ The ``_paged_ops`` module with ``reshape_and_cache()`` and
+ ``paged_attention_v1()``.
+ """
+ global _ops_module
+ if _ops_module is not None:
+ return _ops_module
+
+ # 1. JIT-build the C++ extension if needed
+ from vllm_metal.metal.build import build
+
+ so_path = build()
+
+ # 2. Import the built extension
+ spec = importlib.util.spec_from_file_location("_paged_ops", str(so_path))
+ if spec is None or spec.loader is None:
+ raise ImportError(f"Cannot load extension from {so_path}")
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+
+ # 3. Initialise Metal libraries (JIT-compile shaders)
+ reshape_src = _build_reshape_cache_source()
+ paged_attn_src = _build_paged_attention_source()
+ mod.init_libraries(reshape_src, paged_attn_src)
+
+ _ops_module = mod
+ logger.info("Native paged-attention Metal kernels loaded")
+ return mod
diff --git a/vllm_metal/metal/build.py b/vllm_metal/metal/build.py
new file mode 100644
index 0000000..137e519
--- /dev/null
+++ b/vllm_metal/metal/build.py
@@ -0,0 +1,116 @@
+# SPDX-License-Identifier: Apache-2.0
+"""JIT build script for the native paged-attention Metal extension.
+
+Compiles ``paged_ops.cpp`` + nanobind into a shared library that dispatches
+Metal shaders through MLX's own command encoder.
+"""
+
+from __future__ import annotations
+
+import logging
+import subprocess
+import sysconfig
+from pathlib import Path
+
+logger = logging.getLogger(__name__)
+
+_THIS_DIR = Path(__file__).resolve().parent
+_SRC = _THIS_DIR / "paged_ops.cpp"
+_EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX") or ".so"
+_CACHE_DIR = Path.home() / ".cache" / "vllm-metal"
+_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+_OUT = _CACHE_DIR / f"_paged_ops{_EXT_SUFFIX}"
+
+
+def _find_package_path(name: str) -> Path:
+ """Resolve a Python package's root directory."""
+ import importlib
+
+ mod = importlib.import_module(name)
+ paths = getattr(mod, "__path__", None)
+ if paths:
+ return Path(list(paths)[0])
+ f = getattr(mod, "__file__", None)
+ if f:
+ return Path(f).parent
+ raise RuntimeError(f"Cannot locate package '{name}'")
+
+
+def needs_rebuild() -> bool:
+ """Return True if the .so is missing or older than the source."""
+ if not _OUT.exists():
+ return True
+ src_mtime = _SRC.stat().st_mtime
+ return _OUT.stat().st_mtime < src_mtime
+
+
+def build() -> Path:
+ """JIT-build the native extension, returning the path to the .so."""
+ if not needs_rebuild():
+ return _OUT
+
+ logger.info("Building native paged-attention extension ...")
+
+ py_include = sysconfig.get_paths()["include"]
+ nb_path = _find_package_path("nanobind")
+ mlx_path = _find_package_path("mlx")
+ mlx_include = mlx_path / "include"
+ mlx_lib = mlx_path / "lib"
+ metal_cpp = mlx_include / "metal_cpp"
+
+ # Verify critical paths exist
+ for p, label in [
+ (py_include, "Python include"),
+ (nb_path, "nanobind"),
+ (mlx_include, "MLX include"),
+ (mlx_lib / "libmlx.dylib", "MLX lib"),
+ ]:
+ if not Path(p).exists():
+ raise FileNotFoundError(f"{label} not found: {p}")
+
+ nb_src = nb_path / "src" / "nb_combined.cpp"
+ if not nb_src.exists():
+ raise FileNotFoundError(f"nanobind source not found: {nb_src}")
+
+ cmd = [
+ "clang++",
+ "-std=c++17",
+ "-shared",
+ "-fPIC",
+ "-O2",
+ "-fvisibility=default",
+ f"-I{py_include}",
+ f"-I{nb_path / 'include'}",
+ f"-I{nb_path / 'src'}",
+ f"-I{nb_path / 'ext' / 'robin_map' / 'include'}",
+ f"-I{mlx_include}",
+ f"-I{metal_cpp}",
+ f"-L{mlx_lib}",
+ "-lmlx",
+ "-framework",
+ "Metal",
+ "-framework",
+ "Foundation",
+ f"-Wl,-rpath,{mlx_lib}",
+ "-D_METAL_",
+ "-DACCELERATE_NEW_LAPACK",
+ "-undefined",
+ "dynamic_lookup",
+ str(nb_src),
+ str(_SRC),
+ "-o",
+ str(_OUT),
+ ]
+
+ logger.info(" %s", " ".join(cmd))
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ raise RuntimeError(
+ f"Failed to build paged_ops extension:\n"
+ f"stdout:\n{result.stdout}\n"
+ f"stderr:\n{result.stderr}"
+ )
+
+ logger.info("Built %s", _OUT)
+ return _OUT
diff --git a/vllm_metal/metal/kernels/README.md b/vllm_metal/metal/kernels/README.md
new file mode 100644
index 0000000..32b4195
--- /dev/null
+++ b/vllm_metal/metal/kernels/README.md
@@ -0,0 +1,17 @@
+# Metal Kernel Sources
+
+Vendored from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license), via [HuggingFace kernels-community](https://github.com/huggingface/kernels-community).
+
+## Active (used by `paged_ops.cpp` via MLX dispatch)
+
+- `utils.metal` — bfloat16 polyfill, operator overloads
+- `float8.metal` — FP8 E4M3/E5M2 encode/decode helpers
+- `attention/paged_attention.metal` — paged attention v1/v2 kernels
+- `cache/reshape_and_cache.metal` — write projected K/V into block cache
+- `cache/copy_blocks.metal` — block-level cache copy kernel
+- `convert_fp8.metal` — FP8 precision conversion kernel
+
+## Reference only (not compiled, kept for future use)
+
+- `paged_attention.mm` — PyTorch MPS dispatch (Obj-C++), replaced by `../paged_ops.cpp`
+- `cache.mm` — PyTorch MPS cache ops (Obj-C++), replaced by `../paged_ops.cpp`
diff --git a/vllm_metal/metal/kernels/attention/paged_attention.metal b/vllm_metal/metal/kernels/attention/paged_attention.metal
new file mode 100644
index 0000000..22d972d
--- /dev/null
+++ b/vllm_metal/metal/kernels/attention/paged_attention.metal
@@ -0,0 +1,1401 @@
+// Updated from MLX commit has f70764a
+
+#include "../utils.metal"
+#include "../float8.metal"
+#include
+#include
+
+using namespace metal;
+
+// ========================================== Generic vector types
+
+// A vector type to store Q, K, V elements.
+template struct Vec {};
+
+// A vector type to store FP32 accumulators.
+template struct FloatVec {};
+
+// Template vector operations.
+template inline Acc mul(A a, B b);
+
+template inline float sum(T v);
+
+template inline float dot(T a, T b) {
+ return sum(mul(a, b));
+}
+
+template inline float dot(T a, T b) {
+ return sum(mul(a, b));
+}
+
+// FP32 vector data types.
+struct Float8_ {
+ float4 x;
+ float4 y;
+};
+
+template <> struct Vec {
+ using Type = float;
+};
+template <> struct Vec {
+ using Type = float2;
+};
+template <> struct Vec {
+ using Type = float4;
+};
+template <> struct Vec {
+ using Type = Float8_;
+};
+
+template <> struct FloatVec {
+ using Type = float;
+};
+template <> struct FloatVec {
+ using Type = float2;
+};
+template <> struct FloatVec {
+ using Type = float4;
+};
+template <> struct FloatVec {
+ using Type = Float8_;
+};
+
+template <> inline float mul(float a, float b) { return a * b; }
+
+template <> inline float2 mul(float2 a, float2 b) { return a * b; }
+
+template <> inline float4 mul(float4 a, float4 b) { return a * b; }
+
+template <> inline Float8_ mul(Float8_ a, Float8_ b) {
+ Float8_ c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ return c;
+}
+
+template <> inline float sum(float a) { return a; }
+
+template <> inline float sum(float2 a) { return a.x + a.y; }
+
+template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; }
+
+template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); }
+
+inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) {
+ Float8_ res;
+ res.x = fma(a.x, b.x, c.x);
+ res.y = fma(a.y, b.y, c.y);
+ return res;
+}
+
+inline void from_float(thread float &dst, float src) { dst = src; }
+inline void from_float(thread float2 &dst, float2 src) { dst = src; }
+inline void from_float(thread float4 &dst, float4 src) { dst = src; }
+inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; }
+
+// BF16 vector data types.
+// #if defined(__HAVE_BFLOAT__)
+
+// struct Bfloat8_ {
+// bfloat4 x;
+// bfloat4 y;
+// };
+
+// template<>
+// struct Vec {
+// using Type = bfloat;
+// };
+// template<>
+// struct Vec {
+// using Type = bfloat2;
+// };
+// template<>
+// struct Vec {
+// using Type = bfloat4;
+// };
+// template<>
+// struct Vec {
+// using Type = Bfloat8_;
+// };
+
+// template<>
+// struct FloatVec {
+// using Type = float;
+// };
+// template<>
+// struct FloatVec {
+// using Type = float2;
+// };
+// template<>
+// struct FloatVec {
+// using Type = float4;
+// };
+// template<>
+// struct FloatVec {
+// using Type = Float8_;
+// };
+
+// template<>
+// inline float mul(bfloat a, bfloat b) {
+// return (float)a * (float)b;
+// }
+// template<>
+// inline bfloat mul(bfloat a, bfloat b) {
+// return a*b;
+// }
+
+// template<>
+// inline float2 mul(bfloat2 a, bfloat2 b) {
+// return (float2)a * (float2)b;
+// }
+// template<>
+// inline bfloat2 mul(bfloat2 a, bfloat2 b) {
+// return a * b;
+// }
+
+// template<>
+// inline float4 mul(bfloat4 a, bfloat4 b) {
+// return (float4)a * (float4)b;
+// }
+// template<>
+// inline bfloat4 mul(bfloat4 a, bfloat4 b) {
+// return a * b;
+// }
+
+// template<>
+// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
+// Float8_ c;
+// c.x = mul(a.x, b.x);
+// c.y = mul(a.y, b.y);
+// return c;
+// }
+// template<>
+// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
+// Bfloat8_ c;
+// c.x = mul(a.x, b.x);
+// c.y = mul(a.y, b.y);
+// return c;
+// }
+
+// template<>
+// inline float sum(bfloat a) {
+// return (float)a;
+// }
+
+// template<>
+// inline float sum(bfloat2 a) {
+// return (float)a.x + (float)a.y;
+// }
+
+// template<>
+// inline float sum(bfloat4 a) {
+// return sum(a.x) + sum(a.y);
+// }
+
+// template<>
+// inline float sum(Bfloat8_ a) {
+// return sum(a.x) + sum(a.y);
+// }
+
+// inline float fma(bfloat a, bfloat b, float c) {
+// return (float)a * (float)b + c;
+// }
+
+// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) {
+// return (float2)a * (float2)b + c;
+// }
+
+// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) {
+// return (float4)a * (float4)b + c;
+// }
+
+// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
+// Float8_ res;
+// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x);
+// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y);
+// return res;
+// }
+// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
+// Bfloat8_ res;
+// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x);
+// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y);
+// return c;
+// }
+
+// inline void from_float(thread bfloat& dst, float src) {
+// dst = static_cast(src);
+// }
+// inline void from_float(thread bfloat2& dst, float2 src) {
+// dst.x = static_cast(src.x);
+// dst.y = static_cast(src.y);
+// }
+// inline void from_float(thread bfloat4& dst, float4 src) {
+// dst.x = static_cast(src.x);
+// dst.y = static_cast(src.y);
+// dst.z = static_cast(src.z);
+// dst.w = static_cast(src.w);
+// }
+// inline void from_float(thread Bfloat8_& dst, Float8_ src) {
+// bfloat4 x;
+// bfloat4 y;
+// from_float(x, src.x);
+// from_float(y, src.y);
+// dst.x = x;
+// dst.y = y;
+// }
+
+// #else
+
+struct Bfloat2_ {
+ bfloat16_t x;
+ bfloat16_t y;
+};
+
+struct Bfloat4_ {
+ Bfloat2_ x;
+ Bfloat2_ y;
+};
+
+struct Bfloat8_ {
+ Bfloat4_ x;
+ Bfloat4_ y;
+};
+
+template <> struct Vec {
+ using Type = bfloat16_t;
+};
+template <> struct Vec {
+ using Type = Bfloat2_;
+};
+template <> struct Vec {
+ using Type = Bfloat4_;
+};
+template <> struct Vec {
+ using Type = Bfloat8_;
+};
+
+template <> struct FloatVec {
+ using Type = float;
+};
+template <> struct FloatVec {
+ using Type = float2;
+};
+template <> struct FloatVec {
+ using Type = float4;
+};
+template <> struct FloatVec {
+ using Type = Float8_;
+};
+
+template <> inline float mul(bfloat16_t a, bfloat16_t b) {
+ return (float)a * (float)b;
+}
+template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; }
+
+template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) {
+ float2 a_f((float)a.x, (float)a.y);
+ float2 b_f((float)b.x, (float)b.y);
+ return a_f * b_f;
+}
+template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) {
+ Bfloat2_ c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ return c;
+}
+
+template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) {
+ float2 x = mul(a.x, b.x);
+ float2 y = mul(a.y, b.y);
+ float4 c;
+ c.x = x.x;
+ c.y = x.y;
+ c.z = y.x;
+ c.w = y.y;
+ return c;
+}
+template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) {
+ Bfloat4_ c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ return c;
+}
+
+template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
+ Float8_ c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ return c;
+}
+template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
+ Bfloat8_ c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ return c;
+}
+
+template <> inline float sum(bfloat16_t a) { return (float)a; }
+
+template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; }
+
+template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); }
+
+template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); }
+
+inline float fma(bfloat16_t a, bfloat16_t b, float c) {
+ return (float)a * (float)b + c;
+}
+inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) {
+ return a * b + c;
+}
+
+inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) {
+ float2 a_f((float)a.x, (float)a.y);
+ float2 b_f((float)b.x, (float)b.y);
+ return a_f * b_f + c;
+}
+inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) {
+ Bfloat2_ res;
+ res.x = a.x * b.x + c.x;
+ res.y = a.y * b.y + c.y;
+ return res;
+}
+
+inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) {
+ float4 res;
+ res.x = fma(a.x.x, b.x.x, c.x);
+ res.y = fma(a.x.y, b.x.y, c.y);
+ res.z = fma(a.y.x, b.y.x, c.z);
+ res.w = fma(a.y.y, b.y.y, c.w);
+ return res;
+}
+inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) {
+ Bfloat4_ res;
+ res.x = fma(a.x, b.x, c.x);
+ res.y = fma(a.y, b.y, c.y);
+ return res;
+}
+
+inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
+ float4 x = fma(a.x, b.x, c.x);
+ float4 y = fma(a.y, b.y, c.y);
+ Float8_ res;
+ res.x = x;
+ res.y = y;
+ return res;
+}
+inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
+ Bfloat8_ res;
+ res.x = fma(a.x, b.x, c.x);
+ res.y = fma(a.y, b.y, c.y);
+ return res;
+}
+
+inline void from_float(thread bfloat16_t &dst, float src) {
+ dst = static_cast(src);
+}
+inline void from_float(thread Bfloat2_ &dst, float2 src) {
+ dst.x = static_cast(src.x);
+ dst.y = static_cast(src.y);
+}
+inline void from_float(thread Bfloat4_ &dst, float4 src) {
+ dst.x.x = static_cast(src.x);
+ dst.x.y = static_cast(src.y);
+ dst.y.x = static_cast(src.z);
+ dst.y.y = static_cast(src.w);
+}
+inline void from_float(thread Bfloat8_ &dst, Float8_ src) {
+ Bfloat4_ x;
+ Bfloat4_ y;
+ from_float(x, src.x);
+ from_float(y, src.y);
+ dst.x = x;
+ dst.y = y;
+}
+
+// #endif
+
+// FP16 vector data types.
+struct Half8_ {
+ half4 x;
+ half4 y;
+};
+
+template <> struct Vec {
+ using Type = half;
+};
+template <> struct Vec {
+ using Type = half2;
+};
+template <> struct Vec {
+ using Type = half4;
+};
+template <> struct Vec {
+ using Type = Half8_;
+};
+
+template <> struct FloatVec {
+ using Type = float;
+};
+template <> struct FloatVec {
+ using Type = float2;
+};
+template <> struct FloatVec {
+ using Type = float4;
+};
+template <> struct FloatVec {
+ using Type = Float8_;
+};
+
+template <> inline float mul(half a, half b) { return (float)a * (float)b; }
+template <> inline half mul(half a, half b) { return a * b; }
+
+template <> inline float2 mul(half2 a, half2 b) {
+ return (float2)a * (float2)b;
+}
+template <> inline half2 mul(half2 a, half2 b) { return a * b; }
+
+template <> inline float4 mul(half4 a, half4 b) {
+ return (float4)a * (float4)b;
+}
+template <> inline half4 mul(half4 a, half4 b) { return a * b; }
+
+template <> inline Float8_ mul(Half8_ a, Half8_ b) {
+ float4 x = mul(a.x, b.x);
+ float4 y = mul(a.y, b.y);
+ Float8_ c;
+ c.x = x;
+ c.y = y;
+ return c;
+}
+template <> inline Half8_ mul(Half8_ a, Half8_ b) {
+ Half8_ c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ return c;
+}
+
+template <> inline float sum(half a) { return (float)a; }
+
+template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; }
+
+template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; }
+
+template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); }
+
+inline float fma(half a, half b, float c) { return (float)a * (float)b + c; }
+
+inline float2 fma(half2 a, half2 b, float2 c) {
+ return (float2)a * (float2)b + c;
+}
+
+inline float4 fma(half4 a, half4 b, float4 c) {
+ return (float4)a * (float4)b + c;
+}
+
+inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) {
+ float4 x = fma(a.x, b.x, c.x);
+ float4 y = fma(a.y, b.y, c.y);
+ Float8_ res;
+ res.x = x;
+ res.y = y;
+ return res;
+}
+inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) {
+ Half8_ res;
+ res.x = fma(a.x, b.x, c.x);
+ res.y = fma(a.y, b.y, c.y);
+ return res;
+}
+
+inline void from_float(thread half &dst, float src) {
+ dst = static_cast(src);
+}
+inline void from_float(thread half2 &dst, float2 src) {
+ dst.x = static_cast(src.x);
+ dst.y = static_cast(src.y);
+}
+inline void from_float(thread half4 &dst, float4 src) {
+ dst.x = static_cast(src.x);
+ dst.y = static_cast(src.y);
+ dst.z = static_cast(src.z);
+ dst.w = static_cast(src.w);
+}
+inline void from_float(thread Half8_ &dst, Float8_ src) {
+ half4 x;
+ half4 y;
+ from_float(x, src.x);
+ from_float(y, src.y);
+ dst.x = x;
+ dst.y = y;
+}
+
+// ========================================== FP8 (uchar) vector data types.
+
+// 8‑lane uchar vector – Metal only provides up to uchar4, so build our own.
+struct Uchar8_ {
+ uchar4 x;
+ uchar4 y;
+};
+
+// Vec specialisations so Vec::Type resolves correctly.
+template <> struct Vec {
+ using Type = uchar;
+};
+template <> struct Vec {
+ using Type = uchar2;
+};
+template <> struct Vec {
+ using Type = uchar4;
+};
+template <> struct Vec {
+ using Type = Uchar8_;
+};
+
+// General case: not uchar
+template inline constexpr bool is_uchar() { return false; }
+
+// Specialization: T is uchar
+template <> inline constexpr bool is_uchar() { return true; }
+
+// Generic fallback – will fail to compile if a required specialisation is
+// missing.
+template
+inline Vec fp8_convert(const thread Quant_vec &, float scale) {
+ static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation");
+}
+
+// ========================================== FP8 → float/half/bfloat
+inline float __dequant_single(uchar v, float scale) {
+ return fp8_e4m3_to_float(v) * scale;
+}
+
+// ---- 1‑lane ----
+template <>
+inline float fp8_convert(const thread uchar &in, float scale) {
+ return __dequant_single(in, scale);
+}
+template <>
+inline half fp8_convert(const thread uchar &in, float scale) {
+ return half(__dequant_single(in, scale));
+}
+template <>
+inline bfloat16_t fp8_convert(const thread uchar &in,
+ float scale) {
+ return bfloat16_t(__dequant_single(in, scale));
+}
+
+// ---- 2‑lane ----
+template <>
+inline float2 fp8_convert(const thread uchar2 &in,
+ float scale) {
+ return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale));
+}
+template <>
+inline half2 fp8_convert(const thread uchar2 &in, float scale) {
+ half2 out;
+ out.x = half(__dequant_single(in.x, scale));
+ out.y = half(__dequant_single(in.y, scale));
+ return out;
+}
+template <>
+inline Bfloat2_ fp8_convert(const thread uchar2 &in,
+ float scale) {
+ Bfloat2_ out;
+ out.x = bfloat16_t(__dequant_single(in.x, scale));
+ out.y = bfloat16_t(__dequant_single(in.y, scale));
+ return out;
+}
+
+// ---- 4‑lane ----
+template <>
+inline float4 fp8_convert(const thread uchar4 &in,
+ float scale) {
+ return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale),
+ __dequant_single(in.z, scale), __dequant_single(in.w, scale));
+}
+template <>
+inline half4 fp8_convert(const thread uchar4 &in, float scale) {
+ half4 out;
+ out.x = half(__dequant_single(in.x, scale));
+ out.y = half(__dequant_single(in.y, scale));
+ out.z = half(__dequant_single(in.z, scale));
+ out.w = half(__dequant_single(in.w, scale));
+ return out;
+}
+template <>
+inline Bfloat4_ fp8_convert(const thread uchar4 &in,
+ float scale) {
+ Bfloat4_ out;
+ out.x.x = bfloat16_t(__dequant_single(in.x, scale));
+ out.x.y = bfloat16_t(__dequant_single(in.y, scale));
+ out.y.x = bfloat16_t(__dequant_single(in.z, scale));
+ out.y.y = bfloat16_t(__dequant_single(in.w, scale));
+ return out;
+}
+
+// ---- 8‑lane ----
+template <>
+inline Float8_ fp8_convert(const thread Uchar8_ &in,
+ float scale) {
+ Float8_ out;
+ out.x =
+ float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale),
+ __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale));
+ out.y =
+ float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale),
+ __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale));
+ return out;
+}
+template <>
+inline Half8_ fp8_convert(const thread Uchar8_ &in,
+ float scale) {
+ Half8_ out;
+ out.x = half4(half(__dequant_single(in.x.x, scale)),
+ half(__dequant_single(in.x.y, scale)),
+ half(__dequant_single(in.x.z, scale)),
+ half(__dequant_single(in.x.w, scale)));
+ out.y = half4(half(__dequant_single(in.y.x, scale)),
+ half(__dequant_single(in.y.y, scale)),
+ half(__dequant_single(in.y.z, scale)),
+ half(__dequant_single(in.y.w, scale)));
+ return out;
+}
+template <>
+inline Bfloat8_ fp8_convert(const thread Uchar8_ &in,
+ float scale) {
+ Bfloat8_ out;
+ // first 4
+ out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale));
+ out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale));
+ out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale));
+ out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale));
+ // second 4
+ out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale));
+ out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale));
+ out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale));
+ out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale));
+ return out;
+}
+
+// ========================================== Dot product utilities
+
+// TODO(EricLBuehler): optimize with vectorization
+template
+inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) {
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
+ using A_vec = typename FloatVec::Type;
+ A_vec qk_vec = mul(q[0], k[0]);
+#pragma unroll
+ for (int ii = 1; ii < N; ++ii) {
+ qk_vec = fma(q[ii], k[ii], qk_vec);
+ }
+
+ // Finalize the reduction across lanes.
+ float qk = sum(qk_vec);
+#pragma unroll
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
+ qk += simd_shuffle_xor(qk, mask);
+ }
+ return qk;
+}
+
+template struct Qk_dot {
+ template
+ static inline float dot(const threadgroup Vec (&q)[N],
+ const thread Vec (&k)[N]) {
+ return qk_dot_(q, k);
+ }
+};
+
+// ========================================== Block sum utility
+
+// Utility function for attention softmax.
+template
+inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
+ uint simd_lid) {
+ // Compute the sum per simdgroup.
+#pragma unroll
+ for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
+ sum += simd_shuffle_xor(sum, mask);
+ }
+
+ // Simd leaders store the data to shared memory.
+ if (simd_lid == 0) {
+ red_smem[simd_tid] = sum;
+ }
+
+ // Make sure the data is in shared memory.
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // The warps compute the final sums.
+ if (simd_lid < NUM_WARPS) {
+ sum = red_smem[simd_lid];
+ }
+
+ // Parallel reduction inside the simd group.
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ sum += simd_shuffle_xor(sum, mask);
+ }
+
+ // Broadcast to other threads.
+ return simd_shuffle(sum, 0);
+}
+
+// ========================================== Paged Attention kernel
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+constant bool use_partitioning [[function_constant(10)]];
+constant bool use_alibi [[function_constant(20)]];
+constant bool use_fp8_scales [[function_constant(30)]];
+
+template
+[[kernel]] void paged_attention(
+ device float *exp_sums
+ [[buffer(0)]], // [num_seqs, num_heads, max_num_partitions] - only used when
+ // use_partitioning
+ device float *max_logits
+ [[buffer(1)]], // [num_seqs, num_heads, max_num_partitions] - only used when
+ // use_partitioning
+ device T *out
+ [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size]
+ device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size]
+ device const CACHE_T *k_cache
+ [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ device const CACHE_T *v_cache
+ [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size]
+ const device float *__restrict__ k_scale
+ [[buffer(6)]], // [1] - only used when use_fp8_scales
+ const device float *__restrict__ v_scale
+ [[buffer(7)]], // [1] - only used when use_fp8_scales
+ const constant int &num_kv_heads [[buffer(8)]], // [num_heads]
+ const constant float &scale [[buffer(9)]],
+ const constant float &softcapping [[buffer(10)]],
+ device const uint32_t *block_tables
+ [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq]
+ device const uint32_t *context_lens [[buffer(12)]], // [num_seqs]
+ const constant int &max_num_blocks_per_seq [[buffer(13)]],
+ device const float *alibi_slopes
+ [[buffer(14)]], // [num_heads] - only used when use_alibi
+ const constant int &q_stride [[buffer(15)]],
+ const constant int &kv_block_stride [[buffer(16)]],
+ const constant int &kv_head_stride [[buffer(17)]],
+ threadgroup char *shared_mem [[threadgroup(0)]],
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]],
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
+ uint simd_tid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ const int seq_idx = threadgroup_position_in_grid.y;
+ const int partition_idx = threadgroup_position_in_grid.z;
+ const int max_num_partitions = threadgroups_per_grid.z;
+ const int thread_idx = thread_position_in_threadgroup.x;
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+ const uint32_t context_len = context_lens[seq_idx];
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+ // No work to do. Terminate the thread block.
+ return;
+ }
+
+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+ const int num_blocks_per_partition =
+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
+ const int start_block_idx =
+ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+ const int end_block_idx =
+ MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+ const int num_blocks = end_block_idx - start_block_idx;
+
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
+ const int end_token_idx =
+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+ const int num_tokens = end_token_idx - start_token_idx;
+
+ constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1);
+ constexpr int NUM_THREAD_GROUPS =
+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
+ // divides NUM_THREADS
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP =
+ DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES);
+ constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
+ const int warp_idx = simd_tid;
+ const int lane = simd_lid;
+
+ const int head_idx = threadgroup_position_in_grid.x;
+ const int num_heads = threadgroups_per_grid.x;
+ const int num_queries_per_kv = num_heads / num_kv_heads;
+ const int kv_head_idx = head_idx / num_queries_per_kv;
+ const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx];
+
+ // A vector type to store a part of a key or a query.
+ // The vector size is configured in such a way that the threads in a thread
+ // group fetch or compute 16 bytes at a time. For example, if the size of a
+ // thread group is 4 and the data type is half, then the vector size is 16 /
+ // (4 * sizeof(half)) == 2.
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1);
+ using K_vec = typename Vec::Type;
+ using Q_vec = typename Vec::Type;
+ using Quant_vec = typename Vec::Type;
+
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+ // Load the query to registers.
+ // Each thread in a thread group has a different part of the query.
+ // For example, if the thread group size is 4, then the first thread in the
+ // group has 0, 4, 8, ... th vectors of the query, and the second thread has
+ // 1, 5, 9, ... th vectors of the query, and so on.
+ const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+ threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
+ i += NUM_THREAD_GROUPS) {
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+ q_vecs[thread_group_offset][i] =
+ *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Use fp32 on softmax logits for better accuracy
+ threadgroup float *logits = reinterpret_cast(shared_mem);
+ // Workspace for reduction
+ threadgroup float red_smem[2 * NUM_WARPS];
+
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
+ // Each thread group fetches x elements from the key at a time.
+ constexpr int x = 16 / sizeof(CACHE_T);
+ float qk_max = -FLT_MAX;
+
+ // Iterate over the key blocks.
+ // Each warp fetches a block of keys for each iteration.
+ // Each thread group in a warp fetches a key from the block, and computes
+ // dot product with the query.
+ const device uint32_t *block_table =
+ block_tables + seq_idx * max_num_blocks_per_seq;
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+ block_idx += NUM_WARPS) {
+ // NOTE: The block number is stored in int32. However, we cast it to int64
+ // because int32 can lead to overflow when this variable is multiplied by
+ // large numbers (e.g., kv_block_stride).
+ const int64_t physical_block_number =
+ static_cast(block_table[block_idx]);
+
+ // Load a key to registers.
+ // Each thread in a thread group has a different part of the key.
+ // For example, if the thread group size is 4, then the first thread in the
+ // group has 0, 4, 8, ... th vectors of the key, and the second thread has
+ // 1, 5, 9, ... th vectors of the key, and so on.
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+ const int physical_block_offset =
+ (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+ const device CACHE_T *k_ptr =
+ k_cache + physical_block_number * kv_block_stride +
+ kv_head_idx * kv_head_stride + physical_block_offset * x;
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
+
+ if constexpr (is_uchar()) {
+ // FP8 support
+ Quant_vec k_vec_quant = *reinterpret_cast(
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ k_vecs[j] = fp8_convert(k_vec_quant, *k_scale);
+ } else {
+ // Non-FP8 default
+ k_vecs[j] = *reinterpret_cast(
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ }
+ }
+
+ // Compute dot product.
+ // This includes a reduction across the threads in the same thread group.
+ float qk = scale * Qk_dot::dot(
+ q_vecs[thread_group_offset], k_vecs);
+
+ // Apply softcapping
+ if (softcapping != 1.0) {
+ qk = precise::tanh(qk / softcapping) * softcapping;
+ }
+
+ // Add the ALiBi bias if slopes are given.
+ if (use_alibi && alibi_slope != 0) {
+ // Compute bias with explicit float precision to minimize precision loss
+ int position_offset = token_idx - int(context_len) + 1;
+ float alibi_bias = alibi_slope * float(position_offset);
+ qk += alibi_bias;
+ }
+
+ if (thread_group_offset == 0) {
+ // Store the partial reductions to shared memory.
+ // NOTE: It is required to zero out the masked logits.
+ const bool mask = token_idx >= context_len;
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+ // Update the max value.
+ qk_max = mask ? qk_max : max(qk_max, qk);
+ }
+ }
+ }
+
+ // Perform reduction across the threads in the same warp to get the
+ // max qk value for each "warp" (not across the thread block yet).
+ // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+ for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+ qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = qk_max;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Get the max qk value for the sequence.
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
+ }
+ // Broadcast the max qk value to all threads.
+ qk_max = simd_shuffle(qk_max, 0);
+
+ // Get the sum of the exp values.
+ float exp_sum = 0.f;
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ float val = exp(logits[i] - qk_max);
+ logits[i] = val;
+ exp_sum += val;
+ }
+ exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum,
+ simd_tid, simd_lid);
+
+ // Compute softmax.
+ const float inv_sum = divide(1.f, exp_sum + 1e-6f);
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ logits[i] *= inv_sum;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // If partitioning is enabled, store the max logit and exp_sum.
+ if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
+ device float *max_logits_ptr =
+ max_logits + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions + partition_idx;
+ *max_logits_ptr = qk_max;
+ device float *exp_sums_ptr = exp_sums +
+ seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions + partition_idx;
+ *exp_sums_ptr = exp_sum;
+ }
+
+ // Each thread will fetch 16 bytes from the value cache at a time.
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE);
+ using V_vec = typename Vec::Type;
+ using L_vec = typename Vec::Type;
+ using Float_L_vec = typename FloatVec::Type;
+ using V_quant_vec = typename Vec::Type;
+
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+ constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW;
+ constexpr int NUM_ROWS_PER_THREAD =
+ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+
+ // NOTE: We use FP32 for the accumulator for better accuracy.
+ float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ accs[i] = 0.f;
+ }
+
+ T zero_value = 0;
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+ block_idx += NUM_WARPS) {
+ // NOTE: The block number is stored in int32. However, we cast it to int64
+ // because int32 can lead to overflow when this variable is multiplied by
+ // large numbers (e.g., kv_block_stride).
+ const int64_t physical_block_number =
+ static_cast(block_table[block_idx]);
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ L_vec logits_vec;
+ Float_L_vec logits_float_vec = *reinterpret_cast(
+ logits + token_idx - start_token_idx);
+ from_float(logits_vec, logits_float_vec);
+
+ const device CACHE_T *v_ptr = v_cache + physical_block_number * kv_block_stride +
+ kv_head_idx * kv_head_stride;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE) {
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+ // NOTE: When v_vec contains the tokens that are out of the context,
+ // we should explicitly zero out the values since they may contain NaNs.
+ // See
+ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
+ V_vec v_vec;
+
+ if constexpr (is_uchar()) {
+ // FP8 support
+ V_quant_vec v_quant_vec =
+ *reinterpret_cast(v_ptr + offset);
+ v_vec = fp8_convert(v_quant_vec, *v_scale);
+ } else {
+ // Non-FP8 default
+ v_vec = *reinterpret_cast(v_ptr + offset);
+ }
+
+ if (block_idx == num_context_blocks - 1) {
+ thread T *v_vec_ptr = reinterpret_cast(&v_vec);
+#pragma unroll
+ for (int j = 0; j < V_VEC_SIZE; j++) {
+ v_vec_ptr[j] =
+ token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+ }
+ }
+ accs[i] += dot(logits_vec, v_vec);
+ }
+ }
+ }
+
+ // Perform reduction within each warp.
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ float acc = accs[i];
+#pragma unroll
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+ acc += simd_shuffle_xor(acc, mask);
+ }
+ accs[i] = acc;
+ }
+
+ // NOTE: A barrier is required because the shared memory space for logits
+ // is reused for the output.
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Perform reduction across warps.
+ threadgroup float *out_smem =
+ reinterpret_cast(shared_mem);
+#pragma unroll
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
+ int mid = i / 2;
+ // Upper warps write to shared memory.
+ if (warp_idx >= mid && warp_idx < i) {
+ threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ dst[row_idx] = accs[i];
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Lower warps update the output.
+ if (warp_idx < mid) {
+ const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ accs[i] += src[row_idx];
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ // Write the final output.
+ if (warp_idx == 0) {
+ device T *out_ptr =
+ out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ *(out_ptr + row_idx) = T(accs[i]);
+ }
+ }
+ }
+}
+
+template
+[[kernel]] void paged_attention_v2_reduce(
+ device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]],
+ const device float *max_logits [[buffer(2)]],
+ const device T *tmp_out [[buffer(3)]],
+ device uint32_t *context_lens [[buffer(4)]],
+ const constant int &max_num_partitions [[buffer(5)]],
+ threadgroup char *shared_mem [[threadgroup(0)]],
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]],
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
+ uint3 threads_per_threadgroup [[threads_per_threadgroup]],
+ uint simd_tid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ const int num_heads = threadgroups_per_grid.x;
+ const int head_idx = threadgroup_position_in_grid.x;
+ const int seq_idx = threadgroup_position_in_grid.y;
+ const uint32_t context_len = context_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+ if (num_partitions == 1) {
+ // No need to reduce. Only copy tmp_out to out.
+ device T *out_ptr =
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ const device T *tmp_out_ptr =
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE;
+ for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
+ i += threads_per_threadgroup.x) {
+ out_ptr[i] = tmp_out_ptr[i];
+ }
+ // Terminate the thread block.
+ return;
+ }
+
+ constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
+ const int warp_idx = simd_tid;
+ const int lane = simd_lid;
+
+ // Workspace for reduction.
+ threadgroup float red_smem[2 * NUM_WARPS];
+
+ // Load max logits to shared memory.
+ threadgroup float *shared_max_logits =
+ reinterpret_cast(shared_mem);
+ const device float *max_logits_ptr =
+ max_logits + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions;
+ float max_logit = -FLT_MAX;
+ for (int i = thread_position_in_threadgroup.x; i < num_partitions;
+ i += threads_per_threadgroup.x) {
+ const float l = max_logits_ptr[i];
+ shared_max_logits[i] = l;
+ max_logit = max(max_logit, l);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Get the global max logit.
+ // Reduce within the warp.
+#pragma unroll
+ for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
+ max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = max_logit;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ // Reduce across warps.
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
+ }
+ // Broadcast the max value to all threads.
+ max_logit = simd_shuffle(max_logit, 0);
+
+ // Load rescaled exp sums to shared memory.
+ threadgroup float *shared_exp_sums = reinterpret_cast(
+ shared_mem + sizeof(float) * num_partitions);
+ const device float *exp_sums_ptr = exp_sums +
+ seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions;
+ float global_exp_sum = 0.0f;
+ for (int i = thread_position_in_threadgroup.x; i < num_partitions;
+ i += threads_per_threadgroup.x) {
+ float l = shared_max_logits[i];
+ float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit);
+ global_exp_sum += rescaled_exp_sum;
+ shared_exp_sums[i] = rescaled_exp_sum;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ global_exp_sum = block_sum(
+ &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid);
+ const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f);
+
+ // Aggregate tmp_out to out.
+ const device T *tmp_out_ptr =
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE;
+ device T *out_ptr =
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
+ i += NUM_THREADS) {
+ float acc = 0.0f;
+ for (int j = 0; j < num_partitions; ++j) {
+ acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
+ inv_global_exp_sum;
+ }
+ out_ptr[i] = T(acc);
+ }
+}
+
+#define instantiate_paged_attention_inner(type, cache_type, head_size, \
+ block_size, num_threads, \
+ num_simd_lanes, partition_size) \
+ template [[host_name("paged_attention_" #type "_cache_" #cache_type \
+ "_hs" #head_size "_bs" #block_size "_nt" #num_threads \
+ "_nsl" #num_simd_lanes \
+ "_ps" #partition_size)]] [[kernel]] void \
+ paged_attention( \
+ device float *exp_sums [[buffer(0)]], \
+ device float *max_logits [[buffer(1)]], \
+ device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \
+ device const cache_type *k_cache [[buffer(4)]], \
+ device const cache_type *v_cache [[buffer(5)]], \
+ const device float *__restrict__ k_scale [[buffer(6)]], \
+ const device float *__restrict__ v_scale [[buffer(7)]], \
+ const constant int &num_kv_heads [[buffer(8)]], \
+ const constant float &scale [[buffer(9)]], \
+ const constant float &softcapping [[buffer(10)]], \
+ device const uint32_t *block_tables [[buffer(11)]], \
+ device const uint32_t *context_lens [[buffer(12)]], \
+ const constant int &max_num_blocks_per_seq [[buffer(13)]], \
+ device const float *alibi_slopes [[buffer(14)]], \
+ const constant int &q_stride [[buffer(15)]], \
+ const constant int &kv_block_stride [[buffer(16)]], \
+ const constant int &kv_head_stride [[buffer(17)]], \
+ threadgroup char *shared_mem [[threadgroup(0)]], \
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]);
+
+#define instantiate_paged_attention_v2_reduce_inner( \
+ type, head_size, num_threads, num_simd_lanes, partition_size) \
+ template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
+ "_nt" #num_threads "_nsl" #num_simd_lanes \
+ "_ps" #partition_size)]] [[kernel]] void \
+ paged_attention_v2_reduce( \
+ device type * out [[buffer(0)]], \
+ const device float *exp_sums [[buffer(1)]], \
+ const device float *max_logits [[buffer(2)]], \
+ const device type *tmp_out [[buffer(3)]], \
+ device uint32_t *context_lens [[buffer(4)]], \
+ const constant int &max_num_partitions [[buffer(5)]], \
+ threadgroup char *shared_mem [[threadgroup(0)]], \
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
+ uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]);
+
+#define instantiate_paged_attention_heads( \
+ type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \
+ instantiate_paged_attention_inner(type, cache_type, 32, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 64, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 80, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 96, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 112, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 120, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 128, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 192, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size); \
+ instantiate_paged_attention_inner(type, cache_type, 256, block_size, \
+ num_threads, num_simd_lanes, \
+ partition_size);
+
+#define instantiate_paged_attention_v2_reduce_heads( \
+ type, num_threads, num_simd_lanes, partition_size) \
+ instantiate_paged_attention_v2_reduce_inner(type, 32, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 120, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \
+ num_simd_lanes, partition_size);
+
+#define instantiate_paged_attention_block_size(type, cache_type, num_threads, \
+ num_simd_lanes, partition_size) \
+ instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \
+ num_simd_lanes, partition_size); \
+ instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \
+ num_simd_lanes, partition_size);
+
+// TODO: tune num_threads = 256
+// NOTE: partition_size = 0
+#define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \
+ instantiate_paged_attention_block_size(type, cache_type, 256, \
+ num_simd_lanes, 0);
+
+// TODO: tune num_threads = 256
+// NOTE: partition_size = 512
+#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \
+ instantiate_paged_attention_block_size(type, cache_type, 256, \
+ num_simd_lanes, 512);
+
+// TODO: tune num_threads = 256
+// NOTE: partition_size = 512
+#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \
+ instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
+
+instantiate_paged_attention_v1(float, float, 32);
+instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32);
+instantiate_paged_attention_v1(half, half, 32);
+
+instantiate_paged_attention_v1(float, uchar, 32);
+instantiate_paged_attention_v1(bfloat16_t, uchar, 32);
+instantiate_paged_attention_v1(half, uchar, 32);
+
+instantiate_paged_attention_v2_reduce(float, 32);
+instantiate_paged_attention_v2_reduce(bfloat16_t, 32);
+instantiate_paged_attention_v2_reduce(half, 32);
+
+instantiate_paged_attention_v2(float, float, 32);
+instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32);
+instantiate_paged_attention_v2(half, half, 32);
+
+instantiate_paged_attention_v2(float, uchar, 32);
+instantiate_paged_attention_v2(bfloat16_t, uchar, 32);
+instantiate_paged_attention_v2(half, uchar, 32);
diff --git a/vllm_metal/metal/kernels/cache.mm b/vllm_metal/metal/kernels/cache.mm
new file mode 100644
index 0000000..cf67260
--- /dev/null
+++ b/vllm_metal/metal/kernels/cache.mm
@@ -0,0 +1,562 @@
+#include
+#include
+#include
+
+#import
+#import
+#include
+#include
+#include
+
+static inline id getMTLBufferStorage(const torch::Tensor &tensor) {
+ return __builtin_bit_cast(id, tensor.storage().data());
+}
+
+static std::string getModuleDirectory() {
+ Dl_info dl_info;
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
+ std::string path(dl_info.dli_fname);
+ size_t pos = path.find_last_of('/');
+ if (pos != std::string::npos) {
+ return path.substr(0, pos);
+ }
+ }
+ return ".";
+}
+
+void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
+ const torch::Tensor &block_mapping) {
+ TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
+
+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+ const int64_t num_blocks = block_mapping.size(0);
+
+ // Handle different device combinations
+ if (src.device().is_mps() && dst.device().is_mps()) {
+ // MPS to MPS: Use Metal blit encoder
+ @autoreleasepool {
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
+
+ id commandBuffer = stream->commandBuffer();
+ TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
+
+ dispatch_queue_t serialQueue = stream->queue();
+
+ dispatch_sync(serialQueue, ^{
+ id blitEncoder =
+ [commandBuffer blitCommandEncoder];
+ TORCH_CHECK(blitEncoder, "Failed to create blit command encoder");
+
+ id srcBuf = getMTLBufferStorage(src);
+ id dstBuf = getMTLBufferStorage(dst);
+
+ for (int64_t i = 0; i < num_blocks; ++i) {
+ int64_t src_block_number = block_mapping[i][0].item();
+ int64_t dst_block_number = block_mapping[i][1].item();
+ NSUInteger src_offset = src_block_number * block_size_in_bytes;
+ NSUInteger dst_offset = dst_block_number * block_size_in_bytes;
+
+ [blitEncoder copyFromBuffer:srcBuf
+ sourceOffset:src_offset
+ toBuffer:dstBuf
+ destinationOffset:dst_offset
+ size:block_size_in_bytes];
+ }
+
+ [blitEncoder endEncoding];
+ stream->synchronize(at::mps::SyncType::COMMIT);
+ });
+ }
+ } else {
+ // Cross-device transfers (MPS-CPU, CPU-MPS, CPU-CPU): Use PyTorch's copy
+ for (int64_t i = 0; i < num_blocks; ++i) {
+ int64_t src_block_number = block_mapping[i][0].item();
+ int64_t dst_block_number = block_mapping[i][1].item();
+
+ // Copy the entire block
+ dst[dst_block_number].copy_(src[src_block_number]);
+ }
+ }
+}
+
+void copy_blocks(const std::vector &key_caches,
+ const std::vector &value_caches,
+ const torch::Tensor &block_mapping) {
+ const int64_t num_layers = key_caches.size();
+ TORCH_CHECK(num_layers == static_cast(value_caches.size()),
+ "key_caches and value_caches must have the same length");
+ if (num_layers == 0) {
+ return;
+ }
+
+ // --- Preconditions --------------------------------------------------
+ torch::Device dev = key_caches[0].device();
+ TORCH_CHECK(dev.is_mps(), "copy_blocks: expected MPS tensors");
+
+ // Move block_mapping to CPU if it's on MPS
+ torch::Tensor block_mapping_cpu = block_mapping;
+ if (block_mapping.device().is_mps()) {
+ block_mapping_cpu = block_mapping.cpu();
+ }
+
+ for (int64_t i = 0; i < num_layers; ++i) {
+ TORCH_CHECK(key_caches[i].device() == dev &&
+ value_caches[i].device() == dev,
+ "All cache tensors must be on the same MPS device");
+ TORCH_CHECK(key_caches[i].dtype() == value_caches[i].dtype(),
+ "Key/value cache dtype mismatch at layer ", i);
+ }
+
+ const int64_t num_pairs = block_mapping.size(0);
+ const int32_t numel_per_block =
+ static_cast(key_caches[0][0].numel());
+
+ @autoreleasepool {
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
+
+ id device = stream->device();
+ id cmdBuf = stream->commandBuffer();
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
+
+ // Construct the full path to the metallib file
+ std::string moduleDir = getModuleDirectory();
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
+
+ NSString *metallibPathStr =
+ [NSString stringWithUTF8String:metallibPath.c_str()];
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
+ NSError *error = nil;
+ id lib = [device newLibraryWithURL:metallibURL error:&error];
+ if (!lib) {
+ NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@",
+ metallibPathStr, error.localizedDescription);
+ }
+
+ // Process each layer separately
+ for (int64_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+ NSString *kernName = nil;
+ switch (key_caches[layer_idx].scalar_type()) {
+ case torch::kFloat:
+ kernName = @"copy_blocks_float";
+ break;
+ case torch::kHalf:
+ kernName = @"copy_blocks_half";
+ break;
+ case torch::kBFloat16:
+ kernName = @"copy_blocks_bfloat16_t";
+ break;
+ case torch::kUInt8:
+ kernName = @"copy_blocks_uchar";
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported dtype for copy_blocks");
+ }
+
+ id fn = [lib newFunctionWithName:kernName];
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String);
+
+ id pso =
+ [device newComputePipelineStateWithFunction:fn error:&error];
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
+
+ dispatch_queue_t q = stream->queue();
+ dispatch_sync(q, ^{
+ id enc = [cmdBuf computeCommandEncoder];
+ TORCH_CHECK(enc, "Failed to create compute encoder");
+
+ [enc setComputePipelineState:pso];
+
+ // Set key and value cache buffers
+ [enc setBuffer:getMTLBufferStorage(key_caches[layer_idx])
+ offset:key_caches[layer_idx].storage_offset() *
+ key_caches[layer_idx].element_size()
+ atIndex:0];
+ [enc setBuffer:getMTLBufferStorage(value_caches[layer_idx])
+ offset:value_caches[layer_idx].storage_offset() *
+ value_caches[layer_idx].element_size()
+ atIndex:1];
+
+ // Set block mapping buffer
+ id mappingBuf =
+ [device newBufferWithBytes:block_mapping_cpu.data_ptr()
+ length:num_pairs * 2 * sizeof(int64_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:mappingBuf offset:0 atIndex:2];
+
+ // Set numel_per_block as buffer
+ id numelBuf =
+ [device newBufferWithBytes:&numel_per_block
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:numelBuf offset:0 atIndex:3];
+
+ const uint32_t threadsPerThreadgroup =
+ std::min(256, numel_per_block);
+ MTLSize tg = MTLSizeMake(threadsPerThreadgroup, 1, 1);
+ MTLSize grid = MTLSizeMake(threadsPerThreadgroup * num_pairs, 1, 1);
+
+ [enc dispatchThreads:grid threadsPerThreadgroup:tg];
+ [enc endEncoding];
+ });
+ }
+
+ stream->synchronize(at::mps::SyncType::COMMIT);
+ }
+}
+
+void reshape_and_cache(
+ torch::Tensor &key, // [num_tokens, num_heads, head_size]
+ torch::Tensor &value, // [num_tokens, num_heads, head_size]
+ torch::Tensor
+ &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor
+ &value_cache, // [num_blocks, num_heads, head_size, block_size]
+ torch::Tensor &slot_mapping, // [num_tokens]
+ const std::string &kv_cache_dtype, torch::Tensor &k_scale,
+ torch::Tensor &v_scale) {
+
+ // Determine cache dtype and FP8 usage
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
+ if (use_fp8_scales) {
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
+ "FP8 scales must be float32");
+ }
+
+ TORCH_CHECK(key.device().is_mps() && value.device().is_mps() &&
+ key_cache.device().is_mps() && value_cache.device().is_mps(),
+ "All tensors must be on MPS device");
+
+ // Move slot_mapping to CPU if it's on MPS
+ torch::Tensor slot_mapping_cpu = slot_mapping;
+ if (slot_mapping.device().is_mps()) {
+ slot_mapping_cpu = slot_mapping.cpu();
+ }
+
+ const int64_t num_tokens = key.size(0);
+ const int64_t num_heads = key.size(1);
+ const int64_t head_size = key.size(2);
+ const int64_t block_size = key_cache.size(3);
+ const int64_t x = key_cache.size(4);
+
+ const int32_t key_stride = key.stride(0);
+ const int32_t value_stride = value.stride(0);
+
+ @autoreleasepool {
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
+
+ id device = stream->device();
+ id cmdBuf = stream->commandBuffer();
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
+
+ // Construct the full path to the metallib file
+ std::string moduleDir = getModuleDirectory();
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
+
+ NSString *metallibPathStr =
+ [NSString stringWithUTF8String:metallibPath.c_str()];
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
+ NSError *error = nil;
+ id lib = [device newLibraryWithURL:metallibURL error:&error];
+ if (!lib) {
+ NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@",
+ metallibPathStr, error.localizedDescription);
+ }
+
+ NSString *kernName = nil;
+ std::string kv_dtype_str, cache_dtype_str;
+
+ // Get KV dtype string
+ switch (key.scalar_type()) {
+ case torch::kFloat:
+ kv_dtype_str = "float";
+ break;
+ case torch::kHalf:
+ kv_dtype_str = "half";
+ break;
+ case torch::kBFloat16:
+ kv_dtype_str = "bfloat16_t";
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache");
+ }
+
+ // Get cache dtype string
+ switch (cache_dtype) {
+ case torch::kFloat:
+ cache_dtype_str = "float";
+ break;
+ case torch::kHalf:
+ cache_dtype_str = "half";
+ break;
+ case torch::kBFloat16:
+ cache_dtype_str = "bfloat16_t";
+ break;
+ case torch::kUInt8:
+ cache_dtype_str = "uchar";
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported cache dtype for reshape_and_cache");
+ }
+
+ std::string kernName_str = "reshape_and_cache_kv_" + kv_dtype_str + "_cache_" + cache_dtype_str;
+ kernName = [NSString stringWithUTF8String:kernName_str.c_str()];
+
+ // Create function constants for FP8 support
+ MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init];
+ [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:10];
+
+ id fn = [lib newFunctionWithName:kernName constantValues:constants error:&error];
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String,
+ error ? [NSString stringWithFormat:@": %@", error.localizedDescription].UTF8String : "");
+
+ id pso =
+ [device newComputePipelineStateWithFunction:fn error:&error];
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
+
+ dispatch_queue_t q = stream->queue();
+ dispatch_sync(q, ^{
+ id enc = [cmdBuf computeCommandEncoder];
+ TORCH_CHECK(enc, "Failed to create compute encoder");
+
+ [enc setComputePipelineState:pso];
+
+ // Set tensor buffers
+ [enc setBuffer:getMTLBufferStorage(key)
+ offset:key.storage_offset() * key.element_size()
+ atIndex:0];
+ [enc setBuffer:getMTLBufferStorage(value)
+ offset:value.storage_offset() * value.element_size()
+ atIndex:1];
+ [enc setBuffer:getMTLBufferStorage(key_cache)
+ offset:key_cache.storage_offset() * key_cache.element_size()
+ atIndex:2];
+ [enc setBuffer:getMTLBufferStorage(value_cache)
+ offset:value_cache.storage_offset() * value_cache.element_size()
+ atIndex:3];
+
+ // Set slot mapping buffer
+ id slotMappingBuf =
+ [device newBufferWithBytes:slot_mapping_cpu.data_ptr()
+ length:num_tokens * sizeof(int64_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:slotMappingBuf offset:0 atIndex:4];
+
+ // k_scale and v_scale buffers (for FP8)
+ if (use_fp8_scales) {
+ [enc setBuffer:getMTLBufferStorage(k_scale)
+ offset:k_scale.storage_offset() * k_scale.element_size()
+ atIndex:5];
+ [enc setBuffer:getMTLBufferStorage(v_scale)
+ offset:v_scale.storage_offset() * v_scale.element_size()
+ atIndex:6];
+ } else {
+ // For non-FP8, we still need to increment buffer indices
+ // The Metal kernel expects buffers at indices 5 and 6 even if unused
+ }
+
+ // Set parameters as individual buffers (matching mistralrs pattern)
+ id keyStrideBuf =
+ [device newBufferWithBytes:&key_stride
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:keyStrideBuf offset:0 atIndex:7];
+
+ id valueStrideBuf =
+ [device newBufferWithBytes:&value_stride
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:valueStrideBuf offset:0 atIndex:8];
+
+ const int32_t num_heads_i32 = static_cast(num_heads);
+ id numHeadsBuf =
+ [device newBufferWithBytes:&num_heads_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:numHeadsBuf offset:0 atIndex:9];
+
+ const int32_t head_size_i32 = static_cast(head_size);
+ id headSizeBuf =
+ [device newBufferWithBytes:&head_size_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:headSizeBuf offset:0 atIndex:10];
+
+ const int32_t block_size_i32 = static_cast(block_size);
+ id blockSizeBuf =
+ [device newBufferWithBytes:&block_size_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:blockSizeBuf offset:0 atIndex:11];
+
+ const int32_t x_i32 = static_cast(x);
+ id xBuf =
+ [device newBufferWithBytes:&x_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:xBuf offset:0 atIndex:12];
+
+ const uint64_t threads_per_threadgroup =
+ std::min(512, num_heads * head_size);
+ MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1);
+ MTLSize grid = MTLSizeMake(num_tokens, 1, 1);
+
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
+ [enc endEncoding];
+ });
+
+ stream->synchronize(at::mps::SyncType::COMMIT);
+ }
+}
+
+void reshape_and_cache_flash(
+ torch::Tensor &key, // [num_tokens, num_heads, head_size]
+ torch::Tensor &value, // [num_tokens, num_heads, head_size]
+ torch::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size]
+ torch::Tensor
+ &value_cache, // [num_blocks, block_size, num_heads, head_size]
+ torch::Tensor &slot_mapping, // [num_tokens]
+ const std::string &kv_cache_dtype, torch::Tensor &k_scale,
+ torch::Tensor &v_scale) {
+
+ TORCH_CHECK(key.device().is_mps() && value.device().is_mps() &&
+ key_cache.device().is_mps() && value_cache.device().is_mps(),
+ "All tensors must be on MPS device");
+
+ // Move slot_mapping to CPU if it's on MPS
+ torch::Tensor slot_mapping_cpu = slot_mapping;
+ if (slot_mapping.device().is_mps()) {
+ slot_mapping_cpu = slot_mapping.cpu();
+ }
+
+ const int64_t num_tokens = key.size(0);
+ const int64_t num_heads = key.size(1);
+ const int64_t head_size = key.size(2);
+ const int64_t block_size = key_cache.size(1);
+
+ const int32_t key_stride = key.stride(0);
+ const int32_t value_stride = value.stride(0);
+
+ @autoreleasepool {
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
+
+ id device = stream->device();
+ id cmdBuf = stream->commandBuffer();
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
+
+ // Construct the full path to the metallib file
+ std::string moduleDir = getModuleDirectory();
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
+
+ NSString *metallibPathStr =
+ [NSString stringWithUTF8String:metallibPath.c_str()];
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
+ NSError *error = nil;
+ id lib = [device newLibraryWithURL:metallibURL error:&error];
+ if (!lib) {
+ NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@",
+ metallibPathStr, error.localizedDescription);
+ }
+
+ NSString *kernName = nil;
+ switch (key.scalar_type()) {
+ case torch::kFloat:
+ kernName = @"reshape_and_cache_flash_float";
+ break;
+ case torch::kHalf:
+ kernName = @"reshape_and_cache_flash_half";
+ break;
+ case torch::kBFloat16:
+ kernName = @"reshape_and_cache_flash_bfloat16_t";
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache_flash");
+ }
+
+ id fn = [lib newFunctionWithName:kernName];
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String);
+
+ id pso =
+ [device newComputePipelineStateWithFunction:fn error:&error];
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
+
+ dispatch_queue_t q = stream->queue();
+ dispatch_sync(q, ^{
+ id enc = [cmdBuf computeCommandEncoder];
+ TORCH_CHECK(enc, "Failed to create compute encoder");
+
+ [enc setComputePipelineState:pso];
+
+ // Set tensor buffers
+ [enc setBuffer:getMTLBufferStorage(key)
+ offset:key.storage_offset() * key.element_size()
+ atIndex:0];
+ [enc setBuffer:getMTLBufferStorage(value)
+ offset:value.storage_offset() * value.element_size()
+ atIndex:1];
+ [enc setBuffer:getMTLBufferStorage(key_cache)
+ offset:key_cache.storage_offset() * key_cache.element_size()
+ atIndex:2];
+ [enc setBuffer:getMTLBufferStorage(value_cache)
+ offset:value_cache.storage_offset() * value_cache.element_size()
+ atIndex:3];
+
+ // Set slot mapping buffer
+ id slotMappingBuf =
+ [device newBufferWithBytes:slot_mapping_cpu.data_ptr()
+ length:num_tokens * sizeof(int64_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:slotMappingBuf offset:0 atIndex:4];
+
+ // Set parameters as individual buffers
+ id keyStrideBuf =
+ [device newBufferWithBytes:&key_stride
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:keyStrideBuf offset:0 atIndex:5];
+
+ id valueStrideBuf =
+ [device newBufferWithBytes:&value_stride
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:valueStrideBuf offset:0 atIndex:6];
+
+ const int32_t num_heads_i32 = static_cast(num_heads);
+ id numHeadsBuf =
+ [device newBufferWithBytes:&num_heads_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:numHeadsBuf offset:0 atIndex:7];
+
+ const int32_t head_size_i32 = static_cast(head_size);
+ id headSizeBuf =
+ [device newBufferWithBytes:&head_size_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:headSizeBuf offset:0 atIndex:8];
+
+ const int32_t block_size_i32 = static_cast(block_size);
+ id blockSizeBuf =
+ [device newBufferWithBytes:&block_size_i32
+ length:sizeof(int32_t)
+ options:MTLResourceStorageModeShared];
+ [enc setBuffer:blockSizeBuf offset:0 atIndex:9];
+
+ const uint64_t threads_per_threadgroup =
+ std::min(512, num_heads * head_size);
+ MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1);
+ MTLSize grid = MTLSizeMake(num_tokens, 1, 1);
+
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
+ [enc endEncoding];
+ });
+
+ stream->synchronize(at::mps::SyncType::COMMIT);
+ }
+}
\ No newline at end of file
diff --git a/vllm_metal/metal/kernels/cache/copy_blocks.metal b/vllm_metal/metal/kernels/cache/copy_blocks.metal
new file mode 100644
index 0000000..31595cf
--- /dev/null
+++ b/vllm_metal/metal/kernels/cache/copy_blocks.metal
@@ -0,0 +1,51 @@
+#include "../utils.metal"
+#include
+
+using namespace metal;
+
+template
+[[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]],
+ device T *value_cache [[buffer(1)]],
+ const device int64_t *block_mapping [[buffer(2)]],
+ device const int &numel_per_block,
+ uint tgid [[threadgroup_position_in_grid]],
+ uint tid [[thread_position_in_threadgroup]],
+ uint threads_per_threadgroup
+ [[threads_per_threadgroup]]) {
+ const int pair_idx = tgid;
+
+ int64_t src_block_number = block_mapping[2 * pair_idx];
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+ const int64_t src_block_offset = src_block_number * numel_per_block;
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
+
+ // Copy key cache blocks
+ for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ key_cache[dst_offset] = key_cache[src_offset];
+ }
+
+ // Copy value cache blocks
+ for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ value_cache[dst_offset] = value_cache[src_offset];
+ }
+}
+
+#define instantiate_copy_blocks(type) \
+ template [[host_name("copy_blocks_" #type)]] [[kernel]] void \
+ copy_blocks(device type * key_cache [[buffer(0)]], \
+ device type * value_cache [[buffer(1)]], \
+ const device int64_t *block_mapping [[buffer(2)]], \
+ device const int &numel_per_block, \
+ uint tgid [[threadgroup_position_in_grid]], \
+ uint tid [[thread_position_in_threadgroup]], \
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
+
+instantiate_copy_blocks(float);
+instantiate_copy_blocks(bfloat16_t);
+instantiate_copy_blocks(half);
+instantiate_copy_blocks(uchar);
diff --git a/vllm_metal/metal/kernels/cache/reshape_and_cache.metal b/vllm_metal/metal/kernels/cache/reshape_and_cache.metal
new file mode 100644
index 0000000..28ff7db
--- /dev/null
+++ b/vllm_metal/metal/kernels/cache/reshape_and_cache.metal
@@ -0,0 +1,193 @@
+#include "../utils.metal"
+#include "../float8.metal"
+#include
+
+using namespace metal;
+
+template
+inline CACHE_T to_cache(KV_T v) = delete;
+
+template <> inline uchar to_cache(float v) {
+ return float_to_fp8_e4m3(v);
+}
+
+template <> inline uchar to_cache(bfloat16_t v) {
+ return float_to_fp8_e4m3((float)v);
+}
+
+template <> inline uchar to_cache(half v) {
+ return float_to_fp8_e4m3((float)v);
+}
+
+template <> inline float to_cache(float v) { return v; }
+
+template <> inline bfloat16_t to_cache(bfloat16_t v) {
+ return v;
+}
+
+template <> inline half to_cache(half v) { return v; }
+
+constant bool use_fp8_scales [[function_constant(10)]];
+
+template
+[[kernel]] void reshape_and_cache(
+ const device KV_T *__restrict__ key
+ [[buffer(0)]], // [num_tokens, num_heads, head_size]
+ const device KV_T *__restrict__ value
+ [[buffer(1)]], // [num_tokens, num_heads, head_size]
+ device CACHE_T *__restrict__ key_cache
+ [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x]
+ device CACHE_T *__restrict__ value_cache
+ [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size]
+ const device int64_t *__restrict__ slot_mapping
+ [[buffer(4)]], // [num_tokens]
+ const device float *__restrict__ k_scale
+ [[buffer(5)]], // [1] - only used when use_fp8_scales
+ const device float *__restrict__ v_scale
+ [[buffer(6)]], // [1] - only used when use_fp8_scales
+ device const int &key_stride [[buffer(7)]],
+ device const int &value_stride [[buffer(8)]],
+ device const int &num_heads [[buffer(9)]],
+ device const int &head_size [[buffer(10)]],
+ device const int &block_size [[buffer(11)]],
+ device const int &x [[buffer(12)]],
+ uint gid [[threadgroup_position_in_grid]],
+ uint tid [[thread_position_in_threadgroup]],
+ uint threads_per_threadgroup [[threads_per_threadgroup]]) {
+ const int64_t token_idx = gid;
+ const int64_t slot_idx = slot_mapping[token_idx];
+ if (slot_idx < 0) {
+ // Padding token that should be ignored.
+ return;
+ }
+
+ const int64_t block_idx = slot_idx / block_size;
+ const int64_t block_offset = slot_idx % block_size;
+
+ const int n = num_heads * head_size;
+ for (int i = tid; i < n; i += threads_per_threadgroup) {
+ const int64_t src_key_idx = token_idx * key_stride + i;
+ const int64_t src_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int x_idx = head_offset / x;
+ const int x_offset = head_offset % x;
+
+ const int64_t tgt_key_idx =
+ block_idx * num_heads * (head_size / x) * block_size * x +
+ head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x + x_offset;
+ const int64_t tgt_value_idx =
+ block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size + head_offset * block_size +
+ block_offset;
+
+ if (use_fp8_scales) {
+ key_cache[tgt_key_idx] =
+ to_cache(KV_T((float)key[src_key_idx] / *k_scale));
+ value_cache[tgt_value_idx] =
+ to_cache(KV_T((float)value[src_value_idx] / *v_scale));
+ } else {
+ key_cache[tgt_key_idx] = to_cache(key[src_key_idx]);
+ value_cache[tgt_value_idx] = to_cache(value[src_value_idx]);
+ }
+ }
+}
+
+#define instantiate_reshape_and_cache(kv_type, cache_type) \
+ template [[host_name("reshape_and_cache_kv_" #kv_type \
+ "_cache_" #cache_type)]] [[kernel]] void \
+ reshape_and_cache( \
+ const device kv_type *__restrict__ key [[buffer(0)]], \
+ const device kv_type *__restrict__ value [[buffer(1)]], \
+ device cache_type *__restrict__ key_cache [[buffer(2)]], \
+ device cache_type *__restrict__ value_cache [[buffer(3)]], \
+ const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
+ const device float *__restrict__ k_scale [[buffer(5)]], \
+ const device float *__restrict__ v_scale [[buffer(6)]], \
+ device const int &key_stride [[buffer(7)]], \
+ device const int &value_stride [[buffer(8)]], \
+ device const int &num_heads [[buffer(9)]], \
+ device const int &head_size [[buffer(10)]], \
+ device const int &block_size [[buffer(11)]], \
+ device const int &x [[buffer(12)]], \
+ uint gid [[threadgroup_position_in_grid]], \
+ uint tid [[thread_position_in_threadgroup]], \
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
+
+instantiate_reshape_and_cache(float, float);
+instantiate_reshape_and_cache(bfloat16_t, bfloat16_t);
+instantiate_reshape_and_cache(half, half);
+
+instantiate_reshape_and_cache(float, uchar);
+instantiate_reshape_and_cache(bfloat16_t, uchar);
+instantiate_reshape_and_cache(half, uchar);
+
+// Flash version with different cache layout: [num_blocks, block_size,
+// num_heads, head_size]
+template
+[[kernel]] void reshape_and_cache_flash(
+ const device T *__restrict__ key
+ [[buffer(0)]], // [num_tokens, num_heads, head_size]
+ const device T *__restrict__ value
+ [[buffer(1)]], // [num_tokens, num_heads, head_size]
+ device T *__restrict__ key_cache
+ [[buffer(2)]], // [num_blocks, block_size, num_heads, head_size]
+ device T *__restrict__ value_cache
+ [[buffer(3)]], // [num_blocks, block_size, num_heads, head_size]
+ const device int64_t *__restrict__ slot_mapping
+ [[buffer(4)]], // [num_tokens]
+ device const int &key_stride, device const int &value_stride,
+ device const int &num_heads, device const int &head_size,
+ device const int &block_size, uint gid [[threadgroup_position_in_grid]],
+ uint tid [[thread_position_in_threadgroup]],
+ uint threads_per_threadgroup [[threads_per_threadgroup]]) {
+ const int64_t token_idx = gid;
+ const int64_t slot_idx = slot_mapping[token_idx];
+ if (slot_idx < 0) {
+ // Padding token that should be ignored.
+ return;
+ }
+
+ const int64_t block_idx = slot_idx / block_size;
+ const int64_t block_offset = slot_idx % block_size;
+
+ const int n = num_heads * head_size;
+ for (int i = tid; i < n; i += threads_per_threadgroup) {
+ const int64_t src_key_idx = token_idx * key_stride + i;
+ const int64_t src_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+
+ // Flash cache layout: [num_blocks, block_size, num_heads, head_size]
+ const int64_t tgt_key_idx = block_idx * block_size * num_heads * head_size +
+ block_offset * num_heads * head_size +
+ head_idx * head_size + head_offset;
+ const int64_t tgt_value_idx =
+ block_idx * block_size * num_heads * head_size +
+ block_offset * num_heads * head_size + head_idx * head_size +
+ head_offset;
+ key_cache[tgt_key_idx] = key[src_key_idx];
+ value_cache[tgt_value_idx] = value[src_value_idx];
+ }
+}
+
+#define instantiate_reshape_and_cache_flash(type) \
+ template [[host_name("reshape_and_cache_flash_" #type)]] [[kernel]] void \
+ reshape_and_cache_flash( \
+ const device type *__restrict__ key [[buffer(0)]], \
+ const device type *__restrict__ value [[buffer(1)]], \
+ device type *__restrict__ key_cache [[buffer(2)]], \
+ device type *__restrict__ value_cache [[buffer(3)]], \
+ const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
+ device const int &key_stride, device const int &value_stride, \
+ device const int &num_heads, device const int &head_size, \
+ device const int &block_size, uint gid [[threadgroup_position_in_grid]], \
+ uint tid [[thread_position_in_threadgroup]], \
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
+
+instantiate_reshape_and_cache_flash(float);
+instantiate_reshape_and_cache_flash(bfloat16_t);
+instantiate_reshape_and_cache_flash(half);
diff --git a/vllm_metal/metal/kernels/convert_fp8.metal b/vllm_metal/metal/kernels/convert_fp8.metal
new file mode 100644
index 0000000..22028ce
--- /dev/null
+++ b/vllm_metal/metal/kernels/convert_fp8.metal
@@ -0,0 +1,77 @@
+#include "float8.metal"
+#include "utils.metal"
+#include
+
+using namespace metal;
+
+// Convert between different precision formats for cache tensors
+// This kernel handles conversions like float->fp8, fp8->float, etc.
+
+template
+[[kernel]] void convert_fp8_kernel(
+ const device SRC_T *__restrict__ src [[buffer(0)]],
+ device DST_T *__restrict__ dst [[buffer(1)]],
+ const device float &scale [[buffer(2)]],
+ const device uint32_t &num_elements [[buffer(3)]],
+ uint gid [[thread_position_in_grid]]) {
+
+ if (gid >= num_elements) {
+ return;
+ }
+
+ // Load source value
+ SRC_T src_val = src[gid];
+
+ // Convert based on source and destination types
+ if constexpr (is_same_v && !is_same_v) {
+ // FP8 -> higher precision (dequantization)
+ float fp32_val = fp8_e4m3_to_float(src_val) * scale;
+ dst[gid] = static_cast(fp32_val);
+ } else if constexpr (!is_same_v && is_same_v) {
+ // Higher precision -> FP8 (quantization)
+ float fp32_val = static_cast(src_val) / scale;
+ dst[gid] = float_to_fp8_e4m3(fp32_val);
+ } else if constexpr (is_same_v && is_same_v) {
+ // FP8 -> FP8 (with rescaling)
+ float fp32_val = fp8_e4m3_to_float(src_val) * scale;
+ dst[gid] = float_to_fp8_e4m3(fp32_val);
+ } else {
+ // Regular precision -> regular precision (with scaling)
+ float fp32_val = static_cast(src_val) * scale;
+ dst[gid] = static_cast(fp32_val);
+ }
+}
+
+// Instantiate all required combinations
+#define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \
+ template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \
+ [[kernel]] void convert_fp8_kernel( \
+ const device src_type *__restrict__ src [[buffer(0)]], \
+ device dst_type *__restrict__ dst [[buffer(1)]], \
+ const device float &scale [[buffer(2)]], \
+ const device uint32_t &num_elements [[buffer(3)]], \
+ uint gid [[thread_position_in_grid]]);
+
+// FP8 to other formats (dequantization)
+INSTANTIATE_CONVERT_FP8(uchar, float);
+INSTANTIATE_CONVERT_FP8(uchar, half);
+INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t);
+
+// Other formats to FP8 (quantization)
+INSTANTIATE_CONVERT_FP8(float, uchar);
+INSTANTIATE_CONVERT_FP8(half, uchar);
+INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar);
+
+// FP8 to FP8 (rescaling)
+INSTANTIATE_CONVERT_FP8(uchar, uchar);
+
+// Regular precision conversions with scaling
+INSTANTIATE_CONVERT_FP8(float, float);
+INSTANTIATE_CONVERT_FP8(float, half);
+INSTANTIATE_CONVERT_FP8(float, bfloat16_t);
+INSTANTIATE_CONVERT_FP8(half, float);
+INSTANTIATE_CONVERT_FP8(half, half);
+INSTANTIATE_CONVERT_FP8(half, bfloat16_t);
+INSTANTIATE_CONVERT_FP8(bfloat16_t, float);
+INSTANTIATE_CONVERT_FP8(bfloat16_t, half);
+INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t);
\ No newline at end of file
diff --git a/vllm_metal/metal/kernels/float8.metal b/vllm_metal/metal/kernels/float8.metal
new file mode 100644
index 0000000..b911eba
--- /dev/null
+++ b/vllm_metal/metal/kernels/float8.metal
@@ -0,0 +1,122 @@
+#include
+using namespace metal;
+
+// Helpers ------------------------------------------------------------
+static inline uint as_bits(float x) { return as_type(x); }
+static inline float from_bits(uint b) { return as_type(b); }
+
+// -------------------------------------------------------------------
+// FP8 E4M3 (bias = 7)
+// -------------------------------------------------------------------
+inline float fp8_e4m3_to_float(uchar v) {
+ const uint s = v >> 7;
+ const uint exp = (v >> 3) & 0xF;
+ const uint man = v & 0x7;
+
+ if (exp == 0) { // zero / sub-normal
+ if (man == 0)
+ return s ? -0.f : 0.f;
+ const float m = float(man) / 8.f; // already scaled by 2^-3
+ float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6
+ return s ? -val : val;
+ }
+
+ if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN)
+ if (man != 0)
+ return NAN;
+ return s ? -INFINITY : INFINITY;
+ }
+
+ const float m = 1.f + float(man) / 8.f;
+ float val = ldexp(m, int(exp) - 7);
+ return s ? -val : val;
+}
+
+// -------------------------------------------------------------------
+// FP8 E5M2 (bias = 15)
+// -------------------------------------------------------------------
+inline float fp8_e5m2_to_float(uchar v) {
+ const uint s = v >> 7;
+ const uint exp = (v >> 2) & 0x1F;
+ const uint man = v & 0x3;
+
+ if (exp == 0) {
+ if (man == 0)
+ return s ? -0.f : 0.f;
+ const float m = float(man) / 4.f;
+ float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14
+ return s ? -val : val;
+ }
+
+ if (exp == 0x1F) {
+ if (man != 0)
+ return NAN;
+ return s ? -INFINITY : INFINITY;
+ }
+
+ const float m = 1.f + float(man) / 4.f;
+ float val = ldexp(m, int(exp) - 15);
+ return s ? -val : val;
+}
+
+// -------------------------------------------------------------------
+// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞)
+// -------------------------------------------------------------------
+namespace detail {
+template
+inline uchar fp32_to_fp8(float f) {
+ const uint bits = as_bits(f);
+ const uint s = bits >> 31;
+ const uint abs = bits & 0x7FFFFFFF;
+
+ // NaN propagates, Inf saturates
+ if (abs >= 0x7F800000u) {
+ return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) |
+ (abs != 0x7F800000u));
+ }
+
+ int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent
+ uint m = abs & 0x7FFFFFu; // 23-bit mantissa
+ const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent
+
+ // ---------- Normal path -------------------------------------------------
+ int e_fp8 = e + BIAS;
+ if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) {
+ // round-to-nearest-even
+ const int shift = 23 - MAN_BITS;
+ uint mant = m >> shift;
+ const uint lsb = mant & 1u;
+ const uint round = (m >> (shift - 1)) & 1u;
+ const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u;
+ mant += (round & (sticky | lsb));
+ if (mant >> MAN_BITS) { // mantissa overflow
+ mant = 0;
+ ++e_fp8;
+ if (e_fp8 > EXP_MAX)
+ return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞
+ }
+ return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) |
+ (mant & ((1u << MAN_BITS) - 1u)));
+ }
+
+ // ---------- Sub-normal / under-flow ------------------------------------
+ if (e_fp8 < 1 - MAN_BITS) // too small -> ±0
+ return uchar(s << 7);
+
+ // shift so that exponent becomes 1
+ int rshift = (1 - e_fp8) + (23 - MAN_BITS);
+ uint mant = (0x800000u | m); // implicit 1
+ uint rounded = (mant + (1u << (rshift - 1))) >> rshift;
+ if (rounded == 0)
+ return uchar(s << 7); // rounds to zero
+
+ return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u)));
+}
+} // namespace detail
+
+inline uchar float_to_fp8_e4m3(float f) {
+ return detail::fp32_to_fp8<4, 3, 7>(f);
+}
+inline uchar float_to_fp8_e5m2(float f) {
+ return detail::fp32_to_fp8<5, 2, 15>(f);
+}
\ No newline at end of file
diff --git a/vllm_metal/metal/kernels/paged_attention.mm b/vllm_metal/metal/kernels/paged_attention.mm
new file mode 100644
index 0000000..71c398d
--- /dev/null
+++ b/vllm_metal/metal/kernels/paged_attention.mm
@@ -0,0 +1,693 @@
+#include
+#include
+#include
+
+#import
+#import
+#include
+#include
+#include
+#include
+#include
+
+static inline id getMTLBufferStorage(const torch::Tensor &tensor) {
+ return __builtin_bit_cast(id, tensor.storage().data());
+}
+
+static std::string getModuleDirectory() {
+ Dl_info dl_info;
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
+ std::string path(dl_info.dli_fname);
+ size_t pos = path.find_last_of('/');
+ if (pos != std::string::npos) {
+ return path.substr(0, pos);
+ }
+ }
+ return ".";
+}
+
+// Helper function to get kernel name based on dtype and parameters
+static std::string getKernelName(const std::string &base_name,
+ torch::ScalarType dtype,
+ torch::ScalarType cache_dtype,
+ int head_size,
+ int block_size, int num_threads,
+ int num_simd_lanes, int partition_size = 0) {
+ std::string dtype_str;
+ switch (dtype) {
+ case torch::kFloat:
+ dtype_str = "float";
+ break;
+ case torch::kHalf:
+ dtype_str = "half";
+ break;
+ case torch::kBFloat16:
+ dtype_str = "bfloat16_t";
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported dtype for paged attention: ", dtype);
+ }
+
+ std::string cache_dtype_str;
+ switch (cache_dtype) {
+ case torch::kFloat:
+ cache_dtype_str = "float";
+ break;
+ case torch::kHalf:
+ cache_dtype_str = "half";
+ break;
+ case torch::kBFloat16:
+ cache_dtype_str = "bfloat16_t";
+ break;
+ case torch::kUInt8:
+ cache_dtype_str = "uchar";
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported cache dtype for paged attention: ", cache_dtype);
+ }
+
+ std::string kernel_name =
+ base_name + "_" + dtype_str + "_cache_" + cache_dtype_str + "_hs" + std::to_string(head_size) + "_bs" +
+ std::to_string(block_size) + "_nt" + std::to_string(num_threads) +
+ "_nsl" + std::to_string(num_simd_lanes);
+
+ if (partition_size >= 0) {
+ kernel_name += "_ps" + std::to_string(partition_size);
+ }
+
+ return kernel_name;
+}
+
+// Helper function to calculate shared memory size
+static size_t calculateSharedMemorySize(int max_seq_len, int head_size,
+ int num_threads, int num_simd_lanes) {
+ // Logits storage: max_seq_len * sizeof(float)
+ size_t logits_size = max_seq_len * sizeof(float);
+
+ // Reduction workspace: 2 * (num_threads / num_simd_lanes) * sizeof(float)
+ size_t reduction_size = 2 * (num_threads / num_simd_lanes) * sizeof(float);
+
+ // Output workspace for cross-warp reduction: head_size * sizeof(float)
+ size_t output_size = head_size * sizeof(float);
+ return std::max(logits_size + reduction_size, output_size);
+}
+
+// Helper function to get supported configurations
+static bool isValidConfiguration(int head_size, int block_size) {
+ // Supported head sizes from the Metal kernel instantiations
+ std::vector