diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 43f55a8ff..c9284d367 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -108,7 +108,6 @@ mod imp { let quantizer = diskann_providers::model::pq::FixedChunkPQTable::new( data.ncols(), base.flatten().into(), - vec![0.0; data.ncols()].into(), offsets.into(), )?; diff --git a/diskann-disk/src/build/builder/quantizer.rs b/diskann-disk/src/build/builder/quantizer.rs index c3eac75fc..aca72aec7 100644 --- a/diskann-disk/src/build/builder/quantizer.rs +++ b/diskann-disk/src/build/builder/quantizer.rs @@ -73,7 +73,7 @@ impl BuildQuantizer { PQStorage::new(&pq_paths.pivots, &pq_paths.compressed_data, None); pq_build_storage.write_pivot_data( table.get_pq_table(), - table.get_centroids(), + None, table.get_chunk_offsets(), table.get_num_centers(), table.get_dim(), diff --git a/diskann-disk/src/search/pq/quantizer_preprocess.rs b/diskann-disk/src/search/pq/quantizer_preprocess.rs index cc454ea7b..ab7e72743 100644 --- a/diskann-disk/src/search/pq/quantizer_preprocess.rs +++ b/diskann-disk/src/search/pq/quantizer_preprocess.rs @@ -10,7 +10,6 @@ use diskann_providers::model::compute_pq_distance; use diskann_providers::utils::BridgeErr; use super::{PQData, PQScratch}; -use crate::storage::quant::pq::pq_dataset::PQTable; /// Preprocesses the query vector for PQ distance calculations. /// This function rotates the query vector and prepares the PQ table distances @@ -21,64 +20,33 @@ pub fn quantizer_preprocess( metric: Metric, id_to_calculate_pq_distance: &[u32], ) -> ANNResult<()> { - match &pq_data.pq_table() { - PQTable::Transposed(table) => { - let dim = table.dim(); - let expected_len = table.ncenters() * table.nchunks(); - let dst = diskann_utils::views::MutMatrixView::try_from( - &mut (*pq_scratch.aligned_pqtable_dist_scratch)[..expected_len], - table.nchunks(), - table.ncenters(), - ) - .bridge_err()?; + let table = pq_data.pq_table(); + let dim = table.dim(); + let expected_len = table.ncenters() * table.nchunks(); + let dst = diskann_utils::views::MutMatrixView::try_from( + &mut (*pq_scratch.aligned_pqtable_dist_scratch)[..expected_len], + table.nchunks(), + table.ncenters(), + ) + .bridge_err()?; - match metric { - // Prior to the introduction of the `quantizer_preprocess` method, the - // disk index was hard-coded to use L2 distance for comparisons. - // - // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` - // as L2 until a more thorough evaluation can be made. - Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { - table.process_into::( - &pq_scratch.rotated_query[..dim], - dst, - ); - } - Metric::InnerProduct => { - table.process_into::( - &pq_scratch.rotated_query[..dim], - dst, - ); - } - } + match metric { + // Prior to the introduction of the `quantizer_preprocess` method, the + // disk index was hard-coded to use L2 distance for comparisons. + // + // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` + // as L2 until a more thorough evaluation can be made. + Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { + table.process_into::( + &pq_scratch.rotated_query[..dim], + dst, + ); } - PQTable::Fixed(table) => { - match metric { - // Prior to the introduction of the `quantizer_preprocess` method, the - // disk index was hard-coded to use L2 distance for comparisons. - // - // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` - // as L2 until a more thorough evaluation can be made. - Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { - // The scratch only stores the aligned dimension. However, preprocessing - // wants the actual dimension used, so we have to shrink the rotated query - // accordingly. - let dim = table.get_dim(); - table.preprocess_query(&mut pq_scratch.rotated_query[..dim]); - - // Compute the distance between each chunk of the query to each pq centroids. - table.populate_chunk_distances( - pq_scratch.rotated_query.as_slice(), - &mut pq_scratch.aligned_pqtable_dist_scratch, - )?; - } - Metric::InnerProduct => { - table.populate_chunk_inner_products( - pq_scratch.rotated_query.as_slice(), - &mut pq_scratch.aligned_pqtable_dist_scratch, - )?; - } - } + Metric::InnerProduct => { + table.process_into::( + &pq_scratch.rotated_query[..dim], + dst, + ); } } diff --git a/diskann-disk/src/storage/quant/mod.rs b/diskann-disk/src/storage/quant/mod.rs index d6988638c..5652ea1c9 100644 --- a/diskann-disk/src/storage/quant/mod.rs +++ b/diskann-disk/src/storage/quant/mod.rs @@ -8,7 +8,7 @@ pub use generator::{GeneratorContext, QuantDataGenerator}; pub(crate) mod pq; pub use pq::pq_generation::{PQGeneration, PQGenerationContext}; -pub use pq::{PQData, PQTable}; +pub use pq::PQData; mod compressor; pub use compressor::{CompressionStage, QuantCompressor}; diff --git a/diskann-disk/src/storage/quant/pq/mod.rs b/diskann-disk/src/storage/quant/pq/mod.rs index d5661724a..bdda3b4e8 100644 --- a/diskann-disk/src/storage/quant/pq/mod.rs +++ b/diskann-disk/src/storage/quant/pq/mod.rs @@ -5,6 +5,5 @@ pub(crate) mod pq_dataset; pub use pq_dataset::PQData; -pub use pq_dataset::PQTable; pub mod pq_generation; diff --git a/diskann-disk/src/storage/quant/pq/pq_dataset.rs b/diskann-disk/src/storage/quant/pq/pq_dataset.rs index 049825a19..4eaf66428 100644 --- a/diskann-disk/src/storage/quant/pq/pq_dataset.rs +++ b/diskann-disk/src/storage/quant/pq/pq_dataset.rs @@ -10,22 +10,10 @@ use diskann_providers::model::FixedChunkPQTable; use diskann_quantization::product::TransposedTable; use diskann_utils::views::Matrix; -/// Behind the scenes, we can use either the [`FixedChunkPQTable`] or a -/// [`diskann_quantization::product::TransposedTable`]. The [`TransposedTable`] is much faster -/// for preprocessing, but does not support removal of the dataset centroid. -/// -/// So, we can only use the [`TransposedTable`] when the dataset centroid -/// is all zero. -#[derive(Debug)] -pub enum PQTable { - Transposed(TransposedTable), - Fixed(FixedChunkPQTable), -} - #[derive(Debug)] pub struct PQData { // pq pivot table. - pq_pivot_table: PQTable, + pq_pivot_table: TransposedTable, // pq compressed vectors, shape `num_points × num_pq_chunks`. pq_compressed_data: Matrix, @@ -36,18 +24,11 @@ impl PQData { pq_pivot_table: FixedChunkPQTable, pq_compressed_data: Matrix, ) -> ANNResult { - // Check if we can use the transposed table. If so, go for it. - let centroid_is_zero = pq_pivot_table.get_centroids().iter().all(|i| *i == 0.0); - let pq_pivot_table = if centroid_is_zero { - let transposed = TransposedTable::from_parts( - pq_pivot_table.view_pivots(), - pq_pivot_table.view_offsets().to_owned(), - ) - .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?; - PQTable::Transposed(transposed) - } else { - PQTable::Fixed(pq_pivot_table) - }; + let pq_pivot_table = TransposedTable::from_parts( + pq_pivot_table.view_pivots(), + pq_pivot_table.view_offsets().to_owned(), + ) + .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?; Ok(Self { pq_pivot_table, @@ -56,24 +37,18 @@ impl PQData { } /// Get pq_table - pub fn pq_table(&self) -> &PQTable { + pub fn pq_table(&self) -> &TransposedTable { &self.pq_pivot_table } /// Return the number of chunks in the underlying PQ schema. pub fn get_num_chunks(&self) -> usize { - match &self.pq_pivot_table { - PQTable::Transposed(table) => table.nchunks(), - PQTable::Fixed(table) => table.get_num_chunks(), - } + self.pq_pivot_table.nchunks() } /// Return the number of centers in the underlying PQ schema. pub fn get_num_centers(&self) -> usize { - match &self.pq_pivot_table { - PQTable::Transposed(table) => table.ncenters(), - PQTable::Fixed(table) => table.get_num_centers(), - } + self.pq_pivot_table.ncenters() } /// Get pq_compressed_data @@ -97,13 +72,8 @@ mod tests { fn create_pq_data() -> ANNResult { let dim = 2; - let pq_pivot_table = FixedChunkPQTable::new( - dim, - Box::new([0.0, 0.0, 1.0, 1.0]), - Box::new([0.0, 0.0]), - Box::new([0, 2]), - ) - .unwrap(); + let pq_pivot_table = + FixedChunkPQTable::new(dim, Box::new([0.0, 0.0, 1.0, 1.0]), Box::new([0, 2])).unwrap(); let pq_compressed_data = Matrix::try_from(Box::new([123u8, 111, 255]) as Box<[u8]>, 3, 1) .expect("valid matrix shape"); diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index ccd2c30a7..0e7b07fd4 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -100,8 +100,8 @@ where context.num_centers, context.num_chunks, context.max_kmeans_reps, - context.metric == Metric::L2, )?, + context.metric == Metric::L2, &mut train_data, &context.pq_storage, context.storage_provider, @@ -260,15 +260,9 @@ mod pq_generation_tests { let pool = create_thread_pool_for_test(); generate_pq_pivots( - GeneratePivotArguments::new( - ndata, - dim, - num_centers, - num_chunks, - max_k_means_reps, - true, - ) - .unwrap(), + GeneratePivotArguments::new(ndata, dim, num_centers, num_chunks, max_k_means_reps) + .unwrap(), + true, &mut train_data, &pq_storage, &storage_provider, diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 2a8e24668..0d6da678e 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -71,16 +71,13 @@ pub fn train_pq( model::pq::NUM_PQ_CENTROIDS, num_pq_chunks, 5, - false, )?; - let mut centroid = vec![0.0; dim]; let mut offsets = vec![0; num_pq_chunks + 1]; let mut full_pivot_data = vec![0.0; model::pq::NUM_PQ_CENTROIDS * dim]; model::pq::generate_pq_pivots_from_membuf( &pivot_args, data.as_slice(), - &mut centroid, &mut offsets, &mut full_pivot_data, rng, @@ -88,7 +85,7 @@ pub fn train_pq( pool, )?; - model::pq::FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into()) + model::pq::FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into()) } pub type MemoryIndex = Arc>>; @@ -1559,13 +1556,8 @@ pub(crate) mod tests { ) .unwrap(); - let pqtable = model::pq::FixedChunkPQTable::new( - dim, - Box::new([0.0, 0.0]), - Box::new([0.0, 0.0]), - Box::new([0, 2]), - ) - .unwrap(); + let pqtable = + model::pq::FixedChunkPQTable::new(dim, Box::new([0.0, 0.0]), Box::new([0, 2])).unwrap(); let index = new_quant_index::(config, parameters, pqtable, TableBasedDeletes).unwrap(); @@ -1680,13 +1672,8 @@ pub(crate) mod tests { ) .unwrap(); - let pqtable = model::pq::FixedChunkPQTable::new( - dim, - Box::new([0.0, 0.0]), - Box::new([0.0, 0.0]), - Box::new([0, 2]), - ) - .unwrap(); + let pqtable = + model::pq::FixedChunkPQTable::new(dim, Box::new([0.0, 0.0]), Box::new([0, 2])).unwrap(); let index = new_quant_index::(config, parameters, pqtable, TableBasedDeletes).unwrap(); diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index f7289e146..d5baabd4b 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -2156,7 +2156,7 @@ where let pq_table = &self.quant_vectors.pq_chunk_table; pq_storage.write_pivot_data( pq_table.get_pq_table(), - pq_table.get_centroids(), + None, pq_table.get_chunk_offsets(), NUM_PQ_CENTROIDS, pq_table.get_dim(), @@ -2683,7 +2683,6 @@ mod tests { let pq_table = FixedChunkPQTable::new( dim, vec![0.0; dim * 256].into_boxed_slice(), - vec![0.0; dim].into_boxed_slice(), Box::new([0, 4, dim]), ) .unwrap(); @@ -2785,11 +2784,6 @@ mod tests { loaded_pq.get_pq_table(), "PQ table data mismatch" ); - assert_eq!( - original_pq.get_centroids(), - loaded_pq.get_centroids(), - "PQ table centroids mismatch" - ); assert_eq!( original_pq.get_chunk_offsets(), loaded_pq.get_chunk_offsets(), @@ -2986,7 +2980,6 @@ mod tests { let pq_table = FixedChunkPQTable::new( dim, vec![0.0; dim * 256].into_boxed_slice(), - vec![0.0; dim].into_boxed_slice(), Box::new([0, 4, dim]), ) .unwrap(); diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs index 0d26fa680..11ba43f61 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs @@ -270,13 +270,11 @@ mod tests { #[tokio::test] async fn common_errors() { let dim = 5; - let centroid = vec![0.0; dim]; let offsets = vec![0, dim]; let full_pivot_data = vec![0.0; 256 * dim]; let pq_chunk_table = - FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into()) - .unwrap(); + FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into()).unwrap(); let bf_tree_config = Config::default(); let provider = @@ -306,7 +304,6 @@ mod tests { let table = FixedChunkPQTable::new( dim, Box::new([0.0, 0.0, 1.0, 1.0, 2.0, 2.0]), - Box::new([0.0, 0.0]), Box::new([0, dim]), ) .unwrap(); @@ -379,12 +376,10 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn test_parallel_tree_traversal() { let dim = 2; - let centroid = vec![0.0; dim]; let offsets = vec![0, dim]; let full_pivot_data = vec![0.0; 256 * dim]; let pq_chunk_table = - FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into()) - .unwrap(); + FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into()).unwrap(); let bf_tree_config = Config::default(); let provider = Arc::new( diff --git a/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs b/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs index c77e257d2..1f4fe2242 100644 --- a/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs +++ b/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs @@ -150,7 +150,6 @@ impl TestMultiPQProviderAsync { &vector_f32, table.get_pq_table(), table.get_num_centers(), - Some(table.get_centroids()), table.get_chunk_offsets(), &mut quant_vector, ) diff --git a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs index 62d9c05e1..857689955 100644 --- a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs @@ -261,7 +261,7 @@ impl FastMemoryQuantVectorProviderAsync { let table = &self.pq_chunk_table; pq_storage.write_pivot_data( table.get_pq_table(), - table.get_centroids(), + None, table.get_chunk_offsets(), table.get_num_centers(), table.get_dim(), @@ -367,13 +367,11 @@ mod tests { #[tokio::test] async fn common_errors() { let dim = 5; - let centroid = vec![0.0; dim]; let offsets = vec![0, dim]; let full_pivot_data = vec![0.0; 256 * dim]; let pq_chunk_table = - FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into()) - .unwrap(); + FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into()).unwrap(); let provider = FastMemoryQuantVectorProviderAsync::new(Metric::L2, 10, pq_chunk_table); // try to set an out of bounds vector @@ -400,7 +398,6 @@ mod tests { let table = FixedChunkPQTable::new( dim, Box::new([0.0, 0.0, 1.0, 1.0, 2.0, 2.0]), - Box::new([0.0, 0.0]), Box::new([0, dim]), ) .unwrap(); diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs b/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs index 9ebe2c82b..81a8554ff 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs @@ -156,7 +156,6 @@ use crate::{ /// let table = FixedChunkPQTable::new( /// dim, /// Box::new([0.0, 0.0, 0.0, 0.0]), -/// Box::new([0.0, 0.0, 0.0, 0.0]), /// Box::new([0, dim]), /// ).unwrap(); /// @@ -205,7 +204,6 @@ use crate::{ /// let table = FixedChunkPQTable::new( /// dim, /// Box::new([0.0, 0.0, 0.0, 0.0]), -/// Box::new([0.0, 0.0, 0.0, 0.0]), /// Box::new([0, dim]), /// ).unwrap(); /// diff --git a/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs index dca7ea762..1cec0b0ac 100644 --- a/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs @@ -195,7 +195,7 @@ impl MemoryQuantVectorProviderAsync { let table = &self.pq_chunk_table; pq_storage.write_pivot_data( table.get_pq_table(), - table.get_centroids(), + None, table.get_chunk_offsets(), table.get_num_centers(), table.get_dim(), @@ -335,7 +335,6 @@ mod tests { let table = FixedChunkPQTable::new( dim, Box::new([0.0, 0.0, 1.0, 1.0, 2.0, 2.0]), - Box::new([0.0, 0.0]), Box::new([0, dim]), ) .unwrap(); diff --git a/diskann-providers/src/model/pq/distance/l2.rs b/diskann-providers/src/model/pq/distance/l2.rs index 543534796..c003fb64a 100644 --- a/diskann-providers/src/model/pq/distance/l2.rs +++ b/diskann-providers/src/model/pq/distance/l2.rs @@ -80,12 +80,7 @@ where // Alignment means that the size of `query` gets increased ... // This makes is VERY hard to do error checking on dimension propagation. assert!(self.parent.get_dim() <= query.len()); - let mut local_query: Vec = query.iter().map(|x| (*x).into()).collect(); - - // This function does the following: - // 1. Centers the data (if the centorid is non-zero). - // 2. Applies the OPQ transformation matrix (if it exists). - self.parent.preprocess_query(&mut local_query); + let local_query: Vec = query.iter().map(|x| (*x).into()).collect(); // Compute the partial distances into the lookup-table. self.parent diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 76c240e1f..0f535eb21 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -149,10 +149,7 @@ pub(crate) fn seed_pivot_table(config: TableConfig) -> FixedChunkPQTable { } assert_eq!(pivots.len(), config.dim * config.num_pivots); - - let centroid = vec![0.0f32; config.dim]; - - FixedChunkPQTable::new(config.dim, pivots.into(), centroid.into(), offsets.into()).unwrap() + FixedChunkPQTable::new(config.dim, pivots.into(), offsets.into()).unwrap() } /// Generate a random PQ code spanning the requested number of pivots and chunks. @@ -198,9 +195,7 @@ pub(super) fn test_l2_inner<'a, T, F, R>( { for _ in 0..num_trials { let input: Vec = T::generate(config.dim, rng); - let mut input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); - - table.preprocess_query(&mut input_f32); + let input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); let computer = create(table, &input); for _ in 0..num_trials { diff --git a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs index 1f58205cb..8a80ce82c 100644 --- a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs +++ b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs @@ -24,9 +24,6 @@ use crate::utils::{Bridge, BridgeErr}; pub struct FixedChunkPQTable { /// The underlying table representation. table: BasicTable, - - /// centroid of each dimension - centroids: Box<[f32]>, } // These free functions use internals of the `FixedChunkPQTable`. @@ -109,9 +106,6 @@ impl FixedChunkPQTable { /// /// Refer to the later section for the expected layout of this table. /// - /// * `centroids`: The dimension-wise mean of the training data. The slice underlying - /// this representation must have length `dim`. - /// /// * `chunk_offsets`: A vector marking the beginning and end of each chunk. That is, /// the offsets of the start of chunk `i` is `chunk_offsets[i]` and the end is /// `chunk_offsets[i+1]`. @@ -135,12 +129,7 @@ impl FixedChunkPQTable { /// ... | ... | ... | ... | .... | ... | /// pivot K | cK00 cK01 ... | cK10 cK11 ... | cK20 cK21 ... | .... | ... | /// ``` - pub fn new( - dim: usize, - pq_table: Box<[f32]>, - centroids: Box<[f32]>, - chunk_offsets: Box<[usize]>, - ) -> ANNResult { + pub fn new(dim: usize, pq_table: Box<[f32]>, chunk_offsets: Box<[usize]>) -> ANNResult { let len = pq_table.len(); let table = BasicTable::new( MatrixBase::try_from(pq_table, len / dim, dim).bridge_err()?, @@ -148,15 +137,7 @@ impl FixedChunkPQTable { ) .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?; - if centroids.len() != dim { - return Err(ANNError::log_pq_error(format_args!( - "centroids slice has length {} but the expected dim is {}", - centroids.len(), - dim - ))); - } - - Ok(Self { table, centroids }) + Ok(Self { table }) } /// Get chunk number. @@ -164,14 +145,6 @@ impl FixedChunkPQTable { self.table.nchunks() } - /// Shifting the query according to mean or the whole corpus. The output is a rotated query vector, - /// which is later used to calculate the distance between each query chunk and each centroid using populate_chunk_distances. - pub fn preprocess_query(&self, rotated_query_vec: &mut [f32]) { - for (query, ¢roid) in rotated_query_vec.iter_mut().zip(self.centroids.iter()) { - *query -= centroid; - } - } - pub fn populate_chunk_distances_impl( &self, rotated_query_vec: &[f32], @@ -394,8 +367,8 @@ impl FixedChunkPQTable { /// # Panics /// /// Panics under the following condition: + /// /// * `base_vec.length() != self.get_dim()`. - /// * Any entry in `base_vec` exceeds `self.get_centroids()`. pub fn inflate_vector(&self, base_vec: &[u8]) -> Vec { let mut out_vec: Vec = vec![0.0; self.get_dim()]; self.inflate_vector_into(base_vec, &mut out_vec); @@ -417,10 +390,7 @@ impl FixedChunkPQTable { let stop = chunk_offsets[i + 1]; let out_slice = &mut out[start..stop]; let pivot = &pq_table[(dim * b + start)..(dim * b + stop)]; - let centroid = &self.centroids[start..stop]; - std::iter::zip(out_slice.iter_mut(), pivot.iter()) - .zip(centroid.iter()) - .for_each(|((o, p), c)| *o = *p + *c); + std::iter::zip(out_slice.iter_mut(), pivot.iter()).for_each(|(o, p)| *o = *p); }); } @@ -439,11 +409,6 @@ impl FixedChunkPQTable { self.table.view_offsets().into() } - /// Returns an immutable reference to the `centroids`. - pub fn get_centroids(&self) -> &[f32] { - &self.centroids - } - /// Returns the original dimension of the vectors. pub fn get_dim(&self) -> usize { self.table.dim() @@ -482,13 +447,8 @@ where /// Internally, this calls [`diskann_quantization::product::BasicTable::compress_into`]. /// See the documentation for that method about the failure modes for this function. fn compress_into(&self, from: &[T], to: &mut [u8]) -> Result<(), Self::Error> { - let translated: Vec = std::iter::zip(from.iter(), self.centroids.iter()) - .map(|(f, c)| { - let f: f32 = (*f).into(); - f - *c - }) - .collect(); - self.table.compress_into(&*translated, to) + let converted: Box<[f32]> = from.iter().map(|&v| v.into()).collect(); + self.table.compress_into(&*converted, to) } } @@ -727,70 +687,58 @@ mod fixed_chunk_pq_table_test { #[test] fn constructor_errors() { // Test that we verify all the requirements in the constructor. - type PreSchema = (usize, Box<[f32]>, Box<[f32]>, Box<[usize]>); + type PreSchema = (usize, Box<[f32]>, Box<[usize]>); fn create_valid_schema() -> PreSchema { let dim = 5; - ( - dim, - vec![0.0; dim * 4].into(), - vec![0.0; dim].into(), - Box::new([0, 2, 3, dim]), - ) + (dim, vec![0.0; dim * 4].into(), Box::new([0, 2, 3, dim])) } // Check that our valid schema is indeed valid. { - let (dim, pq_table, centroids, chunk_offsets) = create_valid_schema(); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_ok()); + let (dim, pq_table, chunk_offsets) = create_valid_schema(); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_ok()); } // `pq_table` length not evenly divisible by `dim`.. { - let (dim, _, centroids, chunk_offsets) = create_valid_schema(); + let (dim, _, chunk_offsets) = create_valid_schema(); let pq_table = vec![0.0; dim * 3 + 1].into(); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); - } - - // `centroids` length not equal to `dim`.. - { - let (dim, pq_table, _, chunk_offsets) = create_valid_schema(); - let centroids = vec![0.0; dim - 1].into(); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_err()); } // `offsets` does not begin at zero. { - let (dim, pq_table, centroids, _) = create_valid_schema(); + let (dim, pq_table, _) = create_valid_schema(); let chunk_offsets = Box::new([1, 2, dim]); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_err()); } // `offsets` empty { - let (dim, pq_table, centroids, _) = create_valid_schema(); + let (dim, pq_table, _) = create_valid_schema(); let chunk_offsets = Box::new([]); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_err()); } // `offsets` has length 1. { - let (dim, pq_table, centroids, _) = create_valid_schema(); + let (dim, pq_table, _) = create_valid_schema(); let chunk_offsets = Box::new([0]); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_err()); } // `offsets` not strictly monotonic. { - let (dim, pq_table, centroids, _) = create_valid_schema(); + let (dim, pq_table, _) = create_valid_schema(); let chunk_offsets = Box::new([0, 1, 2, 2, dim]); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_err()); } // `offsets` does not end at `dim`. { - let (dim, pq_table, centroids, _) = create_valid_schema(); + let (dim, pq_table, _) = create_valid_schema(); let chunk_offsets = Box::new([0, 1, 2, dim, dim + 1]); - assert!(FixedChunkPQTable::new(dim, pq_table, centroids, chunk_offsets).is_err()); + assert!(FixedChunkPQTable::new(dim, pq_table, chunk_offsets).is_err()); } } @@ -859,100 +807,39 @@ mod fixed_chunk_pq_table_test { fn load_pivot_test() { let storage_provider = VirtualStorageProvider::new_overlay(test_data_root()); let pq_pivots_path: &str = "/sift/siftsmall_learn_pq_pivots.bin"; - let (dim, pq_table, centroids, chunk_offsets) = + let (dim, pq_table, chunk_offsets) = load_pq_pivots_bin(pq_pivots_path, &1, &storage_provider).unwrap(); let fixed_chunk_pq_table = - FixedChunkPQTable::new(dim, pq_table.into(), centroids.into(), chunk_offsets.into()) - .unwrap(); + FixedChunkPQTable::new(dim, pq_table.into(), chunk_offsets.into()).unwrap(); assert_eq!(dim, DIM); assert_eq!(fixed_chunk_pq_table.table.dim(), DIM); assert_eq!(fixed_chunk_pq_table.table.ncenters(), NUM_PQ_CENTROIDS); - assert_eq!(fixed_chunk_pq_table.centroids.len(), DIM); assert_eq!(fixed_chunk_pq_table.get_chunk_offsets(), &[0, DIM]); } - #[test] - fn clone_pivot_table() { - let dim = 128; - let num_pq_centroids = 4; - let pq_table = vec![1.0; dim * num_pq_centroids]; - let centroids = vec![1.0; dim]; - let chunk_offsets = vec![0, 7, 9, 11, 22, 34, 78, dim]; - - let base = - FixedChunkPQTable::new(dim, pq_table.into(), centroids.into(), chunk_offsets.into()) - .unwrap(); - - let clone = base.clone(); - let FixedChunkPQTable { table, centroids } = clone; - - assert_eq!(table.view_pivots(), base.table.view_pivots()); - assert_eq!(table.view_offsets(), base.table.view_offsets()); - assert_eq!(centroids, base.centroids); - } - #[test] fn get_num_chunks_test() { let num_chunks = 7; let pa_table = vec![0.0; DIM * NUM_PQ_CENTROIDS]; - let centroids = vec![0.0; DIM]; let chunk_offsets = vec![0, 7, 9, 11, 22, 34, 78, 128]; let fixed_chunk_pq_table = - FixedChunkPQTable::new(DIM, pa_table.into(), centroids.into(), chunk_offsets.into()) - .unwrap(); + FixedChunkPQTable::new(DIM, pa_table.into(), chunk_offsets.into()).unwrap(); let chunk: usize = fixed_chunk_pq_table.get_num_chunks(); assert_eq!(chunk, num_chunks); } - #[test] - fn preprocess_query_test() { - let storage_provider = VirtualStorageProvider::new_overlay(test_data_root()); - - let pq_pivots_path: &str = "/sift/siftsmall_learn_pq_pivots.bin"; - let (dim, pq_table, centroids, chunk_offsets) = - load_pq_pivots_bin(pq_pivots_path, &1, &storage_provider).unwrap(); - let fixed_chunk_pq_table = - FixedChunkPQTable::new(dim, pq_table.into(), centroids.into(), chunk_offsets.into()) - .unwrap(); - - let mut query_vec: Vec = vec![ - 32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32, - 68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32, - 7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32, - 65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32, - 11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32, - 59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32, - 7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32, - 48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32, - 47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32, - 74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32, - 65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32, - 41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32, - 84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32, - 74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32, - 37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32, - ]; - fixed_chunk_pq_table.preprocess_query(&mut query_vec); - assert_eq!(query_vec[0], 32.39f32 - fixed_chunk_pq_table.centroids[0]); - assert_eq!( - query_vec[127], - 35.6f32 - fixed_chunk_pq_table.centroids[127] - ); - } - #[test] fn calculate_distances_tests() { let storage_provider = VirtualStorageProvider::new_overlay(test_data_root()); let pq_pivots_path: &str = "/sift/siftsmall_learn_pq_pivots.bin"; - let (dim, pq_table, centroids, chunk_offsets) = + let (dim, pq_table, chunk_offsets) = load_pq_pivots_bin(pq_pivots_path, &1, &storage_provider).unwrap(); let fixed_chunk_pq_table = - FixedChunkPQTable::new(dim, pq_table.into(), centroids.into(), chunk_offsets.into()) - .unwrap(); + FixedChunkPQTable::new(dim, pq_table.into(), chunk_offsets.into()).unwrap(); let query_vec: Vec = vec![ 32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32, @@ -1010,18 +897,9 @@ mod fixed_chunk_pq_table_test { // inflate_vector_test let inflate_vector = fixed_chunk_pq_table.inflate_vector(&base_vec); assert_eq!(inflate_vector.len(), DIM); - assert_eq!( - inflate_vector[0], - pivots[(3, 0)] + fixed_chunk_pq_table.centroids[0] - ); - assert_eq!( - inflate_vector[1], - pivots[(3, 1)] + fixed_chunk_pq_table.centroids[1] - ); - assert_eq!( - inflate_vector[127], - pivots[(3, 127)] + fixed_chunk_pq_table.centroids[127] - ); + assert_eq!(inflate_vector[0], pivots[(3, 0)]); + assert_eq!(inflate_vector[1], pivots[(3, 1)]); + assert_eq!(inflate_vector[127], pivots[(3, 127)]); } #[test] @@ -1032,7 +910,6 @@ mod fixed_chunk_pq_table_test { let num_centers = 3; let dim = 11; let offsets = vec![0, 4, 8, dim]; - let centroid = vec![0.0; dim]; let pq_pivots_pre = vec![ vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0], @@ -1062,8 +939,7 @@ mod fixed_chunk_pq_table_test { acc }); - let table = - FixedChunkPQTable::new(dim, pq_table.into(), centroid.into(), offsets.into()).unwrap(); + let table = FixedChunkPQTable::new(dim, pq_table.into(), offsets.into()).unwrap(); let max_relative: f32 = 1.0e-7; let range: Range = 0..(num_centers as u8); @@ -1112,7 +988,7 @@ mod fixed_chunk_pq_table_test { } } - type LoadPQPivotResult = (usize, Vec, Vec, Vec); + type LoadPQPivotResult = (usize, Vec, Vec); fn load_pq_pivots_bin( pq_pivots_path: &str, num_pq_chunks: &usize, @@ -1134,7 +1010,7 @@ mod fixed_chunk_pq_table_test { } let file_offset_data = offsets.map(|x| x.into_usize()); - let pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)])?; + let mut pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)])?; if pivots.nrows() != NUM_PQ_CENTROIDS { return Err(ANNError::log_pq_error(format_args!( @@ -1158,6 +1034,10 @@ mod fixed_chunk_pq_table_test { ))); } + pivots.row_iter_mut().for_each(|row| { + std::iter::zip(row.iter_mut(), centroids.as_slice().iter()).for_each(|(p, c)| *p += *c); + }); + let chunk_offsets_m = read_bin_from::(&mut reader, file_offset_data[(2, 0)])?; if chunk_offsets_m.nrows() != num_pq_chunks + 1 || chunk_offsets_m.ncols() != 1 { return Err(ANNError::log_pq_error(format_args!( @@ -1173,7 +1053,6 @@ mod fixed_chunk_pq_table_test { Ok(( dim, pivots.into_inner().into_vec(), - centroids.into_inner().into_vec(), chunk_offsets.into_inner().into_vec(), )) } @@ -1186,15 +1065,9 @@ mod fixed_chunk_pq_table_test { let mut rng = crate::utils::create_rnd_in_tests(); let pq_table: Vec = (0..NUM_PQ_CENTROIDS * dim).map(|_| rng.random()).collect(); - let centroids: Vec = (0..dim).map(|_| rng.random()).collect(); let chunk_offsets = vec![0, 8]; - let fixed_chunk_pq_table = FixedChunkPQTable::new( - dim, - pq_table.into(), - centroids.into(), - chunk_offsets.clone().into(), - ) - .unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, pq_table.into(), chunk_offsets.clone().into()).unwrap(); let rotated_query_vec: Vec = (0..dim).map(|_| rng.random()).collect(); let mut aligned_pq_table_dist_scratch = vec![0.0; num_pq_chunks * NUM_PQ_CENTROIDS]; @@ -1222,11 +1095,8 @@ mod fixed_chunk_pq_table_test { fn test_populate_chunk_distances_invalid_input() { let dim = 6; let pq_table = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let centroids = vec![0.0; dim]; let chunk_offsets = vec![0, 2, 4, 6]; - let pq_table = - FixedChunkPQTable::new(dim, pq_table.into(), centroids.into(), chunk_offsets.into()) - .unwrap(); + let pq_table = FixedChunkPQTable::new(dim, pq_table.into(), chunk_offsets.into()).unwrap(); let mut aligned_pq_table_dist_scratch = [0.0; 2]; let rotated_query_vec = vec![0.0; dim]; diff --git a/diskann-providers/src/model/pq/generate_pivot_arguments.rs b/diskann-providers/src/model/pq/generate_pivot_arguments.rs index 752656a9f..a70252f8a 100644 --- a/diskann-providers/src/model/pq/generate_pivot_arguments.rs +++ b/diskann-providers/src/model/pq/generate_pivot_arguments.rs @@ -32,10 +32,6 @@ use thiserror::Error; /// /// * `max_k_means_reps` - The maximum number of iterations for the k-means clustering algorithm. /// Increasing this value can improve clustering quality at the cost of additional computation time. -/// -/// * `translate_to_center` - A boolean flag indicating whether the data should be translated -/// (centered) to the origin before clustering. Centering can improve clustering performance by -/// reducing variance caused by the global offset of the data. #[derive(Debug, Clone)] pub struct GeneratePivotArguments { num_train: usize, @@ -43,7 +39,6 @@ pub struct GeneratePivotArguments { num_centers: usize, num_pq_chunks: usize, max_k_means_reps: usize, - translate_to_center: bool, } #[derive(Error, Debug, PartialEq)] @@ -79,7 +74,6 @@ impl GeneratePivotArguments { num_centers: usize, num_pq_chunks: usize, max_k_means_reps: usize, - translate_to_center: bool, ) -> Result { if num_pq_chunks > dim { return Err(GeneratePivotArgumentsError::NumChunksMoreThanDim { num_pq_chunks, dim }); @@ -107,7 +101,6 @@ impl GeneratePivotArguments { num_centers, num_pq_chunks, max_k_means_reps, - translate_to_center, }) } @@ -135,11 +128,6 @@ impl GeneratePivotArguments { pub fn max_k_means_reps(&self) -> usize { self.max_k_means_reps } - - /// Get whether to translate to center - pub fn translate_to_center(&self) -> bool { - self.translate_to_center - } } #[cfg(test)] @@ -162,7 +150,6 @@ mod arguments_test { num_centers, num_pq_chunks, max_k_means_reps, - true, ); assert!(result.is_err()); @@ -186,7 +173,6 @@ mod arguments_test { num_centers, num_pq_chunks, max_k_means_reps, - true, ); assert!(result.is_err()); @@ -210,7 +196,6 @@ mod arguments_test { num_centers, num_pq_chunks, max_k_means_reps, - true, ); assert!(result.is_err()); @@ -234,7 +219,6 @@ mod arguments_test { num_centers, num_pq_chunks, max_k_means_reps, - true, ); assert!(result.is_err()); diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 2862d7e26..68d71f7b4 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -63,8 +63,16 @@ where /// k-means in each chunk to compute the PQ pivots and stores in bin format in /// file pq_pivots_path as a s num_centers*dim floating point binary file /// PQ pivot table layout: {pivot offsets data: METADATA_SIZE}{pivot vector:[dim; num_centroid]}{centroid vector:[dim; 1]}{chunk offsets:[chunk_num+1; 1]} +/// +/// Argument `legacy_center_data` will center the provided data by the dataset mean. +/// This is to supply backwards compatibility with some `diskann-disk` tests that used this +/// feature and require exact reproducibility in some tests. +/// +/// This argument should **only** be used if the distance metric being used is L2. Otherwise +/// any computed distance on the resulting PQ compressed data will be incorrect. pub fn generate_pq_pivots( parameters: GeneratePivotArguments, + legacy_center_data: bool, train_data: &mut [f32], pq_storage: &PQStorage, storage_provider: &Storage, @@ -85,7 +93,7 @@ where } let mut centroid: Vec = vec![0.0; parameters.dim()]; - if parameters.translate_to_center() { + if legacy_center_data { move_train_data_by_centroid( train_data, parameters.num_train(), @@ -124,7 +132,7 @@ where pq_storage.write_pivot_data( &full_pivot_data, - ¢roid, + Some(¢roid), &chunk_offsets, parameters.num_centers(), parameters.dim(), @@ -151,7 +159,6 @@ where pub fn generate_pq_pivots_from_membuf>( parameters: &GeneratePivotArguments, train_data_slice: &[T], - centroid: &mut [f32], offsets: &mut [usize], full_pivot_data: &mut [f32], rng: &mut (impl Rng + ?Sized), @@ -164,12 +171,6 @@ pub fn generate_pq_pivots_from_membuf>( )); } - if centroid.len() != parameters.dim() { - return Err(ANNError::log_pq_error( - "Error: centroid size is not equal to dim.", - )); - } - if offsets.len() != parameters.num_pq_chunks() + 1 { return Err(ANNError::log_pq_error( "Error: invalid offsets buffer input size.", @@ -183,25 +184,11 @@ pub fn generate_pq_pivots_from_membuf>( } // Convert train_data to f32 - let mut train_data = train_data_slice + let train_data = train_data_slice .iter() .map(|x| (*x).into()) .collect::>(); - // Calculate the centroid if needed and move the train_data to the centroid - if parameters.translate_to_center() { - move_train_data_by_centroid( - &mut train_data, - parameters.num_train(), - parameters.dim(), - centroid, - ); - } else { - for val in centroid.iter_mut() { - *val = 0.0; - } - } - // Calculate the chunk offsets calculate_chunk_offsets(parameters.dim(), parameters.num_pq_chunks(), offsets); @@ -555,7 +542,6 @@ pub fn generate_pq_data_from_pivots_from_membuf>( vector_data: &[T], pivot_data: &[f32], num_pivots: usize, - centroid: Option<&[f32]>, offsets: &[usize], pq_out: &mut [u8], ) -> ANNResult<()> { @@ -573,26 +559,11 @@ pub fn generate_pq_data_from_pivots_from_membuf>( ) .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?; - let mut data = vector_data + let data = vector_data .iter() .map(|x| (*x).into()) .collect::>(); - // Validate centroid dimensionality is correct (if provided). - // Furthermore, if the centroid is provided, use it to adjust our local copy of the - // data. - centroid.map_or(Ok(()), |centroid_unwrapped| -> ANNResult<()> { - if centroid_unwrapped.len() != vector_data.len() { - return Err(ANNError::log_pq_error( - "Error: centroids vector size does not match dimension!", - )); - } - for (dim_index, item) in data.iter_mut().enumerate() { - *item -= centroid_unwrapped[dim_index]; - } - Ok(()) - })?; - table .compress_into(data.as_slice(), pq_out) .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err))) @@ -610,7 +581,6 @@ pub fn generate_pq_data_from_pivots_from_membuf_batch parameters: &GeneratePivotArguments, vector_data: &[T], pivot_data: &[f32], - centroid: &[f32], offsets: &[usize], pq_out: &mut [u8], pool: RayonThreadPoolRef<'_>, @@ -633,8 +603,6 @@ pub fn generate_pq_data_from_pivots_from_membuf_batch "Error: Invalid PQ buffer input size.", )); } - let translate_to_center = parameters.translate_to_center(); - let centroid_option: Option<&[f32]> = translate_to_center.then_some(centroid); pq_out .par_chunks_mut(num_pq_chunks) @@ -644,7 +612,6 @@ pub fn generate_pq_data_from_pivots_from_membuf_batch vector_slice, pivot_data, parameters.num_centers(), - centroid_option, offsets, pq_slice, ) @@ -706,7 +673,8 @@ mod pq_test { ]; let pool = create_thread_pool_for_test(); generate_pq_pivots( - GeneratePivotArguments::new(5, 8, 2, 2, 5, true).unwrap(), + GeneratePivotArguments::new(5, 8, 2, 2, 5).unwrap(), + true, &mut train_data, &pq_storage, &storage_provider, @@ -747,11 +715,9 @@ mod pq_test { } #[rstest] - #[case(false, 2)] - #[case(true, 2)] - #[case(false, 3)] - #[case(true, 3)] - fn generate_pq_pivots_membuf_test(#[case] make_zero_mean: bool, #[case] num_pq_chunks: usize) { + #[case(2)] + #[case(3)] + fn generate_pq_pivots_membuf_test(#[case] num_pq_chunks: usize) { let num_train = 5; let dim = 8; let num_centers = 2; @@ -764,21 +730,11 @@ mod pq_test { ]; let mut full_pivot_data: Vec = vec![0.0; num_centers * dim]; - let mut centroids: Vec = vec![0.0; dim]; let mut offsets: Vec = vec![0; num_pq_chunks + 1]; let pool = create_thread_pool_for_test(); let result = generate_pq_pivots_from_membuf( - &GeneratePivotArguments::new( - num_train, - dim, - num_centers, - num_pq_chunks, - 5, - make_zero_mean, - ) - .unwrap(), + &GeneratePivotArguments::new(num_train, dim, num_centers, num_pq_chunks, 5).unwrap(), &train_data, // train_data - &mut centroids, &mut offsets, &mut full_pivot_data, &mut crate::utils::create_rnd_in_tests(), @@ -813,9 +769,9 @@ mod pq_test { num_centers, num_pq_chunks, max_k_means_reps, - true, ) .unwrap(), + true, &mut train_data, &pq_storage, &storage_provider, @@ -857,7 +813,8 @@ mod pq_test { PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, Some(data_file)); let pool = create_thread_pool_for_test(); generate_pq_pivots( - GeneratePivotArguments::new(5, 8, 2, 2, 5, true).unwrap(), + GeneratePivotArguments::new(5, 8, 2, 2, 5).unwrap(), + true, &mut train_data, &pq_storage, &storage_provider, @@ -892,14 +849,9 @@ mod pq_test { } #[rstest] - #[case(false, 2)] - #[case(true, 2)] - #[case(false, 3)] - #[case(true, 3)] - fn generate_pq_data_from_pivots_membuf_test( - #[case] make_zero_mean: bool, - #[case] num_pq_chunks: usize, - ) { + #[case(2)] + #[case(3)] + fn generate_pq_data_from_pivots_membuf_test(#[case] num_pq_chunks: usize) { let num_train: usize = 5; let dim: usize = 8; let num_centers: usize = 2; @@ -912,7 +864,6 @@ mod pq_test { 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, ]; - let mut centroids: Vec = vec![f32::MAX; dim]; let mut offsets: Vec = vec![usize::MAX; num_pq_chunks + 1]; let mut pivot_data: Vec = vec![f32::MAX; num_centers * dim]; let pool = create_thread_pool_for_test(); @@ -923,11 +874,9 @@ mod pq_test { num_centers, num_pq_chunks, max_k_means_reps, - make_zero_mean, ) .unwrap(), &train_data, - &mut centroids, &mut offsets, &mut pivot_data, &mut crate::utils::create_rnd_in_tests(), @@ -942,18 +891,12 @@ mod pq_test { &train_data[dim * i..dim * (i + 1)], &pivot_data, num_centers, - make_zero_mean.then_some(¢roids), &offsets, &mut pq, ) .unwrap(); } - // Check if any value is equal to max - assert!( - !centroids.contains(&f32::MAX), - "centroids contains max value!" - ); assert!( !offsets.contains(&usize::MAX), "offsets contains max value!" @@ -962,22 +905,15 @@ mod pq_test { !pivot_data.contains(&f32::MAX), "pivot_data contains max value!" ); - - if !make_zero_mean { - assert!( - centroids.iter().all(|&x| x == 0.0), - "centroids is not all 0" - ); - } } #[rstest] - #[case(true, 16)] - #[case(true, 32)] - #[case(true, 17)] - #[case(true, 13)] + #[case(false, 16)] + #[case(false, 32)] + #[case(false, 17)] + #[case(false, 13)] fn verify_identical_results_for_membuf_api( - #[case] make_zero_mean: bool, + #[case] legacy_center_data: bool, #[case] num_pq_chunks: usize, ) { // Creates a new filesystem using a read/write MemoryFS with PhysicalFS as a fall-back read-only filesystem. @@ -1006,9 +942,9 @@ mod pq_test { NUM_PQ_CENTROIDS, num_pq_chunks, NUM_KMEANS_REPS_PQ, - false, ) .expect("Failed to create pivot parameters"), + legacy_center_data, &mut full_data_vector, &pq_storage, &storage_provider, @@ -1053,7 +989,6 @@ mod pq_test { &full_data_vector[train_dim * i..train_dim * (i + 1)], &full_pivot_data, NUM_PQ_CENTROIDS, - make_zero_mean.then_some(¢roid), &offsets, membuf_slice, ) @@ -1150,7 +1085,6 @@ mod pq_test { RandGenStrategy::RandDivByRand )] rand_strategy: RandGenStrategy, - #[values(false, true)] make_zero_mean: bool, #[values(256)] npts: usize, #[case] dim: usize, #[case] num_pq_chunks: usize, @@ -1185,7 +1119,6 @@ mod pq_test { // Generate pivot data let mut full_pivot_data: Vec = vec![0.0; NUM_PQ_CENTROIDS * dim]; - let mut centroids: Vec = vec![0.0; dim]; let mut offsets: Vec = vec![0; num_pq_chunks + 1]; let pool = create_thread_pool_for_test(); let result = generate_pq_pivots_from_membuf( @@ -1195,11 +1128,9 @@ mod pq_test { NUM_PQ_CENTROIDS, num_pq_chunks, crate::model::pq::pq_construction::NUM_KMEANS_REPS_PQ, - make_zero_mean, ) .unwrap(), &full_data_vector, - &mut centroids, &mut offsets, &mut full_pivot_data, &mut crate::utils::create_rnd_in_tests(), @@ -1214,7 +1145,6 @@ mod pq_test { &full_data_vector[(dim * i)..(dim * (i + 1))], &full_pivot_data, NUM_PQ_CENTROIDS, - make_zero_mean.then_some(¢roids), &offsets, &mut membuf_pq_data, ); @@ -1336,7 +1266,6 @@ mod pq_test { // Generate pivot data let mut full_pivot_data: Vec = vec![0.0; NUM_PQ_CENTROIDS * train_dim]; - let mut centroid: Vec = vec![0.0; train_dim]; let mut offsets: Vec = vec![0; num_pq_chunks + 1]; let pivot_args = GeneratePivotArguments::new( train_size, @@ -1344,7 +1273,6 @@ mod pq_test { NUM_PQ_CENTROIDS, num_pq_chunks, crate::model::pq::pq_construction::NUM_KMEANS_REPS_PQ, - false, ) .unwrap(); let pool = create_thread_pool_for_test(); @@ -1352,7 +1280,6 @@ mod pq_test { generate_pq_pivots_from_membuf( &pivot_args, &train_data_vector, - &mut centroid, &mut offsets, &mut full_pivot_data, &mut crate::utils::create_rnd_in_tests(), @@ -1377,7 +1304,6 @@ mod pq_test { pivot_args.num_centers(), pivot_args.num_pq_chunks(), pivot_args.max_k_means_reps(), - pivot_args.translate_to_center(), ) .unwrap(); @@ -1387,20 +1313,14 @@ mod pq_test { &pivot_args, &full_data_vector, &full_pivot_data, - ¢roid, &offsets, &mut pq_data, pool.as_ref(), ) .unwrap(); - let fixed_chunk_pq_table = FixedChunkPQTable::new( - train_dim, - full_pivot_data.into(), - centroid.clone().into(), - offsets.into(), - ) - .unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(train_dim, full_pivot_data.into(), offsets.into()).unwrap(); // Hook into here to test pairwise distances. let pairs = [(0, 1), (1, 0), (10, 10), (23, 42)]; @@ -1410,9 +1330,7 @@ mod pq_test { let self_l2 = fixed_chunk_pq_table.qq_l2_distance(left, right); - let mut inflated = fixed_chunk_pq_table.inflate_vector(left); - fixed_chunk_pq_table.preprocess_query(&mut inflated); - + let inflated = fixed_chunk_pq_table.inflate_vector(left); let from_inflated = fixed_chunk_pq_table.l2_distance(&inflated, right); assert_relative_eq!(self_l2, from_inflated, max_relative = 1e-6); } diff --git a/diskann-providers/src/storage/pq_storage.rs b/diskann-providers/src/storage/pq_storage.rs index c54cef944..7985a8998 100644 --- a/diskann-providers/src/storage/pq_storage.rs +++ b/diskann-providers/src/storage/pq_storage.rs @@ -17,7 +17,7 @@ use rand::Rng; use tracing::info; use crate::{ - model::{FixedChunkPQTable, NUM_PQ_CENTROIDS, pq::METADATA_SIZE}, + model::{FixedChunkPQTable, NUM_PQ_CENTROIDS, accum_row_inplace, pq::METADATA_SIZE}, utils::{gen_random_slice, read_bin_from, write_bin_from}, }; @@ -61,16 +61,22 @@ impl PQStorage { Ok(()) } - /// Write the pivot table to file + /// Write the pivot table to file. + /// /// # Arguments /// * `full_pivot_data` - the pivot table data - /// * `centroid` - the centroid of the pivot table + /// * `centroid` - Optional per-dimension centroid. Pass `None` for the standard + /// (non-legacy) code path; a zero vector of length `dim` is written to preserve + /// the on-disk file format. Pass `Some(centroid)` only when legacy centroid + /// centering is enabled (see [`GeneratePivotArguments::with_legacy_centering`]). /// * `chunk_offsets` - the chunk offsets of the pivot table /// * `num_centers` - the number of centers /// * `dim` - the dimension of the pivot table /// * `storage_provider` - the storage provider + /// /// # Return /// * `Result` - the result of writing the pivot table + /// /// # Remarks /// * 4k bytes are reserved for metadata at the beginning of the file /// * the metadata is written in the following order: @@ -84,7 +90,7 @@ impl PQStorage { pub fn write_pivot_data( &self, full_pivot_data: &[f32], - centroid: &[f32], + centroid: Option<&[f32]>, chunk_offsets: &[usize], num_centers: usize, dim: usize, @@ -105,7 +111,11 @@ impl PQStorage { cumul_bytes[1] = cumul_bytes[0] + write_bin(pivot_view, writer)?; // Write the centroid of PQ centroid vectors - cumul_bytes[2] = cumul_bytes[1] + write_bin(MatrixView::column_vector(centroid), writer)?; + let centroid_bytes = match centroid { + Some(centroid) => write_bin(MatrixView::column_vector(centroid), writer)?, + None => write_bin(Matrix::::new(0.0, dim, 1).as_view(), writer)?, + }; + cumul_bytes[2] = cumul_bytes[1] + centroid_bytes; // Write PQ chunk offsets let chunk_offsets_u32: Vec = chunk_offsets.iter().map(|&x| x as u32).collect(); @@ -146,6 +156,16 @@ impl PQStorage { Ok(Metadata::read(reader)?.into_dims()) } + /// Load the raw pivot data, centroid, and chunk offsets from a pivot file. + /// + /// Unlike [`Self::load_pq_pivots_bin`], this method returns the centroid + /// separately without folding it into the pivot data. Callers that need the + /// effective (centroid-adjusted) pivots must apply the centroid themselves, + /// e.g. via [`accum_row_inplace`](crate::model::pq::accum_row_inplace). + /// + /// For files written without legacy centering (`centroid = None` in + /// [`Self::write_pivot_data`]), the returned centroid will be all zeros and + /// can safely be accumulated as a no-op. pub fn load_existing_pivot_data( &self, num_pq_chunks: &usize, @@ -278,7 +298,7 @@ impl PQStorage { } let file_offset_data = offsets.map(|x| x.into_usize()); - let pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)])?; + let mut pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)])?; if pivots.nrows() > NUM_PQ_CENTROIDS { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", @@ -316,12 +336,13 @@ impl PQStorage { } let chunk_offsets = chunk_offsets_m.map(|x| x.into_usize()); - FixedChunkPQTable::new( - dim, - pivots.into_inner(), - centroids.into_inner(), - chunk_offsets.into_inner(), - ) + // If the centroid is non-zero, we need to add it to the pivots to restore the + // numeric behavior. + if centroids.as_slice().iter().any(|c| *c != 0.0) { + accum_row_inplace(pivots.as_mut_view(), centroids.as_slice()) + } + + FixedChunkPQTable::new(dim, pivots.into_inner(), chunk_offsets.into_inner()) } /// streams data from the file, and samples each vector with probability p_val @@ -445,6 +466,115 @@ mod pq_storage_tests { assert_eq!(chunk_offsets.len(), 2); } + /// Write pivot data with `centroid = None`, read it back via + /// `load_existing_pivot_data`, and verify the pivots are unchanged and the + /// centroid is all zeros. + #[test] + fn write_read_roundtrip_no_centroid() { + let storage_provider = VirtualStorageProvider::new_memory(); + let pivot_path = "/roundtrip_no_centroid_pivots.bin"; + + let num_centers = 3; + let dim = 4; + let num_pq_chunks = 2; + let pivots: Vec = (0..num_centers * dim).map(|i| i as f32).collect(); + let chunk_offsets = vec![0, 2, dim]; + + let pq_storage = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, None); + pq_storage + .write_pivot_data( + &pivots, + None, + &chunk_offsets, + num_centers, + dim, + &storage_provider, + ) + .unwrap(); + + let (loaded_pivots, loaded_centroid, loaded_offsets) = pq_storage + .load_existing_pivot_data(&num_pq_chunks, &num_centers, &dim, &storage_provider) + .unwrap(); + + assert_eq!( + loaded_pivots, pivots, + "pivots should survive the round-trip unchanged" + ); + assert!( + loaded_centroid.iter().all(|&c| c == 0.0), + "centroid should be all zeros when written with None" + ); + assert_eq!(loaded_offsets, chunk_offsets); + } + + /// Write pivot data with a non-zero centroid, read it back, and verify that + /// folding the centroid via `accum_row_inplace` produces the expected + /// adjusted pivots. + #[test] + fn write_read_roundtrip_with_legacy_centroid() { + use crate::model::pq::accum_row_inplace; + use diskann_utils::views::MutMatrixView; + + let storage_provider = VirtualStorageProvider::new_memory(); + let pivot_path = "/roundtrip_legacy_centroid_pivots.bin"; + + let num_centers = 3; + let dim = 4; + let num_pq_chunks = 2; + let pivots: Vec = (0..num_centers * dim).map(|i| i as f32).collect(); + let centroid: Vec = vec![10.0, 20.0, 30.0, 40.0]; + let chunk_offsets = vec![0, 2, dim]; + + let pq_storage = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, None); + pq_storage + .write_pivot_data( + &pivots, + Some(¢roid), + &chunk_offsets, + num_centers, + dim, + &storage_provider, + ) + .unwrap(); + + let (mut loaded_pivots, loaded_centroid, loaded_offsets) = pq_storage + .load_existing_pivot_data(&num_pq_chunks, &num_centers, &dim, &storage_provider) + .unwrap(); + + assert_eq!( + loaded_pivots, pivots, + "raw pivots should match what was written" + ); + assert_eq!( + loaded_centroid, centroid, + "centroid should round-trip exactly" + ); + assert_eq!(loaded_offsets, chunk_offsets); + + // Fold the centroid into the pivots — this is what production callers do. + let mut pivot_mat = + MutMatrixView::try_from(loaded_pivots.as_mut_slice(), num_centers, dim).unwrap(); + accum_row_inplace(pivot_mat.as_mut_view(), &loaded_centroid); + + // Each pivot row should have the centroid added element-wise. + for (idx, (pivot, &orig)) in loaded_pivots.iter().zip(pivots.iter()).enumerate() { + let d = idx % dim; + let expected = orig + centroid[d]; + assert_eq!( + *pivot, expected, + "pivot[{}]: expected {expected}, got {pivot}", + idx + ); + } + + // Check that `load_pq_pivots_bin` correctly does the centroid folding. + let table = pq_storage + .load_pq_pivots_bin(pivot_path, num_pq_chunks, &storage_provider) + .unwrap(); + + assert_eq!(loaded_pivots, table.view_pivots().as_slice()); + } + #[test] fn gen_random_slice_test() { let storage_provider = VirtualStorageProvider::new_memory(); diff --git a/diskann-tools/src/utils/build_pq.rs b/diskann-tools/src/utils/build_pq.rs index bdb73e51e..15321cfca 100644 --- a/diskann-tools/src/utils/build_pq.rs +++ b/diskann-tools/src/utils/build_pq.rs @@ -74,8 +74,8 @@ pub fn build_pq( NUM_PQ_CENTROIDS, num_pq_chunks, NUM_KMEANS_REPS_PQ, - false, )?, + false, &mut train_data_vector, &pq_storage, &storage_provider,