Skip to content

Add TriAttention KV cache compression#985

Open
Blaizzy wants to merge 5 commits intomainfrom
pc/add-triattn
Open

Add TriAttention KV cache compression#985
Blaizzy wants to merge 5 commits intomainfrom
pc/add-triattn

Conversation

@Blaizzy
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy commented Apr 9, 2026

Summary

  • Implements TriAttention (arXiv:2604.04921) — a KV cache compression method that prunes low-importance tokens using trigonometric series scoring derived from pre-RoPE Q/K concentration
  • Adds offline calibration script (triattention_calibrate.py) and drop-in TriAttentionKVCache that integrates with the existing generation pipeline via --triattention-calib and --triattention-budget CLI args
  • No changes to individual model attention implementations — works as a transparent cache wrapper

Key design decisions

  • No inverse RoPE needed — post-RoPE keys scored directly via phase cancellation (position terms cancel in the Q-center/K-phase difference)
  • Efficient scoringa·cos_tw - b·sin_tw decomposition enables matrix multiply over 17 log-spaced offsets instead of naive loop
  • Sliding layers auto-skipped — only full-attention KVCache layers are compressed; RotatingKVCache (Gemma4 sliding window) left untouched
  • Follows established TurboQuant pattern: composition over inheritance, hot-swap via from_cache, self-contained compression inside update_and_fetch

How It Works

  1. Offline calibration — Run a forward pass to compute per-head Q-center statistics (mean direction and magnitude in the frequency domain)
  2. Online scoring — During generation, every 128 tokens, score each cached key using:
    • S_trig: Trigonometric series based on Q-K distance preferences
    • S_norm: Norm-based signal weighted by Q/K concentration (Mean Resultant Length)
  3. Pruning — Retain the top-B scoring keys, evict the rest. Attention sinks and recent tokens are always protected.

Quick Start

Step 1: Calibrate (one-time per model, takes ~30s)

python -m mlx_vlm.triattention_calibrate \
  --model google/gemma-4-31b-it \
  --output gemma4_calib.safetensors

Step 2: Generate with compression

mlx_vlm generate \
  --model google/gemma-4-31b-it \
  --triattention-calib gemma4_calib.safetensors \
  --triattention-budget 512 \
  --prompt "Your prompt here..." \
  --max-tokens 2048
from mlx_vlm import generate

result = generate(
    model, processor, prompt,
    triattention_calib="gemma4_calib.safetensors",
    triattention_budget=512,
    max_tokens=2048,
)

Benchmarks

Code: https://gist.github.com/Blaizzy/008df4f0a2f6df88db6f36569f06ea25

MATH 500 (30 problems, Gemma4-26B-A4B 5-bit, max_tokens=4096):

Mode Accuracy Avg tok/s
Baseline 23/30 (76.7%) 77.2
TA-2048 22/30 (73.3%) 76.4
TA-1024 21/30 (70.0%) 77.0
TA-512 19/30 (63.3%) 77.7

MM-NIAH (Gemma4-31B, multimodal needle-in-a-haystack, 1K–60K tokens):

Context KV Baseline KV TA-512 Saved Correct (BL/TA)
~1K 0.66 GB 0.64 GB 3% Y / Y
~7K 1.25 GB 0.82 GB 34% Y / Y
~30K 2.64 GB 0.82 GB 69% Y / N
~60K 4.43 GB 0.82 GB 81% Y / N

Files changed

  • mlx_vlm/triattention.py — Core: RoPEConfig, scoring, TriAttentionKVCache, calibration I/O
  • mlx_vlm/triattention_calibrate.py — Calibration script with CaptureWrapper hooks
  • mlx_vlm/generate.py — CLI args + maybe_apply_triattention integration
  • mlx_vlm/models/cache.py — Re-export TriAttentionKVCache
  • README.md — Documentation with quick start, benchmarks, compatibility

Test plan

  • Run calibration: python -m mlx_vlm.triattention_calibrate --model <model> --output calib.safetensors
  • Generate with compression: mlx_vlm generate --model <model> --triattention-calib calib.safetensors --triattention-budget 512 --prompt "..." --max-tokens 2048
  • Verify no regression without TriAttention args (standard generation unchanged)
  • Test with Gemma4, LLaVA, and other nn.RoPE models

Blaizzy added 2 commits April 8, 2026 14:06
…oogle/gemma-4-31b-it for triattention calibration and generation commands.
thegodone pushed a commit to guillaume-osmo/mlx-lm that referenced this pull request Apr 10, 2026
…#985

Add trigonometric-series-based token pruning for KV cache compression,
from "TriAttention: Efficient Long Reasoning with Trigonometric KV
Compression" (Lin et al., 2026, arXiv:2604.04921).

Scores post-RoPE key importance using calibrated Q-center statistics
and evicts low-scoring tokens when cache exceeds budget.  No inverse
RoPE needed — position terms cancel in the phase difference.

New files:
- mlx_lm/models/triattention.py: core scoring + TriAttentionKVCache
- mlx_lm/triattention_calibrate.py: offline Q-center calibration CLI

Usage:
  python -m mlx_lm.triattention_calibrate --model <m> --output calib.safetensors
  python -m mlx_lm.generate --model <m> --triattention-calib calib.safetensors \
      --triattention-budget 512

Tested on SmolLM3-3B: exact match at budget>=seq_len, coherent output
with 22% memory reduction at aggressive budget=200.

Ported from Blaizzy/mlx-vlm#985 (Blaizzy).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Blaizzy and others added 3 commits April 14, 2026 14:49
Q/K centers are model-intrinsic properties that converge from very few
tokens (paper Appendix H). Online mode computes calibration from prefill
tokens automatically — no separate calibration file needed.

Usage: just pass --triattention-budget 512 (no --triattention-calib).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant