Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions diskann-providers/src/model/pq/distance/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ mod tests {
config.start_value,
);

let expected = reference.evaluate_similarity(&expected0, &expected1);
let expected = reference.evaluate_similarity(&*expected0, &*expected1);

// Test full-precision/quant.
let got = computer
Expand Down Expand Up @@ -699,9 +699,9 @@ mod tests {
);

// Generate reference results.
let oo = reference.evaluate_similarity(&old_expected, &old_expected);
let nn = reference.evaluate_similarity(&new_expected, &new_expected);
let on = reference.evaluate_similarity(&old_expected, &new_expected);
let oo = reference.evaluate_similarity(&*old_expected, &*old_expected);
let nn = reference.evaluate_similarity(&*new_expected, &*new_expected);
let on = reference.evaluate_similarity(&*old_expected, &*new_expected);

// Quant + Quant
{
Expand Down
3 changes: 2 additions & 1 deletion diskann-vector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ iai-callgrind.workspace = true
rand.workspace = true
criterion.workspace = true
rand_distr.workspace = true
half = { workspace = true, features = ["rand_distr", "num-traits"] }
half = { workspace = true, features = ["rand_distr", "num-traits", "bytemuck"] }
bytemuck = { workspace = true, features = ["must_cast"] }
Comment thread
arkrishn94 marked this conversation as resolved.

[[bench]]
name = "bench_main"
Expand Down
53 changes: 42 additions & 11 deletions diskann-vector/src/distance/distance_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use diskann_wide::arch::aarch64::Neon;
use diskann_wide::arch::x86_64::{V3, V4};
use diskann_wide::{
arch::{Dispatched2, FTarget2, Scalar},
lifetime::Ref,
Architecture,
};
use half::f16;

use super::{implementations::Specialize, Cosine, CosineNormalized, InnerProduct, SquaredL2};
use crate::distance::Metric;
use crate::{distance::Metric, AsUnaligned, UnalignedSlice};

/// Return a function pointer-like [`Distance`] to compute the requested metric.
///
Expand Down Expand Up @@ -46,6 +45,16 @@ pub trait DistanceProvider<T>: Sized + 'static {
fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<Self, T>;
}

#[derive(Debug)]
struct Unaligned<T>(std::marker::PhantomData<T>);

impl<T> diskann_wide::lifetime::AddLifetime for Unaligned<T>
where
T: 'static,
{
type Of<'a> = UnalignedSlice<'a, T>;
}
Comment thread
hildebrandmw marked this conversation as resolved.

/// A function pointer-like type for computing distances between `&[T]` and `&[U]`.
///
/// See: [`DistanceProvider`].
Expand All @@ -55,26 +64,37 @@ where
T: 'static,
U: 'static,
{
f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>,
f: Dispatched2<f32, Unaligned<T>, Unaligned<U>>,
}

impl<T, U> Distance<T, U>
where
T: 'static,
U: 'static,
{
fn new(f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>) -> Self {
fn new(f: Dispatched2<f32, Unaligned<T>, Unaligned<U>>) -> Self {
Self { f }
}

/// Compute a distances between `x` and `y`.
/// Compute the distance between `x` and `y`.
///
/// The actual distance computed depends on the metric supplied to [`DistanceProvider`].
///
/// Additionally, if a dimension were given to [`DistanceProvider`], this function may
/// panic if provided with slices with a length not equal to this dimension.
#[inline]
pub fn call(&self, x: &[T], y: &[U]) -> f32 {
self.call_unaligned(x.as_unaligned(), y.as_unaligned())
}

/// Compute the distance between `x` and `y`.
///
/// The actual distance computed depends on the metric supplied to [`DistanceProvider`].
///
/// Additionally, if a dimension were given to [`DistanceProvider`], this function may
/// panic if provided with slices with a length not equal to this dimension.
#[inline]
pub fn call_unaligned(&self, x: UnalignedSlice<'_, T>, y: UnalignedSlice<'_, U>) -> f32 {
self.f.call(x, y)
}
}
Expand All @@ -89,6 +109,17 @@ where
}
}

impl<T, U> crate::DistanceFunction<UnalignedSlice<'_, T>, UnalignedSlice<'_, U>, f32>
for Distance<T, U>
where
T: 'static,
U: 'static,
{
fn evaluate_similarity(&self, x: UnalignedSlice<'_, T>, y: UnalignedSlice<'_, U>) -> f32 {
self.f.call(x, y)
}
}

