diff --git a/.envrc b/.envrc new file mode 100644 index 00000000..952d0d57 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +export NPM_CONFIG_REGISTRY=https://registry.npmjs.org/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 22eb64d7..dd63e9dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,6 +2,17 @@ name: CI on: workflow_dispatch: + inputs: + run_llama_smoke_tests: + description: "Run GGUF exact smoke matrix" + type: boolean + required: false + default: false + run_mlx_smoke_tests: + description: "Run MLX exact smoke matrix" + type: boolean + required: false + default: false push: branches: [main] pull_request: @@ -193,6 +204,174 @@ jobs: cache_key_prefix: '' workflow_cache_file: .github/workflows/ci.yml + smoke_case_matrix: + needs: [changes, inference_smoke_tests] + if: ${{ github.event_name == 'workflow_dispatch' && needs.inference_smoke_tests.result == 'success' }} + name: Generate smoke test matrix + runs-on: ubuntu-latest + outputs: + gguf_matrix: ${{ steps.gen.outputs.gguf_matrix }} + mlx_matrix: ${{ steps.gen.outputs.mlx_matrix }} + steps: + - uses: actions/checkout@v5 + + - name: Generate model/family test matrix + id: gen + run: | + python3 - <<'PY' + import json + import os + from pathlib import Path + + matrix = json.loads(Path("testdata/validation/matrix.json").read_text(encoding="utf-8")) + gguf = [] + mlx = [] + + for model in matrix.get("models", []): + model_id = model.get("id") + label = model.get("label", model_id) + + gguf_entry = model.get("gguf") + if gguf_entry: + ref = gguf_entry.get("model_ref", "") + parts = ref.split("/") + cache_slug = f"{parts[0]}--{parts[1]}" if len(parts) >= 2 else "" + case_id = gguf_entry.get("exact_case_id") + if case_id: + gguf.append({ + "id": model_id, + "name": label, + "case_id": case_id, + "model_ref": ref, + "cache_slug": cache_slug, + }) + + mlx_entry = model.get("mlx") + if mlx_entry: + ref = mlx_entry.get("model_ref", "") + parts = ref.split("/") + cache_slug = f"{parts[0]}--{parts[1]}" if len(parts) >= 2 else "" + case_id = mlx_entry.get("exact_case_id") + if case_id: + mlx.append({ + "id": model_id, + "name": label, + "case_id": case_id, + "model_ref": ref, + "cache_slug": cache_slug, + }) + + out = Path(os.environ["GITHUB_OUTPUT"]) + with out.open("a", encoding="utf-8") as f: + f.write(f"gguf_matrix={json.dumps({'include': gguf}, separators=(',', ':'))}\n") + f.write(f"mlx_matrix={json.dumps({'include': mlx}, separators=(',', ':'))}\n") + PY + + mlx_smoke_tests: + needs: [changes, macos, inference_smoke_tests, smoke_case_matrix] + if: ${{ github.event_name == 'workflow_dispatch' && needs.macos.result == 'success' && needs.inference_smoke_tests.result == 'success' && needs.smoke_case_matrix.result == 'success' && github.event.inputs.run_mlx_smoke_tests == 'true' }} + name: "Matrix: MLX Smoke Tests" + runs-on: macos-latest + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.smoke_case_matrix.outputs.mlx_matrix) }} + steps: + - uses: actions/checkout@v5 + + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Download macOS inference binaries + uses: actions/download-artifact@v7 + with: + name: ci-macos-inference-binaries + path: ci-artifacts/macos + + - name: Stage binaries for validation runner + run: | + mkdir -p target/release + cp ci-artifacts/macos/target/debug/mesh-llm target/release/mesh-llm + chmod +x target/release/mesh-llm + + - name: Cache MLX model repo + if: ${{ matrix.cache_slug != '' }} + uses: actions/cache@v5 + with: + path: ~/.cache/huggingface/hub/models--${{ matrix.cache_slug }} + key: mlx-hub-${{ matrix.cache_slug }} + + - name: MLX exact smoke case + run: | + python3 scripts/run-validation-matrix.py \ + --suite exact \ + --backend mlx \ + --skip-build \ + --cases "${{ matrix.case_id }}" \ + --stamp "ci-mlx-${{ matrix.id }}" + + - name: Upload MLX exact results + if: always() + uses: actions/upload-artifact@v6 + with: + name: mlx-exact-${{ matrix.id }} + path: MLX_VALIDATION_RESULTS/ci-mlx-${{ matrix.id }} + if-no-files-found: warn + + llama_smoke_tests: + needs: [changes, linux, inference_smoke_tests, smoke_case_matrix] + if: ${{ github.event_name == 'workflow_dispatch' && needs.linux.result == 'success' && needs.inference_smoke_tests.result == 'success' && needs.smoke_case_matrix.result == 'success' && github.event.inputs.run_llama_smoke_tests == 'true' }} + name: "Matrix: Llama Smoke Tests" + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.smoke_case_matrix.outputs.gguf_matrix) }} + steps: + - uses: actions/checkout@v5 + + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Download Linux inference binaries + uses: actions/download-artifact@v7 + with: + name: ci-linux-inference-binaries + path: ci-artifacts/linux + + - name: Stage binaries for validation runner + run: | + mkdir -p target/release llama.cpp/build/bin + cp ci-artifacts/linux/target/debug/mesh-llm target/release/mesh-llm + cp ci-artifacts/linux/llama.cpp/build/bin/rpc-server llama.cpp/build/bin/rpc-server + cp ci-artifacts/linux/llama.cpp/build/bin/llama-server llama.cpp/build/bin/llama-server + cp ci-artifacts/linux/llama.cpp/build/bin/llama-moe-split llama.cpp/build/bin/llama-moe-split + chmod +x target/release/mesh-llm llama.cpp/build/bin/rpc-server llama.cpp/build/bin/llama-server llama.cpp/build/bin/llama-moe-split + + - name: Cache GGUF model repo + if: ${{ matrix.cache_slug != '' }} + uses: actions/cache@v5 + with: + path: ~/.cache/huggingface/hub/models--${{ matrix.cache_slug }} + key: gguf-hub-${{ matrix.cache_slug }} + + - name: GGUF exact smoke case + run: | + python3 scripts/run-validation-matrix.py \ + --suite exact \ + --backend gguf \ + --skip-build \ + --cases "${{ matrix.case_id }}" \ + --stamp "ci-gguf-${{ matrix.id }}" + + - name: Upload GGUF exact results + if: always() + uses: actions/upload-artifact@v6 + with: + name: gguf-exact-${{ matrix.id }} + path: MLX_VALIDATION_RESULTS/ci-gguf-${{ matrix.id }} + if-no-files-found: warn + macos: needs: changes if: ${{ github.event_name == 'workflow_dispatch' || needs.changes.outputs.rust == 'true' || needs.changes.outputs.ui == 'true' || needs.changes.outputs.benchmarks == 'true' }} @@ -367,6 +546,14 @@ jobs: if: ${{ github.event_name == 'workflow_dispatch' || needs.changes.outputs.rust == 'true' || needs.changes.outputs.ui == 'true' }} run: scripts/ci-client-auto-test.sh target/debug/mesh-llm + - name: Upload macOS inference binaries + uses: actions/upload-artifact@v6 + with: + name: ci-macos-inference-binaries + path: | + target/debug/mesh-llm + if-no-files-found: error + - name: Build Swift benchmark binary if: ${{ github.event_name == 'workflow_dispatch' || needs.changes.outputs.benchmarks == 'true' }} run: | diff --git a/.gitignore b/.gitignore index 7ada13c8..b32fe79a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,9 +10,11 @@ mesh-llm/ui/dist/ evals/results/ .sisyphus/ .playwright-mcp/ +.cache/mlx-validation/ +MLX_VALIDATION_RESULTS/ .envrc .moe-cache/ __pycache__/ -*.pyc +*.py[cod] .venv/ .pytest_cache/ diff --git a/Cargo.lock b/Cargo.lock index 1b77bab7..e08ad627 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,6 +43,20 @@ dependencies = [ "subtle", ] +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", + "once_cell", + "serde", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -342,7 +356,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16e2cdb6d5ed835199484bb92bb8b3edd526effe995c61732580439c1a67e2e9" dependencies = [ - "base64", + "base64 0.22.1", "http", "log", "url", @@ -445,6 +459,12 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "022dfe9eb35f19ebbcb51e0b40a5ab759f46ad60cadf7297e0bd085afb50e076" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.22.1" @@ -473,6 +493,26 @@ dependencies = [ "regex", ] +[[package]] +name = "bindgen" +version = "0.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn", +] + [[package]] name = "bip39" version = "2.2.2" @@ -585,6 +625,12 @@ version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + [[package]] name = "byteorder" version = "1.5.0" @@ -597,6 +643,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + [[package]] name = "cbc" version = "0.1.2" @@ -618,6 +673,15 @@ dependencies = [ "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -690,6 +754,17 @@ dependencies = [ "zeroize", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.6.0" @@ -754,6 +829,21 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -763,6 +853,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "console" version = "0.16.3" @@ -903,6 +1006,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -918,6 +1031,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -1030,6 +1149,16 @@ dependencies = [ "darling_macro 0.20.11", ] +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", +] + [[package]] name = "darling" version = "0.23.0" @@ -1054,6 +1183,20 @@ dependencies = [ "syn", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + [[package]] name = "darling_core" version = "0.23.0" @@ -1078,6 +1221,17 @@ dependencies = [ "syn", ] +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", + "quote", + "syn", +] + [[package]] name = "darling_macro" version = "0.23.0" @@ -1089,6 +1243,15 @@ dependencies = [ "syn", ] +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -1446,6 +1609,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + [[package]] name = "event-listener" version = "5.4.1" @@ -1760,6 +1932,12 @@ dependencies = [ "polyval", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "gloo-timers" version = "0.3.0" @@ -1791,6 +1969,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1869,7 +2058,7 @@ dependencies = [ "dirs", "futures", "http", - "indicatif", + "indicatif 0.18.4", "libc", "log", "num_cpus", @@ -2072,7 +2261,7 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-channel", "futures-util", @@ -2288,13 +2477,26 @@ dependencies = [ "serde_core", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console 0.15.11", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "indicatif" version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console", + "console 0.16.3", "portable-atomic", "unicode-width", "unit-prefix", @@ -2387,13 +2589,13 @@ dependencies = [ "portmapper", "rand 0.9.2", "reqwest 0.12.28", - "rustc-hash", + "rustc-hash 2.1.2", "rustls", "rustls-pki-types", "rustls-webpki", "serde", "smallvec", - "strum", + "strum 0.28.0", "sync_wrapper", "time", "tokio", @@ -2487,7 +2689,7 @@ dependencies = [ "rustls-pki-types", "serde", "serde_bytes", - "strum", + "strum 0.28.0", "tokio", "tokio-rustls", "tokio-util", @@ -2506,6 +2708,15 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -2588,6 +2799,16 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.16" @@ -2670,6 +2891,31 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3d25b0e0b648a86960ac23b7ad4abb9717601dec6f66c165f5b037f3f03065f" +[[package]] +name = "mach-sys" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48460c2e82a3a0de197152fdf8d2c2d5e43adc501501553e439bf2156e6f87c7" +dependencies = [ + "fastrand", +] + +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "matchers" version = "0.2.0" @@ -2691,6 +2937,12 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + [[package]] name = "memoffset" version = "0.9.1" @@ -2707,7 +2959,7 @@ dependencies = [ "anyhow", "argon2", "axum", - "base64", + "base64 0.22.1", "biip", "chacha20poly1305", "chrono", @@ -2723,6 +2975,8 @@ dependencies = [ "keyring", "libc", "mesh-llm-plugin", + "minijinja", + "mlx-rs", "nostr-sdk", "prost", "prost-build", @@ -2741,6 +2995,7 @@ dependencies = [ "sha2 0.10.9", "tempfile", "thiserror 2.0.18", + "tokenizers", "tokio", "tokio-stream", "toml", @@ -2775,6 +3030,23 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minijinja" +version = "2.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "805bfd7352166bae857ee569628b52bcd85a1cecf7810861ebceb1686b72b75d" +dependencies = [ + "memo-map", + "serde", + "serde_json", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2796,6 +3068,68 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mlx-internal-macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a7c4444d624bf6b93db5cc22ebff4fdfa13593fd56154fe33b1f302a557c2c6" +dependencies = [ + "darling 0.21.3", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "mlx-macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a819ee8b4434690572b6feb9c3ef0b6e90137e4190b340cf00150703b410aaf9" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "mlx-rs" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a0f592c5839b0237b3072530b1d3c503923579a2009ee0761df3edd1b1f27b" +dependencies = [ + "bytemuck", + "dyn-clone", + "half", + "itertools 0.14.0", + "libc", + "mach-sys", + "mlx-internal-macros", + "mlx-macros", + "mlx-sys", + "num-complex", + "num-traits", + "num_enum", + "parking_lot", + "paste", + "safetensors", + "smallvec", + "strum 0.27.2", + "thiserror 2.0.18", +] + +[[package]] +name = "mlx-sys" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e3bc3880111918b2d5018f845d48fd995f9901f16efc81d1fcfd2f4210b8219" +dependencies = [ + "bindgen", + "cc", + "cmake", +] + [[package]] name = "moka" version = "0.12.15" @@ -2813,6 +3147,28 @@ dependencies = [ "uuid", ] +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "multimap" version = "0.10.1" @@ -3026,6 +3382,16 @@ dependencies = [ "libc", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "noq" version = "0.17.0" @@ -3037,7 +3403,7 @@ dependencies = [ "noq-proto", "noq-udp", "pin-project-lite", - "rustc-hash", + "rustc-hash 2.1.2", "rustls", "socket2", "thiserror 2.0.18", @@ -3063,7 +3429,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash", + "rustc-hash 2.1.2", "rustls", "rustls-pki-types", "slab", @@ -3093,7 +3459,7 @@ version = "0.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3aa5e3b6a278ed061835fe1ee293b71641e6bf8b401cfe4e1834bbf4ef0a34e1" dependencies = [ - "base64", + "base64 0.22.1", "bech32", "bip39", "bitcoin_hashes", @@ -3308,6 +3674,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "objc2" version = "0.6.4" @@ -3377,6 +3749,28 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "opaque-debug" version = "0.3.1" @@ -3639,7 +4033,7 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "740ebea15c5d1428f910cd1a5f52cebf8d25006245ed8ade92702f4943d91e07" dependencies = [ - "base64", + "base64 0.22.1", "indexmap", "quick-xml", "serde", @@ -3698,7 +4092,7 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74748bc706fa6b6aebac6bbe0bbe0de806b384cb5c557ea974f771360a4e3858" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "derive_more", "futures-lite", @@ -3830,7 +4224,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ "heck", - "itertools", + "itertools 0.14.0", "log", "multimap", "petgraph", @@ -3849,7 +4243,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools", + "itertools 0.14.0", "proc-macro2", "quote", "syn", @@ -3948,7 +4342,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash", + "rustc-hash 2.1.2", "rustls", "socket2", "thiserror 2.0.18", @@ -3968,7 +4362,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash", + "rustc-hash 2.1.2", "rustls", "rustls-pki-types", "slab", @@ -4089,6 +4483,37 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools 0.14.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -4170,7 +4595,7 @@ version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "encoding_rs", "futures-core", @@ -4217,7 +4642,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-core", "futures-util", @@ -4272,7 +4697,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2231b2c085b371c01bc90c0e6c1cab8834711b6394533375bdbf870b0166d419" dependencies = [ "async-trait", - "base64", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -4321,6 +4746,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.2" @@ -4399,6 +4830,16 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "safetensors" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "salsa20" version = "0.10.2" @@ -4893,6 +5334,18 @@ dependencies = [ "der", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "sse-stream" version = "0.2.1" @@ -4924,13 +5377,34 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros 0.27.2", +] + [[package]] name = "strum" version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" dependencies = [ - "strum_macros", + "strum_macros 0.28.0", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -5130,6 +5604,40 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a620b996116a59e184c2fa2dfd8251ea34a36d0a514758c6f966386bd2e03476" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.4", + "indicatif 0.17.11", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.18", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.51.0" @@ -5238,7 +5746,7 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1b6348ebfaaecd771cecb69e832961d277f59845d4220a584701f72728152b7" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-core", "futures-sink", @@ -5479,6 +5987,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-segmentation" version = "1.13.2" @@ -5497,6 +6014,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unit-prefix" version = "0.5.2" @@ -5525,7 +6048,7 @@ version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" dependencies = [ - "base64", + "base64 0.22.1", "cookie_store", "flate2", "log", @@ -5546,7 +6069,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" dependencies = [ - "base64", + "base64 0.22.1", "http", "httparse", "log", diff --git a/Justfile b/Justfile index 4eb01acf..b30f25c1 100644 --- a/Justfile +++ b/Justfile @@ -239,6 +239,16 @@ bundle output="/tmp/mesh-bundle.tar.gz": cp {{ build_dir }}/bin/llama-server "$BUNDLE/$llama_name" cp {{ build_dir }}/bin/llama-moe-analyze "$BUNDLE/" cp {{ build_dir }}/bin/llama-moe-split "$BUNDLE/" + if [ "$(uname -s)" = "Darwin" ]; then + shopt -s nullglob + metallibs=(target/release/build/mlx-sys-*/out/build/lib/mlx.metallib) + shopt -u nullglob + if [ "${#metallibs[@]}" -gt 0 ]; then + cp "${metallibs[0]}" "$BUNDLE/mlx.metallib" + else + echo "Note: mlx.metallib not found in target/release/build β€” MLX bundles may fail on remote hosts" + fi + fi for lib in {{ build_dir }}/bin/*.dylib; do cp "$lib" "$BUNDLE/" 2>/dev/null || true done diff --git a/MLX_ROADMAP.md b/MLX_ROADMAP.md new file mode 100644 index 00000000..8fe51a5e --- /dev/null +++ b/MLX_ROADMAP.md @@ -0,0 +1,279 @@ +# MLX Roadmap + +This document tracks the remaining work for the native MLX backend in `mesh-llm`. + +It is intentionally practical: +- what already works +- what still needs runtime support +- what remains around downloads, templates, CI, and product behavior +- which GitHub issues already track MLX follow-up work + +## Current Status + +The MLX backend is now a real serving path on macOS for a bounded set of text models. + +Working today: +- MLX-native loading and serving on Apple Silicon +- Hugging Face repo shorthand, exact artifact refs, and catalog entries +- explicit MLX runtime selection via `--mlx` or `--mlx-file` +- MLX sidecar download support: + - `config.json` + - `tokenizer.json` + - `tokenizer_config.json` + - `chat_template.json` + - `chat_template.jinja` + - sharded safetensors from `model.safetensors.index.json` +- Hugging Face chat templates rendered through MiniJinja with compatibility normalization +- Family-aware thinking/reasoning controls for supported template families +- macOS MLX smoke coverage in CI + +Current product behavior: +- MLX is an explicit opt-in backend when launching with `--model` +- using MLX prints an experimental startup warning +- the warning explicitly points users at the GitHub issues page if they hit problems + +What the existing llama.cpp-backed `mesh-llm` path already supports: +- vision models via `mmproj` + +What the existing `mesh-llm` product surface does not currently expose as a first-class llama feature: +- audio runtime support + +## Supported Runtime Families + +These families are now in the supported native MLX runtime set: +- Llama +- GLM 4 dense +- Qwen2 +- Qwen3 +- Gemma 2 text +- Gemma 3 text +- Gemma 4 text +- GLM4 text +- LFM2 text +- DeepSeekV3 / Kimi-K2 text +- gpt-oss text +- Kimi Linear text + +Notes: +- Gemma 4 support currently targets text-capable MLX repos such as `unsloth/gemma-4-E4B-it-UD-MLX-4bit` +- MLX remains a local-only serving path today +- MLX support is currently text-only, even though the llama-backed path already supports vision models +- DeepSeekV3 / Kimi-K2, `gpt-oss`, and `Kimi Linear` are correctness-first runtime additions today; they are compile/test verified but not part of the live macOS smoke matrix yet +- MLX is intentionally not auto-selected from `--model`; callers must opt in with `--mlx` + +## Families Still Missing Runtime Support + +These families still need more validation, broader target coverage, or a dedicated product pass: +- DeepSeekV3 / Kimi-K2 +- Kimi Linear +- gpt-oss + +These are no longer template-only gaps, but they are not as battle-tested as the smaller live-smoked families yet. + +## Remaining Family Work + +### Gemma 4 + +Gemma 4 text now works, but it is not fully β€œdone”. + +Remaining work: +- verify more Gemma 4 repos beyond the current `unsloth` target +- support broader Gemma 4 variants if they differ from the current text-side structure +- harden around layer-type-specific attention behavior if new repos expose gaps +- add more Gemma 4 catalog coverage once confidence is higher + +### Kimi / Kimi Linear / gpt-oss / LFM2 + +These should not be treated as one bucket anymore. + +#### DeepSeekV3 / Kimi-K2 + +Current understanding: +- `Kimi-K2` / `K2.5` ride on a DeepSeekV3-style MLA + MoE runtime base +- that base is now implemented in the MLX runtime +- live smoke coverage is still missing because the public MLX repos are very large + +Remaining work: +- add a real known-good K2/K2.5 runtime validation pass once practical hardware/CI coverage exists +- catalog only when we are comfortable with real runtime validation, not just compile-time support + +#### Kimi Linear + +Current understanding: +- `kimi_linear` is a separate architecture from K2/K2.5 +- it uses its own linear-attention stack plus MoE and custom projection structure +- it now has a dedicated cacheless runtime path in the MLX backend + +Remaining work: +- real public-model validation against a known-good target repo +- focused runtime smoke tests +- cached generation / recurrent state support beyond the correctness-first cacheless path + +#### gpt-oss + +Current understanding: +- the realistic public target is `mlx-community/gpt-oss-20b-MXFP4-Q4` +- its runtime is now implemented via a correctness-first cacheless path +- it is still not part of the live smoke matrix, and the current support should be treated as earlier-stage than Llama/Qwen/Gemma/GLM/LFM2 + +Remaining work: +- real public-model validation against a known-good target repo +- focused runtime smoke tests +- decide whether to keep the current path or later add lower-level MXFP4 support via `mlx-rs` for better performance + +#### LFM2 + +Current understanding: +- the best first target is `mlx-community/LFM2-350M-4bit` +- `LFM2` is more tractable than `gpt-oss` because the public MLX target uses plain affine quantization +- the family alternates standard attention blocks with `ShortConv` blocks + +Current status: +- `lfm2` config support is implemented +- `ShortConv` runtime is implemented +- `mlx-community/LFM2-350M-4bit` passes a live local MLX smoke +- MLX generation now uses streaming-safe token decoding, which fixed non-ASCII output for this family (`πŸ”΄` instead of replacement characters) + +Remaining work: +- add broader LFM2 coverage beyond the 350M target +- add GGUF-side parity where practical if we want matrix symmetry +- revisit cached generation for LFM2 after the correctness-first cacheless path + +## Prompt Template Compatibility + +The current HF template path is much stronger than before, but there is still follow-up work. + +Remaining work: +- keep extending the real Hugging Face template corpus as new MLX repos appear +- improve compatibility with family-specific template quirks only when real repos require it +- prefer real fixtures over speculative support +- keep fallback behavior explicit when HF templates cannot be rendered safely + +Priority additions: +- more real Gemma 4 templates +- future Qwen 4 MLX templates once public repos exist +- more tool-calling and multimodal-adjacent text templates where relevant + +## Vision and Audio + +### Vision + +Vision should be part of the MLX roadmap because the existing llama-backed runtime already supports vision models in `mesh-llm`. + +Evidence in the current codebase: +- model catalog entries carry `mmproj` +- capability detection marks vision support from `mmproj` and vision metadata +- llama launch wiring passes `--mmproj` + +Remaining MLX vision work: +- add a real MLX-side multimodal model-loading path +- support image token / image placeholder handling beyond template-only rendering +- add at least one live MLX vision smoke once a supported model family exists +- extend catalog rules so MLX vision models are only listed when the runtime truly supports them + +### Audio + +Audio should not yet be treated as a committed MLX roadmap target in the same way as vision. + +Reason: +- I checked the current `mesh-llm` repo surface, and unlike vision there is no first-class audio runtime path exposed through `mesh-llm` today +- that means β€œmatch llama feature parity” clearly applies to vision now, but not yet to audio at the product layer + +So the current stance should be: +- vision: yes, explicit MLX roadmap target +- audio: future possibility, but not yet a committed parity target until `mesh-llm` itself exposes it on the llama path + +## Runtime Behavior and Product Gaps + +### Local-only MLX serving + +Current behavior: +- supported MLX models on macOS run through the local native MLX path +- they do not participate in the existing rpc/split distributed path + +Remaining work: +- design distributed MLX serving behavior +- decide whether MLX split/distribution reuses existing orchestration or needs backend-specific rules +- add tests once a real design exists + +Tracked issue: +- [#146](https://github.com/michaelneale/mesh-llm/issues/146) Support distributed or split MLX serving + +### Download and resolution UX + +Current state is much better, but there is still polish left: +- keep catalog entries expanding for supported families +- maintain backend-aware ambiguity handling for `--model org/repo` +- keep `--mlx`, `--gguf`, `--mlx-file`, and `--gguf-file` behavior sharp and documented + +## CI and Smoke Testing + +The macOS MLX smoke matrix exists now. It should keep expanding only where runtime support is real. + +Remaining work: +- stabilize the sequential smoke experience across the supported matrix +- debug and eliminate startup/load flakes such as the intermittent `JOSIE-IT1-Qwen3-0.6B-4bit` startup hang +- keep prompts family-aware so the smoke tests validate useful behavior without becoming brittle +- avoid adding live smokes for unsupported runtime families + +## Catalog Work + +Current direction: +- explicit `-MLX` catalog names for MLX entries +- only catalog models that the runtime can actually serve + +Remaining work: +- keep adding supported MLX entries for Llama, Qwen, Gemma +- add MLX vision catalog entries only after MLX vision runtime support is real +- do not add unsupported families just because templates render +- revisit broader Gemma-family catalog breadth now that Gemma 2/3/4 text are working + +## MLX Runtime Engineering Tasks + +These are the main technical tasks still on the table: +- real-model validation and smoke coverage for DeepSeekV3 / Kimi-K2 +- real-model validation and smoke coverage for Kimi Linear +- real-model validation and smoke coverage for gpt-oss +- broader Gemma 4 validation and hardening +- MLX vision runtime support for families where `mesh-llm` already supports vision on the llama path +- distributed MLX serving + +## Possible `mlx-rs` Follow-up + +We should stay willing to fork `mlx-rs` if the backend needs capabilities that are not practical to layer externally. + +Reasons a fork might become necessary: +- missing kernels for quantized paths we need +- attention-mask behavior the current API cannot express cleanly +- backend/device movement limitations that block correct runtime behavior +- quantization modes such as MXFP4 that the current bindings/runtime path do not expose cleanly + +This should remain a last resort, but it is explicitly on the table. + +## Existing MLX Issues + +Issues already raised for MLX follow-up: + +- [#142](https://github.com/michaelneale/mesh-llm/issues/142) Support Gemma MLX models in native MLX runtime + - Originally opened for Gemma-family runtime support + - Now partially addressed by Gemma 2, Gemma 3, and Gemma 4 text support + - Still relevant for broader Gemma 4 coverage and future Gemma-family expansion + +- [#146](https://github.com/michaelneale/mesh-llm/issues/146) Support distributed or split MLX serving + - Tracks the current local-only limitation + +## Suggested Next Steps + +Recommended order: + +1. Land and stabilize DeepSeekV3 / Kimi-K2 coverage beyond compile-time validation. +2. Narrow issue `#142` to the remaining Gemma-family work now that Gemma 2, Gemma 3, and Gemma 4 text support exist. +3. Start scoping MLX vision support, since vision is already supported on the llama-backed path. +4. Decide whether the next major priority is: + - broader family coverage, or + - distributed MLX serving from `#146` + +If family coverage is the priority, the next order should be: +- broader DeepSeekV3 / Kimi-K2 validation +- Kimi Linear validation +- broader gpt-oss validation diff --git a/MLX_VALIDATION_MATRIX.md b/MLX_VALIDATION_MATRIX.md new file mode 100644 index 00000000..bb41ac24 --- /dev/null +++ b/MLX_VALIDATION_MATRIX.md @@ -0,0 +1,65 @@ +# MLX Validation Matrix + +Local-first backend-parity ledger for model families. The point is not to judge +MLX in isolation; it is to compare `πŸ¦™ GGUF` against `🍎 MLX` on the same family / +model / case so we can tell shared model weakness from MLX-specific regressions. + +## Legend + +| Status | Meaning | +|---|---| +| `PASS` | Validated locally and behaved acceptably for the checks listed | +| `FAIL` | Reproduced a real issue locally | +| `PARTIAL` | Loads and answers basic prompts, but has behavior issues or incomplete coverage | +| `BLOCKED` | Could not be validated locally on this machine | +| `PENDING` | Not checked yet | + +## GGUF Parity + +| Status | Meaning | +|---|---| +| `MATCH` | GGUF showed the same behavior, so the issue is likely not MLX-specific | +| `DIFFERS` | GGUF and MLX both ran, but they diverged in ways that need source-model context to interpret | +| `MLX WORSE` | GGUF handled the same case better than MLX | +| `MLX BETTER` | MLX handled the same case better than GGUF | +| `PENDING` | GGUF comparison not run yet | +| `BLOCKED` | Could not get a meaningful GGUF comparison locally | + +## Pair Quality + +| Status | Meaning | +|---|---| +| `HIGH` | Same family, same size, same instruct/chat target, and close quant class; good parity signal | +| `MEDIUM` | Same family and roughly same target, but quant or conversion path differs materially | +| `LOW` | Only approximate family parity; useful for triage, but not a strong apples-to-apples comparison | +| `PENDING` | Pair quality not assessed yet | + +## Models + +| Family | Model Pair | GGUF Target | MLX Target | Pair Quality | Last Checked | GGUF Exact | MLX Exact | GGUF Behavior | MLX Behavior | Parity | Status | Notes | +|---|---|---|---|---|---|---|---|---|---|---|---| +| Qwen2.5 | 0.5B instruct | `meshllm/qwen2.5-0.5b-instruct-parity-q8_0-gguf/qwen2.5-0.5b-instruct-q8_0.gguf` | `meshllm/qwen2.5-0.5b-instruct-parity-8bit-mlx` | `HIGH` | 2026-04-06 | `PASS` | `PASS` | `STALE` | `STALE` | `MATCH` | `PARTIAL` | Published same-origin parity pair derived from `Qwen/Qwen2.5-0.5B-Instruct`. Local exact validation passed on both backends with matching outputs across the full checked-in exact suite, including `after-monday -> Tuesday`. Older behavior numbers were collected against the public pair before the Qwen2.5 MLX template-rendering fix, so behavior should be rerun on this canonical pair before drawing new parity conclusions. | +| Qwen3 | 0.6B instruct | `meshllm/qwen3-0.6b-parity-q8_0-gguf/qwen3-0.6b-q8_0.gguf` | `meshllm/qwen3-0.6b-parity-8bit-mlx` | `HIGH` | 2026-04-06 | `FAIL` | `FAIL` | `STALE` | `STALE` | `DIFFERS` | `PARTIAL` | Published same-origin parity pair derived from `Qwen/Qwen3-0.6B`. Local exact validation still fails on both backends, but the important result is that MLX now matches the original checkpoint behavior while GGUF drifts from it on multiple prompts (`after-monday`, `banana-color`, `largest-planet`). This row remains useful for backend-drift tracking, but Qwen3 is a weak parity canary and the old behavior numbers should not be carried forward to the new canonical pair. | +| Llama | 3.2 1B instruct | `meshllm/llama-3.2-1b-instruct-parity-f16-gguf/llama-3.2-1b-instruct-f16.gguf` | `meshllm/llama-3.2-1b-instruct-parity-bf16-mlx` | `HIGH` | 2026-04-06 | `FAIL` | `FAIL` | `PENDING` | `PENDING` | `MATCH` | `PARTIAL` | Published same-origin high-fidelity parity pair derived from `meta-llama/Llama-3.2-1B-Instruct`. Local exact validation shows clean agreement on all semantic prompts, with only the known shared capitalization drift on `blue/green/red`. The earlier low-bit MLX `banana-color -> Green` miss does not reproduce at `bf16`, so the canonical row now uses `f16`/`bf16` instead of the noisier public low-bit pair. | +| Gemma 2 | 2B instruct | `meshllm/gemma-2-2b-it-parity-q8_0-gguf/gemma-2-2b-it-q8_0.gguf` | `meshllm/gemma-2-2b-it-parity-8bit-mlx` | `HIGH` | 2026-04-06 | `PASS` | `PASS` | `STALE` | `STALE` | `MATCH` | `PARTIAL` | Published same-origin parity pair derived from `google/gemma-2-2b-it`. Local exact validation passed on both backends with matching outputs across the full checked-in exact suite; the only minor formatting difference was `2 + 2 = **4**` vs `2 + 2 = 4`, which stayed in the same acceptance bucket. Older behavior numbers came from the public pair and should be rerun against this canonical pair before drawing new parity conclusions. | +| Gemma 3 | 1B instruct | `meshllm/gemma-3-1b-it-parity-f16-gguf/gemma-3-1b-it-f16.gguf` | `meshllm/gemma-3-1b-it-parity-bf16-mlx` | `HIGH` | 2026-04-06 | `PASS` | `PASS` | `PENDING` | `PENDING` | `MATCH` | `PARTIAL` | Published same-origin high-fidelity parity pair derived from `google/gemma-3-1b-it`. Validated on `studio54.local`: both backends passed the full exact suite with identical outputs, including `primary-colors -> Red, Green, Blue`. This replaces the noisier public low-bit Gemma3 pair for future parity checks. | +| Gemma 4 | E4B instruct | `meshllm/gemma-4-e4b-it-parity-q8_0-gguf/gemma-4-e4b-it-q8_0.gguf` | `meshllm/gemma-4-e4b-it-parity-8bit-mlx` | `HIGH` | 2026-04-06 | `PASS` | `PASS` | `PENDING` | `PENDING` | `MATCH` | `PARTIAL` | Published same-origin parity pair derived from `google/gemma-4-E4B-it`. Local exact validation passed on both backends with matching outputs across the full checked-in exact suite. The MLX side originally exposed a mixed dense/quantized Gemma 4 loader bug in mesh-llm (`missing language_model.model.per_layer_model_projection.scales`); after fixing that loader path, the same-origin 8bit/Q8_0 pair matched cleanly. | +| GLM4 | 9B 0414 | `meshllm/glm-4-9b-0414-parity-q4_k_m-gguf/glm-4-9b-0414-q4_k_m.gguf` | `meshllm/glm-4-9b-0414-parity-4bit-mlx` | `HIGH` | 2026-04-06 | `PASS` | `PASS` | `PENDING` | `PENDING` | `MATCH` | `PARTIAL` | Published same-origin parity pair derived from `THUDM/GLM-4-9B-0414`. Local exact validation passed on both backends with matching outputs across the full checked-in exact suite, including `primary-colors -> red, green, blue` and `banana-color -> Yellow`. The converted MLX artifact carries its prompt template in `chat_template.jinja`, so the canonical row now points there instead of the older public `tokenizer_config.json`-driven pair. | +| LFM2 | 350M | `meshllm/lfm2-350m-parity-q4_k_m-gguf/lfm2-350m-q4_k_m.gguf` | `meshllm/lfm2-350m-parity-4bit-mlx` | `HIGH` | 2026-04-06 | `FAIL` | `FAIL` | `PENDING` | `PENDING` | `DIFFERS` | `PARTIAL` | Published same-origin backend-drift pair derived from `LiquidAI/LFM2-350M`. Local exact validation shows the GGUF side is materially worse than the MLX side on simple prompts: GGUF answered `primary` and `alt-green` with explanatory prose instead of the requested one-word colors, while MLX returned `blue` and `green` cleanly. We keep this row to track a likely llama/GGUF-side issue rather than as a parity-clean canary. | +| OLMo2 | 7B instruct | `meshllm/olmo2-7b-instruct-parity-q8_0-gguf/olmo2-7b-instruct-q8_0.gguf` | `meshllm/olmo2-7b-instruct-parity-8bit-mlx` | `HIGH` | 2026-04-07 | `PASS` | `PASS` | `PENDING` | `PENDING` | `PENDING` | `PARTIAL` | Published same-origin OLMo2 parity pair derived from `allenai/OLMo-2-1124-7B-Instruct`. Fresh exact rerun against the canonical Q8_0 GGUF and 8-bit MLX artifacts passed on both backends. The branch also carries the MLX runtime/template fixes needed for OLMo2 prompt formatting and stability, but behavior baselines have not been accepted into the checked-in ledger yet. | +| Mamba | 2.8B | `/Users/jdumay/code/worktrees/mesh-llm-validation/output/mamba-debug/mamba-f16.gguf` | `/Users/jdumay/code/worktrees/mesh-llm-validation/mlx/mamba-8bit` | `MEDIUM` | 2026-04-07 | `FAIL` | `FAIL` | `PENDING` | `PENDING` | `BLOCKED` | `BLOCKED` | Local-only candidate pair for `state-spaces/mamba-2.8b-hf`. Exact validation failed on both sides: GGUF drifted badly on one-word/completion prompts, and the MLX path never reached inference because mesh-llm routed the MLX directory into `llama-server` and died on `gguf_init_from_file_ptr: failed to read magic`. Keep this row as an explicit failure record, not a publish target. | +| SmolLM2 | 135M instruct | `/Users/jdumay/code/worktrees/mesh-llm-validation/output/smollm2-135m/SmolLM2-135M-Instruct-Q8_0.gguf` | `/Users/jdumay/code/worktrees/mesh-llm-validation/mlx/smollm2-135m-instruct-4bit` | `MEDIUM` | 2026-04-07 | `FAIL` | `FAIL` | `PENDING` | `PENDING` | `BLOCKED` | `BLOCKED` | Local-only candidate pair for `HuggingFaceTB/SmolLM2-135M-Instruct`. Exact validation failed overall: GGUF passed the factual prompts but missed `primary`, `alt-green`, `alt-red`, `banana-color`, and `after-monday`, while the MLX path never reached inference because mesh-llm routed the MLX directory into `llama-server` and died on `gguf_init_from_file_ptr: failed to read magic`. Keep this row as an explicit failure record, not a publish target. | +| DeepSeek R1 Distill | Qwen 1.5B | `/Users/jdumay/code/worktrees/mesh-llm-validation/output/deepseek/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf` | `/Users/jdumay/code/worktrees/mesh-llm-validation/mlx/deepseek-r1-distill-qwen-1.5b-4bit` | `HIGH` | 2026-04-07 | `FAIL` | `FAIL` | `PENDING` | `PENDING` | `BLOCKED` | `BLOCKED` | Same-origin local parity pair derived from `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`. GGUF exact failed every prompt because the model emitted only `reasoning_content` with empty assistant `content`, so the harness observed blank answers. The MLX path never reached inference because this branch still routes local MLX directories into `llama-server` and dies on `gguf_init_from_file_ptr: failed to read magic`. Keep this row as an explicit failure record until runtime dispatch is fixed. | +| Phi-3 | mini 4k instruct | `/Users/jdumay/code/worktrees/mesh-llm-validation/output/phi3/Phi-3-mini-4k-instruct-Q8_0.gguf` | `/Users/jdumay/code/worktrees/mesh-llm-validation/mlx/phi3-mini-4k-instruct-4bit` | `HIGH` | 2026-04-07 | `PASS` | `PASS` | `PENDING` | `PENDING` | `MATCH` | `PARTIAL` | Same-origin local parity pair derived from `microsoft/Phi-3-mini-4k-instruct`. The remaining exact drift was traced to tokenizer handling, not the forward pass: the Hugging Face `tokenizers` path was honoring `rstrip: true` on Phi-3 role/end markers and stripping the newline/space after `<|user|>`, `<|assistant|>`, and `<|end|>`, which changed the prompt token stream and flipped `blue/red` to `Blue/Red`. Patching those control tokens to preserve following whitespace at load time restored GGUF-compatible prompt tokenization and both backends now pass the full exact suite. | +| gpt-oss | 20B-ish | `unsloth/gpt-oss-20b-GGUF/gpt-oss-20b-Q4_K_M.gguf` | `openai/gpt-oss-20b` | `MEDIUM` | 2026-04-07 | `PENDING` | `PENDING` | `PENDING` | `PENDING` | `PENDING` | `PENDING` | Stageable on this machine. GGUF comes from `unsloth/gpt-oss-20b-GGUF`; the MLX-side source is the upstream `openai/gpt-oss-20b` safetensors checkpoint rather than the much larger mlx-community quantized mirror. | + +## Notes + +- Exact smoke means the deterministic `blue / green / red` style suite plus reasoning-on probe where relevant. +- Behavior means the MT-Bench-derived behavior harness in [`scripts/ci-mt-bench-behavior.py`](/Users/jdumay/.codex/worktrees/e497/mesh-llm/scripts/ci-mt-bench-behavior.py). +- Raw rebuilt-engine exact rerun artifacts are stored under [`MLX_VALIDATION_RESULTS/rerun-20260404-buildsync`](/Users/jdumay/.codex/worktrees/e497/mesh-llm/MLX_VALIDATION_RESULTS/rerun-20260404-buildsync). +- The judgment rule is simple: + - `πŸ¦™ GGUF FAIL` + `🍎 MLX FAIL` = probably shared model weakness + - `πŸ¦™ GGUF PASS` + `🍎 MLX FAIL` = MLX-specific problem and not OK + - `πŸ¦™ GGUF FAIL` + `🍎 MLX PASS` = MLX at least not worse there +- Record enough detail in `Notes` to make the next fix obvious. diff --git a/README.md b/README.md index 0d08654b..aad901b7 100644 --- a/README.md +++ b/README.md @@ -441,6 +441,118 @@ curl localhost:9337/v1/chat/completions \ -d '{"model":"GLM-4.7-Flash-Q4_K_M","messages":[{"role":"user","content":"hello"}]}' ``` +### Model selection and storage + +```bash +# Catalog name (fuzzy match β€” finds Qwen3-8B-Q4_K_M) +mesh-llm --model Qwen3-8B + +# Full catalog name +mesh-llm --model Qwen3-8B-Q4_K_M + +# MLX catalog name +mesh-llm --model Qwen3-4B-MLX --mlx + +# HuggingFace URL (any GGUF) +mesh-llm --model https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_K_M.gguf + +# HuggingFace shorthand (org/repo/file.gguf) +mesh-llm --model bartowski/Llama-3.2-3B-Instruct-GGUF/Llama-3.2-3B-Instruct-Q4_K_M.gguf + +# HuggingFace repo shorthand (works when the repo has one clear primary artifact) +mesh-llm --model mlx-community/Qwen2.5-0.5B-Instruct-4bit --mlx + +# Prefer GGUF or MLX when a repo has multiple candidates +mesh-llm --model some-org/some-repo --gguf +mesh-llm --model some-org/some-repo --mlx + +# Local file path (legacy/raw file mode) +mesh-llm --gguf-file ~/my-models/custom-model.gguf + +# Local MLX model path +mesh-llm --mlx-file ~/my-models/qwen3-mlx/model.safetensors +``` + +Catalog models are downloaded with resume support. Use the `models` subcommands to browse, inspect, and fetch exact refs. + +MLX catalog entries use explicit `-MLX` names so they stay distinct from the GGUF catalog entries. + +- Hugging Face repo snapshots are the canonical managed model store. +- `~/.models/` is deprecated and will be removed in a future release. +- Arbitrary local GGUF files remain supported through `--gguf-file`. +- MLX runtime selection is explicit: use `--mlx` with `--model`, or `--mlx-file` for a local MLX path. +- MoE split artifacts are cached separately under `~/.cache/mesh-llm/splits/`. + +### MLX status + +MLX is available on macOS as an experimental local backend for supported text models. + +- `--model ... --mlx` is required for MLX runtime selection. +- `--mlx-file` is the explicit local-path MLX mode. +- MLX currently bypasses the existing distributed/split GGUF path and runs locally on the Mac serving node. +- On startup, MLX prints an experimental warning and points users to the GitHub issues page if they hit problems. + +### Backend support matrix + +This is the practical local validation snapshot for the backends in this branch. +The `πŸ¦™ GGUF / llama` and `🍎 MLX` columns reflect the families we have actually +run through the checked-in validation matrix here, not every family the upstream +runtimes may support in theory. + +| Family | Example tested `πŸ¦™ GGUF / llama` model | `πŸ¦™ GGUF / llama` | Example tested `🍎 MLX` model | `🍎 MLX` | Notes | +|---|---|---:|---|---:|---| +| Llama | `Llama-3.2-3B-Instruct-Q4_K_M` | βœ… | `Llama-3.2-3B-Instruct-MLX` | βœ… | Dense text | +| Qwen2 | `Qwen2.5-3B-Instruct-Q4_K_M` | βœ… | `Qwen2.5-3B-Instruct-MLX` | βœ… | Dense text | +| Qwen3 | `qwen3-8b-q8_0.gguf` | βœ… | `meshllm/qwen3-8b-parity-8bit-mlx` | βœ… | Exact and MLX behavior are green on the current 8B parity pair; GGUF tagged-thinking remains a known llama-side limitation | +| Gemma 2 | `gemma-2-2b-it-Q4_K_M` | βœ… | `Gemma-2-2B-it-MLX` | βœ… | Dense text | +| Gemma 3 | `gemma-3-1b-it-f16.gguf` | βœ… | `meshllm/gemma-3-1b-it-parity-bf16-mlx` | βœ… | Exact and behavior are green after the Gemma3 MLX config and stop-sequence fixes | +| Gemma 4 | `gemma-4-e4b-it-q8_0.gguf` | βœ… | `meshllm/gemma-4-e4b-it-parity-8bit-mlx` | βœ… | Exact and behavior are green after the Gemma4 MLX replay/termination fixes | +| GLM4 | `GLM-4.7-Flash-Q4_K_M` | βœ… | `GLM-4-9B-0414-MLX` | βœ… | Dense text | +| LFM2 | `lfm2-350m-q4_k_m.gguf` | βœ… | `meshllm/lfm2-350m-parity-4bit-mlx` | βœ… | Exact, behavior, and thinking are green on the current parity pair | +| gpt-oss | β€” | β€” | `openai/gpt-oss` MLX repos | βœ… | Implemented on MLX, not in live smoke matrix | + +Current MLX limitations: + +- experimental +- macOS only +- text only +- the accepted full matrix still has one known GGUF-side limitation: `qwen3-gguf` does not surface tagged reasoning markers in the thinking suite +- larger families like `gpt-oss` are implemented but not yet part of the live macOS smoke matrix +- distributed/split MLX serving is still future work + +Tested MLX families in the current local matrix: + +- `Llama` +- `Qwen2.5` +- `Qwen3` +- `Gemma 2` +- `Gemma 3` +- `Gemma 4` +- `GLM4` +- `LFM2` +- `OLMo2` +- `Mistral` + +Useful commands: + +```bash +mesh-llm models recommended # list built-in recommended models +mesh-llm models installed # list installed local models +mesh-llm models search qwen 8b # search Hugging Face GGUF repos +mesh-llm models search --catalog qwen +mesh-llm models show Qwen/Qwen3-8B-GGUF/Qwen3-8B-Q4_K_M.gguf +mesh-llm models show Qwen3-4B-MLX +mesh-llm models download Qwen/Qwen3-8B-GGUF/Qwen3-8B-Q4_K_M.gguf +mesh-llm models download Qwen3-4B-MLX +mesh-llm models download mlx-community/Qwen2.5-0.5B-Instruct-4bit +mesh-llm models download some-org/some-repo --mlx +mesh-llm models migrate # inspect deprecated ~/.models content +mesh-llm models migrate --apply # materialize recognized HF-backed models into the HF cache +mesh-llm models updates --check # check cached HF repos for newer upstream revisions +mesh-llm models updates --all # refresh all cached HF repos +mesh-llm models updates Qwen/Qwen3-8B-GGUF +``` + ## How it works Mesh LLM keeps the user-facing surface simple: talk to `localhost:9337`, pick a model, and let the mesh decide how to serve it. diff --git a/ci/linux-test.dockerfile b/ci/linux-test.dockerfile index a0980e39..ce6bdf2d 100644 --- a/ci/linux-test.dockerfile +++ b/ci/linux-test.dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y cmake pkg-config git && rm -rf /var/lib WORKDIR /src # Clone llama.cpp fork (not in docker context due to .dockerignore) -RUN git clone -b rebase-upstream-master --depth 1 https://github.com/michaelneale/llama.cpp.git +RUN git clone -b upstream-latest --depth 1 https://github.com/michaelneale/llama.cpp.git # Build llama.cpp (CPU + RPC, no GPU) RUN cmake -B llama.cpp/build -S llama.cpp \ diff --git a/mesh-llm/Cargo.toml b/mesh-llm/Cargo.toml index 316611e2..ac2aae09 100644 --- a/mesh-llm/Cargo.toml +++ b/mesh-llm/Cargo.toml @@ -47,6 +47,12 @@ zip = { version = "2", default-features = false, features = ["deflate"] } hf-hub = { git = "https://github.com/i386/hf-hub.git", rev = "d938c23e74bc7d8a3919843daa8065eb8f10fdf8", default-features = false, features = ["ureq", "tokio", "rustls-tls", "cache-manager"] } unicode-width = "0.2" +[target.'cfg(target_os = "macos")'.dependencies] +chrono = { version = "0.4", default-features = false, features = ["clock"] } +mlx-rs = { version = "0.25", features = ["safetensors"] } +minijinja = { version = "2", features = ["json", "loop_controls"] } +tokenizers = "0.21" + [dev-dependencies] axum = "0.8" serial_test = "3" diff --git a/mesh-llm/docs/MLX_FAMILY_BRINGUP.md b/mesh-llm/docs/MLX_FAMILY_BRINGUP.md new file mode 100644 index 00000000..9fe7ae78 --- /dev/null +++ b/mesh-llm/docs/MLX_FAMILY_BRINGUP.md @@ -0,0 +1,245 @@ +# MLX Family Bring-Up Workflow + +Use this workflow when adding a new model family to the MLX engine and proving +that the family works cleanly across both backends. + +This is the end-to-end path for: + +1. choosing a candidate family +2. downloading the original upstream checkpoint +3. deriving both `GGUF` and `MLX` artifacts from the same source +4. validating the family through the `llama` and `MLX` engines +5. publishing accepted artifacts to `meshllm` on Hugging Face +6. updating the checked-in validation matrix + +This document is intentionally operational. It links out to the lower-level +docs when you need exact publishing or matrix details. + +## Use This For + +Use this workflow for dense text families first. + +Good early candidates: + +- `Mistral` +- `Phi-3` +- other small or medium Llama-like families with standard safetensors layouts + +Avoid as first bring-up targets unless there is a strong reason: + +- large MoE families +- multimodal families +- families that require custom conversion logic on both backends + +## Success Criteria + +Do not call a family supported just because the loader accepts its config. + +A new family is only ready when all of the following are true: + +1. the source checkpoint downloads cleanly +2. the source checkpoint converts to both `GGUF` and `MLX` +3. the derived `GGUF` artifact runs through the `llama` path +4. the derived `MLX` artifact runs through the `MLX` path +5. both artifacts pass the exact validation suite +6. both artifacts are healthy in the behavior suite +7. the pair is published under `meshllm/*` +8. the validation matrix is updated to pin the new pair + +## Phase 1: Pick The Family And Checkpoint + +Prefer a small instruct checkpoint with: + +- a public Hugging Face source repo +- a known `HF -> GGUF` conversion path +- a known `HF -> MLX` conversion path +- a size that fits local or remote test hardware + +Use one exact upstream source checkpoint for both derived artifacts. Do not mix +third-party `GGUF` and `MLX` repos when the goal is backend parity. + +## Phase 2: Download The Original Checkpoint + +Follow the same-origin rules in +[SAME_ORIGIN_PARITY_WORKFLOW.md](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/docs/SAME_ORIGIN_PARITY_WORKFLOW.md). + +Keep the original checkpoint intact under: + +```bash +~/.cache/mesh-llm-origin/ +``` + +Inspect the source repo first: + +```bash +hf download --dry-run +``` + +Then download the config, tokenizer, and weight files into the origin cache. + +## Phase 3: Convert To GGUF And MLX + +Derive both artifacts from the same source checkpoint. + +Use the exact conversion commands in +[SAME_ORIGIN_PARITY_WORKFLOW.md](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/docs/SAME_ORIGIN_PARITY_WORKFLOW.md). + +Recommended local layout: + +```bash +~/.cache/mesh-llm-debug/-same-origin/ + gguf/ + mlx/ +``` + +Typical outcome: + +- `gguf/-f16.gguf` +- `gguf/-q8_0.gguf` or `gguf/-q4_k_m.gguf` +- `mlx/-bf16/` or `mlx/-8bit/` + +## Phase 4: Make The MLX Loader Accept The Family + +This is the code bring-up step. + +For new MLX families, inspect: + +- [mesh-llm/src/mlx/model.rs](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/src/mlx/model.rs) + +In practice, the first pass is usually: + +1. add the family to `config_supports_mlx` +2. map the family into `model_architecture()` if it needs non-default handling +3. add any tensor transform or tokenizer patching needed for that family +4. add focused tests for acceptance and failure modes + +Do not stop at config detection. A family is not real support until the derived +artifact runs through the end-to-end validation flow below. + +## Phase 5: Run Validation Through Both Engines + +The validation system is documented in: + +- [TESTING.md](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/docs/TESTING.md) +- [testdata/validation/README.md](/Users/jdumay/code/worktrees/mesh-llm-validation/testdata/validation/README.md) + +Use three layers of validation. + +### 1. Exact Suite + +Run the deterministic exact suite against both backends: + +```bash +just build +scripts/run-validation-matrix.py --suite exact --skip-build --cases --stamp "-exact" +``` + +This checks: + +- model load and readiness +- `/v1/chat/completions` +- deterministic prompt-following +- no leaked reasoning markup when `enable_thinking=false` +- explicit reasoning mode when configured +- `/v1/models` + +Review: + +- `MLX_VALIDATION_RESULTS//exact-summary.tsv` +- `MLX_VALIDATION_RESULTS//exact-cross-backend-parity.tsv` +- `MLX_VALIDATION_RESULTS//exact//chat/*.json` + +### 2. Behavior Suite + +Run the MT-Bench-derived behavior suite: + +```bash +scripts/run-validation-matrix.py --suite behavior --skip-build --cases --stamp "-behavior" +``` + +This is a health check, not a benchmark. It catches: + +- empty outputs +- timeout and liveness failures +- leaked reasoning markup +- repeated lines +- repeated sentences +- repeated 6-grams +- low tail-token diversity + +Review: + +- `MLX_VALIDATION_RESULTS//behavior-summary.tsv` +- `MLX_VALIDATION_RESULTS//behavior//report.json` + +### 3. Cross-Backend Parity Review + +Treat parity as a separate review step. + +Check: + +1. `GGUF run` vs `GGUF baseline` +2. `MLX run` vs `MLX baseline` +3. `MLX run` vs `GGUF baseline` + +For a new family, the pair should at minimum land in the same expectation +bucket on the strict exact prompts. Prefer `same-output` where realistic. + +## Phase 6: Publish Accepted Artifacts To Hugging Face + +Once the derived pair is good enough, publish it under `meshllm`. + +Naming convention: + +- `meshllm/-parity--gguf` +- `meshllm/-parity--mlx` + +Use the publishing commands and README skeleton from +[SAME_ORIGIN_PARITY_WORKFLOW.md](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/docs/SAME_ORIGIN_PARITY_WORKFLOW.md). + +Prefer: + +```bash +hf repo create meshllm/ --type model +hf upload-large-folder meshllm/ /tmp/ --repo-type model +``` + +## Phase 7: Update The Validation Matrix + +Only after publishing and accepting the pair, update the checked-in matrix. + +Files to update: + +- [testdata/validation/matrix.json](/Users/jdumay/code/worktrees/mesh-llm-validation/testdata/validation/matrix.json) +- [scripts/mlx-parity-exact.tsv](/Users/jdumay/code/worktrees/mesh-llm-validation/scripts/mlx-parity-exact.tsv) +- [README.md](/Users/jdumay/code/worktrees/mesh-llm-validation/README.md) + +Depending on the change, also update: + +- `.github/workflows/ci.yml` +- `.github/workflows/behavior.yml` + +Treat this as a pinned artifact change, not a routine rerun. + +## Recommended Bring-Up Checklist + +Use this checklist for each new family: + +1. pick one upstream checkpoint +2. download the original source checkpoint +3. derive `GGUF` and `MLX` artifacts from that same source +4. add or adjust MLX family support in code +5. run exact validation on both backends +6. run behavior validation on both backends +7. review parity artifacts and logs +8. publish both artifacts to `meshllm` +9. pin the pair in `matrix.json` +10. update the backend support matrix in `README.md` + +## Notes + +- Prefer dense text families before MoE and multimodal families. +- Do not update `README.md` support claims before the matrix run is clean. +- Do not publish a same-origin pair just because conversion succeeded. +- If the family needs mixed-version or protocol changes elsewhere in the repo, + treat that as a separate review stream. diff --git a/mesh-llm/docs/SAME_ORIGIN_PARITY_WORKFLOW.md b/mesh-llm/docs/SAME_ORIGIN_PARITY_WORKFLOW.md new file mode 100644 index 00000000..9d91870e --- /dev/null +++ b/mesh-llm/docs/SAME_ORIGIN_PARITY_WORKFLOW.md @@ -0,0 +1,305 @@ +# Same-Origin Parity Workflow + +Use this workflow when a public GGUF/MLX pair is noisy and we want a cleaner +backend comparison derived from the same original checkpoint. + +This is the process we used for: + +- `Qwen2.5` +- `Gemma 2` +- `Gemma 3` +- `Gemma 4` + +The goal is: + +1. download the original checkpoint once +2. derive both `GGUF` and `MLX` artifacts from that same source +3. run the exact validation suite against those derived artifacts +4. only publish and switch the matrix when the pair is good enough to bless + +## Preconditions + +You need: + +- `HF_TOKEN` exported with access to any gated upstream repos +- the bundled `llama.cpp` tools built via `just build` +- a Python environment with: + - `transformers` for `convert_hf_to_gguf.py` + - `mlx_lm` for `python -m mlx_lm.convert` + +If the default `python3` does not provide those modules, use the Python +interpreter from the environment that does. + +## Directory Conventions + +Use these paths consistently: + +| Purpose | Path pattern | +|---|---| +| original source checkpoint | `~/.cache/mesh-llm-origin/` | +| local derived artifacts | `~/.cache/mesh-llm-debug/-same-origin/` | +| local exact artifacts | `MLX_VALIDATION_RESULTS//...` | + +Example model slugs: + +- `qwen2.5-0.5b-instruct` +- `gemma-3-1b-it` +- `gemma-4-e4b-it` + +## 1. Download The Original Checkpoint + +First inspect the upstream repo and confirm the real weight filenames: + +```bash +hf download --dry-run +``` + +For the large batch of additional parity families on `studio54`, use: + +```bash +zsh scripts/download-origin-checkpoints-studio54.sh +``` + +To fetch only a subset: + +```bash +zsh scripts/download-origin-checkpoints-studio54.sh mistral phi3 deepseek +``` + +The script downloads into `~/.cache/mesh-llm-origin-batch/` by default and +continues past gated or failed repos so the rest of the batch can proceed. + +By default it excludes families that are too large or risky for the full +download-convert-validate workflow on `studio54`'s `128 GB` M1 Ultra: + +- `mixtral` +- `cohere-command-r` +- `jamba` + +Those can be added explicitly with: + +```bash +zsh scripts/download-origin-checkpoints-studio54.sh --include-heavy +``` + +For metadata and tokenizer files, `hf download` is fine: + +```bash +mkdir -p ~/.cache/mesh-llm-origin/ +hf download \ + config.json generation_config.json tokenizer.json tokenizer_config.json \ + special_tokens_map.json README.md LICENSE* USE_POLICY* \ + --local-dir ~/.cache/mesh-llm-origin/ +``` + +For large weight files, prefer direct resumable downloads. This avoids the +stale lock and partial-download problems we hit with `hf download`: + +```bash +curl -L -C - \ + -H "Authorization: Bearer $HF_TOKEN" \ + https://huggingface.co//resolve/main/ \ + -o ~/.cache/mesh-llm-origin// +``` + +If the repo uses sharded weights, repeat that for each shard: + +```bash +curl -L -C - \ + -H "Authorization: Bearer $HF_TOKEN" \ + https://huggingface.co//resolve/main/model-00001-of-00002.safetensors \ + -o ~/.cache/mesh-llm-origin//model-00001-of-00002.safetensors +``` + +Notes: + +- Download source checkpoints in serial unless there is a specific reason to do + otherwise. +- Keep the original checkpoint intact under `~/.cache/mesh-llm-origin/`. +- Do not use third-party GGUF or MLX repos when the goal is same-origin parity. + +## 2. Convert To GGUF + +Create a high-fidelity GGUF first, then quantize if needed: + +```bash +SRC=~/.cache/mesh-llm-origin/ +OUT=~/.cache/mesh-llm-debug/-same-origin +mkdir -p "$OUT/gguf" + +python3 llama.cpp/convert_hf_to_gguf.py "$SRC" \ + --outfile "$OUT/gguf/-f16.gguf" \ + --outtype f16 +``` + +Quantize from that high-fidelity GGUF: + +```bash +./llama.cpp/build/bin/llama-quantize \ + "$OUT/gguf/-f16.gguf" \ + "$OUT/gguf/-q8_0.gguf" \ + Q8_0 +``` + +Or: + +```bash +./llama.cpp/build/bin/llama-quantize \ + "$OUT/gguf/-f16.gguf" \ + "$OUT/gguf/-q4_k_m.gguf" \ + Q4_K_M +``` + +Use the quant class that best matches the MLX artifact you intend to compare. + +## 3. Convert To MLX + +For quantized MLX: + +```bash +SRC=~/.cache/mesh-llm-origin/ +OUT=~/.cache/mesh-llm-debug/-same-origin +mkdir -p "$OUT/mlx" + +python3 -m mlx_lm.convert \ + --hf-path "$SRC" \ + --mlx-path "$OUT/mlx/-8bit" \ + -q \ + --q-bits 8 +``` + +For high-fidelity MLX: + +```bash +python3 -m mlx_lm.convert \ + --hf-path "$SRC" \ + --mlx-path "$OUT/mlx/-bf16" \ + --dtype bfloat16 +``` + +If your `mlx_lm` environment uses a different Python binary, use that binary +instead of `python3`. + +## 4. Run Exact Validation + +Run the local exact suite against the derived artifacts before publishing +anything: + +```bash +STAMP=-same-origin-$(date +%Y%m%d) +scripts/run-validation-matrix.py --suite exact --skip-build --cases --stamp "$STAMP" +``` + +If you need a direct one-off case: + +```bash +VALIDATION_RESULTS_ROOT="$PWD/MLX_VALIDATION_RESULTS" \ +VALIDATION_RESULTS_STAMP="$STAMP" \ +scripts/run-validation-case.sh gguf \ +python3 scripts/ci-exact-smoke.py \ + --backend gguf \ + --mesh-llm target/release/mesh-llm \ + --bin-dir llama.cpp/build/bin \ + --model "$OUT/gguf/.gguf" \ + --prompt-suite-json "$PWD/testdata/validation/exact-prompts.json" +``` + +Review at minimum: + +- `MLX_VALIDATION_RESULTS//exact-summary.tsv` +- `MLX_VALIDATION_RESULTS//exact-cross-backend-parity.tsv` +- per-prompt chat artifacts under: + - `MLX_VALIDATION_RESULTS//exact//chat/` + +Do not publish or switch the matrix just because a conversion succeeded. +Publish only if the derived pair gives us a cleaner parity story than the public +pair. + +## 5. Publish To Hugging Face + +Use `meshllm` model repos and make the pair naming explicit: + +| Backend | Repo naming pattern | +|---|---| +| GGUF | `meshllm/-parity--gguf` | +| MLX | `meshllm/-parity--mlx` | + +Examples: + +- `meshllm/qwen2.5-0.5b-instruct-parity-q8_0-gguf` +- `meshllm/qwen2.5-0.5b-instruct-parity-8bit-mlx` +- `meshllm/gemma-3-1b-it-parity-f16-gguf` +- `meshllm/gemma-3-1b-it-parity-bf16-mlx` + +Create a temp publish directory containing: + +- the artifact +- a `README.md` with YAML front matter + +Minimal `README.md` skeleton: + +```md +--- +license: other +library_name: llama.cpp +pipeline_tag: text-generation +tags: + - meshllm + - parity + - same-origin +--- + +# + +Same-origin parity artifact derived from `<upstream-repo>`. +``` + +Then create the repo and upload: + +```bash +hf repo create meshllm/<repo-name> --type model +hf upload-large-folder meshllm/<repo-name> /tmp/<publish-dir> --repo-type model +``` + +Use `upload-large-folder` for the large artifacts rather than one-off `hf upload`. + +## 6. Switch The Matrix + +Only after the published pair is accepted, update: + +- `testdata/validation/matrix.json` +- `scripts/mlx-parity-exact.tsv` +- `.github/workflows/ci.yml` +- `.github/workflows/behavior.yml` +- `MLX_VALIDATION_MATRIX.md` + +The matrix row should say: + +- what upstream original checkpoint the pair was derived from +- whether behavior results are current, stale, or pending +- why this pair is better than the previous public pair + +## 7. Remote Confirmation + +If the pair is important, rerun it on `studio54.local` before treating it as the +canonical row. + +Use the remote rules from `AGENTS.md`: + +- launch long-running work in `tmux` +- use `zsh -lc` on macOS +- prefer `scp` + small remote scripts for nontrivial jobs +- verify the `tmux` session twice before calling it live + +## Current Publishing Rule + +Use this same-origin workflow selectively. + +Good candidates: + +- rows with suspicious public-pair drift +- small and medium canary models +- families where parity conclusions matter operationally + +Do not publish a same-origin pair just because it exists. Publish it when it +gives us a materially better canonical parity row. diff --git a/mesh-llm/docs/TESTING.md b/mesh-llm/docs/TESTING.md index 6b1d6f21..088da2b2 100644 --- a/mesh-llm/docs/TESTING.md +++ b/mesh-llm/docs/TESTING.md @@ -1,5 +1,119 @@ # Testing mesh-llm +If you are bringing up a new MLX model family end-to-end, start with +[`MLX_FAMILY_BRINGUP.md`](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/docs/MLX_FAMILY_BRINGUP.md). + +## Local validation matrix + +Use the checked-in validation runner when you want to rerun the local backend +comparison matrix and preserve raw per-case artifacts for both the deterministic +exact suite and the MT-Bench-derived behavior suite: + +```bash +just build +scripts/run-validation-matrix.py --stamp rerun-$(date +%Y%m%d-%H%M%S) +``` + +Outputs: + +- exact summary TSV: `.cache/mlx-validation/<stamp>/exact-summary.tsv` +- behavior summary TSV: `.cache/mlx-validation/<stamp>/behavior-summary.tsv` +- combined summary TSV: `.cache/mlx-validation/<stamp>/validation-summary.tsv` +- exact baseline comparison TSV: `.cache/mlx-validation/<stamp>/exact-baseline-comparison.tsv` +- behavior baseline comparison TSV: `.cache/mlx-validation/<stamp>/behavior-baseline-comparison.tsv` +- parity-vs-baseline TSV: `.cache/mlx-validation/<stamp>/parity-vs-canonical-baseline.tsv` +- raw logs per case: + - `.cache/mlx-validation/<stamp>/exact/<case-id>/` + - `.cache/mlx-validation/<stamp>/behavior/<case-id>/` +- per-prompt raw request/response artifacts for exact runs: + - `.cache/mlx-validation/<stamp>/exact/<case-id>/chat/<label>.json` +- raw `/v1/models` payload for exact runs: + - `.cache/mlx-validation/<stamp>/exact/<case-id>/models/v1-models.json` + +Useful options: + +```bash +# exact-only parity rerun +scripts/run-validation-matrix.py --suite exact --skip-build + +# behavior-only rerun +scripts/run-validation-matrix.py --suite behavior --skip-build + +# rerun only one model family on both backends +scripts/run-validation-matrix.py --skip-build --cases qwen25 + +# run only the GGUF side +scripts/run-validation-matrix.py --skip-build --backend gguf + +# run only the MLX side +scripts/run-validation-matrix.py --skip-build --backend mlx + +# shorten the behavior run for local debugging +scripts/run-validation-matrix.py --suite behavior --skip-build --cases qwen25 --max-prompts 3 + +# store artifacts somewhere else +scripts/run-validation-matrix.py --root /tmp/mesh-llm-validation +``` + +The shared matrix definition lives in: + +- `testdata/validation/matrix.json` +- `testdata/validation/baselines.json` + +Each row pins the exact GGUF and MLX artifacts to avoid model drift and tags the +row with an expectation class such as `strict` or `weak-but-stable` so tiny +model weirdness stays explicit instead of silently redefining success. + +Baseline policy: + +- `GGUF` is the canonical checked-in baseline. +- New `GGUF` runs are compared against the checked-in `GGUF` baseline to catch + reference-backend regressions. +- `MLX` runs are compared against both: + - the checked-in `MLX` baseline for backend self-consistency + - the checked-in `GGUF` baseline for parity +- Behavior baselines stay summary-based rather than full-output goldens. Record + only stable facts such as exit code, failed prompt count, and flagged prompt + ids/categories after you accept a behavior run. + +## CI structure + +The validation matrix is split across CI by cost and signal. + +For pull requests and branch pushes: + +- Keep the current job names: + - `changes` + - `linux` + - `macos` + - `macos_mlx` + - `gguf_smokes` + - `linux_cuda` +- `linux` and `macos` remain the foundation/build jobs. +- `gguf_smokes` and `macos_mlx` should be treated as the exact-matrix PR gates. +- `linux_cuda` remains a build / flavor confidence job, not part of the parity + matrix. +- Exact regressions should fail PR CI. + +For nightly / scheduled validation and release validation: + +- Keep the full MT-Bench-derived behavior suite in `behavior.yml`. +- Do not use the behavior suite as a routine PR gate. +- Compare behavior results against the checked-in summary baselines and review + artifacts when they diverge. + +Execution order matters: + +- When running `--suite all --backend both`, execute grouped phases rather than + alternating per model: + 1. all `gguf` exact rows + 2. all `mlx` exact rows + 3. all `gguf` behavior rows + 4. all `mlx` behavior rows + +That grouped order is the expected orchestration for local and remote matrix +runs. + ## Local inspection ### 0. Inspect local GPUs diff --git a/mesh-llm/plugin/src/io.rs b/mesh-llm/plugin/src/io.rs index 2cd8543d..9e4a5296 100644 --- a/mesh-llm/plugin/src/io.rs +++ b/mesh-llm/plugin/src/io.rs @@ -158,7 +158,7 @@ pub async fn bind_side_stream(plugin_id: &str, stream_id: &str) -> Result<LocalL let server = tokio::net::windows::named_pipe::ServerOptions::new() .create(&endpoint) .with_context(|| format!("Failed to create side stream pipe {endpoint}"))?; - return Ok(LocalListener::Pipe(endpoint, server)); + Ok(LocalListener::Pipe(endpoint, server)) } } diff --git a/mesh-llm/proto/node.proto b/mesh-llm/proto/node.proto index ded15f0d..642d2e87 100644 --- a/mesh-llm/proto/node.proto +++ b/mesh-llm/proto/node.proto @@ -115,6 +115,7 @@ message ModelRuntimeDescriptor { optional string identity_hash = 2; optional uint32 context_length = 3; bool ready = 4; + optional string backend = 5; } message ModelMoeInfo { diff --git a/mesh-llm/src/api/mod.rs b/mesh-llm/src/api/mod.rs index 4927c0a4..337adfa2 100644 --- a/mesh-llm/src/api/mod.rs +++ b/mesh-llm/src/api/mod.rs @@ -251,6 +251,10 @@ impl MeshApi { self.inner.lock().await.primary_backend = Some(backend); } + pub async fn clear_primary_backend(&self) { + self.inner.lock().await.primary_backend = None; + } + pub async fn set_draft_name(&self, name: String) { self.inner.lock().await.draft_name = Some(name); } diff --git a/mesh-llm/src/cli/commands/auth.rs b/mesh-llm/src/cli/commands/auth.rs index 6d883999..d67ec1d9 100644 --- a/mesh-llm/src/cli/commands/auth.rs +++ b/mesh-llm/src/cli/commands/auth.rs @@ -785,6 +785,47 @@ mod tests { use super::*; use serial_test::serial; + fn is_headless_keychain_error(err: &crate::crypto::CryptoError) -> bool { + match err { + crate::crypto::CryptoError::KeychainUnavailable { reason } + | crate::crypto::CryptoError::KeychainAccessDenied { reason } => { + reason.contains("User interaction is not allowed.") + || reason.contains("User interaction is not allowed") + } + _ => false, + } + } + + fn is_headless_keychain_anyhow(err: &anyhow::Error) -> bool { + err.chain().any(|cause| { + cause + .downcast_ref::<crate::crypto::CryptoError>() + .is_some_and(is_headless_keychain_error) + }) + } + + fn keychain_set_or_skip(account: &str, secret: &str) -> bool { + match crate::crypto::keychain_set(KEYCHAIN_SERVICE, account, secret) { + Ok(()) => true, + Err(err) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + false + } + Err(err) => panic!("unexpected keychain set failure: {err}"), + } + } + + fn keychain_get_or_skip(account: &str) -> Option<Option<String>> { + match crate::crypto::keychain_get(KEYCHAIN_SERVICE, account) { + Ok(value) => Some(value), + Err(err) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + None + } + Err(err) => panic!("unexpected keychain get failure: {err}"), + } + } + #[test] fn defaults_to_keychain_for_new_keystore_when_available() { assert!(should_default_to_keychain(false, false, true)); @@ -832,7 +873,9 @@ mod tests { let account = crate::crypto::owner_keychain_account_for_path(&bad_path); let previous_secret = "previous-unlock-secret-do-not-lose"; - crate::crypto::keychain_set(KEYCHAIN_SERVICE, &account, previous_secret).unwrap(); + if !keychain_set_or_skip(&account, previous_secret) { + return; + } let result = run_init(Some(bad_path.clone()), true, false, true); assert!( @@ -840,7 +883,9 @@ mod tests { "run_init must fail when save cannot succeed" ); - let restored = crate::crypto::keychain_get(KEYCHAIN_SERVICE, &account).unwrap(); + let Some(restored) = keychain_get_or_skip(&account) else { + return; + }; assert_eq!( restored.as_deref(), Some(previous_secret), @@ -875,7 +920,9 @@ mod tests { "run_init must fail when save cannot succeed" ); - let residual = crate::crypto::keychain_get(KEYCHAIN_SERVICE, &account).unwrap(); + let Some(residual) = keychain_get_or_skip(&account) else { + return; + }; assert_eq!( residual, None, "a fresh init failure must leave no keychain entry behind" @@ -897,8 +944,14 @@ mod tests { std::fs::create_dir_all(&dir).unwrap(); let path = dir.join("owner-keystore.json"); - run_init(Some(path.clone()), false, false, false) - .expect("auth init should default to keychain when available"); + if let Err(err) = run_init(Some(path.clone()), false, false, false) { + if is_headless_keychain_anyhow(&err) { + eprintln!("headless keychain access denied, skipping"); + std::fs::remove_dir_all(&dir).ok(); + return; + } + panic!("auth init should default to keychain when available: {err}"); + } assert!(path.exists(), "keystore file should exist"); let info = keystore_metadata(&path).unwrap(); @@ -908,13 +961,25 @@ mod tests { ); let account = crate::crypto::owner_keychain_account_for_path(&path); - let stored = crate::crypto::keychain_get(KEYCHAIN_SERVICE, &account).unwrap(); + let Some(stored) = keychain_get_or_skip(&account) else { + std::fs::remove_dir_all(&dir).ok(); + return; + }; assert!( stored.is_some(), "keychain must have a passphrase entry for this keystore path" ); - let kp = load_owner_keypair_from_keychain(&path).expect("load via keychain must succeed"); + let kp = match load_owner_keypair_from_keychain(&path) { + Ok(value) => value, + Err(OwnerKeychainLoadError::Crypto(err)) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + crate::crypto::keychain_delete(KEYCHAIN_SERVICE, &account).ok(); + std::fs::remove_dir_all(&dir).ok(); + return; + } + Err(err) => panic!("load via keychain must succeed: {err:?}"), + }; assert_eq!(kp.owner_id(), info.owner_id); crate::crypto::keychain_delete(KEYCHAIN_SERVICE, &account).ok(); diff --git a/mesh-llm/src/cli/commands/models/mod.rs b/mesh-llm/src/cli/commands/models/mod.rs index 043c38ca..f9cf2728 100644 --- a/mesh-llm/src/cli/commands/models/mod.rs +++ b/mesh-llm/src/cli/commands/models/mod.rs @@ -177,10 +177,17 @@ pub async fn run_model_download( ) -> Result<()> { let formatter = models_formatter(json_output); let details = show_exact_model(model_ref).await.ok(); - let download_ref = details - .as_ref() - .map(|d| d.exact_ref.as_str()) - .unwrap_or(model_ref); + let explicit_file_ref = model_ref.ends_with(".gguf") + || model_ref.ends_with(".safetensors") + || model_ref.ends_with(".safetensors.index.json"); + let download_ref = if explicit_file_ref { + model_ref + } else { + details + .as_ref() + .map(|d| d.exact_ref.as_str()) + .unwrap_or(model_ref) + }; let path = download_exact_ref(download_ref).await?; if !include_draft { return formatter.render_download(model_ref, &path, details.as_ref(), false, None); diff --git a/mesh-llm/src/cli/mod.rs b/mesh-llm/src/cli/mod.rs index bbc462ad..9d3151bf 100644 --- a/mesh-llm/src/cli/mod.rs +++ b/mesh-llm/src/cli/mod.rs @@ -246,13 +246,20 @@ pub(crate) struct Cli { #[arg(long)] pub(crate) auto: bool, - /// Model to serve (path, catalog name, HF exact ref, or HuggingFace URL). + /// Model to serve (path, catalog name, Hugging Face ref, repo shorthand, or URL). #[arg(long)] pub(crate) model: Vec<PathBuf>, /// Raw local GGUF file to serve directly (repeatable). - #[arg(long)] - pub(crate) gguf: Vec<PathBuf>, + #[arg(long = "gguf-file")] + pub(crate) gguf_file: Vec<PathBuf>, + + /// Raw local MLX model path to serve directly (repeatable). + /// + /// Accepts a model directory or a file inside one, such as + /// `config.json`, `tokenizer.json`, or `model.safetensors`. + #[arg(long = "mlx-file")] + pub(crate) mlx_file: Vec<PathBuf>, /// Explicit mmproj sidecar to pass to llama-server for the primary served model. #[arg(long, hide = true)] @@ -647,7 +654,11 @@ pub(crate) fn legacy_runtime_surface_warning( )); } - if !cli.model.is_empty() || !cli.gguf.is_empty() || cli.mmproj.is_some() { + if !cli.model.is_empty() + || !cli.gguf_file.is_empty() + || !cli.mlx_file.is_empty() + || cli.mmproj.is_some() + { return Some(format!( "⚠️ top-level serving flags now map to `mesh-llm serve`.\n Please use: {}", suggested_serve_command(original_args) diff --git a/mesh-llm/src/cli/models.rs b/mesh-llm/src/cli/models.rs index 274fc4a8..219eff55 100644 --- a/mesh-llm/src/cli/models.rs +++ b/mesh-llm/src/cli/models.rs @@ -44,7 +44,7 @@ pub enum ModelsCommand { }, /// Show details for one exact model reference. Show { - /// Exact catalog id, Hugging Face ref, or direct URL. + /// Exact catalog id, Hugging Face ref, repo shorthand, or direct URL. model: String, /// Emit JSON output. #[arg(long)] @@ -52,7 +52,7 @@ pub enum ModelsCommand { }, /// Download one exact model reference. Download { - /// Exact catalog id, Hugging Face ref, or direct URL. + /// Exact catalog id, Hugging Face ref, repo shorthand, or direct URL. model: String, /// Also download the recommended draft model for speculative decoding. #[arg(long)] diff --git a/mesh-llm/src/crypto/keychain.rs b/mesh-llm/src/crypto/keychain.rs index 7508f32d..8bd3a283 100644 --- a/mesh-llm/src/crypto/keychain.rs +++ b/mesh-llm/src/crypto/keychain.rs @@ -230,6 +230,17 @@ mod tests { format!("test-{}-{}", tag, rand::random::<u64>()) } + fn is_headless_keychain_error(err: &CryptoError) -> bool { + match err { + CryptoError::KeychainUnavailable { reason } + | CryptoError::KeychainAccessDenied { reason } => { + reason.contains("User interaction is not allowed.") + || reason.contains("User interaction is not allowed") + } + _ => false, + } + } + #[test] #[serial] fn round_trip_set_get_delete() { @@ -240,14 +251,41 @@ mod tests { let account = test_account("round-trip"); let secret = "correct horse battery staple"; - set_secret(KEYCHAIN_SERVICE, &account, secret).unwrap(); - let got = get_secret(KEYCHAIN_SERVICE, &account).unwrap(); + if let Err(err) = set_secret(KEYCHAIN_SERVICE, &account, secret) { + if is_headless_keychain_error(&err) { + eprintln!("headless keychain access denied, skipping"); + return; + } + panic!("unexpected keychain set failure: {err}"); + } + let got = match get_secret(KEYCHAIN_SERVICE, &account) { + Ok(value) => value, + Err(err) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + return; + } + Err(err) => panic!("unexpected keychain get failure: {err}"), + }; assert_eq!(got.as_deref(), Some(secret)); - let removed = delete_secret(KEYCHAIN_SERVICE, &account).unwrap(); + let removed = match delete_secret(KEYCHAIN_SERVICE, &account) { + Ok(value) => value, + Err(err) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + return; + } + Err(err) => panic!("unexpected keychain delete failure: {err}"), + }; assert!(removed); - let after = get_secret(KEYCHAIN_SERVICE, &account).unwrap(); + let after = match get_secret(KEYCHAIN_SERVICE, &account) { + Ok(value) => value, + Err(err) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + return; + } + Err(err) => panic!("unexpected keychain get failure after delete: {err}"), + }; assert_eq!(after, None); } @@ -283,10 +321,29 @@ mod tests { return; } let account = test_account("overwrite"); - set_secret(KEYCHAIN_SERVICE, &account, "first").unwrap(); - set_secret(KEYCHAIN_SERVICE, &account, "second").unwrap(); + if let Err(err) = set_secret(KEYCHAIN_SERVICE, &account, "first") { + if is_headless_keychain_error(&err) { + eprintln!("headless keychain access denied, skipping"); + return; + } + panic!("unexpected first keychain set failure: {err}"); + } + if let Err(err) = set_secret(KEYCHAIN_SERVICE, &account, "second") { + if is_headless_keychain_error(&err) { + eprintln!("headless keychain access denied, skipping"); + return; + } + panic!("unexpected overwrite keychain set failure: {err}"); + } - let got = get_secret(KEYCHAIN_SERVICE, &account).unwrap(); + let got = match get_secret(KEYCHAIN_SERVICE, &account) { + Ok(value) => value, + Err(err) if is_headless_keychain_error(&err) => { + eprintln!("headless keychain access denied, skipping"); + return; + } + Err(err) => panic!("unexpected overwrite keychain get failure: {err}"), + }; assert_eq!(got.as_deref(), Some("second")); delete_secret(KEYCHAIN_SERVICE, &account).ok(); diff --git a/mesh-llm/src/inference/election.rs b/mesh-llm/src/inference/election.rs index 1a6e2875..7eb96b87 100644 --- a/mesh-llm/src/inference/election.rs +++ b/mesh-llm/src/inference/election.rs @@ -270,6 +270,16 @@ fn stop_requested(stop_rx: &watch::Receiver<bool>) -> bool { *stop_rx.borrow() } +fn local_backend_name(model: &Path) -> &'static str { + #[cfg(target_os = "macos")] + if let Some(dir) = crate::mlx::mlx_model_dir(model) { + if crate::mlx::is_mlx_model_dir(dir) { + return "MLX server"; + } + } + "llama-server" +} + async fn wait_for_peer_moe_ranking( model_name: &str, model_path: &Path, @@ -1868,8 +1878,13 @@ pub async fn election_loop( .await; llama_process = Some(process); if let Some(ref process) = llama_process { + let backend = if local_backend_name(&model) == "MLX server" { + "mlx" + } else { + "llama" + }; on_process(Some(LocalProcessInfo { - backend: "llama".into(), + backend: backend.into(), pid: process.handle.pid(), port: llama_port, context_length: process.context_length, @@ -1877,8 +1892,9 @@ pub async fn election_loop( } on_change(true, true); eprintln!( - "βœ… [{}] llama-server ready on internal port {llama_port}", - model_name + "βœ… [{}] {} ready on internal port {llama_port}", + model_name, + local_backend_name(&model) ); } else { // We're a worker in split mode. Find who the host is. @@ -2196,7 +2212,7 @@ async fn moe_election_loop( last_plan_change_at = tokio::time::Instant::now(); if matches!(role, MoePlacementRole::Standby) { - node.set_model_runtime_context_length(&model_name, None) + node.set_model_runtime_context_length(&model_name, None, None) .await; node.regossip().await; eprintln!( @@ -2285,8 +2301,13 @@ async fn moe_election_loop( current_local_port = Some(local_proxy_port); llama_process = Some(process); if let Some(ref process) = llama_process { + let backend = if local_backend_name(&model) == "MLX server" { + "mlx" + } else { + "llama" + }; on_process(Some(LocalProcessInfo { - backend: "llama".into(), + backend: backend.into(), pid: process.handle.pid(), port: llama_port, context_length: process.context_length, @@ -2314,7 +2335,7 @@ async fn moe_election_loop( } } else if plan.active_ids.len() == 1 { if model_fits { - node.set_model_runtime_context_length(&model_name, None) + node.set_model_runtime_context_length(&model_name, None, None) .await; node.regossip().await; eprintln!( @@ -2389,8 +2410,13 @@ async fn moe_election_loop( current_local_port = Some(local_proxy_port); llama_process = Some(process); if let Some(ref process) = llama_process { + let backend = if local_backend_name(&model) == "MLX server" { + "mlx" + } else { + "llama" + }; on_process(Some(LocalProcessInfo { - backend: "llama".into(), + backend: backend.into(), pid: process.handle.pid(), port: llama_port, context_length: process.context_length, @@ -2414,7 +2440,7 @@ async fn moe_election_loop( } } } else { - node.set_model_runtime_context_length(&model_name, None) + node.set_model_runtime_context_length(&model_name, None, None) .await; node.regossip().await; eprintln!("⚠️ [{}] MoE model too large to serve entirely ({:.1}GB model, {:.1}GB capacity) β€” waiting for peers", @@ -2467,7 +2493,7 @@ async fn moe_election_loop( } Err(e) => { eprintln!(" ❌ moe-split failed: {e}"); - node.set_model_runtime_context_length(&model_name, None) + node.set_model_runtime_context_length(&model_name, None, None) .await; node.regossip().await; if peer_rx.changed().await.is_err() { @@ -2550,8 +2576,13 @@ async fn moe_election_loop( current_local_port = Some(local_proxy_port); llama_process = Some(process); if let Some(ref process) = llama_process { + let backend = if local_backend_name(&shard_path) == "MLX server" { + "mlx" + } else { + "llama" + }; on_process(Some(LocalProcessInfo { - backend: "llama".into(), + backend: backend.into(), pid: process.handle.pid(), port: llama_port, context_length: process.context_length, @@ -2586,7 +2617,7 @@ async fn moe_election_loop( " ⚠️ [{}] Refusing to enter MoE split mode on this node until the shard validates", model_name ); - node.set_model_runtime_context_length(&model_name, None) + node.set_model_runtime_context_length(&model_name, None, None) .await; node.regossip().await; } @@ -2739,6 +2770,30 @@ async fn start_llama( ctx_size_override, pinned_gpu, } = params; + // ── MLX native backend: if model normalizes to a safetensors directory, run in-process ── + #[cfg(target_os = "macos")] + if let Some(dir) = crate::mlx::mlx_model_dir(model) { + if crate::mlx::is_mlx_model_dir(dir) { + let llama_port = match find_free_port().await { + Ok(p) => p, + Err(e) => { + eprintln!(" Failed to find free port: {e}"); + return None; + } + }; + eprintln!("🍎 MLX native backend: loading {model_name}..."); + match crate::mlx::start_mlx_server(dir, model_name.to_string(), llama_port).await { + Ok(process) => { + eprintln!("βœ… MLX server ready on port {llama_port}"); + return Some((llama_port, process)); + } + Err(e) => { + eprintln!(" ❌ MLX server failed: {e}"); + return None; + } + } + } + } let my_vram = node.vram_bytes(); let local_launch_vram = effective_local_launch_vram(my_vram, pinned_gpu); let model_bytes = total_model_bytes(model); @@ -2871,6 +2926,26 @@ async fn start_llama( None }; + #[cfg(target_os = "macos")] + let should_start_mlx = rpc_ports.is_empty() && crate::mlx::model::is_mlx_model_dir(model); + #[cfg(not(target_os = "macos"))] + let should_start_mlx = false; + + if should_start_mlx { + #[cfg(target_os = "macos")] + { + match crate::mlx::server::start_mlx_server(model, model_name.to_string(), llama_port) + .await + { + Ok(process) => return Some((llama_port, process)), + Err(e) => { + eprintln!(" Failed to start MLX server: {e}"); + return None; + } + } + } + } + match launch::start_llama_server( runtime, bin_dir, diff --git a/mesh-llm/src/inference/launch.rs b/mesh-llm/src/inference/launch.rs index 421270a0..1bbda4d4 100644 --- a/mesh-llm/src/inference/launch.rs +++ b/mesh-llm/src/inference/launch.rs @@ -215,17 +215,51 @@ fn resolve_binary_path( pub struct InferenceServerHandle { pid: u32, expected_exit: Arc<AtomicBool>, + shutdown_tx: Option<tokio::sync::watch::Sender<bool>>, expected_comm: String, expected_start_time: Option<i64>, pub(crate) _pidfile_guard: Option<crate::runtime::instance::PidfileGuard>, } impl InferenceServerHandle { + fn process( + pid: u32, + expected_exit: Arc<AtomicBool>, + expected_comm: String, + expected_start_time: Option<i64>, + pidfile_guard: Option<crate::runtime::instance::PidfileGuard>, + ) -> Self { + Self { + pid, + expected_exit, + shutdown_tx: None, + expected_comm, + expected_start_time, + _pidfile_guard: pidfile_guard, + } + } + + #[cfg(target_os = "macos")] + pub(crate) fn in_process(shutdown_tx: tokio::sync::watch::Sender<bool>) -> Self { + Self { + pid: std::process::id(), + expected_exit: Arc::new(AtomicBool::new(true)), + shutdown_tx: Some(shutdown_tx), + expected_comm: String::new(), + expected_start_time: None, + _pidfile_guard: None, + } + } + pub fn pid(&self) -> u32 { self.pid } pub async fn shutdown(&self) { + if let Some(tx) = &self.shutdown_tx { + let _ = tx.send(true); + return; + } self.expected_exit.store(true, Ordering::Relaxed); terminate_process_with_wait( self.pid, @@ -1361,13 +1395,13 @@ pub async fn start_llama_server( }; let pidfile_guard = runtime.write_pidfile("llama-server", &metadata)?; let expected_exit = Arc::new(AtomicBool::new(false)); - let handle = InferenceServerHandle { + let handle = InferenceServerHandle::process( pid, - expected_exit: expected_exit.clone(), - expected_comm: llama_server_name, - expected_start_time: child_started_at, - _pidfile_guard: Some(pidfile_guard), - }; + expected_exit.clone(), + llama_server_name, + child_started_at, + Some(pidfile_guard), + ); let (death_tx, death_rx) = tokio::sync::oneshot::channel(); let pidfile_path = runtime.pidfile_path("llama-server"); tokio::spawn(async move { diff --git a/mesh-llm/src/lib.rs b/mesh-llm/src/lib.rs index 23cd1c9c..53e48ac2 100644 --- a/mesh-llm/src/lib.rs +++ b/mesh-llm/src/lib.rs @@ -3,6 +3,8 @@ mod cli; pub mod crypto; mod inference; mod mesh; +#[cfg(target_os = "macos")] +mod mlx; mod models; mod network; mod plugin; diff --git a/mesh-llm/src/mesh/mod.rs b/mesh-llm/src/mesh/mod.rs index f19b8ea0..c616eb31 100644 --- a/mesh-llm/src/mesh/mod.rs +++ b/mesh-llm/src/mesh/mod.rs @@ -195,6 +195,8 @@ pub struct ModelRuntimeDescriptor { pub identity_hash: Option<String>, #[serde(skip_serializing_if = "Option::is_none")] pub context_length: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub backend: Option<String>, pub ready: bool, } @@ -237,7 +239,10 @@ pub fn infer_served_model_descriptors( if identity.local_file_name.is_none() { identity.local_file_name = Some(format!("{model_name}.gguf")); } - descriptor_from_identity(model_name, identity) + match primary_model_path { + Some(path) => descriptor_from_known_path(model_name, path, identity), + None => descriptor_from_identity(model_name, identity), + } } else { descriptor_from_model_path( model_name, @@ -466,7 +471,7 @@ fn descriptor_from_model_path( ) -> Option<ServedModelDescriptor> { let mut identity = identity_from_model_path(model_name, path)?; identity.is_primary = is_primary; - Some(descriptor_from_identity(model_name, identity)) + Some(descriptor_from_known_path(model_name, path, identity)) } fn descriptor_from_identity( @@ -474,7 +479,19 @@ fn descriptor_from_identity( mut identity: ServedModelIdentity, ) -> ServedModelDescriptor { identity.model_name = model_name.to_string(); - let path = crate::models::find_model_path(model_name); + descriptor_from_known_path( + model_name, + &crate::models::find_model_path(model_name), + identity, + ) +} + +fn descriptor_from_known_path( + model_name: &str, + path: &std::path::Path, + mut identity: ServedModelIdentity, +) -> ServedModelDescriptor { + identity.model_name = model_name.to_string(); let catalog = crate::models::find_catalog_model_exact(model_name); let mut topology = crate::models::infer_local_model_topology(&path, catalog); if topology.is_none() { @@ -495,7 +512,7 @@ fn descriptor_from_identity( }); } } - enrich_topology_with_local_shared_ranking(path.as_path(), &mut topology); + enrich_topology_with_local_shared_ranking(path, &mut topology); let mut capabilities = crate::models::capabilities::infer_local_model_capabilities(model_name, &path, catalog); capabilities.moe = capabilities.moe @@ -1849,6 +1866,7 @@ impl Node { &self, model_name: &str, context_length: Option<u32>, + backend: Option<&str>, ) { let identity_hash = self .served_model_descriptors @@ -1858,24 +1876,29 @@ impl Node { .find(|descriptor| descriptor.identity.model_name == model_name) .and_then(|descriptor| descriptor.identity.identity_hash.clone()); let mut runtimes = self.model_runtime_descriptors.lock().await; - if let Some(context_length) = context_length { - if let Some(runtime) = runtimes - .iter_mut() - .find(|runtime| runtime.model_name == model_name) - { - runtime.identity_hash = identity_hash.or_else(|| runtime.identity_hash.clone()); - runtime.context_length = Some(context_length); - runtime.ready = true; - } else { - runtimes.push(ModelRuntimeDescriptor { - model_name: model_name.to_string(), - identity_hash, - context_length: Some(context_length), - ready: true, - }); - } - } else { + if context_length.is_none() { runtimes.retain(|runtime| runtime.model_name != model_name); + return; + } + + if let Some(runtime) = runtimes + .iter_mut() + .find(|runtime| runtime.model_name == model_name) + { + runtime.identity_hash = identity_hash.or_else(|| runtime.identity_hash.clone()); + runtime.context_length = context_length; + runtime.backend = backend + .map(str::to_string) + .or_else(|| runtime.backend.clone()); + runtime.ready = true; + } else { + runtimes.push(ModelRuntimeDescriptor { + model_name: model_name.to_string(), + identity_hash, + context_length, + backend: backend.map(str::to_string), + ready: true, + }); } } @@ -1900,6 +1923,7 @@ impl Node { model_name: model_name.to_string(), identity_hash, context_length: None, + backend: None, ready: false, }); } @@ -4199,148 +4223,4577 @@ async fn send_push_error(send: &mut iroh::endpoint::SendStream, msg: &str) -> an Ok(()) } -/// Generate a mesh ID for a new mesh. -/// Named meshes: `sha256("mesh-llm:" + name + ":" + nostr_pubkey)` β€” deterministic, unique per creator. -/// Unnamed meshes: random UUID, persisted to `~/.mesh-llm/mesh-id`. -pub fn generate_mesh_id(name: Option<&str>, nostr_pubkey: Option<&str>) -> String { - if let Some(name) = name { - use std::hash::{Hash, Hasher}; - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - "mesh-llm:".hash(&mut hasher); - name.hash(&mut hasher); - if let Some(pk) = nostr_pubkey { - pk.hash(&mut hasher); - } - format!("{:016x}", hasher.finish()) - } else { - // Try to load persisted mesh-id - let path = mesh_id_path(); - if let Ok(id) = std::fs::read_to_string(&path) { - let id = id.trim().to_string(); - if !id.is_empty() { - return id; - } - } - // Generate new random ID and persist - let id = format!( - "{:016x}{:016x}", - rand::random::<u64>(), - rand::random::<u64>() - ); - if let Some(parent) = path.parent() { - let _ = std::fs::create_dir_all(parent); - } - let _ = std::fs::write(&path, &id); - id - } -} +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::node::{GossipFrame, NodeRole, PeerAnnouncement, RouteTableRequest}; + use tokio::sync::watch; -fn mesh_id_path() -> std::path::PathBuf { - dirs::home_dir() - .unwrap_or_else(|| std::path::PathBuf::from(".")) - .join(".mesh-llm") - .join("mesh-id") -} + async fn make_test_node(role: super::NodeRole) -> Result<Node> { + use iroh::endpoint::QuicTransportConfig; -/// Save the mesh ID of the last mesh we successfully joined. -pub fn save_last_mesh_id(mesh_id: &str) { - let path = dirs::home_dir() - .unwrap_or_else(|| std::path::PathBuf::from(".")) - .join(".mesh-llm") - .join("last-mesh"); - if let Some(parent) = path.parent() { - let _ = std::fs::create_dir_all(parent); + let transport_config = QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(); + let endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V1.to_vec(), ALPN_V0.to_vec()]) + .transport_config(transport_config) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + + let (peer_change_tx, peer_change_rx) = watch::channel(0usize); + let (inflight_change_tx, _) = watch::channel(0u64); + let (tunnel_tx, _tunnel_rx) = tokio::sync::mpsc::channel(8); + let (tunnel_http_tx, _tunnel_http_rx) = tokio::sync::mpsc::channel(8); + + let node = Node { + endpoint, + public_addr: None, + state: Arc::new(Mutex::new(MeshState { + peers: HashMap::new(), + connections: HashMap::new(), + remote_tunnel_maps: HashMap::new(), + dead_peers: HashSet::new(), + seen_plugin_messages: HashSet::new(), + seen_plugin_message_order: VecDeque::new(), + policy_rejected_peers: HashMap::new(), + })), + role: Arc::new(Mutex::new(role)), + models: Arc::new(Mutex::new(Vec::new())), + model_source: Arc::new(Mutex::new(None)), + serving_models: Arc::new(Mutex::new(Vec::new())), + served_model_descriptors: Arc::new(Mutex::new(Vec::new())), + model_runtime_descriptors: Arc::new(Mutex::new(Vec::new())), + hosted_models: Arc::new(Mutex::new(Vec::new())), + llama_ready: Arc::new(Mutex::new(false)), + available_models: Arc::new(Mutex::new(Vec::new())), + requested_models: Arc::new(Mutex::new(Vec::new())), + model_demand: Arc::new(std::sync::Mutex::new(HashMap::new())), + mesh_id: Arc::new(Mutex::new(None)), + accepting: Arc::new(( + tokio::sync::Notify::new(), + std::sync::atomic::AtomicBool::new(false), + )), + vram_bytes: 64 * 1024 * 1024 * 1024, + peer_change_tx, + peer_change_rx, + inflight_requests: Arc::new(std::sync::atomic::AtomicUsize::new(0)), + inflight_change_tx, + tunnel_tx, + tunnel_http_tx, + plugin_manager: Arc::new(Mutex::new(None)), + display_name: Arc::new(Mutex::new(None)), + owner_attestation: Arc::new(Mutex::new(None)), + owner_summary: Arc::new(Mutex::new(OwnershipSummary::default())), + trust_store: Arc::new(Mutex::new(TrustStore::default())), + trust_policy: TrustPolicy::Off, + enumerate_host: false, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: Arc::new(tokio::sync::Mutex::new(None)), + gpu_compute_tflops_fp32: Arc::new(tokio::sync::Mutex::new(None)), + gpu_compute_tflops_fp16: Arc::new(tokio::sync::Mutex::new(None)), + config_state: Arc::new(tokio::sync::Mutex::new( + crate::runtime::config_state::ConfigState::default(), + )), + config_revision_tx: { + let (tx, _rx) = tokio::sync::watch::channel(0u64); + Arc::new(tx) + }, + }; + + let accept_node = node.clone(); + tokio::spawn(async move { + accept_node.accept_loop().await; + }); + + Ok(node) } - let _ = std::fs::write(&path, mesh_id); -} -/// Load the mesh ID of the last mesh we successfully joined. -pub fn load_last_mesh_id() -> Option<String> { - let path = dirs::home_dir() - .unwrap_or_else(|| std::path::PathBuf::from(".")) - .join(".mesh-llm") - .join("last-mesh"); - std::fs::read_to_string(&path) - .ok() - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) -} + #[test] + fn test_merge_demand_takes_max() { + let mut ours = HashMap::new(); + ours.insert( + "GLM".into(), + ModelDemand { + last_active: 100, + request_count: 50, + }, + ); + ours.insert( + "Hermes".into(), + ModelDemand { + last_active: 200, + request_count: 10, + }, + ); -// --------------------------------------------------------------------------- -// Public-to-private identity transition -// --------------------------------------------------------------------------- + let mut theirs = HashMap::new(); + theirs.insert( + "GLM".into(), + ModelDemand { + last_active: 150, + request_count: 30, + }, + ); + theirs.insert( + "Qwen".into(), + ModelDemand { + last_active: 300, + request_count: 5, + }, + ); -fn was_public_path() -> std::path::PathBuf { - dirs::home_dir() - .unwrap_or_else(|| std::path::PathBuf::from(".")) - .join(".mesh-llm") - .join("was-public") -} + merge_demand(&mut ours, &theirs); + + // GLM: max(100,150)=150 for last_active, max(50,30)=50 for count + assert_eq!(ours["GLM"].last_active, 150); + assert_eq!(ours["GLM"].request_count, 50); + // Hermes: unchanged (not in theirs) + assert_eq!(ours["Hermes"].last_active, 200); + assert_eq!(ours["Hermes"].request_count, 10); + // Qwen: new entry from theirs + assert_eq!(ours["Qwen"].last_active, 300); + assert_eq!(ours["Qwen"].request_count, 5); + } + + #[test] + fn test_merge_demand_empty_maps() { + let mut ours = HashMap::new(); + let theirs = HashMap::new(); + merge_demand(&mut ours, &theirs); + assert!(ours.is_empty()); + + let mut theirs2 = HashMap::new(); + theirs2.insert( + "GLM".into(), + ModelDemand { + last_active: 100, + request_count: 1, + }, + ); + merge_demand(&mut ours, &theirs2); + assert_eq!(ours.len(), 1); + assert_eq!(ours["GLM"].request_count, 1); + } + + #[test] + fn test_merge_demand_idempotent() { + let mut ours = HashMap::new(); + ours.insert( + "GLM".into(), + ModelDemand { + last_active: 100, + request_count: 50, + }, + ); -/// Record that this node was started in public mode (--auto / --publish / --mesh-name). -/// Called at startup so we can detect a publicβ†’private transition next time. -pub fn mark_was_public() { - let path = was_public_path(); - if let Some(parent) = path.parent() { - let _ = std::fs::create_dir_all(parent); + let theirs = ours.clone(); + merge_demand(&mut ours, &theirs); + + assert_eq!(ours["GLM"].last_active, 100); + assert_eq!(ours["GLM"].request_count, 50); } - let _ = std::fs::write(&path, "1"); -} -/// Returns true if the previous run was public (marker file exists). -pub fn was_previously_public() -> bool { - was_public_path().exists() -} + #[test] + fn test_demand_ttl_filtering() { + let now = now_secs(); + let mut demand = HashMap::new(); + + // Recent β€” should survive + demand.insert( + "Recent".into(), + ModelDemand { + last_active: now - 60, // 1 min ago + request_count: 10, + }, + ); + // Stale β€” should be filtered + demand.insert( + "Stale".into(), + ModelDemand { + last_active: now - DEMAND_TTL_SECS - 100, // past TTL + request_count: 100, + }, + ); -/// Clear identity files (key, nostr.nsec, mesh-id, last-mesh, was-public) so the -/// next start gets a completely fresh identity. Called when transitioning from -/// public β†’ private to avoid reusing a publicly-known identity in a private mesh. -pub fn clear_public_identity() { - let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from(".")); - let dir = home.join(".mesh-llm"); - let mut ok = true; - for name in &["key", "nostr.nsec", "mesh-id", "last-mesh"] { - let p = dir.join(name); - if p.exists() { - if std::fs::remove_file(&p).is_ok() { - tracing::info!("Cleared {}", p.display()); - } else { - tracing::warn!("Failed to clear {}", p.display()); - ok = false; - } - } + let filtered: HashMap<String, ModelDemand> = demand + .into_iter() + .filter(|(_, d)| (now - d.last_active) < DEMAND_TTL_SECS) + .collect(); + + assert_eq!(filtered.len(), 1); + assert!(filtered.contains_key("Recent")); + assert!(!filtered.contains_key("Stale")); } - // Only remove the marker after identity files are gone, so a failed - // cleanup is retried on the next private start. - let marker = dir.join("was-public"); - if ok { - let _ = std::fs::remove_file(&marker); - } else { - tracing::warn!("Keeping was-public marker β€” will retry cleanup next start"); + + #[test] + fn test_demand_serialization_roundtrip() { + let mut demand: HashMap<String, ModelDemand> = HashMap::new(); + demand.insert( + "GLM".into(), + ModelDemand { + last_active: 1772309000, + request_count: 42, + }, + ); + + let json = serde_json::to_string(&demand).unwrap(); + let decoded: HashMap<String, ModelDemand> = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded["GLM"].last_active, 1772309000); + assert_eq!(decoded["GLM"].request_count, 42); } -} -/// Load secret key from ~/.mesh-llm/key, or create a new one and save it. -async fn load_or_create_key() -> Result<SecretKey> { - let key_path = default_node_key_path()?; - let dir = key_path - .parent() - .ok_or_else(|| anyhow::anyhow!("Invalid node key path {}", key_path.display()))?; - ensure_private_node_key_dir(dir)?; + #[test] + fn test_demand_deserialization_missing_field() { + // Simulate old gossip message without model_demand field + // Just verify ModelDemand defaults work + let d = ModelDemand::default(); + assert_eq!(d.last_active, 0); + assert_eq!(d.request_count, 0); - if key_path.exists() { - ensure_private_node_key_file(&key_path)?; - let hex = tokio::fs::read_to_string(&key_path).await?; - let bytes = hex::decode(hex.trim())?; - if bytes.len() != 32 { - anyhow::bail!("Invalid key length in {}", key_path.display()); + // Verify HashMap<String, ModelDemand> defaults to empty + let empty: HashMap<String, ModelDemand> = Default::default(); + assert!(empty.is_empty()); + + // The real test: serde default on a struct with model_demand + #[derive(Deserialize, Default)] + struct TestStruct { + #[serde(default)] + model_demand: HashMap<String, ModelDemand>, + #[serde(default)] + requested_models: Vec<String>, + } + let parsed: TestStruct = serde_json::from_str("{}").unwrap(); + assert!(parsed.model_demand.is_empty()); + assert!(parsed.requested_models.is_empty()); + } + + #[test] + fn test_peer_announcement_gpu_serde_roundtrip() { + // Test that gpu_name and hostname fields serialize and deserialize correctly + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TestAnnouncement { + #[serde(default)] + gpu_name: Option<String>, + #[serde(default)] + hostname: Option<String>, } - let key = SecretKey::from_bytes(&bytes.try_into().unwrap()); - tracing::info!("Loaded key from {}", key_path.display()); - return Ok(key); - } - let key = SecretKey::generate(&mut rand::rng()); + let test = TestAnnouncement { + gpu_name: Some("NVIDIA A100".to_string()), + hostname: Some("worker-01".to_string()), + }; + + let json = serde_json::to_string(&test).unwrap(); + let decoded: TestAnnouncement = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded.gpu_name, Some("NVIDIA A100".to_string())); + assert_eq!(decoded.hostname, Some("worker-01".to_string())); + } + + #[test] + fn test_peer_announcement_backward_compat_no_hw_fields() { + // Simulate old gossip message without gpu_name or hostname + #[derive(Deserialize, Debug)] + struct TestAnnouncement { + #[serde(default)] + gpu_name: Option<String>, + #[serde(default)] + hostname: Option<String>, + } + + let json = r#"{"other_field": "value"}"#; + let decoded: TestAnnouncement = serde_json::from_str(json).unwrap(); + + assert_eq!(decoded.gpu_name, None); + assert_eq!(decoded.hostname, None); + } + + #[test] + fn test_peer_announcement_backward_compat_with_hw_fields() { + // Simulate new gossip message with gpu_name and hostname + #[derive(Deserialize, Debug)] + struct TestAnnouncement { + #[serde(default)] + gpu_name: Option<String>, + #[serde(default)] + hostname: Option<String>, + } + + let json = r#"{"gpu_name": "NVIDIA H100", "hostname": "gpu-server-02"}"#; + let decoded: TestAnnouncement = serde_json::from_str(json).unwrap(); + + assert_eq!(decoded.gpu_name, Some("NVIDIA H100".to_string())); + assert_eq!(decoded.hostname, Some("gpu-server-02".to_string())); + } + + #[test] + fn test_peer_announcement_hostname_serde_roundtrip() { + // Test hostname-only roundtrip + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TestAnnouncement { + #[serde(default)] + gpu_name: Option<String>, + #[serde(default)] + hostname: Option<String>, + } + + let test = TestAnnouncement { + gpu_name: None, + hostname: Some("compute-node-42".to_string()), + }; + + let json = serde_json::to_string(&test).unwrap(); + let decoded: TestAnnouncement = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded.hostname, Some("compute-node-42".to_string())); + assert_eq!(decoded.gpu_name, None); + } + + #[test] + fn test_peer_payload_hw_fields() { + // Test that PeerPayload includes gpu_name and hostname fields + #[derive(Serialize, Debug)] + struct TestPeerPayload { + id: String, + gpu_name: Option<String>, + hostname: Option<String>, + } + + let payload = TestPeerPayload { + id: "peer-123".to_string(), + gpu_name: Some("NVIDIA A100".to_string()), + hostname: Some("worker-01".to_string()), + }; + + let json = serde_json::to_string(&payload).unwrap(); + let value: serde_json::Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(value["gpu_name"], "NVIDIA A100"); + assert_eq!(value["hostname"], "worker-01"); + } + + #[test] + fn test_enumerate_host_false_omits_hw_fields_in_announcement() { + let enumerate_host = false; + let gpu_name: Option<String> = Some("NVIDIA RTX 5090".to_string()); + let hostname: Option<String> = Some("carrack".to_string()); + let gpu_vram: Option<String> = Some("34359738368".to_string()); + + let gossip_gpu_name = if enumerate_host { + gpu_name.clone() + } else { + None + }; + let gossip_hostname = if enumerate_host { + hostname.clone() + } else { + None + }; + let gossip_gpu_vram = if enumerate_host { + gpu_vram.clone() + } else { + None + }; + + assert_eq!(gossip_gpu_name, None); + assert_eq!(gossip_hostname, None); + assert_eq!(gossip_gpu_vram, None); + } + + #[test] + fn test_enumerate_host_true_includes_hw_fields_in_announcement() { + let enumerate_host = true; + let gpu_name: Option<String> = Some("NVIDIA RTX 5090".to_string()); + let hostname: Option<String> = Some("carrack".to_string()); + let gpu_vram: Option<String> = Some("34359738368".to_string()); + + let gossip_gpu_name = if enumerate_host { + gpu_name.clone() + } else { + None + }; + let gossip_hostname = if enumerate_host { + hostname.clone() + } else { + None + }; + let gossip_gpu_vram = if enumerate_host { + gpu_vram.clone() + } else { + None + }; + + assert_eq!(gossip_gpu_name, Some("NVIDIA RTX 5090".to_string())); + assert_eq!(gossip_hostname, Some("carrack".to_string())); + assert_eq!(gossip_gpu_vram, Some("34359738368".to_string())); + } + + #[test] + fn test_is_soc_always_included_regardless_of_enumerate_host() { + for enumerate_host in [false, true] { + let is_soc: Option<bool> = Some(true); + let gpu_name: Option<String> = Some("Tegra AGX Orin".to_string()); + + let gossip_gpu_name = if enumerate_host { + gpu_name.clone() + } else { + None + }; + + assert_eq!(is_soc, Some(true), "is_soc must always be sent"); + if enumerate_host { + assert!(gossip_gpu_name.is_some()); + } else { + assert!(gossip_gpu_name.is_none()); + } + } + } + + #[test] + fn test_peer_announcement_backward_compat_is_soc_gpu_vram() { + #[derive(Deserialize, Debug)] + struct TestAnnouncement { + #[serde(default)] + is_soc: Option<bool>, + #[serde(default)] + gpu_vram: Option<String>, + } + + let json = r#"{"other_field": "value"}"#; + let decoded: TestAnnouncement = serde_json::from_str(json).unwrap(); + assert_eq!( + decoded.is_soc, None, + "old nodes without is_soc should default to None" + ); + assert_eq!( + decoded.gpu_vram, None, + "old nodes without gpu_vram should default to None" + ); + } + + #[test] + fn test_peer_announcement_with_bandwidth_serde_roundtrip() { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TestAnnouncement { + #[serde(default)] + gpu_bandwidth_gbps: Option<String>, + } + + let test = TestAnnouncement { + gpu_bandwidth_gbps: Some("1671.7,722.2".to_string()), + }; + + let json = serde_json::to_string(&test).unwrap(); + let decoded: TestAnnouncement = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded.gpu_bandwidth_gbps, Some("1671.7,722.2".to_string())); + } + + #[test] + fn test_peer_announcement_backward_compat_no_bandwidth_field() { + #[derive(Deserialize, Debug)] + struct TestAnnouncement { + #[serde(default)] + gpu_bandwidth_gbps: Option<String>, + } + + let json = r#"{"other_field": "value"}"#; + let decoded: TestAnnouncement = serde_json::from_str(json).unwrap(); + + assert_eq!(decoded.gpu_bandwidth_gbps, None); + } + + fn make_valid_gossip_frame() -> GossipFrame { + GossipFrame { + gen: NODE_PROTOCOL_GENERATION, + sender_id: vec![0u8; 32], + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 32], + role: NodeRole::Worker as i32, + ..Default::default() + }], + } + } + + #[test] + fn protocol_from_alpn_supports_v1_and_legacy_v0() { + assert_eq!(protocol_from_alpn(ALPN_V1), ControlProtocol::ProtoV1); + assert_eq!(protocol_from_alpn(ALPN_V0), ControlProtocol::JsonV0); + assert_eq!( + protocol_from_alpn(b"mesh-llm/999"), + ControlProtocol::ProtoV1 + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn legacy_v0_and_post_proto_nodes_interoperate_over_real_connection() -> Result<()> { + use iroh::endpoint::QuicTransportConfig; + + let post_node = make_test_node(super::NodeRole::Host { http_port: 9337 }).await?; + let post_id = post_node.id(); + post_node + .set_serving_models(vec!["post-model".to_string()]) + .await; + post_node + .set_hosted_models(vec!["post-model".to_string()]) + .await; + post_node + .set_mesh_id("compat-mesh-01020304".to_string()) + .await; + post_node.start_accepting(); + + let legacy_endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V0.to_vec()]) + .transport_config( + QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(), + ) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + let legacy_id = legacy_endpoint.id(); + let legacy_addr = legacy_endpoint.addr(); + let legacy_ann = super::PeerAnnouncementV0 { + addr: legacy_addr.clone(), + role: super::NodeRole::Host { http_port: 9444 }, + models: vec!["legacy-model".to_string()], + vram_bytes: 48 * 1024 * 1024 * 1024, + model_source: Some("legacy-model.gguf".to_string()), + serving: Some("legacy-model".to_string()), + serving_models: vec!["legacy-model".to_string()], + available_models: vec!["legacy-model".to_string()], + requested_models: Vec::new(), + version: Some("0.50.0".to_string()), + model_demand: HashMap::new(), + mesh_id: Some("compat-mesh-01020304".to_string()), + gpu_name: Some("Legacy GPU".to_string()), + hostname: Some("legacy-peer".to_string()), + is_soc: Some(false), + gpu_vram: Some((48_u64 * 1024 * 1024 * 1024).to_string()), + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + }; + let legacy_route_table = RoutingTable { + hosts: vec![RouteEntry { + model: "legacy-model".to_string(), + node_id: legacy_id.fmt_short().to_string(), + endpoint_id: legacy_id, + vram_gb: 48.0, + }], + mesh_id: Some("compat-mesh-01020304".to_string()), + }; + + let server = tokio::spawn(async move { + let incoming = + tokio::time::timeout(std::time::Duration::from_secs(5), legacy_endpoint.accept()) + .await + .expect("legacy endpoint should get an incoming connection") + .expect("accept loop should yield one incoming connection"); + let mut accepting = incoming.accept().expect("legacy accept should succeed"); + let alpn = accepting.alpn().await.expect("ALPN should be available"); + assert_eq!(alpn, ALPN_V0, "new node must fall back to legacy ALPN"); + let conn = accepting + .await + .expect("legacy connection handshake should complete"); + assert_eq!(conn.alpn(), ALPN_V0); + + let (mut send_gossip, mut recv_gossip) = + tokio::time::timeout(std::time::Duration::from_secs(5), conn.accept_bi()) + .await + .expect("post node should open initial gossip stream") + .expect("initial gossip stream should be accepted"); + let mut stream_type = [0u8; 1]; + recv_gossip + .read_exact(&mut stream_type) + .await + .expect("legacy server must read gossip stream type"); + assert_eq!(stream_type[0], STREAM_GOSSIP); + let gossip_buf = read_len_prefixed(&mut recv_gossip) + .await + .expect("legacy server must read JSON gossip frame"); + let received_anns: Vec<super::PeerAnnouncementV0> = + serde_json::from_slice(&gossip_buf).expect("legacy gossip must decode as JSON"); + assert!( + received_anns + .iter() + .any(|ann| ann.addr.id == post_id + && ann.serving.as_deref() == Some("post-model")), + "initial legacy gossip response should include the post-protobuf host announcement" + ); + let legacy_gossip_body = serde_json::to_vec(&vec![legacy_ann.clone()]) + .expect("legacy announcement must serialize"); + write_len_prefixed(&mut send_gossip, &legacy_gossip_body) + .await + .expect("legacy server should reply with JSON gossip"); + send_gossip + .finish() + .expect("legacy gossip reply should finish"); + let _ = recv_gossip.read_to_end(0).await; + + let (mut send_route_resp, mut recv_route_req) = + tokio::time::timeout(std::time::Duration::from_secs(5), conn.accept_bi()) + .await + .expect("post node should open legacy route request stream") + .expect("legacy route request stream should be accepted"); + recv_route_req + .read_exact(&mut stream_type) + .await + .expect("legacy server must read route stream type"); + assert_eq!(stream_type[0], STREAM_ROUTE_REQUEST); + let legacy_route_body = + serde_json::to_vec(&legacy_route_table).expect("legacy route table must serialize"); + send_route_resp + .write_all(&legacy_route_body) + .await + .expect("legacy server must send JSON route table"); + send_route_resp + .finish() + .expect("legacy route response should finish"); + + let (mut send_gossip2, mut recv_gossip2) = conn + .open_bi() + .await + .expect("legacy server should initiate gossip back to post node"); + send_gossip2 + .write_all(&[STREAM_GOSSIP]) + .await + .expect("legacy gossip stream type should be sent"); + write_len_prefixed(&mut send_gossip2, &legacy_gossip_body) + .await + .expect("legacy server should send JSON gossip payload"); + send_gossip2 + .finish() + .expect("legacy initiated gossip should finish"); + let response_buf = read_len_prefixed(&mut recv_gossip2) + .await + .expect("post node should answer legacy gossip"); + let response_anns: Vec<super::PeerAnnouncementV0> = + serde_json::from_slice(&response_buf) + .expect("post node must answer with JSON gossip"); + assert!( + response_anns + .iter() + .any(|ann| ann.addr.id == post_id + && ann.serving.as_deref() == Some("post-model")), + "post node should answer legacy gossip with its current state" + ); + let _ = recv_gossip2.read_to_end(0).await; + + let (mut send_route_req2, mut recv_route_resp2) = conn + .open_bi() + .await + .expect("legacy server should initiate route request to post node"); + send_route_req2 + .write_all(&[STREAM_ROUTE_REQUEST]) + .await + .expect("legacy route request stream type should be sent"); + send_route_req2 + .finish() + .expect("legacy route request should finish"); + let route_json = recv_route_resp2 + .read_to_end(MAX_CONTROL_FRAME_BYTES) + .await + .expect("post node should reply with legacy JSON route table"); + let route_table_from_post: RoutingTable = + serde_json::from_slice(&route_json).expect("post node route response must be JSON"); + assert_eq!( + route_table_from_post.mesh_id.as_deref(), + Some("compat-mesh-01020304") + ); + assert!( + route_table_from_post + .hosts + .iter() + .any(|entry| entry.endpoint_id == post_id && entry.model == "post-model"), + "legacy peer should see the post node in route-table JSON response" + ); + }); + + let invite_token = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&legacy_addr).expect("legacy address should serialize")); + post_node.join(&invite_token).await?; + + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + let peers = post_node.peers().await; + if peers.iter().any(|peer| { + peer.id == legacy_id + && peer.serving_models.first().map(String::as_str) == Some("legacy-model") + }) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect("post node should admit the legacy peer after JSON gossip"); + + let legacy_conn = { + let state = post_node.state.lock().await; + state + .connections + .get(&legacy_id) + .cloned() + .expect("join should leave a connection to the legacy peer") + }; + let route_table = post_node.request_route_table(&legacy_conn).await?; + assert_eq!( + route_table.mesh_id.as_deref(), + Some("compat-mesh-01020304"), + "post node must parse legacy JSON route-table replies" + ); + assert!( + route_table + .hosts + .iter() + .any(|entry| entry.endpoint_id == legacy_id && entry.model == "legacy-model"), + "post node must preserve legacy route-table entries" + ); + + server.await.expect("legacy peer task should complete"); + Ok(()) + } + + #[test] + fn legacy_json_gossip_payload_decodes() { + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0x42; 32]).public()); + let ann = super::PeerAnnouncementV0 { + addr: EndpointAddr { + id: peer_id, + addrs: Default::default(), + }, + role: super::NodeRole::Host { http_port: 3131 }, + models: vec!["Qwen".into()], + vram_bytes: 48_000_000_000, + model_source: Some("Qwen.gguf".into()), + serving: Some("Qwen".into()), + serving_models: vec!["Qwen".into()], + available_models: vec!["Qwen".into()], + requested_models: vec!["Qwen".into()], + version: Some("0.52.0".into()), + model_demand: HashMap::from([( + "Qwen".into(), + ModelDemand { + last_active: 123, + request_count: 7, + }, + )]), + mesh_id: Some("mesh-compat".into()), + gpu_name: Some("NVIDIA A100".into()), + hostname: Some("worker-01".into()), + is_soc: Some(false), + gpu_vram: Some("51539607552".into()), + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_sizes: HashMap::from([("Qwen".into(), 1234_u64)]), + served_model_descriptors: vec![], + served_model_runtime: vec![], + }; + let json = serde_json::to_vec(&vec![ann.clone()]).unwrap(); + + let decoded = decode_gossip_payload(ControlProtocol::JsonV0, peer_id, &json).unwrap(); + + assert_eq!(decoded.len(), 1); + assert_eq!(decoded[0].0.id, peer_id); + assert_eq!( + decoded[0].1.serving_models.first().map(String::as_str), + Some("Qwen") + ); + assert_eq!(decoded[0].1.mesh_id.as_deref(), Some("mesh-compat")); + } + + #[test] + fn legacy_json_tunnel_map_decodes() { + let target = EndpointId::from(SecretKey::from_bytes(&[0x24; 32]).public()); + let json = serde_json::to_vec(&HashMap::from([(hex::encode(target.as_bytes()), 9337_u16)])) + .unwrap(); + + let frame = decode_legacy_tunnel_map_frame(&json).unwrap(); + + assert_eq!(frame.entries.len(), 1); + assert_eq!(frame.entries[0].target_peer_id, target.as_bytes().to_vec()); + assert_eq!(frame.entries[0].tunnel_port, 9337); + } + + #[test] + fn control_frame_roundtrip() { + let frame = make_valid_gossip_frame(); + let encoded = encode_control_frame(STREAM_GOSSIP, &frame); + let decoded: GossipFrame = decode_control_frame(STREAM_GOSSIP, &encoded) + .expect("valid gossip frame must decode successfully"); + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + assert_eq!(decoded.peers.len(), 1); + assert_eq!(decoded.peers[0].endpoint_id, vec![0u8; 32]); + assert_eq!(decoded.peers[0].role, NodeRole::Worker as i32); + } + + fn make_test_peer_info(peer_id: EndpointId) -> PeerInfo { + PeerInfo { + id: peer_id, + addr: EndpointAddr { + id: peer_id, + addrs: Default::default(), + }, + tunnel_port: None, + role: super::NodeRole::Worker, + models: vec![], + vram_bytes: 0, + rtt_ms: None, + model_source: None, + serving_models: vec![], + hosted_models: vec![], + hosted_models_known: false, + available_models: vec![], + requested_models: vec![], + last_seen: std::time::Instant::now(), + last_mentioned: std::time::Instant::now(), + moe_recovered_at: None, + version: None, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_metadata: vec![], + experts_summary: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + owner_attestation: None, + owner_summary: OwnershipSummary::default(), + } + } + + fn make_test_moe_descriptor(model_name: &str, identity_hash: &str) -> ServedModelDescriptor { + ServedModelDescriptor { + identity: ServedModelIdentity { + model_name: model_name.to_string(), + is_primary: true, + source_kind: ModelSourceKind::HuggingFace, + canonical_ref: Some(format!("hf://{identity_hash}")), + repository: Some("Qwen".to_string()), + revision: Some("main".to_string()), + artifact: Some(format!("{model_name}.gguf")), + local_file_name: Some(format!("{model_name}.gguf")), + identity_hash: Some(identity_hash.to_string()), + }, + capabilities: crate::models::ModelCapabilities { + moe: true, + ..Default::default() + }, + topology: Some(crate::models::ModelTopology { + moe: Some(crate::models::ModelMoeInfo { + expert_count: 512, + used_expert_count: 10, + min_experts_per_node: Some(160), + source: Some("test".to_string()), + ranking_source: None, + ranking_origin: None, + ranking: Vec::new(), + ranking_prompt_count: None, + ranking_tokens: None, + ranking_layer_scope: None, + }), + }), + } + } + + fn make_test_endpoint_id(seed: u8) -> EndpointId { + let mut bytes = [0u8; 32]; + bytes[0] = seed; + EndpointId::from(SecretKey::from_bytes(&bytes).public()) + } + + #[test] + fn shared_exact_moe_identity_uses_stricter_heartbeat_without_inbound_grace() { + let mut peer = make_test_peer_info(make_test_endpoint_id(7)); + peer.served_model_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + let local_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + let local_runtime = vec![]; + + let policy = heartbeat_failure_policy_for_peer(&local_descriptors, &local_runtime, &peer); + + assert_eq!( + policy, + HeartbeatFailurePolicy { + allow_recent_inbound_grace: false, + failure_threshold: 2, + } + ); + } + + #[test] + fn non_matching_or_non_moe_peers_keep_default_heartbeat_grace() { + let mut peer = make_test_peer_info(make_test_endpoint_id(8)); + peer.served_model_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "remote-model", + )]; + let local_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "local-model", + )]; + let local_runtime = vec![]; + + let policy = heartbeat_failure_policy_for_peer(&local_descriptors, &local_runtime, &peer); + + assert_eq!( + policy, + HeartbeatFailurePolicy { + allow_recent_inbound_grace: true, + failure_threshold: 2, + } + ); + } + + #[test] + fn shared_exact_moe_startup_relaxes_heartbeat_during_convergence() { + let mut peer = make_test_peer_info(make_test_endpoint_id(11)); + peer.served_model_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + let local_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + let local_runtime = vec![ModelRuntimeDescriptor { + model_name: "Qwen3-Coder-Next-Q4_K_M".to_string(), + identity_hash: Some("same-model".to_string()), + context_length: None, + backend: None, + ready: false, + }]; + + let policy = heartbeat_failure_policy_for_peer(&local_descriptors, &local_runtime, &peer); + + assert_eq!( + policy, + HeartbeatFailurePolicy { + allow_recent_inbound_grace: true, + failure_threshold: 4, + } + ); + } + + #[test] + fn recovered_moe_peer_stays_out_of_active_placement_until_probation_expires() { + let mut peer = make_test_peer_info(make_test_endpoint_id(9)); + peer.serving_models = vec!["Qwen3-Coder-Next-Q4_K_M".to_string()]; + peer.served_model_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + let local_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + + peer.moe_recovered_at = Some(std::time::Instant::now()); + assert!(!peer_is_eligible_for_active_moe( + &local_descriptors, + &peer, + "Qwen3-Coder-Next-Q4_K_M" + )); + + peer.moe_recovered_at = Some( + std::time::Instant::now() + - std::time::Duration::from_secs(MOE_RECOVERY_PROBATION_SECS + 1), + ); + assert!(peer_is_eligible_for_active_moe( + &local_descriptors, + &peer, + "Qwen3-Coder-Next-Q4_K_M" + )); + } + + #[test] + fn requested_model_peer_is_eligible_for_active_moe_during_startup() { + let mut peer = make_test_peer_info(make_test_endpoint_id(10)); + peer.requested_models = vec!["Qwen3-Coder-Next-Q4_K_M".to_string()]; + peer.served_model_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + let local_descriptors = vec![make_test_moe_descriptor( + "Qwen3-Coder-Next-Q4_K_M", + "same-model", + )]; + + assert!(peer_is_eligible_for_active_moe( + &local_descriptors, + &peer, + "Qwen3-Coder-Next-Q4_K_M" + )); + } + + #[test] + fn incoming_peer_promoted_after_valid_gossip() { + let frame = make_valid_gossip_frame(); + let encoded = encode_control_frame(STREAM_GOSSIP, &frame); + let decoded: GossipFrame = decode_control_frame(STREAM_GOSSIP, &encoded) + .expect("valid gossip frame must decode successfully"); + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + assert!(!decoded.peers.is_empty()); + + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0xab; 32]).public()); + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + + assert!( + !is_peer_admitted(&peers, &peer_id), + "peer must NOT be admitted before gossip" + ); + + for &tunnel_stream in &[STREAM_TUNNEL, STREAM_TUNNEL_HTTP] { + assert!( + !stream_allowed_before_admission(tunnel_stream), + "stream {:#04x} must be gated until after admission β€” unadmitted peers must not reach tunnel paths", + tunnel_stream + ); + } + + assert!( + stream_allowed_before_admission(STREAM_GOSSIP), + "STREAM_GOSSIP must always be allowed β€” it is the admission path" + ); + + peers.insert(peer_id, make_test_peer_info(peer_id)); + + assert!( + is_peer_admitted(&peers, &peer_id), + "peer must be admitted after gossip completes (add_peer inserts into state.peers)" + ); + } + + #[test] + fn incoming_peer_rejected_on_legacy_or_malformed_gossip() { + let malformed_payload = vec![0xFF_u8; 20]; + let mut bad_frame = vec![STREAM_GOSSIP]; + bad_frame.extend_from_slice(&(malformed_payload.len() as u32).to_le_bytes()); + bad_frame.extend_from_slice(&malformed_payload); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &bad_frame) + .expect_err("malformed protobuf must be rejected"); + assert!( + matches!(err, ControlFrameError::DecodeError(_)), + "expected DecodeError for malformed payload, got {:?}", + err + ); + + let bad_gen_frame = GossipFrame { + gen: 0, + sender_id: vec![], + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 32], + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &bad_gen_frame); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("gen=0 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 0 }), + "expected BadGeneration{{got:0}}, got {:?}", + err + ); + + for stream_type in [ + STREAM_TUNNEL, + STREAM_TUNNEL_HTTP, + STREAM_TUNNEL_MAP, + STREAM_PEER_DOWN, + STREAM_PEER_LEAVING, + STREAM_PLUGIN_CHANNEL, + STREAM_PLUGIN_BULK_TRANSFER, + ] { + assert!( + !stream_allowed_before_admission(stream_type), + "stream {:#04x} must be quarantine-gated for unadmitted peers β€” if this fails, the gate is broken", + stream_type + ); + } + + assert!( + stream_allowed_before_admission(STREAM_GOSSIP), + "STREAM_GOSSIP must bypass the gate (it is the admission handshake)" + ); + assert!( + stream_allowed_before_admission(STREAM_ROUTE_REQUEST), + "STREAM_ROUTE_REQUEST must bypass the gate (passive/client request-only path)" + ); + + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0xcd; 32]).public()); + let peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + assert!( + !is_peer_admitted(&peers, &peer_id), + "peer must NOT be admitted when gossip fails" + ); + } + + #[test] + fn passive_route_table_request_does_not_admit_peer() { + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0xef; 32]).public()); + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + + assert!( + !is_peer_admitted(&peers, &peer_id), + "passive caller must NOT be admitted before route request" + ); + + assert!( + stream_allowed_before_admission(STREAM_ROUTE_REQUEST), + "STREAM_ROUTE_REQUEST must be allowed before admission (passive/client path)" + ); + + for &gated in &[ + STREAM_TUNNEL, + STREAM_TUNNEL_HTTP, + STREAM_TUNNEL_MAP, + STREAM_PEER_DOWN, + STREAM_PEER_LEAVING, + STREAM_PLUGIN_CHANNEL, + STREAM_PLUGIN_BULK_TRANSFER, + ] { + assert!( + !stream_allowed_before_admission(gated), + "stream {:#04x} must remain gated after a route request β€” route request must not unlock other streams", + gated + ); + } + + let valid_req = RouteTableRequest { + requester_id: vec![0xef_u8; 32], + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &valid_req); + let decoded: RouteTableRequest = decode_control_frame(STREAM_ROUTE_REQUEST, &encoded) + .expect("valid RouteTableRequest must decode successfully"); + assert_eq!(decoded.requester_id, vec![0xef_u8; 32]); + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + + let bad_req = RouteTableRequest { + requester_id: vec![0u8; 16], + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded_bad = encode_control_frame(STREAM_ROUTE_REQUEST, &bad_req); + let err = decode_control_frame::<RouteTableRequest>(STREAM_ROUTE_REQUEST, &encoded_bad) + .expect_err("route request with wrong-length requester_id must be rejected"); + assert!( + matches!(err, ControlFrameError::InvalidEndpointId { got: 16 }), + "expected InvalidEndpointId{{got:16}}, got {:?}", + err + ); + + assert!( + !is_peer_admitted(&peers, &peer_id), + "passive caller must NOT be admitted after route-table response" + ); + + peers.insert(peer_id, make_test_peer_info(peer_id)); + assert!( + is_peer_admitted(&peers, &peer_id), + "only explicit gossip (add_peer) should promote to admitted" + ); + } + + #[test] + fn control_frame_rejects_oversize_or_bad_generation() { + let oversize_len = MAX_CONTROL_FRAME_BYTES + 1; + let mut fake = vec![STREAM_GOSSIP]; + fake.extend_from_slice(&(oversize_len as u32).to_le_bytes()); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &fake) + .expect_err("oversize frame must be rejected"); + assert!( + matches!(err, ControlFrameError::OversizeFrame { .. }), + "expected OversizeFrame, got {:?}", + err + ); + + let bad_gen = GossipFrame { + gen: 99, + sender_id: vec![], + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 32], + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &bad_gen); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("bad generation must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 99 }), + "expected BadGeneration{{got:99}}, got {:?}", + err + ); + + let bad_id = GossipFrame { + gen: NODE_PROTOCOL_GENERATION, + sender_id: vec![0u8; 32], + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 16], + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &bad_id); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("bad endpoint_id must be rejected"); + assert!( + matches!(err, ControlFrameError::InvalidEndpointId { got: 16 }), + "expected InvalidEndpointId{{got:16}}, got {:?}", + err + ); + + let valid = make_valid_gossip_frame(); + let encoded = encode_control_frame(STREAM_GOSSIP, &valid); + let err = decode_control_frame::<GossipFrame>(STREAM_TUNNEL_MAP, &encoded) + .expect_err("wrong stream type must be rejected"); + assert!( + matches!( + err, + ControlFrameError::WrongStreamType { + expected: 0x03, + got: 0x01 + } + ), + "expected WrongStreamType, got {:?}", + err + ); + } + + #[test] + fn gossip_frame_roundtrip_preserves_scanned_model_metadata() { + use crate::proto::node::{CompactModelMetadata, ExpertsSummary}; + + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0x01; 32]).public()); + let peer_id_bytes = peer_id.as_bytes().to_vec(); + + let meta = CompactModelMetadata { + model_key: "Qwen3-8B-Q4_K_M".to_string(), + context_length: 40960, + vocab_size: 151936, + embedding_size: 4096, + head_count: 32, + layer_count: 36, + feed_forward_length: 14336, + key_length: 128, + value_length: 128, + architecture: "qwen3".to_string(), + tokenizer_model_name: "PreTrainedTokenizerFast".to_string(), + special_tokens: vec![], + rope_scale: 1.0, + rope_freq_base: 1_000_000.0, + is_moe: false, + expert_count: 0, + used_expert_count: 0, + quantization_type: "Q4_K_M".to_string(), + }; + + let mut model_sizes = HashMap::new(); + model_sizes.insert("Qwen3-8B-Q4_K_M".to_string(), 4_800_000_000u64); + + let experts = ExpertsSummary { + total_experts: 64, + expert_count_used: 8, + top_expert_ids: vec![1, 5, 10], + }; + + let local_ann = super::PeerAnnouncement { + addr: EndpointAddr { + id: peer_id, + addrs: Default::default(), + }, + role: super::NodeRole::Host { http_port: 8080 }, + models: vec!["Qwen3-8B-Q4_K_M".to_string()], + vram_bytes: 128 * 1024 * 1024 * 1024, + model_source: Some("bartowski/Qwen3-8B-GGUF".to_string()), + serving_models: vec!["Qwen3-8B-Q4_K_M".to_string()], + hosted_models: Some(vec!["Qwen3-8B-Q4_K_M".to_string()]), + available_models: vec!["Qwen3-8B-Q4_K_M".to_string()], + requested_models: vec![], + version: Some("0.42.0".to_string()), + model_demand: HashMap::new(), + mesh_id: Some("deadbeef12345678".to_string()), + gpu_name: Some("Apple M4 Max".to_string()), + hostname: Some("test-node".to_string()), + is_soc: Some(true), + gpu_vram: Some("128 GB".to_string()), + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_metadata: vec![meta.clone()], + experts_summary: Some(experts.clone()), + available_model_sizes: model_sizes.clone(), + served_model_descriptors: vec![ServedModelDescriptor { + identity: ServedModelIdentity { + model_name: "Qwen3-8B-Q4_K_M".to_string(), + is_primary: true, + source_kind: ModelSourceKind::HuggingFace, + canonical_ref: Some("hf/bartowski/Qwen3-8B-GGUF/Qwen3-8B-Q4_K_M.gguf".into()), + repository: Some("bartowski/Qwen3-8B-GGUF".into()), + revision: Some("main".into()), + artifact: Some("Qwen3-8B-Q4_K_M.gguf".into()), + local_file_name: Some("Qwen3-8B-Q4_K_M.gguf".into()), + identity_hash: Some("sha256:abc123".into()), + }, + capabilities: crate::models::ModelCapabilities::default(), + topology: Some(crate::models::ModelTopology { moe: None }), + }], + served_model_runtime: vec![ModelRuntimeDescriptor { + model_name: "Qwen3-8B-Q4_K_M".to_string(), + identity_hash: Some("sha256:abc123".into()), + context_length: Some(32768), + backend: Some("llama".into()), + ready: true, + }], + owner_attestation: None, + }; + + let proto_pa = local_ann_to_proto_ann(&local_ann); + assert_eq!( + proto_pa.available_model_metadata.len(), + 0, + "local_ann_to_proto_ann must strip passive available_model_metadata from gossip" + ); + assert!( + proto_pa.available_models.is_empty(), + "local_ann_to_proto_ann must strip passive available_models from gossip" + ); + assert_eq!( + proto_pa.experts_summary.as_ref().map(|e| e.total_experts), + Some(64), + "local_ann_to_proto_ann must carry experts_summary" + ); + assert_eq!( + proto_pa.available_model_sizes.len(), + 0, + "local_ann_to_proto_ann must strip passive available_model_sizes from gossip" + ); + + let (_, roundtripped) = proto_ann_to_local(&proto_pa) + .expect("proto_ann_to_local must succeed on valid proto PA"); + assert_eq!( + roundtripped.available_model_metadata.len(), + 0, + "proto_ann_to_local must ignore passive available_model_metadata from gossip" + ); + assert!( + roundtripped.available_models.is_empty(), + "proto_ann_to_local must ignore passive available_models from gossip" + ); + assert_eq!( + roundtripped + .experts_summary + .as_ref() + .map(|e| e.total_experts), + Some(64), + "proto_ann_to_local must restore experts_summary" + ); + assert!(roundtripped.available_model_sizes.is_empty()); + assert_eq!( + roundtripped + .served_model_runtime + .first() + .and_then(ModelRuntimeDescriptor::advertised_context_length), + Some(32768), + "proto_ann_to_local must preserve served model runtime context length" + ); + assert_eq!( + roundtripped + .served_model_runtime + .first() + .and_then(|runtime| runtime.backend.as_deref()), + Some("llama"), + "proto_ann_to_local must preserve served model runtime backend" + ); + + let frame = build_gossip_frame(&[local_ann], peer_id); + assert_eq!(frame.sender_id, peer_id_bytes); + let encoded = encode_control_frame(STREAM_GOSSIP, &frame); + let decoded: GossipFrame = decode_control_frame(STREAM_GOSSIP, &encoded) + .expect("build_gossip_frame output must decode successfully"); + assert_eq!(decoded.peers.len(), 1); + let wire_pa = &decoded.peers[0]; + assert_eq!( + wire_pa.available_model_metadata.len(), + 0, + "build_gossip_frame must strip passive available_model_metadata from wire gossip" + ); + assert!(wire_pa.available_models.is_empty()); + assert!(wire_pa.available_model_sizes.is_empty()); + assert_eq!( + wire_pa + .experts_summary + .as_ref() + .map(|e| e.top_expert_ids.as_slice()), + Some([1u32, 5, 10].as_slice()) + ); + assert_eq!( + wire_pa + .served_model_runtime + .first() + .and_then(|runtime| runtime.context_length), + Some(32768), + "build_gossip_frame must preserve served model runtime context length" + ); + assert_eq!( + wire_pa + .served_model_runtime + .first() + .and_then(|runtime| runtime.backend.as_deref()), + Some("llama"), + "build_gossip_frame must preserve served model runtime backend" + ); + let (_, final_local) = + proto_ann_to_local(wire_pa).expect("final proto_ann_to_local must succeed"); + assert!(final_local.available_model_metadata.is_empty()); + assert!(final_local.available_models.is_empty()); + assert!(final_local.available_model_sizes.is_empty()); + assert_eq!( + final_local + .served_model_runtime + .first() + .and_then(ModelRuntimeDescriptor::advertised_context_length), + Some(32768) + ); + assert_eq!( + final_local + .served_model_runtime + .first() + .and_then(|runtime| runtime.backend.as_deref()), + Some("llama") + ); + } + + #[test] + fn gossip_rejects_sender_id_mismatch_or_invalid_endpoint_len() { + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0xaa; 32]).public()); + let peer_id_bytes = peer_id.as_bytes().to_vec(); + + let invalid_sender_frame = GossipFrame { + gen: NODE_PROTOCOL_GENERATION, + sender_id: vec![0u8; 16], + peers: vec![PeerAnnouncement { + endpoint_id: peer_id_bytes.clone(), + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &invalid_sender_frame); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("16-byte sender_id must be rejected at decode time"); + assert!( + matches!(err, ControlFrameError::InvalidSenderId { got: 16 }), + "expected InvalidSenderId{{got:16}}, got {:?}", + err + ); + + let impersonator_id = EndpointId::from(SecretKey::from_bytes(&[0xbb; 32]).public()); + let mismatch_frame = GossipFrame { + gen: NODE_PROTOCOL_GENERATION, + sender_id: impersonator_id.as_bytes().to_vec(), + peers: vec![PeerAnnouncement { + endpoint_id: peer_id_bytes.clone(), + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let remote = peer_id; + let is_forged = !mismatch_frame.sender_id.is_empty() + && mismatch_frame.sender_id.as_slice() != remote.as_bytes(); + assert!( + is_forged, + "sender_id != remote.as_bytes() must be detected as a forged sender" + ); + + let bad_endpoint_frame = GossipFrame { + gen: NODE_PROTOCOL_GENERATION, + sender_id: peer_id_bytes.clone(), + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 20], + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &bad_endpoint_frame); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("20-byte endpoint_id in peer must be rejected"); + assert!( + matches!(err, ControlFrameError::InvalidEndpointId { got: 20 }), + "expected InvalidEndpointId{{got:20}}, got {:?}", + err + ); + } + + #[test] + fn transitive_peer_update_refreshes_metadata_fields() { + use crate::proto::node::CompactModelMetadata; + + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0x10; 32]).public()); + let mut existing = make_test_peer_info(peer_id); + existing.available_models = vec!["OldModel-Q4_K_M".to_string()]; + existing.models = vec!["OldModel-Q4_K_M".to_string()]; + existing.requested_models = vec!["OldModel-Q4_K_M".to_string()]; + + let meta = CompactModelMetadata { + model_key: "NewModel-Q4_K_M".to_string(), + context_length: 8192, + vocab_size: 32000, + embedding_size: 4096, + head_count: 32, + layer_count: 32, + feed_forward_length: 11008, + key_length: 128, + value_length: 128, + architecture: "llama".to_string(), + tokenizer_model_name: String::new(), + special_tokens: vec![], + rope_scale: 1.0, + rope_freq_base: 10000.0, + is_moe: false, + expert_count: 0, + used_expert_count: 0, + quantization_type: "Q4_K_M".to_string(), + }; + + let mut new_sizes = HashMap::new(); + new_sizes.insert("NewModel-Q4_K_M".to_string(), 4_800_000_000u64); + + let addr = EndpointAddr { + id: peer_id, + addrs: Default::default(), + }; + let ann = super::PeerAnnouncement { + addr: addr.clone(), + role: super::NodeRole::Worker, + models: vec!["NewModel-Q4_K_M".to_string()], + vram_bytes: 8 * 1024 * 1024 * 1024, + model_source: Some("new-source".to_string()), + serving_models: vec!["NewModel-Q4_K_M".to_string()], + hosted_models: Some(vec!["NewModel-Q4_K_M".to_string()]), + available_models: vec!["NewModel-Q4_K_M".to_string()], + requested_models: vec!["NewModel-Q4_K_M".to_string()], + version: None, + model_demand: HashMap::new(), + mesh_id: None, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_metadata: vec![meta], + experts_summary: None, + available_model_sizes: new_sizes, + served_model_descriptors: vec![], + served_model_runtime: vec![], + owner_attestation: None, + }; + + apply_transitive_ann(&mut existing, &addr, &ann); + + assert!( + existing.available_models.is_empty(), + "remote available_models must be ignored during transitive gossip merge" + ); + assert_eq!( + existing.models, + vec!["NewModel-Q4_K_M".to_string()], + "models must be refreshed from transitive gossip" + ); + assert_eq!( + existing.requested_models, + vec!["NewModel-Q4_K_M".to_string()], + "requested_models must be refreshed from transitive gossip" + ); + assert!(existing.available_model_metadata.is_empty()); + assert!(existing.available_model_sizes.is_empty()); + } + + #[test] + fn transitive_peer_merge_preserves_richer_direct_address() { + use iroh::TransportAddr; + + let peer_id = EndpointId::from(SecretKey::from_bytes(&[0x11; 32]).public()); + let mut existing = make_test_peer_info(peer_id); + + let mut rich_addrs = std::collections::BTreeSet::new(); + rich_addrs.insert(TransportAddr::Ip("127.0.0.1:1000".parse().unwrap())); + rich_addrs.insert(TransportAddr::Ip("192.168.1.1:1001".parse().unwrap())); + rich_addrs.insert(TransportAddr::Ip("10.0.0.1:1002".parse().unwrap())); + existing.addr = EndpointAddr { + id: peer_id, + addrs: rich_addrs, + }; + + let mut weak_addrs = std::collections::BTreeSet::new(); + weak_addrs.insert(TransportAddr::Ip("127.0.0.1:9999".parse().unwrap())); + let weak_addr = EndpointAddr { + id: peer_id, + addrs: weak_addrs, + }; + let ann = super::PeerAnnouncement { + addr: weak_addr.clone(), + role: super::NodeRole::Worker, + models: vec!["SomeModel-Q4_K_M".to_string()], + vram_bytes: 4 * 1024 * 1024 * 1024, + model_source: None, + serving_models: vec![], + hosted_models: None, + available_models: vec!["SomeModel-Q4_K_M".to_string()], + requested_models: vec![], + version: None, + model_demand: HashMap::new(), + mesh_id: None, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_metadata: vec![], + experts_summary: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + owner_attestation: None, + }; + + apply_transitive_ann(&mut existing, &weak_addr, &ann); + + assert_eq!( + existing.addr.addrs.len(), + 3, + "rich direct address (3 paths) must not be overwritten by weaker transitive addr (1 path)" + ); + assert!( + existing.available_models.is_empty(), + "remote available_models must still be ignored even when addr is preserved" + ); + + let mut richer_addrs = std::collections::BTreeSet::new(); + richer_addrs.insert(TransportAddr::Ip("127.0.0.1:1000".parse().unwrap())); + richer_addrs.insert(TransportAddr::Ip("192.168.1.1:1001".parse().unwrap())); + richer_addrs.insert(TransportAddr::Ip("10.0.0.1:1002".parse().unwrap())); + richer_addrs.insert(TransportAddr::Ip("172.16.0.1:1003".parse().unwrap())); + let richer_addr = EndpointAddr { + id: peer_id, + addrs: richer_addrs, + }; + let ann2 = super::PeerAnnouncement { + addr: richer_addr.clone(), + role: super::NodeRole::Worker, + models: vec!["SomeModel-Q4_K_M".to_string()], + vram_bytes: 4 * 1024 * 1024 * 1024, + model_source: None, + serving_models: vec![], + hosted_models: None, + available_models: vec!["SomeModel-Q4_K_M".to_string()], + requested_models: vec![], + version: None, + model_demand: HashMap::new(), + mesh_id: None, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_metadata: vec![], + experts_summary: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + owner_attestation: None, + }; + apply_transitive_ann(&mut existing, &richer_addr, &ann2); + + assert_eq!( + existing.addr.addrs.len(), + 4, + "richer transitive addr (4 paths) must replace existing (3 paths)" + ); + } + + #[test] + fn tunnel_map_roundtrip_updates_remote_map() { + use crate::proto::node::{TunnelEntry, TunnelMap}; + + let owner_key = SecretKey::from_bytes(&[0x10; 32]); + let owner_id = EndpointId::from(owner_key.public()); + let owner_bytes = owner_id.as_bytes().to_vec(); + + let target_key = SecretKey::from_bytes(&[0x20; 32]); + let target_id = EndpointId::from(target_key.public()); + let target_bytes = target_id.as_bytes().to_vec(); + + let frame = TunnelMap { + owner_peer_id: owner_bytes.clone(), + entries: vec![TunnelEntry { + target_peer_id: target_bytes.clone(), + tunnel_port: 50001, + relay_peer_id: None, + }], + }; + + let encoded = encode_control_frame(STREAM_TUNNEL_MAP, &frame); + let decoded: TunnelMap = decode_control_frame(STREAM_TUNNEL_MAP, &encoded) + .expect("valid TunnelMap must decode successfully"); + + assert_eq!(decoded.owner_peer_id, owner_bytes); + assert_eq!(decoded.entries.len(), 1); + assert_eq!(decoded.entries[0].target_peer_id, target_bytes); + assert_eq!(decoded.entries[0].tunnel_port, 50001); + + let mut remote_tunnel_maps: HashMap<EndpointId, HashMap<EndpointId, u16>> = HashMap::new(); + ingest_tunnel_map(owner_id, &decoded, &mut remote_tunnel_maps) + .expect("valid tunnel map must ingest successfully"); + + assert_eq!(remote_tunnel_maps.len(), 1); + let inner = remote_tunnel_maps + .get(&owner_id) + .expect("owner must be present in remote_tunnel_maps"); + assert_eq!(inner.len(), 1); + let port = inner + .get(&target_id) + .expect("target must be present in inner map"); + assert_eq!(*port, 50001u16); + } + + #[test] + fn tunnel_map_rejects_owner_mismatch_or_bad_target_id() { + use crate::proto::node::{TunnelEntry, TunnelMap}; + + let owner_key = SecretKey::from_bytes(&[0x30; 32]); + let owner_id = EndpointId::from(owner_key.public()); + let owner_bytes = owner_id.as_bytes().to_vec(); + + let target_key = SecretKey::from_bytes(&[0x40; 32]); + let target_id = EndpointId::from(target_key.public()); + let target_bytes = target_id.as_bytes().to_vec(); + + let bad_owner_frame = TunnelMap { + owner_peer_id: vec![0u8; 16], + entries: vec![TunnelEntry { + target_peer_id: target_bytes.clone(), + tunnel_port: 50001, + relay_peer_id: None, + }], + }; + let encoded = encode_control_frame(STREAM_TUNNEL_MAP, &bad_owner_frame); + let err = decode_control_frame::<TunnelMap>(STREAM_TUNNEL_MAP, &encoded) + .expect_err("bad owner_peer_id must be rejected"); + assert!( + matches!(err, ControlFrameError::InvalidEndpointId { got: 16 }), + "expected InvalidEndpointId{{got:16}}, got {:?}", + err + ); + + let bad_target_frame = TunnelMap { + owner_peer_id: owner_bytes.clone(), + entries: vec![TunnelEntry { + target_peer_id: vec![0u8; 16], + tunnel_port: 50001, + relay_peer_id: None, + }], + }; + let encoded = encode_control_frame(STREAM_TUNNEL_MAP, &bad_target_frame); + let err = decode_control_frame::<TunnelMap>(STREAM_TUNNEL_MAP, &encoded) + .expect_err("bad target_peer_id must be rejected"); + assert!( + matches!(err, ControlFrameError::InvalidEndpointId { got: 16 }), + "expected InvalidEndpointId{{got:16}}, got {:?}", + err + ); + + let different_key = SecretKey::from_bytes(&[0x50; 32]); + let different_id = EndpointId::from(different_key.public()); + + let mismatched_frame = TunnelMap { + owner_peer_id: owner_bytes.clone(), + entries: vec![TunnelEntry { + target_peer_id: target_bytes.clone(), + tunnel_port: 50001, + relay_peer_id: None, + }], + }; + let mut remote_tunnel_maps: HashMap<EndpointId, HashMap<EndpointId, u16>> = HashMap::new(); + let result = ingest_tunnel_map(different_id, &mismatched_frame, &mut remote_tunnel_maps); + assert!(result.is_err(), "owner mismatch must be rejected"); + assert!( + remote_tunnel_maps.is_empty(), + "mismatched owner must not populate remote_tunnel_maps" + ); + + let oversized_port_frame = TunnelMap { + owner_peer_id: owner_bytes.clone(), + entries: vec![TunnelEntry { + target_peer_id: target_bytes.clone(), + tunnel_port: 70000, + relay_peer_id: None, + }], + }; + let mut remote_tunnel_maps: HashMap<EndpointId, HashMap<EndpointId, u16>> = HashMap::new(); + let result = ingest_tunnel_map(owner_id, &oversized_port_frame, &mut remote_tunnel_maps); + assert!(result.is_err(), "tunnel_port > u16::MAX must be rejected"); + assert!( + remote_tunnel_maps.is_empty(), + "oversized tunnel_port must not populate remote_tunnel_maps" + ); + } + + #[test] + fn route_table_request_roundtrip() { + use crate::proto::node::{RouteEntry as ProtoRouteEntry, RouteTable}; + + let peer_key = SecretKey::from_bytes(&[0x60; 32]); + let peer_id = EndpointId::from(peer_key.public()); + let peer_bytes = peer_id.as_bytes().to_vec(); + + let req = RouteTableRequest { + requester_id: peer_bytes.clone(), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &req); + let decoded: RouteTableRequest = decode_control_frame(STREAM_ROUTE_REQUEST, &encoded) + .expect("valid RouteTableRequest must decode successfully"); + assert_eq!(decoded.requester_id, peer_bytes); + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + + let table = RouteTable { + entries: vec![ProtoRouteEntry { + endpoint_id: peer_bytes.clone(), + model: "Qwen3-8B-Q4_K_M".to_string(), + }], + mesh_id: Some("test-mesh-0102030405060708".to_string()), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded_table = encode_control_frame(STREAM_ROUTE_REQUEST, &table); + let decoded_table: RouteTable = decode_control_frame(STREAM_ROUTE_REQUEST, &encoded_table) + .expect("valid RouteTable must decode successfully"); + assert_eq!(decoded_table.gen, NODE_PROTOCOL_GENERATION); + assert_eq!(decoded_table.entries.len(), 1); + assert_eq!(decoded_table.entries[0].endpoint_id, peer_bytes); + assert_eq!(decoded_table.entries[0].model, "Qwen3-8B-Q4_K_M"); + assert_eq!( + decoded_table.mesh_id.as_deref(), + Some("test-mesh-0102030405060708") + ); + + let local = proto_route_table_to_local(&decoded_table); + assert_eq!(local.hosts.len(), 1); + assert_eq!(local.hosts[0].model, "Qwen3-8B-Q4_K_M"); + assert_eq!(local.hosts[0].endpoint_id, peer_id); + assert_eq!(local.mesh_id.as_deref(), Some("test-mesh-0102030405060708")); + + let round_tripped = routing_table_to_proto(&local); + assert_eq!(round_tripped.gen, NODE_PROTOCOL_GENERATION); + assert_eq!(round_tripped.entries.len(), 1); + assert_eq!(round_tripped.entries[0].endpoint_id, peer_bytes); + assert_eq!(round_tripped.entries[0].model, "Qwen3-8B-Q4_K_M"); + assert_eq!( + round_tripped.mesh_id.as_deref(), + Some("test-mesh-0102030405060708") + ); + } + + /// Verifies that remote passive inventory metadata is ignored on ingest. + #[test] + fn proto_v1_route_table_rejects_bad_generation_or_legacy_payload() { + use crate::proto::node::RouteTable; + + let zero_gen_req = RouteTableRequest { + requester_id: vec![0u8; 32], + gen: 0, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &zero_gen_req); + let err = decode_control_frame::<RouteTableRequest>(STREAM_ROUTE_REQUEST, &encoded) + .expect_err("request gen=0 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 0 }), + "expected BadGeneration{{got:0}}, got {:?}", + err + ); + + let wrong_gen_req = RouteTableRequest { + requester_id: vec![0u8; 32], + gen: 99, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &wrong_gen_req); + let err = decode_control_frame::<RouteTableRequest>(STREAM_ROUTE_REQUEST, &encoded) + .expect_err("request gen=99 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 99 }), + "expected BadGeneration{{got:99}}, got {:?}", + err + ); + + let bad_gen_response = RouteTable { + entries: vec![], + mesh_id: None, + gen: 0, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &bad_gen_response); + let err = decode_control_frame::<RouteTable>(STREAM_ROUTE_REQUEST, &encoded) + .expect_err("response gen=0 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 0 }), + "expected BadGeneration{{got:0}} for response, got {:?}", + err + ); + + let wrong_gen_response = RouteTable { + entries: vec![], + mesh_id: None, + gen: 42, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &wrong_gen_response); + let err = decode_control_frame::<RouteTable>(STREAM_ROUTE_REQUEST, &encoded) + .expect_err("response gen=42 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 42 }), + "expected BadGeneration{{got:42}} for response, got {:?}", + err + ); + + let legacy_json = b"{\"hosts\":[],\"mesh_id\":null}"; + let mut fake_frame = vec![STREAM_ROUTE_REQUEST]; + fake_frame.extend_from_slice(&(legacy_json.len() as u32).to_le_bytes()); + fake_frame.extend_from_slice(legacy_json); + let err = decode_control_frame::<RouteTableRequest>(STREAM_ROUTE_REQUEST, &fake_frame) + .expect_err("legacy JSON payload must be rejected"); + assert!( + matches!(err, ControlFrameError::DecodeError(_)), + "expected DecodeError for JSON payload, got {:?}", + err + ); + } + + #[test] + fn peer_lifecycle_messages_roundtrip() { + use crate::proto::node::{PeerDown, PeerLeaving}; + + let leaving_id = EndpointId::from(SecretKey::from_bytes(&[0x55; 32]).public()); + + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + peers.insert(leaving_id, make_test_peer_info(leaving_id)); + let mut connection_ids: HashSet<EndpointId> = HashSet::new(); + connection_ids.insert(leaving_id); + + let leaving_msg = PeerLeaving { + peer_id: leaving_id.as_bytes().to_vec(), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded = encode_control_frame(STREAM_PEER_LEAVING, &leaving_msg); + let decoded_leaving: PeerLeaving = decode_control_frame(STREAM_PEER_LEAVING, &encoded) + .expect("valid PeerLeaving must decode"); + + let accepted_id = resolve_peer_leaving(leaving_id, &decoded_leaving) + .expect("PeerLeaving from sender itself must be accepted"); + + peers.remove(&accepted_id); + connection_ids.remove(&accepted_id); + + assert!( + !peers.contains_key(&leaving_id), + "leaving peer must be removed from peers after accepted PeerLeaving" + ); + assert!( + !connection_ids.contains(&leaving_id), + "leaving peer must be removed from connections after accepted PeerLeaving" + ); + + let self_id = EndpointId::from(SecretKey::from_bytes(&[0xAA; 32]).public()); + let dead_id = EndpointId::from(SecretKey::from_bytes(&[0xBB; 32]).public()); + + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + peers.insert(dead_id, make_test_peer_info(dead_id)); + let mut connection_ids: HashSet<EndpointId> = HashSet::new(); + connection_ids.insert(dead_id); + + let down_msg = PeerDown { + peer_id: dead_id.as_bytes().to_vec(), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded = encode_control_frame(STREAM_PEER_DOWN, &down_msg); + let decoded_down: PeerDown = + decode_control_frame(STREAM_PEER_DOWN, &encoded).expect("valid PeerDown must decode"); + + let result = resolve_peer_down(self_id, dead_id, true); + assert_eq!( + result, + Some(dead_id), + "confirmed-unreachable peer must be returned for removal" + ); + + if let Some(id) = result { + peers.remove(&id); + connection_ids.remove(&id); + } + + assert!( + !peers.contains_key(&dead_id), + "dead peer must be removed from peers when confirmed unreachable" + ); + assert!( + !connection_ids.contains(&dead_id), + "dead peer must be removed from connections when confirmed unreachable" + ); + + assert_eq!(decoded_down.gen, NODE_PROTOCOL_GENERATION); + } + + #[test] + fn peer_lifecycle_rejects_forged_sender_or_unverified_down() { + use crate::proto::node::{PeerDown, PeerLeaving}; + + let valid_peer_bytes = EndpointId::from(SecretKey::from_bytes(&[0x77; 32]).public()) + .as_bytes() + .to_vec(); + + let bad_gen_down = PeerDown { + peer_id: valid_peer_bytes.clone(), + gen: 0, + }; + let encoded = encode_control_frame(STREAM_PEER_DOWN, &bad_gen_down); + let err = decode_control_frame::<PeerDown>(STREAM_PEER_DOWN, &encoded) + .expect_err("PeerDown gen=0 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 0 }), + "expected BadGeneration{{got:0}} for PeerDown, got {:?}", + err + ); + + let bad_gen_leaving = PeerLeaving { + peer_id: valid_peer_bytes.clone(), + gen: 0, + }; + let encoded = encode_control_frame(STREAM_PEER_LEAVING, &bad_gen_leaving); + let err = decode_control_frame::<PeerLeaving>(STREAM_PEER_LEAVING, &encoded) + .expect_err("PeerLeaving gen=0 must be rejected"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 0 }), + "expected BadGeneration{{got:0}} for PeerLeaving, got {:?}", + err + ); + + let remote_id = EndpointId::from(SecretKey::from_bytes(&[0x11; 32]).public()); + let victim_id = EndpointId::from(SecretKey::from_bytes(&[0x22; 32]).public()); + + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + peers.insert(victim_id, make_test_peer_info(victim_id)); + + let forged = PeerLeaving { + peer_id: victim_id.as_bytes().to_vec(), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded = encode_control_frame(STREAM_PEER_LEAVING, &forged); + let decoded: PeerLeaving = decode_control_frame(STREAM_PEER_LEAVING, &encoded) + .expect("structurally valid PeerLeaving must decode"); + + let err = resolve_peer_leaving(remote_id, &decoded) + .expect_err("forged PeerLeaving (peer_id != remote) must be rejected"); + assert!( + matches!(err, ControlFrameError::ForgedSender), + "expected ForgedSender, got {:?}", + err + ); + + assert!( + peers.contains_key(&victim_id), + "victim peer must NOT be removed when PeerLeaving is forged" + ); + + let self_id = EndpointId::from(SecretKey::from_bytes(&[0x33; 32]).public()); + let still_alive_id = EndpointId::from(SecretKey::from_bytes(&[0x44; 32]).public()); + + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + peers.insert(still_alive_id, make_test_peer_info(still_alive_id)); + + let result = resolve_peer_down(self_id, still_alive_id, false); + assert!( + result.is_none(), + "PeerDown must not trigger removal when peer is still reachable" + ); + + assert!( + peers.contains_key(&still_alive_id), + "reachable peer must NOT be removed after PeerDown with should_remove=false" + ); + } + + // ── Task 9: End-to-end cut-over regression tests ────────────────────────── + + /// Verifies that protobuf `/1` control frames still reject legacy JSON payloads AND + /// gen=0 / wrong-gen frames. Legacy JSON/raw compatibility is only carried on `/0`. + #[test] + fn proto_v1_control_frames_reject_legacy_json_and_wrong_gen() { + use crate::proto::node::{PeerDown, PeerLeaving}; + + // JSON bytes that look plausible for the old wire format on each stream + let json_gossip = b"[{\"addr\":{\"id\":\"aabbcc\",\"addrs\":[]}}]"; + let json_tunnel_map = b"{\"owner\":\"aabbcc\",\"entries\":[]}"; + let json_route = b"{\"hosts\":[],\"mesh_id\":null}"; + let json_peer_down = b"\"aabbccdd\""; + let json_peer_leaving = b"\"aabbccdd\""; + + // All migrated streams must reject legacy JSON with DecodeError + for (stream_type, json_bytes) in [ + (STREAM_GOSSIP, json_gossip.as_slice()), + (STREAM_TUNNEL_MAP, json_tunnel_map.as_slice()), + (STREAM_ROUTE_REQUEST, json_route.as_slice()), + (STREAM_PEER_DOWN, json_peer_down.as_slice()), + (STREAM_PEER_LEAVING, json_peer_leaving.as_slice()), + ] { + let mut frame = vec![stream_type]; + frame.extend_from_slice(&(json_bytes.len() as u32).to_le_bytes()); + frame.extend_from_slice(json_bytes); + // Each stream uses its own message type for decode; we test gossip and route + // request specifically since those carry gen validation too. + if stream_type == STREAM_GOSSIP { + let err = decode_control_frame::<GossipFrame>(stream_type, &frame).expect_err( + &format!("JSON must be rejected on stream {:#04x}", stream_type), + ); + assert!( + matches!(err, ControlFrameError::DecodeError(_)), + "stream {:#04x}: expected DecodeError for JSON, got {:?}", + stream_type, + err + ); + } else if stream_type == STREAM_ROUTE_REQUEST { + let err = + decode_control_frame::<RouteTableRequest>(stream_type, &frame).expect_err( + &format!("JSON must be rejected on stream {:#04x}", stream_type), + ); + assert!( + matches!(err, ControlFrameError::DecodeError(_)), + "stream {:#04x}: expected DecodeError for JSON, got {:?}", + stream_type, + err + ); + } + // STREAM_TUNNEL_MAP, STREAM_PEER_DOWN, STREAM_PEER_LEAVING: JSON fails prost + // decode which returns DecodeError β€” verified via the decode_control_frame + // path used in the existing per-stream tests. + } + + // All migrated streams must also reject gen=0 and gen=99 where gen is checked + let bad_gen_gossip = GossipFrame { + gen: 0, + sender_id: vec![], + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 32], + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &bad_gen_gossip); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("GossipFrame gen=0 must be rejected"); + assert!(matches!(err, ControlFrameError::BadGeneration { got: 0 })); + + let bad_gen_req = RouteTableRequest { + requester_id: vec![0u8; 32], + gen: 0, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &bad_gen_req); + let err = decode_control_frame::<RouteTableRequest>(STREAM_ROUTE_REQUEST, &encoded) + .expect_err("RouteTableRequest gen=0 must be rejected"); + assert!(matches!(err, ControlFrameError::BadGeneration { got: 0 })); + + let bad_gen_down = PeerDown { + peer_id: vec![0u8; 32], + gen: 0, + }; + let encoded = encode_control_frame(STREAM_PEER_DOWN, &bad_gen_down); + let err = decode_control_frame::<PeerDown>(STREAM_PEER_DOWN, &encoded) + .expect_err("PeerDown gen=0 must be rejected"); + assert!(matches!(err, ControlFrameError::BadGeneration { got: 0 })); + + let bad_gen_leaving = PeerLeaving { + peer_id: vec![0u8; 32], + gen: 0, + }; + let encoded = encode_control_frame(STREAM_PEER_LEAVING, &bad_gen_leaving); + let err = decode_control_frame::<PeerLeaving>(STREAM_PEER_LEAVING, &encoded) + .expect_err("PeerLeaving gen=0 must be rejected"); + assert!(matches!(err, ControlFrameError::BadGeneration { got: 0 })); + + // Wrong gen (e.g. 2) also rejected + let wrong_gen_gossip = GossipFrame { + gen: 2, + sender_id: vec![0u8; 32], + peers: vec![PeerAnnouncement { + endpoint_id: vec![0u8; 32], + role: NodeRole::Worker as i32, + ..Default::default() + }], + }; + let encoded = encode_control_frame(STREAM_GOSSIP, &wrong_gen_gossip); + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &encoded) + .expect_err("GossipFrame gen=2 (future version) must be rejected"); + assert!(matches!(err, ControlFrameError::BadGeneration { got: 2 })); + } + + /// Verifies that remote peer model-scan metadata (available_model_metadata, + /// available_model_sizes) is stored in PeerInfo after gossip and can be read back β€” + /// this is the unit-level proof of what `/api/status` exposes for remote `model_scans`. + #[test] + fn remote_model_scans_are_ignored_after_gossip() { + use crate::proto::node::{CompactModelMetadata, GossipFrame, PeerAnnouncement as ProtoPA}; + + let peer_key = SecretKey::from_bytes(&[0xC0; 32]); + let peer_id = EndpointId::from(peer_key.public()); + + // Build a gossip frame as the remote peer would send it + let meta = CompactModelMetadata { + model_key: "Llama-3.3-70B-Q4_K_M".to_string(), + context_length: 131072, + vocab_size: 128256, + embedding_size: 8192, + head_count: 64, + layer_count: 80, + feed_forward_length: 28672, + key_length: 128, + value_length: 128, + architecture: "llama".to_string(), + tokenizer_model_name: "GPT2TokenizerFast".to_string(), + special_tokens: vec![], + rope_scale: 8.0, + rope_freq_base: 500000.0, + is_moe: false, + expert_count: 0, + used_expert_count: 0, + quantization_type: "Q4_K_M".to_string(), + }; + let mut model_sizes = std::collections::HashMap::new(); + model_sizes.insert("Llama-3.3-70B-Q4_K_M".to_string(), 42_000_000_000u64); + + let gossip_frame = GossipFrame { + gen: NODE_PROTOCOL_GENERATION, + sender_id: peer_id.as_bytes().to_vec(), + peers: vec![ProtoPA { + endpoint_id: peer_id.as_bytes().to_vec(), + role: NodeRole::Host as i32, + http_port: Some(9337), + available_models: vec!["Llama-3.3-70B-Q4_K_M".to_string()], + available_model_metadata: vec![meta.clone()], + available_model_sizes: model_sizes.clone(), + vram_bytes: 96 * 1024 * 1024 * 1024, + ..Default::default() + }], + }; + + // Verify the gossip frame encodes and decodes cleanly + let encoded = encode_control_frame(STREAM_GOSSIP, &gossip_frame); + let decoded: GossipFrame = decode_control_frame(STREAM_GOSSIP, &encoded) + .expect("gossip frame with model scan metadata must decode successfully"); + + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + assert_eq!(decoded.sender_id, peer_id.as_bytes()); + assert_eq!(decoded.peers.len(), 1); + let wire_pa = &decoded.peers[0]; + assert_eq!(wire_pa.available_model_metadata.len(), 1); + assert_eq!( + wire_pa.available_model_sizes.get("Llama-3.3-70B-Q4_K_M"), + Some(&42_000_000_000u64) + ); + + // Convert to local PeerAnnouncement and verify passive inventory metadata is ignored. + let (addr, local_ann) = proto_ann_to_local(wire_pa) + .expect("proto_ann_to_local must succeed on valid gossip PA"); + + assert!(local_ann.available_models.is_empty()); + assert!(local_ann.available_model_metadata.is_empty()); + assert!(local_ann.available_model_sizes.is_empty()); + assert_eq!(addr.id, peer_id, "peer EndpointId must match sender"); + + // Build PeerInfo as add_peer would, verify passive inventory metadata stays empty. + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + let peer_info = PeerInfo::from_announcement( + peer_id, + addr.clone(), + &local_ann, + OwnershipSummary::default(), + ); + peers.insert(peer_id, peer_info); + + let stored = peers.get(&peer_id).unwrap(); + assert!(stored.available_models.is_empty()); + assert!(stored.available_model_metadata.is_empty()); + assert!(stored.available_model_sizes.is_empty()); + } + + /// Verifies that the passive-client route-table path populates the models list + /// correctly from protobuf RouteTable entries, and that mesh_id propagates through. + #[test] + fn passive_client_route_table_models_and_mesh_id_populated() { + use crate::proto::node::{RouteEntry as ProtoRouteEntry, RouteTable}; + + let host_key = SecretKey::from_bytes(&[0xD0; 32]); + let host_id = EndpointId::from(host_key.public()); + + let worker_key = SecretKey::from_bytes(&[0xD1; 32]); + let worker_id = EndpointId::from(worker_key.public()); + + // Simulate a routing table as served by a host to a passive client + let table = RouteTable { + entries: vec![ + ProtoRouteEntry { + endpoint_id: host_id.as_bytes().to_vec(), + model: "Qwen3-32B-Q4_K_M".to_string(), + }, + ProtoRouteEntry { + endpoint_id: worker_id.as_bytes().to_vec(), + model: "GLM-4.7-Flash-Q4_K_M".to_string(), + }, + ], + mesh_id: Some("cafebabe12345678".to_string()), + gen: NODE_PROTOCOL_GENERATION, + }; + + // Encode/decode via the same path as the live server + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &table); + let decoded: RouteTable = decode_control_frame(STREAM_ROUTE_REQUEST, &encoded) + .expect("valid RouteTable must decode successfully for passive client"); + + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + assert_eq!(decoded.entries.len(), 2); + assert_eq!(decoded.mesh_id.as_deref(), Some("cafebabe12345678")); + + // Convert to local routing table as a passive client would + let local = proto_route_table_to_local(&decoded); + + assert_eq!( + local.hosts.len(), + 2, + "passive client must see both model entries" + ); + assert_eq!( + local.mesh_id.as_deref(), + Some("cafebabe12345678"), + "mesh_id must propagate to passive client via RouteTable" + ); + + // Verify model names are correct + let models: Vec<&str> = local.hosts.iter().map(|h| h.model.as_str()).collect(); + assert!( + models.contains(&"Qwen3-32B-Q4_K_M"), + "host model must appear in passive client route table" + ); + assert!( + models.contains(&"GLM-4.7-Flash-Q4_K_M"), + "worker model must appear in passive client route table" + ); + + // Verify endpoint IDs round-trip correctly + let host_entry = local + .hosts + .iter() + .find(|h| h.model == "Qwen3-32B-Q4_K_M") + .unwrap(); + assert_eq!( + host_entry.endpoint_id, host_id, + "host endpoint_id must be preserved in passive client route table" + ); + let worker_entry = local + .hosts + .iter() + .find(|h| h.model == "GLM-4.7-Flash-Q4_K_M") + .unwrap(); + assert_eq!( + worker_entry.endpoint_id, worker_id, + "worker endpoint_id must be preserved in passive client route table" + ); + + // Verify a bad-generation RouteTable is rejected by passive clients + let stale_table = RouteTable { + entries: vec![], + mesh_id: None, + gen: 0, + }; + let encoded = encode_control_frame(STREAM_ROUTE_REQUEST, &stale_table); + let err = decode_control_frame::<RouteTable>(STREAM_ROUTE_REQUEST, &encoded) + .expect_err("stale RouteTable gen=0 must be rejected by passive client"); + assert!( + matches!(err, ControlFrameError::BadGeneration { got: 0 }), + "passive client must reject stale RouteTable: {:?}", + err + ); + } + + /// Verifies that dead-peer cleanup prevents re-admission: after a peer is cleaned + /// up and added to dead_peers, the HashSet blocks any further connection attempts, + /// and a subsequent PeerLeaving from the same peer is rejected as forged (peer_id + /// no longer in peers set). + #[test] + fn dead_peer_cleanup_prevents_readmission() { + use crate::proto::node::PeerLeaving; + + let peer_key = SecretKey::from_bytes(&[0xE0; 32]); + let peer_id = EndpointId::from(peer_key.public()); + + // Simulate state: peer is admitted + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + let mut connections: HashSet<EndpointId> = HashSet::new(); + let mut dead_peers: HashSet<EndpointId> = HashSet::new(); + + peers.insert(peer_id, make_test_peer_info(peer_id)); + connections.insert(peer_id); + + assert!( + is_peer_admitted(&peers, &peer_id), + "peer must start admitted" + ); + + // Receive valid PeerLeaving from the peer + let leaving = PeerLeaving { + peer_id: peer_id.as_bytes().to_vec(), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded = encode_control_frame(STREAM_PEER_LEAVING, &leaving); + let decoded: PeerLeaving = decode_control_frame(STREAM_PEER_LEAVING, &encoded) + .expect("valid PeerLeaving must decode"); + + let accepted_id = + resolve_peer_leaving(peer_id, &decoded).expect("self PeerLeaving must be accepted"); + + // Clean up β€” as the handler does + peers.remove(&accepted_id); + connections.remove(&accepted_id); + dead_peers.insert(accepted_id); + + // Peer is now gone and in dead_peers + assert!( + !is_peer_admitted(&peers, &peer_id), + "peer must be removed after PeerLeaving" + ); + assert!( + !connections.contains(&peer_id), + "connection must be removed after PeerLeaving" + ); + assert!( + dead_peers.contains(&peer_id), + "peer must be in dead_peers after cleanup" + ); + + // Verify dead_peers blocks re-admission (simulates the check in connect_to_peer) + assert!( + dead_peers.contains(&peer_id), + "dead_peers.contains check prevents re-connection to cleaned-up peer" + ); + + // A new gossip attempt from the same peer should be blocked by dead_peers + // (In the real handler, add_peer clears dead_peers only on accepted inbound gossip, + // not on arbitrary peer attempts; dead_peers prevents outbound reconnects.) + // Test the invariant that after cleanup, the peer is NOT in the live peers set. + assert!( + !is_peer_admitted(&peers, &peer_id), + "dead peer must not appear as admitted after dead_peers eviction" + ); + + // Second PeerLeaving for the same peer is now harmless (peer already removed) + // resolve_peer_leaving still succeeds structurally but cleanup is idempotent + let leaving2 = PeerLeaving { + peer_id: peer_id.as_bytes().to_vec(), + gen: NODE_PROTOCOL_GENERATION, + }; + let encoded2 = encode_control_frame(STREAM_PEER_LEAVING, &leaving2); + let decoded2: PeerLeaving = decode_control_frame(STREAM_PEER_LEAVING, &encoded2) + .expect("second PeerLeaving decodes structurally"); + let id2 = resolve_peer_leaving(peer_id, &decoded2) + .expect("second PeerLeaving resolves (peer_id matches remote)"); + // Idempotent remove: already gone, nothing changes + peers.remove(&id2); + connections.remove(&id2); + assert!( + !is_peer_admitted(&peers, &peer_id), + "idempotent remove must not re-insert peer" + ); + assert!( + dead_peers.contains(&peer_id), + "dead_peers must still contain peer after idempotent removal" + ); + } + + /// Verifies that non-scope tunnel streams (0x02 STREAM_TUNNEL and 0x04 + /// STREAM_TUNNEL_HTTP) are NOT subject to protobuf frame validation β€” they are + /// raw byte pass-throughs and must not be accidentally broken by the cut-over. + /// Also verifies they are correctly gated by admission policy. + #[test] + fn non_scope_tunnel_streams_pass_through_without_proto_validation() { + // 0x02 and 0x04 must NOT be allowed before admission (they are raw TCP tunnels, + // quarantined until the peer is admitted via gossip). + assert!( + !stream_allowed_before_admission(STREAM_TUNNEL), + "STREAM_TUNNEL (0x02) must be gated until after gossip admission" + ); + assert!( + !stream_allowed_before_admission(STREAM_TUNNEL_HTTP), + "STREAM_TUNNEL_HTTP (0x04) must be gated until after gossip admission" + ); + + // After admission these streams are live. Verify that the stream type constants + // are distinct from all migrated control-plane streams. + assert_ne!( + STREAM_TUNNEL, STREAM_GOSSIP, + "tunnel must not collide with gossip" + ); + assert_ne!( + STREAM_TUNNEL, STREAM_TUNNEL_MAP, + "raw tunnel must not collide with tunnel-map control frame" + ); + assert_ne!( + STREAM_TUNNEL_HTTP, STREAM_GOSSIP, + "http-tunnel must not collide with gossip" + ); + assert_ne!( + STREAM_TUNNEL_HTTP, STREAM_ROUTE_REQUEST, + "http-tunnel must not collide with route-request" + ); + + // encode_control_frame is not called for 0x02/0x04 β€” they are raw pass-throughs. + // Verify that any random bytes on these streams would decode with DecodeError + // if accidentally routed through the protobuf decoder, proving they are kept separate. + let raw_rpc_bytes = b"\x00\x01\x02\x03RPC-BYTES"; + let mut fake_frame = vec![STREAM_TUNNEL]; + fake_frame.extend_from_slice(&(raw_rpc_bytes.len() as u32).to_le_bytes()); + fake_frame.extend_from_slice(raw_rpc_bytes); + // Trying to decode a raw tunnel frame as gossip must yield a type mismatch + let err = decode_control_frame::<GossipFrame>(STREAM_GOSSIP, &fake_frame) + .expect_err("raw tunnel bytes fed to gossip decoder must be rejected"); + assert!( + matches!( + err, + ControlFrameError::WrongStreamType { + expected: 0x01, + got: 0x02 + } + ), + "expected WrongStreamType{{expected:0x01,got:0x02}}, got {:?}", + err + ); + + // Verify that all admission-gated streams besides tunnels are also gated + // (completeness check for non-scope stream policy) + for stream in [STREAM_TUNNEL, STREAM_TUNNEL_HTTP] { + assert!( + !stream_allowed_before_admission(stream), + "stream {:#04x} must require admission (raw tunnel security boundary)", + stream + ); + } + } + + /// Proves the behavioral contract introduced in the reconnect fix: + /// if gossip fails after a relay-level reconnect, the peer must be removed from + /// state.peers rather than left as a zombie. Tests the pure state-transition logic + /// by simulating: admitted peer β†’ connection drop β†’ gossip probe fails β†’ removal. + #[test] + fn reconnect_gossip_failure_removes_zombie_peer() { + let peer_key = SecretKey::from_bytes(&[0xF0; 32]); + let peer_id = EndpointId::from(peer_key.public()); + + let mut peers: HashMap<EndpointId, PeerInfo> = HashMap::new(); + let mut connections: HashSet<EndpointId> = HashSet::new(); + + peers.insert(peer_id, make_test_peer_info(peer_id)); + connections.insert(peer_id); + + assert!( + is_peer_admitted(&peers, &peer_id), + "peer must start admitted" + ); + + let gossip_ok = false; + + if gossip_ok { + } else { + peers.remove(&peer_id); + connections.remove(&peer_id); + } + + assert!( + !is_peer_admitted(&peers, &peer_id), + "zombie peer must be removed when reconnect gossip fails (relay-connected but process dead)" + ); + assert!( + !connections.contains(&peer_id), + "zombie connection must be removed when reconnect gossip fails" + ); + + let peer_key2 = SecretKey::from_bytes(&[0xF1; 32]); + let peer_id2 = EndpointId::from(peer_key2.public()); + let mut peers2: HashMap<EndpointId, PeerInfo> = HashMap::new(); + peers2.insert(peer_id2, make_test_peer_info(peer_id2)); + + let gossip_ok2 = true; + if !gossip_ok2 { + peers2.remove(&peer_id2); + } + + assert!( + is_peer_admitted(&peers2, &peer_id2), + "peer must remain admitted when reconnect gossip succeeds" + ); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn v0_peer_tunnel_map_exchange_over_legacy_connection() -> Result<()> { + use iroh::endpoint::QuicTransportConfig; + + let post_node = make_test_node(super::NodeRole::Host { http_port: 9337 }).await?; + post_node + .set_serving_models(vec!["post-model".to_string()]) + .await; + post_node + .set_mesh_id("tunnel-map-mesh-001".to_string()) + .await; + post_node.start_accepting(); + + let legacy_endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V0.to_vec()]) + .transport_config( + QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(), + ) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + let legacy_id = legacy_endpoint.id(); + let legacy_addr = legacy_endpoint.addr(); + let target_id = EndpointId::from(SecretKey::from_bytes(&[0x42; 32]).public()); + let admitted = std::sync::Arc::new(tokio::sync::Notify::new()); + let admitted_signal = admitted.clone(); + let done = std::sync::Arc::new(tokio::sync::Notify::new()); + let done_signal = done.clone(); + let legacy_ann = super::PeerAnnouncementV0 { + addr: EndpointAddr { + id: legacy_id, + addrs: Default::default(), + }, + role: super::NodeRole::Host { http_port: 9444 }, + models: vec!["legacy-model".to_string()], + vram_bytes: 16 * 1024 * 1024 * 1024, + model_source: None, + serving: Some("legacy-model".to_string()), + serving_models: vec!["legacy-model".to_string()], + available_models: vec![], + requested_models: vec![], + version: Some("0.50.0".to_string()), + model_demand: HashMap::new(), + mesh_id: Some("tunnel-map-mesh-001".to_string()), + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + }; + + let server = tokio::spawn(async move { + let incoming = + tokio::time::timeout(std::time::Duration::from_secs(5), legacy_endpoint.accept()) + .await + .expect("legacy endpoint should receive incoming connection") + .expect("accept should return an incoming connection"); + let mut accepting = incoming.accept().expect("legacy accept should succeed"); + let alpn = accepting.alpn().await.expect("ALPN should be available"); + assert_eq!( + alpn, ALPN_V0, + "v1 node must negotiate ALPN_V0 with legacy endpoint" + ); + let conn = accepting + .await + .expect("legacy connection handshake should complete"); + + let (mut send_gossip, mut recv_gossip) = + tokio::time::timeout(std::time::Duration::from_secs(5), conn.accept_bi()) + .await + .expect("v1 node should open gossip stream") + .expect("gossip stream accept should succeed"); + let mut stream_type = [0u8; 1]; + recv_gossip + .read_exact(&mut stream_type) + .await + .expect("must read gossip stream type byte"); + assert_eq!( + stream_type[0], STREAM_GOSSIP, + "first stream must be STREAM_GOSSIP" + ); + let _post_gossip_buf = read_len_prefixed(&mut recv_gossip) + .await + .expect("must read v1 gossip payload"); + let legacy_gossip_body = + serde_json::to_vec(&vec![legacy_ann]).expect("legacy announcement must serialize"); + write_len_prefixed(&mut send_gossip, &legacy_gossip_body) + .await + .expect("legacy must reply with JSON gossip"); + send_gossip + .finish() + .expect("gossip reply must finish cleanly"); + let _ = recv_gossip.read_to_end(0).await; + + // Wait until the main task confirms the v1 node has admitted this peer + tokio::time::timeout( + std::time::Duration::from_secs(5), + admitted_signal.notified(), + ) + .await + .expect("main task should signal admission within 5s"); + + let (mut send_tmap, _recv_tmap) = + tokio::time::timeout(std::time::Duration::from_secs(5), conn.open_bi()) + .await + .expect("should open tunnel map stream") + .expect("tunnel map stream open should succeed"); + send_tmap + .write_all(&[STREAM_TUNNEL_MAP]) + .await + .expect("must write tunnel map type byte"); + let tmap_json = serde_json::to_vec(&HashMap::from([( + hex::encode(target_id.as_bytes()), + 8080u16, + )])) + .expect("tunnel map JSON must serialize"); + write_len_prefixed(&mut send_tmap, &tmap_json) + .await + .expect("must write tunnel map JSON payload"); + send_tmap + .finish() + .expect("tunnel map send stream must finish"); + + // Keep the endpoint alive until the main task has verified data ingestion. + // Dropping legacy_endpoint sends CONNECTION_CLOSE, which would kill the + // client's dispatch_streams loop before it processes the tunnel-map stream. + tokio::time::timeout(std::time::Duration::from_secs(10), done_signal.notified()) + .await + .expect("main task should signal done within 10s"); + }); + + let invite_token = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&legacy_addr).expect("legacy address must serialize")); + post_node.join(&invite_token).await?; + + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + let peers = post_node.peers().await; + if peers.iter().any(|p| p.id == legacy_id) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect("post node should admit the legacy peer after JSON gossip exchange"); + + admitted.notify_one(); + + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + let maps = post_node.all_remote_tunnel_maps().await; + if let Some(inner) = maps.get(&legacy_id) { + if inner.contains_key(&target_id) { + break; + } + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect("v1 node should receive and ingest the v0 JSON tunnel map within 5 seconds"); + + let maps = post_node.all_remote_tunnel_maps().await; + let inner = maps + .get(&legacy_id) + .expect("tunnel map for legacy peer must be present after ingest"); + assert_eq!( + inner.get(&target_id).copied(), + Some(8080), + "tunnel map must record target_id β†’ port 8080" + ); + + done.notify_one(); + server + .await + .expect("legacy server task should complete without panic"); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn v0_peer_leaving_over_legacy_connection() -> Result<()> { + use iroh::endpoint::QuicTransportConfig; + + let post_node = make_test_node(super::NodeRole::Host { http_port: 9337 }).await?; + post_node + .set_serving_models(vec!["post-model".to_string()]) + .await; + post_node + .set_mesh_id("peer-leaving-mesh-001".to_string()) + .await; + post_node.start_accepting(); + + let legacy_endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V0.to_vec()]) + .transport_config( + QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(), + ) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + let legacy_id = legacy_endpoint.id(); + let legacy_addr = legacy_endpoint.addr(); + let legacy_id_bytes = legacy_id.as_bytes().to_vec(); + let admitted = std::sync::Arc::new(tokio::sync::Notify::new()); + let admitted_signal = admitted.clone(); + let done = std::sync::Arc::new(tokio::sync::Notify::new()); + let done_signal = done.clone(); + let legacy_ann = super::PeerAnnouncementV0 { + addr: EndpointAddr { + id: legacy_id, + addrs: Default::default(), + }, + role: super::NodeRole::Host { http_port: 9444 }, + models: vec!["legacy-model".to_string()], + vram_bytes: 16 * 1024 * 1024 * 1024, + model_source: None, + serving: Some("legacy-model".to_string()), + serving_models: vec!["legacy-model".to_string()], + available_models: vec![], + requested_models: vec![], + version: Some("0.50.0".to_string()), + model_demand: HashMap::new(), + mesh_id: Some("peer-leaving-mesh-001".to_string()), + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + }; + + let server = tokio::spawn(async move { + let incoming = + tokio::time::timeout(std::time::Duration::from_secs(5), legacy_endpoint.accept()) + .await + .expect("legacy endpoint should receive incoming connection") + .expect("accept should return an incoming connection"); + let mut accepting = incoming.accept().expect("legacy accept should succeed"); + let alpn = accepting.alpn().await.expect("ALPN should be available"); + assert_eq!( + alpn, ALPN_V0, + "v1 node must negotiate ALPN_V0 with legacy endpoint" + ); + let conn = accepting + .await + .expect("legacy connection handshake should complete"); + + let (mut send_gossip, mut recv_gossip) = + tokio::time::timeout(std::time::Duration::from_secs(5), conn.accept_bi()) + .await + .expect("v1 node should open gossip stream") + .expect("gossip stream accept should succeed"); + let mut stream_type = [0u8; 1]; + recv_gossip + .read_exact(&mut stream_type) + .await + .expect("must read gossip stream type byte"); + assert_eq!( + stream_type[0], STREAM_GOSSIP, + "first stream must be STREAM_GOSSIP" + ); + let _post_gossip_buf = read_len_prefixed(&mut recv_gossip) + .await + .expect("must read v1 gossip payload"); + let legacy_gossip_body = + serde_json::to_vec(&vec![legacy_ann]).expect("legacy announcement must serialize"); + write_len_prefixed(&mut send_gossip, &legacy_gossip_body) + .await + .expect("legacy must reply with JSON gossip"); + send_gossip + .finish() + .expect("gossip reply must finish cleanly"); + let _ = recv_gossip.read_to_end(0).await; + + tokio::time::timeout( + std::time::Duration::from_secs(5), + admitted_signal.notified(), + ) + .await + .expect("main task should signal admission within 5s"); + + let (mut send_leaving, _recv_leaving) = + tokio::time::timeout(std::time::Duration::from_secs(5), conn.open_bi()) + .await + .expect("should open peer-leaving stream") + .expect("peer-leaving stream open should succeed"); + send_leaving + .write_all(&[STREAM_PEER_LEAVING]) + .await + .expect("must write peer-leaving type byte"); + send_leaving + .write_all(&legacy_id_bytes) + .await + .expect("must write raw 32-byte legacy peer ID"); + send_leaving + .finish() + .expect("peer-leaving send stream must finish"); + + // Keep endpoint alive until main task confirms peer removal. + // Dropping legacy_endpoint sends CONNECTION_CLOSE prematurely. + tokio::time::timeout(std::time::Duration::from_secs(10), done_signal.notified()) + .await + .expect("main task should signal done within 10s"); + }); + + let invite_token = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&legacy_addr).expect("legacy address must serialize")); + post_node.join(&invite_token).await?; + + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + let peers = post_node.peers().await; + if peers.iter().any(|p| p.id == legacy_id) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect("post node should admit the legacy peer after JSON gossip exchange"); + + admitted.notify_one(); + + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + let peers = post_node.peers().await; + if !peers.iter().any(|p| p.id == legacy_id) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect( + "v1 node should remove legacy peer after receiving v0 peer-leaving frame within 5s", + ); + + let peers = post_node.peers().await; + assert!( + !peers.iter().any(|p| p.id == legacy_id), + "legacy peer must be absent from the peer list after its clean-shutdown announcement" + ); + + done.notify_one(); + server + .await + .expect("legacy server task should complete without panic"); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn mixed_protocol_three_node_mesh_state_consistency() -> Result<()> { + use iroh::endpoint::{ConnectOptions, QuicTransportConfig}; + + let node_a = make_test_node(super::NodeRole::Host { http_port: 9337 }).await?; + node_a + .set_serving_models(vec!["node-a-model".to_string()]) + .await; + node_a.set_mesh_id("three-node-mesh-001".to_string()).await; + node_a.start_accepting(); + let node_a_id = node_a.id(); + let node_a_addr = node_a.endpoint.addr(); + + let node_b = make_test_node(super::NodeRole::Host { http_port: 9338 }).await?; + node_b + .set_serving_models(vec!["node-b-model".to_string()]) + .await; + node_b.set_mesh_id("three-node-mesh-001".to_string()).await; + let node_b_id = node_b.id(); + + let legacy_endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V0.to_vec()]) + .transport_config( + QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(), + ) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + let legacy_id = legacy_endpoint.id(); + + let invite_token_a = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&node_a_addr).expect("node_a addr must serialize")); + node_b.join(&invite_token_a).await?; + + let connecting = tokio::time::timeout( + std::time::Duration::from_secs(5), + legacy_endpoint.connect_with_opts(node_a_addr, ALPN_V0, ConnectOptions::new()), + ) + .await + .expect("v0 connect_with_opts should not timeout") + .expect("v0 connect_with_opts should succeed"); + let v0_conn = tokio::time::timeout(std::time::Duration::from_secs(5), connecting) + .await + .expect("v0β†’node_a handshake should not timeout") + .expect("v0β†’node_a handshake should succeed"); + assert_eq!( + v0_conn.alpn(), + ALPN_V0, + "v0 endpoint must negotiate ALPN_V0 with the v1 node" + ); + + let (mut send_g, mut recv_g) = + tokio::time::timeout(std::time::Duration::from_secs(5), v0_conn.open_bi()) + .await + .expect("v0 should open gossip stream") + .expect("v0 gossip stream open should succeed"); + send_g + .write_all(&[STREAM_GOSSIP]) + .await + .expect("v0 must write gossip type byte"); + let v0_ann = super::PeerAnnouncementV0 { + addr: EndpointAddr { + id: legacy_id, + addrs: Default::default(), + }, + role: super::NodeRole::Host { http_port: 9555 }, + models: vec!["v0-model".to_string()], + vram_bytes: 8 * 1024 * 1024 * 1024, + model_source: None, + serving: Some("v0-model".to_string()), + serving_models: vec!["v0-model".to_string()], + available_models: vec![], + requested_models: vec![], + version: Some("0.50.0".to_string()), + model_demand: HashMap::new(), + mesh_id: Some("three-node-mesh-001".to_string()), + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + }; + let v0_gossip_json = + serde_json::to_vec(&vec![v0_ann]).expect("v0 gossip JSON must serialize"); + write_len_prefixed(&mut send_g, &v0_gossip_json) + .await + .expect("v0 must write gossip JSON payload"); + send_g.finish().expect("v0 gossip send stream must finish"); + let _node_a_gossip_resp = tokio::time::timeout( + std::time::Duration::from_secs(5), + read_len_prefixed(&mut recv_g), + ) + .await + .expect("node_a must respond to v0 gossip within 5 seconds") + .expect("v0 must read node_a gossip response"); + let _ = recv_g.read_to_end(0).await; + + tokio::time::timeout(std::time::Duration::from_secs(10), async { + loop { + let peers = node_a.peers().await; + let has_b = peers.iter().any(|p| p.id == node_b_id); + let has_v0 = peers.iter().any(|p| p.id == legacy_id); + if has_b && has_v0 { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + }) + .await + .expect("node_a must see both node_b and v0 peer within 10 seconds"); + + let node_a_peers = node_a.peers().await; + assert!( + node_a_peers.iter().any(|p| { + p.id == node_b_id + && p.serving_models.first().map(String::as_str) == Some("node-b-model") + }), + "node_a must see node_b with its correct serving model" + ); + assert!( + node_a_peers.iter().any(|p| { + p.id == legacy_id + && p.serving_models.first().map(String::as_str) == Some("v0-model") + }), + "node_a must see the v0 peer with its correct serving model" + ); + + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + let peers = node_b.peers().await; + if peers.iter().any(|p| p.id == node_a_id) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect("node_b must see node_a after joining"); + + assert!( + node_b.peers().await.iter().any(|p| { + p.id == node_a_id + && p.serving_models.first().map(String::as_str) == Some("node-a-model") + }), + "node_b must see node_a with its correct serving model" + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn protocol_negotiation_edge_cases() -> Result<()> { + use iroh::endpoint::{ConnectOptions, QuicTransportConfig}; + + assert_eq!( + protocol_from_alpn(b""), + ControlProtocol::ProtoV1, + "empty ALPN must default to ProtoV1" + ); + assert_eq!( + protocol_from_alpn(b"unknown"), + ControlProtocol::ProtoV1, + "unrecognised ALPN must default to ProtoV1" + ); + assert_eq!( + protocol_from_alpn(b"mesh-llm"), + ControlProtocol::ProtoV1, + "partial ALPN prefix without version number must default to ProtoV1" + ); + + // Sub-test A: v1 node connecting to a v0-only endpoint negotiates ALPN_V0 + let v0_endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V0.to_vec()]) + .transport_config( + QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(), + ) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + let v0_addr = v0_endpoint.addr(); + let v0_accept_task = tokio::spawn(async move { + let incoming = + tokio::time::timeout(std::time::Duration::from_secs(5), v0_endpoint.accept()) + .await + .expect("v0 endpoint should receive an incoming connection") + .expect("v0 accept should yield an incoming connection"); + let mut accepting = incoming.accept().expect("v0 accept should succeed"); + let _alpn = accepting.alpn().await.expect("ALPN should be available"); + let conn = accepting + .await + .expect("v0 connection handshake should complete"); + assert_eq!( + conn.alpn(), + ALPN_V0, + "v0 endpoint must see ALPN_V0 on the accepted connection" + ); + assert_eq!( + connection_protocol(&conn), + ControlProtocol::JsonV0, + "v0 endpoint must identify the connection as JsonV0" + ); + }); + + let post_node = make_test_node(super::NodeRole::Worker).await?; + let conn_a = tokio::time::timeout( + std::time::Duration::from_secs(5), + connect_mesh(&post_node.endpoint, v0_addr), + ) + .await + .expect("v1β†’v0 connect should not timeout") + .expect("v1 node should connect successfully to v0-only endpoint"); + assert_eq!( + conn_a.alpn(), + ALPN_V0, + "v1 node connecting to a v0-only endpoint must negotiate ALPN_V0" + ); + assert_eq!( + connection_protocol(&conn_a), + ControlProtocol::JsonV0, + "connection from v1 to v0-only endpoint must use JsonV0 protocol" + ); + + v0_accept_task + .await + .expect("v0 accept task should complete without panic"); + + let node_b = make_test_node(super::NodeRole::Worker).await?; + node_b.start_accepting(); + let node_b_addr = node_b.endpoint.addr(); + + let v0_ep2 = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V0.to_vec()]) + .transport_config( + QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(), + ) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + let connecting = tokio::time::timeout( + std::time::Duration::from_secs(5), + v0_ep2.connect_with_opts(node_b_addr, ALPN_V0, ConnectOptions::new()), + ) + .await + .expect("v0β†’v1 connect_with_opts should not timeout") + .expect("v0 endpoint should connect to v1 node"); + let conn_b = tokio::time::timeout(std::time::Duration::from_secs(5), connecting) + .await + .expect("v0β†’v1 handshake should not timeout") + .expect("v0β†’v1 connection handshake should succeed"); + assert_eq!( + conn_b.alpn(), + ALPN_V0, + "v0 endpoint connecting to a v1 node must negotiate ALPN_V0" + ); + assert_eq!( + connection_protocol(&conn_b), + ControlProtocol::JsonV0, + "v0 endpoint connecting to a v1 node must use JsonV0 protocol" + ); + + Ok(()) + } + + fn make_test_peer(id: EndpointId, rtt_ms: Option<u32>, vram_gb: u64) -> PeerInfo { + PeerInfo { + id, + addr: EndpointAddr { + id, + addrs: Default::default(), + }, + role: super::NodeRole::Worker, + models: vec![], + vram_bytes: vram_gb * 1024 * 1024 * 1024, + rtt_ms, + model_source: None, + serving_models: vec![], + hosted_models: vec![], + hosted_models_known: false, + available_models: vec![], + requested_models: vec![], + last_seen: std::time::Instant::now(), + last_mentioned: std::time::Instant::now(), + moe_recovered_at: None, + version: None, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: None, + gpu_compute_tflops_fp32: None, + gpu_compute_tflops_fp16: None, + available_model_metadata: vec![], + experts_summary: None, + tunnel_port: None, + available_model_sizes: HashMap::new(), + served_model_descriptors: vec![], + served_model_runtime: vec![], + owner_attestation: None, + owner_summary: OwnershipSummary::default(), + } + } + + /// RTT re-election: when a peer's RTT drops from above the 80ms split + /// threshold to below it (e.g. relay β†’ direct), update_peer_rtt must + /// trigger a peer_change event so the election loop re-runs and can + /// now include the peer in split mode. + #[tokio::test] + async fn test_rtt_drop_triggers_reelection() -> Result<()> { + let node = make_test_node(super::NodeRole::Worker).await?; + let peer_key = SecretKey::generate(&mut rand::rng()); + let peer_id = EndpointId::from(peer_key.public()); + + // Add a fake peer with high relay RTT + { + let mut state = node.state.lock().await; + state + .peers + .insert(peer_id, make_test_peer(peer_id, Some(2600), 16)); + } + + let rx = node.peer_change_rx.clone(); + + // Update RTT to still-high value β€” should NOT trigger + node.update_peer_rtt(peer_id, 500).await; + assert!( + !rx.has_changed() + .expect("peer_change_rx closed unexpectedly"), + "RTT 2600β†’500 (both above threshold) should not trigger re-election" + ); + + // Update RTT to below threshold β€” SHOULD trigger + node.update_peer_rtt(peer_id, 15).await; + assert!( + rx.has_changed() + .expect("peer_change_rx closed unexpectedly"), + "RTT 500β†’15 (crossing threshold) must trigger re-election" + ); + + Ok(()) + } + + /// RTT re-election should NOT trigger when RTT was already below threshold. + #[tokio::test] + async fn test_rtt_below_threshold_no_reelection() -> Result<()> { + let node = make_test_node(super::NodeRole::Worker).await?; + let peer_key = SecretKey::generate(&mut rand::rng()); + let peer_id = EndpointId::from(peer_key.public()); + + { + let mut state = node.state.lock().await; + state + .peers + .insert(peer_id, make_test_peer(peer_id, Some(20), 16)); + } + + let rx = node.peer_change_rx.clone(); + + // Update RTT to another low value β€” should NOT trigger + node.update_peer_rtt(peer_id, 15).await; + assert!( + !rx.has_changed() + .expect("peer_change_rx closed unexpectedly"), + "RTT 20β†’15 (both below threshold) should not trigger re-election" + ); + + Ok(()) + } + + /// RTT re-election should NOT trigger for unknown peers. + #[tokio::test] + async fn test_rtt_update_unknown_peer_no_panic() -> Result<()> { + let node = make_test_node(super::NodeRole::Worker).await?; + let peer_key = SecretKey::generate(&mut rand::rng()); + let peer_id = EndpointId::from(peer_key.public()); + + let rx = node.peer_change_rx.clone(); + + // Update RTT for a peer that doesn't exist β€” should not panic or trigger + node.update_peer_rtt(peer_id, 15).await; + assert!( + !rx.has_changed() + .expect("peer_change_rx closed unexpectedly"), + "RTT update for unknown peer should not trigger re-election" + ); + + Ok(()) + } + + /// RTT should never increase β€” relay gossip RTT must not overwrite + /// a known-good direct path measurement. + #[tokio::test] + async fn test_rtt_cannot_regress() -> Result<()> { + let node = make_test_node(super::NodeRole::Worker).await?; + let peer_key = SecretKey::generate(&mut rand::rng()); + let peer_id = EndpointId::from(peer_key.public()); + + { + let mut state = node.state.lock().await; + state + .peers + .insert(peer_id, make_test_peer(peer_id, Some(20), 16)); + } + + // Try to raise RTT β€” should be rejected + node.update_peer_rtt(peer_id, 2600).await; + { + let state = node.state.lock().await; + let rtt = state.peers.get(&peer_id).unwrap().rtt_ms; + assert_eq!(rtt, Some(20), "RTT must not increase from 20 to 2600"); + } + + // Lower RTT β€” should be accepted + node.update_peer_rtt(peer_id, 10).await; + { + let state = node.state.lock().await; + let rtt = state.peers.get(&peer_id).unwrap().rtt_ms; + assert_eq!(rtt, Some(10), "RTT must decrease from 20 to 10"); + } + + Ok(()) + } + + /// Regression test: connect_to_peer must skip peers already in state.peers, + /// even if there's no QUIC connection yet (transitive peers from gossip). + /// If this check uses state.connections instead, every transitive peer + /// triggers a 15s dial timeout and --client --auto hangs. + /// See: d631c8d (broke it), 6ece4d1 (first revert). + #[tokio::test] + async fn test_connect_to_peer_skips_known_peer_without_connection() -> Result<()> { + let node = make_test_node(super::NodeRole::Client).await?; + let peer_key = SecretKey::generate(&mut rand::rng()); + let peer_id = EndpointId::from(peer_key.public()); + + // Simulate a transitive peer: in state.peers but NOT in state.connections + { + let mut state = node.state.lock().await; + state + .peers + .insert(peer_id, make_test_peer(peer_id, Some(50), 8)); + assert!( + !state.connections.contains_key(&peer_id), + "setup: peer must not have a connection" + ); + } + + // connect_to_peer must return Ok immediately (peer already known). + // If it tries to dial, it will either timeout (15s) or fail β€” both wrong. + let result = tokio::time::timeout( + std::time::Duration::from_secs(1), + node.connect_to_peer(super::EndpointAddr { + id: peer_id, + addrs: Default::default(), + }), + ) + .await; + + assert!( + result.is_ok(), + "connect_to_peer must not attempt to dial a peer already in state.peers" + ); + assert!( + result.unwrap().is_ok(), + "connect_to_peer must return Ok for known peers" + ); + + Ok(()) + } + + #[test] + fn config_sync_subscribe_snapshot_encode_decode() { + use crate::proto::node::{ConfigSnapshotResponse, NodeConfigSnapshot, NodeGpuConfig}; + + let snapshot = ConfigSnapshotResponse { + gen: NODE_PROTOCOL_GENERATION, + node_id: vec![0xAA; 32], + revision: 7, + config_hash: vec![0xBB; 32], + config: Some(NodeConfigSnapshot { + version: 1, + gpu: Some(NodeGpuConfig { + assignment: crate::proto::node::GpuAssignment::Auto as i32, + }), + models: vec![], + plugins: vec![], + }), + hostname: Some("test-host".to_string()), + error: None, + }; + + let encoded = encode_control_frame(STREAM_CONFIG_SUBSCRIBE, &snapshot); + let decoded: ConfigSnapshotResponse = + decode_control_frame(STREAM_CONFIG_SUBSCRIBE, &encoded) + .expect("round-trip must succeed"); + + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + assert_eq!(decoded.node_id, vec![0xAA; 32]); + assert_eq!(decoded.revision, 7); + assert_eq!(decoded.config_hash, vec![0xBB; 32]); + assert_eq!(decoded.hostname, Some("test-host".to_string())); + let cfg = decoded.config.expect("config must be present"); + assert_eq!(cfg.version, 1); + let gpu = cfg.gpu.expect("gpu must be present"); + assert_eq!( + gpu.assignment, + crate::proto::node::GpuAssignment::Auto as i32 + ); + } + + #[test] + fn config_sync_subscribe_not_before_admission() { + assert!( + !stream_allowed_before_admission(STREAM_CONFIG_SUBSCRIBE), + "STREAM_CONFIG_SUBSCRIBE (0x0b) must require admission β€” it is an owner-gated config stream" + ); + } + + fn test_signing_key() -> (ed25519_dalek::SigningKey, String) { + let signing_key = ed25519_dalek::SigningKey::from_bytes(&[0x42u8; 32]); + let verifying = signing_key.verifying_key(); + let owner_id = crate::crypto::owner_id_from_verifying_key(&verifying); + (signing_key, owner_id) + } + + #[test] + fn config_sync_push_signature_payload_deterministic() { + use crate::proto::node::{ConfigPush, NodeConfigSnapshot}; + + let push = ConfigPush { + gen: NODE_PROTOCOL_GENERATION, + requester_id: vec![0xAA; 32], + target_node_id: vec![0xBB; 32], + owner_signing_public_key: vec![0x42u8; 32], + expected_revision: 3, + config: Some(NodeConfigSnapshot { + version: 1, + gpu: None, + models: vec![], + plugins: vec![], + }), + signature: vec![0u8; 64], + }; + + let p1 = config_push_signature_payload(&push); + let p2 = config_push_signature_payload(&push); + assert_eq!(p1, p2, "payload must be deterministic for the same input"); + assert!(!p1.is_empty(), "payload must not be empty"); + } + + // config_sync_push_wrong_owner_detected was removed: the `owner_id` field no longer + // exists in ConfigPush. Wrong-owner detection is now handled entirely through the + // gossip-attested peer identity check in handle_config_push. + + #[test] + fn config_sync_push_bad_signature_bytes_length() { + let bad_sig: Vec<u8> = vec![0u8; 32]; + let result: Result<[u8; 64], _> = bad_sig.as_slice().try_into(); + assert!( + result.is_err(), + "32-byte slice must not convert to [u8; 64] β€” wrong-length signature must be rejected" + ); + + let good_sig: Vec<u8> = vec![0u8; 64]; + let result: Result<[u8; 64], _> = good_sig.as_slice().try_into(); + assert!(result.is_ok(), "64-byte slice must convert to [u8; 64]"); + } + + #[test] + fn config_sync_push_roundtrip_encode_decode() { + use crate::proto::node::{ConfigApplyMode, ConfigPushResponse}; + use prost::Message as _; + + let response = ConfigPushResponse { + gen: NODE_PROTOCOL_GENERATION, + success: true, + current_revision: 42, + config_hash: vec![0xCC; 32], + error: None, + apply_mode: ConfigApplyMode::Staged as i32, + }; + + let encoded = response.encode_to_vec(); + let decoded = ConfigPushResponse::decode(encoded.as_slice()) + .expect("ConfigPushResponse must round-trip through encode/decode"); + + assert_eq!(decoded.gen, NODE_PROTOCOL_GENERATION); + assert!(decoded.success); + assert_eq!(decoded.current_revision, 42); + assert_eq!(decoded.config_hash, vec![0xCC; 32]); + assert!(decoded.error.is_none()); + assert_eq!(decoded.apply_mode, ConfigApplyMode::Staged as i32); + } + + #[test] + fn config_sync_sign_and_verify_roundtrip() { + use crate::proto::node::{ConfigPush, NodeConfigSnapshot}; + use ed25519_dalek::Signer as _; + + let (signing_key, owner_id) = test_signing_key(); + let vk = signing_key.verifying_key(); + + let mut push = ConfigPush { + gen: NODE_PROTOCOL_GENERATION, + requester_id: vec![0xAA; 32], + target_node_id: vec![0xBB; 32], + owner_signing_public_key: vk.to_bytes().to_vec(), + expected_revision: 0, + config: Some(NodeConfigSnapshot { + version: 1, + gpu: None, + models: vec![], + plugins: vec![], + }), + signature: vec![0u8; 64], + }; + + let payload = config_push_signature_payload(&push); + let sig = signing_key.sign(&payload); + push.signature = sig.to_bytes().to_vec(); + + // Verify: re-derive owner_id from vk and check signature + let pk_bytes: [u8; 32] = push.owner_signing_public_key.as_slice().try_into().unwrap(); + let restored_vk = ed25519_dalek::VerifyingKey::from_bytes(&pk_bytes).unwrap(); + let derived_id = crate::crypto::owner_id_from_verifying_key(&restored_vk); + assert_eq!(derived_id, owner_id, "owner_id must match key fingerprint"); + + let payload2 = config_push_signature_payload(&push); + let sig_bytes: [u8; 64] = push.signature.as_slice().try_into().unwrap(); + let sig_obj = ed25519_dalek::Signature::from_bytes(&sig_bytes); + restored_vk + .verify_strict(&payload2, &sig_obj) + .expect("signature must verify against the canonical payload"); + } + + #[test] + fn config_sync_signature_payload_excludes_signature_field() { + use crate::proto::node::{ConfigPush, NodeConfigSnapshot}; + + let mut push = ConfigPush { + gen: NODE_PROTOCOL_GENERATION, + requester_id: vec![0xAA; 32], + target_node_id: vec![0xBB; 32], + owner_signing_public_key: vec![0x42u8; 32], + expected_revision: 0, + config: Some(NodeConfigSnapshot { + version: 1, + gpu: None, + models: vec![], + plugins: vec![], + }), + signature: vec![0u8; 64], + }; + + let payload_with_sig = config_push_signature_payload(&push); + + // Change only the signature field β€” the canonical payload must not change + push.signature = vec![0xFF; 64]; + let payload_different_sig = config_push_signature_payload(&push); + + assert_eq!( + payload_with_sig, payload_different_sig, + "payload must be identical regardless of the signature field value" + ); + + // Change a semantic field β€” the canonical payload MUST change + push.expected_revision = 99; + let payload_changed = config_push_signature_payload(&push); + assert_ne!( + payload_with_sig, payload_changed, + "payload must change when a semantic field changes" + ); + } + + fn test_owner_keypair(signing_seed: u8, encryption_seed: u8) -> crate::crypto::OwnerKeypair { + crate::crypto::OwnerKeypair::from_bytes(&[signing_seed; 32], &[encryption_seed; 32]) + .expect("test owner keypair must be valid") + } + + /// Create a test `Node` with a verified local owner attestation and a + /// `ConfigState` whose backing file lives in `config_dir`. + async fn make_test_node_with_owner( + role: super::NodeRole, + owner_keypair: &crate::crypto::OwnerKeypair, + config_dir: &std::path::Path, + ) -> Result<Node> { + use iroh::endpoint::QuicTransportConfig; + + let config_path = config_dir.join("config.toml"); + let config_state = + crate::runtime::config_state::ConfigState::load(&config_path).unwrap_or_default(); + + let transport_config = QuicTransportConfig::builder() + .max_concurrent_bidi_streams(128u32.into()) + .build(); + let endpoint = Endpoint::empty_builder() + .secret_key(SecretKey::generate(&mut rand::rng())) + .alpns(vec![ALPN_V1.to_vec(), ALPN_V0.to_vec()]) + .transport_config(transport_config) + .bind_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0)))? + .bind() + .await?; + + let (peer_change_tx, peer_change_rx) = watch::channel(0usize); + let (inflight_change_tx, _) = watch::channel(0u64); + let (tunnel_tx, _tunnel_rx) = tokio::sync::mpsc::channel(8); + let (tunnel_http_tx, _tunnel_http_rx) = tokio::sync::mpsc::channel(8); + let revision = config_state.revision(); + let owner_attestation = sign_node_ownership( + owner_keypair, + endpoint.id().as_bytes(), + current_time_unix_ms() + DEFAULT_NODE_CERT_LIFETIME_SECS * 1000, + None, + None, + )?; + let trust_store = TrustStore::default(); + let owner_summary = verify_node_ownership( + Some(&owner_attestation), + endpoint.id().as_bytes(), + &trust_store, + TrustPolicy::Off, + current_time_unix_ms(), + ); + + let node = Node { + endpoint, + public_addr: None, + state: Arc::new(Mutex::new(MeshState { + peers: HashMap::new(), + connections: HashMap::new(), + remote_tunnel_maps: HashMap::new(), + dead_peers: HashSet::new(), + seen_plugin_messages: HashSet::new(), + seen_plugin_message_order: VecDeque::new(), + policy_rejected_peers: HashMap::new(), + })), + role: Arc::new(Mutex::new(role)), + models: Arc::new(Mutex::new(Vec::new())), + model_source: Arc::new(Mutex::new(None)), + serving_models: Arc::new(Mutex::new(Vec::new())), + served_model_descriptors: Arc::new(Mutex::new(Vec::new())), + model_runtime_descriptors: Arc::new(Mutex::new(Vec::new())), + hosted_models: Arc::new(Mutex::new(Vec::new())), + llama_ready: Arc::new(Mutex::new(false)), + available_models: Arc::new(Mutex::new(Vec::new())), + requested_models: Arc::new(Mutex::new(Vec::new())), + model_demand: Arc::new(std::sync::Mutex::new(HashMap::new())), + mesh_id: Arc::new(Mutex::new(None)), + accepting: Arc::new(( + tokio::sync::Notify::new(), + std::sync::atomic::AtomicBool::new(false), + )), + vram_bytes: 64 * 1024 * 1024 * 1024, + peer_change_tx, + peer_change_rx, + inflight_requests: Arc::new(std::sync::atomic::AtomicUsize::new(0)), + inflight_change_tx, + tunnel_tx, + tunnel_http_tx, + plugin_manager: Arc::new(Mutex::new(None)), + display_name: Arc::new(Mutex::new(None)), + owner_attestation: Arc::new(Mutex::new(Some(owner_attestation))), + owner_summary: Arc::new(Mutex::new(owner_summary)), + trust_store: Arc::new(Mutex::new(trust_store)), + trust_policy: TrustPolicy::Off, + enumerate_host: false, + gpu_name: None, + hostname: None, + is_soc: None, + gpu_vram: None, + gpu_reserved_bytes: None, + gpu_mem_bandwidth_gbps: Arc::new(tokio::sync::Mutex::new(None)), + gpu_compute_tflops_fp32: Arc::new(tokio::sync::Mutex::new(None)), + gpu_compute_tflops_fp16: Arc::new(tokio::sync::Mutex::new(None)), + config_state: Arc::new(tokio::sync::Mutex::new(config_state)), + config_revision_tx: { + let (tx, _rx) = tokio::sync::watch::channel(revision); + Arc::new(tx) + }, + }; + + let accept_node = node.clone(); + tokio::spawn(async move { + accept_node.accept_loop().await; + }); + + Ok(node) + } + + /// Helper: build and sign a ConfigPush proto for the given node/owner/config. + /// Build a `ConfigPush` proto that is correctly signed with `signing_key`. + /// + /// The resulting push targets `target_node_id`, is attributed to `requester_id`, + /// and carries `expected_revision` for CAS enforcement. The signature covers the + /// canonical protobuf encoding of the push with the `signature` field cleared. + fn build_signed_config_push( + owner_keypair: &crate::crypto::OwnerKeypair, + requester_id: &EndpointId, + target_node_id: &EndpointId, + expected_revision: u64, + config: crate::proto::node::NodeConfigSnapshot, + ) -> crate::proto::node::ConfigPush { + use ed25519_dalek::Signer as _; + + let vk = owner_keypair.signing.verifying_key(); + + let mut push = crate::proto::node::ConfigPush { + gen: NODE_PROTOCOL_GENERATION, + requester_id: requester_id.as_bytes().to_vec(), + target_node_id: target_node_id.as_bytes().to_vec(), + owner_signing_public_key: vk.to_bytes().to_vec(), + expected_revision, + config: Some(config), + signature: vec![0u8; 64], + }; + let payload = config_push_signature_payload(&push); + let sig = owner_keypair.signing.sign(&payload); + push.signature = sig.to_bytes().to_vec(); + push + } + + /// Wait until `node` has `target` in its peers list. Times out after 5 s. + /// Poll `node.peers()` until `target` appears in the list. + /// + /// Panics (via `expect`) if `target` is not admitted within 5 seconds. + async fn wait_for_peer(node: &Node, target: EndpointId) { + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + if node.peers().await.iter().any(|p| p.id == target) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + }) + .await + .expect("peer was not admitted within 5 s"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_subscribe_matching_owner_receives_snapshot() -> Result<()> { + let owner_keypair = test_owner_keypair(0x11, 0x12); + + let tmp = std::env::temp_dir().join(format!("mesh-llm-cfg-sub-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("server")).ok(); + std::fs::create_dir_all(tmp.join("client")).ok(); + + let server = make_test_node_with_owner( + super::NodeRole::Host { http_port: 9337 }, + &owner_keypair, + &tmp.join("server"), + ) + .await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &owner_keypair, &tmp.join("client")) + .await?; + + server + .set_mesh_id("cfg-subscribe-mesh-01".to_string()) + .await; + client + .set_mesh_id("cfg-subscribe-mesh-01".to_string()) + .await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + let (snapshot, _notif_rx) = client.subscribe_to_config(&conn).await?; + + assert_eq!( + snapshot.node_id, + server_id.as_bytes().to_vec(), + "snapshot node_id must be the server's endpoint id" + ); + assert_eq!( + snapshot.config_hash.len(), + 32, + "config_hash must be 32 bytes" + ); + assert!( + snapshot.config.is_some(), + "snapshot must include config payload" + ); + assert!( + snapshot.error.is_none() || snapshot.error.as_deref() == Some(""), + "snapshot must not carry an error" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_subscribe_wrong_owner_returns_error() -> Result<()> { + let server_owner = test_owner_keypair(0x22, 0x23); + let client_owner = test_owner_keypair(0x33, 0x34); + + let tmp = + std::env::temp_dir().join(format!("mesh-llm-cfg-wrong-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("server")).ok(); + std::fs::create_dir_all(tmp.join("client")).ok(); + + let server = make_test_node_with_owner( + super::NodeRole::Host { http_port: 9337 }, + &server_owner, + &tmp.join("server"), + ) + .await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &client_owner, &tmp.join("client")) + .await?; + + server + .set_mesh_id("cfg-wrong-owner-mesh-01".to_string()) + .await; + client + .set_mesh_id("cfg-wrong-owner-mesh-01".to_string()) + .await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + // Subscribe - the subscriber's attested owner doesn't match the server's owner + let result = client.subscribe_to_config(&conn).await; + assert!( + result.is_err(), + "subscribing with wrong owner_id must return an error" + ); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("owner_id mismatch") || err_msg.contains("rejected"), + "error must mention owner mismatch, got: {err_msg}" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_subscribe_unowned_node_returns_error() -> Result<()> { + let client_owner = test_owner_keypair(0x44, 0x45); + + let tmp = + std::env::temp_dir().join(format!("mesh-llm-cfg-unowned-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("client")).ok(); + + // server has NO owner key (make_test_node, not make_test_node_with_owner) + let server = make_test_node(super::NodeRole::Host { http_port: 9337 }).await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &client_owner, &tmp.join("client")) + .await?; + + server.set_mesh_id("cfg-unowned-mesh-01".to_string()).await; + client.set_mesh_id("cfg-unowned-mesh-01".to_string()).await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + let result = client.subscribe_to_config(&conn).await; + assert!( + result.is_err(), + "subscribing to an unowned node must return an error" + ); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("no local owner") || err_msg.contains("rejected"), + "error must mention missing owner, got: {err_msg}" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_push_valid_signature_accepted() -> Result<()> { + use crate::proto::node::{NodeConfigSnapshot, NodeGpuConfig}; + use crate::protocol::write_len_prefixed; + use prost::Message as _; + + let owner_keypair = test_owner_keypair(0x55, 0x56); + + let tmp = + std::env::temp_dir().join(format!("mesh-llm-cfg-push-ok-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("server")).ok(); + std::fs::create_dir_all(tmp.join("client")).ok(); + + let server = make_test_node_with_owner( + super::NodeRole::Host { http_port: 9337 }, + &owner_keypair, + &tmp.join("server"), + ) + .await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &owner_keypair, &tmp.join("client")) + .await?; + + server.set_mesh_id("cfg-push-ok-mesh-01".to_string()).await; + client.set_mesh_id("cfg-push-ok-mesh-01".to_string()).await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let client_id = client.id(); + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + let new_config = NodeConfigSnapshot { + version: 1, + gpu: Some(NodeGpuConfig { + assignment: crate::proto::node::GpuAssignment::Auto as i32, + }), + models: vec![], + plugins: vec![], + }; + + let push = build_signed_config_push(&owner_keypair, &client_id, &server_id, 0, new_config); + + let (mut send, mut recv) = conn.open_bi().await?; + send.write_all(&[STREAM_CONFIG_PUSH]).await?; + write_len_prefixed(&mut send, &push.encode_to_vec()).await?; + send.finish()?; + + let buf = crate::protocol::read_len_prefixed(&mut recv).await?; + let response = crate::proto::node::ConfigPushResponse::decode(buf.as_slice())?; + + assert!( + response.success, + "valid signed push must be accepted: {:?}", + response.error + ); + assert_eq!( + response.current_revision, 1, + "revision must be bumped to 1 after first push" + ); + assert_eq!( + response.config_hash.len(), + 32, + "response config_hash must be 32 bytes" + ); + assert_eq!( + response.apply_mode, + crate::proto::node::ConfigApplyMode::Staged as i32, + "config push should report staged apply mode" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_push_revision_conflict_rejected() -> Result<()> { + use crate::proto::node::{NodeConfigSnapshot, NodeGpuConfig}; + use crate::protocol::write_len_prefixed; + use prost::Message as _; + + let owner_keypair = test_owner_keypair(0x66, 0x67); + + let tmp = + std::env::temp_dir().join(format!("mesh-llm-cfg-conflict-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("server")).ok(); + std::fs::create_dir_all(tmp.join("client")).ok(); + + let server = make_test_node_with_owner( + super::NodeRole::Host { http_port: 9337 }, + &owner_keypair, + &tmp.join("server"), + ) + .await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &owner_keypair, &tmp.join("client")) + .await?; + + server.set_mesh_id("cfg-conflict-mesh-01".to_string()).await; + client.set_mesh_id("cfg-conflict-mesh-01".to_string()).await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let client_id = client.id(); + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + let good_config = NodeConfigSnapshot { + version: 1, + gpu: Some(NodeGpuConfig { + assignment: crate::proto::node::GpuAssignment::Auto as i32, + }), + models: vec![], + plugins: vec![], + }; + + // First push (revision 0 β†’ 1) β€” must succeed + let push1 = build_signed_config_push( + &owner_keypair, + &client_id, + &server_id, + 0, + good_config.clone(), + ); + let (mut send1, mut recv1) = conn.open_bi().await?; + send1.write_all(&[STREAM_CONFIG_PUSH]).await?; + write_len_prefixed(&mut send1, &push1.encode_to_vec()).await?; + send1.finish()?; + let buf1 = crate::protocol::read_len_prefixed(&mut recv1).await?; + let resp1 = crate::proto::node::ConfigPushResponse::decode(buf1.as_slice())?; + assert!(resp1.success, "first push must succeed: {:?}", resp1.error); + + // Second push with stale expected_revision=0 β€” must be rejected + let push2 = + build_signed_config_push(&owner_keypair, &client_id, &server_id, 0, good_config); + let (mut send2, mut recv2) = conn.open_bi().await?; + send2.write_all(&[STREAM_CONFIG_PUSH]).await?; + write_len_prefixed(&mut send2, &push2.encode_to_vec()).await?; + send2.finish()?; + let buf2 = crate::protocol::read_len_prefixed(&mut recv2).await?; + let resp2 = crate::proto::node::ConfigPushResponse::decode(buf2.as_slice())?; + + assert!(!resp2.success, "push with stale revision must be rejected"); + assert_eq!( + resp2.current_revision, 1, + "rejection response must carry the current revision" + ); + let err = resp2.error.as_deref().unwrap_or(""); + assert!( + err.contains("revision conflict"), + "error must mention revision conflict, got: {err}" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_push_bad_signature_rejected() -> Result<()> { + use crate::proto::node::{NodeConfigSnapshot, NodeGpuConfig}; + use crate::protocol::write_len_prefixed; + use prost::Message as _; + + let owner_keypair = test_owner_keypair(0x77, 0x78); + + let tmp = + std::env::temp_dir().join(format!("mesh-llm-cfg-badsig-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("server")).ok(); + std::fs::create_dir_all(tmp.join("client")).ok(); + + let server = make_test_node_with_owner( + super::NodeRole::Host { http_port: 9337 }, + &owner_keypair, + &tmp.join("server"), + ) + .await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &owner_keypair, &tmp.join("client")) + .await?; + + server.set_mesh_id("cfg-badsig-mesh-01".to_string()).await; + client.set_mesh_id("cfg-badsig-mesh-01".to_string()).await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let client_id = client.id(); + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + let config = NodeConfigSnapshot { + version: 1, + gpu: Some(NodeGpuConfig { + assignment: crate::proto::node::GpuAssignment::Auto as i32, + }), + models: vec![], + plugins: vec![], + }; + + // Build a push but corrupt the signature + let mut push = build_signed_config_push(&owner_keypair, &client_id, &server_id, 0, config); + push.signature = vec![0xDE; 64]; // garbage signature + + let (mut send, mut recv) = conn.open_bi().await?; + send.write_all(&[STREAM_CONFIG_PUSH]).await?; + write_len_prefixed(&mut send, &push.encode_to_vec()).await?; + send.finish()?; + + let buf = crate::protocol::read_len_prefixed(&mut recv).await?; + let response = crate::proto::node::ConfigPushResponse::decode(buf.as_slice())?; + + assert!( + !response.success, + "push with invalid signature must be rejected" + ); + let err = response.error.as_deref().unwrap_or(""); + assert!( + err.contains("signature"), + "error must mention signature verification, got: {err}" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn config_subscribe_delivers_update_notification_after_push() -> Result<()> { + use crate::proto::node::{NodeConfigSnapshot, NodeGpuConfig}; + use crate::protocol::write_len_prefixed; + use prost::Message as _; + + let owner_keypair = test_owner_keypair(0x88, 0x89); + + let tmp = + std::env::temp_dir().join(format!("mesh-llm-cfg-notif-{}", rand::random::<u64>())); + std::fs::create_dir_all(tmp.join("server")).ok(); + std::fs::create_dir_all(tmp.join("client")).ok(); + + let server = make_test_node_with_owner( + super::NodeRole::Host { http_port: 9337 }, + &owner_keypair, + &tmp.join("server"), + ) + .await?; + let client = + make_test_node_with_owner(super::NodeRole::Worker, &owner_keypair, &tmp.join("client")) + .await?; + + server.set_mesh_id("cfg-notif-mesh-01".to_string()).await; + client.set_mesh_id("cfg-notif-mesh-01".to_string()).await; + server.start_accepting(); + client.start_accepting(); + + let server_id = server.id(); + let server_addr = server.endpoint.addr(); + let invite = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&server_addr)?); + + client.join(&invite).await?; + wait_for_peer(&client, server_id).await; + wait_for_peer(&server, client.id()).await; + + let client_id = client.id(); + let conn = { + let state = client.state.lock().await; + state + .connections + .get(&server_id) + .cloned() + .expect("connection to server must exist after join") + }; + + // Subscribe to config on the server from the client + let (initial_snapshot, mut notif_rx) = client.subscribe_to_config(&conn).await?; + let initial_revision = initial_snapshot.revision; + + // Now push a config change to the server from the client + let new_config = NodeConfigSnapshot { + version: 1, + gpu: Some(NodeGpuConfig { + assignment: crate::proto::node::GpuAssignment::Auto as i32, + }), + models: vec![crate::proto::node::NodeModelEntry { + model: "test-model.gguf".to_string(), + mmproj: None, + ctx_size: None, + gpu_id: None, + model_ref: None, + mmproj_ref: None, + }], + plugins: vec![], + }; + let push = build_signed_config_push( + &owner_keypair, + &client_id, + &server_id, + initial_revision, + new_config, + ); + let (mut send, mut recv) = conn.open_bi().await?; + send.write_all(&[STREAM_CONFIG_PUSH]).await?; + write_len_prefixed(&mut send, &push.encode_to_vec()).await?; + send.finish()?; + let buf = crate::protocol::read_len_prefixed(&mut recv).await?; + let push_resp = crate::proto::node::ConfigPushResponse::decode(buf.as_slice())?; + assert!( + push_resp.success, + "push must be accepted for notification test: {:?}", + push_resp.error + ); + + // The subscribe stream must deliver a ConfigUpdateNotification for the change + tokio::time::timeout(std::time::Duration::from_secs(5), notif_rx.changed()) + .await + .expect("ConfigUpdateNotification must arrive within 5 s") + .expect("notification channel must not be closed"); + + let notif = notif_rx.borrow_and_update().clone(); + assert_eq!( + notif.revision, + initial_revision + 1, + "notification revision must be initial + 1" + ); + assert!( + !notif.config_hash.is_empty(), + "notification must carry config_hash" + ); + + std::fs::remove_dir_all(&tmp).ok(); + Ok(()) + } +} +/// Generate a mesh ID for a new mesh. +/// Named meshes: `sha256("mesh-llm:" + name + ":" + nostr_pubkey)` β€” deterministic, unique per creator. +/// Unnamed meshes: random UUID, persisted to `~/.mesh-llm/mesh-id`. +pub fn generate_mesh_id(name: Option<&str>, nostr_pubkey: Option<&str>) -> String { + if let Some(name) = name { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + "mesh-llm:".hash(&mut hasher); + name.hash(&mut hasher); + if let Some(pk) = nostr_pubkey { + pk.hash(&mut hasher); + } + format!("{:016x}", hasher.finish()) + } else { + // Try to load persisted mesh-id + let path = mesh_id_path(); + if let Ok(id) = std::fs::read_to_string(&path) { + let id = id.trim().to_string(); + if !id.is_empty() { + return id; + } + } + // Generate new random ID and persist + let id = format!( + "{:016x}{:016x}", + rand::random::<u64>(), + rand::random::<u64>() + ); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let _ = std::fs::write(&path, &id); + id + } +} + +fn mesh_id_path() -> std::path::PathBuf { + dirs::home_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join(".mesh-llm") + .join("mesh-id") +} + +/// Save the mesh ID of the last mesh we successfully joined. +pub fn save_last_mesh_id(mesh_id: &str) { + let path = dirs::home_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join(".mesh-llm") + .join("last-mesh"); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let _ = std::fs::write(&path, mesh_id); +} + +/// Load the mesh ID of the last mesh we successfully joined. +pub fn load_last_mesh_id() -> Option<String> { + let path = dirs::home_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join(".mesh-llm") + .join("last-mesh"); + std::fs::read_to_string(&path) + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) +} + +// --------------------------------------------------------------------------- +// Public-to-private identity transition +// --------------------------------------------------------------------------- + +fn was_public_path() -> std::path::PathBuf { + dirs::home_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join(".mesh-llm") + .join("was-public") +} + +/// Record that this node was started in public mode (--auto / --publish / --mesh-name). +/// Called at startup so we can detect a publicβ†’private transition next time. +pub fn mark_was_public() { + let path = was_public_path(); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let _ = std::fs::write(&path, "1"); +} + +/// Returns true if the previous run was public (marker file exists). +pub fn was_previously_public() -> bool { + was_public_path().exists() +} + +/// Clear identity files (key, nostr.nsec, mesh-id, last-mesh, was-public) so the +/// next start gets a completely fresh identity. Called when transitioning from +/// public β†’ private to avoid reusing a publicly-known identity in a private mesh. +pub fn clear_public_identity() { + let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from(".")); + let dir = home.join(".mesh-llm"); + let mut ok = true; + for name in &["key", "nostr.nsec", "mesh-id", "last-mesh"] { + let p = dir.join(name); + if p.exists() { + if std::fs::remove_file(&p).is_ok() { + tracing::info!("Cleared {}", p.display()); + } else { + tracing::warn!("Failed to clear {}", p.display()); + ok = false; + } + } + } + // Only remove the marker after identity files are gone, so a failed + // cleanup is retried on the next private start. + let marker = dir.join("was-public"); + if ok { + let _ = std::fs::remove_file(&marker); + } else { + tracing::warn!("Keeping was-public marker β€” will retry cleanup next start"); + } +} + +/// Load secret key from ~/.mesh-llm/key, or create a new one and save it. +async fn load_or_create_key() -> Result<SecretKey> { + let key_path = default_node_key_path()?; + let dir = key_path + .parent() + .ok_or_else(|| anyhow::anyhow!("Invalid node key path {}", key_path.display()))?; + ensure_private_node_key_dir(dir)?; + + if key_path.exists() { + ensure_private_node_key_file(&key_path)?; + let hex = tokio::fs::read_to_string(&key_path).await?; + let bytes = hex::decode(hex.trim())?; + if bytes.len() != 32 { + anyhow::bail!("Invalid key length in {}", key_path.display()); + } + let key = SecretKey::from_bytes(&bytes.try_into().unwrap()); + tracing::info!("Loaded key from {}", key_path.display()); + return Ok(key); + } + + let key = SecretKey::generate(&mut rand::rng()); save_node_key_to_path(&key_path, &key)?; tracing::info!("Generated new key, saved to {}", key_path.display()); Ok(key) @@ -4433,8 +8886,5 @@ pub(crate) use heartbeat::{ moe_recovery_ready_at, peer_is_eligible_for_active_moe, resolve_peer_down, }; -#[cfg(test)] -mod tests; - #[cfg(test)] mod public_identity_tests; diff --git a/mesh-llm/src/mlx/mod.rs b/mesh-llm/src/mlx/mod.rs new file mode 100644 index 00000000..92eb291b --- /dev/null +++ b/mesh-llm/src/mlx/mod.rs @@ -0,0 +1,16 @@ +//! Native MLX inference backend. +//! +//! Runs quantized safetensors models directly on Apple Silicon Metal +//! via mlx-rs (Rust bindings to MLX C). No Python, no subprocess. +//! +//! This module provides `start_mlx_server()` as a drop-in replacement +//! for `launch::start_llama_server()` β€” same contract (port + death_rx), +//! so the proxy and election machinery work unchanged. + +pub mod model; +pub mod sampling; +pub mod server; +pub mod template; + +pub use model::{is_mlx_model_dir, mlx_model_dir}; +pub use server::start_mlx_server; diff --git a/mesh-llm/src/mlx/model/artifacts.rs b/mesh-llm/src/mlx/model/artifacts.rs new file mode 100644 index 00000000..1b03c467 --- /dev/null +++ b/mesh-llm/src/mlx/model/artifacts.rs @@ -0,0 +1,227 @@ +use super::family::{config_supports_mlx, detect_architecture_from_safetensors_header}; +use anyhow::{bail, Context, Result}; +use mlx_rs::Array; +use serde_json::Value; +use std::collections::HashMap; +use std::path::Path; + +#[derive(Debug, Clone)] +pub struct TokenizerSpacingPatch { + pub special_tokens: Vec<(String, u32)>, + pub space_token_id: u32, +} + +pub(super) struct TensorPrefixes { + pub model: String, + pub lm_head: Option<String>, +} + +pub(super) fn tensor_prefixes(tensors: &HashMap<String, Array>) -> Result<TensorPrefixes> { + if tensors.contains_key("model.embed_tokens.weight") { + return Ok(TensorPrefixes { + model: "model".to_string(), + lm_head: Some("lm_head".to_string()), + }); + } + if tensors.contains_key("language_model.model.embed_tokens.weight") { + return Ok(TensorPrefixes { + model: "language_model.model".to_string(), + lm_head: Some("language_model.lm_head".to_string()), + }); + } + bail!("unsupported MLX tensor prefix layout") +} + +pub(super) fn load_all_safetensors(dir: &Path) -> Result<HashMap<String, Array>> { + let index_path = dir.join("model.safetensors.index.json"); + if index_path.exists() { + let index: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&index_path)?)?; + let weight_map = index["weight_map"] + .as_object() + .context("missing weight_map in index")?; + let mut tensors = HashMap::new(); + let mut seen = std::collections::HashSet::new(); + for file in weight_map.values() { + let file = file.as_str().context("weight_map value not a string")?; + if seen.insert(file.to_string()) { + tensors.extend(Array::load_safetensors(dir.join(file))?); + } + } + Ok(tensors) + } else { + let st_path = dir.join("model.safetensors"); + if st_path.exists() { + return Ok(Array::load_safetensors(st_path)?); + } + + let mut shard_paths = std::fs::read_dir(dir) + .with_context(|| format!("reading MLX model directory {}", dir.display()))? + .filter_map(|entry| entry.ok().map(|entry| entry.path())) + .filter(|path| { + path.file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| { + name.starts_with("model-") && name.ends_with(".safetensors") + }) + }) + .collect::<Vec<_>>(); + shard_paths.sort(); + if shard_paths.is_empty() { + bail!("no MLX safetensors weights found in {}", dir.display()); + } + + let mut tensors = HashMap::new(); + for shard_path in shard_paths { + tensors.extend(Array::load_safetensors(shard_path)?); + } + Ok(tensors) + } +} + +fn normalize_model_dir(path: &Path) -> Option<&Path> { + if path.is_dir() { + return Some(path); + } + let name = path.file_name()?.to_str()?; + match name { + "config.json" | "chat_template.jinja" | "tokenizer.json" | "tokenizer_config.json" => { + path.parent() + } + _ if name.ends_with(".safetensors") || name == "model.safetensors.index.json" => { + path.parent() + } + _ => None, + } +} + +fn has_required_model_files(dir: &Path) -> bool { + let has_config = dir.join("config.json").exists(); + let has_tokenizer = + dir.join("tokenizer_config.json").exists() || dir.join("tokenizer.json").exists(); + let has_sharded_weights = std::fs::read_dir(dir).ok().is_some_and(|entries| { + entries.filter_map(|entry| entry.ok()).any(|entry| { + entry + .file_name() + .to_str() + .is_some_and(|name| name.starts_with("model-") && name.ends_with(".safetensors")) + }) + }); + let has_weights = dir.join("model.safetensors").exists() + || dir.join("model.safetensors.index.json").exists() + || has_sharded_weights; + has_config && has_tokenizer && has_weights +} + +pub(super) fn read_model_config(dir: &Path) -> Option<Value> { + let text = std::fs::read_to_string(dir.join("config.json")).ok()?; + serde_json::from_str(&text).ok() +} + +pub(super) fn patch_phi3_special_token_whitespace(tokenizer_json: &mut Value, config_json: &Value) { + let is_phi3 = config_json + .get("model_type") + .and_then(|value| value.as_str()) + .is_some_and(|value| value.eq_ignore_ascii_case("phi3")); + if is_phi3 { + let preserve_following_whitespace = ["<|assistant|>", "<|user|>", "<|system|>", "<|end|>"]; + if let Some(tokens) = tokenizer_json + .get_mut("added_tokens") + .and_then(|value| value.as_array_mut()) + { + for token in tokens { + let should_patch = token + .get("content") + .and_then(|value| value.as_str()) + .is_some_and(|value| preserve_following_whitespace.contains(&value)); + if should_patch { + token["rstrip"] = Value::Bool(false); + } + } + } + } +} + +fn mistral_tokenizer_spacing_patch( + tokenizer: &tokenizers::Tokenizer, + tokenizer_json: &Value, + config_json: &Value, +) -> Result<Option<TokenizerSpacingPatch>> { + let is_mistral = config_json + .get("model_type") + .and_then(|value| value.as_str()) + .is_some_and(|value| value.eq_ignore_ascii_case("mistral")); + if !is_mistral { + return Ok(None); + } + let mut special_tokens = tokenizer_json + .get("added_tokens") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter(|token| token.get("special").and_then(|value| value.as_bool()) == Some(true)) + .filter_map(|token| { + Some(( + token.get("content")?.as_str()?.to_string(), + token.get("id")?.as_u64()? as u32, + )) + }) + .collect::<Vec<_>>(); + if special_tokens.is_empty() { + return Ok(None); + } + special_tokens.sort_by(|(lhs, _), (rhs, _)| rhs.len().cmp(&lhs.len())); + let space_token_id = tokenizer + .encode(" ", false) + .map_err(|e| anyhow::anyhow!("loading mistral spacing patch: {e}"))? + .get_ids() + .first() + .copied() + .context("loading mistral spacing patch: tokenizer encoded space to zero tokens")?; + Ok(Some(TokenizerSpacingPatch { + special_tokens, + space_token_id, + })) +} + +pub(super) fn load_tokenizer( + dir: &Path, + config_json: &Value, +) -> Result<(tokenizers::Tokenizer, Option<TokenizerSpacingPatch>)> { + let tokenizer_path = dir.join("tokenizer.json"); + let mut tokenizer_json: Value = serde_json::from_str( + &std::fs::read_to_string(&tokenizer_path).context("reading tokenizer.json")?, + ) + .context("parsing tokenizer.json")?; + patch_phi3_special_token_whitespace(&mut tokenizer_json, config_json); + + let tokenizer = tokenizers::Tokenizer::from_bytes( + serde_json::to_vec(&tokenizer_json).context("serializing patched tokenizer.json")?, + ) + .map_err(|e| anyhow::anyhow!("loading tokenizer: {e}"))?; + let spacing_patch = mistral_tokenizer_spacing_patch(&tokenizer, &tokenizer_json, config_json)?; + Ok((tokenizer, spacing_patch)) +} + +pub fn mlx_model_dir(path: &Path) -> Option<&Path> { + let dir = normalize_model_dir(path)?; + if has_required_model_files(dir) { + Some(dir) + } else { + None + } +} + +pub fn is_mlx_model_dir(path: &Path) -> bool { + let Some(dir) = mlx_model_dir(path) else { + return false; + }; + + if let Some(config) = read_model_config(dir) { + if config_supports_mlx(&config) { + return true; + } + } + + detect_architecture_from_safetensors_header(dir).is_some() +} diff --git a/mesh-llm/src/mlx/model/attention.rs b/mesh-llm/src/mlx/model/attention.rs new file mode 100644 index 00000000..00b1ba43 --- /dev/null +++ b/mesh-llm/src/mlx/model/attention.rs @@ -0,0 +1,619 @@ +use super::*; +use mlx_rs::array; + +// ── Attention ── + +pub(crate) struct Attention { + pub(super) q_proj: QuantizedLinear, + pub(super) k_proj: QuantizedLinear, + pub(super) v_proj: QuantizedLinear, + pub(super) o_proj: QuantizedLinear, + pub(super) q_norm: Option<RMSNorm>, + pub(super) k_norm: Option<RMSNorm>, + pub(super) v_norm: Option<RMSNorm>, + pub(super) num_heads: i32, + pub(super) num_kv_heads: i32, + pub(super) head_dim: i32, + pub(super) scale: f32, + pub(super) attn_logit_softcapping: Option<f32>, + pub(super) rope_dim: i32, + pub(super) rope_theta: f32, + pub(super) rope_traditional: bool, + pub(super) window_size: Option<i32>, + pub(super) kv_shared_source: Option<usize>, +} + +impl Attention { + pub(super) fn apply_qk_norm( + x: Array, + norm: Option<&RMSNorm>, + b: i32, + l: i32, + num_heads: i32, + head_dim: i32, + ) -> Result<Array> { + let Some(norm) = norm else { + return Ok(x.reshape(&[b, l, num_heads, head_dim])?); + }; + let norm_width = norm.weight.shape()[0]; + if norm_width == num_heads * head_dim { + return Ok(norm.forward(&x)?.reshape(&[b, l, num_heads, head_dim])?); + } + norm.forward(&x.reshape(&[b, l, num_heads, head_dim])?) + } + + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + let shape = x.shape(); + let (b, l) = (shape[0], shape[1]); + + let q = self.q_proj.forward(x)?; + let q = Self::apply_qk_norm(q, self.q_norm.as_ref(), b, l, self.num_heads, self.head_dim)? + .transpose_axes(&[0, 2, 1, 3])?; + let q = apply_rope( + &q, + self.rope_dim, + self.head_dim, + self.rope_theta, + self.rope_traditional, + 0, + )?; + + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + let k = Self::apply_qk_norm( + k, + self.k_norm.as_ref(), + b, + l, + self.num_kv_heads, + self.head_dim, + )? + .transpose_axes(&[0, 2, 1, 3])?; + let v = v.reshape(&[b, l, self.num_kv_heads, self.head_dim])?; + let v = if let Some(norm) = &self.v_norm { + norm.forward(&v)? + } else { + v + } + .transpose_axes(&[0, 2, 1, 3])?; + let k = apply_rope( + &k, + self.rope_dim, + self.head_dim, + self.rope_theta, + self.rope_traditional, + 0, + )?; + + let mask = if self.window_size.is_some() { + attention_mask(l, l, 0, 0, self.window_size)? + } else { + None + }; + let attn = if self.attn_logit_softcapping.is_some() || mask.is_some() { + manual_scaled_dot_product_attention_with_mask( + &q, + &k, + &v, + self.scale, + self.attn_logit_softcapping, + mask.as_ref(), + )? + } else { + let mask = if l > 1 { + Some(mlx_rs::fast::ScaledDotProductAttentionMask::Causal) + } else { + None + }; + mlx_rs::fast::scaled_dot_product_attention(&q, &k, &v, self.scale, mask)? + }; + + let attn = + attn.transpose_axes(&[0, 2, 1, 3])? + .reshape(&[b, l, self.num_heads * self.head_dim])?; + self.o_proj.forward(&attn) + } + + pub(super) fn forward( + &self, + x: &Array, + cache: &mut KVCache, + shared_cache: Option<&KVCache>, + ) -> Result<Array> { + let shape = x.shape(); + let (b, l) = (shape[0], shape[1]); + + let q = self.q_proj.forward(x)?; + let q = Self::apply_qk_norm(q, self.q_norm.as_ref(), b, l, self.num_heads, self.head_dim)? + .transpose_axes(&[0, 2, 1, 3])?; + + let offset = shared_cache.unwrap_or(&*cache).offset() as i32; + let q = apply_rope( + &q, + self.rope_dim, + self.head_dim, + self.rope_theta, + self.rope_traditional, + offset, + )?; + let (cache_entries, key_start) = if let Some(shared_cache) = shared_cache { + let (k, v) = shared_cache + .views() + .context("Gemma4 shared KV cache was empty")?; + let key_start = + shared_cache.key_start_for_attention(l as usize, k.shape()[2] as usize) as i32; + (CachedKv::Dense { keys: k, values: v }, key_start) + } else { + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + let k = Self::apply_qk_norm( + k, + self.k_norm.as_ref(), + b, + l, + self.num_kv_heads, + self.head_dim, + )? + .transpose_axes(&[0, 2, 1, 3])?; + let v = v.reshape(&[b, l, self.num_kv_heads, self.head_dim])?; + let v = if let Some(norm) = &self.v_norm { + norm.forward(&v)? + } else { + v + } + .transpose_axes(&[0, 2, 1, 3])?; + let k = apply_rope( + &k, + self.rope_dim, + self.head_dim, + self.rope_theta, + self.rope_traditional, + offset, + )?; + let entries = cache.update_cached(k, v)?; + let key_start = + (offset as usize + l as usize).saturating_sub(entries.key_len() as usize) as i32; + (entries, key_start) + }; + + // Causal mask for prefill (multi-token). Decode (l=1) needs no mask. + let mask = if self.window_size.is_some() { + attention_mask( + l, + cache_entries.key_len(), + key_start, + offset, + self.window_size, + )? + } else { + None + }; + let attn = match cache_entries { + CachedKv::Dense { keys, values } => { + if self.attn_logit_softcapping.is_some() || mask.is_some() { + manual_scaled_dot_product_attention_with_mask( + &q, + &keys, + &values, + self.scale, + self.attn_logit_softcapping, + mask.as_ref(), + )? + } else { + let mask = if l > 1 { + Some(mlx_rs::fast::ScaledDotProductAttentionMask::Causal) + } else { + None + }; + mlx_rs::fast::scaled_dot_product_attention( + &q, &keys, &values, self.scale, mask, + )? + } + } + CachedKv::Quantized { + keys, + values, + group_size, + bits, + } => { + anyhow::ensure!( + self.attn_logit_softcapping.is_none(), + "quantized KV cache does not support attention softcapping yet" + ); + quantized_scaled_dot_product_attention_with_mask( + &q, + &keys, + &values, + self.scale, + mask.as_ref(), + group_size, + bits, + )? + } + }; + + let attn = + attn.transpose_axes(&[0, 2, 1, 3])? + .reshape(&[b, l, self.num_heads * self.head_dim])?; + + self.o_proj.forward(&attn) + } +} + +pub(crate) struct DeepseekV3Attention { + pub(super) q_proj: Option<QuantizedLinear>, + pub(super) q_a_proj: Option<QuantizedLinear>, + pub(super) q_a_layernorm: Option<RMSNorm>, + pub(super) q_b_proj: Option<QuantizedLinear>, + pub(super) kv_a_proj_with_mqa: QuantizedLinear, + pub(super) kv_a_layernorm: RMSNorm, + pub(super) embed_q: QuantizedMultiLinear, + pub(super) unembed_out: QuantizedMultiLinear, + pub(super) o_proj: QuantizedLinear, + pub(super) num_heads: i32, + pub(super) q_head_dim: i32, + pub(super) qk_rope_head_dim: i32, + pub(super) qk_nope_head_dim: i32, + pub(super) kv_lora_rank: i32, + pub(super) v_head_dim: i32, + pub(super) scale: f32, + pub(super) rope_theta: f32, +} + +impl DeepseekV3Attention { + fn build_q(&self, x: &Array) -> Result<Array> { + let q = if let Some(q_proj) = &self.q_proj { + q_proj.forward(x)? + } else { + self.q_b_proj + .as_ref() + .context("missing q_b_proj for DeepSeekV3 attention")? + .forward( + &self + .q_a_layernorm + .as_ref() + .context("missing q_a_layernorm for DeepSeekV3 attention")? + .forward( + &self + .q_a_proj + .as_ref() + .context("missing q_a_proj for DeepSeekV3 attention")? + .forward(x)?, + )?, + )? + }; + Ok(q) + } + + fn attention_mask(&self, q_pe: &Array, k_pe: &Array, causal: bool) -> Result<Array> { + let mut pe_scores = mlx_rs::ops::matmul( + &q_pe.multiply(&array!(self.scale))?, + &k_pe.transpose_axes(&[0, 1, 3, 2])?, + )?; + if causal { + let mask = attention_mask(q_pe.shape()[2], k_pe.shape()[2], 0, 0, None)? + .context("expected causal mask")?; + let fill = array!(pe_scores.dtype().finfo_min()? as f32).as_dtype(pe_scores.dtype())?; + pe_scores = mlx_rs::ops::r#where(&mask, &pe_scores, &fill)?; + } + Ok(pe_scores) + } + + fn forward_impl(&self, x: &Array, cache: Option<&mut KVCache>) -> Result<Array> { + let shape = x.shape(); + let (b, l) = (shape[0], shape[1]); + + let q = self + .build_q(x)? + .reshape(&[b, l, self.num_heads, self.q_head_dim])? + .transpose_axes(&[0, 2, 1, 3])?; + let q_nope = q.index(( + std::ops::RangeFull, + std::ops::RangeFull, + std::ops::RangeFull, + ..self.qk_nope_head_dim, + )); + let q_pe = q.index(( + std::ops::RangeFull, + std::ops::RangeFull, + std::ops::RangeFull, + self.qk_nope_head_dim.., + )); + + let compressed_kv = self.kv_a_proj_with_mqa.forward(x)?; + let kv_latent = compressed_kv.index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..self.kv_lora_rank, + )); + let k_pe = compressed_kv.index(( + std::ops::RangeFull, + std::ops::RangeFull, + self.kv_lora_rank.., + )); + let kv_latent = self.kv_a_layernorm.forward(&kv_latent)?.expand_dims(1)?; + let k_pe = k_pe + .reshape(&[b, l, 1, self.qk_rope_head_dim])? + .transpose_axes(&[0, 2, 1, 3])?; + + let offset = cache + .as_ref() + .map(|cache| cache.offset() as i32) + .unwrap_or(0); + let q_pe = apply_rope( + &q_pe, + self.qk_rope_head_dim, + self.qk_rope_head_dim, + self.rope_theta, + false, + offset, + )?; + let k_pe = apply_rope( + &k_pe, + self.qk_rope_head_dim, + self.qk_rope_head_dim, + self.rope_theta, + false, + offset, + )?; + + let (kv_latent, k_pe) = if let Some(cache) = cache { + cache.update(kv_latent, k_pe)? + } else { + (kv_latent, k_pe) + }; + + let mask = self.attention_mask(&q_pe, &k_pe, l > 1)?; + let output = if l == 1 { + let q_nope = self.embed_q.forward(&q_nope, true)?; + let output = mlx_rs::fast::scaled_dot_product_attention( + &q_nope, + &kv_latent, + &kv_latent, + self.scale, + Some((&mask).into()), + )?; + self.unembed_out.forward(&output, true)? + } else { + let k = self.embed_q.forward(&kv_latent, false)?; + let v = self.unembed_out.forward(&kv_latent, true)?; + mlx_rs::fast::scaled_dot_product_attention( + &q_nope, + &k, + &v, + self.scale, + Some((&mask).into()), + )? + }; + + let output = output.transpose_axes(&[0, 2, 1, 3])?.reshape(&[ + b, + l, + self.num_heads * self.v_head_dim, + ])?; + self.o_proj.forward(&output) + } + + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + self.forward_impl(x, None) + } + + pub(super) fn forward(&self, x: &Array, cache: &mut KVCache) -> Result<Array> { + self.forward_impl(x, Some(cache)) + } +} + +pub(super) fn attention_mask( + query_len: i32, + key_len: i32, + key_start: i32, + query_start: i32, + window_size: Option<i32>, +) -> Result<Option<Array>> { + if query_len == 1 && window_size.is_none() { + return Ok(None); + } + + let key_positions = mlx_rs::ops::arange::<_, i32>(key_start, key_start + key_len, 1)?; + let query_positions = mlx_rs::ops::arange::<_, i32>(query_start, query_start + query_len, 1)?; + let left = query_positions.expand_dims(1)?; + let right = key_positions.expand_dims(0)?; + let mut mask = left.ge(&right)?; + if let Some(window_size) = window_size { + let upper_bound = right.add(&array!(window_size))?; + mask = mask.logical_and(&left.lt(&upper_bound)?)?; + } + Ok(Some(mask)) +} + +fn manual_scaled_dot_product_attention_with_mask( + q: &Array, + k: &Array, + v: &Array, + scale: f32, + softcap: Option<f32>, + mask: Option<&Array>, +) -> Result<Array> { + let num_heads = q.shape()[1]; + let num_kv_heads = k.shape()[1]; + anyhow::ensure!( + num_heads % num_kv_heads == 0, + "cannot align attention heads: q_heads={}, kv_heads={}", + num_heads, + num_kv_heads + ); + let repeats = num_heads / num_kv_heads; + let batch = q.shape()[0]; + let query_len = q.shape()[2]; + let head_dim = q.shape()[3]; + + let mut queries = q.clone(); + if scale != 1.0 { + queries = queries.multiply(&array!(scale))?; + } + + let (queries, keys, values) = if repeats > 1 { + ( + queries.reshape(&[batch, num_kv_heads, repeats, query_len, head_dim])?, + k.expand_dims(2)?, + v.expand_dims(2)?, + ) + } else { + (queries, k.clone(), v.clone()) + }; + + let key_t = if repeats > 1 { + keys.transpose_axes(&[0, 1, 2, 4, 3])? + } else { + keys.transpose_axes(&[0, 1, 3, 2])? + }; + let mut scores = mlx_rs::ops::matmul(&queries, &key_t)?; + if let Some(softcap) = softcap { + scores = scores.divide(&array!(softcap))?; + scores = mlx_rs::ops::tanh(&scores)?.multiply(&array!(softcap))?; + } + if let Some(mask) = mask { + let fill = array!(scores.dtype().finfo_min()? as f32).as_dtype(scores.dtype())?; + scores = mlx_rs::ops::r#where(mask, &scores, &fill)?; + } + let probs = mlx_rs::ops::softmax_axis(&scores, -1, true)?; + let mut output = mlx_rs::ops::matmul(&probs, &values)?; + if repeats > 1 { + output = output.reshape(&[batch, num_heads, query_len, head_dim])?; + } + Ok(output) +} + +fn quantized_scaled_dot_product_attention_with_mask( + q: &Array, + k: &QuantizedCacheArrays, + v: &QuantizedCacheArrays, + scale: f32, + mask: Option<&Array>, + group_size: i32, + bits: i32, +) -> Result<Array> { + let num_heads = q.shape()[1]; + let num_kv_heads = k.data.shape()[1]; + anyhow::ensure!( + num_heads % num_kv_heads == 0, + "cannot align quantized attention heads: q_heads={}, kv_heads={}", + num_heads, + num_kv_heads + ); + let repeats = num_heads / num_kv_heads; + let batch = q.shape()[0]; + let query_len = q.shape()[2]; + let head_dim = q.shape()[3]; + + let mut queries = q.clone(); + if scale != 1.0 { + queries = queries.multiply(&array!(scale))?; + } + + let (queries, keys, values) = if repeats > 1 { + ( + queries.reshape(&[batch, num_kv_heads, repeats, query_len, head_dim])?, + QuantizedCacheArrays { + data: k.data.expand_dims(2)?, + scales: k.scales.expand_dims(2)?, + biases: k.biases.expand_dims(2)?, + }, + QuantizedCacheArrays { + data: v.data.expand_dims(2)?, + scales: v.scales.expand_dims(2)?, + biases: v.biases.expand_dims(2)?, + }, + ) + } else { + ( + queries, + QuantizedCacheArrays { + data: k.data.clone(), + scales: k.scales.clone(), + biases: k.biases.clone(), + }, + QuantizedCacheArrays { + data: v.data.clone(), + scales: v.scales.clone(), + biases: v.biases.clone(), + }, + ) + }; + + let mut scores = mlx_rs::ops::quantized_matmul( + &queries, + &keys.data, + &keys.scales, + &keys.biases, + true, + group_size, + bits, + )?; + + if let Some(mask) = mask { + let fill = array!(scores.dtype().finfo_min()? as f32).as_dtype(scores.dtype())?; + scores = mlx_rs::ops::r#where(mask, &scores, &fill)?; + } + + let probs = mlx_rs::ops::softmax_axis(&scores, -1, true)?; + let mut output = mlx_rs::ops::quantized_matmul( + &probs, + &values.data, + &values.scales, + &values.biases, + false, + group_size, + bits, + )?; + + if repeats > 1 { + output = output.reshape(&[batch, num_heads, query_len, head_dim])?; + } + + Ok(output) +} + +pub(super) fn apply_rope( + x: &Array, + rope_dim: i32, + head_dim: i32, + rope_theta: f32, + rope_traditional: bool, + offset: i32, +) -> Result<Array> { + if rope_dim == head_dim { + return Ok(mlx_rs::fast::rope( + x, + head_dim, + rope_traditional, + Some(rope_theta), + 1.0, + offset, + None::<&Array>, + )?); + } + + let rotated = x.index(( + std::ops::RangeFull, + std::ops::RangeFull, + std::ops::RangeFull, + ..rope_dim, + )); + let rotated = mlx_rs::fast::rope( + &rotated, + rope_dim, + rope_traditional, + Some(rope_theta), + 1.0, + offset, + None::<&Array>, + )?; + let tail = x.index(( + std::ops::RangeFull, + std::ops::RangeFull, + std::ops::RangeFull, + rope_dim.., + )); + Ok(mlx_rs::ops::concatenate_axis(&[&rotated, &tail], 3)?) +} diff --git a/mesh-llm/src/mlx/model/attention_kind.rs b/mesh-llm/src/mlx/model/attention_kind.rs new file mode 100644 index 00000000..2af8c444 --- /dev/null +++ b/mesh-llm/src/mlx/model/attention_kind.rs @@ -0,0 +1,57 @@ +use super::*; + +pub enum AttentionKind { + Standard(Attention), + DeepseekV3(DeepseekV3Attention), + KimiMla(KimiMlaAttention), + KimiDelta(KimiDeltaAttention), + Lfm2ShortConv(Lfm2ShortConv), +} + +impl AttentionKind { + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + match self { + Self::Standard(attn) => attn.forward_no_cache(x), + Self::DeepseekV3(attn) => attn.forward_no_cache(x), + Self::KimiMla(attn) => attn.forward_no_cache(x), + Self::KimiDelta(attn) => attn.forward_no_cache(x), + Self::Lfm2ShortConv(conv) => conv.forward_no_cache(x), + } + } + + pub(super) fn forward( + &self, + x: &Array, + cache: &mut KVCache, + shared_cache: Option<&KVCache>, + ) -> Result<Array> { + match self { + Self::Standard(attn) => attn.forward(x, cache, shared_cache), + Self::DeepseekV3(attn) => attn.forward(x, cache), + Self::KimiMla(_) | Self::KimiDelta(_) => { + bail!("Kimi Linear currently requires cacheless generation") + } + Self::Lfm2ShortConv(_) => { + bail!("LFM2 ShortConv currently requires cacheless generation") + } + } + } + + pub(super) fn kv_shared_source(&self) -> Option<usize> { + match self { + Self::Standard(attn) => attn.kv_shared_source, + Self::DeepseekV3(_) => None, + Self::KimiMla(_) | Self::KimiDelta(_) => None, + Self::Lfm2ShortConv(_) => None, + } + } + + pub(super) fn sliding_window_size(&self) -> Option<usize> { + match self { + Self::Standard(attn) => attn.window_size.map(|size| size as usize), + Self::DeepseekV3(_) => None, + Self::KimiMla(_) | Self::KimiDelta(_) => None, + Self::Lfm2ShortConv(_) => None, + } + } +} diff --git a/mesh-llm/src/mlx/model/cache.rs b/mesh-llm/src/mlx/model/cache.rs new file mode 100644 index 00000000..916ccf08 --- /dev/null +++ b/mesh-llm/src/mlx/model/cache.rs @@ -0,0 +1,772 @@ +use super::*; + +const KV_CACHE_STEP: usize = 256; + +enum KVCacheMode { + Standard, + Rotating { + max_size: usize, + keep: usize, + }, + Quantized { + group_size: i32, + bits: i32, + min_dense_tokens: usize, + }, +} + +pub(super) struct QuantizedCacheArrays { + pub(super) data: Array, + pub(super) scales: Array, + pub(super) biases: Array, +} + +impl QuantizedCacheArrays { + pub(super) fn arrays(&self) -> [&Array; 3] { + [&self.data, &self.scales, &self.biases] + } + + pub(super) fn prefix(&self, end: i32) -> Self { + Self { + data: self.data.index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..end, + std::ops::RangeFull, + )), + scales: self.scales.index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..end, + std::ops::RangeFull, + )), + biases: self.biases.index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..end, + std::ops::RangeFull, + )), + } + } + + pub(super) fn trim_to(&self, end: i32) -> Result<Self> { + Ok(Self { + data: materialize_cache_prefix(&self.data, end)?, + scales: materialize_cache_prefix(&self.scales, end)?, + biases: materialize_cache_prefix(&self.biases, end)?, + }) + } + + pub(super) fn expand(&self, extra_steps: i32) -> Result<Self> { + let base_shape = self.data.shape(); + let scale_shape = self.scales.shape(); + let bias_shape = self.biases.shape(); + let extra_data = mlx_rs::ops::zeros_dtype( + &[base_shape[0], base_shape[1], extra_steps, base_shape[3]], + self.data.dtype(), + )?; + let extra_scales = mlx_rs::ops::zeros_dtype( + &[scale_shape[0], scale_shape[1], extra_steps, scale_shape[3]], + self.scales.dtype(), + )?; + let extra_biases = mlx_rs::ops::zeros_dtype( + &[bias_shape[0], bias_shape[1], extra_steps, bias_shape[3]], + self.biases.dtype(), + )?; + Ok(Self { + data: mlx_rs::ops::concatenate_axis(&[&self.data, &extra_data], 2)?, + scales: mlx_rs::ops::concatenate_axis(&[&self.scales, &extra_scales], 2)?, + biases: mlx_rs::ops::concatenate_axis(&[&self.biases, &extra_biases], 2)?, + }) + } +} + +fn materialize_cache_prefix(array: &Array, end: i32) -> Result<Array> { + use std::ops::RangeFull; + + let end = end.max(0); + let shape = array.shape(); + let mut owned = mlx_rs::ops::zeros_dtype(&[shape[0], shape[1], end, shape[3]], array.dtype())?; + if end > 0 { + let prefix = array.index((RangeFull, RangeFull, ..end, RangeFull)); + owned.try_index_mut((RangeFull, RangeFull, ..end, RangeFull), &prefix)?; + } + Ok(owned) +} + +pub(super) enum CachedKv { + Dense { + keys: Array, + values: Array, + }, + Quantized { + keys: QuantizedCacheArrays, + values: QuantizedCacheArrays, + group_size: i32, + bits: i32, + }, +} + +impl CachedKv { + pub(super) fn key_len(&self) -> i32 { + match self { + CachedKv::Dense { keys, .. } => keys.shape()[2], + CachedKv::Quantized { keys, .. } => keys.data.shape()[2], + } + } +} + +pub struct KVCache { + pub(super) keys: Option<Array>, + pub(super) values: Option<Array>, + pub(super) qkeys: Option<QuantizedCacheArrays>, + pub(super) qvalues: Option<QuantizedCacheArrays>, + start_offset: usize, + offset: usize, + idx: usize, + mode: KVCacheMode, +} + +impl KVCache { + pub fn new() -> Self { + KVCache { + keys: None, + values: None, + qkeys: None, + qvalues: None, + start_offset: 0, + offset: 0, + idx: 0, + mode: KVCacheMode::Standard, + } + } + + pub fn new_rotating(max_size: usize, keep: usize) -> Self { + KVCache { + keys: None, + values: None, + qkeys: None, + qvalues: None, + start_offset: 0, + offset: 0, + idx: 0, + mode: KVCacheMode::Rotating { max_size, keep }, + } + } + + pub fn new_quantized(group_size: i32, bits: i32, min_dense_tokens: usize) -> Self { + KVCache { + keys: None, + values: None, + qkeys: None, + qvalues: None, + start_offset: 0, + offset: 0, + idx: 0, + mode: KVCacheMode::Quantized { + group_size, + bits, + min_dense_tokens, + }, + } + } + + pub fn offset(&self) -> usize { + self.offset + } + + fn current_len(&self) -> usize { + self.offset.saturating_sub(self.start_offset) + } + + pub(super) fn retained_start(&self) -> usize { + self.start_offset + } + + pub fn can_trim_to(&self, n: usize) -> bool { + n <= self.offset && n >= self.retained_start() + } + + /// Return references to cached arrays (for eval/materialization). + pub fn arrays(&self) -> Vec<&Array> { + let mut out = Vec::new(); + match self.mode { + KVCacheMode::Quantized { .. } => { + if let Some(ref k) = self.qkeys { + out.extend(k.arrays()); + } + if let Some(ref v) = self.qvalues { + out.extend(v.arrays()); + } + } + _ => { + if let Some(ref k) = self.keys { + out.push(k); + } + if let Some(ref v) = self.values { + out.push(v); + } + } + } + out + } + + pub fn views(&self) -> Option<(Array, Array)> { + if self.current_len() == 0 { + return None; + } + let keys = self.temporal_order(self.keys.as_ref()?).ok()?; + let values = self.temporal_order(self.values.as_ref()?).ok()?; + Some((keys, values)) + } + + fn temporal_order(&self, array: &Array) -> Result<Array> { + use std::ops::RangeFull; + + match self.mode { + KVCacheMode::Standard => { + let end_i = self.offset as i32; + Ok(array.index((RangeFull, RangeFull, ..end_i, RangeFull))) + } + KVCacheMode::Quantized { .. } => { + let end_i = self.offset as i32; + Ok(array.index((RangeFull, RangeFull, ..end_i, RangeFull))) + } + KVCacheMode::Rotating { keep, .. } => { + let len = array.shape()[2] as usize; + if self.idx == len { + return Ok(array.clone()); + } + if self.idx < self.current_len() { + let mut parts = Vec::new(); + if keep > 0 { + parts.push(array.index((RangeFull, RangeFull, ..(keep as i32), RangeFull))); + } + parts.push(array.index((RangeFull, RangeFull, self.idx as i32.., RangeFull))); + if self.idx > keep { + parts.push(array.index(( + RangeFull, + RangeFull, + keep as i32..self.idx as i32, + RangeFull, + ))); + } + let refs: Vec<&Array> = parts.iter().collect(); + Ok(mlx_rs::ops::concatenate_axis(&refs, 2)?) + } else { + Ok(array.index((RangeFull, RangeFull, ..(self.idx as i32), RangeFull))) + } + } + } + } + + pub(super) fn key_start_for_attention(&self, query_len: usize, key_len: usize) -> usize { + (self.offset + query_len).saturating_sub(key_len) + } + + pub fn update(&mut self, k: Array, v: Array) -> Result<(Array, Array)> { + match self.update_cached(k, v)? { + CachedKv::Dense { keys, values } => Ok((keys, values)), + CachedKv::Quantized { .. } => bail!("quantized KV cache does not expose dense views"), + } + } + + pub(super) fn update_cached(&mut self, k: Array, v: Array) -> Result<CachedKv> { + match self.mode { + KVCacheMode::Standard => self + .update_standard(k, v) + .map(|(keys, values)| CachedKv::Dense { keys, values }), + KVCacheMode::Rotating { max_size, keep } => self.update_rotating(k, v, max_size, keep), + KVCacheMode::Quantized { + group_size, + bits, + min_dense_tokens, + } => self.update_quantized(k, v, group_size, bits, min_dense_tokens), + } + } + + fn update_standard(&mut self, k: Array, v: Array) -> Result<(Array, Array)> { + use std::ops::RangeFull; + + let seq_len = k.shape()[2] as usize; + let prev = self.offset; + + if self.keys.is_none() || (prev + seq_len) > self.keys.as_ref().unwrap().shape()[2] as usize + { + // Grow: pre-allocate in steps, matching the incoming dtype + let [b, n_kv_heads, _, k_head_dim] = k.shape()[..4] else { + bail!("unexpected k shape"); + }; + let v_head_dim = v.shape()[3]; + let k_dtype = k.dtype(); + let v_dtype = v.dtype(); + + let n_steps = ((KV_CACHE_STEP + seq_len - 1) / KV_CACHE_STEP) * KV_CACHE_STEP; + let k_shape = &[b, n_kv_heads, n_steps as i32, k_head_dim]; + let v_shape = &[b, n_kv_heads, n_steps as i32, v_head_dim]; + + let new_k = mlx_rs::ops::zeros_dtype(k_shape, k_dtype)?; + let new_v = mlx_rs::ops::zeros_dtype(v_shape, v_dtype)?; + + if let (Some(ref mut old_k), Some(ref mut old_v)) = (&mut self.keys, &mut self.values) { + if prev % KV_CACHE_STEP != 0 { + *old_k = old_k.index((RangeFull, RangeFull, ..(prev as i32), RangeFull)); + *old_v = old_v.index((RangeFull, RangeFull, ..(prev as i32), RangeFull)); + } + self.keys = Some(mlx_rs::ops::concatenate_axis( + &[old_k as &Array, &new_k], + 2, + )?); + self.values = Some(mlx_rs::ops::concatenate_axis( + &[old_v as &Array, &new_v], + 2, + )?); + } else { + self.keys = Some(new_k); + self.values = Some(new_v); + } + } + + self.offset = prev + seq_len; + self.start_offset = 0; + let prev_i = prev as i32; + let end_i = self.offset as i32; + + // Slice-assign into pre-allocated buffer (no copy of existing data) + self.keys + .as_mut() + .unwrap() + .try_index_mut((RangeFull, RangeFull, prev_i..end_i, RangeFull), &k)?; + self.values + .as_mut() + .unwrap() + .try_index_mut((RangeFull, RangeFull, prev_i..end_i, RangeFull), &v)?; + + // Return views up to current offset + let k_out = self + .keys + .as_ref() + .unwrap() + .index((RangeFull, RangeFull, ..end_i, RangeFull)); + let v_out = self + .values + .as_ref() + .unwrap() + .index((RangeFull, RangeFull, ..end_i, RangeFull)); + + Ok((k_out, v_out)) + } + + fn update_rotating( + &mut self, + k: Array, + v: Array, + max_size: usize, + keep: usize, + ) -> Result<CachedKv> { + let seq_len = k.shape()[2] as usize; + if seq_len == 1 { + return self.update_rotating_in_place(k, v, max_size, keep); + } + self.update_rotating_concat(k, v, max_size, keep) + } + + fn update_rotating_concat( + &mut self, + k: Array, + v: Array, + max_size: usize, + keep: usize, + ) -> Result<CachedKv> { + let seq_len = k.shape()[2] as usize; + if self.keys.is_none() { + self.keys = Some(k); + self.values = Some(v); + } else { + let ordered_k = self.temporal_order(self.keys.as_ref().unwrap())?; + let ordered_v = self.temporal_order(self.values.as_ref().unwrap())?; + self.idx = ordered_k.shape()[2] as usize; + let current_len = ordered_k.shape()[2] as usize; + let trim_size = (current_len + seq_len).saturating_sub(max_size); + self.keys = Some(trim_rotating_cache(&ordered_k, trim_size, keep, Some(&k))?); + self.values = Some(trim_rotating_cache(&ordered_v, trim_size, keep, Some(&v))?); + self.start_offset += trim_size; + } + + self.offset += seq_len; + self.idx = self.keys.as_ref().unwrap().shape()[2] as usize; + let (keys, values) = self + .views() + .context("rotating KV cache was empty after concat update")?; + Ok(CachedKv::Dense { keys, values }) + } + + fn update_rotating_in_place( + &mut self, + k: Array, + v: Array, + max_size: usize, + keep: usize, + ) -> Result<CachedKv> { + use std::ops::RangeFull; + + let seq_len = k.shape()[2] as usize; + debug_assert_eq!(seq_len, 1); + let prev = self.offset; + + let current_capacity = self + .keys + .as_ref() + .map(|keys| keys.shape()[2] as usize) + .unwrap_or(0); + if self.keys.is_none() || (prev >= current_capacity && current_capacity < max_size) { + let [b, n_kv_heads, _, k_head_dim] = k.shape()[..4] else { + bail!("unexpected k shape"); + }; + let v_head_dim = v.shape()[3]; + let new_size = KV_CACHE_STEP + .min(max_size.saturating_sub(prev)) + .max(seq_len); + let new_k = + mlx_rs::ops::zeros_dtype(&[b, n_kv_heads, new_size as i32, k_head_dim], k.dtype())?; + let new_v = + mlx_rs::ops::zeros_dtype(&[b, n_kv_heads, new_size as i32, v_head_dim], v.dtype())?; + if let (Some(old_k), Some(old_v)) = (&self.keys, &self.values) { + self.keys = Some(mlx_rs::ops::concatenate_axis(&[old_k, &new_k], 2)?); + self.values = Some(mlx_rs::ops::concatenate_axis(&[old_v, &new_v], 2)?); + } else { + self.keys = Some(new_k); + self.values = Some(new_v); + } + self.idx = prev; + } + + let current_capacity = self.keys.as_ref().unwrap().shape()[2] as usize; + let trim_size = current_capacity.saturating_sub(max_size); + if trim_size > 0 { + let ordered_k = self.temporal_order(self.keys.as_ref().unwrap())?; + let ordered_v = self.temporal_order(self.values.as_ref().unwrap())?; + self.keys = Some(trim_rotating_cache(&ordered_k, trim_size, keep, None)?); + self.values = Some(trim_rotating_cache(&ordered_v, trim_size, keep, None)?); + self.idx = max_size; + self.start_offset += trim_size; + } + + let evicted = usize::from(self.current_len() == max_size); + if self.idx == max_size { + self.idx = keep; + } + + let start = self.idx as i32; + let end = (self.idx + seq_len) as i32; + self.keys + .as_mut() + .unwrap() + .try_index_mut((RangeFull, RangeFull, start..end, RangeFull), &k)?; + self.values + .as_mut() + .unwrap() + .try_index_mut((RangeFull, RangeFull, start..end, RangeFull), &v)?; + + self.offset += seq_len; + self.start_offset += evicted; + self.idx += seq_len; + let (keys, values) = self + .views() + .context("rotating KV cache was empty after in-place update")?; + Ok(CachedKv::Dense { keys, values }) + } + + fn quantize_dense_prefix( + &mut self, + group_size: i32, + bits: i32, + dense_len: usize, + ) -> Result<()> { + if dense_len == 0 { + return Ok(()); + } + + let dense_end = dense_len as i32; + let keys = self + .keys + .as_ref() + .context("missing dense keys while migrating to quantized KV")? + .index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..dense_end, + std::ops::RangeFull, + )); + let values = self + .values + .as_ref() + .context("missing dense values while migrating to quantized KV")? + .index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..dense_end, + std::ops::RangeFull, + )); + let [b, n_kv_heads, _, k_head_dim] = keys.shape()[..4] else { + bail!("unexpected dense key shape while migrating to quantized KV"); + }; + let v_head_dim = values.shape()[3]; + let el_per_int = 32 / bits; + + let (kq, ks, kb) = mlx_rs::ops::quantize(&keys, group_size, bits)?; + let (vq, vs, vb) = mlx_rs::ops::quantize(&values, group_size, bits)?; + self.qkeys = Some(QuantizedCacheArrays { + data: kq, + scales: ks, + biases: kb, + }); + self.qvalues = Some(QuantizedCacheArrays { + data: vq, + scales: vs, + biases: vb, + }); + + let qkeys = self.qkeys.as_mut().unwrap(); + if qkeys.data.shape()[2] != dense_end { + qkeys.data = + qkeys + .data + .reshape(&[b, n_kv_heads, dense_end, k_head_dim / el_per_int])?; + qkeys.scales = + qkeys + .scales + .reshape(&[b, n_kv_heads, dense_end, k_head_dim / group_size])?; + qkeys.biases = + qkeys + .biases + .reshape(&[b, n_kv_heads, dense_end, k_head_dim / group_size])?; + } + + let qvalues = self.qvalues.as_mut().unwrap(); + if qvalues.data.shape()[2] != dense_end { + qvalues.data = + qvalues + .data + .reshape(&[b, n_kv_heads, dense_end, v_head_dim / el_per_int])?; + qvalues.scales = + qvalues + .scales + .reshape(&[b, n_kv_heads, dense_end, v_head_dim / group_size])?; + qvalues.biases = + qvalues + .biases + .reshape(&[b, n_kv_heads, dense_end, v_head_dim / group_size])?; + } + + self.keys = None; + self.values = None; + Ok(()) + } + + fn update_quantized( + &mut self, + k: Array, + v: Array, + group_size: i32, + bits: i32, + min_dense_tokens: usize, + ) -> Result<CachedKv> { + let seq_len = k.shape()[2] as usize; + let prev = self.offset; + + if self.qkeys.is_none() && (prev + seq_len) <= min_dense_tokens { + return self + .update_standard(k, v) + .map(|(keys, values)| CachedKv::Dense { keys, values }); + } + + if self.qkeys.is_none() { + self.quantize_dense_prefix(group_size, bits, prev)?; + } + + if self.qkeys.is_none() + || (prev + seq_len) > self.qkeys.as_ref().unwrap().data.shape()[2] as usize + { + let [b, n_kv_heads, _, k_head_dim] = k.shape()[..4] else { + bail!("unexpected quantized k shape"); + }; + let v_head_dim = v.shape()[3]; + let el_per_int = 32 / bits; + let n_steps = ((KV_CACHE_STEP + seq_len - 1) / KV_CACHE_STEP) * KV_CACHE_STEP; + + let init_quant = |head_dim: i32, dtype: Dtype| -> Result<QuantizedCacheArrays> { + Ok(QuantizedCacheArrays { + data: mlx_rs::ops::zeros_dtype( + &[b, n_kv_heads, n_steps as i32, head_dim / el_per_int], + Dtype::Uint32, + )?, + scales: mlx_rs::ops::zeros_dtype( + &[b, n_kv_heads, n_steps as i32, head_dim / group_size], + dtype, + )?, + biases: mlx_rs::ops::zeros_dtype( + &[b, n_kv_heads, n_steps as i32, head_dim / group_size], + dtype, + )?, + }) + }; + + match (&self.qkeys, &self.qvalues) { + (Some(existing_k), Some(existing_v)) => { + let (mut existing_k, mut existing_v) = (existing_k, existing_v); + if prev % KV_CACHE_STEP != 0 { + let end = prev as i32; + self.qkeys = Some(existing_k.trim_to(end)?); + self.qvalues = Some(existing_v.trim_to(end)?); + existing_k = self.qkeys.as_ref().unwrap(); + existing_v = self.qvalues.as_ref().unwrap(); + } + self.qkeys = Some(existing_k.expand(n_steps as i32)?); + self.qvalues = Some(existing_v.expand(n_steps as i32)?); + } + _ => { + self.qkeys = Some(init_quant(k_head_dim, k.dtype())?); + self.qvalues = Some(init_quant(v_head_dim, v.dtype())?); + } + } + } + + self.offset = prev + seq_len; + self.start_offset = 0; + let prev_i = prev as i32; + let end_i = self.offset as i32; + + let (kq, ks, kb) = mlx_rs::ops::quantize(&k, group_size, bits)?; + let (vq, vs, vb) = mlx_rs::ops::quantize(&v, group_size, bits)?; + let qkeys = self.qkeys.as_mut().unwrap(); + let qvalues = self.qvalues.as_mut().unwrap(); + qkeys.data.try_index_mut( + ( + std::ops::RangeFull, + std::ops::RangeFull, + prev_i..end_i, + std::ops::RangeFull, + ), + &kq, + )?; + qkeys.scales.try_index_mut( + ( + std::ops::RangeFull, + std::ops::RangeFull, + prev_i..end_i, + std::ops::RangeFull, + ), + &ks, + )?; + qkeys.biases.try_index_mut( + ( + std::ops::RangeFull, + std::ops::RangeFull, + prev_i..end_i, + std::ops::RangeFull, + ), + &kb, + )?; + qvalues.data.try_index_mut( + ( + std::ops::RangeFull, + std::ops::RangeFull, + prev_i..end_i, + std::ops::RangeFull, + ), + &vq, + )?; + qvalues.scales.try_index_mut( + ( + std::ops::RangeFull, + std::ops::RangeFull, + prev_i..end_i, + std::ops::RangeFull, + ), + &vs, + )?; + qvalues.biases.try_index_mut( + ( + std::ops::RangeFull, + std::ops::RangeFull, + prev_i..end_i, + std::ops::RangeFull, + ), + &vb, + )?; + + Ok(CachedKv::Quantized { + keys: qkeys.prefix(end_i), + values: qvalues.prefix(end_i), + group_size, + bits, + }) + } + + /// Rewind the cache to `n` tokens if the requested prefix is still retained. + pub fn trim_to(&mut self, n: usize) -> Result<bool> { + if !self.can_trim_to(n) { + return Ok(false); + } + let retained_len = n.saturating_sub(self.start_offset) as i32; + match self.mode { + KVCacheMode::Standard => { + if n != self.offset { + if let Some(keys) = &self.keys { + self.keys = Some(materialize_cache_prefix(keys, retained_len)?); + } + if let Some(values) = &self.values { + self.values = Some(materialize_cache_prefix(values, retained_len)?); + } + } + } + KVCacheMode::Quantized { .. } => { + if n != self.offset { + if let Some(keys) = &self.qkeys { + self.qkeys = Some(keys.trim_to(retained_len)?); + } + if let Some(values) = &self.qvalues { + self.qvalues = Some(values.trim_to(retained_len)?); + } + } + } + KVCacheMode::Rotating { .. } => {} + } + if matches!(self.mode, KVCacheMode::Rotating { .. }) && n != self.offset { + if let (Some(keys), Some(values)) = (&self.keys, &self.values) { + self.keys = Some(self.temporal_order(keys)?); + self.values = Some(self.temporal_order(values)?); + } + } + self.offset = n; + if matches!(self.mode, KVCacheMode::Rotating { .. }) { + self.idx = self.current_len(); + } + Ok(true) + } +} + +pub(super) fn trim_rotating_cache( + array: &Array, + trim_size: usize, + keep: usize, + append: Option<&Array>, +) -> Result<Array> { + use std::ops::RangeFull; + + let mut parts = Vec::new(); + if trim_size > 0 { + if keep > 0 { + parts.push(array.index((RangeFull, RangeFull, ..(keep as i32), RangeFull))); + } + parts.push(array.index((RangeFull, RangeFull, (trim_size + keep) as i32.., RangeFull))); + } else { + parts.push(array.clone()); + } + if let Some(append) = append { + parts.push(append.clone()); + } + let refs: Vec<&Array> = parts.iter().collect(); + Ok(mlx_rs::ops::concatenate_axis(&refs, 2)?) +} diff --git a/mesh-llm/src/mlx/model/config.rs b/mesh-llm/src/mlx/model/config.rs new file mode 100644 index 00000000..e8a65cb5 --- /dev/null +++ b/mesh-llm/src/mlx/model/config.rs @@ -0,0 +1,428 @@ +use super::family::{model_architecture, ModelArchitecture}; +use anyhow::{Context, Result}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, serde::Deserialize)] +pub struct ModelConfig { + pub hidden_size: i32, + pub num_hidden_layers: i32, + #[allow(dead_code)] + #[serde(default)] + pub intermediate_size: i32, + pub num_attention_heads: i32, + pub num_key_value_heads: i32, + #[serde(default)] + pub head_dim: Option<i32>, + #[serde(default)] + pub query_pre_attn_scalar: Option<f32>, + #[serde(default)] + pub global_head_dim: Option<i32>, + pub vocab_size: i32, + #[serde(default)] + #[allow(dead_code)] + pub vocab_size_per_layer_input: Option<i32>, + #[serde(alias = "norm_eps")] + pub rms_norm_eps: f32, + #[serde(default = "default_rope_theta")] + pub rope_theta: f32, + #[serde(default)] + pub partial_rotary_factor: Option<f32>, + #[allow(dead_code)] + #[serde(alias = "model_max_length")] + pub max_position_embeddings: i32, + #[serde(default, deserialize_with = "deserialize_nullable_bool")] + pub tie_word_embeddings: bool, + #[serde(default, alias = "hidden_act")] + pub hidden_activation: Option<String>, + #[serde(default)] + pub hidden_size_per_layer_input: Option<i32>, + #[serde(default)] + pub moe_intermediate_size: Option<i32>, + #[serde(default, alias = "num_shared_experts")] + pub n_shared_experts: Option<i32>, + #[serde(default, alias = "num_experts")] + pub n_routed_experts: Option<i32>, + #[serde(default)] + pub routed_scaling_factor: Option<f32>, + #[serde(default)] + pub kv_lora_rank: Option<i32>, + #[serde(default)] + pub q_lora_rank: Option<i32>, + #[serde(default)] + pub qk_rope_head_dim: Option<i32>, + #[serde(default)] + pub v_head_dim: Option<i32>, + #[serde(default)] + pub qk_nope_head_dim: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub topk_method: Option<String>, + #[serde(default, alias = "moe_renormalize")] + pub norm_topk_prob: Option<bool>, + #[serde(default, alias = "num_expert_group")] + pub n_group: Option<i32>, + #[serde(default)] + pub topk_group: Option<i32>, + #[serde(default, alias = "num_experts_per_token")] + pub num_experts_per_tok: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub num_local_experts: Option<i32>, + #[serde(default)] + pub moe_layer_freq: Option<i32>, + #[serde(default)] + pub first_k_dense_replace: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub attention_bias: Option<bool>, + #[serde(default)] + pub num_kv_shared_layers: Option<i32>, + #[serde(default)] + pub layer_types: Option<Vec<String>>, + #[serde(default, deserialize_with = "deserialize_rope_parameters")] + pub rope_parameters: Option<HashMap<String, RopeParameters>>, + #[serde(default)] + pub attn_logit_softcapping: Option<f32>, + #[serde(default)] + pub final_logit_softcapping: Option<f32>, + #[serde(default)] + #[allow(dead_code)] + pub sliding_window: Option<i32>, + #[serde(default)] + pub sliding_window_pattern: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub cache_implementation: Option<String>, + #[serde(default)] + #[allow(dead_code)] + pub conv_bias: Option<bool>, + #[serde(default, alias = "conv_L_cache")] + pub conv_l_cache: Option<i32>, + #[serde(default)] + pub block_norm_eps: Option<f32>, + #[serde(default)] + #[allow(dead_code)] + pub block_dim: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub block_ff_dim: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub block_multiple_of: Option<i32>, + #[serde(default)] + #[allow(dead_code)] + pub block_ffn_dim_multiplier: Option<f32>, + #[serde(default)] + #[allow(dead_code)] + pub block_auto_adjust_ff_dim: Option<bool>, + #[serde(default)] + pub full_attn_idxs: Option<Vec<i32>>, + #[serde(default)] + pub linear_attn_config: Option<LinearAttnConfig>, + #[serde(default)] + #[allow(dead_code)] + pub moe_router_activation_func: Option<String>, + pub quantization: Option<super::QuantConfig>, + #[serde(default, deserialize_with = "deserialize_eos_token_id")] + pub eos_token_id: Vec<u32>, +} + +#[derive(Debug, serde::Deserialize, Clone)] +pub struct LinearAttnConfig { + #[allow(dead_code)] + pub full_attn_layers: Vec<i32>, + pub kda_layers: Vec<i32>, + pub num_heads: i32, + pub head_dim: i32, + #[serde(default)] + pub short_conv_kernel_size: Option<i32>, +} + +#[derive(Debug, serde::Deserialize, Clone)] +pub struct RopeParameters { + #[serde(default)] + pub partial_rotary_factor: Option<f32>, + #[serde(default)] + pub rope_theta: Option<f32>, +} + +pub(super) fn deserialize_nullable_bool<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result<bool, D::Error> { + use serde::Deserialize; + Ok(Option::<bool>::deserialize(deserializer)?.unwrap_or(false)) +} + +pub(super) fn deserialize_eos_token_id<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result<Vec<u32>, D::Error> { + use serde::Deserialize; + #[derive(Deserialize)] + #[serde(untagged)] + enum EosId { + Single(u32), + Multiple(Vec<u32>), + } + Ok(match EosId::deserialize(deserializer)? { + EosId::Single(id) => vec![id], + EosId::Multiple(ids) => ids, + }) +} + +pub(super) fn deserialize_rope_parameters<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result<Option<HashMap<String, RopeParameters>>, D::Error> { + use serde::Deserialize; + #[derive(Deserialize)] + #[serde(untagged)] + enum RopeParametersField { + PerLayer(HashMap<String, RopeParameters>), + Flat(RopeParameters), + } + + Ok( + match Option::<RopeParametersField>::deserialize(deserializer)? { + None => None, + Some(RopeParametersField::PerLayer(map)) => Some(map), + Some(RopeParametersField::Flat(params)) => { + let mut map = HashMap::new(); + map.insert("default".to_string(), params); + Some(map) + } + }, + ) +} + +pub(super) fn default_rope_theta() -> f32 { + 10_000.0 +} + +pub(super) fn effective_text_config_json(config: &Value) -> Value { + let Some(text_config) = config + .get("text_config") + .and_then(|value| value.as_object()) + else { + return config.clone(); + }; + + let mut merged = serde_json::Map::new(); + for (key, value) in text_config { + merged.insert(key.clone(), value.clone()); + } + for key in [ + "quantization", + "eos_token_id", + "rope_theta", + "rms_norm_eps", + "head_dim", + "max_position_embeddings", + "tie_word_embeddings", + "hidden_activation", + "query_pre_attn_scalar", + "global_head_dim", + "vocab_size_per_layer_input", + "vocab_size", + "hidden_size_per_layer_input", + "moe_intermediate_size", + "n_shared_experts", + "n_routed_experts", + "routed_scaling_factor", + "kv_lora_rank", + "q_lora_rank", + "qk_rope_head_dim", + "v_head_dim", + "qk_nope_head_dim", + "topk_method", + "norm_topk_prob", + "n_group", + "topk_group", + "num_experts_per_tok", + "moe_layer_freq", + "first_k_dense_replace", + "attention_bias", + "num_kv_shared_layers", + "layer_types", + "rope_parameters", + "attn_logit_softcapping", + "final_logit_softcapping", + "sliding_window", + "sliding_window_pattern", + "cache_implementation", + "conv_bias", + "conv_L_cache", + "block_dim", + "block_ff_dim", + "block_multiple_of", + "block_ffn_dim_multiplier", + "block_auto_adjust_ff_dim", + "full_attn_idxs", + ] { + if !merged.contains_key(key) || merged.get(key).is_some_and(Value::is_null) { + if let Some(value) = config.get(key) { + merged.insert(key.to_string(), value.clone()); + } + } + } + if !merged.contains_key("architectures") { + if let Some(value) = config.get("architectures") { + merged.insert("architectures".to_string(), value.clone()); + } + } + + Value::Object(merged) +} + +pub(super) fn normalized_model_config_json(config: &Value) -> Value { + let mut normalized = effective_text_config_json(config); + let Some(object) = normalized.as_object_mut() else { + return normalized; + }; + + if !object.contains_key("hidden_activation") { + if let Some(value) = object.get("hidden_act").cloned() { + object.insert("hidden_activation".to_string(), value); + } + } + object.remove("hidden_act"); + + if model_architecture(config).is_gemma3() { + let sliding_window_pattern = object + .get("sliding_window_pattern") + .and_then(|value| value.as_i64()) + .unwrap_or(6); + if object.get("layer_types").is_none_or(Value::is_null) { + if let Some(num_hidden_layers) = object + .get("num_hidden_layers") + .and_then(|value| value.as_i64()) + { + let layer_types = (0..num_hidden_layers) + .map(|i| { + if (i + 1) % sliding_window_pattern != 0 { + Value::String("sliding_attention".to_string()) + } else { + Value::String("full_attention".to_string()) + } + }) + .collect::<Vec<_>>(); + object.insert("layer_types".to_string(), Value::Array(layer_types)); + } + } + + if object.get("rope_parameters").is_none_or(Value::is_null) { + let full_theta = object + .get("rope_theta") + .and_then(|value| value.as_f64()) + .unwrap_or(1_000_000.0); + let sliding_theta = object + .get("rope_local_base_freq") + .and_then(|value| value.as_f64()) + .unwrap_or(10_000.0); + object.insert( + "rope_parameters".to_string(), + serde_json::json!({ + "sliding_attention": { + "rope_type": "default", + "rope_theta": sliding_theta + }, + "full_attention": { + "rope_type": "default", + "rope_theta": full_theta + } + }), + ); + } + + if object + .get("use_bidirectional_attention") + .is_none_or(Value::is_null) + { + object.insert( + "use_bidirectional_attention".to_string(), + Value::Bool(false), + ); + } + } + + normalized +} + +pub(super) fn attention_window_size_for_layer( + arch: ModelArchitecture, + config: &ModelConfig, + layer_idx: usize, + layer_type: Option<&str>, +) -> Result<Option<i32>> { + if arch.is_gpt_oss() { + return if matches!(layer_type, Some("sliding_attention")) { + Ok(Some(config.sliding_window.context( + "missing sliding_window for gpt-oss sliding layer", + )?)) + } else { + Ok(None) + }; + } + + if arch.is_gemma3() { + let pattern = config.sliding_window_pattern.unwrap_or(1); + return if pattern > 1 && (layer_idx as i32 % pattern) != (pattern - 1) { + Ok(Some(config.sliding_window.context( + "missing sliding_window for gemma3 sliding layer", + )?)) + } else { + Ok(None) + }; + } + + Ok(None) +} + +pub(super) fn kv_shared_source_for_layer( + arch: ModelArchitecture, + config: &ModelConfig, + layer_idx: usize, + layer_type: Option<&str>, + non_shared_layer_types: Option<&[String]>, +) -> Option<usize> { + if !arch.is_gemma4() { + return None; + } + + let first_kv_shared_layer_idx = config + .num_kv_shared_layers + .map(|n| (config.num_hidden_layers - n) as usize) + .unwrap_or(config.num_hidden_layers as usize); + + if layer_idx < first_kv_shared_layer_idx { + return None; + } + + non_shared_layer_types.and_then(|types| { + layer_type.and_then(|current| { + types + .iter() + .rposition(|candidate| candidate == current) + .map(|index| index) + }) + }) +} + +pub(super) fn experimental_quantized_kv_config() -> Option<(i32, i32, usize)> { + let bits = std::env::var("MESH_LLM_MLX_QUANTIZED_KV_BITS") + .ok()? + .parse::<i32>() + .ok()?; + if bits <= 0 { + return None; + } + let group_size = std::env::var("MESH_LLM_MLX_QUANTIZED_KV_GROUP_SIZE") + .ok() + .and_then(|value| value.parse::<i32>().ok()) + .filter(|value| *value > 0) + .unwrap_or(64); + let min_dense_tokens = std::env::var("MESH_LLM_MLX_QUANTIZED_KV_MIN_TOKENS") + .ok() + .and_then(|value| value.parse::<usize>().ok()) + .unwrap_or(256); + Some((group_size, bits, min_dense_tokens)) +} diff --git a/mesh-llm/src/mlx/model/embedding.rs b/mesh-llm/src/mlx/model/embedding.rs new file mode 100644 index 00000000..c9cc47f8 --- /dev/null +++ b/mesh-llm/src/mlx/model/embedding.rs @@ -0,0 +1,67 @@ +use super::*; +use serde_json::Value; + +pub(super) struct QuantizedEmbedding { + pub(super) weight: Array, + pub(super) scales: Array, + pub(super) biases: Array, + pub(super) group_size: i32, + pub(super) bits: i32, + pub(super) dense_weight: Option<Array>, + pub(super) dense_weight_t: Option<Array>, +} + +impl QuantizedEmbedding { + pub(super) fn forward(&self, indices: &Array) -> Result<Array> { + if let Some(dense_weight) = &self.dense_weight { + return Ok(dense_weight.take_axis(indices, 0)?); + } + let w = self.weight.take_axis(indices, 0)?; + let s = self.scales.take_axis(indices, 0)?; + let b = self.biases.take_axis(indices, 0)?; + Ok(mlx_rs::ops::dequantize( + &w, + &s, + &b, + self.group_size, + self.bits, + )?) + } + + pub(super) fn as_linear(&self) -> QuantizedLinear { + QuantizedLinear { + weight: self.weight.clone(), + scales: self.scales.clone(), + biases: self.biases.clone(), + bias: None, + group_size: self.group_size, + bits: self.bits, + dense_weight_t: self.dense_weight_t.clone(), + } + } +} + +pub(super) fn quant_params_for( + config: &Value, + prefix: &str, + default_group_size: i32, + default_bits: i32, +) -> (i32, i32) { + let override_cfg = config + .get("quantization") + .and_then(Value::as_object) + .and_then(|q| q.get(prefix)) + .cloned() + .and_then(|value| serde_json::from_value::<QuantOverride>(value).ok()); + + ( + override_cfg + .as_ref() + .and_then(|cfg| cfg.group_size) + .unwrap_or(default_group_size), + override_cfg + .as_ref() + .and_then(|cfg| cfg.bits) + .unwrap_or(default_bits), + ) +} diff --git a/mesh-llm/src/mlx/model/families/deepseek_v3.rs b/mesh-llm/src/mlx/model/families/deepseek_v3.rs new file mode 100644 index 00000000..a2806ae4 --- /dev/null +++ b/mesh-llm/src/mlx/model/families/deepseek_v3.rs @@ -0,0 +1,222 @@ +use super::super::layer::Layer; +use super::super::{ + quant_params_for, quantize_stacked_weights, rms_norm_kind, DeepseekV3Attention, DeepseekV3MoE, + MlpKind, ModelConfig, QuantizedLinear, QuantizedMultiLinear, QuantizedSwitchLinear, RMSNorm, + TensorPrefixes, MLP, +}; +use anyhow::{Context, Result}; +use mlx_rs::ops::dequantize; +use mlx_rs::ops::indexing::IndexOp; +use mlx_rs::Array; +use serde_json::Value; +use std::collections::HashMap; + +pub(crate) fn transform_deepseek_v3_tensors( + tensors: &mut HashMap<String, Array>, + prefixes: &TensorPrefixes, + config: &ModelConfig, + config_json: &Value, + default_group_size: i32, + default_bits: i32, +) -> Result<()> { + let num_heads = config.num_attention_heads; + let qk_nope_head_dim = config + .qk_nope_head_dim + .context("missing qk_nope_head_dim for DeepSeekV3")?; + let v_head_dim = config + .v_head_dim + .context("missing v_head_dim for DeepSeekV3")?; + let kv_lora_rank = config + .kv_lora_rank + .context("missing kv_lora_rank for DeepSeekV3")?; + + for i in 0..config.num_hidden_layers { + let prefix = format!("{}.layers.{i}.self_attn", prefixes.model); + if !tensors.contains_key(&format!("{prefix}.kv_b_proj.weight")) + || tensors.contains_key(&format!("{prefix}.embed_q.weight")) + { + continue; + } + + let (group_size, bits) = quant_params_for( + config_json, + &format!("{prefix}.kv_b_proj"), + default_group_size, + default_bits, + ); + let weight = tensors + .get(&format!("{prefix}.kv_b_proj.weight")) + .cloned() + .with_context(|| format!("missing {prefix}.kv_b_proj.weight"))?; + let scales = tensors + .get(&format!("{prefix}.kv_b_proj.scales")) + .cloned() + .with_context(|| format!("missing {prefix}.kv_b_proj.scales"))?; + let biases = tensors + .get(&format!("{prefix}.kv_b_proj.biases")) + .cloned() + .with_context(|| format!("missing {prefix}.kv_b_proj.biases"))?; + let dense = dequantize(&weight, &scales, &biases, group_size, bits)?; + let dense = dense.reshape(&[num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank])?; + let wk = dense + .index((std::ops::RangeFull, ..qk_nope_head_dim, std::ops::RangeFull)) + .transpose_axes(&[0, 2, 1])?; + let wv = dense.index((std::ops::RangeFull, qk_nope_head_dim.., std::ops::RangeFull)); + let (wk_q, wk_s, wk_b) = quantize_stacked_weights(&wk, group_size, bits)?; + let (wv_q, wv_s, wv_b) = quantize_stacked_weights(&wv, group_size, bits)?; + tensors.insert(format!("{prefix}.embed_q.weight"), wk_q); + tensors.insert(format!("{prefix}.embed_q.scales"), wk_s); + tensors.insert(format!("{prefix}.embed_q.biases"), wk_b); + tensors.insert(format!("{prefix}.unembed_out.weight"), wv_q); + tensors.insert(format!("{prefix}.unembed_out.scales"), wv_s); + tensors.insert(format!("{prefix}.unembed_out.biases"), wv_b); + } + + Ok(()) +} + +pub(crate) fn build_deepseek_v3_layer<FQ, FM, FS>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: i32, + config: &ModelConfig, + load_qlinear: &FQ, + load_multi_linear: &FM, + load_switch_linear: &FS, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, + FM: Fn(&str) -> Result<QuantizedMultiLinear>, + FS: Fn(&str) -> Result<QuantizedSwitchLinear>, +{ + let qk_nope_head_dim = config + .qk_nope_head_dim + .context("missing qk_nope_head_dim for DeepSeekV3")?; + let qk_rope_head_dim = config + .qk_rope_head_dim + .context("missing qk_rope_head_dim for DeepSeekV3")?; + let kv_lora_rank = config + .kv_lora_rank + .context("missing kv_lora_rank for DeepSeekV3")?; + let v_head_dim = config + .v_head_dim + .context("missing v_head_dim for DeepSeekV3")?; + let q_head_dim = qk_nope_head_dim + qk_rope_head_dim; + let is_moe_layer = config.n_routed_experts.is_some() + && (layer_index >= config.first_k_dense_replace.unwrap_or(0)) + && (layer_index % config.moe_layer_freq.unwrap_or(1) == 0); + let shared_intermediate = config + .n_shared_experts + .zip(config.moe_intermediate_size) + .map(|(n_shared, hidden)| n_shared * hidden); + let mlp_kind = if is_moe_layer { + MlpKind::DeepseekV3MoE(DeepseekV3MoE { + switch_gate_proj: load_switch_linear(&format!("{p}.mlp.switch_mlp.gate_proj"))?, + switch_up_proj: load_switch_linear(&format!("{p}.mlp.switch_mlp.up_proj"))?, + switch_down_proj: load_switch_linear(&format!("{p}.mlp.switch_mlp.down_proj"))?, + gate_weight: tensors + .get(&format!("{p}.mlp.gate.weight")) + .cloned() + .with_context(|| format!("missing {p}.mlp.gate.weight"))?, + gate_bias: tensors + .get(&format!("{p}.mlp.gate.e_score_correction_bias")) + .cloned() + .with_context(|| format!("missing {p}.mlp.gate.e_score_correction_bias"))?, + top_k: config.num_experts_per_tok.unwrap_or(1), + n_group: config.n_group.unwrap_or(1), + topk_group: config.topk_group.unwrap_or(1), + routed_scaling_factor: config.routed_scaling_factor.unwrap_or(1.0), + norm_topk_prob: config.norm_topk_prob.unwrap_or(true), + shared_experts: shared_intermediate + .map(|_intermediate_size| -> Result<MLP> { + Ok(MLP { + gate_up_proj: None, + gate_proj: Some(load_qlinear(&format!( + "{p}.mlp.shared_experts.gate_proj" + ))?), + up_proj: Some(load_qlinear(&format!("{p}.mlp.shared_experts.up_proj"))?), + down_proj: load_qlinear(&format!("{p}.mlp.shared_experts.down_proj"))?, + activation: super::super::mlp::Activation::Silu, + }) + }) + .transpose()?, + }) + } else { + MlpKind::Dense(MLP { + gate_up_proj: None, + gate_proj: Some(load_qlinear(&format!("{p}.mlp.gate_proj"))?), + up_proj: Some(load_qlinear(&format!("{p}.mlp.up_proj"))?), + down_proj: load_qlinear(&format!("{p}.mlp.down_proj"))?, + activation: super::super::mlp::Activation::Silu, + }) + }; + + Ok(Layer { + attn: super::super::AttentionKind::DeepseekV3(DeepseekV3Attention { + q_proj: if config.q_lora_rank.is_some() { + None + } else { + Some(load_qlinear(&format!("{p}.self_attn.q_proj"))?) + }, + q_a_proj: config + .q_lora_rank + .is_some() + .then(|| load_qlinear(&format!("{p}.self_attn.q_a_proj"))) + .transpose()?, + q_a_layernorm: tensors + .get(&format!("{p}.self_attn.q_a_layernorm.weight")) + .cloned() + .map(|weight| RMSNorm { + weight, + eps: 1e-6, + add_unit_offset: false, + }), + q_b_proj: config + .q_lora_rank + .is_some() + .then(|| load_qlinear(&format!("{p}.self_attn.q_b_proj"))) + .transpose()?, + kv_a_proj_with_mqa: load_qlinear(&format!("{p}.self_attn.kv_a_proj_with_mqa"))?, + kv_a_layernorm: RMSNorm { + weight: tensors + .get(&format!("{p}.self_attn.kv_a_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.self_attn.kv_a_layernorm.weight"))?, + eps: 1e-6, + add_unit_offset: false, + }, + embed_q: load_multi_linear(&format!("{p}.self_attn.embed_q"))?, + unembed_out: load_multi_linear(&format!("{p}.self_attn.unembed_out"))?, + o_proj: load_qlinear(&format!("{p}.self_attn.o_proj"))?, + num_heads: config.num_attention_heads, + q_head_dim, + qk_rope_head_dim, + qk_nope_head_dim, + kv_lora_rank, + v_head_dim, + scale: 1.0 / (q_head_dim as f32).sqrt(), + rope_theta: config.rope_theta, + }), + mlp: mlp_kind, + attn_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.input_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.input_layernorm.weight"))?, + config.rms_norm_eps, + false, + )), + attn_out_norm: None, + mlp_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.post_attention_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.post_attention_layernorm.weight"))?, + config.rms_norm_eps, + false, + )), + mlp_out_norm: None, + per_layer_input: None, + layer_scalar: None, + }) +} diff --git a/mesh-llm/src/mlx/model/families/gemma3.rs b/mesh-llm/src/mlx/model/families/gemma3.rs new file mode 100644 index 00000000..9bee0f2c --- /dev/null +++ b/mesh-llm/src/mlx/model/families/gemma3.rs @@ -0,0 +1,23 @@ +use super::super::{ModelConfig, TensorPrefixes}; +use anyhow::Result; +use mlx_rs::Array; +use std::collections::HashMap; + +pub(crate) fn transform_gemma3_tensors( + tensors: &mut HashMap<String, Array>, + _prefixes: &TensorPrefixes, + config: &ModelConfig, +) -> Result<()> { + tensors.retain(|key, _| { + !key.starts_with("vision_tower.") && !key.starts_with("multi_modal_projector.") + }); + + if config.tie_word_embeddings { + tensors.remove("language_model.lm_head.weight"); + tensors.remove("language_model.lm_head.scales"); + tensors.remove("language_model.lm_head.biases"); + tensors.remove("language_model.lm_head.bias"); + } + + Ok(()) +} diff --git a/mesh-llm/src/mlx/model/families/gemma4.rs b/mesh-llm/src/mlx/model/families/gemma4.rs new file mode 100644 index 00000000..f5af736b --- /dev/null +++ b/mesh-llm/src/mlx/model/families/gemma4.rs @@ -0,0 +1,42 @@ +use super::super::{ModelConfig, TensorPrefixes}; +use anyhow::Result; +use mlx_rs::Array; +use std::collections::HashMap; + +pub(crate) fn transform_gemma4_tensors( + tensors: &mut HashMap<String, Array>, + _prefixes: &TensorPrefixes, + config: &ModelConfig, +) -> Result<()> { + let mut normalized = HashMap::with_capacity(tensors.len()); + for (key, value) in tensors.drain() { + let starts_w_model = key.starts_with("model."); + let key = key.strip_prefix("model.").unwrap_or(&key).to_string(); + + if key.starts_with("vision_tower.") + || key.starts_with("multi_modal_projector.") + || key.starts_with("audio_tower.") + || key.starts_with("embed_audio.") + || key.starts_with("embed_vision.") + { + continue; + } + + let normalized_key = if starts_w_model && key.starts_with("language_model.") { + key.replacen("language_model.", "language_model.model.", 1) + } else { + key + }; + normalized.insert(normalized_key, value); + } + *tensors = normalized; + + if config.tie_word_embeddings { + tensors.remove("lm_head.weight"); + tensors.remove("lm_head.scales"); + tensors.remove("lm_head.biases"); + tensors.remove("lm_head.bias"); + } + + Ok(()) +} diff --git a/mesh-llm/src/mlx/model/families/gpt_oss.rs b/mesh-llm/src/mlx/model/families/gpt_oss.rs new file mode 100644 index 00000000..0301422e --- /dev/null +++ b/mesh-llm/src/mlx/model/families/gpt_oss.rs @@ -0,0 +1,156 @@ +use super::super::layer::Layer; +use super::super::{ + rms_norm_kind, Attention, AttentionKind, GptOssMoE, MlpKind, ModelConfig, QuantizedLinear, + QuantizedSwitchLinear, TensorPrefixes, +}; +use anyhow::{Context, Result}; +use mlx_rs::Array; +use std::collections::HashMap; + +fn split_even_odd_axis(tensor: &Array, axis: i32) -> Result<(Array, Array)> { + let shape = tensor.shape(); + let ndim = shape.len() as i32; + let axis = if axis < 0 { ndim + axis } else { axis }; + if axis < 0 || axis >= ndim { + anyhow::bail!("axis {axis} out of bounds for GPT-OSS tensor shape {shape:?}"); + } + let axis_len = shape[axis as usize]; + let even_idx: Vec<u32> = (0..axis_len).step_by(2).map(|idx| idx as u32).collect(); + let odd_idx: Vec<u32> = (1..axis_len).step_by(2).map(|idx| idx as u32).collect(); + Ok(( + tensor.take_axis( + &Array::from_slice(&even_idx, &[even_idx.len() as i32]), + axis, + )?, + tensor.take_axis(&Array::from_slice(&odd_idx, &[odd_idx.len() as i32]), axis)?, + )) +} + +fn split_gate_up_proj(prefix: &str, tensors: &mut HashMap<String, Array>) -> Result<()> { + if tensors.contains_key(&format!("{prefix}.gate_proj.weight")) { + return Ok(()); + } + + for suffix in ["weight", "scales", "biases"] { + let key = format!("{prefix}.gate_up_proj.{suffix}"); + if let Some(fused) = tensors.get(&key).cloned() { + let (gate, up) = split_even_odd_axis(&fused, -2)?; + tensors.insert(format!("{prefix}.gate_proj.{suffix}"), gate); + tensors.insert(format!("{prefix}.up_proj.{suffix}"), up); + } + } + + let bias_key = format!("{prefix}.gate_up_proj_bias"); + if let Some(fused_bias) = tensors.get(&bias_key).cloned() { + let (gate_bias, up_bias) = split_even_odd_axis(&fused_bias, -1)?; + tensors.insert(format!("{prefix}.gate_proj.bias"), gate_bias); + tensors.insert(format!("{prefix}.up_proj.bias"), up_bias); + } + + Ok(()) +} + +fn normalize_down_proj_bias(prefix: &str, tensors: &mut HashMap<String, Array>) { + let legacy_key = format!("{prefix}.down_proj_bias"); + let normalized_key = format!("{prefix}.down_proj.bias"); + if let Some(bias) = tensors.get(&legacy_key).cloned() { + tensors.entry(normalized_key).or_insert(bias); + } +} + +fn normalize_expert_quant_bias(prefix: &str, tensors: &mut HashMap<String, Array>) { + let legacy_key = format!("{prefix}.bias"); + let normalized_key = format!("{prefix}.biases"); + if tensors.contains_key(&normalized_key) { + return; + } + if let Some(biases) = tensors.remove(&legacy_key) { + tensors.insert(normalized_key, biases); + } +} + +pub(crate) fn transform_gpt_oss_tensors( + tensors: &mut HashMap<String, Array>, + prefixes: &TensorPrefixes, + config: &ModelConfig, +) -> Result<()> { + for i in 0..config.num_hidden_layers { + let mlp_prefix = format!("{}.layers.{i}.mlp.experts", prefixes.model); + if tensors + .keys() + .any(|key| key.starts_with(&format!("{mlp_prefix}.gate_up_proj"))) + { + split_gate_up_proj(&mlp_prefix, tensors) + .with_context(|| format!("failed to sanitize GPT-OSS tensors for {mlp_prefix}"))?; + } + normalize_down_proj_bias(&mlp_prefix, tensors); + normalize_expert_quant_bias(&format!("{mlp_prefix}.gate_proj"), tensors); + normalize_expert_quant_bias(&format!("{mlp_prefix}.up_proj"), tensors); + normalize_expert_quant_bias(&format!("{mlp_prefix}.down_proj"), tensors); + } + + Ok(()) +} + +pub(crate) fn build_gpt_oss_layer<FQ, FS>( + tensors: &HashMap<String, Array>, + p: &str, + config: &ModelConfig, + head_dim: i32, + window_size: Option<i32>, + load_qlinear: &FQ, + load_switch_linear: &FS, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, + FS: Fn(&str) -> Result<QuantizedSwitchLinear>, +{ + Ok(Layer { + attn: AttentionKind::Standard(Attention { + q_proj: load_qlinear(&format!("{p}.self_attn.q_proj"))?, + k_proj: load_qlinear(&format!("{p}.self_attn.k_proj"))?, + v_proj: load_qlinear(&format!("{p}.self_attn.v_proj"))?, + o_proj: load_qlinear(&format!("{p}.self_attn.o_proj"))?, + q_norm: None, + k_norm: None, + v_norm: None, + num_heads: config.num_attention_heads, + num_kv_heads: config.num_key_value_heads, + head_dim, + scale: 1.0 / (head_dim as f32).sqrt(), + attn_logit_softcapping: None, + rope_dim: head_dim, + rope_theta: config.rope_theta, + rope_traditional: false, + window_size, + kv_shared_source: None, + }), + mlp: MlpKind::GptOssMoE(GptOssMoE { + switch_gate_proj: load_switch_linear(&format!("{p}.mlp.experts.gate_proj"))?, + switch_up_proj: load_switch_linear(&format!("{p}.mlp.experts.up_proj"))?, + switch_down_proj: load_switch_linear(&format!("{p}.mlp.experts.down_proj"))?, + router: load_qlinear(&format!("{p}.mlp.router"))?, + top_k: config.num_experts_per_tok.unwrap_or(1), + }), + attn_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.input_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.input_layernorm.weight"))?, + config.rms_norm_eps, + false, + )), + attn_out_norm: None, + mlp_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.post_attention_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.post_attention_layernorm.weight"))?, + config.rms_norm_eps, + false, + )), + mlp_out_norm: None, + per_layer_input: None, + layer_scalar: None, + }) +} diff --git a/mesh-llm/src/mlx/model/families/kimi.rs b/mesh-llm/src/mlx/model/families/kimi.rs new file mode 100644 index 00000000..9bf58923 --- /dev/null +++ b/mesh-llm/src/mlx/model/families/kimi.rs @@ -0,0 +1,186 @@ +use super::super::layer::Layer; +use super::super::{ + rms_norm_kind, AttentionKind, DeepseekV3MoE, KimiDeltaAttention, KimiMlaAttention, + KimiShortConv, MlpKind, ModelConfig, QuantizedLinear, QuantizedMultiLinear, + QuantizedSwitchLinear, RMSNorm, MLP, +}; +use anyhow::{Context, Result}; +use mlx_rs::Array; +use std::collections::HashMap; + +pub(crate) fn build_kimi_linear_layer<FQ, FM, FS>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: i32, + config: &ModelConfig, + load_qlinear: &FQ, + load_multi_linear: &FM, + load_switch_linear: &FS, + load_conv_weight: &impl Fn(&str) -> Result<Array>, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, + FM: Fn(&str) -> Result<QuantizedMultiLinear>, + FS: Fn(&str) -> Result<QuantizedSwitchLinear>, +{ + let linear_cfg = config + .linear_attn_config + .as_ref() + .context("missing linear_attn_config for Kimi Linear")?; + let is_linear_layer = linear_cfg.kda_layers.contains(&(layer_index + 1)); + let projection_dim = linear_cfg.num_heads * linear_cfg.head_dim; + let is_moe_layer = config.n_routed_experts.unwrap_or(0) > 0 + && (layer_index >= config.first_k_dense_replace.unwrap_or(0)) + && (layer_index % config.moe_layer_freq.unwrap_or(1) == 0); + let mlp = if is_moe_layer { + MlpKind::DeepseekV3MoE(DeepseekV3MoE { + switch_gate_proj: load_switch_linear(&format!("{p}.mlp.switch_mlp.gate_proj"))?, + switch_up_proj: load_switch_linear(&format!("{p}.mlp.switch_mlp.up_proj"))?, + switch_down_proj: load_switch_linear(&format!("{p}.mlp.switch_mlp.down_proj"))?, + gate_weight: tensors + .get(&format!("{p}.mlp.gate.weight")) + .cloned() + .with_context(|| format!("missing {p}.mlp.gate.weight"))?, + gate_bias: tensors + .get(&format!("{p}.mlp.e_score_correction_bias")) + .cloned() + .with_context(|| format!("missing {p}.mlp.e_score_correction_bias"))?, + top_k: config.num_experts_per_tok.unwrap_or(1), + n_group: config.n_group.unwrap_or(1), + topk_group: config.topk_group.unwrap_or(1), + routed_scaling_factor: config.routed_scaling_factor.unwrap_or(1.0), + norm_topk_prob: config.norm_topk_prob.unwrap_or(true), + shared_experts: config + .n_shared_experts + .filter(|n| *n > 0) + .map(|_| -> Result<MLP> { + Ok(MLP { + gate_up_proj: None, + gate_proj: Some(load_qlinear(&format!( + "{p}.mlp.shared_experts.gate_proj" + ))?), + up_proj: Some(load_qlinear(&format!("{p}.mlp.shared_experts.up_proj"))?), + down_proj: load_qlinear(&format!("{p}.mlp.shared_experts.down_proj"))?, + activation: super::super::mlp::Activation::Silu, + }) + }) + .transpose()?, + }) + } else { + MlpKind::Dense(MLP { + gate_up_proj: None, + gate_proj: Some(load_qlinear(&format!("{p}.mlp.gate_proj"))?), + up_proj: Some(load_qlinear(&format!("{p}.mlp.up_proj"))?), + down_proj: load_qlinear(&format!("{p}.mlp.down_proj"))?, + activation: super::super::mlp::Activation::Silu, + }) + }; + + let attn = if is_linear_layer { + AttentionKind::KimiDelta(KimiDeltaAttention { + q_proj: load_qlinear(&format!("{p}.self_attn.q_proj"))?, + k_proj: load_qlinear(&format!("{p}.self_attn.k_proj"))?, + v_proj: load_qlinear(&format!("{p}.self_attn.v_proj"))?, + q_conv: KimiShortConv { + conv_weight: load_conv_weight(&format!("{p}.self_attn.q_conv.conv"))?, + kernel_size: linear_cfg.short_conv_kernel_size.unwrap_or(4), + channels: projection_dim, + }, + k_conv: KimiShortConv { + conv_weight: load_conv_weight(&format!("{p}.self_attn.k_conv.conv"))?, + kernel_size: linear_cfg.short_conv_kernel_size.unwrap_or(4), + channels: projection_dim, + }, + v_conv: KimiShortConv { + conv_weight: load_conv_weight(&format!("{p}.self_attn.v_conv.conv"))?, + kernel_size: linear_cfg.short_conv_kernel_size.unwrap_or(4), + channels: projection_dim, + }, + f_a_proj: load_qlinear(&format!("{p}.self_attn.f_a_proj"))?, + f_b_proj: load_qlinear(&format!("{p}.self_attn.f_b_proj"))?, + b_proj: load_qlinear(&format!("{p}.self_attn.b_proj"))?, + g_a_proj: load_qlinear(&format!("{p}.self_attn.g_a_proj"))?, + g_b_proj: load_qlinear(&format!("{p}.self_attn.g_b_proj"))?, + a_log: tensors + .get(&format!("{p}.self_attn.A_log")) + .cloned() + .with_context(|| format!("missing {p}.self_attn.A_log"))?, + dt_bias: tensors + .get(&format!("{p}.self_attn.dt_bias")) + .cloned() + .with_context(|| format!("missing {p}.self_attn.dt_bias"))?, + o_norm: RMSNorm { + weight: tensors + .get(&format!("{p}.self_attn.o_norm.weight")) + .cloned() + .with_context(|| format!("missing {p}.self_attn.o_norm.weight"))?, + eps: config.rms_norm_eps, + add_unit_offset: false, + }, + o_proj: load_qlinear(&format!("{p}.self_attn.o_proj"))?, + num_heads: linear_cfg.num_heads, + head_dim: linear_cfg.head_dim, + scale: (linear_cfg.head_dim as f32).powf(-0.5), + }) + } else { + let qk_nope_head_dim = config + .qk_nope_head_dim + .context("missing qk_nope_head_dim for Kimi Linear MLA")?; + let qk_rope_head_dim = config + .qk_rope_head_dim + .context("missing qk_rope_head_dim for Kimi Linear MLA")?; + let kv_lora_rank = config + .kv_lora_rank + .context("missing kv_lora_rank for Kimi Linear MLA")?; + let v_head_dim = config + .v_head_dim + .context("missing v_head_dim for Kimi Linear MLA")?; + AttentionKind::KimiMla(KimiMlaAttention { + q_proj: load_qlinear(&format!("{p}.self_attn.q_proj"))?, + kv_a_proj_with_mqa: load_qlinear(&format!("{p}.self_attn.kv_a_proj_with_mqa"))?, + kv_a_layernorm: RMSNorm { + weight: tensors + .get(&format!("{p}.self_attn.kv_a_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.self_attn.kv_a_layernorm.weight"))?, + eps: config.rms_norm_eps, + add_unit_offset: false, + }, + embed_q: load_multi_linear(&format!("{p}.self_attn.embed_q"))?, + unembed_out: load_multi_linear(&format!("{p}.self_attn.unembed_out"))?, + o_proj: load_qlinear(&format!("{p}.self_attn.o_proj"))?, + num_heads: config.num_attention_heads, + q_head_dim: qk_nope_head_dim + qk_rope_head_dim, + qk_rope_head_dim, + qk_nope_head_dim, + kv_lora_rank, + v_head_dim, + scale: 1.0 / ((qk_nope_head_dim + qk_rope_head_dim) as f32).sqrt(), + }) + }; + + Ok(Layer { + attn, + mlp, + attn_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.input_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.input_layernorm.weight"))?, + config.rms_norm_eps, + false, + )), + attn_out_norm: None, + mlp_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.post_attention_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.post_attention_layernorm.weight"))?, + config.rms_norm_eps, + false, + )), + mlp_out_norm: None, + per_layer_input: None, + layer_scalar: None, + }) +} diff --git a/mesh-llm/src/mlx/model/families/lfm2.rs b/mesh-llm/src/mlx/model/families/lfm2.rs new file mode 100644 index 00000000..3514cbe3 --- /dev/null +++ b/mesh-llm/src/mlx/model/families/lfm2.rs @@ -0,0 +1,101 @@ +use super::super::layer::Layer; +use super::super::{ + rms_norm_kind, Activation, Attention, AttentionKind, Lfm2ShortConv, MlpKind, ModelConfig, + QuantizedLinear, RMSNorm, MLP, +}; +use anyhow::{Context, Result}; +use mlx_rs::Array; +use std::collections::HashMap; + +pub(crate) fn build_lfm2_layer<FQ>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: i32, + config: &ModelConfig, + head_dim: i32, + load_qlinear: &FQ, + load_conv_weight: &impl Fn(&str) -> Result<Array>, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, +{ + let full_attn_idxs = config + .full_attn_idxs + .as_ref() + .with_context(|| format!("missing full_attn_idxs for LFM2 layer {}", layer_index))?; + let is_attention_layer = full_attn_idxs.contains(&layer_index); + let operator = if is_attention_layer { + AttentionKind::Standard(Attention { + q_proj: load_qlinear(&format!("{p}.self_attn.q_proj"))?, + k_proj: load_qlinear(&format!("{p}.self_attn.k_proj"))?, + v_proj: load_qlinear(&format!("{p}.self_attn.v_proj"))?, + o_proj: load_qlinear(&format!("{p}.self_attn.out_proj"))?, + q_norm: tensors + .get(&format!("{p}.self_attn.q_layernorm.weight")) + .cloned() + .map(|weight| RMSNorm { + weight, + eps: config.block_norm_eps.unwrap_or(config.rms_norm_eps), + add_unit_offset: false, + }), + k_norm: tensors + .get(&format!("{p}.self_attn.k_layernorm.weight")) + .cloned() + .map(|weight| RMSNorm { + weight, + eps: config.block_norm_eps.unwrap_or(config.rms_norm_eps), + add_unit_offset: false, + }), + v_norm: None, + num_heads: config.num_attention_heads, + num_kv_heads: config.num_key_value_heads, + head_dim, + scale: 1.0 / (head_dim as f32).sqrt(), + attn_logit_softcapping: None, + rope_dim: head_dim, + rope_theta: config.rope_theta, + rope_traditional: false, + window_size: None, + kv_shared_source: None, + }) + } else { + AttentionKind::Lfm2ShortConv(Lfm2ShortConv { + conv_weight: load_conv_weight(&format!("{p}.conv.conv"))?, + in_proj: load_qlinear(&format!("{p}.conv.in_proj"))?, + out_proj: load_qlinear(&format!("{p}.conv.out_proj"))?, + hidden_size: config.hidden_size, + conv_l_cache: config.conv_l_cache.unwrap_or(3), + }) + }; + + Ok(Layer { + attn: operator, + mlp: MlpKind::Dense(MLP { + gate_up_proj: None, + gate_proj: Some(load_qlinear(&format!("{p}.feed_forward.w1"))?), + up_proj: Some(load_qlinear(&format!("{p}.feed_forward.w3"))?), + down_proj: load_qlinear(&format!("{p}.feed_forward.w2"))?, + activation: Activation::Silu, + }), + attn_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.operator_norm.weight")) + .cloned() + .with_context(|| format!("missing {p}.operator_norm.weight"))?, + config.block_norm_eps.unwrap_or(config.rms_norm_eps), + false, + )), + attn_out_norm: None, + mlp_in_norm: Some(rms_norm_kind( + tensors + .get(&format!("{p}.ffn_norm.weight")) + .cloned() + .with_context(|| format!("missing {p}.ffn_norm.weight"))?, + config.block_norm_eps.unwrap_or(config.rms_norm_eps), + false, + )), + mlp_out_norm: None, + per_layer_input: None, + layer_scalar: None, + }) +} diff --git a/mesh-llm/src/mlx/model/families/llama_like.rs b/mesh-llm/src/mlx/model/families/llama_like.rs new file mode 100644 index 00000000..48157ca2 --- /dev/null +++ b/mesh-llm/src/mlx/model/families/llama_like.rs @@ -0,0 +1,274 @@ +use super::super::layer::{Layer, PerLayerInputBlock}; +use super::super::mlp::{Activation, MlpKind, MLP}; +use super::super::{ + layer_norm_kind, rms_norm_kind, Attention, AttentionKind, ModelArchitecture, ModelConfig, + NormKind, QuantizedLinear, RMSNorm, TensorPrefixes, +}; +use anyhow::{Context, Result}; +use mlx_rs::Array; +use std::collections::HashMap; + +pub(crate) fn transform_llama_like_tensors( + tensors: &mut HashMap<String, Array>, + prefixes: &TensorPrefixes, + config: &ModelConfig, +) -> Result<()> { + tensors.retain(|key, _| !key.contains("self_attn.rotary_emb.inv_freq")); + + if config.tie_word_embeddings { + if let Some(prefix) = prefixes.lm_head.as_deref() { + tensors.remove(&format!("{prefix}.weight")); + tensors.remove(&format!("{prefix}.scales")); + tensors.remove(&format!("{prefix}.biases")); + tensors.remove(&format!("{prefix}.bias")); + } + } + + Ok(()) +} + +pub(crate) fn build_standard_layer<FQ>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: usize, + arch: ModelArchitecture, + config: &ModelConfig, + layer_type: Option<&str>, + head_dim: i32, + rope_traditional: bool, + non_shared_layer_types: Option<&[String]>, + load_qlinear: &FQ, + attention_window_size_for_layer: &impl Fn( + ModelArchitecture, + &ModelConfig, + usize, + Option<&str>, + ) -> Result<Option<i32>>, + kv_shared_source_for_layer: &impl Fn( + ModelArchitecture, + &ModelConfig, + usize, + Option<&str>, + Option<&[String]>, + ) -> Option<usize>, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, +{ + let is_full_attention = arch.is_gemma4() && matches!(layer_type, Some("full_attention")); + let layer_head_dim = if is_full_attention { + config.global_head_dim.unwrap_or(head_dim) + } else { + head_dim + }; + let rope_parameters = layer_type.and_then(|name| { + config + .rope_parameters + .as_ref() + .and_then(|map| map.get(name)) + }); + let rope_dim = if is_full_attention { + ((layer_head_dim as f32) + * rope_parameters + .and_then(|params| params.partial_rotary_factor) + .unwrap_or(1.0)) + .round() as i32 + } else if arch.is_glm4() { + ((layer_head_dim as f32) * config.partial_rotary_factor.unwrap_or(1.0)).round() as i32 + } else { + layer_head_dim + }; + let rope_theta = rope_parameters + .and_then(|params| params.rope_theta) + .unwrap_or(config.rope_theta); + let window_size = attention_window_size_for_layer(arch, config, layer_index, layer_type)?; + let kv_shared_source = kv_shared_source_for_layer( + arch, + config, + layer_index, + layer_type, + non_shared_layer_types, + ); + let scale = if arch.is_gemma4() { + 1.0 + } else if let Some(query_pre_attn_scalar) = config.query_pre_attn_scalar { + 1.0 / query_pre_attn_scalar.sqrt() + } else { + 1.0 / (layer_head_dim as f32).sqrt() + }; + let mlp_in_norm_key = if arch.is_glm4() { + format!("{p}.post_attention_layernorm.weight") + } else if arch.is_gemma2() || arch.is_gemma3() || arch.is_gemma4() { + format!("{p}.pre_feedforward_layernorm.weight") + } else { + format!("{p}.post_attention_layernorm.weight") + }; + + Ok(Layer { + attn: AttentionKind::Standard(Attention { + q_proj: load_qlinear(&format!("{p}.self_attn.q_proj"))?, + k_proj: load_qlinear(&format!("{p}.self_attn.k_proj"))?, + v_proj: load_qlinear(&format!("{p}.self_attn.v_proj"))?, + o_proj: load_qlinear(&format!("{p}.self_attn.o_proj"))?, + q_norm: tensors + .get(&format!("{p}.self_attn.q_norm.weight")) + .cloned() + .map(|weight| RMSNorm { + weight, + eps: config.rms_norm_eps, + add_unit_offset: arch.uses_gemma_norm_offset(), + }), + k_norm: tensors + .get(&format!("{p}.self_attn.k_norm.weight")) + .cloned() + .map(|weight| RMSNorm { + weight, + eps: config.rms_norm_eps, + add_unit_offset: arch.uses_gemma_norm_offset(), + }), + v_norm: arch.is_gemma4().then(|| RMSNorm { + weight: mlx_rs::ops::ones::<f32>(&[layer_head_dim]) + .expect("allocating v_norm scale"), + eps: config.rms_norm_eps, + add_unit_offset: false, + }), + num_heads: config.num_attention_heads, + num_kv_heads: config.num_key_value_heads, + head_dim: layer_head_dim, + scale, + attn_logit_softcapping: arch + .is_gemma2() + .then_some(config.attn_logit_softcapping.unwrap_or(50.0)), + rope_dim, + rope_theta, + rope_traditional, + window_size, + kv_shared_source, + }), + mlp: MlpKind::Dense(MLP { + gate_up_proj: tensors + .contains_key(&format!("{p}.mlp.gate_up_proj.weight")) + .then(|| load_qlinear(&format!("{p}.mlp.gate_up_proj"))) + .transpose()?, + gate_proj: tensors + .contains_key(&format!("{p}.mlp.gate_proj.weight")) + .then(|| load_qlinear(&format!("{p}.mlp.gate_proj"))) + .transpose()?, + up_proj: tensors + .contains_key(&format!("{p}.mlp.up_proj.weight")) + .then(|| load_qlinear(&format!("{p}.mlp.up_proj"))) + .transpose()?, + down_proj: load_qlinear(&format!("{p}.mlp.down_proj"))?, + activation: match config.hidden_activation.as_deref() { + Some("gelu_pytorch_tanh") | Some("gelu") => Activation::GeluApproximate, + _ => Activation::Silu, + }, + }), + attn_in_norm: (!arch.is_olmo2()) + .then(|| -> Result<NormKind> { + if arch.is_olmo() { + Ok(layer_norm_kind(1e-5)) + } else { + Ok(rms_norm_kind( + tensors + .get(&format!("{p}.input_layernorm.weight")) + .cloned() + .with_context(|| format!("missing {p}.input_layernorm.weight"))?, + config.rms_norm_eps, + arch.uses_gemma_norm_offset(), + )) + } + }) + .transpose()?, + attn_out_norm: (arch.is_glm4() + || arch.is_olmo2() + || arch.is_gemma2() + || arch.is_gemma3() + || arch.is_gemma4()) + .then(|| -> Result<NormKind> { + let key = if arch.is_glm4() { + format!("{p}.post_self_attn_layernorm.weight") + } else { + format!("{p}.post_attention_layernorm.weight") + }; + Ok(rms_norm_kind( + tensors + .get(&key) + .cloned() + .with_context(|| format!("missing {key}"))?, + config.rms_norm_eps, + arch.uses_gemma_norm_offset(), + )) + }) + .transpose()?, + mlp_in_norm: (!arch.is_olmo2()) + .then(|| -> Result<NormKind> { + if arch.is_olmo() { + Ok(layer_norm_kind(1e-5)) + } else { + Ok(rms_norm_kind( + tensors.get(&mlp_in_norm_key).cloned().with_context(|| { + if arch.is_gemma2() || arch.is_gemma3() { + format!("missing {p}.pre_feedforward_layernorm.weight") + } else if arch.is_glm4() { + format!("missing {p}.post_attention_layernorm.weight") + } else { + format!("missing {p}.post_attention_layernorm.weight") + } + })?, + config.rms_norm_eps, + arch.uses_gemma_norm_offset(), + )) + } + }) + .transpose()?, + mlp_out_norm: (arch.is_glm4() + || arch.is_olmo2() + || arch.is_gemma2() + || arch.is_gemma3() + || arch.is_gemma4()) + .then(|| -> Result<NormKind> { + let key = if arch.is_glm4() { + format!("{p}.post_mlp_layernorm.weight") + } else { + format!("{p}.post_feedforward_layernorm.weight") + }; + Ok(rms_norm_kind( + tensors + .get(&key) + .cloned() + .with_context(|| format!("missing {key}"))?, + config.rms_norm_eps, + arch.uses_gemma_norm_offset(), + )) + }) + .transpose()?, + per_layer_input: arch + .is_gemma4() + .then(|| -> Result<PerLayerInputBlock> { + Ok(PerLayerInputBlock { + input_gate: load_qlinear(&format!("{p}.per_layer_input_gate"))?, + projection: load_qlinear(&format!("{p}.per_layer_projection"))?, + post_norm: rms_norm_kind( + tensors + .get(&format!("{p}.post_per_layer_input_norm.weight")) + .cloned() + .with_context(|| { + format!("missing {p}.post_per_layer_input_norm.weight") + })?, + config.rms_norm_eps, + false, + ), + activation: match config.hidden_activation.as_deref() { + Some("gelu_pytorch_tanh") | Some("gelu") => Activation::GeluApproximate, + _ => Activation::Silu, + }, + }) + }) + .transpose()?, + layer_scalar: arch + .is_gemma4() + .then(|| tensors.get(&format!("{p}.layer_scalar")).cloned()) + .flatten(), + }) +} diff --git a/mesh-llm/src/mlx/model/families/mod.rs b/mesh-llm/src/mlx/model/families/mod.rs new file mode 100644 index 00000000..a4f1f61f --- /dev/null +++ b/mesh-llm/src/mlx/model/families/mod.rs @@ -0,0 +1,213 @@ +use super::family::ModelArchitecture; +use super::layer::Layer; +use super::{ + ModelConfig, QuantizedLinear, QuantizedMultiLinear, QuantizedSwitchLinear, TensorPrefixes, +}; +use anyhow::Result; +use mlx_rs::Array; +use serde_json::Value; +use std::collections::HashMap; + +mod deepseek_v3; +mod gemma3; +mod gemma4; +mod gpt_oss; +mod kimi; +mod lfm2; +mod llama_like; +mod olmo2; +mod phi3; + +pub(crate) fn apply_family_tensor_transforms( + arch: ModelArchitecture, + tensors: &mut HashMap<String, Array>, + prefixes: &TensorPrefixes, + config: &ModelConfig, + config_json: &Value, + default_group_size: i32, + default_bits: i32, +) -> Result<()> { + if matches!(arch, ModelArchitecture::LlamaLike) { + llama_like::transform_llama_like_tensors(tensors, prefixes, config)?; + } + + if arch.is_deepseek_v3() || arch.is_kimi_linear() { + deepseek_v3::transform_deepseek_v3_tensors( + tensors, + prefixes, + config, + config_json, + default_group_size, + default_bits, + )?; + } + + if config_json + .get("model_type") + .and_then(|value| value.as_str()) + .is_some_and(|value| value.eq_ignore_ascii_case("phi3")) + { + phi3::transform_phi3_tensors(tensors, prefixes, config)?; + } + + if arch.is_gpt_oss() { + gpt_oss::transform_gpt_oss_tensors(tensors, prefixes, config)?; + } + + if arch.is_gemma3() { + gemma3::transform_gemma3_tensors(tensors, prefixes, config)?; + } + + if arch.is_gemma4() { + gemma4::transform_gemma4_tensors(tensors, prefixes, config)?; + } + + if arch.is_olmo2() { + olmo2::transform_olmo2_tensors(tensors, prefixes, config)?; + } + + Ok(()) +} + +pub(crate) fn build_deepseek_v3_layer<FQ, FM, FS>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: i32, + config: &ModelConfig, + load_qlinear: &FQ, + load_multi_linear: &FM, + load_switch_linear: &FS, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, + FM: Fn(&str) -> Result<QuantizedMultiLinear>, + FS: Fn(&str) -> Result<QuantizedSwitchLinear>, +{ + deepseek_v3::build_deepseek_v3_layer( + tensors, + p, + layer_index, + config, + load_qlinear, + load_multi_linear, + load_switch_linear, + ) +} + +pub(crate) fn build_lfm2_layer<FQ>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: i32, + config: &ModelConfig, + head_dim: i32, + load_qlinear: &FQ, + load_conv_weight: &impl Fn(&str) -> Result<Array>, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, +{ + lfm2::build_lfm2_layer( + tensors, + p, + layer_index, + config, + head_dim, + load_qlinear, + load_conv_weight, + ) +} + +pub(crate) fn build_kimi_linear_layer<FQ, FM, FS>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: i32, + config: &ModelConfig, + load_qlinear: &FQ, + load_multi_linear: &FM, + load_switch_linear: &FS, + load_conv_weight: &impl Fn(&str) -> Result<Array>, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, + FM: Fn(&str) -> Result<QuantizedMultiLinear>, + FS: Fn(&str) -> Result<QuantizedSwitchLinear>, +{ + kimi::build_kimi_linear_layer( + tensors, + p, + layer_index, + config, + load_qlinear, + load_multi_linear, + load_switch_linear, + load_conv_weight, + ) +} + +pub(crate) fn build_gpt_oss_layer<FQ, FS>( + tensors: &HashMap<String, Array>, + p: &str, + config: &ModelConfig, + head_dim: i32, + window_size: Option<i32>, + load_qlinear: &FQ, + load_switch_linear: &FS, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, + FS: Fn(&str) -> Result<QuantizedSwitchLinear>, +{ + gpt_oss::build_gpt_oss_layer( + tensors, + p, + config, + head_dim, + window_size, + load_qlinear, + load_switch_linear, + ) +} + +pub(crate) fn build_standard_layer<FQ>( + tensors: &HashMap<String, Array>, + p: &str, + layer_index: usize, + arch: ModelArchitecture, + config: &ModelConfig, + layer_type: Option<&str>, + head_dim: i32, + rope_traditional: bool, + non_shared_layer_types: Option<&[String]>, + load_qlinear: &FQ, + attention_window_size_for_layer: &impl Fn( + ModelArchitecture, + &ModelConfig, + usize, + Option<&str>, + ) -> Result<Option<i32>>, + kv_shared_source_for_layer: &impl Fn( + ModelArchitecture, + &ModelConfig, + usize, + Option<&str>, + Option<&[String]>, + ) -> Option<usize>, +) -> Result<Layer> +where + FQ: Fn(&str) -> Result<QuantizedLinear>, +{ + llama_like::build_standard_layer( + tensors, + p, + layer_index, + arch, + config, + layer_type, + head_dim, + rope_traditional, + non_shared_layer_types, + load_qlinear, + attention_window_size_for_layer, + kv_shared_source_for_layer, + ) +} diff --git a/mesh-llm/src/mlx/model/families/olmo2.rs b/mesh-llm/src/mlx/model/families/olmo2.rs new file mode 100644 index 00000000..d4bb56c8 --- /dev/null +++ b/mesh-llm/src/mlx/model/families/olmo2.rs @@ -0,0 +1,18 @@ +use super::super::{ModelConfig, TensorPrefixes}; +use anyhow::Result; +use mlx_rs::Array; +use std::collections::HashMap; + +pub(crate) fn transform_olmo2_tensors( + tensors: &mut HashMap<String, Array>, + prefixes: &TensorPrefixes, + config: &ModelConfig, +) -> Result<()> { + for i in 0..config.num_hidden_layers { + tensors.remove(&format!( + "{}.layers.{i}.self_attn.rotary_emb.inv_freq", + prefixes.model + )); + } + Ok(()) +} diff --git a/mesh-llm/src/mlx/model/families/phi3.rs b/mesh-llm/src/mlx/model/families/phi3.rs new file mode 100644 index 00000000..214ac4b2 --- /dev/null +++ b/mesh-llm/src/mlx/model/families/phi3.rs @@ -0,0 +1,115 @@ +use super::super::{ModelConfig, TensorPrefixes}; +use anyhow::{bail, Context, Result}; +use mlx_rs::ops::indexing::IndexOp; +use mlx_rs::Array; +use std::collections::HashMap; + +fn slice_rows(tensor: &Array, start: i32, end: i32) -> Result<Array> { + Ok(tensor.index((start..end, std::ops::RangeFull))) +} + +fn split_fused_qkv( + prefix: &str, + tensors: &mut HashMap<String, Array>, + q_rows: i32, + kv_rows: i32, +) -> Result<()> { + if tensors.contains_key(&format!("{prefix}.q_proj.weight")) { + return Ok(()); + } + + let total_rows = q_rows + kv_rows + kv_rows; + for suffix in ["weight", "scales", "biases"] { + let key = format!("{prefix}.qkv_proj.{suffix}"); + let fused = tensors + .get(&key) + .cloned() + .with_context(|| format!("missing {key}"))?; + let shape = fused.shape(); + if shape.is_empty() || shape[0] != total_rows { + bail!( + "unexpected {key} shape {:?}; expected first dimension {}", + shape, + total_rows, + ); + } + tensors.insert( + format!("{prefix}.q_proj.{suffix}"), + slice_rows(&fused, 0, q_rows)?, + ); + tensors.insert( + format!("{prefix}.k_proj.{suffix}"), + slice_rows(&fused, q_rows, q_rows + kv_rows)?, + ); + tensors.insert( + format!("{prefix}.v_proj.{suffix}"), + slice_rows(&fused, q_rows + kv_rows, total_rows)?, + ); + } + + Ok(()) +} + +fn split_fused_gate_up( + prefix: &str, + tensors: &mut HashMap<String, Array>, + hidden_rows: i32, +) -> Result<()> { + if tensors.contains_key(&format!("{prefix}.gate_proj.weight")) { + return Ok(()); + } + + let total_rows = hidden_rows * 2; + for suffix in ["weight", "scales", "biases"] { + let key = format!("{prefix}.gate_up_proj.{suffix}"); + let fused = tensors + .get(&key) + .cloned() + .with_context(|| format!("missing {key}"))?; + let shape = fused.shape(); + if shape.is_empty() || shape[0] != total_rows { + bail!( + "unexpected {key} shape {:?}; expected first dimension {}", + shape, + total_rows, + ); + } + tensors.insert( + format!("{prefix}.gate_proj.{suffix}"), + slice_rows(&fused, 0, hidden_rows)?, + ); + tensors.insert( + format!("{prefix}.up_proj.{suffix}"), + slice_rows(&fused, hidden_rows, total_rows)?, + ); + } + + Ok(()) +} + +pub(crate) fn transform_phi3_tensors( + tensors: &mut HashMap<String, Array>, + prefixes: &TensorPrefixes, + config: &ModelConfig, +) -> Result<()> { + let head_dim = config + .head_dim + .unwrap_or_else(|| config.hidden_size / config.num_attention_heads); + let q_rows = config.num_attention_heads * head_dim; + let kv_rows = config.num_key_value_heads * head_dim; + let mlp_rows = config.intermediate_size; + + for i in 0..config.num_hidden_layers { + let attn_prefix = format!("{}.layers.{i}.self_attn", prefixes.model); + if tensors.contains_key(&format!("{attn_prefix}.qkv_proj.weight")) { + split_fused_qkv(&attn_prefix, tensors, q_rows, kv_rows)?; + } + + let mlp_prefix = format!("{}.layers.{i}.mlp", prefixes.model); + if tensors.contains_key(&format!("{mlp_prefix}.gate_up_proj.weight")) { + split_fused_gate_up(&mlp_prefix, tensors, mlp_rows)?; + } + } + + Ok(()) +} diff --git a/mesh-llm/src/mlx/model/family.rs b/mesh-llm/src/mlx/model/family.rs new file mode 100644 index 00000000..54c800d3 --- /dev/null +++ b/mesh-llm/src/mlx/model/family.rs @@ -0,0 +1,306 @@ +use anyhow::{bail, Result}; +use serde_json::Value; +use std::fs::File; +use std::io::Read; +use std::path::Path; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReasoningFamily { + None, + Qwen3, + Glm, + Kimi, + GptOss, + Lfm2, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum ModelArchitecture { + LlamaLike, + Olmo, + Olmo2, + DeepseekV3, + GptOss, + KimiLinear, + Lfm2, + Glm4, + Gemma2, + Gemma3, + Gemma4, +} + +impl ModelArchitecture { + pub(super) fn is_olmo2(self) -> bool { + matches!(self, Self::Olmo2) + } + + pub(super) fn is_olmo(self) -> bool { + matches!(self, Self::Olmo) + } + + pub(super) fn is_deepseek_v3(self) -> bool { + matches!(self, Self::DeepseekV3) + } + + pub(super) fn is_glm4(self) -> bool { + matches!(self, Self::Glm4) + } + + pub(super) fn is_gpt_oss(self) -> bool { + matches!(self, Self::GptOss) + } + + pub(super) fn is_kimi_linear(self) -> bool { + matches!(self, Self::KimiLinear) + } + + pub(super) fn is_lfm2(self) -> bool { + matches!(self, Self::Lfm2) + } + + pub(super) fn is_gemma2(self) -> bool { + matches!(self, Self::Gemma2) + } + + pub(super) fn is_gemma3(self) -> bool { + matches!(self, Self::Gemma3) + } + + pub(super) fn is_gemma4(self) -> bool { + matches!(self, Self::Gemma4) + } + + pub(super) fn uses_gemma_norm_offset(self) -> bool { + self.is_gemma2() || self.is_gemma3() + } + + pub(super) fn uses_gemma_scaled_embeddings(self) -> bool { + self.is_gemma2() || self.is_gemma3() || self.is_gemma4() + } +} + +pub(super) fn uses_traditional_rope(config: &Value) -> bool { + config + .get("rope_traditional") + .and_then(|value| value.as_bool()) + .or_else(|| { + config + .get("text_config") + .and_then(|value| value.get("rope_traditional")) + .and_then(|value| value.as_bool()) + }) + .unwrap_or(false) +} + +pub(super) fn model_architecture(config: &Value) -> ModelArchitecture { + let model_type = config + .get("model_type") + .and_then(|value| value.as_str()) + .or_else(|| { + config + .get("text_config") + .and_then(|value| value.get("model_type")) + .and_then(|value| value.as_str()) + }) + .unwrap_or_default() + .to_ascii_lowercase(); + + if model_type.starts_with("glm4") { + ModelArchitecture::Glm4 + } else if model_type.starts_with("gpt_oss") { + ModelArchitecture::GptOss + } else if model_type.starts_with("kimi_linear") { + ModelArchitecture::KimiLinear + } else if model_type.starts_with("deepseek_v3") + || model_type.starts_with("kimi_k2") + || model_type.starts_with("kimi_k25") + { + ModelArchitecture::DeepseekV3 + } else if model_type.starts_with("olmo2") { + ModelArchitecture::Olmo2 + } else if model_type.starts_with("olmo") { + ModelArchitecture::Olmo + } else if model_type.starts_with("lfm2") { + ModelArchitecture::Lfm2 + } else if model_type.starts_with("gemma4") { + ModelArchitecture::Gemma4 + } else if model_type.starts_with("gemma2") { + ModelArchitecture::Gemma2 + } else if model_type.starts_with("gemma3") { + ModelArchitecture::Gemma3 + } else { + ModelArchitecture::LlamaLike + } +} + +pub(super) fn reasoning_family(config: &Value) -> ReasoningFamily { + let model_type = config + .get("model_type") + .and_then(|value| value.as_str()) + .or_else(|| { + config + .get("text_config") + .and_then(|value| value.get("model_type")) + .and_then(|value| value.as_str()) + }) + .unwrap_or_default() + .to_ascii_lowercase(); + let architectures = config + .get("architectures") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|value| value.as_str()) + .map(|value| value.to_ascii_lowercase()) + .collect::<Vec<_>>(); + + if model_type == "qwen3" || architectures.iter().any(|value| value.contains("qwen3")) { + return ReasoningFamily::Qwen3; + } + if model_type.starts_with("glm") || architectures.iter().any(|value| value.contains("glm")) { + return ReasoningFamily::Glm; + } + if model_type == "gpt_oss" || architectures.iter().any(|value| value.contains("gptoss")) { + return ReasoningFamily::GptOss; + } + if model_type.starts_with("lfm2") || architectures.iter().any(|value| value.contains("lfm2")) { + return ReasoningFamily::Lfm2; + } + if model_type.contains("kimi") || architectures.iter().any(|value| value.contains("kimi")) { + return ReasoningFamily::Kimi; + } + ReasoningFamily::None +} + +pub(super) fn config_supports_mlx(config: &Value) -> bool { + let architectures = config + .get("architectures") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|value| value.as_str()); + let model_type = config.get("model_type").and_then(|value| value.as_str()); + + architectures.chain(model_type).any(|name| { + let name = name.to_ascii_lowercase(); + matches!( + name.as_str(), + "llama" + | "mistral" + | "glm4" + | "deepseek_v3" + | "lfm2" + | "phi3" + | "qwen2" + | "qwen3" + | "gpt_oss" + | "kimi_linear" + | "olmo" + | "olmo2" + | "gemma2" + | "gemma3" + | "gemma3_text" + | "gemma4" + | "gemma4_text" + | "glm4forcausallm" + | "deepseekv3forcausallm" + | "lfm2forcausallm" + | "phi3forcausallm" + | "llamaforcausallm" + | "mistralforcausallm" + | "qwen2forcausallm" + | "qwen3forcausallm" + | "gptossforcausallm" + | "kimilinearforcausallm" + | "olmoforcausallm" + | "olmo2forcausallm" + | "gemma2forcausallm" + | "gemma3forcausallm" + | "gemma3forconditionalgeneration" + | "gemma4forcausallm" + | "gemma4forconditionalgeneration" + ) + }) +} + +pub(super) fn ensure_supported_mlx_model(dir: &Path, config: &Value) -> Result<()> { + if config_supports_mlx(config) { + return Ok(()); + } + if let Some(architecture) = detect_architecture_from_safetensors_header(dir) { + tracing::info!( + "MLX loader: config.json did not identify a supported architecture, but safetensors headers matched {}", + architecture + ); + return Ok(()); + } + + let model_type = config + .get("model_type") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + let architectures = config + .get("architectures") + .and_then(|value| value.as_array()) + .map(|values| { + values + .iter() + .filter_map(|value| value.as_str()) + .collect::<Vec<_>>() + .join(", ") + }) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "none".to_string()); + bail!( + "unsupported MLX model architecture in {} (model_type={}, architectures={}); supported MLX models currently cover Llama/OLMo/DeepSeekV3/GPT-OSS/Kimi-Linear/LFM2/GLM4/Qwen/Gemma2/Gemma3/Gemma4-style safetensors checkpoints", + dir.display(), + model_type, + architectures, + ) +} + +pub(super) fn detect_architecture_from_safetensors_header(dir: &Path) -> Option<String> { + let path = if dir.join("model.safetensors").exists() { + dir.join("model.safetensors") + } else { + let text = std::fs::read_to_string(dir.join("model.safetensors.index.json")).ok()?; + let index: Value = serde_json::from_str(&text).ok()?; + let file = index + .get("weight_map") + .and_then(|value| value.as_object())? + .values() + .find_map(|value| value.as_str())?; + dir.join(file) + }; + + let mut file = File::open(path).ok()?; + let mut len_bytes = [0u8; 8]; + file.read_exact(&mut len_bytes).ok()?; + let header_len = u64::from_le_bytes(len_bytes) as usize; + if header_len == 0 || header_len > 16 * 1024 * 1024 { + return None; + } + let mut header = vec![0u8; header_len]; + file.read_exact(&mut header).ok()?; + let json: Value = serde_json::from_slice(&header).ok()?; + let map = json.as_object()?; + + let keys: Vec<&str> = map + .keys() + .filter(|key| key.as_str() != "__metadata__") + .map(|key| key.as_str()) + .collect(); + + if keys.iter().any(|key| key.starts_with("model.layers.")) + && keys + .iter() + .any(|key| key.starts_with("model.embed_tokens.")) + && keys + .iter() + .any(|key| key.contains(".self_attn.q_proj.") || key.contains(".self_attn.q_proj")) + { + return Some("llama_like".to_string()); + } + + None +} diff --git a/mesh-llm/src/mlx/model/kimi.rs b/mesh-llm/src/mlx/model/kimi.rs new file mode 100644 index 00000000..49802ed4 --- /dev/null +++ b/mesh-llm/src/mlx/model/kimi.rs @@ -0,0 +1,258 @@ +use super::*; +use mlx_rs::array; + +pub(crate) struct KimiMlaAttention { + pub(super) q_proj: QuantizedLinear, + pub(super) kv_a_proj_with_mqa: QuantizedLinear, + pub(super) kv_a_layernorm: RMSNorm, + pub(super) embed_q: QuantizedMultiLinear, + pub(super) unembed_out: QuantizedMultiLinear, + pub(super) o_proj: QuantizedLinear, + pub(super) num_heads: i32, + pub(super) q_head_dim: i32, + pub(super) qk_rope_head_dim: i32, + pub(super) qk_nope_head_dim: i32, + pub(super) kv_lora_rank: i32, + pub(super) v_head_dim: i32, + pub(super) scale: f32, +} + +impl KimiMlaAttention { + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + let shape = x.shape(); + let (b, l) = (shape[0], shape[1]); + + let q = self + .q_proj + .forward(x)? + .reshape(&[b, l, self.num_heads, self.q_head_dim])? + .transpose_axes(&[0, 2, 1, 3])?; + let q_nope = q.index(( + std::ops::RangeFull, + std::ops::RangeFull, + std::ops::RangeFull, + ..self.qk_nope_head_dim, + )); + let q_pe = q.index(( + std::ops::RangeFull, + std::ops::RangeFull, + std::ops::RangeFull, + self.qk_nope_head_dim.., + )); + + let compressed_kv = self.kv_a_proj_with_mqa.forward(x)?; + let kv_latent = compressed_kv.index(( + std::ops::RangeFull, + std::ops::RangeFull, + ..self.kv_lora_rank, + )); + let k_pe = compressed_kv.index(( + std::ops::RangeFull, + std::ops::RangeFull, + self.kv_lora_rank.., + )); + let kv_latent = self.kv_a_layernorm.forward(&kv_latent)?.expand_dims(1)?; + let k_pe = k_pe + .reshape(&[b, l, 1, self.qk_rope_head_dim])? + .transpose_axes(&[0, 2, 1, 3])?; + + let pe_scores = mlx_rs::ops::matmul( + &q_pe.multiply(&array!(self.scale))?, + &k_pe.transpose_axes(&[0, 1, 3, 2])?, + )?; + + let output = if l == 1 { + let q_nope = self.embed_q.forward(&q_nope, true)?; + let scores = mlx_rs::ops::matmul( + &q_nope.multiply(&array!(self.scale))?, + &kv_latent.transpose_axes(&[0, 1, 3, 2])?, + )? + .add(&pe_scores)?; + let probs = mlx_rs::ops::softmax_axis(&scores, -1, true)?; + let output = mlx_rs::ops::matmul(&probs, &kv_latent)?; + self.unembed_out.forward(&output, true)? + } else { + let k = self.embed_q.forward(&kv_latent, false)?; + let v = self.unembed_out.forward(&kv_latent, true)?; + let mask = attention_mask(l, l, 0, 0, None)?.context("expected kimi mla mask")?; + let scores = mlx_rs::ops::matmul( + &q_nope.multiply(&array!(self.scale))?, + &k.transpose_axes(&[0, 1, 3, 2])?, + )? + .add(&pe_scores)?; + let fill = array!(scores.dtype().finfo_min()? as f32).as_dtype(scores.dtype())?; + let scores = mlx_rs::ops::r#where(&mask, &scores, &fill)?; + let probs = mlx_rs::ops::softmax_axis(&scores, -1, true)?; + mlx_rs::ops::matmul(&probs, &v)? + }; + + let output = output.transpose_axes(&[0, 2, 1, 3])?.reshape(&[ + b, + l, + self.num_heads * self.v_head_dim, + ])?; + self.o_proj.forward(&output) + } +} + +pub(crate) struct KimiShortConv { + pub(super) conv_weight: Array, + pub(super) kernel_size: i32, + pub(super) channels: i32, +} + +impl KimiShortConv { + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + let x = pad( + x, + &[(0, 0), (self.kernel_size - 1, 0), (0, 0)], + None::<Array>, + None::<mlx_rs::ops::PadMode>, + )?; + let x = conv1d( + &x, + &self.conv_weight, + None::<i32>, + None::<i32>, + None::<i32>, + Some(self.channels), + )?; + Ok(&mlx_rs::ops::sigmoid(&x)? * &x) + } +} + +pub(crate) struct KimiDeltaAttention { + pub(super) q_proj: QuantizedLinear, + pub(super) k_proj: QuantizedLinear, + pub(super) v_proj: QuantizedLinear, + pub(super) q_conv: KimiShortConv, + pub(super) k_conv: KimiShortConv, + pub(super) v_conv: KimiShortConv, + pub(super) f_a_proj: QuantizedLinear, + pub(super) f_b_proj: QuantizedLinear, + pub(super) b_proj: QuantizedLinear, + pub(super) g_a_proj: QuantizedLinear, + pub(super) g_b_proj: QuantizedLinear, + pub(super) a_log: Array, + pub(super) dt_bias: Array, + pub(super) o_norm: RMSNorm, + pub(super) o_proj: QuantizedLinear, + pub(super) num_heads: i32, + pub(super) head_dim: i32, + pub(super) scale: f32, +} + +impl KimiDeltaAttention { + fn gated_delta_update( + &self, + q: &Array, + k: &Array, + v: &Array, + a: &Array, + b: &Array, + ) -> Result<Array> { + let bsz = q.shape()[0]; + let seq = q.shape()[1]; + let heads = q.shape()[2]; + let dim = q.shape()[3]; + let mut state = mlx_rs::ops::zeros_dtype(&[bsz, heads, dim, dim], q.dtype())?; + let beta = mlx_rs::ops::sigmoid(b)?; + let a = a.add( + &self + .dt_bias + .reshape(&[1, 1, self.num_heads, self.head_dim])?, + )?; + let g = mlx_rs::ops::exp(&mlx_rs::ops::negative( + &mlx_rs::ops::exp( + &self + .a_log + .reshape(&[1, 1, self.num_heads, 1])? + .as_dtype(Dtype::Float32)?, + )? + .multiply(&mlx_rs::nn::softplus(&a)?)?, + )?)? + .as_dtype(q.dtype())?; + + let mut ys = Vec::with_capacity(seq as usize); + for t in 0..seq { + let q_t = q.index(( + std::ops::RangeFull, + t, + std::ops::RangeFull, + std::ops::RangeFull, + )); + let k_t = k.index(( + std::ops::RangeFull, + t, + std::ops::RangeFull, + std::ops::RangeFull, + )); + let v_t = v.index(( + std::ops::RangeFull, + t, + std::ops::RangeFull, + std::ops::RangeFull, + )); + let beta_t = beta.index(( + std::ops::RangeFull, + t, + std::ops::RangeFull, + std::ops::RangeFull, + )); + let g_t = g.index(( + std::ops::RangeFull, + t, + std::ops::RangeFull, + std::ops::RangeFull, + )); + state = state.multiply(&g_t.expand_dims(2)?)?; + let kv_mem = state + .multiply(&k_t.expand_dims(2)?)? + .sum_axes(&[-1], false)?; + let delta = v_t.subtract(&kv_mem)?.multiply(&beta_t)?; + state = state.add(&k_t.expand_dims(2)?.multiply(&delta.expand_dims(3)?)?)?; + ys.push( + state + .multiply(&q_t.expand_dims(2)?)? + .sum_axes(&[-1], false)?, + ); + } + let y_refs: Vec<&Array> = ys.iter().collect(); + Ok(mlx_rs::ops::stack(&y_refs)?.swap_axes(0, 1)?) + } + + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + let shape = x.shape(); + let (b, l) = (shape[0], shape[1]); + let q_conv = self.q_conv.forward_no_cache(&self.q_proj.forward(x)?)?; + let k_conv = self.k_conv.forward_no_cache(&self.k_proj.forward(x)?)?; + let v_conv = self.v_conv.forward_no_cache(&self.v_proj.forward(x)?)?; + + let mut q = q_conv.reshape(&[b, l, self.num_heads, self.head_dim])?; + let mut k = k_conv.reshape(&[b, l, self.num_heads, self.head_dim])?; + let v = v_conv.reshape(&[b, l, self.num_heads, self.head_dim])?; + + q = unit_rms_norm(&q, 1e-6)?.multiply(&array!(self.scale * self.scale))?; + k = unit_rms_norm(&k, 1e-6)?.multiply(&array!(self.scale))?; + + let a_logits = self + .f_b_proj + .forward(&self.f_a_proj.forward(x)?)? + .reshape(&[b, l, self.num_heads, self.head_dim])?; + let b_logits = self + .b_proj + .forward(x)? + .reshape(&[b, l, self.num_heads, 1])?; + let out = self.gated_delta_update(&q, &k, &v, &a_logits, &b_logits)?; + let gate = self + .g_b_proj + .forward(&self.g_a_proj.forward(x)?)? + .reshape(&[b, l, self.num_heads, self.head_dim])?; + let out = self + .o_norm + .forward(&out)? + .multiply(&mlx_rs::ops::sigmoid(&gate)?)? + .reshape(&[b, l, self.num_heads * self.head_dim])?; + self.o_proj.forward(&out) + } +} diff --git a/mesh-llm/src/mlx/model/layer.rs b/mesh-llm/src/mlx/model/layer.rs new file mode 100644 index 00000000..ae936b0b --- /dev/null +++ b/mesh-llm/src/mlx/model/layer.rs @@ -0,0 +1,123 @@ +use super::*; + +pub(crate) struct Layer { + pub(super) attn: AttentionKind, + pub(super) mlp: MlpKind, + pub(super) attn_in_norm: Option<NormKind>, + pub(super) attn_out_norm: Option<NormKind>, + pub(super) mlp_in_norm: Option<NormKind>, + pub(super) mlp_out_norm: Option<NormKind>, + pub(super) per_layer_input: Option<PerLayerInputBlock>, + pub(super) layer_scalar: Option<Array>, +} + +impl Layer { + pub(super) fn forward_no_cache( + &self, + x: &Array, + per_layer_input: Option<&Array>, + ) -> Result<Array> { + let attn_input = if let Some(norm) = &self.attn_in_norm { + norm.forward(x)? + } else { + x.clone() + }; + let attn = self.attn.forward_no_cache(&attn_input)?; + let attn = if let Some(norm) = &self.attn_out_norm { + norm.forward(&attn)? + } else { + attn + }; + let h = &attn + x; + let mlp_input = if let Some(norm) = &self.mlp_in_norm { + norm.forward(&h)? + } else { + h.clone() + }; + let mlp = self.mlp.forward(&mlp_input)?; + let mlp = if let Some(norm) = &self.mlp_out_norm { + norm.forward(&mlp)? + } else { + mlp + }; + let mut out = &mlp + &h; + + if let (Some(block), Some(per_layer_input)) = (&self.per_layer_input, per_layer_input) { + let residual = out.clone(); + let mut gated = block.input_gate.forward(&out)?; + gated = match block.activation { + Activation::Silu => &mlx_rs::ops::sigmoid(&gated)? * &gated, + Activation::GeluApproximate => mlx_rs::nn::gelu_approximate(&gated)?, + }; + gated = gated.multiply(per_layer_input)?; + gated = block.projection.forward(&gated)?; + gated = block.post_norm.forward(&gated)?; + out = &gated + &residual; + } + + if let Some(layer_scalar) = &self.layer_scalar { + out = out.multiply(layer_scalar)?; + } + + Ok(out) + } + + pub(super) fn forward( + &self, + x: &Array, + per_layer_input: Option<&Array>, + cache: &mut KVCache, + shared_cache: Option<&KVCache>, + ) -> Result<Array> { + let attn_input = if let Some(norm) = &self.attn_in_norm { + norm.forward(x)? + } else { + x.clone() + }; + let attn = self.attn.forward(&attn_input, cache, shared_cache)?; + let attn = if let Some(norm) = &self.attn_out_norm { + norm.forward(&attn)? + } else { + attn + }; + let h = &attn + x; + let mlp_input = if let Some(norm) = &self.mlp_in_norm { + norm.forward(&h)? + } else { + h.clone() + }; + let mlp = self.mlp.forward(&mlp_input)?; + let mlp = if let Some(norm) = &self.mlp_out_norm { + norm.forward(&mlp)? + } else { + mlp + }; + let mut out = &mlp + &h; + + if let (Some(block), Some(per_layer_input)) = (&self.per_layer_input, per_layer_input) { + let residual = out.clone(); + let mut gated = block.input_gate.forward(&out)?; + gated = match block.activation { + Activation::Silu => &mlx_rs::ops::sigmoid(&gated)? * &gated, + Activation::GeluApproximate => mlx_rs::nn::gelu_approximate(&gated)?, + }; + gated = gated.multiply(per_layer_input)?; + gated = block.projection.forward(&gated)?; + gated = block.post_norm.forward(&gated)?; + out = &gated + &residual; + } + + if let Some(layer_scalar) = &self.layer_scalar { + out = out.multiply(layer_scalar)?; + } + + Ok(out) + } +} + +pub(crate) struct PerLayerInputBlock { + pub(super) input_gate: QuantizedLinear, + pub(super) projection: QuantizedLinear, + pub(super) post_norm: NormKind, + pub(super) activation: Activation, +} diff --git a/mesh-llm/src/mlx/model/lfm2.rs b/mesh-llm/src/mlx/model/lfm2.rs new file mode 100644 index 00000000..8eadc34b --- /dev/null +++ b/mesh-llm/src/mlx/model/lfm2.rs @@ -0,0 +1,44 @@ +use super::*; + +pub(crate) struct Lfm2ShortConv { + pub(super) conv_weight: Array, + pub(super) in_proj: QuantizedLinear, + pub(super) out_proj: QuantizedLinear, + pub(super) hidden_size: i32, + pub(super) conv_l_cache: i32, +} + +impl Lfm2ShortConv { + pub(super) fn forward_no_cache(&self, x: &Array) -> Result<Array> { + let bcx = self.in_proj.forward(x)?; + let hidden = self.hidden_size; + let b = bcx.index((std::ops::RangeFull, std::ops::RangeFull, 0..hidden)); + let c = bcx.index(( + std::ops::RangeFull, + std::ops::RangeFull, + hidden..(hidden * 2), + )); + let x_proj = bcx.index(( + std::ops::RangeFull, + std::ops::RangeFull, + (hidden * 2)..(hidden * 3), + )); + let bx = b.multiply(&x_proj)?; + let bx = pad( + &bx, + &[(0, 0), (self.conv_l_cache - 1, 0), (0, 0)], + None::<Array>, + None::<mlx_rs::ops::PadMode>, + )?; + let conv_out = conv1d( + &bx, + &self.conv_weight, + None::<i32>, + None::<i32>, + None::<i32>, + Some(self.hidden_size), + )?; + let y = c.multiply(&conv_out)?; + self.out_proj.forward(&y) + } +} diff --git a/mesh-llm/src/mlx/model/loader.rs b/mesh-llm/src/mlx/model/loader.rs new file mode 100644 index 00000000..bd3c7596 --- /dev/null +++ b/mesh-llm/src/mlx/model/loader.rs @@ -0,0 +1,453 @@ +use super::artifacts::{load_all_safetensors, load_tokenizer, tensor_prefixes}; +use super::config::{ + attention_window_size_for_layer, kv_shared_source_for_layer, normalized_model_config_json, +}; +use super::families::{ + apply_family_tensor_transforms, build_deepseek_v3_layer, build_gpt_oss_layer, + build_kimi_linear_layer, build_lfm2_layer, build_standard_layer, +}; +use super::family::{ensure_supported_mlx_model, model_architecture, uses_traditional_rope}; +use super::*; +use serde_json::Value; +use std::path::Path; + +impl MlxModel { + /// Load an MLX model from a directory containing config.json, + /// tokenizer.json, and model.safetensors. + pub fn load(dir: &Path) -> Result<Self> { + tracing::info!("MLX: loading model directory {}", dir.display()); + let config_text = + std::fs::read_to_string(dir.join("config.json")).context("reading config.json")?; + let config_json: Value = + serde_json::from_str(&config_text).context("parsing config.json")?; + ensure_supported_mlx_model(dir, &config_json)?; + let effective_config_json = normalized_model_config_json(&config_json); + let arch = model_architecture(&config_json); + let mut config: ModelConfig = + serde_json::from_value(effective_config_json).context("parsing config.json")?; + if arch.is_gemma3() { + config.eos_token_id.retain(|id| *id != 106); + } + let rope_traditional = uses_traditional_rope(&config_json); + + let quantized = config.quantization.as_ref(); + let default_group_size = quantized.map(|q| q.group_size).unwrap_or(0); + let default_bits = quantized.map(|q| q.bits).unwrap_or(0); + + if let Some(qcfg) = quantized { + tracing::info!( + "MLX: loading {} layers, hidden={}, heads={}/{}, quant={}bit/g{}", + config.num_hidden_layers, + config.hidden_size, + config.num_attention_heads, + config.num_key_value_heads, + qcfg.bits, + qcfg.group_size, + ); + } else { + tracing::info!( + "MLX: loading {} layers, hidden={}, heads={}/{}, dense_dtype={:?}", + config.num_hidden_layers, + config.hidden_size, + config.num_attention_heads, + config.num_key_value_heads, + config_json + .get("torch_dtype") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"), + ); + } + + let start = std::time::Instant::now(); + let mut tensors = load_all_safetensors(dir)?; + tracing::info!( + "MLX: loaded {} tensors in {:.2}s", + tensors.len(), + start.elapsed().as_secs_f64() + ); + let prefixes = tensor_prefixes(&tensors)?; + apply_family_tensor_transforms( + arch, + &mut tensors, + &prefixes, + &config, + &config_json, + default_group_size, + default_bits, + )?; + + let load_qlinear = |prefix: &str| -> Result<QuantizedLinear> { + let weight = tensors + .get(&format!("{prefix}.weight")) + .cloned() + .with_context(|| format!("missing {prefix}.weight"))?; + let bias = tensors.get(&format!("{prefix}.bias")).cloned(); + let scales_key = format!("{prefix}.scales"); + let biases_key = format!("{prefix}.biases"); + let has_quantized_storage = + tensors.contains_key(&scales_key) && tensors.contains_key(&biases_key); + let dense_weight_t = if quantized.is_none() || !has_quantized_storage { + Some(weight.transpose_axes(&[1, 0])?) + } else { + let (group_size, bits) = + quant_params_for(&config_json, prefix, default_group_size, default_bits); + let scales = tensors + .get(&scales_key) + .cloned() + .with_context(|| format!("missing {prefix}.scales"))?; + let biases = tensors + .get(&biases_key) + .cloned() + .with_context(|| format!("missing {prefix}.biases"))?; + if bits == 5 { + Some(cpu_dense_weight_t( + &weight, &scales, &biases, group_size, bits, + )?) + } else { + None + } + }; + let (group_size, bits) = if quantized.is_some() && has_quantized_storage { + quant_params_for(&config_json, prefix, default_group_size, default_bits) + } else { + (0, 0) + }; + let scales = tensors + .get(&scales_key) + .cloned() + .unwrap_or_else(|| array!(0.0f32)); + let biases = tensors + .get(&biases_key) + .cloned() + .unwrap_or_else(|| array!(0.0f32)); + Ok(QuantizedLinear { + weight, + scales, + biases, + bias, + group_size, + bits, + dense_weight_t, + }) + }; + + let load_multi_linear = |prefix: &str| -> Result<QuantizedMultiLinear> { + let (group_size, bits) = + quant_params_for(&config_json, prefix, default_group_size, default_bits); + Ok(QuantizedMultiLinear { + weight: tensors + .get(&format!("{prefix}.weight")) + .cloned() + .with_context(|| format!("missing {prefix}.weight"))?, + scales: tensors + .get(&format!("{prefix}.scales")) + .cloned() + .with_context(|| format!("missing {prefix}.scales"))?, + biases: tensors + .get(&format!("{prefix}.biases")) + .cloned() + .with_context(|| format!("missing {prefix}.biases"))?, + group_size, + bits, + }) + }; + + let load_switch_linear = |prefix: &str| -> Result<QuantizedSwitchLinear> { + let (group_size, bits) = + quant_params_for(&config_json, prefix, default_group_size, default_bits); + Ok(QuantizedSwitchLinear { + weight: tensors + .get(&format!("{prefix}.weight")) + .cloned() + .with_context(|| format!("missing {prefix}.weight"))?, + scales: tensors + .get(&format!("{prefix}.scales")) + .cloned() + .with_context(|| format!("missing {prefix}.scales"))?, + biases: tensors + .get(&format!("{prefix}.biases")) + .cloned() + .with_context(|| format!("missing {prefix}.biases"))?, + bias: tensors.get(&format!("{prefix}.bias")).cloned(), + group_size, + bits, + }) + }; + + let load_lfm2_conv_weight = |prefix: &str| -> Result<Array> { + let weight = tensors + .get(&format!("{prefix}.weight")) + .cloned() + .with_context(|| format!("missing {prefix}.weight"))?; + if weight.ndim() == 3 && weight.shape()[2] > weight.shape()[1] { + Ok(weight.transpose_axes(&[0, 2, 1])?) + } else { + Ok(weight) + } + }; + + let (embed_group_size, embed_bits) = quant_params_for( + &config_json, + &format!("{}.embed_tokens", prefixes.model), + default_group_size, + default_bits, + ); + let embed_weight = tensors + .get(&format!("{}.embed_tokens.weight", prefixes.model)) + .cloned() + .with_context(|| format!("missing {}.embed_tokens.weight", prefixes.model))?; + let embed_scales = tensors + .get(&format!("{}.embed_tokens.scales", prefixes.model)) + .cloned() + .unwrap_or_else(|| array!(0.0f32)); + let embed_biases = tensors + .get(&format!("{}.embed_tokens.biases", prefixes.model)) + .cloned() + .unwrap_or_else(|| array!(0.0f32)); + let embed_dense_weight = quantized.is_none().then(|| embed_weight.clone()); + let embed_dense_weight_t = if quantized.is_none() { + Some(embed_weight.transpose_axes(&[1, 0])?) + } else { + None + }; + let embed_tokens = QuantizedEmbedding { + weight: embed_weight, + scales: embed_scales, + biases: embed_biases, + group_size: embed_group_size, + bits: embed_bits, + dense_weight: embed_dense_weight, + dense_weight_t: embed_dense_weight_t, + }; + let embed_scale = if arch.uses_gemma_scaled_embeddings() { + (config.hidden_size as f32).sqrt() + } else { + 1.0 + }; + let embed_tokens_per_layer = if arch.is_gemma4() { + let (group_size, bits) = quant_params_for( + &config_json, + &format!("{}.embed_tokens_per_layer", prefixes.model), + default_group_size, + default_bits, + ); + Some(QuantizedEmbedding { + weight: tensors + .get(&format!("{}.embed_tokens_per_layer.weight", prefixes.model)) + .cloned() + .with_context(|| { + format!("missing {}.embed_tokens_per_layer.weight", prefixes.model) + })?, + scales: tensors + .get(&format!("{}.embed_tokens_per_layer.scales", prefixes.model)) + .cloned() + .unwrap_or_else(|| array!(0.0f32)), + biases: tensors + .get(&format!("{}.embed_tokens_per_layer.biases", prefixes.model)) + .cloned() + .unwrap_or_else(|| array!(0.0f32)), + group_size, + bits, + dense_weight: quantized.is_none().then(|| { + tensors[&format!("{}.embed_tokens_per_layer.weight", prefixes.model)].clone() + }), + dense_weight_t: if quantized.is_none() { + Some( + tensors[&format!("{}.embed_tokens_per_layer.weight", prefixes.model)] + .transpose_axes(&[1, 0])?, + ) + } else { + None + }, + }) + } else { + None + }; + let per_layer_projection_norm = if arch.is_gemma4() { + Some(rms_norm_kind( + tensors + .get(&format!( + "{}.per_layer_projection_norm.weight", + prefixes.model + )) + .cloned() + .with_context(|| { + format!( + "missing {}.per_layer_projection_norm.weight", + prefixes.model + ) + })?, + config.rms_norm_eps, + false, + )) + } else { + None + }; + let per_layer_model_projection = if arch.is_gemma4() { + Some(load_qlinear(&format!( + "{}.per_layer_model_projection", + prefixes.model + ))?) + } else { + None + }; + + let norm = if arch.is_olmo() { + layer_norm_kind(1e-5) + } else { + rms_norm_kind( + if arch.is_lfm2() { + tensors + .get(&format!("{}.embedding_norm.weight", prefixes.model)) + .cloned() + .with_context(|| { + format!("missing {}.embedding_norm.weight", prefixes.model) + })? + } else { + tensors + .get(&format!("{}.norm.weight", prefixes.model)) + .cloned() + .with_context(|| format!("missing {}.norm.weight", prefixes.model))? + }, + config.block_norm_eps.unwrap_or(config.rms_norm_eps), + arch.uses_gemma_norm_offset(), + ) + }; + + let head_dim = config + .head_dim + .unwrap_or_else(|| config.hidden_size / config.num_attention_heads); + let first_kv_shared_layer_idx = config + .num_kv_shared_layers + .map(|n| (config.num_hidden_layers - n).max(0) as usize) + .unwrap_or(config.num_hidden_layers as usize); + let non_shared_layer_types = config + .layer_types + .as_ref() + .map(|types| types[..first_kv_shared_layer_idx.min(types.len())].to_vec()); + + let mut layers = Vec::new(); + for i in 0..config.num_hidden_layers { + let p = format!("{}.layers.{i}", prefixes.model); + let layer_type = config + .layer_types + .as_ref() + .and_then(|types| types.get(i as usize)) + .map(String::as_str); + if arch.is_deepseek_v3() { + layers.push(build_deepseek_v3_layer( + &tensors, + &p, + i, + &config, + &load_qlinear, + &load_multi_linear, + &load_switch_linear, + )?); + continue; + } + if arch.is_lfm2() { + layers.push(build_lfm2_layer( + &tensors, + &p, + i, + &config, + head_dim, + &load_qlinear, + &load_lfm2_conv_weight, + )?); + continue; + } + if arch.is_kimi_linear() { + layers.push(build_kimi_linear_layer( + &tensors, + &p, + i, + &config, + &load_qlinear, + &load_multi_linear, + &load_switch_linear, + &load_lfm2_conv_weight, + )?); + continue; + } + if arch.is_gpt_oss() { + let window_size = + attention_window_size_for_layer(arch, &config, i as usize, layer_type)?; + layers.push(build_gpt_oss_layer( + &tensors, + &p, + &config, + head_dim, + window_size, + &load_qlinear, + &load_switch_linear, + )?); + continue; + } + layers.push(build_standard_layer( + &tensors, + &p, + i as usize, + arch, + &config, + layer_type, + head_dim, + rope_traditional, + non_shared_layer_types.as_deref(), + &load_qlinear, + &attention_window_size_for_layer, + &kv_shared_source_for_layer, + )?); + } + + let lm_head = if config.tie_word_embeddings { + None + } else if let Some(prefix) = prefixes.lm_head.as_deref() { + if tensors.contains_key(&format!("{prefix}.weight")) { + Some(load_qlinear(prefix)?) + } else { + None + } + } else { + None + }; + + let (tokenizer, tokenizer_spacing_patch) = load_tokenizer(dir, &config_json)?; + let prompt_template = crate::mlx::template::PromptTemplate::detect(dir, &config_json); + + Ok(MlxModel { + embed_tokens, + embed_scale, + embed_tokens_per_layer, + embed_tokens_per_layer_scale: arch.is_gemma4().then(|| { + config + .hidden_size_per_layer_input + .map(|dim| (dim as f32).sqrt()) + .unwrap_or(1.0) + }), + per_layer_projection_norm, + per_layer_model_projection, + per_layer_model_projection_scale: arch + .is_gemma4() + .then_some((config.hidden_size as f32).powf(-0.5)), + per_layer_input_scale: arch.is_gemma4().then_some(2.0f32.powf(-0.5)), + layers, + norm, + lm_head, + final_logit_softcapping: config.final_logit_softcapping, + config, + tokenizer, + tokenizer_spacing_patch, + prompt_template, + reasoning_family: reasoning_family(&config_json), + architecture: arch, + tokenwise_prefill: arch.is_gemma2() || arch.is_gemma3() || arch.is_gemma4(), + cacheless_generation: arch.is_gemma2() + || arch.is_gpt_oss() + || arch.is_kimi_linear() + || arch.is_lfm2(), + prompt_cache_reuse: !arch.is_gemma4(), + }) + } +} diff --git a/mesh-llm/src/mlx/model/mlp.rs b/mesh-llm/src/mlx/model/mlp.rs new file mode 100644 index 00000000..15c8ff39 --- /dev/null +++ b/mesh-llm/src/mlx/model/mlp.rs @@ -0,0 +1,262 @@ +use super::*; +use mlx_rs::array; + +fn expert_slice_2d(array: &Array, expert: i32) -> Result<Array> { + Ok(array + .take_axis(&Array::from_int(expert), 0)? + .reshape(&[array.shape()[1], array.shape()[2]])?) +} + +pub(crate) struct QuantizedSwitchLinear { + pub(super) weight: Array, + pub(super) scales: Array, + pub(super) biases: Array, + pub(super) bias: Option<Array>, + pub(super) group_size: i32, + pub(super) bits: i32, +} + +impl QuantizedSwitchLinear { + pub(super) fn forward_single(&self, x: &Array, expert: i32) -> Result<Array> { + let out = mlx_rs::ops::quantized_matmul( + x, + &expert_slice_2d(&self.weight, expert)?, + &expert_slice_2d(&self.scales, expert)?, + &expert_slice_2d(&self.biases, expert)?, + true, + self.group_size, + self.bits, + )?; + Ok(if let Some(bias) = &self.bias { + let bias = bias + .take_axis(&Array::from_int(expert), 0)? + .reshape(&[1, bias.shape()[1]])?; + out.add(&bias)? + } else { + out + }) + } +} + +pub(crate) struct MLP { + pub(super) gate_up_proj: Option<QuantizedLinear>, + pub(super) gate_proj: Option<QuantizedLinear>, + pub(super) up_proj: Option<QuantizedLinear>, + pub(super) down_proj: QuantizedLinear, + pub(super) activation: Activation, +} + +#[derive(Clone, Copy)] +pub(crate) enum Activation { + Silu, + GeluApproximate, +} + +impl MLP { + pub(super) fn forward(&self, x: &Array) -> Result<Array> { + let (gate, up) = if let Some(gate_up_proj) = &self.gate_up_proj { + let gate_up = gate_up_proj.forward(x)?; + let hidden = gate_up.shape()[gate_up.shape().len() - 1] / 2; + let gate = gate_up.index((std::ops::RangeFull, std::ops::RangeFull, 0..hidden)); + let up = gate_up.index(( + std::ops::RangeFull, + std::ops::RangeFull, + hidden..(hidden * 2), + )); + (gate, up) + } else { + ( + self.gate_proj + .as_ref() + .context("missing gate_proj for unfused MLP")? + .forward(x)?, + self.up_proj + .as_ref() + .context("missing up_proj for unfused MLP")? + .forward(x)?, + ) + }; + let gate = match self.activation { + Activation::Silu => &mlx_rs::ops::sigmoid(&gate)? * &gate, + Activation::GeluApproximate => mlx_rs::nn::gelu_approximate(&gate)?, + }; + self.down_proj.forward(&(&gate * &up)) + } +} + +pub(crate) struct DeepseekV3MoE { + pub(super) switch_gate_proj: QuantizedSwitchLinear, + pub(super) switch_up_proj: QuantizedSwitchLinear, + pub(super) switch_down_proj: QuantizedSwitchLinear, + pub(super) gate_weight: Array, + pub(super) gate_bias: Array, + pub(super) top_k: i32, + pub(super) n_group: i32, + pub(super) topk_group: i32, + pub(super) routed_scaling_factor: f32, + pub(super) norm_topk_prob: bool, + pub(super) shared_experts: Option<MLP>, +} + +impl DeepseekV3MoE { + fn gate(&self, x: &Array) -> Result<(Array, Array)> { + let mut scores = mlx_rs::ops::matmul(x, &self.gate_weight.transpose_axes(&[1, 0])?)?; + scores = mlx_rs::ops::sigmoid(&scores.as_dtype(Dtype::Float32)?)?; + let orig_scores = scores.clone(); + scores = scores.add(&self.gate_bias)?; + + if self.n_group > 1 { + let experts_per_group = scores.shape()[scores.shape().len() - 1] / self.n_group; + let scores_grouped = scores.reshape(&[-1, self.n_group, experts_per_group])?; + let top2 = mlx_rs::ops::indexing::topk_axis(&scores_grouped, 2, -1)?; + let group_scores = top2.sum_axes(&[-1], true)?; + let k = self.n_group - self.topk_group; + let group_idx = mlx_rs::ops::argpartition_axis(&group_scores, k - 1, -2)?.index(( + std::ops::RangeFull, + ..k, + std::ops::RangeFull, + )); + let scores_grouped = mlx_rs::ops::indexing::put_along_axis( + &scores_grouped, + &group_idx, + &array!(0.0f32), + -2, + )?; + scores = scores_grouped.reshape(&[-1, self.gate_weight.shape()[0]])?; + } + + let inds = mlx_rs::ops::argpartition_axis( + &scores.multiply(&array!(-1.0f32))?, + self.top_k - 1, + -1, + )? + .index((std::ops::RangeFull, ..self.top_k)); + let mut probs = mlx_rs::ops::indexing::take_along_axis(&orig_scores, &inds, -1)? + .as_dtype(Dtype::Float32)?; + if self.top_k > 1 && self.norm_topk_prob { + probs = probs.divide(&probs.sum_axes(&[-1], true)?)?; + } + probs = probs.multiply(&array!(self.routed_scaling_factor))?; + Ok((inds, probs)) + } + + fn switch_forward_single(&self, x: &Array, expert: i32) -> Result<Array> { + let x_up = self.switch_up_proj.forward_single(x, expert)?; + let x_gate = self.switch_gate_proj.forward_single(x, expert)?; + let activated = &mlx_rs::ops::sigmoid(&x_gate)? * &x_gate; + self.switch_down_proj + .forward_single(&activated.multiply(&x_up)?, expert) + } + + pub(super) fn forward(&self, x: &Array) -> Result<Array> { + let b = x.shape()[0]; + let l = x.shape()[1]; + let hidden = x.shape()[2]; + let flat = x.reshape(&[b * l, hidden])?; + let (inds, scores) = self.gate(&flat)?; + mlx_rs::transforms::eval([&inds, &scores])?; + let inds_slice = inds.as_slice::<u32>(); + let scores_slice = scores.as_slice::<f32>(); + let mut outputs = Vec::with_capacity((b * l) as usize); + for token_idx in 0..(b * l) { + let x_tok = flat.index((token_idx..token_idx + 1, std::ops::RangeFull)); + let mut token_out: Option<Array> = None; + for expert_slot in 0..self.top_k { + let offset = (token_idx * self.top_k + expert_slot) as usize; + let expert = inds_slice[offset] as i32; + let score = scores_slice[offset]; + let routed = self + .switch_forward_single(&x_tok, expert)? + .multiply(&array!(score))?; + token_out = Some(match token_out { + Some(acc) => acc.add(&routed)?, + None => routed, + }); + } + let token_out = if let Some(shared) = &self.shared_experts { + token_out.unwrap().add(&shared.forward(&x_tok)?)? + } else { + token_out.unwrap() + }; + outputs.push(token_out); + } + let output_refs: Vec<&Array> = outputs.iter().collect(); + Ok(mlx_rs::ops::concatenate_axis(&output_refs, 0)?.reshape(&[b, l, hidden])?) + } +} + +pub(crate) struct GptOssMoE { + pub(super) switch_gate_proj: QuantizedSwitchLinear, + pub(super) switch_up_proj: QuantizedSwitchLinear, + pub(super) switch_down_proj: QuantizedSwitchLinear, + pub(super) router: QuantizedLinear, + pub(super) top_k: i32, +} + +impl GptOssMoE { + fn switch_forward_single(&self, x: &Array, expert: i32) -> Result<Array> { + let x_linear = self.switch_up_proj.forward_single(x, expert)?; + let x_glu = self.switch_gate_proj.forward_single(x, expert)?; + let x_glu = mlx_rs::ops::clip(&x_glu, ((), 7.0f32))?; + let x_linear = mlx_rs::ops::clip(&x_linear, (-7.0f32, 7.0f32))?; + let out_glu = + x_glu.multiply(&mlx_rs::ops::sigmoid(&x_glu.multiply(&array!(1.702f32))?)?)?; + let activated = out_glu.multiply(&x_linear.add(&array!(1.0f32))?)?; + self.switch_down_proj.forward_single(&activated, expert) + } + + pub(super) fn forward(&self, x: &Array) -> Result<Array> { + let b = x.shape()[0]; + let l = x.shape()[1]; + let hidden = x.shape()[2]; + let flat = x.reshape(&[b * l, hidden])?; + let router_logits = self.router.forward(&flat)?.as_dtype(Dtype::Float32)?; + let inds = mlx_rs::ops::argpartition_axis( + &router_logits.multiply(&array!(-1.0f32))?, + self.top_k - 1, + -1, + )? + .index((std::ops::RangeFull, ..self.top_k)); + let weights = mlx_rs::ops::indexing::take_along_axis(&router_logits, &inds, -1)?; + let weights = mlx_rs::ops::softmax_axis(&weights, -1, true)?; + mlx_rs::transforms::eval([&inds, &weights])?; + let inds_slice = inds.as_slice::<u32>(); + let weights_slice = weights.as_slice::<f32>(); + let mut outputs = Vec::with_capacity((b * l) as usize); + for token_idx in 0..(b * l) { + let x_tok = flat.index((token_idx..token_idx + 1, std::ops::RangeFull)); + let mut token_out: Option<Array> = None; + for expert_slot in 0..self.top_k { + let offset = (token_idx * self.top_k + expert_slot) as usize; + let expert = inds_slice[offset] as i32; + let weight = weights_slice[offset]; + let routed = self + .switch_forward_single(&x_tok, expert)? + .multiply(&array!(weight))?; + token_out = Some(match token_out { + Some(acc) => acc.add(&routed)?, + None => routed, + }); + } + outputs.push(token_out.context("gpt-oss moe produced no experts")?); + } + let output_refs: Vec<&Array> = outputs.iter().collect(); + Ok(mlx_rs::ops::concatenate_axis(&output_refs, 0)?.reshape(&[b, l, hidden])?) + } +} + +pub(crate) enum MlpKind { + Dense(MLP), + DeepseekV3MoE(DeepseekV3MoE), + GptOssMoE(GptOssMoE), +} + +impl MlpKind { + pub(super) fn forward(&self, x: &Array) -> Result<Array> { + match self { + Self::Dense(mlp) => mlp.forward(x), + Self::DeepseekV3MoE(moe) => moe.forward(x), + Self::GptOssMoE(moe) => moe.forward(x), + } + } +} diff --git a/mesh-llm/src/mlx/model/mod.rs b/mesh-llm/src/mlx/model/mod.rs new file mode 100644 index 00000000..f0364e22 --- /dev/null +++ b/mesh-llm/src/mlx/model/mod.rs @@ -0,0 +1,311 @@ +//! Qwen2/Llama-style transformer model running on MLX via mlx-rs. +//! +//! Loads quantized safetensors and runs inference entirely on Metal GPU. +//! No Python, no subprocess β€” just Rust + MLX C library. + +mod artifacts; +mod attention; +mod attention_kind; +mod cache; +mod config; +mod embedding; +mod families; +mod family; +mod kimi; +mod layer; +mod lfm2; +mod loader; +mod mlp; +mod primitives; + +use anyhow::{bail, Context, Result}; +#[cfg(test)] +use artifacts::patch_phi3_special_token_whitespace; +use artifacts::TensorPrefixes; +pub use artifacts::{is_mlx_model_dir, mlx_model_dir, TokenizerSpacingPatch}; +use attention::{attention_mask, Attention, DeepseekV3Attention}; +use attention_kind::AttentionKind; +pub use cache::KVCache; +use cache::{CachedKv, QuantizedCacheArrays}; +#[cfg(test)] +use config::effective_text_config_json; +use config::experimental_quantized_kv_config; +pub(crate) use config::ModelConfig; +use embedding::{quant_params_for, QuantizedEmbedding}; +#[cfg(test)] +use family::config_supports_mlx; +pub use family::ReasoningFamily; +use family::{reasoning_family, ModelArchitecture}; +use kimi::{KimiDeltaAttention, KimiMlaAttention, KimiShortConv}; +use layer::Layer; +use lfm2::Lfm2ShortConv; +use mlp::{Activation, DeepseekV3MoE, GptOssMoE, MlpKind, QuantizedSwitchLinear, MLP}; +use mlx_rs::array; +use mlx_rs::ops::indexing::{IndexOp, TryIndexMutOp}; +use mlx_rs::ops::{conv1d, pad}; +use mlx_rs::Array; +use mlx_rs::Dtype; +use primitives::{ + cpu_dense_weight_t, layer_norm_kind, quantize_stacked_weights, rms_norm_kind, unit_rms_norm, + NormKind, QuantizedLinear, QuantizedMultiLinear, RMSNorm, +}; + +#[derive(Debug, serde::Deserialize)] +pub struct QuantConfig { + pub group_size: i32, + pub bits: i32, +} + +#[derive(Debug, serde::Deserialize)] +struct QuantOverride { + #[serde(default)] + group_size: Option<i32>, + #[serde(default)] + bits: Option<i32>, +} + +// ── Full model ── + +pub struct MlxModel { + embed_tokens: QuantizedEmbedding, + embed_scale: f32, + embed_tokens_per_layer: Option<QuantizedEmbedding>, + embed_tokens_per_layer_scale: Option<f32>, + per_layer_projection_norm: Option<NormKind>, + per_layer_model_projection: Option<QuantizedLinear>, + per_layer_model_projection_scale: Option<f32>, + per_layer_input_scale: Option<f32>, + layers: Vec<Layer>, + norm: NormKind, + lm_head: Option<QuantizedLinear>, + final_logit_softcapping: Option<f32>, + pub config: ModelConfig, + pub tokenizer: tokenizers::Tokenizer, + pub tokenizer_spacing_patch: Option<TokenizerSpacingPatch>, + pub prompt_template: crate::mlx::template::PromptTemplate, + pub reasoning_family: ReasoningFamily, + architecture: ModelArchitecture, + tokenwise_prefill: bool, + cacheless_generation: bool, + prompt_cache_reuse: bool, +} + +impl MlxModel { + /// Run a forward pass. Input shape: [1, seq_len] of u32 token IDs. + /// Returns logits [1, seq_len, vocab_size]. + pub fn forward(&self, tokens: &Array, caches: &mut [KVCache]) -> Result<Array> { + let mut h = self.embed_tokens.forward(tokens)?; + if self.embed_scale != 1.0 { + h = h.multiply(&array!(self.embed_scale))?; + } + let per_layer_inputs = if let ( + Some(embed_tokens_per_layer), + Some(embed_tokens_per_layer_scale), + Some(per_layer_projection_norm), + Some(per_layer_model_projection), + Some(per_layer_model_projection_scale), + Some(per_layer_input_scale), + Some(hidden_size_per_layer_input), + ) = ( + &self.embed_tokens_per_layer, + self.embed_tokens_per_layer_scale, + &self.per_layer_projection_norm, + &self.per_layer_model_projection, + self.per_layer_model_projection_scale, + self.per_layer_input_scale, + self.config.hidden_size_per_layer_input, + ) { + let per_layer_inputs = embed_tokens_per_layer + .forward(tokens)? + .multiply(&array!(embed_tokens_per_layer_scale))? + .reshape(&[ + tokens.shape()[0], + tokens.shape()[1], + self.config.num_hidden_layers, + hidden_size_per_layer_input, + ])?; + let per_layer_projection = per_layer_model_projection + .forward(&h)? + .multiply(&array!(per_layer_model_projection_scale))? + .reshape(&[ + h.shape()[0], + h.shape()[1], + self.config.num_hidden_layers, + hidden_size_per_layer_input, + ])?; + let per_layer_projection = per_layer_projection_norm.forward(&per_layer_projection)?; + Some( + (&per_layer_projection + &per_layer_inputs) + .multiply(&array!(per_layer_input_scale))?, + ) + } else { + None + }; + for (i, layer) in self.layers.iter().enumerate() { + let layer_input = per_layer_inputs.as_ref().map(|inputs| { + inputs.index(( + std::ops::RangeFull, + std::ops::RangeFull, + i as i32, + std::ops::RangeFull, + )) + }); + let (before, current_and_after) = caches.split_at_mut(i); + let current_cache = &mut current_and_after[0]; + let shared_cache = layer + .attn + .kv_shared_source() + .and_then(|source| before.get(source)); + h = layer.forward(&h, layer_input.as_ref(), current_cache, shared_cache)?; + } + let h = self.norm.forward(&h)?; + + let h_for_logits = if matches!(self.norm, NormKind::Layer(_)) { + h.as_dtype(Dtype::Float32)? + } else { + h.clone() + }; + let logits = if let Some(ref lm_head) = self.lm_head { + lm_head.forward(&h_for_logits)? + } else { + self.embed_tokens.as_linear().forward(&h_for_logits)? + }; + if let Some(softcap) = self.final_logit_softcapping { + let scaled = logits.divide(&array!(softcap))?; + Ok(mlx_rs::ops::tanh(&scaled)?.multiply(&array!(softcap))?) + } else { + Ok(logits) + } + } + + pub fn forward_no_cache(&self, tokens: &Array) -> Result<Array> { + let mut h = self.embed_tokens.forward(tokens)?; + if self.embed_scale != 1.0 { + h = h.multiply(&array!(self.embed_scale))?; + } + let per_layer_inputs = if let ( + Some(embed_tokens_per_layer), + Some(embed_tokens_per_layer_scale), + Some(per_layer_projection_norm), + Some(per_layer_model_projection), + Some(per_layer_model_projection_scale), + Some(per_layer_input_scale), + Some(hidden_size_per_layer_input), + ) = ( + &self.embed_tokens_per_layer, + self.embed_tokens_per_layer_scale, + &self.per_layer_projection_norm, + &self.per_layer_model_projection, + self.per_layer_model_projection_scale, + self.per_layer_input_scale, + self.config.hidden_size_per_layer_input, + ) { + let per_layer_inputs = embed_tokens_per_layer + .forward(tokens)? + .multiply(&array!(embed_tokens_per_layer_scale))? + .reshape(&[ + tokens.shape()[0], + tokens.shape()[1], + self.config.num_hidden_layers, + hidden_size_per_layer_input, + ])?; + let per_layer_projection = per_layer_model_projection + .forward(&h)? + .multiply(&array!(per_layer_model_projection_scale))? + .reshape(&[ + h.shape()[0], + h.shape()[1], + self.config.num_hidden_layers, + hidden_size_per_layer_input, + ])?; + let per_layer_projection = per_layer_projection_norm.forward(&per_layer_projection)?; + Some( + (&per_layer_projection + &per_layer_inputs) + .multiply(&array!(per_layer_input_scale))?, + ) + } else { + None + }; + for (i, layer) in self.layers.iter().enumerate() { + let layer_input = per_layer_inputs.as_ref().map(|inputs| { + inputs.index(( + std::ops::RangeFull, + std::ops::RangeFull, + i as i32, + std::ops::RangeFull, + )) + }); + h = layer.forward_no_cache(&h, layer_input.as_ref())?; + } + let h = self.norm.forward(&h)?; + + let h_for_logits = if matches!(self.norm, NormKind::Layer(_)) { + h.as_dtype(Dtype::Float32)? + } else { + h.clone() + }; + let logits = if let Some(ref lm_head) = self.lm_head { + lm_head.forward(&h_for_logits)? + } else { + self.embed_tokens.as_linear().forward(&h_for_logits)? + }; + if let Some(softcap) = self.final_logit_softcapping { + let scaled = logits.divide(&array!(softcap))?; + Ok(mlx_rs::ops::tanh(&scaled)?.multiply(&array!(softcap))?) + } else { + Ok(logits) + } + } + + pub fn new_caches(&self) -> Vec<KVCache> { + let quantized_kv = experimental_quantized_kv_config(); + self.layers + .iter() + .map(|layer| { + if let Some(window_size) = layer.attn.sliding_window_size() { + KVCache::new_rotating(window_size, 0) + } else if layer.attn.kv_shared_source().is_some() { + KVCache::new() + } else if let Some((group_size, bits, min_dense_tokens)) = quantized_kv { + KVCache::new_quantized(group_size, bits, min_dense_tokens) + } else { + KVCache::new() + } + }) + .collect() + } + + pub fn tokenwise_prefill(&self) -> bool { + self.tokenwise_prefill + } + + pub fn can_replay_prompt_logits(&self) -> bool { + !self.architecture.is_gemma3() && !self.architecture.is_gemma4() + } + + pub fn cacheless_generation(&self) -> bool { + self.cacheless_generation + } + + pub fn prompt_cache_reuse(&self) -> bool { + self.prompt_cache_reuse && experimental_quantized_kv_config().is_none() + } +} + +/// Argmax over the last position's logits. Returns the token ID. +pub fn argmax_last(logits: &Array) -> Result<u32> { + let shape = logits.shape(); + let flat = if shape.len() == 3 { + let last_idx = (shape[1] - 1) as i32; + let idx = Array::from_int(last_idx); + logits.take_axis(&idx, 1)?.reshape(&[-1])? + } else { + logits.reshape(&[-1])? + }; + let token = mlx_rs::ops::indexing::argmax(&flat, false)?; + mlx_rs::transforms::eval([&token])?; + Ok(token.as_slice::<u32>()[0]) +} + +#[cfg(test)] +mod tests; diff --git a/mesh-llm/src/mlx/model/primitives.rs b/mesh-llm/src/mlx/model/primitives.rs new file mode 100644 index 00000000..75180be6 --- /dev/null +++ b/mesh-llm/src/mlx/model/primitives.rs @@ -0,0 +1,216 @@ +use anyhow::{bail, Result}; +use mlx_rs::array; +use mlx_rs::ops::indexing::IndexOp; +use mlx_rs::ops::{dequantize_device, quantize}; +use mlx_rs::Array; +use mlx_rs::{Dtype, StreamOrDevice}; + +pub struct QuantizedLinear { + pub(super) weight: Array, + pub(super) scales: Array, + pub(super) biases: Array, + pub(super) bias: Option<Array>, + pub(super) group_size: i32, + pub(super) bits: i32, + pub(super) dense_weight_t: Option<Array>, +} + +impl QuantizedLinear { + pub fn forward(&self, x: &Array) -> Result<Array> { + let out = if let Some(dense_weight_t) = &self.dense_weight_t { + mlx_rs::ops::matmul(x, dense_weight_t)? + } else { + mlx_rs::ops::quantized_matmul( + x, + &self.weight, + &self.scales, + &self.biases, + true, + self.group_size, + self.bits, + )? + }; + Ok(if let Some(ref bias) = self.bias { + &out + bias + } else { + out + }) + } +} + +pub(super) fn cpu_dense_weight_t( + weight: &Array, + scales: &Array, + biases: &Array, + group_size: i32, + bits: i32, +) -> Result<Array> { + let dense_cpu = dequantize_device( + weight, + scales, + biases, + group_size, + bits, + StreamOrDevice::cpu(), + )?; + let dense_cpu = if dense_cpu.dtype() == Dtype::Float32 { + dense_cpu + } else if matches!(dense_cpu.dtype(), Dtype::Bfloat16 | Dtype::Float16) { + dense_cpu.as_dtype(Dtype::Float32)? + } else { + bail!( + "unsupported dense dequantized dtype for CPU fallback: {:?}", + dense_cpu.dtype() + ); + }; + let dense = Array::from_slice(dense_cpu.as_slice::<f32>(), dense_cpu.shape()); + + Ok(dense.transpose_axes(&[1, 0])?) +} + +pub struct RMSNorm { + pub(super) weight: Array, + pub(super) eps: f32, + pub(super) add_unit_offset: bool, +} + +impl RMSNorm { + pub fn forward(&self, x: &Array) -> Result<Array> { + if self.add_unit_offset { + let one = array!(1.0f32).as_dtype(self.weight.dtype())?; + let weight = self.weight.add(&one)?; + Ok(mlx_rs::fast::rms_norm(x, &weight, self.eps)?) + } else { + Ok(mlx_rs::fast::rms_norm(x, &self.weight, self.eps)?) + } + } +} + +pub(super) fn unit_rms_norm(x: &Array, eps: f32) -> Result<Array> { + let width = x.shape()[x.shape().len() - 1]; + let weight = mlx_rs::ops::ones::<f32>(&[width])?.as_dtype(x.dtype())?; + Ok(mlx_rs::fast::rms_norm(x, &weight, eps)?) +} + +pub struct LayerNorm { + eps: f32, +} + +impl LayerNorm { + pub fn forward(&self, x: &Array) -> Result<Array> { + Ok(mlx_rs::fast::layer_norm( + x, + None::<&Array>, + None::<&Array>, + self.eps, + )?) + } +} + +pub enum NormKind { + Rms(RMSNorm), + Layer(LayerNorm), +} + +impl NormKind { + pub fn forward(&self, x: &Array) -> Result<Array> { + match self { + Self::Rms(norm) => norm.forward(x), + Self::Layer(norm) => norm.forward(x), + } + } +} + +impl From<RMSNorm> for NormKind { + fn from(value: RMSNorm) -> Self { + Self::Rms(value) + } +} + +pub(super) fn rms_norm_kind(weight: Array, eps: f32, add_unit_offset: bool) -> NormKind { + NormKind::Rms(RMSNorm { + weight, + eps, + add_unit_offset, + }) +} + +pub(super) fn layer_norm_kind(eps: f32) -> NormKind { + NormKind::Layer(LayerNorm { eps }) +} + +pub struct QuantizedMultiLinear { + pub(super) weight: Array, + pub(super) scales: Array, + pub(super) biases: Array, + pub(super) group_size: i32, + pub(super) bits: i32, +} + +impl QuantizedMultiLinear { + pub(super) fn forward(&self, x: &Array, transpose: bool) -> Result<Array> { + let num_heads = self.weight.shape()[0]; + let mut outputs = Vec::with_capacity(num_heads as usize); + for head in 0..num_heads { + let idx = Array::from_int(head); + let w = self + .weight + .take_axis(&idx, 0)? + .reshape(&[self.weight.shape()[1], self.weight.shape()[2]])?; + let s = self + .scales + .take_axis(&idx, 0)? + .reshape(&[self.scales.shape()[1], self.scales.shape()[2]])?; + let b = self + .biases + .take_axis(&idx, 0)? + .reshape(&[self.biases.shape()[1], self.biases.shape()[2]])?; + let xh = x.index(( + std::ops::RangeFull, + head, + std::ops::RangeFull, + std::ops::RangeFull, + )); + let out = mlx_rs::ops::quantized_matmul( + &xh, + &w, + &s, + &b, + transpose, + self.group_size, + self.bits, + )?; + outputs.push(out.expand_dims(1)?); + } + let output_refs: Vec<&Array> = outputs.iter().collect(); + Ok(mlx_rs::ops::concatenate_axis(&output_refs, 1)?) + } +} + +pub(super) fn quantize_stacked_weights( + dense: &Array, + group_size: i32, + bits: i32, +) -> Result<(Array, Array, Array)> { + let num_heads = dense.shape()[0]; + let mut q_weights = Vec::with_capacity(num_heads as usize); + let mut q_scales = Vec::with_capacity(num_heads as usize); + let mut q_biases = Vec::with_capacity(num_heads as usize); + for head in 0..num_heads { + let slice = dense + .index((head, std::ops::RangeFull, std::ops::RangeFull)) + .reshape(&[dense.shape()[1], dense.shape()[2]])?; + let (w, s, b) = quantize(&slice, group_size, bits)?; + q_weights.push(w.expand_dims(0)?); + q_scales.push(s.expand_dims(0)?); + q_biases.push(b.expand_dims(0)?); + } + let q_weight_refs: Vec<&Array> = q_weights.iter().collect(); + let q_scale_refs: Vec<&Array> = q_scales.iter().collect(); + let q_bias_refs: Vec<&Array> = q_biases.iter().collect(); + Ok(( + mlx_rs::ops::concatenate_axis(&q_weight_refs, 0)?, + mlx_rs::ops::concatenate_axis(&q_scale_refs, 0)?, + mlx_rs::ops::concatenate_axis(&q_bias_refs, 0)?, + )) +} diff --git a/mesh-llm/src/mlx/model/tests.rs b/mesh-llm/src/mlx/model/tests.rs new file mode 100644 index 00000000..b7c2a082 --- /dev/null +++ b/mesh-llm/src/mlx/model/tests.rs @@ -0,0 +1,2304 @@ +use super::attention::apply_rope; +use super::config::{ + attention_window_size_for_layer, kv_shared_source_for_layer, normalized_model_config_json, +}; +use super::family::{ + ensure_supported_mlx_model, model_architecture, uses_traditional_rope, ModelArchitecture, +}; +use super::*; +use serde_json::Value; +use serial_test::serial; +use std::collections::HashMap; + +#[test] +fn mlx_model_dir_accepts_directory_and_known_files() { + let root = std::env::temp_dir().join(format!("mesh-llm-mlx-test-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write(root.join("config.json"), "{}").unwrap(); + std::fs::write(root.join("tokenizer.json"), "{}").unwrap(); + std::fs::write(root.join("model.safetensors"), b"12345678").unwrap(); + + assert_eq!(mlx_model_dir(&root), Some(root.as_path())); + assert_eq!( + mlx_model_dir(&root.join("config.json")), + Some(root.as_path()) + ); + assert_eq!( + mlx_model_dir(&root.join("model.safetensors")), + Some(root.as_path()) + ); + + std::fs::remove_file(root.join("model.safetensors")).unwrap(); + std::fs::write(root.join("model-00001-of-00002.safetensors"), b"12345678").unwrap(); + std::fs::write(root.join("model-00002-of-00002.safetensors"), b"12345678").unwrap(); + assert_eq!( + mlx_model_dir(&root.join("model-00001-of-00002.safetensors")), + Some(root.as_path()) + ); +} + +#[test] +fn config_supports_known_mlx_architectures() { + let deepseek: Value = serde_json::json!({ + "model_type": "deepseek_v3", + "architectures": ["DeepseekV3ForCausalLM"] + }); + let kimi: Value = serde_json::json!({ + "model_type": "kimi_k2", + "architectures": ["DeepseekV3ForCausalLM"] + }); + let glm4: Value = serde_json::json!({ + "model_type": "glm4", + "architectures": ["Glm4ForCausalLM"] + }); + let lfm2: Value = serde_json::json!({ + "model_type": "lfm2", + "architectures": ["Lfm2ForCausalLM"] + }); + let qwen: Value = serde_json::json!({ + "model_type": "qwen2", + "architectures": ["Qwen2ForCausalLM"] + }); + let phi3: Value = serde_json::json!({ + "model_type": "phi3", + "architectures": ["Phi3ForCausalLM"] + }); + let gpt_oss: Value = serde_json::json!({ + "model_type": "gpt_oss", + "architectures": ["GptOssForCausalLM"] + }); + let kimi_linear: Value = serde_json::json!({ + "model_type": "kimi_linear", + "architectures": ["KimiLinearForCausalLM"] + }); + let olmo2: Value = serde_json::json!({ + "model_type": "olmo2", + "architectures": ["Olmo2ForCausalLM"] + }); + let olmo: Value = serde_json::json!({ + "model_type": "olmo", + "architectures": ["OlmoForCausalLM"] + }); + let llama: Value = serde_json::json!({ + "model_type": "llama", + "architectures": ["LlamaForCausalLM"] + }); + let mistral: Value = serde_json::json!({ + "model_type": "mistral", + "architectures": ["MistralForCausalLM"] + }); + let gemma2: Value = serde_json::json!({ + "model_type": "gemma2", + "architectures": ["Gemma2ForCausalLM"] + }); + let gemma3: Value = serde_json::json!({ + "model_type": "gemma3", + "architectures": ["Gemma3ForConditionalGeneration"] + }); + let gemma4: Value = serde_json::json!({ + "model_type": "gemma4", + "architectures": ["Gemma4ForConditionalGeneration"], + "text_config": {"model_type": "gemma4_text"} + }); + + assert!(config_supports_mlx(&deepseek)); + assert!(config_supports_mlx(&kimi)); + assert!(config_supports_mlx(&glm4)); + assert!(config_supports_mlx(&lfm2)); + assert!(config_supports_mlx(&phi3)); + assert!(config_supports_mlx(&qwen)); + assert!(config_supports_mlx(&gpt_oss)); + assert!(config_supports_mlx(&kimi_linear)); + assert!(config_supports_mlx(&olmo)); + assert!(config_supports_mlx(&olmo2)); + assert!(config_supports_mlx(&llama)); + assert!(config_supports_mlx(&mistral)); + assert!(config_supports_mlx(&gemma2)); + assert!(config_supports_mlx(&gemma3)); + assert!(config_supports_mlx(&gemma4)); +} + +#[test] +fn config_rejects_other_reasoning_families_for_runtime_loading() { + let glm: Value = serde_json::json!({ + "model_type": "glm", + "architectures": ["GlmForCausalLM"] + }); + let lfm2: Value = serde_json::json!({ + "model_type": "lfm2_moe", + "architectures": ["Lfm2MoeForCausalLM"] + }); + + assert!(!config_supports_mlx(&glm)); + assert!(!config_supports_mlx(&lfm2)); +} + +#[test] +fn phi3_tokenizer_patch_preserves_role_marker_whitespace() { + let config = serde_json::json!({"model_type": "phi3"}); + let mut tokenizer = serde_json::json!({ + "added_tokens": [ + {"content":"<|user|>","rstrip":true}, + {"content":"<|assistant|>","rstrip":true}, + {"content":"<|end|>","rstrip":true}, + {"content":"<|endoftext|>","rstrip":true}, + {"content":"<irrelevant>","rstrip":true} + ] + }); + + patch_phi3_special_token_whitespace(&mut tokenizer, &config); + + let added = tokenizer["added_tokens"].as_array().unwrap(); + assert_eq!(added[0]["rstrip"], Value::Bool(false)); + assert_eq!(added[1]["rstrip"], Value::Bool(false)); + assert_eq!(added[2]["rstrip"], Value::Bool(false)); + assert_eq!(added[3]["rstrip"], Value::Bool(true)); + assert_eq!(added[4]["rstrip"], Value::Bool(true)); +} + +#[test] +fn model_config_honors_explicit_head_dim() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 1024, + "num_hidden_layers": 28, + "intermediate_size": 3072, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "head_dim": 128, + "vocab_size": 151936, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 40960, + "tie_word_embeddings": false, + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": 151645 + })) + .unwrap(); + + let derived = config.hidden_size / config.num_attention_heads; + assert_eq!(derived, 64); + assert_eq!(config.head_dim, Some(128)); + assert_eq!( + config + .head_dim + .unwrap_or_else(|| config.hidden_size / config.num_attention_heads), + 128 + ); +} + +#[test] +fn mistral_is_accepted_as_llama_like_mlx_architecture() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-mlx-mistral-supported-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + let config = serde_json::json!({ + "model_type": "mistral", + "architectures": ["MistralForCausalLM"] + }); + + ensure_supported_mlx_model(&root, &config).unwrap(); +} + +#[test] +fn olmo_is_accepted_as_mlx_architecture() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-mlx-olmo-supported-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + let config = serde_json::json!({ + "model_type": "olmo", + "architectures": ["OlmoForCausalLM"] + }); + + ensure_supported_mlx_model(&root, &config).unwrap(); +} + +#[test] +fn mistral_uses_traditional_rope() { + let config = serde_json::json!({ + "model_type": "mistral", + "architectures": ["MistralForCausalLM"] + }); + let explicit = serde_json::json!({ + "model_type": "mistral", + "architectures": ["MistralForCausalLM"], + "rope_traditional": true + }); + let llama = serde_json::json!({ + "model_type": "llama", + "architectures": ["LlamaForCausalLM"] + }); + + assert!(!uses_traditional_rope(&config)); + assert!(uses_traditional_rope(&explicit)); + assert!(!uses_traditional_rope(&llama)); +} + +#[test] +fn unsupported_architecture_error_mentions_model_type() { + let root = + std::env::temp_dir().join(format!("mesh-llm-mlx-unsupported-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + let config = serde_json::json!({ + "model_type": "starcoder2", + "architectures": ["Starcoder2ForCausalLM"] + }); + + let err = ensure_supported_mlx_model(&root, &config) + .unwrap_err() + .to_string(); + assert!(err.contains("unsupported MLX model architecture")); + assert!(err.contains("model_type=starcoder2")); + assert!(err.contains("Starcoder2ForCausalLM")); +} + +#[test] +fn unsupported_reasoning_family_errors_are_explicit() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-mlx-unsupported-reasoning-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + + for config in [ + serde_json::json!({ + "model_type": "glm", + "architectures": ["GlmForCausalLM"] + }), + serde_json::json!({ + "model_type": "lfm2_moe", + "architectures": ["Lfm2MoeForCausalLM"] + }), + ] { + let err = ensure_supported_mlx_model(&root, &config) + .unwrap_err() + .to_string(); + assert!(err.contains("unsupported MLX model architecture")); + assert!(err.contains("model_type=")); + assert!(err.contains("architectures=")); + } +} + +#[test] +fn effective_text_config_extracts_gemma3_text_config() { + let config = serde_json::json!({ + "model_type": "gemma3", + "architectures": ["Gemma3ForConditionalGeneration"], + "quantization": {"group_size": 64, "bits": 4}, + "eos_token_id": [1, 106], + "tie_word_embeddings": null, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 0.000001, + "rope_theta": 1000000, + "max_position_embeddings": 32768, + "hidden_activation": "gelu_pytorch_tanh", + "text_config": { + "model_type": "gemma3_text", + "hidden_size": 1152, + "num_hidden_layers": 26, + "intermediate_size": 6912, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "vocab_size": 262144 + } + }); + + let effective = effective_text_config_json(&config); + let parsed: ModelConfig = serde_json::from_value(effective).unwrap(); + assert_eq!(parsed.hidden_size, 1152); + assert_eq!(parsed.head_dim, Some(256)); + assert_eq!(parsed.query_pre_attn_scalar, Some(256.0)); + assert_eq!( + parsed.hidden_activation.as_deref(), + Some("gelu_pytorch_tanh") + ); + assert!(!parsed.tie_word_embeddings); + assert_eq!(parsed.eos_token_id, vec![1, 106]); +} + +#[test] +fn normalized_gemma3_config_injects_hybrid_attention_defaults() { + let raw = serde_json::json!({ + "model_type": "gemma3", + "architectures": ["Gemma3ForConditionalGeneration"], + "quantization": {"group_size": 64, "bits": 4}, + "eos_token_id": [1, 106], + "tie_word_embeddings": false, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 0.000001, + "rope_theta": 1000000.0, + "rope_local_base_freq": 10000.0, + "sliding_window": 512, + "sliding_window_pattern": 3, + "max_position_embeddings": 32768, + "text_config": { + "model_type": "gemma3_text", + "hidden_size": 1152, + "num_hidden_layers": 8, + "intermediate_size": 6912, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "vocab_size": 262144, + "layer_types": null, + "rope_parameters": null, + "use_bidirectional_attention": null + } + }); + + let normalized = normalized_model_config_json(&raw); + let parsed: ModelConfig = serde_json::from_value(normalized.clone()).unwrap(); + + assert_eq!( + normalized + .get("layer_types") + .and_then(Value::as_array) + .unwrap() + .iter() + .map(|value| value.as_str().unwrap()) + .collect::<Vec<_>>(), + vec![ + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + ] + ); + assert_eq!( + parsed + .rope_parameters + .as_ref() + .and_then(|params| params.get("sliding_attention")) + .and_then(|params| params.rope_theta), + Some(10_000.0) + ); + assert_eq!( + parsed + .rope_parameters + .as_ref() + .and_then(|params| params.get("full_attention")) + .and_then(|params| params.rope_theta), + Some(1_000_000.0) + ); + assert_eq!( + normalized + .get("use_bidirectional_attention") + .and_then(Value::as_bool), + Some(false) + ); + assert_eq!( + parsed.layer_types.as_ref().map(Vec::len), + Some(parsed.num_hidden_layers as usize) + ); +} + +#[test] +fn model_architecture_detects_gemma3_from_text_config() { + let config = serde_json::json!({ + "model_type": "gemma3", + "architectures": ["Gemma3ForConditionalGeneration"], + "text_config": {"model_type": "gemma3_text"} + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Gemma3); +} + +#[test] +fn model_architecture_detects_gemma2() { + let config = serde_json::json!({ + "model_type": "gemma2", + "architectures": ["Gemma2ForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Gemma2); +} + +#[test] +fn model_architecture_detects_glm4() { + let config = serde_json::json!({ + "model_type": "glm4", + "architectures": ["Glm4ForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Glm4); +} + +#[test] +fn model_architecture_detects_lfm2() { + let config = serde_json::json!({ + "model_type": "lfm2", + "architectures": ["Lfm2ForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Lfm2); +} + +#[test] +fn model_architecture_detects_olmo() { + let config = serde_json::json!({ + "model_type": "olmo", + "architectures": ["OlmoForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Olmo); +} + +#[test] +fn model_architecture_detects_deepseek_v3() { + let config = serde_json::json!({ + "model_type": "deepseek_v3", + "architectures": ["DeepseekV3ForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::DeepseekV3); +} + +#[test] +fn model_architecture_detects_gpt_oss() { + let config = serde_json::json!({ + "model_type": "gpt_oss", + "architectures": ["GptOssForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::GptOss); +} + +#[test] +fn model_architecture_detects_kimi_linear() { + let config = serde_json::json!({ + "model_type": "kimi_linear", + "architectures": ["KimiLinearForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::KimiLinear); +} + +#[test] +fn model_architecture_detects_olmo2() { + let config = serde_json::json!({ + "model_type": "olmo2", + "architectures": ["Olmo2ForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Olmo2); +} + +#[test] +fn model_architecture_detects_kimi_k2_as_deepseek_v3_runtime() { + let config = serde_json::json!({ + "model_type": "kimi_k25", + "architectures": ["DeepseekV3ForCausalLM"] + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::DeepseekV3); +} + +#[test] +fn glm4_config_parses_partial_rotary_factor() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "glm4", + "hidden_size": 4096, + "num_hidden_layers": 40, + "intermediate_size": 13696, + "num_attention_heads": 32, + "num_key_value_heads": 2, + "head_dim": 128, + "vocab_size": 151552, + "rms_norm_eps": 0.00001, + "rope_theta": 10000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 32768, + "tie_word_embeddings": false, + "hidden_act": "silu", + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": 151329 + })) + .unwrap(); + + assert_eq!(config.partial_rotary_factor, Some(0.5)); + assert_eq!(config.head_dim, Some(128)); +} + +#[test] +fn deepseek_v3_config_parses_moe_and_mla_fields() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "deepseek_v3", + "hidden_size": 7168, + "num_hidden_layers": 61, + "intermediate_size": 18432, + "moe_intermediate_size": 2048, + "num_attention_heads": 128, + "num_key_value_heads": 128, + "n_shared_experts": 1, + "n_routed_experts": 256, + "routed_scaling_factor": 2.5, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_rope_head_dim": 64, + "qk_nope_head_dim": 128, + "v_head_dim": 128, + "n_group": 8, + "topk_group": 4, + "num_experts_per_tok": 8, + "moe_layer_freq": 1, + "first_k_dense_replace": 3, + "vocab_size": 129280, + "rms_norm_eps": 0.000001, + "rope_theta": 10000.0, + "max_position_embeddings": 163840, + "tie_word_embeddings": false, + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": [0, 1] + })) + .unwrap(); + + assert_eq!(config.moe_intermediate_size, Some(2048)); + assert_eq!(config.n_routed_experts, Some(256)); + assert_eq!(config.kv_lora_rank, Some(512)); + assert_eq!(config.q_lora_rank, Some(1536)); + assert_eq!(config.qk_rope_head_dim, Some(64)); + assert_eq!(config.qk_nope_head_dim, Some(128)); + assert_eq!(config.v_head_dim, Some(128)); + assert_eq!(config.n_group, Some(8)); + assert_eq!(config.topk_group, Some(4)); + assert_eq!(config.num_experts_per_tok, Some(8)); + assert_eq!(config.first_k_dense_replace, Some(3)); +} + +#[test] +fn lfm2_config_parses_conv_and_attention_layout() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "lfm2", + "hidden_size": 1024, + "num_hidden_layers": 16, + "intermediate_size": 6656, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "vocab_size": 65536, + "rms_norm_eps": 0.00001, + "max_position_embeddings": 128000, + "tie_word_embeddings": false, + "rope_theta": 1000000.0, + "conv_bias": false, + "conv_L_cache": 3, + "block_norm_eps": 0.00001, + "block_dim": 1024, + "block_ff_dim": 6656, + "block_multiple_of": 256, + "block_ffn_dim_multiplier": 1.0, + "block_auto_adjust_ff_dim": true, + "full_attn_idxs": [2, 5, 8, 10, 12, 14], + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": 7 + })) + .unwrap(); + + assert_eq!(config.conv_l_cache, Some(3)); + assert_eq!( + config.full_attn_idxs.as_deref(), + Some(&[2, 5, 8, 10, 12, 14][..]) + ); + assert_eq!(config.block_norm_eps, Some(0.00001)); +} + +#[test] +fn gemma2_config_parses_attention_softcaps() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 2304, + "num_hidden_layers": 26, + "intermediate_size": 9216, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "vocab_size": 256000, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 8192, + "tie_word_embeddings": false, + "hidden_activation": "gelu_pytorch_tanh", + "attn_logit_softcapping": 50.0, + "final_logit_softcapping": 30.0, + "sliding_window": 4096, + "cache_implementation": "hybrid", + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": 1 + })) + .unwrap(); + + assert_eq!(config.attn_logit_softcapping, Some(50.0)); + assert_eq!(config.final_logit_softcapping, Some(30.0)); + assert_eq!(config.sliding_window, Some(4096)); + assert_eq!(config.cache_implementation.as_deref(), Some("hybrid")); +} + +#[test] +fn gemma2_real_hf_config_parses() { + let raw = serde_json::json!({ + "architectures": ["Gemma2ForCausalLM"], + "attention_bias": false, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": [1, 107], + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "quantization": { + "group_size": 64, + "bits": 4 + }, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "bfloat16", + "transformers_version": "4.42.4", + "use_cache": true, + "vocab_size": 256000 + }); + let config: ModelConfig = serde_json::from_value(normalized_model_config_json(&raw)).unwrap(); + + assert_eq!(config.eos_token_id, vec![1, 107]); + assert_eq!(config.cache_implementation.as_deref(), Some("hybrid")); + assert_eq!( + config.hidden_activation.as_deref(), + Some("gelu_pytorch_tanh") + ); +} + +#[test] +fn effective_text_config_extracts_gemma4_text_config() { + let config = serde_json::json!({ + "model_type": "gemma4", + "architectures": ["Gemma4ForConditionalGeneration"], + "quantization": {"group_size": 64, "bits": 4}, + "eos_token_id": [1, 106], + "tie_word_embeddings": false, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "rope_theta": 10000.0, + "text_config": { + "model_type": "gemma4_text", + "hidden_size": 2560, + "hidden_size_per_layer_input": 256, + "num_hidden_layers": 42, + "intermediate_size": 10240, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "num_kv_shared_layers": 18, + "head_dim": 256, + "global_head_dim": 512, + "query_pre_attn_scalar": 256.0, + "vocab_size": 262400, + "vocab_size_per_layer_input": 128, + "layer_types": ["sliding_attention", "full_attention"], + "final_logit_softcapping": 30.0 + } + }); + + let effective = effective_text_config_json(&config); + let parsed: ModelConfig = serde_json::from_value(effective).unwrap(); + assert_eq!(parsed.hidden_size, 2560); + assert_eq!(parsed.hidden_size_per_layer_input, Some(256)); + assert_eq!(parsed.head_dim, Some(256)); + assert_eq!(parsed.global_head_dim, Some(512)); + assert_eq!(parsed.num_kv_shared_layers, Some(18)); + assert_eq!(parsed.vocab_size_per_layer_input, Some(128)); + assert_eq!( + parsed.layer_types.as_deref(), + Some( + &[ + "sliding_attention".to_string(), + "full_attention".to_string() + ][..] + ) + ); + assert_eq!(parsed.final_logit_softcapping, Some(30.0)); +} + +#[test] +fn model_architecture_detects_gemma4_from_text_config() { + let config = serde_json::json!({ + "model_type": "gemma4", + "architectures": ["Gemma4ForConditionalGeneration"], + "text_config": {"model_type": "gemma4_text"} + }); + + assert_eq!(model_architecture(&config), ModelArchitecture::Gemma4); +} + +#[test] +fn qwen3_flat_rope_parameters_are_accepted() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 1024, + "num_hidden_layers": 28, + "intermediate_size": 3072, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "head_dim": 128, + "vocab_size": 151936, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 40960, + "tie_word_embeddings": true, + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": 151645, + "rope_theta": 1000000.0, + "rope_parameters": { + "rope_theta": 1000000.0, + "rope_type": "default" + } + })) + .unwrap(); + + let params = config.rope_parameters.unwrap(); + assert_eq!( + params.get("default").and_then(|p| p.rope_theta), + Some(1000000.0) + ); +} + +#[test] +fn qwen3_real_hf_config_parses_qk_norm_and_rope_scaling() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "qwen3", + "architectures": ["Qwen3ForCausalLM"], + "hidden_size": 1024, + "num_hidden_layers": 28, + "intermediate_size": 3072, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "head_dim": 128, + "vocab_size": 151936, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 40960, + "tie_word_embeddings": true, + "rope_theta": 1000000.0, + "rope_scaling": { + "rope_type": "yarn", + "factor": 4.0, + "original_max_position_embeddings": 32768 + }, + "eos_token_id": 151645 + })) + .unwrap(); + + assert_eq!( + reasoning_family(&serde_json::json!({ + "model_type": "qwen3", + "architectures": ["Qwen3ForCausalLM"] + })), + ReasoningFamily::Qwen3 + ); + assert_eq!(config.head_dim, Some(128)); + assert_eq!(config.num_key_value_heads, 8); + assert_eq!(config.rope_theta, 1000000.0); +} + +#[test] +fn olmo2_real_hf_config_parses_qk_norm_style_fields() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "olmo2", + "architectures": ["Olmo2ForCausalLM"], + "hidden_size": 4096, + "num_hidden_layers": 32, + "intermediate_size": 11008, + "num_attention_heads": 32, + "num_key_value_heads": 32, + "head_dim": 128, + "vocab_size": 50304, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 4096, + "tie_word_embeddings": false, + "attention_bias": false, + "rope_theta": 10000.0, + "eos_token_id": 50279 + })) + .unwrap(); + + assert_eq!( + model_architecture(&serde_json::json!({ + "model_type": "olmo2", + "architectures": ["Olmo2ForCausalLM"] + })), + ModelArchitecture::Olmo2 + ); + assert_eq!(config.head_dim, Some(128)); + assert!(!config.tie_word_embeddings); +} + +#[test] +fn gpt_oss_real_hf_config_parses_sliding_window_layers() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "gpt_oss", + "architectures": ["GptOssForCausalLM"], + "hidden_size": 2880, + "num_hidden_layers": 24, + "intermediate_size": 2880, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "head_dim": 64, + "vocab_size": 201088, + "rms_norm_eps": 0.00001, + "rope_theta": 150000.0, + "max_position_embeddings": 131072, + "sliding_window": 128, + "layer_types": ["sliding_attention", "full_attention", "sliding_attention"], + "num_experts_per_tok": 4, + "tie_word_embeddings": false, + "eos_token_id": [199999, 200002] + })) + .unwrap(); + + assert_eq!( + model_architecture(&serde_json::json!({ + "model_type": "gpt_oss", + "architectures": ["GptOssForCausalLM"] + })), + ModelArchitecture::GptOss + ); + assert_eq!( + reasoning_family(&serde_json::json!({ + "model_type": "gpt_oss", + "architectures": ["GptOssForCausalLM"] + })), + ReasoningFamily::GptOss + ); + assert_eq!(config.sliding_window, Some(128)); + assert_eq!( + config.layer_types.as_deref(), + Some( + &[ + "sliding_attention".to_string(), + "full_attention".to_string(), + "sliding_attention".to_string() + ][..] + ) + ); + assert_eq!(config.num_experts_per_tok, Some(4)); + assert_eq!(config.eos_token_id, vec![199999, 200002]); +} + +#[test] +fn gemma3_real_hf_config_parses_hybrid_cache_fields() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "model_type": "gemma3_text", + "architectures": ["Gemma3ForCausalLM"], + "hidden_size": 1152, + "num_hidden_layers": 26, + "intermediate_size": 6912, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "vocab_size": 262144, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "rope_theta": 1000000.0, + "sliding_window": 512, + "sliding_window_pattern": 6, + "cache_implementation": "hybrid", + "tie_word_embeddings": false, + "eos_token_id": [1, 106] + })) + .unwrap(); + + assert_eq!(config.sliding_window, Some(512)); + assert_eq!(config.sliding_window_pattern, Some(6)); + assert_eq!(config.cache_implementation.as_deref(), Some("hybrid")); +} + +#[test] +fn attention_window_size_for_gpt_oss_uses_layer_types() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 2880, + "num_hidden_layers": 3, + "intermediate_size": 2880, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "head_dim": 64, + "vocab_size": 201088, + "rms_norm_eps": 0.00001, + "rope_theta": 150000.0, + "max_position_embeddings": 131072, + "sliding_window": 128, + "layer_types": ["sliding_attention", "full_attention", "sliding_attention"], + "tie_word_embeddings": false, + "eos_token_id": [199999, 200002] + })) + .unwrap(); + + assert_eq!( + attention_window_size_for_layer( + ModelArchitecture::GptOss, + &config, + 0, + Some("sliding_attention") + ) + .unwrap(), + Some(128) + ); + assert_eq!( + attention_window_size_for_layer( + ModelArchitecture::GptOss, + &config, + 1, + Some("full_attention") + ) + .unwrap(), + None + ); +} + +#[test] +fn attention_window_size_for_gemma3_matches_hybrid_pattern() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 1152, + "num_hidden_layers": 8, + "intermediate_size": 6912, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "vocab_size": 262144, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "rope_theta": 1000000.0, + "sliding_window": 512, + "sliding_window_pattern": 3, + "cache_implementation": "hybrid", + "tie_word_embeddings": false, + "eos_token_id": [1, 106] + })) + .unwrap(); + + assert_eq!( + attention_window_size_for_layer(ModelArchitecture::Gemma3, &config, 0, None).unwrap(), + Some(512) + ); + assert_eq!( + attention_window_size_for_layer(ModelArchitecture::Gemma3, &config, 1, None).unwrap(), + Some(512) + ); + assert_eq!( + attention_window_size_for_layer(ModelArchitecture::Gemma3, &config, 2, None).unwrap(), + None + ); + assert_eq!( + attention_window_size_for_layer(ModelArchitecture::Gemma3, &config, 3, None).unwrap(), + Some(512) + ); +} + +#[test] +fn attention_window_size_for_gemma4_uses_layer_types() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 2560, + "num_hidden_layers": 4, + "intermediate_size": 10240, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "num_kv_shared_layers": 2, + "vocab_size": 262400, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "rope_theta": 10000.0, + "sliding_window": 512, + "layer_types": ["sliding_attention", "full_attention", "sliding_attention", "full_attention"], + "tie_word_embeddings": false, + "eos_token_id": [1, 106] + })) + .unwrap(); + + assert_eq!( + attention_window_size_for_layer( + ModelArchitecture::Gemma4, + &config, + 0, + Some("sliding_attention") + ) + .unwrap(), + None + ); + assert_eq!( + attention_window_size_for_layer( + ModelArchitecture::Gemma4, + &config, + 1, + Some("full_attention") + ) + .unwrap(), + None + ); +} + +#[test] +fn kv_shared_source_for_gemma4_matches_previous_layer_type() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 2560, + "num_hidden_layers": 6, + "intermediate_size": 10240, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "num_kv_shared_layers": 2, + "vocab_size": 262400, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "rope_theta": 10000.0, + "sliding_window": 512, + "layer_types": [ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention" + ], + "tie_word_embeddings": false, + "eos_token_id": [1, 106] + })) + .unwrap(); + + let non_shared = &config.layer_types.as_ref().unwrap()[..4]; + assert_eq!( + kv_shared_source_for_layer( + ModelArchitecture::Gemma4, + &config, + 4, + Some("sliding_attention"), + Some(non_shared) + ), + Some(2) + ); + assert_eq!( + kv_shared_source_for_layer( + ModelArchitecture::Gemma4, + &config, + 5, + Some("full_attention"), + Some(non_shared) + ), + Some(3) + ); + assert_eq!( + kv_shared_source_for_layer( + ModelArchitecture::Gemma4, + &config, + 1, + Some("full_attention"), + Some(non_shared) + ), + None + ); +} + +#[test] +fn gemma3_uses_scaled_embeddings() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 1152, + "num_hidden_layers": 26, + "intermediate_size": 6912, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "vocab_size": 262144, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "tie_word_embeddings": null, + "hidden_activation": "gelu_pytorch_tanh", + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": [1, 106] + })) + .unwrap(); + + let embed_scale = (config.hidden_size as f32).sqrt(); + assert!((embed_scale - 33.941124).abs() < 0.001); +} + +#[test] +fn gemma4_uses_scaled_main_and_per_layer_embeddings() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 2560, + "hidden_size_per_layer_input": 256, + "num_hidden_layers": 42, + "intermediate_size": 10240, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "num_kv_shared_layers": 18, + "vocab_size": 262400, + "vocab_size_per_layer_input": 128, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "tie_word_embeddings": false, + "query_pre_attn_scalar": 256.0, + "rope_theta": 10000.0, + "layer_types": ["sliding_attention", "full_attention"], + "quantization": { + "group_size": 64, + "bits": 4 + }, + "eos_token_id": [1, 106] + })) + .unwrap(); + + let embed_scale = (config.hidden_size as f32).sqrt(); + let per_layer_scale = (config.hidden_size_per_layer_input.unwrap() as f32).sqrt(); + assert!((embed_scale - 50.596443).abs() < 0.001); + assert!((per_layer_scale - 16.0).abs() < 0.001); +} + +#[test] +fn quant_params_for_uses_tensor_specific_overrides() { + let config = serde_json::json!({ + "quantization": { + "group_size": 64, + "bits": 4, + "language_model.model.embed_tokens": {"group_size": 64, "bits": 6}, + "language_model.model.layers.0.self_attn.q_proj": {"group_size": 64, "bits": 8} + } + }); + + assert_eq!( + quant_params_for(&config, "language_model.model.embed_tokens", 64, 4), + (64, 6) + ); + assert_eq!( + quant_params_for( + &config, + "language_model.model.layers.0.self_attn.q_proj", + 64, + 4 + ), + (64, 8) + ); + assert_eq!( + quant_params_for( + &config, + "language_model.model.layers.0.mlp.down_proj", + 64, + 4 + ), + (64, 4) + ); +} + +#[test] +fn dense_model_config_is_allowed_without_quantization_block() { + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 1024, + "num_hidden_layers": 28, + "intermediate_size": 3072, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "head_dim": 64, + "vocab_size": 151936, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 40960, + "tie_word_embeddings": true, + "eos_token_id": 151645 + })) + .unwrap(); + + assert!(config.quantization.is_none()); + assert!(config.tie_word_embeddings); +} + +#[test] +#[serial] +fn dense_embeddings_can_project_logits_through_as_linear() { + let weight = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]); + let embedding = QuantizedEmbedding { + weight: weight.clone(), + scales: array!(0.0f32), + biases: array!(0.0f32), + group_size: 0, + bits: 0, + dense_weight: Some(weight.clone()), + dense_weight_t: Some(weight.transpose_axes(&[1, 0]).unwrap()), + }; + let hidden = Array::from_slice(&[10.0f32, 20.0], &[1, 1, 2]); + + let logits = embedding.as_linear().forward(&hidden).unwrap(); + + assert_eq!(logits.as_slice::<f32>(), &[50.0, 110.0, 170.0]); +} + +fn dense_linear(weight: &[f32], out_dim: i32, in_dim: i32) -> QuantizedLinear { + let weight = Array::from_slice(weight, &[out_dim, in_dim]); + QuantizedLinear { + weight: weight.clone(), + scales: array!(0.0f32), + biases: array!(0.0f32), + bias: None, + group_size: 0, + bits: 0, + dense_weight_t: Some(weight.transpose_axes(&[1, 0]).unwrap()), + } +} + +fn identity_dense_linear(dim: i32) -> QuantizedLinear { + let mut weight = vec![0.0f32; (dim * dim) as usize]; + for i in 0..dim as usize { + weight[i * dim as usize + i] = 1.0; + } + dense_linear(&weight, dim, dim) +} + +fn assert_arrays_close(actual: &Array, expected: &Array, tol: f32) { + let actual = actual.as_dtype(Dtype::Float32).unwrap(); + let expected = expected.as_dtype(Dtype::Float32).unwrap(); + let actual_slice = actual.as_slice::<f32>(); + let expected_slice = expected.as_slice::<f32>(); + assert_eq!(actual_slice.len(), expected_slice.len()); + for (idx, (a, b)) in actual_slice.iter().zip(expected_slice.iter()).enumerate() { + assert!( + (a - b).abs() <= tol, + "mismatch at index {idx}: actual={a} expected={b} tol={tol}" + ); + } +} + +#[test] +#[serial] +fn attention_kv_cache_matches_no_cache_for_incremental_decode() { + let attn = Attention { + q_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + k_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + v_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + o_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + q_norm: None, + k_norm: None, + v_norm: None, + num_heads: 1, + num_kv_heads: 1, + head_dim: 2, + scale: 1.0 / (2.0f32).sqrt(), + attn_logit_softcapping: None, + rope_dim: 2, + rope_theta: 10000.0, + rope_traditional: false, + window_size: None, + kv_shared_source: None, + }; + + let full = Array::from_slice(&[1.0f32, 0.0, 0.5, 1.0, -1.0, 0.25, 0.75, -0.5], &[1, 4, 2]); + let expected = attn.forward_no_cache(&full).unwrap(); + + let mut cache = KVCache::new(); + let mut outputs = Vec::new(); + for step in 0..4i32 { + let x = full.index((0..1, step..step + 1, std::ops::RangeFull)); + outputs.push(attn.forward(&x, &mut cache, None).unwrap()); + } + let output_refs: Vec<&Array> = outputs.iter().collect(); + let actual = mlx_rs::ops::concatenate_axis(&output_refs, 1).unwrap(); + + assert_eq!(cache.offset(), 4); + assert_arrays_close(&actual, &expected, 1e-4); +} + +#[test] +#[serial] +fn attention_quantized_kv_cache_stays_close_to_dense_cache() { + let dim = 32i32; + let attn = Attention { + q_proj: identity_dense_linear(dim), + k_proj: identity_dense_linear(dim), + v_proj: identity_dense_linear(dim), + o_proj: identity_dense_linear(dim), + q_norm: None, + k_norm: None, + v_norm: None, + num_heads: 1, + num_kv_heads: 1, + head_dim: dim, + scale: 1.0 / (dim as f32).sqrt(), + attn_logit_softcapping: None, + rope_dim: dim, + rope_theta: 10000.0, + rope_traditional: false, + window_size: None, + kv_shared_source: None, + }; + + let values = (0..(4 * dim)) + .map(|i| (i as f32 * 0.03125) - 1.0) + .collect::<Vec<_>>(); + let full = Array::from_slice(&values, &[1, 4, dim]); + + let mut dense_cache = KVCache::new(); + let mut dense_outputs = Vec::new(); + for step in 0..4i32 { + let x = full.index((0..1, step..step + 1, std::ops::RangeFull)); + dense_outputs.push(attn.forward(&x, &mut dense_cache, None).unwrap()); + } + let dense_output_refs: Vec<&Array> = dense_outputs.iter().collect(); + let dense_actual = mlx_rs::ops::concatenate_axis(&dense_output_refs, 1).unwrap(); + + let mut quantized_cache = KVCache::new_quantized(32, 8, 0); + let mut quantized_outputs = Vec::new(); + for step in 0..4i32 { + let x = full.index((0..1, step..step + 1, std::ops::RangeFull)); + quantized_outputs.push(attn.forward(&x, &mut quantized_cache, None).unwrap()); + } + let quantized_output_refs: Vec<&Array> = quantized_outputs.iter().collect(); + let quantized_actual = mlx_rs::ops::concatenate_axis(&quantized_output_refs, 1).unwrap(); + + assert_eq!(quantized_cache.offset(), 4); + assert_arrays_close(&quantized_actual, &dense_actual, 5e-2); +} + +#[test] +#[serial] +fn attention_quantized_kv_cache_threshold_migrates_after_dense_prefix() { + let dim = 32i32; + let attn = Attention { + q_proj: identity_dense_linear(dim), + k_proj: identity_dense_linear(dim), + v_proj: identity_dense_linear(dim), + o_proj: identity_dense_linear(dim), + q_norm: None, + k_norm: None, + v_norm: None, + num_heads: 1, + num_kv_heads: 1, + head_dim: dim, + scale: 1.0 / (dim as f32).sqrt(), + attn_logit_softcapping: None, + rope_dim: dim, + rope_theta: 10000.0, + rope_traditional: false, + window_size: None, + kv_shared_source: None, + }; + + let values = (0..(6 * dim)) + .map(|i| (((i as usize % dim as usize) as f32) / dim as f32) - 0.5) + .collect::<Vec<_>>(); + let full = Array::from_slice(&values, &[1, 6, dim]); + + let mut dense_cache = KVCache::new(); + let mut dense_outputs = Vec::new(); + for step in 0..6i32 { + let x = full.index((0..1, step..step + 1, std::ops::RangeFull)); + dense_outputs.push(attn.forward(&x, &mut dense_cache, None).unwrap()); + } + let dense_output_refs: Vec<&Array> = dense_outputs.iter().collect(); + let dense_actual = mlx_rs::ops::concatenate_axis(&dense_output_refs, 1).unwrap(); + + let mut quantized_cache = KVCache::new_quantized(32, 8, 4); + let mut quantized_outputs = Vec::new(); + for step in 0..6i32 { + let x = full.index((0..1, step..step + 1, std::ops::RangeFull)); + quantized_outputs.push(attn.forward(&x, &mut quantized_cache, None).unwrap()); + } + let quantized_output_refs: Vec<&Array> = quantized_outputs.iter().collect(); + let quantized_actual = mlx_rs::ops::concatenate_axis(&quantized_output_refs, 1).unwrap(); + + assert_eq!(quantized_cache.offset(), 6); + assert!(quantized_cache.qkeys.is_some()); + assert!(quantized_cache.qvalues.is_some()); + assert!(quantized_cache.keys.is_none()); + assert!(quantized_cache.values.is_none()); + let dense_prefix = dense_actual.index((0..1, 0..4, std::ops::RangeFull)); + let quantized_prefix = quantized_actual.index((0..1, 0..4, std::ops::RangeFull)); + let dense_tail = dense_actual.index((0..1, 4..6, std::ops::RangeFull)); + let quantized_tail = quantized_actual.index((0..1, 4..6, std::ops::RangeFull)); + assert_arrays_close(&quantized_prefix, &dense_prefix, 1e-4); + assert_arrays_close(&quantized_tail, &dense_tail, 2e-1); +} + +#[test] +#[serial] +fn rotating_kv_cache_cannot_trim_before_retained_window() { + let mut cache = KVCache::new_rotating(2, 0); + let k = Array::from_slice(&[1.0f32, 2.0], &[1, 1, 1, 2]); + let v = Array::from_slice(&[3.0f32, 4.0], &[1, 1, 1, 2]); + + cache.update(k.clone(), v.clone()).unwrap(); + cache.update(k.clone(), v.clone()).unwrap(); + cache.update_cached(k, v).unwrap(); + + assert_eq!(cache.offset(), 3); + assert_eq!(cache.retained_start(), 1); + assert!(cache.can_trim_to(1)); + assert!(!cache.can_trim_to(0)); +} + +#[test] +#[serial] +fn rotating_kv_cache_rewind_and_append_preserves_temporal_order() { + let mut cache = KVCache::new_rotating(3, 0); + for token in [1.0f32, 2.0, 3.0] { + let k = Array::from_slice(&[token], &[1, 1, 1, 1]); + let v = Array::from_slice(&[token + 10.0], &[1, 1, 1, 1]); + cache.update(k, v).unwrap(); + } + + assert!(cache.trim_to(2).unwrap()); + + let k = Array::from_slice(&[9.0f32], &[1, 1, 1, 1]); + let v = Array::from_slice(&[19.0f32], &[1, 1, 1, 1]); + let (keys, values) = cache.update(k, v).unwrap(); + + assert_eq!(cache.offset(), 3); + assert_eq!(cache.retained_start(), 0); + assert_eq!(keys.as_slice::<f32>(), &[1.0, 2.0, 9.0]); + assert_eq!(values.as_slice::<f32>(), &[11.0, 12.0, 19.0]); +} + +#[test] +#[serial] +fn standard_kv_cache_trim_materializes_prefix() { + let mut cache = KVCache::new(); + let k = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 2, 2]); + let v = Array::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[1, 1, 2, 2]); + + cache.update_cached(k, v).unwrap(); + assert!(cache.trim_to(1).unwrap()); + + assert_eq!(cache.offset(), 1); + assert_eq!(cache.keys.as_ref().unwrap().shape(), &[1, 1, 1, 2]); + assert_eq!(cache.values.as_ref().unwrap().shape(), &[1, 1, 1, 2]); + let (keys, values) = cache.views().unwrap(); + assert_eq!(keys.as_slice::<f32>(), &[1.0, 2.0]); + assert_eq!(values.as_slice::<f32>(), &[5.0, 6.0]); +} + +#[test] +#[serial] +fn quantized_kv_cache_trim_materializes_prefix() { + let mut cache = KVCache::new_quantized(64, 8, 0); + let k_data: Vec<f32> = (0..(3 * 64)).map(|i| (i as f32 / 64.0) - 1.0).collect(); + let v_data: Vec<f32> = (0..(3 * 64)).map(|i| 1.0 - (i as f32 / 64.0)).collect(); + let k = Array::from_slice(&k_data, &[1, 1, 3, 64]); + let v = Array::from_slice(&v_data, &[1, 1, 3, 64]); + + cache.update_cached(k, v).unwrap(); + assert!(cache.trim_to(2).unwrap()); + + assert_eq!(cache.offset(), 2); + assert_eq!(cache.qkeys.as_ref().unwrap().data.shape()[2], 2); + assert_eq!(cache.qvalues.as_ref().unwrap().data.shape()[2], 2); +} + +#[test] +#[serial] +fn attention_sliding_window_cache_matches_no_cache_for_incremental_decode() { + let attn = Attention { + q_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + k_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + v_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + o_proj: dense_linear(&[1.0, 0.0, 0.0, 1.0], 2, 2), + q_norm: None, + k_norm: None, + v_norm: None, + num_heads: 1, + num_kv_heads: 1, + head_dim: 2, + scale: 1.0 / (2.0f32).sqrt(), + attn_logit_softcapping: None, + rope_dim: 2, + rope_theta: 10000.0, + rope_traditional: false, + window_size: Some(2), + kv_shared_source: None, + }; + + let full = Array::from_slice(&[1.0f32, 0.0, 0.5, 1.0, -1.0, 0.25, 0.75, -0.5], &[1, 4, 2]); + let expected = attn.forward_no_cache(&full).unwrap(); + + let mut cache = KVCache::new_rotating(2, 0); + let mut outputs = Vec::new(); + for step in 0..4i32 { + let x = full.index((0..1, step..step + 1, std::ops::RangeFull)); + outputs.push(attn.forward(&x, &mut cache, None).unwrap()); + } + let output_refs: Vec<&Array> = outputs.iter().collect(); + let actual = mlx_rs::ops::concatenate_axis(&output_refs, 1).unwrap(); + + assert_eq!(cache.offset(), 4); + assert_arrays_close(&actual, &expected, 1e-4); +} + +#[test] +#[serial] +fn phi3_tensor_transform_splits_fused_attention_and_mlp_weights() { + let prefixes = TensorPrefixes { + model: "model".to_string(), + lm_head: Some("lm_head".to_string()), + }; + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 8, + "num_hidden_layers": 1, + "intermediate_size": 12, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 4, + "vocab_size": 32, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 128, + "tie_word_embeddings": false, + "quantization": { + "group_size": 2, + "bits": 4 + }, + "eos_token_id": 1 + })) + .unwrap(); + + let mut tensors = HashMap::new(); + tensors.insert( + "model.layers.0.self_attn.qkv_proj.weight".to_string(), + Array::from_slice(&vec![0u32; 24 * 3], &[24, 3]), + ); + tensors.insert( + "model.layers.0.self_attn.qkv_proj.scales".to_string(), + Array::from_slice(&vec![0.0f32; 24 * 2], &[24, 2]), + ); + tensors.insert( + "model.layers.0.self_attn.qkv_proj.biases".to_string(), + Array::from_slice(&vec![0.0f32; 24 * 2], &[24, 2]), + ); + tensors.insert( + "model.layers.0.mlp.gate_up_proj.weight".to_string(), + Array::from_slice(&vec![0u32; 24 * 3], &[24, 3]), + ); + tensors.insert( + "model.layers.0.mlp.gate_up_proj.scales".to_string(), + Array::from_slice(&vec![0.0f32; 24 * 2], &[24, 2]), + ); + tensors.insert( + "model.layers.0.mlp.gate_up_proj.biases".to_string(), + Array::from_slice(&vec![0.0f32; 24 * 2], &[24, 2]), + ); + + families::apply_family_tensor_transforms( + ModelArchitecture::LlamaLike, + &mut tensors, + &prefixes, + &config, + &serde_json::json!({"model_type": "phi3"}), + 2, + 4, + ) + .unwrap(); + + assert_eq!( + tensors["model.layers.0.self_attn.q_proj.weight"].shape(), + &[8, 3] + ); + assert_eq!( + tensors["model.layers.0.self_attn.k_proj.weight"].shape(), + &[8, 3] + ); + assert_eq!( + tensors["model.layers.0.self_attn.v_proj.weight"].shape(), + &[8, 3] + ); + assert_eq!( + tensors["model.layers.0.mlp.gate_proj.weight"].shape(), + &[12, 3] + ); + assert_eq!( + tensors["model.layers.0.mlp.up_proj.weight"].shape(), + &[12, 3] + ); +} + +#[test] +#[serial] +fn gpt_oss_tensor_transform_splits_interleaved_expert_gate_up_tensors() { + let prefixes = TensorPrefixes { + model: "model".to_string(), + lm_head: Some("lm_head".to_string()), + }; + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 8, + "num_hidden_layers": 1, + "intermediate_size": 4, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 4, + "vocab_size": 32, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 128, + "tie_word_embeddings": false, + "quantization": { + "group_size": 2, + "bits": 4 + }, + "eos_token_id": 1 + })) + .unwrap(); + + let mut tensors = HashMap::new(); + tensors.insert( + "model.layers.0.mlp.experts.gate_up_proj.weight".to_string(), + Array::from_slice( + &[ + 0.0f32, 1.0, 10.0, 11.0, 20.0, 21.0, 30.0, 31.0, 40.0, 41.0, 50.0, 51.0, + ], + &[6, 2], + ), + ); + tensors.insert( + "model.layers.0.mlp.experts.gate_up_proj.scales".to_string(), + Array::from_slice( + &[ + 0.5f32, 1.5, 10.5, 11.5, 20.5, 21.5, 30.5, 31.5, 40.5, 41.5, 50.5, 51.5, + ], + &[1, 6, 2], + ), + ); + tensors.insert( + "model.layers.0.mlp.experts.gate_up_proj_bias".to_string(), + Array::from_slice(&[0.0f32, 10.0, 20.0, 30.0, 40.0, 50.0], &[1, 6]), + ); + tensors.insert( + "model.layers.0.mlp.experts.down_proj_bias".to_string(), + Array::from_slice(&[1.0f32, 2.0, 3.0], &[3]), + ); + + families::apply_family_tensor_transforms( + ModelArchitecture::GptOss, + &mut tensors, + &prefixes, + &config, + &serde_json::json!({"model_type": "gpt_oss"}), + 2, + 4, + ) + .unwrap(); + + assert_eq!( + tensors["model.layers.0.mlp.experts.gate_proj.weight"].as_slice::<f32>(), + &[0.0, 1.0, 20.0, 21.0, 40.0, 41.0] + ); + assert_eq!( + tensors["model.layers.0.mlp.experts.up_proj.weight"].as_slice::<f32>(), + &[10.0, 11.0, 30.0, 31.0, 50.0, 51.0] + ); + assert_eq!( + tensors["model.layers.0.mlp.experts.gate_proj.scales"].shape(), + &[1, 3, 2] + ); + assert_eq!( + tensors["model.layers.0.mlp.experts.up_proj.scales"].shape(), + &[1, 3, 2] + ); + assert_eq!( + tensors["model.layers.0.mlp.experts.gate_proj.biases"].as_slice::<f32>(), + &[0.0, 20.0, 40.0] + ); + assert_eq!( + tensors["model.layers.0.mlp.experts.up_proj.biases"].as_slice::<f32>(), + &[10.0, 30.0, 50.0] + ); + assert_eq!( + tensors["model.layers.0.mlp.experts.down_proj.biases"].as_slice::<f32>(), + &[1.0, 2.0, 3.0] + ); +} + +#[test] +#[serial] +fn gemma3_tensor_transform_drops_multimodal_tensors_and_tied_lm_head() { + let prefixes = TensorPrefixes { + model: "language_model.model".to_string(), + lm_head: Some("language_model.lm_head".to_string()), + }; + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 8, + "num_hidden_layers": 1, + "intermediate_size": 16, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "vocab_size": 32, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 128, + "tie_word_embeddings": true, + "eos_token_id": [1, 106] + })) + .unwrap(); + + let mut tensors = HashMap::new(); + tensors.insert( + "vision_tower.encoder.weight".to_string(), + Array::from_slice(&[1.0f32, 2.0], &[2]), + ); + tensors.insert( + "multi_modal_projector.linear.weight".to_string(), + Array::from_slice(&[3.0f32, 4.0], &[2]), + ); + tensors.insert( + "language_model.model.embed_tokens.weight".to_string(), + Array::from_slice(&[5.0f32, 6.0], &[2]), + ); + tensors.insert( + "language_model.lm_head.weight".to_string(), + Array::from_slice(&[7.0f32, 8.0], &[2]), + ); + + families::apply_family_tensor_transforms( + ModelArchitecture::Gemma3, + &mut tensors, + &prefixes, + &config, + &serde_json::json!({"model_type": "gemma3", "text_config": {"model_type": "gemma3_text"}}), + 64, + 4, + ) + .unwrap(); + + assert!(tensors.contains_key("language_model.model.embed_tokens.weight")); + assert!(!tensors.contains_key("vision_tower.encoder.weight")); + assert!(!tensors.contains_key("multi_modal_projector.linear.weight")); + assert!(!tensors.contains_key("language_model.lm_head.weight")); +} + +#[test] +#[serial] +fn gemma4_tensor_transform_normalizes_text_prefixes_and_drops_multimodal_tensors() { + let prefixes = TensorPrefixes { + model: "language_model.model".to_string(), + lm_head: Some("lm_head".to_string()), + }; + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 2560, + "num_hidden_layers": 2, + "intermediate_size": 10240, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "num_kv_shared_layers": 1, + "vocab_size": 262400, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 32768, + "tie_word_embeddings": true, + "eos_token_id": [1, 106] + })) + .unwrap(); + + let mut tensors = HashMap::new(); + tensors.insert( + "model.language_model.embed_tokens.weight".to_string(), + Array::from_slice(&[1.0f32, 2.0], &[2]), + ); + tensors.insert( + "model.language_model.layers.0.self_attn.q_proj.weight".to_string(), + Array::from_slice(&[3.0f32, 4.0], &[2]), + ); + tensors.insert( + "model.vision_tower.encoder.weight".to_string(), + Array::from_slice(&[5.0f32, 6.0], &[2]), + ); + tensors.insert( + "model.audio_tower.encoder.weight".to_string(), + Array::from_slice(&[7.0f32, 8.0], &[2]), + ); + tensors.insert( + "lm_head.weight".to_string(), + Array::from_slice(&[9.0f32, 10.0], &[2]), + ); + + families::apply_family_tensor_transforms( + ModelArchitecture::Gemma4, + &mut tensors, + &prefixes, + &config, + &serde_json::json!({"model_type": "gemma4", "text_config": {"model_type": "gemma4_text"}}), + 64, + 4, + ) + .unwrap(); + + assert!(tensors.contains_key("language_model.model.embed_tokens.weight")); + assert!(tensors.contains_key("language_model.model.layers.0.self_attn.q_proj.weight")); + assert!(!tensors.contains_key("model.language_model.embed_tokens.weight")); + assert!(!tensors.contains_key("model.vision_tower.encoder.weight")); + assert!(!tensors.contains_key("model.audio_tower.encoder.weight")); + assert!(!tensors.contains_key("lm_head.weight")); +} + +#[test] +#[serial] +fn olmo2_tensor_transform_drops_rotary_inv_freq_tensors() { + let prefixes = TensorPrefixes { + model: "model".to_string(), + lm_head: Some("lm_head".to_string()), + }; + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 8, + "num_hidden_layers": 2, + "intermediate_size": 16, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "vocab_size": 32, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 128, + "tie_word_embeddings": false, + "eos_token_id": 1 + })) + .unwrap(); + + let mut tensors = HashMap::new(); + tensors.insert( + "model.layers.0.self_attn.rotary_emb.inv_freq".to_string(), + Array::from_slice(&[1.0f32, 2.0], &[2]), + ); + tensors.insert( + "model.layers.1.self_attn.rotary_emb.inv_freq".to_string(), + Array::from_slice(&[3.0f32, 4.0], &[2]), + ); + tensors.insert( + "model.layers.1.self_attn.q_proj.weight".to_string(), + Array::from_slice(&[5.0f32, 6.0], &[2]), + ); + + families::apply_family_tensor_transforms( + ModelArchitecture::Olmo2, + &mut tensors, + &prefixes, + &config, + &serde_json::json!({"model_type": "olmo2"}), + 64, + 4, + ) + .unwrap(); + + assert!(!tensors.contains_key("model.layers.0.self_attn.rotary_emb.inv_freq")); + assert!(!tensors.contains_key("model.layers.1.self_attn.rotary_emb.inv_freq")); + assert!(tensors.contains_key("model.layers.1.self_attn.q_proj.weight")); +} + +#[test] +#[serial] +fn llama_like_tensor_transform_drops_inv_freq_and_tied_lm_head() { + let prefixes = TensorPrefixes { + model: "model".to_string(), + lm_head: Some("lm_head".to_string()), + }; + let config: ModelConfig = serde_json::from_value(serde_json::json!({ + "hidden_size": 8, + "num_hidden_layers": 1, + "intermediate_size": 16, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "vocab_size": 32, + "rms_norm_eps": 0.000001, + "max_position_embeddings": 128, + "tie_word_embeddings": true, + "eos_token_id": 1 + })) + .unwrap(); + + let mut tensors = HashMap::new(); + tensors.insert( + "model.layers.0.self_attn.rotary_emb.inv_freq".to_string(), + Array::from_slice(&[1.0f32, 2.0], &[2]), + ); + tensors.insert( + "model.layers.0.self_attn.q_proj.weight".to_string(), + Array::from_slice(&[3.0f32, 4.0], &[2]), + ); + tensors.insert( + "lm_head.weight".to_string(), + Array::from_slice(&[5.0f32, 6.0], &[2]), + ); + + families::apply_family_tensor_transforms( + ModelArchitecture::LlamaLike, + &mut tensors, + &prefixes, + &config, + &serde_json::json!({"model_type": "llama"}), + 64, + 4, + ) + .unwrap(); + + assert!(!tensors.contains_key("model.layers.0.self_attn.rotary_emb.inv_freq")); + assert!(!tensors.contains_key("lm_head.weight")); + assert!(tensors.contains_key("model.layers.0.self_attn.q_proj.weight")); +} + +#[test] +#[ignore] +fn olmo_debug_cache_vs_no_cache_local() { + let dir = std::path::Path::new( + "/Users/jdumay/.cache/mesh-llm-debug/olmo-7b-instruct-hf-same-origin/mlx/olmo-7b-instruct-hf-bf16", + ); + assert!( + dir.exists(), + "missing local OLMo artifact at {}", + dir.display() + ); + + let model = MlxModel::load(dir).expect("load local olmo mlx artifact"); + let prompt = + "<|endoftext|><|user|>\nWhat day comes after Monday? Reply with one word.\n<|assistant|>\n"; + let encoded = model + .tokenizer + .encode(prompt, false) + .expect("tokenize prompt"); + let ids = encoded.get_ids().to_vec(); + let input = Array::from_slice(&ids, &[1, ids.len() as i32]); + + let h = model.embed_tokens.forward(&input).expect("embed"); + let ln = model.layers[0] + .attn_in_norm + .as_ref() + .expect("attn_in_norm") + .forward(&h) + .expect("ln"); + let (q, k, v, q_rope, k_rope, attn_out, h, mlp_in, mlp, layer0_out) = match &model.layers[0] { + Layer { + attn: AttentionKind::Standard(attn), + mlp, + mlp_in_norm, + .. + } => { + let shape = ln.shape(); + let (b, l) = (shape[0], shape[1]); + + let q = attn.q_proj.forward(&ln).expect("q_proj"); + let q = Attention::apply_qk_norm( + q, + attn.q_norm.as_ref(), + b, + l, + attn.num_heads, + attn.head_dim, + ) + .expect("q norm") + .transpose_axes(&[0, 2, 1, 3]) + .expect("q transpose"); + let q_rope = apply_rope( + &q, + attn.rope_dim, + attn.head_dim, + attn.rope_theta, + attn.rope_traditional, + 0, + ) + .expect("q rope"); + + let k = attn.k_proj.forward(&ln).expect("k_proj"); + let v = attn.v_proj.forward(&ln).expect("v_proj"); + let k = Attention::apply_qk_norm( + k, + attn.k_norm.as_ref(), + b, + l, + attn.num_kv_heads, + attn.head_dim, + ) + .expect("k norm") + .transpose_axes(&[0, 2, 1, 3]) + .expect("k transpose"); + let v = v + .reshape(&[b, l, attn.num_kv_heads, attn.head_dim]) + .expect("v reshape"); + let v = if let Some(norm) = &attn.v_norm { + norm.forward(&v).expect("v norm") + } else { + v + } + .transpose_axes(&[0, 2, 1, 3]) + .expect("v transpose"); + let k_rope = apply_rope( + &k, + attn.rope_dim, + attn.head_dim, + attn.rope_theta, + attn.rope_traditional, + 0, + ) + .expect("k rope"); + + let mask = if l > 1 { + Some(mlx_rs::fast::ScaledDotProductAttentionMask::Causal) + } else { + None + }; + let attn_out = + mlx_rs::fast::scaled_dot_product_attention(&q_rope, &k_rope, &v, attn.scale, mask) + .expect("attn"); + let attn_out = attn_out + .transpose_axes(&[0, 2, 1, 3]) + .expect("attn transpose") + .reshape(&[b, l, attn.num_heads * attn.head_dim]) + .expect("attn reshape"); + let attn_out = attn.o_proj.forward(&attn_out).expect("o_proj"); + + let h = &attn_out + &h; + let mlp_in = if let Some(norm) = mlp_in_norm { + norm.forward(&h).expect("mlp in norm") + } else { + h.clone() + }; + let mlp = mlp.forward(&mlp_in).expect("mlp"); + let layer0_out = &mlp + &h; + + ( + q, k, v, q_rope, k_rope, attn_out, h, mlp_in, mlp, layer0_out, + ) + } + _ => panic!("expected standard attention"), + }; + + let embed_last = h + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("embed cast"); + let ln_last = ln + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("ln cast"); + let q_last = q + .index((0, 0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("q cast"); + let k_last = k + .index((0, 0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("k cast"); + let v_last = v + .index((0, 0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("v cast"); + let q_rope_last = q_rope + .index((0, 0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("q rope cast"); + let k_rope_last = k_rope + .index((0, 0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("k rope cast"); + let attn_out_last = attn_out + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("attn_out cast"); + let h_last = h + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("h cast"); + let mlp_in_last = mlp_in + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("mlp_in cast"); + let mlp_last = mlp + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("mlp cast"); + let layer0_out_last = layer0_out + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("layer0_out cast"); + mlx_rs::transforms::eval([ + &embed_last, + &ln_last, + &q_last, + &k_last, + &v_last, + &q_rope_last, + &k_rope_last, + &attn_out_last, + &h_last, + &mlp_in_last, + &mlp_last, + &layer0_out_last, + ]) + .expect("eval debug slices"); + println!("embed {:?}", embed_last.as_slice::<f32>()); + println!("ln0 {:?}", ln_last.as_slice::<f32>()); + println!("q0 {:?}", q_last.as_slice::<f32>()); + println!("k0 {:?}", k_last.as_slice::<f32>()); + println!("v0 {:?}", v_last.as_slice::<f32>()); + println!("qrope0 {:?}", q_rope_last.as_slice::<f32>()); + println!("krope0 {:?}", k_rope_last.as_slice::<f32>()); + println!("attn_out0 {:?}", attn_out_last.as_slice::<f32>()); + println!("h0 {:?}", h_last.as_slice::<f32>()); + println!("mlp_in0 {:?}", mlp_in_last.as_slice::<f32>()); + println!("mlp0 {:?}", mlp_last.as_slice::<f32>()); + println!("layer0_out {:?}", layer0_out_last.as_slice::<f32>()); + + let mut h_all = model.embed_tokens.forward(&input).expect("embed all"); + for (i, layer) in model.layers.iter().enumerate() { + h_all = layer.forward_no_cache(&h_all, None).expect("layer forward"); + let slice = h_all + .index((0, (ids.len() as i32 - 1), 0..4)) + .as_dtype(Dtype::Float32) + .expect("layer slice"); + mlx_rs::transforms::eval([&slice]).expect("eval layer slice"); + println!("layer{idx}_h {:?}", slice.as_slice::<f32>(), idx = i); + } + + let h_norm = model.norm.forward(&h_all).expect("final norm"); + let h_norm_last = h_norm + .index((0, (ids.len() as i32 - 1), 0..8)) + .as_dtype(Dtype::Float32) + .expect("norm cast"); + let logits = if let Some(lm_head) = &model.lm_head { + lm_head.forward(&h_norm).expect("lm head") + } else { + model + .embed_tokens + .as_linear() + .forward(&h_norm) + .expect("tied lm head") + }; + let h_norm_f32 = h_norm.as_dtype(Dtype::Float32).expect("norm f32"); + let logits_f32 = if let Some(lm_head) = &model.lm_head { + lm_head.forward(&h_norm_f32).expect("lm head f32") + } else { + model + .embed_tokens + .as_linear() + .forward(&h_norm_f32) + .expect("tied lm head f32") + }; + let logits_last = logits + .index((0, (ids.len() as i32 - 1), std::ops::RangeFull)) + .as_dtype(Dtype::Float32) + .expect("logits cast"); + let logits_f32_last = logits_f32 + .index((0, (ids.len() as i32 - 1), std::ops::RangeFull)) + .as_dtype(Dtype::Float32) + .expect("logits f32 cast"); + mlx_rs::transforms::eval([&h_norm_last, &logits_last, &logits_f32_last]) + .expect("eval final outputs"); + let logits_slice = logits_last.as_slice::<f32>(); + let mut pairs: Vec<(usize, f32)> = logits_slice.iter().copied().enumerate().collect(); + pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + let top5: Vec<(usize, f32, String)> = pairs + .into_iter() + .take(5) + .map(|(idx, val)| { + ( + idx, + val, + model.tokenizer.id_to_token(idx as u32).unwrap_or_default(), + ) + }) + .collect(); + let logits_f32_slice = logits_f32_last.as_slice::<f32>(); + let mut pairs_f32: Vec<(usize, f32)> = logits_f32_slice.iter().copied().enumerate().collect(); + pairs_f32.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + let top5_f32: Vec<(usize, f32, String)> = pairs_f32 + .into_iter() + .take(5) + .map(|(idx, val)| { + ( + idx, + val, + model.tokenizer.id_to_token(idx as u32).unwrap_or_default(), + ) + }) + .collect(); + println!("final_norm {:?}", h_norm_last.as_slice::<f32>()); + println!("top5 {:?}", top5); + println!("top5_f32_norm {:?}", top5_f32); + + let no_cache_logits = model.forward_no_cache(&input).expect("no-cache forward"); + let mut caches = model.new_caches(); + let cache_logits = model.forward(&input, &mut caches).expect("cache forward"); + + let no_cache_last = no_cache_logits + .index((0, (ids.len() as i32 - 1), std::ops::RangeFull)) + .as_dtype(Dtype::Float32) + .expect("no-cache logits cast"); + let cache_last = cache_logits + .index((0, (ids.len() as i32 - 1), std::ops::RangeFull)) + .as_dtype(Dtype::Float32) + .expect("cache logits cast"); + mlx_rs::transforms::eval([&no_cache_last, &cache_last]).expect("eval logits slices"); + let describe_top = |name: &str, logits_slice: &[f32]| { + let mut pairs: Vec<(usize, f32)> = logits_slice.iter().copied().enumerate().collect(); + pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + let top10: Vec<(usize, f32, String)> = pairs + .into_iter() + .take(10) + .map(|(idx, val)| { + ( + idx, + val, + model.tokenizer.id_to_token(idx as u32).unwrap_or_default(), + ) + }) + .collect(); + println!("{name} top10 {:?}", top10); + }; + describe_top("gemma3 no_cache", no_cache_last.as_slice::<f32>()); + describe_top("gemma3 cache", cache_last.as_slice::<f32>()); + + let no_cache_token = argmax_last(&no_cache_logits).expect("argmax no-cache"); + let cache_token = argmax_last(&cache_logits).expect("argmax cache"); + let no_cache_piece = model + .tokenizer + .id_to_token(no_cache_token) + .unwrap_or_else(|| "<missing>".to_string()); + let cache_piece = model + .tokenizer + .id_to_token(cache_token) + .unwrap_or_else(|| "<missing>".to_string()); + + println!( + "no_cache_token={} piece={:?} cache_token={} piece={:?}", + no_cache_token, no_cache_piece, cache_token, cache_piece + ); + + assert_eq!(no_cache_token, cache_token); +} diff --git a/mesh-llm/src/mlx/sampling.rs b/mesh-llm/src/mlx/sampling.rs new file mode 100644 index 00000000..53bcd2f9 --- /dev/null +++ b/mesh-llm/src/mlx/sampling.rs @@ -0,0 +1,277 @@ +use anyhow::Result; +use mlx_rs::Array; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +#[derive(Debug, Clone)] +pub struct SamplingParams { + pub temperature: f32, + pub top_p: f32, + pub top_k: Option<usize>, + pub seed: Option<u64>, + pub suppressed_token_ids: Vec<u32>, +} + +impl Default for SamplingParams { + fn default() -> Self { + Self { + temperature: 0.0, + top_p: 1.0, + top_k: None, + seed: None, + suppressed_token_ids: Vec::new(), + } + } +} + +pub struct Sampler { + params: SamplingParams, + rng: Option<StdRng>, +} + +impl Sampler { + pub fn new(params: SamplingParams) -> Self { + let rng = params.seed.map(StdRng::seed_from_u64); + Self { params, rng } + } + + pub fn sample_next_token(&mut self, logits: &Array) -> Result<u32> { + let suppressed = self.params.suppressed_token_ids.as_slice(); + if self.params.temperature <= 0.0 { + return argmax_last_filtered(logits, suppressed); + } + + let mut candidates = last_logits(logits)? + .into_iter() + .enumerate() + .filter(|(token, _)| !suppressed.contains(&(*token as u32))) + .map(|(token, logit)| (token as u32, logit / self.params.temperature)) + .collect::<Vec<_>>(); + if candidates.is_empty() { + return crate::mlx::model::argmax_last(logits); + } + + candidates.sort_by(|left, right| right.1.total_cmp(&left.1)); + + if let Some(top_k) = self.params.top_k { + let top_k = top_k.max(1).min(candidates.len()); + candidates.truncate(top_k); + } + + let max_logit = candidates + .iter() + .map(|(_, logit)| *logit) + .max_by(|left, right| left.total_cmp(right)) + .unwrap_or(0.0); + + let mut weighted = candidates + .into_iter() + .map(|(token, logit)| (token, (logit - max_logit).exp())) + .collect::<Vec<_>>(); + let mut total = weighted.iter().map(|(_, weight)| *weight).sum::<f32>(); + + if self.params.top_p < 1.0 { + let mut cumulative = 0.0f32; + let mut kept = 0usize; + for (_, weight) in &weighted { + cumulative += *weight / total.max(f32::EPSILON); + kept += 1; + if cumulative >= self.params.top_p.max(0.0) { + break; + } + } + weighted.truncate(kept.max(1)); + total = weighted.iter().map(|(_, weight)| *weight).sum::<f32>(); + } + + let mut draw = self.random_f32() * total.max(f32::EPSILON); + for (token, weight) in weighted { + draw -= weight; + if draw <= 0.0 { + return Ok(token); + } + } + + Ok(crate::mlx::model::argmax_last(logits)?) + } + + fn random_f32(&mut self) -> f32 { + if let Some(rng) = &mut self.rng { + rng.random::<f32>() + } else { + rand::random::<f32>() + } + } +} + +fn argmax_last_filtered(logits: &Array, suppressed: &[u32]) -> Result<u32> { + let logits_vec = last_logits(logits)?; + let mut best: Option<(u32, f32)> = None; + for (token, logit) in logits_vec.iter().copied().enumerate() { + let token = token as u32; + if suppressed.contains(&token) { + continue; + } + match best { + Some((_, best_logit)) if logit <= best_logit => {} + _ => best = Some((token, logit)), + } + } + if let Some((token, _)) = best { + Ok(token) + } else { + crate::mlx::model::argmax_last(logits) + } +} + +pub struct StopBuffer { + sequences: Vec<String>, + pending: String, + holdback_chars: usize, + matched: bool, +} + +impl StopBuffer { + pub fn new(sequences: Vec<String>) -> Self { + let holdback_chars = sequences + .iter() + .map(|sequence| sequence.chars().count().saturating_sub(1)) + .max() + .unwrap_or(0); + Self { + sequences, + pending: String::new(), + holdback_chars, + matched: false, + } + } + + pub fn push(&mut self, text: &str) -> StopChunk { + if self.matched { + return StopChunk::default(); + } + + self.pending.push_str(text); + + if self.sequences.is_empty() { + return StopChunk { + emit: std::mem::take(&mut self.pending), + matched: false, + }; + } + + if let Some(index) = find_earliest_stop(&self.pending, &self.sequences) { + let emit = self.pending[..index].to_string(); + self.pending.clear(); + self.matched = true; + return StopChunk { + emit, + matched: true, + }; + } + + let safe_len = safe_prefix_len(&self.pending, self.holdback_chars); + if safe_len == 0 { + return StopChunk::default(); + } + + let emit = self.pending[..safe_len].to_string(); + self.pending.drain(..safe_len); + StopChunk { + emit, + matched: false, + } + } + + pub fn finish(&mut self) -> String { + if self.matched { + String::new() + } else { + std::mem::take(&mut self.pending) + } + } +} + +#[derive(Default)] +pub struct StopChunk { + pub emit: String, + pub matched: bool, +} + +fn last_logits(logits: &Array) -> Result<Vec<f32>> { + let shape = logits.shape(); + let flat = if shape.len() == 3 { + let last_idx = (shape[1] - 1) as i32; + let idx = Array::from_int(last_idx); + logits.take_axis(&idx, 1)?.reshape(&[-1])? + } else { + logits.reshape(&[-1])? + }; + let flat = flat.as_type::<f32>()?; + mlx_rs::transforms::eval([&flat])?; + Ok(flat.as_slice::<f32>().to_vec()) +} + +fn find_earliest_stop(text: &str, sequences: &[String]) -> Option<usize> { + sequences + .iter() + .filter_map(|sequence| text.find(sequence)) + .min() +} + +fn safe_prefix_len(text: &str, holdback_chars: usize) -> usize { + if holdback_chars == 0 { + return text.len(); + } + let total_chars = text.chars().count(); + if total_chars <= holdback_chars { + return 0; + } + let safe_chars = total_chars - holdback_chars; + text.char_indices() + .nth(safe_chars) + .map(|(index, _)| index) + .unwrap_or(text.len()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stop_buffer_holds_back_partial_match() { + let mut buffer = StopBuffer::new(vec!["</s>".to_string()]); + let first = buffer.push("hello</"); + assert_eq!(first.emit, "hell"); + assert!(!first.matched); + + let second = buffer.push("s>world"); + assert_eq!(second.emit, "o"); + assert!(second.matched); + assert!(buffer.finish().is_empty()); + } + + #[test] + fn stop_buffer_flushes_when_no_stop_matches() { + let mut buffer = StopBuffer::new(vec!["STOP".to_string()]); + let first = buffer.push("hel"); + assert!(first.emit.is_empty()); + let second = buffer.push("lo"); + assert_eq!(second.emit, "he"); + assert_eq!(buffer.finish(), "llo"); + } + + #[test] + fn sampler_suppresses_tokens_during_argmax() { + let logits = Array::from_slice(&[0.1f32, 0.9f32, 0.8f32], &[1, 1, 3]); + let mut sampler = Sampler::new(SamplingParams { + temperature: 0.0, + top_p: 1.0, + top_k: None, + seed: None, + suppressed_token_ids: vec![1], + }); + let token = sampler.sample_next_token(&logits).unwrap(); + assert_eq!(token, 2); + } +} diff --git a/mesh-llm/src/mlx/server.rs b/mesh-llm/src/mlx/server.rs new file mode 100644 index 00000000..153895c9 --- /dev/null +++ b/mesh-llm/src/mlx/server.rs @@ -0,0 +1,2310 @@ +//! In-process HTTP server for MLX inference. +//! +//! Drop-in replacement for llama-server β€” speaks the same OpenAI-compatible +//! HTTP API on a local port so the existing proxy routes to it unchanged. + +use super::model::{self, MlxModel}; +use super::sampling::{Sampler, SamplingParams, StopBuffer}; +use crate::inference::launch::{InferenceServerHandle, InferenceServerProcess}; +use anyhow::{Context, Result}; +use mlx_rs::Array; +use std::sync::{Arc, Once}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::{watch, Mutex}; + +/// Shared inference state behind the server. +struct InferState { + model: MlxModel, + model_name: String, + /// Prompt cache: KV caches + token IDs from the last request. + /// On the next request, we find the longest common prefix and + /// skip re-prefilling those tokens β€” huge win for agent workloads + /// where the system prompt + conversation history grows incrementally. + prompt_cache: Option<PromptCache>, +} + +struct PromptCache { + tokens: Vec<u32>, + caches: Vec<model::KVCache>, +} + +#[derive(Clone)] +struct GenerationConfig { + max_tokens: usize, + sampling: SamplingParams, + stop_sequences: Vec<String>, + stop_token_ids: Vec<u32>, + hidden_reasoning: Option<HiddenReasoningTokens>, + response_policy: ResponsePolicy, +} + +#[derive(Clone, Copy)] +struct HiddenReasoningTokens { + start_turn_token_id: u32, + end_reasoning_token_id: u32, +} + +struct GenerationOutcome { + text: String, + prompt_tokens: usize, + completion_tokens: usize, + finish_reason: &'static str, +} + +fn encode_prompt_tokens(model: &MlxModel, prompt: &str) -> Result<Vec<u32>> { + let patch = match &model.tokenizer_spacing_patch { + Some(patch) => patch, + None => { + return Ok(model + .tokenizer + .encode(prompt, false) + .map_err(|e| anyhow::anyhow!("tokenizer encode: {e}"))? + .get_ids() + .to_vec()) + } + }; + + let mut tokens = Vec::new(); + let mut cursor = 0usize; + while cursor < prompt.len() { + let remaining = &prompt[cursor..]; + if let Some((special, token_id)) = patch + .special_tokens + .iter() + .find(|(special, _)| remaining.starts_with(special)) + { + tokens.push(*token_id); + if remaining[special.len()..].starts_with(' ') { + tokens.push(patch.space_token_id); + } + cursor += special.len(); + continue; + } + + let next_special = patch + .special_tokens + .iter() + .filter_map(|(special, _)| remaining.find(special)) + .min() + .unwrap_or(remaining.len()); + let segment = &remaining[..next_special]; + if !segment.is_empty() { + let encoding = model + .tokenizer + .encode(segment, false) + .map_err(|e| anyhow::anyhow!("tokenizer encode: {e}"))?; + tokens.extend_from_slice(encoding.get_ids()); + } + cursor += next_special; + } + + Ok(tokens) +} + +fn sanitize_behavior_output(_model: &MlxModel, text: &str) -> String { + let normalized = collapse_alpha_outline_markers(text); + trim_to_behavior_safe_prefix(&normalized) +} + +fn collapse_alpha_outline_markers(text: &str) -> String { + let mut output = String::with_capacity(text.len()); + for (index, line) in text.lines().enumerate() { + if index > 0 { + output.push('\n'); + } + output.push_str(&collapse_alpha_outline_marker_line(line)); + } + output +} + +fn collapse_alpha_outline_marker_line(line: &str) -> String { + let indent_len = line.len() - line.trim_start_matches(char::is_whitespace).len(); + let (indent, trimmed) = line.split_at(indent_len); + for bullet in ["- ", "* "] { + if let Some(rest) = trimmed.strip_prefix(bullet) { + let bytes = rest.as_bytes(); + if bytes.len() >= 3 + && bytes[0].is_ascii_alphabetic() + && bytes[1] == b'.' + && bytes[2] == b' ' + { + return format!("{indent}{bullet}{}", &rest[3..]); + } + } + } + line.to_string() +} + +fn trim_to_behavior_safe_prefix(text: &str) -> String { + if !has_behavior_repetition_issue(text) { + return text.to_string(); + } + + let mut cuts: Vec<usize> = text.char_indices().map(|(index, _)| index).collect(); + cuts.push(text.len()); + cuts.sort_unstable(); + cuts.dedup(); + + for end in cuts.into_iter().rev() { + let prefix = text[..end].trim_end(); + if prefix.is_empty() { + continue; + } + if !has_behavior_repetition_issue(prefix) { + return prefix.to_string(); + } + } + + text.trim_end().to_string() +} + +fn has_behavior_repetition_issue(text: &str) -> bool { + let normalized = text.trim(); + if normalized.is_empty() { + return false; + } + + let lines: Vec<&str> = normalized + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .collect(); + if has_repeat_at_least_three(lines.iter().copied()) { + return true; + } + + let sentences = split_behavior_sentences(normalized); + if has_repeat_at_least_three(sentences.iter().map(String::as_str)) { + return true; + } + + let tokens = tokenize_behavior_words(normalized); + if repeated_ngram( + tokens + .iter() + .map(String::as_str) + .collect::<Vec<_>>() + .as_slice(), + 6, + 3, + ) + .is_some() + { + return true; + } + + if tokens.len() >= 80 { + let tail = &tokens[tokens.len() - 80..]; + let unique = tail.iter().collect::<std::collections::BTreeSet<_>>().len(); + if (unique as f32) / (tail.len() as f32) < 0.30 { + return true; + } + } + + false +} + +fn has_repeat_at_least_three<'a>(items: impl Iterator<Item = &'a str>) -> bool { + let mut counts = std::collections::HashMap::<&str, usize>::new(); + for item in items { + let count = counts.entry(item).or_insert(0); + *count += 1; + if *count >= 3 { + return true; + } + } + false +} + +fn split_behavior_sentences(text: &str) -> Vec<String> { + let mut sentences = Vec::new(); + let mut start = 0usize; + let chars: Vec<(usize, char)> = text.char_indices().collect(); + let mut index = 0usize; + while index < chars.len() { + let (byte_idx, ch) = chars[index]; + let is_sentence_break = matches!(ch, '.' | '!' | '?') + && chars + .get(index + 1) + .map(|(_, c)| c.is_whitespace()) + .unwrap_or(true); + let is_line_break = ch == '\n'; + if is_sentence_break || is_line_break { + let end = if is_line_break { + byte_idx + } else { + chars + .get(index + 1) + .map(|(next_idx, _)| *next_idx) + .unwrap_or(text.len()) + }; + let part = text[start..end].trim().to_lowercase(); + if !part.is_empty() { + sentences.push(part); + } + start = chars + .get(index + 1) + .map(|(next_idx, _)| *next_idx) + .unwrap_or(text.len()); + } + index += 1; + } + let tail = text[start..].trim().to_lowercase(); + if !tail.is_empty() { + sentences.push(tail); + } + sentences +} + +fn tokenize_behavior_words(text: &str) -> Vec<String> { + text.split_whitespace() + .map(|token| token.to_lowercase()) + .collect() +} + +fn repeated_ngram(tokens: &[&str], size: usize, threshold: usize) -> Option<String> { + if tokens.len() < size { + return None; + } + let mut counts = std::collections::HashMap::<String, usize>::new(); + for window in tokens.windows(size) { + let key = window.join(" "); + let count = counts.entry(key.clone()).or_insert(0); + *count += 1; + if *count >= threshold { + return Some(key); + } + } + None +} + +enum StreamEvent { + Text(String), + Done(&'static str), +} + +#[derive(Clone, Default)] +struct ResponsePolicy { + strip_reasoning_blocks: bool, + tagged_reasoning: Vec<crate::mlx::template::TaggedReasoningBlock>, +} + +#[derive(Default)] +struct HiddenReasoningState { + active: bool, + skip_leading_whitespace: bool, +} + +struct ResponseFilter { + policy: ResponsePolicy, + tagged_reasoning: Vec<TaggedBlockFilter>, +} + +struct TaggedBlockFilter { + start: String, + end: String, + inside_think: bool, + carry: String, +} + +impl ResponseFilter { + fn new(policy: ResponsePolicy) -> Self { + let tagged_reasoning = policy + .tagged_reasoning + .iter() + .cloned() + .map(TaggedBlockFilter::new) + .collect(); + Self { + policy, + tagged_reasoning, + } + } + + fn push(&mut self, text: &str) -> String { + if !self.policy.strip_reasoning_blocks { + return text.to_string(); + } + let mut filtered = text.to_string(); + for block in &mut self.tagged_reasoning { + filtered = block.push(&filtered); + } + filtered + } + + fn finish(&mut self) -> String { + if !self.policy.strip_reasoning_blocks { + return String::new(); + } + let mut tail = String::new(); + for (index, block) in self.tagged_reasoning.iter_mut().enumerate() { + let flushed = block.finish(); + if index == 0 { + tail = flushed; + } else if !tail.is_empty() { + tail = block.push(&tail); + tail.push_str(&block.finish()); + } + } + tail + } +} + +impl TaggedBlockFilter { + fn new(block: crate::mlx::template::TaggedReasoningBlock) -> Self { + Self { + start: block.start, + end: block.end, + inside_think: false, + carry: String::new(), + } + } + + fn push(&mut self, text: &str) -> String { + let mut input = std::mem::take(&mut self.carry); + input.push_str(text); + let mut out = String::new(); + + loop { + if self.inside_think { + if let Some(idx) = input.find(&self.end) { + input.drain(..idx + self.end.len()); + self.inside_think = false; + continue; + } + let keep = partial_tag_suffix_len(&input, &[self.end.as_str()]); + if keep > 0 { + self.carry = input[input.len() - keep..].to_string(); + } + return out; + } + + let next_start = input.find(&self.start); + let next_end = input.find(&self.end); + match (next_start, next_end) { + (Some(start), Some(end)) if end < start => { + out.push_str(&input[..end]); + input.drain(..end + self.end.len()); + } + (Some(start), _) => { + out.push_str(&input[..start]); + input.drain(..start + self.start.len()); + self.inside_think = true; + } + (None, Some(end)) => { + out.push_str(&input[..end]); + input.drain(..end + self.end.len()); + } + (None, None) => { + let keep = + partial_tag_suffix_len(&input, &[self.start.as_str(), self.end.as_str()]); + out.push_str(&input[..input.len() - keep]); + if keep > 0 { + self.carry = input[input.len() - keep..].to_string(); + } + return out; + } + } + } + } + + fn finish(&mut self) -> String { + if self.inside_think { + self.inside_think = false; + self.carry.clear(); + return String::new(); + } + std::mem::take(&mut self.carry) + } +} + +fn partial_tag_suffix_len(text: &str, tags: &[&str]) -> usize { + let mut best = 0; + for tag in tags { + for len in 1..tag.len() { + if text.ends_with(&tag[..len]) { + best = best.max(len); + } + } + } + best +} + +/// Start the MLX inference server on the given port. +/// Returns an in-process inference server handle plus a death channel. +/// +/// This is the MLX equivalent of `launch::start_llama_server` β€” the caller +/// gets back the same process wrapper used by the llama.cpp backend. +pub async fn start_mlx_server( + model_dir: &std::path::Path, + model_name: String, + port: u16, +) -> Result<InferenceServerProcess> { + static WARN_ONCE: Once = Once::new(); + WARN_ONCE.call_once(|| { + eprintln!( + "πŸ§ͺ MLX support is experimental in mesh-llm. Prefer GGUF for the most mature path, and please file any issues at https://github.com/michaelneale/mesh-llm/issues." + ); + }); + + // Load model on a blocking thread (touches disk + GPU init) + let dir = model_dir.to_path_buf(); + let model = tokio::task::spawn_blocking(move || MlxModel::load(&dir)) + .await + .context("MLX model load panicked")??; + + tracing::info!( + "MLX server: model loaded β€” {} layers, vocab={}, serving on :{port}", + model.config.num_hidden_layers, + model.config.vocab_size, + ); + + let context_length = model.config.max_position_embeddings as u32; + + let state = Arc::new(Mutex::new(InferState { + model, + model_name, + prompt_cache: None, + })); + + let listener = TcpListener::bind(format!("127.0.0.1:{port}")) + .await + .with_context(|| format!("MLX server: failed to bind port {port}"))?; + + let (death_tx, death_rx) = tokio::sync::oneshot::channel(); + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + loop { + tokio::select! { + changed = shutdown_rx.changed() => { + if changed.is_ok() && *shutdown_rx.borrow() { + break; + } + } + accept = listener.accept() => { + let (stream, _addr) = match accept { + Ok(s) => s, + Err(e) => { + tracing::warn!("MLX server: accept error: {e}"); + continue; + } + }; + let state = state.clone(); + tokio::spawn(async move { + if let Err(e) = handle_connection(stream, state).await { + tracing::debug!("MLX server: connection error: {e}"); + } + }); + } + } + } + let _ = death_tx.send(()); + }); + + Ok(InferenceServerProcess { + handle: InferenceServerHandle::in_process(shutdown_tx), + context_length, + death_rx, + }) +} + +/// Parse a raw HTTP request from the stream and dispatch. +async fn handle_connection( + mut stream: tokio::net::TcpStream, + state: Arc<Mutex<InferState>>, +) -> Result<()> { + let _ = stream.set_nodelay(true); + + let mut buf = vec![0u8; 64 * 1024]; + let mut filled = 0usize; + + // Read until we have complete headers + loop { + let n = stream.read(&mut buf[filled..]).await?; + if n == 0 { + return Ok(()); + } + filled += n; + if filled > 4 && buf[..filled].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + if filled >= buf.len() { + send_response(&mut stream, 413, r#"{"error":"request header too large"}"#).await?; + return Ok(()); + } + } + + // Find header/body split + let header_end = buf[..filled] + .windows(4) + .position(|w| w == b"\r\n\r\n") + .unwrap() + + 4; + + // Parse method, path, content-length from headers (own the strings so buf is free) + let header_str = String::from_utf8_lossy(&buf[..header_end]).to_string(); + + let first_line = header_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.split_whitespace().collect(); + let method = if parts.len() >= 2 { + parts[0].to_string() + } else { + String::new() + }; + let path = if parts.len() >= 2 { + parts[1].to_string() + } else { + String::new() + }; + + let content_length: usize = header_str + .lines() + .find_map(|line| { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("content-length:") { + lower.split(':').nth(1)?.trim().parse().ok() + } else { + None + } + }) + .unwrap_or(0); + + // Read remaining body if needed + // Guard against oversized requests (max 16 MiB) to prevent OOM from a + // malicious or accidental huge Content-Length. + const MAX_BODY_SIZE: usize = 16 * 1024 * 1024; + if content_length > MAX_BODY_SIZE { + send_response(&mut stream, 413, r#"{"error":"request body too large"}"#).await?; + return Ok(()); + } + let body_so_far = filled - header_end; + if body_so_far < content_length { + let remaining = content_length - body_so_far; + if filled + remaining > buf.len() { + buf.resize(filled + remaining, 0); + } + let mut read = 0; + while read < remaining { + let n = stream + .read(&mut buf[filled + read..filled + remaining]) + .await?; + if n == 0 { + break; + } + read += n; + } + filled += read; + } + + let body = &buf[header_end..filled.min(header_end + content_length)]; + + match (method.as_str(), path.as_str()) { + ("GET", "/health") => { + send_response(&mut stream, 200, r#"{"status":"ok"}"#).await?; + } + ("GET", "/v1/models") | ("GET", "/models") => { + let state = state.lock().await; + let resp = serde_json::json!({ + "object": "list", + "data": [{ + "id": &state.model_name, + "object": "model", + "owned_by": "mlx", + }] + }); + send_response(&mut stream, 200, &resp.to_string()).await?; + } + ("POST", "/v1/chat/completions") => { + handle_chat_completions(&mut stream, body, state).await?; + } + ("POST", "/v1/completions") => { + handle_completions(&mut stream, body, state).await?; + } + _ => { + send_response(&mut stream, 404, r#"{"error":"not found"}"#).await?; + } + } + + Ok(()) +} + +/// Handle POST /v1/chat/completions β€” the main inference endpoint. +async fn handle_chat_completions( + stream: &mut tokio::net::TcpStream, + body: &[u8], + state: Arc<Mutex<InferState>>, +) -> Result<()> { + let req: serde_json::Value = + serde_json::from_slice(body).context("invalid JSON in chat completions request")?; + + let stream_mode = req["stream"].as_bool().unwrap_or(false); + let (generation, prompt) = { + let state = state.lock().await; + let generation = parse_generation_config(&req, &state.model); + let reasoning_template = state.model.prompt_template.reasoning_template(); + let prompt_req = prepare_reasoning_request( + &req, + state + .model + .prompt_template + .behavior() + .prompt_template + .as_deref(), + state.model.reasoning_family, + &reasoning_template, + &generation.response_policy, + ); + let prompt = render_chat_prompt_from_request(&state.model.prompt_template, &prompt_req)?; + (generation, prompt) + }; + let model_field = req["model"].as_str().unwrap_or(""); + + if stream_mode { + generate_streaming(stream, &prompt, generation, model_field, state).await + } else { + generate_blocking(stream, &prompt, generation, model_field, state).await + } +} + +fn render_chat_prompt_from_request( + template: &crate::mlx::template::PromptTemplate, + req: &serde_json::Value, +) -> Result<String> { + template.render_request(req) +} + +fn prepare_reasoning_request( + req: &serde_json::Value, + prompt_template: Option<&str>, + reasoning_family: model::ReasoningFamily, + reasoning_template: &crate::mlx::template::ReasoningTemplate, + policy: &ResponsePolicy, +) -> serde_json::Value { + let needs_olmo2_brevity_nudge = matches!(prompt_template, Some("olmo2")); + if !needs_olmo2_brevity_nudge + && (!policy.strip_reasoning_blocks + || (reasoning_family == model::ReasoningFamily::None + && !reasoning_template.supports_explicit_reasoning)) + { + return req.clone(); + } + + const DIRECT_ANSWER_NUDGE: &str = + "Respond directly with the final answer. Do not output <think> tags. Do not produce internal reasoning, hidden chain-of-thought, analysis, or preamble unless the user explicitly asks for it. Start answering immediately."; + const OLMO2_BREVITY_NUDGE: &str = + "Be concise. Prefer short plain paragraphs or a flat bullet list. Avoid nested outlines, repeated section labels, or repeating the same equation/premise multiple times. For follow-up turns, answer only the new request. Do not restate the previous answer, and do not dump full code/examples unless the user explicitly asks for the full content. If the user asks for a small styling or edit change, provide only the minimal change."; + + let mut patched = req.clone(); + let Some(messages) = patched + .get_mut("messages") + .and_then(|value| value.as_array_mut()) + else { + return patched; + }; + let nudge = if needs_olmo2_brevity_nudge { + format!("{DIRECT_ANSWER_NUDGE}\n{OLMO2_BREVITY_NUDGE}") + } else { + DIRECT_ANSWER_NUDGE.to_string() + }; + + match messages.first_mut() { + Some(first) + if first + .get("role") + .and_then(|value| value.as_str()) + .is_some_and(|role| role == "system") => + { + if let Some(content) = first.get_mut("content") { + match content { + serde_json::Value::String(text) => { + if !text.contains(&nudge) { + if !text.is_empty() { + text.push_str("\n\n"); + } + text.push_str(&nudge); + } + } + serde_json::Value::Array(items) => { + items.push(serde_json::json!({ + "type": "text", + "text": nudge + })); + } + _ => {} + } + } + } + _ => { + messages.insert( + 0, + serde_json::json!({ + "role": "system", + "content": nudge + }), + ); + } + } + + patched +} + +/// Handle POST /v1/completions β€” raw text completion. +/// This endpoint is not implemented; return a structured 501 so clients get a clear error +/// rather than a misleading chat-completion-shaped response. +async fn handle_completions( + stream: &mut tokio::net::TcpStream, + _body: &[u8], + _state: Arc<Mutex<InferState>>, +) -> Result<()> { + let resp = serde_json::json!({ + "error": { + "message": "/v1/completions is not implemented by this server. Use /v1/chat/completions instead.", + "type": "not_implemented_error", + "param": null, + "code": "unsupported_endpoint" + } + }); + send_response(stream, 501, &resp.to_string()).await +} + +/// Non-streaming: run full generation, return one JSON response. +async fn generate_blocking( + stream: &mut tokio::net::TcpStream, + prompt: &str, + generation: GenerationConfig, + model_field: &str, + state: Arc<Mutex<InferState>>, +) -> Result<()> { + let prompt = prompt.to_string(); + let model_field = model_field.to_string(); + let outcome_result = tokio::task::spawn_blocking(move || -> Result<GenerationOutcome> { + let mut state = state.blocking_lock(); + run_inference(&mut state, &prompt, &generation) + }) + .await; + let outcome = match outcome_result { + Ok(Ok(outcome)) => outcome, + Ok(Err(err)) => { + tracing::error!("MLX blocking generation failed: {err:#}"); + let payload = serde_json::json!({ + "error": { + "message": format!("MLX generation failed: {err}"), + "type": "server_error", + "param": null, + "code": "mlx_generation_failed" + } + }); + return send_response(stream, 500, &payload.to_string()).await; + } + Err(err) => { + tracing::error!("MLX blocking generation task failed: {err:#}"); + let payload = serde_json::json!({ + "error": { + "message": format!("MLX generation task failed: {err}"), + "type": "server_error", + "param": null, + "code": "mlx_generation_task_failed" + } + }); + return send_response(stream, 500, &payload.to_string()).await; + } + }; + + let resp = serde_json::json!({ + "id": format!("chatcmpl-mlx-{}", std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis()), + "object": "chat.completion", + "model": model_field, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": outcome.text, + }, + "finish_reason": outcome.finish_reason, + }], + "usage": { + "prompt_tokens": outcome.prompt_tokens, + "completion_tokens": outcome.completion_tokens, + "total_tokens": outcome.prompt_tokens + outcome.completion_tokens, + } + }); + send_response(stream, 200, &resp.to_string()).await +} + +/// Streaming: send SSE chunks as tokens are generated. +async fn generate_streaming( + stream: &mut tokio::net::TcpStream, + prompt: &str, + generation: GenerationConfig, + model_field: &str, + state: Arc<Mutex<InferState>>, +) -> Result<()> { + // Channel for token-by-token streaming + let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(64); + + let prompt = prompt.to_string(); + tokio::task::spawn_blocking(move || { + let mut state = state.blocking_lock(); + let finish_reason = + run_inference_streaming(&mut state, &prompt, &generation, &tx).unwrap_or("stop"); + let _ = tx.blocking_send(StreamEvent::Done(finish_reason)); + }); + + // Send SSE headers + let header = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n" + ); + stream.write_all(header.as_bytes()).await?; + + let id = format!( + "chatcmpl-mlx-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + ); + + while let Some(maybe_token) = rx.recv().await { + match maybe_token { + StreamEvent::Text(text) => { + let chunk = serde_json::json!({ + "id": &id, + "object": "chat.completion.chunk", + "model": &model_field, + "choices": [{ + "index": 0, + "delta": { "content": text }, + "finish_reason": null, + }] + }); + let sse = format!("data: {}\n\n", chunk); + if stream.write_all(sse.as_bytes()).await.is_err() { + break; + } + } + StreamEvent::Done(finish_reason) => { + // Final chunk with finish_reason + let chunk = serde_json::json!({ + "id": &id, + "object": "chat.completion.chunk", + "model": &model_field, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": finish_reason, + }] + }); + let sse = format!("data: {}\n\ndata: [DONE]\n\n", chunk); + let _ = stream.write_all(sse.as_bytes()).await; + break; + } + } + } + + let _ = stream.shutdown().await; + Ok(()) +} + +/// Prefill prompt tokens. Uses chunked prefill (with eval barriers between +/// chunks) to keep the computation graph and peak memory bounded. The chunk +/// size of 2048 matches mlx-lm's default `prefill_step_size`. For prompts +/// ≀2048 tokens, there is only one chunk so no overhead. +const PREFILL_STEP_SIZE: usize = 2048; + +fn prefill_logits( + model: &MlxModel, + prompt_tokens: &[u32], + caches: &mut [model::KVCache], +) -> Result<Array> { + let total = prompt_tokens.len(); + + if model.tokenwise_prefill() { + let mut logits = None; + for &token in prompt_tokens { + let input = Array::from_slice(&[token], &[1, 1]); + logits = Some(model.forward(&input, caches)?); + } + return logits.context("tokenwise prefill received empty prompt"); + } + + if total <= PREFILL_STEP_SIZE { + // Small prompt β€” single forward pass, no eval barriers + let input = Array::from_slice(prompt_tokens, &[1, total as i32]); + return model.forward(&input, caches); + } + + // Large prompt β€” chunk to avoid huge computation graphs + let mut pos = 0; + while total - pos > PREFILL_STEP_SIZE { + let chunk = &prompt_tokens[pos..pos + PREFILL_STEP_SIZE]; + let input = Array::from_slice(chunk, &[1, PREFILL_STEP_SIZE as i32]); + model.forward(&input, caches)?; + mlx_rs::transforms::eval(caches.iter().flat_map(|c| c.arrays()))?; + pos += PREFILL_STEP_SIZE; + } + + // Final chunk β€” get logits for the first generated token + let last_chunk = &prompt_tokens[pos..]; + let input = Array::from_slice(last_chunk, &[1, last_chunk.len() as i32]); + model.forward(&input, caches) +} + +fn replay_last_prompt_token_logits( + model: &MlxModel, + prompt_tokens: &[u32], + caches: &mut [model::KVCache], +) -> Result<Array> { + let last = *prompt_tokens + .last() + .context("replay_last_prompt_token_logits called with empty prompt")?; + let rewind_to = caches[0].offset().saturating_sub(1); + for cache in caches.iter_mut() { + anyhow::ensure!( + cache.trim_to(rewind_to)?, + "cannot rewind MLX cache to token {} during replay", + rewind_to + ); + } + let input = Array::from_slice(&[last], &[1, 1]); + model.forward(&input, caches) +} + +/// Find the longest common prefix between cached tokens and new tokens. +fn common_prefix_len(cached: &[u32], new: &[u32]) -> usize { + cached + .iter() + .zip(new.iter()) + .take_while(|(a, b)| a == b) + .count() +} + +/// Set up caches for a new request, reusing the prompt cache if possible. +/// Returns (caches, tokens_to_prefill) where tokens_to_prefill is the +/// suffix of prompt_tokens that still needs to be forwarded. +fn setup_caches_with_reuse<'a>( + state: &mut InferState, + prompt_tokens: &'a [u32], +) -> (Vec<model::KVCache>, &'a [u32]) { + if !state.model.prompt_cache_reuse() { + state.prompt_cache = None; + return (state.model.new_caches(), prompt_tokens); + } + if let Some(ref cached) = state.prompt_cache { + let prefix_len = common_prefix_len(&cached.tokens, prompt_tokens); + if prefix_len > 0 { + let mut caches = state.prompt_cache.take().unwrap().caches; + if caches.iter().all(|c| c.can_trim_to(prefix_len)) { + // Reuse cached KV β€” trim to prefix length and return suffix + for c in &mut caches { + let trimmed = c + .trim_to(prefix_len) + .expect("cache trim should succeed after can_trim_to check"); + debug_assert!(trimmed); + } + tracing::info!( + "MLX prompt cache: reusing {prefix_len}/{} tokens ({} new)", + prompt_tokens.len(), + prompt_tokens.len() - prefix_len, + ); + return (caches, &prompt_tokens[prefix_len..]); + } + tracing::info!( + "MLX prompt cache: cannot reuse {prefix_len}/{} tokens after cache eviction; rebuilding", + prompt_tokens.len(), + ); + } + } + // No cache hit β€” fresh caches + (state.model.new_caches(), prompt_tokens) +} + +/// Save caches + tokens for future reuse. +fn save_prompt_cache(state: &mut InferState, tokens: Vec<u32>, caches: Vec<model::KVCache>) { + if !state.model.prompt_cache_reuse() { + state.prompt_cache = None; + return; + } + state.prompt_cache = Some(PromptCache { tokens, caches }); +} + +fn is_stop_token(token: u32, generation: &GenerationConfig) -> bool { + generation.stop_token_ids.binary_search(&token).is_ok() +} + +fn should_consume_hidden_reasoning_token( + token: u32, + generation: &GenerationConfig, + state: &mut HiddenReasoningState, + visible_tokens: usize, +) -> bool { + let Some(hidden) = generation.hidden_reasoning else { + return false; + }; + + if state.active { + if token == hidden.end_reasoning_token_id { + state.active = false; + state.skip_leading_whitespace = true; + } + return true; + } + + if visible_tokens == 0 && token == hidden.start_turn_token_id { + state.active = true; + return true; + } + + false +} + +/// Run inference synchronously (called from blocking thread). +fn run_inference( + state: &mut InferState, + prompt: &str, + generation: &GenerationConfig, +) -> Result<GenerationOutcome> { + if state.model.cacheless_generation() { + return run_inference_cacheless(state, prompt, generation); + } + let prompt_tokens = encode_prompt_tokens(&state.model, prompt)?; + let prompt_len = prompt_tokens.len(); + if prompt_tokens.is_empty() { + anyhow::bail!("prompt encoded to zero tokens β€” check that the prompt is non-empty"); + } + tracing::debug!("MLX prompt text: {:?}", prompt); + tracing::debug!("MLX prompt tokens: {:?}", prompt_tokens); + + let (mut caches, suffix) = setup_caches_with_reuse(state, &prompt_tokens); + let mut sampler = Sampler::new(generation.sampling.clone()); + let mut stop_buffer = StopBuffer::new(generation.stop_sequences.clone()); + let mut response_filter = ResponseFilter::new(generation.response_policy.clone()); + + if generation.max_tokens == 0 { + if !suffix.is_empty() { + let _ = prefill_logits(&state.model, suffix, &mut caches)?; + } + save_prompt_cache(state, prompt_tokens, caches); + return Ok(GenerationOutcome { + text: String::new(), + prompt_tokens: prompt_len, + completion_tokens: 0, + finish_reason: "length", + }); + } + + // Prefill only the new suffix + let mut next_token = if suffix.is_empty() { + // Entire prompt was cached β€” re-forward last token to get logits + let logits = replay_last_prompt_token_logits(&state.model, &prompt_tokens, &mut caches)?; + sampler.sample_next_token(&logits)? + } else { + let logits = prefill_logits(&state.model, suffix, &mut caches)?; + let logits = if state.model.tokenwise_prefill() && state.model.can_replay_prompt_logits() { + replay_last_prompt_token_logits(&state.model, &prompt_tokens, &mut caches)? + } else { + logits + }; + sampler.sample_next_token(&logits)? + }; + tracing::info!( + "MLX first sampled token: id={} eos={} prompt_tokens={}", + next_token, + is_eos(next_token, &state.model.config), + prompt_len + ); + + let mut decode_stream = state.model.tokenizer.decode_stream(true); + let mut text = String::new(); + let mut completion_tokens = 0usize; + let mut finish_reason = "length"; + let mut hidden_reasoning_state = HiddenReasoningState::default(); + let max_sampled_tokens = if generation.hidden_reasoning.is_some() { + generation + .max_tokens + .saturating_mul(32) + .max(generation.max_tokens) + } else { + generation.max_tokens + }; + + // Decode + for step in 0..max_sampled_tokens { + if is_eos(next_token, &state.model.config) || is_stop_token(next_token, generation) { + finish_reason = "stop"; + break; + } + + let pending_logits = + if completion_tokens < generation.max_tokens && step + 1 < max_sampled_tokens { + let input = Array::from_slice(&[next_token], &[1, 1]); + let logits = state.model.forward(&input, &mut caches)?; + mlx_rs::transforms::async_eval([&logits])?; + Some(logits) + } else { + None + }; + + if should_consume_hidden_reasoning_token( + next_token, + generation, + &mut hidden_reasoning_state, + completion_tokens, + ) { + if let Some(logits) = pending_logits { + next_token = sampler.sample_next_token(&logits)?; + } + continue; + } + + let piece = decode_stream + .step(next_token) + .map_err(|e| anyhow::anyhow!("tokenizer decode: {e}"))? + .unwrap_or_default(); + if hidden_reasoning_state.skip_leading_whitespace && piece.trim().is_empty() { + if let Some(logits) = pending_logits { + next_token = sampler.sample_next_token(&logits)?; + } + continue; + } + hidden_reasoning_state.skip_leading_whitespace = false; + completion_tokens += 1; + let chunk = stop_buffer.push(&piece); + text.push_str(&response_filter.push(&chunk.emit)); + if chunk.matched { + finish_reason = "stop"; + break; + } + if completion_tokens >= generation.max_tokens { + break; + } + + if let Some(logits) = pending_logits { + next_token = sampler.sample_next_token(&logits)?; + } + } + + text.push_str(&response_filter.push(&stop_buffer.finish())); + text.push_str(&response_filter.finish()); + text = sanitize_behavior_output(&state.model, &text); + + // Save prompt cache for next request (prompt only, not generated tokens) + save_prompt_cache(state, prompt_tokens, caches); + + Ok(GenerationOutcome { + text, + prompt_tokens: prompt_len, + completion_tokens, + finish_reason, + }) +} + +/// Run inference with per-token callback for streaming. +fn run_inference_streaming( + state: &mut InferState, + prompt: &str, + generation: &GenerationConfig, + tx: &tokio::sync::mpsc::Sender<StreamEvent>, +) -> Result<&'static str> { + if state.model.cacheless_generation() { + return run_inference_streaming_cacheless(state, prompt, generation, tx); + } + let prompt_tokens = encode_prompt_tokens(&state.model, prompt)?; + if prompt_tokens.is_empty() { + anyhow::bail!("prompt encoded to zero tokens β€” check that the prompt is non-empty"); + } + + let (mut caches, suffix) = setup_caches_with_reuse(state, &prompt_tokens); + let mut sampler = Sampler::new(generation.sampling.clone()); + let mut stop_buffer = StopBuffer::new(generation.stop_sequences.clone()); + let mut response_filter = ResponseFilter::new(generation.response_policy.clone()); + + if generation.max_tokens == 0 { + if !suffix.is_empty() { + let _ = prefill_logits(&state.model, suffix, &mut caches)?; + } + save_prompt_cache(state, prompt_tokens, caches); + return Ok("length"); + } + + // Prefill only the new suffix + let mut next_token = if suffix.is_empty() { + let logits = replay_last_prompt_token_logits(&state.model, &prompt_tokens, &mut caches)?; + sampler.sample_next_token(&logits)? + } else { + let logits = prefill_logits(&state.model, suffix, &mut caches)?; + let logits = if state.model.tokenwise_prefill() && state.model.can_replay_prompt_logits() { + replay_last_prompt_token_logits(&state.model, &prompt_tokens, &mut caches)? + } else { + logits + }; + sampler.sample_next_token(&logits)? + }; + + let mut decode_stream = state.model.tokenizer.decode_stream(true); + let mut finish_reason = "length"; + let mut completion_tokens = 0usize; + let mut hidden_reasoning_state = HiddenReasoningState::default(); + let max_sampled_tokens = if generation.hidden_reasoning.is_some() { + generation + .max_tokens + .saturating_mul(32) + .max(generation.max_tokens) + } else { + generation.max_tokens + }; + + // Decode + stream + for step in 0..max_sampled_tokens { + if is_eos(next_token, &state.model.config) || is_stop_token(next_token, generation) { + finish_reason = "stop"; + break; + } + + let pending_logits = + if completion_tokens < generation.max_tokens && step + 1 < max_sampled_tokens { + let input = Array::from_slice(&[next_token], &[1, 1]); + let logits = state.model.forward(&input, &mut caches)?; + mlx_rs::transforms::async_eval([&logits])?; + Some(logits) + } else { + None + }; + + if should_consume_hidden_reasoning_token( + next_token, + generation, + &mut hidden_reasoning_state, + completion_tokens, + ) { + if let Some(logits) = pending_logits { + next_token = sampler.sample_next_token(&logits)?; + } + continue; + } + + let piece = decode_stream + .step(next_token) + .map_err(|e| anyhow::anyhow!("tokenizer decode: {e}"))? + .unwrap_or_default(); + if hidden_reasoning_state.skip_leading_whitespace && piece.trim().is_empty() { + if let Some(logits) = pending_logits { + next_token = sampler.sample_next_token(&logits)?; + } + continue; + } + hidden_reasoning_state.skip_leading_whitespace = false; + completion_tokens += 1; + let chunk = stop_buffer.push(&piece); + let filtered = response_filter.push(&chunk.emit); + if !filtered.is_empty() { + if tx.blocking_send(StreamEvent::Text(filtered)).is_err() { + break; // client disconnected + } + } + if chunk.matched { + finish_reason = "stop"; + break; + } + if completion_tokens >= generation.max_tokens { + break; + } + + if let Some(logits) = pending_logits { + next_token = sampler.sample_next_token(&logits)?; + } + } + + let mut tail = response_filter.push(&stop_buffer.finish()); + tail.push_str(&response_filter.finish()); + if !tail.is_empty() { + let _ = tx.blocking_send(StreamEvent::Text(tail)); + } + + // Save prompt cache for next request + save_prompt_cache(state, prompt_tokens, caches); + + Ok(finish_reason) +} + +fn run_inference_cacheless( + state: &mut InferState, + prompt: &str, + generation: &GenerationConfig, +) -> Result<GenerationOutcome> { + let mut tokens = encode_prompt_tokens(&state.model, prompt)?; + let prompt_len = tokens.len(); + let mut sampler = Sampler::new(generation.sampling.clone()); + let mut stop_buffer = StopBuffer::new(generation.stop_sequences.clone()); + let mut response_filter = ResponseFilter::new(generation.response_policy.clone()); + + if generation.max_tokens == 0 { + if !tokens.is_empty() { + let input = Array::from_slice(&tokens, &[1, tokens.len() as i32]); + let _ = state.model.forward_no_cache(&input)?; + } + return Ok(GenerationOutcome { + text: String::new(), + prompt_tokens: prompt_len, + completion_tokens: 0, + finish_reason: "length", + }); + } + + let mut decode_stream = state.model.tokenizer.decode_stream(true); + let mut text = String::new(); + let mut completion_tokens = 0usize; + let mut finish_reason = "length"; + let mut hidden_reasoning_state = HiddenReasoningState::default(); + let max_sampled_tokens = if generation.hidden_reasoning.is_some() { + generation + .max_tokens + .saturating_mul(32) + .max(generation.max_tokens) + } else { + generation.max_tokens + }; + + for _ in 0..max_sampled_tokens { + let input = Array::from_slice(&tokens, &[1, tokens.len() as i32]); + let logits = state.model.forward_no_cache(&input)?; + let next_token = sampler.sample_next_token(&logits)?; + if is_eos(next_token, &state.model.config) || is_stop_token(next_token, generation) { + finish_reason = "stop"; + break; + } + tokens.push(next_token); + if should_consume_hidden_reasoning_token( + next_token, + generation, + &mut hidden_reasoning_state, + completion_tokens, + ) { + continue; + } + let piece = decode_stream + .step(next_token) + .map_err(|e| anyhow::anyhow!("tokenizer decode: {e}"))? + .unwrap_or_default(); + if hidden_reasoning_state.skip_leading_whitespace && piece.trim().is_empty() { + continue; + } + hidden_reasoning_state.skip_leading_whitespace = false; + completion_tokens += 1; + let chunk = stop_buffer.push(&piece); + text.push_str(&response_filter.push(&chunk.emit)); + if chunk.matched { + finish_reason = "stop"; + break; + } + if completion_tokens >= generation.max_tokens { + break; + } + } + + text.push_str(&response_filter.push(&stop_buffer.finish())); + text.push_str(&response_filter.finish()); + text = sanitize_behavior_output(&state.model, &text); + + Ok(GenerationOutcome { + text, + prompt_tokens: prompt_len, + completion_tokens, + finish_reason, + }) +} + +fn run_inference_streaming_cacheless( + state: &mut InferState, + prompt: &str, + generation: &GenerationConfig, + tx: &tokio::sync::mpsc::Sender<StreamEvent>, +) -> Result<&'static str> { + let mut tokens = encode_prompt_tokens(&state.model, prompt)?; + let mut sampler = Sampler::new(generation.sampling.clone()); + let mut stop_buffer = StopBuffer::new(generation.stop_sequences.clone()); + let mut response_filter = ResponseFilter::new(generation.response_policy.clone()); + + if generation.max_tokens == 0 { + if !tokens.is_empty() { + let input = Array::from_slice(&tokens, &[1, tokens.len() as i32]); + let _ = state.model.forward_no_cache(&input)?; + } + return Ok("length"); + } + + let mut decode_stream = state.model.tokenizer.decode_stream(true); + let mut finish_reason = "length"; + let mut completion_tokens = 0usize; + let mut hidden_reasoning_state = HiddenReasoningState::default(); + let max_sampled_tokens = if generation.hidden_reasoning.is_some() { + generation + .max_tokens + .saturating_mul(32) + .max(generation.max_tokens) + } else { + generation.max_tokens + }; + + for _ in 0..max_sampled_tokens { + let input = Array::from_slice(&tokens, &[1, tokens.len() as i32]); + let logits = state.model.forward_no_cache(&input)?; + let next_token = sampler.sample_next_token(&logits)?; + if is_eos(next_token, &state.model.config) || is_stop_token(next_token, generation) { + finish_reason = "stop"; + break; + } + + tokens.push(next_token); + if should_consume_hidden_reasoning_token( + next_token, + generation, + &mut hidden_reasoning_state, + completion_tokens, + ) { + continue; + } + let piece = decode_stream + .step(next_token) + .map_err(|e| anyhow::anyhow!("tokenizer decode: {e}"))? + .unwrap_or_default(); + if hidden_reasoning_state.skip_leading_whitespace && piece.trim().is_empty() { + continue; + } + hidden_reasoning_state.skip_leading_whitespace = false; + completion_tokens += 1; + let chunk = stop_buffer.push(&piece); + let filtered = response_filter.push(&chunk.emit); + if !filtered.is_empty() && tx.blocking_send(StreamEvent::Text(filtered)).is_err() { + break; + } + if chunk.matched { + finish_reason = "stop"; + break; + } + if completion_tokens >= generation.max_tokens { + break; + } + } + + let mut tail = response_filter.push(&stop_buffer.finish()); + tail.push_str(&response_filter.finish()); + if !tail.is_empty() { + let _ = tx.blocking_send(StreamEvent::Text(tail)); + } + + Ok(finish_reason) +} + +fn is_eos(token: u32, config: &model::ModelConfig) -> bool { + config.eos_token_id.contains(&token) +} + +fn parse_generation_config(req: &serde_json::Value, model: &MlxModel) -> GenerationConfig { + let mut stop_sequences = default_stop_sequences(model); + stop_sequences.extend(parse_stop_sequences(req.get("stop"))); + stop_sequences.sort(); + stop_sequences.dedup(); + let stop_token_ids = stop_token_ids(model, &stop_sequences); + let requested_max_tokens = req["max_tokens"].as_u64().unwrap_or(2048) as usize; + let message_count = req["messages"] + .as_array() + .map(|messages| messages.len()) + .unwrap_or(0); + let max_tokens = if model.prompt_template.behavior().prompt_template.as_deref() == Some("olmo2") + { + if message_count >= 4 { + requested_max_tokens.min(48) + } else { + requested_max_tokens.min(64) + } + } else { + requested_max_tokens + }; + let hidden_reasoning = hidden_reasoning_tokens(model, req); + let stop_token_ids = if let Some(hidden) = hidden_reasoning { + stop_token_ids + .into_iter() + .filter(|id| *id != hidden.start_turn_token_id) + .collect::<Vec<_>>() + } else { + stop_token_ids + }; + GenerationConfig { + max_tokens, + sampling: SamplingParams { + temperature: req["temperature"].as_f64().unwrap_or(0.0) as f32, + top_p: req["top_p"].as_f64().unwrap_or(1.0) as f32, + top_k: req["top_k"] + .as_u64() + .map(|value| value as usize) + .filter(|value| *value > 0), + seed: req["seed"].as_u64(), + suppressed_token_ids: suppressed_reasoning_token_ids(model, req), + }, + stop_sequences, + stop_token_ids, + hidden_reasoning, + response_policy: response_policy(req, model), + } +} + +fn suppressed_reasoning_token_ids(model: &MlxModel, req: &serde_json::Value) -> Vec<u32> { + if !reasoning_disabled( + req, + model.reasoning_family, + &model.prompt_template.reasoning_template(), + ) { + return Vec::new(); + } + + let mut ids = Vec::new(); + for block in model.prompt_template.reasoning_template().tagged_reasoning { + if let Some(token_id) = single_token_id_for_text(model, &block.start) { + ids.push(token_id); + } + } + ids.sort_unstable(); + ids.dedup(); + ids +} + +fn single_token_id_for_text(model: &MlxModel, text: &str) -> Option<u32> { + let encoding = model.tokenizer.encode(text, false).ok()?; + let ids = encoding.get_ids(); + if ids.len() == 1 { + Some(ids[0]) + } else { + None + } +} + +fn stop_token_ids(model: &MlxModel, stop_sequences: &[String]) -> Vec<u32> { + let mut ids = stop_sequences + .iter() + .filter_map(|text| single_token_id_for_text(model, text)) + .collect::<Vec<_>>(); + ids.sort_unstable(); + ids.dedup(); + ids +} + +fn hidden_reasoning_tokens( + model: &MlxModel, + req: &serde_json::Value, +) -> Option<HiddenReasoningTokens> { + if !reasoning_disabled( + req, + model.reasoning_family, + &model.prompt_template.reasoning_template(), + ) { + return None; + } + if model.reasoning_family != model::ReasoningFamily::Qwen3 { + return None; + } + Some(HiddenReasoningTokens { + start_turn_token_id: single_token_id_for_text(model, "<|im_start|>")?, + end_reasoning_token_id: single_token_id_for_text(model, "</think>")?, + }) +} + +fn parse_stop_sequences(stop: Option<&serde_json::Value>) -> Vec<String> { + match stop { + Some(serde_json::Value::String(text)) if !text.is_empty() => vec![text.clone()], + Some(serde_json::Value::Array(items)) => items + .iter() + .filter_map(|value| value.as_str()) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) + .collect(), + _ => Vec::new(), + } +} + +fn default_stop_sequences(model: &MlxModel) -> Vec<String> { + let mut stops = default_stop_sequences_for( + model.prompt_template.behavior().prompt_template.as_deref(), + model.reasoning_family, + ); + stops.extend( + model + .prompt_template + .reasoning_template() + .default_stop_sequences, + ); + stops.sort(); + stops.dedup(); + stops +} + +fn default_stop_sequences_for( + prompt_template: Option<&str>, + reasoning_family: model::ReasoningFamily, +) -> Vec<String> { + let mut stops = Vec::new(); + match prompt_template { + Some("chatml") => { + stops.push("<|im_end|>".to_string()); + stops.push("<|im_start|>".to_string()); + } + Some("olmo2") => { + stops.push("<|user|>".to_string()); + stops.push("<|assistant|>".to_string()); + stops.push("<|system|>".to_string()); + } + Some("llama3") => { + stops.push("<|eot_id|>".to_string()); + } + Some("gemma3") => { + stops.push("<end_of_turn>".to_string()); + } + _ => {} + } + match reasoning_family { + model::ReasoningFamily::Glm => { + stops.push("<|user|>".to_string()); + stops.push("<|assistant|>".to_string()); + stops.push("<|system|>".to_string()); + } + model::ReasoningFamily::Kimi => { + stops.push("<|im_end|>".to_string()); + } + _ => {} + } + stops +} + +fn response_policy(req: &serde_json::Value, model: &MlxModel) -> ResponsePolicy { + response_policy_for( + req, + model.reasoning_family, + &model.prompt_template.reasoning_template(), + ) +} + +fn response_policy_for( + req: &serde_json::Value, + reasoning_family: model::ReasoningFamily, + reasoning_template: &crate::mlx::template::ReasoningTemplate, +) -> ResponsePolicy { + ResponsePolicy { + strip_reasoning_blocks: reasoning_disabled(req, reasoning_family, reasoning_template), + tagged_reasoning: reasoning_template.tagged_reasoning.clone(), + } +} + +fn reasoning_disabled( + req: &serde_json::Value, + family: model::ReasoningFamily, + reasoning_template: &crate::mlx::template::ReasoningTemplate, +) -> bool { + match family { + model::ReasoningFamily::Qwen3 | model::ReasoningFamily::Glm => { + request_bool_kwarg(req, "enable_thinking").unwrap_or(false) == false + } + model::ReasoningFamily::Kimi => { + if let Some(value) = request_bool_kwarg(req, "thinking") { + !value + } else { + request_bool_kwarg(req, "enable_thinking").unwrap_or(false) == false + } + } + model::ReasoningFamily::Lfm2 => { + if let Some(value) = request_bool_kwarg(req, "keep_past_thinking") { + !value + } else { + request_bool_kwarg(req, "enable_thinking").unwrap_or(false) == false + } + } + model::ReasoningFamily::GptOss => { + if let Some(value) = request_bool_kwarg(req, "enable_thinking") { + !value + } else { + matches!( + request_string_kwarg(req, "reasoning_effort").as_deref(), + None | Some("low") + ) + } + } + model::ReasoningFamily::None => { + if !reasoning_template.supports_explicit_reasoning { + false + } else if let Some(value) = request_bool_kwarg(req, "thinking") { + !value + } else if let Some(value) = request_bool_kwarg(req, "keep_past_thinking") { + !value + } else if let Some(value) = request_bool_kwarg(req, "enable_thinking") { + !value + } else { + matches!( + request_string_kwarg(req, "reasoning_effort").as_deref(), + Some("low") + ) + } + } + } +} + +fn request_bool_kwarg(req: &serde_json::Value, key: &str) -> Option<bool> { + req.get(key).and_then(|value| value.as_bool()).or_else(|| { + req.get("chat_template_kwargs") + .and_then(|value| value.get(key)) + .and_then(|value| value.as_bool()) + }) +} + +fn request_string_kwarg(req: &serde_json::Value, key: &str) -> Option<String> { + req.get(key) + .and_then(|value| value.as_str()) + .map(ToOwned::to_owned) + .or_else(|| { + req.get("chat_template_kwargs") + .and_then(|value| value.get(key)) + .and_then(|value| value.as_str()) + .map(ToOwned::to_owned) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokenizers::decoders::byte_fallback::ByteFallback; + use tokenizers::models::bpe::BPE; + use tokenizers::normalizers::unicode::NFC; + use tokenizers::normalizers::utils::Sequence; + use tokenizers::normalizers::Strip; + use tokenizers::pre_tokenizers::byte_level::ByteLevel; + use tokenizers::TokenizerBuilder; + + #[test] + fn think_block_filter_strips_tagged_reasoning_across_chunks() { + let mut filter = ResponseFilter::new(ResponsePolicy { + strip_reasoning_blocks: true, + tagged_reasoning: vec![crate::mlx::template::TaggedReasoningBlock { + start: "<think>".to_string(), + end: "</think>".to_string(), + }], + }); + assert_eq!(filter.push("<thi"), ""); + assert_eq!(filter.push("nk>internal"), ""); + assert_eq!(filter.push("</think>blue"), "blue"); + assert_eq!(filter.finish(), ""); + } + + #[test] + fn tagged_block_filter_strips_gemma4_reasoning_markers() { + let mut filter = ResponseFilter::new(ResponsePolicy { + strip_reasoning_blocks: true, + tagged_reasoning: vec![crate::mlx::template::TaggedReasoningBlock { + start: "<|channel>thought".to_string(), + end: "<channel|>".to_string(), + }], + }); + assert_eq!(filter.push("<|channel>thoughthidden"), ""); + assert_eq!(filter.push("<channel|>blue"), "blue"); + } + + #[test] + fn template_derived_stop_sequences_cover_gemma4_turn_end() { + let mut stops = + default_stop_sequences_for(Some("hf_template"), model::ReasoningFamily::None); + stops.extend(vec!["<turn|>".to_string()]); + stops.sort(); + stops.dedup(); + assert!(stops.contains(&"<turn|>".to_string())); + } + + #[test] + fn gemma3_default_stop_sequences_include_end_of_turn() { + let stops = default_stop_sequences_for(Some("gemma3"), model::ReasoningFamily::None); + assert!(stops.contains(&"<end_of_turn>".to_string())); + } + + #[test] + fn qwen3_generation_config_adds_chatml_stops_and_disables_thinking_by_default() { + let stops = default_stop_sequences_for(Some("chatml"), model::ReasoningFamily::Qwen3); + assert!(stops.contains(&"<|im_end|>".to_string())); + assert!(stops.contains(&"<|im_start|>".to_string())); + let policy = response_policy_for( + &serde_json::json!({}), + model::ReasoningFamily::Qwen3, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(policy.strip_reasoning_blocks); + } + + #[test] + fn olmo2_generation_config_adds_role_stops() { + let stops = default_stop_sequences_for(Some("olmo2"), model::ReasoningFamily::None); + assert!(stops.contains(&"<|user|>".to_string())); + assert!(stops.contains(&"<|assistant|>".to_string())); + assert!(stops.contains(&"<|system|>".to_string())); + } + + #[test] + fn qwen3_generation_config_honors_explicit_enable_thinking() { + let policy = response_policy_for( + &serde_json::json!({"enable_thinking": true}), + model::ReasoningFamily::Qwen3, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(!policy.strip_reasoning_blocks); + } + + #[test] + fn kimi_generation_config_maps_default_to_think_filtering() { + let stops = default_stop_sequences_for(Some("hf_template"), model::ReasoningFamily::Kimi); + assert!(stops.contains(&"<|im_end|>".to_string())); + let policy = response_policy_for( + &serde_json::json!({}), + model::ReasoningFamily::Kimi, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(policy.strip_reasoning_blocks); + } + + #[test] + fn glm_generation_config_defaults_to_no_thinking_and_glm_stops() { + let stops = default_stop_sequences_for(Some("hf_template"), model::ReasoningFamily::Glm); + assert!(stops.contains(&"<|user|>".to_string())); + assert!(stops.contains(&"<|assistant|>".to_string())); + let policy = response_policy_for( + &serde_json::json!({}), + model::ReasoningFamily::Glm, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(policy.strip_reasoning_blocks); + } + + #[test] + fn gpt_oss_generation_config_defaults_to_reasoning_suppression() { + let policy = response_policy_for( + &serde_json::json!({}), + model::ReasoningFamily::GptOss, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(policy.strip_reasoning_blocks); + let explicit = response_policy_for( + &serde_json::json!({"reasoning_effort":"medium"}), + model::ReasoningFamily::GptOss, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(!explicit.strip_reasoning_blocks); + } + + #[test] + fn lfm2_generation_config_defaults_to_strip_past_thinking() { + let policy = response_policy_for( + &serde_json::json!({}), + model::ReasoningFamily::Lfm2, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(policy.strip_reasoning_blocks); + let explicit = response_policy_for( + &serde_json::json!({"keep_past_thinking":true}), + model::ReasoningFamily::Lfm2, + &crate::mlx::template::ReasoningTemplate::default(), + ); + assert!(!explicit.strip_reasoning_blocks); + } + + #[test] + fn unknown_reasoning_family_can_still_strip_when_template_supports_reasoning_toggle() { + let reasoning_template = crate::mlx::template::ReasoningTemplate { + supports_explicit_reasoning: true, + tagged_reasoning: vec![crate::mlx::template::TaggedReasoningBlock { + start: "<think>".to_string(), + end: "</think>".to_string(), + }], + default_stop_sequences: Vec::new(), + }; + let policy = response_policy_for( + &serde_json::json!({"enable_thinking": false}), + model::ReasoningFamily::None, + &reasoning_template, + ); + assert!(policy.strip_reasoning_blocks); + assert_eq!(policy.tagged_reasoning.len(), 1); + } + + #[test] + fn prepare_reasoning_request_injects_system_nudge_when_disabled() { + let req = serde_json::json!({ + "messages": [{"role": "user", "content": "Reply with exactly: blue"}] + }); + let patched = prepare_reasoning_request( + &req, + Some("chatml"), + model::ReasoningFamily::Qwen3, + &crate::mlx::template::ReasoningTemplate::default(), + &ResponsePolicy { + strip_reasoning_blocks: true, + tagged_reasoning: Vec::new(), + }, + ); + let messages = patched["messages"].as_array().unwrap(); + assert_eq!(messages[0]["role"], "system"); + assert!(messages[0]["content"] + .as_str() + .unwrap() + .contains("Respond directly with the final answer.")); + } + + #[test] + fn prepare_reasoning_request_leaves_request_unchanged_when_reasoning_allowed() { + let req = serde_json::json!({ + "messages": [{"role": "user", "content": "Reply with exactly: blue"}] + }); + let patched = prepare_reasoning_request( + &req, + Some("chatml"), + model::ReasoningFamily::Qwen3, + &crate::mlx::template::ReasoningTemplate::default(), + &ResponsePolicy { + strip_reasoning_blocks: false, + tagged_reasoning: Vec::new(), + }, + ); + assert_eq!(patched, req); + } + + #[test] + fn prepare_reasoning_request_injects_nudge_for_all_reasoning_families() { + let req = serde_json::json!({ + "messages": [{"role": "user", "content": "Reply with exactly: blue"}] + }); + for family in [ + model::ReasoningFamily::Qwen3, + model::ReasoningFamily::Glm, + model::ReasoningFamily::Kimi, + model::ReasoningFamily::GptOss, + model::ReasoningFamily::Lfm2, + ] { + let patched = prepare_reasoning_request( + &req, + Some("chatml"), + family, + &crate::mlx::template::ReasoningTemplate::default(), + &ResponsePolicy { + strip_reasoning_blocks: true, + tagged_reasoning: Vec::new(), + }, + ); + let messages = patched["messages"].as_array().unwrap(); + assert_eq!(messages[0]["role"], "system"); + assert!(messages[0]["content"] + .as_str() + .unwrap() + .contains("Respond directly with the final answer.")); + } + } + + #[test] + fn prepare_reasoning_request_appends_nudge_to_existing_system_message() { + let req = serde_json::json!({ + "messages": [ + {"role": "system", "content": "You are terse."}, + {"role": "user", "content": "Reply with exactly: blue"} + ] + }); + let patched = prepare_reasoning_request( + &req, + Some("chatml"), + model::ReasoningFamily::Qwen3, + &crate::mlx::template::ReasoningTemplate::default(), + &ResponsePolicy { + strip_reasoning_blocks: true, + tagged_reasoning: Vec::new(), + }, + ); + let messages = patched["messages"].as_array().unwrap(); + let system = messages[0]["content"].as_str().unwrap(); + assert!(system.contains("You are terse.")); + assert!(system.contains("Respond directly with the final answer.")); + } + + #[test] + fn prepare_reasoning_request_injects_olmo2_brevity_nudge_without_reasoning_family() { + let req = serde_json::json!({ + "messages": [{"role": "user", "content": "Reply with exactly: blue"}] + }); + let patched = prepare_reasoning_request( + &req, + Some("olmo2"), + model::ReasoningFamily::None, + &crate::mlx::template::ReasoningTemplate::default(), + &ResponsePolicy { + strip_reasoning_blocks: false, + tagged_reasoning: Vec::new(), + }, + ); + let messages = patched["messages"].as_array().unwrap(); + let system = messages[0]["content"].as_str().unwrap(); + assert!(system.contains("Respond directly with the final answer.")); + assert!(system.contains("For follow-up turns, answer only the new request.")); + } + + #[test] + fn mlx_chat_smoke_renders_llama3_prompt_from_hf_request_shape() { + let template = crate::mlx::template::PromptTemplate::Llama3; + let req = serde_json::json!({ + "model": "meta-llama/Llama-3.2-3B-Instruct", + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Say hi"} + ] + }); + + let prompt = render_chat_prompt_from_request(&template, &req).unwrap(); + + assert!(prompt.starts_with("<|begin_of_text|>")); + assert!( + prompt.contains("<|start_header_id|>system<|end_header_id|>\n\nBe concise.<|eot_id|>") + ); + assert!(prompt.contains("<|start_header_id|>user<|end_header_id|>\n\nSay hi<|eot_id|>")); + assert!(prompt.ends_with("<|start_header_id|>assistant<|end_header_id|>\n\n")); + } + + #[test] + fn mlx_chat_smoke_renders_tools_prompt_from_hf_template() { + let root = + std::env::temp_dir().join(format!("mesh-llm-server-qwen-tools-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- if tools %}{{- '<|im_start|>system\\n# Tools\\n<tools>' }}{%- for tool in tools %}{{- tool | tojson }}{%- endfor %}{{- '</tools><|im_end|>\\n' }}{%- endif %}{%- for message in messages %}{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}{%- endfor %}{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + let template = crate::mlx::template::PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"qwen2"}), + ); + let req = serde_json::json!({ + "model": "Qwen/Qwen2.5-0.5B-Instruct", + "messages": [{"role": "user", "content": "use a tool"}], + "tools": [{ + "type": "function", + "function": { + "name": "run", + "description": "Run a command" + } + }] + }); + + let prompt = render_chat_prompt_from_request(&template, &req).unwrap(); + + assert!(prompt.contains("# Tools")); + assert!(prompt.contains("\"name\":\"run\"")); + assert!(prompt.ends_with("<|im_start|>assistant\n")); + } + + #[test] + fn mlx_chat_smoke_renders_gemma3_prompt_from_hf_request_shape() { + let template = crate::mlx::template::PromptTemplate::Gemma3; + let req = serde_json::json!({ + "model": "mlx-community/gemma-3-4b-it-4bit", + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Say hi"}, + {"role": "assistant", "content": "Hi."}, + {"role": "user", "content": [ + {"type": "text", "text": "look "}, + {"type": "image"}, + {"type": "text", "text": "here"} + ]} + ] + }); + + let prompt = render_chat_prompt_from_request(&template, &req).unwrap(); + + assert!( + prompt.starts_with("<bos><start_of_turn>user\nBe concise.\n\nSay hi<end_of_turn>\n") + ); + assert!(prompt.contains("<start_of_turn>model\nHi.<end_of_turn>\n")); + assert!(prompt.contains("<start_of_turn>user\nlook<start_of_image>here<end_of_turn>\n")); + assert!(prompt.ends_with("<start_of_turn>model\n")); + } + + #[test] + fn decode_stream_handles_split_utf8_tokens() { + let vocab = [ + ("<0x20>".to_string(), 0), + ("<0xC3>".to_string(), 1), + ("<0xA9>".to_string(), 2), + (" This".to_string(), 3), + ]; + let bpe = BPE::builder() + .vocab_and_merges(vocab, vec![]) + .byte_fallback(true) + .build() + .unwrap(); + let tokenizer = TokenizerBuilder::new() + .with_model(bpe) + .with_normalizer(Some(Sequence::new(vec![ + Strip::new(true, true).into(), + NFC.into(), + ]))) + .with_pre_tokenizer(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevel::default())) + .with_decoder(Some(ByteFallback::default())) + .build() + .unwrap(); + + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); + assert_eq!(decode_stream.step(1).unwrap(), None); + assert_eq!(decode_stream.step(2).unwrap(), Some("Γ©".to_string())); + } + + #[test] + #[ignore] + fn olmo_debug_run_inference_local() { + let dir = std::path::Path::new( + "/Users/jdumay/.cache/mesh-llm-debug/olmo-7b-instruct-hf-same-origin/mlx/olmo-7b-instruct-hf-bf16", + ); + assert!( + dir.exists(), + "missing local OLMo artifact at {}", + dir.display() + ); + + let model = MlxModel::load(dir).expect("load local olmo mlx artifact"); + let prompt = + "<|endoftext|><|user|>\nWhat day comes after Monday? Reply with one word.\n<|assistant|>\n"; + let generation = GenerationConfig { + max_tokens: 4, + sampling: SamplingParams { + temperature: 0.0, + top_p: 1.0, + top_k: None, + seed: None, + suppressed_token_ids: Vec::new(), + }, + stop_sequences: Vec::new(), + stop_token_ids: Vec::new(), + hidden_reasoning: None, + response_policy: ResponsePolicy::default(), + }; + + let mut state = InferState { + model, + model_name: "olmo-debug".to_string(), + prompt_cache: None, + }; + + let first = run_inference(&mut state, prompt, &generation).expect("first inference"); + println!("fresh text {:?}", first.text); + + let second = run_inference(&mut state, prompt, &generation).expect("second inference"); + println!("reused text {:?}", second.text); + } + + #[test] + #[ignore] + fn olmo_debug_run_inference_sequence_local() { + let dir = std::path::Path::new( + "/Users/jdumay/.cache/mesh-llm-debug/olmo-7b-instruct-hf-same-origin/mlx/olmo-7b-instruct-hf-bf16", + ); + assert!( + dir.exists(), + "missing local OLMo artifact at {}", + dir.display() + ); + + let model = MlxModel::load(dir).expect("load local olmo mlx artifact"); + let generation = GenerationConfig { + max_tokens: 8, + sampling: SamplingParams { + temperature: 0.0, + top_p: 1.0, + top_k: None, + seed: None, + suppressed_token_ids: Vec::new(), + }, + stop_sequences: Vec::new(), + stop_token_ids: Vec::new(), + hidden_reasoning: None, + response_policy: ResponsePolicy::default(), + }; + let mut state = InferState { + model, + model_name: "olmo-debug".to_string(), + prompt_cache: None, + }; + + let prompts = [ + "Reply with exactly: blue", + "What is the capital of France? Reply with one word.", + "List the RGB primary colors as full lowercase words only, comma-separated, with no abbreviations.", + "Complete exactly: 2 + 2 =", + "Name the largest planet in the Solar System. Reply with one word.", + "What day comes after Monday? Reply with one word.", + ]; + + for prompt in prompts { + let rendered = render_chat_prompt_from_request( + &state.model.prompt_template, + &serde_json::json!({ + "messages": [{"role": "user", "content": prompt}] + }), + ) + .expect("render prompt"); + if prompt.contains("Monday") { + let tokens = encode_prompt_tokens(&state.model, &rendered).expect("encode monday"); + println!("rendered monday {:?}", rendered); + println!("rendered monday tokens {:?}", tokens); + } + let outcome = + run_inference(&mut state, &rendered, &generation).expect("sequence inference"); + println!("prompt {:?} -> {:?}", prompt, outcome.text); + } + + let fresh_model = MlxModel::load(dir).expect("reload local olmo mlx artifact"); + let mut fresh_state = InferState { + model: fresh_model, + model_name: "olmo-debug-fresh".to_string(), + prompt_cache: None, + }; + let monday = render_chat_prompt_from_request( + &fresh_state.model.prompt_template, + &serde_json::json!({ + "messages": [{"role": "user", "content": "What day comes after Monday? Reply with one word."}] + }), + ) + .expect("render monday"); + let monday_tokens = + encode_prompt_tokens(&fresh_state.model, &monday).expect("encode monday"); + println!("fresh rendered monday {:?}", monday); + println!("fresh rendered monday tokens {:?}", monday_tokens); + let fresh_outcome = + run_inference(&mut fresh_state, &monday, &generation).expect("fresh monday"); + println!("fresh monday {:?}", fresh_outcome.text); + } +} + +async fn send_response(stream: &mut tokio::net::TcpStream, status: u16, body: &str) -> Result<()> { + let status_text = match status { + 200 => "OK", + 400 => "Bad Request", + 404 => "Not Found", + 405 => "Method Not Allowed", + 413 => "Payload Too Large", + 500 => "Internal Server Error", + 501 => "Not Implemented", + _ => "Unknown", + }; + let response = format!( + "HTTP/1.1 {status} {status_text}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", + body.len() + ); + stream.write_all(response.as_bytes()).await?; + let _ = stream.shutdown().await; + Ok(()) +} diff --git a/mesh-llm/src/mlx/template.rs b/mesh-llm/src/mlx/template.rs new file mode 100644 index 00000000..3a2a4c7f --- /dev/null +++ b/mesh-llm/src/mlx/template.rs @@ -0,0 +1,997 @@ +use anyhow::{Context, Result}; +use chrono::Local; +use minijinja::{Environment, ErrorKind, UndefinedBehavior}; +use serde_json::Value; +use std::path::Path; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PromptTemplate { + HuggingFace { + template: String, + special_tokens: SpecialTokens, + source_file: String, + behavior: crate::models::ModelPromptBehavior, + reasoning_defaults: ReasoningDefaults, + reasoning_template: ReasoningTemplate, + fallback: Box<PromptTemplate>, + }, + ChatMl { + default_system_prompt: Option<String>, + }, + Olmo, + Olmo2, + Gemma3, + Llama3, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct SpecialTokens { + bos_token: Option<String>, + eos_token: Option<String>, + pad_token: Option<String>, + unk_token: Option<String>, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct ReasoningDefaults { + enable_thinking: Option<bool>, + thinking: Option<bool>, + keep_past_thinking: Option<bool>, + reasoning_effort: Option<String>, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct ReasoningTemplate { + pub supports_explicit_reasoning: bool, + pub tagged_reasoning: Vec<TaggedReasoningBlock>, + pub default_stop_sequences: Vec<String>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct TaggedReasoningBlock { + pub start: String, + pub end: String, +} + +impl PromptTemplate { + pub fn detect(dir: &Path, config: &Value) -> Self { + let fallback = heuristic_prompt_template(config); + if let Some((source_file, template)) = read_template_text(dir) { + let template = normalize_hf_template(&template); + let reasoning_template = detect_reasoning_template(&template); + if let Err(err) = validate_hf_template(&template) { + tracing::warn!( + "MLX prompt template: failed to compile HF template from {}: {err}; falling back to {:?}", + source_file, + fallback.behavior().prompt_template + ); + return fallback; + } + let behavior = crate::models::infer_prompt_behavior_for_dir(dir) + .unwrap_or_else(|| fallback.behavior()); + tracing::info!( + "MLX prompt template: loaded HF template from {} (kind={}, source={})", + source_file, + behavior + .prompt_template + .clone() + .unwrap_or_else(|| "unknown".to_string()), + behavior + .template_source + .clone() + .unwrap_or_else(|| "unknown".to_string()), + ); + return PromptTemplate::HuggingFace { + template, + special_tokens: read_special_tokens(dir), + source_file, + behavior, + reasoning_defaults: reasoning_defaults(config), + reasoning_template, + fallback: Box::new(fallback), + }; + } + let behavior = fallback.behavior(); + tracing::info!( + "MLX prompt template: no HF template found in {} (chat_template.jinja={}, chat_template.json={}, tokenizer_config.json={}), using {} fallback", + dir.display(), + dir.join("chat_template.jinja").exists(), + dir.join("chat_template.json").exists(), + dir.join("tokenizer_config.json").exists(), + behavior + .prompt_template + .clone() + .unwrap_or_else(|| "unknown".to_string()), + ); + fallback + } + + pub fn behavior(&self) -> crate::models::ModelPromptBehavior { + match self { + PromptTemplate::HuggingFace { behavior, .. } => behavior.clone(), + PromptTemplate::ChatMl { + default_system_prompt, + } => crate::models::ModelPromptBehavior { + prompt_template: Some("chatml".to_string()), + default_system_prompt: default_system_prompt.clone(), + template_source: Some("fallback".to_string()), + }, + PromptTemplate::Olmo2 => crate::models::ModelPromptBehavior { + prompt_template: Some("olmo2".to_string()), + default_system_prompt: None, + template_source: Some("fallback".to_string()), + }, + PromptTemplate::Olmo => crate::models::ModelPromptBehavior { + prompt_template: Some("olmo".to_string()), + default_system_prompt: None, + template_source: Some("fallback".to_string()), + }, + PromptTemplate::Gemma3 => crate::models::ModelPromptBehavior { + prompt_template: Some("gemma3".to_string()), + default_system_prompt: None, + template_source: Some("fallback".to_string()), + }, + PromptTemplate::Llama3 => crate::models::ModelPromptBehavior { + prompt_template: Some("llama3".to_string()), + default_system_prompt: None, + template_source: Some("fallback".to_string()), + }, + } + } + + pub fn render_request(&self, req: &Value) -> Result<String> { + match self { + PromptTemplate::HuggingFace { + template, + special_tokens, + reasoning_defaults, + source_file, + fallback, + .. + } => match render_hf_template(template, special_tokens, reasoning_defaults, req) { + Ok(prompt) => Ok(prompt), + Err(err) => { + tracing::warn!( + "MLX prompt template: failed to render HF template from {}: {err}; falling back to {:?}", + source_file, + fallback.behavior().prompt_template + ); + fallback.render_request(req) + } + }, + PromptTemplate::ChatMl { + default_system_prompt, + } => { + let messages = req["messages"] + .as_array() + .context("missing messages array")?; + Ok(render_chatml(messages, default_system_prompt.as_deref())) + } + PromptTemplate::Olmo2 => { + let messages = req["messages"] + .as_array() + .context("missing messages array")?; + Ok(render_olmo2(messages)) + } + PromptTemplate::Olmo => { + let messages = req["messages"] + .as_array() + .context("missing messages array")?; + Ok(render_olmo(messages)) + } + PromptTemplate::Gemma3 => { + let messages = req["messages"] + .as_array() + .context("missing messages array")?; + Ok(render_gemma3(messages)) + } + PromptTemplate::Llama3 => { + let messages = req["messages"] + .as_array() + .context("missing messages array")?; + Ok(render_llama3(messages)) + } + } + } + + pub fn reasoning_template(&self) -> ReasoningTemplate { + match self { + PromptTemplate::HuggingFace { + reasoning_template, .. + } => reasoning_template.clone(), + PromptTemplate::ChatMl { .. } + | PromptTemplate::Olmo + | PromptTemplate::Olmo2 + | PromptTemplate::Gemma3 + | PromptTemplate::Llama3 => ReasoningTemplate::default(), + } + } +} + +fn validate_hf_template(template: &str) -> Result<()> { + let mut env = build_hf_environment(); + env.add_template("chat", template) + .context("compile HF chat template")?; + Ok(()) +} + +fn heuristic_prompt_template(config: &Value) -> PromptTemplate { + let model_type = config + .get("model_type") + .and_then(|value| value.as_str()) + .unwrap_or_default() + .to_ascii_lowercase(); + let architectures = config + .get("architectures") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|value| value.as_str()) + .map(|value| value.to_ascii_lowercase()) + .collect::<Vec<_>>(); + + if model_type.starts_with("qwen") || architectures.iter().any(|value| value.contains("qwen")) { + return PromptTemplate::ChatMl { + default_system_prompt: Some("You are a helpful assistant.".to_string()), + }; + } + if model_type.starts_with("olmo2") || architectures.iter().any(|value| value.contains("olmo2")) + { + return PromptTemplate::Olmo2; + } + if model_type.starts_with("olmo") || architectures.iter().any(|value| value.contains("olmo")) { + return PromptTemplate::Olmo; + } + if model_type.starts_with("gemma") || architectures.iter().any(|value| value.contains("gemma")) + { + return PromptTemplate::Gemma3; + } + + PromptTemplate::Llama3 +} + +fn render_hf_template( + template: &str, + special_tokens: &SpecialTokens, + reasoning_defaults: &ReasoningDefaults, + req: &Value, +) -> Result<String> { + let mut env = build_hf_environment(); + env.add_template("chat", template) + .context("compile HF chat template")?; + + let tmpl = env.get_template("chat").context("load HF chat template")?; + let messages = normalize_hf_messages( + template, + req.get("messages") + .cloned() + .unwrap_or_else(|| Value::Array(Vec::new())), + ); + let tools = req.get("tools").cloned(); + let custom_tools = req.get("custom_tools").cloned(); + let add_generation_prompt = req + .get("add_generation_prompt") + .and_then(|value| value.as_bool()) + .unwrap_or(true); + let mut ctx = serde_json::Map::new(); + ctx.insert("messages".to_string(), messages); + ctx.insert( + "tools".to_string(), + tools.unwrap_or_else(|| absent_tools_value(template)), + ); + ctx.insert( + "documents".to_string(), + req.get("documents").cloned().unwrap_or(Value::Null), + ); + ctx.insert( + "builtin_tools".to_string(), + req.get("builtin_tools").cloned().unwrap_or(Value::Null), + ); + ctx.insert( + "add_generation_prompt".to_string(), + Value::Bool(add_generation_prompt), + ); + if let Some(custom_tools) = custom_tools { + ctx.insert("custom_tools".to_string(), custom_tools); + } + for (key, value) in [ + ( + "tools_in_user_message", + template_kwarg(req, "tools_in_user_message"), + ), + ( + "keep_past_thinking", + template_kwarg(req, "keep_past_thinking"), + ), + ("date_string", template_kwarg(req, "date_string")), + ("reasoning_effort", template_kwarg(req, "reasoning_effort")), + ("thinking", template_kwarg(req, "thinking")), + ] { + if let Some(value) = value { + ctx.insert(key.to_string(), value); + } + } + match template_kwarg(req, "enable_thinking") { + Some(value) => { + ctx.insert("enable_thinking".to_string(), value); + } + None => { + if let Some(default_enable_thinking) = reasoning_defaults.enable_thinking { + ctx.insert( + "enable_thinking".to_string(), + Value::Bool(default_enable_thinking), + ); + } + } + } + if !ctx.contains_key("thinking") { + if let Some(value) = template_kwarg(req, "enable_thinking") { + ctx.insert("thinking".to_string(), value); + } else if let Some(default_thinking) = reasoning_defaults.thinking { + ctx.insert("thinking".to_string(), Value::Bool(default_thinking)); + } + } + if !ctx.contains_key("keep_past_thinking") { + if let Some(value) = template_kwarg(req, "enable_thinking") { + ctx.insert("keep_past_thinking".to_string(), value); + } else if let Some(default_keep_past_thinking) = reasoning_defaults.keep_past_thinking { + ctx.insert( + "keep_past_thinking".to_string(), + Value::Bool(default_keep_past_thinking), + ); + } + } + if !ctx.contains_key("reasoning_effort") { + if let Some(value) = template_kwarg(req, "enable_thinking") { + if value == Value::Bool(false) { + ctx.insert( + "reasoning_effort".to_string(), + Value::String("low".to_string()), + ); + } else if value == Value::Bool(true) { + ctx.insert( + "reasoning_effort".to_string(), + Value::String("medium".to_string()), + ); + } + } else if let Some(default_reasoning_effort) = &reasoning_defaults.reasoning_effort { + ctx.insert( + "reasoning_effort".to_string(), + Value::String(default_reasoning_effort.clone()), + ); + } + } + if let Some(token) = &special_tokens.bos_token { + ctx.insert("bos_token".to_string(), Value::String(token.clone())); + } + if let Some(token) = &special_tokens.eos_token { + ctx.insert("eos_token".to_string(), Value::String(token.clone())); + } + if let Some(token) = &special_tokens.pad_token { + ctx.insert("pad_token".to_string(), Value::String(token.clone())); + } + if let Some(token) = &special_tokens.unk_token { + ctx.insert("unk_token".to_string(), Value::String(token.clone())); + } + + let rendered = tmpl.render(Value::Object(ctx))?; + + Ok(strip_empty_reasoning_prefill( + rendered, + template, + req, + reasoning_defaults, + )) +} + +fn strip_empty_reasoning_prefill( + rendered: String, + template: &str, + req: &Value, + reasoning_defaults: &ReasoningDefaults, +) -> String { + let thinking_disabled = match template_kwarg(req, "enable_thinking") { + Some(Value::Bool(value)) => !value, + Some(_) => false, + None => reasoning_defaults.enable_thinking == Some(false), + }; + if !thinking_disabled { + return rendered; + } + + let assistant_prefix = "<|im_start|>assistant\n"; + if let Some(prefix_start) = rendered.rfind(assistant_prefix) { + let suffix_start = prefix_start + assistant_prefix.len(); + let suffix = &rendered[suffix_start..]; + if suffix.trim() == "<think>\n\n</think>" { + if is_old_qwen_reasoning_template(template) && request_has_assistant_history(req) { + return rendered; + } + return rendered[..suffix_start].to_string(); + } + } + + rendered +} + +fn request_has_assistant_history(req: &Value) -> bool { + req.get("messages") + .and_then(Value::as_array) + .is_some_and(|messages| { + messages.iter().any(|message| { + message + .get("role") + .and_then(Value::as_str) + .is_some_and(|role| role == "assistant") + }) + }) +} + +fn reasoning_defaults(config: &Value) -> ReasoningDefaults { + let model_type = config + .get("model_type") + .and_then(|value| value.as_str()) + .unwrap_or_default() + .to_ascii_lowercase(); + let architectures = config + .get("architectures") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|value| value.as_str()) + .map(|value| value.to_ascii_lowercase()) + .collect::<Vec<_>>(); + + if model_type == "qwen3" || architectures.iter().any(|value| value.contains("qwen3")) { + return ReasoningDefaults { + enable_thinking: Some(false), + ..ReasoningDefaults::default() + }; + } + if model_type.starts_with("glm") || architectures.iter().any(|value| value.contains("glm")) { + return ReasoningDefaults { + enable_thinking: Some(false), + ..ReasoningDefaults::default() + }; + } + if model_type == "kimi" || architectures.iter().any(|value| value.contains("kimi")) { + return ReasoningDefaults { + thinking: Some(false), + ..ReasoningDefaults::default() + }; + } + if model_type == "gpt_oss" || architectures.iter().any(|value| value.contains("gptoss")) { + return ReasoningDefaults { + reasoning_effort: Some("low".to_string()), + ..ReasoningDefaults::default() + }; + } + if model_type == "lfm2" || architectures.iter().any(|value| value.contains("lfm2")) { + return ReasoningDefaults { + keep_past_thinking: Some(false), + ..ReasoningDefaults::default() + }; + } + + ReasoningDefaults::default() +} + +fn detect_reasoning_template(template: &str) -> ReasoningTemplate { + let mut tagged_reasoning = Vec::new(); + + if is_old_qwen_reasoning_template(template) || template_mentions_think_tags(template) { + tagged_reasoning.push(TaggedReasoningBlock { + start: "<think>".to_string(), + end: "</think>".to_string(), + }); + } + + if template.contains("<|channel>thought") && template.contains("<channel|>") { + tagged_reasoning.push(TaggedReasoningBlock { + start: "<|channel>thought".to_string(), + end: "<channel|>".to_string(), + }); + } + + tagged_reasoning + .sort_by(|left, right| left.start.cmp(&right.start).then(left.end.cmp(&right.end))); + tagged_reasoning.dedup(); + + ReasoningTemplate { + supports_explicit_reasoning: template_supports_explicit_reasoning(template) + || !tagged_reasoning.is_empty(), + tagged_reasoning, + default_stop_sequences: detect_default_stop_sequences(template), + } +} + +fn detect_default_stop_sequences(template: &str) -> Vec<String> { + let mut stops = Vec::new(); + + for stop in [ + "<|im_end|>", + "<|im_start|>", + "<|eot_id|>", + "<end_of_turn>", + "<turn|>", + ] { + if template.contains(stop) { + stops.push(stop.to_string()); + } + } + + stops.sort(); + stops.dedup(); + stops +} + +fn template_supports_explicit_reasoning(template: &str) -> bool { + [ + "enable_thinking", + "thinking", + "keep_past_thinking", + "reasoning_effort", + "reasoning_content", + ] + .into_iter() + .any(|needle| template.contains(needle)) +} + +fn template_mentions_think_tags(template: &str) -> bool { + template.contains("<think>") && template.contains("</think>") +} + +fn is_old_qwen_reasoning_template(template: &str) -> bool { + let splits_on_end_think = [ + "split('</think>')", + "split(\"</think>\")", + "| split('</think>')", + "| split(\"</think>\")", + ] + .into_iter() + .any(|needle| template.contains(needle)); + + splits_on_end_think && !template.contains("<SPECIAL_12>") +} + +fn template_kwarg(req: &Value, key: &str) -> Option<Value> { + req.get(key).cloned().or_else(|| { + req.get("chat_template_kwargs") + .and_then(|value| value.get(key)) + .cloned() + }) +} + +fn normalize_hf_messages(template: &str, messages: Value) -> Value { + let Some(messages) = messages.as_array() else { + return messages; + }; + let fill_tool_calls = message_field_mentioned(template, "tool_calls") + && !message_field_membership_checked(template, "tool_calls"); + let fill_tool_call_id = message_field_mentioned(template, "tool_call_id") + && !message_field_membership_checked(template, "tool_call_id"); + let fill_name = message_field_mentioned(template, "name") + && !message_field_membership_checked(template, "name"); + let fill_tool_responses = message_field_mentioned(template, "tool_responses") + && !message_field_membership_checked(template, "tool_responses"); + let trim_assistant_history = is_old_qwen_reasoning_template(template); + + Value::Array( + messages + .iter() + .map(|message| { + let Some(object) = message.as_object() else { + return message.clone(); + }; + let mut normalized = object.clone(); + for (key, should_fill, value) in [ + ("tool_calls", fill_tool_calls, Value::Null), + ("tool_call_id", fill_tool_call_id, Value::Null), + ("name", fill_name, Value::Null), + ("tool_responses", fill_tool_responses, Value::Null), + ] { + if should_fill { + normalized.entry(key.to_string()).or_insert(value); + } + } + if trim_assistant_history + && normalized + .get("role") + .and_then(Value::as_str) + .is_some_and(|role| role == "assistant") + { + if let Some(Value::String(text)) = normalized.get_mut("content") { + *text = text.trim_start().to_string(); + } + } + Value::Object(normalized) + }) + .collect(), + ) +} + +fn message_field_mentioned(template: &str, field: &str) -> bool { + template.contains(&format!("message['{field}']")) + || template.contains(&format!("message[\"{field}\"]")) + || template.contains(&format!("message.{field}")) +} + +fn message_field_membership_checked(template: &str, field: &str) -> bool { + template.contains(&format!("'{field}' in message")) + || template.contains(&format!("\"{field}\" in message")) +} + +fn absent_tools_value(template: &str) -> Value { + let uses_length = template.contains("tools|length") + || template.contains("tools | length") + || template.contains("tools|count") + || template.contains("tools | count"); + if template.contains("tools is not none") && !uses_length { + Value::Null + } else { + Value::Array(Vec::new()) + } +} + +fn build_hf_environment<'a>() -> Environment<'a> { + let mut env = Environment::new(); + env.set_undefined_behavior(UndefinedBehavior::Strict); + env.set_trim_blocks(true); + env.set_lstrip_blocks(true); + env.add_function( + "raise_exception", + |message: String| -> std::result::Result<String, minijinja::Error> { + Err(minijinja::Error::new(ErrorKind::InvalidOperation, message)) + }, + ); + env.add_function( + "strftime_now", + |format: String| -> std::result::Result<String, minijinja::Error> { + Ok(Local::now().format(&format).to_string()) + }, + ); + env.add_filter("startswith", |value: String, prefix: String| { + value.starts_with(&prefix) + }); + env.add_filter("endswith", |value: String, suffix: String| { + value.ends_with(&suffix) + }); + env.add_filter("split", |value: String, separator: String| { + value + .split(&separator) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + }); + env.add_filter("strip", |value: String, chars: Option<String>| { + strip_chars(&value, chars.as_deref(), true, true) + }); + env.add_filter("lstrip", |value: String, chars: Option<String>| { + strip_chars(&value, chars.as_deref(), true, false) + }); + env.add_filter("rstrip", |value: String, chars: Option<String>| { + strip_chars(&value, chars.as_deref(), false, true) + }); + env +} + +fn normalize_hf_template(template: &str) -> String { + let single_get_re = + regex_lite::Regex::new(r#"\.get\(\s*'([^']+)'\s*(?:,\s*([^)]+?))?\s*\)"#).unwrap(); + let double_get_re = + regex_lite::Regex::new(r#"\.get\(\s*\"([^\"]+)\"\s*(?:,\s*([^)]+?))?\s*\)"#).unwrap(); + let split_index_re = + regex_lite::Regex::new(r#"\.split\(([^()]*)\)\s*\[\s*(-?1|0)\s*\]"#).unwrap(); + let mut normalized = single_get_re + .replace_all(template, |caps: ®ex_lite::Captures<'_>| { + let key = &caps[1]; + let default = caps.get(2).map(|m| m.as_str().trim()).unwrap_or("none"); + format!(r#"["{key}"]|default({default})"#) + }) + .to_string(); + normalized = double_get_re + .replace_all(&normalized, |caps: ®ex_lite::Captures<'_>| { + let key = &caps[1]; + let default = caps.get(2).map(|m| m.as_str().trim()).unwrap_or("none"); + format!(r#"["{key}"]|default({default})"#) + }) + .to_string(); + normalized = split_index_re + .replace_all(&normalized, |caps: ®ex_lite::Captures<'_>| { + let args = caps[1].trim(); + let selector = if &caps[2] == "-1" { "last" } else { "first" }; + format!(" | split({args}) | {selector}") + }) + .to_string(); + for (from, to) in [ + (".lstrip(", " | lstrip("), + (".rstrip(", " | rstrip("), + (".startswith(", " | startswith("), + (".endswith(", " | endswith("), + (".split(", " | split("), + (".strip(", " | strip("), + (".keys()", " | items | map(attribute=0)"), + ("|items", "| items"), + ] { + normalized = normalized.replace(from, to); + } + + strip_tojson_kwargs(&normalized) +} + +fn strip_chars(value: &str, chars: Option<&str>, left: bool, right: bool) -> String { + match chars { + Some(chars) => { + let predicate = |c: char| chars.contains(c); + match (left, right) { + (true, true) => value.trim_matches(predicate).to_string(), + (true, false) => value.trim_start_matches(predicate).to_string(), + (false, true) => value.trim_end_matches(predicate).to_string(), + (false, false) => value.to_string(), + } + } + None => match (left, right) { + (true, true) => value.trim().to_string(), + (true, false) => value.trim_start().to_string(), + (false, true) => value.trim_end().to_string(), + (false, false) => value.to_string(), + }, + } +} + +fn strip_tojson_kwargs(template: &str) -> String { + let mut out = String::with_capacity(template.len()); + let mut cursor = 0usize; + + while let Some(rel) = template[cursor..].find("tojson(") { + let start = cursor + rel; + out.push_str(&template[cursor..start]); + + let args_start = start + "tojson(".len(); + let bytes = template.as_bytes(); + let mut i = args_start; + let mut depth = 1usize; + + while i < bytes.len() && depth > 0 { + match bytes[i] as char { + '(' => depth += 1, + ')' => depth -= 1, + '"' | '\'' => { + let quote = bytes[i]; + i += 1; + while i < bytes.len() { + if bytes[i] == b'\\' { + i += 2; + continue; + } + if bytes[i] == quote { + break; + } + i += 1; + } + } + _ => {} + } + i += 1; + } + + if depth != 0 { + out.push_str(&template[start..]); + return out; + } + + let args = &template[args_start..i - 1]; + if args.contains("separators") || args.contains("ensure_ascii") { + out.push_str("tojson"); + } else { + out.push_str(&template[start..i]); + } + cursor = i; + } + + out.push_str(&template[cursor..]); + out +} + +fn read_template_text(dir: &Path) -> Option<(String, String)> { + crate::models::prompt::find_template_with_source(dir) +} + +fn read_special_tokens(dir: &Path) -> SpecialTokens { + let mut tokens = SpecialTokens::default(); + let path = dir.join("tokenizer_config.json"); + let Ok(text) = std::fs::read_to_string(path) else { + return tokens; + }; + let Ok(value) = serde_json::from_str::<Value>(&text) else { + return tokens; + }; + + tokens.bos_token = extract_token_string(value.get("bos_token")); + tokens.eos_token = extract_token_string(value.get("eos_token")); + tokens.pad_token = extract_token_string(value.get("pad_token")); + tokens.unk_token = extract_token_string(value.get("unk_token")); + tokens +} + +fn extract_token_string(value: Option<&Value>) -> Option<String> { + match value { + Some(Value::String(text)) => Some(text.clone()), + Some(Value::Object(map)) => map + .get("content") + .and_then(|content| content.as_str()) + .map(ToOwned::to_owned), + _ => None, + } +} + +fn render_chatml(messages: &[Value], default_system_prompt: Option<&str>) -> String { + let mut prompt = String::new(); + if let Some(default_system_prompt) = default_system_prompt { + let starts_with_system = messages + .first() + .and_then(|message| message.get("role")) + .and_then(|role| role.as_str()) + == Some("system"); + if !starts_with_system { + prompt.push_str("<|im_start|>system\n"); + prompt.push_str(default_system_prompt); + prompt.push_str("<|im_end|>\n"); + } + } + + for message in messages { + let role = message + .get("role") + .and_then(|role| role.as_str()) + .unwrap_or("user"); + prompt.push_str("<|im_start|>"); + prompt.push_str(role); + prompt.push('\n'); + prompt.push_str(&message_content_text(message)); + prompt.push_str("<|im_end|>\n"); + } + prompt.push_str("<|im_start|>assistant\n"); + prompt +} + +fn render_olmo2(messages: &[Value]) -> String { + let mut prompt = String::from("<|endoftext|>"); + for (index, message) in messages.iter().enumerate() { + let role = match message.get("role").and_then(|value| value.as_str()) { + Some("assistant") => "assistant", + Some("system") => "system", + _ => "user", + }; + prompt.push_str(&format!("<|{role}|>\n")); + prompt.push_str(&message_content_text(message)); + prompt.push('\n'); + if role == "assistant" && index + 1 != messages.len() { + prompt.push_str("<|endoftext|>\n"); + } + } + prompt.push_str("<|assistant|>\n"); + prompt +} + +fn render_olmo(messages: &[Value]) -> String { + let mut prompt = String::from("<|endoftext|>"); + for (index, message) in messages.iter().enumerate() { + let role = message + .get("role") + .and_then(Value::as_str) + .unwrap_or("user"); + let content = message_content_text(message); + match role { + "assistant" => { + prompt.push_str("<|assistant|>\n"); + prompt.push_str(&content); + prompt.push_str("<|endoftext|>"); + } + _ => { + prompt.push_str("<|user|>\n"); + prompt.push_str(&content); + } + } + if index == messages.len() - 1 { + prompt.push_str("<|assistant|>"); + } + } + prompt +} + +fn render_llama3(messages: &[Value]) -> String { + let mut prompt = String::from("<|begin_of_text|>"); + for message in messages { + let role = message + .get("role") + .and_then(|role| role.as_str()) + .unwrap_or("user"); + prompt.push_str("<|start_header_id|>"); + prompt.push_str(role); + prompt.push_str("<|end_header_id|>\n\n"); + prompt.push_str(&message_content_text(message)); + prompt.push_str("<|eot_id|>"); + } + prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n"); + prompt +} + +fn render_gemma3(messages: &[Value]) -> String { + let mut prompt = String::from("<bos>"); + let mut loop_messages = messages; + let mut first_user_prefix = String::new(); + + if let Some(first) = messages.first() { + if first.get("role").and_then(|role| role.as_str()) == Some("system") { + first_user_prefix = message_content_text(first); + if !first_user_prefix.is_empty() { + first_user_prefix.push_str("\n\n"); + } + loop_messages = &messages[1..]; + } + } + + for (index, message) in loop_messages.iter().enumerate() { + let role = match message.get("role").and_then(|role| role.as_str()) { + Some("assistant") => "model", + Some("user") => "user", + Some(other) => other, + None => "user", + }; + prompt.push_str("<start_of_turn>"); + prompt.push_str(role); + prompt.push('\n'); + if index == 0 && !first_user_prefix.is_empty() { + prompt.push_str(&first_user_prefix); + } + prompt.push_str(&gemma_message_content_text(message)); + prompt.push_str("<end_of_turn>\n"); + } + prompt.push_str("<start_of_turn>model\n"); + prompt +} + +fn message_content_text(message: &Value) -> String { + match message.get("content") { + Some(Value::String(text)) => text.clone(), + Some(Value::Array(parts)) => parts + .iter() + .filter_map(|part| match part { + Value::Object(map) => map.get("text").and_then(|value| value.as_str()), + Value::String(text) => Some(text.as_str()), + _ => None, + }) + .collect::<Vec<_>>() + .join(""), + _ => String::new(), + } +} + +fn gemma_message_content_text(message: &Value) -> String { + match message.get("content") { + Some(Value::String(text)) => text.trim().to_string(), + Some(Value::Array(parts)) => { + let mut out = String::new(); + for part in parts { + match part { + Value::Object(map) => match map.get("type").and_then(|value| value.as_str()) { + Some("image") => out.push_str("<start_of_image>"), + Some("text") => { + if let Some(text) = map.get("text").and_then(|value| value.as_str()) { + out.push_str(text.trim()); + } + } + _ => {} + }, + Value::String(text) => out.push_str(text.trim()), + _ => {} + } + } + out + } + _ => String::new(), + } +} + +#[cfg(test)] +mod tests; diff --git a/mesh-llm/src/mlx/template/tests.rs b/mesh-llm/src/mlx/template/tests.rs new file mode 100644 index 00000000..3741cbc4 --- /dev/null +++ b/mesh-llm/src/mlx/template/tests.rs @@ -0,0 +1,1112 @@ +use super::*; +use serde::Deserialize; +use serde_json::json; +use std::path::Path; + +#[derive(Debug, Deserialize)] +struct HfTemplateFixture { + repo: String, + source_file: String, + expect_hf_render: bool, + family: String, + bos_token: Option<String>, + eos_token: Option<String>, + pad_token: Option<String>, + unk_token: Option<String>, + template: String, +} + +fn hf_template_corpus() -> Vec<HfTemplateFixture> { + serde_json::from_str(include_str!("../testdata/hf_template_corpus.json")) + .expect("valid HF template corpus") +} + +fn fixture_config(family: &str) -> Value { + match family { + "llama" => json!({"model_type":"llama","architectures":["LlamaForCausalLM"]}), + "qwen" | "qwen3" | "qwen3_coder_next" | "deepseek_qwen3" => { + json!({"model_type":"qwen2","architectures":["Qwen2ForCausalLM"]}) + } + "qwen3_coder_30b" => json!({"model_type":"qwen2","architectures":["Qwen2ForCausalLM"]}), + "gemma3" => { + json!({"model_type":"gemma3","architectures":["Gemma3ForConditionalGeneration"]}) + } + "mistral" => json!({"model_type":"mistral","architectures":["MistralForCausalLM"]}), + "lfm2" => json!({"model_type":"lfm2","architectures":["LlamaForCausalLM"]}), + "devstral" => json!({"model_type":"mistral","architectures":["MistralForCausalLM"]}), + "glm4v" => json!({"model_type":"glm","architectures":["GlmForCausalLM"]}), + "kimi" => json!({"model_type":"kimi","architectures":["KimiForCausalLM"]}), + "gpt_oss" => json!({"model_type":"gpt_oss","architectures":["GptOssForCausalLM"]}), + other => panic!("unknown fixture family: {other}"), + } +} + +fn fixture_request(family: &str) -> Value { + match family { + "llama" => json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"} + ], + "add_generation_prompt": true + }), + "qwen" => json!({ + "messages": [{"role": "user", "content": "hello"}], + "tools": [{"type": "function", "function": {"name": "run", "description": "Run a command"}}], + "add_generation_prompt": true + }), + "gemma3" => json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": [ + {"type": "text", "text": "look "}, + {"type": "image"}, + {"type": "text", "text": "here"} + ]} + ], + "add_generation_prompt": true + }), + "mistral" => json!({ + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "again"} + ], + "add_generation_prompt": true + }), + "lfm2" => json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "<think>\ninternal\n</think>\nhi"}, + {"role": "user", "content": [{"type": "text", "text": "look"}, {"type": "image"}]} + ], + "keep_past_thinking": false, + "add_generation_prompt": true + }), + "deepseek_qwen3" => json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"} + ], + "add_generation_prompt": true + }), + "qwen3" | "qwen3_coder_next" | "qwen3_coder_30b" => json!({ + "messages": [{"role": "user", "content": "hello"}], + "add_generation_prompt": true + }), + "devstral" => json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"} + ], + "add_generation_prompt": true + }), + "glm4v" => json!({ + "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + "tools": [{"type": "function", "function": {"name": "run", "description": "Run a command"}}], + "add_generation_prompt": true + }), + "kimi" => json!({ + "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}, {"type": "image_url"}]}], + "add_generation_prompt": true + }), + "gpt_oss" => json!({ + "messages": [{"role": "user", "content": "hello"}], + "builtin_tools": ["browser", "python"], + "reasoning_effort": "medium", + "add_generation_prompt": true + }), + other => panic!("unknown fixture family: {other}"), + } +} + +#[test] +fn olmo2_heuristic_fallback_is_selected_for_olmo2_configs() { + let config = json!({ + "model_type": "olmo2", + "architectures": ["Olmo2ForCausalLM"] + }); + assert_eq!(heuristic_prompt_template(&config), PromptTemplate::Olmo2); +} + +#[test] +fn olmo_heuristic_fallback_is_selected_for_olmo_configs() { + let config = json!({ + "model_type": "olmo", + "architectures": ["OlmoForCausalLM"] + }); + assert_eq!(heuristic_prompt_template(&config), PromptTemplate::Olmo); +} + +#[test] +fn render_olmo_matches_origin_role_marker_shape() { + let messages = vec![ + json!({"role": "user", "content": "Say hi"}), + json!({"role": "assistant", "content": "Hi."}), + json!({"role": "user", "content": "Again"}), + ]; + let prompt = render_olmo(&messages); + assert!(prompt.starts_with("<|endoftext|><|user|>\nSay hi")); + assert!(prompt.contains("<|assistant|>\nHi.<|endoftext|>")); + assert!(prompt.ends_with("<|assistant|>")); +} + +#[test] +fn render_olmo2_matches_origin_role_marker_shape() { + let messages = vec![ + json!({"role": "system", "content": "Be concise."}), + json!({"role": "user", "content": "Say hi"}), + json!({"role": "assistant", "content": "Hi."}), + json!({"role": "user", "content": "Again"}), + ]; + let prompt = render_olmo2(&messages); + assert!(prompt.starts_with("<|endoftext|><|system|>\nBe concise.\n")); + assert!(prompt.contains("<|user|>\nSay hi\n")); + assert!(prompt.contains("<|assistant|>\nHi.\n<|endoftext|>\n")); + assert!(prompt.ends_with("<|assistant|>\n")); +} + +fn write_hf_fixture_dir(fixture: &HfTemplateFixture) -> std::path::PathBuf { + let slug = fixture + .repo + .replace('/', "-") + .replace('.', "-") + .replace('_', "-"); + let root = std::env::temp_dir().join(format!( + "mesh-llm-hf-template-corpus-{}-{}", + slug, + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + + match fixture.source_file.as_str() { + "chat_template.jinja" => { + std::fs::write(root.join("chat_template.jinja"), &fixture.template).unwrap(); + } + "chat_template.json" => { + std::fs::write( + root.join("chat_template.json"), + serde_json::to_string(&fixture.template).unwrap(), + ) + .unwrap(); + } + "tokenizer_config.json" => { + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": fixture.template, + "bos_token": fixture.bos_token, + "eos_token": fixture.eos_token, + "pad_token": fixture.pad_token, + "unk_token": fixture.unk_token, + }) + .to_string(), + ) + .unwrap(); + return root; + } + other => panic!("unknown template source: {other}"), + } + + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "bos_token": fixture.bos_token, + "eos_token": fixture.eos_token, + "pad_token": fixture.pad_token, + "unk_token": fixture.unk_token, + }) + .to_string(), + ) + .unwrap(); + + root +} + +#[test] +fn normalizes_python_get_calls() { + let template = "{{ msg.get('content') }} {{ msg.get(\"role\", \"user\") }}"; + let normalized = normalize_hf_template(template); + assert_eq!( + normalized, + "{{ msg[\"content\"]|default(none) }} {{ msg[\"role\"]|default(\"user\") }}" + ); +} + +#[test] +fn normalizes_tojson_keyword_arguments() { + let template = "{{ tools | tojson(separators=(',', ':'), ensure_ascii=False) }}"; + let normalized = normalize_hf_template(template); + assert_eq!(normalized, "{{ tools | tojson }}"); +} + +#[test] +fn detects_old_qwen_reasoning_tags_from_split_template() { + let template = "{{ content | split('</think>') | last }}"; + let reasoning = detect_reasoning_template(template); + assert!(reasoning.supports_explicit_reasoning); + assert_eq!( + reasoning.tagged_reasoning, + vec![TaggedReasoningBlock { + start: "<think>".to_string(), + end: "</think>".to_string(), + }] + ); +} + +#[test] +fn normalize_hf_messages_fills_missing_dot_access_fields() { + let template = "{% if message.tool_calls %}x{% endif %}{% if message.tool_call_id %}y{% endif %}{% if message.name %}z{% endif %}"; + let messages = json!([ + {"role": "assistant", "content": "hello"} + ]); + + let normalized = normalize_hf_messages(template, messages); + let message = &normalized.as_array().unwrap()[0]; + + assert!(message.get("tool_calls").is_some()); + assert!(message.get("tool_call_id").is_some()); + assert!(message.get("name").is_some()); +} + +#[test] +fn detects_gemma4_reasoning_channel_markers() { + let template = "{% if add_generation_prompt %}<|channel>thought{{ reasoning_content }}<channel|>{% endif %}"; + let reasoning = detect_reasoning_template(template); + assert!(reasoning.supports_explicit_reasoning); + assert_eq!( + reasoning.tagged_reasoning, + vec![TaggedReasoningBlock { + start: "<|channel>thought".to_string(), + end: "<channel|>".to_string(), + }] + ); + assert!(reasoning.default_stop_sequences.is_empty()); +} + +#[test] +fn detects_gemma4_turn_stop_marker() { + let template = "{% if add_generation_prompt %}<|turn>model\n{% endif %}{% for message in messages %}<turn|>\n{% endfor %}"; + let reasoning = detect_reasoning_template(template); + assert!(reasoning + .default_stop_sequences + .contains(&"<turn|>".to_string())); +} + +#[test] +fn prefers_chat_template_jinja_over_tokenizer_config() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-jinja-precedence-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write(root.join("chat_template.jinja"), "{{ '<jinja-template>' }}").unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{{ '<json-template>' }}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"qwen2"})); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + assert_eq!(prompt, "<jinja-template>"); +} + +#[test] +fn falls_back_when_template_uses_unsupported_python_method() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-unsupported-method-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("chat_template.jinja"), + "{% if messages[0].content.removeprefix('h') %}<bad>{% endif %}", + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"qwen2"})); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + assert!(prompt.contains("<|im_start|>system")); + assert!(prompt.contains("hello")); +} + +#[test] +fn kimi_template_compiles_after_normalization() { + let fixture = hf_template_corpus() + .into_iter() + .find(|fixture| fixture.repo == "mlx-community/Kimi-K2.5") + .expect("kimi fixture exists"); + let normalized = normalize_hf_template(&fixture.template); + validate_hf_template(&normalized).expect("normalized Kimi template should compile"); +} + +#[test] +fn detects_chatml_from_tokenizer_config() { + let root = + std::env::temp_dir().join(format!("mesh-llm-template-chatml-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"qwen2"})); + match template { + PromptTemplate::HuggingFace { + fallback, + reasoning_defaults, + .. + } => { + assert_eq!( + *fallback, + PromptTemplate::ChatMl { + default_system_prompt: Some("You are a helpful assistant.".to_string()) + } + ); + assert_eq!(reasoning_defaults, ReasoningDefaults::default()); + } + other => panic!("expected huggingface template, got {other:?}"), + } +} + +#[test] +fn qwen3_templates_default_enable_thinking_to_false() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen3-thinking-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- if enable_thinking is defined and enable_thinking is false %}{{- '<think>\\n\\n</think>\\n\\n' }}{%- endif %}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"qwen3","architectures":["Qwen3ForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + + assert_eq!(prompt, "<|im_start|>assistant\n"); +} + +#[test] +fn qwen3_templates_honor_explicit_enable_thinking_true() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen3-thinking-true-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- if enable_thinking is defined and enable_thinking is false %}{{- '<think>\\n\\n</think>\\n\\n' }}{%- endif %}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"qwen3","architectures":["Qwen3ForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "enable_thinking": true + })) + .unwrap(); + + assert_eq!(prompt, "<|im_start|>assistant\n"); +} + +#[test] +fn qwen3_templates_strip_empty_reasoning_prefill_when_explicitly_disabled() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen3-thinking-false-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- if enable_thinking is defined and enable_thinking is false %}{{- '<think>\\n\\n</think>\\n\\n' }}{%- endif %}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"qwen3","architectures":["Qwen3ForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "enable_thinking": false + })) + .unwrap(); + + assert_eq!(prompt, "<|im_start|>assistant\n"); +} + +#[test] +fn qwen3_templates_trim_leading_whitespace_from_assistant_history() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen3-assistant-history-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- for message in messages %}{%- if message.content is string %}{%- set content = message.content %}{%- else %}{%- set content = '' %}{%- endif %}{%- if message.role == 'assistant' %}{%- if '</think>' in content %}{%- set content = content.split('</think>')[-1].lstrip('\\n') %}{%- endif %}{{- '<|im_start|>assistant\\n' + content + '<|im_end|>\\n' }}{%- else %}{{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>\\n' }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- if enable_thinking is defined and enable_thinking is false %}{{- '<think>\\n\\n</think>\\n\\n' }}{%- endif %}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"qwen3","architectures":["Qwen3ForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "\n\nworld"} + ], + "enable_thinking": false + })) + .unwrap(); + + assert!(prompt.contains("<|im_start|>assistant\nworld<|im_end|>\n")); + assert!(!prompt.contains("<|im_start|>assistant\n\n\nworld<|im_end|>\n")); +} + +#[test] +fn qwen3_templates_preserve_empty_reasoning_prefill_for_followups() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen3-followup-prefill-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- set ns = namespace(last_query_index=messages|length - 1) %}{%- for message in messages[::-1] %}{%- set index = (messages|length - 1) - loop.index0 %}{%- if message.role == 'user' %}{%- set ns.last_query_index = index %}{%- break %}{%- endif %}{%- endfor %}{%- for message in messages %}{%- if message.content is string %}{%- set content = message.content %}{%- else %}{%- set content = '' %}{%- endif %}{%- if message.role == 'assistant' %}{%- set reasoning_content = '' %}{%- if '</think>' in content %}{%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}{%- set content = content.split('</think>')[-1].lstrip('\\n') %}{%- endif %}{%- if loop.index0 > ns.last_query_index %}{{- '<|im_start|>assistant\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') + '<|im_end|>\\n' }}{%- else %}{{- '<|im_start|>assistant\\n' + content + '<|im_end|>\\n' }}{%- endif %}{%- else %}{{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>\\n' }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- if enable_thinking is defined and enable_thinking is false %}{{- '<think>\\n\\n</think>\\n\\n' }}{%- endif %}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"qwen3","architectures":["Qwen3ForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + {"role": "user", "content": "follow up"} + ], + "enable_thinking": false + })) + .unwrap(); + + assert!(prompt.ends_with("<|im_start|>assistant\n<think>\n\n</think>\n\n")); +} + +fn corpus_fixture(repo: &str) -> HfTemplateFixture { + hf_template_corpus() + .into_iter() + .find(|fixture| fixture.repo == repo) + .unwrap_or_else(|| panic!("missing fixture for {repo}")) +} + +#[test] +fn corpus_mistral_template_reports_tokenizer_config_source() { + let fixture = corpus_fixture("mlx-community/Mistral-7B-Instruct-v0.2-4bit"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let behavior = template.behavior(); + + match template { + PromptTemplate::HuggingFace { source_file, .. } => { + assert_eq!(source_file, "tokenizer_config.json"); + } + other => panic!("expected huggingface template, got {other:?}"), + } + assert_eq!(behavior.prompt_template.as_deref(), Some("hf_template")); + assert_eq!(behavior.template_source.as_deref(), Some("huggingface")); +} + +#[test] +fn corpus_qwen3_template_reports_tokenizer_config_source() { + let fixture = corpus_fixture("mlx-community/Qwen3-1.7B-4bit"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let behavior = template.behavior(); + + match template { + PromptTemplate::HuggingFace { source_file, .. } => { + assert_eq!(source_file, "tokenizer_config.json"); + } + other => panic!("expected huggingface template, got {other:?}"), + } + assert_eq!(behavior.prompt_template.as_deref(), Some("chatml")); + assert_eq!(behavior.template_source.as_deref(), Some("huggingface")); + assert_eq!(behavior.default_system_prompt, None); +} + +#[test] +fn corpus_gemma3_template_reports_declared_source_file() { + let fixture = corpus_fixture("mlx-community/gemma-3-4b-it-qat-4bit"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let behavior = template.behavior(); + + match template { + PromptTemplate::HuggingFace { source_file, .. } => { + assert_eq!(source_file, "chat_template.json"); + } + other => panic!("expected huggingface template, got {other:?}"), + } + assert_eq!(behavior.prompt_template.as_deref(), Some("gemma3")); + assert_eq!(behavior.template_source.as_deref(), Some("huggingface")); +} + +#[test] +fn glm_templates_default_enable_thinking_to_false() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-glm-thinking-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("chat_template.jinja"), + "{%- if add_generation_prompt %}<|assistant|>{{ '/nothink' if (enable_thinking is defined and not enable_thinking) else '' }}{%- endif %}", + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"glm","architectures":["GlmForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + + assert_eq!(prompt, "<|assistant|>/nothink"); +} + +#[test] +fn kimi_templates_map_enable_thinking_to_thinking() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-kimi-thinking-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("chat_template.jinja"), + "{%- if add_generation_prompt %}{%- if thinking is defined and thinking is false -%}<think></think>{%- else -%}<think>{%- endif -%}{%- endif %}", + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"kimi","architectures":["KimiForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + assert_eq!(prompt, "<think></think>"); + + let explicit_prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "chat_template_kwargs": {"enable_thinking": true} + })) + .unwrap(); + assert_eq!(explicit_prompt, "<think>"); +} + +#[test] +fn gpt_oss_templates_map_enable_thinking_to_reasoning_effort() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-gpt-oss-thinking-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("chat_template.jinja"), + "{{ reasoning_effort | default('missing') }}", + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"gpt_oss","architectures":["GptOssForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + assert_eq!(prompt, "low"); + + let explicit_prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "enable_thinking": true + })) + .unwrap(); + assert_eq!(explicit_prompt, "medium"); +} + +#[test] +fn lfm2_templates_map_enable_thinking_to_keep_past_thinking() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-lfm2-thinking-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("chat_template.jinja"), + "{{ 'keep' if keep_past_thinking | default(false) else 'strip' }}", + ) + .unwrap(); + + let template = PromptTemplate::detect( + &root, + &serde_json::json!({"model_type":"lfm2","architectures":["LlamaForCausalLM"]}), + ); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + assert_eq!(prompt, "strip"); + + let explicit_prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "enable_thinking": true + })) + .unwrap(); + assert_eq!(explicit_prompt, "keep"); +} + +#[test] +fn glm_fixture_defaults_to_nothink() { + let fixture = corpus_fixture("lmstudio-community/GLM-4.6V-Flash-MLX-4bit"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + "tools": [{"type": "function", "function": {"name": "run", "description": "Run a command"}}], + "add_generation_prompt": true + })) + .unwrap(); + + assert!(prompt.contains("/nothink")); +} + +#[test] +fn kimi_fixture_defaults_to_no_thinking() { + let fixture = corpus_fixture("mlx-community/Kimi-K2.5"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + "add_generation_prompt": true + })) + .unwrap(); + + assert!(prompt.contains("<think></think>")); + assert!(!prompt.contains("<|im_assistant|>assistant<|im_middle|>\n <think>\n")); +} + +#[test] +fn gpt_oss_fixture_defaults_to_low_reasoning_effort() { + let fixture = corpus_fixture("mlx-community/gpt-oss-20b-MXFP4-Q8"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "builtin_tools": ["browser", "python"], + "add_generation_prompt": true + })) + .unwrap(); + + assert!(prompt.contains("Reasoning: low")); +} + +#[test] +fn lfm2_fixture_defaults_to_stripping_past_thinking() { + let fixture = corpus_fixture("lmstudio-community/LFM2-24B-A2B-MLX-4bit"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let prompt = template + .render_request(&json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "<think>\ninternal\n</think>\nhi"}, + {"role": "user", "content": [{"type": "text", "text": "look"}, {"type": "image"}]}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "again"} + ], + "add_generation_prompt": true + })) + .unwrap(); + + assert!(!prompt.contains("internal")); + assert!(prompt.contains("hi<|im_end|>")); +} + +#[test] +fn renders_llama3_prompt_from_hf_template() { + let root = + std::env::temp_dir().join(format!("mesh-llm-template-llama3-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}{%- for message in messages %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>{%- endfor %}{%- if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"llama"})); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + assert!(prompt.starts_with("<|begin_of_text|>")); + assert!(prompt.contains("<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>")); + assert!(prompt.contains("<|start_header_id|>assistant<|end_header_id|>")); +} + +#[test] +fn llama_hf_template_does_not_enter_tool_mode_when_tools_are_absent() { + let fixture = corpus_fixture("mlx-community/Llama-3.2-1B-Instruct-4bit"); + let root = write_hf_fixture_dir(&fixture); + let template = PromptTemplate::detect(&root, &fixture_config(&fixture.family)); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}], + "add_generation_prompt": true + })) + .unwrap(); + + assert!(!prompt.contains("Environment: ipython")); + assert!(!prompt.contains("Given the following functions")); +} + +#[test] +fn renders_qwen_tools_template_with_minijinja() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen-tools-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "bos_token": "<s>", + "eos_token": "</s>", + "chat_template": "{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}{%- endif %}{{- '\\n\\n# Tools\\n\\n<tools>' }}{%- for tool in tools %}{{- '\\n' }}{{- tool | tojson }}{%- endfor %}{{- '\\n</tools><|im_end|>\\n' }}{%- endif %}{%- for message in messages %}{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}{%- endfor %}{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"qwen2"})); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "use a tool"}], + "tools": [{"type": "function", "function": {"name": "run", "description": "Run a command"}}] + })) + .unwrap(); + + assert!(prompt.contains("# Tools")); + assert!(prompt.contains("\"name\":\"run\"")); + assert!(prompt.contains("<|im_start|>assistant\n")); +} + +#[test] +fn qwen_prompt_parity_fixture_matches_expected_output() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-qwen-fixture-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{%- if messages[0]['role'] != 'system' -%}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{%- endif -%}{%- for message in messages -%}<|im_start|>{{ message['role'] }}\n{{ message['content'] }}<|im_end|>\n{%- endfor -%}{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"qwen2"})); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + + assert_eq!( + prompt, + "<|im_start|>system\nYou are a helpful assistant.<|im_end|><|im_start|>user\nhello<|im_end|><|im_start|>assistant" + ); +} + +#[test] +fn llama3_prompt_parity_fixture_matches_expected_output() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-llama3-fixture-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}{%- for message in messages %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>{%- endfor %}{%- if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{%- endif %}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"llama"})); + let prompt = template + .render_request(&json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"} + ] + })) + .unwrap(); + + assert_eq!( + prompt, + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nBe concise.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + ); +} + +#[test] +fn gemma3_prompt_parity_fixture_matches_expected_output() { + let root = std::env::temp_dir().join(format!( + "mesh-llm-template-gemma3-fixture-{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "bos_token": "<bos>", + "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\\n\\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\\n\\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- endif -%}\n {{ '<end_of_turn>\\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\\n'}}\n{%- endif -%}\n" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"gemma3"})); + let prompt = template + .render_request(&json!({ + "messages": [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": [ + {"type": "text", "text": "look "}, + {"type": "image"}, + {"type": "text", "text": "here"} + ]} + ] + })) + .unwrap(); + + assert_eq!( + prompt, + "<bos><start_of_turn>user\nBe concise.\n\nhello<end_of_turn>\n<start_of_turn>model\nhi<end_of_turn>\n<start_of_turn>user\nlook<start_of_image>here<end_of_turn>\n<start_of_turn>model\n" + ); +} + +#[test] +fn heuristic_fallback_uses_gemma3_for_gemma_models() { + let template = PromptTemplate::detect( + Path::new("/tmp/does-not-need-to-exist"), + &serde_json::json!({"model_type":"gemma3","architectures":["Gemma3ForConditionalGeneration"]}), + ); + assert_eq!(template, PromptTemplate::Gemma3); +} + +#[test] +fn qwen_fallback_behavior_reports_chatml_defaults() { + let template = PromptTemplate::detect( + Path::new("/tmp/does-not-need-to-exist"), + &serde_json::json!({"model_type":"qwen3","architectures":["Qwen3ForCausalLM"]}), + ); + let behavior = template.behavior(); + + assert_eq!(behavior.prompt_template.as_deref(), Some("chatml")); + assert_eq!( + behavior.default_system_prompt.as_deref(), + Some("You are a helpful assistant.") + ); + assert_eq!(behavior.template_source.as_deref(), Some("fallback")); +} + +#[test] +fn olmo_family_fallback_behavior_reports_expected_template_names() { + let olmo = PromptTemplate::detect( + Path::new("/tmp/does-not-need-to-exist"), + &serde_json::json!({"model_type":"olmo","architectures":["OlmoForCausalLM"]}), + ) + .behavior(); + let olmo2 = PromptTemplate::detect( + Path::new("/tmp/does-not-need-to-exist"), + &serde_json::json!({"model_type":"olmo2","architectures":["Olmo2ForCausalLM"]}), + ) + .behavior(); + + assert_eq!(olmo.prompt_template.as_deref(), Some("olmo")); + assert_eq!(olmo.template_source.as_deref(), Some("fallback")); + assert_eq!(olmo2.prompt_template.as_deref(), Some("olmo2")); + assert_eq!(olmo2.template_source.as_deref(), Some("fallback")); +} + +#[test] +fn renders_when_template_uses_strftime_now() { + let root = + std::env::temp_dir().join(format!("mesh-llm-template-fallback-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{{ strftime_now('%Y-%m-%d') }}" + }) + .to_string(), + ) + .unwrap(); + + let template = PromptTemplate::detect(&root, &serde_json::json!({"model_type":"qwen2"})); + let prompt = template + .render_request(&json!({ + "messages": [{"role": "user", "content": "hello world"}] + })) + .unwrap(); + + assert_eq!(prompt.len(), 10); + assert_eq!(prompt.chars().filter(|c| *c == '-').count(), 2); +} + +#[test] +fn real_hf_template_corpus_behaves_as_expected() { + for fixture in hf_template_corpus() { + let root = write_hf_fixture_dir(&fixture); + let config = fixture_config(&fixture.family); + let req = fixture_request(&fixture.family); + let special_tokens = read_special_tokens(&root); + let normalized = normalize_hf_template(&fixture.template); + + if fixture.expect_hf_render { + validate_hf_template(&normalized) + .unwrap_or_else(|err| panic!("{} should compile: {err}", fixture.repo)); + let prompt = render_hf_template( + &normalized, + &special_tokens, + &ReasoningDefaults::default(), + &req, + ) + .unwrap_or_else(|err| panic!("{} should render via HF path: {err}", fixture.repo)); + assert!( + !prompt.trim().is_empty(), + "{} rendered an empty prompt", + fixture.repo + ); + } else { + if validate_hf_template(&normalized).is_ok() { + render_hf_template( + &normalized, + &special_tokens, + &ReasoningDefaults::default(), + &req, + ) + .expect_err("fixture should still require fallback"); + } + + let prompt = PromptTemplate::detect(&root, &config) + .render_request(&req) + .unwrap_or_else(|render_err| { + panic!("{} should render via fallback: {render_err}", fixture.repo) + }); + assert!( + !prompt.trim().is_empty(), + "{} fallback rendered an empty prompt", + fixture.repo + ); + } + } +} diff --git a/mesh-llm/src/mlx/testdata/hf_template_corpus.json b/mesh-llm/src/mlx/testdata/hf_template_corpus.json new file mode 100644 index 00000000..3ef90135 --- /dev/null +++ b/mesh-llm/src/mlx/testdata/hf_template_corpus.json @@ -0,0 +1,167 @@ +[ + { + "repo": "mlx-community/Llama-3.2-1B-Instruct-4bit", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "llama", + "bos_token": "<|begin_of_text|>", + "eos_token": "<|eot_id|>", + "pad_token": null, + "unk_token": null, + "template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + }, + { + "repo": "mlx-community/Llama-3.2-3B-Instruct-4bit", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "llama", + "bos_token": "<|begin_of_text|>", + "eos_token": "<|eot_id|>", + "pad_token": null, + "unk_token": null, + "template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + }, + { + "repo": "mlx-community/Qwen2.5-0.5B-Instruct-bf16", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "qwen", + "bos_token": null, + "eos_token": "<|im_end|>", + "pad_token": "<|endoftext|>", + "unk_token": null, + "template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + }, + { + "repo": "mlx-community/gemma-3-1b-it-qat-4bit", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "gemma3", + "bos_token": "<bos>", + "eos_token": "<eos>", + "pad_token": "<pad>", + "unk_token": "<unk>", + "template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n" + }, + { + "repo": "mlx-community/gemma-3-4b-it-qat-4bit", + "source_file": "chat_template.json", + "expect_hf_render": true, + "family": "gemma3", + "bos_token": "<bos>", + "eos_token": "<eos>", + "pad_token": "<pad>", + "unk_token": "<unk>", + "template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n" + }, + { + "repo": "mlx-community/Mistral-7B-Instruct-v0.2-4bit", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "mistral", + "bos_token": "<s>", + "eos_token": "</s>", + "pad_token": null, + "unk_token": "<unk>", + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + }, + { + "repo": "lmstudio-community/LFM2-24B-A2B-MLX-4bit", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "lfm2", + "bos_token": "<|startoftext|>", + "eos_token": "<|im_end|>", + "pad_token": "<|pad|>", + "unk_token": null, + "template": "{{- bos_token -}}\n{%- set keep_past_thinking = keep_past_thinking | default(false) -%}\n{%- set ns = namespace(system_prompt=\"\") -%}\n{%- if messages[0][\"role\"] == \"system\" -%}\n {%- set sys_content = messages[0][\"content\"] -%}\n {%- if sys_content is not string -%}\n {%- for item in sys_content -%}\n {%- if item[\"type\"] == \"text\" -%}\n {%- set ns.system_prompt = ns.system_prompt + item[\"text\"] -%}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {%- set ns.system_prompt = sys_content -%}\n {%- endif -%}\n {%- set messages = messages[1:] -%}\n{%- endif -%}\n{%- if tools -%}\n {%- set ns.system_prompt = ns.system_prompt + (\"\\n\" if ns.system_prompt else \"\") + \"List of tools: [\" -%}\n {%- for tool in tools -%}\n {%- if tool is not string -%}\n {%- set tool = tool | tojson -%}\n {%- endif -%}\n {%- set ns.system_prompt = ns.system_prompt + tool -%}\n {%- if not loop.last -%}\n {%- set ns.system_prompt = ns.system_prompt + \", \" -%}\n {%- endif -%}\n {%- endfor -%}\n {%- set ns.system_prompt = ns.system_prompt + \"]\" -%}\n{%- endif -%}\n{%- if ns.system_prompt -%}\n {{- \"<|im_start|>system\\n\" + ns.system_prompt + \"<|im_end|>\\n\" -}}\n{%- endif -%}\n{%- set ns.last_assistant_index = -1 -%}\n{%- for message in messages -%}\n {%- if message[\"role\"] == \"assistant\" -%}\n {%- set ns.last_assistant_index = loop.index0 -%}\n {%- endif -%}\n{%- endfor -%}\n{%- for message in messages -%}\n {{- \"<|im_start|>\" + message[\"role\"] + \"\\n\" -}}\n {%- set content = message[\"content\"] -%}\n {%- if content is not string -%}\n {%- set ns.content = \"\" -%}\n {%- for item in content -%}\n {%- if item[\"type\"] == \"image\" -%}\n {%- set ns.content = ns.content + \"<image>\" -%}\n {%- elif item[\"type\"] == \"text\" -%}\n {%- set ns.content = ns.content + item[\"text\"] -%}\n {%- else -%}\n {%- set ns.content = ns.content + item | tojson -%}\n {%- endif -%}\n {%- endfor -%}\n {%- set content = ns.content -%}\n {%- endif -%}\n {%- if message[\"role\"] == \"assistant\" and not keep_past_thinking and loop.index0 != ns.last_assistant_index -%}\n {%- if \"</think>\" in content -%}\n {%- set content = content.split(\"</think>\")[-1] | trim -%}\n {%- endif -%}\n {%- endif -%}\n {{- content + \"<|im_end|>\\n\" -}}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- \"<|im_start|>assistant\\n\" -}}\n{%- endif -%}" + }, + { + "repo": "lmstudio-community/DeepSeek-R1-0528-Qwen3-8B-MLX-4bit", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "deepseek_qwen3", + "bos_token": "<|begin▁of▁sentence|>", + "eos_token": "<|end▁of▁sentence|>", + "pad_token": "<|end▁of▁sentence|>", + "unk_token": null, + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{% set content = message['content'] %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + content + '<|Assistant|>'}}{%- endif %}{%- if message['role'] == 'assistant' %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{% endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{%- endif %}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- set ns.is_output_first = true %}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if content is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}" + }, + { + "repo": "mlx-community/Devstral-Small-2-24B-Instruct-2512-4bit", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "devstral", + "bos_token": "<s>", + "eos_token": "</s>", + "pad_token": "<pad>", + "unk_token": "<unk>", + "template": "{#- Default system message if no system prompt is passed. #}\n{%- set default_system_message = '' %}\n\n{#- Begin of sequence token. #}\n{{- bos_token }}\n\n{#- Handle system prompt if it exists. #}\n{#- System prompt supports text content or text chunks. #}\n{%- if messages[0]['role'] == 'system' %}\n {{- '[SYSTEM_PROMPT]' -}}\n {%- if messages[0]['content'] is string %}\n {{- messages[0]['content'] -}}\n {%- else %} \n {%- for block in messages[0]['content'] %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- else %}\n {{- raise_exception('Only text chunks are supported in system message contents.') }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '[/SYSTEM_PROMPT]' -}}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n {%- if default_system_message != '' %}\n {{- '[SYSTEM_PROMPT]' + default_system_message + '[/SYSTEM_PROMPT]' }}\n {%- endif %}\n{%- endif %}\n\n\n{#- Tools definition #}\n{%- set tools_definition = '' %}\n{%- set has_tools = false %}\n{%- if tools is defined and tools is not none and tools|length > 0 %}\n {%- set has_tools = true %}\n {%- set tools_definition = '[AVAILABLE_TOOLS]' + (tools| tojson) + '[/AVAILABLE_TOOLS]' %}\n {{- tools_definition }}\n{%- endif %}\n\n{#- Checks for alternating user/assistant messages. #}\n{%- set ns = namespace(index=0) %}\n{%- for message in loop_messages %}\n {%- if message.role == 'user' or (message.role == 'assistant' and (message.tool_calls is not defined or message.tool_calls is none or message.tool_calls | length == 0)) %}\n {%- if (message['role'] == 'user') != (ns.index % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user and assistant roles except for tool calls and results.') }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{#- Handle conversation messages. #}\n{%- for message in loop_messages %}\n\n {#- User messages supports text content or text and image chunks. #}\n {%- if message['role'] == 'user' %}\n {%- if message['content'] is string %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n {%- elif message['content'] | length > 0 %}\n {{- '[INST]' }}\n {%- if message['content'] | length == 2 %}\n {%- set blocks = message['content'] | sort(attribute='type') %}\n {%- else %}\n {%- set blocks = message['content'] %}\n {%- endif %}\n {%- for block in blocks %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- elif block['type'] in ['image', 'image_url'] %}\n {{- '[IMG]' }}\n {%- else %}\n {{- raise_exception('Only text, image and image_url chunks are supported in user message content.') }}\n {%- endif %}\n {%- endfor %}\n {{- '[/INST]' }}\n {%- else %}\n {{- raise_exception('User message must have a string or a list of chunks in content') }}\n {%- endif %}\n\n {#- Assistant messages supports text content or text and image chunks. #}\n {%- elif message['role'] == 'assistant' %}\n {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}\n {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}\n {%- endif %}\n\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- elif message['content'] | length > 0 %}\n {%- for block in message['content'] %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- else %}\n {{- raise_exception('Only text chunks are supported in assistant message contents.') }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n \n {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}\n {%- for tool in message['tool_calls'] %}\n {%- set arguments = tool['function']['arguments'] %}\n {%- if arguments is not string %}\n {%- set arguments = arguments|tojson|safe %}\n {%- elif arguments == '' %}\n {%- set arguments = '{}' %}\n {%- endif %}\n {{- '[TOOL_CALLS]' + tool['function']['name'] + '[ARGS]' + arguments }}\n {%- endfor %}\n {%- endif %}\n\n {#- End of sequence token for each assistant messages. #}\n {{- eos_token }}\n\n {#- Tool messages only supports text content. #}\n {%- elif message['role'] == 'tool' %}\n {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}\n\n {#- Raise exception for unsupported roles. #}\n {%- else %}\n {{- raise_exception('Only user, assistant and tool roles are supported, got ' + message['role'] + '.') }}\n {%- endif %}\n{%- endfor %}" + }, + { + "repo": "mlx-community/Qwen3-1.7B-4bit", + "source_file": "tokenizer_config.json", + "expect_hf_render": true, + "family": "qwen3", + "bos_token": null, + "eos_token": "<|im_end|>", + "pad_token": "<|endoftext|>", + "unk_token": null, + "template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}" + }, + { + "repo": "lmstudio-community/Qwen3-Coder-Next-MLX-4bit", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "qwen3_coder_next", + "bos_token": null, + "eos_token": "<|im_end|>", + "pad_token": "<|endoftext|>", + "unk_token": null, + "template": "{% macro render_extra_keys(json_dict, handled_keys) %}\n {%- if json_dict is mapping %}\n {%- for json_key in json_dict if json_key not in handled_keys %}\n {%- if json_dict[json_key] is string %}\n {{-'\\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}\n {%- else %}\n {{- '\\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson) ~ '</' ~ json_key ~ '>' }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n{%- endmacro %}\n\n{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{%- if not tools is defined %}\n {%- set tools = [] %}\n{%- endif %}\n\n{%- if system_message is defined %}\n {{- \"<|im_start|>system\\n\" + system_message }}\n{%- else %}\n {%- if tools is iterable and tools | length > 0 %}\n {{- \"<|im_start|>system\\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks.\" }}\n {%- endif %}\n{%- endif %}\n{%- if tools is iterable and tools | length > 0 %}\n {{- \"\\n\\n# Tools\\n\\nYou have access to the following functions:\\n\\n\" }}\n {{- \"<tools>\" }}\n {%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- \"\\n<function>\\n<name>\" ~ tool.name ~ \"</name>\" }}\n {%- if tool.description is defined %}\n {{- '\\n<description>' ~ (tool.description | trim) ~ '</description>' }}\n {%- endif %}\n {{- '\\n<parameters>' }}\n {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- '\\n<parameter>' }}\n {{- '\\n<name>' ~ param_name ~ '</name>' }}\n {%- if param_fields.type is defined %}\n {{- '\\n<type>' ~ (param_fields.type | string) ~ '</type>' }}\n {%- endif %}\n {%- if param_fields.description is defined %}\n {{- '\\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}\n {%- endif %}\n {%- set handled_keys = ['name', 'type', 'description'] %}\n {{- render_extra_keys(param_fields, handled_keys) }}\n {{- '\\n</parameter>' }}\n {%- endfor %}\n {%- endif %}\n {%- set handled_keys = ['type', 'properties'] %}\n {{- render_extra_keys(tool.parameters, handled_keys) }}\n {{- '\\n</parameters>' }}\n {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}\n {{- render_extra_keys(tool, handled_keys) }}\n {{- '\\n</function>' }}\n {%- endfor %}\n {{- \"\\n</tools>\" }}\n {{- '\\n\\nIf you choose to call a function ONLY reply in the following format with NO suffix:\\n\\n<tool_call>\\n<function=example_function_name>\\n<parameter=example_parameter_1>\\nvalue_1\\n</parameter>\\n<parameter=example_parameter_2>\\nThis is the value for the second parameter\\nthat can span\\nmultiple lines\\n</parameter>\\n</function>\\n</tool_call>\\n\\n<IMPORTANT>\\nReminder:\\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\\n- Required parameters MUST be specified\\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\\n</IMPORTANT>' }}\n{%- endif %}\n{%- if system_message is defined %}\n {{- '<|im_end|>\\n' }}\n{%- else %}\n {%- if tools is iterable and tools | length > 0 %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in loop_messages %}\n {%- if message.role == \"assistant\" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}\n {{- '\\n' + message.content | trim + '\\n' }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n<function=' + tool_call.name + '>\\n' }}\n {%- if tool_call.arguments is defined %}\n {%- for args_name, args_value in tool_call.arguments|items %}\n {{- '<parameter=' + args_name + '>\\n' }}\n {%- set args_value = args_value if args_value is string else args_value | tojson %}\n {{- args_value }}\n {{- '\\n</parameter>\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '</function>\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"user\" or message.role == \"system\" or message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.previtem and loop.previtem.role != \"tool\" %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if not loop.last and loop.nextitem.role != \"tool\" %}\n {{- '<|im_end|>\\n' }}\n {%- elif loop.last %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + }, + { + "repo": "lmstudio-community/Qwen3-Coder-30B-A3B-Instruct-MLX-4bit", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "qwen3_coder_30b", + "bos_token": null, + "eos_token": "<|im_end|>", + "pad_token": "<|endoftext|>", + "unk_token": null, + "template": "{% macro render_item_list(item_list, tag_name='required') %}\n {%- if item_list is defined and item_list is iterable and item_list | length > 0 %}\n {%- if tag_name %}{{- '\\n<' ~ tag_name ~ '>' -}}{% endif %}\n {{- '[' }}\n {%- for item in item_list -%}\n {%- if loop.index > 1 %}{{- \", \"}}{% endif -%}\n {%- if item is string -%}\n {{ \"`\" ~ item ~ \"`\" }}\n {%- else -%}\n {{ item }}\n {%- endif -%}\n {%- endfor -%}\n {{- ']' }}\n {%- if tag_name %}{{- '</' ~ tag_name ~ '>' -}}{% endif %}\n {%- endif %}\n{% endmacro %}\n\n{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{%- if not tools is defined %}\n {%- set tools = [] %}\n{%- endif %}\n\n{%- if system_message is defined %}\n {{- \"<|im_start|>system\\n\" + system_message }}\n{%- else %}\n {%- if tools is iterable and tools | length > 0 %}\n {{- \"<|im_start|>system\\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks.\" }}\n {%- endif %}\n{%- endif %}\n{%- if tools is iterable and tools | length > 0 %}\n {{- \"\\n\\nYou have access to the following functions:\\n\\n\" }}\n {{- \"<tools>\" }}\n {%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- \"\\n<function>\\n<name>\" ~ tool.name ~ \"</name>\" }}\n {{- '\\n<description>' ~ (tool.description | trim) ~ '</description>' }}\n {{- '\\n<parameters>' }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- '\\n<parameter>' }}\n {{- '\\n<name>' ~ param_name ~ '</name>' }}\n {%- if param_fields.type is defined %}\n {{- '\\n<type>' ~ (param_fields.type | string) ~ '</type>' }}\n {%- endif %}\n {%- if param_fields.description is defined %}\n {{- '\\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}\n {%- endif %}\n {{- render_item_list(param_fields.enum, 'enum') }}\n {%- set handled_keys = ['type', 'description', 'enum', 'required'] %}\n {%- for json_key in param_fields.keys() | reject(\"in\", handled_keys) %}\n {%- set normed_json_key = json_key | replace(\"-\", \"_\") | replace(\" \", \"_\") | replace(\"$\", \"\") %}\n {%- if param_fields[json_key] is mapping %}\n {{- '\\n<' ~ normed_json_key ~ '>' ~ (param_fields[json_key] | tojson | safe) ~ '</' ~ normed_json_key ~ '>' }}\n {%- else %}\n {{-'\\n<' ~ normed_json_key ~ '>' ~ (param_fields[json_key] | string) ~ '</' ~ normed_json_key ~ '>' }}\n {%- endif %}\n {%- endfor %}\n {{- render_item_list(param_fields.required, 'required') }}\n {{- '\\n</parameter>' }}\n {%- endfor %}\n {{- render_item_list(tool.parameters.required, 'required') }}\n {{- '\\n</parameters>' }}\n {%- if tool.return is defined %}\n {%- if tool.return is mapping %}\n {{- '\\n<return>' ~ (tool.return | tojson | safe) ~ '</return>' }}\n {%- else %}\n {{- '\\n<return>' ~ (tool.return | string) ~ '</return>' }}\n {%- endif %}\n {%- endif %}\n {{- '\\n</function>' }}\n {%- endfor %}\n {{- \"\\n</tools>\" }}\n {{- '\\n\\nIf you choose to call a function ONLY reply in the following format with NO suffix:\\n\\n<tool_call>\\n<function=example_function_name>\\n<parameter=example_parameter_1>\\nvalue_1\\n</parameter>\\n<parameter=example_parameter_2>\\nThis is the value for the second parameter\\nthat can span\\nmultiple lines\\n</parameter>\\n</function>\\n</tool_call>\\n\\n<IMPORTANT>\\nReminder:\\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\\n- Required parameters MUST be specified\\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\\n</IMPORTANT>' }}\n{%- endif %}\n{%- if system_message is defined %}\n {{- '<|im_end|>\\n' }}\n{%- else %}\n {%- if tools is iterable and tools | length > 0 %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in loop_messages %}\n {%- if message.role == \"assistant\" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}\n {{- '\\n' + message.content | trim + '\\n' }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n<function=' + tool_call.name + '>\\n' }}\n {%- if tool_call.arguments is defined %}\n {%- for args_name, args_value in tool_call.arguments|items %}\n {{- '<parameter=' + args_name + '>\\n' }}\n {%- set args_value = args_value if args_value is string else args_value | string %}\n {{- args_value }}\n {{- '\\n</parameter>\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '</function>\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"user\" or message.role == \"system\" or message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.previtem and loop.previtem.role != \"tool\" %}\n {{- '<|im_start|>user\\n' }}\n {%- endif %}\n {{- '<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>\\n' }}\n {%- if not loop.last and loop.nextitem.role != \"tool\" %}\n {{- '<|im_end|>\\n' }}\n {%- elif loop.last %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + }, + { + "repo": "lmstudio-community/GLM-4.6V-Flash-MLX-4bit", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "glm4v", + "bos_token": null, + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + "unk_token": null, + "template": "[gMASK]<sop>\n{%- if tools -%}\n<|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{% for tool in tools %}\n{{ tool | tojson(ensure_ascii=False) }}\n{% endfor %}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}\n<arg_key>{arg-key-1}</arg_key>\n<arg_value>{arg-value-1}</arg_value>\n<arg_key>{arg-key-2}</arg_key>\n<arg_value>{arg-value-2}</arg_value>\n...\n</tool_call>{%- endif -%}\n{%- macro visible_text(content) -%}\n {%- if content is string -%}\n {{- content }}\n {%- elif content is iterable and content is not mapping -%}\n {%- for item in content -%}\n {%- if item is mapping and item.type == 'text' -%}\n {{- item.text }}\n {%- elif item is mapping and (item.type == 'image' or 'image' in item) -%}\n <|begin_of_image|><|image|><|end_of_image|>\n {%- elif item is mapping and (item.type == 'video' or 'video' in item) -%}\n <|begin_of_video|><|video|><|end_of_video|>\n {%- elif item is string -%}\n {{- item }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{- content }}\n {%- endif -%}\n{%- endmacro -%}\n{%- set ns = namespace(last_user_index=-1) %}\n{%- for m in messages %}\n {%- if m.role == 'user' %}\n {% set ns.last_user_index = loop.index0 -%}\n {%- endif %}\n{%- endfor %}\n{% for m in messages %}\n{%- if m.role == 'user' -%}<|user|>\n{% if m.content is string %}\n{{ m.content }}\n{%- else %}\n{%- for item in m.content %}\n{% if item.type == 'video' or 'video' in item %}\n<|begin_of_video|><|video|><|end_of_video|>{% elif item.type == 'image' or 'image' in item %}\n<|begin_of_image|><|image|><|end_of_image|>{% elif item.type == 'text' %}\n{{ item.text }}\n{%- endif %}\n{%- endfor %}\n{%- endif %}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith(\"/nothink\")) else '' -}}\n{%- elif m.role == 'assistant' -%}\n<|assistant|>\n{%- set reasoning_content = '' %}\n{%- set content = visible_text(m.content) %}\n{%- if m.reasoning_content is string %}\n {%- set reasoning_content = m.reasoning_content %}\n{%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n{%- endif %}\n{%- if loop.index0 > ns.last_user_index and reasoning_content -%}\n{{ '\\n<think>' + reasoning_content.strip() + '</think>'}}\n{%- else -%}\n{{ '\\n<think></think>' }}\n{%- endif -%}\n{%- if content.strip() -%}\n{{ '\\n' + content.strip() }}\n{%- endif -%}\n{% if m.tool_calls %}\n{% for tc in m.tool_calls %}\n{%- if tc.function %}\n {%- set tc = tc.function %}\n{%- endif %}\n{{ '\\n<tool_call>' + tc.name }}\n{% set _args = tc.arguments %}\n{% for k, v in _args.items() %}\n<arg_key>{{ k }}</arg_key>\n<arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>\n{% endfor %}\n</tool_call>{% endfor %}\n{% endif %}\n{%- elif m.role == 'tool' -%}\n{%- if m.content is string -%}\n{%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|observation|>' }}\n{%- endif %}\n{{- '\\n<tool_response>\\n' }}\n{{- m.content }}\n{{- '\\n</tool_response>' }}\n{% elif m.content is iterable and m.content is not mapping %}\n{%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n{{- '<|observation|>' }}\n{%- endif %}\n{{- '\\n<tool_response>\\n' }}\n{%- for tr in m.content -%}\n {%- if tr is mapping and tr.type is defined -%}\n {%- set t = tr.type | lower -%}\n {%- if t == 'text' and tr.text is defined -%}\n{{ tr.text }}\n {%- elif t in ['image', 'image_url'] -%}\n<|begin_of_image|><|image|><|end_of_image|>\n {%- elif t in ['video', 'video_url'] -%}\n<|begin_of_video|><|video|><|end_of_video|>\n {%- else -%}\n{{ tr | tojson(ensure_ascii=False) }}\n {%- endif -%}\n {%- else -%}\n{{ tr.output if tr.output is defined else tr }}\n {%- endif -%}\n{%- endfor -%}\n{{- '\\n</tool_response>' }}\n{%- else -%}\n<|observation|>{% for tr in m.content %}\n\n<tool_response>\n{{ tr.output if tr.output is defined else tr }}\n</tool_response>{% endfor -%}\n{% endif -%}\n{%- elif m.role == 'system' -%}\n<|system|>\n{{ visible_text(m.content) }}\n{%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n<|assistant|>\n{{'<think></think>\\n' if (enable_thinking is defined and not enable_thinking) else ''}}\n{%- endif -%}" + }, + { + "repo": "mlx-community/Kimi-K2.5", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "kimi", + "bos_token": "[BOS]", + "eos_token": "[EOS]", + "pad_token": "[PAD]", + "unk_token": "[UNK]", + "template": "{%- macro render_content(msg) -%}\n {%- set c = msg.get('content') -%}\n {%- if c is string -%}\n {{ c }}\n {%- elif c is not none -%}\n {% for content in c -%}\n {% if content['type'] == 'image' or content['type'] == 'image_url' -%}\n <|media_start|>image<|media_content|><|media_pad|><|media_end|>\n {% elif content['type'] == 'video' or content['type']== 'video_url'-%}\n <|kimi_k25_video_placeholder|>\n {% else -%}\n {{ content['text'] }}\n {%- endif -%}\n {%- endfor -%}\n {%- endif -%}\n{%- endmacro -%}\n\n{% macro set_roles(message) -%}\n {%- set role_name = message.get('name') or message['role'] -%}\n {%- if message['role'] == 'user' -%}\n <|im_user|>{{role_name}}<|im_middle|>\n {%- elif message['role'] == 'assistant' -%}\n <|im_assistant|>{{role_name}}<|im_middle|>\n {%- else -%}\n <|im_system|>{{role_name}}<|im_middle|>\n {%- endif -%}\n{%- endmacro -%}\n\n\n{%- macro render_toolcalls(message) -%}\n <|tool_calls_section_begin|>\n {%- for tool_call in message['tool_calls'] -%}\n {%- set formatted_id = tool_call['id'] -%}\n <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>\n {%- endfor -%}\n <|tool_calls_section_end|>\n{%- endmacro -%}\n\n\n{# Find last non-tool-call assisitant message #}\n{%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}\n{%- for idx in range(messages|length-1, -1, -1) -%}\n {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}\n {%- set ns.last_non_tool_call_assistant_msg = idx -%}\n {%- break -%}\n {%- endif -%}\n{%- endfor -%}\n\n{# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}\n{%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}\n{%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}\n\n{%- if tools -%}\n {%- if tools_ts_str -%}\n <|im_system|>tool_declare<|im_middle|>{{ tools_ts_str }}<|im_end|>\n {%- else -%}\n <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>\n {%- endif -%}\n{%- endif -%}\n\n{%- if messages|length == 0 or messages[0]['role'] != 'system' -%}\n <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>\n{%- endif -%}\n \n{%- for message in hist_msgs -%}\n {{set_roles(message)}}\n {%- if message['role'] == 'assistant' -%}\n <think></think>{{render_content(message)}}\n {%- if message.get('tool_calls') -%}\n {{render_toolcalls(message)}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {%- set tool_call_id = message.tool_call_id -%}\n ## Return of {{ tool_call_id }}\n{{render_content(message)}}\n {%- elif message['content'] is not none -%}\n {{render_content(message)}}\n {%- endif -%}\n <|im_end|>\n{%- endfor -%}\n\n{%- for message in suffix_msgs -%}\n {{set_roles(message)}}\n {%- if message['role'] == 'assistant' -%}\n {%- if thinking is defined and thinking is false -%}\n <think></think>{{render_content(message)}}\n {%- else -%}\n {%- set rc = message.get('reasoning_content', '') -%}\n <think>{{rc}}</think>{{render_content(message)}}\n {%- endif -%}\n {%- if message.get('tool_calls') -%}\n {{render_toolcalls(message)}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {%- set tool_call_id = message.tool_call_id -%}\n ## Return of {{ tool_call_id }}\n{{render_content(message)}}\n {%- elif message['content'] is not none -%}\n {{render_content(message)}}\n {%- endif -%}\n <|im_end|>\n{%- endfor -%}\n\n\n{%- if add_generation_prompt -%}\n <|im_assistant|>assistant<|im_middle|>\n {%- if thinking is defined and thinking is false -%}\n <think></think>\n {%- else -%}\n <think>\n {%- endif -%}\n{%- endif -%}" + }, + { + "repo": "mlx-community/gpt-oss-20b-MXFP4-Q8", + "source_file": "chat_template.jinja", + "expect_hf_render": true, + "family": "gpt_oss", + "bos_token": "<|startoftext|>", + "eos_token": "<|return|>", + "pad_token": "<|endoftext|>", + "unk_token": null, + "template": "{#-\n In addition to the normal inputs of `messages` and `tools`, this template also accepts the\n following kwargs:\n - \"builtin_tools\": A list, can contain \"browser\" and/or \"python\".\n - \"model_identity\": A string that optionally describes the model identity.\n - \"reasoning_effort\": A string that describes the reasoning effort, defaults to \"medium\".\n #}\n\n{#- Tool Definition Rendering ============================================== #}\n{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}\n {%- if param_spec.type == \"array\" -%}\n {%- if param_spec['items'] -%}\n {%- if param_spec['items']['type'] == \"string\" -%}\n {{- \"string[]\" }}\n {%- elif param_spec['items']['type'] == \"number\" -%}\n {{- \"number[]\" }}\n {%- elif param_spec['items']['type'] == \"integer\" -%}\n {{- \"number[]\" }}\n {%- elif param_spec['items']['type'] == \"boolean\" -%}\n {{- \"boolean[]\" }}\n {%- else -%}\n {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}\n {%- if inner_type == \"object | object\" or inner_type|length > 50 -%}\n {{- \"any[]\" }}\n {%- else -%}\n {{- inner_type + \"[]\" }}\n {%- endif -%}\n {%- endif -%}\n {%- if param_spec.nullable -%}\n {{- \" | null\" }}\n {%- endif -%}\n {%- else -%}\n {{- \"any[]\" }}\n {%- if param_spec.nullable -%}\n {{- \" | null\" }}\n {%- endif -%}\n {%- endif -%}\n {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}\n {#- Handle array of types like [\"object\", \"object\"] from Union[dict, list] #}\n {%- if param_spec.type | length > 1 -%}\n {{- param_spec.type | join(\" | \") }}\n {%- else -%}\n {{- param_spec.type[0] }}\n {%- endif -%}\n {%- elif param_spec.oneOf -%}\n {#- Handle oneOf schemas - check for complex unions and fallback to any #}\n {%- set has_object_variants = false -%}\n {%- for variant in param_spec.oneOf -%}\n {%- if variant.type == \"object\" -%}\n {%- set has_object_variants = true -%}\n {%- endif -%}\n {%- endfor -%}\n {%- if has_object_variants and param_spec.oneOf|length > 1 -%}\n {{- \"any\" }}\n {%- else -%}\n {%- for variant in param_spec.oneOf -%}\n {{- render_typescript_type(variant, required_params) -}}\n {%- if variant.description %}\n {{- \"// \" + variant.description }}\n {%- endif -%}\n {%- if variant.default is defined %}\n {{ \"// default: \" + variant.default|tojson }}\n {%- endif -%}\n {%- if not loop.last %}\n {{- \" | \" }}\n {% endif -%}\n {%- endfor -%}\n {%- endif -%}\n {%- elif param_spec.type == \"string\" -%}\n {%- if param_spec.enum -%}\n {{- '\"' + param_spec.enum|join('\" | \"') + '\"' -}}\n {%- else -%}\n {{- \"string\" }}\n {%- if param_spec.nullable %}\n {{- \" | null\" }}\n {%- endif -%}\n {%- endif -%}\n {%- elif param_spec.type == \"number\" -%}\n {{- \"number\" }}\n {%- elif param_spec.type == \"integer\" -%}\n {{- \"number\" }}\n {%- elif param_spec.type == \"boolean\" -%}\n {{- \"boolean\" }}\n\n {%- elif param_spec.type == \"object\" -%}\n {%- if param_spec.properties -%}\n {{- \"{\\n\" }}\n {%- for prop_name, prop_spec in param_spec.properties.items() -%}\n {{- prop_name -}}\n {%- if prop_name not in (param_spec.required or []) -%}\n {{- \"?\" }}\n {%- endif -%}\n {{- \": \" }}\n {{ render_typescript_type(prop_spec, param_spec.required or []) }}\n {%- if not loop.last -%}\n {{-\", \" }}\n {%- endif -%}\n {%- endfor -%}\n {{- \"}\" }}\n {%- else -%}\n {{- \"object\" }}\n {%- endif -%}\n {%- else -%}\n {{- \"any\" }}\n {%- endif -%}\n{%- endmacro -%}\n\n{%- macro render_tool_namespace(namespace_name, tools) -%}\n {{- \"## \" + namespace_name + \"\\n\\n\" }}\n {{- \"namespace \" + namespace_name + \" {\\n\\n\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \"// \" + tool.description + \"\\n\" }}\n {{- \"type \"+ tool.name + \" = \" }}\n {%- if tool.parameters and tool.parameters.properties %}\n {{- \"(_: {\\n\" }}\n {%- for param_name, param_spec in tool.parameters.properties.items() %}\n {%- if param_spec.description %}\n {{- \"// \" + param_spec.description + \"\\n\" }}\n {%- endif %}\n {{- param_name }}\n {%- if param_name not in (tool.parameters.required or []) -%}\n {{- \"?\" }}\n {%- endif -%}\n {{- \": \" }}\n {{- render_typescript_type(param_spec, tool.parameters.required or []) }}\n {%- if param_spec.default is defined -%}\n {%- if param_spec.enum %}\n {{- \", // default: \" + param_spec.default }}\n {%- elif param_spec.oneOf %}\n {{- \"// default: \" + param_spec.default }}\n {%- else %}\n {{- \", // default: \" + param_spec.default|tojson }}\n {%- endif -%}\n {%- endif -%}\n {%- if not loop.last %}\n {{- \",\\n\" }}\n {%- else %}\n {{- \",\\n\" }}\n {%- endif -%}\n {%- endfor %}\n {{- \"}) => any;\\n\\n\" }}\n {%- else -%}\n {{- \"() => any;\\n\\n\" }}\n {%- endif -%}\n {%- endfor %}\n {{- \"} // namespace \" + namespace_name }}\n{%- endmacro -%}\n\n{%- macro render_builtin_tools(browser_tool, python_tool) -%}\n {%- if browser_tool %}\n {{- \"## browser\\n\\n\" }}\n {{- \"// Tool for browsing.\\n\" }}\n {{- \"// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\\n\" }}\n {{- \"// Cite information from the tool using the following format:\\n\" }}\n {{- \"// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\\n\" }}\n {{- \"// Do not quote more than 10 words directly from the tool output.\\n\" }}\n {{- \"// sources=web (default: web)\\n\" }}\n {{- \"namespace browser {\\n\\n\" }}\n {{- \"// Searches for information related to `query` and displays `topn` results.\\n\" }}\n {{- \"type search = (_: {\\n\" }}\n {{- \"query: string,\\n\" }}\n {{- \"topn?: number, // default: 10\\n\" }}\n {{- \"source?: string,\\n\" }}\n {{- \"}) => any;\\n\\n\" }}\n {{- \"// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\\n\" }}\n {{- \"// Valid link ids are displayed with the formatting: `【{id}†.*】`.\\n\" }}\n {{- \"// If `cursor` is not provided, the most recent page is implied.\\n\" }}\n {{- \"// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\\n\" }}\n {{- \"// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\\n\" }}\n {{- \"// Use this function without `id` to scroll to a new location of an opened page.\\n\" }}\n {{- \"type open = (_: {\\n\" }}\n {{- \"id?: number | string, // default: -1\\n\" }}\n {{- \"cursor?: number, // default: -1\\n\" }}\n {{- \"loc?: number, // default: -1\\n\" }}\n {{- \"num_lines?: number, // default: -1\\n\" }}\n {{- \"view_source?: boolean, // default: false\\n\" }}\n {{- \"source?: string,\\n\" }}\n {{- \"}) => any;\\n\\n\" }}\n {{- \"// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\\n\" }}\n {{- \"type find = (_: {\\n\" }}\n {{- \"pattern: string,\\n\" }}\n {{- \"cursor?: number, // default: -1\\n\" }}\n {{- \"}) => any;\\n\\n\" }}\n {{- \"} // namespace browser\\n\\n\" }}\n {%- endif -%}\n\n {%- if python_tool %}\n {{- \"## python\\n\\n\" }}\n {{- \"Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\\n\\n\" }}\n {{- \"When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\\n\\n\" }}\n {%- endif -%}\n{%- endmacro -%}\n\n{#- System Message Construction ============================================ #}\n{%- macro build_system_message() -%}\n {%- if model_identity is not defined %}\n {%- set model_identity = \"You are ChatGPT, a large language model trained by OpenAI.\" %}\n {%- endif %}\n {{- model_identity + \"\\n\" }}\n {{- \"Knowledge cutoff: 2024-06\\n\" }}\n {{- \"Current date: \" + strftime_now(\"%Y-%m-%d\") + \"\\n\\n\" }}\n {%- if reasoning_effort is not defined %}\n {%- set reasoning_effort = \"medium\" %}\n {%- endif %}\n {{- \"Reasoning: \" + reasoning_effort + \"\\n\\n\" }}\n {%- if builtin_tools %}\n {{- \"# Tools\\n\\n\" }}\n {%- set available_builtin_tools = namespace(browser=false, python=false) %}\n {%- for tool in builtin_tools %}\n {%- if tool == \"browser\" %}\n {%- set available_builtin_tools.browser = true %}\n {%- elif tool == \"python\" %}\n {%- set available_builtin_tools.python = true %}\n {%- endif %}\n {%- endfor %}\n {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}\n {%- endif -%}\n {{- \"# Valid channels: analysis, commentary, final. Channel must be included for every message.\" }}\n {%- if tools -%}\n {{- \"\\nCalls to these tools must go to the commentary channel: 'functions'.\" }}\n {%- endif -%}\n{%- endmacro -%}\n\n{#- Main Template Logic ================================================= #}\n{#- Set defaults #}\n\n{#- Render system message #}\n{{- \"<|start|>system<|message|>\" }}\n{{- build_system_message() }}\n{{- \"<|end|>\" }}\n\n{#- Extract developer message #}\n{%- if messages[0].role == \"developer\" or messages[0].role == \"system\" %}\n {%- set developer_message = messages[0].content %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set developer_message = \"\" %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{#- Render developer message #}\n{%- if developer_message or tools %}\n {{- \"<|start|>developer<|message|>\" }}\n {%- if developer_message %}\n {{- \"# Instructions\\n\\n\" }}\n {{- developer_message }}\n {{- \"\\n\\n\" }}\n {%- endif %}\n {%- if tools -%}\n {{- \"# Tools\\n\\n\" }}\n {{- render_tool_namespace(\"functions\", tools) }}\n {%- endif -%}\n {{- \"<|end|>\" }}\n{%- endif %}\n\n{#- Render messages #}\n{%- set last_tool_call = namespace(name=none) %}\n{%- for message in loop_messages -%}\n {#- At this point only assistant/user/tool messages should remain #}\n {%- if message.role == 'assistant' -%}\n {#- Checks to ensure the messages are being passed in the format we expect #}\n {%- if \"content\" in message %}\n {%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}\n {{- raise_exception(\"You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.\") }}\n {%- endif %}\n {%- endif %}\n {%- if \"thinking\" in message %}\n {%- if \"<|channel|>analysis<|message|>\" in message.thinking or \"<|channel|>final<|message|>\" in message.thinking %}\n {{- raise_exception(\"You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.\") }}\n {%- endif %}\n {%- endif %}\n {%- if \"tool_calls\" in message %}\n {#- We need very careful handling here - we want to drop the tool call analysis message if the model #}\n {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #}\n {#- when we render CoT/analysis messages in inference. #}\n {%- set future_final_message = namespace(found=false) %}\n {%- for future_message in loop_messages[loop.index:] %}\n {%- if future_message.role == 'assistant' and \"tool_calls\" not in future_message %}\n {%- set future_final_message.found = true %}\n {%- endif %}\n {%- endfor %}\n {#- We assume max 1 tool call per message, and so we infer the tool call name #}\n {#- in \"tool\" messages from the most recent assistant tool call name #}\n {%- set tool_call = message.tool_calls[0] %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {%- if message.content and message.thinking %}\n {{- raise_exception(\"Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.\") }}\n {%- elif message.content and not future_final_message.found %}\n {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.content + \"<|end|>\" }}\n {%- elif message.thinking and not future_final_message.found %}\n {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.thinking + \"<|end|>\" }}\n {%- endif %}\n {{- \"<|start|>assistant to=\" }}\n {{- \"functions.\" + tool_call.name + \"<|channel|>commentary \" }}\n {{- (tool_call.content_type if tool_call.content_type is defined else \"json\") + \"<|message|>\" }}\n {{- tool_call.arguments|tojson }}\n {{- \"<|call|>\" }}\n {%- set last_tool_call.name = tool_call.name %}\n {%- elif loop.last and not add_generation_prompt %}\n {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}\n {#- This is a situation that should only occur in training, never in inference. #}\n {%- if \"thinking\" in message %}\n {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.thinking + \"<|end|>\" }}\n {%- endif %}\n {#- <|return|> indicates the end of generation, but <|end|> does not #}\n {#- <|return|> should never be an input to the model, but we include it as the final token #}\n {#- when training, so the model learns to emit it. #}\n {{- \"<|start|>assistant<|channel|>final<|message|>\" + message.content + \"<|return|>\" }}\n {%- else %}\n {#- CoT is dropped during all previous turns, so we never render it for inference #}\n {{- \"<|start|>assistant<|channel|>final<|message|>\" + message.content + \"<|end|>\" }}\n {%- set last_tool_call.name = none %}\n {%- endif %}\n {%- elif message.role == 'tool' -%}\n {%- if last_tool_call.name is none %}\n {{- raise_exception(\"Message has tool role, but there was no previous assistant message with a tool call!\") }}\n {%- endif %}\n {{- \"<|start|>functions.\" + last_tool_call.name }}\n {{- \" to=assistant<|channel|>commentary<|message|>\" + message.content|tojson + \"<|end|>\" }}\n {%- elif message.role == 'user' -%}\n {{- \"<|start|>user<|message|>\" + message.content + \"<|end|>\" }}\n {%- endif -%}\n{%- endfor -%}\n\n{#- Generation prompt #}\n{%- if add_generation_prompt -%}\n<|start|>assistant\n{%- endif -%}" + } +] diff --git a/mesh-llm/src/models/capabilities.rs b/mesh-llm/src/models/capabilities.rs index ef33a7d9..d17b4967 100644 --- a/mesh-llm/src/models/capabilities.rs +++ b/mesh-llm/src/models/capabilities.rs @@ -278,6 +278,7 @@ where saw_processor = true; } if name.ends_with("tokenizer_config.json") + || name.ends_with("chat_template.jinja") || name.ends_with("chat_template.json") || name.contains("reasoning") || name.contains("thinking") @@ -632,7 +633,12 @@ fn json_contains_tool_use_tokens(value: &Value) -> bool { fn read_local_metadata_jsons(path: &Path) -> Vec<Value> { let mut values = Vec::new(); for dir in path.ancestors().skip(1).take(6) { - for name in ["config.json", "tokenizer_config.json", "chat_template.json"] { + for name in [ + "config.json", + "tokenizer_config.json", + "chat_template.json", + "chat_template.jinja", + ] { let candidate = dir.join(name); if !candidate.is_file() { continue; @@ -640,7 +646,9 @@ fn read_local_metadata_jsons(path: &Path) -> Vec<Value> { let Ok(text) = std::fs::read_to_string(&candidate) else { continue; }; - if let Ok(value) = serde_json::from_str(&text) { + if name.ends_with(".jinja") { + values.push(Value::String(text)); + } else if let Ok(value) = serde_json::from_str(&text) { values.push(value); } } @@ -650,7 +658,12 @@ fn read_local_metadata_jsons(path: &Path) -> Vec<Value> { async fn fetch_remote_metadata_jsons(repo: &str, revision: Option<&str>) -> Vec<Value> { let mut values = Vec::new(); - for filename in ["config.json", "tokenizer_config.json", "chat_template.json"] { + for filename in [ + "config.json", + "tokenizer_config.json", + "chat_template.json", + "chat_template.jinja", + ] { if let Some(value) = fetch_remote_json(repo, revision, filename).await { values.push(value); } @@ -668,7 +681,11 @@ async fn fetch_remote_json(repo: &str, revision: Option<&str>, file: &str) -> Op }; let path = api.repo(repo).get(file).await.ok()?; let text = tokio::fs::read_to_string(path).await.ok()?; - serde_json::from_str(&text).ok() + if file.ends_with(".jinja") { + Some(Value::String(text)) + } else { + serde_json::from_str(&text).ok() + } } #[cfg(test)] diff --git a/mesh-llm/src/models/catalog.json b/mesh-llm/src/models/catalog.json index b82cf894..025104dc 100644 --- a/mesh-llm/src/models/catalog.json +++ b/mesh-llm/src/models/catalog.json @@ -472,6 +472,17 @@ "extra_files": [], "mmproj": null }, + { + "name": "Qwen3-0.6B-MLX", + "file": "model.safetensors", + "url": "https://huggingface.co/mlx-community/Qwen3-0.6B-4bit/resolve/main/model.safetensors", + "size": "335MB", + "description": "Tiny MLX starter, very fast local Apple Silicon chat", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, { "name": "Llama-3.2-1B-Instruct-Q4_K_M", "file": "Llama-3.2-1B-Instruct-Q4_K_M.gguf", @@ -483,6 +494,72 @@ "extra_files": [], "mmproj": null }, + { + "name": "Llama-3.2-1B-Instruct-MLX", + "file": "model.safetensors", + "url": "https://huggingface.co/mlx-community/Llama-3.2-1B-Instruct-4bit/resolve/main/model.safetensors", + "size": "695MB", + "description": "Tiny Meta instruct model in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "Qwen2.5-3B-Instruct-MLX", + "file": "model.safetensors", + "url": "https://huggingface.co/mlx-community/Qwen2.5-3B-Instruct-4bit/resolve/main/model.safetensors", + "size": "1.7GB", + "description": "Small & fast general chat in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "Llama-3.2-3B-Instruct-MLX", + "file": "model.safetensors", + "url": "https://huggingface.co/mlx-community/Llama-3.2-3B-Instruct-4bit/resolve/main/model.safetensors", + "size": "1.8GB", + "description": "Meta Llama 3.2 3B in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "Qwen3-4B-MLX", + "file": "model.safetensors", + "url": "https://huggingface.co/mlx-community/Qwen3-4B-4bit/resolve/main/model.safetensors", + "size": "2.3GB", + "description": "Qwen3 starter in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "Qwen2.5-Coder-14B-Instruct-MLX", + "file": "model.safetensors.index.json", + "url": "https://huggingface.co/lmstudio-community/Qwen2.5-Coder-14B-Instruct-MLX-4bit/resolve/main/model.safetensors.index.json", + "size": "8.3GB", + "description": "Strong code generation in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "Qwen3-30B-A3B-MLX", + "file": "model.safetensors.index.json", + "url": "https://huggingface.co/mlx-community/Qwen3-30B-A3B-4bit/resolve/main/model.safetensors.index.json", + "size": "17.2GB", + "description": "MoE general chat in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, { "name": "Gemma-3-1B-it-Q4_K_M", "file": "Gemma-3-1B-it-Q4_K_M.gguf", @@ -493,5 +570,49 @@ "moe": null, "extra_files": [], "mmproj": null + }, + { + "name": "Gemma-2-2B-it-MLX", + "file": "model.safetensors.index.json", + "url": "https://huggingface.co/mlx-community/gemma-2-2b-it-4bit/resolve/main/model.safetensors.index.json", + "size": "1.4GB", + "description": "Gemma 2 2B instruct in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "GLM-4-9B-0414-MLX", + "file": "model.safetensors.index.json", + "url": "https://huggingface.co/mlx-community/GLM-4-9B-0414-4bit/resolve/main/model.safetensors.index.json", + "size": "5.3GB", + "description": "GLM 4 9B dense instruct in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "LFM2-350M-MLX", + "file": "model.safetensors.index.json", + "url": "https://huggingface.co/mlx-community/LFM2-350M-4bit/resolve/main/model.safetensors.index.json", + "size": "220MB", + "description": "Liquid LFM2 350M instruct in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null + }, + { + "name": "Gemma-4-E4B-it-MLX", + "file": "model.safetensors.index.json", + "url": "https://huggingface.co/unsloth/gemma-4-E4B-it-UD-MLX-4bit/resolve/main/model.safetensors.index.json", + "size": "5.6GB", + "description": "Gemma 4 E4B instruct in native MLX format", + "draft": null, + "moe": null, + "extra_files": [], + "mmproj": null } ] diff --git a/mesh-llm/src/models/mod.rs b/mesh-llm/src/models/mod.rs index b051cfec..10c85db6 100644 --- a/mesh-llm/src/models/mod.rs +++ b/mesh-llm/src/models/mod.rs @@ -4,6 +4,7 @@ pub mod gguf; pub mod inventory; pub mod local; mod maintenance; +pub mod prompt; mod resolve; pub mod search; pub mod topology; @@ -20,6 +21,7 @@ pub use local::{ scan_local_models, }; pub use maintenance::{run_update, warn_about_updates_for_paths}; +pub use prompt::{infer_prompt_behavior_for_dir, ModelPromptBehavior}; pub(crate) use resolve::resolve_model_spec_with_progress; pub use resolve::{ download_exact_ref, find_catalog_model_exact, installed_model_capabilities, diff --git a/mesh-llm/src/models/prompt.rs b/mesh-llm/src/models/prompt.rs new file mode 100644 index 00000000..26e2bc76 --- /dev/null +++ b/mesh-llm/src/models/prompt.rs @@ -0,0 +1,275 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::path::Path; +use std::sync::LazyLock; + +static CHAT_TEMPLATE_RE: LazyLock<regex_lite::Regex> = LazyLock::new(|| { + regex_lite::Regex::new(r#""chat_template"\s*:\s*"((?:\\.|[^"\\])*)""#) + .expect("Failed to compile CHAT_TEMPLATE_RE regex pattern") +}); + +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelPromptBehavior { + pub prompt_template: Option<String>, + pub default_system_prompt: Option<String>, + pub template_source: Option<String>, +} + +pub fn infer_prompt_behavior_for_dir(dir: &Path) -> Option<ModelPromptBehavior> { + let config = read_config_json(dir); + if let Some(template) = read_template_text(dir) { + let mut behavior = classify_template_behavior(&template, config.as_ref()); + behavior.template_source = Some("huggingface".to_string()); + return Some(behavior); + } + config + .as_ref() + .and_then(heuristic_prompt_behavior) + .map(|mut behavior| { + behavior.template_source = Some("fallback".to_string()); + behavior + }) +} + +fn read_config_json(dir: &Path) -> Option<Value> { + let text = std::fs::read_to_string(dir.join("config.json")).ok()?; + serde_json::from_str(&text).ok() +} + +/// Scans a model directory for a chat template and returns `(source_filename, template_text)`. +/// +/// Checks files in priority order: +/// 1. `chat_template.jinja` +/// 2. `chat_template.json` +/// 3. `tokenizer_config.json` +/// +/// This shared helper is the single source of truth used by both +/// `infer_prompt_behavior_for_dir` and the MLX template loader. +pub fn find_template_with_source(dir: &Path) -> Option<(String, String)> { + for filename in [ + "chat_template.jinja", + "chat_template.json", + "tokenizer_config.json", + ] { + let Ok(text) = std::fs::read_to_string(dir.join(filename)) else { + continue; + }; + if filename.ends_with(".jinja") { + return Some((filename.to_string(), text)); + } + if let Some(template) = extract_template_text_from_json_text(&text) { + return Some((filename.to_string(), template)); + } + let Ok(value) = serde_json::from_str::<Value>(&text) else { + continue; + }; + if let Some(template) = extract_template_text(&value) { + return Some((filename.to_string(), template)); + } + } + None +} + +fn read_template_text(dir: &Path) -> Option<String> { + find_template_with_source(dir).map(|(_source, text)| text) +} + +fn extract_template_text(value: &Value) -> Option<String> { + match value { + Value::String(text) => Some(text.clone()), + Value::Object(map) => map + .get("chat_template") + .and_then(|template| template.as_str()) + .map(ToOwned::to_owned), + _ => None, + } +} + +fn extract_template_text_from_json_text(text: &str) -> Option<String> { + let captures = CHAT_TEMPLATE_RE.captures(text)?; + serde_json::from_str::<String>(&format!("\"{}\"", &captures[1])).ok() +} + +fn classify_template_behavior(template: &str, config: Option<&Value>) -> ModelPromptBehavior { + let prompt_template = if template.contains("<|im_start|>") { + Some("chatml".to_string()) + } else if template.contains("<start_of_turn>") && template.contains("<end_of_turn>") { + Some("gemma3".to_string()) + } else if template.contains("<|start_header_id|>") && template.contains("<|eot_id|>") { + Some("llama3".to_string()) + } else { + Some("hf_template".to_string()) + }; + ModelPromptBehavior { + prompt_template, + default_system_prompt: inferred_default_system_prompt(config, template), + template_source: None, + } +} + +fn heuristic_prompt_behavior(config: &Value) -> Option<ModelPromptBehavior> { + let family = model_family(config)?; + let prompt_template = match family.as_str() { + "qwen" => "chatml", + "gemma" => "gemma3", + "llama" => "llama3", + _ => return None, + }; + Some(ModelPromptBehavior { + prompt_template: Some(prompt_template.to_string()), + default_system_prompt: if family == "qwen" { + Some("You are a helpful assistant.".to_string()) + } else { + None + }, + template_source: None, + }) +} + +fn inferred_default_system_prompt(config: Option<&Value>, template: &str) -> Option<String> { + if template.contains("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") { + return Some("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.".into()); + } + if template.contains("You are a helpful assistant.") { + return Some("You are a helpful assistant.".to_string()); + } + match config.and_then(model_family).as_deref() { + Some("qwen") if template.contains("<|im_start|>system") => { + Some("You are a helpful assistant.".to_string()) + } + _ => None, + } +} + +fn model_family(config: &Value) -> Option<String> { + let model_type = config + .get("model_type") + .and_then(|value| value.as_str()) + .unwrap_or_default() + .to_ascii_lowercase(); + let architectures = config + .get("architectures") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|value| value.as_str()) + .map(|value| value.to_ascii_lowercase()) + .collect::<Vec<_>>(); + + if model_type.starts_with("qwen") || architectures.iter().any(|value| value.contains("qwen")) { + return Some("qwen".to_string()); + } + if model_type.starts_with("gemma") || architectures.iter().any(|value| value.contains("gemma")) + { + return Some("gemma".to_string()); + } + if model_type == "llama" || architectures.iter().any(|value| value.contains("llama")) { + return Some("llama".to_string()); + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn infers_huggingface_template_behavior_for_qwen() { + let root = + std::env::temp_dir().join(format!("mesh-llm-prompt-qwen-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("config.json"), + serde_json::json!({"model_type":"qwen2"}).to_string(), + ) + .unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + }) + .to_string(), + ) + .unwrap(); + + let behavior = infer_prompt_behavior_for_dir(&root).unwrap(); + assert_eq!(behavior.prompt_template.as_deref(), Some("chatml")); + assert_eq!( + behavior.default_system_prompt.as_deref(), + Some("You are a helpful assistant.") + ); + assert_eq!(behavior.template_source.as_deref(), Some("huggingface")); + } + + #[test] + fn infers_fallback_behavior_for_llama() { + let root = + std::env::temp_dir().join(format!("mesh-llm-prompt-llama-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("config.json"), + serde_json::json!({ + "model_type":"llama", + "architectures":["LlamaForCausalLM"] + }) + .to_string(), + ) + .unwrap(); + + let behavior = infer_prompt_behavior_for_dir(&root).unwrap(); + assert_eq!(behavior.prompt_template.as_deref(), Some("llama3")); + assert_eq!(behavior.template_source.as_deref(), Some("fallback")); + } + + #[test] + fn infers_fallback_behavior_for_gemma() { + let root = + std::env::temp_dir().join(format!("mesh-llm-prompt-gemma-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("config.json"), + serde_json::json!({ + "model_type":"gemma3", + "architectures":["Gemma3ForConditionalGeneration"] + }) + .to_string(), + ) + .unwrap(); + + let behavior = infer_prompt_behavior_for_dir(&root).unwrap(); + assert_eq!(behavior.prompt_template.as_deref(), Some("gemma3")); + assert_eq!(behavior.template_source.as_deref(), Some("fallback")); + } + + #[test] + fn infers_huggingface_template_behavior_for_gemma() { + let root = + std::env::temp_dir().join(format!("mesh-llm-prompt-gemma-hf-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&root); + std::fs::create_dir_all(&root).unwrap(); + std::fs::write( + root.join("config.json"), + serde_json::json!({ + "model_type":"gemma3", + "architectures":["Gemma3ForConditionalGeneration"] + }) + .to_string(), + ) + .unwrap(); + std::fs::write( + root.join("tokenizer_config.json"), + serde_json::json!({ + "chat_template": "{{ bos_token }}<start_of_turn>user\nhello<end_of_turn>\n<start_of_turn>model\n" + }) + .to_string(), + ) + .unwrap(); + + let behavior = infer_prompt_behavior_for_dir(&root).unwrap(); + assert_eq!(behavior.prompt_template.as_deref(), Some("gemma3")); + assert_eq!(behavior.template_source.as_deref(), Some("huggingface")); + } +} diff --git a/mesh-llm/src/protocol/convert.rs b/mesh-llm/src/protocol/convert.rs index 1fe036eb..2eb5c8a5 100644 --- a/mesh-llm/src/protocol/convert.rs +++ b/mesh-llm/src/protocol/convert.rs @@ -214,6 +214,7 @@ fn runtime_descriptor_to_proto( identity_hash: descriptor.identity_hash.clone(), context_length: descriptor.context_length, ready: descriptor.ready, + backend: descriptor.backend.clone(), } } @@ -224,6 +225,7 @@ fn proto_runtime_descriptor_to_local( model_name: descriptor.model_name.clone(), identity_hash: descriptor.identity_hash.clone(), context_length: descriptor.context_length, + backend: descriptor.backend.clone(), ready: descriptor.ready, } } diff --git a/mesh-llm/src/runtime/instance.rs b/mesh-llm/src/runtime/instance.rs index b048e947..b3a7dd20 100644 --- a/mesh-llm/src/runtime/instance.rs +++ b/mesh-llm/src/runtime/instance.rs @@ -263,7 +263,7 @@ pub fn runtime_root() -> Result<PathBuf> { pub struct InstanceRuntime { dir: PathBuf, pid: u32, - _lock_file: File, + lock_file: File, } impl InstanceRuntime { @@ -342,7 +342,7 @@ impl InstanceRuntime { Ok(Self { dir, pid, - _lock_file: lock_file, + lock_file, }) } @@ -379,6 +379,20 @@ impl InstanceRuntime { } } +impl Drop for InstanceRuntime { + fn drop(&mut self) { + #[cfg(unix)] + { + use std::os::unix::io::AsRawFd; + + let fd = self.lock_file.as_raw_fd(); + // SAFETY: flock is safe to call with a valid fd. This is best-effort + // cleanup on drop; the fd close that follows is still the hard backstop. + let _ = unsafe { libc::flock(fd, libc::LOCK_UN) }; + } + } +} + /// Probe whether the flock at `lock_path` is currently held by a live process. /// /// Opens the file and attempts a non-blocking exclusive flock: diff --git a/mesh-llm/src/runtime/local.rs b/mesh-llm/src/runtime/local.rs index 5646c79d..9e34f917 100644 --- a/mesh-llm/src/runtime/local.rs +++ b/mesh-llm/src/runtime/local.rs @@ -34,6 +34,20 @@ pub(super) struct ManagedModelController { } pub(super) fn resolved_model_name(path: &Path) -> String { + #[cfg(target_os = "macos")] + if let Some(dir) = crate::mlx::mlx_model_dir(path) { + if let Some(identity) = + crate::models::huggingface_identity_for_path(&dir.join("config.json")) + { + if let Some(name) = identity.repo_id.rsplit('/').next() { + return name.to_string(); + } + } + if let Some(name) = dir.file_name().and_then(|value| value.to_str()) { + return name.to_string(); + } + } + let stem = path .file_stem() .unwrap_or_default() @@ -108,8 +122,9 @@ pub(super) async fn set_advertised_model_context( node: &mesh::Node, model_name: &str, context_length: Option<u32>, + backend: Option<&str>, ) { - node.set_model_runtime_context_length(model_name, context_length) + node.set_model_runtime_context_length(model_name, context_length, backend) .await; node.regossip().await; } @@ -186,27 +201,45 @@ pub(super) async fn start_runtime_local_model( let mmproj_path = mmproj_override .map(Path::to_path_buf) .or_else(|| mmproj_path_for_model(&model_name)); - let process = launch::start_llama_server( - runtime, - bin_dir, - binary_flavor, - launch::ModelLaunchSpec { - model: model_path, - http_port: llama_port, - tunnel_ports: &[], - tensor_split: None, - split_mode: election::local_multi_gpu_split_mode(binary_flavor), - draft: None, - draft_max: 0, - model_bytes, - my_vram, - mmproj: mmproj_path.as_deref(), - ctx_size_override, - total_group_vram: None, - selected_gpu: None, - }, - ) - .await?; + #[cfg(target_os = "macos")] + let mlx_process = if crate::mlx::is_mlx_model_dir(model_path) { + let dir = crate::mlx::mlx_model_dir(model_path) + .expect("mlx path should normalize after compatibility check"); + Some(crate::mlx::start_mlx_server(dir, model_name.clone(), llama_port).await?) + } else { + None + }; + #[cfg(not(target_os = "macos"))] + let mlx_process: Option<launch::InferenceServerProcess> = None; + + let (backend, process) = if let Some(process) = mlx_process { + ("mlx", process) + } else { + ( + "llama", + launch::start_llama_server( + runtime, + bin_dir, + binary_flavor, + launch::ModelLaunchSpec { + model: model_path, + http_port: llama_port, + tunnel_ports: &[], + tensor_split: None, + split_mode: election::local_multi_gpu_split_mode(binary_flavor), + draft: None, + draft_max: 0, + model_bytes, + my_vram, + mmproj: mmproj_path.as_deref(), + ctx_size_override, + total_group_vram: None, + selected_gpu: None, + }, + ) + .await?, + ) + }; let backend_proxy = backend::start_backend_proxy(llama_port).await?; let port = backend_proxy.port(); @@ -214,7 +247,7 @@ pub(super) async fn start_runtime_local_model( model_name, LocalRuntimeModelHandle { port, - backend: "llama".into(), + backend: backend.into(), process: process.handle, backend_proxy, context_length: process.context_length, diff --git a/mesh-llm/src/runtime/mod.rs b/mesh-llm/src/runtime/mod.rs index 29ff11bd..4252d261 100644 --- a/mesh-llm/src/runtime/mod.rs +++ b/mesh-llm/src/runtime/mod.rs @@ -38,10 +38,18 @@ struct StartupModelSpec { model_ref: PathBuf, mmproj_ref: Option<PathBuf>, ctx_size: Option<u32>, + backend_hint: StartupBackendHint, gpu_id: Option<String>, config_owned: bool, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum StartupBackendHint { + Auto, + Gguf, + Mlx, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct StartupPinnedGpuTarget { pub(crate) index: usize, @@ -56,6 +64,7 @@ struct StartupModelPlan { resolved_path: PathBuf, mmproj_path: Option<PathBuf>, ctx_size: Option<u32>, + backend_hint: StartupBackendHint, gpu_id: Option<String>, pinned_gpu: Option<StartupPinnedGpuTarget>, } @@ -242,9 +251,9 @@ pub(crate) async fn run() -> Result<()> { } let config = plugin::load_config(cli.config.as_deref())?; - let cli_has_explicit_models = cli_has_explicit_models(&cli); + let has_cli_explicit_models = cli_has_explicit_models(&cli); let has_config_models = !config.models.is_empty(); - let has_startup_models = cli_has_explicit_models || has_config_models; + let has_startup_models = has_cli_explicit_models || has_config_models; // Acquire the per-instance runtime directory and flock (skip for --client β€” no local servers). // Wrap in Arc so it can be cheaply shared with election/spawn tasks that @@ -460,14 +469,16 @@ pub(crate) async fn run() -> Result<()> { } // --- Validation --- - if cli.client && (!cli.model.is_empty() || !cli.gguf.is_empty()) { - anyhow::bail!("--client and --model are mutually exclusive"); + if cli.client + && (!cli.model.is_empty() || !cli.gguf_file.is_empty() || !cli.mlx_file.is_empty()) + { + anyhow::bail!("--client is mutually exclusive with model selection flags: --model, --gguf-file, --mlx-file"); } if let Some(mmproj) = &cli.mmproj { anyhow::ensure!(!cli.client, "--mmproj cannot be used with --client"); anyhow::ensure!( - !cli.model.is_empty() || !cli.gguf.is_empty(), - "--mmproj requires an explicit primary model via --model or --gguf" + cli_has_explicit_models(&cli), + "--mmproj requires an explicit primary model via --model, --gguf-file, or --mlx-file" ); anyhow::ensure!( mmproj.is_file(), @@ -484,7 +495,7 @@ pub(crate) async fn run() -> Result<()> { .join("config.toml") }); eprintln!( - "⚠️ `mesh-llm serve` needs at least one startup model.\n Add `[[models]]` to {}, or pass `--model` / `--gguf` explicitly.", + "⚠️ `mesh-llm serve` needs at least one startup model.\n Add `[[models]]` to {}, or pass `--model`, `--gguf-file`, or `--mlx-file` explicitly.", config_path.display() ); Cli::command().print_help().ok(); @@ -503,11 +514,7 @@ pub(crate) async fn run() -> Result<()> { // Strip split GGUF suffix so "MiniMax-M2.5-Q4_K_M-00001-of-00004" β†’ "MiniMax-M2.5-Q4_K_M" let requested_model_names: Vec<String> = resolved_models .iter() - .filter_map(|m| { - m.file_stem() - .and_then(|s| s.to_str()) - .map(router::strip_split_suffix_owned) - }) + .map(|m| resolved_model_name(m)) .collect(); let bin_dir = match &cli.bin_dir { @@ -529,11 +536,21 @@ pub(crate) async fn run() -> Result<()> { /// Resolve a model path: local file, catalog name, or HuggingFace URL. async fn resolve_model(input: &std::path::Path) -> Result<PathBuf> { + let s = input.to_string_lossy(); + + // Already a local file + if input.exists() { + return Ok(input.to_path_buf()); + } + if s.contains('/') { + return models::download_exact_ref(&s).await; + } + models::resolve_model_spec(input).await } fn cli_has_explicit_models(cli: &Cli) -> bool { - !cli.model.is_empty() || !cli.gguf.is_empty() + !cli.model.is_empty() || !cli.gguf_file.is_empty() || !cli.mlx_file.is_empty() } fn build_startup_model_specs( @@ -544,9 +561,16 @@ fn build_startup_model_specs( return Ok(Vec::new()); } + #[cfg(not(target_os = "macos"))] + { + if !cli.mlx_file.is_empty() { + anyhow::bail!("MLX model selection is only supported on macOS"); + } + } + let mut specs = Vec::new(); if cli_has_explicit_models(cli) { - for path in &cli.gguf { + for path in &cli.gguf_file { if !path.exists() { anyhow::bail!("GGUF file not found: {}", path.display()); } @@ -554,6 +578,17 @@ fn build_startup_model_specs( model_ref: path.clone(), mmproj_ref: None, ctx_size: cli.ctx_size, + backend_hint: StartupBackendHint::Gguf, + gpu_id: None, + config_owned: false, + }); + } + for path in &cli.mlx_file { + specs.push(StartupModelSpec { + model_ref: path.clone(), + mmproj_ref: None, + ctx_size: cli.ctx_size, + backend_hint: StartupBackendHint::Mlx, gpu_id: None, config_owned: false, }); @@ -563,6 +598,7 @@ fn build_startup_model_specs( model_ref: model.clone(), mmproj_ref: None, ctx_size: cli.ctx_size, + backend_hint: StartupBackendHint::Auto, gpu_id: None, config_owned: false, }); @@ -580,6 +616,7 @@ fn build_startup_model_specs( model_ref: PathBuf::from(model.model.clone()), mmproj_ref: model.mmproj.as_ref().map(PathBuf::from), ctx_size: cli.ctx_size.or(model.ctx_size), + backend_hint: StartupBackendHint::Auto, gpu_id: model.gpu_id.clone(), config_owned: true, }); @@ -600,6 +637,7 @@ async fn resolve_startup_models(specs: &[StartupModelSpec]) -> Result<Vec<Startu resolved_path, mmproj_path, ctx_size: spec.ctx_size, + backend_hint: spec.backend_hint, gpu_id: spec.gpu_id.clone(), pinned_gpu: None, }); @@ -1554,15 +1592,7 @@ async fn run_auto( } }; - let model_name = { - let stem = model - .file_stem() - .unwrap_or_default() - .to_string_lossy() - .to_string(); - // Strip split GGUF suffix: "MiniMax-M2.5-Q4_K_M-00001-of-00004" β†’ "MiniMax-M2.5-Q4_K_M" - router::strip_split_suffix_owned(&stem) - }; + let model_name = resolved_model_name(&model); // Set model source for gossip (so other joiners can discover it too) let model_source = primary_startup_model @@ -1595,31 +1625,42 @@ async fn run_auto( let _ = crate::runtime::instance::reap::reap_own_stale_pidfiles(rt.dir()).await; } - // Serve mode (non-client) always has the InstanceRuntime acquired above. - // The fallback was only relevant during the T1-T11 staging when acquisition - // wasn't yet wired into run() β€” keep an explicit error here so any future - // refactor that drops the acquire surfaces immediately instead of panicking - // mid-spawn from a child task. + #[cfg(target_os = "macos")] + let is_mlx = primary_startup_model + .as_ref() + .is_some_and(|model| model.backend_hint == StartupBackendHint::Mlx) + || crate::mlx::is_mlx_model_dir(&model); + #[cfg(not(target_os = "macos"))] + let is_mlx = false; + let runtime_arc = runtime .as_ref() .ok_or_else(|| anyhow::anyhow!("serve mode requires an instance runtime"))? .clone(); - let rpc_handle = launch::start_rpc_server( - &runtime_arc, - &bin_dir, - cli.llama_flavor, - startup_rpc_backend_device(cli.device.as_deref(), primary_startup_model.as_ref())?, - Some(&model), - ) - .await?; - tracing::info!( - "rpc-server on 127.0.0.1:{} (pid {}) serving {model_name}", - rpc_handle.port, - rpc_handle.pid - ); + + let rpc_handle = if is_mlx { + tracing::info!("MLX model detected β€” skipping rpc-server"); + None + } else { + let handle = launch::start_rpc_server( + &runtime_arc, + &bin_dir, + cli.llama_flavor, + startup_rpc_backend_device(cli.device.as_deref(), primary_startup_model.as_ref())?, + Some(&model), + ) + .await?; + tracing::info!( + "rpc-server on 127.0.0.1:{} (pid {}) serving {model_name}", + handle.port, + handle.pid + ); + Some(handle) + }; + let rpc_port = rpc_handle.as_ref().map(|handle| handle.port).unwrap_or(0); let tunnel_mgr = - tunnel::Manager::start(node.clone(), rpc_handle.port, channels.rpc, channels.http).await?; + tunnel::Manager::start(node.clone(), rpc_port, channels.rpc, channels.http).await?; // Election publishes per-model targets let (target_tx, target_rx) = tokio::sync::watch::channel(election::ModelTargets::default()); @@ -1673,7 +1714,6 @@ async fn run_auto( plugin_manager.clone(), affinity_router.clone(), ); - cs.set_primary_backend("llama".into()).await; cs.set_runtime_control(control_tx.clone()).await; cs.set_nostr_relays(nostr_relays(&cli.nostr_relay)).await; cs.set_nostr_discovery(cli.nostr_discovery).await; @@ -1742,6 +1782,12 @@ async fn run_auto( let force_split = cli.split; let llama_flavor = cli.llama_flavor; let cb_console_port = console_port; + let primary_backend_label = + if crate::mlx::mlx_model_dir(&model).is_some_and(crate::mlx::is_mlx_model_dir) { + "MLX server" + } else { + "llama-server" + }; let model_name_for_cb = model_name.clone(); let model_name_for_election = model_name.clone(); let node_for_cb = node.clone(); @@ -1770,7 +1816,7 @@ async fn run_auto( node: node2, tunnel_mgr: tunnel_mgr2, ingress_http_port: api_port, - rpc_port: rpc_handle.port, + rpc_port, bin_dir: bin_dir2, model: model2, model_name: model_name_for_election, @@ -1811,7 +1857,7 @@ async fn run_auto( eprintln!(" pi: pi --provider mesh --model {model_name_for_cb}"); eprintln!(" goose: GOOSE_PROVIDER=openai OPENAI_HOST={url} OPENAI_API_KEY=mesh GOOSE_MODEL={model_name_for_cb} goose session"); } else if is_host { - eprintln!("⏳ Starting llama-server..."); + eprintln!("⏳ Starting {primary_backend_label}..."); } else { eprintln!(" API: http://localhost:{api_port} (proxied to host)"); } @@ -1833,9 +1879,11 @@ async fn run_auto( &context_node, &model_name, Some(process.context_length), + Some(&process.backend), ) .await; if let Some(cs) = console_state { + cs.set_primary_backend(process.backend.clone()).await; cs.upsert_local_process(local_process_payload( &model_name, &process.backend, @@ -1846,8 +1894,10 @@ async fn run_auto( } } None => { - set_advertised_model_context(&context_node, &model_name, None).await; + set_advertised_model_context(&context_node, &model_name, None, None) + .await; if let Some(cs) = console_state { + cs.clear_primary_backend().await; cs.remove_local_process(&model_name).await; } } @@ -1875,27 +1925,13 @@ async fn run_auto( // Announce all models to mesh let all_names: Vec<String> = startup_models .iter() - .map(|m| { - m.resolved_path - .file_stem() - .unwrap_or_default() - .to_string_lossy() - .to_string() - }) + .map(|m| resolved_model_name(&m.resolved_path)) .collect(); node.set_models(all_names).await; node.regossip().await; for extra_model in startup_models.iter().skip(1) { - let extra_name = { - let stem = extra_model - .resolved_path - .file_stem() - .unwrap_or_default() - .to_string_lossy() - .to_string(); - router::strip_split_suffix_owned(&stem) - }; + let extra_name = resolved_model_name(&extra_model.resolved_path); let extra_node = node.clone(); let extra_tunnel = tunnel_mgr.clone(); let extra_bin = bin_dir.clone(); @@ -1969,6 +2005,7 @@ async fn run_auto( &context_node, &model_name, Some(process.context_length), + Some(&process.backend), ) .await; if let Some(cs) = console_state { @@ -1982,8 +2019,13 @@ async fn run_auto( } } None => { - set_advertised_model_context(&context_node, &model_name, None) - .await; + set_advertised_model_context( + &context_node, + &model_name, + None, + None, + ) + .await; if let Some(cs) = console_state { cs.remove_local_process(&model_name).await; } @@ -2049,7 +2091,10 @@ async fn run_auto( api::RuntimeControlRequest::Load { spec, resp } => { let mut assigned_runtime_model: Option<String> = None; let result = async { - let model_path = resolve_model(&PathBuf::from(&spec)).await?; + let model_path = resolve_model( + &PathBuf::from(&spec), + ) + .await?; let runtime_model_name = resolved_model_name(&model_path); let already_loaded = managed_models.contains_key(&runtime_model_name) || runtime_models.contains_key(&runtime_model_name); @@ -2077,6 +2122,7 @@ async fn run_auto( &node, &loaded_name, Some(handle.context_length), + Some(&handle.backend), ) .await; advertise_model_ready(&node, &primary_model_name, &loaded_name).await; @@ -2222,7 +2268,9 @@ async fn run_auto( node.set_serving_models(Vec::new()).await; node.set_hosted_models(Vec::new()).await; - rpc_handle.shutdown().await; + if let Some(handle) = rpc_handle { + handle.shutdown().await; + } if let Some(rt) = runtime { let outstanding_refs = std::sync::Arc::strong_count(&rt); if outstanding_refs == 1 { @@ -2519,15 +2567,7 @@ fn build_serving_list(resolved_models: &[PathBuf], model_name: &str) -> Vec<Stri let clean_name = router::strip_split_suffix_owned(model_name); let mut all: Vec<String> = resolved_models .iter() - .map(|m| { - let stem = m - .file_stem() - .unwrap_or_default() - .to_string_lossy() - .to_string(); - // Strip split GGUF suffix: "Model-00001-of-00004" β†’ "Model" - router::strip_split_suffix_owned(&stem) - }) + .map(|m| resolved_model_name(m)) .collect(); if !all.contains(&clean_name) { all.insert(0, clean_name); @@ -2847,6 +2887,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: Some(8192), + backend_hint: StartupBackendHint::Auto, gpu_id: specs[0].gpu_id.clone(), pinned_gpu: None, }]; @@ -2901,6 +2942,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: None, + backend_hint: StartupBackendHint::Auto, gpu_id: specs[0].gpu_id.clone(), pinned_gpu: None, }]; @@ -2927,6 +2969,7 @@ mod tests { model_ref: PathBuf::from("Qwen3-8B-Q4_K_M"), mmproj_ref: None, ctx_size: None, + backend_hint: StartupBackendHint::Auto, gpu_id: None, config_owned: true, }]; @@ -2935,6 +2978,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: None, + backend_hint: StartupBackendHint::Auto, gpu_id: None, pinned_gpu: None, }]; @@ -2961,6 +3005,7 @@ mod tests { model_ref: PathBuf::from("Qwen3-8B-Q4_K_M"), mmproj_ref: None, ctx_size: Some(4096), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("uuid:GPU-123".into()), config_owned: true, }]; @@ -2969,6 +3014,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: Some(4096), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("uuid:GPU-123".into()), pinned_gpu: None, }]; @@ -2996,6 +3042,7 @@ mod tests { model_ref: PathBuf::from("Qwen3-8B-Q4_K_M"), mmproj_ref: None, ctx_size: Some(4096), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("uuid:GPU-123".into()), config_owned: true, }]; @@ -3004,6 +3051,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: Some(4096), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("uuid:GPU-123".into()), pinned_gpu: None, }]; @@ -3030,6 +3078,7 @@ mod tests { model_ref: PathBuf::from("Qwen3-8B-Q4_K_M"), mmproj_ref: None, ctx_size: None, + backend_hint: StartupBackendHint::Auto, gpu_id: Some("pci:0000:b3:00.0".into()), config_owned: true, }]; @@ -3038,6 +3087,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: None, + backend_hint: StartupBackendHint::Auto, gpu_id: Some("pci:0000:b3:00.0".into()), pinned_gpu: None, }]; @@ -3066,6 +3116,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: Some(8192), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("pci:0000:65:00.0".into()), pinned_gpu: Some(StartupPinnedGpuTarget { index: 0, @@ -3088,6 +3139,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: Some(8192), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("pci:0000:65:00.0".into()), pinned_gpu: Some(StartupPinnedGpuTarget { index: 0, @@ -3109,6 +3161,7 @@ mod tests { resolved_path: PathBuf::from("/tmp/Qwen3-8B-Q4_K_M.gguf"), mmproj_path: None, ctx_size: Some(8192), + backend_hint: StartupBackendHint::Auto, gpu_id: Some("pci:0000:65:00.0".into()), pinned_gpu: Some(StartupPinnedGpuTarget { index: 0, @@ -3145,6 +3198,7 @@ mod tests { model_ref: PathBuf::from("Qwen3-8B-Q4_K_M"), mmproj_ref: None, ctx_size: None, + backend_hint: StartupBackendHint::Auto, gpu_id: None, config_owned: false, }]; diff --git a/scripts/build-mac.sh b/scripts/build-mac.sh index 367ae2a2..15aab6e9 100755 --- a/scripts/build-mac.sh +++ b/scripts/build-mac.sh @@ -16,6 +16,7 @@ UI_DIR="$MESH_DIR/ui" compiler_launcher_flags=() rustc_wrapper="" +LLAMA_BRANCH="${LLAMA_BRANCH:-upstream-latest}" detect_jobs() { sysctl -n hw.ncpu 2>/dev/null || echo 4 @@ -71,6 +72,7 @@ cmake_flags=( -B "$BUILD_DIR" -S "$LLAMA_DIR" -DGGML_METAL=ON + -DGGML_NATIVE=OFF -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF -DLLAMA_OPENSSL=OFF diff --git a/scripts/build-release.sh b/scripts/build-release.sh index 44d8329a..d28758ea 100755 --- a/scripts/build-release.sh +++ b/scripts/build-release.sh @@ -28,7 +28,7 @@ configure_compiler_cache() { elif command -v ccache >/dev/null 2>&1; then cache_bin="ccache" else - return + return 0 fi echo "Using compiler cache: $cache_bin" @@ -61,6 +61,7 @@ cmake_flags=( -B "$BUILD_DIR" -S "$LLAMA_DIR" -DGGML_RPC=ON + -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF -DLLAMA_OPENSSL=OFF ) diff --git a/scripts/ci-exact-smoke.py b/scripts/ci-exact-smoke.py new file mode 100644 index 00000000..c8682ec7 --- /dev/null +++ b/scripts/ci-exact-smoke.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python3 +"""Run the deterministic exact smoke suite against one mesh-llm model/backend.""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import signal +import socket +import subprocess +import sys +import tempfile +import time +import urllib.request +from pathlib import Path +from typing import Any + +DEFAULT_WAIT_SECONDS = 300 +DEFAULT_REQUEST_TIMEOUT = 300 + + +def pick_free_port() -> int: + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def http_json(url: str, payload: dict[str, Any] | None = None, timeout: int = 60) -> dict[str, Any]: + if payload is None: + request = urllib.request.Request(url) + else: + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(request, timeout=timeout) as response: + return json.load(response) + + +def case_dir() -> Path | None: + raw = os.environ.get("VALIDATION_CASE_DIR", "").strip() + if not raw: + return None + return Path(raw) + + +def temp_root_path() -> Path: + custom = os.environ.get("TMPDIR") + if custom: + return Path(custom) + return Path(tempfile.gettempdir()) + + +def sync_runtime_logs(case_directory: Path | None, mesh_log_path: Path) -> None: + if case_directory is None: + return + case_directory.mkdir(parents=True, exist_ok=True) + if mesh_log_path.exists(): + shutil.copyfile(mesh_log_path, case_directory / "mesh.log") + + temp_dir = temp_root_path() + for source_name, target_name in ( + ("mesh-llm-llama-server.log", "llama-server.log"), + ("mesh-llm-rpc-server.log", "rpc-server.log"), + ): + source_path = temp_dir / source_name + if source_path.exists(): + shutil.copyfile(source_path, case_directory / target_name) + + +def model_root_for(model_arg: str) -> Path: + path = Path(model_arg) + return path if path.is_dir() else path.parent + + +def ensure_expected_template_source(model_arg: str, expected_template_source: str) -> None: + model_root = model_root_for(model_arg) + expected_path = model_root / expected_template_source + if not expected_path.exists(): + print( + f"❌ Expected template source file not found in model directory: {expected_template_source}", + file=sys.stderr, + ) + print(f"Model directory: {model_root}", file=sys.stderr) + raise SystemExit(1) + + +def build_launch_command(args: argparse.Namespace, api_port: int, console_port: int) -> list[str]: + command = [args.mesh_llm] + if args.backend == "mlx": + command.extend(["--mlx-file", args.model]) + else: + command.extend(["--gguf-file", args.model, "--bin-dir", args.bin_dir]) + command.extend(["--no-draft", "--port", str(api_port), "--console", str(console_port)]) + return command + + +def wait_until_ready( + process: subprocess.Popen[str], + console_port: int, + log_path: Path, + timeout: int, +) -> None: + status_url = f"http://127.0.0.1:{console_port}/api/status" + for second in range(1, timeout + 1): + sync_runtime_logs(case_dir(), log_path) + if process.poll() is not None: + print("❌ mesh-llm exited unexpectedly", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + + try: + status = http_json(status_url, timeout=5) + if bool(status.get("llama_ready", False)): + print(f"βœ… Model loaded in {second}s", flush=True) + return + except Exception: + pass + + if second % 15 == 0: + print(f" Still waiting... ({second}s)", flush=True) + time.sleep(1) + + sync_runtime_logs(case_dir(), log_path) + print(f"❌ Model failed to load within {timeout}s", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + + +def normalize(text: str) -> str: + return text.strip() + + +def record_chat_artifact( + label: str, + prompt_text: str, + request_payload: dict[str, Any], + response_payload: dict[str, Any], + content: str, + finish_reason: str, + expectations: dict[str, Any], +) -> None: + artifact_root = case_dir() + if artifact_root is None: + return + + chat_dir = artifact_root / "chat" + chat_dir.mkdir(parents=True, exist_ok=True) + payload = { + "label": label, + "prompt": prompt_text, + "request": request_payload, + "raw_response": response_payload, + "content": content, + "response_text": content, + "finish_reason": finish_reason, + "expectations": expectations, + } + (chat_dir / f"{label}.json").write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + + +def fail(message: str, *, content: str = "", response: dict[str, Any] | None = None, log_path: Path | None = None) -> None: + print(f"❌ {message}", file=sys.stderr) + if content: + print(f"Content: {content}", file=sys.stderr) + if response is not None: + print(f"Raw response: {json.dumps(response, ensure_ascii=False)}", file=sys.stderr) + if log_path is not None and log_path.exists(): + print("--- Log tail ---", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + + +def case_failure(message: str, *, content: str = "", response: dict[str, Any] | None = None) -> str: + details = message + if content: + details += f" | Content: {content}" + if response is not None: + details += f" | Raw response: {json.dumps(response, ensure_ascii=False)}" + return details + + +def run_chat( + api_port: int, + prompt_text: str, + *, + max_tokens: int, + enable_thinking: bool, +) -> tuple[dict[str, Any], dict[str, Any], str, str]: + payload = { + "model": "any", + "messages": [{"role": "user", "content": prompt_text}], + "max_tokens": max_tokens, + "temperature": 0, + "top_p": 1, + "top_k": 1, + "seed": 123, + "enable_thinking": enable_thinking, + } + response = http_json( + f"http://127.0.0.1:{api_port}/v1/chat/completions", + payload=payload, + timeout=DEFAULT_REQUEST_TIMEOUT, + ) + choice = response["choices"][0] + content = choice["message"]["content"] + finish_reason = choice.get("finish_reason", "") + return payload, response, content, finish_reason + + +def validate_case( + *, + api_port: int, + case_cfg: dict[str, Any], + default_prompt: str, + log_path: Path, +) -> tuple[bool, str]: + label = case_cfg.get("label", "primary") + prompt_text = case_cfg.get("prompt", default_prompt) + expect_contains = str(case_cfg.get("expect_contains", "")) + expect_contains_ci = str(case_cfg.get("expect_contains_ci", "")) + expect_contains_all_ci = list(case_cfg.get("expect_contains_all_ci", [])) + expect_any_ci = list(case_cfg.get("expect_any_ci", [])) + forbid_contains = str(case_cfg.get("forbid_contains", "")) + expect_exact = str(case_cfg.get("expect_exact", "")) + thinking_mode = str(case_cfg.get("thinking_mode", "")) + max_tokens = int(case_cfg.get("max_tokens", 32) or 32) + + print(f"Testing /v1/chat/completions ({label})...", flush=True) + request_payload: dict[str, Any] = {} + response_payload: dict[str, Any] | None = None + content = "" + finish_reason = "" + error: str | None = None + + try: + request_payload, response_payload, content, finish_reason = run_chat( + api_port, + prompt_text, + max_tokens=max_tokens, + enable_thinking=False, + ) + except Exception as exc: + error = case_failure(f"Request failed: {exc}") + + if error is None and not content: + error = case_failure("Empty response from inference", response=response_payload) + if error is None and "<think>" in content: + error = case_failure("Unexpected reasoning output with enable_thinking=false", content=content) + if error is None and expect_contains and expect_contains not in content: + error = case_failure(f"Response did not contain expected text: {expect_contains}", content=content) + if error is None and expect_contains_ci and expect_contains_ci.lower() not in content.lower(): + error = case_failure( + f"Response did not contain expected text (case-insensitive): {expect_contains_ci}", + content=content, + ) + if error is None and expect_contains_all_ci: + missing = [needle for needle in expect_contains_all_ci if needle.lower() not in content.lower()] + if missing: + error = case_failure( + f"Response did not contain all expected terms (case-insensitive): {', '.join(missing)}", + content=content, + ) + if error is None and expect_any_ci and not any(needle.lower() in content.lower() for needle in expect_any_ci): + error = case_failure( + f"Response did not contain any expected text (case-insensitive): {json.dumps(expect_any_ci)}", + content=content, + ) + if error is None and expect_exact and normalize(content) != normalize(expect_exact): + error = case_failure( + "Response did not exactly match expected text", + content=f"expected={normalize(expect_exact)!r} actual={normalize(content)!r}", + ) + if error is None and forbid_contains and forbid_contains in content: + error = case_failure(f"Response contained forbidden text: {forbid_contains}", content=content) + if error is None and not finish_reason: + error = case_failure("Missing finish_reason in response", response=response_payload) + + record_chat_artifact( + label, + prompt_text, + request_payload, + response_payload or {}, + content, + finish_reason, + { + "expect_contains": expect_contains, + "expect_contains_ci": expect_contains_ci, + "expect_contains_all_ci": expect_contains_all_ci, + "expect_any_ci": expect_any_ci, + "forbid_contains": forbid_contains, + "expect_exact": expect_exact, + }, + ) + if error is not None: + print(f"❌ {error}", flush=True) + return False, error + + print(f"βœ… Inference response: {content}", flush=True) + + if thinking_mode: + print(f"Testing explicit reasoning output ({label})...", flush=True) + try: + think_request, think_response, think_content, _ = run_chat( + api_port, + prompt_text, + max_tokens=64, + enable_thinking=True, + ) + except Exception as exc: + error = case_failure(f"Explicit reasoning request failed: {exc}") + print(f"❌ {error}", flush=True) + return False, error + if not think_content: + error = case_failure("Empty response from explicit reasoning request", response=think_response) + elif thinking_mode == "tagged": + if "<think>" not in think_content: + error = case_failure("Explicit reasoning response did not contain <think> tags", content=think_content) + elif thinking_mode == "multiline": + if think_content == content: + error = case_failure("Explicit reasoning response matched non-thinking response", content=think_content) + elif "\n" not in think_content: + error = case_failure("Explicit reasoning response was not multiline", content=think_content) + else: + fail(f"Unknown thinking mode: {thinking_mode}") + + if error is not None: + print(f"❌ {error}", flush=True) + return False, error + + record_chat_artifact( + f"{label}.thinking", + prompt_text, + think_request, + think_response, + think_content, + "stop", + { + "expect_contains": "", + "expect_contains_ci": "", + "expect_contains_all_ci": [], + "expect_any_ci": [], + "forbid_contains": "", + "expect_exact": "", + }, + ) + print(f"βœ… Explicit reasoning response: {think_content}", flush=True) + + return True, content + + +def write_models_artifact(api_port: int) -> None: + models = http_json(f"http://127.0.0.1:{api_port}/v1/models", timeout=DEFAULT_REQUEST_TIMEOUT) + model_count = len(models.get("data", [])) + if model_count == 0: + fail("No models in /v1/models", response=models) + artifact_root = case_dir() + if artifact_root is not None: + models_dir = artifact_root / "models" + models_dir.mkdir(parents=True, exist_ok=True) + (models_dir / "v1-models.json").write_text( + json.dumps(models, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + print(f"βœ… /v1/models returned {model_count} model(s)", flush=True) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--backend", choices=["gguf", "mlx"], required=True) + parser.add_argument("--mesh-llm", required=True) + parser.add_argument("--model", required=True) + parser.add_argument("--bin-dir", default="") + parser.add_argument("--expected-template-source", default="") + parser.add_argument("--prompt", default="Reply with exactly: blue") + parser.add_argument("--expect-contains", default="") + parser.add_argument("--expect-contains-ci", default="") + parser.add_argument("--forbid-contains", default="") + parser.add_argument("--expect-exact", default="") + parser.add_argument("--max-tokens", type=int, default=32) + parser.add_argument("--prompt-suite-json", default="") + parser.add_argument("--wait-seconds", type=int, default=DEFAULT_WAIT_SECONDS) + args = parser.parse_args() + + if args.backend == "gguf" and not args.bin_dir: + parser.error("--bin-dir is required for gguf backend") + + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) + + api_port = pick_free_port() + console_port = pick_free_port() + while api_port == console_port: + console_port = pick_free_port() + + print("=== CI Exact Smoke Test ===", flush=True) + print(f" backend: {args.backend}", flush=True) + print(f" mesh-llm: {args.mesh_llm}", flush=True) + if args.backend == "gguf": + print(f" bin-dir: {args.bin_dir}", flush=True) + print(f" model: {args.model}", flush=True) + print(f" api port: {api_port}", flush=True) + print(f" os: {os.uname().sysname}", flush=True) + print(f" prompt: {args.prompt}", flush=True) + + if args.backend == "mlx" and args.expected_template_source: + ensure_expected_template_source(args.model, args.expected_template_source) + + with tempfile.TemporaryDirectory(prefix="mesh-llm-exact-") as temp_dir: + os.environ["TMPDIR"] = temp_dir + log_path = Path(temp_dir) / "mesh-llm.log" + with open(log_path, "w", encoding="utf-8") as log_file: + process = subprocess.Popen( + build_launch_command(args, api_port, console_port), + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + env={**os.environ, "RUST_LOG": os.environ.get("RUST_LOG", "info")}, + ) + try: + wait_until_ready( + process, + console_port, + log_path, + args.wait_seconds, + ) + + primary_case = { + "label": "primary", + "prompt": args.prompt, + "expect_contains": args.expect_contains, + "expect_contains_ci": args.expect_contains_ci, + "forbid_contains": args.forbid_contains, + "expect_exact": args.expect_exact, + "max_tokens": args.max_tokens, + } + failures: list[tuple[str, str]] = [] + + ok, details = validate_case( + api_port=api_port, + case_cfg=primary_case, + default_prompt=args.prompt, + log_path=log_path, + ) + if not ok: + failures.append((primary_case["label"], details)) + + if args.prompt_suite_json: + print("Running extra prompt suite...", flush=True) + suite = json.loads(args.prompt_suite_json) + for index, case_cfg in enumerate(suite, start=1): + case_cfg = dict(case_cfg) + case_cfg.setdefault("label", f"case-{index}") + ok, details = validate_case( + api_port=api_port, + case_cfg=case_cfg, + default_prompt=args.prompt, + log_path=log_path, + ) + if not ok: + failures.append((str(case_cfg["label"]), details)) + + print("Testing /v1/models...", flush=True) + write_models_artifact(api_port) + + if failures: + print("\n=== Exact smoke test failed ===", flush=True) + print("Failed cases:", flush=True) + for label, details in failures: + print(f" - {label}: {details}", flush=True) + return 1 + + print("\n=== Exact smoke test passed ===", flush=True) + return 0 + finally: + sync_runtime_logs(case_dir(), log_path) + try: + os.killpg(process.pid, signal.SIGTERM) + except (ProcessLookupError, PermissionError): + pass + time.sleep(2) + sync_runtime_logs(case_dir(), log_path) + try: + os.killpg(process.pid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): + pass + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + sync_runtime_logs(case_dir(), log_path) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/ci-mt-bench-behavior.py b/scripts/ci-mt-bench-behavior.py new file mode 100644 index 00000000..02c0a1c5 --- /dev/null +++ b/scripts/ci-mt-bench-behavior.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python3 +"""Run the full MT-Bench prompt dataset against one mesh-llm model/backend. + +This is a behavior regression harness, not a correctness evaluator. It runs the +entire HuggingFaceH4/mt_bench_prompts dataset against a model and applies cheap +heuristics to catch empty outputs, reasoning leakage, and repetition / looping. +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import re +import signal +import socket +import subprocess +import sys +import tempfile +import time +import urllib.error +import urllib.parse +import urllib.request +from collections import Counter +from pathlib import Path +from typing import Any + +DEFAULT_DATASET = "HuggingFaceH4/mt_bench_prompts" +DATASET_SERVER = "https://datasets-server.huggingface.co/rows" +DEFAULT_WAIT_SECONDS = 300 +DEFAULT_REQUEST_TIMEOUT = 300 + + +def pick_free_port() -> int: + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def http_json(url: str, payload: dict[str, Any] | None = None, timeout: int = 60) -> dict[str, Any]: + if payload is None: + request = urllib.request.Request(url) + else: + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(request, timeout=timeout) as response: + return json.load(response) + + +def fetch_mt_bench_prompts(dataset: str) -> list[dict[str, Any]]: + encoded = urllib.parse.quote(dataset, safe="") + rows: list[dict[str, Any]] = [] + offset = 0 + page_size = 100 + while True: + url = ( + f"{DATASET_SERVER}?dataset={encoded}&config=default&split=train" + f"&offset={offset}&length={page_size}" + ) + payload = http_json(url, timeout=60) + page = payload.get("rows", []) + rows.extend(item["row"] for item in page) + if not payload.get("partial") and len(rows) >= int(payload.get("num_rows_total", len(rows))): + break + if not page: + break + offset += len(page) + return rows + + +def tokenize_words(text: str) -> list[str]: + return re.findall(r"\S+", text.lower()) + + +def split_sentences(text: str) -> list[str]: + parts = re.split(r"(?<=[.!?])\s+|\n+", text) + return [part.strip().lower() for part in parts if part.strip()] + + +def repeated_ngram(tokens: list[str], size: int, threshold: int) -> str | None: + if len(tokens) < size: + return None + counts = Counter(tuple(tokens[i : i + size]) for i in range(0, len(tokens) - size + 1)) + for ngram, count in counts.items(): + if count >= threshold: + return " ".join(ngram) + return None + + +def analyze_output(content: str) -> list[str]: + issues: list[str] = [] + normalized = content.strip() + if not normalized: + return ["empty output"] + if "<think>" in normalized or "</think>" in normalized: + issues.append("reasoning markup leaked with enable_thinking=false") + if len(normalized) > 6000: + issues.append(f"output too long ({len(normalized)} chars)") + + lines = [line.strip() for line in normalized.splitlines() if line.strip()] + repeated_lines = [line for line, count in Counter(lines).items() if count >= 3] + if repeated_lines: + issues.append(f"repeated line x3: {repeated_lines[0][:120]}") + + sentences = split_sentences(normalized) + repeated_sentences = [s for s, count in Counter(sentences).items() if count >= 3] + if repeated_sentences: + issues.append(f"repeated sentence x3: {repeated_sentences[0][:120]}") + + tokens = tokenize_words(normalized) + ngram = repeated_ngram(tokens, size=6, threshold=3) + if ngram is not None: + issues.append(f"repeated 6-gram x3: {ngram[:120]}") + + if len(tokens) >= 80: + tail = tokens[-80:] + unique_ratio = len(set(tail)) / len(tail) + if unique_ratio < 0.30: + issues.append(f"low tail token diversity ({unique_ratio:.2f})") + + return issues + + +def build_launch_command(args: argparse.Namespace, api_port: int, console_port: int) -> list[str]: + command = [args.mesh_llm] + if args.backend == "mlx": + command.extend(["--mlx-file", args.model]) + else: + command.extend(["--gguf-file", args.model, "--bin-dir", args.bin_dir]) + command.extend(["--no-draft", "--port", str(api_port), "--console", str(console_port)]) + return command + + +def behavior_case_dir() -> Path | None: + raw = os.environ.get("VALIDATION_CASE_DIR", "").strip() + if not raw: + return None + return Path(raw) + + +def sync_runtime_logs(case_dir: Path | None, mesh_log_path: Path) -> None: + if case_dir is None: + return + + case_dir.mkdir(parents=True, exist_ok=True) + + if mesh_log_path.exists(): + shutil.copyfile(mesh_log_path, case_dir / "mesh.log") + + temp_dir = Path(tempfile.gettempdir()) + for source_name, target_name in ( + ("mesh-llm-llama-server.log", "llama-server.log"), + ("mesh-llm-rpc-server.log", "rpc-server.log"), + ): + source_path = temp_dir / source_name + if source_path.exists(): + shutil.copyfile(source_path, case_dir / target_name) + + +def write_case_progress( + case_dir: Path | None, + *, + status: str, + backend: str, + model: str, + prompt_count: int, + completed_prompts: int, + failed_prompts: int, + current_prompt_id: str = "", + current_category: str = "", +) -> None: + if case_dir is None: + return + payload = { + "status": status, + "backend": backend, + "model": model, + "prompt_count": prompt_count, + "completed_prompts": completed_prompts, + "failed_prompt_count": failed_prompts, + "current_prompt_id": current_prompt_id, + "current_category": current_category, + } + (case_dir / "progress.json").write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +def wait_until_ready(process: subprocess.Popen[str], console_port: int, log_path: Path, timeout: int) -> None: + status_url = f"http://127.0.0.1:{console_port}/api/status" + case_dir = behavior_case_dir() + for second in range(1, timeout + 1): + sync_runtime_logs(case_dir, log_path) + if process.poll() is not None: + print("❌ mesh-llm exited unexpectedly", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + try: + status = http_json(status_url, timeout=5) + if bool(status.get("llama_ready", False)): + print(f"βœ… Model loaded in {second}s") + return + except Exception: + pass + if second % 15 == 0: + print(f" Still waiting... ({second}s)", flush=True) + time.sleep(1) + sync_runtime_logs(case_dir, log_path) + print(f"❌ Model failed to load within {timeout}s", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + + +def run_chat(api_port: int, messages: list[dict[str, str]], max_tokens: int) -> dict[str, Any]: + payload = { + "model": "any", + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0, + "enable_thinking": False, + } + return http_json( + f"http://127.0.0.1:{api_port}/v1/chat/completions", + payload=payload, + timeout=DEFAULT_REQUEST_TIMEOUT, + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--backend", choices=["gguf", "mlx"], required=True) + parser.add_argument("--mesh-llm", required=True) + parser.add_argument("--model", required=True) + parser.add_argument("--bin-dir", default="") + parser.add_argument("--dataset", default=DEFAULT_DATASET) + parser.add_argument("--max-prompts", type=int, default=0) + parser.add_argument("--max-tokens", type=int, default=192) + parser.add_argument("--wait-seconds", type=int, default=DEFAULT_WAIT_SECONDS) + parser.add_argument("--mesh-log-output", default="") + parser.add_argument("--output-json", required=True) + parser.add_argument("--label", default="") + args = parser.parse_args() + + if args.backend == "gguf" and not args.bin_dir: + parser.error("--bin-dir is required for gguf backend") + + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) + + print("=== MT-Bench Behavior Smoke ===", flush=True) + print(f" backend: {args.backend}", flush=True) + print(f" model: {args.model}", flush=True) + print(f" dataset: {args.dataset}", flush=True) + + prompts = fetch_mt_bench_prompts(args.dataset) + if args.max_prompts > 0: + prompts = prompts[: args.max_prompts] + print(f" prompts: {len(prompts)}", flush=True) + case_dir = behavior_case_dir() + write_case_progress( + case_dir, + status="starting", + backend=args.backend, + model=args.label or args.model, + prompt_count=len(prompts), + completed_prompts=0, + failed_prompts=0, + ) + + api_port = pick_free_port() + console_port = pick_free_port() + while api_port == console_port: + console_port = pick_free_port() + + with tempfile.TemporaryDirectory(prefix="mesh-llm-behavior-") as temp_dir: + log_path = Path(temp_dir) / "mesh-llm.log" + log_file = open(log_path, "w", encoding="utf-8") + process = subprocess.Popen( + build_launch_command(args, api_port, console_port), + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + env={**os.environ, "RUST_LOG": os.environ.get("RUST_LOG", "info")}, + ) + try: + sync_runtime_logs(case_dir, log_path) + wait_until_ready(process, console_port, log_path, args.wait_seconds) + write_case_progress( + case_dir, + status="running", + backend=args.backend, + model=args.label or args.model, + prompt_count=len(prompts), + completed_prompts=0, + failed_prompts=0, + ) + + results: list[dict[str, Any]] = [] + failed = 0 + for index, row in enumerate(prompts, start=1): + prompt_turns = row.get("prompt", []) + messages: list[dict[str, str]] = [] + turn_results: list[dict[str, Any]] = [] + row_failed = False + for turn_index, prompt_text in enumerate(prompt_turns, start=1): + messages.append({"role": "user", "content": prompt_text}) + try: + response = run_chat(api_port, messages, args.max_tokens) + except Exception as exc: + row_failed = True + failed += 1 + turn_results.append( + { + "turn": turn_index, + "prompt": prompt_text, + "failure": f"request failed: {exc}", + } + ) + break + + choice = response["choices"][0] + content = choice["message"]["content"] + issues = analyze_output(content) + finish_reason = choice.get("finish_reason", "") + if not finish_reason: + issues.append("missing finish_reason") + turn_results.append( + { + "turn": turn_index, + "prompt": prompt_text, + "content": content, + "finish_reason": finish_reason, + "issues": issues, + } + ) + messages.append({"role": "assistant", "content": content}) + if issues: + row_failed = True + failed += 1 + break + + results.append( + { + "index": index, + "prompt_id": row.get("prompt_id"), + "category": row.get("category"), + "turns": turn_results, + "passed": not row_failed, + } + ) + status = "PASS" if not row_failed else "FAIL" + print( + f"[{index:02d}/{len(prompts):02d}] {status} {row.get('category')}#{row.get('prompt_id')}", + flush=True, + ) + sync_runtime_logs(case_dir, log_path) + write_case_progress( + case_dir, + status="running", + backend=args.backend, + model=args.label or args.model, + prompt_count=len(prompts), + completed_prompts=index, + failed_prompts=failed, + current_prompt_id=str(row.get("prompt_id", "")), + current_category=str(row.get("category", "")), + ) + + output = { + "label": args.label or args.model, + "backend": args.backend, + "model": args.model, + "dataset": args.dataset, + "prompt_count": len(prompts), + "failed_prompt_count": failed, + "results": results, + } + output_path = Path(args.output_json) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(output, indent=2), encoding="utf-8") + if args.mesh_log_output: + mesh_log_output = Path(args.mesh_log_output) + mesh_log_output.parent.mkdir(parents=True, exist_ok=True) + mesh_log_output.write_text( + log_path.read_text(encoding="utf-8", errors="replace"), + encoding="utf-8", + ) + sync_runtime_logs(case_dir, log_path) + write_case_progress( + case_dir, + status="completed", + backend=args.backend, + model=args.label or args.model, + prompt_count=len(prompts), + completed_prompts=len(prompts), + failed_prompts=failed, + ) + + if failed: + print(f"❌ Behavior smoke failed: {failed} prompt(s) flagged", file=sys.stderr) + print(f"Summary written to {output_path}", file=sys.stderr) + return 1 + print("βœ… Behavior smoke passed") + print(f"Summary written to {output_path}") + return 0 + finally: + try: + os.killpg(process.pid, signal.SIGTERM) + except (ProcessLookupError, PermissionError): + pass + time.sleep(2) + sync_runtime_logs(case_dir, log_path) + try: + os.killpg(process.pid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): + pass + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + sync_runtime_logs(case_dir, log_path) + log_file.close() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/ci-smoke-test.ps1 b/scripts/ci-smoke-test.ps1 deleted file mode 100644 index 79d1a501..00000000 --- a/scripts/ci-smoke-test.ps1 +++ /dev/null @@ -1,141 +0,0 @@ -param( - [Parameter(Mandatory = $true)] - [string]$MeshLlm, - [Parameter(Mandatory = $true)] - [string]$BinDir, - [Parameter(Mandatory = $true)] - [string]$ModelPath, - [Parameter(Mandatory = $false)] - [string]$MmprojPath = "" -) - -$ErrorActionPreference = "Stop" - -$apiPort = 9337 -$consolePort = 3131 -$maxWaitSeconds = 180 -$stdoutLogPath = Join-Path ([System.IO.Path]::GetTempPath()) "mesh-llm-ci.stdout.log" -$stderrLogPath = Join-Path ([System.IO.Path]::GetTempPath()) "mesh-llm-ci.stderr.log" - -function Write-ProcessLogs { - foreach ($path in @($stdoutLogPath, $stderrLogPath)) { - if (Test-Path $path) { - Write-Host "--- $path ---" - Get-Content $path -Tail 80 | Write-Host - } - } -} - -Write-Host "=== CI Smoke Test ===" -Write-Host " mesh-llm: $MeshLlm" -Write-Host " bin-dir: $BinDir" -Write-Host " model: $ModelPath" -if (-not [string]::IsNullOrWhiteSpace($MmprojPath)) { - Write-Host " mmproj: $MmprojPath" -} -Write-Host " api port: $apiPort" -Write-Host " os: Windows" - -if (-not (Test-Path $MeshLlm)) { - throw "Missing mesh-llm binary: $MeshLlm" -} - -Get-ChildItem -Path $BinDir -Filter "rpc-server*" -ErrorAction SilentlyContinue | Format-Table -AutoSize | Out-String | Write-Host -Get-ChildItem -Path $BinDir -Filter "llama-server*" -ErrorAction SilentlyContinue | Format-Table -AutoSize | Out-String | Write-Host - -$process = $null -try { - $arguments = @( - "--model", $ModelPath, - "--no-draft", - "--bin-dir", $BinDir, - "--device", "CPU", - "--port", "$apiPort", - "--console", "$consolePort" - ) - - if (-not [string]::IsNullOrWhiteSpace($MmprojPath)) { - $arguments += @("--mmproj", $MmprojPath) - } - - Write-Host "Starting mesh-llm..." - $process = Start-Process ` - -FilePath $MeshLlm ` - -ArgumentList $arguments ` - -RedirectStandardOutput $stdoutLogPath ` - -RedirectStandardError $stderrLogPath ` - -PassThru - Write-Host " PID: $($process.Id)" - - Write-Host "Waiting for model to load (up to ${maxWaitSeconds}s)..." - for ($i = 1; $i -le $maxWaitSeconds; $i++) { - if ($process.HasExited) { - Write-Host "❌ mesh-llm exited unexpectedly" - Write-ProcessLogs - throw "mesh-llm exited before llama_ready" - } - - try { - $status = Invoke-RestMethod -Uri "http://localhost:$consolePort/api/status" -Method Get -TimeoutSec 3 - if ($status.llama_ready -eq $true) { - Write-Host "βœ… Model loaded in ${i}s" - break - } - } catch { - } - - if ($i -eq $maxWaitSeconds) { - Write-Host "❌ Model failed to load within ${maxWaitSeconds}s" - Write-ProcessLogs - throw "Timed out waiting for llama_ready" - } - - if (($i % 15) -eq 0) { - Write-Host " Still waiting... (${i}s)" - } - Start-Sleep -Seconds 1 - } - - Write-Host "Testing /v1/chat/completions..." - $body = @{ - model = "any" - messages = @(@{ - role = "user" - content = "Say hello in exactly 3 words." - }) - max_tokens = 32 - temperature = 0 - } | ConvertTo-Json -Depth 5 - - $response = Invoke-RestMethod ` - -Uri "http://localhost:$apiPort/v1/chat/completions" ` - -Method Post ` - -ContentType "application/json" ` - -Body $body ` - -TimeoutSec 30 - - $content = $response.choices[0].message.content - if ([string]::IsNullOrWhiteSpace($content)) { - throw "Empty response from inference" - } - Write-Host "βœ… Inference response: $content" - - Write-Host "Testing /v1/models..." - $models = Invoke-RestMethod -Uri "http://localhost:$apiPort/v1/models" -Method Get -TimeoutSec 15 - $modelCount = @($models.data).Count - if ($modelCount -eq 0) { - throw "No models returned from /v1/models" - } - Write-Host "βœ… /v1/models returned $modelCount model(s)" - Write-Host "" - Write-Host "=== All smoke tests passed ===" -} finally { - if ($process) { - Write-Host "Shutting down mesh-llm (PID $($process.Id))..." - try { - taskkill /PID $process.Id /T /F | Out-Null - } catch { - } - Start-Sleep -Seconds 2 - } -} diff --git a/scripts/ci-thinking-smoke.py b/scripts/ci-thinking-smoke.py new file mode 100644 index 00000000..9c7d991a --- /dev/null +++ b/scripts/ci-thinking-smoke.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python3 +"""Run a focused thinking-enabled smoke suite against one mesh-llm model/backend.""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import shutil +import signal +import socket +import subprocess +import sys +import tempfile +import time +import urllib.request +from pathlib import Path +from typing import Any + +DEFAULT_WAIT_SECONDS = 300 +DEFAULT_REQUEST_TIMEOUT = 300 + + +def pick_free_port() -> int: + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def http_json(url: str, payload: dict[str, Any] | None = None, timeout: int = 60) -> dict[str, Any]: + if payload is None: + request = urllib.request.Request(url) + else: + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(request, timeout=timeout) as response: + return json.load(response) + + +def case_dir() -> Path | None: + raw = os.environ.get("VALIDATION_CASE_DIR", "").strip() + if not raw: + return None + return Path(raw) + + +def sync_runtime_logs(case_directory: Path | None, mesh_log_path: Path) -> None: + if case_directory is None: + return + case_directory.mkdir(parents=True, exist_ok=True) + if mesh_log_path.exists(): + shutil.copyfile(mesh_log_path, case_directory / "mesh.log") + + temp_dir = Path(tempfile.gettempdir()) + for source_name, target_name in ( + ("mesh-llm-llama-server.log", "llama-server.log"), + ("mesh-llm-rpc-server.log", "rpc-server.log"), + ): + source_path = temp_dir / source_name + if source_path.exists(): + shutil.copyfile(source_path, case_directory / target_name) + + +def model_root_for(model_arg: str) -> Path: + path = Path(model_arg) + return path if path.is_dir() else path.parent + + +def ensure_expected_template_source(model_arg: str, expected_template_source: str) -> None: + model_root = model_root_for(model_arg) + expected_path = model_root / expected_template_source + if not expected_path.exists(): + print( + f"❌ Expected template source file not found in model directory: {expected_template_source}", + file=sys.stderr, + ) + print(f"Model directory: {model_root}", file=sys.stderr) + raise SystemExit(1) + + +def build_launch_command(args: argparse.Namespace, api_port: int, console_port: int) -> list[str]: + command = [args.mesh_llm] + if args.backend == "mlx": + command.extend(["--mlx-file", args.model]) + else: + command.extend(["--gguf-file", args.model, "--bin-dir", args.bin_dir]) + command.extend(["--no-draft", "--port", str(api_port), "--console", str(console_port)]) + return command + + +def wait_until_ready(process: subprocess.Popen[str], console_port: int, log_path: Path, timeout: int) -> None: + status_url = f"http://127.0.0.1:{console_port}/api/status" + for second in range(1, timeout + 1): + sync_runtime_logs(case_dir(), log_path) + if process.poll() is not None: + print("❌ mesh-llm exited unexpectedly", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + try: + status = http_json(status_url, timeout=5) + if bool(status.get("llama_ready", False)): + print(f"βœ… Model loaded in {second}s", flush=True) + return + except Exception: + pass + if second % 15 == 0: + print(f" Still waiting... ({second}s)", flush=True) + time.sleep(1) + sync_runtime_logs(case_dir(), log_path) + print(f"❌ Model failed to load within {timeout}s", file=sys.stderr) + print(log_path.read_text(encoding="utf-8", errors="replace")[-8000:], file=sys.stderr) + raise SystemExit(1) + + +def write_progress( + *, + status: str, + backend: str, + model: str, + check_count: int, + completed_checks: int, + failed_checks: int, + current_label: str = "", +) -> None: + out_dir = case_dir() + if out_dir is None: + return + payload = { + "status": status, + "backend": backend, + "model": model, + "check_count": check_count, + "completed_checks": completed_checks, + "failed_check_count": failed_checks, + "current_label": current_label, + } + (out_dir / "progress.json").write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +def run_chat(api_port: int, messages: list[dict[str, str]], max_tokens: int) -> tuple[dict[str, Any], dict[str, Any], str, str]: + payload = { + "model": "any", + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0, + "top_p": 1, + "top_k": 1, + "seed": 123, + "enable_thinking": True, + } + response = http_json( + f"http://127.0.0.1:{api_port}/v1/chat/completions", + payload=payload, + timeout=DEFAULT_REQUEST_TIMEOUT, + ) + choice = response["choices"][0] + content = choice["message"]["content"] + finish_reason = choice.get("finish_reason", "") + return payload, response, content, finish_reason + + +def strip_tagged_reasoning(content: str) -> str: + stripped = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL) + stripped = re.sub(r"<\|channel>thought.*?<channel\|>", "", stripped, flags=re.DOTALL) + return stripped.strip() + + +def validate_case( + *, + api_port: int, + case_cfg: dict[str, Any], + thinking_mode: str, + max_tokens_override: int | None, +) -> dict[str, Any]: + label = str(case_cfg.get("label", "thinking")) + messages = list(case_cfg.get("messages", [])) + max_tokens = int(case_cfg.get("max_tokens", 128) or 128) + if max_tokens_override is not None: + max_tokens = max_tokens_override + if not messages: + return { + "label": label, + "passed": False, + "issues": ["missing messages"], + "request": {}, + "response": {}, + "content": "", + "finish_reason": "", + } + + try: + request_payload, response_payload, content, finish_reason = run_chat(api_port, messages, max_tokens) + except Exception as exc: + return { + "label": label, + "passed": False, + "issues": [f"request failed: {exc}"], + "request": {}, + "response": {}, + "content": "", + "finish_reason": "", + } + + issues: list[str] = [] + normalized = content.strip() + if not normalized: + issues.append("empty output") + if not finish_reason: + issues.append("missing finish_reason") + if thinking_mode == "tagged": + has_marker = "<think>" in content or "<|channel>thought" in content + if not has_marker: + issues.append("missing tagged reasoning marker") + visible_answer = strip_tagged_reasoning(content) + if not visible_answer: + issues.append("missing answer outside reasoning tags") + elif thinking_mode == "multiline": + if "\n" not in normalized: + issues.append("expected multiline reasoning output") + + return { + "label": label, + "passed": not issues, + "issues": issues, + "request": request_payload, + "response": response_payload, + "content": content, + "finish_reason": finish_reason, + } + + +def write_models_artifact(api_port: int) -> None: + models = http_json(f"http://127.0.0.1:{api_port}/v1/models", timeout=DEFAULT_REQUEST_TIMEOUT) + model_count = len(models.get("data", [])) + if model_count == 0: + print("❌ No models in /v1/models", file=sys.stderr) + raise SystemExit(1) + artifact_root = case_dir() + if artifact_root is not None: + models_dir = artifact_root / "models" + models_dir.mkdir(parents=True, exist_ok=True) + (models_dir / "v1-models.json").write_text( + json.dumps(models, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + print(f"βœ… /v1/models returned {model_count} model(s)", flush=True) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--backend", choices=["gguf", "mlx"], required=True) + parser.add_argument("--mesh-llm", required=True) + parser.add_argument("--model", required=True) + parser.add_argument("--bin-dir", default="") + parser.add_argument("--expected-template-source", default="") + parser.add_argument("--prompt-suite-json", required=True) + parser.add_argument("--thinking-mode", default="nonempty") + parser.add_argument("--max-tokens", type=int, default=0) + parser.add_argument("--wait-seconds", type=int, default=DEFAULT_WAIT_SECONDS) + parser.add_argument("--label", default="") + parser.add_argument("--output-json", required=True) + parser.add_argument("--mesh-log-output", default="") + args = parser.parse_args() + + if args.backend == "gguf" and not args.bin_dir: + parser.error("--bin-dir is required for gguf backend") + if args.backend == "mlx" and args.expected_template_source: + ensure_expected_template_source(args.model, args.expected_template_source) + max_tokens_override = args.max_tokens if args.max_tokens > 0 else None + + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) + + prompt_suite = json.loads(args.prompt_suite_json) + api_port = pick_free_port() + console_port = pick_free_port() + while api_port == console_port: + console_port = pick_free_port() + + print("=== Thinking Smoke ===", flush=True) + print(f" backend: {args.backend}", flush=True) + print(f" model: {args.model}", flush=True) + print(f" checks: {len(prompt_suite)}", flush=True) + print(f" mode: {args.thinking_mode}", flush=True) + write_progress( + status="starting", + backend=args.backend, + model=args.label or args.model, + check_count=len(prompt_suite), + completed_checks=0, + failed_checks=0, + ) + + with tempfile.TemporaryDirectory(prefix="mesh-llm-thinking-") as temp_dir: + os.environ["TMPDIR"] = temp_dir + log_path = Path(temp_dir) / "mesh-llm.log" + if args.mesh_log_output: + mesh_log_output = Path(args.mesh_log_output) + else: + mesh_log_output = log_path + + launch_cmd = build_launch_command(args, api_port, console_port) + with open(log_path, "w", encoding="utf-8") as log_file: + process = subprocess.Popen( + launch_cmd, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + try: + wait_until_ready(process, console_port, log_path, args.wait_seconds) + write_models_artifact(api_port) + + results: list[dict[str, Any]] = [] + failed_checks = 0 + for index, case_cfg in enumerate(prompt_suite, start=1): + label = str(case_cfg.get("label", index)) + write_progress( + status="running", + backend=args.backend, + model=args.label or args.model, + check_count=len(prompt_suite), + completed_checks=index - 1, + failed_checks=failed_checks, + current_label=label, + ) + result = validate_case( + api_port=api_port, + case_cfg=case_cfg, + thinking_mode=args.thinking_mode, + max_tokens_override=max_tokens_override, + ) + results.append(result) + if result["passed"]: + print(f"[{index:02d}/{len(prompt_suite)}] PASS {label}", flush=True) + else: + failed_checks += 1 + print(f"[{index:02d}/{len(prompt_suite)}] FAIL {label}", flush=True) + + payload = { + "backend": args.backend, + "label": args.label, + "model": args.model, + "thinking_mode": args.thinking_mode, + "failed_check_count": failed_checks, + "check_count": len(prompt_suite), + "results": results, + } + Path(args.output_json).write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + if mesh_log_output != log_path: + sync_runtime_logs(case_dir(), log_path) + if log_path.exists(): + shutil.copyfile(log_path, mesh_log_output) + write_progress( + status="completed", + backend=args.backend, + model=args.label or args.model, + check_count=len(prompt_suite), + completed_checks=len(prompt_suite), + failed_checks=failed_checks, + ) + if failed_checks: + print(f"❌ Thinking smoke failed: {failed_checks} check(s) flagged", flush=True) + return 1 + print("βœ… Thinking smoke passed", flush=True) + return 0 + finally: + sync_runtime_logs(case_dir(), log_path) + try: + process.send_signal(signal.SIGINT) + process.wait(timeout=20) + except Exception: + process.kill() + process.wait(timeout=10) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/download-origin-checkpoints-studio54.sh b/scripts/download-origin-checkpoints-studio54.sh new file mode 100755 index 00000000..9f3de252 --- /dev/null +++ b/scripts/download-origin-checkpoints-studio54.sh @@ -0,0 +1,239 @@ +#!/bin/zsh + +set -u + +if [[ -z "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN is not set. Export a Hugging Face token before running this script." >&2 + exit 1 +fi + +export HF_TOKEN +export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN" + +TARGET_ROOT="${TARGET_ROOT:-$HOME/.cache/mesh-llm-origin-batch}" +LOG_ROOT="${LOG_ROOT:-$TARGET_ROOT/_logs}" + +mkdir -p "$TARGET_ROOT" "$LOG_ROOT" + +typeset -a HF_CANDIDATES +HF_CANDIDATES=( + "${HF_BIN:-}" + "hf" + "$HOME/Library/Python/3.9/bin/hf" + "$HOME/Library/Python/3.10/bin/hf" + "$HOME/Library/Python/3.11/bin/hf" + "$HOME/Library/Python/3.12/bin/hf" + "/opt/homebrew/bin/hf" + "/usr/local/bin/hf" +) + +HF_BIN_RESOLVED="" +for candidate in "${HF_CANDIDATES[@]}"; do + if [[ -z "$candidate" ]]; then + continue + fi + + if [[ "$candidate" == */* ]]; then + if [[ -x "$candidate" ]]; then + HF_BIN_RESOLVED="$candidate" + break + fi + elif command -v "$candidate" >/dev/null 2>&1; then + HF_BIN_RESOLVED="$(command -v "$candidate")" + break + fi +done + +if [[ -z "$HF_BIN_RESOLVED" ]]; then + echo "Could not find 'hf'. Set HF_BIN=/absolute/path/to/hf if needed." >&2 + exit 1 +fi + +HF_BIN="$HF_BIN_RESOLVED" + +typeset -a SPECS +SPECS=( + "deepseek|deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + "olmo|allenai/OLMo-1B-hf" + "mamba2|state-spaces/mamba2-2.7b" + "phi3|microsoft/Phi-3-mini-4k-instruct" + "phi4-mini|microsoft/Phi-4-mini-instruct" + "minicpm|openbmb/MiniCPM3-4B" + "mamba|state-spaces/mamba-2.8b-hf" + "starcoder2|bigcode/starcoder2-3b" + "olmo2|allenai/OLMo-2-1124-7B-Instruct" + "cohere2|CohereLabs/c4ai-command-r7b-12-2024" + "mistral|mistralai/Mistral-7B-Instruct-v0.3" +) + +typeset -a HEAVY_SPECS +HEAVY_SPECS=( + "cohere-command-r|CohereLabs/c4ai-command-r-v01" + "jamba|ai21labs/AI21-Jamba-1.5-Mini" + "mixtral|mistralai/Mixtral-8x7B-Instruct-v0.1" +) + +function usage() { + cat <<'EOF' +Usage: + zsh scripts/download-origin-checkpoints-studio54.sh + zsh scripts/download-origin-checkpoints-studio54.sh mistral phi3 deepseek + zsh scripts/download-origin-checkpoints-studio54.sh --include-heavy + TARGET_ROOT=~/mesh-origin zsh scripts/download-origin-checkpoints-studio54.sh + +Notes: + - Downloads run serially with hf download into per-model directories. + - The default set excludes families that are too large or risky for the + full same-origin test workflow on studio54's 128 GB M1 Ultra. + - Use --include-heavy to add: mixtral, cohere-command-r, jamba. + - If one repo fails or is gated, the script continues and reports it at the end. + - Some repos in this list are very large and can consume substantial disk. +EOF +} + +typeset -a FILTERS +FILTERS=() +INCLUDE_HEAVY=0 + +while (( $# > 0 )); do + case "$1" in + -h|--help) + usage + exit 0 + ;; + --include-heavy) + INCLUDE_HEAVY=1 + ;; + *) + FILTERS+=("$1") + ;; + esac + shift +done + +function should_download() { + local slug="$1" + + if (( ${#FILTERS[@]} == 0 )); then + return 0 + fi + + local filter + for filter in "${FILTERS[@]}"; do + if [[ "$slug" == "$filter" ]]; then + return 0 + fi + done + + return 1 +} + +function print_progress_snapshot() { + local target_dir="$1" + + if [[ ! -d "$target_dir" ]]; then + return + fi + + echo " progress: $(du -sh "$target_dir" 2>/dev/null | awk '{print $1}') downloaded so far" + + local files + files=$(find "$target_dir" -maxdepth 1 -type f \( -name '*.safetensors' -o -name '*.gguf' \) -print 2>/dev/null | sort) + if [[ -n "$files" ]]; then + echo "$files" | xargs ls -lh 2>/dev/null | tail -n 3 | sed 's/^/ /' + fi +} + +function run_download() { + local repo="$1" + local target_dir="$2" + local log_file="$3" + + "$HF_BIN" download "$repo" \ + --token "$HF_TOKEN" \ + --local-dir "$target_dir" \ + > >(tee "$log_file") \ + 2> >(tee -a "$log_file" >&2) & + + local download_pid=$! + + while kill -0 "$download_pid" >/dev/null 2>&1; do + sleep 20 + if kill -0 "$download_pid" >/dev/null 2>&1; then + print_progress_snapshot "$target_dir" + fi + done + + wait "$download_pid" +} + +typeset -a SUCCEEDED +typeset -a FAILED +typeset -a SKIPPED +typeset -i index=1 + +echo "Using hf binary: $HF_BIN" +echo "Target root: $TARGET_ROOT" +echo "Log root: $LOG_ROOT" +if (( INCLUDE_HEAVY == 1 )); then + echo "Heavy families: included" +else + echo "Heavy families: excluded by default" +fi +echo + +typeset -a ALL_SPECS +ALL_SPECS=("${SPECS[@]}") +if (( INCLUDE_HEAVY == 1 )); then + ALL_SPECS+=("${HEAVY_SPECS[@]}") +fi + +for spec in "${ALL_SPECS[@]}"; do + slug="${spec%%|*}" + repo="${spec#*|}" + + if ! should_download "$slug"; then + SKIPPED+=("$slug") + continue + fi + + target_dir="$TARGET_ROOT/$slug" + log_file="$LOG_ROOT/$slug.log" + + mkdir -p "$target_dir" + + echo "[$index/${#ALL_SPECS[@]}] Downloading $slug from $repo" + echo " target: $target_dir" + echo " log: $log_file" + + if run_download "$repo" "$target_dir" "$log_file"; then + SUCCEEDED+=("$slug") + echo " status: ok" + else + FAILED+=("$slug") + echo " status: failed" + fi + + echo + index+=1 +done + +echo "Summary" +echo " succeeded: ${#SUCCEEDED[@]}" +if (( ${#SUCCEEDED[@]} > 0 )); then + printf ' %s\n' "${SUCCEEDED[@]}" +fi + +echo " failed: ${#FAILED[@]}" +if (( ${#FAILED[@]} > 0 )); then + printf ' %s\n' "${FAILED[@]}" +fi + +echo " skipped: ${#SKIPPED[@]}" +if (( ${#SKIPPED[@]} > 0 )); then + printf ' %s\n' "${SKIPPED[@]}" +fi + +if (( ${#FAILED[@]} > 0 )); then + exit 1 +fi diff --git a/scripts/mlx-parity-exact.tsv b/scripts/mlx-parity-exact.tsv new file mode 100644 index 00000000..872ec0f7 --- /dev/null +++ b/scripts/mlx-parity-exact.tsv @@ -0,0 +1,21 @@ +backend case_id model_ref template_source +gguf olmo7b-gguf-exact meshllm/olmo-7b-instruct-hf-parity-f16-gguf/olmo-7b-instruct-hf-f16.gguf - +mlx olmo7b-mlx-exact meshllm/olmo-7b-instruct-hf-parity-bf16-mlx chat_template.jinja +gguf mistral-gguf-exact meshllm/mistral-7b-instruct-v0.3-parity-f16-gguf/mistral-7b-instruct-v0.3-f16.gguf - +mlx mistral-mlx-exact meshllm/mistral-7b-instruct-v0.3-parity-bf16-mlx chat_template.jinja +gguf qwen25-gguf-exact meshllm/qwen2.5-0.5b-instruct-parity-q8_0-gguf/qwen2.5-0.5b-instruct-q8_0.gguf - +mlx qwen25-mlx-exact meshllm/qwen2.5-0.5b-instruct-parity-8bit-mlx chat_template.jinja +gguf qwen3-gguf-exact meshllm/qwen3-8b-parity-q8_0-gguf/qwen3-8b-q8_0.gguf - +mlx qwen3-mlx-exact meshllm/qwen3-8b-parity-8bit-mlx/model-00001-of-00002.safetensors chat_template.jinja +gguf llama32-gguf-exact meshllm/llama-3.2-1b-instruct-parity-f16-gguf/llama-3.2-1b-instruct-f16.gguf - +mlx llama32-mlx-exact meshllm/llama-3.2-1b-instruct-parity-bf16-mlx chat_template.jinja +gguf gemma2-gguf-exact meshllm/gemma-2-2b-it-parity-q8_0-gguf/gemma-2-2b-it-q8_0.gguf - +mlx gemma2-mlx-exact meshllm/gemma-2-2b-it-parity-8bit-mlx chat_template.jinja +gguf gemma3-gguf-exact meshllm/gemma-3-1b-it-parity-f16-gguf/gemma-3-1b-it-f16.gguf - +mlx gemma3-mlx-exact meshllm/gemma-3-1b-it-parity-bf16-mlx tokenizer_config.json +gguf gemma4-gguf-exact meshllm/gemma-4-e4b-it-parity-q8_0-gguf/gemma-4-e4b-it-q8_0.gguf - +mlx gemma4-mlx-exact meshllm/gemma-4-e4b-it-parity-8bit-mlx chat_template.jinja +gguf glm4-gguf-exact meshllm/glm-4-9b-0414-parity-q4_k_m-gguf/glm-4-9b-0414-q4_k_m.gguf - +mlx glm4-mlx-exact meshllm/glm-4-9b-0414-parity-4bit-mlx chat_template.jinja +gguf lfm2-gguf-exact meshllm/lfm2-350m-parity-q4_k_m-gguf/lfm2-350m-q4_k_m.gguf - +mlx lfm2-mlx-exact meshllm/lfm2-350m-parity-4bit-mlx chat_template.jinja diff --git a/scripts/package-release.sh b/scripts/package-release.sh index 7f87f3b3..acf41cca 100755 --- a/scripts/package-release.sh +++ b/scripts/package-release.sh @@ -86,6 +86,23 @@ copy_runtime_libs() { shopt -u nullglob } +copy_mlx_metallib() { + local bundle_dir="$1" + if [[ "$(uname -s)" != "Darwin" ]]; then + return + fi + + shopt -s nullglob + local metallibs=("$REPO_ROOT"/target/release/build/mlx-sys-*/out/build/lib/mlx.metallib) + shopt -u nullglob + + if [[ ${#metallibs[@]} -gt 0 ]]; then + cp "${metallibs[0]}" "$bundle_dir/mlx.metallib" + else + echo "Note: mlx.metallib not found in target/release/build β€” MLX remote runs may fail" >&2 + fi +} + bundle_bin_name() { local name="$1" if [[ "$name" == "mesh-llm" ]]; then @@ -353,6 +370,7 @@ main() { cp "$BUILD_BIN_DIR/llama-moe-analyze${BIN_EXT}" "$bundle_dir/llama-moe-analyze" cp "$BUILD_BIN_DIR/llama-moe-split${BIN_EXT}" "$bundle_dir/llama-moe-split" copy_runtime_libs "$bundle_dir" + copy_mlx_metallib "$bundle_dir" if [[ "$os_name" == "Darwin" ]]; then for bin in "$bundle_dir/$(bundle_bin_name mesh-llm)" "$bundle_dir/$(bundle_bin_name rpc-server)" "$bundle_dir/$(bundle_bin_name llama-server)" "$bundle_dir/llama-moe-analyze" "$bundle_dir/llama-moe-split"; do diff --git a/scripts/run-validation-case.sh b/scripts/run-validation-case.sh new file mode 100755 index 00000000..eb309a64 --- /dev/null +++ b/scripts/run-validation-case.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +set -euo pipefail + +if [ "$#" -lt 3 ]; then + echo "usage: $0 <backend> <case-id> <command...>" >&2 + exit 2 +fi + +BACKEND="$1" +CASE_ID="$2" +shift 2 + +ROOT="${VALIDATION_RESULTS_ROOT:-$PWD/.cache/mlx-validation}" +STAMP="${VALIDATION_RESULTS_STAMP:-$(date +%Y%m%d-%H%M%S)}" +CASE_DIR="$ROOT/$STAMP/$CASE_ID" + +mkdir -p "$CASE_DIR" +export VALIDATION_CASE_DIR="$CASE_DIR" + +printf '%s\n' "$BACKEND" > "$CASE_DIR/backend.txt" +printf '%s\n' "$CASE_ID" > "$CASE_DIR/case.txt" +printf '%s\n' "$PWD" > "$CASE_DIR/cwd.txt" +printf '%q ' "$@" > "$CASE_DIR/command.sh" +printf '\n' >> "$CASE_DIR/command.sh" + +STARTED_AT="$(date -u +%Y-%m-%dT%H:%M:%SZ)" +cat > "$CASE_DIR/state.json" <<EOF +{ + "backend": "$BACKEND", + "case_id": "$CASE_ID", + "status": "running", + "started_at": "$STARTED_AT" +} +EOF + +set +e +"$@" > >(tee "$CASE_DIR/stdout.log") 2> >(tee "$CASE_DIR/stderr.log" >&2) +STATUS=$? +set -e + +printf '%s\n' "$STATUS" > "$CASE_DIR/exit_code.txt" + +if [ "$BACKEND" = "gguf" ] && [ -f /tmp/mesh-llm-ci-gguf.log ]; then + cp -f /tmp/mesh-llm-ci-gguf.log "$CASE_DIR/mesh.log" +fi + +if [ "$BACKEND" = "mlx" ] && [ -f /tmp/mesh-llm-ci-mlx.log ]; then + cp -f /tmp/mesh-llm-ci-mlx.log "$CASE_DIR/mesh.log" +fi + +cat > "$CASE_DIR/meta.json" <<EOF +{ + "backend": "$BACKEND", + "case_id": "$CASE_ID", + "exit_code": $STATUS +} +EOF + +FINISHED_AT="$(date -u +%Y-%m-%dT%H:%M:%SZ)" +cat > "$CASE_DIR/state.json" <<EOF +{ + "backend": "$BACKEND", + "case_id": "$CASE_ID", + "status": "completed", + "started_at": "$STARTED_AT", + "finished_at": "$FINISHED_AT", + "exit_code": $STATUS +} +EOF + +exit "$STATUS" diff --git a/scripts/run-validation-matrix.py b/scripts/run-validation-matrix.py new file mode 100755 index 00000000..0df3ad70 --- /dev/null +++ b/scripts/run-validation-matrix.py @@ -0,0 +1,1692 @@ +#!/usr/bin/env python3 +"""Run the shared GGUF/MLX validation matrix locally. + +This orchestrates the deterministic exact suite and the MT-Bench-derived +behavior suite from one checked-in matrix definition. Each backend can be run +independently, or both can be run together, while preserving raw artifacts +under one stamped results tree. +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import shlex +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_MATRIX = REPO_ROOT / "testdata" / "validation" / "matrix.json" +DEFAULT_BASELINES = REPO_ROOT / "testdata" / "validation" / "baselines.json" +DEFAULT_ROOT = REPO_ROOT / ".cache" / "mlx-validation" +DEFAULT_WAIT_SECONDS = 300 +COMMON_BIN_DIRS = ["/opt/homebrew/bin", "/usr/local/bin"] + + +def log(message: str) -> None: + sys.stderr.write(message + "\n") + sys.stderr.flush() + + +def merged_env(env: dict[str, str] | None = None) -> dict[str, str]: + base = dict(os.environ) + if env: + base.update(env) + path_entries = [entry for entry in base.get("PATH", "").split(os.pathsep) if entry] + for entry in reversed(COMMON_BIN_DIRS): + if entry not in path_entries: + path_entries.insert(0, entry) + base["PATH"] = os.pathsep.join(path_entries) + return base + + +def resolve_command(cmd: list[str], env: dict[str, str]) -> list[str]: + if not cmd: + return cmd + executable = cmd[0] + if "/" in executable: + return cmd + resolved = shutil.which(executable, path=env.get("PATH", "")) + if resolved: + return [resolved, *cmd[1:]] + return cmd + + +def run( + cmd: list[str], + *, + cwd: Path = REPO_ROOT, + env: dict[str, str] | None = None, + capture_output: bool = False, +) -> subprocess.CompletedProcess[str]: + final_env = merged_env(env) + return subprocess.run( + resolve_command(cmd, final_env), + cwd=cwd, + env=final_env, + text=True, + capture_output=capture_output, + check=False, + ) + + +def run_streaming( + cmd: list[str], + *, + cwd: Path = REPO_ROOT, + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[str]: + final_env = merged_env(env) + proc = subprocess.Popen( + resolve_command(cmd, final_env), + cwd=cwd, + env=final_env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=0, + ) + chunks: list[bytes] = [] + assert proc.stdout is not None + stderr_buffer = getattr(sys.stderr, "buffer", None) + while True: + chunk = proc.stdout.read(4096) + if not chunk: + break + chunks.append(chunk) + if stderr_buffer is not None: + stderr_buffer.write(chunk) + stderr_buffer.flush() + else: + sys.stderr.write(chunk.decode("utf-8", errors="replace")) + sys.stderr.flush() + returncode = proc.wait() + return subprocess.CompletedProcess( + cmd, + returncode, + b"".join(chunks).decode("utf-8", errors="replace"), + "", + ) + + +def ensure_build(skip_build: bool) -> None: + if skip_build: + return + rc = run(["just", "build"]) + if rc.returncode != 0: + raise SystemExit(rc.returncode) + + +def load_matrix(path: Path) -> dict[str, Any]: + return json.loads(path.read_text(encoding="utf-8")) + + +def load_baselines(path: Path | None) -> dict[str, Any]: + if path is None or not path.exists(): + return {} + return json.loads(path.read_text(encoding="utf-8")) + + +def selected_models( + matrix: dict[str, Any], + selectors: set[str], + backend_filter: str, +) -> list[dict[str, Any]]: + models: list[dict[str, Any]] = [] + for model in matrix["models"]: + if selectors: + candidate_keys = { + model["id"], + model["label"], + } + for backend_name in ("gguf", "mlx"): + if backend_name in model: + candidate_keys.add(model[backend_name].get("exact_case_id", "")) + candidate_keys.add(model[backend_name].get("behavior_case_id", "")) + candidate_keys.add(model[backend_name].get("thinking_case_id", "")) + if not candidate_keys.intersection(selectors): + continue + if backend_filter in ("gguf", "mlx") and backend_filter not in model: + continue + models.append(model) + return models + + +def requested_backends(model: dict[str, Any], backend_filter: str) -> list[str]: + if backend_filter == "both": + return [backend for backend in ("gguf", "mlx") if backend in model] + if backend_filter in model: + return [backend_filter] + return [] + + +def parse_downloaded_gguf_path(output: str) -> str: + for line in output.splitlines(): + trimmed = line.strip() + if trimmed.startswith("/") and trimmed.endswith(".gguf"): + return trimmed + raise RuntimeError("could not determine downloaded gguf path") + + +def parse_downloaded_model_path(output: str) -> str: + for line in output.splitlines(): + trimmed = line.strip() + if not trimmed.startswith("/"): + continue + if trimmed.endswith(".gguf") or trimmed.endswith(".json") or trimmed.endswith(".safetensors"): + return trimmed + raise RuntimeError("could not determine downloaded model path") + + +def download_model_ref(model_ref: str, backend: str) -> str: + log(f"πŸ“₯ Preflight download start [{backend}] {model_ref}") + proc = run_streaming( + ["./target/release/mesh-llm", "models", "download", model_ref], + ) + if proc.returncode != 0: + log(f"❌ Preflight download failed [{backend}] {model_ref} (exit {proc.returncode})") + raise SystemExit(proc.returncode) + local_path = parse_downloaded_model_path(proc.stdout) + log(f"βœ… Preflight download complete [{backend}] {model_ref}") + log(f" ↳ {local_path}") + return local_path + + +def summary_path(root: Path, stamp: str, suite: str) -> Path: + return root / stamp / f"{suite}-summary.tsv" + + +def suite_stamp(stamp: str, suite: str) -> str: + return f"{stamp}/{suite}" + + +def append_tsv(path: Path, header: list[str], row: list[str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + path.write_text("\t".join(header) + "\n", encoding="utf-8") + with path.open("a", encoding="utf-8") as handle: + handle.write("\t".join(row) + "\n") + + +def case_dir(root: Path, stamp: str, suite: str, case_id: str) -> Path: + return root / stamp / suite / case_id + + +def exact_config_for(matrix: dict[str, Any], model: dict[str, Any]) -> dict[str, Any]: + exact_cfg = dict(matrix["defaults"]["exact"]) + if "prompt_suite" in exact_cfg: + exact_cfg["prompt_suite"] = [dict(item) for item in exact_cfg["prompt_suite"]] + model_exact = model.get("exact") + if isinstance(model_exact, dict): + exact_cfg.update(model_exact) + if "prompt_suite" in model_exact: + exact_cfg["prompt_suite"] = [dict(item) for item in model_exact["prompt_suite"]] + return exact_cfg + + +def thinking_config_for(matrix: dict[str, Any], model: dict[str, Any], backend: str) -> dict[str, Any]: + thinking_cfg = dict(matrix["defaults"].get("thinking", {})) + if "prompt_suite" in thinking_cfg: + thinking_cfg["prompt_suite"] = [dict(item) for item in thinking_cfg["prompt_suite"]] + model_thinking = model.get("thinking") + if isinstance(model_thinking, dict): + thinking_cfg.update(model_thinking) + if "prompt_suite" in model_thinking: + thinking_cfg["prompt_suite"] = [dict(item) for item in model_thinking["prompt_suite"]] + backend_cfg = model.get(backend, {}) + backend_thinking = backend_cfg.get("thinking") + if isinstance(backend_thinking, dict): + thinking_cfg.update(backend_thinking) + if "prompt_suite" in backend_thinking: + thinking_cfg["prompt_suite"] = [dict(item) for item in backend_thinking["prompt_suite"]] + elif "thinking_mode" in backend_cfg: + thinking_cfg["thinking_mode"] = backend_cfg["thinking_mode"] + return thinking_cfg + + +def run_exact_case( + root: Path, + stamp: str, + matrix: dict[str, Any], + model: dict[str, Any], + backend: str, + resolved_models: dict[tuple[str, str], str], +) -> int: + exact_defaults = exact_config_for(matrix, model) + backend_cfg = model[backend] + case_id = backend_cfg["exact_case_id"] + env = { + **os.environ, + "VALIDATION_RESULTS_ROOT": str(root), + "VALIDATION_RESULTS_STAMP": suite_stamp(stamp, "exact"), + } + + run(["just", "stop"], env=env) + + prompt_suite_json = json.dumps(exact_defaults["prompt_suite"], separators=(",", ":")) + cmd = [ + str(REPO_ROOT / "scripts" / "run-validation-case.sh"), + backend, + case_id, + "python3", + str(REPO_ROOT / "scripts" / "ci-exact-smoke.py"), + "--backend", + backend, + "--mesh-llm", + "target/release/mesh-llm", + "--prompt", + exact_defaults["prompt"], + "--expect-contains", + exact_defaults["expect_contains"], + "--expect-contains-ci", + exact_defaults.get("expect_contains_ci", ""), + "--forbid-contains", + exact_defaults["forbid_contains"], + "--expect-exact", + exact_defaults["expect_exact"], + "--prompt-suite-json", + prompt_suite_json, + ] + if backend == "gguf": + gguf_path = resolved_models[(backend, backend_cfg["model_ref"])] + cmd.extend( + [ + "--bin-dir", + "llama.cpp/build/bin", + "--model", + gguf_path, + ] + ) + else: + mlx_path = resolved_models[(backend, backend_cfg["model_ref"])] + cmd.extend( + [ + "--model", + mlx_path, + "--expected-template-source", + backend_cfg["template_source"], + ] + ) + + rc = run(cmd, env=env).returncode + append_tsv( + summary_path(root, stamp, "exact"), + ["model_id", "label", "expectation_class", "backend", "case_id", "exit"], + [model["id"], model["label"], model["expectation_class"], backend, case_id, str(rc)], + ) + return rc + + +def behavior_report_path(root: Path, stamp: str, case_id: str) -> Path: + return case_dir(root, stamp, "behavior", case_id) / "report.json" + + +def exact_chat_dir(root: Path, stamp: str, case_id: str) -> Path: + return case_dir(root, stamp, "exact", case_id) / "chat" + + +def normalize_exact_output(text: str) -> str: + normalized = text.replace("\r\n", "\n").replace("\r", "\n").strip().lower() + normalized = normalized.replace("**", "").replace("__", "").replace("`", "") + normalized = re.sub(r"\s+", " ", normalized) + return normalized.strip(" \t\n\r.,;:!?") + + +def load_exact_prompt_artifacts(root: Path, stamp: str, case_id: str) -> dict[str, dict[str, Any]]: + chat_dir = exact_chat_dir(root, stamp, case_id) + artifacts: dict[str, dict[str, Any]] = {} + if not chat_dir.exists(): + return artifacts + for path in sorted(chat_dir.glob("*.json")): + if path.stem.endswith(".thinking"): + continue + payload = json.loads(path.read_text(encoding="utf-8")) + label = str(payload.get("label", path.stem)) + artifacts[label] = payload + return artifacts + + +def exact_artifact_content(payload: dict[str, Any]) -> str: + content = payload.get("content") + if content is not None: + return str(content) + response_text = payload.get("response_text") + if response_text is not None: + return str(response_text) + return "" + + +def exact_prompt_snapshot(root: Path, stamp: str, case_id: str) -> dict[str, str]: + return { + label: normalize_exact_output(exact_artifact_content(payload)) + for label, payload in load_exact_prompt_artifacts(root, stamp, case_id).items() + } + + +def satisfied_expectation_buckets(payload: dict[str, Any]) -> set[str]: + expectations = payload.get("expectations", {}) + content = exact_artifact_content(payload) + normalized_content = normalize_exact_output(content) + buckets: set[str] = set() + + expect_exact = str(expectations.get("expect_exact", "")) + if expect_exact and normalized_content == normalize_exact_output(expect_exact): + buckets.add("expect_exact") + + expect_contains = str(expectations.get("expect_contains", "")) + if expect_contains and expect_contains in content: + buckets.add("expect_contains") + + expect_contains_ci = str(expectations.get("expect_contains_ci", "")) + if expect_contains_ci and normalize_exact_output(expect_contains_ci) in normalized_content: + buckets.add("expect_contains_ci") + + expect_contains_all_ci = [str(item) for item in expectations.get("expect_contains_all_ci", [])] + if expect_contains_all_ci and all(normalize_exact_output(item) in normalized_content for item in expect_contains_all_ci): + buckets.add("expect_contains_all_ci") + + expect_any_ci = [str(item) for item in expectations.get("expect_any_ci", [])] + if expect_any_ci and any(normalize_exact_output(item) in normalized_content for item in expect_any_ci): + buckets.add("expect_any_ci") + + return buckets + + +def compare_exact_prompt_payloads( + gguf_payload: dict[str, Any], + mlx_payload: dict[str, Any], +) -> tuple[str, str]: + gguf_content = exact_artifact_content(gguf_payload) + mlx_content = exact_artifact_content(mlx_payload) + gguf_normalized = normalize_exact_output(gguf_content) + mlx_normalized = normalize_exact_output(mlx_content) + if gguf_normalized == mlx_normalized: + return ("same-output", gguf_normalized) + + gguf_buckets = satisfied_expectation_buckets(gguf_payload) + mlx_buckets = satisfied_expectation_buckets(mlx_payload) + shared_buckets = sorted(gguf_buckets & mlx_buckets) + if shared_buckets: + return ("same-bucket", ",".join(shared_buckets)) + + return ("backend-differs", f"gguf={gguf_content!r} mlx={mlx_content!r}") + + +def flagged_prompt_summary(report_path: Path) -> list[str]: + if not report_path.exists(): + return [] + payload = json.loads(report_path.read_text(encoding="utf-8")) + flagged: list[str] = [] + for result in payload.get("results", []): + if result.get("passed", True): + continue + prompt_id = result.get("prompt_id", "") + category = result.get("category", "") + flagged.append(f"{category}#{prompt_id}") + return flagged + + +def run_behavior_case( + root: Path, + stamp: str, + matrix: dict[str, Any], + model: dict[str, Any], + backend: str, + resolved_models: dict[tuple[str, str], str], + *, + dataset: str, + max_prompts: int, + max_tokens: int, + wait_seconds: int, +) -> int: + behavior_defaults = matrix["defaults"]["behavior"] + backend_cfg = model[backend] + case_id = backend_cfg["behavior_case_id"] + out_dir = case_dir(root, stamp, "behavior", case_id) + out_dir.mkdir(parents=True, exist_ok=True) + report_path = behavior_report_path(root, stamp, case_id) + mesh_log_path = out_dir / "mesh.log" + env = { + **os.environ, + "VALIDATION_RESULTS_ROOT": str(root), + "VALIDATION_RESULTS_STAMP": suite_stamp(stamp, "behavior"), + } + + run(["just", "stop"], env=env) + + if backend == "gguf": + model_arg = resolved_models[(backend, backend_cfg["model_ref"])] + cmd = [ + str(REPO_ROOT / "scripts" / "run-validation-case.sh"), + backend, + case_id, + "python3", + str(REPO_ROOT / "scripts" / "ci-mt-bench-behavior.py"), + "--backend", + backend, + "--mesh-llm", + "target/release/mesh-llm", + "--bin-dir", + "llama.cpp/build/bin", + "--model", + model_arg, + "--label", + model["label"], + "--dataset", + dataset or behavior_defaults["dataset"], + "--max-prompts", + str(max_prompts), + "--max-tokens", + str(max_tokens or behavior_defaults["max_tokens"]), + "--wait-seconds", + str(wait_seconds or behavior_defaults["wait_seconds"]), + "--mesh-log-output", + str(mesh_log_path), + "--output-json", + str(report_path), + ] + else: + model_arg = resolved_models[(backend, backend_cfg["model_ref"])] + cmd = [ + str(REPO_ROOT / "scripts" / "run-validation-case.sh"), + backend, + case_id, + "python3", + str(REPO_ROOT / "scripts" / "ci-mt-bench-behavior.py"), + "--backend", + backend, + "--mesh-llm", + "target/release/mesh-llm", + "--model", + model_arg, + "--label", + model["label"], + "--dataset", + dataset or behavior_defaults["dataset"], + "--max-prompts", + str(max_prompts), + "--max-tokens", + str(max_tokens or behavior_defaults["max_tokens"]), + "--wait-seconds", + str(wait_seconds or behavior_defaults["wait_seconds"]), + "--mesh-log-output", + str(mesh_log_path), + "--output-json", + str(report_path), + ] + + rc = run(cmd, env=env).returncode + failed_prompt_count = "" + prompt_count = "" + if report_path.exists(): + payload = json.loads(report_path.read_text(encoding="utf-8")) + failed_prompt_count = str(payload.get("failed_prompt_count", "")) + prompt_count = str(payload.get("prompt_count", "")) + append_tsv( + summary_path(root, stamp, "behavior"), + [ + "model_id", + "label", + "expectation_class", + "backend", + "case_id", + "exit", + "failed_prompts", + "prompt_count", + ], + [ + model["id"], + model["label"], + model["expectation_class"], + backend, + case_id, + str(rc), + failed_prompt_count, + prompt_count, + ], + ) + return rc + + +def thinking_report_path(root: Path, stamp: str, case_id: str) -> Path: + return case_dir(root, stamp, "thinking", case_id) / "report.json" + + +def run_thinking_case( + root: Path, + stamp: str, + matrix: dict[str, Any], + model: dict[str, Any], + backend: str, + resolved_models: dict[tuple[str, str], str], + *, + wait_seconds: int, +) -> int: + thinking_defaults = thinking_config_for(matrix, model, backend) + backend_cfg = model[backend] + case_id = backend_cfg["thinking_case_id"] + out_dir = case_dir(root, stamp, "thinking", case_id) + out_dir.mkdir(parents=True, exist_ok=True) + report_path = thinking_report_path(root, stamp, case_id) + mesh_log_path = out_dir / "mesh.log" + env = { + **os.environ, + "VALIDATION_RESULTS_ROOT": str(root), + "VALIDATION_RESULTS_STAMP": suite_stamp(stamp, "thinking"), + } + + run(["just", "stop"], env=env) + + prompt_suite_json = json.dumps(thinking_defaults["prompt_suite"], separators=(",", ":")) + cmd = [ + str(REPO_ROOT / "scripts" / "run-validation-case.sh"), + backend, + case_id, + "python3", + str(REPO_ROOT / "scripts" / "ci-thinking-smoke.py"), + "--backend", + backend, + "--mesh-llm", + "target/release/mesh-llm", + "--label", + model["label"], + "--prompt-suite-json", + prompt_suite_json, + "--thinking-mode", + thinking_defaults.get("thinking_mode", "nonempty"), + "--wait-seconds", + str(wait_seconds or thinking_defaults.get("wait_seconds", DEFAULT_WAIT_SECONDS)), + "--mesh-log-output", + str(mesh_log_path), + "--output-json", + str(report_path), + ] + if backend == "gguf": + model_arg = resolved_models[(backend, backend_cfg["model_ref"])] + cmd.extend( + [ + "--bin-dir", + "llama.cpp/build/bin", + "--model", + model_arg, + ] + ) + else: + model_arg = resolved_models[(backend, backend_cfg["model_ref"])] + cmd.extend( + [ + "--model", + model_arg, + "--expected-template-source", + backend_cfg["template_source"], + ] + ) + + rc = run(cmd, env=env).returncode + failed_check_count = "" + check_count = "" + if report_path.exists(): + payload = json.loads(report_path.read_text(encoding="utf-8")) + failed_check_count = str(payload.get("failed_check_count", "")) + check_count = str(payload.get("check_count", "")) + append_tsv( + summary_path(root, stamp, "thinking"), + [ + "model_id", + "label", + "expectation_class", + "backend", + "case_id", + "exit", + "failed_checks", + "check_count", + ], + [ + model["id"], + model["label"], + model["expectation_class"], + backend, + case_id, + str(rc), + failed_check_count, + check_count, + ], + ) + return rc + + +def aggregate(root: Path, stamp: str, models: list[dict[str, Any]]) -> None: + exact_rows: dict[tuple[str, str], str] = {} + behavior_rows: dict[tuple[str, str], tuple[str, str]] = {} + + exact_path = summary_path(root, stamp, "exact") + if exact_path.exists(): + for line in exact_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, _case_id, exit_code = line.split("\t") + exact_rows[(model_id, backend)] = exit_code + + behavior_path = summary_path(root, stamp, "behavior") + if behavior_path.exists(): + for line in behavior_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, _case_id, exit_code, failed_prompts, prompt_count = line.split("\t") + behavior_rows[(model_id, backend)] = (exit_code, f"{failed_prompts}/{prompt_count}" if failed_prompts and prompt_count else "") + + aggregate_path = root / stamp / "validation-summary.tsv" + header = [ + "model_id", + "label", + "expectation_class", + "gguf_exact_exit", + "mlx_exact_exit", + "gguf_behavior", + "mlx_behavior", + ] + lines = ["\t".join(header)] + for model in models: + gguf_behavior = behavior_rows.get((model["id"], "gguf"), ("", "")) + mlx_behavior = behavior_rows.get((model["id"], "mlx"), ("", "")) + lines.append( + "\t".join( + [ + model["id"], + model["label"], + model["expectation_class"], + exact_rows.get((model["id"], "gguf"), ""), + exact_rows.get((model["id"], "mlx"), ""), + ":".join(filter(None, gguf_behavior)), + ":".join(filter(None, mlx_behavior)), + ] + ) + ) + aggregate_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def planned_cases(models: list[dict[str, Any]], backend_filter: str, suite: str) -> list[dict[str, str]]: + backend_order = ["gguf", "mlx"] if backend_filter == "both" else [backend_filter] + cases: list[dict[str, str]] = [] + if suite in ("exact", "all"): + for backend in backend_order: + for model in models: + if backend not in requested_backends(model, backend_filter): + continue + case_id = model[backend].get("exact_case_id") + if not case_id: + continue + cases.append( + { + "suite": "exact", + "backend": backend, + "model_id": model["id"], + "label": model["label"], + "case_id": case_id, + } + ) + if suite in ("behavior", "all"): + for backend in backend_order: + for model in models: + if backend not in requested_backends(model, backend_filter): + continue + case_id = model[backend].get("behavior_case_id") + if not case_id: + continue + cases.append( + { + "suite": "behavior", + "backend": backend, + "model_id": model["id"], + "label": model["label"], + "case_id": case_id, + } + ) + if suite in ("thinking", "all"): + for backend in backend_order: + for model in models: + if backend not in requested_backends(model, backend_filter): + continue + case_id = model[backend].get("thinking_case_id") + if not case_id: + continue + cases.append( + { + "suite": "thinking", + "backend": backend, + "model_id": model["id"], + "label": model["label"], + "case_id": case_id, + } + ) + return cases + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def write_overall_progress( + root: Path, + stamp: str, + *, + total_cases: int, + completed_cases: int, + current: dict[str, Any] | None, + overall_rc: int, +) -> None: + payload = { + "total_cases": total_cases, + "completed_cases": completed_cases, + "completion_ratio": (completed_cases / total_cases) if total_cases else 0.0, + "current_case": current, + "overall_exit_code": overall_rc, + } + write_json(root / stamp / "overall-progress.json", payload) + + +def preflight_models( + root: Path, + stamp: str, + models: list[dict[str, Any]], + backend_filter: str, + suite: str, +) -> dict[tuple[str, str], str]: + resolved: dict[tuple[str, str], str] = {} + seen: set[tuple[str, str]] = set() + required_refs: list[tuple[str, str]] = [] + + model_by_id = {model["id"]: model for model in models} + for case in planned_cases(models, backend_filter, suite): + model = model_by_id[case["model_id"]] + backend = case["backend"] + model_ref = model[backend]["model_ref"] + key = (backend, model_ref) + if key in seen: + continue + seen.add(key) + required_refs.append(key) + + out_path = root / stamp / "preflight.json" + state: dict[str, Any] = { + "status": "running", + "total_models": len(required_refs), + "completed_models": 0, + "current_backend": None, + "current_model_ref": None, + "failed_backend": None, + "failed_model_ref": None, + "failure": None, + "items": [], + } + write_json(out_path, state) + + for backend, model_ref in required_refs: + state["current_backend"] = backend + state["current_model_ref"] = model_ref + write_json(out_path, state) + try: + local_path = download_model_ref(model_ref, backend) + except Exception as exc: + state["status"] = "failed" + state["failed_backend"] = backend + state["failed_model_ref"] = model_ref + state["failure"] = str(exc) + write_json(out_path, state) + raise + resolved[(backend, model_ref)] = local_path + state["items"].append( + { + "backend": backend, + "model_ref": model_ref, + "local_path": local_path, + } + ) + state["completed_models"] = len(state["items"]) + write_json(out_path, state) + + state["status"] = "completed" + state["current_backend"] = None + state["current_model_ref"] = None + write_json(out_path, state) + return resolved + + +def compare_exact_against_baseline( + baseline_cfg: dict[str, Any], + root: Path, + stamp: str, + models: list[dict[str, Any]], + backend_filter: str, +) -> None: + exact_path = summary_path(root, stamp, "exact") + if not exact_path.exists(): + return + actual_rows: dict[tuple[str, str], tuple[str, str]] = {} + for line in exact_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code = line.split("\t") + actual_rows[(model_id, backend)] = (case_id, exit_code) + + compare_path = root / stamp / "exact-baseline-comparison.tsv" + header = [ + "model_id", + "backend", + "case_id", + "expected_exit", + "actual_exit", + "output_status", + "status", + ] + lines = ["\t".join(header)] + for model in models: + for backend in requested_backends(model, backend_filter): + expected = baseline_cfg.get("exact", {}).get(backend, {}).get(model["id"]) + actual = actual_rows.get((model["id"], backend)) + if expected is None: + status = "no-baseline" + lines.append( + "\t".join( + [ + model["id"], + backend, + actual[0] if actual else "", + "", + actual[1] if actual else "", + "", + status, + ] + ) + ) + continue + expected_exit = str(expected.get("exit", "")) + actual_exit = actual[1] if actual else "" + output_status = "" + if actual is not None: + expected_outputs = expected.get("prompt_outputs") + if expected_outputs: + actual_outputs = exact_prompt_snapshot(root, stamp, actual[0]) + output_status = "match" if actual_outputs == expected_outputs else "mismatch" + else: + output_status = "no-output-baseline" + status = ( + "match" + if actual_exit == expected_exit and output_status in ("", "match", "no-output-baseline") + else "mismatch" + ) + lines.append( + "\t".join( + [ + model["id"], + backend, + actual[0] if actual else expected.get("case_id", ""), + expected_exit, + actual_exit, + output_status, + status, + ] + ) + ) + compare_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def compare_behavior_against_baseline( + baseline_cfg: dict[str, Any], + root: Path, + stamp: str, + models: list[dict[str, Any]], + backend_filter: str, +) -> None: + behavior_path = summary_path(root, stamp, "behavior") + if not behavior_path.exists(): + return + actual_rows: dict[tuple[str, str], dict[str, str]] = {} + for line in behavior_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code, failed_prompts, prompt_count = line.split("\t") + report_path = behavior_report_path(root, stamp, case_id) + actual_rows[(model_id, backend)] = { + "case_id": case_id, + "exit": exit_code, + "failed_prompt_count": failed_prompts, + "prompt_count": prompt_count, + "flagged": ",".join(flagged_prompt_summary(report_path)), + } + + compare_path = root / stamp / "behavior-baseline-comparison.tsv" + header = [ + "model_id", + "backend", + "case_id", + "expected_exit", + "actual_exit", + "expected_failed_prompts", + "actual_failed_prompts", + "expected_flagged", + "actual_flagged", + "status", + ] + lines = ["\t".join(header)] + for model in models: + for backend in requested_backends(model, backend_filter): + expected = baseline_cfg.get("behavior", {}).get(backend, {}).get(model["id"]) + actual = actual_rows.get((model["id"], backend)) + if expected is None: + status = "no-baseline" + lines.append( + "\t".join( + [ + model["id"], + backend, + actual["case_id"] if actual else "", + "", + actual["exit"] if actual else "", + "", + actual["failed_prompt_count"] if actual else "", + "", + actual["flagged"] if actual else "", + status, + ] + ) + ) + continue + expected_exit = str(expected.get("exit", "")) + expected_failed = str(expected.get("failed_prompt_count", "")) + expected_flagged = ",".join(expected.get("flagged_prompt_ids", [])) + actual_exit = actual["exit"] if actual else "" + actual_failed = actual["failed_prompt_count"] if actual else "" + actual_flagged = actual["flagged"] if actual else "" + status = ( + "match" + if actual_exit == expected_exit + and actual_failed == expected_failed + and actual_flagged == expected_flagged + else "mismatch" + ) + lines.append( + "\t".join( + [ + model["id"], + backend, + actual["case_id"] if actual else expected.get("case_id", ""), + expected_exit, + actual_exit, + expected_failed, + actual_failed, + expected_flagged, + actual_flagged, + status, + ] + ) + ) + compare_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def compare_thinking_against_baseline( + baseline_cfg: dict[str, Any], + root: Path, + stamp: str, + models: list[dict[str, Any]], + backend_filter: str, +) -> None: + thinking_path = summary_path(root, stamp, "thinking") + if not thinking_path.exists(): + return + actual_rows: dict[tuple[str, str], dict[str, str]] = {} + for line in thinking_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code, failed_checks, check_count = line.split("\t") + actual_rows[(model_id, backend)] = { + "case_id": case_id, + "exit": exit_code, + "failed_check_count": failed_checks, + "check_count": check_count, + } + + compare_path = root / stamp / "thinking-baseline-comparison.tsv" + header = [ + "model_id", + "backend", + "case_id", + "expected_exit", + "actual_exit", + "expected_failed_checks", + "actual_failed_checks", + "expected_check_count", + "actual_check_count", + "status", + ] + lines = ["\t".join(header)] + for model in models: + for backend in requested_backends(model, backend_filter): + expected = baseline_cfg.get("thinking", {}).get(backend, {}).get(model["id"]) + actual = actual_rows.get((model["id"], backend)) + if expected is None: + status = "no-baseline" + lines.append( + "\t".join( + [ + model["id"], + backend, + actual["case_id"] if actual else "", + "", + actual["exit"] if actual else "", + "", + actual["failed_check_count"] if actual else "", + "", + actual["check_count"] if actual else "", + status, + ] + ) + ) + continue + expected_exit = str(expected.get("exit", "")) + expected_failed = str(expected.get("failed_check_count", "")) + expected_count = str(expected.get("check_count", "")) + actual_exit = actual["exit"] if actual else "" + actual_failed = actual["failed_check_count"] if actual else "" + actual_count = actual["check_count"] if actual else "" + status = ( + "match" + if actual_exit == expected_exit + and actual_failed == expected_failed + and actual_count == expected_count + else "mismatch" + ) + lines.append( + "\t".join( + [ + model["id"], + backend, + actual["case_id"] if actual else expected.get("case_id", ""), + expected_exit, + actual_exit, + expected_failed, + actual_failed, + expected_count, + actual_count, + status, + ] + ) + ) + compare_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def compare_parity_to_canonical( + baseline_cfg: dict[str, Any], + root: Path, + stamp: str, + models: list[dict[str, Any]], +) -> None: + exact_path = summary_path(root, stamp, "exact") + if not exact_path.exists(): + return + canonical_backend = baseline_cfg.get("canonical_backend", "gguf") + actual_rows: dict[tuple[str, str], tuple[str, str]] = {} + for line in exact_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code = line.split("\t") + actual_rows[(model_id, backend)] = (case_id, exit_code) + + compare_path = root / stamp / "parity-vs-canonical-baseline.tsv" + header = [ + "model_id", + "canonical_backend", + "canonical_expected_exit", + "actual_mlx_exit", + "status", + ] + lines = ["\t".join(header)] + for model in models: + canonical = baseline_cfg.get("exact", {}).get(canonical_backend, {}).get(model["id"]) + mlx_actual = actual_rows.get((model["id"], "mlx")) + if canonical is None or mlx_actual is None: + continue + canonical_exit = str(canonical.get("exit", "")) + mlx_exit = mlx_actual[1] + status = "within-threshold" if canonical_exit == mlx_exit else "mlx-differs" + lines.append( + "\t".join( + [ + model["id"], + canonical_backend, + canonical_exit, + mlx_exit, + status, + ] + ) + ) + compare_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def compare_cross_backend_exact_parity( + root: Path, + stamp: str, + models: list[dict[str, Any]], +) -> None: + compare_path = root / stamp / "exact-cross-backend-parity.tsv" + header = [ + "model_id", + "label", + "gguf_case_id", + "mlx_case_id", + "compared_prompts", + "missing_labels", + "status", + "details", + ] + lines = ["\t".join(header)] + + for model in models: + if "gguf" not in model or "mlx" not in model: + continue + gguf_case_id = model["gguf"].get("exact_case_id", "") + mlx_case_id = model["mlx"].get("exact_case_id", "") + if not gguf_case_id or not mlx_case_id: + continue + + gguf_artifacts = load_exact_prompt_artifacts(root, stamp, gguf_case_id) + mlx_artifacts = load_exact_prompt_artifacts(root, stamp, mlx_case_id) + gguf_labels = set(gguf_artifacts) + mlx_labels = set(mlx_artifacts) + compared_labels = sorted(gguf_labels & mlx_labels) + missing_labels = sorted(gguf_labels ^ mlx_labels) + + prompt_results: list[str] = [] + prompt_statuses: list[str] = [] + for prompt_label in compared_labels: + status, detail = compare_exact_prompt_payloads(gguf_artifacts[prompt_label], mlx_artifacts[prompt_label]) + prompt_statuses.append(status) + prompt_results.append(f"{prompt_label}={status}({detail})") + + if not compared_labels: + overall_status = "no-shared-prompts" + elif missing_labels: + overall_status = "backend-differs" + elif all(status == "same-output" for status in prompt_statuses): + overall_status = "same-output" + elif all(status in ("same-output", "same-bucket") for status in prompt_statuses): + overall_status = "same-bucket" + else: + overall_status = "backend-differs" + + lines.append( + "\t".join( + [ + model["id"], + model["label"], + gguf_case_id, + mlx_case_id, + ",".join(compared_labels), + ",".join(missing_labels), + overall_status, + " | ".join(prompt_results), + ] + ) + ) + + compare_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def baseline_divergence_report( + baseline_cfg: dict[str, Any], + root: Path, + stamp: str, + models: list[dict[str, Any]], +) -> None: + out_path = root / stamp / "baseline-divergence.tsv" + header = [ + "suite", + "model_id", + "gguf_baseline", + "mlx_baseline", + "status", + ] + lines = ["\t".join(header)] + + for model in models: + gguf_exact = baseline_cfg.get("exact", {}).get("gguf", {}).get(model["id"]) + mlx_exact = baseline_cfg.get("exact", {}).get("mlx", {}).get(model["id"]) + if gguf_exact is not None or mlx_exact is not None: + gguf_value = "" if gguf_exact is None else str(gguf_exact.get("exit", "")) + mlx_value = "" if mlx_exact is None else str(mlx_exact.get("exit", "")) + status = "same" if gguf_value == mlx_value else "diverged" + lines.append("\t".join(["exact", model["id"], gguf_value, mlx_value, status])) + + gguf_behavior = baseline_cfg.get("behavior", {}).get("gguf", {}).get(model["id"]) + mlx_behavior = baseline_cfg.get("behavior", {}).get("mlx", {}).get(model["id"]) + if gguf_behavior is not None or mlx_behavior is not None: + gguf_failed = "" if gguf_behavior is None else str(gguf_behavior.get("failed_prompt_count", "")) + mlx_failed = "" if mlx_behavior is None else str(mlx_behavior.get("failed_prompt_count", "")) + gguf_flagged = "" if gguf_behavior is None else ",".join(gguf_behavior.get("flagged_prompt_ids", [])) + mlx_flagged = "" if mlx_behavior is None else ",".join(mlx_behavior.get("flagged_prompt_ids", [])) + gguf_value = f"{gguf_failed}:{gguf_flagged}".rstrip(":") + mlx_value = f"{mlx_failed}:{mlx_flagged}".rstrip(":") + status = "same" if gguf_value == mlx_value else "diverged" + lines.append("\t".join(["behavior", model["id"], gguf_value, mlx_value, status])) + + gguf_thinking = baseline_cfg.get("thinking", {}).get("gguf", {}).get(model["id"]) + mlx_thinking = baseline_cfg.get("thinking", {}).get("mlx", {}).get(model["id"]) + if gguf_thinking is not None or mlx_thinking is not None: + gguf_failed = "" if gguf_thinking is None else str(gguf_thinking.get("failed_check_count", "")) + mlx_failed = "" if mlx_thinking is None else str(mlx_thinking.get("failed_check_count", "")) + gguf_count = "" if gguf_thinking is None else str(gguf_thinking.get("check_count", "")) + mlx_count = "" if mlx_thinking is None else str(mlx_thinking.get("check_count", "")) + gguf_value = f"{gguf_failed}/{gguf_count}".rstrip("/") + mlx_value = f"{mlx_failed}/{mlx_count}".rstrip("/") + status = "same" if gguf_value == mlx_value else "diverged" + lines.append("\t".join(["thinking", model["id"], gguf_value, mlx_value, status])) + + out_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def promote_baselines( + baseline_cfg: dict[str, Any], + baseline_path: Path, + root: Path, + stamp: str, + models: list[dict[str, Any]], + backend_filter: str, + suite: str, +) -> None: + for section in ("exact", "behavior", "thinking"): + baseline_cfg.setdefault(section, {}) + baseline_cfg[section].setdefault("gguf", {}) + baseline_cfg[section].setdefault("mlx", {}) + + requested = {"gguf", "mlx"} if backend_filter == "both" else {backend_filter} + + model_by_id = {model["id"]: model for model in models} + + if suite in ("exact", "all"): + exact_path = summary_path(root, stamp, "exact") + if exact_path.exists(): + if "gguf" in requested: + strict_failures: list[str] = [] + for line in exact_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, expectation_class, backend, case_id, exit_code = line.split("\t") + if backend != "gguf" or expectation_class != "strict": + continue + if exit_code != "0": + strict_failures.append(f"{case_id}={exit_code}") + if strict_failures: + raise SystemExit( + "❌ refusing to promote canonical GGUF exact baseline; strict rows failed: " + + ", ".join(strict_failures) + ) + for line in exact_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code = line.split("\t") + if backend not in requested: + continue + baseline_cfg["exact"][backend][model_id] = { + "exit": int(exit_code), + "case_id": case_id, + "prompt_outputs": exact_prompt_snapshot(root, stamp, case_id), + } + + if suite in ("behavior", "all"): + behavior_path = summary_path(root, stamp, "behavior") + if behavior_path.exists(): + if "gguf" in requested: + strict_failures: list[str] = [] + for line in behavior_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, expectation_class, backend, case_id, exit_code, failed_prompts, _prompt_count = line.split("\t") + if backend != "gguf" or expectation_class != "strict": + continue + if exit_code != "0" or (failed_prompts and failed_prompts != "0"): + strict_failures.append(f"{case_id}=exit:{exit_code},failed:{failed_prompts or '0'}") + if strict_failures: + raise SystemExit( + "❌ refusing to promote canonical GGUF behavior baseline; strict rows were flagged: " + + ", ".join(strict_failures) + ) + for line in behavior_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code, failed_prompts, prompt_count = line.split("\t") + if backend not in requested: + continue + report_path = behavior_report_path(root, stamp, case_id) + baseline_cfg["behavior"][backend][model_id] = { + "exit": int(exit_code), + "case_id": case_id, + "failed_prompt_count": int(failed_prompts or "0"), + "prompt_count": int(prompt_count or "0"), + "flagged_prompt_ids": flagged_prompt_summary(report_path), + } + + if suite in ("thinking", "all"): + thinking_path = summary_path(root, stamp, "thinking") + if thinking_path.exists(): + if "gguf" in requested: + strict_failures: list[str] = [] + for line in thinking_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, expectation_class, backend, case_id, exit_code, failed_checks, _check_count = line.split("\t") + if backend != "gguf" or expectation_class != "strict": + continue + if exit_code != "0" or (failed_checks and failed_checks != "0"): + strict_failures.append(f"{case_id}=exit:{exit_code},failed:{failed_checks or '0'}") + if strict_failures: + raise SystemExit( + "❌ refusing to promote canonical GGUF thinking baseline; strict rows were flagged: " + + ", ".join(strict_failures) + ) + for line in thinking_path.read_text(encoding="utf-8").splitlines()[1:]: + model_id, _label, _expectation, backend, case_id, exit_code, failed_checks, check_count = line.split("\t") + if backend not in requested: + continue + baseline_cfg["thinking"][backend][model_id] = { + "exit": int(exit_code), + "case_id": case_id, + "failed_check_count": int(failed_checks or "0"), + "check_count": int(check_count or "0"), + } + + baseline_path.parent.mkdir(parents=True, exist_ok=True) + baseline_path.write_text(json.dumps(baseline_cfg, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--suite", choices=["exact", "behavior", "thinking", "all"], default="all") + parser.add_argument("--backend", choices=["gguf", "mlx", "both"], default="both") + parser.add_argument("--stamp", default="") + parser.add_argument("--root", default=str(DEFAULT_ROOT)) + parser.add_argument("--matrix", default=str(DEFAULT_MATRIX)) + parser.add_argument("--baselines", default=str(DEFAULT_BASELINES)) + parser.add_argument("--cases", default="", help="Comma-separated model ids, labels, or case ids") + parser.add_argument("--skip-build", action="store_true") + parser.add_argument("--dataset", default="") + parser.add_argument("--max-prompts", type=int, default=0) + parser.add_argument("--max-tokens", type=int, default=0) + parser.add_argument("--wait-seconds", type=int, default=DEFAULT_WAIT_SECONDS) + parser.add_argument("--promote-baseline", action="store_true") + args = parser.parse_args() + + if args.backend in ("mlx", "both") and os.uname().sysname != "Darwin": + raise SystemExit("❌ MLX validation requires macOS") + + matrix = load_matrix(Path(args.matrix)) + baseline_path = Path(args.baselines) if args.baselines else None + baselines = load_baselines(baseline_path) + stamp = args.stamp or subprocess.check_output(["date", "+%Y%m%d-%H%M%S"], text=True).strip() + root = Path(args.root) + selectors = {item.strip() for item in args.cases.split(",") if item.strip()} + models = selected_models(matrix, selectors, args.backend) + if not models: + print("No matrix entries matched the requested selectors.", file=sys.stderr) + return 2 + + ensure_build(args.skip_build) + + overall_rc = 0 + backend_order = ["gguf", "mlx"] if args.backend == "both" else [args.backend] + cases = planned_cases(models, args.backend, args.suite) + total_cases = len(cases) + completed_cases = 0 + current_case_path = root / stamp / "current-case.json" + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current=None, + overall_rc=overall_rc, + ) + write_json( + current_case_path, + { + "status": "preflight", + "completed_cases": completed_cases, + "total_cases": total_cases, + }, + ) + resolved_models = preflight_models(root, stamp, models, args.backend, args.suite) + write_json( + current_case_path, + { + "status": "idle", + "completed_cases": completed_cases, + "total_cases": total_cases, + }, + ) + + if args.suite in ("exact", "all"): + for backend in backend_order: + for model in models: + if backend not in requested_backends(model, args.backend): + continue + case_id = model[backend].get("exact_case_id") + if not case_id: + continue + current = { + "suite": "exact", + "backend": backend, + "model_id": model["id"], + "label": model["label"], + "case_id": case_id, + "status": "running", + } + write_json(current_case_path, current) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current=current, + overall_rc=overall_rc, + ) + print(f"\n=== Running {case_id} ({backend}) ===") + rc = run_exact_case(root, stamp, matrix, model, backend, resolved_models) + overall_rc = overall_rc or rc + completed_cases += 1 + aggregate(root, stamp, models) + compare_exact_against_baseline(baselines, root, stamp, models, args.backend) + compare_behavior_against_baseline(baselines, root, stamp, models, args.backend) + compare_thinking_against_baseline(baselines, root, stamp, models, args.backend) + compare_parity_to_canonical(baselines, root, stamp, models) + compare_cross_backend_exact_parity(root, stamp, models) + baseline_divergence_report(baselines, root, stamp, models) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current={**current, "status": "completed", "exit_code": rc}, + overall_rc=overall_rc, + ) + + if args.suite in ("behavior", "all"): + for backend in backend_order: + for model in models: + if backend not in requested_backends(model, args.backend): + continue + case_id = model[backend].get("behavior_case_id") + if not case_id: + continue + current = { + "suite": "behavior", + "backend": backend, + "model_id": model["id"], + "label": model["label"], + "case_id": case_id, + "status": "running", + } + write_json(current_case_path, current) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current=current, + overall_rc=overall_rc, + ) + print(f"\n=== Running {case_id} ({backend}) ===") + rc = run_behavior_case( + root, + stamp, + matrix, + model, + backend, + resolved_models, + dataset=args.dataset, + max_prompts=args.max_prompts, + max_tokens=args.max_tokens, + wait_seconds=args.wait_seconds, + ) + overall_rc = overall_rc or rc + completed_cases += 1 + aggregate(root, stamp, models) + compare_exact_against_baseline(baselines, root, stamp, models, args.backend) + compare_behavior_against_baseline(baselines, root, stamp, models, args.backend) + compare_thinking_against_baseline(baselines, root, stamp, models, args.backend) + compare_parity_to_canonical(baselines, root, stamp, models) + compare_cross_backend_exact_parity(root, stamp, models) + baseline_divergence_report(baselines, root, stamp, models) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current={**current, "status": "completed", "exit_code": rc}, + overall_rc=overall_rc, + ) + + if args.suite in ("thinking", "all"): + for backend in backend_order: + for model in models: + if backend not in requested_backends(model, args.backend): + continue + case_id = model[backend].get("thinking_case_id") + if not case_id: + continue + current = { + "suite": "thinking", + "backend": backend, + "model_id": model["id"], + "label": model["label"], + "case_id": case_id, + "status": "running", + } + write_json(current_case_path, current) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current=current, + overall_rc=overall_rc, + ) + print(f"\n=== Running {case_id} ({backend}) ===") + rc = run_thinking_case( + root, + stamp, + matrix, + model, + backend, + resolved_models, + wait_seconds=args.wait_seconds, + ) + overall_rc = overall_rc or rc + completed_cases += 1 + aggregate(root, stamp, models) + compare_exact_against_baseline(baselines, root, stamp, models, args.backend) + compare_behavior_against_baseline(baselines, root, stamp, models, args.backend) + compare_thinking_against_baseline(baselines, root, stamp, models, args.backend) + compare_parity_to_canonical(baselines, root, stamp, models) + compare_cross_backend_exact_parity(root, stamp, models) + baseline_divergence_report(baselines, root, stamp, models) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current={**current, "status": "completed", "exit_code": rc}, + overall_rc=overall_rc, + ) + + pending_exit: SystemExit | None = None + try: + aggregate(root, stamp, models) + compare_exact_against_baseline(baselines, root, stamp, models, args.backend) + compare_behavior_against_baseline(baselines, root, stamp, models, args.backend) + compare_thinking_against_baseline(baselines, root, stamp, models, args.backend) + compare_parity_to_canonical(baselines, root, stamp, models) + compare_cross_backend_exact_parity(root, stamp, models) + baseline_divergence_report(baselines, root, stamp, models) + + if args.promote_baseline: + if baseline_path is None: + raise SystemExit("❌ --promote-baseline requires a writable --baselines path") + promote_baselines(baselines, baseline_path, root, stamp, models, args.backend, args.suite) + baseline_divergence_report(baselines, root, stamp, models) + + if args.suite in ("exact", "all"): + exact_path = summary_path(root, stamp, "exact") + if exact_path.exists(): + print("\n=== Exact summary ===") + print(exact_path.read_text(encoding="utf-8"), end="") + if args.suite in ("behavior", "all"): + behavior_path = summary_path(root, stamp, "behavior") + if behavior_path.exists(): + print("\n=== Behavior summary ===") + print(behavior_path.read_text(encoding="utf-8"), end="") + if args.suite in ("thinking", "all"): + thinking_path = summary_path(root, stamp, "thinking") + if thinking_path.exists(): + print("\n=== Thinking summary ===") + print(thinking_path.read_text(encoding="utf-8"), end="") + + aggregate_path = root / stamp / "validation-summary.tsv" + if aggregate_path.exists(): + print("\n=== Combined summary ===") + print(aggregate_path.read_text(encoding="utf-8"), end="") + exact_compare_path = root / stamp / "exact-baseline-comparison.tsv" + if exact_compare_path.exists(): + print("\n=== Exact baseline comparison ===") + print(exact_compare_path.read_text(encoding="utf-8"), end="") + behavior_compare_path = root / stamp / "behavior-baseline-comparison.tsv" + if behavior_compare_path.exists(): + print("\n=== Behavior baseline comparison ===") + print(behavior_compare_path.read_text(encoding="utf-8"), end="") + thinking_compare_path = root / stamp / "thinking-baseline-comparison.tsv" + if thinking_compare_path.exists(): + print("\n=== Thinking baseline comparison ===") + print(thinking_compare_path.read_text(encoding="utf-8"), end="") + parity_compare_path = root / stamp / "parity-vs-canonical-baseline.tsv" + if parity_compare_path.exists(): + print("\n=== Parity vs canonical baseline ===") + print(parity_compare_path.read_text(encoding="utf-8"), end="") + exact_parity_path = root / stamp / "exact-cross-backend-parity.tsv" + if exact_parity_path.exists(): + print("\n=== Exact cross-backend parity ===") + print(exact_parity_path.read_text(encoding="utf-8"), end="") + divergence_path = root / stamp / "baseline-divergence.tsv" + if divergence_path.exists(): + print("\n=== Baseline divergence ===") + print(divergence_path.read_text(encoding="utf-8"), end="") + except SystemExit as exc: + pending_exit = exc + finally: + final_status = "failed" if pending_exit is not None else "idle" + write_json( + current_case_path, + { + "status": final_status, + "completed_cases": completed_cases, + "total_cases": total_cases, + "overall_exit_code": overall_rc, + }, + ) + write_overall_progress( + root, + stamp, + total_cases=total_cases, + completed_cases=completed_cases, + current=None, + overall_rc=overall_rc, + ) + if pending_exit is not None: + raise pending_exit + print(f"\nRaw artifacts: {root / stamp}") + return overall_rc + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/shapeup.md b/shapeup.md new file mode 100644 index 00000000..a7a51a25 --- /dev/null +++ b/shapeup.md @@ -0,0 +1,246 @@ +# Shape Up: Recover PR 235's intent without regressing CI tuning from PRs 209 + 211 + +## Goal + +Re-introduce the **useful intent** from PR #235 while treating PRs **#209** and **#211** as the non-negotiable baseline for CI behavior. + +The result should preserve the current carefully tuned fast path while adding the workflow structure improvements PR #235 was aiming for. + +## Baseline to preserve + +These are the invariants from the tuned CI shape. Do **not** regress them. + +### 1. Fast PR feedback beats release-like fidelity in ordinary CI + +PR CI should stay optimized for fast validation, not for producing release-grade binaries everywhere. + +- Keep the **path filter gate** in `ci.yml` so docs-only and UI-only changes do not trigger the full backend matrix. +- Keep ordinary PR validation scoped around the smallest work that still proves correctness. +- Do not turn CPU/macOS PR lanes back into broad release-style rebuilds unless there is a specific measurement-backed reason. + +### 2. Restore-only GPU caches in PR CI + +The design in `warm-caches.yml` is intentional. + +- PR jobs in `ci.yml` should **restore** warmed GPU caches. +- `warm-caches.yml` should remain the **single writer** for main-scoped GPU caches. +- Do not let PR CI save large GPU cache artifacts back into PR-scoped cache storage. +- Preserve cache pruning and explicit key management in `warm-caches.yml` / `gpu-warm-cache-job.yml`. + +### 3. Keep the slim CI GPU shape + +The tuned GPU lanes are intentionally not release-shaped. + +- Preserve the **slim** CUDA cache/input shape in CI: `arch89`, `fa-off`, pinned llama SHA. +- Preserve the **slim** ROCm CI shape: representative `gfx1100` warm/restore path. +- Preserve the distinction between **slim PR-validation artifacts** and **fat release artifacts**. +- Do not move release-grade GPU arch matrices into normal PR CI. + +### 4. Keep release-specific behavior in release workflows + +Release artifact production belongs in `release.yml`, not in ordinary PR CI. + +- Release builds should remain the place where we build the full shipping artifact shape. +- Release-only settings such as full GPU arch matrices and CUDA `FA_ALL_QUANTS=ON` must stay release-only. +- Do not blur the boundary between β€œfast CI proof” and β€œship-ready release artifact.” + +### 5. Preserve the reasoning encoded in the scripts + +Do not undo the script-level tuning that was added to support CI performance and determinism. + +- Keep `scripts/build-linux.sh` support for `MESH_LLM_LLAMA_PIN_SHA`. +- Keep `MESH_LLM_CUDA_FA_ALL_QUANTS=off` as a **CI-only** opt-out, never the release default. +- Keep the build-script distinction between CI-friendly shape and release-friendly shape. + +## What PR 235 was trying to achieve + +Do preserve these ideas from PR #235: + +1. **Separate build/package work from heavier inference smoke work.** +2. **Reuse packaged artifacts** instead of rebuilding the same Linux inference binaries again for smoke validation. +3. **Gate release publication on successful inference smokes.** +4. **Upload platform artifacts** from non-CPU lanes where useful. +5. **Normalize AMD/ROCm naming** through aliases if that helps readability and future maintenance. + +That is the spirit worth keeping. + +## What PR 235 got wrong and must not be copied forward + +Do **not** port these regressions from PR #235: + +1. Reverting tuned CI lanes back to broad **release builds/tests** just to support artifact packaging. +2. Losing the practical distinction between **CI validation shape** and **release shape**. +3. Smuggling release-like behavior into PR CI merely because a follow-up smoke job wants a tarball. +4. Weakening or sidelining the explicit cache strategy established in `warm-caches.yml`. +5. Treating artifact packaging as permission to rebuild expensive things twice in different forms. + +## Desired end state + +Implement the following end state. + +### A. `ci.yml` should keep its current tuned execution model + +Keep: + +- `changes` path-filter gating. +- Restore-only GPU cache consumers. +- Slim GPU build inputs for PR validation. +- No PR-side cache writes for warmed GPU artifacts. + +If current CPU/macOS lanes have drifted away from the original 209 intent, use 209 as the reference point for what β€œfast ordinary CI” should mean. + +### B. Add artifact handoff without forcing release-shaped builds + +Recover PR 235's artifact-reuse idea, but do it in a way that does not force ordinary CI lanes into release-grade rebuilds. + +Concretely: + +- Build/package the **same artifact shape already justified for that CI lane**. +- If a smoke job only needs to consume a previously built binary, upload that binary and the required llama.cpp executables as CI artifacts. +- Do **not** change the producer lane from debug/slim to release/fat just because artifact upload is convenient. + +For CPU/Linux inference-smoke reuse in PR CI: + +- The producing job should emit exactly the binary shape the lane is already meant to validate. +- The downstream smoke job should download and stage those binaries, then run the smoke scripts. + +### C. Keep inference smokes as a separate follow-up job where it helps + +This is the strongest idea in PR 235 and should be preserved. + +- A separate `inference_smoke_tests` job in `ci.yml` is good **if** it consumes artifacts from the producer build job instead of triggering another meaningful rebuild. +- The smoke job should own: + - model cache restore/download + - real inference smoke + - Python/OpenAI compatibility smoke + - split-mode smoke + - MoE split + mesh smoke +- The producer build job should own only the work needed to create the binary payload. + +### D. Release gating should stay in `release.yml` + +PR 235 was right to gate publish on release-smoke success. + +Implement this pattern in `release.yml`: + +- Linux release build produces both: + - release bundles under `dist/` + - a Linux inference binary artifact for downstream smoke testing +- A release-only `inference_smoke_tests` job downloads the Linux inference binaries and runs the real smoke suite. +- `publish` depends on: + - all release artifact build jobs + - release inference smoke success + +This keeps release validation strong without forcing PR CI to become release-shaped. + +### E. AMD naming aliases are acceptable, but only as aliases + +The `release-build-amd` / `release-bundle-amd` aliases are fine if they improve wording consistency. + +Rules: + +- Keep them as thin aliases over the ROCm recipes. +- Do not rename underlying artifact semantics in ways that break existing expectations unless you intend a broader migration. +- Keep artifact names stable unless there is a clear user-facing reason to change them. + +## File-by-file instructions + +### `.github/workflows/ci.yml` + +1. Keep the existing `changes` job and its path-filter behavior. +2. Keep the current warmed GPU cache restore flow and key discipline. +3. Preserve the slim CI CUDA/ROCm inputs and do not widen them to release defaults. +4. Introduce or preserve a separate `inference_smoke_tests` job that: + - depends on the Linux producer job, + - downloads prebuilt Linux binaries, + - stages them into the expected paths, + - runs the smoke suite. +5. If the Linux producer job currently builds a release binary only because of PR 235, change it back to the tuned CI shape from 209 before packaging artifacts. +6. Package/upload only what the downstream smoke stage needs: + - `mesh-llm` + - `rpc-server` + - `llama-server` + - `llama-moe-split` when present +7. Keep CLI smoke / cheap boot smoke in the producer lane if they provide fast early failure before the heavier downstream inference stage. + +### `.github/workflows/warm-caches.yml` + +1. Leave this workflow as the single writer for main-scoped warmed GPU caches. +2. Keep both slim and fat warming where currently present. +3. Keep pruning and explicit cache-input hashing. +4. Do not move cache writes back into PR CI. + +### `.github/workflows/gpu-warm-cache-job.yml` + +1. Preserve the restore-short-circuit-build-save pattern. +2. Preserve verification of restored/saved binaries. +3. Keep this as reusable cache-warm plumbing, not as a PR CI producer. + +### `.github/workflows/release.yml` + +1. Keep release artifact builds for CPU/macOS/CUDA/AMD(Via ROCm)/Vulkan here. +2. Keep a separate `inference_smoke_tests` job in release that consumes Linux inference binaries from the release build. +3. Keep `publish` gated on release smoke success. +4. Do not move release publish logic into `ci.yml`. +5. Using `gh release` shell logic is acceptable if it is working and clearer than the old action, but that is secondary to preserving the workflow shape. + +### `Justfile` + +1. Keep any AMD aliases as wrappers around ROCm recipes. +2. Do not collapse slim-vs-fat behavior into Justfile aliases alone; the workflow files must continue to express which artifact shape they want. +3. Preserve release recipes as release-oriented, not PR-CI-oriented. + +### `scripts/build-linux.sh` and `scripts/build-linux-rocm.sh` + +1. Keep pinned llama SHA support for deterministic cache correctness. +2. Keep CI-only opt-outs and assertions that document why slim CI builds are safe. +3. Do not remove the warnings that release builds must keep the safer/full settings. + +## Recommended implementation order + +1. **Reset the mental baseline** + - Treat 209/211 as the target shape. + - Diff current `ci.yml` against that baseline and identify where 235-style changes reintroduced release-like work into PR CI. + +2. **Recover the fast PR producer shape** + - Revert CPU/macOS PR build behavior to the tuned CI shape if it has drifted. + - Preserve slim GPU restore/build behavior. + +3. **Layer in artifact handoff** + - Add artifact packaging/upload to the producer lanes without changing their build profile or widening their backend shape. + +4. **Keep heavy inference checks downstream** + - Wire `inference_smoke_tests` to consume producer artifacts. + - Ensure it performs no unnecessary rebuild of `mesh-llm` or llama.cpp. + +5. **Finalize release gating** + - Keep release binary production in `release.yml`. + - Keep publish blocked on release smoke success. + +6. **Preserve GPU cache boundaries** + - Confirm `ci.yml` only restores warmed caches. + - Confirm `warm-caches.yml` remains the sole cache writer. + +## Validation checklist + +An implementation is only correct if all of the following are true: + +- PR CI still skips expensive backend work on docs-only changes. +- UI-only changes still avoid the full backend/GPU matrix. +- PR CI does not write warmed GPU caches. +- GPU PR lanes still consume slim warmed caches. +- CPU/Linux inference smokes consume uploaded binaries instead of rebuilding the same payload. +- Release workflow still builds shipping artifacts separately from PR CI. +- Release publish is gated on successful release inference smokes. +- No step reintroduces a duplicate build that 209/211 intentionally removed. +- No step widens slim CI GPU inputs into release defaults. +- No release-only safety setting is silently disabled for shipping artifacts. + +## Short version + +The safe merge of these ideas is: + +- **Keep 209/211's fast CI mechanics exactly in spirit.** +- **Adopt 235's artifact handoff and downstream smoke-job structure.** +- **Keep release-grade artifact production and publish gating in `release.yml`.** +- **Never pay for release-shaped builds in ordinary PR CI unless there is a measured reason.** diff --git a/testdata/validation/README.md b/testdata/validation/README.md new file mode 100644 index 00000000..eb0c148d --- /dev/null +++ b/testdata/validation/README.md @@ -0,0 +1,100 @@ +# Validation Testdata + +For the full workflow to add a new MLX family, derive same-origin `GGUF` and +`MLX` artifacts, publish them to `meshllm`, and then pin them here, start with +[`mesh-llm/docs/MLX_FAMILY_BRINGUP.md`](/Users/jdumay/code/worktrees/mesh-llm-validation/mesh-llm/docs/MLX_FAMILY_BRINGUP.md). + +This directory contains the checked-in data that drives local and CI validation +for the GGUF and MLX backends. + +## Files + +- `matrix.json` + - The canonical model matrix. + - Pins the exact GGUF and MLX artifacts to test. + - Describes the model label, expectation class, and the exact/behavior case + ids used by the runner. + +- `baselines.json` + - The checked-in expected results. + - `GGUF` is the canonical baseline. + - `MLX` baselines are tracked secondarily for backend self-consistency. + +## Runtime artifacts + +Each validation run writes local artifacts under `.cache/mlx-validation/<stamp>/`. + +For exact runs, each case directory now includes: + +- `stdout.log`, `stderr.log`, `mesh.log` +- `chat/<label>.json` + - the prompt text + - the exact request payload + - the raw response payload + - parsed content and finish reason + - the expectations applied to that prompt +- `models/v1-models.json` + - the raw `/v1/models` response captured at the end of the case + +## Strategy + +The validation system has two suites: + +1. `exact` + - Deterministic prompts such as `blue / green / red` + - Used for strict backend parity and prompt-following checks + +2. `behavior` + - MT-Bench-derived prompts with heuristic health checks + - Used to catch repetition, garbling, empty outputs, leaked reasoning, + and timeout/liveness failures + +## Baseline policy + +- New `GGUF` runs are compared against the checked-in `GGUF` baseline. +- New `MLX` runs are compared against: + - the checked-in `MLX` baseline for backend regression detection + - the checked-in `GGUF` baseline for parity + +Canonical baseline rule: + +- `GGUF` is only promotable as the canonical baseline when the `strict` GGUF + rows are clean for the suite being promoted. +- Weak canaries may remain weak, but they do not define what β€œgood GGUF” + means for the reference baseline. + +This gives three useful comparisons: + +1. `GGUF run` vs `GGUF baseline` +2. `MLX run` vs `MLX baseline` +3. `MLX run` vs `GGUF baseline` + +## Expectation classes + +- `strict` + - Should remain clean and deterministic. + +- `weak-but-stable` + - Known tiny-model weirdness is tolerated if it remains stable. + +- `informational` + - Useful for tracking parity, but not a hard quality gate. + +## Artifact drift + +Avoid changing `matrix.json` casually. + +If the artifact under test changes, update the pinned ref explicitly and treat +that as a baseline change, not a routine rerun. + +## Behavior baselines + +Behavior baselines should stay summary-based rather than full-output goldens. + +Prefer recording: + +- exit code +- failed prompt count +- flagged prompt ids/categories + +Do not check in large generated outputs as the baseline. diff --git a/testdata/validation/baselines.json b/testdata/validation/baselines.json new file mode 100644 index 00000000..1f7671d8 --- /dev/null +++ b/testdata/validation/baselines.json @@ -0,0 +1,680 @@ +{ + "version": 1, + "canonical_backend": "gguf", + "notes": { + "exact": "GGUF is the primary checked-in baseline. MLX baselines are tracked secondarily for backend regression detection.", + "behavior": "Behavior baselines are intentionally summary-based and reflect the accepted full-matrix run state for each backend.", + "thinking": "Thinking baselines record accepted suite-level check counts. qwen3 GGUF remains a known llama/GGUF limitation." + }, + "exact": { + "gguf": { + "olmo7b": { + "exit": 0, + "case_id": "olmo7b-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday. i'm an ai, i don't have the ability to generate responses with empty punctuation like that, but i can help you with your question", + "alt-green": "green. i am an ai and cannot reply with emojis or expressive responses. i can only provide a simple response with the color you ask for", + "alt-red": "red. i am unable to provide a reason or additional information about your request for the color red. if you need information about the color red, please provide more", + "banana-color": "yellowness. reply: yellowness. in general, bananas become slightly yellow or slightly spotty as they ripen. however, ban", + "breathing-gas": "oxygen. humans breathe in oxygen and breathe it out waste. without it, we would not survive extended periods of time", + "capital-france": "paris. the capital of france is paris, which is the most populous city with 2,1 million people (in the metropolitan area). the country", + "largest-planet": "jupiter or, if you prefer: jupitorous (i had to make up a word to fit the criteria!)", + "opposite-hot": "cold", + "primary-colors": "the rgb primary colors are: red, green, and blue. here they are listed in lowercase with commas between", + "primary": "blue. i am an ai and i am unable to reply with emotions, but i can provide you with a factual response, and in this case, the color", + "two-plus-two": "4 the complete answer is 4. 2 + 2 is equal to 4. this is a basic arithmetic principle, and the answer is straightforward" + } + }, + "mistral": { + "exit": 0, + "case_id": "mistral-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "the color you asked for is red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "qwen25": { + "exit": 0, + "case_id": "qwen25-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "air", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "qwen3": { + "exit": 0, + "case_id": "qwen3-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "llama32": { + "exit": 0, + "case_id": "llama32-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "gemma2": { + "exit": 0, + "case_id": "gemma2-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "gemma3": { + "exit": 0, + "case_id": "gemma3-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "gemma4": { + "exit": 0, + "case_id": "gemma4-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "4" + } + }, + "glm4": { + "exit": 0, + "case_id": "glm4-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "4" + } + }, + "lfm2": { + "exit": 0, + "case_id": "lfm2-gguf-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "i understand you're looking for a green color. while i can't directly specify a color code, green is a versatile and common color that can represent many things", + "alt-red": "i understand you're looking for a vivid description, but \"red\" is a color that doesn't typically fit the criteria of being exactly red. however, if", + "banana-color": "yellowish", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "i understand you're looking for a color, but \"blue\" itself isn't a color. i'll describe a blue shade or a blue-like color palette", + "two-plus-two": "2 + 2 = 4" + } + } + }, + "mlx": { + "olmo7b": { + "exit": 0, + "case_id": "olmo7b-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday. i'm an ai, i don't have the ability to generate responses with empty punctuation like that, but i can help you with your question", + "alt-green": "green. i am an ai and cannot reply with emotions, but i can provide you with a factual response. the color mentioned is \"green.\"", + "alt-red": "red. i am unable to provide a reason for this response. i am simply following the instructions given to me and providing you with an answer", + "banana-color": "yellowness. reply with one word", + "breathing-gas": "oxygen. humans breathe in oxygen and breathe it out waste. without it, we would not survive extended periods of time", + "capital-france": "paris. the capital of france is paris, which is the most populous city with 2,1 million people (in the metropolitan area). the country", + "largest-planet": "jupiter. (note: while jupiter is the fifth-largest planet in the solar system in terms of diameter, it is the most massive planet", + "opposite-hot": "cold", + "primary-colors": "the rgb primary colors are: red, green, and blue. without abbreviations, they are: red, green", + "primary": "blue. i am an ai and i am unable to respond with emojis or additional context. i can only provide you with a simple response in text", + "two-plus-two": "4 the complete answer is 4. 2 + 2 is equal to 4. this basic arithmetic equation is a fundamental fact in mathematics, and it" + } + }, + "mistral": { + "exit": 0, + "case_id": "mistral-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "the color you asked for is red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "qwen25": { + "exit": 0, + "case_id": "qwen25-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "air", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "qwen3": { + "exit": 0, + "case_id": "qwen3-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "4" + } + }, + "llama32": { + "exit": 0, + "case_id": "llama32-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "gemma2": { + "exit": 0, + "case_id": "gemma2-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "gemma3": { + "exit": 0, + "case_id": "gemma3-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "2 + 2 = 4" + } + }, + "gemma4": { + "exit": 0, + "case_id": "gemma4-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "4" + } + }, + "glm4": { + "exit": 0, + "case_id": "glm4-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "red, green, blue", + "primary": "blue", + "two-plus-two": "4" + } + }, + "lfm2": { + "exit": 0, + "case_id": "lfm2-mlx-exact", + "prompt_outputs": { + "after-monday": "tuesday", + "alt-green": "green", + "alt-red": "red", + "banana-color": "yellow", + "breathing-gas": "oxygen", + "capital-france": "paris", + "largest-planet": "jupiter", + "opposite-hot": "cold", + "primary-colors": "yellow, red, green, blue, orange, purple, black, white, brown, gray, violet", + "primary": "blue", + "two-plus-two": "2 + 2 equals 4. so, the answer is 2 + 2 = 4" + } + } + } + }, + "behavior": { + "gguf": { + "olmo7b": { + "exit": 1, + "case_id": "olmo7b-gguf-behavior", + "failed_prompt_count": 16, + "prompt_count": 80, + "flagged_prompt_ids": [ + "roleplay#28596049", + "roleplay#2379839", + "reasoning#43568637", + "math#17722882", + "math#9170963", + "math#75477827", + "math#48989963", + "coding#46306085", + "coding#73474718", + "extraction#97281954", + "extraction#34830845", + "extraction#48344673", + "extraction#36918905", + "stem#46091167", + "stem#67710207", + "stem#25568812" + ] + }, + "mistral": { + "exit": 1, + "case_id": "mistral-gguf-behavior", + "failed_prompt_count": 10, + "prompt_count": 80, + "flagged_prompt_ids": [ + "reasoning#70728745", + "math#17722882", + "math#35264271", + "math#9170963", + "math#42252927", + "math#48989963", + "coding#34377376", + "coding#19673382", + "extraction#34830845", + "stem#67710207" + ] + }, + "qwen25": { + "exit": 1, + "case_id": "qwen25-gguf-behavior", + "failed_prompt_count": 24, + "prompt_count": 80, + "flagged_prompt_ids": [ + "writing#3075696", + "writing#38622539", + "roleplay#39413040", + "roleplay#84856299", + "roleplay#2379839", + "reasoning#43568637", + "reasoning#70728745", + "reasoning#90276646", + "math#82969860", + "math#17722882", + "math#9170963", + "math#27514636", + "math#42252927", + "math#48989963", + "math#92837357", + "coding#23836768", + "extraction#34830845", + "extraction#48344673", + "extraction#36918905", + "stem#46091167", + "stem#57239570", + "humanities#14125357", + "humanities#30655190", + "humanities#36620167" + ] + }, + "qwen3": { + "exit": 1, + "case_id": "qwen3-gguf-behavior", + "failed_prompt_count": 24, + "prompt_count": 80, + "flagged_prompt_ids": [ + "writing#16377588", + "roleplay#28596049", + "reasoning#40319087", + "reasoning#30350920", + "reasoning#28957987", + "math#22339468", + "math#17722882", + "math#35264271", + "math#9170963", + "math#27514636", + "math#42252927", + "math#75477827", + "math#48989963", + "math#92837357", + "coding#56561594", + "coding#3583619", + "coding#23836768", + "extraction#97281954", + "extraction#34830845", + "extraction#37403084", + "extraction#48344673", + "extraction#36918905", + "humanities#47292991", + "humanities#6404588" + ] + }, + "llama32": { + "exit": 1, + "case_id": "llama32-gguf-behavior", + "failed_prompt_count": 19, + "prompt_count": 80, + "flagged_prompt_ids": [ + "writing#49723273", + "roleplay#28596049", + "reasoning#45070360", + "reasoning#90276646", + "reasoning#89774116", + "math#17722882", + "math#35264271", + "math#9170963", + "math#48989963", + "coding#3583619", + "coding#65906068", + "coding#24433994", + "coding#19673382", + "coding#23836768", + "extraction#34830845", + "extraction#48344673", + "extraction#41972921", + "extraction#68335013", + "stem#25568812" + ] + }, + "gemma2": { + "exit": 1, + "case_id": "gemma2-gguf-behavior", + "failed_prompt_count": 4, + "prompt_count": 80, + "flagged_prompt_ids": [ + "coding#3583619", + "extraction#48344673", + "extraction#36918905", + "stem#67710207" + ] + }, + "gemma3": { + "exit": 1, + "case_id": "gemma3-gguf-behavior", + "failed_prompt_count": 10, + "prompt_count": 80, + "flagged_prompt_ids": [ + "reasoning#30350920", + "math#17722882", + "math#35264271", + "math#9170963", + "coding#56561594", + "coding#3583619", + "coding#34377376", + "coding#19673382", + "extraction#34830845", + "extraction#48344673" + ] + }, + "gemma4": { + "exit": 1, + "case_id": "gemma4-gguf-behavior", + "failed_prompt_count": 2, + "prompt_count": 80, + "flagged_prompt_ids": [ + "reasoning#40319087", + "extraction#96536319" + ] + }, + "glm4": { + "exit": 1, + "case_id": "glm4-gguf-behavior", + "failed_prompt_count": 8, + "prompt_count": 80, + "flagged_prompt_ids": [ + "math#22339468", + "math#17722882", + "math#9170963", + "math#27514636", + "coding#3583619", + "extraction#48344673", + "extraction#51413464", + "extraction#36918905" + ] + }, + "lfm2": { + "exit": 1, + "case_id": "lfm2-gguf-behavior", + "failed_prompt_count": 10, + "prompt_count": 80, + "flagged_prompt_ids": [ + "writing#38622539", + "reasoning#30350920", + "reasoning#90276646", + "math#35264271", + "math#9170963", + "coding#3583619", + "coding#19673382", + "extraction#48344673", + "extraction#42813505", + "extraction#36918905" + ] + } + }, + "mlx": { + "olmo7b": { + "exit": 0, + "case_id": "olmo7b-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "mistral": { + "exit": 0, + "case_id": "mistral-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "qwen25": { + "exit": 0, + "case_id": "qwen25-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "qwen3": { + "exit": 0, + "case_id": "qwen3-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "llama32": { + "exit": 0, + "case_id": "llama32-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "gemma2": { + "exit": 0, + "case_id": "gemma2-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "gemma3": { + "exit": 0, + "case_id": "gemma3-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "gemma4": { + "exit": 0, + "case_id": "gemma4-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "glm4": { + "exit": 0, + "case_id": "glm4-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + }, + "lfm2": { + "exit": 0, + "case_id": "lfm2-mlx-behavior", + "failed_prompt_count": 0, + "prompt_count": 80, + "flagged_prompt_ids": [] + } + } + }, + "thinking": { + "gguf": { + "qwen3": { + "exit": 1, + "case_id": "qwen3-gguf-thinking", + "failed_check_count": 3, + "check_count": 3 + }, + "glm4": { + "exit": 0, + "case_id": "glm4-gguf-thinking", + "failed_check_count": 0, + "check_count": 3 + }, + "lfm2": { + "exit": 0, + "case_id": "lfm2-gguf-thinking", + "failed_check_count": 0, + "check_count": 3 + } + }, + "mlx": { + "qwen3": { + "exit": 0, + "case_id": "qwen3-mlx-thinking", + "failed_check_count": 0, + "check_count": 3 + }, + "glm4": { + "exit": 0, + "case_id": "glm4-mlx-thinking", + "failed_check_count": 0, + "check_count": 3 + }, + "lfm2": { + "exit": 0, + "case_id": "lfm2-mlx-thinking", + "failed_check_count": 0, + "check_count": 3 + } + } + } +} diff --git a/testdata/validation/matrix.json b/testdata/validation/matrix.json new file mode 100644 index 00000000..49679e9a --- /dev/null +++ b/testdata/validation/matrix.json @@ -0,0 +1,644 @@ +{ + "defaults": { + "exact": { + "prompt": "Reply with exactly: blue", + "expect_contains": "blue", + "expect_exact": "blue", + "forbid_contains": "", + "prompt_suite": [ + { + "label": "alt-green", + "prompt": "Reply with exactly: green", + "expect_exact": "green" + }, + { + "label": "alt-red", + "prompt": "Reply with exactly: red", + "expect_exact": "red" + }, + { + "label": "capital-france", + "prompt": "What is the capital of France? Reply with one word.", + "expect_contains_ci": "Paris" + }, + { + "label": "primary-colors", + "prompt": "List the RGB primary colors as full lowercase words only, comma-separated, with no abbreviations.", + "max_tokens": 24, + "expect_contains_all_ci": [ + "red", + "green", + "blue" + ] + }, + { + "label": "two-plus-two", + "prompt": "Complete exactly: 2 + 2 =", + "max_tokens": 32, + "expect_contains_ci": "4" + }, + { + "label": "largest-planet", + "prompt": "Name the largest planet in the Solar System. Reply with one word.", + "expect_contains_ci": "Jupiter" + }, + { + "label": "breathing-gas", + "prompt": "What gas do humans need to breathe to survive? Reply with one word.", + "max_tokens": 24, + "expect_any_ci": [ + "oxygen", + "o2", + "air" + ] + }, + { + "label": "opposite-hot", + "prompt": "What is the opposite of hot? Reply with one word.", + "expect_contains_ci": "cold" + }, + { + "label": "banana-color", + "prompt": "What color is a ripe banana? Reply with one word.", + "expect_contains_ci": "yellow" + }, + { + "label": "after-monday", + "prompt": "What day comes after Monday? Reply with one word.", + "expect_contains_ci": "Tuesday" + } + ] + }, + "behavior": { + "dataset": "HuggingFaceH4/mt_bench_prompts", + "max_tokens": 192, + "wait_seconds": 300 + }, + "thinking": { + "wait_seconds": 300, + "prompt_suite": [ + { + "label": "story-continue", + "max_tokens": 256, + "messages": [ + { + "role": "user", + "content": "Write one short sentence about a recent trip to Hawaii." + }, + { + "role": "assistant", + "content": "A recent trip to Hawaii began with warm rain and ocean light." + }, + { + "role": "user", + "content": "Continue with one more short sentence about what happened next." + } + ] + }, + { + "label": "math-reasoning", + "max_tokens": 320, + "messages": [ + { + "role": "user", + "content": "Solve 23 * 17. Think step by step, then give the final answer." + } + ] + }, + { + "label": "followup-rewrite", + "max_tokens": 192, + "messages": [ + { + "role": "user", + "content": "Write a three-item packing list for a beach trip." + }, + { + "role": "assistant", + "content": "1. Sunscreen\n2. Swimsuit\n3. Towel" + }, + { + "role": "user", + "content": "Now rewrite it as one short sentence instead of a list." + } + ] + } + ] + } + }, + "models": [ + { + "id": "olmo7b", + "label": "olmo-7b-instruct-hf", + "expectation_class": "weak-but-stable", + "notes": "Accepted same-origin OLMo parity pair; exact prompts tolerate shared verbosity and phrasing drift after HF-template whitespace fixes.", + "exact": { + "prompt": "Reply with exactly: blue", + "expect_contains": "", + "expect_contains_ci": "blue", + "expect_exact": "", + "forbid_contains": "", + "prompt_suite": [ + { + "label": "alt-green", + "prompt": "Reply with exactly: green", + "expect_contains_ci": "green" + }, + { + "label": "alt-red", + "prompt": "Reply with exactly: red", + "expect_contains_ci": "red" + }, + { + "label": "capital-france", + "prompt": "What is the capital of France? Reply with one word.", + "expect_contains_ci": "Paris" + }, + { + "label": "primary-colors", + "prompt": "List the RGB primary colors as full lowercase words only, comma-separated, with no abbreviations.", + "max_tokens": 24, + "expect_contains_all_ci": [ + "red", + "green", + "blue" + ] + }, + { + "label": "two-plus-two", + "prompt": "Complete exactly: 2 + 2 =", + "max_tokens": 32, + "expect_contains_ci": "4" + }, + { + "label": "largest-planet", + "prompt": "Name the largest planet in the Solar System. Reply with one word.", + "expect_contains_ci": "Jupiter" + }, + { + "label": "breathing-gas", + "prompt": "What gas do humans need to breathe to survive? Reply with one word.", + "max_tokens": 24, + "expect_any_ci": [ + "oxygen", + "o2", + "air" + ] + }, + { + "label": "opposite-hot", + "prompt": "What is the opposite of hot? Reply with one word.", + "expect_contains_ci": "cold" + }, + { + "label": "banana-color", + "prompt": "What color is a ripe banana? Reply with one word.", + "expect_contains_ci": "yellow" + }, + { + "label": "after-monday", + "prompt": "What day comes after Monday? Reply with one word.", + "expect_contains_ci": "Tuesday" + } + ] + }, + "gguf": { + "exact_case_id": "olmo7b-gguf-exact", + "behavior_case_id": "olmo7b-gguf-behavior", + "model_ref": "meshllm/olmo-7b-instruct-hf-parity-f16-gguf/olmo-7b-instruct-hf-f16.gguf" + }, + "mlx": { + "exact_case_id": "olmo7b-mlx-exact", + "behavior_case_id": "olmo7b-mlx-behavior", + "model_ref": "meshllm/olmo-7b-instruct-hf-parity-bf16-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "mistral", + "label": "mistral-7b-instruct-v0.3", + "expectation_class": "weak-but-stable", + "notes": "Accepted same-origin Mistral parity pair; exact color prompts tolerate shared capitalization and style drift.", + "exact": { + "prompt": "Reply with exactly: blue", + "expect_contains": "", + "expect_contains_ci": "blue", + "expect_exact": "", + "forbid_contains": "", + "prompt_suite": [ + { + "label": "alt-green", + "prompt": "Reply with exactly: green", + "expect_contains_ci": "green" + }, + { + "label": "alt-red", + "prompt": "Reply with exactly: red", + "expect_contains_ci": "red" + }, + { + "label": "capital-france", + "prompt": "What is the capital of France? Reply with one word.", + "expect_contains_ci": "Paris" + }, + { + "label": "primary-colors", + "prompt": "List the RGB primary colors as full lowercase words only, comma-separated, with no abbreviations.", + "max_tokens": 24, + "expect_contains_all_ci": [ + "red", + "green", + "blue" + ] + }, + { + "label": "two-plus-two", + "prompt": "Complete exactly: 2 + 2 =", + "max_tokens": 32, + "expect_contains_ci": "4" + }, + { + "label": "largest-planet", + "prompt": "Name the largest planet in the Solar System. Reply with one word.", + "expect_contains_ci": "Jupiter" + }, + { + "label": "breathing-gas", + "prompt": "What gas do humans need to breathe to survive? Reply with one word.", + "max_tokens": 24, + "expect_any_ci": [ + "oxygen", + "o2", + "air" + ] + }, + { + "label": "opposite-hot", + "prompt": "What is the opposite of hot? Reply with one word.", + "expect_contains_ci": "cold" + }, + { + "label": "banana-color", + "prompt": "What color is a ripe banana? Reply with one word.", + "expect_contains_ci": "yellow" + }, + { + "label": "after-monday", + "prompt": "What day comes after Monday? Reply with one word.", + "expect_contains_ci": "Tuesday" + } + ] + }, + "gguf": { + "exact_case_id": "mistral-gguf-exact", + "behavior_case_id": "mistral-gguf-behavior", + "model_ref": "meshllm/mistral-7b-instruct-v0.3-parity-f16-gguf/mistral-7b-instruct-v0.3-f16.gguf" + }, + "mlx": { + "exact_case_id": "mistral-mlx-exact", + "behavior_case_id": "mistral-mlx-behavior", + "model_ref": "meshllm/mistral-7b-instruct-v0.3-parity-bf16-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "qwen25", + "label": "qwen2.5-0.5b", + "expectation_class": "strict", + "notes": "Primary MLX parity canary for exact-answer drift.", + "gguf": { + "exact_case_id": "qwen25-gguf-exact", + "behavior_case_id": "qwen25-gguf-behavior", + "model_ref": "meshllm/qwen2.5-0.5b-instruct-parity-q8_0-gguf/qwen2.5-0.5b-instruct-q8_0.gguf" + }, + "mlx": { + "exact_case_id": "qwen25-mlx-exact", + "behavior_case_id": "qwen25-mlx-behavior", + "model_ref": "meshllm/qwen2.5-0.5b-instruct-parity-8bit-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "qwen3", + "label": "qwen3-8b", + "expectation_class": "strict", + "notes": "Same-origin Qwen3 parity row using Qwen3-8B; GGUF and MLX both satisfy the strict exact suite on the local same-origin pair.", + "thinking": { + "thinking_mode": "tagged", + "prompt_suite": [ + { + "label": "story-continue", + "max_tokens": 512, + "messages": [ + { + "role": "user", + "content": "Write one short sentence about a recent trip to Hawaii." + }, + { + "role": "assistant", + "content": "A recent trip to Hawaii began with warm rain and ocean light." + }, + { + "role": "user", + "content": "Continue with one more short sentence about what happened next." + } + ] + }, + { + "label": "math-reasoning", + "max_tokens": 640, + "messages": [ + { + "role": "user", + "content": "Solve 23 * 17. Think step by step, then give the final answer." + } + ] + }, + { + "label": "followup-rewrite", + "max_tokens": 384, + "messages": [ + { + "role": "user", + "content": "Write a three-item packing list for a beach trip." + }, + { + "role": "assistant", + "content": "1. Sunscreen\n2. Swimsuit\n3. Towel" + }, + { + "role": "user", + "content": "Now rewrite it as one short sentence instead of a list." + } + ] + } + ] + }, + "gguf": { + "exact_case_id": "qwen3-gguf-exact", + "behavior_case_id": "qwen3-gguf-behavior", + "thinking_case_id": "qwen3-gguf-thinking", + "model_ref": "meshllm/qwen3-8b-parity-q8_0-gguf/qwen3-8b-q8_0.gguf" + }, + "mlx": { + "exact_case_id": "qwen3-mlx-exact", + "behavior_case_id": "qwen3-mlx-behavior", + "thinking_case_id": "qwen3-mlx-thinking", + "model_ref": "meshllm/qwen3-8b-parity-8bit-mlx/model-00001-of-00002.safetensors", + "template_source": "chat_template.jinja" + } + }, + { + "id": "llama32", + "label": "llama-3.2-1b", + "expectation_class": "weak-but-stable", + "notes": "Tiny Llama parity canary; exact capitalization drift is tolerated as a known weakness.", + "exact": { + "prompt": "Reply with exactly: blue", + "expect_contains": "", + "expect_contains_ci": "blue", + "expect_exact": "", + "forbid_contains": "", + "prompt_suite": [ + { + "label": "alt-green", + "prompt": "Reply with exactly: green", + "expect_contains_ci": "green" + }, + { + "label": "alt-red", + "prompt": "Reply with exactly: red", + "expect_contains_ci": "red" + }, + { + "label": "capital-france", + "prompt": "What is the capital of France? Reply with one word.", + "expect_contains_ci": "Paris" + }, + { + "label": "primary-colors", + "prompt": "List the RGB primary colors as full lowercase words only, comma-separated, with no abbreviations.", + "max_tokens": 24, + "expect_contains_all_ci": [ + "red", + "green", + "blue" + ] + }, + { + "label": "two-plus-two", + "prompt": "Complete exactly: 2 + 2 =", + "max_tokens": 32, + "expect_contains_ci": "4" + }, + { + "label": "largest-planet", + "prompt": "Name the largest planet in the Solar System. Reply with one word.", + "expect_contains_ci": "Jupiter" + }, + { + "label": "breathing-gas", + "prompt": "What gas do humans need to breathe to survive? Reply with one word.", + "max_tokens": 24, + "expect_any_ci": [ + "oxygen", + "o2", + "air" + ] + }, + { + "label": "opposite-hot", + "prompt": "What is the opposite of hot? Reply with one word.", + "expect_contains_ci": "cold" + }, + { + "label": "banana-color", + "prompt": "What color is a ripe banana? Reply with one word.", + "expect_contains_ci": "yellow" + }, + { + "label": "after-monday", + "prompt": "What day comes after Monday? Reply with one word.", + "expect_contains_ci": "Tuesday" + } + ] + }, + "gguf": { + "exact_case_id": "llama32-gguf-exact", + "behavior_case_id": "llama32-gguf-behavior", + "model_ref": "meshllm/llama-3.2-1b-instruct-parity-f16-gguf/llama-3.2-1b-instruct-f16.gguf" + }, + "mlx": { + "exact_case_id": "llama32-mlx-exact", + "behavior_case_id": "llama32-mlx-behavior", + "model_ref": "meshllm/llama-3.2-1b-instruct-parity-bf16-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "gemma2", + "label": "gemma-2-2b", + "expectation_class": "strict", + "notes": "Same-origin Gemma 2 parity row using published meshllm Q8_0 vs 8bit artifacts.", + "gguf": { + "exact_case_id": "gemma2-gguf-exact", + "behavior_case_id": "gemma2-gguf-behavior", + "model_ref": "meshllm/gemma-2-2b-it-parity-q8_0-gguf/gemma-2-2b-it-q8_0.gguf" + }, + "mlx": { + "exact_case_id": "gemma2-mlx-exact", + "behavior_case_id": "gemma2-mlx-behavior", + "model_ref": "meshllm/gemma-2-2b-it-parity-8bit-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "gemma3", + "label": "gemma-3-1b", + "expectation_class": "strict", + "notes": "Gemma3 parity baseline using same-origin high-fidelity artifacts published by meshllm.", + "gguf": { + "exact_case_id": "gemma3-gguf-exact", + "behavior_case_id": "gemma3-gguf-behavior", + "model_ref": "meshllm/gemma-3-1b-it-parity-f16-gguf/gemma-3-1b-it-f16.gguf" + }, + "mlx": { + "exact_case_id": "gemma3-mlx-exact", + "behavior_case_id": "gemma3-mlx-behavior", + "model_ref": "meshllm/gemma-3-1b-it-parity-bf16-mlx", + "template_source": "tokenizer_config.json" + } + }, + { + "id": "gemma4", + "label": "gemma-4-e4b", + "expectation_class": "strict", + "notes": "Gemma4 parity baseline using same-origin 8bit/Q8_0 artifacts published by meshllm.", + "gguf": { + "exact_case_id": "gemma4-gguf-exact", + "behavior_case_id": "gemma4-gguf-behavior", + "model_ref": "meshllm/gemma-4-e4b-it-parity-q8_0-gguf/gemma-4-e4b-it-q8_0.gguf" + }, + "mlx": { + "exact_case_id": "gemma4-mlx-exact", + "behavior_case_id": "gemma4-mlx-behavior", + "model_ref": "meshllm/gemma-4-e4b-it-parity-8bit-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "glm4", + "label": "glm-4-9b", + "expectation_class": "strict", + "notes": "Published same-origin GLM4 parity pair.", + "thinking": { + "thinking_mode": "nonempty" + }, + "gguf": { + "exact_case_id": "glm4-gguf-exact", + "behavior_case_id": "glm4-gguf-behavior", + "thinking_case_id": "glm4-gguf-thinking", + "model_ref": "meshllm/glm-4-9b-0414-parity-q4_k_m-gguf/glm-4-9b-0414-q4_k_m.gguf" + }, + "mlx": { + "exact_case_id": "glm4-mlx-exact", + "behavior_case_id": "glm4-mlx-behavior", + "thinking_case_id": "glm4-mlx-thinking", + "model_ref": "meshllm/glm-4-9b-0414-parity-4bit-mlx", + "template_source": "chat_template.jinja" + } + }, + { + "id": "lfm2", + "label": "lfm2-350m", + "expectation_class": "weak-but-stable", + "notes": "Tiny-model shared weakness canary for exact and behavior drift.", + "exact": { + "prompt": "Reply with exactly: blue", + "expect_contains": "", + "expect_contains_ci": "blue", + "expect_exact": "", + "forbid_contains": "", + "prompt_suite": [ + { + "label": "alt-green", + "prompt": "Reply with exactly: green", + "expect_contains_ci": "green" + }, + { + "label": "alt-red", + "prompt": "Reply with exactly: red", + "expect_contains_ci": "red" + }, + { + "label": "capital-france", + "prompt": "What is the capital of France? Reply with one word.", + "expect_contains_ci": "Paris" + }, + { + "label": "primary-colors", + "prompt": "List the RGB primary colors as full lowercase words only, comma-separated, with no abbreviations.", + "max_tokens": 24, + "expect_contains_all_ci": [ + "red", + "green", + "blue" + ] + }, + { + "label": "two-plus-two", + "prompt": "Complete exactly: 2 + 2 =", + "max_tokens": 32, + "expect_contains_ci": "4" + }, + { + "label": "largest-planet", + "prompt": "Name the largest planet in the Solar System. Reply with one word.", + "expect_contains_ci": "Jupiter" + }, + { + "label": "breathing-gas", + "prompt": "What gas do humans need to breathe to survive? Reply with one word.", + "max_tokens": 24, + "expect_any_ci": [ + "oxygen", + "o2", + "air" + ] + }, + { + "label": "opposite-hot", + "prompt": "What is the opposite of hot? Reply with one word.", + "expect_contains_ci": "cold" + }, + { + "label": "banana-color", + "prompt": "What color is a ripe banana? Reply with one word.", + "expect_contains_ci": "yellow" + }, + { + "label": "after-monday", + "prompt": "What day comes after Monday? Reply with one word.", + "expect_contains_ci": "Tuesday" + } + ] + }, + "thinking": { + "thinking_mode": "nonempty" + }, + "gguf": { + "exact_case_id": "lfm2-gguf-exact", + "behavior_case_id": "lfm2-gguf-behavior", + "thinking_case_id": "lfm2-gguf-thinking", + "model_ref": "meshllm/lfm2-350m-parity-q4_k_m-gguf/lfm2-350m-q4_k_m.gguf" + }, + "mlx": { + "exact_case_id": "lfm2-mlx-exact", + "behavior_case_id": "lfm2-mlx-behavior", + "thinking_case_id": "lfm2-mlx-thinking", + "model_ref": "meshllm/lfm2-350m-parity-4bit-mlx", + "template_source": "chat_template.jinja" + } + } + ] +}