From 316e21f921f0ad762118b1c41b4e76586cc94424 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Thu, 23 Apr 2026 19:52:14 -0400 Subject: [PATCH 01/13] Move ObjectPool from diskann to diskann-utils Relocates the object pool module so that it is available to crates that depend on diskann-utils but not diskann (notably diskann-quantization, which will gain pool-aware distance-table allocation in a follow-up). diskann::utils::object_pool stays as a re-export for backwards compatibility. Direct importers in diskann-providers, diskann-disk, and diskann-garnet are switched to use diskann_utils::object_pool directly. Internal diskann users continue to use the re-export. --- diskann-disk/src/search/provider/disk_provider.rs | 6 ++---- diskann-garnet/src/provider.rs | 6 ++---- .../graph/provider/async_/bf_tree/quant_vector_provider.rs | 7 ++----- .../provider/async_/fast_memory_quant_vector_provider.rs | 7 ++----- .../graph/provider/async_/memory_quant_vector_provider.rs | 3 ++- diskann-providers/src/model/pq/distance/common.rs | 2 +- diskann-providers/src/model/pq/distance/dynamic.rs | 3 ++- diskann-providers/src/model/pq/distance/innerproduct.rs | 6 ++---- diskann-providers/src/model/pq/distance/l2.rs | 6 ++---- diskann-utils/src/lib.rs | 1 + {diskann/src/utils => diskann-utils/src}/object_pool.rs | 0 diskann/src/graph/index.rs | 2 +- diskann/src/graph/search/scratch.rs | 3 ++- diskann/src/utils/mod.rs | 2 -- 14 files changed, 21 insertions(+), 33 deletions(-) rename {diskann/src/utils => diskann-utils/src}/object_pool.rs (100%) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 0403a5019..b7b30e94a 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -31,10 +31,7 @@ use diskann::{ Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId, NeighborAccessor, NoopGuard, }, - utils::{ - object_pool::{ObjectPool, PoolOption, TryAsPooled}, - IntoUsize, VectorRepr, - }, + utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, }; use diskann_providers::storage::StorageReadProvider; @@ -42,6 +39,7 @@ use diskann_providers::{ model::{compute_pq_distance, compute_pq_distance_for_pq_coordinates}, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, }; +use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; use crate::search::pq::{quantizer_preprocess, PQData, PQScratch}; use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction}; diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 1bb0679ad..e2522b147 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -20,13 +20,11 @@ use diskann::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, - utils::{ - VectorRepr, - object_pool::{AsPooled, ObjectPool, PooledRef, Undef}, - }, + utils::VectorRepr, }; use diskann_providers::model::graph::provider::async_::common::FullPrecision; use diskann_utils::Reborrow; +use diskann_utils::object_pool::{AsPooled, ObjectPool, PooledRef, Undef}; use diskann_vector::{PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric}; use std::{ future, mem, diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs index 57e5120d9..0d26fa680 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs @@ -9,12 +9,9 @@ use std::sync::Arc; use bf_tree::{BfTree, Config}; use bytemuck::bytes_of; -use diskann::{ - ANNError, ANNErrorKind, ANNResult, - error::IntoANNResult, - utils::{VectorRepr, object_pool::ObjectPool}, -}; +use diskann::{ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, utils::VectorRepr}; use diskann_quantization::CompressInto; +use diskann_utils::object_pool::ObjectPool; use diskann_vector::distance::Metric; use thiserror::Error; diff --git a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs index e70776ded..62d9c05e1 100644 --- a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs @@ -14,12 +14,9 @@ use std::sync::{Arc, Mutex}; use crate::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann::{ - ANNError, ANNResult, - error::IntoANNResult, - utils::{VectorRepr, object_pool::ObjectPool}, -}; +use diskann::{ANNError, ANNResult, error::IntoANNResult, utils::VectorRepr}; use diskann_quantization::CompressInto; +use diskann_utils::object_pool::ObjectPool; use diskann_vector::distance::Metric; use super::common::{AlignedMemoryVectorStore, TestCallCount}; diff --git a/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs index bdd191f32..dca7ea762 100644 --- a/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs @@ -14,9 +14,10 @@ use crate::storage::{StorageReadProvider, StorageWriteProvider}; use arc_swap::{ArcSwap, Guard}; #[cfg(test)] use diskann::utils::VectorRepr; -use diskann::{ANNError, ANNResult, utils::object_pool::ObjectPool}; +use diskann::{ANNError, ANNResult}; #[cfg(test)] use diskann_quantization::CompressInto; +use diskann_utils::object_pool::ObjectPool; use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; use super::{VectorGuard, common::TestCallCount}; diff --git a/diskann-providers/src/model/pq/distance/common.rs b/diskann-providers/src/model/pq/distance/common.rs index 2888e6836..f36b5b386 100644 --- a/diskann-providers/src/model/pq/distance/common.rs +++ b/diskann-providers/src/model/pq/distance/common.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann::utils::object_pool::{self, ObjectPool}; +use diskann_utils::object_pool::{self, ObjectPool}; use crate::model::pq::fixed_chunk_pq_table::FixedChunkPQTable; diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index bc9d83b94..cefacacde 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -5,7 +5,8 @@ use std::{ops::Deref, sync::Arc}; -use diskann::{ANNResult, utils::object_pool::ObjectPool}; +use diskann::ANNResult; +use diskann_utils::object_pool::ObjectPool; use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; // Concrete implementations diff --git a/diskann-providers/src/model/pq/distance/innerproduct.rs b/diskann-providers/src/model/pq/distance/innerproduct.rs index 8849f30ff..7d9abfc7e 100644 --- a/diskann-providers/src/model/pq/distance/innerproduct.rs +++ b/diskann-providers/src/model/pq/distance/innerproduct.rs @@ -5,10 +5,8 @@ use std::{ops::Deref, sync::Arc}; -use diskann::{ - ANNResult, - utils::object_pool::{self, ObjectPool, PoolOption}, -}; +use diskann::ANNResult; +use diskann_utils::object_pool::{self, ObjectPool, PoolOption}; use diskann_vector::PreprocessedDistanceFunction; use super::common::get_lookup_table_size; diff --git a/diskann-providers/src/model/pq/distance/l2.rs b/diskann-providers/src/model/pq/distance/l2.rs index e551ee35a..543534796 100644 --- a/diskann-providers/src/model/pq/distance/l2.rs +++ b/diskann-providers/src/model/pq/distance/l2.rs @@ -5,10 +5,8 @@ use std::{ops::Deref, sync::Arc}; -use diskann::{ - ANNResult, - utils::object_pool::{self, ObjectPool, PoolOption}, -}; +use diskann::ANNResult; +use diskann_utils::object_pool::{self, ObjectPool, PoolOption}; use diskann_vector::PreprocessedDistanceFunction; use super::common::get_lookup_table_size; diff --git a/diskann-utils/src/lib.rs b/diskann-utils/src/lib.rs index cd8c1b84d..089e0dece 100644 --- a/diskann-utils/src/lib.rs +++ b/diskann-utils/src/lib.rs @@ -15,6 +15,7 @@ pub use lifetime::WithLifetime; pub mod future; pub mod io; +pub mod object_pool; pub mod sampling; // Views diff --git a/diskann/src/utils/object_pool.rs b/diskann-utils/src/object_pool.rs similarity index 100% rename from diskann/src/utils/object_pool.rs rename to diskann-utils/src/object_pool.rs diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index cd81ba5f1..eca65b8e6 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -52,9 +52,9 @@ use crate::{ utils::{ IntoUsize, TryIntoVectorId, VectorId, async_tools::{self, DynamicBalancer}, - object_pool::{ObjectPool, PooledRef}, }, }; +use diskann_utils::object_pool::{ObjectPool, PooledRef}; #[derive(Debug)] pub struct DiskANNIndex { diff --git a/diskann/src/graph/search/scratch.rs b/diskann/src/graph/search/scratch.rs index 2a4706821..98a5b3127 100644 --- a/diskann/src/graph/search/scratch.rs +++ b/diskann/src/graph/search/scratch.rs @@ -11,8 +11,9 @@ use std::collections::VecDeque; use crate::{ neighbor::{Neighbor, NeighborPriorityQueue}, - utils::{VectorId, object_pool::AsPooled}, + utils::VectorId, }; +use diskann_utils::object_pool::AsPooled; use hashbrown::HashSet; /// In-mem index related limits diff --git a/diskann/src/utils/mod.rs b/diskann/src/utils/mod.rs index bf6ac5c40..86b78f0ab 100644 --- a/diskann/src/utils/mod.rs +++ b/diskann/src/utils/mod.rs @@ -3,8 +3,6 @@ * Licensed under the MIT license. */ -pub mod object_pool; - pub mod async_tools; #[allow(clippy::module_inception)] From ca3a5293bbf6cc201491f3824c1abcc0b2c62871 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Fri, 24 Apr 2026 12:28:50 -0400 Subject: [PATCH 02/13] move chunks_offsets to diskann_quantization --- .../src/backend/exhaustive/product.rs | 2 +- .../src/search/pq/quantizer_preprocess.rs | 19 +++++------ diskann-providers/src/model/mod.rs | 2 +- .../src/model/pq/distance/test_utils.rs | 3 +- diskann-providers/src/model/pq/mod.rs | 7 ++-- .../src/model/pq/pq_construction.rs | 30 +---------------- diskann-quantization/src/views.rs | 33 +++++++++++++++++++ 7 files changed, 50 insertions(+), 46 deletions(-) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 4723753a0..f339a3c44 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -87,7 +87,7 @@ mod imp { 5, ); - let offsets = diskann_providers::model::pq::calculate_chunk_offsets_auto( + let offsets = diskann_quantization::views::calculate_chunk_offsets_auto( data.ncols(), input.num_pq_chunks.get(), ); diff --git a/diskann-disk/src/search/pq/quantizer_preprocess.rs b/diskann-disk/src/search/pq/quantizer_preprocess.rs index cc454ea7b..82f7cc7c5 100644 --- a/diskann-disk/src/search/pq/quantizer_preprocess.rs +++ b/diskann-disk/src/search/pq/quantizer_preprocess.rs @@ -3,6 +3,13 @@ * Licensed under the MIT license. */ +//! PQ quantizer query preprocessing. +//! +//! Prior to the introduction of the [`quantizer_preprocess`] method, the disk index was +//! hard-coded to use L2 distance for comparisons. We're keeping that behavior here - +//! treating `Cosine` and `CosineNormalized` as L2 until a more thorough evaluation can +//! be made. + use diskann::ANNResult; use diskann_vector::distance::Metric; @@ -33,11 +40,7 @@ pub fn quantizer_preprocess( .bridge_err()?; match metric { - // Prior to the introduction of the `quantizer_preprocess` method, the - // disk index was hard-coded to use L2 distance for comparisons. - // - // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` - // as L2 until a more thorough evaluation can be made. + // Cosine and CosineNormalized fall back to L2; see module docs. Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { table.process_into::( &pq_scratch.rotated_query[..dim], @@ -54,11 +57,7 @@ pub fn quantizer_preprocess( } PQTable::Fixed(table) => { match metric { - // Prior to the introduction of the `quantizer_preprocess` method, the - // disk index was hard-coded to use L2 distance for comparisons. - // - // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` - // as L2 until a more thorough evaluation can be made. + // Cosine and CosineNormalized fall back to L2; see module docs. Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { // The scratch only stores the aligned dimension. However, preprocessing // wants the actual dimension used, so we have to shrink the rotated query diff --git a/diskann-providers/src/model/mod.rs b/diskann-providers/src/model/mod.rs index f6ae2be75..e9fceb869 100644 --- a/diskann-providers/src/model/mod.rs +++ b/diskann-providers/src/model/mod.rs @@ -11,7 +11,7 @@ pub use configuration::IndexConfiguration; pub mod pq; pub use pq::{ FixedChunkPQTable, GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, - NUM_PQ_CENTROIDS, accum_row_inplace, calculate_chunk_offsets_auto, compute_pq_distance, + NUM_PQ_CENTROIDS, accum_row_inplace, compute_pq_distance, compute_pq_distance_for_pq_coordinates, direct_distance_impl, distance, generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots, generate_pq_pivots_from_membuf, diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 76c240e1f..677fc32a2 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -13,7 +13,8 @@ use diskann_vector::{ use rand::{Rng, distr::Distribution}; use rand_distr::{Normal, Uniform}; -use crate::model::{FixedChunkPQTable, pq::calculate_chunk_offsets_auto}; +use crate::model::FixedChunkPQTable; +use diskann_quantization::views::calculate_chunk_offsets_auto; /// We need a way to generate random queries. /// diff --git a/diskann-providers/src/model/pq/mod.rs b/diskann-providers/src/model/pq/mod.rs index 6338e39ec..6c1962a4a 100644 --- a/diskann-providers/src/model/pq/mod.rs +++ b/diskann-providers/src/model/pq/mod.rs @@ -11,10 +11,9 @@ pub use fixed_chunk_pq_table::{ mod pq_construction; pub use pq_construction::{ MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS, accum_row_inplace, - calculate_chunk_offsets, calculate_chunk_offsets_auto, generate_pq_data_from_pivots, - generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, - generate_pq_pivots, generate_pq_pivots_from_membuf, get_chunk_from_training_data, - move_train_data_by_centroid, + generate_pq_data_from_pivots, generate_pq_data_from_pivots_from_membuf, + generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots, + generate_pq_pivots_from_membuf, get_chunk_from_training_data, move_train_data_by_centroid, }; /// all metadata of individual sub-component files is written in first 4KB for unified files diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index b2ab3da85..69b95c92a 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -19,6 +19,7 @@ use diskann::{ use diskann_quantization::{ CompressInto, product::{BasicTableView, TransposedTable, train::TrainQuantizer}, + views::calculate_chunk_offsets, }; use diskann_utils::{ io::Metadata, @@ -328,35 +329,6 @@ pub fn move_train_data_by_centroid( } } -/// Calculate the number of chunks for the product quantization algorithm. Returns a vector of offsets where -/// each offset corresponds to a chunk based on the index of the chunk in the vector. -/// -/// # Arguments -/// * `dimensions` Number of dimensions of the input data -/// * `num_pq_chunks` - Number of chunks that will be used in the PQ calculation. Each vector will be split into these -/// number of chunks and each chunk will be compressed down to one byte. -/// * `offsets` - An output vector of offsets, where the size is equal to the number of pq chunks + 1. -#[inline] -pub fn calculate_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { - // Calculate each chunk's offset - // If we have 8 dimension and 3 chunks then offsets would be [0,3,6,8] - let mut chunk_offset: usize = 0; - offsets[0] = chunk_offset; - for chunk_index in 0..num_pq_chunks { - chunk_offset += dimensions / num_pq_chunks; - if chunk_index < (dimensions % num_pq_chunks) { - chunk_offset += 1; - } - offsets[chunk_index + 1] = chunk_offset; - } -} - -pub fn calculate_chunk_offsets_auto(dimensions: usize, num_pq_chunks: usize) -> Vec { - let mut offsets = vec![0; num_pq_chunks + 1]; - calculate_chunk_offsets(dimensions, num_pq_chunks, offsets.as_mut_slice()); - offsets -} - /// Add the row `y` to every row in `x`. /// /// # Panics diff --git a/diskann-quantization/src/views.rs b/diskann-quantization/src/views.rs index 6ef928345..f8882e8b4 100644 --- a/diskann-quantization/src/views.rs +++ b/diskann-quantization/src/views.rs @@ -205,6 +205,39 @@ impl<'a> From> for &'a [usize] { } } +/// Calculate the chunk offsets for the product quantization algorithm. Fills `offsets` +/// with the prefix-sum partitioning of `dimensions` into `num_pq_chunks` chunks, where +/// the first `dimensions % num_pq_chunks` chunks are one element larger than the rest. +/// +/// # Arguments +/// * `dimensions` - Number of dimensions of the input data. +/// * `num_pq_chunks` - Number of chunks that will be used in the PQ calculation. Each +/// vector will be split into these number of chunks and each chunk will be compressed +/// down to one byte. +/// * `offsets` - An output slice of offsets, where the length must equal `num_pq_chunks + 1`. +#[inline] +pub fn calculate_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { + // Calculate each chunk's offset + // If we have 8 dimension and 3 chunks then offsets would be [0,3,6,8] + let mut chunk_offset: usize = 0; + offsets[0] = chunk_offset; + for chunk_index in 0..num_pq_chunks { + chunk_offset += dimensions / num_pq_chunks; + if chunk_index < (dimensions % num_pq_chunks) { + chunk_offset += 1; + } + offsets[chunk_index + 1] = chunk_offset; + } +} + +/// Allocating wrapper around [`calculate_chunk_offsets`] that returns a fresh +/// `Vec` of length `num_pq_chunks + 1`. +pub fn calculate_chunk_offsets_auto(dimensions: usize, num_pq_chunks: usize) -> Vec { + let mut offsets = vec![0; num_pq_chunks + 1]; + calculate_chunk_offsets(dimensions, num_pq_chunks, offsets.as_mut_slice()); + offsets +} + /////////////// // ChunkView // /////////////// From 4fe3345408e3e817708eb54138715137fb694464 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Fri, 24 Apr 2026 14:26:43 -0400 Subject: [PATCH 03/13] remove redundant deref impls in PQ dynamic distance --- .../fast_memory_quant_vector_provider.rs | 9 +-- .../src/model/pq/distance/dynamic.rs | 65 ------------------- 2 files changed, 3 insertions(+), 71 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs index 62d9c05e1..9da2296ac 100644 --- a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs @@ -444,21 +444,18 @@ mod tests { // Query Computer. let c = provider.query_computer(&[-0.5, -0.5]).unwrap(); let expected: f32 = 1.5 * 1.5 * 2.0; - assert_eq!( - c.evaluate_similarity(&provider.get_vector_sync(3)), - expected - ); + assert_eq!(c.evaluate_similarity(provider.get_vector_sync(3)), expected); // Distance Computer. let d = provider.distance_computer(); assert_eq!( - d.evaluate_similarity(&provider.get_vector_sync(0), &provider.get_vector_sync(3)), + d.evaluate_similarity(provider.get_vector_sync(0), provider.get_vector_sync(3)), 2.0 ); let slice: &[f32] = &[-0.5, -0.5]; assert_eq!( - d.evaluate_similarity(slice, &provider.get_vector_sync(3)), + d.evaluate_similarity(slice, provider.get_vector_sync(3)), expected, ); } diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index cefacacde..352912846 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -101,25 +101,6 @@ where } } -impl PreprocessedDistanceFunction<&Vec, f32> for QueryComputer -where - T: Deref, -{ - fn evaluate_similarity(&self, changing: &Vec) -> f32 { - self.evaluate_similarity(changing.as_slice()) - } -} - -impl PreprocessedDistanceFunction<&&[u8], f32> for QueryComputer -where - T: Deref, -{ - fn evaluate_similarity(&self, changing: &&[u8]) -> f32 { - let changing: &[u8] = changing; - self.evaluate_similarity(changing) - } -} - /// Pre-dispatched distance functions for the `FixedChunkPQTable`. #[derive(Debug)] pub struct VTable { @@ -233,52 +214,6 @@ where } } -/// Perform a comparison between a full-precision vector and quantized vector. -impl DistanceFunction<&[f32], &&[u8], f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, fp: &[f32], q: &&[u8]) -> f32 { - let q: &[u8] = q; - self.evaluate_similarity(fp, q) - } -} - -impl DistanceFunction<&[f32], &Vec, f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, fp: &[f32], q: &Vec) -> f32 { - self.evaluate_similarity(fp, q.as_slice()) - } -} - -/// Perform a comparison between two quantized vectors. -impl DistanceFunction<&&[u8], &&[u8], f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, q0: &&[u8], q1: &&[u8]) -> f32 { - let q0: &[u8] = q0; - let q1: &[u8] = q1; - self.evaluate_similarity(q0, q1) - } -} - -/// Perform a comparison between two quantized vectors. -impl DistanceFunction<&Vec, &Vec, f32> for DistanceComputer -where - T: Deref, -{ - #[inline(always)] - fn evaluate_similarity(&self, q0: &Vec, q1: &Vec) -> f32 { - self.evaluate_similarity(q0.as_slice(), q1.as_slice()) - } -} - #[cfg(test)] mod tests { use std::marker::PhantomData; From 40dbba944d51e87cd350a682932986fd6ab4ddfc Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Fri, 24 Apr 2026 15:02:40 -0400 Subject: [PATCH 04/13] fix bf_tree gated test for removed deref impls --- .../provider/async_/bf_tree/quant_vector_provider.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs index 0d26fa680..edcb593d8 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs @@ -353,7 +353,7 @@ mod tests { let c = provider.query_computer(&[-0.5, -0.5]).unwrap(); let expected: f32 = 1.5 * 1.5 * 2.0; assert_eq!( - c.evaluate_similarity(&provider.get_vector_sync(3).unwrap()), + c.evaluate_similarity(provider.get_vector_sync(3).unwrap().as_slice()), expected ); @@ -361,15 +361,15 @@ mod tests { let d = provider.distance_computer(); assert_eq!( d.evaluate_similarity( - &provider.get_vector_sync(0).unwrap(), - &provider.get_vector_sync(3).unwrap() + provider.get_vector_sync(0).unwrap().as_slice(), + provider.get_vector_sync(3).unwrap().as_slice() ), 2.0 ); let slice: &[f32] = &[-0.5, -0.5]; assert_eq!( - d.evaluate_similarity(slice, &provider.get_vector_sync(3).unwrap()), + d.evaluate_similarity(slice, provider.get_vector_sync(3).unwrap().as_slice()), expected, ); } From bf082868f23fcc27f44933f0b7f85f0c639f8957 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Fri, 24 Apr 2026 22:57:29 -0400 Subject: [PATCH 05/13] move accum_row_inplace to diskann-utils; demote get_chunk_from_training_data to test helper --- .../src/storage/quant/pq/pq_generation.rs | 7 +- diskann-providers/src/model/mod.rs | 8 +-- diskann-providers/src/model/pq/mod.rs | 7 +- .../src/model/pq/pq_construction.rs | 71 +++++++------------ diskann-utils/src/views.rs | 17 +++++ 5 files changed, 50 insertions(+), 60 deletions(-) diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index a8a1557c7..ac59b5ad1 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -9,15 +9,12 @@ use diskann::{utils::VectorRepr, ANNError}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ forward_threadpool, - model::{ - pq::{accum_row_inplace, generate_pq_pivots}, - GeneratePivotArguments, - }, + model::{pq::generate_pq_pivots, GeneratePivotArguments}, storage::PQStorage, utils::{AsThreadPool, BridgeErr, Timer}, }; use diskann_quantization::{product::TransposedTable, CompressInto}; -use diskann_utils::views::MatrixBase; +use diskann_utils::views::{accum_row_inplace, MatrixBase}; use diskann_vector::distance::Metric; use tracing::info; diff --git a/diskann-providers/src/model/mod.rs b/diskann-providers/src/model/mod.rs index e9fceb869..61addc05a 100644 --- a/diskann-providers/src/model/mod.rs +++ b/diskann-providers/src/model/mod.rs @@ -11,10 +11,10 @@ pub use configuration::IndexConfiguration; pub mod pq; pub use pq::{ FixedChunkPQTable, GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, - NUM_PQ_CENTROIDS, accum_row_inplace, compute_pq_distance, - compute_pq_distance_for_pq_coordinates, direct_distance_impl, distance, - generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, - generate_pq_pivots, generate_pq_pivots_from_membuf, + NUM_PQ_CENTROIDS, compute_pq_distance, compute_pq_distance_for_pq_coordinates, + direct_distance_impl, distance, generate_pq_data_from_pivots_from_membuf, + generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots, + generate_pq_pivots_from_membuf, }; pub mod statistics; diff --git a/diskann-providers/src/model/pq/mod.rs b/diskann-providers/src/model/pq/mod.rs index 6c1962a4a..ba6a49d25 100644 --- a/diskann-providers/src/model/pq/mod.rs +++ b/diskann-providers/src/model/pq/mod.rs @@ -10,10 +10,9 @@ pub use fixed_chunk_pq_table::{ mod pq_construction; pub use pq_construction::{ - MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS, accum_row_inplace, - generate_pq_data_from_pivots, generate_pq_data_from_pivots_from_membuf, - generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots, - generate_pq_pivots_from_membuf, get_chunk_from_training_data, move_train_data_by_centroid, + MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS, generate_pq_data_from_pivots, + generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, + generate_pq_pivots, generate_pq_pivots_from_membuf, move_train_data_by_centroid, }; /// all metadata of individual sub-component files is written in first 4KB for unified files diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 69b95c92a..bae68e650 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -23,7 +23,7 @@ use diskann_quantization::{ }; use diskann_utils::{ io::Metadata, - views::{MatrixView, MutMatrixView}, + views::{MatrixView, MutMatrixView, accum_row_inplace}, }; use rand::{Rng, distr::Distribution}; use rayon::prelude::*; @@ -254,35 +254,6 @@ pub fn generate_pq_pivots_from_membuf, Pool: AsThreadPool>( Ok(()) } -/// Gets all instances of a chunk from the training data for all records in the training data. Each vector in the -/// training dataset is divided into chunks and the PQ algorithm handles each vector chunk individually. This method -/// gets the same chunk from each vector in the training data and creates a new vector out of all of them. -/// -/// # Example -/// See tests for examples -#[inline] -pub fn get_chunk_from_training_data( - train_data: &[f32], - num_train: usize, - raw_vector_dim: usize, - chunk_size: usize, - chunk_offset: usize, -) -> Vec { - let mut result: Vec = vec![0.0; num_train * chunk_size]; - - result - // group empty result data into chunks of chunk_size - .chunks_mut(chunk_size) - .enumerate() - // for each chunk, copy the chunk from the training data into the result vector - .for_each(|(chunk_number, result_chunk)| { - let train_data_start = chunk_number * raw_vector_dim + chunk_offset; - let train_data_end = train_data_start + chunk_size; - result_chunk.copy_from_slice(&train_data[train_data_start..train_data_end]); - }); - result -} - /// Calculates the centroid if needed and moves the train_data to to the centroid /// # Arguments /// * `train_data` Dataset @@ -329,23 +300,6 @@ pub fn move_train_data_by_centroid( } } -/// Add the row `y` to every row in `x`. -/// -/// # Panics -/// -/// Panics if `y.len() != x.ncols()`. -pub fn accum_row_inplace(mut x: MutMatrixView, y: &[T]) -where - T: Copy + std::ops::AddAssign, -{ - assert_eq!(x.ncols(), y.len()); - x.row_iter_mut().for_each(|row| { - std::iter::zip(row.iter_mut(), y.iter()).for_each(|(a, b)| { - *a += *b; - }); - }); -} - /// streams the base file (data_file), and computes the closest centers in each /// chunk to generate the compressed data_file and stores it in /// pq_compressed_vectors_path. @@ -655,6 +609,29 @@ mod pq_test { utils::{ParallelIteratorInPool, create_thread_pool_for_test, read_bin_from}, }; + /// Test helper: Gets all instances of a chunk from the training data for all records + /// in the training data. Each vector in the training dataset is divided into chunks + /// and the PQ algorithm handles each vector chunk individually. This helper gets the + /// same chunk from each vector in the training data and returns it as a flat vector. + fn get_chunk_from_training_data( + train_data: &[f32], + num_train: usize, + raw_vector_dim: usize, + chunk_size: usize, + chunk_offset: usize, + ) -> Vec { + let mut result: Vec = vec![0.0; num_train * chunk_size]; + result + .chunks_mut(chunk_size) + .enumerate() + .for_each(|(chunk_number, result_chunk)| { + let train_data_start = chunk_number * raw_vector_dim + chunk_offset; + let train_data_end = train_data_start + chunk_size; + result_chunk.copy_from_slice(&train_data[train_data_start..train_data_end]); + }); + result + } + #[test] fn test_move_train_data_by_centroid() { let dim = 20; diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index a9352918c..f4ff31883 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -673,6 +673,23 @@ impl<'a, T> From> for &'a [T] { } } +/// Add the row `y` to every row in `x`. +/// +/// # Panics +/// +/// Panics if `y.len() != x.ncols()`. +pub fn accum_row_inplace(mut x: MutMatrixView, y: &[T]) +where + T: Copy + std::ops::AddAssign, +{ + assert_eq!(x.ncols(), y.len()); + x.row_iter_mut().for_each(|row| { + std::iter::zip(row.iter_mut(), y.iter()).for_each(|(a, b)| { + *a += *b; + }); + }); +} + /// Return a reference to the item at entry `(row, col)` in the matrix. /// /// # Panics From f257ee218f3d72a3589945fc79f41843f364f254 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 12:23:14 -0400 Subject: [PATCH 06/13] chunk offset constructor --- .../src/backend/exhaustive/product.rs | 8 +- .../src/model/pq/distance/test_utils.rs | 5 +- .../src/model/pq/pq_construction.rs | 31 ++-- diskann-quantization/src/views.rs | 148 +++++++++++++++--- 4 files changed, 148 insertions(+), 44 deletions(-) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index f339a3c44..4526c06f7 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -87,10 +87,10 @@ mod imp { 5, ); - let offsets = diskann_quantization::views::calculate_chunk_offsets_auto( + let offsets = diskann_quantization::views::ChunkOffsets::from_dimensions( data.ncols(), input.num_pq_chunks.get(), - ); + )?; let base = { let threadpool = rayon::ThreadPoolBuilder::new() @@ -99,7 +99,7 @@ mod imp { threadpool.install(|| -> anyhow::Result<_> { Ok(parameters.train( data.as_view(), - diskann_quantization::views::ChunkOffsetsView::new(offsets.as_slice())?, + offsets.as_view(), diskann_quantization::Parallelism::Rayon, &diskann_quantization::random::StdRngBuilder::new(input.seed), &diskann_quantization::cancel::DontCancel, @@ -111,7 +111,7 @@ mod imp { data.ncols(), base.flatten().into(), vec![0.0; data.ncols()].into(), - offsets.into(), + offsets.as_slice().into(), )?; let training_time: MicroSeconds = start.elapsed().into(); diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 677fc32a2..6b6c90d60 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -14,7 +14,7 @@ use rand::{Rng, distr::Distribution}; use rand_distr::{Normal, Uniform}; use crate::model::FixedChunkPQTable; -use diskann_quantization::views::calculate_chunk_offsets_auto; +use diskann_quantization::views::ChunkOffsets; /// We need a way to generate random queries. /// @@ -131,7 +131,8 @@ pub(crate) fn generate_expected_vector( /// * N + 1: The number of PQ Pivots pub(crate) fn seed_pivot_table(config: TableConfig) -> FixedChunkPQTable { // Get the chunk offsets for the selected dimension and bytes. - let offsets = calculate_chunk_offsets_auto(config.dim, config.pq_chunks); + let chunk_offsets = ChunkOffsets::from_dimensions(config.dim, config.pq_chunks).unwrap(); + let offsets = chunk_offsets.as_slice(); // Create the pivot table following the schema described in the docstring. let mut pivots = Vec::::new(); diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index bae68e650..216dbea8d 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -19,7 +19,7 @@ use diskann::{ use diskann_quantization::{ CompressInto, product::{BasicTableView, TransposedTable, train::TrainQuantizer}, - views::calculate_chunk_offsets, + views::{ChunkOffsets, ChunkOffsetsView}, }; use diskann_utils::{ io::Metadata, @@ -97,12 +97,8 @@ where ); } - let mut chunk_offsets: Vec = vec![0; parameters.num_pq_chunks() + 1]; - calculate_chunk_offsets( - parameters.dim(), - parameters.num_pq_chunks(), - &mut chunk_offsets, - ); + let chunk_offsets = + ChunkOffsets::from_dimensions(parameters.dim(), parameters.num_pq_chunks()).bridge_err()?; forward_threadpool!(pool = pool); let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( @@ -115,8 +111,7 @@ where .train( MatrixView::try_from(train_data, parameters.num_train(), parameters.dim()) .bridge_err()?, - diskann_quantization::views::ChunkOffsetsView::new(chunk_offsets.as_slice()) - .bridge_err()?, + chunk_offsets.as_view(), diskann_quantization::Parallelism::Rayon, &random_provider, &diskann_quantization::cancel::DontCancel, @@ -129,7 +124,7 @@ where pq_storage.write_pivot_data( &full_pivot_data, ¢roid, - &chunk_offsets, + chunk_offsets.as_slice(), parameters.num_centers(), parameters.dim(), storage_provider, @@ -206,8 +201,13 @@ pub fn generate_pq_pivots_from_membuf, Pool: AsThreadPool>( } } - // Calculate the chunk offsets - calculate_chunk_offsets(parameters.dim(), parameters.num_pq_chunks(), offsets); + // Calculate the chunk offsets, filling the caller-owned buffer. + let chunk_offsets_view = ChunkOffsetsView::from_dimensions_into( + parameters.dim(), + parameters.num_pq_chunks(), + offsets, + ) + .bridge_err()?; forward_threadpool!(pool = pool); let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( @@ -240,7 +240,7 @@ pub fn generate_pq_pivots_from_membuf, Pool: AsThreadPool>( parameters.dim(), ) .bridge_err()?, - diskann_quantization::views::ChunkOffsetsView::new(offsets).bridge_err()?, + chunk_offsets_view, diskann_quantization::Parallelism::Rayon, &rng_builder, &cancelation, @@ -1037,9 +1037,8 @@ mod pq_test { // Pre-emptively construct an offset view to compare mismatched slices. // We want to check that the difference in the mismatched chunks is small. - let mut offsets = vec![0; num_pq_chunks + 1]; - calculate_chunk_offsets(train_dim, num_pq_chunks, &mut offsets); - let offset_view = diskann_quantization::views::ChunkOffsetsView::new(&offsets).unwrap(); + let chunk_offsets = ChunkOffsets::from_dimensions(train_dim, num_pq_chunks).unwrap(); + let offset_view = chunk_offsets.as_view(); let full_data = MatrixView::try_from(full_data_vector.as_slice(), num_train, train_dim).unwrap(); let pivot_view = diff --git a/diskann-quantization/src/views.rs b/diskann-quantization/src/views.rs index f8882e8b4..cc8ec1229 100644 --- a/diskann-quantization/src/views.rs +++ b/diskann-quantization/src/views.rs @@ -54,6 +54,8 @@ pub enum ChunkOffsetError { start: usize, next_val: usize, }, + #[error("scratch buffer length {actual} does not match expected length {expected}")] + ScratchLengthMismatch { expected: usize, actual: usize }, } impl ChunkOffsetsBase @@ -205,20 +207,52 @@ impl<'a> From> for &'a [usize] { } } -/// Calculate the chunk offsets for the product quantization algorithm. Fills `offsets` -/// with the prefix-sum partitioning of `dimensions` into `num_pq_chunks` chunks, where -/// the first `dimensions % num_pq_chunks` chunks are one element larger than the rest. -/// -/// # Arguments -/// * `dimensions` - Number of dimensions of the input data. -/// * `num_pq_chunks` - Number of chunks that will be used in the PQ calculation. Each -/// vector will be split into these number of chunks and each chunk will be compressed -/// down to one byte. -/// * `offsets` - An output slice of offsets, where the length must equal `num_pq_chunks + 1`. -#[inline] -pub fn calculate_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { - // Calculate each chunk's offset - // If we have 8 dimension and 3 chunks then offsets would be [0,3,6,8] +impl ChunkOffsets { + /// Build a chunk-offset plan that partitions `dimensions` into `num_pq_chunks` + /// near-equal chunks. The first `dimensions % num_pq_chunks` chunks are one element + /// larger than the rest. + /// + /// Returns an error if the requested partition is not valid (e.g. `dimensions == 0`, + /// `num_pq_chunks == 0`, or `num_pq_chunks > dimensions`). + pub fn from_dimensions( + dimensions: usize, + num_pq_chunks: usize, + ) -> Result { + let mut offsets = vec![0usize; num_pq_chunks + 1].into_boxed_slice(); + fill_chunk_offsets(dimensions, num_pq_chunks, &mut offsets); + Self::new(offsets) + } +} + +impl<'a> ChunkOffsetsView<'a> { + /// Fill the caller-owned `scratch` buffer with the partition for `(dimensions, + /// num_pq_chunks)` and return a validated view borrowing it. + /// + /// See [`ChunkOffsets::from_dimensions`] for the partitioning rule. + /// + /// Returns an error if `scratch.len() != num_pq_chunks + 1` or if the requested + /// partition is not valid (e.g. `dimensions == 0`, `num_pq_chunks == 0`, or + /// `num_pq_chunks > dimensions`). + pub fn from_dimensions_into( + dimensions: usize, + num_pq_chunks: usize, + scratch: &'a mut [usize], + ) -> Result { + let expected = num_pq_chunks + 1; + if scratch.len() != expected { + return Err(ChunkOffsetError::ScratchLengthMismatch { + expected, + actual: scratch.len(), + }); + } + fill_chunk_offsets(dimensions, num_pq_chunks, scratch); + Self::new(scratch) + } +} + +/// Internal helper: fill `offsets` (of length `num_pq_chunks + 1`) with the prefix-sum +/// partitioning of `dimensions` into `num_pq_chunks` chunks. +fn fill_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { let mut chunk_offset: usize = 0; offsets[0] = chunk_offset; for chunk_index in 0..num_pq_chunks { @@ -230,14 +264,6 @@ pub fn calculate_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: } } -/// Allocating wrapper around [`calculate_chunk_offsets`] that returns a fresh -/// `Vec` of length `num_pq_chunks + 1`. -pub fn calculate_chunk_offsets_auto(dimensions: usize, num_pq_chunks: usize) -> Vec { - let mut offsets = vec![0; num_pq_chunks + 1]; - calculate_chunk_offsets(dimensions, num_pq_chunks, offsets.as_mut_slice()); - offsets -} - /////////////// // ChunkView // /////////////// @@ -458,6 +484,84 @@ mod tests { ); } + ////////////////////////////// + // from_dimensions builders // + ////////////////////////////// + + #[test] + fn from_dimensions_happy_path() { + // Even split: 9 / 3 = 3 each. + let offsets = ChunkOffsets::from_dimensions(9, 3).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 3, 6, 9]); + assert_eq!(offsets.dim(), 9); + assert_eq!(offsets.len(), 3); + + // Uneven split: 8 / 3 = 2 r 2 -> first two chunks get an extra element. + let offsets = ChunkOffsets::from_dimensions(8, 3).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 3, 6, 8]); + + // Single chunk degenerate case. + let offsets = ChunkOffsets::from_dimensions(5, 1).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 5]); + + // dimensions == num_pq_chunks: each chunk is size 1. + let offsets = ChunkOffsets::from_dimensions(4, 4).unwrap(); + assert_eq!(offsets.as_slice(), &[0, 1, 2, 3, 4]); + + // The view-into variant matches the owning constructor. + let mut scratch = [0usize; 4]; + let view = ChunkOffsetsView::from_dimensions_into(8, 3, &mut scratch).unwrap(); + assert_eq!(view.as_slice(), &[0, 3, 6, 8]); + assert_eq!(view.dim(), 8); + assert_eq!(view.len(), 3); + assert_eq!(scratch.as_slice(), &[0, 3, 6, 8]); + } + + #[test] + fn from_dimensions_construction_errors() { + // num_pq_chunks > dimensions -> some chunk would be empty -> NonMonotonic. + let err = ChunkOffsets::from_dimensions(3, 5).unwrap_err(); + assert!( + matches!(err, ChunkOffsetError::NonMonotonic { .. }), + "expected NonMonotonic, got {err:?}" + ); + + // dimensions == 0 -> trivially non-monotonic. + let err = ChunkOffsets::from_dimensions(0, 1).unwrap_err(); + assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); + + // num_pq_chunks == 0 -> length 1 buffer, fails LengthNotAtLeastTwo. + let err = ChunkOffsets::from_dimensions(8, 0).unwrap_err(); + assert!(matches!(err, ChunkOffsetError::LengthNotAtLeastTwo(1))); + + // Scratch length too short -> ScratchLengthMismatch error (no panic). + let mut too_short = [0usize; 3]; + let err = ChunkOffsetsView::from_dimensions_into(8, 3, &mut too_short).unwrap_err(); + assert!(matches!( + err, + ChunkOffsetError::ScratchLengthMismatch { + expected: 4, + actual: 3 + } + )); + + // Scratch length too long -> ScratchLengthMismatch error. + let mut too_long = [0usize; 5]; + let err = ChunkOffsetsView::from_dimensions_into(8, 3, &mut too_long).unwrap_err(); + assert!(matches!( + err, + ChunkOffsetError::ScratchLengthMismatch { + expected: 4, + actual: 5 + } + )); + + // Partition validation errors propagate through the view builder too. + let mut scratch = [0usize; 6]; + let err = ChunkOffsetsView::from_dimensions_into(3, 5, &mut scratch).unwrap_err(); + assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); + } + /////////////// // ChunkView // /////////////// From 2aa27a7c852fb9a8242bf2785697ae1517862664 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 15:13:41 -0400 Subject: [PATCH 07/13] broadcast row mut --- .../src/storage/quant/pq/pq_generation.rs | 16 +-- .../src/model/pq/pq_construction.rs | 6 +- diskann-providers/src/model/pq/views.rs | 8 ++ diskann-utils/src/views.rs | 119 +++++++++++++++--- 4 files changed, 123 insertions(+), 26 deletions(-) diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index 6e8568d49..446a386a7 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -5,15 +5,15 @@ use std::marker::PhantomData; -use diskann::{utils::VectorRepr, ANNError}; +use diskann::{ANNError, utils::VectorRepr}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ - model::{pq::generate_pq_pivots, GeneratePivotArguments}, + model::{GeneratePivotArguments, pq::generate_pq_pivots}, storage::PQStorage, utils::{BridgeErr, RayonThreadPoolRef, Timer}, }; -use diskann_quantization::{product::TransposedTable, CompressInto}; -use diskann_utils::views::{accum_row_inplace, MatrixBase}; +use diskann_quantization::{CompressInto, product::TransposedTable}; +use diskann_utils::views::MatrixBase; use diskann_vector::distance::Metric; use tracing::info; @@ -133,7 +133,9 @@ where ) .bridge_err()?; - accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice()); + full_pivot_data_mat + .broadcast_rows_mut(centroid.as_slice(), |a, b| *a += *b) + .bridge_err()?; let table = TransposedTable::from_parts( full_pivot_data_mat.as_view(), @@ -173,12 +175,12 @@ where #[cfg(test)] mod pq_generation_tests { use diskann::ANNError; - use diskann_providers::model::pq::generate_pq_pivots; use diskann_providers::model::GeneratePivotArguments; + use diskann_providers::model::pq::generate_pq_pivots; use diskann_providers::storage::{ PQStorage, StorageReadProvider, StorageWriteProvider, VirtualStorageProvider, }; - use diskann_providers::utils::{create_thread_pool_for_test, RayonThreadPoolRef}; + use diskann_providers::utils::{RayonThreadPoolRef, create_thread_pool_for_test}; use diskann_utils::{ io::{read_bin, write_bin}, test_data_root, diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 1e0dcf6fc..0f65b1d44 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -23,7 +23,7 @@ use diskann_quantization::{ }; use diskann_utils::{ io::Metadata, - views::{MatrixView, MutMatrixView, accum_row_inplace}, + views::{MatrixView, MutMatrixView}, }; use rand::{Rng, distr::Distribution}; use rayon::prelude::*; @@ -355,7 +355,9 @@ where let mut full_pivot_data_mat = MutMatrixView::try_from(full_pivot_data.as_mut_slice(), num_centers, full_dim) .bridge_err()?; - accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice()); + full_pivot_data_mat + .broadcast_rows_mut(centroid.as_slice(), |a, b| *a += *b) + .bridge_err()?; pq_storage.write_compressed_pivot_metadata::( num_points, diff --git a/diskann-providers/src/model/pq/views.rs b/diskann-providers/src/model/pq/views.rs index 3329c31d2..0d0782562 100644 --- a/diskann-providers/src/model/pq/views.rs +++ b/diskann-providers/src/model/pq/views.rs @@ -32,6 +32,14 @@ impl From>> for ANNError { } } +// Compatibility with ANNError. +impl From> for ANNError { + #[track_caller] + fn from(value: Bridge) -> Self { + ANNError::log_pq_error(value.into_inner()) + } +} + /////////// // Tests // /////////// diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index f4ff31883..81c9da470 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -161,6 +161,15 @@ impl TryFromError { } } +/// Error returned when a broadcast vector's length does not match the matrix's column +/// count. +#[derive(Debug, Error)] +#[error("broadcast vector length {broadcast_len} does not match matrix ncols {ncols}")] +pub struct BroadcastLenMismatch { + pub ncols: usize, + pub broadcast_len: usize, +} + /// A generator for initializing the entries in a matrix via `Matrix::new`. pub trait Generator { fn generate(&mut self) -> T; @@ -388,6 +397,32 @@ where self.data.as_mut_slice().chunks_exact_mut(ncols) } + /// Apply `op(row[j], &broadcast[j])` to every entry of every row, broadcasting + /// `broadcast` across the rows of the matrix through an `op`. + /// + /// Returns an error if `broadcast.len() != self.ncols()`. The matrix is left + /// unmodified in the error case. + pub fn broadcast_rows_mut( + &mut self, + broadcast: &[U], + mut op: F, + ) -> Result<(), BroadcastLenMismatch> + where + T: MutDenseData, + F: FnMut(&mut T::Elem, &U), + { + if self.ncols() != broadcast.len() { + return Err(BroadcastLenMismatch { + ncols: self.ncols(), + broadcast_len: broadcast.len(), + }); + } + self.row_iter_mut().for_each(|row| { + std::iter::zip(row.iter_mut(), broadcast.iter()).for_each(|(a, b)| op(a, b)); + }); + Ok(()) + } + /// Return an iterator that divides the matrix into sub-matrices with (up to) /// `batchsize` rows with `self.ncols()` columns. /// @@ -673,23 +708,6 @@ impl<'a, T> From> for &'a [T] { } } -/// Add the row `y` to every row in `x`. -/// -/// # Panics -/// -/// Panics if `y.len() != x.ncols()`. -pub fn accum_row_inplace(mut x: MutMatrixView, y: &[T]) -where - T: Copy + std::ops::AddAssign, -{ - assert_eq!(x.ncols(), y.len()); - x.row_iter_mut().for_each(|row| { - std::iter::zip(row.iter_mut(), y.iter()).for_each(|(a, b)| { - *a += *b; - }); - }); -} - /// Return a reference to the item at entry `(row, col)` in the matrix. /// /// # Panics @@ -2197,4 +2215,71 @@ mod tests { // Verify all elements are 43 assert!(m.as_slice().iter().all(|&x| x == 43)); } + + #[test] + fn test_broadcast_rows_mut() { + // Use the canonical test fixture and broadcast-add a length-3 row across all + // rows. The resulting rows are the original rows offset by [10, 20, 30]. + let data = make_test_matrix(); + let mut m = Matrix::try_from(data.into(), 4, 3).unwrap(); + let bias = [10usize, 20, 30]; + + m.broadcast_rows_mut(&bias, |a, b| *a += *b).unwrap(); + assert_eq!(m.row(0), &[10, 21, 32]); + assert_eq!(m.row(1), &[11, 22, 33]); + assert_eq!(m.row(2), &[12, 23, 34]); + assert_eq!(m.row(3), &[13, 24, 35]); + + // Subtract the same bias to confirm op generality and that we round-trip back to + // the canonical fixture. + m.broadcast_rows_mut(&bias, |a, b| *a -= *b).unwrap(); + test_basic_indexing(&m); + + // Broadcast also works through a `MutMatrixView`. + let mut data = make_test_matrix(); + let mut view = MutMatrixView::try_from(data.as_mut_slice(), 4, 3).unwrap(); + view.broadcast_rows_mut(&bias, |a, b| *a += *b).unwrap(); + assert_eq!(view.row(0), &[10, 21, 32]); + assert_eq!(view.row(3), &[13, 24, 35]); + + // Heterogeneous element types: matrix is f32, broadcast is i32. + let mut fmat = Matrix::::new(0.0, 2, 3); + let scale: [i32; 3] = [1, 2, 3]; + fmat.broadcast_rows_mut(&scale, |a, b| *a += *b as f32) + .unwrap(); + assert_eq!(fmat.row(0), &[1.0, 2.0, 3.0]); + assert_eq!(fmat.row(1), &[1.0, 2.0, 3.0]); + + // Single-row degenerate case. + let mut single = Matrix::::new(5, 1, 4); + single + .broadcast_rows_mut(&[1, 2, 3, 4], |a, b| *a += *b) + .unwrap(); + assert_eq!(single.row(0), &[6, 7, 8, 9]); + } + + #[test] + fn test_broadcast_rows_mut_length_mismatch() { + let data = make_test_matrix(); + let mut m = Matrix::try_from(data.clone().into(), 4, 3).unwrap(); + + // Too short. + let err = m.broadcast_rows_mut(&[1, 2], |a, b| *a += *b).unwrap_err(); + assert_eq!( + err.to_string(), + "broadcast vector length 2 does not match matrix ncols 3" + ); + // Matrix must be unmodified after a failed call. + assert_eq!(m.as_slice(), data.as_slice()); + + // Too long. + let err = m + .broadcast_rows_mut(&[1, 2, 3, 4], |a, b| *a += *b) + .unwrap_err(); + assert_eq!( + err.to_string(), + "broadcast vector length 4 does not match matrix ncols 3" + ); + assert_eq!(m.as_slice(), data.as_slice()); + } } From daca60e29863f448f2bbdcf32d814327cc59d9ca Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 15:25:21 -0400 Subject: [PATCH 08/13] revert commebts in quantizer_preprocess --- .../src/search/pq/quantizer_preprocess.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/diskann-disk/src/search/pq/quantizer_preprocess.rs b/diskann-disk/src/search/pq/quantizer_preprocess.rs index 82f7cc7c5..cc454ea7b 100644 --- a/diskann-disk/src/search/pq/quantizer_preprocess.rs +++ b/diskann-disk/src/search/pq/quantizer_preprocess.rs @@ -3,13 +3,6 @@ * Licensed under the MIT license. */ -//! PQ quantizer query preprocessing. -//! -//! Prior to the introduction of the [`quantizer_preprocess`] method, the disk index was -//! hard-coded to use L2 distance for comparisons. We're keeping that behavior here - -//! treating `Cosine` and `CosineNormalized` as L2 until a more thorough evaluation can -//! be made. - use diskann::ANNResult; use diskann_vector::distance::Metric; @@ -40,7 +33,11 @@ pub fn quantizer_preprocess( .bridge_err()?; match metric { - // Cosine and CosineNormalized fall back to L2; see module docs. + // Prior to the introduction of the `quantizer_preprocess` method, the + // disk index was hard-coded to use L2 distance for comparisons. + // + // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` + // as L2 until a more thorough evaluation can be made. Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { table.process_into::( &pq_scratch.rotated_query[..dim], @@ -57,7 +54,11 @@ pub fn quantizer_preprocess( } PQTable::Fixed(table) => { match metric { - // Cosine and CosineNormalized fall back to L2; see module docs. + // Prior to the introduction of the `quantizer_preprocess` method, the + // disk index was hard-coded to use L2 distance for comparisons. + // + // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` + // as L2 until a more thorough evaluation can be made. Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { // The scratch only stores the aligned dimension. However, preprocessing // wants the actual dimension used, so we have to shrink the rotated query From 0871f0b5c3a8558c897b18130b04c8959a76b864 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 16:02:58 -0400 Subject: [PATCH 09/13] rustfmt --- diskann-disk/src/storage/quant/pq/pq_generation.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index 446a386a7..4b32fab7a 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -5,14 +5,14 @@ use std::marker::PhantomData; -use diskann::{ANNError, utils::VectorRepr}; +use diskann::{utils::VectorRepr, ANNError}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ - model::{GeneratePivotArguments, pq::generate_pq_pivots}, + model::{pq::generate_pq_pivots, GeneratePivotArguments}, storage::PQStorage, utils::{BridgeErr, RayonThreadPoolRef, Timer}, }; -use diskann_quantization::{CompressInto, product::TransposedTable}; +use diskann_quantization::{product::TransposedTable, CompressInto}; use diskann_utils::views::MatrixBase; use diskann_vector::distance::Metric; use tracing::info; @@ -175,12 +175,12 @@ where #[cfg(test)] mod pq_generation_tests { use diskann::ANNError; - use diskann_providers::model::GeneratePivotArguments; use diskann_providers::model::pq::generate_pq_pivots; + use diskann_providers::model::GeneratePivotArguments; use diskann_providers::storage::{ PQStorage, StorageReadProvider, StorageWriteProvider, VirtualStorageProvider, }; - use diskann_providers::utils::{RayonThreadPoolRef, create_thread_pool_for_test}; + use diskann_providers::utils::{create_thread_pool_for_test, RayonThreadPoolRef}; use diskann_utils::{ io::{read_bin, write_bin}, test_data_root, From 081d45ad76bbc12a1389de40550149f5421b29f0 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 20:07:16 -0400 Subject: [PATCH 10/13] nonzero and cleanup --- .../src/backend/exhaustive/product.rs | 8 +- .../src/model/pq/distance/test_utils.rs | 6 +- .../src/model/pq/pq_construction.rs | 23 ++-- diskann-quantization/src/views.rs | 124 ++++++++---------- 4 files changed, 78 insertions(+), 83 deletions(-) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index be3cdf50c..542ad853e 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -85,9 +85,11 @@ mod imp { 5, ); - let offsets = diskann_quantization::views::ChunkOffsets::from_dimensions( - data.ncols(), - input.num_pq_chunks.get(), + let dim = std::num::NonZeroUsize::new(data.ncols()) + .ok_or_else(|| anyhow::anyhow!("data has zero columns"))?; + let offsets = diskann_quantization::views::ChunkOffsets::from_dim( + dim, + input.num_pq_chunks, )?; let base = { diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 6b6c90d60..e18d39b74 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -131,7 +131,11 @@ pub(crate) fn generate_expected_vector( /// * N + 1: The number of PQ Pivots pub(crate) fn seed_pivot_table(config: TableConfig) -> FixedChunkPQTable { // Get the chunk offsets for the selected dimension and bytes. - let chunk_offsets = ChunkOffsets::from_dimensions(config.dim, config.pq_chunks).unwrap(); + let chunk_offsets = ChunkOffsets::from_dim( + std::num::NonZeroUsize::new(config.dim).unwrap(), + std::num::NonZeroUsize::new(config.pq_chunks).unwrap(), + ) + .unwrap(); let offsets = chunk_offsets.as_slice(); // Create the pivot table following the schema described in the docstring. diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 0f65b1d44..889223c52 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -6,6 +6,7 @@ use std::{ io::{Seek, SeekFrom, Write}, mem::size_of, + num::NonZeroUsize, sync::atomic::AtomicBool, vec, }; @@ -95,8 +96,11 @@ where ); } - let chunk_offsets = - ChunkOffsets::from_dimensions(parameters.dim(), parameters.num_pq_chunks()).bridge_err()?; + let dim = NonZeroUsize::new(parameters.dim()) + .ok_or_else(|| ANNError::log_pq_error("dim must be non-zero"))?; + let num_chunks = NonZeroUsize::new(parameters.num_pq_chunks()) + .ok_or_else(|| ANNError::log_pq_error("num_pq_chunks must be non-zero"))?; + let chunk_offsets = ChunkOffsets::from_dim(dim, num_chunks).bridge_err()?; let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( parameters.num_centers(), @@ -199,12 +203,9 @@ pub fn generate_pq_pivots_from_membuf>( } // Calculate the chunk offsets, filling the caller-owned buffer. - let chunk_offsets_view = ChunkOffsetsView::from_dimensions_into( - parameters.dim(), - parameters.num_pq_chunks(), - offsets, - ) - .bridge_err()?; + let dim = NonZeroUsize::new(parameters.dim()) + .ok_or_else(|| ANNError::log_pq_error("dim must be non-zero"))?; + let chunk_offsets_view = ChunkOffsetsView::from_dim_into(dim, offsets).bridge_err()?; let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( parameters.num_centers(), @@ -1028,7 +1029,11 @@ mod pq_test { // Pre-emptively construct an offset view to compare mismatched slices. // We want to check that the difference in the mismatched chunks is small. - let chunk_offsets = ChunkOffsets::from_dimensions(train_dim, num_pq_chunks).unwrap(); + let chunk_offsets = ChunkOffsets::from_dim( + NonZeroUsize::new(train_dim).unwrap(), + NonZeroUsize::new(num_pq_chunks).unwrap(), + ) + .unwrap(); let offset_view = chunk_offsets.as_view(); let full_data = MatrixView::try_from(full_data_vector.as_slice(), num_train, train_dim).unwrap(); diff --git a/diskann-quantization/src/views.rs b/diskann-quantization/src/views.rs index cc8ec1229..827d69137 100644 --- a/diskann-quantization/src/views.rs +++ b/diskann-quantization/src/views.rs @@ -54,8 +54,6 @@ pub enum ChunkOffsetError { start: usize, next_val: usize, }, - #[error("scratch buffer length {actual} does not match expected length {expected}")] - ScratchLengthMismatch { expected: usize, actual: usize }, } impl ChunkOffsetsBase @@ -208,56 +206,62 @@ impl<'a> From> for &'a [usize] { } impl ChunkOffsets { - /// Build a chunk-offset plan that partitions `dimensions` into `num_pq_chunks` - /// near-equal chunks. The first `dimensions % num_pq_chunks` chunks are one element - /// larger than the rest. + /// Build a chunk-offset plan that partitions `dim` into `num_chunks` + /// near-equal chunks. The first `dim.get() % num_chunks.get()` chunks are + /// one element larger than the rest. /// - /// Returns an error if the requested partition is not valid (e.g. `dimensions == 0`, - /// `num_pq_chunks == 0`, or `num_pq_chunks > dimensions`). - pub fn from_dimensions( - dimensions: usize, - num_pq_chunks: usize, + /// Returns an error if the requested partition is not valid (e.g. + /// `num_chunks.get() > dim.get()`). + pub fn from_dim( + dim: NonZeroUsize, + num_chunks: NonZeroUsize, ) -> Result { - let mut offsets = vec![0usize; num_pq_chunks + 1].into_boxed_slice(); - fill_chunk_offsets(dimensions, num_pq_chunks, &mut offsets); + let mut offsets = vec![0usize; num_chunks.get() + 1].into_boxed_slice(); + fill_chunk_offsets(dim, &mut offsets); Self::new(offsets) } } impl<'a> ChunkOffsetsView<'a> { - /// Fill the caller-owned `scratch` buffer with the partition for `(dimensions, - /// num_pq_chunks)` and return a validated view borrowing it. + /// Fill the caller-owned `scratch` buffer with the partition for `dim` + /// into `scratch.len() - 1` chunks and return a validated view borrowing it. /// - /// See [`ChunkOffsets::from_dimensions`] for the partitioning rule. + /// See [`ChunkOffsets::from_dim`] for the partitioning rule. /// - /// Returns an error if `scratch.len() != num_pq_chunks + 1` or if the requested - /// partition is not valid (e.g. `dimensions == 0`, `num_pq_chunks == 0`, or - /// `num_pq_chunks > dimensions`). - pub fn from_dimensions_into( - dimensions: usize, - num_pq_chunks: usize, + /// Returns an error if `scratch.len() < 2` or if the requested partition is not + /// valid (e.g. `scratch.len() - 1 > dim.get()`). + pub fn from_dim_into( + dim: NonZeroUsize, scratch: &'a mut [usize], ) -> Result { - let expected = num_pq_chunks + 1; - if scratch.len() != expected { - return Err(ChunkOffsetError::ScratchLengthMismatch { - expected, - actual: scratch.len(), - }); + if scratch.len() < 2 { + return Err(ChunkOffsetError::LengthNotAtLeastTwo(scratch.len())); } - fill_chunk_offsets(dimensions, num_pq_chunks, scratch); + + fill_chunk_offsets(dim, scratch); Self::new(scratch) } } -/// Internal helper: fill `offsets` (of length `num_pq_chunks + 1`) with the prefix-sum -/// partitioning of `dimensions` into `num_pq_chunks` chunks. -fn fill_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) { +/// Internal helper: fill `offsets` with the prefix-sum +/// partitioning of `dimensions` into `num_chunks` near-equal chunks, where +/// `num_chunks = offsets.len() - 1`. +/// +/// The first `dimensions.get() % num_chunks` chunks are one element larger than the +/// rest, so each chunk has size `dimensions.get() / num_chunks` or +/// `dimensions.get() / num_chunks + 1` and the total covers `[0, dimensions.get()]`. +/// +/// # Panics +/// +/// Panics if `offsets.len() <= 1`. +fn fill_chunk_offsets(dimensions: NonZeroUsize, offsets: &mut [usize]) { + let num_chunks = offsets.len() - 1; + let dimensions = dimensions.get(); let mut chunk_offset: usize = 0; offsets[0] = chunk_offset; - for chunk_index in 0..num_pq_chunks { - chunk_offset += dimensions / num_pq_chunks; - if chunk_index < (dimensions % num_pq_chunks) { + for chunk_index in 0..num_chunks { + chunk_offset += dimensions / num_chunks; + if chunk_index < (dimensions % num_chunks) { chunk_offset += 1; } offsets[chunk_index + 1] = chunk_offset; @@ -490,27 +494,30 @@ mod tests { #[test] fn from_dimensions_happy_path() { + let nz = |x: usize| NonZeroUsize::new(x).unwrap(); + // Even split: 9 / 3 = 3 each. - let offsets = ChunkOffsets::from_dimensions(9, 3).unwrap(); + let offsets = ChunkOffsets::from_dim(nz(9), nz(3)).unwrap(); assert_eq!(offsets.as_slice(), &[0, 3, 6, 9]); assert_eq!(offsets.dim(), 9); assert_eq!(offsets.len(), 3); // Uneven split: 8 / 3 = 2 r 2 -> first two chunks get an extra element. - let offsets = ChunkOffsets::from_dimensions(8, 3).unwrap(); + let offsets = ChunkOffsets::from_dim(nz(8), nz(3)).unwrap(); assert_eq!(offsets.as_slice(), &[0, 3, 6, 8]); // Single chunk degenerate case. - let offsets = ChunkOffsets::from_dimensions(5, 1).unwrap(); + let offsets = ChunkOffsets::from_dim(nz(5), nz(1)).unwrap(); assert_eq!(offsets.as_slice(), &[0, 5]); // dimensions == num_pq_chunks: each chunk is size 1. - let offsets = ChunkOffsets::from_dimensions(4, 4).unwrap(); + let offsets = ChunkOffsets::from_dim(nz(4), nz(4)).unwrap(); assert_eq!(offsets.as_slice(), &[0, 1, 2, 3, 4]); - // The view-into variant matches the owning constructor. + // The view-into variant matches the owning constructor; num_chunks is + // inferred from `scratch.len() - 1`. let mut scratch = [0usize; 4]; - let view = ChunkOffsetsView::from_dimensions_into(8, 3, &mut scratch).unwrap(); + let view = ChunkOffsetsView::from_dim_into(nz(8), &mut scratch).unwrap(); assert_eq!(view.as_slice(), &[0, 3, 6, 8]); assert_eq!(view.dim(), 8); assert_eq!(view.len(), 3); @@ -519,46 +526,23 @@ mod tests { #[test] fn from_dimensions_construction_errors() { + let nz = |x: usize| NonZeroUsize::new(x).unwrap(); + // num_pq_chunks > dimensions -> some chunk would be empty -> NonMonotonic. - let err = ChunkOffsets::from_dimensions(3, 5).unwrap_err(); + let err = ChunkOffsets::from_dim(nz(3), nz(5)).unwrap_err(); assert!( matches!(err, ChunkOffsetError::NonMonotonic { .. }), "expected NonMonotonic, got {err:?}" ); - // dimensions == 0 -> trivially non-monotonic. - let err = ChunkOffsets::from_dimensions(0, 1).unwrap_err(); - assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); - - // num_pq_chunks == 0 -> length 1 buffer, fails LengthNotAtLeastTwo. - let err = ChunkOffsets::from_dimensions(8, 0).unwrap_err(); + // Scratch length < 2 -> LengthNotAtLeastTwo (cannot infer num_chunks). + let mut too_short = [0usize; 1]; + let err = ChunkOffsetsView::from_dim_into(nz(8), &mut too_short).unwrap_err(); assert!(matches!(err, ChunkOffsetError::LengthNotAtLeastTwo(1))); - // Scratch length too short -> ScratchLengthMismatch error (no panic). - let mut too_short = [0usize; 3]; - let err = ChunkOffsetsView::from_dimensions_into(8, 3, &mut too_short).unwrap_err(); - assert!(matches!( - err, - ChunkOffsetError::ScratchLengthMismatch { - expected: 4, - actual: 3 - } - )); - - // Scratch length too long -> ScratchLengthMismatch error. - let mut too_long = [0usize; 5]; - let err = ChunkOffsetsView::from_dimensions_into(8, 3, &mut too_long).unwrap_err(); - assert!(matches!( - err, - ChunkOffsetError::ScratchLengthMismatch { - expected: 4, - actual: 5 - } - )); - // Partition validation errors propagate through the view builder too. let mut scratch = [0usize; 6]; - let err = ChunkOffsetsView::from_dimensions_into(3, 5, &mut scratch).unwrap_err(); + let err = ChunkOffsetsView::from_dim_into(nz(3), &mut scratch).unwrap_err(); assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); } From abde0c0e0e48f7f5191a430eeb313f00379690e7 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 20:26:23 -0400 Subject: [PATCH 11/13] fmt --- diskann-benchmark/src/backend/exhaustive/product.rs | 6 ++---- diskann-quantization/src/views.rs | 5 +---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 542ad853e..0fbf456c9 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -87,10 +87,8 @@ mod imp { let dim = std::num::NonZeroUsize::new(data.ncols()) .ok_or_else(|| anyhow::anyhow!("data has zero columns"))?; - let offsets = diskann_quantization::views::ChunkOffsets::from_dim( - dim, - input.num_pq_chunks, - )?; + let offsets = + diskann_quantization::views::ChunkOffsets::from_dim(dim, input.num_pq_chunks)?; let base = { let threadpool = rayon::ThreadPoolBuilder::new() diff --git a/diskann-quantization/src/views.rs b/diskann-quantization/src/views.rs index 827d69137..d0483be8d 100644 --- a/diskann-quantization/src/views.rs +++ b/diskann-quantization/src/views.rs @@ -212,10 +212,7 @@ impl ChunkOffsets { /// /// Returns an error if the requested partition is not valid (e.g. /// `num_chunks.get() > dim.get()`). - pub fn from_dim( - dim: NonZeroUsize, - num_chunks: NonZeroUsize, - ) -> Result { + pub fn from_dim(dim: NonZeroUsize, num_chunks: NonZeroUsize) -> Result { let mut offsets = vec![0usize; num_chunks.get() + 1].into_boxed_slice(); fill_chunk_offsets(dim, &mut offsets); Self::new(offsets) From 83563c06977ea765a6a8fecbaf7e75e2b8c1d04d Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 23:05:56 -0400 Subject: [PATCH 12/13] inline centroid centering --- .../src/storage/quant/pq/pq_generation.rs | 16 ++- .../src/model/pq/pq_construction.rs | 17 ++- diskann-providers/src/model/pq/views.rs | 8 -- diskann-utils/src/views.rs | 102 ------------------ 4 files changed, 27 insertions(+), 116 deletions(-) diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index 4b32fab7a..aa0de8a71 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -133,9 +133,19 @@ where ) .bridge_err()?; - full_pivot_data_mat - .broadcast_rows_mut(centroid.as_slice(), |a, b| *a += *b) - .bridge_err()?; + if full_pivot_data_mat.ncols() != centroid.len() { + return Err(ANNError::log_pq_error(format_args!( + "pivot data ncols {} does not match centroid length {}", + full_pivot_data_mat.ncols(), + centroid.len(), + ))); + } + + for row in full_pivot_data_mat.row_iter_mut() { + for (a, b) in std::iter::zip(row.iter_mut(), centroid.iter()) { + *a += *b; + } + } let table = TransposedTable::from_parts( full_pivot_data_mat.as_view(), diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 889223c52..eb96a0509 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -356,9 +356,20 @@ where let mut full_pivot_data_mat = MutMatrixView::try_from(full_pivot_data.as_mut_slice(), num_centers, full_dim) .bridge_err()?; - full_pivot_data_mat - .broadcast_rows_mut(centroid.as_slice(), |a, b| *a += *b) - .bridge_err()?; + + if full_pivot_data_mat.ncols() != centroid.len() { + return Err(ANNError::log_pq_error(format_args!( + "pivot data ncols {} does not match centroid length {}", + full_pivot_data_mat.ncols(), + centroid.len(), + ))); + } + + for row in full_pivot_data_mat.row_iter_mut() { + for (a, b) in std::iter::zip(row.iter_mut(), centroid.iter()) { + *a += *b; + } + } pq_storage.write_compressed_pivot_metadata::( num_points, diff --git a/diskann-providers/src/model/pq/views.rs b/diskann-providers/src/model/pq/views.rs index 0d0782562..3329c31d2 100644 --- a/diskann-providers/src/model/pq/views.rs +++ b/diskann-providers/src/model/pq/views.rs @@ -32,14 +32,6 @@ impl From>> for ANNError { } } -// Compatibility with ANNError. -impl From> for ANNError { - #[track_caller] - fn from(value: Bridge) -> Self { - ANNError::log_pq_error(value.into_inner()) - } -} - /////////// // Tests // /////////// diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index 81c9da470..a9352918c 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -161,15 +161,6 @@ impl TryFromError { } } -/// Error returned when a broadcast vector's length does not match the matrix's column -/// count. -#[derive(Debug, Error)] -#[error("broadcast vector length {broadcast_len} does not match matrix ncols {ncols}")] -pub struct BroadcastLenMismatch { - pub ncols: usize, - pub broadcast_len: usize, -} - /// A generator for initializing the entries in a matrix via `Matrix::new`. pub trait Generator { fn generate(&mut self) -> T; @@ -397,32 +388,6 @@ where self.data.as_mut_slice().chunks_exact_mut(ncols) } - /// Apply `op(row[j], &broadcast[j])` to every entry of every row, broadcasting - /// `broadcast` across the rows of the matrix through an `op`. - /// - /// Returns an error if `broadcast.len() != self.ncols()`. The matrix is left - /// unmodified in the error case. - pub fn broadcast_rows_mut( - &mut self, - broadcast: &[U], - mut op: F, - ) -> Result<(), BroadcastLenMismatch> - where - T: MutDenseData, - F: FnMut(&mut T::Elem, &U), - { - if self.ncols() != broadcast.len() { - return Err(BroadcastLenMismatch { - ncols: self.ncols(), - broadcast_len: broadcast.len(), - }); - } - self.row_iter_mut().for_each(|row| { - std::iter::zip(row.iter_mut(), broadcast.iter()).for_each(|(a, b)| op(a, b)); - }); - Ok(()) - } - /// Return an iterator that divides the matrix into sub-matrices with (up to) /// `batchsize` rows with `self.ncols()` columns. /// @@ -2215,71 +2180,4 @@ mod tests { // Verify all elements are 43 assert!(m.as_slice().iter().all(|&x| x == 43)); } - - #[test] - fn test_broadcast_rows_mut() { - // Use the canonical test fixture and broadcast-add a length-3 row across all - // rows. The resulting rows are the original rows offset by [10, 20, 30]. - let data = make_test_matrix(); - let mut m = Matrix::try_from(data.into(), 4, 3).unwrap(); - let bias = [10usize, 20, 30]; - - m.broadcast_rows_mut(&bias, |a, b| *a += *b).unwrap(); - assert_eq!(m.row(0), &[10, 21, 32]); - assert_eq!(m.row(1), &[11, 22, 33]); - assert_eq!(m.row(2), &[12, 23, 34]); - assert_eq!(m.row(3), &[13, 24, 35]); - - // Subtract the same bias to confirm op generality and that we round-trip back to - // the canonical fixture. - m.broadcast_rows_mut(&bias, |a, b| *a -= *b).unwrap(); - test_basic_indexing(&m); - - // Broadcast also works through a `MutMatrixView`. - let mut data = make_test_matrix(); - let mut view = MutMatrixView::try_from(data.as_mut_slice(), 4, 3).unwrap(); - view.broadcast_rows_mut(&bias, |a, b| *a += *b).unwrap(); - assert_eq!(view.row(0), &[10, 21, 32]); - assert_eq!(view.row(3), &[13, 24, 35]); - - // Heterogeneous element types: matrix is f32, broadcast is i32. - let mut fmat = Matrix::::new(0.0, 2, 3); - let scale: [i32; 3] = [1, 2, 3]; - fmat.broadcast_rows_mut(&scale, |a, b| *a += *b as f32) - .unwrap(); - assert_eq!(fmat.row(0), &[1.0, 2.0, 3.0]); - assert_eq!(fmat.row(1), &[1.0, 2.0, 3.0]); - - // Single-row degenerate case. - let mut single = Matrix::::new(5, 1, 4); - single - .broadcast_rows_mut(&[1, 2, 3, 4], |a, b| *a += *b) - .unwrap(); - assert_eq!(single.row(0), &[6, 7, 8, 9]); - } - - #[test] - fn test_broadcast_rows_mut_length_mismatch() { - let data = make_test_matrix(); - let mut m = Matrix::try_from(data.clone().into(), 4, 3).unwrap(); - - // Too short. - let err = m.broadcast_rows_mut(&[1, 2], |a, b| *a += *b).unwrap_err(); - assert_eq!( - err.to_string(), - "broadcast vector length 2 does not match matrix ncols 3" - ); - // Matrix must be unmodified after a failed call. - assert_eq!(m.as_slice(), data.as_slice()); - - // Too long. - let err = m - .broadcast_rows_mut(&[1, 2, 3, 4], |a, b| *a += *b) - .unwrap_err(); - assert_eq!( - err.to_string(), - "broadcast vector length 4 does not match matrix ncols 3" - ); - assert_eq!(m.as_slice(), data.as_slice()); - } } From 58a398d3c5e64de49a28483a06f09f2bc55b70db Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 23:19:44 -0400 Subject: [PATCH 13/13] throw err for chunk offset constr --- .../src/model/pq/pq_construction.rs | 1 + diskann-quantization/src/views.rs | 37 ++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index eb96a0509..226b413a6 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -100,6 +100,7 @@ where .ok_or_else(|| ANNError::log_pq_error("dim must be non-zero"))?; let num_chunks = NonZeroUsize::new(parameters.num_pq_chunks()) .ok_or_else(|| ANNError::log_pq_error("num_pq_chunks must be non-zero"))?; + let chunk_offsets = ChunkOffsets::from_dim(dim, num_chunks).bridge_err()?; let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( diff --git a/diskann-quantization/src/views.rs b/diskann-quantization/src/views.rs index d0483be8d..3c54af414 100644 --- a/diskann-quantization/src/views.rs +++ b/diskann-quantization/src/views.rs @@ -54,6 +54,8 @@ pub enum ChunkOffsetError { start: usize, next_val: usize, }, + #[error("num_chunks {num_chunks} must not exceed dim {dim}")] + TooManyChunks { num_chunks: usize, dim: usize }, } impl ChunkOffsetsBase @@ -213,6 +215,12 @@ impl ChunkOffsets { /// Returns an error if the requested partition is not valid (e.g. /// `num_chunks.get() > dim.get()`). pub fn from_dim(dim: NonZeroUsize, num_chunks: NonZeroUsize) -> Result { + if num_chunks.get() > dim.get() { + return Err(ChunkOffsetError::TooManyChunks { + num_chunks: num_chunks.get(), + dim: dim.get(), + }); + } let mut offsets = vec![0usize; num_chunks.get() + 1].into_boxed_slice(); fill_chunk_offsets(dim, &mut offsets); Self::new(offsets) @@ -234,6 +242,13 @@ impl<'a> ChunkOffsetsView<'a> { if scratch.len() < 2 { return Err(ChunkOffsetError::LengthNotAtLeastTwo(scratch.len())); } + let num_chunks = scratch.len() - 1; + if num_chunks > dim.get() { + return Err(ChunkOffsetError::TooManyChunks { + num_chunks, + dim: dim.get(), + }); + } fill_chunk_offsets(dim, scratch); Self::new(scratch) @@ -525,11 +540,17 @@ mod tests { fn from_dimensions_construction_errors() { let nz = |x: usize| NonZeroUsize::new(x).unwrap(); - // num_pq_chunks > dimensions -> some chunk would be empty -> NonMonotonic. + // num_chunks > dim -> TooManyChunks (caught explicitly before partitioning). let err = ChunkOffsets::from_dim(nz(3), nz(5)).unwrap_err(); assert!( - matches!(err, ChunkOffsetError::NonMonotonic { .. }), - "expected NonMonotonic, got {err:?}" + matches!( + err, + ChunkOffsetError::TooManyChunks { + num_chunks: 5, + dim: 3 + } + ), + "expected TooManyChunks, got {err:?}" ); // Scratch length < 2 -> LengthNotAtLeastTwo (cannot infer num_chunks). @@ -537,10 +558,16 @@ mod tests { let err = ChunkOffsetsView::from_dim_into(nz(8), &mut too_short).unwrap_err(); assert!(matches!(err, ChunkOffsetError::LengthNotAtLeastTwo(1))); - // Partition validation errors propagate through the view builder too. + // num_chunks > dim via the view builder too. let mut scratch = [0usize; 6]; let err = ChunkOffsetsView::from_dim_into(nz(3), &mut scratch).unwrap_err(); - assert!(matches!(err, ChunkOffsetError::NonMonotonic { .. })); + assert!(matches!( + err, + ChunkOffsetError::TooManyChunks { + num_chunks: 5, + dim: 3 + } + )); } ///////////////