////////////////////
// Implementation //
////////////////////
Expand Down Expand Up @@ -305,7 +336,7 @@ struct Spec<const N: usize>;
impl<A, F, const N: usize, T, U> TrySpecialize<A, F, T, U> for Spec<N>
where
A: Architecture,
Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, UnalignedSlice<'a, T>, UnalignedSlice<'b, U>>,
T: 'static,
U: 'static,
{
Expand All @@ -314,7 +345,7 @@ where
if d == N {
return Some(Distance::new(
// NOTE: This line here is what actually compiles the specialized kernel.
arch.dispatch2::<Specialize<N, F>, f32, Ref<[T]>, Ref<[U]>>(),
arch.dispatch2::<Specialize<N, F>, f32, Unaligned<T>, Unaligned<U>>(),
));
}
}
Expand Down Expand Up @@ -352,7 +383,7 @@ impl<Head, Tail> Cons<Head, Tail> {
fn specialize<A, F, T, U>(&self, arch: A, _f: F, dim: Option<usize>) -> Distance<T, U>
where
A: Architecture,
F: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
F: for<'a, 'b> FTarget2<A, f32, UnalignedSlice<'a, T>, UnalignedSlice<'b, U>>,
Head: TrySpecialize<A, F, T, U>,
Tail: TrySpecialize<A, F, T, U>,
T: 'static,
Expand All @@ -361,7 +392,7 @@ impl<Head, Tail> Cons<Head, Tail> {
if let Some(f) = self.try_specialize(arch, dim) {
f
} else {
Distance::new(arch.dispatch2::<F, f32, Ref<[T]>, Ref<[U]>>())
Distance::new(arch.dispatch2::<F, f32, Unaligned<T>, Unaligned<U>>())
}
}
}
Expand Down Expand Up @@ -465,7 +496,7 @@ mod test_unaligned_distance_provider {
// Unwrap the SimilarityScore for the reference implementation.
let converted = |a: &[T], b: &[T]| -> f32 { reference(a, b).into_inner() };

let checker = test_util::Checker::<T, T, f32>::new(
let mut checker = test_util::Checker::<T, T, f32>::new(
|a, b| under_test.call(a, b),
converted,
|got: f32, expected: f32| {
Expand All @@ -479,7 +510,7 @@ mod test_unaligned_distance_provider {
);

test_util::test_distance_function(
checker,
&mut checker,
distribution.clone(),
distribution.clone(),
dim,
Expand Down
92 changes: 38 additions & 54 deletions diskann-vector/src/distance/implementations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,36 @@ use diskann_wide::{arch::Target2, Architecture, ARCH};

/// Experimental traits for distance functions.
use super::simd;
use crate::{Half, MathematicalValue, PureDistanceFunction, SimilarityScore};

trait ToSlice {
type Target;
fn to_slice(&self) -> &[Self::Target];
}

impl<T> ToSlice for &[T] {
type Target = T;
fn to_slice(&self) -> &[T] {
self
}
}
impl<T, const N: usize> ToSlice for &[T; N] {
type Target = T;
fn to_slice(&self) -> &[T] {
&self[..]
}
}
impl<T, const N: usize> ToSlice for [T; N] {
type Target = T;
fn to_slice(&self) -> &[T] {
&self[..]
}
}
use crate::{
AsUnaligned, Half, MathematicalValue, PureDistanceFunction, SimilarityScore, UnalignedSlice,
};

macro_rules! architecture_hook {
($functor:ty, $impl:path) => {
impl<A, T, L, R> diskann_wide::arch::Target2<A, T, L, R> for $functor
where
A: Architecture,
L: ToSlice,
R: ToSlice,
$impl: simd::SIMDSchema<L::Target, R::Target, A>,
Self: PostOp<<$impl as simd::SIMDSchema<L::Target, R::Target, A>>::Return, T>,
L: AsUnaligned,
R: AsUnaligned,
$impl: simd::SIMDSchema<L::Element, R::Element, A>,
Comment thread
hildebrandmw marked this conversation as resolved.
Self: PostOp<<$impl as simd::SIMDSchema<L::Element, R::Element, A>>::Return, T>,
{
#[inline(always)]
fn run(self, arch: A, left: L, right: R) -> T {
Self::post_op(simd::simd_op(
&$impl,
arch,
left.to_slice(),
right.to_slice(),
left.as_unaligned(),
right.as_unaligned(),
))
}
}

impl<A, T, L, R> diskann_wide::arch::FTarget2<A, T, L, R> for $functor
where
A: Architecture,
L: ToSlice,
R: ToSlice,
L: AsUnaligned,
R: AsUnaligned,
Self: diskann_wide::arch::Target2<A, T, L, R>,
{
#[inline(always)]
Expand All @@ -73,35 +51,31 @@ macro_rules! architecture_hook {
#[derive(Debug, Clone, Copy)]
pub(crate) struct Specialize<const N: usize, F>(std::marker::PhantomData<F>);

impl<A, T, L, R, const N: usize, F> diskann_wide::arch::FTarget2<A, T, &[L], &[R]>
impl<A, T, L, R, const N: usize, F>
diskann_wide::arch::FTarget2<A, T, UnalignedSlice<'_, L>, UnalignedSlice<'_, R>>
for Specialize<N, F>
where
A: Architecture,
F: for<'a, 'b> diskann_wide::arch::Target2<A, T, &'a [L; N], &'b [R; N]> + Default,
F: for<'a, 'b> diskann_wide::arch::Target2<A, T, UnalignedSlice<'a, L>, UnalignedSlice<'b, R>>
+ Default,
Comment thread
hildebrandmw marked this conversation as resolved.
{
#[inline(always)]
fn run(arch: A, x: &[L], y: &[R]) -> T {
fn run(arch: A, x: UnalignedSlice<'_, L>, y: UnalignedSlice<'_, R>) -> T {
if (x.len() != N) | (y.len() != N) {
fail_length_check(x, y, N);
}

// SAFETY: We have checked that both arguments have the correct length.
//
// The alignment requirements of arrays are the alignment requirements of
// `Left` and `Right` respectively, which is provided by the corresponding slices.
arch.run2(
F::default(),
unsafe { &*(x.as_ptr() as *const [L; N]) },
unsafe { &*(y.as_ptr() as *const [R; N]) },
)
// The validation of `x.len()` and `y.len()` is sufficient (and indeed necessary)
// to trigger constant propagation and unrolling.
arch.run2(F::default(), x, y)
}
}

// Outline the panic formatting and keep the calling convention the same as
// the top function. This keeps code generation extremely lightweight.
#[inline(never)]
#[allow(clippy::panic)]
fn fail_length_check<L, R>(x: &[L], y: &[R], len: usize) -> ! {
fn fail_length_check<L, R>(x: UnalignedSlice<'_, L>, y: UnalignedSlice<'_, R>, len: usize) -> ! {
let message = if x.len() != len {
("first", x.len())
} else {
Expand Down Expand Up @@ -136,6 +110,7 @@ macro_rules! use_simd_implementation {
<$functor>::default().run(ARCH, x, y)
}
}

// Statically Sized
impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], SimilarityScore<f32>>
for $functor
Expand Down Expand Up @@ -566,8 +541,11 @@ mod tests {
let (x, y) = random_normal_arguments(DIM, -100.0, 100.0, 0x023457AA);

let reference: f32 = SquaredL2::evaluate(x.as_slice(), y.as_slice());
let evaluated: f32 =
Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
let evaluated: f32 = Specialize::<DIM, SquaredL2>::run(
diskann_wide::ARCH,
x.as_slice().as_unaligned(),
y.as_slice().as_unaligned(),
);

// Equality should be exact.
assert_eq!(reference, evaluated);
Expand All @@ -582,8 +560,11 @@ mod tests {
let x = vec![0.0f32; DIM + 1];
let y = vec![0.0f32; DIM];
// Since `x` does not have the correct dimensions, this should panic.
let _: f32 =
Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
let _: f32 = Specialize::<DIM, SquaredL2>::run(
diskann_wide::ARCH,
x.as_slice().as_unaligned(),
y.as_slice().as_unaligned(),
);
}

#[test]
Expand All @@ -595,8 +576,11 @@ mod tests {
let x = vec![0.0f32; DIM];
let y = vec![0.0f32; DIM + 1];
// Since `y` does not have the correct dimensions, this should panic.
let _: f32 =
Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
let _: f32 = Specialize::<DIM, SquaredL2>::run(
diskann_wide::ARCH,
x.as_slice().as_unaligned(),
y.as_slice().as_unaligned(),
);
}

////////////////////
Expand Down Expand Up @@ -650,7 +634,7 @@ mod tests {
To: GetInner + Copy,
Callback: FnMut(To, To),
{
let checker =
let mut checker =
test_util::Checker::<L, R, To>::new(under_test, reference, |got, expected| {
// Invoke the callback with the received numbers.
cb(got, expected);
Expand All @@ -663,7 +647,7 @@ mod tests {
});

test_util::test_distance_function(
checker,
&mut checker,
distribution.clone(),
distribution.clone(),
dim,
Expand Down
Loading
Loading