Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ members = [
"examples/stable-diffusion",
"examples/stable-diffusion-3",
"examples/stable-diffusion-xl",
"examples/wasm-model-bench",

"harness/core-proptest-pulse",
"harness/nnef-inceptionv3",
Expand Down
21 changes: 21 additions & 0 deletions examples/wasm-model-bench/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "wasm-model-bench"
version = "0.1.0"
license = "MIT OR Apache-2.0"
edition = "2024"

[lib]
path = "src/lib.rs"

[[bin]]
name = "bench-onnx"
path = "src/bench_onnx.rs"

[[bin]]
name = "bench-nnef"
path = "src/bench_nnef.rs"

[dependencies]
anyhow.workspace = true
tract-nnef.workspace = true
tract-onnx.workspace = true
Binary file not shown.
29 changes: 29 additions & 0 deletions examples/wasm-model-bench/src/bench_nnef.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//! Bench an NNEF model. Loads via tract-nnef, runs N timed reps, reports
//! min/median/max/spread.

use anyhow::Result;
use tract_nnef::prelude::*;

fn main() -> Result<()> {
let mut args = std::env::args();
let _prog = args.next();
let model_path = args.next().ok_or_else(|| {
anyhow::anyhow!("usage: bench-nnef <model.nnef.tgz> [warmup] [timed] [reps]")
})?;
let warmup_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(20);
let timed_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(50);
let repetitions: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(10);

let model =
tract_nnef::nnef().model_for_path(&model_path)?.into_optimized()?.into_runnable()?;

let inputs = wasm_model_bench::build_zero_inputs(&model)?;

let samples =
wasm_model_bench::run_bench(&model, &inputs, warmup_iters, timed_iters, repetitions)?;
wasm_model_bench::print_stats(&model_path, &samples);
if std::env::var("TRACT_BENCH_QUALITY").ok().as_deref() == Some("1") {
wasm_model_bench::run_quality_check(&model, &inputs)?;
}
Ok(())
}
87 changes: 87 additions & 0 deletions examples/wasm-model-bench/src/bench_onnx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//! Bench an ONNX model.
//!
//! Usage: bench-onnx <model.onnx> [<shape_spec>|-] [warmup] [timed] [reps]
//!
//! <shape_spec> can be a single per-input fact like "1,3,224,224,f32" or
//! multi-input with ";" separators: "1,1,100,32,f32;1,2,100,96,f32".
//! Use "-" or empty to skip override for that input.
//!
//! Optional env var TRACT_BENCH_SYMBOLS="S=100,BATCH=1" applies global
//! symbol concretization for models whose internal nodes still reference
//! symbols after input facts are set.

use anyhow::{Result, anyhow};
use tract_onnx::prelude::*;

fn parse_fact_spec(spec: &str) -> Result<InferenceFact> {
let parts: Vec<&str> = spec.split(',').collect();
if parts.len() < 2 {
anyhow::bail!("shape spec needs at least 'shape,dtype': {spec}");
}
let dt_str = parts.last().unwrap();
let dt = match *dt_str {
"f32" => DatumType::F32,
"f16" => DatumType::F16,
"i64" => DatumType::I64,
"i32" => DatumType::I32,
"u8" => DatumType::U8,
s => return Err(anyhow!("unsupported dtype: {s}")),
};
let shape: Vec<usize> = parts[..parts.len() - 1]
.iter()
.map(|s| s.parse::<usize>().map_err(anyhow::Error::from))
.collect::<Result<_>>()?;
Ok(InferenceFact::dt_shape(dt, shape))
}

fn main() -> Result<()> {
let mut args = std::env::args();
let _prog = args.next();
let model_path = args.next().ok_or_else(|| {
anyhow!("usage: bench-onnx <model.onnx> [<shape_spec>|-] [warmup] [timed] [reps]")
})?;
let shape_spec = args.next();
let warmup_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(20);
let timed_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(50);
let repetitions: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(10);

let mut model = tract_onnx::onnx().model_for_path(&model_path)?;
let symbols_env = std::env::var("TRACT_BENCH_SYMBOLS").ok().filter(|s| !s.is_empty());

// When symbols are provided, the model is symbolic-shaped and we go straight
// to TypedModel → substitute_symbols, ignoring shape_spec (input shapes will
// be derived from the symbol substitution). When no symbols, we use the
// shape_spec to pin input facts on the InferenceModel before into_typed.
let typed = if let Some(symbols_str) = symbols_env {
let typed = model.into_typed()?;
let mut subs = std::collections::HashMap::new();
for kv in symbols_str.split(',') {
let (k, v) = kv.split_once('=').ok_or_else(|| anyhow!("bad symbol: {kv}"))?;
let sym = typed.symbols.sym(k);
subs.insert(sym, TDim::Val(v.parse::<i64>()?));
}
typed.substitute_symbols(&subs)?
} else {
if let Some(spec) = shape_spec.as_deref().filter(|s| *s != "-") {
for (i, one) in spec.split(';').enumerate() {
if one.is_empty() || one == "-" {
continue;
}
let fact = parse_fact_spec(one)?;
model.set_input_fact(i, fact)?;
}
}
model.into_typed()?
};
let model = typed.into_optimized()?.into_runnable()?;

let inputs = wasm_model_bench::build_zero_inputs(&model)?;

let samples =
wasm_model_bench::run_bench(&model, &inputs, warmup_iters, timed_iters, repetitions)?;
wasm_model_bench::print_stats(&model_path, &samples);
if std::env::var("TRACT_BENCH_QUALITY").ok().as_deref() == Some("1") {
wasm_model_bench::run_quality_check(&model, &inputs)?;
}
Ok(())
}
99 changes: 99 additions & 0 deletions examples/wasm-model-bench/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//! Shared bench harness for WASM E2E model timing.

use anyhow::Result;
use std::sync::Arc;
use std::time::Instant;
use tract_nnef::internal::DimLike;
use tract_nnef::prelude::*;

pub type Runnable = Arc<TypedSimplePlan>;

pub fn build_zero_inputs(model: &Runnable) -> Result<TVec<TValue>> {
let mut inputs = tvec![];
let typed = model.model();
for &outlet in typed.input_outlets()?.iter() {
let fact = typed.outlet_fact(outlet)?;
let shape: Vec<usize> =
fact.shape.iter().map(|d| d.to_usize()).collect::<TractResult<Vec<_>>>()?;
let dt = fact.datum_type;
let tensor = Tensor::zero_dt(dt, &shape)?;
inputs.push(tensor.into_tvalue());
}
Ok(inputs)
}

pub fn run_bench(
model: &Runnable,
inputs: &TVec<TValue>,
warmup_iters: usize,
timed_iters: usize,
repetitions: usize,
) -> Result<Vec<f64>> {
for _ in 0..warmup_iters {
let _ = model.run(inputs.clone())?;
}

let mut samples = Vec::with_capacity(repetitions);
for _ in 0..repetitions {
let t0 = Instant::now();
for _ in 0..timed_iters {
let _ = model.run(inputs.clone())?;
}
let elapsed = t0.elapsed();
let ns_per_call = elapsed.as_secs_f64() / timed_iters as f64 * 1e9;
samples.push(ns_per_call);
}
samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
Ok(samples)
}

/// Print quality metrics for the model's output(s). Use after bench. With
/// fixed-shape models and a deterministic input (zeros), running this with
/// baseline vs relaxed-simd builds and comparing outputs is the quality
/// regression check (FMA gives ~1 ulp drift; mul+add is bit-stable).
pub fn run_quality_check(model: &Runnable, inputs: &TVec<TValue>) -> Result<()> {
let outputs = model.run(inputs.clone())?;
for (i, out) in outputs.iter().enumerate() {
let dt = out.datum_type();
let shape = out.shape();
if dt == DatumType::F32 {
let tensor: &Tensor = &*out;
let slice: &[f32] = unsafe { tensor.as_slice_unchecked::<f32>() };
let n = slice.len();
let l2: f64 = slice.iter().map(|x| (*x as f64) * (*x as f64)).sum::<f64>().sqrt();
let mean: f64 = slice.iter().map(|x| *x as f64).sum::<f64>() / n as f64;
let preview: Vec<f32> = slice.iter().take(5).copied().collect();
let last5: Vec<f32> =
slice.iter().rev().take(5).copied().collect::<Vec<_>>().into_iter().rev().collect();
// Cheap deterministic checksum: XOR of all bit-patterns
let mut xor: u32 = 0;
for x in slice {
xor ^= x.to_bits();
}
eprintln!(
" output[{i}] shape={shape:?} dt=F32 n={n} L2={l2:.6e} mean={mean:.6e} xor=0x{xor:08x} first5={preview:?} last5={last5:?}"
);
} else {
eprintln!(" output[{i}] shape={shape:?} dt={dt:?} (skip non-F32 stats)");
}
}
Ok(())
}

pub fn print_stats(label: &str, samples: &[f64]) {
let min = samples[0];
let median = samples[samples.len() / 2];
let max = samples[samples.len() - 1];
let pct_spread = (max - min) / min * 100.0;
let target = if cfg!(target_feature = "relaxed-simd") {
"+relaxed-simd (FMA)"
} else if cfg!(target_family = "wasm") {
"+simd128 only (mul+add)"
} else {
"native"
};
eprintln!(
"[{target}] {label}: min={min:.0} median={median:.0} max={max:.0} ns/inference (spread {pct_spread:.0}%, n={})",
samples.len()
);
}
4 changes: 4 additions & 0 deletions linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,7 @@ harness = false
[[bench]]
name = "avx512_zombies"
harness = false

[[bench]]
name = "wasm"
harness = false
105 changes: 105 additions & 0 deletions linalg/WASM_RELAXED_SIMD.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# tract-linalg on `wasm32` — relaxed-simd FMA

The WASM MMM kernels (`wasm_f32_4x4`, `4x1`, `8x1`, `16x1`, `32x1`, `8x8`)
and the WASM sigmoid/tanh activations all flip between two emit modes at
compile time, gated on `cfg(target_feature = "relaxed-simd")`:

- **Without** `+relaxed-simd`: pure `f32x4_add(_, f32x4_mul(_, _))` (mul+add).
Runs on any WASM runtime that supports `simd128`.
- **With** `+relaxed-simd`: `f32x4_relaxed_madd(_, _, _)`. Fused, single-rounded
multiply-add on hosts whose CPU has hardware FMA (all ARM64, x86_64 + FMA3).
Universal browser/runtime support since 2023 (Chrome 114+, Firefox 120+,
Safari 17+, wasmtime 16+).

The speedup of the relaxed path over the baseline is typically **1.40–1.55× at
the kernel level** and **1.08–1.46× end-to-end** across vision CNNs,
transformer attention and RNN audio models. Bit-pattern drift versus the
mul+add path is bounded at one ulp (FMA single-rounding); within
`Approximation::Close` (1e-4).

## Build flags

```sh
# Baseline (any wasm32 runtime supporting simd128)
RUSTFLAGS='-C target-feature=+simd128' \
cargo build --release --target wasm32-wasip1 -p tract-linalg

# Relaxed (requires host support for relaxed-simd; ~1.40× faster on FMA-capable hosts)
RUSTFLAGS='-C target-feature=+simd128,+relaxed-simd' \
cargo build --release --target wasm32-wasip1 -p tract-linalg
```

Same on `wasm32-unknown-unknown` if shipping for the browser.

## Why two binaries (and not in-process runtime dispatch)

WASM validates the entire module at instantiation, before any code runs.
A binary containing `f32x4.relaxed_madd` fails to instantiate on hosts without
relaxed-simd — `LinkError` / `CompileError`, not a runtime trap. So the
x86/ARM pattern (one binary, both paths in source, runtime CPU detection picks
at execution time) cannot be replicated in-binary on WASM: the FMA opcodes are
either present (and host support is required) or absent.

Runtime dispatch happens one layer up — at the host runtime / consumer layer
— by selecting the correct binary at module-load time.

## Consumer-side dispatch

### Browser / `WebAssembly.validate`

```js
async function loadTract(baseUrl) {
const candidate = await fetch(`${baseUrl}/tract-relaxed.wasm`);
const bytes = await candidate.arrayBuffer();

const wantRelaxed = WebAssembly.validate(bytes, {
builtins: ['relaxed_simd'],
});

const url = wantRelaxed
? `${baseUrl}/tract-relaxed.wasm`
: `${baseUrl}/tract.wasm`;

const final = await fetch(url);
return WebAssembly.instantiateStreaming(final);
}
```

Fallback for hosts without the `WebAssembly.validate(bytes, { ... })`
options-arg: try-instantiate the relaxed binary, catch `LinkError` /
`CompileError`, retry with the baseline.

### `wasmtime` (server / native)

```rust
use wasmtime::{Config, Engine};

let mut config = Config::new();
config.wasm_relaxed_simd(true); // gate on host-CPU detection if needed
let engine = Engine::new(&config)?;

let bytes = std::fs::read(if relaxed_supported {
"tract-relaxed.wasm"
} else {
"tract.wasm"
})?;
let module = wasmtime::Module::new(&engine, &bytes)?;
```

`wasmtime::Engine`'s `wasm_relaxed_simd` configures the runtime; a separate
`wasmtime::Module::validate()` call against the engine is the equivalent of
the browser's `WebAssembly.validate` for picking which binary to load.

## Quality

The two binaries are **not bit-identical**. FMA's single-rounding produces
≤1 ulp drift from explicit mul+add. Verified end-to-end on Inception v3 and
DFN3 sub-models:

| model | output shape | baseline L2 | relaxed L2 |
|--------------|--------------------|-------------:|-------------:|
| Inception v3 | [1, 1001] | 6.477089e-2 | 6.477089e-2 |
| DFN3 df_dec | [1, 100, 96, 10] | 1.080686e-2 | 1.080686e-2 |

L2 norms are bit-identical to 7 sig figs; per-element values diverge in the
7th–8th decimal place. Within tract's `Approximation::Close` (1e-4).
Loading