Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion diskann-benchmark/src/backend/exhaustive/product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ mod imp {
let quantizer = diskann_providers::model::pq::FixedChunkPQTable::new(
data.ncols(),
base.flatten().into(),
vec![0.0; data.ncols()].into(),
offsets.into(),
)?;

Expand Down
2 changes: 1 addition & 1 deletion diskann-disk/src/build/builder/quantizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl BuildQuantizer {
PQStorage::new(&pq_paths.pivots, &pq_paths.compressed_data, None);
pq_build_storage.write_pivot_data(
table.get_pq_table(),
table.get_centroids(),
None,
table.get_chunk_offsets(),
table.get_num_centers(),
table.get_dim(),
Expand Down
82 changes: 25 additions & 57 deletions diskann-disk/src/search/pq/quantizer_preprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use diskann_providers::model::compute_pq_distance;
use diskann_providers::utils::BridgeErr;

use super::{PQData, PQScratch};
use crate::storage::quant::pq::pq_dataset::PQTable;

/// Preprocesses the query vector for PQ distance calculations.
/// This function rotates the query vector and prepares the PQ table distances
Expand All @@ -21,64 +20,33 @@ pub fn quantizer_preprocess(
metric: Metric,
id_to_calculate_pq_distance: &[u32],
) -> ANNResult<()> {
match &pq_data.pq_table() {
PQTable::Transposed(table) => {
let dim = table.dim();
let expected_len = table.ncenters() * table.nchunks();
let dst = diskann_utils::views::MutMatrixView::try_from(
&mut (*pq_scratch.aligned_pqtable_dist_scratch)[..expected_len],
table.nchunks(),
table.ncenters(),
)
.bridge_err()?;
let table = pq_data.pq_table();
let dim = table.dim();
let expected_len = table.ncenters() * table.nchunks();
let dst = diskann_utils::views::MutMatrixView::try_from(
&mut (*pq_scratch.aligned_pqtable_dist_scratch)[..expected_len],
table.nchunks(),
table.ncenters(),
)
.bridge_err()?;

match metric {
// Prior to the introduction of the `quantizer_preprocess` method, the
// disk index was hard-coded to use L2 distance for comparisons.
//
// We're keeping that behavior here - treating `Cosine` and `CosineNormalized`
// as L2 until a more thorough evaluation can be made.
Metric::L2 | Metric::Cosine | Metric::CosineNormalized => {
table.process_into::<diskann_quantization::distances::SquaredL2>(
&pq_scratch.rotated_query[..dim],
dst,
);
}
Metric::InnerProduct => {
table.process_into::<diskann_quantization::distances::InnerProduct>(
&pq_scratch.rotated_query[..dim],
dst,
);
}
}
match metric {
// Prior to the introduction of the `quantizer_preprocess` method, the
// disk index was hard-coded to use L2 distance for comparisons.
//
// We're keeping that behavior here - treating `Cosine` and `CosineNormalized`
// as L2 until a more thorough evaluation can be made.
Metric::L2 | Metric::Cosine | Metric::CosineNormalized => {
table.process_into::<diskann_quantization::distances::SquaredL2>(
&pq_scratch.rotated_query[..dim],
dst,
);
}
PQTable::Fixed(table) => {
match metric {
// Prior to the introduction of the `quantizer_preprocess` method, the
// disk index was hard-coded to use L2 distance for comparisons.
//
// We're keeping that behavior here - treating `Cosine` and `CosineNormalized`
// as L2 until a more thorough evaluation can be made.
Metric::L2 | Metric::Cosine | Metric::CosineNormalized => {
// The scratch only stores the aligned dimension. However, preprocessing
// wants the actual dimension used, so we have to shrink the rotated query
// accordingly.
let dim = table.get_dim();
table.preprocess_query(&mut pq_scratch.rotated_query[..dim]);

// Compute the distance between each chunk of the query to each pq centroids.
table.populate_chunk_distances(
pq_scratch.rotated_query.as_slice(),
&mut pq_scratch.aligned_pqtable_dist_scratch,
)?;
}
Metric::InnerProduct => {
table.populate_chunk_inner_products(
pq_scratch.rotated_query.as_slice(),
&mut pq_scratch.aligned_pqtable_dist_scratch,
)?;
}
}
Metric::InnerProduct => {
table.process_into::<diskann_quantization::distances::InnerProduct>(
&pq_scratch.rotated_query[..dim],
dst,
);
}
}

Expand Down
2 changes: 1 addition & 1 deletion diskann-disk/src/storage/quant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub use generator::{GeneratorContext, QuantDataGenerator};

pub(crate) mod pq;
pub use pq::pq_generation::{PQGeneration, PQGenerationContext};
pub use pq::{PQData, PQTable};
pub use pq::PQData;

mod compressor;
pub use compressor::{CompressionStage, QuantCompressor};
1 change: 0 additions & 1 deletion diskann-disk/src/storage/quant/pq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@

pub(crate) mod pq_dataset;
pub use pq_dataset::PQData;
pub use pq_dataset::PQTable;

pub mod pq_generation;
52 changes: 11 additions & 41 deletions diskann-disk/src/storage/quant/pq/pq_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,10 @@ use diskann_providers::model::FixedChunkPQTable;
use diskann_quantization::product::TransposedTable;
use diskann_utils::views::Matrix;

/// Behind the scenes, we can use either the [`FixedChunkPQTable`] or a
/// [`diskann_quantization::product::TransposedTable`]. The [`TransposedTable`] is much faster
/// for preprocessing, but does not support removal of the dataset centroid.
///
/// So, we can only use the [`TransposedTable`] when the dataset centroid
/// is all zero.
#[derive(Debug)]
pub enum PQTable {
Transposed(TransposedTable),
Fixed(FixedChunkPQTable),
}

#[derive(Debug)]
pub struct PQData {
// pq pivot table.
pq_pivot_table: PQTable,
pq_pivot_table: TransposedTable,

// pq compressed vectors, shape `num_points × num_pq_chunks`.
pq_compressed_data: Matrix<u8>,
Expand All @@ -36,18 +24,11 @@ impl PQData {
pq_pivot_table: FixedChunkPQTable,
pq_compressed_data: Matrix<u8>,
) -> ANNResult<Self> {
// Check if we can use the transposed table. If so, go for it.
let centroid_is_zero = pq_pivot_table.get_centroids().iter().all(|i| *i == 0.0);
let pq_pivot_table = if centroid_is_zero {
let transposed = TransposedTable::from_parts(
pq_pivot_table.view_pivots(),
pq_pivot_table.view_offsets().to_owned(),
)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
PQTable::Transposed(transposed)
} else {
PQTable::Fixed(pq_pivot_table)
};
let pq_pivot_table = TransposedTable::from_parts(
pq_pivot_table.view_pivots(),
pq_pivot_table.view_offsets().to_owned(),
)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;

Ok(Self {
pq_pivot_table,
Expand All @@ -56,24 +37,18 @@ impl PQData {
}

/// Get pq_table
pub fn pq_table(&self) -> &PQTable {
pub fn pq_table(&self) -> &TransposedTable {
&self.pq_pivot_table
}

/// Return the number of chunks in the underlying PQ schema.
pub fn get_num_chunks(&self) -> usize {
match &self.pq_pivot_table {
PQTable::Transposed(table) => table.nchunks(),
PQTable::Fixed(table) => table.get_num_chunks(),
}
self.pq_pivot_table.nchunks()
Comment on lines 39 to +46
}

/// Return the number of centers in the underlying PQ schema.
pub fn get_num_centers(&self) -> usize {
match &self.pq_pivot_table {
PQTable::Transposed(table) => table.ncenters(),
PQTable::Fixed(table) => table.get_num_centers(),
}
self.pq_pivot_table.ncenters()
}

/// Get pq_compressed_data
Expand All @@ -97,13 +72,8 @@ mod tests {
fn create_pq_data() -> ANNResult<PQData> {
let dim = 2;

let pq_pivot_table = FixedChunkPQTable::new(
dim,
Box::new([0.0, 0.0, 1.0, 1.0]),
Box::new([0.0, 0.0]),
Box::new([0, 2]),
)
.unwrap();
let pq_pivot_table =
FixedChunkPQTable::new(dim, Box::new([0.0, 0.0, 1.0, 1.0]), Box::new([0, 2])).unwrap();
let pq_compressed_data = Matrix::try_from(Box::new([123u8, 111, 255]) as Box<[u8]>, 3, 1)
.expect("valid matrix shape");

Expand Down
14 changes: 4 additions & 10 deletions diskann-disk/src/storage/quant/pq/pq_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ where
context.num_centers,
context.num_chunks,
context.max_kmeans_reps,
context.metric == Metric::L2,
)?,
context.metric == Metric::L2,
&mut train_data,
&context.pq_storage,
context.storage_provider,
Expand Down Expand Up @@ -260,15 +260,9 @@ mod pq_generation_tests {

let pool = create_thread_pool_for_test();
generate_pq_pivots(
GeneratePivotArguments::new(
ndata,
dim,
num_centers,
num_chunks,
max_k_means_reps,
true,
)
.unwrap(),
GeneratePivotArguments::new(ndata, dim, num_centers, num_chunks, max_k_means_reps)
.unwrap(),
true,
&mut train_data,
&pq_storage,
&storage_provider,
Expand Down
23 changes: 5 additions & 18 deletions diskann-providers/src/index/diskann_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,21 @@ pub fn train_pq(
model::pq::NUM_PQ_CENTROIDS,
num_pq_chunks,
5,
false,
)?;
let mut centroid = vec![0.0; dim];
let mut offsets = vec![0; num_pq_chunks + 1];
let mut full_pivot_data = vec![0.0; model::pq::NUM_PQ_CENTROIDS * dim];

model::pq::generate_pq_pivots_from_membuf(
&pivot_args,
data.as_slice(),
&mut centroid,
&mut offsets,
&mut full_pivot_data,
rng,
&mut (false),
pool,
)?;

model::pq::FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into())
model::pq::FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into())
}

pub type MemoryIndex<T, D = NoDeletes> = Arc<DiskANNIndex<FullPrecisionProvider<T, NoStore, D>>>;
Expand Down Expand Up @@ -1559,13 +1556,8 @@ pub(crate) mod tests {
)
.unwrap();

let pqtable = model::pq::FixedChunkPQTable::new(
dim,
Box::new([0.0, 0.0]),
Box::new([0.0, 0.0]),
Box::new([0, 2]),
)
.unwrap();
let pqtable =
model::pq::FixedChunkPQTable::new(dim, Box::new([0.0, 0.0]), Box::new([0, 2])).unwrap();

let index =
new_quant_index::<f32, _, _>(config, parameters, pqtable, TableBasedDeletes).unwrap();
Expand Down Expand Up @@ -1680,13 +1672,8 @@ pub(crate) mod tests {
)
.unwrap();

let pqtable = model::pq::FixedChunkPQTable::new(
dim,
Box::new([0.0, 0.0]),
Box::new([0.0, 0.0]),
Box::new([0, 2]),
)
.unwrap();
let pqtable =
model::pq::FixedChunkPQTable::new(dim, Box::new([0.0, 0.0]), Box::new([0, 2])).unwrap();

let index =
new_quant_index::<f32, _, _>(config, parameters, pqtable, TableBasedDeletes).unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2156,7 +2156,7 @@ where
let pq_table = &self.quant_vectors.pq_chunk_table;
pq_storage.write_pivot_data(
pq_table.get_pq_table(),
pq_table.get_centroids(),
None,
pq_table.get_chunk_offsets(),
NUM_PQ_CENTROIDS,
pq_table.get_dim(),
Expand Down Expand Up @@ -2683,7 +2683,6 @@ mod tests {
let pq_table = FixedChunkPQTable::new(
dim,
vec![0.0; dim * 256].into_boxed_slice(),
vec![0.0; dim].into_boxed_slice(),
Box::new([0, 4, dim]),
)
.unwrap();
Expand Down Expand Up @@ -2785,11 +2784,6 @@ mod tests {
loaded_pq.get_pq_table(),
"PQ table data mismatch"
);
assert_eq!(
original_pq.get_centroids(),
loaded_pq.get_centroids(),
"PQ table centroids mismatch"
);
assert_eq!(
original_pq.get_chunk_offsets(),
loaded_pq.get_chunk_offsets(),
Expand Down Expand Up @@ -2986,7 +2980,6 @@ mod tests {
let pq_table = FixedChunkPQTable::new(
dim,
vec![0.0; dim * 256].into_boxed_slice(),
vec![0.0; dim].into_boxed_slice(),
Box::new([0, 4, dim]),
)
.unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,11 @@ mod tests {
#[tokio::test]
async fn common_errors() {
let dim = 5;
let centroid = vec![0.0; dim];
let offsets = vec![0, dim];
let full_pivot_data = vec![0.0; 256 * dim];

let pq_chunk_table =
FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into())
.unwrap();
FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into()).unwrap();

let bf_tree_config = Config::default();
let provider =
Expand Down Expand Up @@ -306,7 +304,6 @@ mod tests {
let table = FixedChunkPQTable::new(
dim,
Box::new([0.0, 0.0, 1.0, 1.0, 2.0, 2.0]),
Box::new([0.0, 0.0]),
Box::new([0, dim]),
)
.unwrap();
Expand Down Expand Up @@ -379,12 +376,10 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_parallel_tree_traversal() {
let dim = 2;
let centroid = vec![0.0; dim];
let offsets = vec![0, dim];
let full_pivot_data = vec![0.0; 256 * dim];
let pq_chunk_table =
FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into())
.unwrap();
FixedChunkPQTable::new(dim, full_pivot_data.into(), offsets.into()).unwrap();

let bf_tree_config = Config::default();
let provider = Arc::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ impl TestMultiPQProviderAsync {
&vector_f32,
table.get_pq_table(),
table.get_num_centers(),
Some(table.get_centroids()),
table.get_chunk_offsets(),
&mut quant_vector,
)
Expand Down
Loading
Loading