Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ mod tests {
let d = provider.distance_computer();
assert_eq!(
d.evaluate_similarity(
&provider.get_vector_sync(0).unwrap(),
&provider.get_vector_sync(3).unwrap()
provider.get_vector_sync(0).unwrap().as_slice(),
provider.get_vector_sync(3).unwrap().as_slice(),
),
2.0
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,13 @@ mod tests {
// Distance Computer.
let d = provider.distance_computer();
assert_eq!(
d.evaluate_similarity(&provider.get_vector_sync(0), &provider.get_vector_sync(3)),
d.evaluate_similarity(provider.get_vector_sync(0), provider.get_vector_sync(3),),
2.0
);

let slice: &[f32] = &[-0.5, -0.5];
assert_eq!(
d.evaluate_similarity(slice, &provider.get_vector_sync(3)),
d.evaluate_similarity(slice, provider.get_vector_sync(3)),
expected,
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use diskann::{ANNError, ANNResult};
#[cfg(test)]
use diskann_quantization::CompressInto;
use diskann_utils::object_pool::ObjectPool;
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
use diskann_vector::distance::Metric;

use super::{VectorGuard, common::TestCallCount};
#[cfg(test)]
Expand Down Expand Up @@ -283,37 +283,6 @@ impl storage::bin::GetData for MemoryQuantVectorProviderAsync {
}
}

/// Overload `DistanceFunction` for `Guard<Arc<Vec<u8>>>` by dereferencing the
/// guard to a slice.
impl DistanceFunction<&[f32], &Guard<Arc<Vec<u8>>>, f32> for DistanceComputer {
#[inline(always)]
fn evaluate_similarity(&self, left: &[f32], right: &Guard<Arc<Vec<u8>>>) -> f32 {
let right: &[u8] = right;
self.evaluate_similarity(left, right)
}
}

/// Overload `DistanceFunction` for `Guard<Arc<Vec<u8>>>` by dereferencing the
/// guard to a slice.
impl DistanceFunction<&Guard<Arc<Vec<u8>>>, &Guard<Arc<Vec<u8>>>, f32> for DistanceComputer {
#[inline(always)]
fn evaluate_similarity(&self, left: &Guard<Arc<Vec<u8>>>, right: &Guard<Arc<Vec<u8>>>) -> f32 {
let left: &[u8] = left;
let right: &[u8] = right;
self.evaluate_similarity(left, right)
}
}

/// Overload `PreprocessedDistanceFunction` for `Guard<Arc<Vec<u8>>>` by dereferencing the
/// guard to a slice.
impl PreprocessedDistanceFunction<&Guard<Arc<Vec<u8>>>, f32> for QueryComputer {
#[inline(always)]
fn evaluate_similarity(&self, changing: &Guard<Arc<Vec<u8>>>) -> f32 {
let changing: &[u8] = changing;
self.evaluate_similarity(changing)
}
}

///////////
// Tests //
///////////
Expand All @@ -322,6 +291,8 @@ impl PreprocessedDistanceFunction<&Guard<Arc<Vec<u8>>>, f32> for QueryComputer {
mod tests {
use std::num::NonZeroUsize;

use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};

use crate::storage::VirtualStorageProvider;

use super::*;
Expand Down Expand Up @@ -372,23 +343,23 @@ mod tests {
let c = provider.query_computer(&[-0.5, -0.5]).unwrap();
let expected: f32 = 1.5 * 1.5 * 2.0;
assert_eq!(
c.evaluate_similarity(&provider.get_vector_sync(3).unwrap()),
c.evaluate_similarity(provider.get_vector_sync(3).unwrap().as_slice()),
expected
);

// Distance Computer.
let d = provider.distance_computer();
assert_eq!(
d.evaluate_similarity(
&provider.get_vector_sync(0).unwrap(),
&provider.get_vector_sync(3).unwrap()
provider.get_vector_sync(0).unwrap().as_slice(),
provider.get_vector_sync(3).unwrap().as_slice(),
),
2.0
);

let slice: &[f32] = &[-0.5, -0.5];
assert_eq!(
d.evaluate_similarity(slice, &provider.get_vector_sync(3).unwrap()),
d.evaluate_similarity(slice, &**provider.get_vector_sync(3).unwrap()),
Comment thread
hildebrandmw marked this conversation as resolved.
Outdated
expected,
);
}
Expand Down
46 changes: 0 additions & 46 deletions diskann-providers/src/model/pq/distance/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,52 +233,6 @@ where
}
}

/// Perform a comparison between a full-precision vector and quantized vector.
impl<T> DistanceFunction<&[f32], &&[u8], f32> for DistanceComputer<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
#[inline(always)]
fn evaluate_similarity(&self, fp: &[f32], q: &&[u8]) -> f32 {
let q: &[u8] = q;
self.evaluate_similarity(fp, q)
}
}

impl<T> DistanceFunction<&[f32], &Vec<u8>, f32> for DistanceComputer<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
#[inline(always)]
fn evaluate_similarity(&self, fp: &[f32], q: &Vec<u8>) -> f32 {
self.evaluate_similarity(fp, q.as_slice())
}
}

/// Perform a comparison between two quantized vectors.
impl<T> DistanceFunction<&&[u8], &&[u8], f32> for DistanceComputer<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
#[inline(always)]
fn evaluate_similarity(&self, q0: &&[u8], q1: &&[u8]) -> f32 {
let q0: &[u8] = q0;
let q1: &[u8] = q1;
self.evaluate_similarity(q0, q1)
}
}

/// Perform a comparison between two quantized vectors.
impl<T> DistanceFunction<&Vec<u8>, &Vec<u8>, f32> for DistanceComputer<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
#[inline(always)]
fn evaluate_similarity(&self, q0: &Vec<u8>, q1: &Vec<u8>) -> f32 {
self.evaluate_similarity(q0.as_slice(), q1.as_slice())
}
}

#[cfg(test)]
mod tests {
use std::marker::PhantomData;
Expand Down
Loading