diff --git a/core/src/lib.rs b/core/src/lib.rs index 3cfd830932..8994faa1cd 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -45,13 +45,13 @@ //! tract-tensorflow or tract-onnx crates. //! -#[cfg(feature="blas")] -extern crate cblas; -#[cfg(feature="accelerate")] +#[cfg(feature = "accelerate")] extern crate accelerate_src; -#[cfg(feature="blis")] +#[cfg(feature = "blis")] extern crate blis_src; -#[cfg(feature="openblas")] +#[cfg(feature = "blas")] +extern crate cblas; +#[cfg(feature = "openblas")] extern crate openblas_src; extern crate bit_set; @@ -81,8 +81,8 @@ pub mod ops; pub mod axes; pub mod broadcast; -pub mod framework; pub mod floats; +pub mod framework; pub mod model; pub mod optim; pub mod plan; @@ -98,7 +98,7 @@ mod late_bind; pub mod prelude { pub use crate::framework::Framework; pub use crate::model::*; - pub use crate::plan::{SimplePlan, SimpleState, PlanOptions}; + pub use crate::plan::{PlanOptions, SimplePlan, SimpleState}; pub use crate::value::{IntoTValue, TValue}; pub use std::sync::Arc; pub use tract_data::prelude::*; @@ -118,8 +118,9 @@ pub mod internal { pub use crate::ops::change_axes::*; pub use crate::ops::element_wise::ElementWiseMiniOp; pub use crate::ops::{Cost, EvalOp, FrozenOpState, Op, OpState, Validation}; - pub use crate::plan::{ SessionState, SessionStateHandler }; + pub use crate::plan::{SessionState, SessionStateHandler}; pub use crate::prelude::*; + pub use crate::runtime::{DefaultRuntime, Runnable, Runtime, State}; pub use dims; pub use downcast_rs as tract_downcast_rs; pub use std::borrow::Cow; @@ -131,10 +132,9 @@ pub mod internal { dispatch_copy, dispatch_datum, dispatch_datum_by_size, dispatch_floatlike, dispatch_numbers, }; pub use tvec; - pub use {args_1, args_2, args_3, args_4, args_5, args_6, args_7, args_8}; + pub use {args_1, args_2, args_3, args_4, args_5, args_6, args_7, args_8, args_9}; pub use {as_op, impl_op_same_as, not_a_typed_op, op_as_typed_op}; pub use {bin_to_super_type, element_wise, element_wise_oop}; - pub use crate::runtime::{Runtime, Runnable, State, DefaultRuntime}; } #[cfg(test)] diff --git a/core/src/ops/macros.rs b/core/src/ops/macros.rs index 3a9794da14..0ee4f9a19f 100644 --- a/core/src/ops/macros.rs +++ b/core/src/ops/macros.rs @@ -173,6 +173,30 @@ macro_rules! args_8 { }}; } +#[allow(unused_macros)] +#[macro_export] +macro_rules! args_9 { + ($inputs:expr) => {{ + let mut inputs = $inputs; + if inputs.len() != 9 { + $crate::internal::bail!("Expected 9 arg, got {:?}", inputs) + } + inputs.reverse(); + let result = ( + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + ); + result + }}; +} + #[macro_export] macro_rules! impl_op_same_as { () => { @@ -233,4 +257,3 @@ macro_rules! trivial_op_state_freeeze { } }; } - diff --git a/core/src/ops/mod.rs b/core/src/ops/mod.rs index 4760d1207b..f4a95d2636 100644 --- a/core/src/ops/mod.rs +++ b/core/src/ops/mod.rs @@ -32,6 +32,7 @@ pub mod scan; pub mod source; pub mod submodel; pub mod unimpl; +pub mod vptq; pub use downsample::Downsample; pub use memory::*; diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs new file mode 100644 index 0000000000..98a9f6900a --- /dev/null +++ b/core/src/ops/vptq.rs @@ -0,0 +1,334 @@ +use std::collections::HashSet; + +use tract_data::itertools::Itertools; +use tract_ndarray::Array1; + +use crate::{ + internal::*, + ops::{ + array::{Gather, Topk}, + math::shift_left, + }, +}; +use tract_linalg::{mmm::FusedSpec, ops}; + +#[derive(Debug, Clone)] +pub struct VPTQGemm { + pub vector_len: usize, + pub in_features: usize, + pub out_features: usize, + pub is_indice_packed: bool, + pub group_size: usize, + pub outlier_size: usize, +} + +impl Op for VPTQGemm { + fn name(&self) -> Cow { + "VPTQGemm".into() + } + + op_as_typed_op!(); +} +fn shift_right_zero_and_1(input: TValue, shift_value: TValue) -> TractResult { + let input = input.to_array_view::()?; + let shift_value = shift_value.to_array_view::()?; + let out_shape = crate::broadcast::multi_broadcast(&[input.shape(), shift_value.shape()])?; + let mut out = unsafe { Tensor::uninitialized_dt(DatumType::I32, &out_shape)? }; + crate::ndarray::Zip::from(out.to_array_view_mut::()?) + .and_broadcast(input) + .and_broadcast(shift_value) + .for_each(|c, a, b| *c = a.checked_shr(*b as u32).unwrap_or(0i32) & 1i32); + Ok(out.into_tvalue()) +} + +fn gather_all_elements(centroids: &Tensor, indices: &Tensor) -> TractResult { + let &[_, _, vlen] = centroids.shape() else { bail!("wrong centroids shape") }; + let &[b, n_indices_x, n_indices_y] = indices.shape() else { + bail!("wrong indice shape {:?}", indices.shape()) + }; + let mut out = unsafe { + Tensor::uninitialized_dt(centroids.datum_type(), &[b, n_indices_x * n_indices_y, vlen])? + }; + indices.to_array_view::()?.iter().enumerate().for_each(|(idx, idx_val)| { + let idx_val = *idx_val as usize; + out.assign_slice(idx..idx + 1, centroids, idx_val..(idx_val + 1), 1).unwrap(); + }); + Ok(out) +} + +impl VPTQGemm { + /// decompression of indexes + fn eval_unpack_index_tensor( + &self, + pack_tensor: Tensor, + index_bits: usize, + num_elements: usize, + ) -> TractResult { + let wf = tensor1(&(0..32i32).collect_vec()).into_shape(&[1, 1, 1, 32])?; + + let pack_tensor_shape = pack_tensor.shape().to_vec(); + + let mut pre_shift_pack_tensor_shape = pack_tensor_shape.clone(); + pre_shift_pack_tensor_shape.push(1); + + let mut out = shift_right_zero_and_1( + pack_tensor.into_shape(&pre_shift_pack_tensor_shape)?.into(), + wf.into(), + )?; + + let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone(); + let pval = post_shift_pack_tensor_shape.pop().unwrap(); + post_shift_pack_tensor_shape.push(32 * pval); + out = out.into_tensor().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue(); + + let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements); + if pad_size > 0 { + let axis = out.rank() - 1; + let end = out.shape()[axis] - pad_size; + out = out.slice(axis, 0, end)?.into(); + } + + let mut post_pad_pack_tensor_shape = pack_tensor_shape.clone(); + post_pad_pack_tensor_shape.pop(); + let auto = out.shape().last().unwrap() / index_bits; + post_pad_pack_tensor_shape.push(auto); + post_pad_pack_tensor_shape.push(index_bits); + out = out.into_tensor().into_shape(&post_pad_pack_tensor_shape)?.into(); + + let wf1 = Tensor::from( + Array1::from_iter(0..(index_bits as i32)).to_shape([1, 1, 1, index_bits])?.into_owned(), + ); + + out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap(); + + let axis = out.rank() - 1; + out = out + .into_tensor() + .to_array_view_mut::()? + .sum_axis(tract_ndarray::Axis(axis)) + .into_tvalue(); + + let unpack_indice = out.cast_to_dt(DatumType::I32)?; + + let mut indices = + unsafe { Tensor::uninitialized_dt(DatumType::I32, unpack_indice.shape())? }; + + crate::ndarray::Zip::from(&mut indices.to_array_view_mut::()?) + .and_broadcast(unpack_indice.to_array_view::()?) + .for_each(|indice, upack_indice| *indice = upack_indice & ((1 << index_bits) - 1)); + indices = indices.slice(2, 0, num_elements)?; + + Ok(indices) + } + + fn eval_extract_from_vector_quant( + &self, + centroids: Tensor, + indices: Tensor, + group_size: usize, + ) -> TractResult { + let mut indices = indices.clone(); + let [num_codebooks, num_centroids, vector_len] = *centroids.shape() else { + unimplemented!("unexected centroid shape ?") + }; + + if self.is_indice_packed { + let index_bits = (num_centroids as f32).log2().ceil() as usize; + indices = self.eval_unpack_index_tensor(indices, index_bits, group_size)?; + } + + let selected_centroids = gather_all_elements(¢roids, &indices)?; + + let remain = selected_centroids.volume() / (num_codebooks * group_size * vector_len); + + let mut qweight = selected_centroids + .into_shape(&[num_codebooks, remain, group_size, vector_len])? + .permute_axes(&[0, 1, 3, 2])? // NOTE: costly in tract (applied in memory) + .into_shape(&[num_codebooks, remain * vector_len, group_size])? + .permute_axes(&[1, 0, 2])? // NOTE: costly in tract (applied in memory) + .into_shape(&[vector_len * remain, num_codebooks * group_size])?; + + let dim0 = qweight.shape()[0]; + let padding = (-(self.out_features as i16)).wrapping_rem_euclid(vector_len as i16); + if padding > 0 { + let end = dim0 as i16 - padding; + qweight = qweight.slice(0, 0, end as usize)?; + } + Ok(qweight) + } +} + +impl EvalOp for VPTQGemm { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let ( + input, + indices, + centroids, + outlier_indices, + outlier_centroids, + perm, + weight_scale, + weight_bias, + bias, + ) = args_9!(inputs); + let indices = indices.into_tensor(); + let mut centroids = centroids.into_tensor(); + let outlier_indices = outlier_indices.into_tensor(); + let mut outlier_centroids = outlier_centroids.into_tensor(); + let perm = perm.into_tensor(); + let weight_scale = weight_scale.into_tensor(); + let weight_bias = weight_bias.into_tensor(); + let bias = bias.into_tensor(); + + if weight_scale.len() > 1 { + unimplemented!("'weight scale' for vptq not yet supported !"); + } + if weight_bias.len() > 1 { + unimplemented!("'weight bias' for vptq not yet supported !"); + } + let enable_norm = weight_scale.len() > 1 && weight_bias.len() > 1; + if bias.len() > 1 { + unimplemented!("'bias' for vptq not yet supported !"); + } + assert!([2, 3].contains(&input.rank())); + assert!(input.datum_type().is_float()); + + assert_eq!(indices.rank(), 3); + assert_eq!(indices.datum_type(), DatumType::I32); + assert_eq!(centroids.rank(), 3); + assert!(centroids.datum_type().is_float()); + + let enable_outlier = outlier_indices.len() > 0; + if enable_outlier { + assert_eq!(outlier_indices.rank(), 3); + assert_eq!(outlier_indices.datum_type(), DatumType::I32); + assert_eq!(outlier_centroids.rank(), 3); + assert!(outlier_centroids.datum_type().is_float()); + } + let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; + let fdtypes = HashSet::from(_fdtypes); + if fdtypes.len() != 1 { + log::warn!("force cast centroids to be same type as input: {:?}", input.datum_type()); + centroids = centroids.cast_to_dt(input.datum_type())?.into_owned(); + outlier_centroids = outlier_centroids.cast_to_dt(input.datum_type())?.into_owned(); + } + let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; + let fdtypes = HashSet::from(_fdtypes); + assert!(fdtypes.len() == 1, "mixed dtypes: {_fdtypes:?}"); + + let mut qweight = + self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; + if enable_outlier { + // same as centroids to qweights except for outlier + let outlier_qweight = self.eval_extract_from_vector_quant( + outlier_centroids, + outlier_indices, + self.outlier_size, + )?; + + qweight = + Tensor::stack_tensors(1, &[&outlier_qweight, &qweight]).with_context(|| { + format!( + "outlier.shape:{:?}, main.shape:{:?}", + &outlier_qweight.shape(), + &qweight.shape() + ) + })?; + } + + let enable_perm = perm.len() > 1; + if enable_perm { + let axis = 0; + let dim = perm.shape()[0]; + let top_k = Topk { axis, largest: false, fallback_k: dim.into() }; + let invert_perm = + top_k.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?.remove(1); + // TODO: manage case with quant dim == 'in' ? + // if self.vector_quant_dim == "in": + // assert True, "Not implemented" + // qweight = qweight[invert_perm, :] + + let perm_gather_axis = 1; + let gather_perm = Gather { axis: perm_gather_axis }; + qweight = gather_perm + .eval(tvec!(qweight.into(), invert_perm))? + .pop() + .context("apply gather to permutation") + .unwrap() + .into_tensor(); + } + + let data_type = *fdtypes.iter().next().unwrap(); + + if enable_norm { + qweight = (qweight.into_array::()? * weight_scale.to_array_view::()? + + weight_bias.to_array_view::()?) + .into_tensor(); + } + // NOTE: next steps is fast matmul equivalent of { + // let einsum_op = EinSum::new("ik,kj->ij".parse()?, f32::datum_type()); + // einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) + // } + qweight = qweight.permute_axes(&[1, 0])?; + let op = ops(); + let ishape = input.shape(); + + let &n = qweight.shape().last().unwrap(); + + let (m, k, out_shape) = match *ishape { + [m, k] => (m, k, vec![m, n]), + [b, m, k] => (m, k, vec![b, m, n]), + _ => { + bail!("unexpected rank {:?}", ishape.len()) + } + }; + + let input_offset = input.rank() - 2; + let weight_offset = qweight.rank() - 2; + + /* this would be better for Intel where there is no f16 support, but the kernel selection + APIs are not up to the task (yet) + + let acc_type = if tract_linalg::has_fp16() { + f16::datum_type() + } else { + f32::datum_type() + }; + + */ + let mmm = op.mmm(data_type, Some(m), Some(k), Some(n)).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[0]; + + let cstore = unsafe { mmm.c_view(input_offset, 1 + input_offset) }; + + let a = pack_a.prepare_tensor(&input, 1 + input_offset, input_offset)?; + let b = pack_b.prepare_tensor(&qweight, weight_offset, 1 + weight_offset)?; + unsafe { + let out = Tensor::uninitialized_dt(data_type, &out_shape)?; + let non_linear = &[ + FusedSpec::AddMatMul { + a: tract_linalg::mmm::AsInputValue::Owned(a), + b: tract_linalg::mmm::AsInputValue::Owned(b), + packing: 0, + }, + FusedSpec::Store(cstore.wrap(&out.view())), + ]; + mmm.run(m, n, non_linear)?; + Ok(tvec!(out.into_tvalue())) + } + } +} + +impl TypedOp for VPTQGemm { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let mut tfact = inputs[0].without_value(); + tfact.shape.set(tfact.rank() - 1, self.out_features.into()); + Ok(tvec!(tfact)) + } + + as_op!(); +} diff --git a/harness/nnef-test-cases/vptq-basic/io.npz b/harness/nnef-test-cases/vptq-basic/io.npz new file mode 100644 index 0000000000..cfdaabb905 Binary files /dev/null and b/harness/nnef-test-cases/vptq-basic/io.npz differ diff --git a/harness/nnef-test-cases/vptq-basic/model.nnef.tgz b/harness/nnef-test-cases/vptq-basic/model.nnef.tgz new file mode 100644 index 0000000000..32afbde349 Binary files /dev/null and b/harness/nnef-test-cases/vptq-basic/model.nnef.tgz differ diff --git a/harness/nnef-test-cases/vptq-basic/runme.sh b/harness/nnef-test-cases/vptq-basic/runme.sh new file mode 100755 index 0000000000..828dcb90de --- /dev/null +++ b/harness/nnef-test-cases/vptq-basic/runme.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +cd $(dirname $0) +set -ex + +: ${TRACT_RUN:=cargo run -p tract $CARGO_OPTS --} + +$TRACT_RUN ./model.nnef.tgz --nnef-tract-core run --input-from-bundle ./io.npz --assert-output-bundle ./io.npz -q diff --git a/harness/nnef-test-cases/vptq-with-perm/io.npz b/harness/nnef-test-cases/vptq-with-perm/io.npz new file mode 100644 index 0000000000..8c32731c49 Binary files /dev/null and b/harness/nnef-test-cases/vptq-with-perm/io.npz differ diff --git a/harness/nnef-test-cases/vptq-with-perm/model.nnef.tgz b/harness/nnef-test-cases/vptq-with-perm/model.nnef.tgz new file mode 100644 index 0000000000..21fc4f1483 Binary files /dev/null and b/harness/nnef-test-cases/vptq-with-perm/model.nnef.tgz differ diff --git a/harness/nnef-test-cases/vptq-with-perm/runme.sh b/harness/nnef-test-cases/vptq-with-perm/runme.sh new file mode 100755 index 0000000000..828dcb90de --- /dev/null +++ b/harness/nnef-test-cases/vptq-with-perm/runme.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +cd $(dirname $0) +set -ex + +: ${TRACT_RUN:=cargo run -p tract $CARGO_OPTS --} + +$TRACT_RUN ./model.nnef.tgz --nnef-tract-core run --input-from-bundle ./io.npz --assert-output-bundle ./io.npz -q diff --git a/linalg/src/frame/mmm/scratch.rs b/linalg/src/frame/mmm/scratch.rs index ac34fa6da4..80e1fb0020 100644 --- a/linalg/src/frame/mmm/scratch.rs +++ b/linalg/src/frame/mmm/scratch.rs @@ -47,7 +47,8 @@ impl TLSScratch { ker_specs.extend_from_slice(&scratch.ker_specs); unsafe { - self.blob.ensure_size_and_align(scratch.blob_size, scratch.blob_align); + self.blob + .ensure_size_and_align(scratch.blob_size, scratch.blob_align); for LocDependant { loc, ker_spec, .. } in &scratch.loc_dependant { #[allow(clippy::single_match)] @@ -121,7 +122,13 @@ impl ScratchSpaceImpl { let mut offset = 0; let mut align = std::mem::size_of::<*const ()>(); fn ld(spec: usize, uspec: usize, loc: usize) -> LocDependant { - LocDependant { spec, ker_spec: uspec, loc, buffer_a: None, buffer_b: None } + LocDependant { + spec, + ker_spec: uspec, + loc, + buffer_a: None, + buffer_b: None, + } } for (ix, spec) in specs.iter().enumerate() { offset = offset.next_multiple_of(&align); @@ -138,25 +145,35 @@ impl ScratchSpaceImpl { FS::RoundingShiftRight(s, rp) => FKS::RoundingShiftRight(*s, *rp), FS::QScale(s, rp, m) => FKS::QScale(*s, *rp, *m), FS::BinPerRow(_, _) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * ker.mr(); FusedKerSpec::Done } FS::BinPerCol(_, _) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * ker.nr(); FusedKerSpec::Done } FS::AddRowColProducts(_, _) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * (ker.mr() + ker.nr()); FusedKerSpec::Done } - FS::Store(_) | FS::AddUnicast(_) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + FS::AddUnicast(_) => { + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * ker.mr() * ker.nr(); FusedKerSpec::Done } + FS::Store(store) => { + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); + offset += store.item_size * ker.mr() * ker.nr(); + FusedKerSpec::Done + } FS::LeakyRelu(t) => FKS::LeakyRelu(*t.to_scalar()?), FS::AddMatMul { a, b, packing } => { let mut ld = ld(ix, self.ker_specs.len(), offset); @@ -206,10 +223,16 @@ impl ScratchSpaceImpl { let err = ker.kernel(tls.ker_specs()); debug_assert_eq!(err, 0, "Kernel return error {err}"); } else { - let remnant_down = - if down < self.valid_down_tiles { ker.mr() } else { self.remnant_down }; - let remnant_right = - if right < self.valid_right_tiles { ker.nr() } else { self.remnant_right }; + let remnant_down = if down < self.valid_down_tiles { + ker.mr() + } else { + self.remnant_down + }; + let remnant_right = if right < self.valid_right_tiles { + ker.nr() + } else { + self.remnant_right + }; self.for_border_tile(ker, specs, tls, down, right, remnant_down, remnant_right)?; let err = ker.kernel(tls.ker_specs()); debug_assert_eq!(err, 0, "Kernel return error {err}"); @@ -230,9 +253,20 @@ impl ScratchSpaceImpl { ) -> TractResult<()> { use FusedKerSpec as FKS; use FusedSpec as FS; - let ScratchSpaceImpl { ker_specs, loc_dependant, .. } = self; + let ScratchSpaceImpl { + ker_specs, + loc_dependant, + .. + } = self; debug_assert!(specs.len() + 2 == ker_specs.len()); - for LocDependant { spec, ker_spec, loc, buffer_a, buffer_b } in loc_dependant { + for LocDependant { + spec, + ker_spec, + loc, + buffer_a, + buffer_b, + } in loc_dependant + { let spec = specs.get_unchecked(*spec); let it = match spec { FS::BinPerRow(v, op) => { @@ -265,8 +299,9 @@ impl ScratchSpaceImpl { FS::AddUnicast(store) => FKS::AddUnicast(store.tile_c(down, right)), FS::Store(c_store) => FKS::Store(c_store.tile_c(down, right)), FS::AddMatMul { a, b, packing } => { - let scratch = - (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp).as_mut().unwrap(); + let scratch = (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp) + .as_mut() + .unwrap(); if scratch.panel_a_id != down { scratch.ptr_a = a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; @@ -305,7 +340,14 @@ impl ScratchSpaceImpl { ) -> TractResult<()> { use FusedKerSpec as FKS; use FusedSpec as FS; - for LocDependant { spec, ker_spec: uspec, loc, buffer_a, buffer_b } in &self.loc_dependant { + for LocDependant { + spec, + ker_spec: uspec, + loc, + buffer_a, + buffer_b, + } in &self.loc_dependant + { let loc = tls.blob.as_mut_ptr().add(*loc); let spec = specs.get_unchecked(*spec); let it = match spec { @@ -442,13 +484,13 @@ impl ScratchSpaceImpl { FS::AddMatMul { a, b, packing } => { let scratch = (loc as *mut AddMatMulTemp).as_mut().unwrap(); if scratch.panel_a_id != down { - scratch.ptr_a = a - .panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; + scratch.ptr_a = + a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_a_id = down; } if scratch.panel_b_id != right { - scratch.ptr_b = b - .panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)))?; + scratch.ptr_b = + b.panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_b_id = right; } FKS::AddMatMul { @@ -482,7 +524,12 @@ impl ScratchSpaceImpl { where TI: LADatum, { - for LocDependant { spec, ker_spec: uspec, .. } in self.loc_dependant.iter() { + for LocDependant { + spec, + ker_spec: uspec, + .. + } in self.loc_dependant.iter() + { let spec = specs.get_unchecked(*spec); let ker_spec = tls.ker_specs::().get_unchecked(*uspec); if let (FusedSpec::Store(c_store), FusedKerSpec::Store(tmp)) = (spec, ker_spec) { diff --git a/nnef/src/ops/core.rs b/nnef/src/ops/core.rs index dec4aeec94..2a607e0f76 100644 --- a/nnef/src/ops/core.rs +++ b/nnef/src/ops/core.rs @@ -27,6 +27,7 @@ mod store; mod submodel; mod topk; mod trilu; +mod vptq; pub fn register(registry: &mut Registry) { registry.register_unit_element_wise("tract_core_round_even", &ops::math::RoundHalfToEven {}); @@ -67,4 +68,5 @@ pub fn register(registry: &mut Registry) { range::register(registry); topk::register(registry); trilu::register(registry); + vptq::register(registry); } diff --git a/nnef/src/ops/core/vptq.rs b/nnef/src/ops/core/vptq.rs new file mode 100644 index 0000000000..118a0885b3 --- /dev/null +++ b/nnef/src/ops/core/vptq.rs @@ -0,0 +1,108 @@ +use crate::internal::*; +use crate::ser::*; +use tract_core::ops::vptq::VPTQGemm; + +pub fn register(registry: &mut Registry) { + registry.register_dumper(ser_vptq_gemm); + registry.register_primitive( + "tract_core_vptq_gemm", + &[ + TypeName::Scalar.tensor().named("input"), + TypeName::Scalar.tensor().named("indices"), + TypeName::Scalar.tensor().named("centroids"), + TypeName::Scalar.tensor().named("outlier_indices"), + TypeName::Scalar.tensor().named("outlier_centroids"), + TypeName::Scalar.tensor().named("perm"), + TypeName::Scalar.tensor().named("weight_scale"), + TypeName::Scalar.tensor().named("weight_bias"), + TypeName::Scalar.tensor().named("bias"), + TypeName::Integer.named("vector_len"), + TypeName::Integer.tensor().named("in_features"), + TypeName::Integer.tensor().named("out_features"), + TypeName::Integer.tensor().named("group_size"), + TypeName::Integer.tensor().named("outlier_size"), + ], + &[("output", TypeName::Scalar.tensor())], + de_vptq_gemm, + ); +} + +fn ser_vptq_gemm( + ast: &mut IntoAst, + node: &TypedNode, + op: &VPTQGemm, +) -> TractResult>> { + let input = ast.mapping[&node.inputs[0]].clone(); + let indices = ast.mapping[&node.inputs[1]].clone(); + let centroids = ast.mapping[&node.inputs[2]].clone(); + let outlier_indices = ast.mapping[&node.inputs[3]].clone(); + let outlier_centroids = ast.mapping[&node.inputs[4]].clone(); + let perm = ast.mapping[&node.inputs[5]].clone(); + let weight_scale = ast.mapping[&node.inputs[6]].clone(); + let weight_bias = ast.mapping[&node.inputs[7]].clone(); + let bias = ast.mapping[&node.inputs[8]].clone(); + + Ok(Some(invocation( + "tract_core_vptq_gemm", + &[ + input, + indices, + centroids, + outlier_indices, + outlier_centroids, + perm, + weight_scale, + weight_bias, + bias, + ], + &[ + ("vector_len", numeric(op.vector_len)), + ("in_features", numeric(op.in_features)), + ("out_features", numeric(op.out_features)), + ("group_size", numeric(op.group_size)), + ("outlier_size", numeric(op.outlier_size)), + ], + ))) +} + +fn de_vptq_gemm(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { + let input = invocation.named_arg_as(builder, "input")?; + let indices = invocation.named_arg_as(builder, "indices")?; + let centroids = invocation.named_arg_as(builder, "centroids")?; + let outlier_indices = invocation.named_arg_as(builder, "outlier_indices")?; + let outlier_centroids = invocation.named_arg_as(builder, "outlier_centroids")?; + let perm = invocation.named_arg_as(builder, "perm")?; + let weight_scale = invocation.named_arg_as(builder, "weight_scale")?; + let weight_bias = invocation.named_arg_as(builder, "weight_bias")?; + let bias = invocation.named_arg_as(builder, "bias")?; + + let vector_len = invocation.named_arg_as(builder, "vector_len")?; + let in_features = invocation.named_arg_as(builder, "in_features")?; + let out_features = invocation.named_arg_as(builder, "out_features")?; + let is_indice_packed = invocation.named_arg_as(builder, "is_indice_packed")?; + + let group_size = invocation.named_arg_as(builder, "group_size")?; + let outlier_size = invocation.named_arg_as(builder, "outlier_size")?; + + builder.wire( + VPTQGemm { + vector_len, + in_features, + out_features, + is_indice_packed, + group_size, + outlier_size, + }, + &[ + input, + indices, + centroids, + outlier_indices, + outlier_centroids, + perm, + weight_scale, + weight_bias, + bias, + ], + ) +}