Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
014d1a4
compare --stream: skip intermediate nodes whose streaming axis was op…
kali Mar 31, 2026
e965535
nnef: fix slice with symbolic end on padded dimension
kali Mar 31, 2026
d08544a
block-left-1: use tract_assert S>=0 instead of deser workaround
kali Mar 31, 2026
e20f19b
harness: add mask-phase harnesses (block-l-eq-p-mask, block-left-1-mask)
kali Mar 31, 2026
c19fcba
pulsify block-left-1-mask: FoldWindowAttention + token fold/unfold pu…
kali Apr 1, 2026
36833d8
harness/sdpa-pulse: add d2 progress-report diagrams
kali Apr 1, 2026
433cf24
ex05: flat-token windowed attention with ALiBi pos-bias; ROI pulsific…
kali Apr 2, 2026
239ace4
Remove FoldWindowAttention; ex04/ex05 stream via ChunkWindowMask
kali Apr 2, 2026
510a5ab
ex06: batch+multihead attention; fix Iff pulsifier rank promotion
kali Apr 2, 2026
d6fa643
README: add ex06-batch-multihead entry and update arc
kali Apr 2, 2026
e46966b
sdpa-pulse ex07: chunk-level position bias, Div() in TDim coord expre…
kali Apr 2, 2026
6493eff
pulse: fix MultiBroadcastTo per-pulse size; support non-streaming sou…
kali Apr 3, 2026
d3fec7c
pulse: pulsify Slice when input is non-streaming but end contains str…
kali Apr 3, 2026
239ef70
pulse: pulsify encoder attention — chunk-window mask, non-streaming I…
kali Apr 3, 2026
0e7b3c3
sdpa-pulse ex08/ex09/ex10: batch, multi-head, and inverted-Iff harnes…
kali Apr 7, 2026
b1577d8
propagate_roi: annotate false-branch (inputs[2]) for inverted Iff con…
kali Apr 7, 2026
dc3d42a
pulse: pulsify inverted Iff convention (select(~mask, fill, scores))
kali Apr 7, 2026
31dadf3
Extend PropagateRoi to ScaledMaskedSoftmax via TypedOp::input_roi
kali Apr 7, 2026
05df3ff
pulse: pulsify Transformer-XL RPE skew trick for windowed attention (…
kali Apr 7, 2026
0efb8bb
pulse: pulsify Range op; fix range.rs len computation for symbolic bo…
kali Apr 7, 2026
351c76e
encoder pulsification: propagate uniform_tdim through DynSlice + stal…
kali Apr 8, 2026
370a8d5
fix rebase onto main's systematic ROI propagation
kali Apr 8, 2026
ea23b38
ex02: rename ci.sh to ci-failing.sh (multi-input handle_stream not ye…
kali Apr 8, 2026
2fcd347
binary pulsifier: walk upstream for uniform_tdim through scalar ops
kali Apr 8, 2026
68e9723
REVISIT: add item 13 — systematic uniform_tdim propagation
kali Apr 8, 2026
157d2ed
classify_chunk_window: handle offset coordinates from ROI bubbling
kali Apr 8, 2026
45a7987
REVISIT: add item 14 — classify_chunk_window offset handling
kali Apr 8, 2026
ec6a286
pulsify_qk: re-add constant K path for symmetric RPE pre-slicing
kali Apr 8, 2026
5044e2c
sdpa-pulse: add ex14 test cases (reduced APE, reduced skew, large tab…
kali Apr 8, 2026
741fc9a
pulsify_qk: add try_compute_const_with_substitution for symbolic chains
kali Apr 8, 2026
072a434
try_compute_const: use eval_with_session instead of eval
kali Apr 9, 2026
30944b8
sdpa-pulse: add ex15 — shared posEnc constant with linear projection …
kali Apr 9, 2026
1ad654c
ex15: beef up with BATCH, dynamic posEnc slice, shared posEnc (2 cons…
kali Apr 9, 2026
71aa45c
pulsify_qk: re-evaluate with larger symbol value when RPE table too s…
kali Apr 9, 2026
d0b19ae
REVISIT: add item 15 — encoder skew trick T→P vs pre-sliced r_pos_window
kali Apr 9, 2026
45138fd
ex15: add ci-failing.sh with left_chunks=5 reproducing encoder failure
kali Apr 9, 2026
fdc6e6c
ex15: minimal repro of encoder skew trick failure
kali Apr 9, 2026
e0c1c02
ex15: strip to minimal — separate q/k inputs, no v, no softmax, no li…
kali Apr 9, 2026
4cdaa3b
ex15: remove BATCH — not needed to reproduce
kali Apr 9, 2026
c6dd78d
ex15: valid pulsifiable config — T=16, P=4, left_chunks=3, W=16
kali Apr 9, 2026
606303f
ex15: add ci.sh (was dropped during amend)
kali Apr 9, 2026
20c44fb
ex15: reproduce encoder.p1 failure with subsampled streaming dimension
kali Apr 9, 2026
44c7ed6
--set and concretize_symbols: accept TDim expressions as values
kali Apr 9, 2026
1b54d14
Revert "--set and concretize_symbols: accept TDim expressions as values"
kali Apr 9, 2026
54ea584
DiagGather: fold skew trick into single op for pulsification
kali Apr 9, 2026
ae31dc3
WIP: streaming ASR example with pulsified encoder
kali Apr 9, 2026
2efa3f1
Fix two bugs in GpuDelay for pulsed streaming models
kali Apr 9, 2026
8b600be
Fix flat_copy byte/element mismatch in GPU device context
kali Apr 9, 2026
ce30c5f
Streaming ASR: pulsify preprocessor, buffer features between stages
kali Apr 9, 2026
b2e5f95
Streaming ASR: single-line progressive output, stats at end
kali Apr 9, 2026
ccbac29
Streaming ASR: loading messages, smaller preprocessor pulse
kali Apr 9, 2026
95ad7db
Fix warnings in GpuDelay and streaming example
kali Apr 9, 2026
4a64a3a
Streaming ASR: run all 4 models on GPU, fix pulse metadata read order
kali Apr 9, 2026
cc343a7
Streaming ASR: refactor into NemotronModels + StreamState
kali Apr 10, 2026
99389ed
Streaming ASR: Arc<NemotronModels> with spawn() method
kali Apr 10, 2026
8989a0e
Streaming ASR: add Config struct with clap derive
kali Apr 10, 2026
4d0044e
Streaming ASR: add live microphone support via cpal
kali Apr 10, 2026
13df70b
CI: encoder pulsified run test, preprocessor small pulse, cleanup
kali Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ members = [
"examples/keras-tract-tf2",
"examples/nemo-parakeet-asr",
"examples/nemo-nemotron-asr",
"examples/nemo-nemotron-streaming-asr",
"examples/nnef-dump-mobilenet-v2",
"examples/nnef-mobilenet-v2",
"examples/nnef-mobilenet-v2-api",
Expand Down
67 changes: 43 additions & 24 deletions cli/src/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,10 @@ pub fn handle_stream(
}
}

let stream_dim = max_delay + 3 * input_pulse + input_pulse / 2;
let stream_dim = {
let raw = max_delay + 3 * input_pulse + input_pulse / 2;
((raw + input_pulse - 1) / input_pulse) * input_pulse
};
let concrete_sym_values = SymbolValues::default().with(&stream_symbol, stream_dim as _);

// Second pass: build full metadata with fixed_output_len
Expand Down Expand Up @@ -414,27 +417,39 @@ pub fn handle_stream(
let result = tract_core::plan::eval(session, op_state, node, input)?;

if let Some(info) = pulse_meta.get(&node.name) {
let output_offset = (i + 1) * info.output_pulse;
// Check if this pulse has valid output for this node
if output_offset > info.delay
&& output_offset - info.output_pulse < info.delay + info.fixed_output_len
{
let (p_o, count) = if output_offset - info.output_pulse < info.delay {
// Beginning of signal: partial overlap
let count = output_offset - info.delay;
(info.output_pulse - count, count)
} else if output_offset > info.delay + info.fixed_output_len {
// End of signal: partial overlap
let count = info.delay + info.fixed_output_len
- (output_offset - info.output_pulse);
(0, count)
} else {
// Full pulse in valid region
(0, info.output_pulse)
};
let valid = result[0].slice(info.output_axis, p_o, p_o + count)?;
node_slices.entry(node.name.clone()).or_default().push(valid.into_tensor());
node_axes.insert(node.name.clone(), info.output_axis);
// Skip nodes where the optimizer removed the streaming axis (e.g.
// ChangeAxes squeezing a singleton batch dim from an intermediate
// EinSum). The optimised model is still correct end-to-end; we just
// cannot slice along the expected axis any more.
let streaming_axis_ok = result[0].rank() > info.output_axis
&& result[0].shape()[info.output_axis] == info.output_pulse;
if streaming_axis_ok {
let output_offset = (i + 1) * info.output_pulse;
// Check if this pulse has valid output for this node
if output_offset > info.delay
&& output_offset - info.output_pulse
< info.delay + info.fixed_output_len
{
let (p_o, count) = if output_offset - info.output_pulse < info.delay {
// Beginning of signal: partial overlap
let count = output_offset - info.delay;
(info.output_pulse - count, count)
} else if output_offset > info.delay + info.fixed_output_len {
// End of signal: partial overlap
let count = info.delay + info.fixed_output_len
- (output_offset - info.output_pulse);
(0, count)
} else {
// Full pulse in valid region
(0, info.output_pulse)
};
let valid = result[0].slice(info.output_axis, p_o, p_o + count)?;
node_slices
.entry(node.name.clone())
.or_default()
.push(valid.into_tensor());
node_axes.insert(node.name.clone(), info.output_axis);
}
}
} else if i == 0 && node.outputs.len() == 1 {
// Non-streaming node: capture on first pulse only
Expand Down Expand Up @@ -540,8 +555,12 @@ where

if **needed_shape != *reference.shape() {
let Ok(reshaped) = reference.clone().into_shape(&needed_shape) else {
comparison_error = Some(format!("Incompatible shape on output {slot} reference is {reference:?}, model expects {:?}.", needed_shape));
tags.style = Some(Red.into());
// Shapes are structurally incompatible — the pulsed model uses a
// different intermediate representation (e.g. windowed attention
// [P, key_window] vs full attention [S, S]). Treat as unchecked
// rather than failing; the final output will still be verified.
tags.style = Some(Yellow.into());
unchecked.insert(node.id);
continue;
};
reference = reshaped;
Expand Down
75 changes: 38 additions & 37 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,13 @@ impl Parameters {
let (key, value) = set.split_once('=').with_context(|| {
format!("--set and --hint must be in the X=value form, got {set}")
})?;
let value: i64 = value
.parse()
.with_context(|| format!("value expected to be an integer, got {value}"))?;
let tdim = tract_core::internal::parse_tdim(&typed_model.symbols, value)
.with_context(|| format!("Failed to parse value for --set {key}={value}"))?;
let key = typed_model.get_or_intern_symbol(key);
values.set(&key, value);
match tdim.to_i64() {
Ok(v) => values.set(&key, v),
Err(_) => values.set_tdim(&key, tdim),
}
}
Ok(values)
}
Expand Down Expand Up @@ -702,6 +704,38 @@ impl Parameters {
dec.optimize(&mut m)?;
Ok(m)
});
if let Some(set) = matches.get_many::<String>("set") {
let values = Self::parse_set_and_hint(typed_model.as_ref().unwrap(), set)?;
stage!("set", typed_model -> typed_model, |mut m: TypedModel| {
for node in m.eval_order()? {
let node = m.node_mut(node);
if let Some(op) = node.op_as_mut::<Const>() {
if op.val().datum_type() == DatumType::TDim { {
// get inner value to Arc<Tensor>
let mut constant:Tensor = (**op.val()).clone();
// Generally a shape or hyperparam
constant
.try_as_plain_mut()?
.as_slice_mut::<TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&values));

*op = Const::new(constant.into_arc_tensor())?;
}
}
}
}
m.concretize_dims(&values)
});
stage!("set-declutter", typed_model -> typed_model, |mut m| {
let mut dec = tract_core::optim::Optimizer::declutter();
if let Some(steps) = matches.get_one::<String>("declutter-set-step") {
dec = dec.stopping_at(steps.parse()?);
}
dec.optimize(&mut m)?;
Ok(m)
})
}
#[cfg(not(feature = "pulse"))]
{
if matches.get_one::<String>("pulse").is_some() {
Expand Down Expand Up @@ -753,39 +787,6 @@ impl Parameters {
stage!(&format!("{}_declutter", transform.name()), typed_model -> typed_model, |m:TypedModel| m.into_decluttered());
}
}

if let Some(set) = matches.get_many::<String>("set") {
let values = Self::parse_set_and_hint(typed_model.as_ref().unwrap(), set)?;
stage!("set", typed_model -> typed_model, |mut m: TypedModel| {
for node in m.eval_order()? {
let node = m.node_mut(node);
if let Some(op) = node.op_as_mut::<Const>() {
if op.val().datum_type() == DatumType::TDim { {
// get inner value to Arc<Tensor>
let mut constant:Tensor = (**op.val()).clone();
// Generally a shape or hyperparam
constant
.try_as_plain_mut()?
.as_slice_mut::<TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&values));

*op = Const::new(constant.into_arc_tensor())?;
}
}
}
}
m.concretize_dims(&values)
});
stage!("set-declutter", typed_model -> typed_model, |mut m| {
let mut dec = tract_core::optim::Optimizer::declutter();
if let Some(steps) = matches.get_one::<String>("declutter-set-step") {
dec = dec.stopping_at(steps.parse()?);
}
dec.optimize(&mut m)?;
Ok(m)
})
}
if matches.get_flag("nnef-cycle") {
stage!("nnef-cycle", typed_model -> typed_model, |m:TypedModel| {
let nnef = super::nnef(matches);
Expand Down
32 changes: 32 additions & 0 deletions core/src/ops/array/dyn_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,41 @@ impl TypedOp for DynSlice {
ensure!(inputs.len() == 3);
let mut fact = inputs[0].without_value();
fact.shape.set(self.axis, self.len.clone());
// Propagate uniform_tdim when begin is statically known to be 0.
// With begin=0 the result coordinates are identical to the input coordinates, so
// the uniform_tdim predicate (which is expressed in terms of coordinate symbols)
// remains valid for the sliced output.
let begin_is_zero = inputs[1]
.konst
.as_ref()
.map(|k| {
k.cast_to_dt(i64::datum_type())
.ok()
.and_then(|c| {
c.try_as_plain().ok().and_then(|p| {
p.as_slice::<i64>().ok().map(|s| s.iter().all(|&v| v == 0))
})
})
.unwrap_or(false)
})
.unwrap_or(false);
if begin_is_zero {
fact.uniform_tdim = inputs[0].uniform_tdim.clone();
}
Ok(tvec!(fact))
}

fn input_roi(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TVec<Option<TDim>>>> {
let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
// Propagate output ROI to the data input only; start/end scalars don't carry ROI.
Ok(Some(tvec![Some(roi.clone()), None, None]))
}

fn axes_mapping(
&self,
inputs: &[&TypedFact],
Expand Down
21 changes: 21 additions & 0 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,27 @@ impl TypedOp for TypedBinOp {
fact.uniform_tdim = Some(TDim::Mul(vec![a.clone(), b.clone()]).reduce());
}
}
// And-specific: if one input carries a chunk-window uniform_tdim and the
// other has None (e.g. a padding-validity mask whose uniform_tdim chain
// broke because the audio-length scalar has no coordinate expression),
// propagate the chunk-window expression. In streaming inference every
// token in the chunk window is a valid (non-padding) token, so the
// None-side is effectively always-True within the window.
if fact.uniform_tdim.is_none() && self.0.is::<crate::ops::logic::And>() {
use crate::ops::logic::classify_chunk_window;
let cw_expr = match (&inputs[0].uniform_tdim, &inputs[1].uniform_tdim) {
(Some(a), None) if classify_chunk_window(&a.clone().simplify()).is_some() => {
Some(a.clone())
}
(None, Some(b)) if classify_chunk_window(&b.clone().simplify()).is_some() => {
Some(b.clone())
}
_ => None,
};
if let Some(expr) = cw_expr {
fact.uniform_tdim = Some(expr);
}
}
// Fallback: one side has uniform_tdim, the other is a scalar constant
if fact.uniform_tdim.is_none() {
for (expr, konst_fact) in [
Expand Down
45 changes: 45 additions & 0 deletions core/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,16 @@ impl AxisOp {
for _ in to.iter().rev() {
shape.insert(*at, 1.into());
}
} else if shape.len() >= from.len() + *at {
// from_volume == to_volume was already verified above. The actual shape
// values may be symbolically equivalent to `from` without being structurally
// equal (e.g. `S` vs `P*(S/P)`). Trust the from.len() count and apply.
for _ in from {
shape.remove(*at);
}
for d in to.iter().rev() {
shape.insert(*at, d.try_into()?);
}
} else {
bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
}
Expand Down Expand Up @@ -704,6 +714,41 @@ fn remap_uniform_tdim(expr: &TDim, axis_op: &AxisOp) -> Option<TDim> {
impl TypedOp for AxisOp {
as_op!();

fn input_roi(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TVec<Option<TDim>>>> {
use crate::ops::logic::{build_chunk_window_roi, classify_chunk_window};
let AxisOp::Reshape(at, _from, _to) = self else {
return crate::optim::propagate_roi::bubble_roi(model, node);
};
if node.inputs.len() != 1 {
return crate::optim::propagate_roi::bubble_roi(model, node);
}
let output_fact = &node.outputs[0].fact;
let roi = match &output_fact.region_of_interest {
Some(r) => r.clone().simplify(),
None => return Ok(None),
};
let cw = match classify_chunk_window(&roi) {
Some(cw) => cw,
None => return Ok(Some(tvec![Some(roi)])),
};
// If the reshape swaps the two axes in the block [at, at+1], swap col/row in the ROI.
let at = *at;
let (new_row, new_col) = if (cw.row_axis == at && cw.col_axis == at + 1)
|| (cw.row_axis == at + 1 && cw.col_axis == at)
{
(cw.col_axis, cw.row_axis)
} else {
(cw.row_axis, cw.col_axis)
};
let symbols = &model.symbols;
let new_roi = build_chunk_window_roi(&symbols, cw.p, cw.left_chunks, new_row, new_col);
Ok(Some(tvec![Some(new_roi)]))
}

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
if let Some(bqf) =
inputs[0].exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
Expand Down
46 changes: 46 additions & 0 deletions core/src/ops/einsum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,52 @@ impl EvalOp for EinSum {
}

impl TypedOp for EinSum {
fn input_roi(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TVec<Option<TDim>>>> {
use crate::ops::logic::{build_chunk_window_roi, classify_chunk_window};
let output_fact = &node.outputs[0].fact;
let roi = match &output_fact.region_of_interest {
Some(r) => r.clone().simplify(),
None => return Ok(None),
};
let cw = match classify_chunk_window(&roi) {
Some(cw) => cw,
None => return Ok(None),
};
let symbols = &model.symbols;
let num_inputs = node.inputs.len();
// For each input, find the axis that maps to the output col_axis.
// Annotate only inputs that have the col axis but NOT the row axis
// (i.e., the key/position input, not the query input).
let mut result: TVec<Option<TDim>> = tvec![None; num_inputs];
for ix in 0..num_inputs {
let col_in_input = self
.axes
.iter_all_axes()
.find(|ax| {
ax.outputs[0].contains(&cw.col_axis)
&& ax.inputs.get(ix).map_or(false, |v| !v.is_empty())
})
.and_then(|ax| ax.inputs.get(ix)?.first().copied());
let has_row = self.axes.iter_all_axes().any(|ax| {
ax.outputs[0].contains(&cw.row_axis)
&& ax.inputs.get(ix).map_or(false, |v| !v.is_empty())
});
// Only annotate the key input: has col axis but not row axis.
if let Some(col_ax) = col_in_input {
if !has_row {
// Use row_axis=1 as a placeholder (doesn't affect pulsifier logic).
result[ix] =
Some(build_chunk_window_roi(symbols, cw.p, cw.left_chunks, 1, col_ax));
}
}
}
Ok(Some(result))
}

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let shapes = self.actual_input_shapes_from_facts(inputs)?;
for i in 0..inputs.len() {
Expand Down
3 changes: 2 additions & 1 deletion core/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ impl TypedOp for ElementWiseOp {
// Ops with a TDim arm (e.g. Floor → identity) pass the value through;
// ops without one return an error and uniform_tdim stays None.
let mut tmp = tensor0(tdim.clone());
if self.0.eval_in_place(&mut tmp, None).is_ok() {
let eval_result = self.0.eval_in_place(&mut tmp, None);
if eval_result.is_ok() {
fact.uniform_tdim = tmp
.try_as_plain()
.ok()
Expand Down
Loading
Loading