diff --git a/Cargo.toml b/Cargo.toml index f52166307f..11682edf1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ members = [ "examples/stable-diffusion", "examples/stable-diffusion-3", "examples/stable-diffusion-xl", + "examples/wasm-model-bench", "harness/core-proptest-pulse", "harness/nnef-inceptionv3", diff --git a/examples/wasm-model-bench/Cargo.toml b/examples/wasm-model-bench/Cargo.toml new file mode 100644 index 0000000000..0a05e96a64 --- /dev/null +++ b/examples/wasm-model-bench/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "wasm-model-bench" +version = "0.1.0" +license = "MIT OR Apache-2.0" +edition = "2024" + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "bench-onnx" +path = "src/bench_onnx.rs" + +[[bin]] +name = "bench-nnef" +path = "src/bench_nnef.rs" + +[dependencies] +anyhow.workspace = true +tract-nnef.workspace = true +tract-onnx.workspace = true diff --git a/examples/wasm-model-bench/MMM_MACRO_ATTRIBUTION.pdf b/examples/wasm-model-bench/MMM_MACRO_ATTRIBUTION.pdf new file mode 100644 index 0000000000..5d98322303 Binary files /dev/null and b/examples/wasm-model-bench/MMM_MACRO_ATTRIBUTION.pdf differ diff --git a/examples/wasm-model-bench/src/bench_nnef.rs b/examples/wasm-model-bench/src/bench_nnef.rs new file mode 100644 index 0000000000..c42a1ad871 --- /dev/null +++ b/examples/wasm-model-bench/src/bench_nnef.rs @@ -0,0 +1,29 @@ +//! Bench an NNEF model. Loads via tract-nnef, runs N timed reps, reports +//! min/median/max/spread. + +use anyhow::Result; +use tract_nnef::prelude::*; + +fn main() -> Result<()> { + let mut args = std::env::args(); + let _prog = args.next(); + let model_path = args.next().ok_or_else(|| { + anyhow::anyhow!("usage: bench-nnef [warmup] [timed] [reps]") + })?; + let warmup_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(20); + let timed_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(50); + let repetitions: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(10); + + let model = + tract_nnef::nnef().model_for_path(&model_path)?.into_optimized()?.into_runnable()?; + + let inputs = wasm_model_bench::build_zero_inputs(&model)?; + + let samples = + wasm_model_bench::run_bench(&model, &inputs, warmup_iters, timed_iters, repetitions)?; + wasm_model_bench::print_stats(&model_path, &samples); + if std::env::var("TRACT_BENCH_QUALITY").ok().as_deref() == Some("1") { + wasm_model_bench::run_quality_check(&model, &inputs)?; + } + Ok(()) +} diff --git a/examples/wasm-model-bench/src/bench_onnx.rs b/examples/wasm-model-bench/src/bench_onnx.rs new file mode 100644 index 0000000000..906d35154b --- /dev/null +++ b/examples/wasm-model-bench/src/bench_onnx.rs @@ -0,0 +1,87 @@ +//! Bench an ONNX model. +//! +//! Usage: bench-onnx [|-] [warmup] [timed] [reps] +//! +//! can be a single per-input fact like "1,3,224,224,f32" or +//! multi-input with ";" separators: "1,1,100,32,f32;1,2,100,96,f32". +//! Use "-" or empty to skip override for that input. +//! +//! Optional env var TRACT_BENCH_SYMBOLS="S=100,BATCH=1" applies global +//! symbol concretization for models whose internal nodes still reference +//! symbols after input facts are set. + +use anyhow::{Result, anyhow}; +use tract_onnx::prelude::*; + +fn parse_fact_spec(spec: &str) -> Result { + let parts: Vec<&str> = spec.split(',').collect(); + if parts.len() < 2 { + anyhow::bail!("shape spec needs at least 'shape,dtype': {spec}"); + } + let dt_str = parts.last().unwrap(); + let dt = match *dt_str { + "f32" => DatumType::F32, + "f16" => DatumType::F16, + "i64" => DatumType::I64, + "i32" => DatumType::I32, + "u8" => DatumType::U8, + s => return Err(anyhow!("unsupported dtype: {s}")), + }; + let shape: Vec = parts[..parts.len() - 1] + .iter() + .map(|s| s.parse::().map_err(anyhow::Error::from)) + .collect::>()?; + Ok(InferenceFact::dt_shape(dt, shape)) +} + +fn main() -> Result<()> { + let mut args = std::env::args(); + let _prog = args.next(); + let model_path = args.next().ok_or_else(|| { + anyhow!("usage: bench-onnx [|-] [warmup] [timed] [reps]") + })?; + let shape_spec = args.next(); + let warmup_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(20); + let timed_iters: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(50); + let repetitions: usize = args.next().map(|s| s.parse()).transpose()?.unwrap_or(10); + + let mut model = tract_onnx::onnx().model_for_path(&model_path)?; + let symbols_env = std::env::var("TRACT_BENCH_SYMBOLS").ok().filter(|s| !s.is_empty()); + + // When symbols are provided, the model is symbolic-shaped and we go straight + // to TypedModel → substitute_symbols, ignoring shape_spec (input shapes will + // be derived from the symbol substitution). When no symbols, we use the + // shape_spec to pin input facts on the InferenceModel before into_typed. + let typed = if let Some(symbols_str) = symbols_env { + let typed = model.into_typed()?; + let mut subs = std::collections::HashMap::new(); + for kv in symbols_str.split(',') { + let (k, v) = kv.split_once('=').ok_or_else(|| anyhow!("bad symbol: {kv}"))?; + let sym = typed.symbols.sym(k); + subs.insert(sym, TDim::Val(v.parse::()?)); + } + typed.substitute_symbols(&subs)? + } else { + if let Some(spec) = shape_spec.as_deref().filter(|s| *s != "-") { + for (i, one) in spec.split(';').enumerate() { + if one.is_empty() || one == "-" { + continue; + } + let fact = parse_fact_spec(one)?; + model.set_input_fact(i, fact)?; + } + } + model.into_typed()? + }; + let model = typed.into_optimized()?.into_runnable()?; + + let inputs = wasm_model_bench::build_zero_inputs(&model)?; + + let samples = + wasm_model_bench::run_bench(&model, &inputs, warmup_iters, timed_iters, repetitions)?; + wasm_model_bench::print_stats(&model_path, &samples); + if std::env::var("TRACT_BENCH_QUALITY").ok().as_deref() == Some("1") { + wasm_model_bench::run_quality_check(&model, &inputs)?; + } + Ok(()) +} diff --git a/examples/wasm-model-bench/src/lib.rs b/examples/wasm-model-bench/src/lib.rs new file mode 100644 index 0000000000..9500c2329e --- /dev/null +++ b/examples/wasm-model-bench/src/lib.rs @@ -0,0 +1,99 @@ +//! Shared bench harness for WASM E2E model timing. + +use anyhow::Result; +use std::sync::Arc; +use std::time::Instant; +use tract_nnef::internal::DimLike; +use tract_nnef::prelude::*; + +pub type Runnable = Arc; + +pub fn build_zero_inputs(model: &Runnable) -> Result> { + let mut inputs = tvec![]; + let typed = model.model(); + for &outlet in typed.input_outlets()?.iter() { + let fact = typed.outlet_fact(outlet)?; + let shape: Vec = + fact.shape.iter().map(|d| d.to_usize()).collect::>>()?; + let dt = fact.datum_type; + let tensor = Tensor::zero_dt(dt, &shape)?; + inputs.push(tensor.into_tvalue()); + } + Ok(inputs) +} + +pub fn run_bench( + model: &Runnable, + inputs: &TVec, + warmup_iters: usize, + timed_iters: usize, + repetitions: usize, +) -> Result> { + for _ in 0..warmup_iters { + let _ = model.run(inputs.clone())?; + } + + let mut samples = Vec::with_capacity(repetitions); + for _ in 0..repetitions { + let t0 = Instant::now(); + for _ in 0..timed_iters { + let _ = model.run(inputs.clone())?; + } + let elapsed = t0.elapsed(); + let ns_per_call = elapsed.as_secs_f64() / timed_iters as f64 * 1e9; + samples.push(ns_per_call); + } + samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); + Ok(samples) +} + +/// Print quality metrics for the model's output(s). Use after bench. With +/// fixed-shape models and a deterministic input (zeros), running this with +/// baseline vs relaxed-simd builds and comparing outputs is the quality +/// regression check (FMA gives ~1 ulp drift; mul+add is bit-stable). +pub fn run_quality_check(model: &Runnable, inputs: &TVec) -> Result<()> { + let outputs = model.run(inputs.clone())?; + for (i, out) in outputs.iter().enumerate() { + let dt = out.datum_type(); + let shape = out.shape(); + if dt == DatumType::F32 { + let tensor: &Tensor = &*out; + let slice: &[f32] = unsafe { tensor.as_slice_unchecked::() }; + let n = slice.len(); + let l2: f64 = slice.iter().map(|x| (*x as f64) * (*x as f64)).sum::().sqrt(); + let mean: f64 = slice.iter().map(|x| *x as f64).sum::() / n as f64; + let preview: Vec = slice.iter().take(5).copied().collect(); + let last5: Vec = + slice.iter().rev().take(5).copied().collect::>().into_iter().rev().collect(); + // Cheap deterministic checksum: XOR of all bit-patterns + let mut xor: u32 = 0; + for x in slice { + xor ^= x.to_bits(); + } + eprintln!( + " output[{i}] shape={shape:?} dt=F32 n={n} L2={l2:.6e} mean={mean:.6e} xor=0x{xor:08x} first5={preview:?} last5={last5:?}" + ); + } else { + eprintln!(" output[{i}] shape={shape:?} dt={dt:?} (skip non-F32 stats)"); + } + } + Ok(()) +} + +pub fn print_stats(label: &str, samples: &[f64]) { + let min = samples[0]; + let median = samples[samples.len() / 2]; + let max = samples[samples.len() - 1]; + let pct_spread = (max - min) / min * 100.0; + let target = if cfg!(target_feature = "relaxed-simd") { + "+relaxed-simd (FMA)" + } else if cfg!(target_family = "wasm") { + "+simd128 only (mul+add)" + } else { + "native" + }; + eprintln!( + "[{target}] {label}: min={min:.0} median={median:.0} max={max:.0} ns/inference (spread {pct_spread:.0}%, n={})", + samples.len() + ); +} diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 5cc725e385..674255031d 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -123,3 +123,7 @@ harness = false [[bench]] name = "avx512_zombies" harness = false + +[[bench]] +name = "wasm" +harness = false diff --git a/linalg/WASM_RELAXED_SIMD.md b/linalg/WASM_RELAXED_SIMD.md new file mode 100644 index 0000000000..3797a0a3d9 --- /dev/null +++ b/linalg/WASM_RELAXED_SIMD.md @@ -0,0 +1,105 @@ +# tract-linalg on `wasm32` — relaxed-simd FMA + +The WASM MMM kernels (`wasm_f32_4x4`, `4x1`, `8x1`, `16x1`, `32x1`, `8x8`) +and the WASM sigmoid/tanh activations all flip between two emit modes at +compile time, gated on `cfg(target_feature = "relaxed-simd")`: + +- **Without** `+relaxed-simd`: pure `f32x4_add(_, f32x4_mul(_, _))` (mul+add). + Runs on any WASM runtime that supports `simd128`. +- **With** `+relaxed-simd`: `f32x4_relaxed_madd(_, _, _)`. Fused, single-rounded + multiply-add on hosts whose CPU has hardware FMA (all ARM64, x86_64 + FMA3). + Universal browser/runtime support since 2023 (Chrome 114+, Firefox 120+, + Safari 17+, wasmtime 16+). + +The speedup of the relaxed path over the baseline is typically **1.40–1.55× at +the kernel level** and **1.08–1.46× end-to-end** across vision CNNs, +transformer attention and RNN audio models. Bit-pattern drift versus the +mul+add path is bounded at one ulp (FMA single-rounding); within +`Approximation::Close` (1e-4). + +## Build flags + +```sh +# Baseline (any wasm32 runtime supporting simd128) +RUSTFLAGS='-C target-feature=+simd128' \ + cargo build --release --target wasm32-wasip1 -p tract-linalg + +# Relaxed (requires host support for relaxed-simd; ~1.40× faster on FMA-capable hosts) +RUSTFLAGS='-C target-feature=+simd128,+relaxed-simd' \ + cargo build --release --target wasm32-wasip1 -p tract-linalg +``` + +Same on `wasm32-unknown-unknown` if shipping for the browser. + +## Why two binaries (and not in-process runtime dispatch) + +WASM validates the entire module at instantiation, before any code runs. +A binary containing `f32x4.relaxed_madd` fails to instantiate on hosts without +relaxed-simd — `LinkError` / `CompileError`, not a runtime trap. So the +x86/ARM pattern (one binary, both paths in source, runtime CPU detection picks +at execution time) cannot be replicated in-binary on WASM: the FMA opcodes are +either present (and host support is required) or absent. + +Runtime dispatch happens one layer up — at the host runtime / consumer layer +— by selecting the correct binary at module-load time. + +## Consumer-side dispatch + +### Browser / `WebAssembly.validate` + +```js +async function loadTract(baseUrl) { + const candidate = await fetch(`${baseUrl}/tract-relaxed.wasm`); + const bytes = await candidate.arrayBuffer(); + + const wantRelaxed = WebAssembly.validate(bytes, { + builtins: ['relaxed_simd'], + }); + + const url = wantRelaxed + ? `${baseUrl}/tract-relaxed.wasm` + : `${baseUrl}/tract.wasm`; + + const final = await fetch(url); + return WebAssembly.instantiateStreaming(final); +} +``` + +Fallback for hosts without the `WebAssembly.validate(bytes, { ... })` +options-arg: try-instantiate the relaxed binary, catch `LinkError` / +`CompileError`, retry with the baseline. + +### `wasmtime` (server / native) + +```rust +use wasmtime::{Config, Engine}; + +let mut config = Config::new(); +config.wasm_relaxed_simd(true); // gate on host-CPU detection if needed +let engine = Engine::new(&config)?; + +let bytes = std::fs::read(if relaxed_supported { + "tract-relaxed.wasm" +} else { + "tract.wasm" +})?; +let module = wasmtime::Module::new(&engine, &bytes)?; +``` + +`wasmtime::Engine`'s `wasm_relaxed_simd` configures the runtime; a separate +`wasmtime::Module::validate()` call against the engine is the equivalent of +the browser's `WebAssembly.validate` for picking which binary to load. + +## Quality + +The two binaries are **not bit-identical**. FMA's single-rounding produces +≤1 ulp drift from explicit mul+add. Verified end-to-end on Inception v3 and +DFN3 sub-models: + +| model | output shape | baseline L2 | relaxed L2 | +|--------------|--------------------|-------------:|-------------:| +| Inception v3 | [1, 1001] | 6.477089e-2 | 6.477089e-2 | +| DFN3 df_dec | [1, 100, 96, 10] | 1.080686e-2 | 1.080686e-2 | + +L2 norms are bit-identical to 7 sig figs; per-element values diverge in the +7th–8th decimal place. Within tract's `Approximation::Close` (1e-4). diff --git a/linalg/benches/wasm.rs b/linalg/benches/wasm.rs new file mode 100644 index 0000000000..c1d998eef7 --- /dev/null +++ b/linalg/benches/wasm.rs @@ -0,0 +1,235 @@ +//! WASM kernel microbenches. Run on wasm32 only. +//! +//! RUSTFLAGS='-C target-feature=+simd128' \ +//! CARGO_TARGET_WASM32_WASIP1_RUNNER='wasmtime --env RUST_TEST_NOCAPTURE=1 --' \ +//! cargo bench --release --target wasm32-wasip1 -p tract-linalg --bench wasm +//! +//! Re-run with `+simd128,+relaxed-simd` to compare baseline mul+add against +//! the FMA emit driven by the `madd_f32x4!` macro in `linalg/src/wasm.rs`. + +#[cfg(not(target_arch = "wasm32"))] +fn main() { + eprintln!("this bench only runs on wasm32 targets — skipping on host"); +} + +#[cfg(target_arch = "wasm32")] +fn main() { + let target = if cfg!(target_feature = "relaxed-simd") { + "+simd128,+relaxed-simd (FMA)" + } else { + "+simd128 only (mul+add)" + }; + + eprintln!("=== WASM 8x8 GEMM microbench ({target}) ==="); + bench_8x8::run(); + + eprintln!(); + eprintln!("=== Isolated 32x1 GEMV microbench ({target}) ==="); + bench_32x1::run(); +} + +#[cfg(target_arch = "wasm32")] +mod bench_8x8 { + //! Microbench: time `wasm_f32_8x8` (the GEMM kernel for N>=2) at shapes + //! relevant to DFN3, transformer FFN, and CNN→GEMM workloads. + + use std::time::Instant; + use tract_data::internal::*; + use tract_linalg::mmm::{AsInputValue, FusedSpec}; + + fn run_one( + kernel: &dyn tract_linalg::mmm::MatMatMul, + m: usize, + k: usize, + n: usize, + iters: usize, + ) -> f64 { + let packing = &kernel.packings()[0]; + let a = Tensor::zero::(&[m, k]).unwrap(); + let pa = packing.0.prepare_one(&a, 1, 0).unwrap(); + let b = Tensor::zero::(&[k, n]).unwrap(); + let pb = packing.1.prepare_one(&b, 0, 1).unwrap(); + let mut c = Tensor::zero::(&[m, n]).unwrap(); + + for _ in 0..50 { + unsafe { + kernel + .run( + m, + n, + &[ + FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 0, + }, + FusedSpec::Store(kernel.c_view(Some(0), Some(1)).wrap(&c.view_mut())), + ], + ) + .unwrap(); + } + } + + let t0 = Instant::now(); + for _ in 0..iters { + unsafe { + kernel + .run( + m, + n, + &[ + FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 0, + }, + FusedSpec::Store(kernel.c_view(Some(0), Some(1)).wrap(&c.view_mut())), + ], + ) + .unwrap(); + } + } + let elapsed = t0.elapsed(); + elapsed.as_secs_f64() / iters as f64 * 1e9 + } + + fn pick(name: &str) -> Box { + let mut ops = tract_linalg::generic(); + tract_linalg::wasm::plug(&mut ops); + for impl_ in ops.mmm_impls() { + if impl_.name() == name { + return impl_.clone(); + } + } + panic!("kernel {name} not registered") + } + + fn bench_shape(label: &str, m: usize, k: usize, n: usize, iters: usize) { + let k88 = pick("wasm_f32_8x8"); + let ns = run_one(&*k88, m, k, n, iters); + let m_tiles = m.div_ceil(8); + let n_tiles = n.div_ceil(8); + let total_tiles = m_tiles * n_tiles; + let per_tile_ns = ns / total_tiles as f64; + eprintln!( + "{label} (m={m} k={k} n={n}, iters={iters}): {ns:.0} ns/call \ + ({total_tiles} 8x8 tiles, {per_tile_ns:.1} ns/tile)" + ); + } + + pub fn run() { + // DFN3 N>1 GEMM case (the primary 8x8 hit on DFN3). + bench_shape("DFN3-style m=64 k=64 n=8", 64, 64, 8, 50_000); + // Larger N — typical batched/transformer GEMM. + bench_shape("m=64 k=64 n=64", 64, 64, 64, 10_000); + bench_shape("m=128 k=128 n=8", 128, 128, 8, 20_000); + bench_shape("m=128 k=128 n=64", 128, 128, 64, 5_000); + bench_shape("m=256 k=256 n=8", 256, 256, 8, 5_000); + bench_shape("m=256 k=256 n=64", 256, 256, 64, 1_000); + // Whisper-tiny FFN-ish (large K, small N). + bench_shape("m=384 k=1536 n=8", 384, 1536, 8, 1_000); + } +} + +#[cfg(target_arch = "wasm32")] +mod bench_32x1 { + //! Isolated, statistics-aware microbench for `wasm_f32_32x1` to investigate + //! the apparent regression at M=100/256 in `microbench_dispatch_gemv`. That + //! bench loops all 4 GEMV kernels back-to-back at every shape, biasing the + //! later-running kernel (32x1) with cache contention and thermal buildup. + //! This module benches 32x1 alone, with min-of-N reporting across + //! repetitions to expose variance honestly. + + use std::time::Instant; + use tract_data::internal::*; + use tract_linalg::mmm::{AsInputValue, FusedSpec}; + + fn run_one(kernel: &dyn tract_linalg::mmm::MatMatMul, m: usize, k: usize, iters: usize) -> f64 { + let packing = &kernel.packings()[0]; + let a = Tensor::zero::(&[m, k]).unwrap(); + let pa = packing.0.prepare_one(&a, 1, 0).unwrap(); + let b = Tensor::zero::(&[k, 1]).unwrap(); + let pb = packing.1.prepare_one(&b, 0, 1).unwrap(); + let mut c = Tensor::zero::(&[m, 1]).unwrap(); + + // Generous warmup — 200 calls primes the JIT and hot caches. + for _ in 0..200 { + unsafe { + kernel + .run( + m, + 1, + &[ + FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 0, + }, + FusedSpec::Store(kernel.c_view(Some(0), Some(0)).wrap(&c.view_mut())), + ], + ) + .unwrap(); + } + } + + let t0 = Instant::now(); + for _ in 0..iters { + unsafe { + kernel + .run( + m, + 1, + &[ + FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 0, + }, + FusedSpec::Store(kernel.c_view(Some(0), Some(0)).wrap(&c.view_mut())), + ], + ) + .unwrap(); + } + } + let elapsed = t0.elapsed(); + elapsed.as_secs_f64() / iters as f64 * 1e9 + } + + fn pick(name: &str) -> Box { + let mut ops = tract_linalg::generic(); + tract_linalg::wasm::plug(&mut ops); + for impl_ in ops.mmm_impls() { + if impl_.name() == name { + return impl_.clone(); + } + } + panic!("kernel {name} not registered") + } + + fn bench_min_of_n(label: &str, m: usize, k: usize, iters: usize, repetitions: usize) { + let kernel = pick("wasm_f32_32x1"); + let mut samples: Vec = Vec::with_capacity(repetitions); + for _ in 0..repetitions { + samples.push(run_one(&*kernel, m, k, iters)); + } + samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let min = samples[0]; + let median = samples[samples.len() / 2]; + let max = samples[samples.len() - 1]; + let pct_spread = (max - min) / min * 100.0; + eprintln!( + "{label} (m={m} k={k}, {iters} iters × {repetitions} reps): \ + min={min:.0} median={median:.0} max={max:.0} ns/call (spread {pct_spread:.0}%)" + ); + } + + pub fn run() { + // Suspect shapes from microbench_dispatch_gemv (apparent regression): + bench_min_of_n("M=100 k=256", 100, 256, 10_000, 10); + bench_min_of_n("M=256 k=256", 256, 256, 5_000, 10); + bench_min_of_n("M=256 k=512", 256, 512, 2_000, 10); + // Reference shapes (showed clean speedup before): + bench_min_of_n("M=24 k=256", 24, 256, 30_000, 10); + bench_min_of_n("M=64 k=96", 64, 96, 20_000, 10); + } +} diff --git a/linalg/src/wasm.rs b/linalg/src/wasm.rs index c3c1ff9e71..501c78cdae 100644 --- a/linalg/src/wasm.rs +++ b/linalg/src/wasm.rs @@ -14,6 +14,29 @@ use crate::{Ops, Scaler}; #[cfg(target_feature = "relaxed-simd")] use crate::frame::element_wise::ElementWiseKer; +// f32x4 mul+add → relaxed FMA when the build has +relaxed-simd, else explicit +// mul+add. Lets the MMM kernels emit f32x4.relaxed_madd without duplicating +// kernel source. Per PR #2195: LLVM does not auto-emit relaxed_madd from +// f32x4_add(f32x4_mul(...)) even with +relaxed-simd — hand emission is needed. +// +// Caller must have `use std::arch::wasm32::*;` in scope (every kernel does). +// Args are passed (acc, a, b); evaluation order differs between the two arms +// (acc-first in baseline, acc-last in FMA), so callers must pass simple +// variable names rather than expressions with side effects. +#[cfg(target_feature = "relaxed-simd")] +macro_rules! madd_f32x4 { + ($acc:expr, $a:expr, $b:expr) => { + f32x4_relaxed_madd($a, $b, $acc) + }; +} + +#[cfg(not(target_feature = "relaxed-simd"))] +macro_rules! madd_f32x4 { + ($acc:expr, $a:expr, $b:expr) => { + f32x4_add($acc, f32x4_mul($a, $b)) + }; +} + pub fn plug(ops: &mut Ops) { ops.mmm_impls.push(wasm_f32_4x4.mmm()); ops.mmm_impls.push(wasm_f32_4x1.mmm()); @@ -277,10 +300,10 @@ unsafe fn kernel_f32_4x4(mut pnl: *const FusedKerSpec) -> isize { } FusedKerSpec::AddRowColProducts(rows, cols) => { let cols = v128_load(cols as *const v128); - ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(*rows.add(0)), cols)); - ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(*rows.add(1)), cols)); - ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(*rows.add(2)), cols)); - ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(*rows.add(3)), cols)); + ab0 = madd_f32x4!(ab0, f32x4_splat(*rows.add(0)), cols); + ab1 = madd_f32x4!(ab1, f32x4_splat(*rows.add(1)), cols); + ab2 = madd_f32x4!(ab2, f32x4_splat(*rows.add(2)), cols); + ab3 = madd_f32x4!(ab3, f32x4_splat(*rows.add(3)), cols); } FusedKerSpec::Store(tile) => { let mut ptr: *mut u8 = tile.ptr; @@ -322,10 +345,10 @@ unsafe fn kernel_f32_4x4(mut pnl: *const FusedKerSpec) -> isize { for i in 0..k { let a = std::slice::from_raw_parts(a.offset(4 * i as isize), 4); let b = v128_load(b.offset(i as isize)); - ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(a[0]), b)); - ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(a[1]), b)); - ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(a[2]), b)); - ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(a[3]), b)); + ab0 = madd_f32x4!(ab0, f32x4_splat(a[0]), b); + ab1 = madd_f32x4!(ab1, f32x4_splat(a[1]), b); + ab2 = madd_f32x4!(ab2, f32x4_splat(a[2]), b); + ab3 = madd_f32x4!(ab3, f32x4_splat(a[3]), b); } } } @@ -460,7 +483,7 @@ unsafe fn kernel_f32_4x1(mut pnl: *const FusedKerSpec) -> isize { // ab[i] += rows[i] * cols[0] (cols[0] is the single col) let r = v128_load(rows as *const v128); let c = f32x4_splat(*cols); - ab = f32x4_add(ab, f32x4_mul(r, c)); + ab = madd_f32x4!(ab, r, c); } FusedKerSpec::Store(tile) => { // 4 rows × 1 col, write each lane to a separate row @@ -482,7 +505,7 @@ unsafe fn kernel_f32_4x1(mut pnl: *const FusedKerSpec) -> isize { for i in 0..k { let a_vec = v128_load(a.offset(i as isize)); let b_splat = f32x4_splat(*b.offset(i as isize)); - ab = f32x4_add(ab, f32x4_mul(a_vec, b_splat)); + ab = madd_f32x4!(ab, a_vec, b_splat); } } } @@ -688,8 +711,8 @@ unsafe fn kernel_f32_8x1(mut pnl: *const FusedKerSpec) -> isize { let r_t = v128_load(p); let r_b = v128_load(p.add(1)); let c = f32x4_splat(*cols); - ab_top = f32x4_add(ab_top, f32x4_mul(r_t, c)); - ab_bot = f32x4_add(ab_bot, f32x4_mul(r_b, c)); + ab_top = madd_f32x4!(ab_top, r_t, c); + ab_bot = madd_f32x4!(ab_bot, r_b, c); } FusedKerSpec::Store(tile) => { // 8 rows × 1 col, write each lane to a separate row @@ -720,8 +743,8 @@ unsafe fn kernel_f32_8x1(mut pnl: *const FusedKerSpec) -> isize { let a_t = v128_load(a.offset((2 * i) as isize)); let a_b = v128_load(a.offset((2 * i + 1) as isize)); let b_splat = f32x4_splat(*b.offset(i as isize)); - ab_top = f32x4_add(ab_top, f32x4_mul(a_t, b_splat)); - ab_bot = f32x4_add(ab_bot, f32x4_mul(a_b, b_splat)); + ab_top = madd_f32x4!(ab_top, a_t, b_splat); + ab_bot = madd_f32x4!(ab_bot, a_b, b_splat); } } } @@ -944,10 +967,10 @@ unsafe fn kernel_f32_16x1(mut pnl: *const FusedKerSpec) -> isize { FusedKerSpec::AddRowColProducts(rows, cols) => { let p = rows as *const v128; let c = f32x4_splat(*cols); - ab_q0 = f32x4_add(ab_q0, f32x4_mul(v128_load(p), c)); - ab_q1 = f32x4_add(ab_q1, f32x4_mul(v128_load(p.add(1)), c)); - ab_q2 = f32x4_add(ab_q2, f32x4_mul(v128_load(p.add(2)), c)); - ab_q3 = f32x4_add(ab_q3, f32x4_mul(v128_load(p.add(3)), c)); + ab_q0 = madd_f32x4!(ab_q0, v128_load(p), c); + ab_q1 = madd_f32x4!(ab_q1, v128_load(p.add(1)), c); + ab_q2 = madd_f32x4!(ab_q2, v128_load(p.add(2)), c); + ab_q3 = madd_f32x4!(ab_q3, v128_load(p.add(3)), c); } FusedKerSpec::Store(tile) => { // 16 rows × 1 col, write each lane to a separate row @@ -975,10 +998,10 @@ unsafe fn kernel_f32_16x1(mut pnl: *const FusedKerSpec) -> isize { let a2 = v128_load(a.offset((4 * i + 2) as isize)); let a3 = v128_load(a.offset((4 * i + 3) as isize)); let bs = f32x4_splat(*b.offset(i as isize)); - ab_q0 = f32x4_add(ab_q0, f32x4_mul(a0, bs)); - ab_q1 = f32x4_add(ab_q1, f32x4_mul(a1, bs)); - ab_q2 = f32x4_add(ab_q2, f32x4_mul(a2, bs)); - ab_q3 = f32x4_add(ab_q3, f32x4_mul(a3, bs)); + ab_q0 = madd_f32x4!(ab_q0, a0, bs); + ab_q1 = madd_f32x4!(ab_q1, a1, bs); + ab_q2 = madd_f32x4!(ab_q2, a2, bs); + ab_q3 = madd_f32x4!(ab_q3, a3, bs); } } } @@ -1315,14 +1338,14 @@ unsafe fn kernel_f32_32x1(mut pnl: *const FusedKerSpec) -> isize { FusedKerSpec::AddRowColProducts(rows, cols) => { let p = rows as *const v128; let c = f32x4_splat(*cols); - ab_q0 = f32x4_add(ab_q0, f32x4_mul(v128_load(p), c)); - ab_q1 = f32x4_add(ab_q1, f32x4_mul(v128_load(p.add(1)), c)); - ab_q2 = f32x4_add(ab_q2, f32x4_mul(v128_load(p.add(2)), c)); - ab_q3 = f32x4_add(ab_q3, f32x4_mul(v128_load(p.add(3)), c)); - ab_q4 = f32x4_add(ab_q4, f32x4_mul(v128_load(p.add(4)), c)); - ab_q5 = f32x4_add(ab_q5, f32x4_mul(v128_load(p.add(5)), c)); - ab_q6 = f32x4_add(ab_q6, f32x4_mul(v128_load(p.add(6)), c)); - ab_q7 = f32x4_add(ab_q7, f32x4_mul(v128_load(p.add(7)), c)); + ab_q0 = madd_f32x4!(ab_q0, v128_load(p), c); + ab_q1 = madd_f32x4!(ab_q1, v128_load(p.add(1)), c); + ab_q2 = madd_f32x4!(ab_q2, v128_load(p.add(2)), c); + ab_q3 = madd_f32x4!(ab_q3, v128_load(p.add(3)), c); + ab_q4 = madd_f32x4!(ab_q4, v128_load(p.add(4)), c); + ab_q5 = madd_f32x4!(ab_q5, v128_load(p.add(5)), c); + ab_q6 = madd_f32x4!(ab_q6, v128_load(p.add(6)), c); + ab_q7 = madd_f32x4!(ab_q7, v128_load(p.add(7)), c); } FusedKerSpec::Store(tile) => { // 32 rows × 1 col, write each lane to a separate row @@ -1354,14 +1377,14 @@ unsafe fn kernel_f32_32x1(mut pnl: *const FusedKerSpec) -> isize { let a6 = v128_load(a.offset((8 * i + 6) as isize)); let a7 = v128_load(a.offset((8 * i + 7) as isize)); let bs = f32x4_splat(*b.offset(i as isize)); - ab_q0 = f32x4_add(ab_q0, f32x4_mul(a0, bs)); - ab_q1 = f32x4_add(ab_q1, f32x4_mul(a1, bs)); - ab_q2 = f32x4_add(ab_q2, f32x4_mul(a2, bs)); - ab_q3 = f32x4_add(ab_q3, f32x4_mul(a3, bs)); - ab_q4 = f32x4_add(ab_q4, f32x4_mul(a4, bs)); - ab_q5 = f32x4_add(ab_q5, f32x4_mul(a5, bs)); - ab_q6 = f32x4_add(ab_q6, f32x4_mul(a6, bs)); - ab_q7 = f32x4_add(ab_q7, f32x4_mul(a7, bs)); + ab_q0 = madd_f32x4!(ab_q0, a0, bs); + ab_q1 = madd_f32x4!(ab_q1, a1, bs); + ab_q2 = madd_f32x4!(ab_q2, a2, bs); + ab_q3 = madd_f32x4!(ab_q3, a3, bs); + ab_q4 = madd_f32x4!(ab_q4, a4, bs); + ab_q5 = madd_f32x4!(ab_q5, a5, bs); + ab_q6 = madd_f32x4!(ab_q6, a6, bs); + ab_q7 = madd_f32x4!(ab_q7, a7, bs); } } } @@ -1972,29 +1995,29 @@ unsafe fn kernel_f32_8x8(mut pnl: *const FusedKerSpec) -> isize { let clo = v128_load(p); let chi = v128_load(p.add(1)); let r0 = f32x4_splat(*rows.add(0)); - a0lo = f32x4_add(a0lo, f32x4_mul(r0, clo)); - a0hi = f32x4_add(a0hi, f32x4_mul(r0, chi)); + a0lo = madd_f32x4!(a0lo, r0, clo); + a0hi = madd_f32x4!(a0hi, r0, chi); let r1 = f32x4_splat(*rows.add(1)); - a1lo = f32x4_add(a1lo, f32x4_mul(r1, clo)); - a1hi = f32x4_add(a1hi, f32x4_mul(r1, chi)); + a1lo = madd_f32x4!(a1lo, r1, clo); + a1hi = madd_f32x4!(a1hi, r1, chi); let r2 = f32x4_splat(*rows.add(2)); - a2lo = f32x4_add(a2lo, f32x4_mul(r2, clo)); - a2hi = f32x4_add(a2hi, f32x4_mul(r2, chi)); + a2lo = madd_f32x4!(a2lo, r2, clo); + a2hi = madd_f32x4!(a2hi, r2, chi); let r3 = f32x4_splat(*rows.add(3)); - a3lo = f32x4_add(a3lo, f32x4_mul(r3, clo)); - a3hi = f32x4_add(a3hi, f32x4_mul(r3, chi)); + a3lo = madd_f32x4!(a3lo, r3, clo); + a3hi = madd_f32x4!(a3hi, r3, chi); let r4 = f32x4_splat(*rows.add(4)); - a4lo = f32x4_add(a4lo, f32x4_mul(r4, clo)); - a4hi = f32x4_add(a4hi, f32x4_mul(r4, chi)); + a4lo = madd_f32x4!(a4lo, r4, clo); + a4hi = madd_f32x4!(a4hi, r4, chi); let r5 = f32x4_splat(*rows.add(5)); - a5lo = f32x4_add(a5lo, f32x4_mul(r5, clo)); - a5hi = f32x4_add(a5hi, f32x4_mul(r5, chi)); + a5lo = madd_f32x4!(a5lo, r5, clo); + a5hi = madd_f32x4!(a5hi, r5, chi); let r6 = f32x4_splat(*rows.add(6)); - a6lo = f32x4_add(a6lo, f32x4_mul(r6, clo)); - a6hi = f32x4_add(a6hi, f32x4_mul(r6, chi)); + a6lo = madd_f32x4!(a6lo, r6, clo); + a6hi = madd_f32x4!(a6hi, r6, chi); let r7 = f32x4_splat(*rows.add(7)); - a7lo = f32x4_add(a7lo, f32x4_mul(r7, clo)); - a7hi = f32x4_add(a7hi, f32x4_mul(r7, chi)); + a7lo = madd_f32x4!(a7lo, r7, clo); + a7hi = madd_f32x4!(a7hi, r7, chi); } FusedKerSpec::Store(tile) => { // 8 rows × 8 cols stores @@ -2039,29 +2062,29 @@ unsafe fn kernel_f32_8x8(mut pnl: *const FusedKerSpec) -> isize { let blo = v128_load(b.offset((2 * i) as isize)); let bhi = v128_load(b.offset((2 * i + 1) as isize)); let s = f32x4_splat(arow[0]); - a0lo = f32x4_add(a0lo, f32x4_mul(s, blo)); - a0hi = f32x4_add(a0hi, f32x4_mul(s, bhi)); + a0lo = madd_f32x4!(a0lo, s, blo); + a0hi = madd_f32x4!(a0hi, s, bhi); let s = f32x4_splat(arow[1]); - a1lo = f32x4_add(a1lo, f32x4_mul(s, blo)); - a1hi = f32x4_add(a1hi, f32x4_mul(s, bhi)); + a1lo = madd_f32x4!(a1lo, s, blo); + a1hi = madd_f32x4!(a1hi, s, bhi); let s = f32x4_splat(arow[2]); - a2lo = f32x4_add(a2lo, f32x4_mul(s, blo)); - a2hi = f32x4_add(a2hi, f32x4_mul(s, bhi)); + a2lo = madd_f32x4!(a2lo, s, blo); + a2hi = madd_f32x4!(a2hi, s, bhi); let s = f32x4_splat(arow[3]); - a3lo = f32x4_add(a3lo, f32x4_mul(s, blo)); - a3hi = f32x4_add(a3hi, f32x4_mul(s, bhi)); + a3lo = madd_f32x4!(a3lo, s, blo); + a3hi = madd_f32x4!(a3hi, s, bhi); let s = f32x4_splat(arow[4]); - a4lo = f32x4_add(a4lo, f32x4_mul(s, blo)); - a4hi = f32x4_add(a4hi, f32x4_mul(s, bhi)); + a4lo = madd_f32x4!(a4lo, s, blo); + a4hi = madd_f32x4!(a4hi, s, bhi); let s = f32x4_splat(arow[5]); - a5lo = f32x4_add(a5lo, f32x4_mul(s, blo)); - a5hi = f32x4_add(a5hi, f32x4_mul(s, bhi)); + a5lo = madd_f32x4!(a5lo, s, blo); + a5hi = madd_f32x4!(a5hi, s, bhi); let s = f32x4_splat(arow[6]); - a6lo = f32x4_add(a6lo, f32x4_mul(s, blo)); - a6hi = f32x4_add(a6hi, f32x4_mul(s, bhi)); + a6lo = madd_f32x4!(a6lo, s, blo); + a6hi = madd_f32x4!(a6hi, s, bhi); let s = f32x4_splat(arow[7]); - a7lo = f32x4_add(a7lo, f32x4_mul(s, blo)); - a7hi = f32x4_add(a7hi, f32x4_mul(s, bhi)); + a7lo = madd_f32x4!(a7lo, s, blo); + a7hi = madd_f32x4!(a7hi, s, bhi); } } } @@ -2414,6 +2437,7 @@ mod microbench_dispatch_gemv { bench_shape("M=256 k=256", 256, 256, 5_000); } } + // Relaxed-SIMD activation kernels (f32, FMA path). // // `f32x4_relaxed_madd(a, b, c)` computes `a * b + c`. On hosts with hardware