diff --git a/README.md b/README.md index ee1c799..07d76b0 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ vLLM Metal is a plugin that enables vLLM to run on Apple Silicon Macs using MLX - **MLX-accelerated inference**: faster than PyTorch MPS on Apple Silicon - **Unified memory**: True zero-copy operations leveraging Apple Silicon's unified memory architecture - **vLLM compatibility**: Full integration with vLLM's engine, scheduler, and OpenAI-compatible API -- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1` (requires `pip install 'vllm-metal[paged]'`); default path uses MLX-managed KV cache +- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1`; default path uses MLX-managed KV cache. When enabled, expect significantly better serving performance (~82x TTFT, ~3.75x throughput in early benchmarks on Qwen3-0.6B). Other models may have rough edges. - **GQA support**: Grouped-Query Attention for efficient inference ## Requirements @@ -95,14 +95,13 @@ Environment variables for customization: | `VLLM_METAL_USE_MLX` | `1` | Use MLX for compute (1=yes, 0=no) | | `VLLM_MLX_DEVICE` | `gpu` | MLX device (`gpu` or `cpu`) | | `VLLM_METAL_BLOCK_SIZE` | `16` | KV cache block size | -| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache (requires `pip install 'vllm-metal[paged]'`) | +| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache | | `VLLM_METAL_DEBUG` | `0` | Enable debug logging | | `VLLM_USE_MODELSCOPE` | `False` | Set True to change model registry to | | `VLLM_METAL_MODELSCOPE_CACHE` | None | Specify the absolute path of the local model | | `VLLM_METAL_PREFIX_CACHE` | (unset) | Set to enable prefix caching for shared prompt reuse | | `VLLM_METAL_PREFIX_CACHE_FRACTION` | `0.05` | Fraction of MLX working set for prefix cache (0, 1] | - ## Paged KV vs MLX KV memory settings - MLX path (`VLLM_METAL_USE_PAGED_ATTENTION=0`): `VLLM_METAL_MEMORY_FRACTION` must be `auto`. @@ -115,3 +114,7 @@ Environment variables for customization: `auto` | `1` | Yes | Paged KV path; defaults to 0.9 internally `0.7` | `1` | Yes | Paged KV path with explicit memory budget `0.7` | `0` | No | Explicit fraction without paged KV is invalid + +## Acknowledgements + +- The Metal paged attention kernels are currently adapted from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license), via [HuggingFace kernels-community](https://github.com/huggingface/kernels-community). We plan to develop custom kernels in the future. diff --git a/pyproject.toml b/pyproject.toml index 763c53d..846e42f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,16 +35,14 @@ dependencies = [ "transformers>=4.40.0", "accelerate>=0.26.0", "safetensors>=0.4.0", + # Native Metal extension JIT build + "nanobind>=2.0.0; platform_system == 'Darwin' and platform_machine == 'arm64'", # Core utilities "numpy>=1.24.0", "psutil>=5.9.0", ] [project.optional-dependencies] -paged = [ - # Paged attention Metal kernel (opt-in, experimental) - "kernels>=0.4.5; platform_system == 'Darwin' and platform_machine == 'arm64'", -] vllm = ["vllm>=0.14.0"] stt = [ # Speech-to-text audio processing (Whisper models) @@ -58,7 +56,7 @@ dev = [ "mypy>=1.19.1", ] all = [ - "vllm-metal[vllm,paged,stt,dev]", + "vllm-metal[vllm,stt,dev]", ] [project.urls] diff --git a/tests/test_kernel_loader.py b/tests/test_kernel_loader.py deleted file mode 100644 index 0872c8a..0000000 --- a/tests/test_kernel_loader.py +++ /dev/null @@ -1,128 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for kernel_loader: OS-aware revision pinning for Metal compatibility. - -Verifies that: -- macOS 16+ uses the latest HF kernel (default revision) -- macOS 15 and earlier pins to the Nov 2025 compat revision (Metal 3.2) -- Both revisions actually load and expose the expected ops - -Run with: - python -m pytest tests/test_kernel_loader.py -v -s -""" - -from __future__ import annotations - -from unittest import mock - -import pytest - -pytest.importorskip("kernels") - -# --------------------------------------------------------------------------- -# Unit tests (no network, no GPU) -# --------------------------------------------------------------------------- - - -class TestNeedsCompatRevision: - """Test _needs_compat_revision() with mocked macOS versions.""" - - @pytest.mark.parametrize( - "ver, expected", - [ - ("15.7.4", True), # macOS 15 — needs compat - ("14.5", True), # macOS 14 — needs compat - ("26.3", False), # macOS 26 — modern - ("", False), # empty — safe default - ], - ) - def test_version_check(self, ver, expected): - from vllm_metal.metal_kernel_backend.kernel_loader import _needs_compat_revision - - with mock.patch("platform.mac_ver", return_value=(ver, ("", "", ""), "")): - assert _needs_compat_revision() is expected - - -class TestGetKernelRevisionSelection: - """Test that get_paged_attention_ops passes the right revision to get_kernel.""" - - def _reset_kernel_cache(self): - import vllm_metal.metal_kernel_backend.kernel_loader as kl - - kl._kernel = None - - def test_macos_15_uses_compat_revision(self): - self._reset_kernel_cache() - with ( - mock.patch("platform.mac_ver", return_value=("15.7.4", ("", "", ""), "")), - mock.patch("kernels.get_kernel", return_value=mock.MagicMock()) as mk, - ): - from vllm_metal.metal_kernel_backend.kernel_loader import ( - _MACOS15_COMPAT_REVISION, - get_paged_attention_ops, - ) - - get_paged_attention_ops() - mk.assert_called_once_with( - "kernels-community/paged-attention", - revision=_MACOS15_COMPAT_REVISION, - ) - self._reset_kernel_cache() - - def test_macos_26_uses_latest(self): - self._reset_kernel_cache() - with ( - mock.patch("platform.mac_ver", return_value=("26.3", ("", "", ""), "")), - mock.patch("kernels.get_kernel", return_value=mock.MagicMock()) as mk, - ): - from vllm_metal.metal_kernel_backend.kernel_loader import ( - get_paged_attention_ops, - ) - - get_paged_attention_ops() - mk.assert_called_once_with( - "kernels-community/paged-attention", - revision=None, - ) - self._reset_kernel_cache() - - -# --------------------------------------------------------------------------- -# Integration tests (require network + MPS) -# --------------------------------------------------------------------------- - - -def _mps_available() -> bool: - try: - import torch - - return torch.backends.mps.is_available() - except Exception: - return False - - -@pytest.mark.skipif(not _mps_available(), reason="MPS not available") -class TestKernelLoadsForReal: - """Actually load the kernel from HuggingFace and verify ops exist.""" - - _EXPECTED_OPS = {"reshape_and_cache", "paged_attention_v1"} - - def test_latest_revision_loads(self): - from kernels import get_kernel - - kernel = get_kernel("kernels-community/paged-attention") - ops = set(dir(kernel)) - assert self._EXPECTED_OPS <= ops, f"Missing ops: {self._EXPECTED_OPS - ops}" - - def test_compat_revision_loads(self): - from kernels import get_kernel - - from vllm_metal.metal_kernel_backend.kernel_loader import ( - _MACOS15_COMPAT_REVISION, - ) - - kernel = get_kernel( - "kernels-community/paged-attention", - revision=_MACOS15_COMPAT_REVISION, - ) - ops = set(dir(kernel)) - assert self._EXPECTED_OPS <= ops, f"Missing ops: {self._EXPECTED_OPS - ops}" diff --git a/tests/test_metal_kernel_paged.py b/tests/test_metal_kernel_paged.py index 9c0c839..a1d77b6 100644 --- a/tests/test_metal_kernel_paged.py +++ b/tests/test_metal_kernel_paged.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for Metal kernel paged attention — verifies output matches non-paged path. -Requires ``kernels`` package with ``kernels-community/paged-attention`` support. - Run with: python -m pytest tests/test_metal_kernel_paged.py -v -s """ @@ -16,20 +14,19 @@ try: import mlx.core as mx - import torch from mlx_lm import load as mlx_lm_load from mlx_lm.models.cache import make_prompt_cache from vllm_metal.kv_cache_dtype import infer_kv_cache_dtype_from_model except ImportError as exc: pytest.skip( - f"Metal kernel paged attention tests require mlx/torch/mlx_lm: {exc}", + f"Metal kernel paged attention tests require mlx/mlx_lm: {exc}", allow_module_level=True, ) try: - from vllm_metal.metal_kernel_backend.cache import MPSPagedKVCache - from vllm_metal.metal_kernel_backend.kernel_loader import get_paged_attention_ops + from vllm_metal.metal import get_ops + from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache from vllm_metal.metal_kernel_backend.paged_attention import ( MetalKernelPagedAttentionWrapper, patch_model_attention_metal_kernel, @@ -42,22 +39,15 @@ ) except ImportError as exc: pytest.skip( - "Metal kernel paged attention tests require the vllm-metal paged backend: " - f"{exc}. Install with: pip install 'vllm-metal[paged]'", + f"Metal kernel paged attention tests require vllm-metal paged backend: {exc}", allow_module_level=True, ) @pytest.fixture(scope="module", autouse=True) def _paged_attention_ops_available() -> None: - """Skip this module if the paged-attention ops cannot be loaded.""" - - try: - get_paged_attention_ops() - except ImportError as exc: - pytest.skip(str(exc)) - except Exception as exc: - pytest.skip(f"kernels-community/paged-attention not available: {exc}") + """Fail early if the native paged-attention ops cannot be loaded.""" + get_ops() # --------------------------------------------------------------------------- @@ -65,8 +55,8 @@ def _paged_attention_ops_available() -> None: # --------------------------------------------------------------------------- -def _test_infer_paged_kv_dtype(model) -> torch.dtype: - """Test-only helper: choose a float dtype for MPSPagedKVCache. +def _test_infer_paged_kv_dtype(model) -> mx.Dtype: + """Test-only helper: choose a float dtype for MetalPagedKVCache. This is deliberately local to this test module. Production code uses `vllm_metal.kv_cache_dtype.infer_kv_cache_dtype_from_model()`. @@ -115,7 +105,7 @@ def _greedy_generate_metal_kernel( total_tokens = len(token_ids) + max_new + BLOCK_SIZE num_blocks = (total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE + 4 - mps_cache = MPSPagedKVCache( + metal_cache = MetalPagedKVCache( num_layers=num_layers, num_kv_heads=num_kv_heads, head_dim=head_dim, @@ -124,7 +114,7 @@ def _greedy_generate_metal_kernel( dtype=_test_infer_paged_kv_dtype(model), ) - n_patched = patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE) + n_patched = patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) assert n_patched == num_layers # Assign block IDs for this sequence (manual allocation) @@ -228,7 +218,7 @@ def test_batched_decode_matches(self, qwen3_model): ) num_blocks = ((total_max + BLOCK_SIZE - 1) // BLOCK_SIZE) * len(prompts) + 8 - mps_cache = MPSPagedKVCache( + metal_cache = MetalPagedKVCache( num_layers=num_layers, num_kv_heads=num_kv_heads, head_dim=head_dim, @@ -236,7 +226,7 @@ def test_batched_decode_matches(self, qwen3_model): block_size=BLOCK_SIZE, dtype=_test_infer_paged_kv_dtype(model), ) - patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE) + patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) # Prefill each prompt all_token_ids = [] @@ -305,7 +295,7 @@ def test_patch_replaces_self_attn(self, qwen3_model): model, _ = qwen3_model args = model.args - mps_cache = MPSPagedKVCache( + metal_cache = MetalPagedKVCache( num_layers=args.num_hidden_layers, num_kv_heads=args.num_key_value_heads, head_dim=args.head_dim, @@ -313,7 +303,7 @@ def test_patch_replaces_self_attn(self, qwen3_model): block_size=BLOCK_SIZE, dtype=_test_infer_paged_kv_dtype(model), ) - patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE) + patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) layers = model.model.layers for i, layer in enumerate(layers): @@ -328,7 +318,7 @@ def test_fallback_when_no_context(self, qwen3_model): model, _ = qwen3_model args = model.args - mps_cache = MPSPagedKVCache( + metal_cache = MetalPagedKVCache( num_layers=args.num_hidden_layers, num_kv_heads=args.num_key_value_heads, head_dim=args.head_dim, @@ -336,7 +326,7 @@ def test_fallback_when_no_context(self, qwen3_model): block_size=BLOCK_SIZE, dtype=_test_infer_paged_kv_dtype(model), ) - patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE) + patch_model_attention_metal_kernel(model, metal_cache, BLOCK_SIZE) # Run forward without setting context → should use fallback cache = make_prompt_cache(model) diff --git a/vllm_metal/kv_cache_dtype.py b/vllm_metal/kv_cache_dtype.py index 365b53f..b52501e 100644 --- a/vllm_metal/kv_cache_dtype.py +++ b/vllm_metal/kv_cache_dtype.py @@ -2,7 +2,7 @@ """KV cache dtype inference and policy. The Metal paged-attention backend stores *activation* K/V tensors in an -MPS-backed cache. Those tensors must be floating point. Some models may have +MLX-native cache. Those tensors must be floating point. Some models may have quantized *weights* (e.g. int8), so we must not derive the KV cache dtype from weights without enforcing a float-only policy. """ @@ -12,17 +12,16 @@ from dataclasses import dataclass from typing import Any -import torch +import mlx.core as mx from vllm_metal.paged_attention_common import find_layers_and_attr -from vllm_metal.pytorch_backend.tensor_bridge import MLX_TO_TORCH_DTYPE -DEFAULT_KV_CACHE_DTYPE = torch.float16 -ALLOWED_KV_CACHE_DTYPES: frozenset[torch.dtype] = frozenset( +DEFAULT_KV_CACHE_DTYPE: mx.Dtype = mx.float16 +ALLOWED_KV_CACHE_DTYPES: frozenset[mx.Dtype] = frozenset( { - torch.float16, - torch.bfloat16, - torch.float32, + mx.float16, + mx.bfloat16, + mx.float32, } ) @@ -31,18 +30,17 @@ class KvCacheDtypeInference: """Result of inferring the KV cache dtype from a model.""" - dtype: torch.dtype + dtype: mx.Dtype warning: str | None = None def infer_kv_cache_dtype_from_model( - model: Any, *, default: torch.dtype = DEFAULT_KV_CACHE_DTYPE + model: Any, *, default: mx.Dtype = DEFAULT_KV_CACHE_DTYPE ) -> KvCacheDtypeInference: """Infer a float KV-cache dtype from an MLX(-LM/-VLM) model. Policy: - - If we can map the model's attention weight dtype to torch and it's a - supported float dtype, use it. + - If the model's attention weight dtype is a supported float dtype, use it. - Otherwise, fall back to *default* and provide a warning string the caller may log. """ @@ -62,20 +60,13 @@ def infer_kv_cache_dtype_from_model( warning=f"Cannot infer KV cache dtype from model ({exc}); using {default}", ) - torch_dtype = MLX_TO_TORCH_DTYPE.get(mlx_dtype) - if torch_dtype is None: - return KvCacheDtypeInference( - dtype=default, - warning=f"Unsupported MLX dtype for KV cache ({mlx_dtype!r}); using {default}", - ) - - if torch_dtype not in ALLOWED_KV_CACHE_DTYPES: + if mlx_dtype not in ALLOWED_KV_CACHE_DTYPES: return KvCacheDtypeInference( dtype=default, warning=( - f"Model weight dtype {mlx_dtype!r} maps to non-float torch dtype " - f"{torch_dtype}; using {default} for KV cache instead" + f"Model weight dtype {mlx_dtype!r} is not a supported float dtype; " + f"using {default} for KV cache instead" ), ) - return KvCacheDtypeInference(dtype=torch_dtype) + return KvCacheDtypeInference(dtype=mlx_dtype) diff --git a/vllm_metal/metal/README.md b/vllm_metal/metal/README.md new file mode 100644 index 0000000..1d65a78 --- /dev/null +++ b/vllm_metal/metal/README.md @@ -0,0 +1,43 @@ +# Metal Kernel Sources + +This directory contains two sets of Metal paged-attention shaders, both vendored from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license). + +## `kernels/` — active (current) + +Drop-in replacement for the HuggingFace kernels-community paged-attention shaders, originally vendored from an older version of mistral.rs. This is what `paged_ops.cpp` compiles and dispatches today via MLX. + +| File | Purpose | +|------|---------| +| `utils.metal` | bfloat16 polyfill, operator overloads | +| `float8.metal` | FP8 E4M3/E5M2 encode/decode helpers | +| `attention/paged_attention.metal` | paged attention v1/v2 kernels | +| `cache/reshape_and_cache.metal` | write projected K/V into block cache | +| `cache/copy_blocks.metal` | block-level cache copy kernel | +| `convert_fp8.metal` | FP8 precision conversion kernel | + +### Reference only (not compiled, kept for context) + +- `paged_attention.mm` — PyTorch MPS dispatch (Obj-C++), replaced by `paged_ops.cpp` +- `cache.mm` — PyTorch MPS cache ops (Obj-C++), replaced by `paged_ops.cpp` + +## `kernels_v1/` — next-generation (not yet wired up) + +Latest Metal kernels from the mistral.rs repo. More mature than `kernels/`, with preliminary scaffolding for variable-length sequences and gpt-oss sink attention support. + +| File | Purpose | +|------|---------| +| `utils.metal` | shared types and helpers | +| `float8.metal` | FP8 encode/decode helpers | +| `pagedattention.metal` | paged attention kernel (restructured) | +| `reshape_and_cache.metal` | K/V cache reshape kernel | +| `copy_blocks.metal` | block-level cache copy kernel | +| `gather_kv_cache.metal` | *new* — gather KV from non-contiguous blocks | +| `kv_scale_update.metal` | *new* — KV scale update for quantised caches | + +## Deprecation plan + +Neither kernel set will persist long-term. Both are slated for deprecation once we introduce first-class variable-length kernel support, which is a prerequisite for: + +- Continuous batching +- Chunked prefill +- MQA Scorer speculative decoding diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py new file mode 100644 index 0000000..2cbc142 --- /dev/null +++ b/vllm_metal/metal/__init__.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Native paged-attention Metal kernels dispatched through MLX. + +Usage:: + + from vllm_metal.metal import get_ops + ops = get_ops() + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + ops.paged_attention_v1(out, query, key_cache, value_cache, ...) +""" + +from __future__ import annotations + +import importlib +import importlib.util +import logging +import re +from pathlib import Path +from types import ModuleType + +logger = logging.getLogger(__name__) + +_THIS_DIR = Path(__file__).resolve().parent +_KERNELS_DIR = _THIS_DIR / "kernels" + +# Cached after first get_ops() call. The Metal shaders are JIT-compiled once +# and held in MLX's library cache for the lifetime of the process. Editing +# .metal source files requires restarting the Python interpreter to pick up +# changes (the .cpp extension itself is rebuilt automatically by build.py when +# paged_ops.cpp is newer than the .so). +_ops_module: ModuleType | None = None + + +def _read_metal_source(path: Path) -> str: + """Read a .metal file and strip local #include directives.""" + text = path.read_text() + # Remove #include "..." for our vendored files (keep etc.) + text = re.sub(r'#include\s+"[^"]*"', "", text) + return text + + +def _build_reshape_cache_source() -> str: + """Concatenate utils + float8 + reshape_and_cache into a single source.""" + parts = [ + _read_metal_source(_KERNELS_DIR / "utils.metal"), + _read_metal_source(_KERNELS_DIR / "float8.metal"), + _read_metal_source(_KERNELS_DIR / "cache" / "reshape_and_cache.metal"), + ] + return "\n".join(parts) + + +def _build_paged_attention_source() -> str: + """Concatenate utils + float8 + paged_attention into a single source.""" + parts = [ + _read_metal_source(_KERNELS_DIR / "utils.metal"), + _read_metal_source(_KERNELS_DIR / "float8.metal"), + _read_metal_source(_KERNELS_DIR / "attention" / "paged_attention.metal"), + ] + return "\n".join(parts) + + +def get_ops() -> ModuleType: + """JIT-build and import the native paged_ops extension. + + The Metal shader sources are read, pre-processed (includes inlined), + and passed to the C++ extension which JIT-compiles them via + ``mlx::core::metal::Device::get_library()``. + + Returns: + The ``_paged_ops`` module with ``reshape_and_cache()`` and + ``paged_attention_v1()``. + """ + global _ops_module + if _ops_module is not None: + return _ops_module + + # 1. JIT-build the C++ extension if needed + from vllm_metal.metal.build import build + + so_path = build() + + # 2. Import the built extension + spec = importlib.util.spec_from_file_location("_paged_ops", str(so_path)) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load extension from {so_path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # 3. Initialise Metal libraries (JIT-compile shaders) + reshape_src = _build_reshape_cache_source() + paged_attn_src = _build_paged_attention_source() + mod.init_libraries(reshape_src, paged_attn_src) + + _ops_module = mod + logger.info("Native paged-attention Metal kernels loaded") + return mod diff --git a/vllm_metal/metal/build.py b/vllm_metal/metal/build.py new file mode 100644 index 0000000..137e519 --- /dev/null +++ b/vllm_metal/metal/build.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +"""JIT build script for the native paged-attention Metal extension. + +Compiles ``paged_ops.cpp`` + nanobind into a shared library that dispatches +Metal shaders through MLX's own command encoder. +""" + +from __future__ import annotations + +import logging +import subprocess +import sysconfig +from pathlib import Path + +logger = logging.getLogger(__name__) + +_THIS_DIR = Path(__file__).resolve().parent +_SRC = _THIS_DIR / "paged_ops.cpp" +_EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX") or ".so" +_CACHE_DIR = Path.home() / ".cache" / "vllm-metal" +_CACHE_DIR.mkdir(parents=True, exist_ok=True) +_OUT = _CACHE_DIR / f"_paged_ops{_EXT_SUFFIX}" + + +def _find_package_path(name: str) -> Path: + """Resolve a Python package's root directory.""" + import importlib + + mod = importlib.import_module(name) + paths = getattr(mod, "__path__", None) + if paths: + return Path(list(paths)[0]) + f = getattr(mod, "__file__", None) + if f: + return Path(f).parent + raise RuntimeError(f"Cannot locate package '{name}'") + + +def needs_rebuild() -> bool: + """Return True if the .so is missing or older than the source.""" + if not _OUT.exists(): + return True + src_mtime = _SRC.stat().st_mtime + return _OUT.stat().st_mtime < src_mtime + + +def build() -> Path: + """JIT-build the native extension, returning the path to the .so.""" + if not needs_rebuild(): + return _OUT + + logger.info("Building native paged-attention extension ...") + + py_include = sysconfig.get_paths()["include"] + nb_path = _find_package_path("nanobind") + mlx_path = _find_package_path("mlx") + mlx_include = mlx_path / "include" + mlx_lib = mlx_path / "lib" + metal_cpp = mlx_include / "metal_cpp" + + # Verify critical paths exist + for p, label in [ + (py_include, "Python include"), + (nb_path, "nanobind"), + (mlx_include, "MLX include"), + (mlx_lib / "libmlx.dylib", "MLX lib"), + ]: + if not Path(p).exists(): + raise FileNotFoundError(f"{label} not found: {p}") + + nb_src = nb_path / "src" / "nb_combined.cpp" + if not nb_src.exists(): + raise FileNotFoundError(f"nanobind source not found: {nb_src}") + + cmd = [ + "clang++", + "-std=c++17", + "-shared", + "-fPIC", + "-O2", + "-fvisibility=default", + f"-I{py_include}", + f"-I{nb_path / 'include'}", + f"-I{nb_path / 'src'}", + f"-I{nb_path / 'ext' / 'robin_map' / 'include'}", + f"-I{mlx_include}", + f"-I{metal_cpp}", + f"-L{mlx_lib}", + "-lmlx", + "-framework", + "Metal", + "-framework", + "Foundation", + f"-Wl,-rpath,{mlx_lib}", + "-D_METAL_", + "-DACCELERATE_NEW_LAPACK", + "-undefined", + "dynamic_lookup", + str(nb_src), + str(_SRC), + "-o", + str(_OUT), + ] + + logger.info(" %s", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError( + f"Failed to build paged_ops extension:\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + + logger.info("Built %s", _OUT) + return _OUT diff --git a/vllm_metal/metal/kernels/README.md b/vllm_metal/metal/kernels/README.md new file mode 100644 index 0000000..32b4195 --- /dev/null +++ b/vllm_metal/metal/kernels/README.md @@ -0,0 +1,17 @@ +# Metal Kernel Sources + +Vendored from [mistral.rs](https://github.com/EricLBuehler/mistral.rs) (MIT license), via [HuggingFace kernels-community](https://github.com/huggingface/kernels-community). + +## Active (used by `paged_ops.cpp` via MLX dispatch) + +- `utils.metal` — bfloat16 polyfill, operator overloads +- `float8.metal` — FP8 E4M3/E5M2 encode/decode helpers +- `attention/paged_attention.metal` — paged attention v1/v2 kernels +- `cache/reshape_and_cache.metal` — write projected K/V into block cache +- `cache/copy_blocks.metal` — block-level cache copy kernel +- `convert_fp8.metal` — FP8 precision conversion kernel + +## Reference only (not compiled, kept for future use) + +- `paged_attention.mm` — PyTorch MPS dispatch (Obj-C++), replaced by `../paged_ops.cpp` +- `cache.mm` — PyTorch MPS cache ops (Obj-C++), replaced by `../paged_ops.cpp` diff --git a/vllm_metal/metal/kernels/attention/paged_attention.metal b/vllm_metal/metal/kernels/attention/paged_attention.metal new file mode 100644 index 0000000..22d972d --- /dev/null +++ b/vllm_metal/metal/kernels/attention/paged_attention.metal @@ -0,0 +1,1401 @@ +// Updated from MLX commit has f70764a + +#include "../utils.metal" +#include "../float8.metal" +#include +#include + +using namespace metal; + +// ========================================== Generic vector types + +// A vector type to store Q, K, V elements. +template struct Vec {}; + +// A vector type to store FP32 accumulators. +template struct FloatVec {}; + +// Template vector operations. +template inline Acc mul(A a, B b); + +template inline float sum(T v); + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +// FP32 vector data types. +struct Float8_ { + float4 x; + float4 y; +}; + +template <> struct Vec { + using Type = float; +}; +template <> struct Vec { + using Type = float2; +}; +template <> struct Vec { + using Type = float4; +}; +template <> struct Vec { + using Type = Float8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(float a, float b) { return a * b; } + +template <> inline float2 mul(float2 a, float2 b) { return a * b; } + +template <> inline float4 mul(float4 a, float4 b) { return a * b; } + +template <> inline Float8_ mul(Float8_ a, Float8_ b) { + Float8_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float sum(float a) { return a; } + +template <> inline float sum(float2 a) { return a.x + a.y; } + +template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); } + +inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { + Float8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread float &dst, float src) { dst = src; } +inline void from_float(thread float2 &dst, float2 src) { dst = src; } +inline void from_float(thread float4 &dst, float4 src) { dst = src; } +inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; } + +// BF16 vector data types. +// #if defined(__HAVE_BFLOAT__) + +// struct Bfloat8_ { +// bfloat4 x; +// bfloat4 y; +// }; + +// template<> +// struct Vec { +// using Type = bfloat; +// }; +// template<> +// struct Vec { +// using Type = bfloat2; +// }; +// template<> +// struct Vec { +// using Type = bfloat4; +// }; +// template<> +// struct Vec { +// using Type = Bfloat8_; +// }; + +// template<> +// struct FloatVec { +// using Type = float; +// }; +// template<> +// struct FloatVec { +// using Type = float2; +// }; +// template<> +// struct FloatVec { +// using Type = float4; +// }; +// template<> +// struct FloatVec { +// using Type = Float8_; +// }; + +// template<> +// inline float mul(bfloat a, bfloat b) { +// return (float)a * (float)b; +// } +// template<> +// inline bfloat mul(bfloat a, bfloat b) { +// return a*b; +// } + +// template<> +// inline float2 mul(bfloat2 a, bfloat2 b) { +// return (float2)a * (float2)b; +// } +// template<> +// inline bfloat2 mul(bfloat2 a, bfloat2 b) { +// return a * b; +// } + +// template<> +// inline float4 mul(bfloat4 a, bfloat4 b) { +// return (float4)a * (float4)b; +// } +// template<> +// inline bfloat4 mul(bfloat4 a, bfloat4 b) { +// return a * b; +// } + +// template<> +// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Float8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } +// template<> +// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Bfloat8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } + +// template<> +// inline float sum(bfloat a) { +// return (float)a; +// } + +// template<> +// inline float sum(bfloat2 a) { +// return (float)a.x + (float)a.y; +// } + +// template<> +// inline float sum(bfloat4 a) { +// return sum(a.x) + sum(a.y); +// } + +// template<> +// inline float sum(Bfloat8_ a) { +// return sum(a.x) + sum(a.y); +// } + +// inline float fma(bfloat a, bfloat b, float c) { +// return (float)a * (float)b + c; +// } + +// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) { +// return (float2)a * (float2)b + c; +// } + +// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) { +// return (float4)a * (float4)b + c; +// } + +// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { +// Float8_ res; +// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y); +// return res; +// } +// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { +// Bfloat8_ res; +// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y); +// return c; +// } + +// inline void from_float(thread bfloat& dst, float src) { +// dst = static_cast(src); +// } +// inline void from_float(thread bfloat2& dst, float2 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// } +// inline void from_float(thread bfloat4& dst, float4 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// dst.z = static_cast(src.z); +// dst.w = static_cast(src.w); +// } +// inline void from_float(thread Bfloat8_& dst, Float8_ src) { +// bfloat4 x; +// bfloat4 y; +// from_float(x, src.x); +// from_float(y, src.y); +// dst.x = x; +// dst.y = y; +// } + +// #else + +struct Bfloat2_ { + bfloat16_t x; + bfloat16_t y; +}; + +struct Bfloat4_ { + Bfloat2_ x; + Bfloat2_ y; +}; + +struct Bfloat8_ { + Bfloat4_ x; + Bfloat4_ y; +}; + +template <> struct Vec { + using Type = bfloat16_t; +}; +template <> struct Vec { + using Type = Bfloat2_; +}; +template <> struct Vec { + using Type = Bfloat4_; +}; +template <> struct Vec { + using Type = Bfloat8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(bfloat16_t a, bfloat16_t b) { + return (float)a * (float)b; +} +template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; } + +template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f; +} +template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { + Bfloat2_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) { + float2 x = mul(a.x, b.x); + float2 y = mul(a.y, b.y); + float4 c; + c.x = x.x; + c.y = x.y; + c.z = y.x; + c.w = y.y; + return c; +} +template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { + Bfloat4_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { + Float8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} +template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { + Bfloat8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(bfloat16_t a) { return (float)a; } + +template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); } + +template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(bfloat16_t a, bfloat16_t b, float c) { + return (float)a * (float)b + c; +} +inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) { + return a * b + c; +} + +inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f + c; +} +inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) { + Bfloat2_ res; + res.x = a.x * b.x + c.x; + res.y = a.y * b.y + c.y; + return res; +} + +inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) { + float4 res; + res.x = fma(a.x.x, b.x.x, c.x); + res.y = fma(a.x.y, b.x.y, c.y); + res.z = fma(a.y.x, b.y.x, c.z); + res.w = fma(a.y.y, b.y.y, c.w); + return res; +} +inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) { + Bfloat4_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { + Bfloat8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread bfloat16_t &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread Bfloat2_ &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread Bfloat4_ &dst, float4 src) { + dst.x.x = static_cast(src.x); + dst.x.y = static_cast(src.y); + dst.y.x = static_cast(src.z); + dst.y.y = static_cast(src.w); +} +inline void from_float(thread Bfloat8_ &dst, Float8_ src) { + Bfloat4_ x; + Bfloat4_ y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// #endif + +// FP16 vector data types. +struct Half8_ { + half4 x; + half4 y; +}; + +template <> struct Vec { + using Type = half; +}; +template <> struct Vec { + using Type = half2; +}; +template <> struct Vec { + using Type = half4; +}; +template <> struct Vec { + using Type = Half8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(half a, half b) { return (float)a * (float)b; } +template <> inline half mul(half a, half b) { return a * b; } + +template <> inline float2 mul(half2 a, half2 b) { + return (float2)a * (float2)b; +} +template <> inline half2 mul(half2 a, half2 b) { return a * b; } + +template <> inline float4 mul(half4 a, half4 b) { + return (float4)a * (float4)b; +} +template <> inline half4 mul(half4 a, half4 b) { return a * b; } + +template <> inline Float8_ mul(Half8_ a, Half8_ b) { + float4 x = mul(a.x, b.x); + float4 y = mul(a.y, b.y); + Float8_ c; + c.x = x; + c.y = y; + return c; +} +template <> inline Half8_ mul(Half8_ a, Half8_ b) { + Half8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(half a) { return (float)a; } + +template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(half a, half b, float c) { return (float)a * (float)b + c; } + +inline float2 fma(half2 a, half2 b, float2 c) { + return (float2)a * (float2)b + c; +} + +inline float4 fma(half4 a, half4 b, float4 c) { + return (float4)a * (float4)b + c; +} + +inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) { + Half8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread half &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread half2 &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread half4 &dst, float4 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); + dst.z = static_cast(src.z); + dst.w = static_cast(src.w); +} +inline void from_float(thread Half8_ &dst, Float8_ src) { + half4 x; + half4 y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// ========================================== FP8 (uchar) vector data types. + +// 8‑lane uchar vector – Metal only provides up to uchar4, so build our own. +struct Uchar8_ { + uchar4 x; + uchar4 y; +}; + +// Vec specialisations so Vec::Type resolves correctly. +template <> struct Vec { + using Type = uchar; +}; +template <> struct Vec { + using Type = uchar2; +}; +template <> struct Vec { + using Type = uchar4; +}; +template <> struct Vec { + using Type = Uchar8_; +}; + +// General case: not uchar +template inline constexpr bool is_uchar() { return false; } + +// Specialization: T is uchar +template <> inline constexpr bool is_uchar() { return true; } + +// Generic fallback – will fail to compile if a required specialisation is +// missing. +template +inline Vec fp8_convert(const thread Quant_vec &, float scale) { + static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation"); +} + +// ========================================== FP8 → float/half/bfloat +inline float __dequant_single(uchar v, float scale) { + return fp8_e4m3_to_float(v) * scale; +} + +// ---- 1‑lane ---- +template <> +inline float fp8_convert(const thread uchar &in, float scale) { + return __dequant_single(in, scale); +} +template <> +inline half fp8_convert(const thread uchar &in, float scale) { + return half(__dequant_single(in, scale)); +} +template <> +inline bfloat16_t fp8_convert(const thread uchar &in, + float scale) { + return bfloat16_t(__dequant_single(in, scale)); +} + +// ---- 2‑lane ---- +template <> +inline float2 fp8_convert(const thread uchar2 &in, + float scale) { + return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale)); +} +template <> +inline half2 fp8_convert(const thread uchar2 &in, float scale) { + half2 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + return out; +} +template <> +inline Bfloat2_ fp8_convert(const thread uchar2 &in, + float scale) { + Bfloat2_ out; + out.x = bfloat16_t(__dequant_single(in.x, scale)); + out.y = bfloat16_t(__dequant_single(in.y, scale)); + return out; +} + +// ---- 4‑lane ---- +template <> +inline float4 fp8_convert(const thread uchar4 &in, + float scale) { + return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale), + __dequant_single(in.z, scale), __dequant_single(in.w, scale)); +} +template <> +inline half4 fp8_convert(const thread uchar4 &in, float scale) { + half4 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + out.z = half(__dequant_single(in.z, scale)); + out.w = half(__dequant_single(in.w, scale)); + return out; +} +template <> +inline Bfloat4_ fp8_convert(const thread uchar4 &in, + float scale) { + Bfloat4_ out; + out.x.x = bfloat16_t(__dequant_single(in.x, scale)); + out.x.y = bfloat16_t(__dequant_single(in.y, scale)); + out.y.x = bfloat16_t(__dequant_single(in.z, scale)); + out.y.y = bfloat16_t(__dequant_single(in.w, scale)); + return out; +} + +// ---- 8‑lane ---- +template <> +inline Float8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Float8_ out; + out.x = + float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale), + __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale)); + out.y = + float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale), + __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale)); + return out; +} +template <> +inline Half8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Half8_ out; + out.x = half4(half(__dequant_single(in.x.x, scale)), + half(__dequant_single(in.x.y, scale)), + half(__dequant_single(in.x.z, scale)), + half(__dequant_single(in.x.w, scale))); + out.y = half4(half(__dequant_single(in.y.x, scale)), + half(__dequant_single(in.y.y, scale)), + half(__dequant_single(in.y.z, scale)), + half(__dequant_single(in.y.w, scale))); + return out; +} +template <> +inline Bfloat8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Bfloat8_ out; + // first 4 + out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale)); + out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale)); + out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale)); + out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale)); + // second 4 + out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale)); + out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale)); + out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale)); + out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale)); + return out; +} + +// ========================================== Dot product utilities + +// TODO(EricLBuehler): optimize with vectorization +template +inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { + // Compute the parallel products for Q*K^T (treat vector lanes separately). + using A_vec = typename FloatVec::Type; + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += simd_shuffle_xor(qk, mask); + } + return qk; +} + +template struct Qk_dot { + template + static inline float dot(const threadgroup Vec (&q)[N], + const thread Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +// ========================================== Block sum utility + +// Utility function for attention softmax. +template +inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, + uint simd_lid) { + // Compute the sum per simdgroup. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Simd leaders store the data to shared memory. + if (simd_lid == 0) { + red_smem[simd_tid] = sum; + } + + // Make sure the data is in shared memory. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The warps compute the final sums. + if (simd_lid < NUM_WARPS) { + sum = red_smem[simd_lid]; + } + + // Parallel reduction inside the simd group. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Broadcast to other threads. + return simd_shuffle(sum, 0); +} + +// ========================================== Paged Attention kernel + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +constant bool use_partitioning [[function_constant(10)]]; +constant bool use_alibi [[function_constant(20)]]; +constant bool use_fp8_scales [[function_constant(30)]]; + +template +[[kernel]] void paged_attention( + device float *exp_sums + [[buffer(0)]], // [num_seqs, num_heads, max_num_partitions] - only used when + // use_partitioning + device float *max_logits + [[buffer(1)]], // [num_seqs, num_heads, max_num_partitions] - only used when + // use_partitioning + device T *out + [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] + device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size] + device const CACHE_T *k_cache + [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] + device const CACHE_T *v_cache + [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] + const device float *__restrict__ k_scale + [[buffer(6)]], // [1] - only used when use_fp8_scales + const device float *__restrict__ v_scale + [[buffer(7)]], // [1] - only used when use_fp8_scales + const constant int &num_kv_heads [[buffer(8)]], // [num_heads] + const constant float &scale [[buffer(9)]], + const constant float &softcapping [[buffer(10)]], + device const uint32_t *block_tables + [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq] + device const uint32_t *context_lens [[buffer(12)]], // [num_seqs] + const constant int &max_num_blocks_per_seq [[buffer(13)]], + device const float *alibi_slopes + [[buffer(14)]], // [num_heads] - only used when use_alibi + const constant int &q_stride [[buffer(15)]], + const constant int &kv_block_stride [[buffer(16)]], + const constant int &kv_head_stride [[buffer(17)]], + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int seq_idx = threadgroup_position_in_grid.y; + const int partition_idx = threadgroup_position_in_grid.z; + const int max_num_partitions = threadgroups_per_grid.z; + const int thread_idx = thread_position_in_threadgroup.x; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const uint32_t context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES); + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + const int head_idx = threadgroup_position_in_grid.x; + const int num_heads = threadgroups_per_grid.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the query, and the second thread has + // 1, 5, 9, ... th vectors of the query, and so on. + const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use fp32 on softmax logits for better accuracy + threadgroup float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction + threadgroup float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(CACHE_T); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const device uint32_t *block_table = + block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const device CACHE_T *k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (is_uchar()) { + // FP8 support + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8_convert(k_vec_quant, *k_scale); + } else { + // Non-FP8 default + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + + // Apply softcapping + if (softcapping != 1.0) { + qk = precise::tanh(qk / softcapping) * softcapping; + } + + // Add the ALiBi bias if slopes are given. + if (use_alibi && alibi_slope != 0) { + // Compute bias with explicit float precision to minimize precision loss + int position_offset = token_idx - int(context_len) + 1; + float alibi_bias = alibi_slope * float(position_offset); + qk += alibi_bias; + } + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE: It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : max(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = simd_shuffle(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, + simd_tid, simd_lid); + + // Compute softmax. + const float inv_sum = divide(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) { + device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + using V_quant_vec = typename Vec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE: We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + T zero_value = 0; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + Float_L_vec logits_float_vec = *reinterpret_cast( + logits + token_idx - start_token_idx); + from_float(logits_vec, logits_float_vec); + + const device CACHE_T *v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // NOTE: When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + V_vec v_vec; + + if constexpr (is_uchar()) { + // FP8 support + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + v_vec = fp8_convert(v_quant_vec, *v_scale); + } else { + // Non-FP8 default + v_vec = *reinterpret_cast(v_ptr + offset); + } + + if (block_idx == num_context_blocks - 1) { + thread T *v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += simd_shuffle_xor(acc, mask); + } + accs[i] = acc; + } + + // NOTE: A barrier is required because the shared memory space for logits + // is reused for the output. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Perform reduction across warps. + threadgroup float *out_smem = + reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Lower warps update the output. + if (warp_idx < mid) { + const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write the final output. + if (warp_idx == 0) { + device T *out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + *(out_ptr + row_idx) = T(accs[i]); + } + } + } +} + +template +[[kernel]] void paged_attention_v2_reduce( + device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]], + const device float *max_logits [[buffer(2)]], + const device T *tmp_out [[buffer(3)]], + device uint32_t *context_lens [[buffer(4)]], + const constant int &max_num_partitions [[buffer(5)]], + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint3 threads_per_threadgroup [[threads_per_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int num_heads = threadgroups_per_grid.x; + const int head_idx = threadgroup_position_in_grid.x; + const int seq_idx = threadgroup_position_in_grid.y; + const uint32_t context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += threads_per_threadgroup.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + // Workspace for reduction. + threadgroup float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + threadgroup float *shared_max_logits = + reinterpret_cast(shared_mem); + const device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = max(max_logit, l); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = simd_shuffle(max_logit, 0); + + // Load rescaled exp sums to shared memory. + threadgroup float *shared_exp_sums = reinterpret_cast( + shared_mem + sizeof(float) * num_partitions); + const device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + global_exp_sum = block_sum( + &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); + const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + out_ptr[i] = T(acc); + } +} + +#define instantiate_paged_attention_inner(type, cache_type, head_size, \ + block_size, num_threads, \ + num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_" #type "_cache_" #cache_type \ + "_hs" #head_size "_bs" #block_size "_nt" #num_threads \ + "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention( \ + device float *exp_sums [[buffer(0)]], \ + device float *max_logits [[buffer(1)]], \ + device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \ + device const cache_type *k_cache [[buffer(4)]], \ + device const cache_type *v_cache [[buffer(5)]], \ + const device float *__restrict__ k_scale [[buffer(6)]], \ + const device float *__restrict__ v_scale [[buffer(7)]], \ + const constant int &num_kv_heads [[buffer(8)]], \ + const constant float &scale [[buffer(9)]], \ + const constant float &softcapping [[buffer(10)]], \ + device const uint32_t *block_tables [[buffer(11)]], \ + device const uint32_t *context_lens [[buffer(12)]], \ + const constant int &max_num_blocks_per_seq [[buffer(13)]], \ + device const float *alibi_slopes [[buffer(14)]], \ + const constant int &q_stride [[buffer(15)]], \ + const constant int &kv_block_stride [[buffer(16)]], \ + const constant int &kv_head_stride [[buffer(17)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_v2_reduce_inner( \ + type, head_size, num_threads, num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention_v2_reduce( \ + device type * out [[buffer(0)]], \ + const device float *exp_sums [[buffer(1)]], \ + const device float *max_logits [[buffer(2)]], \ + const device type *tmp_out [[buffer(3)]], \ + device uint32_t *context_lens [[buffer(4)]], \ + const constant int &max_num_partitions [[buffer(5)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_heads( \ + type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, cache_type, 32, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 64, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 80, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 96, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 112, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 120, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 128, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 192, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 256, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); + +#define instantiate_paged_attention_v2_reduce_heads( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner(type, 32, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 120, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \ + num_simd_lanes, partition_size); + +#define instantiate_paged_attention_block_size(type, cache_type, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \ + num_simd_lanes, partition_size); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 0 +#define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 0); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 512); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \ + instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512); + +instantiate_paged_attention_v1(float, float, 32); +instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v1(half, half, 32); + +instantiate_paged_attention_v1(float, uchar, 32); +instantiate_paged_attention_v1(bfloat16_t, uchar, 32); +instantiate_paged_attention_v1(half, uchar, 32); + +instantiate_paged_attention_v2_reduce(float, 32); +instantiate_paged_attention_v2_reduce(bfloat16_t, 32); +instantiate_paged_attention_v2_reduce(half, 32); + +instantiate_paged_attention_v2(float, float, 32); +instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v2(half, half, 32); + +instantiate_paged_attention_v2(float, uchar, 32); +instantiate_paged_attention_v2(bfloat16_t, uchar, 32); +instantiate_paged_attention_v2(half, uchar, 32); diff --git a/vllm_metal/metal/kernels/cache.mm b/vllm_metal/metal/kernels/cache.mm new file mode 100644 index 0000000..cf67260 --- /dev/null +++ b/vllm_metal/metal/kernels/cache.mm @@ -0,0 +1,562 @@ +#include +#include +#include + +#import +#import +#include +#include +#include + +static inline id getMTLBufferStorage(const torch::Tensor &tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +static std::string getModuleDirectory() { + Dl_info dl_info; + if (dladdr((void *)getModuleDirectory, &dl_info)) { + std::string path(dl_info.dli_fname); + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(0, pos); + } + } + return "."; +} + +void swap_blocks(torch::Tensor &src, torch::Tensor &dst, + const torch::Tensor &block_mapping) { + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const int64_t num_blocks = block_mapping.size(0); + + // Handle different device combinations + if (src.device().is_mps() && dst.device().is_mps()) { + // MPS to MPS: Use Metal blit encoder + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id commandBuffer = stream->commandBuffer(); + TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); + + dispatch_queue_t serialQueue = stream->queue(); + + dispatch_sync(serialQueue, ^{ + id blitEncoder = + [commandBuffer blitCommandEncoder]; + TORCH_CHECK(blitEncoder, "Failed to create blit command encoder"); + + id srcBuf = getMTLBufferStorage(src); + id dstBuf = getMTLBufferStorage(dst); + + for (int64_t i = 0; i < num_blocks; ++i) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + NSUInteger src_offset = src_block_number * block_size_in_bytes; + NSUInteger dst_offset = dst_block_number * block_size_in_bytes; + + [blitEncoder copyFromBuffer:srcBuf + sourceOffset:src_offset + toBuffer:dstBuf + destinationOffset:dst_offset + size:block_size_in_bytes]; + } + + [blitEncoder endEncoding]; + stream->synchronize(at::mps::SyncType::COMMIT); + }); + } + } else { + // Cross-device transfers (MPS-CPU, CPU-MPS, CPU-CPU): Use PyTorch's copy + for (int64_t i = 0; i < num_blocks; ++i) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + + // Copy the entire block + dst[dst_block_number].copy_(src[src_block_number]); + } + } +} + +void copy_blocks(const std::vector &key_caches, + const std::vector &value_caches, + const torch::Tensor &block_mapping) { + const int64_t num_layers = key_caches.size(); + TORCH_CHECK(num_layers == static_cast(value_caches.size()), + "key_caches and value_caches must have the same length"); + if (num_layers == 0) { + return; + } + + // --- Preconditions -------------------------------------------------- + torch::Device dev = key_caches[0].device(); + TORCH_CHECK(dev.is_mps(), "copy_blocks: expected MPS tensors"); + + // Move block_mapping to CPU if it's on MPS + torch::Tensor block_mapping_cpu = block_mapping; + if (block_mapping.device().is_mps()) { + block_mapping_cpu = block_mapping.cpu(); + } + + for (int64_t i = 0; i < num_layers; ++i) { + TORCH_CHECK(key_caches[i].device() == dev && + value_caches[i].device() == dev, + "All cache tensors must be on the same MPS device"); + TORCH_CHECK(key_caches[i].dtype() == value_caches[i].dtype(), + "Key/value cache dtype mismatch at layer ", i); + } + + const int64_t num_pairs = block_mapping.size(0); + const int32_t numel_per_block = + static_cast(key_caches[0][0].numel()); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Construct the full path to the metallib file + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + if (!lib) { + NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@", + metallibPathStr, error.localizedDescription); + } + + // Process each layer separately + for (int64_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + NSString *kernName = nil; + switch (key_caches[layer_idx].scalar_type()) { + case torch::kFloat: + kernName = @"copy_blocks_float"; + break; + case torch::kHalf: + kernName = @"copy_blocks_half"; + break; + case torch::kBFloat16: + kernName = @"copy_blocks_bfloat16_t"; + break; + case torch::kUInt8: + kernName = @"copy_blocks_uchar"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for copy_blocks"); + } + + id fn = [lib newFunctionWithName:kernName]; + TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set key and value cache buffers + [enc setBuffer:getMTLBufferStorage(key_caches[layer_idx]) + offset:key_caches[layer_idx].storage_offset() * + key_caches[layer_idx].element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(value_caches[layer_idx]) + offset:value_caches[layer_idx].storage_offset() * + value_caches[layer_idx].element_size() + atIndex:1]; + + // Set block mapping buffer + id mappingBuf = + [device newBufferWithBytes:block_mapping_cpu.data_ptr() + length:num_pairs * 2 * sizeof(int64_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:mappingBuf offset:0 atIndex:2]; + + // Set numel_per_block as buffer + id numelBuf = + [device newBufferWithBytes:&numel_per_block + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numelBuf offset:0 atIndex:3]; + + const uint32_t threadsPerThreadgroup = + std::min(256, numel_per_block); + MTLSize tg = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize grid = MTLSizeMake(threadsPerThreadgroup * num_pairs, 1, 1); + + [enc dispatchThreads:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + }); + } + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} + +void reshape_and_cache( + torch::Tensor &key, // [num_tokens, num_heads, head_size] + torch::Tensor &value, // [num_tokens, num_heads, head_size] + torch::Tensor + &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor + &value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor &slot_mapping, // [num_tokens] + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale) { + + // Determine cache dtype and FP8 usage + torch::ScalarType cache_dtype = key_cache.scalar_type(); + bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3"); + if (use_fp8_scales) { + TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type"); + TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars"); + TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32, + "FP8 scales must be float32"); + } + + TORCH_CHECK(key.device().is_mps() && value.device().is_mps() && + key_cache.device().is_mps() && value_cache.device().is_mps(), + "All tensors must be on MPS device"); + + // Move slot_mapping to CPU if it's on MPS + torch::Tensor slot_mapping_cpu = slot_mapping; + if (slot_mapping.device().is_mps()) { + slot_mapping_cpu = slot_mapping.cpu(); + } + + const int64_t num_tokens = key.size(0); + const int64_t num_heads = key.size(1); + const int64_t head_size = key.size(2); + const int64_t block_size = key_cache.size(3); + const int64_t x = key_cache.size(4); + + const int32_t key_stride = key.stride(0); + const int32_t value_stride = value.stride(0); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Construct the full path to the metallib file + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + if (!lib) { + NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@", + metallibPathStr, error.localizedDescription); + } + + NSString *kernName = nil; + std::string kv_dtype_str, cache_dtype_str; + + // Get KV dtype string + switch (key.scalar_type()) { + case torch::kFloat: + kv_dtype_str = "float"; + break; + case torch::kHalf: + kv_dtype_str = "half"; + break; + case torch::kBFloat16: + kv_dtype_str = "bfloat16_t"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache"); + } + + // Get cache dtype string + switch (cache_dtype) { + case torch::kFloat: + cache_dtype_str = "float"; + break; + case torch::kHalf: + cache_dtype_str = "half"; + break; + case torch::kBFloat16: + cache_dtype_str = "bfloat16_t"; + break; + case torch::kUInt8: + cache_dtype_str = "uchar"; + break; + default: + TORCH_CHECK(false, "Unsupported cache dtype for reshape_and_cache"); + } + + std::string kernName_str = "reshape_and_cache_kv_" + kv_dtype_str + "_cache_" + cache_dtype_str; + kernName = [NSString stringWithUTF8String:kernName_str.c_str()]; + + // Create function constants for FP8 support + MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init]; + [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:10]; + + id fn = [lib newFunctionWithName:kernName constantValues:constants error:&error]; + TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String, + error ? [NSString stringWithFormat:@": %@", error.localizedDescription].UTF8String : ""); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set tensor buffers + [enc setBuffer:getMTLBufferStorage(key) + offset:key.storage_offset() * key.element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(value) + offset:value.storage_offset() * value.element_size() + atIndex:1]; + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:2]; + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:3]; + + // Set slot mapping buffer + id slotMappingBuf = + [device newBufferWithBytes:slot_mapping_cpu.data_ptr() + length:num_tokens * sizeof(int64_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:slotMappingBuf offset:0 atIndex:4]; + + // k_scale and v_scale buffers (for FP8) + if (use_fp8_scales) { + [enc setBuffer:getMTLBufferStorage(k_scale) + offset:k_scale.storage_offset() * k_scale.element_size() + atIndex:5]; + [enc setBuffer:getMTLBufferStorage(v_scale) + offset:v_scale.storage_offset() * v_scale.element_size() + atIndex:6]; + } else { + // For non-FP8, we still need to increment buffer indices + // The Metal kernel expects buffers at indices 5 and 6 even if unused + } + + // Set parameters as individual buffers (matching mistralrs pattern) + id keyStrideBuf = + [device newBufferWithBytes:&key_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:keyStrideBuf offset:0 atIndex:7]; + + id valueStrideBuf = + [device newBufferWithBytes:&value_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:valueStrideBuf offset:0 atIndex:8]; + + const int32_t num_heads_i32 = static_cast(num_heads); + id numHeadsBuf = + [device newBufferWithBytes:&num_heads_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numHeadsBuf offset:0 atIndex:9]; + + const int32_t head_size_i32 = static_cast(head_size); + id headSizeBuf = + [device newBufferWithBytes:&head_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:headSizeBuf offset:0 atIndex:10]; + + const int32_t block_size_i32 = static_cast(block_size); + id blockSizeBuf = + [device newBufferWithBytes:&block_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:blockSizeBuf offset:0 atIndex:11]; + + const int32_t x_i32 = static_cast(x); + id xBuf = + [device newBufferWithBytes:&x_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:xBuf offset:0 atIndex:12]; + + const uint64_t threads_per_threadgroup = + std::min(512, num_heads * head_size); + MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1); + MTLSize grid = MTLSizeMake(num_tokens, 1, 1); + + [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + }); + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} + +void reshape_and_cache_flash( + torch::Tensor &key, // [num_tokens, num_heads, head_size] + torch::Tensor &value, // [num_tokens, num_heads, head_size] + torch::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor + &value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor &slot_mapping, // [num_tokens] + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale) { + + TORCH_CHECK(key.device().is_mps() && value.device().is_mps() && + key_cache.device().is_mps() && value_cache.device().is_mps(), + "All tensors must be on MPS device"); + + // Move slot_mapping to CPU if it's on MPS + torch::Tensor slot_mapping_cpu = slot_mapping; + if (slot_mapping.device().is_mps()) { + slot_mapping_cpu = slot_mapping.cpu(); + } + + const int64_t num_tokens = key.size(0); + const int64_t num_heads = key.size(1); + const int64_t head_size = key.size(2); + const int64_t block_size = key_cache.size(1); + + const int32_t key_stride = key.stride(0); + const int32_t value_stride = value.stride(0); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Construct the full path to the metallib file + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + if (!lib) { + NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@", + metallibPathStr, error.localizedDescription); + } + + NSString *kernName = nil; + switch (key.scalar_type()) { + case torch::kFloat: + kernName = @"reshape_and_cache_flash_float"; + break; + case torch::kHalf: + kernName = @"reshape_and_cache_flash_half"; + break; + case torch::kBFloat16: + kernName = @"reshape_and_cache_flash_bfloat16_t"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache_flash"); + } + + id fn = [lib newFunctionWithName:kernName]; + TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set tensor buffers + [enc setBuffer:getMTLBufferStorage(key) + offset:key.storage_offset() * key.element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(value) + offset:value.storage_offset() * value.element_size() + atIndex:1]; + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:2]; + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:3]; + + // Set slot mapping buffer + id slotMappingBuf = + [device newBufferWithBytes:slot_mapping_cpu.data_ptr() + length:num_tokens * sizeof(int64_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:slotMappingBuf offset:0 atIndex:4]; + + // Set parameters as individual buffers + id keyStrideBuf = + [device newBufferWithBytes:&key_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:keyStrideBuf offset:0 atIndex:5]; + + id valueStrideBuf = + [device newBufferWithBytes:&value_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:valueStrideBuf offset:0 atIndex:6]; + + const int32_t num_heads_i32 = static_cast(num_heads); + id numHeadsBuf = + [device newBufferWithBytes:&num_heads_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numHeadsBuf offset:0 atIndex:7]; + + const int32_t head_size_i32 = static_cast(head_size); + id headSizeBuf = + [device newBufferWithBytes:&head_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:headSizeBuf offset:0 atIndex:8]; + + const int32_t block_size_i32 = static_cast(block_size); + id blockSizeBuf = + [device newBufferWithBytes:&block_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:blockSizeBuf offset:0 atIndex:9]; + + const uint64_t threads_per_threadgroup = + std::min(512, num_heads * head_size); + MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1); + MTLSize grid = MTLSizeMake(num_tokens, 1, 1); + + [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + }); + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} \ No newline at end of file diff --git a/vllm_metal/metal/kernels/cache/copy_blocks.metal b/vllm_metal/metal/kernels/cache/copy_blocks.metal new file mode 100644 index 0000000..31595cf --- /dev/null +++ b/vllm_metal/metal/kernels/cache/copy_blocks.metal @@ -0,0 +1,51 @@ +#include "../utils.metal" +#include + +using namespace metal; + +template +[[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]], + device T *value_cache [[buffer(1)]], + const device int64_t *block_mapping [[buffer(2)]], + device const int &numel_per_block, + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup + [[threads_per_threadgroup]]) { + const int pair_idx = tgid; + + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + + // Copy key cache blocks + for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + + // Copy value cache blocks + for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +#define instantiate_copy_blocks(type) \ + template [[host_name("copy_blocks_" #type)]] [[kernel]] void \ + copy_blocks(device type * key_cache [[buffer(0)]], \ + device type * value_cache [[buffer(1)]], \ + const device int64_t *block_mapping [[buffer(2)]], \ + device const int &numel_per_block, \ + uint tgid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_copy_blocks(float); +instantiate_copy_blocks(bfloat16_t); +instantiate_copy_blocks(half); +instantiate_copy_blocks(uchar); diff --git a/vllm_metal/metal/kernels/cache/reshape_and_cache.metal b/vllm_metal/metal/kernels/cache/reshape_and_cache.metal new file mode 100644 index 0000000..28ff7db --- /dev/null +++ b/vllm_metal/metal/kernels/cache/reshape_and_cache.metal @@ -0,0 +1,193 @@ +#include "../utils.metal" +#include "../float8.metal" +#include + +using namespace metal; + +template +inline CACHE_T to_cache(KV_T v) = delete; + +template <> inline uchar to_cache(float v) { + return float_to_fp8_e4m3(v); +} + +template <> inline uchar to_cache(bfloat16_t v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline uchar to_cache(half v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline float to_cache(float v) { return v; } + +template <> inline bfloat16_t to_cache(bfloat16_t v) { + return v; +} + +template <> inline half to_cache(half v) { return v; } + +constant bool use_fp8_scales [[function_constant(10)]]; + +template +[[kernel]] void reshape_and_cache( + const device KV_T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device KV_T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device CACHE_T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x] + device CACHE_T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + const device float *__restrict__ k_scale + [[buffer(5)]], // [1] - only used when use_fp8_scales + const device float *__restrict__ v_scale + [[buffer(6)]], // [1] - only used when use_fp8_scales + device const int &key_stride [[buffer(7)]], + device const int &value_stride [[buffer(8)]], + device const int &num_heads [[buffer(9)]], + device const int &head_size [[buffer(10)]], + device const int &block_size [[buffer(11)]], + device const int &x [[buffer(12)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + + if (use_fp8_scales) { + key_cache[tgt_key_idx] = + to_cache(KV_T((float)key[src_key_idx] / *k_scale)); + value_cache[tgt_value_idx] = + to_cache(KV_T((float)value[src_value_idx] / *v_scale)); + } else { + key_cache[tgt_key_idx] = to_cache(key[src_key_idx]); + value_cache[tgt_value_idx] = to_cache(value[src_value_idx]); + } + } +} + +#define instantiate_reshape_and_cache(kv_type, cache_type) \ + template [[host_name("reshape_and_cache_kv_" #kv_type \ + "_cache_" #cache_type)]] [[kernel]] void \ + reshape_and_cache( \ + const device kv_type *__restrict__ key [[buffer(0)]], \ + const device kv_type *__restrict__ value [[buffer(1)]], \ + device cache_type *__restrict__ key_cache [[buffer(2)]], \ + device cache_type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + const device float *__restrict__ k_scale [[buffer(5)]], \ + const device float *__restrict__ v_scale [[buffer(6)]], \ + device const int &key_stride [[buffer(7)]], \ + device const int &value_stride [[buffer(8)]], \ + device const int &num_heads [[buffer(9)]], \ + device const int &head_size [[buffer(10)]], \ + device const int &block_size [[buffer(11)]], \ + device const int &x [[buffer(12)]], \ + uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache(float, float); +instantiate_reshape_and_cache(bfloat16_t, bfloat16_t); +instantiate_reshape_and_cache(half, half); + +instantiate_reshape_and_cache(float, uchar); +instantiate_reshape_and_cache(bfloat16_t, uchar); +instantiate_reshape_and_cache(half, uchar); + +// Flash version with different cache layout: [num_blocks, block_size, +// num_heads, head_size] +template +[[kernel]] void reshape_and_cache_flash( + const device T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, block_size, num_heads, head_size] + device T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, block_size, num_heads, head_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + device const int &key_stride, device const int &value_stride, + device const int &num_heads, device const int &head_size, + device const int &block_size, uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + // Flash cache layout: [num_blocks, block_size, num_heads, head_size] + const int64_t tgt_key_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + const int64_t tgt_value_idx = + block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + head_idx * head_size + + head_offset; + key_cache[tgt_key_idx] = key[src_key_idx]; + value_cache[tgt_value_idx] = value[src_value_idx]; + } +} + +#define instantiate_reshape_and_cache_flash(type) \ + template [[host_name("reshape_and_cache_flash_" #type)]] [[kernel]] void \ + reshape_and_cache_flash( \ + const device type *__restrict__ key [[buffer(0)]], \ + const device type *__restrict__ value [[buffer(1)]], \ + device type *__restrict__ key_cache [[buffer(2)]], \ + device type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + device const int &key_stride, device const int &value_stride, \ + device const int &num_heads, device const int &head_size, \ + device const int &block_size, uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache_flash(float); +instantiate_reshape_and_cache_flash(bfloat16_t); +instantiate_reshape_and_cache_flash(half); diff --git a/vllm_metal/metal/kernels/convert_fp8.metal b/vllm_metal/metal/kernels/convert_fp8.metal new file mode 100644 index 0000000..22028ce --- /dev/null +++ b/vllm_metal/metal/kernels/convert_fp8.metal @@ -0,0 +1,77 @@ +#include "float8.metal" +#include "utils.metal" +#include + +using namespace metal; + +// Convert between different precision formats for cache tensors +// This kernel handles conversions like float->fp8, fp8->float, etc. + +template +[[kernel]] void convert_fp8_kernel( + const device SRC_T *__restrict__ src [[buffer(0)]], + device DST_T *__restrict__ dst [[buffer(1)]], + const device float &scale [[buffer(2)]], + const device uint32_t &num_elements [[buffer(3)]], + uint gid [[thread_position_in_grid]]) { + + if (gid >= num_elements) { + return; + } + + // Load source value + SRC_T src_val = src[gid]; + + // Convert based on source and destination types + if constexpr (is_same_v && !is_same_v) { + // FP8 -> higher precision (dequantization) + float fp32_val = fp8_e4m3_to_float(src_val) * scale; + dst[gid] = static_cast(fp32_val); + } else if constexpr (!is_same_v && is_same_v) { + // Higher precision -> FP8 (quantization) + float fp32_val = static_cast(src_val) / scale; + dst[gid] = float_to_fp8_e4m3(fp32_val); + } else if constexpr (is_same_v && is_same_v) { + // FP8 -> FP8 (with rescaling) + float fp32_val = fp8_e4m3_to_float(src_val) * scale; + dst[gid] = float_to_fp8_e4m3(fp32_val); + } else { + // Regular precision -> regular precision (with scaling) + float fp32_val = static_cast(src_val) * scale; + dst[gid] = static_cast(fp32_val); + } +} + +// Instantiate all required combinations +#define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \ + template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \ + [[kernel]] void convert_fp8_kernel( \ + const device src_type *__restrict__ src [[buffer(0)]], \ + device dst_type *__restrict__ dst [[buffer(1)]], \ + const device float &scale [[buffer(2)]], \ + const device uint32_t &num_elements [[buffer(3)]], \ + uint gid [[thread_position_in_grid]]); + +// FP8 to other formats (dequantization) +INSTANTIATE_CONVERT_FP8(uchar, float); +INSTANTIATE_CONVERT_FP8(uchar, half); +INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t); + +// Other formats to FP8 (quantization) +INSTANTIATE_CONVERT_FP8(float, uchar); +INSTANTIATE_CONVERT_FP8(half, uchar); +INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar); + +// FP8 to FP8 (rescaling) +INSTANTIATE_CONVERT_FP8(uchar, uchar); + +// Regular precision conversions with scaling +INSTANTIATE_CONVERT_FP8(float, float); +INSTANTIATE_CONVERT_FP8(float, half); +INSTANTIATE_CONVERT_FP8(float, bfloat16_t); +INSTANTIATE_CONVERT_FP8(half, float); +INSTANTIATE_CONVERT_FP8(half, half); +INSTANTIATE_CONVERT_FP8(half, bfloat16_t); +INSTANTIATE_CONVERT_FP8(bfloat16_t, float); +INSTANTIATE_CONVERT_FP8(bfloat16_t, half); +INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t); \ No newline at end of file diff --git a/vllm_metal/metal/kernels/float8.metal b/vllm_metal/metal/kernels/float8.metal new file mode 100644 index 0000000..b911eba --- /dev/null +++ b/vllm_metal/metal/kernels/float8.metal @@ -0,0 +1,122 @@ +#include +using namespace metal; + +// Helpers ------------------------------------------------------------ +static inline uint as_bits(float x) { return as_type(x); } +static inline float from_bits(uint b) { return as_type(b); } + +// ------------------------------------------------------------------- +// FP8 E4M3 (bias = 7) +// ------------------------------------------------------------------- +inline float fp8_e4m3_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 3) & 0xF; + const uint man = v & 0x7; + + if (exp == 0) { // zero / sub-normal + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 8.f; // already scaled by 2^-3 + float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6 + return s ? -val : val; + } + + if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN) + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 8.f; + float val = ldexp(m, int(exp) - 7); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// FP8 E5M2 (bias = 15) +// ------------------------------------------------------------------- +inline float fp8_e5m2_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 2) & 0x1F; + const uint man = v & 0x3; + + if (exp == 0) { + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 4.f; + float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14 + return s ? -val : val; + } + + if (exp == 0x1F) { + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 4.f; + float val = ldexp(m, int(exp) - 15); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞) +// ------------------------------------------------------------------- +namespace detail { +template +inline uchar fp32_to_fp8(float f) { + const uint bits = as_bits(f); + const uint s = bits >> 31; + const uint abs = bits & 0x7FFFFFFF; + + // NaN propagates, Inf saturates + if (abs >= 0x7F800000u) { + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) | + (abs != 0x7F800000u)); + } + + int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent + uint m = abs & 0x7FFFFFu; // 23-bit mantissa + const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent + + // ---------- Normal path ------------------------------------------------- + int e_fp8 = e + BIAS; + if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) { + // round-to-nearest-even + const int shift = 23 - MAN_BITS; + uint mant = m >> shift; + const uint lsb = mant & 1u; + const uint round = (m >> (shift - 1)) & 1u; + const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u; + mant += (round & (sticky | lsb)); + if (mant >> MAN_BITS) { // mantissa overflow + mant = 0; + ++e_fp8; + if (e_fp8 > EXP_MAX) + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞ + } + return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) | + (mant & ((1u << MAN_BITS) - 1u))); + } + + // ---------- Sub-normal / under-flow ------------------------------------ + if (e_fp8 < 1 - MAN_BITS) // too small -> ±0 + return uchar(s << 7); + + // shift so that exponent becomes 1 + int rshift = (1 - e_fp8) + (23 - MAN_BITS); + uint mant = (0x800000u | m); // implicit 1 + uint rounded = (mant + (1u << (rshift - 1))) >> rshift; + if (rounded == 0) + return uchar(s << 7); // rounds to zero + + return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u))); +} +} // namespace detail + +inline uchar float_to_fp8_e4m3(float f) { + return detail::fp32_to_fp8<4, 3, 7>(f); +} +inline uchar float_to_fp8_e5m2(float f) { + return detail::fp32_to_fp8<5, 2, 15>(f); +} \ No newline at end of file diff --git a/vllm_metal/metal/kernels/paged_attention.mm b/vllm_metal/metal/kernels/paged_attention.mm new file mode 100644 index 0000000..71c398d --- /dev/null +++ b/vllm_metal/metal/kernels/paged_attention.mm @@ -0,0 +1,693 @@ +#include +#include +#include + +#import +#import +#include +#include +#include +#include +#include + +static inline id getMTLBufferStorage(const torch::Tensor &tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +static std::string getModuleDirectory() { + Dl_info dl_info; + if (dladdr((void *)getModuleDirectory, &dl_info)) { + std::string path(dl_info.dli_fname); + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(0, pos); + } + } + return "."; +} + +// Helper function to get kernel name based on dtype and parameters +static std::string getKernelName(const std::string &base_name, + torch::ScalarType dtype, + torch::ScalarType cache_dtype, + int head_size, + int block_size, int num_threads, + int num_simd_lanes, int partition_size = 0) { + std::string dtype_str; + switch (dtype) { + case torch::kFloat: + dtype_str = "float"; + break; + case torch::kHalf: + dtype_str = "half"; + break; + case torch::kBFloat16: + dtype_str = "bfloat16_t"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for paged attention: ", dtype); + } + + std::string cache_dtype_str; + switch (cache_dtype) { + case torch::kFloat: + cache_dtype_str = "float"; + break; + case torch::kHalf: + cache_dtype_str = "half"; + break; + case torch::kBFloat16: + cache_dtype_str = "bfloat16_t"; + break; + case torch::kUInt8: + cache_dtype_str = "uchar"; + break; + default: + TORCH_CHECK(false, "Unsupported cache dtype for paged attention: ", cache_dtype); + } + + std::string kernel_name = + base_name + "_" + dtype_str + "_cache_" + cache_dtype_str + "_hs" + std::to_string(head_size) + "_bs" + + std::to_string(block_size) + "_nt" + std::to_string(num_threads) + + "_nsl" + std::to_string(num_simd_lanes); + + if (partition_size >= 0) { + kernel_name += "_ps" + std::to_string(partition_size); + } + + return kernel_name; +} + +// Helper function to calculate shared memory size +static size_t calculateSharedMemorySize(int max_seq_len, int head_size, + int num_threads, int num_simd_lanes) { + // Logits storage: max_seq_len * sizeof(float) + size_t logits_size = max_seq_len * sizeof(float); + + // Reduction workspace: 2 * (num_threads / num_simd_lanes) * sizeof(float) + size_t reduction_size = 2 * (num_threads / num_simd_lanes) * sizeof(float); + + // Output workspace for cross-warp reduction: head_size * sizeof(float) + size_t output_size = head_size * sizeof(float); + return std::max(logits_size + reduction_size, output_size); +} + +// Helper function to get supported configurations +static bool isValidConfiguration(int head_size, int block_size) { + // Supported head sizes from the Metal kernel instantiations + std::vector supported_head_sizes = {32, 64, 80, 96, 112, + 120, 128, 192, 256}; + std::vector supported_block_sizes = {8, 16, 32}; + + return std::find(supported_head_sizes.begin(), supported_head_sizes.end(), + head_size) != supported_head_sizes.end() && + std::find(supported_block_sizes.begin(), supported_block_sizes.end(), + block_size) != supported_block_sizes.end(); +} + +void paged_attention_v1( + torch::Tensor &out, // [num_seqs, num_heads, head_size] + torch::Tensor &query, // [num_seqs, num_heads, head_size] + torch::Tensor + &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor + &value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor &seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional &alibi_slopes, + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + // Validate block sparse is not supported yet + // TODO: support blocksparse. + TORCH_CHECK( + !is_block_sparse, + "Block sparse attention is not yet supported in Metal implementation"); + + // Determine cache dtype based on kv_cache_dtype + torch::ScalarType cache_dtype = key_cache.scalar_type(); + bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3"); + if (use_fp8_scales) { + TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type"); + TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars"); + TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32, + "FP8 scales must be float32"); + } + + // Validate input tensors + TORCH_CHECK(out.device().is_mps() && query.device().is_mps() && + key_cache.device().is_mps() && + value_cache.device().is_mps() && + block_tables.device().is_mps() && seq_lens.device().is_mps(), + "All tensors must be on MPS device"); + + const int64_t num_seqs = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_size = query.size(2); + const int64_t max_num_blocks_per_seq = block_tables.size(1); + + // Validate configurations + TORCH_CHECK(isValidConfiguration(head_size, block_size), + "Unsupported head_size/block_size combination: ", head_size, "/", + block_size); + + // For v1, no partitioning - each sequence processed by one threadgroup + // Kernel configuration (should match the instantiated kernels) + const int num_threads = 256; + const int num_simd_lanes = 32; + const int partition_size = 0; // v1 doesn't use partitioning + + // Calculate shared memory requirements (from mistral.rs) + const int num_simds = num_threads / num_simd_lanes; + const int padded_max_context_len = + ((max_seq_len + block_size - 1) / block_size) * block_size; + const int logits_size = padded_max_context_len * sizeof(float); + const int outputs_size = (num_simds / 2) * head_size * sizeof(float); + const size_t shared_memory_size = std::max(logits_size, outputs_size); + + // Get kernel name - v1 kernels have partition_size=0 in their name + std::string kernel_name = + getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size, + block_size, num_threads, num_simd_lanes, partition_size); + + @autoreleasepool { + id device = MTLCreateSystemDefaultDevice(); + + // Load Metal library + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ", + error ? error.localizedDescription.UTF8String + : "unknown error"); + + // Create function constants for conditional compilation + MTLFunctionConstantValues *constants = + [[MTLFunctionConstantValues alloc] init]; + bool use_partitioning = false; + bool use_alibi = alibi_slopes.has_value(); + [constants setConstantValue:&use_partitioning + type:MTLDataTypeBool + atIndex:10]; + [constants setConstantValue:&use_alibi type:MTLDataTypeBool atIndex:20]; + [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:30]; + + NSString *kernelNameStr = + [NSString stringWithUTF8String:kernel_name.c_str()]; + id fn = [lib newFunctionWithName:kernelNameStr + constantValues:constants + error:&error]; + TORCH_CHECK( + fn, "Failed to create Metal function '", kernel_name, + "': ", error ? error.localizedDescription.UTF8String : "unknown error"); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, "Failed to create compute pipeline state: ", + error ? error.localizedDescription.UTF8String + : "unknown error"); + + // Setup command buffer and encoder + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer"); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute command encoder"); + + [enc setComputePipelineState:pso]; + + // Set threadgroup memory + [enc setThreadgroupMemoryLength:shared_memory_size atIndex:0]; + + // Buffer arguments (matching the Metal kernel signature) + int buffer_idx = 0; + + // Skip exp_sums and max_logits for v1 (buffers 0, 1) + buffer_idx = 2; + + // out buffer + [enc setBuffer:getMTLBufferStorage(out) + offset:out.storage_offset() * out.element_size() + atIndex:buffer_idx++]; + + // query buffer + [enc setBuffer:getMTLBufferStorage(query) + offset:query.storage_offset() * query.element_size() + atIndex:buffer_idx++]; + + // key_cache buffer + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:buffer_idx++]; + + // value_cache buffer + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:buffer_idx++]; + + // k_scale and v_scale (for FP8) + if (use_fp8_scales) { + [enc setBuffer:getMTLBufferStorage(k_scale) + offset:k_scale.storage_offset() * k_scale.element_size() + atIndex:buffer_idx++]; + [enc setBuffer:getMTLBufferStorage(v_scale) + offset:v_scale.storage_offset() * v_scale.element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx += 2; // Skip k_scale and v_scale buffer slots + } + + // num_kv_heads + int32_t num_kv_heads_i32 = static_cast(num_kv_heads); + [enc setBytes:&num_kv_heads_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // scale + float scale_f32 = static_cast(scale); + [enc setBytes:&scale_f32 length:sizeof(float) atIndex:buffer_idx++]; + + // softcapping (default to 1.0 for no capping) + float softcapping = 1.0f; + [enc setBytes:&softcapping length:sizeof(float) atIndex:buffer_idx++]; + + // block_tables buffer + [enc setBuffer:getMTLBufferStorage(block_tables) + offset:block_tables.storage_offset() * block_tables.element_size() + atIndex:buffer_idx++]; + + // seq_lens buffer (context_lens in kernel) + [enc setBuffer:getMTLBufferStorage(seq_lens) + offset:seq_lens.storage_offset() * seq_lens.element_size() + atIndex:buffer_idx++]; + + // max_num_blocks_per_seq + int32_t max_num_blocks_per_seq_i32 = + static_cast(max_num_blocks_per_seq); + [enc setBytes:&max_num_blocks_per_seq_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // alibi_slopes (optional) + if (use_alibi) { + [enc setBuffer:getMTLBufferStorage(alibi_slopes.value()) + offset:alibi_slopes.value().storage_offset() * + alibi_slopes.value().element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx++; // Skip this buffer slot + } + + // Stride parameters + int32_t q_stride = static_cast(query.stride(0)); + int32_t kv_block_stride = static_cast(key_cache.stride(0)); + int32_t kv_head_stride = static_cast(key_cache.stride(1)); + + [enc setBytes:&q_stride length:sizeof(int32_t) atIndex:buffer_idx++]; + [enc setBytes:&kv_block_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + [enc setBytes:&kv_head_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // Dispatch configuration + // Grid: (num_heads, num_seqs, 1) - no partitioning for v1 + MTLSize grid = MTLSizeMake(num_heads, num_seqs, 1); + MTLSize threadgroup = MTLSizeMake(num_threads, 1, 1); + + [enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup]; + [enc endEncoding]; + + stream->synchronize(at::mps::SyncType::COMMIT); + }); + } +} + +void paged_attention_v2( + torch::Tensor &out, // [num_seqs, num_heads, head_size] + torch::Tensor &exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor &max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor + &tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor &query, // [num_seqs, num_heads, head_size] + torch::Tensor + &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor + &value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor &seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional &alibi_slopes, + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + // TODO: support blocksparse. + // Validate block sparse is not supported yet + TORCH_CHECK( + !is_block_sparse, + "Block sparse attention is not yet supported in Metal implementation"); + + // Determine cache dtype based on kv_cache_dtype + torch::ScalarType cache_dtype = key_cache.scalar_type(); + bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3"); + if (use_fp8_scales) { + TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type"); + TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars"); + TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32, + "FP8 scales must be float32"); + } + + // Validate input tensors + TORCH_CHECK(out.device().is_mps() && query.device().is_mps() && + key_cache.device().is_mps() && + value_cache.device().is_mps() && exp_sums.device().is_mps() && + max_logits.device().is_mps() && tmp_out.device().is_mps() && + block_tables.device().is_mps() && seq_lens.device().is_mps(), + "All tensors must be on MPS device"); + + const int64_t num_seqs = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_size = query.size(2); + const int64_t max_num_blocks_per_seq = block_tables.size(1); + const int64_t max_num_partitions = exp_sums.size(2); + + // Validate configurations + TORCH_CHECK(isValidConfiguration(head_size, block_size), + "Unsupported head_size/block_size combination: ", head_size, "/", + block_size); + + // For v2, use partitioning (matching the instantiated kernels) + const int num_threads = 256; + const int num_simd_lanes = 32; + const int partition_size = 512; // v2 uses partitioning + + // Calculate shared memory requirements (from mistral.rs) + const int num_simds = num_threads / num_simd_lanes; + const int logits_size = partition_size * sizeof(float); + const int outputs_size = (num_simds / 2) * head_size * sizeof(float); + const size_t shared_memory_size = std::max(logits_size, outputs_size); + + // Get kernel names + std::string kernel_name = + getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size, + block_size, num_threads, num_simd_lanes, partition_size); + // Reduce kernel doesn't have block_size in its name + std::string reduce_kernel_name = "paged_attention_v2_reduce"; + switch (query.scalar_type()) { + case torch::kFloat: + reduce_kernel_name += "_float"; + break; + case torch::kHalf: + reduce_kernel_name += "_half"; + break; + case torch::kBFloat16: + reduce_kernel_name += "_bfloat16_t"; + break; + default: + TORCH_CHECK(false, + "Unsupported dtype for paged attention: ", query.scalar_type()); + } + reduce_kernel_name += "_hs" + std::to_string(head_size) + "_nt" + + std::to_string(num_threads) + "_nsl" + + std::to_string(num_simd_lanes) + "_ps" + + std::to_string(partition_size); + + @autoreleasepool { + id device = MTLCreateSystemDefaultDevice(); + + // Load Metal library + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ", + error ? error.localizedDescription.UTF8String + : "unknown error"); + + // Setup command buffer and queue + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer"); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + // ================================================================== + // Phase 1: Main paged attention kernel with partitioning + // ================================================================== + + // Create function constants for main kernel + MTLFunctionConstantValues *mainConstants = + [[MTLFunctionConstantValues alloc] init]; + bool use_partitioning = true; + bool use_alibi = alibi_slopes.has_value(); + [mainConstants setConstantValue:&use_partitioning + type:MTLDataTypeBool + atIndex:10]; + [mainConstants setConstantValue:&use_alibi + type:MTLDataTypeBool + atIndex:20]; + [mainConstants setConstantValue:&use_fp8_scales + type:MTLDataTypeBool + atIndex:30]; + + NSString *kernelNameStr = + [NSString stringWithUTF8String:kernel_name.c_str()]; + NSError *mainError = nil; + id mainFn = [lib newFunctionWithName:kernelNameStr + constantValues:mainConstants + error:&mainError]; + TORCH_CHECK(mainFn, "Failed to create Metal function '", kernel_name, + "': ", + mainError ? mainError.localizedDescription.UTF8String + : "unknown error"); + + NSError *psoError = nil; + id mainPso = + [device newComputePipelineStateWithFunction:mainFn error:&psoError]; + TORCH_CHECK(mainPso, "Failed to create compute pipeline state: ", + psoError ? psoError.localizedDescription.UTF8String + : "unknown error"); + + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute command encoder"); + + [enc setComputePipelineState:mainPso]; + [enc setThreadgroupMemoryLength:shared_memory_size atIndex:0]; + + // Set buffers for main kernel + int buffer_idx = 0; + + // exp_sums buffer + [enc setBuffer:getMTLBufferStorage(exp_sums) + offset:exp_sums.storage_offset() * exp_sums.element_size() + atIndex:buffer_idx++]; + + // max_logits buffer + [enc setBuffer:getMTLBufferStorage(max_logits) + offset:max_logits.storage_offset() * max_logits.element_size() + atIndex:buffer_idx++]; + + // tmp_out buffer + [enc setBuffer:getMTLBufferStorage(tmp_out) + offset:tmp_out.storage_offset() * tmp_out.element_size() + atIndex:buffer_idx++]; + + // query buffer + [enc setBuffer:getMTLBufferStorage(query) + offset:query.storage_offset() * query.element_size() + atIndex:buffer_idx++]; + + // key_cache buffer + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:buffer_idx++]; + + // value_cache buffer + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:buffer_idx++]; + + // k_scale and v_scale (for FP8) + if (use_fp8_scales) { + [enc setBuffer:getMTLBufferStorage(k_scale) + offset:k_scale.storage_offset() * k_scale.element_size() + atIndex:buffer_idx++]; + [enc setBuffer:getMTLBufferStorage(v_scale) + offset:v_scale.storage_offset() * v_scale.element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx += 2; // Skip k_scale and v_scale buffer slots + } + + // num_kv_heads + int32_t num_kv_heads_i32 = static_cast(num_kv_heads); + [enc setBytes:&num_kv_heads_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // scale + float scale_f32 = static_cast(scale); + [enc setBytes:&scale_f32 length:sizeof(float) atIndex:buffer_idx++]; + + // softcapping (default to 1.0 for no capping) + float softcapping = 1.0f; + [enc setBytes:&softcapping length:sizeof(float) atIndex:buffer_idx++]; + + // block_tables buffer + [enc setBuffer:getMTLBufferStorage(block_tables) + offset:block_tables.storage_offset() * block_tables.element_size() + atIndex:buffer_idx++]; + + // seq_lens buffer (context_lens in kernel) + [enc setBuffer:getMTLBufferStorage(seq_lens) + offset:seq_lens.storage_offset() * seq_lens.element_size() + atIndex:buffer_idx++]; + + // max_num_blocks_per_seq + int32_t max_num_blocks_per_seq_i32 = + static_cast(max_num_blocks_per_seq); + [enc setBytes:&max_num_blocks_per_seq_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // alibi_slopes (optional) + if (use_alibi) { + [enc setBuffer:getMTLBufferStorage(alibi_slopes.value()) + offset:alibi_slopes.value().storage_offset() * + alibi_slopes.value().element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx++; // Skip this buffer slot + } + + // Stride parameters + int32_t q_stride = static_cast(query.stride(0)); + int32_t kv_block_stride = static_cast(key_cache.stride(0)); + int32_t kv_head_stride = static_cast(key_cache.stride(1)); + + [enc setBytes:&q_stride length:sizeof(int32_t) atIndex:buffer_idx++]; + [enc setBytes:&kv_block_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + [enc setBytes:&kv_head_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // Dispatch main kernel + // Grid: (num_heads, num_seqs, max_num_partitions) - with partitioning for + // v2 + MTLSize mainGrid = MTLSizeMake(num_heads, num_seqs, max_num_partitions); + MTLSize mainThreadgroup = MTLSizeMake(num_threads, 1, 1); + + [enc dispatchThreadgroups:mainGrid threadsPerThreadgroup:mainThreadgroup]; + [enc endEncoding]; + + // ================================================================== + // Phase 2: Reduction kernel to combine partitions + // ================================================================== + + // Create reduction kernel + NSString *reduceKernelNameStr = + [NSString stringWithUTF8String:reduce_kernel_name.c_str()]; + id reduceFn = [lib newFunctionWithName:reduceKernelNameStr]; + TORCH_CHECK(reduceFn, "Failed to create Metal function '", + reduce_kernel_name, "'"); + + NSError *reducePsoError = nil; + id reducePso = + [device newComputePipelineStateWithFunction:reduceFn + error:&reducePsoError]; + TORCH_CHECK( + reducePso, "Failed to create compute pipeline state for reduction: ", + reducePsoError ? reducePsoError.localizedDescription.UTF8String + : "unknown error"); + + // Calculate shared memory for reduction kernel + size_t reduce_shared_memory_size = + max_num_partitions * sizeof(float) * 2; // max_logits + exp_sums + + id reduceEnc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(reduceEnc, + "Failed to create compute command encoder for reduction"); + + [reduceEnc setComputePipelineState:reducePso]; + [reduceEnc setThreadgroupMemoryLength:reduce_shared_memory_size + atIndex:0]; + + // Set buffers for reduction kernel + buffer_idx = 0; + + // out buffer (final output) + [reduceEnc setBuffer:getMTLBufferStorage(out) + offset:out.storage_offset() * out.element_size() + atIndex:buffer_idx++]; + + // exp_sums buffer + [reduceEnc setBuffer:getMTLBufferStorage(exp_sums) + offset:exp_sums.storage_offset() * exp_sums.element_size() + atIndex:buffer_idx++]; + + // max_logits buffer + [reduceEnc + setBuffer:getMTLBufferStorage(max_logits) + offset:max_logits.storage_offset() * max_logits.element_size() + atIndex:buffer_idx++]; + + // tmp_out buffer + [reduceEnc setBuffer:getMTLBufferStorage(tmp_out) + offset:tmp_out.storage_offset() * tmp_out.element_size() + atIndex:buffer_idx++]; + + // seq_lens buffer (context_lens in kernel) + [reduceEnc setBuffer:getMTLBufferStorage(seq_lens) + offset:seq_lens.storage_offset() * seq_lens.element_size() + atIndex:buffer_idx++]; + + // max_num_partitions + int32_t max_num_partitions_i32 = static_cast(max_num_partitions); + [reduceEnc setBytes:&max_num_partitions_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // Dispatch reduction kernel + // Grid: (num_heads, num_seqs) - one threadgroup per sequence/head + // combination + MTLSize reduceGrid = MTLSizeMake(num_heads, num_seqs, 1); + MTLSize reduceThreadgroup = MTLSizeMake(num_threads, 1, 1); + + [reduceEnc dispatchThreadgroups:reduceGrid + threadsPerThreadgroup:reduceThreadgroup]; + [reduceEnc endEncoding]; + + stream->synchronize(at::mps::SyncType::COMMIT); + }); + } +} \ No newline at end of file diff --git a/vllm_metal/metal/kernels/utils.metal b/vllm_metal/metal/kernels/utils.metal new file mode 100644 index 0000000..d3b638a --- /dev/null +++ b/vllm_metal/metal/kernels/utils.metal @@ -0,0 +1,246 @@ +#include +using namespace metal; + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template >::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif diff --git a/vllm_metal/metal/kernels_v1/copy_blocks.metal b/vllm_metal/metal/kernels_v1/copy_blocks.metal new file mode 100644 index 0000000..8408947 --- /dev/null +++ b/vllm_metal/metal/kernels_v1/copy_blocks.metal @@ -0,0 +1,58 @@ +#include "utils.metal" +#include + +using namespace metal; + +template +[[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]], + device T *value_cache [[buffer(1)]], + const device int64_t *block_mapping [[buffer(2)]], + device const int &numel_per_block_key, + device const int &numel_per_block_value, + uint gid [[thread_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup + [[threads_per_threadgroup]]) { + const int pair_idx = gid; + + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset_key = src_block_number * numel_per_block_key; + const int64_t dst_block_offset_key = dst_block_number * numel_per_block_key; + + // Copy key cache blocks + for (int i = tid; i < numel_per_block_key; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset_key + i; + int64_t dst_offset = dst_block_offset_key + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + + const int64_t src_block_offset_value = + src_block_number * numel_per_block_value; + const int64_t dst_block_offset_value = + dst_block_number * numel_per_block_value; + + // Copy value cache blocks + for (int i = tid; i < numel_per_block_value; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset_value + i; + int64_t dst_offset = dst_block_offset_value + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +#define instantiate_copy_blocks(type) \ + template [[host_name("copy_blocks_" #type)]] [[kernel]] void \ + copy_blocks(device type * key_cache_ptrs [[buffer(0)]], \ + device type * value_cache_ptrs [[buffer(1)]], \ + const device int64_t *block_mapping [[buffer(2)]], \ + device const int &numel_per_block_key, \ + device const int &numel_per_block_value, \ + uint gid [[thread_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_copy_blocks(float); +instantiate_copy_blocks(bfloat16_t); +instantiate_copy_blocks(half); +instantiate_copy_blocks(uchar); diff --git a/vllm_metal/metal/kernels_v1/float8.metal b/vllm_metal/metal/kernels_v1/float8.metal new file mode 100644 index 0000000..3773ca8 --- /dev/null +++ b/vllm_metal/metal/kernels_v1/float8.metal @@ -0,0 +1,149 @@ +#include +using namespace metal; + +// Helpers ------------------------------------------------------------ +static inline uint as_bits(float x) { return as_type(x); } +static inline float from_bits(uint b) { return as_type(b); } + +// ------------------------------------------------------------------- +// FP8 E4M3 (bias = 7) +// ------------------------------------------------------------------- +inline float fp8_e4m3_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 3) & 0xF; + const uint man = v & 0x7; + + if (exp == 0) { // zero / sub-normal + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 8.f; // already scaled by 2^-3 + float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6 + return s ? -val : val; + } + + // E4M3 has NO infinity - only NaN when exp=15 and mantissa=7 + if (exp == 0xF && man == 0x7) { + return NAN; + } + + // Normalized (including exp=0xF with mantissa 0-6, which are valid numbers) + const float m = 1.f + float(man) / 8.f; + float val = ldexp(m, int(exp) - 7); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// FP8 E5M2 (bias = 15) +// ------------------------------------------------------------------- +inline float fp8_e5m2_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 2) & 0x1F; + const uint man = v & 0x3; + + if (exp == 0) { + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 4.f; + float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14 + return s ? -val : val; + } + + if (exp == 0x1F) { + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 4.f; + float val = ldexp(m, int(exp) - 15); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞) +// ------------------------------------------------------------------- +namespace detail { +template +inline uchar fp32_to_fp8(float f) { + const uint bits = as_bits(f); + const uint s = bits >> 31; + const uint abs = bits & 0x7FFFFFFF; + + // NaN propagates, Inf saturates + if (abs >= 0x7F800000u) { + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) | + (abs != 0x7F800000u)); + } + + int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent + uint m = abs & 0x7FFFFFu; // 23-bit mantissa + const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent + + // ---------- Normal path ------------------------------------------------- + int e_fp8 = e + BIAS; + if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) { + // round-to-nearest-even + const int shift = 23 - MAN_BITS; + uint mant = m >> shift; + const uint lsb = mant & 1u; + const uint round = (m >> (shift - 1)) & 1u; + const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u; + mant += (round & (sticky | lsb)); + if (mant >> MAN_BITS) { // mantissa overflow + mant = 0; + ++e_fp8; + if (e_fp8 > EXP_MAX) + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞ + } + return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) | + (mant & ((1u << MAN_BITS) - 1u))); + } + + // ---------- Sub-normal / under-flow ------------------------------------ + if (e_fp8 < 1 - MAN_BITS) // too small -> ±0 + return uchar(s << 7); + + // shift so that exponent becomes 1 + int rshift = (1 - e_fp8) + (23 - MAN_BITS); + uint mant = (0x800000u | m); // implicit 1 + uint rounded = (mant + (1u << (rshift - 1))) >> rshift; + if (rounded == 0) + return uchar(s << 7); // rounds to zero + + return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u))); +} +} // namespace detail + +inline uchar float_to_fp8_e4m3(float f) { + // E4M3 has no infinity - must handle specially + // Max value is 448 (exp=15, mantissa=6), mantissa=7 is NaN + + if (isnan(f)) { + return 0x7F; // positive NaN (exp=15, mantissa=7) + } + + const uint bits = as_bits(f); + const uint s = bits >> 31; + + // Clamp infinity and overflow to max value (448) + if (isinf(f) || fabs(f) > 448.0f) { + // E4M3 max: exp=15, mantissa=6 (value = 1.75 * 2^8 = 448) + return uchar((s << 7) | (0xF << 3) | 0x6); + } + + // Use the template for normal values, but check result + uchar result = detail::fp32_to_fp8<4, 3, 7>(f); + + // Ensure we don't accidentally create NaN or invalid encoding + uint exp_bits = (result >> 3) & 0xF; + uint man_bits = result & 0x7; + if (exp_bits == 0xF && man_bits == 0x7) { + // Would be NaN, clamp to max value instead + return uchar((s << 7) | (0xF << 3) | 0x6); + } + + return result; +} +inline uchar float_to_fp8_e5m2(float f) { + return detail::fp32_to_fp8<5, 2, 15>(f); +} diff --git a/vllm_metal/metal/kernels_v1/gather_kv_cache.metal b/vllm_metal/metal/kernels_v1/gather_kv_cache.metal new file mode 100644 index 0000000..82bba47 --- /dev/null +++ b/vllm_metal/metal/kernels_v1/gather_kv_cache.metal @@ -0,0 +1,161 @@ +#include "utils.metal" +#include + +using namespace metal; + +// Convert from cache type to output type, with optional FP8 dequantization. +template +inline OUT_T from_cache(CACHE_T v) = delete; + +// Identity conversions (cache_t == out_t) +template <> inline float from_cache(float v) { return v; } +template <> inline bfloat16_t from_cache(bfloat16_t v) { + return v; +} +template <> inline half from_cache(half v) { return v; } + +// FP8 E4M3 -> output type conversions +template <> inline float from_cache(uchar v) { + return fp8_e4m3_to_float(v); +} +template <> inline half from_cache(uchar v) { + return (half)fp8_e4m3_to_float(v); +} +template <> inline bfloat16_t from_cache(uchar v) { + return (bfloat16_t)fp8_e4m3_to_float(v); +} + +constant bool use_fp8_scales [[function_constant(10)]]; + +/// Gather K and V from paged KV cache into contiguous output tensors. +/// +/// One threadgroup per output token. Threads cooperatively copy +/// kv_heads * head_size elements for both K and V. +/// +/// Uses binary search on cu_seq_lens to find batch_id. +/// +/// K cache layout: [num_blocks, kv_heads, head_size/x, block_size, x] +/// V cache layout: [num_blocks, kv_heads, head_size, block_size] +/// K/V output: [num_tokens, kv_heads, head_size] +template +[[kernel]] void gather_kv_cache( + const device CACHE_T *__restrict__ key_cache + [[buffer(0)]], // [num_blocks, kv_heads, head_size/x, block_size, x] + const device CACHE_T *__restrict__ value_cache + [[buffer(1)]], // [num_blocks, kv_heads, head_size, block_size] + device OUT_T *__restrict__ k_out + [[buffer(2)]], // [num_tokens, kv_heads, head_size] + device OUT_T *__restrict__ v_out + [[buffer(3)]], // [num_tokens, kv_heads, head_size] + const device float *__restrict__ k_scale + [[buffer(4), function_constant(use_fp8_scales)]], + const device float *__restrict__ v_scale + [[buffer(5), function_constant(use_fp8_scales)]], + const device int *__restrict__ block_table + [[buffer(6)]], // [batch, max_blocks] + const device int *__restrict__ cu_seq_lens [[buffer(7)]], // [batch + 1] + device const int &num_tokens [[buffer(8)]], + device const int &num_seqs [[buffer(9)]], + device const int &block_size [[buffer(10)]], + device const int &block_table_stride [[buffer(11)]], + device const int &num_kv_heads [[buffer(12)]], + device const int &head_size [[buffer(13)]], + device const int &x [[buffer(14)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int token_id = gid; + if (token_id >= num_tokens) { + return; + } + + // Binary search cu_seq_lens to find batch_id + int lo = 0, hi = num_seqs; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (cu_seq_lens[mid] <= token_id) { + lo = mid; + } else { + hi = mid - 1; + } + } + const int batch_id = lo; + + const int batch_offset = token_id - cu_seq_lens[batch_id]; + const int block_table_id = batch_offset / block_size; + const int slot = batch_offset % block_size; + const int block_id = + block_table[batch_id * block_table_stride + block_table_id]; + + const int n = num_kv_heads * head_size; + const long out_base = (long)token_id * num_kv_heads * head_size; + + // Precompute strides + const long k_block_stride = + (long)num_kv_heads * (head_size / x) * block_size * x; + const long k_head_stride = (long)(head_size / x) * block_size * x; + const long v_block_stride = (long)num_kv_heads * head_size * block_size; + const long v_head_stride = (long)head_size * block_size; + + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int head_idx = i / head_size; + const int d = i % head_size; + + // K: [block_id, head_idx, d/x, slot, d%x] + const int x_idx = d / x; + const int x_offset = d % x; + const long k_src_idx = (long)block_id * k_block_stride + + head_idx * k_head_stride + x_idx * block_size * x + + slot * x + x_offset; + + // V: [block_id, head_idx, d, slot] + const long v_src_idx = (long)block_id * v_block_stride + + head_idx * v_head_stride + d * block_size + slot; + + if (use_fp8_scales) { + k_out[out_base + i] = OUT_T( + (float)from_cache(key_cache[k_src_idx]) * (*k_scale)); + v_out[out_base + i] = + OUT_T((float)from_cache(value_cache[v_src_idx]) * + (*v_scale)); + } else { + k_out[out_base + i] = from_cache(key_cache[k_src_idx]); + v_out[out_base + i] = from_cache(value_cache[v_src_idx]); + } + } +} + +#define instantiate_gather_kv_cache(cache_type, out_type) \ + template [[host_name("gather_kv_cache_cache_" #cache_type \ + "_out_" #out_type)]] [[kernel]] void \ + gather_kv_cache( \ + const device cache_type *__restrict__ key_cache [[buffer(0)]], \ + const device cache_type *__restrict__ value_cache [[buffer(1)]], \ + device out_type *__restrict__ k_out [[buffer(2)]], \ + device out_type *__restrict__ v_out [[buffer(3)]], \ + const device float *__restrict__ k_scale \ + [[buffer(4), function_constant(use_fp8_scales)]], \ + const device float *__restrict__ v_scale \ + [[buffer(5), function_constant(use_fp8_scales)]], \ + const device int *__restrict__ block_table [[buffer(6)]], \ + const device int *__restrict__ cu_seq_lens [[buffer(7)]], \ + device const int &num_tokens [[buffer(8)]], \ + device const int &num_seqs [[buffer(9)]], \ + device const int &block_size [[buffer(10)]], \ + device const int &block_table_stride [[buffer(11)]], \ + device const int &num_kv_heads [[buffer(12)]], \ + device const int &head_size [[buffer(13)]], \ + device const int &x [[buffer(14)]], \ + uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +// Same-type (no dequant) +instantiate_gather_kv_cache(float, float); +instantiate_gather_kv_cache(bfloat16_t, bfloat16_t); +instantiate_gather_kv_cache(half, half); + +// FP8 E4M3 -> compute type (dequant) +instantiate_gather_kv_cache(uchar, float); +instantiate_gather_kv_cache(uchar, bfloat16_t); +instantiate_gather_kv_cache(uchar, half); diff --git a/vllm_metal/metal/kernels_v1/kv_scale_update.metal b/vllm_metal/metal/kernels_v1/kv_scale_update.metal new file mode 100644 index 0000000..f57df96 --- /dev/null +++ b/vllm_metal/metal/kernels_v1/kv_scale_update.metal @@ -0,0 +1,95 @@ +#include "utils.metal" +#include + +using namespace metal; + +#define DIV_CONST 240.0f + +template +[[kernel]] void kv_scale_update(const device T *k [[buffer(0)]], + const device T *v [[buffer(1)]], + device atomic *k_scale [[buffer(2)]], + device atomic *v_scale [[buffer(3)]], + constant long &num_elements [[buffer(4)]], + uint gid [[thread_position_in_grid]], + uint grid_size [[threads_per_grid]], + uint tid [[thread_position_in_threadgroup]], + uint tg_size [[threads_per_threadgroup]], + threadgroup float *shared_k [[threadgroup(0)]], + threadgroup float *shared_v + [[threadgroup(1)]]) { + + // Per-thread local maxima + float local_max_k = 0.0f; + float local_max_v = 0.0f; + + // Strided loop covering entire array + for (long idx = gid; idx < num_elements; idx += grid_size) { + float avk = abs(static_cast(k[idx])); + float avv = abs(static_cast(v[idx])); + local_max_k = max(local_max_k, avk); + local_max_v = max(local_max_v, avv); + } + + // Store per-thread maxima to shared memory + shared_k[tid] = local_max_k; + shared_v[tid] = local_max_v; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Parallel reduction in shared memory to find block maxima + for (uint s = tg_size / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_k[tid] = max(shared_k[tid], shared_k[tid + s]); + shared_v[tid] = max(shared_v[tid], shared_v[tid + s]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Thread 0 of block updates global scales atomically + if (tid == 0) { + float candidate_k_scale = shared_k[0] / DIV_CONST; + float candidate_v_scale = shared_v[0] / DIV_CONST; + + // Atomic max update for k_scale + if (candidate_k_scale > 0.0f) { + float current = atomic_load_explicit(k_scale, memory_order_relaxed); + while (candidate_k_scale > current) { + if (atomic_compare_exchange_weak_explicit( + k_scale, ¤t, candidate_k_scale, memory_order_relaxed, + memory_order_relaxed)) { + break; + } + } + } + + // Atomic max update for v_scale + if (candidate_v_scale > 0.0f) { + float current = atomic_load_explicit(v_scale, memory_order_relaxed); + while (candidate_v_scale > current) { + if (atomic_compare_exchange_weak_explicit( + v_scale, ¤t, candidate_v_scale, memory_order_relaxed, + memory_order_relaxed)) { + break; + } + } + } + } +} + +#define instantiate_kv_scale_update(type) \ + template [[host_name("kv_scale_update_" #type)]] [[kernel]] void \ + kv_scale_update(const device type *k [[buffer(0)]], \ + const device type *v [[buffer(1)]], \ + device atomic *k_scale [[buffer(2)]], \ + device atomic *v_scale [[buffer(3)]], \ + constant long &num_elements [[buffer(4)]], \ + uint gid [[thread_position_in_grid]], \ + uint grid_size [[threads_per_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint tg_size [[threads_per_threadgroup]], \ + threadgroup float *shared_k [[threadgroup(0)]], \ + threadgroup float *shared_v [[threadgroup(1)]]); + +instantiate_kv_scale_update(float); +instantiate_kv_scale_update(bfloat16_t); +instantiate_kv_scale_update(half); diff --git a/vllm_metal/metal/kernels_v1/pagedattention.metal b/vllm_metal/metal/kernels_v1/pagedattention.metal new file mode 100644 index 0000000..be3a9e5 --- /dev/null +++ b/vllm_metal/metal/kernels_v1/pagedattention.metal @@ -0,0 +1,1429 @@ +// Portions of this file are adapted from Apple's MLX framework +// (https://github.com/ml-explore/mlx) +// Licensed under the Apache License 2.0 +// Copyright © 2023 Apple Inc. + +// Portions of this file are adapted from the vLLM project +// (https://github.com/vllm-project/vllm) +// Licensed under the Apache License 2.0 +// Copyright contributors to the vLLM project + +#include "utils.metal" +#include +#include + +using namespace metal; + +// ========================================== Generic vector types + +// A vector type to store Q, K, V elements. +template struct Vec {}; + +// A vector type to store FP32 accumulators. +template struct FloatVec {}; + +// Template vector operations. +template inline Acc mul(A a, B b); + +template inline float sum(T v); + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +// FP32 vector data types. +struct Float8_ { + float4 x; + float4 y; +}; + +template <> struct Vec { + using Type = float; +}; +template <> struct Vec { + using Type = float2; +}; +template <> struct Vec { + using Type = float4; +}; +template <> struct Vec { + using Type = Float8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(float a, float b) { return a * b; } + +template <> inline float2 mul(float2 a, float2 b) { return a * b; } + +template <> inline float4 mul(float4 a, float4 b) { return a * b; } + +template <> inline Float8_ mul(Float8_ a, Float8_ b) { + Float8_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float sum(float a) { return a; } + +template <> inline float sum(float2 a) { return a.x + a.y; } + +template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); } + +inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { + Float8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread float &dst, float src) { dst = src; } +inline void from_float(thread float2 &dst, float2 src) { dst = src; } +inline void from_float(thread float4 &dst, float4 src) { dst = src; } +inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; } + +// BF16 vector data types. +// #if defined(__HAVE_BFLOAT__) + +// struct Bfloat8_ { +// bfloat4 x; +// bfloat4 y; +// }; + +// template<> +// struct Vec { +// using Type = bfloat; +// }; +// template<> +// struct Vec { +// using Type = bfloat2; +// }; +// template<> +// struct Vec { +// using Type = bfloat4; +// }; +// template<> +// struct Vec { +// using Type = Bfloat8_; +// }; + +// template<> +// struct FloatVec { +// using Type = float; +// }; +// template<> +// struct FloatVec { +// using Type = float2; +// }; +// template<> +// struct FloatVec { +// using Type = float4; +// }; +// template<> +// struct FloatVec { +// using Type = Float8_; +// }; + +// template<> +// inline float mul(bfloat a, bfloat b) { +// return (float)a * (float)b; +// } +// template<> +// inline bfloat mul(bfloat a, bfloat b) { +// return a*b; +// } + +// template<> +// inline float2 mul(bfloat2 a, bfloat2 b) { +// return (float2)a * (float2)b; +// } +// template<> +// inline bfloat2 mul(bfloat2 a, bfloat2 b) { +// return a * b; +// } + +// template<> +// inline float4 mul(bfloat4 a, bfloat4 b) { +// return (float4)a * (float4)b; +// } +// template<> +// inline bfloat4 mul(bfloat4 a, bfloat4 b) { +// return a * b; +// } + +// template<> +// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Float8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } +// template<> +// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Bfloat8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } + +// template<> +// inline float sum(bfloat a) { +// return (float)a; +// } + +// template<> +// inline float sum(bfloat2 a) { +// return (float)a.x + (float)a.y; +// } + +// template<> +// inline float sum(bfloat4 a) { +// return sum(a.x) + sum(a.y); +// } + +// template<> +// inline float sum(Bfloat8_ a) { +// return sum(a.x) + sum(a.y); +// } + +// inline float fma(bfloat a, bfloat b, float c) { +// return (float)a * (float)b + c; +// } + +// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) { +// return (float2)a * (float2)b + c; +// } + +// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) { +// return (float4)a * (float4)b + c; +// } + +// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { +// Float8_ res; +// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y); +// return res; +// } +// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { +// Bfloat8_ res; +// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y); +// return c; +// } + +// inline void from_float(thread bfloat& dst, float src) { +// dst = static_cast(src); +// } +// inline void from_float(thread bfloat2& dst, float2 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// } +// inline void from_float(thread bfloat4& dst, float4 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// dst.z = static_cast(src.z); +// dst.w = static_cast(src.w); +// } +// inline void from_float(thread Bfloat8_& dst, Float8_ src) { +// bfloat4 x; +// bfloat4 y; +// from_float(x, src.x); +// from_float(y, src.y); +// dst.x = x; +// dst.y = y; +// } + +// #else + +struct Bfloat2_ { + bfloat16_t x; + bfloat16_t y; +}; + +struct Bfloat4_ { + Bfloat2_ x; + Bfloat2_ y; +}; + +struct Bfloat8_ { + Bfloat4_ x; + Bfloat4_ y; +}; + +template <> struct Vec { + using Type = bfloat16_t; +}; +template <> struct Vec { + using Type = Bfloat2_; +}; +template <> struct Vec { + using Type = Bfloat4_; +}; +template <> struct Vec { + using Type = Bfloat8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(bfloat16_t a, bfloat16_t b) { + return (float)a * (float)b; +} +template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; } + +template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f; +} +template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { + Bfloat2_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) { + float2 x = mul(a.x, b.x); + float2 y = mul(a.y, b.y); + float4 c; + c.x = x.x; + c.y = x.y; + c.z = y.x; + c.w = y.y; + return c; +} +template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { + Bfloat4_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { + Float8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} +template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { + Bfloat8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(bfloat16_t a) { return (float)a; } + +template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); } + +template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(bfloat16_t a, bfloat16_t b, float c) { + return (float)a * (float)b + c; +} +inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) { + return a * b + c; +} + +inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f + c; +} +inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) { + Bfloat2_ res; + res.x = a.x * b.x + c.x; + res.y = a.y * b.y + c.y; + return res; +} + +inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) { + float4 res; + res.x = fma(a.x.x, b.x.x, c.x); + res.y = fma(a.x.y, b.x.y, c.y); + res.z = fma(a.y.x, b.y.x, c.z); + res.w = fma(a.y.y, b.y.y, c.w); + return res; +} +inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) { + Bfloat4_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { + Bfloat8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread bfloat16_t &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread Bfloat2_ &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread Bfloat4_ &dst, float4 src) { + dst.x.x = static_cast(src.x); + dst.x.y = static_cast(src.y); + dst.y.x = static_cast(src.z); + dst.y.y = static_cast(src.w); +} +inline void from_float(thread Bfloat8_ &dst, Float8_ src) { + Bfloat4_ x; + Bfloat4_ y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// #endif + +// ========================================== FP8 (uchar) vector data types. + +// 8‑lane uchar vector – Metal only provides up to uchar4, so build our own. +struct Uchar8_ { + uchar4 x; + uchar4 y; +}; + +// Vec specialisations so Vec::Type resolves correctly. +template <> struct Vec { + using Type = uchar; +}; +template <> struct Vec { + using Type = uchar2; +}; +template <> struct Vec { + using Type = uchar4; +}; +template <> struct Vec { + using Type = Uchar8_; +}; + +// FP16 vector data types. +struct Half8_ { + half4 x; + half4 y; +}; + +template <> struct Vec { + using Type = half; +}; +template <> struct Vec { + using Type = half2; +}; +template <> struct Vec { + using Type = half4; +}; +template <> struct Vec { + using Type = Half8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(half a, half b) { return (float)a * (float)b; } +template <> inline half mul(half a, half b) { return a * b; } + +template <> inline float2 mul(half2 a, half2 b) { + return (float2)a * (float2)b; +} +template <> inline half2 mul(half2 a, half2 b) { return a * b; } + +template <> inline float4 mul(half4 a, half4 b) { + return (float4)a * (float4)b; +} +template <> inline half4 mul(half4 a, half4 b) { return a * b; } + +template <> inline Float8_ mul(Half8_ a, Half8_ b) { + float4 x = mul(a.x, b.x); + float4 y = mul(a.y, b.y); + Float8_ c; + c.x = x; + c.y = y; + return c; +} +template <> inline Half8_ mul(Half8_ a, Half8_ b) { + Half8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(half a) { return (float)a; } + +template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(half a, half b, float c) { return (float)a * (float)b + c; } + +inline float2 fma(half2 a, half2 b, float2 c) { + return (float2)a * (float2)b + c; +} + +inline float4 fma(half4 a, half4 b, float4 c) { + return (float4)a * (float4)b + c; +} + +inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) { + Half8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread half &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread half2 &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread half4 &dst, float4 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); + dst.z = static_cast(src.z); + dst.w = static_cast(src.w); +} +inline void from_float(thread Half8_ &dst, Float8_ src) { + half4 x; + half4 y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// General case: not uchar +template inline constexpr bool is_uchar() { return false; } + +// Specialization: T is uchar +template <> inline constexpr bool is_uchar() { return true; } + +// Generic fallback – will fail to compile if a required specialisation is +// missing. +template +inline Vec fp8_convert(const thread Quant_vec &, float scale) { + static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation"); +} + +// ========================================== FP8 -> float/half/bfloat +inline float __dequant_single(uchar v, float scale) { + return fp8_e4m3_to_float(v) * scale; +} + +// ---- 1‑lane ---- +template <> +inline float fp8_convert(const thread uchar &in, float scale) { + return __dequant_single(in, scale); +} +template <> +inline half fp8_convert(const thread uchar &in, float scale) { + return half(__dequant_single(in, scale)); +} +template <> +inline bfloat16_t fp8_convert(const thread uchar &in, + float scale) { + return bfloat16_t(__dequant_single(in, scale)); +} + +// ---- 2‑lane ---- +template <> +inline float2 fp8_convert(const thread uchar2 &in, + float scale) { + return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale)); +} +template <> +inline half2 fp8_convert(const thread uchar2 &in, float scale) { + half2 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + return out; +} +template <> +inline Bfloat2_ fp8_convert(const thread uchar2 &in, + float scale) { + Bfloat2_ out; + out.x = bfloat16_t(__dequant_single(in.x, scale)); + out.y = bfloat16_t(__dequant_single(in.y, scale)); + return out; +} + +// ---- 4‑lane ---- +template <> +inline float4 fp8_convert(const thread uchar4 &in, + float scale) { + return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale), + __dequant_single(in.z, scale), __dequant_single(in.w, scale)); +} +template <> +inline half4 fp8_convert(const thread uchar4 &in, float scale) { + half4 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + out.z = half(__dequant_single(in.z, scale)); + out.w = half(__dequant_single(in.w, scale)); + return out; +} +template <> +inline Bfloat4_ fp8_convert(const thread uchar4 &in, + float scale) { + Bfloat4_ out; + out.x.x = bfloat16_t(__dequant_single(in.x, scale)); + out.x.y = bfloat16_t(__dequant_single(in.y, scale)); + out.y.x = bfloat16_t(__dequant_single(in.z, scale)); + out.y.y = bfloat16_t(__dequant_single(in.w, scale)); + return out; +} + +// ---- 8‑lane ---- +template <> +inline Float8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Float8_ out; + out.x = + float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale), + __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale)); + out.y = + float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale), + __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale)); + return out; +} +template <> +inline Half8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Half8_ out; + out.x = half4(half(__dequant_single(in.x.x, scale)), + half(__dequant_single(in.x.y, scale)), + half(__dequant_single(in.x.z, scale)), + half(__dequant_single(in.x.w, scale))); + out.y = half4(half(__dequant_single(in.y.x, scale)), + half(__dequant_single(in.y.y, scale)), + half(__dequant_single(in.y.z, scale)), + half(__dequant_single(in.y.w, scale))); + return out; +} +template <> +inline Bfloat8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Bfloat8_ out; + // first 4 + out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale)); + out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale)); + out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale)); + out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale)); + // second 4 + out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale)); + out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale)); + out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale)); + out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale)); + return out; +} + +// ========================================== Dot product utilities + +// TODO(EricLBuehler): optimize with vectorization +template +inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { + // Compute the parallel products for Q*K^T (treat vector lanes separately). + using A_vec = typename FloatVec::Type; + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += simd_shuffle_xor(qk, mask); + } + return qk; +} + +template struct Qk_dot { + template + static inline float dot(const threadgroup Vec (&q)[N], + const thread Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +// ========================================== Block sum utility + +// Utility function for attention softmax. +template +inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, + uint simd_lid) { + // Compute the sum per simdgroup. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Simd leaders store the data to shared memory. + if (simd_lid == 0) { + red_smem[simd_tid] = sum; + } + + // Make sure the data is in shared memory. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The warps compute the final sums. + if (simd_lid < NUM_WARPS) { + sum = red_smem[simd_lid]; + } + + // Parallel reduction inside the simd group. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Broadcast to other threads. + return simd_shuffle(sum, 0); +} + +// ========================================== Paged Attention kernel + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +constant bool use_partitioning [[function_constant(10)]]; +constant bool use_alibi [[function_constant(20)]]; +constant bool use_fp8_scales [[function_constant(30)]]; +constant bool use_sinks [[function_constant(40)]]; + +template +[[kernel]] void paged_attention( + device float *exp_sums + [[buffer(0), function_constant(use_partitioning)]], // [num_seqs, num_heads, + // max_num_partitions] + device float *max_logits + [[buffer(1), function_constant(use_partitioning)]], // [num_seqs, num_heads, + // max_num_partitions] + device T *out + [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] + device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size] + device const CACHE_T *k_cache + [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] + device const CACHE_T *v_cache + [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] + const device float *__restrict__ k_scale + [[buffer(6), function_constant(use_fp8_scales)]], // [1] + const device float *__restrict__ v_scale + [[buffer(7), function_constant(use_fp8_scales)]], // [1] + const constant int &num_kv_heads [[buffer(8)]], // [num_heads] + const constant float &scale [[buffer(9)]], + const constant float &softcapping [[buffer(10)]], + device const uint32_t *block_tables + [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq] + device const uint32_t *context_lens [[buffer(12)]], // [num_seqs] + const constant int &max_num_blocks_per_seq [[buffer(13)]], + device const float *alibi_slopes + [[buffer(14), function_constant(use_alibi)]], // [num_heads] + const constant int &q_stride [[buffer(15)]], + const constant int &kv_block_stride [[buffer(16)]], + const constant int &kv_head_stride [[buffer(17)]], + const device float *sinks + [[buffer(18), function_constant(use_sinks)]], // [num_heads] + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int seq_idx = threadgroup_position_in_grid.y; + const int partition_idx = threadgroup_position_in_grid.z; + const int max_num_partitions = threadgroups_per_grid.z; + const int thread_idx = thread_position_in_threadgroup.x; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const uint32_t context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES); + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + const int head_idx = threadgroup_position_in_grid.x; + const int num_heads = threadgroups_per_grid.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the query, and the second thread has + // 1, 5, 9, ... th vectors of the query, and so on. + const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use fp32 on softmax logits for better accuracy + threadgroup float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction + threadgroup float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(CACHE_T); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const device uint32_t *block_table = + block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const device CACHE_T *k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (is_uchar()) { + // FP8 support + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8_convert(k_vec_quant, *k_scale); + } else { + // Non-FP8 default + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + + // Apply softcapping + if (softcapping != 1.0) { + qk = tanh(qk / softcapping) * softcapping; + } + + // Add the ALiBi bias if slopes are given. + qk += + (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE: It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : max(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = simd_shuffle(qk_max, 0); + + // For non-partitioned (V1) mode, include the sink in the max. + // For V2 (partitioned), the sink is handled once in the reduce kernel. + if (!USE_PARTITIONING && use_sinks) { + qk_max = max(qk_max, sinks[head_idx]); + } + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, + simd_tid, simd_lid); + + // For non-partitioned (V1) mode, include the sink in the exp sum. + if (!USE_PARTITIONING && use_sinks) { + exp_sum += exp(sinks[head_idx] - qk_max); + } + + // Compute softmax. + const float inv_sum = 1.f / (exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) { + device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + using V_quant_vec = typename Vec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE: We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + T zero_value = 0; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + Float_L_vec logits_float_vec = *reinterpret_cast( + logits + token_idx - start_token_idx); + from_float(logits_vec, logits_float_vec); + + const device CACHE_T *v_ptr = v_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // NOTE: When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + V_vec v_vec; + + if constexpr (is_uchar()) { + // FP8 support + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + v_vec = fp8_convert(v_quant_vec, *v_scale); + } else { + // Non-FP8 default + v_vec = *reinterpret_cast(v_ptr + offset); + } + + if (block_idx == num_context_blocks - 1) { + thread T *v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += simd_shuffle_xor(acc, mask); + } + accs[i] = acc; + } + + // NOTE: A barrier is required because the shared memory space for logits + // is reused for the output. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Perform reduction across warps. + threadgroup float *out_smem = + reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Lower warps update the output. + if (warp_idx < mid) { + const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write the final output. + if (warp_idx == 0) { + device T *out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + *(out_ptr + row_idx) = T(accs[i]); + } + } + } +} + +template +[[kernel]] void paged_attention_v2_reduce( + device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]], + const device float *max_logits [[buffer(2)]], + const device T *tmp_out [[buffer(3)]], + device uint32_t *context_lens [[buffer(4)]], + const constant int &max_num_partitions [[buffer(5)]], + const device float *sinks + [[buffer(6), function_constant(use_sinks)]], // [num_heads] + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint3 threads_per_threadgroup [[threads_per_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int num_heads = threadgroups_per_grid.x; + const int head_idx = threadgroup_position_in_grid.x; + const int seq_idx = threadgroup_position_in_grid.y; + const uint32_t context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1 && !use_sinks) { + // No need to reduce. Only copy tmp_out to out. + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += threads_per_threadgroup.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + // Workspace for reduction. + threadgroup float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + threadgroup float *shared_max_logits = + reinterpret_cast(shared_mem); + const device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = max(max_logit, l); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = simd_shuffle(max_logit, 0); + + // Include the sink in the global max before rescaling. + if (use_sinks) { + max_logit = max(max_logit, sinks[head_idx]); + } + + // Load rescaled exp sums to shared memory. + threadgroup float *shared_exp_sums = reinterpret_cast( + shared_mem + sizeof(float) * num_partitions); + const device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + global_exp_sum = block_sum( + &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); + + // Include the sink in the global exp sum. + if (use_sinks) { + global_exp_sum += exp(sinks[head_idx] - max_logit); + } + + const float inv_global_exp_sum = 1.0f / (global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + out_ptr[i] = T(acc); + } +} + +#define instantiate_paged_attention_inner(type, cache_type, head_size, \ + block_size, num_threads, \ + num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_" #type "_cache_" #cache_type \ + "_hs" #head_size "_bs" #block_size "_nt" #num_threads \ + "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention( \ + device float *exp_sums \ + [[buffer(0), function_constant(use_partitioning)]], \ + device float *max_logits \ + [[buffer(1), function_constant(use_partitioning)]], \ + device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \ + device const cache_type *k_cache [[buffer(4)]], \ + device const cache_type *v_cache [[buffer(5)]], \ + device const float *k_scale \ + [[buffer(6), function_constant(use_fp8_scales)]], \ + device const float *v_scale \ + [[buffer(7), function_constant(use_fp8_scales)]], \ + const constant int &num_kv_heads [[buffer(8)]], \ + const constant float &scale [[buffer(9)]], \ + const constant float &softcapping [[buffer(10)]], \ + device const uint32_t *block_tables [[buffer(11)]], \ + device const uint32_t *context_lens [[buffer(12)]], \ + const constant int &max_num_blocks_per_seq [[buffer(13)]], \ + device const float *alibi_slopes \ + [[buffer(14), function_constant(use_alibi)]], \ + const constant int &q_stride [[buffer(15)]], \ + const constant int &kv_block_stride [[buffer(16)]], \ + const constant int &kv_head_stride [[buffer(17)]], \ + const device float *sinks [[buffer(18), function_constant(use_sinks)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_v2_reduce_inner( \ + type, head_size, num_threads, num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention_v2_reduce( \ + device type * out [[buffer(0)]], \ + const device float *exp_sums [[buffer(1)]], \ + const device float *max_logits [[buffer(2)]], \ + const device type *tmp_out [[buffer(3)]], \ + device uint32_t *context_lens [[buffer(4)]], \ + const constant int &max_num_partitions [[buffer(5)]], \ + const device float *sinks [[buffer(6), function_constant(use_sinks)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_heads( \ + type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, cache_type, 64, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 80, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 96, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 112, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 128, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 192, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 256, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); + +#define instantiate_paged_attention_v2_reduce_heads( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \ + num_simd_lanes, partition_size); + +#define instantiate_paged_attention_block_size(type, cache_type, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \ + num_simd_lanes, partition_size); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 0 +#define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 0); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 512); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \ + instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512); + +instantiate_paged_attention_v1(float, float, 32); +instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v1(half, half, 32); + +instantiate_paged_attention_v1(float, uchar, 32); +instantiate_paged_attention_v1(bfloat16_t, uchar, 32); +instantiate_paged_attention_v1(half, uchar, 32); + +instantiate_paged_attention_v2_reduce(float, 32); +instantiate_paged_attention_v2_reduce(bfloat16_t, 32); +instantiate_paged_attention_v2_reduce(half, 32); + +instantiate_paged_attention_v2(float, float, 32); +instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v2(half, half, 32); + +instantiate_paged_attention_v2(float, uchar, 32); +instantiate_paged_attention_v2(bfloat16_t, uchar, 32); +instantiate_paged_attention_v2(half, uchar, 32); diff --git a/vllm_metal/metal/kernels_v1/reshape_and_cache.metal b/vllm_metal/metal/kernels_v1/reshape_and_cache.metal new file mode 100644 index 0000000..982a4c8 --- /dev/null +++ b/vllm_metal/metal/kernels_v1/reshape_and_cache.metal @@ -0,0 +1,127 @@ +#include "utils.metal" +#include + +using namespace metal; + +template +inline CACHE_T to_cache(KV_T v) = delete; + +template <> inline uchar to_cache(float v) { + return float_to_fp8_e4m3(v); +} + +template <> inline uchar to_cache(bfloat16_t v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline uchar to_cache(half v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline float to_cache(float v) { return v; } + +template <> inline bfloat16_t to_cache(bfloat16_t v) { + return v; +} + +template <> inline half to_cache(half v) { return v; } + +constant bool use_fp8_scales [[function_constant(10)]]; + +template +[[kernel]] void reshape_and_cache( + const device KV_T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device KV_T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device CACHE_T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x] + device CACHE_T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + const device float *__restrict__ k_scale + [[buffer(5), function_constant(use_fp8_scales)]], // [1] + const device float *__restrict__ v_scale + [[buffer(6), function_constant(use_fp8_scales)]], // [1] + device const int &key_stride [[buffer(7)]], + device const int &value_stride [[buffer(8)]], + device const int &num_heads [[buffer(9)]], + device const int &head_size [[buffer(10)]], + device const int &block_size [[buffer(11)]], + device const int &x [[buffer(12)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + + if (use_fp8_scales) { + key_cache[tgt_key_idx] = + to_cache(KV_T((float)key[src_key_idx] / *k_scale)); + value_cache[tgt_value_idx] = + to_cache(KV_T((float)value[src_value_idx] / *v_scale)); + } else { + key_cache[tgt_key_idx] = to_cache(key[src_key_idx]); + value_cache[tgt_value_idx] = + to_cache(value[src_value_idx]); + } + } +} + +#define instantiate_reshape_and_cache(kv_type, cache_type) \ + template [[host_name("reshape_and_cache_kv_" #kv_type \ + "_cache_" #cache_type)]] [[kernel]] void \ + reshape_and_cache( \ + const device kv_type *__restrict__ key [[buffer(0)]], \ + const device kv_type *__restrict__ value [[buffer(1)]], \ + device cache_type *__restrict__ key_cache [[buffer(2)]], \ + device cache_type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + const device float *__restrict__ k_scale \ + [[buffer(5), function_constant(use_fp8_scales)]], \ + const device float *__restrict__ v_scale \ + [[buffer(6), function_constant(use_fp8_scales)]], \ + device const int &key_stride [[buffer(7)]], \ + device const int &value_stride [[buffer(8)]], \ + device const int &num_heads [[buffer(9)]], \ + device const int &head_size [[buffer(10)]], \ + device const int &block_size [[buffer(11)]], \ + device const int &x [[buffer(12)]], \ + uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache(float, float); +instantiate_reshape_and_cache(bfloat16_t, bfloat16_t); +instantiate_reshape_and_cache(half, half); + +instantiate_reshape_and_cache(float, uchar); +instantiate_reshape_and_cache(bfloat16_t, uchar); +instantiate_reshape_and_cache(half, uchar); diff --git a/vllm_metal/metal/kernels_v1/utils.metal b/vllm_metal/metal/kernels_v1/utils.metal new file mode 100644 index 0000000..f9b5cdb --- /dev/null +++ b/vllm_metal/metal/kernels_v1/utils.metal @@ -0,0 +1,253 @@ +// Portions of this file are adapted from Apple's MLX framework +// (https://github.com/ml-explore/mlx) +// Licensed under the Apache License 2.0 +// Copyright © 2023 Apple Inc. + +#include "float8.metal" +#include + +using namespace metal; + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template >::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp new file mode 100644 index 0000000..1928041 --- /dev/null +++ b/vllm_metal/metal/paged_ops.cpp @@ -0,0 +1,269 @@ +// SPDX-License-Identifier: Apache-2.0 +// C++ nanobind bridge for paged attention Metal kernels. +// +// Dispatches reshape_and_cache and paged_attention_v1 through MLX's own +// Metal command encoder, eliminating the PyTorch MPS bridge. +// +// Uses nb::handle + nb::inst_ptr() to extract the C++ array from +// the Python mlx.core.array object, bypassing nanobind's cross-module +// RTTI matching which fails due to hidden symbol visibility in libmlx. + +#include +#include + +#include +#include + +#include "mlx/mlx.h" +#include "mlx/backend/metal/device.h" + +namespace nb = nanobind; +using namespace mlx::core; + +// --------------------------------------------------------------------------- +// Library caching +// --------------------------------------------------------------------------- + +static std::string reshape_cache_source_; +static std::string paged_attention_source_; + +void init_libraries( + const std::string& reshape_src, + const std::string& paged_attn_src) { + reshape_cache_source_ = reshape_src; + paged_attention_source_ = paged_attn_src; + + auto& d = metal::device(Device::gpu); + d.get_library( + "paged_reshape_cache", + [&]() { return reshape_cache_source_; }); + d.get_library( + "paged_attention_kern", + [&]() { return paged_attention_source_; }); +} + +// --------------------------------------------------------------------------- +// Helper: dtype → Metal type string +// --------------------------------------------------------------------------- + +static std::string dtype_to_metal(Dtype dt) { + switch (dt) { + case float16: return "half"; + case bfloat16: return "bfloat16_t"; + case float32: return "float"; + default: + throw std::runtime_error( + "Unsupported dtype for paged attention kernel"); + } +} + +// --------------------------------------------------------------------------- +// reshape_and_cache +// --------------------------------------------------------------------------- + +void reshape_and_cache_impl( + nb::handle key_h, + nb::handle value_h, + nb::handle key_cache_h, + nb::handle value_cache_h, + nb::handle slot_mapping_h +) { + // Extract C++ arrays from Python handles + auto& key = *nb::inst_ptr(key_h); + auto& value = *nb::inst_ptr(value_h); + auto& key_cache = *nb::inst_ptr(key_cache_h); + auto& value_cache = *nb::inst_ptr(value_cache_h); + auto& slot_mapping = *nb::inst_ptr(slot_mapping_h); + + auto s = default_stream(Device::gpu); + auto& d = metal::device(Device::gpu); + + int num_tokens = static_cast(key.shape(0)); + int num_heads = static_cast(key.shape(1)); + int head_size = static_cast(key.shape(2)); + int block_size = static_cast(key_cache.shape(3)); + int x_val = static_cast(key_cache.shape(4)); + + // Contiguous strides (arrays must be row-major after mx.eval) + int32_t key_stride = static_cast(num_heads * head_size); + int32_t value_stride = static_cast(num_heads * head_size); + int32_t num_heads_i = static_cast(num_heads); + int32_t head_size_i = static_cast(head_size); + int32_t block_size_i = static_cast(block_size); + int32_t x_i = static_cast(x_val); + + // Kernel name: same kv and cache dtype (no FP8) + auto dt = dtype_to_metal(key.dtype()); + std::string kname = + "reshape_and_cache_kv_" + dt + "_cache_" + dt; + + // Get library & specialise kernel with function constants + auto* lib = d.get_library("paged_reshape_cache"); + bool use_fp8 = false; + auto* kernel = d.get_kernel( + kname, lib, kname, + {{&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(10)}}); + + // Dispatch on the current MLX command encoder + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + + // Buffer bindings (match reshape_and_cache.metal signature) + enc.set_input_array(key, 0); + enc.set_input_array(value, 1); + enc.set_output_array(key_cache, 2); + enc.set_output_array(value_cache, 3); + enc.set_input_array(slot_mapping, 4); + // 5, 6: k_scale / v_scale — unused (use_fp8_scales=false) + enc.set_bytes(key_stride, 7); + enc.set_bytes(value_stride, 8); + enc.set_bytes(num_heads_i, 9); + enc.set_bytes(head_size_i, 10); + enc.set_bytes(block_size_i, 11); + enc.set_bytes(x_i, 12); + + int tpg = std::min(512, num_heads * head_size); + enc.dispatch_threadgroups( + MTL::Size::Make(num_tokens, 1, 1), + MTL::Size::Make(tpg, 1, 1)); + + // Keep ALL referenced arrays alive until the command buffer completes + d.add_temporary(key, s.index); + d.add_temporary(value, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(slot_mapping, s.index); +} + +// --------------------------------------------------------------------------- +// paged_attention_v1 +// --------------------------------------------------------------------------- + +void paged_attention_v1_impl( + nb::handle out_h, + nb::handle query_h, + nb::handle key_cache_h, + nb::handle value_cache_h, + int num_kv_heads, + float scale, + nb::handle block_tables_h, + nb::handle seq_lens_h, + int block_size, + int max_seq_len +) { + auto& out = *nb::inst_ptr(out_h); + auto& query = *nb::inst_ptr(query_h); + auto& key_cache = *nb::inst_ptr(key_cache_h); + auto& value_cache = *nb::inst_ptr(value_cache_h); + auto& block_tables = *nb::inst_ptr(block_tables_h); + auto& seq_lens = *nb::inst_ptr(seq_lens_h); + + auto s = default_stream(Device::gpu); + auto& d = metal::device(Device::gpu); + + int num_seqs = static_cast(query.shape(0)); + int num_heads = static_cast(query.shape(1)); + int head_size = static_cast(query.shape(2)); + int max_blocks = static_cast(block_tables.shape(1)); + + // Kernel name + auto dt = dtype_to_metal(query.dtype()); + std::string kname = + "paged_attention_" + dt + "_cache_" + dt + + "_hs" + std::to_string(head_size) + + "_bs" + std::to_string(block_size) + + "_nt256_nsl32_ps0"; + + // Function constants + bool use_partitioning = false; + bool use_alibi = false; + bool use_fp8 = false; + + auto* lib = d.get_library("paged_attention_kern"); + auto* kernel = d.get_kernel( + kname, lib, kname, + {{&use_partitioning, MTL::DataType::DataTypeBool, NS::UInteger(10)}, + {&use_alibi, MTL::DataType::DataTypeBool, NS::UInteger(20)}, + {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}}); + + // Threadgroup shared memory + constexpr int NUM_THREADS = 256; + constexpr int NUM_SIMD_LANES = 32; + int padded_ctx = ((max_seq_len + block_size - 1) / block_size) * block_size; + int logits_bytes = padded_ctx * static_cast(sizeof(float)); + int outputs_bytes = (NUM_THREADS / NUM_SIMD_LANES / 2) + * head_size * static_cast(sizeof(float)); + size_t shmem = static_cast(std::max(logits_bytes, outputs_bytes)); + + // Dispatch + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + enc.set_threadgroup_memory_length(shmem, 0); + + // Buffer bindings (match paged_attention.metal signature) + // 0: exp_sums — skipped (v1, no partitioning) + // 1: max_logits — skipped + enc.set_output_array(out, 2); + enc.set_input_array(query, 3); + enc.set_input_array(key_cache, 4); + enc.set_input_array(value_cache, 5); + // 6: k_scale — skipped (no FP8) + // 7: v_scale — skipped + + int32_t nkv = static_cast(num_kv_heads); + enc.set_bytes(nkv, 8); + enc.set_bytes(scale, 9); + float softcapping = 1.0f; + enc.set_bytes(softcapping, 10); + + enc.set_input_array(block_tables, 11); + enc.set_input_array(seq_lens, 12); + + int32_t max_blocks_i = static_cast(max_blocks); + enc.set_bytes(max_blocks_i, 13); + // 14: alibi_slopes — skipped + + // Strides (contiguous row-major) + int32_t q_stride = static_cast(num_heads * head_size); + int32_t kv_block_stride = static_cast(key_cache.strides()[0]); + int32_t kv_head_stride = static_cast(key_cache.strides()[1]); + enc.set_bytes(q_stride, 15); + enc.set_bytes(kv_block_stride, 16); + enc.set_bytes(kv_head_stride, 17); + + enc.dispatch_threadgroups( + MTL::Size::Make(num_heads, num_seqs, 1), + MTL::Size::Make(NUM_THREADS, 1, 1)); + + // Keep ALL referenced arrays alive until the command buffer completes + d.add_temporary(out, s.index); + d.add_temporary(query, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(block_tables, s.index); + d.add_temporary(seq_lens, s.index); +} + +// --------------------------------------------------------------------------- +// nanobind module +// --------------------------------------------------------------------------- + +NB_MODULE(_paged_ops, m) { + m.def("init_libraries", &init_libraries, + nb::arg("reshape_src"), nb::arg("paged_attn_src"), + "JIT-compile the vendored Metal shaders."); + + m.def("reshape_and_cache", &reshape_and_cache_impl, + nb::arg("key"), nb::arg("value"), + nb::arg("key_cache"), nb::arg("value_cache"), + nb::arg("slot_mapping"), + "Write projected K/V into the paged cache."); + + m.def("paged_attention_v1", &paged_attention_v1_impl, + nb::arg("out"), nb::arg("query"), + nb::arg("key_cache"), nb::arg("value_cache"), + nb::arg("num_kv_heads"), nb::arg("scale"), + nb::arg("block_tables"), nb::arg("seq_lens"), + nb::arg("block_size"), nb::arg("max_seq_len"), + "Zero-copy paged attention (v1, no partitioning)."); +} diff --git a/vllm_metal/metal_kernel_backend/cache.py b/vllm_metal/metal_kernel_backend/cache.py index 43ce6d3..85dc7d1 100644 --- a/vllm_metal/metal_kernel_backend/cache.py +++ b/vllm_metal/metal_kernel_backend/cache.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -"""MPS-backed paged KV cache for the HF Metal kernel. +"""MLX-backed paged KV cache for native Metal paged attention. -Stores per-layer key/value caches as PyTorch MPS tensors in the layout -expected by ``reshape_and_cache`` and ``paged_attention_v1``: +Stores per-layer key/value caches as MLX arrays in the layout expected by +``reshape_and_cache`` and ``paged_attention_v1``: - key_cache: [num_blocks, num_kv_heads, head_dim // x, block_size, x] where x = 16 // element_size (8 for float16) @@ -13,11 +13,18 @@ from __future__ import annotations -import torch +import mlx.core as mx +# mx.Dtype → element size in bytes +_DTYPE_SIZE = { + mx.float16: 2, + mx.bfloat16: 2, + mx.float32: 4, +} -class MPSPagedKVCache: - """Per-layer MPS tensors for the HF paged-attention kernel.""" + +class MetalPagedKVCache: + """Per-layer MLX arrays for native Metal paged attention.""" def __init__( self, @@ -26,7 +33,7 @@ def __init__( head_dim: int, num_blocks: int, block_size: int, - dtype: torch.dtype = torch.float16, + dtype: mx.Dtype = mx.float16, ) -> None: self.num_layers = num_layers self.num_kv_heads = num_kv_heads @@ -35,14 +42,9 @@ def __init__( self.block_size = block_size self.dtype = dtype - # The key cache uses a 5D layout for vectorized memory access: - # [num_blocks, num_kv_heads, head_dim // x, block_size, x] - # where x = 16 // element_size ensures each innermost vector is - # exactly 16 bytes, matching the Metal kernel's load granularity. - # This layout is required by the HF paged-attention Metal kernel - # (ported from vLLM CUDA / mistral.rs): - # https://github.com/huggingface/kernels-community/blob/main/paged-attention/paged-attention-metal/paged_attention.mm - element_size = torch.tensor([], dtype=dtype).element_size() + element_size = _DTYPE_SIZE.get(dtype) + if element_size is None: + raise ValueError(f"Unsupported dtype for paged KV cache: {dtype}") self.x = 16 // element_size # 8 for float16, 4 for float32 if head_dim % self.x != 0: @@ -53,31 +55,21 @@ def __init__( ) # Per-layer caches - self.key_caches: list[torch.Tensor] = [] - self.value_caches: list[torch.Tensor] = [] + self.key_caches: list[mx.array] = [] + self.value_caches: list[mx.array] = [] for _ in range(num_layers): self.key_caches.append( - torch.zeros( - num_blocks, - num_kv_heads, - head_dim // self.x, - block_size, - self.x, + mx.zeros( + (num_blocks, num_kv_heads, head_dim // self.x, block_size, self.x), dtype=dtype, - device="mps", ) ) self.value_caches.append( - torch.zeros( - num_blocks, - num_kv_heads, - head_dim, - block_size, + mx.zeros( + (num_blocks, num_kv_heads, head_dim, block_size), dtype=dtype, - device="mps", ) ) - # Scale tensors (identity scaling) - self.k_scale_tensor = torch.tensor(1.0, dtype=torch.float32, device="mps") - self.v_scale_tensor = torch.tensor(1.0, dtype=torch.float32, device="mps") + # Force allocation so Metal buffers exist before kernel dispatch + mx.eval(*self.key_caches, *self.value_caches) diff --git a/vllm_metal/metal_kernel_backend/kernel_loader.py b/vllm_metal/metal_kernel_backend/kernel_loader.py deleted file mode 100644 index 2870b84..0000000 --- a/vllm_metal/metal_kernel_backend/kernel_loader.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Load the HuggingFace paged-attention Metal kernel. - -Uses ``kernels.get_kernel()`` to fetch the community paged-attention shader -which provides ``reshape_and_cache`` and ``paged_attention_v1`` ops that run -natively on Apple Metal (MPS). - -The latest HF builds (Jan 2026+) are compiled with Metal language version 4.0 -which requires macOS 16 (Tahoe). On macOS 15 (Sequoia) and earlier, we pin -to the Nov 2025 build (commit ``8968951``) which targets Metal 3.2. -""" - -from __future__ import annotations - -import logging -import platform -from typing import Any - -logger = logging.getLogger(__name__) - -_kernel: Any = None - -# Latest HF build requires Metal 4.0 (macOS 16+). This older revision -# was built with Metal 3.2 and works on macOS 15 and earlier. -_MACOS15_COMPAT_REVISION = "8968951" - - -def _needs_compat_revision() -> bool: - """Return True when the current macOS only supports Metal < 4.0.""" - ver = platform.mac_ver()[0] - if not ver: - return False - major = int(ver.split(".")[0]) - return major <= 15 - - -def get_paged_attention_ops() -> Any: - """Return the loaded paged-attention kernel module. - - The module exposes at minimum: - - ``reshape_and_cache(...)`` - - ``paged_attention_v1(...)`` - - The kernel is loaded once and cached for subsequent calls. - """ - global _kernel - if _kernel is None: - try: - from kernels import get_kernel - except ImportError: - raise ImportError( - "Paged attention requires the 'kernels' package. " - "Install it with: pip install 'vllm-metal[paged]'" - ) from None - - revision = _MACOS15_COMPAT_REVISION if _needs_compat_revision() else None - _kernel = get_kernel("kernels-community/paged-attention", revision=revision) - if revision: - logger.info( - "Loaded HF paged-attention Metal kernel (compat revision %s)", - revision, - ) - else: - logger.info("Loaded HF paged-attention Metal kernel") - return _kernel diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index 730881a..4806bb6 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 -"""Paged attention using the HF ``kernels-community/paged-attention`` Metal -kernel for zero-copy decode. +"""Paged attention using vendored Metal kernels dispatched through MLX. -Prefill: MLX inline SDPA (causal), then bridge K/V to MPS and call -``reshape_and_cache`` to write into the paged cache. +Prefill: MLX inline SDPA (causal), then ``reshape_and_cache`` to write +projected K/V into the paged cache. -Decode: MLX projections + per-request RoPE, bridge Q/K/V to MPS, call -``reshape_and_cache`` then ``paged_attention_v1`` (zero-copy read from -block tables), bridge output back to MLX. +Decode: MLX projections + per-request RoPE, ``reshape_and_cache`` to write +the new token, then ``paged_attention_v1`` for zero-copy attention over +all cached K/V blocks. + +All operations use MLX arrays end-to-end — no PyTorch MPS bridge. Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_prefill``, ``prepare_decode``, ``clear_context`` from ``paged_attention_common``. @@ -15,8 +16,7 @@ Backend replacement guide ------------------------- This module exists because there is no flash attention library for Apple -Silicon. The HF Metal kernel is no longer actively maintained, so it may -be replaced in the future. To swap in a new attention backend: +Silicon. To swap in a new attention backend: 1. **Cache**: Create a new cache class that allocates per-layer KV storage addressable by block index. Block allocation is managed externally @@ -65,16 +65,14 @@ import mlx.core as mx import mlx.nn as nn -import torch -from vllm_metal.metal_kernel_backend.cache import MPSPagedKVCache -from vllm_metal.metal_kernel_backend.kernel_loader import get_paged_attention_ops +from vllm_metal.metal import get_ops +from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache from vllm_metal.paged_attention_common import ( PagedAttentionContext, find_layers_and_attr, get_context, ) -from vllm_metal.pytorch_backend.tensor_bridge import mlx_to_torch, torch_to_mlx # --------------------------------------------------------------------------- # Prefill attention (MLX SDPA + reshape_and_cache write) @@ -86,14 +84,14 @@ def _metal_kernel_prefill_attention( queries: mx.array, keys: mx.array, values: mx.array, - cache: MPSPagedKVCache, + cache: MetalPagedKVCache, layer_idx: int, ctx: PagedAttentionContext, offset_cache: Any, ) -> mx.array: """Prefill: B=1, L=prompt_len. - Inline causal SDPA in MLX, then write K/V to MPS paged cache via + Inline causal SDPA in MLX, then write K/V to paged cache via ``reshape_and_cache``. """ B, _, L, _ = queries.shape # noqa: N806 @@ -114,26 +112,25 @@ def _metal_kernel_prefill_attention( queries, keys, values, scale=attn_module.scale, mask=attn_mask ) - # Write K/V into paged MPS cache via reshape_and_cache + # Write K/V into paged cache via reshape_and_cache # keys/values: (1, kv_heads, L, head_dim) → (L, kv_heads, head_dim) k_flat = keys[0].transpose(1, 0, 2) # (L, kv_heads, head_dim) v_flat = values[0].transpose(1, 0, 2) - k_mps = mlx_to_torch(k_flat, device="mps").to(dtype=cache.dtype) - v_mps = mlx_to_torch(v_flat, device="mps").to(dtype=cache.dtype) + # Ensure contiguous + correct dtype + k_flat = mx.contiguous(k_flat.astype(cache.dtype)) + v_flat = mx.contiguous(v_flat.astype(cache.dtype)) - slot_mapping_mps = torch.tensor(ctx.slot_mapping, dtype=torch.long, device="mps") + slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64) + mx.eval(k_flat, v_flat, slot_mapping) - ops = get_paged_attention_ops() + ops = get_ops() ops.reshape_and_cache( - k_mps, - v_mps, + k_flat, + v_flat, cache.key_caches[layer_idx], cache.value_caches[layer_idx], - slot_mapping_mps, - "auto", - cache.k_scale_tensor, - cache.v_scale_tensor, + slot_mapping, ) # output: (B, heads, L, head_dim) → (B, L, heads, head_dim) → (B, L, D) @@ -151,7 +148,7 @@ def _metal_kernel_decode_attention( queries: mx.array, keys: mx.array, values: mx.array, - cache: MPSPagedKVCache, + cache: MetalPagedKVCache, layer_idx: int, ctx: PagedAttentionContext, ) -> mx.array: @@ -178,38 +175,38 @@ def _metal_kernel_decode_attention( queries = mx.concatenate(q_parts, axis=0) # (B, heads, 1, head_dim) keys_new = mx.concatenate(k_parts, axis=0) # (B, kv_heads, 1, head_dim) - # Bridge Q, new K/V to MPS - # (B, heads, 1, hd) → squeeze seq dim → (B, heads, hd) - q_mps = mlx_to_torch(queries[:, :, 0, :], device="mps").to(dtype=cache.dtype) - k_mps = mlx_to_torch(keys_new[:, :, 0, :], device="mps").to(dtype=cache.dtype) - v_mps = mlx_to_torch(values[:, :, 0, :], device="mps").to(dtype=cache.dtype) + # Squeeze seq dim: (B, heads, 1, hd) → (B, heads, hd) + q_3d = mx.contiguous(queries[:, :, 0, :].astype(cache.dtype)) + k_3d = mx.contiguous(keys_new[:, :, 0, :].astype(cache.dtype)) + v_3d = mx.contiguous(values[:, :, 0, :].astype(cache.dtype)) + + slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64) + + # Build block_tables and seq_lens + max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables) + block_tables_list = [ + bt + [0] * (max_blocks_per_seq - len(bt)) for bt in ctx.block_tables + ] + block_tables = mx.array(block_tables_list, dtype=mx.int32) + seq_lens = mx.array(ctx.context_lens, dtype=mx.int32) - slot_mapping_mps = torch.tensor(ctx.slot_mapping, dtype=torch.long, device="mps") + # Eval all inputs before kernel dispatch + mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens) - ops = get_paged_attention_ops() + ops = get_ops() # Write new K/V tokens into paged cache ops.reshape_and_cache( - k_mps, - v_mps, + k_3d, + v_3d, cache.key_caches[layer_idx], cache.value_caches[layer_idx], - slot_mapping_mps, - "auto", - cache.k_scale_tensor, - cache.v_scale_tensor, + slot_mapping, ) - # Build block_tables and seq_lens tensors - max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables) - block_tables_list = [ - bt + [0] * (max_blocks_per_seq - len(bt)) for bt in ctx.block_tables - ] - block_tables_mps = torch.tensor(block_tables_list, dtype=torch.int32, device="mps") - seq_lens_mps = torch.tensor(ctx.context_lens, dtype=torch.int32, device="mps") - - # Allocate output tensor - out = torch.zeros(B, n_heads, head_dim, dtype=cache.dtype, device="mps") + # Allocate output + out = mx.zeros((B, n_heads, head_dim), dtype=cache.dtype) + mx.eval(out) max_seq_len = max(ctx.context_lens) scale = attn_module.scale @@ -217,30 +214,24 @@ def _metal_kernel_decode_attention( # Zero-copy paged attention ops.paged_attention_v1( out, - q_mps, + q_3d, cache.key_caches[layer_idx], cache.value_caches[layer_idx], cache.num_kv_heads, scale, - block_tables_mps, - seq_lens_mps, + block_tables, + seq_lens, cache.block_size, max_seq_len, - None, # alibi_slopes - "auto", # kv_cache_dtype - cache.k_scale_tensor, - cache.v_scale_tensor, - 0, # tp_rank - 0, # blocksparse_local_blocks - 0, # blocksparse_vert_stride - 64, # blocksparse_block_size - 0, # blocksparse_head_sliding_step ) - # Bridge output back to MLX: (B, heads, hd) → (B, 1, heads*hd) - out_mlx = torch_to_mlx(out) # (B, heads, head_dim) - out_mlx = out_mlx.reshape(B, 1, n_heads * head_dim) - return attn_module.o_proj(out_mlx) + # Synchronize GPU: paged_attention_v1 wrote to out's buffer via a raw + # Metal dispatch that MLX's lazy graph doesn't track. mx.eval(out) would + # be a no-op here (out was already evaluated as zeros), so we must use + # mx.synchronize() to flush the command encoder and wait for the kernel. + mx.synchronize() + out = out.reshape(B, 1, n_heads * head_dim) + return attn_module.o_proj(out) # --------------------------------------------------------------------------- @@ -249,7 +240,7 @@ def _metal_kernel_decode_attention( class MetalKernelPagedAttentionWrapper(nn.Module): - """Wraps an mlx_lm Attention module to use the HF Metal kernel for paged KV. + """Wraps an mlx_lm Attention module to use native Metal paged KV. Uses ``object.__setattr__`` to bypass MLX nn.Module's ``__setattr__``. @@ -260,7 +251,7 @@ def __init__( self, inner: nn.Module, layer_idx: int, - kv_cache: MPSPagedKVCache, + kv_cache: MetalPagedKVCache, block_size: int, ) -> None: super().__init__() @@ -318,7 +309,7 @@ def __call__(self, x: mx.array, mask: Any = None, cache: Any = None) -> mx.array def patch_model_attention_metal_kernel( model: Any, - kv_cache: MPSPagedKVCache, + kv_cache: MetalPagedKVCache, block_size: int, ) -> int: """Walk model layers and replace each attention module with a diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 02a3ec2..13dd88b 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -641,10 +641,10 @@ def __init__( self._prefix_cache = PrefixCacheManager() # Paged attention state (set by worker when enabled) - self._paged_kv_cache: Any = None # MPSPagedKVCache, set by worker + self._paged_kv_cache: Any = None # MetalPagedKVCache, set by worker self._paged_block_size: int = 0 self._paged_request_seq_lens: dict[str, int] = {} # req_id → seq_len - self.kv_cache_dtype: torch.dtype | None = None + self.kv_cache_dtype: mx.Dtype | None = None def _is_vlm_model(self) -> bool: """Check if the model is a vision-language model (VLM). @@ -825,6 +825,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if self.kv_cache_dtype is None: raise RuntimeError("KV cache dtype not initialized; load_model() first") + # FullAttentionSpec (upstream vLLM) expects torch.dtype + from vllm_metal.pytorch_backend.tensor_bridge import MLX_TO_TORCH_DTYPE + + torch_dtype = MLX_TO_TORCH_DTYPE[self.kv_cache_dtype] + # Create a spec for each layer specs: dict[str, KVCacheSpec] = {} for layer_idx in range(self.num_layers): @@ -833,7 +838,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_dim, - dtype=self.kv_cache_dtype, + dtype=torch_dtype, ) return specs @@ -861,7 +866,7 @@ def get_cache_block_size_bytes(self) -> int: # Block memory = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype_size if self.kv_cache_dtype is None: raise RuntimeError("KV cache dtype not initialized; load_model() first") - dtype_size = self.kv_cache_dtype.itemsize + dtype_size = self.kv_cache_dtype.size return ( 2 * self.num_layers @@ -900,42 +905,39 @@ def warm_up(self) -> None: self._warm_up_paged_attention_kernel() def _warm_up_paged_attention_kernel(self) -> None: - """Load the HF paged-attention kernel and verify Metal ops work. + """JIT-compile vendored Metal shaders and verify ops work. - Forces ``newLibraryWithData`` inside the .so by running a single-token - ``reshape_and_cache`` against layer 0 of the already-allocated cache. - If the embedded metallib targets a Metal language version unsupported - by this OS, the error surfaces here instead of mid-inference. + Calls ``get_ops()`` which triggers JIT build of the C++ nanobind + extension + Metal shader compilation via MLX's device.get_library(). + Then runs a single-token ``reshape_and_cache`` smoke test against + layer 0 of the already-allocated cache. """ import platform - from vllm_metal.metal_kernel_backend.kernel_loader import ( - get_paged_attention_ops, - ) + from vllm_metal.metal import get_ops cache = self._paged_kv_cache logger.info("Warming up paged attention Metal kernel...") try: - ops = get_paged_attention_ops() + ops = get_ops() except Exception as e: raise RuntimeError( - f"Failed to load paged-attention Metal kernel: {e}. " + f"Failed to build/load native paged-attention Metal kernel: {e}. " f"macOS version: {platform.mac_ver()[0]}" ) from e # Smoke-test: single-token reshape_and_cache on layer 0 try: - dummy_k = torch.zeros( - 1, - cache.num_kv_heads, - cache.head_dim, - dtype=cache.dtype, - device="mps", + dummy_k = mx.zeros( + (1, cache.num_kv_heads, cache.head_dim), dtype=cache.dtype + ) + dummy_v = mx.zeros( + (1, cache.num_kv_heads, cache.head_dim), dtype=cache.dtype ) - dummy_v = torch.zeros_like(dummy_k) - dummy_slot = torch.zeros(1, dtype=torch.long, device="mps") + dummy_slot = mx.zeros((1,), dtype=mx.int64) + mx.eval(dummy_k, dummy_v, dummy_slot) ops.reshape_and_cache( dummy_k, @@ -943,10 +945,8 @@ def _warm_up_paged_attention_kernel(self) -> None: cache.key_caches[0], cache.value_caches[0], dummy_slot, - "auto", - cache.k_scale_tensor, - cache.v_scale_tensor, ) + mx.eval(cache.key_caches[0]) logger.info("Paged attention Metal kernel warm-up complete") except RuntimeError as e: mac_ver = platform.mac_ver()[0] diff --git a/vllm_metal/v1/worker.py b/vllm_metal/v1/worker.py index ea0fce3..c3eb7bf 100644 --- a/vllm_metal/v1/worker.py +++ b/vllm_metal/v1/worker.py @@ -139,7 +139,7 @@ def load_model(self) -> None: self._setup_paged_attention() def _setup_paged_attention(self) -> None: - """Create MPSPagedKVCache and patch model attention for HF Metal kernel. + """Create MetalPagedKVCache and patch model attention for native Metal kernel. Computes num_blocks from available system RAM, model weight size, and a configurable memory fraction, rather than blindly scaling from @@ -147,7 +147,7 @@ def _setup_paged_attention(self) -> None: """ import psutil - from vllm_metal.metal_kernel_backend.cache import MPSPagedKVCache + from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache from vllm_metal.metal_kernel_backend.paged_attention import ( patch_model_attention_metal_kernel, ) @@ -243,7 +243,8 @@ def _setup_paged_attention(self) -> None: # --- Create cache and patch model --- if runner.kv_cache_dtype is None: raise RuntimeError("KV cache dtype not initialized; runner.load_model()") - mps_kv_cache = MPSPagedKVCache( + + metal_kv_cache = MetalPagedKVCache( num_layers=runner.num_layers, num_kv_heads=runner.num_kv_heads, head_dim=runner.head_dim, @@ -253,7 +254,7 @@ def _setup_paged_attention(self) -> None: ) n_patched = patch_model_attention_metal_kernel( - runner.model, mps_kv_cache, block_size + runner.model, metal_kv_cache, block_size ) logger.info( "Metal kernel paged attention enabled: %d layers patched, " @@ -266,7 +267,7 @@ def _setup_paged_attention(self) -> None: ) # Store on model runner for use by paged prefill/decode - runner._paged_kv_cache = mps_kv_cache + runner._paged_kv_cache = metal_kv_cache runner._paged_block_size = block_size def _get_model_memory_usage(self) -> int: @@ -299,7 +300,7 @@ def _one_sequence_kv_bytes(self) -> int: """Bytes for one max-length sequence of KV cache (K + V).""" runner = self.model_runner dtype_size = ( - runner.kv_cache_dtype.itemsize if runner.kv_cache_dtype is not None else 2 + runner.kv_cache_dtype.size if runner.kv_cache_dtype is not None else 2 ) return ( 2 # K and V