diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 43f55a8ff..be3cdf50c 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -85,10 +85,10 @@ mod imp { 5, ); - let offsets = diskann_providers::model::pq::calculate_chunk_offsets_auto( + let offsets = diskann_quantization::views::ChunkOffsets::from_dimensions( data.ncols(), input.num_pq_chunks.get(), - ); + )?; let base = { let threadpool = rayon::ThreadPoolBuilder::new() @@ -97,7 +97,7 @@ mod imp { threadpool.install(|| -> anyhow::Result<_> { Ok(parameters.train( data.as_view(), - diskann_quantization::views::ChunkOffsetsView::new(offsets.as_slice())?, + offsets.as_view(), diskann_quantization::Parallelism::Rayon, &diskann_quantization::random::StdRngBuilder::new(input.seed), &diskann_quantization::cancel::DontCancel, @@ -109,7 +109,7 @@ mod imp { data.ncols(), base.flatten().into(), vec![0.0; data.ncols()].into(), - offsets.into(), + offsets.as_slice().into(), )?; let training_time: MicroSeconds = start.elapsed().into(); diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index ccd2c30a7..4b32fab7a 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -8,10 +8,7 @@ use std::marker::PhantomData; use diskann::{utils::VectorRepr, ANNError}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ - model::{ - pq::{accum_row_inplace, generate_pq_pivots}, - GeneratePivotArguments, - }, + model::{pq::generate_pq_pivots, GeneratePivotArguments}, storage::PQStorage, utils::{BridgeErr, RayonThreadPoolRef, Timer}, }; @@ -136,7 +133,9 @@ where ) .bridge_err()?; - accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice()); + full_pivot_data_mat + .broadcast_rows_mut(centroid.as_slice(), |a, b| *a += *b) + .bridge_err()?; let table = TransposedTable::from_parts( full_pivot_data_mat.as_view(), diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs index 0d26fa680..edcb593d8 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs @@ -353,7 +353,7 @@ mod tests { let c = provider.query_computer(&[-0.5, -0.5]).unwrap(); let expected: f32 = 1.5 * 1.5 * 2.0; assert_eq!( - c.evaluate_similarity(&provider.get_vector_sync(3).unwrap()), + c.evaluate_similarity(provider.get_vector_sync(3).unwrap().as_slice()), expected ); @@ -361,15 +361,15 @@ mod tests { let d = provider.distance_computer(); assert_eq!( d.evaluate_similarity( - &provider.get_vector_sync(0).unwrap(), - &provider.get_vector_sync(3).unwrap() + provider.get_vector_sync(0).unwrap().as_slice(), + provider.get_vector_sync(3).unwrap().as_slice() ), 2.0 ); let slice: &[f32] = &[-0.5, -0.5]; assert_eq!( - d.evaluate_similarity(slice, &provider.get_vector_sync(3).unwrap()), + d.evaluate_similarity(slice, provider.get_vector_sync(3).unwrap().as_slice()), expected, ); } diff --git a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs index 62d9c05e1..9da2296ac 100644 --- a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs @@ -444,21 +444,18 @@ mod tests { // Query Computer. let c = provider.query_computer(&[-0.5, -0.5]).unwrap(); let expected: f32 = 1.5 * 1.5 * 2.0; - assert_eq!( - c.evaluate_similarity(&provider.get_vector_sync(3)), - expected - ); + assert_eq!(c.evaluate_similarity(provider.get_vector_sync(3)), expected); // Distance Computer. let d = provider.distance_computer(); assert_eq!( - d.evaluate_similarity(&provider.get_vector_sync(0), &provider.get_vector_sync(3)), + d.evaluate_similarity(provider.get_vector_sync(0), provider.get_vector_sync(3)), 2.0 ); let slice: &[f32] = &[-0.5, -0.5]; assert_eq!( - d.evaluate_similarity(slice, &provider.get_vector_sync(3)), + d.evaluate_similarity(slice, provider.get_vector_sync(3)), expected, ); } diff --git a/diskann-providers/src/model/mod.rs b/diskann-providers/src/model/mod.rs index f6ae2be75..61addc05a 100644 --- a/diskann-providers/src/model/mod.rs +++ b/diskann-providers/src/model/mod.rs @@ -11,10 +11,10 @@ pub use configuration::IndexConfiguration; pub mod pq; pub use pq::{ FixedChunkPQTable, GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, - NUM_PQ_CENTROIDS, accum_row_inplace, calculate_chunk_offsets_auto, compute_pq_distance, - compute_pq_distance_for_pq_coordinates, direct_distance_impl, distance, - generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, - generate_pq_pivots, generate_pq_pivots_from_membuf, + NUM_PQ_CENTROIDS, compute_pq_distance, compute_pq_distance_for_pq_coordinates, + direct_distance_impl, distance, generate_pq_data_from_pivots_from_membuf, + generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots, + generate_pq_pivots_from_membuf, }; pub mod statistics; diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index cefacacde..352912846 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -101,25 +101,6 @@ where } } -impl PreprocessedDistanceFunction<&Vec, f32> for QueryComputer -where - T: Deref, -{ - fn evaluate_similarity(&self, changing: &Vec) -> f32 { - self.evaluate_similarity(changing.as_slice()) - } -} - -impl PreprocessedDistanceFunction<&&[u8], f32> for QueryComputer -where - T: Deref, -{ - fn evaluate_similarity(&self, changing: &&[u8]) -> f32 { - let changing: &[u8] = changing; - self.evaluate_similarity(changing) - } -} - /// Pre-dispatched distance functions for the `FixedChunkPQTable`. #[derive(Debug)] pub struct VTable { @@ -233,52 +214,6 @@ where } } -/// Perform a comparison between a full-precision vector and quantized vector. -impl DistanceFunction<&[f32], &&[u8], f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, fp: &[f32], q: &&[u8]) -> f32 { - let q: &[u8] = q; - self.evaluate_similarity(fp, q) - } -} - -impl DistanceFunction<&[f32], &Vec, f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, fp: &[f32], q: &Vec) -> f32 { - self.evaluate_similarity(fp, q.as_slice()) - } -} - -/// Perform a comparison between two quantized vectors. -impl DistanceFunction<&&[u8], &&[u8], f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, q0: &&[u8], q1: &&[u8]) -> f32 { - let q0: &[u8] = q0; - let q1: &[u8] = q1; - self.evaluate_similarity(q0, q1) - } -} - -/// Perform a comparison between two quantized vectors. -impl DistanceFunction<&Vec, &Vec, f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, q0: &Vec, q1: &Vec) -> f32 { - self.evaluate_similarity(q0.as_slice(), q1.as_slice()) - } -} - #[cfg(test)] mod tests { use std::marker::PhantomData; diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 76c240e1f..6b6c90d60 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -13,7 +13,8 @@ use diskann_vector::{ use rand::{Rng, distr::Distribution}; use rand_distr::{Normal, Uniform}; -use crate::model::{FixedChunkPQTable, pq::calculate_chunk_offsets_auto}; +use crate::model::FixedChunkPQTable; +use diskann_quantization::views::ChunkOffsets; /// We need a way to generate random queries. /// @@ -130,7 +131,8 @@ pub(crate) fn generate_expected_vector( /// * N + 1: The number of PQ Pivots pub(crate) fn seed_pivot_table(config: TableConfig) -> FixedChunkPQTable { // Get the chunk offsets for the selected dimension and bytes. - let offsets = calculate_chunk_offsets_auto(config.dim, config.pq_chunks); + let chunk_offsets = ChunkOffsets::from_dimensions(config.dim, config.pq_chunks).unwrap(); + let offsets = chunk_offsets.as_slice(); // Create the pivot table following the schema described in the docstring. let mut pivots = Vec::::new(); diff --git a/diskann-providers/src/model/pq/mod.rs b/diskann-providers/src/model/pq/mod.rs index 6338e39ec..ba6a49d25 100644 --- a/diskann-providers/src/model/pq/mod.rs +++ b/diskann-providers/src/model/pq/mod.rs @@ -10,11 +10,9 @@ pub use fixed_chunk_pq_table::{ mod pq_construction; pub use pq_construction::{ - MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS, accum_row_inplace, - calculate_chunk_offsets, calculate_chunk_offsets_auto, generate_pq_data_from_pivots, + MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS, generate_pq_data_from_pivots, generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, - generate_pq_pivots, generate_pq_pivots_from_membuf, get_chunk_from_training_data, - move_train_data_by_centroid, + generate_pq_pivots, generate_pq_pivots_from_membuf, move_train_data_by_centroid, }; /// all metadata of individual sub-component files is written in first 4KB for unified files diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 2862d7e26..0f65b1d44 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -19,6 +19,7 @@ use diskann::{ use diskann_quantization::{ CompressInto, product::{BasicTableView, TransposedTable, train::TrainQuantizer}, + views::{ChunkOffsets, ChunkOffsetsView}, }; use diskann_utils::{ io::Metadata, @@ -94,12 +95,8 @@ where ); } - let mut chunk_offsets: Vec = vec![0; parameters.num_pq_chunks() + 1]; - calculate_chunk_offsets( - parameters.dim(), - parameters.num_pq_chunks(), - &mut chunk_offsets, - ); + let chunk_offsets = + ChunkOffsets::from_dimensions(parameters.dim(), parameters.num_pq_chunks()).bridge_err()?; let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( parameters.num_centers(), @@ -111,8 +108,7 @@ where .train( MatrixView::try_from(train_data, parameters.num_train(), parameters.dim()) .bridge_err()?, - diskann_quantization::views::ChunkOffsetsView::new(chunk_offsets.as_slice()) - .bridge_err()?, + chunk_offsets.as_view(), diskann_quantization::Parallelism::Rayon, &random_provider, &diskann_quantization::cancel::DontCancel, @@ -125,7 +121,7 @@ where pq_storage.write_pivot_data( &full_pivot_data, ¢roid, - &chunk_offsets, + chunk_offsets.as_slice(), parameters.num_centers(), parameters.dim(), storage_provider, @@ -202,8 +198,13 @@ pub fn generate_pq_pivots_from_membuf>( } } - // Calculate the chunk offsets - calculate_chunk_offsets(parameters.dim(), parameters.num_pq_chunks(), offsets); + // Calculate the chunk offsets, filling the caller-owned buffer. + let chunk_offsets_view = ChunkOffsetsView::from_dimensions_into( + parameters.dim(), + parameters.num_pq_chunks(), + offsets, + ) + .bridge_err()?; let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( parameters.num_centers(), @@ -235,7 +236,7 @@ pub fn generate_pq_pivots_from_membuf>( parameters.dim(), ) .bridge_err()?, - diskann_quantization::views::ChunkOffsetsView::new(offsets).bridge_err()?, + chunk_offsets_view, diskann_quantization::Parallelism::Rayon, &rng_builder, &cancelation, @@ -249,35 +250,6 @@ pub fn generate_pq_pivots_from_membuf>( Ok(()) } -/// Gets all instances of a chunk from the training data for all records in the training data. Each vector in the -/// training dataset is divided into chunks and the PQ algorithm handles each vector chunk individually. This method -/// gets the same chunk from each vector in the training data and creates a new vector out of all of them. -/// -/// # Example -/// See tests for examples -#[inline] -pub fn get_chunk_from_training_data( - train_data: &[f32], - num_train: usize, - raw_vector_dim: usize, - chunk_size: usize, - chunk_offset: usize, -) -> Vec { - let mut result: Vec = vec![0.0; num_train * chunk_size]; - - result - // group empty result data into chunks of chunk_size - .chunks_mut(chunk_size) - .enumerate() - // for each chunk, copy the chunk from the training data into the result vector - .for_each(|(chunk_number, result_chunk)| { - let train_data_start = chunk_number * raw_vector_dim + chunk_offset; - let train_data_end = train_data_start + chunk_size; - result_chunk.copy_from_slice(&train_data[train_data_start..train_data_end]); - }); - result -} - /// Calculates the centroid if needed and moves the train_data to to the centroid /// # Arguments /// * `train_data` Dataset @@ -324,52 +296,6 @@ pub fn move_train_data_by_centroid( } } -/// Calculate the number of chunks for the product quantization algorithm. Returns a vector of offsets where -/// each offset corresponds to a chunk based on the index of the chunk in the vector. -/// -/// # Arguments -/// * `dimensions` Number of dimensions of the input data -/// * `num_pq_chunks` - Number of chunks that will be used in the PQ calculation. Each vector will be split into these -/// number of chunks and each chunk will be compressed down to one byte. -/// * `offsets` - An output vector of offsets, where the size is equal to the number of pq chunks + 1. -#[inline] -pub fn calculate_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { - // Calculate each chunk's offset - // If we have 8 dimension and 3 chunks then offsets would be [0,3,6,8] - let mut chunk_offset: usize = 0; - offsets[0] = chunk_offset; - for chunk_index in 0..num_pq_chunks { - chunk_offset += dimensions / num_pq_chunks; - if chunk_index < (dimensions % num_pq_chunks) { - chunk_offset += 1; - } - offsets[chunk_index + 1] = chunk_offset; - } -} - -pub fn calculate_chunk_offsets_auto(dimensions: usize, num_pq_chunks: usize) -> Vec { - let mut offsets = vec![0; num_pq_chunks + 1]; - calculate_chunk_offsets(dimensions, num_pq_chunks, offsets.as_mut_slice()); - offsets -} - -/// Add the row `y` to every row in `x`. -/// -/// # Panics -/// -/// Panics if `y.len() != x.ncols()`. -pub fn accum_row_inplace(mut x: MutMatrixView, y: &[T]) -where - T: Copy + std::ops::AddAssign, -{ - assert_eq!(x.ncols(), y.len()); - x.row_iter_mut().for_each(|row| { - std::iter::zip(row.iter_mut(), y.iter()).for_each(|(a, b)| { - *a += *b; - }); - }); -} - /// streams the base file (data_file), and computes the closest centers in each /// chunk to generate the compressed data_file and stores it in /// pq_compressed_vectors_path. @@ -429,7 +355,9 @@ where let mut full_pivot_data_mat = MutMatrixView::try_from(full_pivot_data.as_mut_slice(), num_centers, full_dim) .bridge_err()?; - accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice()); + full_pivot_data_mat + .broadcast_rows_mut(centroid.as_slice(), |a, b| *a += *b) + .bridge_err()?; pq_storage.write_compressed_pivot_metadata::( num_points, @@ -672,6 +600,29 @@ mod pq_test { utils::{ParallelIteratorInPool, create_thread_pool_for_test, read_bin_from}, }; + /// Test helper: Gets all instances of a chunk from the training data for all records + /// in the training data. Each vector in the training dataset is divided into chunks + /// and the PQ algorithm handles each vector chunk individually. This helper gets the + /// same chunk from each vector in the training data and returns it as a flat vector. + fn get_chunk_from_training_data( + train_data: &[f32], + num_train: usize, + raw_vector_dim: usize, + chunk_size: usize, + chunk_offset: usize, + ) -> Vec { + let mut result: Vec = vec![0.0; num_train * chunk_size]; + result + .chunks_mut(chunk_size) + .enumerate() + .for_each(|(chunk_number, result_chunk)| { + let train_data_start = chunk_number * raw_vector_dim + chunk_offset; + let train_data_end = train_data_start + chunk_size; + result_chunk.copy_from_slice(&train_data[train_data_start..train_data_end]); + }); + result + } + #[test] fn test_move_train_data_by_centroid() { let dim = 20; @@ -1077,9 +1028,8 @@ mod pq_test { // Pre-emptively construct an offset view to compare mismatched slices. // We want to check that the difference in the mismatched chunks is small. - let mut offsets = vec![0; num_pq_chunks + 1]; - calculate_chunk_offsets(train_dim, num_pq_chunks, &mut offsets); - let offset_view = diskann_quantization::views::ChunkOffsetsView::new(&offsets).unwrap(); + let chunk_offsets = ChunkOffsets::from_dimensions(train_dim, num_pq_chunks).unwrap(); + let offset_view = chunk_offsets.as_view(); let full_data = MatrixView::try_from(full_data_vector.as_slice(), num_train, train_dim).unwrap(); let pivot_view = diff --git a/diskann-providers/src/model/pq/views.rs b/diskann-providers/src/model/pq/views.rs index 3329c31d2..0d0782562 100644 --- a/diskann-providers/src/model/pq/views.rs +++ b/diskann-providers/src/model/pq/views.rs @@ -32,6 +32,14 @@ impl From>> for ANNError { } } +// Compatibility with ANNError. +impl From> for ANNError { + #[track_caller] + fn from(value: Bridge) -> Self { + ANNError::log_pq_error(value.into_inner()) + } +} + /////////// // Tests // /////////// diff --git a/diskann-quantization/src/views.rs b/diskann-quantization/src/views.rs index 6ef928345..cc8ec1229 100644 --- a/diskann-quantization/src/views.rs +++ b/diskann-quantization/src/views.rs @@ -54,6 +54,8 @@ pub enum ChunkOffsetError { start: usize, next_val: usize, }, + #[error("scratch buffer length {actual} does not match expected length {expected}")] + ScratchLengthMismatch { expected: usize, actual: usize }, } impl ChunkOffsetsBase @@ -205,6 +207,63 @@ impl<'a> From> for &'a [usize] { } } +impl ChunkOffsets { + /// Build a chunk-offset plan that partitions `dimensions` into `num_pq_chunks` + /// near-equal chunks. The first `dimensions % num_pq_chunks` chunks are one element + /// larger than the rest. + /// + /// Returns an error if the requested partition is not valid (e.g. `dimensions == 0`, + /// `num_pq_chunks == 0`, or `num_pq_chunks > dimensions`). + pub fn from_dimensions( + dimensions: usize, + num_pq_chunks: usize, + ) -> Result { + let mut offsets = vec![0usize; num_pq_chunks + 1].into_boxed_slice(); + fill_chunk_offsets(dimensions, num_pq_chunks, &mut offsets); + Self::new(offsets) + } +} + +impl<'a> ChunkOffsetsView<'a> { + /// Fill the caller-owned `scratch` buffer with the partition for `(dimensions, + /// num_pq_chunks)` and return a validated view borrowing it. + /// + /// See [`ChunkOffsets::from_dimensions`] for the partitioning rule. + /// + /// Returns an error if `scratch.len() != num_pq_chunks + 1` or if the requested + /// partition is not valid (e.g. `dimensions == 0`, `num_pq_chunks == 0`, or + /// `num_pq_chunks > dimensions`). + pub fn from_dimensions_into( + dimensions: usize, + num_pq_chunks: usize, + scratch: &'a mut [usize], + ) -> Result { + let expected = num_pq_chunks + 1; + if scratch.len() != expected { + return Err(ChunkOffsetError::ScratchLengthMismatch { + expected, + actual: scratch.len(), + }); + } + fill_chunk_offsets(dimensions, num_pq_chunks, scratch); + Self::new(scratch) + } +} + +/// Internal helper: fill `offsets` (of length `num_pq_chunks + 1`) with the prefix-sum +/// partitioning of `dimensions` into `num_pq_chunks` chunks. +fn fill_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { + let mut chunk_offset: usize = 0; + offsets[0] = chunk_offset; + for chunk_index in 0..num_pq_chunks { + chunk_offset += dimensions / num_pq_chunks; + if chunk_index < (dimensions % num_pq_chunks) { + chunk_offset += 1; + } + offsets[chunk_index + 1] = chunk_offset; + } +} + /////////////// // ChunkView // /////////////// @@ -425,6 +484,84 @@ mod tests { ); } + ////////////////////////////// + // from_dimensions builders // + ////////////////////////////// + + #[test] + fn from_dimensions_happy_path() { + // Even split: 9 / 3 = 3 each. + let offsets = ChunkOffsets::from_dimensions(9, 3).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 3, 6, 9]); + assert_eq!(offsets.dim(), 9); + assert_eq!(offsets.len(), 3); + + // Uneven split: 8 / 3 = 2 r 2 -> first two chunks get an extra element. + let offsets = ChunkOffsets::from_dimensions(8, 3).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 3, 6, 8]); + + // Single chunk degenerate case. + let offsets = ChunkOffsets::from_dimensions(5, 1).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 5]); + + // dimensions == num_pq_chunks: each chunk is size 1. + let offsets = ChunkOffsets::from_dimensions(4, 4).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 1, 2, 3, 4]); + + // The view-into variant matches the owning constructor. + let mut scratch = [0usize; 4]; + let view = ChunkOffsetsView::from_dimensions_into(8, 3, &mut scratch).unwrap(); + assert_eq!(view.as_slice(), &[0, 3, 6, 8]); + assert_eq!(view.dim(), 8); + assert_eq!(view.len(), 3); + assert_eq!(scratch.as_slice(), &[0, 3, 6, 8]); + } + + #[test] + fn from_dimensions_construction_errors() { + // num_pq_chunks > dimensions -> some chunk would be empty -> NonMonotonic. + let err = ChunkOffsets::from_dimensions(3, 5).unwrap_err(); + assert!( + matches!(err, ChunkOffsetError::NonMonotonic { .. }), + "expected NonMonotonic, got {err:?}" + ); + + // dimensions == 0 -> trivially non-monotonic. + let err = ChunkOffsets::from_dimensions(0, 1).unwrap_err(); + assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); + + // num_pq_chunks == 0 -> length 1 buffer, fails LengthNotAtLeastTwo. + let err = ChunkOffsets::from_dimensions(8, 0).unwrap_err(); + assert!(matches!(err, ChunkOffsetError::LengthNotAtLeastTwo(1))); + + // Scratch length too short -> ScratchLengthMismatch error (no panic). + let mut too_short = [0usize; 3]; + let err = ChunkOffsetsView::from_dimensions_into(8, 3, &mut too_short).unwrap_err(); + assert!(matches!( + err, + ChunkOffsetError::ScratchLengthMismatch { + expected: 4, + actual: 3 + } + )); + + // Scratch length too long -> ScratchLengthMismatch error. + let mut too_long = [0usize; 5]; + let err = ChunkOffsetsView::from_dimensions_into(8, 3, &mut too_long).unwrap_err(); + assert!(matches!( + err, + ChunkOffsetError::ScratchLengthMismatch { + expected: 4, + actual: 5 + } + )); + + // Partition validation errors propagate through the view builder too. + let mut scratch = [0usize; 6]; + let err = ChunkOffsetsView::from_dimensions_into(3, 5, &mut scratch).unwrap_err(); + assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); + } + /////////////// // ChunkView // /////////////// diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index a9352918c..81c9da470 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -161,6 +161,15 @@ impl TryFromError { } } +/// Error returned when a broadcast vector's length does not match the matrix's column +/// count. +#[derive(Debug, Error)] +#[error("broadcast vector length {broadcast_len} does not match matrix ncols {ncols}")] +pub struct BroadcastLenMismatch { + pub ncols: usize, + pub broadcast_len: usize, +} + /// A generator for initializing the entries in a matrix via `Matrix::new`. pub trait Generator { fn generate(&mut self) -> T; @@ -388,6 +397,32 @@ where self.data.as_mut_slice().chunks_exact_mut(ncols) } + /// Apply `op(row[j], &broadcast[j])` to every entry of every row, broadcasting + /// `broadcast` across the rows of the matrix through an `op`. + /// + /// Returns an error if `broadcast.len() != self.ncols()`. The matrix is left + /// unmodified in the error case. + pub fn broadcast_rows_mut( + &mut self, + broadcast: &[U], + mut op: F, + ) -> Result<(), BroadcastLenMismatch> + where + T: MutDenseData, + F: FnMut(&mut T::Elem, &U), + { + if self.ncols() != broadcast.len() { + return Err(BroadcastLenMismatch { + ncols: self.ncols(), + broadcast_len: broadcast.len(), + }); + } + self.row_iter_mut().for_each(|row| { + std::iter::zip(row.iter_mut(), broadcast.iter()).for_each(|(a, b)| op(a, b)); + }); + Ok(()) + } + /// Return an iterator that divides the matrix into sub-matrices with (up to) /// `batchsize` rows with `self.ncols()` columns. /// @@ -2180,4 +2215,71 @@ mod tests { // Verify all elements are 43 assert!(m.as_slice().iter().all(|&x| x == 43)); } + + #[test] + fn test_broadcast_rows_mut() { + // Use the canonical test fixture and broadcast-add a length-3 row across all + // rows. The resulting rows are the original rows offset by [10, 20, 30]. + let data = make_test_matrix(); + let mut m = Matrix::try_from(data.into(), 4, 3).unwrap(); + let bias = [10usize, 20, 30]; + + m.broadcast_rows_mut(&bias, |a, b| *a += *b).unwrap(); + assert_eq!(m.row(0), &[10, 21, 32]); + assert_eq!(m.row(1), &[11, 22, 33]); + assert_eq!(m.row(2), &[12, 23, 34]); + assert_eq!(m.row(3), &[13, 24, 35]); + + // Subtract the same bias to confirm op generality and that we round-trip back to + // the canonical fixture. + m.broadcast_rows_mut(&bias, |a, b| *a -= *b).unwrap(); + test_basic_indexing(&m); + + // Broadcast also works through a `MutMatrixView`. + let mut data = make_test_matrix(); + let mut view = MutMatrixView::try_from(data.as_mut_slice(), 4, 3).unwrap(); + view.broadcast_rows_mut(&bias, |a, b| *a += *b).unwrap(); + assert_eq!(view.row(0), &[10, 21, 32]); + assert_eq!(view.row(3), &[13, 24, 35]); + + // Heterogeneous element types: matrix is f32, broadcast is i32. + let mut fmat = Matrix::::new(0.0, 2, 3); + let scale: [i32; 3] = [1, 2, 3]; + fmat.broadcast_rows_mut(&scale, |a, b| *a += *b as f32) + .unwrap(); + assert_eq!(fmat.row(0), &[1.0, 2.0, 3.0]); + assert_eq!(fmat.row(1), &[1.0, 2.0, 3.0]); + + // Single-row degenerate case. + let mut single = Matrix::::new(5, 1, 4); + single + .broadcast_rows_mut(&[1, 2, 3, 4], |a, b| *a += *b) + .unwrap(); + assert_eq!(single.row(0), &[6, 7, 8, 9]); + } + + #[test] + fn test_broadcast_rows_mut_length_mismatch() { + let data = make_test_matrix(); + let mut m = Matrix::try_from(data.clone().into(), 4, 3).unwrap(); + + // Too short. + let err = m.broadcast_rows_mut(&[1, 2], |a, b| *a += *b).unwrap_err(); + assert_eq!( + err.to_string(), + "broadcast vector length 2 does not match matrix ncols 3" + ); + // Matrix must be unmodified after a failed call. + assert_eq!(m.as_slice(), data.as_slice()); + + // Too long. + let err = m + .broadcast_rows_mut(&[1, 2, 3, 4], |a, b| *a += *b) + .unwrap_err(); + assert_eq!( + err.to_string(), + "broadcast vector length 4 does not match matrix ncols 3" + ); + assert_eq!(m.as_slice(), data.as_slice()); + } }