diff --git a/cli/src/main.rs b/cli/src/main.rs index 8f7aab7a42..29f495ac54 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -74,6 +74,9 @@ pub const STAGES: &[&str] = &[ "pulse", "pulse-to-type", "pulse-declutter", + "pulse-v2", + "pulse-v2-to-type", + "pulse-v2-declutter", "set", "set-declutter", "nnef-cycle", @@ -169,6 +172,7 @@ fn main() -> TractResult<()> { .arg(arg!(-O --optimize "Optimize before running")) .arg(arg!(--"assert-maximal-mm-quality-cost" [MAX] "Maximum value for quality category (0=assembly, 4=dreadful rust code)")) .arg(arg!(--pulse [PULSE] "Translate to pulse network")) + .arg(arg!(--"pulse-v2" [SYM] "Translate to pulse-v2 network (streaming axis symbol, default S)")) .arg(arg!(--"machine-friendly" "Machine friendly output")) .arg(arg!(--"timeout" [SECONDS] "Kill the process after this many seconds")) diff --git a/cli/src/params.rs b/cli/src/params.rs index 4f8bea628c..8b8de0cefd 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -725,6 +725,16 @@ impl Parameters { stage!("pulse-to-type", pulsed_model -> typed_model, |m:PulsedModel| m.into_typed()); stage!("pulse-declutter", typed_model -> typed_model, |m:TypedModel| m.into_decluttered()); } + if let Some(spec) = matches.get_one::("pulse-v2") { + stage!("pulse-v2", typed_model -> typed_model, |m:TypedModel| { + use tract_pulse::v2::PulseV2Model; + let stream_sym = m.symbols.sym(spec); + let pv2 = PulseV2Model::new(&m, stream_sym)?; + pv2.into_typed() + }); + stage!("pulse-v2-to-type", typed_model -> typed_model, |m:TypedModel| Ok(m)); + stage!("pulse-v2-declutter", typed_model -> typed_model, |m:TypedModel| m.into_decluttered()); + } } let mut transforms: Vec<&str> = matches .get_many::("transform") diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index be09079344..2a6012ca47 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -1902,3 +1902,17 @@ mod tests { assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B"); } } + +#[test] +fn mul_neg_b_by_8() { + let s = crate::dim::SymbolScope::default(); + let b = Sym(s.sym("B")); + // 8*(-1*B) should equal -8*B + let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]); + let c = MulInt(-8, Box::new(b.clone())); + let a_s = a.simplify(); + let c_s = c.simplify(); + eprintln!("8*(-1*B) = {a_s}"); + eprintln!("-8*B = {c_s}"); + assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B"); +} diff --git a/harness/core-proptest-pulse/Cargo.toml b/harness/core-proptest-pulse/Cargo.toml index 669815987b..b9ed7ede6b 100644 --- a/harness/core-proptest-pulse/Cargo.toml +++ b/harness/core-proptest-pulse/Cargo.toml @@ -13,3 +13,4 @@ tract-pulse.workspace = true log.workspace = true proptest.workspace = true env_logger.workspace = true +tract-tensorflow.workspace = true diff --git a/harness/core-proptest-pulse/src/conv_plus_conv.rs b/harness/core-proptest-pulse/src/conv_plus_conv.rs index 9499aab160..3be81dccf3 100644 --- a/harness/core-proptest-pulse/src/conv_plus_conv.rs +++ b/harness/core-proptest-pulse/src/conv_plus_conv.rs @@ -63,7 +63,7 @@ impl Arbitrary for ConvOp { } #[derive(Debug, Clone)] -struct ConvPlusConvProblem { +pub(crate) struct ConvPlusConvProblem { input: Tensor, pulse: usize, convs: Vec, @@ -123,6 +123,10 @@ impl ConvPlusConvProblem { 2, ) } + + pub fn run_v2(&self) -> TestCaseResult { + crate::v2::run_and_compare_v2(Self::model(&self.convs), self.pulse, &self.input, 2) + } } proptest! { diff --git a/harness/core-proptest-pulse/src/deconv.rs b/harness/core-proptest-pulse/src/deconv.rs index e823d581c7..1fc353f05b 100644 --- a/harness/core-proptest-pulse/src/deconv.rs +++ b/harness/core-proptest-pulse/src/deconv.rs @@ -68,7 +68,7 @@ impl Arbitrary for DeconvOp { } #[derive(Debug, Clone)] -struct DeconvProblem { +pub(crate) struct DeconvProblem { input: Array3, pulse: usize, deconv: DeconvOp, @@ -94,6 +94,17 @@ impl Arbitrary for DeconvProblem { } impl DeconvProblem { + pub fn run_v2(&self) -> TestCaseResult { + let mut model = TypedModel::default(); + let mut fact = f32::fact(self.input.shape()); + let s = model.symbols.sym("S"); + fact.shape.set(2, s.to_dim()); + let input = model.add_source("a", fact).unwrap(); + let id = self.deconv.chain("deconv1", &mut model, input); + model.select_output_outlets(&[id]).unwrap(); + crate::v2::run_and_compare_v2(model, self.pulse, &self.input.clone().into_tensor(), 2) + } + pub fn run(&self) -> TestCaseResult { let mut model = TypedModel::default(); let mut fact = f32::fact(self.input.shape()); diff --git a/harness/core-proptest-pulse/src/delay_plus_downsample.rs b/harness/core-proptest-pulse/src/delay_plus_downsample.rs index a095721968..f08c385106 100644 --- a/harness/core-proptest-pulse/src/delay_plus_downsample.rs +++ b/harness/core-proptest-pulse/src/delay_plus_downsample.rs @@ -6,7 +6,7 @@ use tract_core::tract_data::itertools::Itertools; use super::*; #[derive(Debug, Clone)] -struct DelayPlusDownsampleProblem { +pub(crate) struct DelayPlusDownsampleProblem { input: usize, pulse: usize, delay: usize, @@ -45,6 +45,22 @@ impl Arbitrary for DelayPlusDownsampleProblem { } impl DelayPlusDownsampleProblem { + pub fn run_v2(&self) -> TestCaseResult { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(1, s, 1))).unwrap(); + let crop = model.wire_node("delay", Slice::new(1, self.delay, s), &[a]).unwrap(); + let ds = model + .wire_node( + "ds", + Downsample { axis: 1, stride: self.stride as isize, modulo: self.modulo }, + &crop, + ) + .unwrap(); + model.select_output_outlets(&ds).unwrap(); + crate::v2::run_and_compare_v2(model, self.pulse, &t(self.input).into_tensor(), 1) + } + pub fn run(&self) -> TestCaseResult { let mut model = TypedModel::default(); let s = model.symbols.sym("S"); diff --git a/harness/core-proptest-pulse/src/delay_plus_pool.rs b/harness/core-proptest-pulse/src/delay_plus_pool.rs index a8fb8fcf9f..85fa84f18f 100644 --- a/harness/core-proptest-pulse/src/delay_plus_pool.rs +++ b/harness/core-proptest-pulse/src/delay_plus_pool.rs @@ -5,7 +5,7 @@ use tract_core::ops::cnn::MaxPool; use super::*; #[derive(Debug, Clone)] -struct DelayPlusPoolProblem { +pub(crate) struct DelayPlusPoolProblem { input: Vec, pulse: usize, delay: usize, @@ -44,6 +44,27 @@ impl Arbitrary for DelayPlusPoolProblem { } impl DelayPlusPoolProblem { + pub fn run_v2(&self) -> TestCaseResult { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(1, s, 1))).unwrap(); + let crop = model.wire_node("delay", Slice::new(1, self.delay, s), &[a]).unwrap(); + let pool_spec = PoolSpec::new( + DataFormat::NHWC, + tvec!(self.pool_window), + self.padding.clone(), + None, + Some(tvec!(self.stride)), + 1, + 1, + ); + let pool = model.wire_node("pool", MaxPool::new(pool_spec, None), &crop).unwrap(); + model.select_output_outlets(&pool).unwrap(); + let input = + arr1(&self.input).into_shape_with_order((1, self.input.len(), 1)).unwrap().into_dyn(); + crate::v2::run_and_compare_v2(model, self.pulse, &input.into_tensor(), 1) + } + pub fn run(&self) -> TestCaseResult { let mut model = TypedModel::default(); let s = model.symbols.sym("S"); diff --git a/harness/core-proptest-pulse/src/lib.rs b/harness/core-proptest-pulse/src/lib.rs index a5f360696f..b70f6296a0 100644 --- a/harness/core-proptest-pulse/src/lib.rs +++ b/harness/core-proptest-pulse/src/lib.rs @@ -20,12 +20,13 @@ use tract_core::ops::nn::DataFormat; use tract_ndarray::prelude::*; use tract_pulse::internal::*; -mod conv_plus_conv; -mod deconv; -mod delay_plus_downsample; -mod delay_plus_pool; +pub(crate) mod conv_plus_conv; +pub(crate) mod deconv; +pub(crate) mod delay_plus_downsample; +pub(crate) mod delay_plus_pool; mod einsum; -mod pad_plus_conv; +pub(crate) mod pad_plus_conv; +mod v2; #[allow(dead_code)] fn setup_test_logger() { diff --git a/harness/core-proptest-pulse/src/pad_plus_conv.rs b/harness/core-proptest-pulse/src/pad_plus_conv.rs index 29bd87fdbe..35e268556d 100644 --- a/harness/core-proptest-pulse/src/pad_plus_conv.rs +++ b/harness/core-proptest-pulse/src/pad_plus_conv.rs @@ -5,7 +5,7 @@ use proptest::*; use super::*; #[derive(Debug, Clone)] -struct PadPlusConvProblem { +pub(crate) struct PadPlusConvProblem { pad_before: usize, pad_after: usize, pad_mode: PadMode, @@ -107,6 +107,50 @@ proptest! { fn proptest_conv(pb in PadPlusConvProblem::arbitrary()) { pb.run().unwrap() } } +impl PadPlusConvProblem { + pub fn run_v2(&self) -> TestCaseResult { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let mut wire = model.add_source("a", f32::fact(dims!(1, 1, s))).unwrap(); + if self.pad_before > 0 || self.pad_after > 0 { + wire = model + .wire_node( + "pad", + Pad::new( + vec![(0, 0), (0, 0), (self.pad_before, self.pad_after)], + self.pad_mode.clone(), + ), + &[wire], + ) + .unwrap()[0]; + } + let kernel = model.add_const("kernel", self.ker.clone()).unwrap(); + let bias = model.add_const("bias", tensor0(0f32)).unwrap(); + let conv = model + .wire_node( + "conv", + Conv { + pool_spec: PoolSpec { + data_format: DataFormat::NCHW, + kernel_shape: self.ker.shape()[2..].into(), + padding: PaddingSpec::Valid, + dilations: Some(tvec!(self.dilation)), + strides: Some(tvec!(self.stride)), + input_channels: 1, + output_channels: 1, + }, + kernel_fmt: tract_core::ops::cnn::KernelFormat::OIHW, + group: 1, + q_params: None, + }, + &[wire, kernel, bias], + ) + .unwrap(); + model.select_output_outlets(&conv).unwrap(); + crate::v2::run_and_compare_v2(model, self.pulse, &self.input.clone().into_tensor(), 2) + } +} + #[test] fn conv_1() { PadPlusConvProblem { diff --git a/harness/core-proptest-pulse/src/v2.rs b/harness/core-proptest-pulse/src/v2.rs new file mode 100644 index 0000000000..9b3093644a --- /dev/null +++ b/harness/core-proptest-pulse/src/v2.rs @@ -0,0 +1,346 @@ +use proptest::prelude::*; +use proptest::test_runner::TestCaseResult; +use tract_core::ops::cnn::Conv; +use tract_core::plan::SimplePlan; +use tract_core::prelude::*; +use tract_pulse::v2::PulseV2Model; +use tract_pulse::v2_buffer::PulseV2Buffer; +use tract_pulse::v2_slice::PulseV2Slice; + +/// Pulsify a model via PulseV2, run pulse by pulse, stitch output, +/// and compare against batch reference. +pub fn run_and_compare_v2( + model: TypedModel, + pulse: usize, + input: &Tensor, + axis: usize, +) -> TestCaseResult { + let stream_sym = model + .input_fact(0) + .unwrap() + .shape + .iter() + .flat_map(|d| d.symbols()) + .next() + .expect("No streaming symbol found in model input"); + + let batch = + model.clone().into_runnable().unwrap().run(tvec!(input.clone().into_tvalue())).unwrap(); + + let pv2 = PulseV2Model::new(&model, stream_sym.clone()).unwrap(); + let t_sym = pv2.symbols.pulse_id.clone(); + let p_sym = pv2.symbols.pulse.clone(); + + let batch_output_fact = model.output_fact(0).unwrap(); + let output_axis = batch_output_fact + .shape + .iter() + .position(|d| d.symbols().contains(&stream_sym)) + .unwrap_or(axis); + + let typed = pv2.into_typed().unwrap(); + // Output-axis delay: each PulseV2Buffer pre-fills `lookback` zeros on the + // input axis; the corresponding garbage shows up at the OUTPUT axis after + // being divided by every stride applied between the buffer and the output. + // For `Source → Buffer(L) → Conv(stride=s) → output`, output garbage = + // L/s. For chained ops, the stride compounds. Approximation here: walk + // the model once, accumulate `total_lookback` from all buffers and + // `total_stride` from all Conv nodes that touch the streaming axis, + // delay = total_lookback / total_stride. Holds for the proptest models + // (single linear streaming path) — real graphs would need per-buffer + // accumulated downstream stride. + let mut total_lookback = 0usize; + let mut total_stride = 1usize; + let mut total_slice_start: i64 = 0; + for node in typed.nodes() { + if let Some(buf) = node.op.downcast_ref::() { + total_lookback += buf.lookback.iter().copied().max().unwrap_or(0); + } + if let Some(conv) = node.op.downcast_ref::() { + let h_axis = conv.pool_spec.data_format.h_axis(); + if axis >= h_axis { + let geo_ix = axis - h_axis; + let s = conv + .pool_spec + .strides + .as_ref() + .and_then(|st| st.get(geo_ix).copied()) + .unwrap_or(1); + total_stride *= s; + } + } + if let Some(slice) = node.op.downcast_ref::() { + if slice.axis == axis { + if let Ok(s) = slice.start.to_i64() { + total_slice_start += s; + } + } + } + } + let total_delay = total_lookback / total_stride.max(1) + total_slice_start as usize; + let plan = SimplePlan::new(typed).unwrap(); + let mut state = plan.spawn().unwrap(); + + let input_len = input.shape()[axis]; + let mut output_chunks: Vec = vec![]; + let mut written = 0; + let mut pulse_idx: i64 = 0; + let batch_len = batch[0].shape()[output_axis]; + + loop { + let chunk_len = pulse.min(input_len.saturating_sub(written)); + + state.turn_state.resolved_symbols.set(&t_sym, pulse_idx); + state.turn_state.resolved_symbols.set(&p_sym, pulse as i64); + state.turn_state.resolved_symbols.set(&stream_sym, input_len as i64); + + let chunk = if chunk_len == 0 { + let mut shape = input.shape().to_vec(); + shape[axis] = pulse; + Tensor::zero_dt(input.datum_type(), &shape).unwrap() + } else if chunk_len < pulse { + let chunk = input.slice(axis, written, written + chunk_len).unwrap(); + let mut padded_shape = input.shape().to_vec(); + padded_shape[axis] = pulse; + let mut padded = Tensor::zero_dt(input.datum_type(), &padded_shape).unwrap(); + padded.assign_slice(0..chunk_len, &chunk, 0..chunk_len, axis).unwrap(); + padded + } else { + input.slice(axis, written, written + chunk_len).unwrap().into_tensor() + }; + + let outputs = state.run(tvec!(chunk.into_tvalue())).unwrap(); + let out = outputs[0].clone().into_tensor(); + if out.shape()[output_axis] > 0 { + output_chunks.push(out); + } + written += pulse; + pulse_idx += 1; + + let total_out: usize = output_chunks.iter().map(|t| t.shape()[output_axis]).sum(); + if total_out >= batch_len + total_delay { + break; + } + if pulse_idx > 1000 { + panic!("Pulsed run exceeded 1000 pulses"); + } + } + + let pulsed_output = Tensor::stack_tensors(output_axis, &output_chunks).unwrap(); + + let batch_output = &batch[0]; + let pulsed_len = pulsed_output.shape()[output_axis]; + prop_assert!( + pulsed_len >= total_delay, + "Pulsed output ({pulsed_len}) shorter than total_delay ({total_delay})" + ); + let pulsed_valid = pulsed_output.slice(output_axis, total_delay, pulsed_len).unwrap(); + let pulsed_valid_len = pulsed_valid.shape()[output_axis]; + let compare_len = batch_len.min(pulsed_valid_len); + prop_assert!(compare_len > 0, "No output produced"); + let batch_slice = batch_output.slice(output_axis, 0, compare_len).unwrap(); + let pulsed_slice = pulsed_valid.slice(output_axis, 0, compare_len).unwrap(); + prop_assert!( + pulsed_slice.close_enough(&batch_slice, true).is_ok(), + "Mismatch:\nbatch: {:?}\npulsed: {:?}", + batch_slice, + pulsed_slice + ); + Ok(()) +} + +// ── Proptests using V1 problem generators ────────────────────────────── + +use super::conv_plus_conv::ConvPlusConvProblem; +use super::deconv::DeconvProblem; +use super::delay_plus_downsample::DelayPlusDownsampleProblem; +use super::delay_plus_pool::DelayPlusPoolProblem; +use super::pad_plus_conv::PadPlusConvProblem; + +#[test] +#[ignore] +fn proptest_v2_conv_full() { + use proptest::test_runner::{Config, TestRunner}; + let mut runner = TestRunner::new(Config::default()); + runner.run(&ConvPlusConvProblem::arbitrary(), |pb| pb.run_v2()).unwrap(); +} + +#[test] +#[ignore] +fn proptest_v2_pad_plus_conv_full() { + use proptest::test_runner::{Config, TestRunner}; + let mut runner = TestRunner::new(Config::default()); + runner.run(&PadPlusConvProblem::arbitrary(), |pb| pb.run_v2()).unwrap(); +} + +#[test] +#[ignore] +fn proptest_v2_delay_plus_pool_full() { + use proptest::test_runner::{Config, TestRunner}; + let mut runner = TestRunner::new(Config::default()); + runner.run(&DelayPlusPoolProblem::arbitrary(), |pb| pb.run_v2()).unwrap(); +} + +#[test] +#[ignore] +fn proptest_v2_deconv_full() { + use proptest::test_runner::{Config, TestRunner}; + let mut runner = TestRunner::new(Config::default()); + runner.run(&DeconvProblem::arbitrary(), |pb| pb.run_v2()).unwrap(); +} + +#[test] +#[ignore] +fn proptest_v2_delay_plus_downsample_full() { + use proptest::test_runner::{Config, TestRunner}; + let mut runner = TestRunner::new(Config::default()); + runner.run(&DelayPlusDownsampleProblem::arbitrary(), |pb| pb.run_v2()).unwrap(); +} + +// ── Focused proptests for supported subset ───────────────────────────── + +use tract_core::dims; +use tract_core::ops::cnn::*; +use tract_core::ops::nn::DataFormat; + +fn wire_conv( + model: &mut TypedModel, + name: &str, + input: OutletId, + kernel_size: usize, + dilation: usize, + stride: usize, +) -> TVec { + let ker_data: Vec = (0..kernel_size).map(|i| (i + 1) as f32).collect(); + let ker_tensor = tensor1(&ker_data).into_shape(&[1, 1, kernel_size]).unwrap(); + let k = model.add_const(format!("{name}.k"), ker_tensor).unwrap(); + let b = model.add_const(format!("{name}.b"), tensor0(0f32)).unwrap(); + model + .wire_node( + name, + Conv { + pool_spec: PoolSpec { + data_format: DataFormat::NCHW, + kernel_shape: tvec!(kernel_size), + padding: PaddingSpec::Valid, + dilations: if dilation > 1 { Some(tvec!(dilation)) } else { None }, + strides: if stride > 1 { Some(tvec!(stride)) } else { None }, + input_channels: 1, + output_channels: 1, + }, + kernel_fmt: KernelFormat::OIHW, + group: 1, + q_params: None, + }, + &[input, k, b], + ) + .unwrap() +} + +fn nchw_model_and_input( + ops: impl FnOnce(&mut TypedModel, OutletId) -> TVec, + input_len: usize, +) -> (TypedModel, Tensor) { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(1, 1, s))).unwrap(); + let out = ops(&mut model, a); + model.select_output_outlets(&out).unwrap(); + let input_data: Vec = (0..input_len).map(|i| i as f32).collect(); + let input = tensor1(&input_data).into_shape(&[1, 1, input_len]).unwrap(); + (model, input) +} + +proptest! { + #[test] + fn proptest_v2_single_conv(kernel in 1usize..5, pulse_extra in 0usize..5, input_extra in 0usize..8) { + let pulse = kernel + pulse_extra; + let (model, input) = nchw_model_and_input( + |m, a| wire_conv(m, "conv", a, kernel, 1, 1), + kernel + pulse + input_extra, + ); + run_and_compare_v2(model, pulse, &input, 2)?; + } + + #[test] + fn proptest_v2_conv_chain(k1 in 1usize..4, k2 in 1usize..4, pulse_extra in 0usize..4, input_extra in 0usize..6) { + let lookback = (k1 - 1) + (k2 - 1); + let pulse = lookback.max(1) + pulse_extra; + let (model, input) = nchw_model_and_input( + |m, a| { let c1 = wire_conv(m, "c1", a, k1, 1, 1); wire_conv(m, "c2", c1[0], k2, 1, 1) }, + lookback + pulse + input_extra + 1, + ); + run_and_compare_v2(model, pulse, &input, 2)?; + } + + #[test] + fn proptest_v2_conv_dilation(kernel in 1usize..4, dilation in 1usize..4, pulse_extra in 0usize..4, input_extra in 0usize..6) { + let receptive = (kernel - 1) * dilation; + let pulse = receptive.max(1) + pulse_extra; + let (model, input) = nchw_model_and_input( + |m, a| wire_conv(m, "conv", a, kernel, dilation, 1), + receptive + pulse + input_extra + 1, + ); + run_and_compare_v2(model, pulse, &input, 2)?; + } +} + +proptest! { + #[test] + fn proptest_v2_conv_stride(kernel in 1usize..4, conv_stride in 1usize..4, pulse_extra in 0usize..4, input_extra in 0usize..6) { + let min_factors = (kernel + conv_stride - 1) / conv_stride; + let pulse = conv_stride * (min_factors + pulse_extra); + let (model, input) = nchw_model_and_input( + |m, a| wire_conv(m, "conv", a, kernel, 1, conv_stride), + kernel + pulse + input_extra, + ); + run_and_compare_v2(model, pulse, &input, 2)?; + } +} + +proptest! { + #[test] + fn proptest_v2_crop(pulse in 1usize..5, input_len in 1usize..10, begin in 0usize..3, end_margin in 0usize..3) { + use tract_core::ops::array::Slice; + let full_len = input_len + begin + end_margin; + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(&[TDim::from(s.clone())])).unwrap(); + let slice = model.wire_node("slice", Slice::new(0, begin, begin + input_len), &[a]).unwrap(); + model.select_output_outlets(&slice).unwrap(); + let input = tensor1(&(1..=full_len).map(|i| i as f32).collect::>()); + run_and_compare_v2(model, pulse, &input, 0)?; + } +} + +proptest! { + #[test] + fn proptest_v2_pad(pulse in 1usize..5, input_len in 1usize..10, before in 0usize..4, after in 0usize..4) { + use tract_core::ops::array::{Pad, PadMode}; + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(&[TDim::from(s.clone())])).unwrap(); + let pad = model.wire_node("pad", Pad::new(vec![(before, after)], PadMode::Constant(Arc::new(Tensor::from(-1f32)))), &[a]).unwrap(); + model.select_output_outlets(&pad).unwrap(); + let input = tensor1(&(1..=input_len).map(|i| i as f32).collect::>()); + run_and_compare_v2(model, pulse, &input, 0)?; + } +} + +proptest! { + #[test] + fn proptest_v2_pad_plus_conv(kernel in 1usize..4, pad_before in 0usize..4, pad_after in 0usize..4, pulse_extra in 0usize..4, input_extra in 0usize..6) { + use tract_core::ops::array::{Pad, PadMode}; + let pulse = kernel.max(1) + pulse_extra; + let input_len = kernel + pulse + input_extra; + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(1, 1, s))).unwrap(); + let wire = model.wire_node("pad", Pad::new(vec![(0, 0), (0, 0), (pad_before, pad_after)], PadMode::Constant(Arc::new(Tensor::from(-1f32)))), &[a]).unwrap()[0]; + let conv = wire_conv(&mut model, "conv", wire, kernel, 1, 1); + model.select_output_outlets(&conv).unwrap(); + let input_data: Vec = (0..input_len).map(|i| i as f32).collect(); + let input = tensor1(&input_data).into_shape(&[1, 1, input_len]).unwrap(); + run_and_compare_v2(model, pulse, &input, 2)?; + } +} diff --git a/harness/core-proptest-pulse/tests/hey_snips_v2.rs b/harness/core-proptest-pulse/tests/hey_snips_v2.rs new file mode 100644 index 0000000000..4817cab7cc --- /dev/null +++ b/harness/core-proptest-pulse/tests/hey_snips_v2.rs @@ -0,0 +1,169 @@ +//! End-to-end test of PulseV2 against the hey_snips WaveNet model. +//! +//! Loads the cached .pb, brings it to typed/decluttered form, runs both the +//! batch model and the pulsified v2 model on the same random input, stitches +//! the per-pulse outputs back together, and compares numerically against +//! batch (modulo the pulsified model's startup delay). + +use std::path::PathBuf; + +use tract_core::ops::cnn::Conv; +use tract_core::plan::SimplePlan; +use tract_core::prelude::*; +use tract_pulse::v2::PulseV2Model; +use tract_pulse::v2_buffer::PulseV2Buffer; +use tract_pulse::v2_slice::PulseV2Slice; +use tract_tensorflow::prelude::*; + +fn cached_hey_snips() -> Option { + for candidate in [ + std::env::var("MODELS").ok().map(|m| PathBuf::from(m).join("hey_snips_v4_model17.pb")), + Some(PathBuf::from(".cached/hey_snips_v4_model17.pb")), + Some(PathBuf::from("../../.cached/hey_snips_v4_model17.pb")), + ] + .into_iter() + .flatten() + { + if candidate.exists() { + return Some(candidate); + } + } + None +} + +#[test] +fn hey_snips_pulsify_v2() -> TractResult<()> { + let Some(path) = cached_hey_snips() else { + eprintln!("hey_snips_v4_model17.pb not cached — skipping"); + return Ok(()); + }; + + let mut model = tensorflow().model_for_path(&path)?; + let s = model.symbols.sym("S"); + model.set_input_fact(0, f32::fact([s.to_dim(), 20.to_dim()]).into())?; + // Pulsification operates on the typed model. Don't call `into_optimized()` + // here — that decomposes Conv into Im2col + OptMatMul and inserts AxisOp + // wrappers that pulse-v2's RegionTransform inventory doesn't handle. + let model = model.into_typed()?.into_decluttered()?; + + // Concrete stream length for the comparison run. Wavenet's receptive + // field is ~180 samples; pick something well past it so steady-state + // dominates. + let stream_len: usize = 256; + let pulse: usize = 8; + let stream_axis = 0; + + // Random-but-deterministic input. + let input_shape = [stream_len, 20]; + let mut input_vec: Vec = Vec::with_capacity(stream_len * 20); + for i in 0..stream_len * 20 { + input_vec.push((i as f32 * 0.137).sin()); + } + let input = tract_ndarray::Array::from_shape_vec(input_shape, input_vec)?.into_tensor(); + + // Run batch reference. + let concrete_input_fact = f32::fact(&input_shape); + let mut batch_model = model.clone(); + batch_model.set_input_fact(0, concrete_input_fact.into())?; + let batch_model = batch_model.into_decluttered()?; + let batch_outputs = batch_model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?; + let batch_output = batch_outputs[0].clone().into_tensor(); + + // Pulsify and run pulse-by-pulse. + let pv2 = PulseV2Model::new(&model, s.clone())?; + let t_sym = pv2.symbols.pulse_id.clone(); + let p_sym = pv2.symbols.pulse.clone(); + let typed = pv2.into_typed()?; + + let output_axis = typed + .output_fact(0)? + .shape + .iter() + .position(|d| d.symbols().contains(&p_sym)) + .expect("Pulsed output should have a streaming axis"); + + // Total delay = sum of buffer lookbacks (divided by total stride) + + // sum of slice starts on the streaming axis. + // Quick instrumentation of inserted ops, for debugging. + let mut buf_count = 0usize; + let mut conv_count = 0usize; + let mut slice_count = 0usize; + for node in typed.nodes() { + if node.op.downcast_ref::().is_some() { + buf_count += 1; + } + if node.op.downcast_ref::().is_some() { + conv_count += 1; + } + if node.op.downcast_ref::().is_some() { + slice_count += 1; + } + } + eprintln!( + "PulseV2 wavenet inventory: buffers={buf_count} convs={conv_count} slices={slice_count}" + ); + + let plan = SimplePlan::new(typed)?; + let mut state = plan.spawn()?; + + let batch_len = batch_output.shape()[output_axis]; + let mut output_chunks: Vec = vec![]; + let mut written = 0usize; + let mut pulse_idx: i64 = 0; + // Run exactly as many pulses as needed to consume the input stream. + // Per-pulse output is constant (= pulse, mod stride). After + // `num_pulses * out_per_pulse` total output samples, the steady-state + // tail of the pulsed output corresponds to the BATCH output's tail. + let num_input_pulses = stream_len.div_ceil(pulse); + while pulse_idx < num_input_pulses as i64 { + let chunk_len = pulse.min(stream_len.saturating_sub(written)); + + state.turn_state.resolved_symbols.set(&t_sym, pulse_idx); + state.turn_state.resolved_symbols.set(&p_sym, pulse as i64); + state.turn_state.resolved_symbols.set(&s, stream_len as i64); + + let chunk = if chunk_len == 0 { + let mut shape = input.shape().to_vec(); + shape[stream_axis] = pulse; + Tensor::zero_dt(input.datum_type(), &shape)? + } else if chunk_len < pulse { + let chunk = input.slice(stream_axis, written, written + chunk_len)?; + let mut padded_shape = input.shape().to_vec(); + padded_shape[stream_axis] = pulse; + let mut padded = Tensor::zero_dt(input.datum_type(), &padded_shape)?; + padded.assign_slice(0..chunk_len, &chunk, 0..chunk_len, stream_axis)?; + padded + } else { + input.slice(stream_axis, written, written + chunk_len)?.into_tensor() + }; + + let outputs = state.run(tvec!(chunk.into_tvalue()))?; + let out = outputs[0].clone().into_tensor(); + if out.shape()[output_axis] > 0 { + output_chunks.push(out); + } + written += pulse; + pulse_idx += 1; + } + + let pulsed_output = Tensor::stack_tensors(output_axis, &output_chunks)?; + let pulsed_len = pulsed_output.shape()[output_axis]; + assert!(pulsed_len >= batch_len, "pulsed_len={pulsed_len} batch_len={batch_len}"); + // Steady-state tail: last `batch_len` samples of pulsed output line up + // with the full batch output (assuming the model has run long enough + // for warmup garbage to wash out, which `stream_len > 2 × receptive_field` + // ensures). + let compare_len = batch_len; + let pulsed_slice = pulsed_output.slice(output_axis, pulsed_len - compare_len, pulsed_len)?; + let batch_slice = batch_output.slice(output_axis, 0, compare_len)?; + + println!( + "stream_len={stream_len} pulse={pulse} batch_len={batch_len} pulsed_len={pulsed_len} \ + compare_len={compare_len}" + ); + if let Err(e) = pulsed_slice.close_enough(&batch_slice, true) { + panic!("Mismatch:\nbatch: {batch_slice:?}\npulsed: {pulsed_slice:?}\n{e}"); + } + println!("PulseV2 hey_snips numerical match: {compare_len} samples on output axis"); + Ok(()) +} diff --git a/harness/core-proptest-pulse/tests/v2_tdim_blowup.rs b/harness/core-proptest-pulse/tests/v2_tdim_blowup.rs new file mode 100644 index 0000000000..457fd97aa5 --- /dev/null +++ b/harness/core-proptest-pulse/tests/v2_tdim_blowup.rs @@ -0,0 +1,155 @@ +//! Isolate the TDim simplifier blowup that hangs PulseV2 on deep conv chains. +//! +//! Build synthetic chains of dilated 1D convs of increasing depth and report +//! how long `PulseV2Model::new` takes plus the size of the output fact shape +//! expressions. We expect time (or expression size) to grow fast past some N. + +use std::time::Instant; + +use tract_core::ops::array::Slice; +use tract_core::ops::binary::TypedBinOp; +use tract_core::ops::cnn::pools::PoolSpec; +use tract_core::ops::cnn::{Conv, KernelFormat}; +use tract_core::ops::math::Add; +use tract_core::ops::nn::DataFormat; +use tract_core::prelude::*; +use tract_pulse::v2::PulseV2Model; + +/// Wire a 1D Conv with the given kernel size and dilation, NCHW layout. +fn wire_conv( + model: &mut TypedModel, + name: &str, + input: OutletId, + in_channels: usize, + out_channels: usize, + kernel: usize, + dilation: usize, +) -> OutletId { + let pool_spec = PoolSpec::new( + DataFormat::NCHW, + tvec!(kernel), + tract_core::ops::cnn::PaddingSpec::Valid, + Some(tvec!(dilation)), + Some(tvec!(1)), + in_channels, + out_channels, + ); + let kernel_t = Tensor::zero::(&[out_channels, in_channels, kernel]).unwrap(); + let kernel_outlet = model.add_const(format!("{name}.k"), kernel_t).unwrap(); + let bias_t = Tensor::zero::(&[out_channels]).unwrap(); + let bias_outlet = model.add_const(format!("{name}.b"), bias_t).unwrap(); + + let conv = Conv { pool_spec, kernel_fmt: KernelFormat::OIHW, group: 1, q_params: None }; + model.wire_node(name, conv, &[input, kernel_outlet, bias_outlet]).unwrap()[0] +} + +fn dilated_conv_chain(n_layers: usize) -> (TypedModel, Symbol) { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + // NCHW layout: [1, channels=1, time=S] + let fact = f32::fact([1.to_dim(), 1.to_dim(), s.to_dim()]); + let mut wire = model.add_source("input", fact).unwrap(); + for i in 0..n_layers { + let dilation = 1usize << i; + wire = wire_conv(&mut model, &format!("conv_{i}"), wire, 1, 1, 3, dilation); + } + model.select_output_outlets(&[wire]).unwrap(); + (model, s) +} + +/// Build a WaveNet-like residual chain: each block has `conv → conv` and a +/// skip connection from the block input that gets sliced to align with the +/// block output, then added back. Skip-Adds are where parallel paths merge — +/// this is what blows broadcast Max in v2. +fn skip_residual_chain(n_blocks: usize) -> (TypedModel, Symbol) { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let fact = f32::fact([1.to_dim(), 1.to_dim(), s.to_dim()]); + let mut wire = model.add_source("input", fact).unwrap(); + let mut produced_overlap = 0i64; // cumulative receptive-field consumed so far + for i in 0..n_blocks { + let dilation = 1usize << (i % 4); + let block_in = wire; + let conv = wire_conv(&mut model, &format!("b{i}_c1"), block_in, 1, 1, 3, dilation); + let conv2 = wire_conv(&mut model, &format!("b{i}_c2"), conv, 1, 1, 3, dilation); + produced_overlap += 4 * dilation as i64; // two convs each consume 2*dilation + // Slice block_in to align with conv2's effective length. + // block_in has shape on streaming axis = S - produced_overlap_before. + // conv2 has shape = S - produced_overlap. Skip needs slice of length + // S - produced_overlap, taking the central window (offset 4*dilation). + let begin = (4 * dilation) as usize; + let slice = model + .wire_node( + format!("b{i}_skip"), + Slice { axis: 2, start: TDim::Val(begin as i64), end: s.to_dim() }, + &[block_in], + ) + .unwrap()[0]; + let _ = produced_overlap; + wire = model + .wire_node(format!("b{i}_add"), TypedBinOp(Box::new(Add), None), &[conv2, slice]) + .unwrap()[0]; + } + model.select_output_outlets(&[wire]).unwrap(); + (model, s) +} + +#[test] +fn skip_residual_chain_grows() { + for n in 1..=8 { + let (model, s) = skip_residual_chain(n); + let t0 = Instant::now(); + let result = PulseV2Model::new(&model, s); + let elapsed = t0.elapsed(); + match result { + Ok(pv2) => { + let typed = pv2.into_typed().unwrap(); + let out_fact = typed.output_fact(0).unwrap(); + let shape_str = format!("{:?}", out_fact.shape); + let len = shape_str.len(); + println!( + "n={n} {}ms shape_str_len={len} shape={shape_str}", + elapsed.as_millis() + ); + } + Err(e) => { + println!("n={n} {}ms ERROR: {:#}", elapsed.as_millis(), e); + break; + } + } + if elapsed.as_secs() > 10 { + println!("(bail — >10s)"); + break; + } + } +} + +#[test] +fn dilated_chain_grows() { + for n in 1..=30 { + let (model, s) = dilated_conv_chain(n); + let t0 = Instant::now(); + let result = PulseV2Model::new(&model, s); + let elapsed = t0.elapsed(); + match result { + Ok(pv2) => { + let typed = pv2.into_typed().unwrap(); + let out_fact = typed.output_fact(0).unwrap(); + let shape_str = format!("{:?}", out_fact.shape); + let len = shape_str.len(); + println!( + "n={n:2} {:>7}ms shape_str_len={len:5} shape={shape_str}", + elapsed.as_millis() + ); + } + Err(e) => { + println!("n={n} {}ms ERROR: {:#}", elapsed.as_millis(), e); + break; + } + } + if elapsed.as_secs() > 10 { + println!("(bail — >10s)"); + break; + } + } +} diff --git a/pulse/Cargo.toml b/pulse/Cargo.toml index ece76961be..f55b4e8a03 100644 --- a/pulse/Cargo.toml +++ b/pulse/Cargo.toml @@ -17,6 +17,7 @@ maintenance = { status = "actively-developed" } downcast-rs.workspace = true dyn-eq.workspace = true erased-serde.workspace = true +inventory.workspace = true lazy_static.workspace = true log.workspace = true serde.workspace = true diff --git a/pulse/src/lib.rs b/pulse/src/lib.rs index 24ea6f3c5c..6a9b1811d1 100644 --- a/pulse/src/lib.rs +++ b/pulse/src/lib.rs @@ -5,6 +5,12 @@ pub mod macros; pub mod fact; pub mod model; pub mod ops; +pub mod v2; +pub mod v2_buffer; +pub mod v2_conv; +pub mod v2_deconv; +pub mod v2_pad; +pub mod v2_slice; pub mod internal { pub use std::fmt; diff --git a/pulse/src/v2.rs b/pulse/src/v2.rs new file mode 100644 index 0000000000..259388c450 --- /dev/null +++ b/pulse/src/v2.rs @@ -0,0 +1,584 @@ +/// PulseV2: region/increment-based pulsification. +/// +/// Symbols: +/// T — pulse index (0, 1, 2, …) +/// P — pulse size (kept symbolic until runtime) +/// +/// Each op declares a RegionTransform: "given my output region, what input +/// regions do I need?" The generic pulsifier compares what each op needs +/// against what the source provides, and inserts PulseV2Buffer where there's +/// a gap (lookback into previous pulses). +/// +/// Output shapes are computed by the batch model's output_facts — no separate +/// region propagation needed on the forward path. +use crate::internal::*; + +// ── Per-axis region ──────────────────────────────────────────────────── + +/// Per-axis specification at pulse T. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum AxisRegion { + /// This axis is not streaming — full extent every pulse. + Fixed(TDim), + /// This axis produces [start, end) at pulse T. + /// start and end are TDim expressions in symbols T and P. + Streaming { start: TDim, end: TDim }, +} + +impl AxisRegion { + pub fn size(&self) -> TDim { + match self { + AxisRegion::Fixed(d) => d.clone(), + AxisRegion::Streaming { start, end } => end.clone() - start.clone(), + } + } + + pub fn is_streaming(&self) -> bool { + matches!(self, AxisRegion::Streaming { .. }) + } +} + +/// Region description for a wire at pulse T. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PulseV2Region { + pub axes: TVec, +} + +impl PulseV2Region { + pub fn rank(&self) -> usize { + self.axes.len() + } +} + +// ── Region transforms (inventory) ────────────────────────────────────── + +/// Result of a region transform: either adjust input regions (the generic +/// pulsifier handles buffer insertion and op wiring) or replace the op +/// entirely with pre-wired outlets. +pub enum PulseV2Action { + /// Use these input regions. None means "not streaming, pass as-is". + /// overlap: optional per-input, per-axis overlap (un-rounded). + /// When provided, lookback = region difference (rounded) and + /// overlap < lookback. The buffer trims (lookback - overlap) from the front. + InputRegions(TVec>, Option>>), + /// Skip this op — just forward the mapped inputs as outputs. + Skip, + /// Replace this op with a different one (same inputs). + ReplaceOp(Box), + /// Wire the original op normally, then append a post-processing op on its output. + WireOpThenPostOp(Box), + /// Wire a pre-op on the data input (index 0), then wire the replacement op. + /// Used for decomposing Conv(padded) into PulseV2Pad + Conv(valid). + WirePreOpThenOp { pre_op: Box, main_op: Box }, +} + +pub type RegionTransformFn = fn( + op: &dyn TypedOp, + source_region: &PulseV2Region, + symbols: &PulseV2Symbols, +) -> TractResult>; + +/// Inventory entry: maps an op TypeId to its region transform. +pub struct RegionTransform { + pub type_id: std::any::TypeId, + pub func: RegionTransformFn, +} + +inventory::collect!(RegionTransform); + +fn lookup_region_transform( + op: &dyn TypedOp, + source_region: &PulseV2Region, + symbols: &PulseV2Symbols, +) -> TractResult> { + let type_id = op.type_id(); + for rt in inventory::iter:: { + if rt.type_id == type_id { + return (rt.func)(op, source_region, symbols); + } + } + Ok(None) +} + +// ── Symbols ──────────────────────────────────────────────────────────── + +#[derive(Clone, Debug)] +pub struct PulseV2Symbols { + pub stream: Symbol, + pub pulse: Symbol, + pub pulse_id: Symbol, +} + +// ── PulseV2Model ─────────────────────────────────────────────────────── + +pub struct PulseV2Model { + pub typed: TypedModel, + pub symbols: PulseV2Symbols, +} + +impl PulseV2Model { + /// Pulsify a batch model. + /// + /// For each Source, substitute S → P. For each op, ask its RegionTransform + /// what input regions it needs. Where an input needs lookback beyond the + /// source increment, insert a PulseV2Buffer. Wire the op with (potentially + /// buffered) inputs. Output shapes are computed by the op's output_facts. + pub fn new(batch_model: &TypedModel, stream_sym: Symbol) -> TractResult { + use crate::v2_buffer::PulseV2Buffer; + + // Pre-process: decompose Conv/MaxPool with non-valid padding on + // streaming axes into Pad + Conv/MaxPool(valid). + let batch_model = Self::decompose_streaming_padding(batch_model, &stream_sym)?; + let batch_model = &batch_model; + + let p_sym = batch_model.symbols.sym("P"); + let t_sym = batch_model.symbols.sym("T"); + + // Assert S >= P: the stream must be at least one pulse long. + // This ensures min(T*P, lookback) is valid — there really are + // T*P cumulative source samples at pulse T. + batch_model.symbols.add_assertion(format!("{} >= {}", stream_sym, p_sym)).ok(); + + let symbols = PulseV2Symbols { + stream: stream_sym.clone(), + pulse: p_sym.clone(), + pulse_id: t_sym.clone(), + }; + + // The source increment: what one pulse provides on streaming axes. + // This is the baseline — ops that need more get a buffer. + let t = TDim::Sym(t_sym.clone()); + let p = TDim::Sym(p_sym.clone()); + + let mut typed = TypedModel::default(); + let mut mapping: HashMap = HashMap::new(); + // Track which pulsed wires are streaming and what their source region is. + let mut wire_regions: HashMap = HashMap::new(); + + let order = batch_model.eval_order()?; + + for &node_id in &order { + let node = batch_model.node(node_id); + + // Source: substitute S → P, record the source increment as region. + if node.op.downcast_ref::().is_some() { + let batch_fact = batch_model.outlet_fact(OutletId::new(node_id, 0))?; + let mut axes = TVec::new(); + let mut pulse_shape = TVec::new(); + for dim in batch_fact.shape.iter() { + if dim.symbols().contains(&stream_sym) && dim == &TDim::Sym(stream_sym.clone()) + { + axes.push(AxisRegion::Streaming { + start: t.clone() * p.clone(), + end: (t.clone() + 1) * p.clone(), + }); + pulse_shape.push(p.clone()); + } else { + axes.push(AxisRegion::Fixed(dim.clone())); + pulse_shape.push(dim.clone()); + } + } + let pulse_fact = batch_fact.datum_type.fact(pulse_shape); + let new_outlet = typed.add_source(&node.name, pulse_fact)?; + mapping.insert(OutletId::new(node_id, 0), new_outlet); + wire_regions.insert(new_outlet, PulseV2Region { axes }); + continue; + } + + // Find the source region for the first streaming input. + let source_region = node + .inputs + .iter() + .find_map(|i| mapping.get(i).and_then(|o| wire_regions.get(o))) + .cloned(); + + // Ask the region transform what to do. + let typed_op: &dyn TypedOp = node.op.as_ref(); + let action = if let Some(src) = &source_region { + lookup_region_transform(typed_op, src, &symbols)? + } else { + None + }; + + // Handle Skip: forward inputs directly as outputs. + if matches!(action, Some(PulseV2Action::Skip)) { + for (slot, batch_input) in node.inputs.iter().enumerate() { + let pulsed = mapping.get(batch_input).copied().unwrap_or(*batch_input); + mapping.insert(OutletId::new(node_id, slot), pulsed); + } + if let Some(region) = &source_region { + for slot in 0..node.outputs.len() { + if let Some(&pulsed) = mapping.get(&OutletId::new(node_id, slot)) { + wire_regions.insert(pulsed, region.clone()); + } + } + } + continue; + } + + // Handle ReplaceOp: wire the replacement op instead. + if let Some(PulseV2Action::ReplaceOp(replacement)) = action { + let inputs: TVec = + node.inputs.iter().map(|i| mapping.get(i).copied().unwrap_or(*i)).collect(); + let new_outlets = typed.wire_node(&node.name, replacement, &inputs)?; + for (slot, &new_outlet) in new_outlets.iter().enumerate() { + mapping.insert(OutletId::new(node_id, slot), new_outlet); + } + if let Some(region) = &source_region { + for &new_outlet in &new_outlets { + wire_regions.insert(new_outlet, region.clone()); + } + } + continue; + } + + // Handle WireOpThenPostOp: wire the original op, then append a post-op. + if let Some(PulseV2Action::WireOpThenPostOp(post_op)) = action { + let inputs: TVec = + node.inputs.iter().map(|i| mapping.get(i).copied().unwrap_or(*i)).collect(); + let op_outlets = typed.wire_node(&node.name, node.op.clone(), &inputs)?; + let post_outlets = + typed.wire_node(format!("{}.accum", node.name), post_op, &op_outlets)?; + for (slot, &new_outlet) in post_outlets.iter().enumerate() { + mapping.insert(OutletId::new(node_id, slot), new_outlet); + } + if let Some(region) = &source_region { + for &new_outlet in &post_outlets { + wire_regions.insert(new_outlet, region.clone()); + } + } + continue; + } + + // Handle WirePreOpThenOp: wire pre-op on input 0, then look up + // the main_op's region transform (for buffer insertion), then wire main_op. + if let Some(PulseV2Action::WirePreOpThenOp { pre_op, main_op }) = action { + let mut inputs: TVec = + node.inputs.iter().map(|i| mapping.get(i).copied().unwrap_or(*i)).collect(); + // Wire pre-op on the first (data) input. + let pre_outlets = + typed.wire_node(format!("{}.pre", node.name), pre_op, &[inputs[0]])?; + inputs[0] = pre_outlets[0]; + // Check if the main_op needs a buffer (e.g. Conv(valid) after Pad). + let main_typed: &dyn TypedOp = main_op.as_ref(); + if let Some(src) = &source_region { + if let Some(PulseV2Action::InputRegions(main_regions, _)) = + lookup_region_transform(main_typed, src, &symbols)? + { + // Insert buffer on the data input if needed. + if let Some(Some(needed)) = main_regions.first() { + let provided = src; + let mut lookback = tvec![0usize; needed.rank()]; + let mut needs_buffer = false; + for (ax, (n, p)) in + needed.axes.iter().zip(provided.axes.iter()).enumerate() + { + if let ( + AxisRegion::Streaming { start: ns, .. }, + AxisRegion::Streaming { start: ps, .. }, + ) = (n, p) + { + let lb = (ps.clone() - ns.clone()).simplify(); + if let Ok(v) = lb.to_i64() { + if v > 0 { + lookback[ax] = v as usize; + needs_buffer = true; + } + } + } + } + if needs_buffer { + let buffer = PulseV2Buffer { + overlap: lookback.clone(), + lookback, + pulse_id: t_sym.clone(), + pulse_sym: p_sym.clone(), + }; + let buffered = typed.wire_node( + format!("{}.buffer", node.name), + buffer, + &[inputs[0]], + )?; + inputs[0] = buffered[0]; + } + } + } + } + let new_outlets = typed.wire_node(&node.name, main_op, &inputs)?; + for (slot, &new_outlet) in new_outlets.iter().enumerate() { + mapping.insert(OutletId::new(node_id, slot), new_outlet); + } + if let Some(region) = &source_region { + for &new_outlet in &new_outlets { + wire_regions.insert(new_outlet, region.clone()); + } + } + continue; + } + + let (input_regions, overlap_hints): ( + TVec>, + Option>>, + ) = match action { + Some(PulseV2Action::InputRegions(r, o)) => (r, o), + _ => { + let r = if let Some(src) = &source_region { + node.inputs.iter().map(|_| Some(src.clone())).collect() + } else { + tvec![None; node.inputs.len()] + }; + (r, None) + } + }; + + // Wire inputs, inserting buffers where needed. + let mut inputs: TVec = TVec::new(); + for (ix, batch_input) in node.inputs.iter().enumerate() { + let pulsed_input = mapping.get(batch_input).copied().unwrap_or(*batch_input); + let wire_region = wire_regions.get(&pulsed_input); + let needed = input_regions.get(ix).and_then(|r| r.as_ref()); + + if let (Some(needed), Some(provided)) = (needed, wire_region) { + let mut lookback = tvec![0usize; needed.rank()]; + let mut needs_buffer = false; + for (ax, (needed_ax, provided_ax)) in + needed.axes.iter().zip(provided.axes.iter()).enumerate() + { + if let ( + AxisRegion::Streaming { start: needed_start, .. }, + AxisRegion::Streaming { start: provided_start, .. }, + ) = (needed_ax, provided_ax) + { + let lb = (provided_start.clone() - needed_start.clone()).simplify(); + if let Ok(v) = lb.to_i64() { + if v > 0 { + lookback[ax] = v as usize; + needs_buffer = true; + } + } + } + } + + if needs_buffer { + let overlap = overlap_hints + .as_ref() + .and_then(|h| h.get(ix)) + .cloned() + .unwrap_or_else(|| lookback.clone()); + let buffer = PulseV2Buffer { + overlap, + lookback, + pulse_id: t_sym.clone(), + pulse_sym: p_sym.clone(), + }; + let buffered = typed.wire_node( + format!("{}.buffer.{}", node.name, ix), + buffer, + &[pulsed_input], + )?; + inputs.push(buffered[0]); + } else { + inputs.push(pulsed_input); + } + } else { + inputs.push(pulsed_input); + } + } + + let new_outlets = typed.wire_node(&node.name, node.op.clone(), &inputs)?; + for (slot, &new_outlet) in new_outlets.iter().enumerate() { + mapping.insert(OutletId::new(node_id, slot), new_outlet); + if let Some(region) = &source_region { + wire_regions.insert(new_outlet, region.clone()); + } + // Intermediates stay in flat symbolic form; clamp only at sinks + // (below) to avoid nested Max trees blowing up the TDim simplifier + // on deep conv chains. + } + } + + let batch_outputs = batch_model.output_outlets()?.to_vec(); + let pulsed_outputs: TVec = batch_outputs.iter().map(|o| mapping[o]).collect(); + typed.select_output_outlets(&pulsed_outputs)?; + + Ok(PulseV2Model { typed, symbols }) + } + + /// Decompose Conv/MaxPool with non-valid padding on streaming axes + /// into explicit Pad + Conv/MaxPool(valid-on-streaming-axis). + fn decompose_streaming_padding( + model: &TypedModel, + stream_sym: &Symbol, + ) -> TractResult { + use crate::fact::StreamFact; + use tract_pulse_opl::tract_core::ops::array::{Pad, PadMode}; + use tract_pulse_opl::tract_core::ops::cnn::{Conv, MaxPool, PaddingSpec, PoolSpec}; + + let mut new_model = TypedModel::default(); + new_model.symbols = model.symbols.clone(); + let mut mapping: HashMap = HashMap::new(); + let order = model.eval_order()?; + let mut changed = false; + + for &node_id in &order { + let node = model.node(node_id); + + // Check if this node is a Conv or MaxPool with non-valid padding + // on a streaming axis. + let pool_spec = node + .op + .downcast_ref::() + .map(|c| &c.pool_spec) + .or_else(|| node.op.downcast_ref::().map(|p| &p.pool_spec)); + + if let Some(spec) = pool_spec { + let input_fact = model.outlet_fact(node.inputs[0])?; + let stream_axis = input_fact.shape.stream_info(stream_sym).map(|(ax, _)| ax); + + if let Some(stream_ax) = stream_axis { + let geo_axis = stream_ax - spec.data_format.h_axis(); + if geo_axis < spec.kernel_shape.len() + && !spec.padding.valid_dim(geo_axis, false) + { + // Compute padding amounts. + let dummy_hw: TVec = + spec.kernel_shape.iter().map(|k| k * 10).collect(); + let computed = spec.computed_padding(&dummy_hw); + + // Build Pad op: only pad the streaming axis. + let rank = input_fact.rank(); + let mut pads = vec![(0usize, 0usize); rank]; + pads[stream_ax] = ( + computed[geo_axis].pad_before.to_usize().unwrap_or(0), + computed[geo_axis].pad_after.to_usize().unwrap_or(0), + ); + + // Build new PoolSpec with zero padding on streaming axis. + let mut bef = tvec![]; + let mut aft = tvec![]; + for ix in 0..spec.kernel_shape.len() { + if ix == geo_axis { + bef.push(0); + aft.push(0); + } else { + bef.push(computed[ix].pad_before.to_usize().unwrap_or(0)); + aft.push(computed[ix].pad_after.to_usize().unwrap_or(0)); + } + } + let new_padding = + if bef.iter().all(|b| *b == 0) && aft.iter().all(|a| *a == 0) { + PaddingSpec::Valid + } else { + PaddingSpec::ExplicitOnnxPool(bef, aft, false) + }; + + // Wire: input → Pad → Conv/MaxPool(new_padding) + let data_input = + mapping.get(&node.inputs[0]).copied().unwrap_or(node.inputs[0]); + let pad_wire = new_model.wire_node( + format!("{}.pad", node.name), + Pad::new(pads, PadMode::Constant(Arc::new(Tensor::from(0.0f32)))), + &[data_input], + )?; + + let mut inputs: TVec = node + .inputs + .iter() + .map(|i| mapping.get(i).copied().unwrap_or(*i)) + .collect(); + inputs[0] = pad_wire[0]; + + let new_op: Box = + if let Some(conv) = node.op.downcast_ref::() { + Box::new(Conv { + pool_spec: PoolSpec { + padding: new_padding, + ..conv.pool_spec.clone() + }, + ..conv.clone() + }) + } else { + let pool = node.op.downcast_ref::().unwrap(); + Box::new(MaxPool { + pool_spec: PoolSpec { + padding: new_padding, + ..pool.pool_spec.clone() + }, + ..pool.clone() + }) + }; + + let new_outlets = new_model.wire_node(&node.name, new_op, &inputs)?; + for (slot, &outlet) in new_outlets.iter().enumerate() { + mapping.insert(OutletId::new(node_id, slot), outlet); + } + changed = true; + continue; + } + } + } + + // Default: copy the node as-is. + if node + .op + .downcast_ref::() + .is_some() + { + let fact = model.outlet_fact(OutletId::new(node_id, 0))?; + let new_outlet = new_model.add_source(&node.name, fact.clone())?; + mapping.insert(OutletId::new(node_id, 0), new_outlet); + } else { + let inputs: TVec = + node.inputs.iter().map(|i| mapping.get(i).copied().unwrap_or(*i)).collect(); + let new_outlets = new_model.wire_node(&node.name, node.op.clone(), &inputs)?; + for (slot, &outlet) in new_outlets.iter().enumerate() { + mapping.insert(OutletId::new(node_id, slot), outlet); + } + } + } + + let outputs = model.output_outlets()?.to_vec(); + let new_outputs: TVec = outputs.iter().map(|o| mapping[o]).collect(); + new_model.select_output_outlets(&new_outputs)?; + + if changed { Ok(new_model) } else { Ok(model.clone()) } + } + + pub fn into_typed(self) -> TractResult { + Ok(self.typed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_streaming_input() { + let scope = SymbolScope::default(); + let s = scope.sym("S"); + let p = scope.sym("P"); + let t = scope.sym("T"); + + let batch_fact = DatumType::F32.fact(&[1.to_dim(), s.clone().into(), 16.to_dim()]); + + let mut axes = TVec::new(); + for dim in batch_fact.shape.iter() { + if dim == &TDim::Sym(s.clone()) { + axes.push(AxisRegion::Streaming { + start: TDim::Sym(t.clone()) * TDim::Sym(p.clone()), + end: (TDim::Sym(t.clone()) + 1) * TDim::Sym(p.clone()), + }); + } else { + axes.push(AxisRegion::Fixed(dim.clone())); + } + } + let region = PulseV2Region { axes }; + + assert_eq!(region.rank(), 3); + assert!(!region.axes[0].is_streaming()); + assert!(region.axes[1].is_streaming()); + assert!(!region.axes[2].is_streaming()); + assert_eq!(region.axes[1].size().simplify(), TDim::Sym(p)); + } +} diff --git a/pulse/src/v2_buffer.rs b/pulse/src/v2_buffer.rs new file mode 100644 index 0000000000..9d606258a9 --- /dev/null +++ b/pulse/src/v2_buffer.rs @@ -0,0 +1,159 @@ +/// PulseV2Buffer: fixed-size streaming-axis history buffer. +/// +/// At each pulse T, the op emits `lookback + current` samples on the buffered +/// axis: the last `lookback` samples seen plus this pulse's input. The history +/// is initialised to zeros at session start, so on T=0 the output is +/// `[zeros…lookback…, current]` — fixed shape, garbage prefix during ramp, +/// matching v1's `Delay` semantics. +/// +/// Output shape on the buffered axis = input + lookback (constant). All other +/// axes are passed through unchanged. Multi-axis lookback is not supported in +/// this revision (the streaming axis is normally the only buffered one). +use crate::internal::*; +use tract_pulse_opl::tract_core::ops::{FrozenOpState, OpStateFreeze}; + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PulseV2Buffer { + /// Per-axis lookback. lookback[i] = 0 means no buffering on axis i. + /// Exactly one axis is expected to have non-zero lookback. + pub lookback: TVec, + /// Per-axis overlap: the actual data overlap the consumer needs. + /// overlap[i] <= lookback[i]. The difference (lookback - overlap) is + /// stride alignment padding that gets trimmed from the output. + pub overlap: TVec, + /// Pulse index symbol (T). Kept for compatibility with the rest of the + /// pulsifier even though `eval` no longer needs it. + pub pulse_id: Symbol, + /// Pulse size symbol (P). + pub pulse_sym: Symbol, +} + +impl PulseV2Buffer { + /// Axis with nonzero lookback. Panics if zero or more than one axis is + /// buffered — the rewrite assumes a single streaming axis. + pub fn buffered_axis(&self) -> Option<(usize, usize)> { + let mut it = self.lookback.iter().copied().enumerate().filter(|(_, lb)| *lb > 0); + let first = it.next(); + debug_assert!(it.next().is_none(), "PulseV2Buffer expects a single buffered axis"); + first + } +} + +impl Op for PulseV2Buffer { + fn name(&self) -> StaticName { + "PulseV2Buffer".into() + } + + fn info(&self) -> TractResult> { + Ok(vec![match self.buffered_axis() { + Some((ax, lb)) => format!("axis {ax}: lookback {lb}"), + None => "passthrough".to_string(), + }]) + } + + op_as_typed_op!(); +} + +impl EvalOp for PulseV2Buffer { + fn is_stateless(&self) -> bool { + false + } + + fn state( + &self, + _session: &TurnState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(PulseV2BufferState { history: None }))) + } +} + +impl TypedOp for PulseV2Buffer { + as_op!(); + + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let mut fact = inputs[0].clone(); + if let Some((axis, lookback)) = self.buffered_axis() { + let input_dim = fact.shape[axis].clone(); + fact.shape.set(axis, input_dim + TDim::Val(lookback as i64)); + } + Ok(tvec!(fact)) + } + + /// At declutter time, lower the trivial single-axis case to v1's `Delay`, + /// which has the in-place memmove fast path and NNEF round-tripping. The + /// semantic equivalence is exact: PulseV2Buffer { lookback: […N…] } with + /// a single axis matches Delay { axis, delay: 0, overlap: N } — same + /// state size, same per-pulse output (`input + N`), zero-initialised + /// history on first eval. + /// + /// Only fires when exactly one axis has non-zero lookback. Multi-axis + /// cases (none today, but the data structure leaves room) keep the v2 + /// op until a multi-axis Delay equivalent exists. + fn declutter( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + let Some((axis, lookback)) = self.buffered_axis() else { + // No buffered axis at all → identity, shunt entirely. + return TypedModelPatch::shunt_one_op(model, node); + }; + let input_fact = model.outlet_fact(node.inputs[0])?; + let delay = tract_pulse_opl::ops::Delay::new_typed(input_fact, axis, 0, lookback); + Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, delay)?)) + } +} + +#[derive(Debug, Clone)] +pub struct PulseV2BufferState { + /// History on the buffered axis, shape `lookback` on that axis and + /// matching the input on all other axes. `None` until first eval, then + /// initialised to zeros. + history: Option, +} + +impl OpStateFreeze for PulseV2BufferState { + fn freeze(&self) -> Box { + unimplemented!("PulseV2BufferState::freeze not yet implemented") + } +} + +impl OpState for PulseV2BufferState { + fn eval( + &mut self, + _session: &mut TurnState, + op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + let op = op.downcast_ref::().unwrap(); + let input = args_1!(inputs); + let input_tensor = input.into_tensor(); + + let Some((axis, lookback)) = op.buffered_axis() else { + return Ok(tvec!(input_tensor.into_tvalue())); + }; + + // Initialise history to zeros on first call. Shape matches the input's + // shape with the buffered axis replaced by `lookback`. + let history = match self.history.take() { + Some(h) => h, + None => { + let mut shape = input_tensor.shape().to_vec(); + shape[axis] = lookback; + Tensor::zero_dt(input_tensor.datum_type(), &shape)? + } + }; + + // Output = concat(history, input) on the buffered axis. + let output = Tensor::stack_tensors(axis, &[history, input_tensor])?; + + // Update history = last `lookback` samples of the output. + let total = output.shape()[axis]; + debug_assert!(total >= lookback); + let new_history = output.slice(axis, total - lookback, total)?.into_tensor(); + self.history = Some(new_history); + + Ok(tvec!(output.into_tvalue())) + } +} diff --git a/pulse/src/v2_conv.rs b/pulse/src/v2_conv.rs new file mode 100644 index 0000000000..f1f6c847e7 --- /dev/null +++ b/pulse/src/v2_conv.rs @@ -0,0 +1,77 @@ +use crate::internal::*; +use crate::v2::{AxisRegion, PulseV2Action, PulseV2Region, PulseV2Symbols, RegionTransform}; +use tract_pulse_opl::tract_core::ops::cnn::MaxPool; +use tract_pulse_opl::tract_core::ops::cnn::{Conv, pools::PoolSpec}; + +/// Compute input regions for any op that has a PoolSpec (Conv, MaxPool, etc.). +/// Extends the streaming axis start backward by the overlap. +fn pool_spec_input_regions( + pool_spec: &PoolSpec, + source_region: &PulseV2Region, + n_inputs: usize, +) -> TractResult> { + let dilations = pool_spec.dilations(); + let strides = pool_spec.strides(); + let kernel_shape = &pool_spec.kernel_shape; + + let geo_axes = + pool_spec.data_format.h_axis()..pool_spec.data_format.h_axis() + kernel_shape.len(); + + let mut axes = source_region.axes.clone(); + let mut overlap_per_axis = tvec![0usize; source_region.rank()]; + for (geo_ix, ax_ix) in geo_axes.enumerate() { + if let Some(AxisRegion::Streaming { start, .. }) = axes.get_mut(ax_ix) { + let kernel_field = (kernel_shape[geo_ix] - 1) * dilations[geo_ix]; + let s = strides[geo_ix]; + let overlap = kernel_field.saturating_sub(s - 1); + let lookback = if s > 1 && overlap > 0 { ((overlap + s - 1) / s) * s } else { overlap }; + overlap_per_axis[ax_ix] = overlap; + if lookback > 0 { + *start = start.clone() - TDim::Val(lookback as i64); + } + } + } + let data_region = PulseV2Region { axes }; + + let mut regions = tvec![Some(data_region)]; + for _ in 1..n_inputs { + regions.push(None); + } + // Pass per-input overlap hints (only for the data input). + let overlaps = tvec![overlap_per_axis]; + Ok(Some(PulseV2Action::InputRegions(regions, Some(overlaps)))) +} + +fn conv_input_regions( + op: &dyn TypedOp, + source_region: &PulseV2Region, + _symbols: &PulseV2Symbols, +) -> TractResult> { + let conv = op.downcast_ref::().unwrap(); + // Padding on streaming axes is decomposed into Pad + Conv(valid) by + // decompose_streaming_padding() before pulsification reaches here. + pool_spec_input_regions(&conv.pool_spec, source_region, 3) +} + +fn maxpool_input_regions( + op: &dyn TypedOp, + source_region: &PulseV2Region, + _symbols: &PulseV2Symbols, +) -> TractResult> { + let pool = op.downcast_ref::().unwrap(); + pool_spec_input_regions(&pool.pool_spec, source_region, 1) +} + +inventory::submit! { + RegionTransform { + type_id: std::any::TypeId::of::(), + func: conv_input_regions, + } +} + +inventory::submit! { + RegionTransform { + type_id: std::any::TypeId::of::(), + func: maxpool_input_regions, + } +} diff --git a/pulse/src/v2_deconv.rs b/pulse/src/v2_deconv.rs new file mode 100644 index 0000000000..25179f33cb --- /dev/null +++ b/pulse/src/v2_deconv.rs @@ -0,0 +1,158 @@ +use std::ops::AddAssign; + +use crate::internal::*; +use crate::v2::{PulseV2Action, PulseV2Region, PulseV2Symbols, RegionTransform}; +use tract_pulse_opl::tract_core::num_traits::Zero; +use tract_pulse_opl::tract_core::ops::cnn::deconv::Deconv; +use tract_pulse_opl::tract_core::ops::{FrozenOpState, OpStateFreeze}; + +fn deconv_transform( + op: &dyn TypedOp, + _source_region: &PulseV2Region, + _symbols: &PulseV2Symbols, +) -> TractResult> { + let deconv = op.downcast_ref::().unwrap(); + let dilations = deconv.pool_spec.dilations(); + let kernel_shape = &deconv.pool_spec.kernel_shape; + + // Overlap on the output = (K-1)*D for each spatial axis. + // For 1D: just one axis. + let axis = deconv.pool_spec.data_format.h_axis(); + let overlap = (kernel_shape[0] - 1) * dilations[0]; + + if overlap == 0 { + return Ok(None); // No overlap, default pass-through is fine. + } + + // Wire the deconv op normally, then append PulseV2DeconvAccum. + Ok(Some(PulseV2Action::WireOpThenPostOp(Box::new(PulseV2DeconvAccum { axis, overlap })))) +} + +inventory::submit! { + RegionTransform { + type_id: std::any::TypeId::of::(), + func: deconv_transform, + } +} + +/// Stateful op that accumulates overlapping deconv output. +/// +/// The deconv produces (input_pulse + overlap) output samples per pulse. +/// This op: +/// 1. Adds the buffer (from previous pulse's tail) to the first `overlap` positions +/// 2. Saves the last `overlap` positions as the new buffer +/// 3. Emits the first (output_len - overlap) positions +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PulseV2DeconvAccum { + pub axis: usize, + pub overlap: usize, +} + +impl Op for PulseV2DeconvAccum { + fn name(&self) -> StaticName { + "PulseV2DeconvAccum".into() + } + fn info(&self) -> TractResult> { + Ok(vec![format!("axis={} overlap={}", self.axis, self.overlap)]) + } + op_as_typed_op!(); +} + +impl EvalOp for PulseV2DeconvAccum { + fn is_stateless(&self) -> bool { + false + } + + fn state( + &self, + _session: &TurnState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(PulseV2DeconvAccumState { buffer: None }))) + } +} + +impl TypedOp for PulseV2DeconvAccum { + as_op!(); + + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let mut fact = inputs[0].clone(); + let dim = fact.shape[self.axis].clone() - self.overlap.to_dim(); + fact.shape.set(self.axis, dim); + Ok(tvec!(fact)) + } +} + +#[derive(Debug, Clone)] +struct PulseV2DeconvAccumState { + buffer: Option, +} + +impl OpStateFreeze for PulseV2DeconvAccumState { + fn freeze(&self) -> Box { + unimplemented!("PulseV2DeconvAccumState::freeze") + } +} + +impl OpState for PulseV2DeconvAccumState { + fn eval( + &mut self, + _session: &mut TurnState, + op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + let op = op.downcast_ref::().unwrap(); + let input = args_1!(inputs); + let mut input = input.into_tensor(); + let input_len = input.shape()[op.axis]; + + if input_len <= op.overlap { + // Edge case: output is entirely overlap. Accumulate and emit nothing. + if let Some(ref buf) = self.buffer { + dispatch_numbers!(Self::add_buffer(input.datum_type())( + &mut input, buf, op.axis, input_len + ))?; + } + self.buffer = Some(input); + let mut shape = self.buffer.as_ref().unwrap().shape().to_vec(); + shape[op.axis] = 0; + return Ok(tvec!( + Tensor::zero_dt(self.buffer.as_ref().unwrap().datum_type(), &shape)?.into_tvalue() + )); + } + + // Add buffer to first `overlap` positions. + if let Some(ref buf) = self.buffer { + dispatch_numbers!(Self::add_buffer(input.datum_type())( + &mut input, buf, op.axis, op.overlap + ))?; + } + + // Save last `overlap` positions as new buffer. + let emit_len = input_len - op.overlap; + self.buffer = Some(input.slice(op.axis, emit_len, input_len)?.into_tensor()); + + // Emit the first part. + let output = input.slice(op.axis, 0, emit_len)?; + Ok(tvec!(output.into_tvalue())) + } +} + +impl PulseV2DeconvAccumState { + fn add_buffer( + input: &mut Tensor, + buffer: &Tensor, + axis: usize, + count: usize, + ) -> TractResult<()> { + let buf_len = buffer.shape()[axis]; + let add_len = count.min(buf_len); + let mut input_view = input.to_plain_array_view_mut::()?; + let buffer_view = buffer.to_plain_array_view::()?; + let mut target = input_view.slice_axis_mut(tract_ndarray::Axis(axis), (0..add_len).into()); + let source = + buffer_view.slice_axis(tract_ndarray::Axis(axis), (buf_len - add_len..).into()); + target += &source; + Ok(()) + } +} diff --git a/pulse/src/v2_pad.rs b/pulse/src/v2_pad.rs new file mode 100644 index 0000000000..47a56fb144 --- /dev/null +++ b/pulse/src/v2_pad.rs @@ -0,0 +1,154 @@ +use crate::internal::*; +use crate::v2::{PulseV2Action, PulseV2Region, PulseV2Symbols, RegionTransform}; +use tract_pulse_opl::tract_core::ops::array::{Pad, PadMode}; + +fn pad_transform( + op: &dyn TypedOp, + _source_region: &PulseV2Region, + symbols: &PulseV2Symbols, +) -> TractResult> { + let pad = op.downcast_ref::().unwrap(); + let PadMode::Constant(value) = &pad.mode else { + return Ok(None); + }; + Ok(Some(PulseV2Action::ReplaceOp(Box::new(PulseV2Pad { + pads: pad.pads.clone(), + value: value.clone(), + pulse_id: symbols.pulse_id.clone(), + stream_sym: symbols.stream.clone(), + })))) +} + +inventory::submit! { + RegionTransform { + type_id: std::any::TypeId::of::(), + func: pad_transform, + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PulseV2Pad { + pub pads: Vec<(usize, usize)>, + pub value: Arc, + pub pulse_id: Symbol, + pub stream_sym: Symbol, +} + +impl Op for PulseV2Pad { + fn name(&self) -> StaticName { + "PulseV2Pad".into() + } + fn info(&self) -> TractResult> { + Ok(vec![format!("{:?}", self.pads)]) + } + op_as_typed_op!(); +} + +impl EvalOp for PulseV2Pad { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + _node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input = args_1!(inputs); + let t = session.resolved_symbols.get(&self.pulse_id).unwrap_or(0) as usize; + let s = session.resolved_symbols.get(&self.stream_sym).unwrap_or(0) as usize; + + let mut result = input.into_tensor(); + for (axis, &(before, after)) in self.pads.iter().enumerate() { + if before == 0 && after == 0 { + continue; + } + let p = result.shape()[axis]; + let real = p.min(s.saturating_sub(t * p)); + let mut parts: Vec = Vec::new(); + + // Before-pad: all at T=0. + if t == 0 && before > 0 { + let mut shape = result.shape().to_vec(); + shape[axis] = before; + parts.push(Tensor::broadcast_scalar_to_shape(&self.value, &shape)?); + } + + // Real data. + if real > 0 { + parts.push(result.slice(axis, 0, real)?.into_tensor()); + } + + // After-pad: emitted for positions past S. + // This pulse covers input-space positions [t*p, (t+1)*p). + // After-pad occupies positions [s, s+after). + // Intersection of [(t+1)*p capped to relevant range] with [s, s+after). + let pulse_end = t * p + p; // (t+1)*p in input space + if after > 0 && pulse_end > s { + let after_start = s.max(t * p); + let after_end = (s + after).min(pulse_end); + let ap_count = after_end.saturating_sub(after_start); + if ap_count > 0 { + let mut shape = result.shape().to_vec(); + shape[axis] = ap_count; + parts.push(Tensor::broadcast_scalar_to_shape(&self.value, &shape)?); + } + } + + result = if parts.is_empty() { + let mut shape = result.shape().to_vec(); + shape[axis] = 0; + Tensor::zero_dt(result.datum_type(), &shape)? + } else { + Tensor::stack_tensors(axis, &parts)? + }; + } + Ok(tvec!(result.into_tvalue())) + } +} + +impl TypedOp for PulseV2Pad { + as_op!(); + + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let mut fact = inputs[0].clone(); + let t = TDim::Sym(self.pulse_id.clone()); + let s = TDim::Sym(self.stream_sym.clone()); + for (axis, &(before, after)) in self.pads.iter().enumerate() { + if before == 0 && after == 0 { + continue; + } + let p = fact.shape[axis].clone(); + // real data = min(P, max(0, S - T*P)) + let real_data = TDim::Min(vec![ + p.clone(), + TDim::Max(vec![TDim::Val(0), s.clone() - t.clone() * p.clone()]), + ]); + // before: all emitted at T=0 → max(0, before * (1 - min(1, T))) + // simpler: before * Eq(T, 0)... but Eq gives 0/1 and we'd need before * Eq(T,0). + // Use: min(before, max(0, before - T*P)) which gives `before` at T=0, 0 at T≥1 + // (assuming P ≥ 1 and before < 2*P... actually before - T*P at T=1 is before - P, + // which could be > 0 if before > P. But we emit ALL before at T=0.) + // + // Correct: before if T==0, else 0. Express as: before * Ge(1, T+1)... no. + // Use: max(0, before - T * max(before, 1)) — at T=0: before, at T≥1: ≤0. + // Simplest correct: min(before, max(0, 1-T) * before) + // Actually just: before * (1 - min(1, T)) ... still messy. + // Pragmatic: before * Eq(T, 0). TDim has Eq! + let bp = + TDim::Val(before as i64) * TDim::Eq(Box::new(t.clone()), Box::new(TDim::Val(0))); + // after: emitted when stream ends. Same approach but harder. + // Use: after * Eq(real_data + T*P, S) ... no, could span multiple pulses. + // For now use the incremental formula: + // How far past the stream end at the start / end of this pulse. + let cum_before = TDim::Max(vec![TDim::Val(0), t.clone() * p.clone() - s.clone()]); + let cum_after = TDim::Max(vec![TDim::Val(0), (t.clone() + 1) * p.clone() - s.clone()]); + let ap = TDim::Min(vec![TDim::Val(after as i64), cum_after]) + - TDim::Min(vec![TDim::Val(after as i64), cum_before]); + let out_dim = bp + real_data + ap; + fact.shape.set(axis, out_dim); + } + Ok(tvec!(fact)) + } +} diff --git a/pulse/src/v2_slice.rs b/pulse/src/v2_slice.rs new file mode 100644 index 0000000000..9a70ecae11 --- /dev/null +++ b/pulse/src/v2_slice.rs @@ -0,0 +1,98 @@ +/// PulseV2 handling for Slice on the streaming axis. +/// +/// Under fixed-pulse semantics, Slice on the streaming axis is purely +/// bookkeeping: per-pulse output shape equals input shape (constant `P`), +/// the data passes through unchanged. The slice's `start` is the v2-flavoured +/// equivalent of v1's `delay` — it shifts the output stream's effective +/// position relative to the input stream — and it's tracked at sink/merge +/// time, not in per-pulse tensor sizes. The op exists at all so consumers +/// (e.g. test harness, downstream merges) can read the `start`/`end` +/// metadata off the graph; it shunts itself out at declutter time. +/// +/// Slice on a non-streaming axis is left alone — the original `Slice` op +/// runs normally there. +use crate::internal::*; +use crate::v2::{AxisRegion, PulseV2Action, PulseV2Region, PulseV2Symbols, RegionTransform}; +use tract_pulse_opl::tract_core::ops::array::Slice; + +fn slice_transform( + op: &dyn TypedOp, + source_region: &PulseV2Region, + _symbols: &PulseV2Symbols, +) -> TractResult> { + let slice = op.downcast_ref::().unwrap(); + let on_streaming_axis = source_region + .axes + .get(slice.axis) + .is_some_and(|a| matches!(a, AxisRegion::Streaming { .. })); + if !on_streaming_axis { + // Non-streaming axis: original Slice op behaves correctly under + // fixed-pulse, no replacement needed. + return Ok(None); + } + Ok(Some(PulseV2Action::ReplaceOp(Box::new(PulseV2Slice { + axis: slice.axis, + start: slice.start.clone(), + end: slice.end.clone(), + })))) +} + +inventory::submit! { + RegionTransform { type_id: std::any::TypeId::of::(), func: slice_transform } +} + +/// Pass-through op on the streaming axis. `start` and `end` are kept as +/// metadata so the harness can account for the delay shift. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PulseV2Slice { + pub axis: usize, + pub start: TDim, + pub end: TDim, +} + +impl Op for PulseV2Slice { + fn name(&self) -> StaticName { + "PulseV2Slice".into() + } + + fn info(&self) -> TractResult> { + Ok(vec![format!("axis={} [{}, {})", self.axis, self.start, self.end)]) + } + + op_as_typed_op!(); +} + +impl EvalOp for PulseV2Slice { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + _node_id: usize, + _session: &TurnState, + inputs: TVec, + ) -> TractResult> { + Ok(inputs) + } +} + +impl TypedOp for PulseV2Slice { + as_op!(); + + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + Ok(tvec!(inputs[0].clone())) + } + + /// At declutter time the metadata role is over (any consumer that needed + /// to read `start`/`end` already had its chance during pulsification); + /// shunt the op so the resulting graph matches v1's topology — Buffer/ + /// Delay only, no streaming-axis Slice nodes. + fn declutter( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + TypedModelPatch::shunt_one_op(model, node) + } +}