Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions diskann-providers/src/model/pq/distance/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ mod tests {
config.start_value,
);

let expected = reference.evaluate_similarity(&expected0, &expected1);
let expected = reference.evaluate_similarity(&*expected0, &*expected1);

// Test full-precision/quant.
let got = computer
Expand Down Expand Up @@ -699,9 +699,9 @@ mod tests {
);

// Generate reference results.
let oo = reference.evaluate_similarity(&old_expected, &old_expected);
let nn = reference.evaluate_similarity(&new_expected, &new_expected);
let on = reference.evaluate_similarity(&old_expected, &new_expected);
let oo = reference.evaluate_similarity(&*old_expected, &*old_expected);
let nn = reference.evaluate_similarity(&*new_expected, &*new_expected);
let on = reference.evaluate_similarity(&*old_expected, &*new_expected);

// Quant + Quant
{
Expand Down
3 changes: 2 additions & 1 deletion diskann-vector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ iai-callgrind.workspace = true
rand.workspace = true
criterion.workspace = true
rand_distr.workspace = true
half = { workspace = true, features = ["rand_distr", "num-traits"] }
half = { workspace = true, features = ["rand_distr", "num-traits", "bytemuck"] }
bytemuck = { workspace = true, features = ["must_cast"] }
Comment thread
arkrishn94 marked this conversation as resolved.

[[bench]]
name = "bench_main"
Expand Down
53 changes: 42 additions & 11 deletions diskann-vector/src/distance/distance_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use diskann_wide::arch::aarch64::Neon;
use diskann_wide::arch::x86_64::{V3, V4};
use diskann_wide::{
arch::{Dispatched2, FTarget2, Scalar},
lifetime::Ref,
Architecture,
};
use half::f16;

use super::{implementations::Specialize, Cosine, CosineNormalized, InnerProduct, SquaredL2};
use crate::distance::Metric;
use crate::{distance::Metric, AsUnaligned, UnalignedSlice};

/// Return a function pointer-like [`Distance`] to compute the requested metric.
///
Expand Down Expand Up @@ -46,6 +45,16 @@ pub trait DistanceProvider<T>: Sized + 'static {
fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<Self, T>;
}

#[derive(Debug)]
struct Unaligned<T>(std::marker::PhantomData<T>);

impl<T> diskann_wide::lifetime::AddLifetime for Unaligned<T>
where
T: 'static,
{
type Of<'a> = UnalignedSlice<'a, T>;
}
Comment thread
hildebrandmw marked this conversation as resolved.

/// A function pointer-like type for computing distances between `&[T]` and `&[U]`.
///
/// See: [`DistanceProvider`].
Expand All @@ -55,26 +64,37 @@ where
T: 'static,
U: 'static,
{
f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>,
f: Dispatched2<f32, Unaligned<T>, Unaligned<U>>,
}

impl<T, U> Distance<T, U>
where
T: 'static,
U: 'static,
{
fn new(f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>) -> Self {
fn new(f: Dispatched2<f32, Unaligned<T>, Unaligned<U>>) -> Self {
Self { f }
}

/// Compute a distances between `x` and `y`.
/// Compute the distance between `x` and `y`.
///
/// The actual distance computed depends on the metric supplied to [`DistanceProvider`].
///
/// Additionally, if a dimension were given to [`DistanceProvider`], this function may
/// panic if provided with slices with a length not equal to this dimension.
#[inline]
pub fn call(&self, x: &[T], y: &[U]) -> f32 {
self.f.call(x.as_unaligned(), y.as_unaligned())
}

/// Compute the distance between `x` and `y`.
///
/// The actual distance computed depends on the metric supplied to [`DistanceProvider`].
///
/// Additionally, if a dimension were given to [`DistanceProvider`], this function may
/// panic if provided with slices with a length not equal to this dimension.
#[inline]
pub fn call_unaligned(&self, x: UnalignedSlice<'_, T>, y: UnalignedSlice<'_, U>) -> f32 {
self.f.call(x, y)
}
}
Expand All @@ -89,6 +109,17 @@ where
}
}

impl<T, U> crate::DistanceFunction<UnalignedSlice<'_, T>, UnalignedSlice<'_, U>, f32>
for Distance<T, U>
where
T: 'static,
U: 'static,
{
fn evaluate_similarity(&self, x: UnalignedSlice<'_, T>, y: UnalignedSlice<'_, U>) -> f32 {
self.f.call(x, y)
}
}

////////////////////
// Implementation //
////////////////////
Expand Down Expand Up @@ -305,7 +336,7 @@ struct Spec<const N: usize>;
impl<A, F, const N: usize, T, U> TrySpecialize<A, F, T, U> for Spec<N>
where
A: Architecture,
Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, UnalignedSlice<'a, T>, UnalignedSlice<'b, U>>,
T: 'static,
U: 'static,
{
Expand All @@ -314,7 +345,7 @@ where
if d == N {
return Some(Distance::new(
// NOTE: This line here is what actually compiles the specialized kernel.
arch.dispatch2::<Specialize<N, F>, f32, Ref<[T]>, Ref<[U]>>(),
arch.dispatch2::<Specialize<N, F>, f32, Unaligned<T>, Unaligned<U>>(),
));
}
}
Expand Down Expand Up @@ -352,7 +383,7 @@ impl<Head, Tail> Cons<Head, Tail> {
fn specialize<A, F, T, U>(&self, arch: A, _f: F, dim: Option<usize>) -> Distance<T, U>
where
A: Architecture,
F: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
F: for<'a, 'b> FTarget2<A, f32, UnalignedSlice<'a, T>, UnalignedSlice<'b, U>>,
Head: TrySpecialize<A, F, T, U>,
Tail: TrySpecialize<A, F, T, U>,
T: 'static,
Expand All @@ -361,7 +392,7 @@ impl<Head, Tail> Cons<Head, Tail> {
if let Some(f) = self.try_specialize(arch, dim) {
f
} else {
Distance::new(arch.dispatch2::<F, f32, Ref<[T]>, Ref<[U]>>())
Distance::new(arch.dispatch2::<F, f32, Unaligned<T>, Unaligned<U>>())
}
}
}
Expand Down Expand Up @@ -465,7 +496,7 @@ mod test_unaligned_distance_provider {
// Unwrap the SimilarityScore for the reference implementation.
let converted = |a: &[T], b: &[T]| -> f32 { reference(a, b).into_inner() };

let checker = test_util::Checker::<T, T, f32>::new(
let mut checker = test_util::Checker::<T, T, f32>::new(
|a, b| under_test.call(a, b),
converted,
|got: f32, expected: f32| {
Expand All @@ -479,7 +510,7 @@ mod test_unaligned_distance_provider {
);

test_util::test_distance_function(
checker,
&mut checker,
distribution.clone(),
distribution.clone(),
dim,
Expand Down
92 changes: 38 additions & 54 deletions diskann-vector/src/distance/implementations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,36 @@ use diskann_wide::{arch::Target2, Architecture, ARCH};

/// Experimental traits for distance functions.
use super::simd;
use crate::{Half, MathematicalValue, PureDistanceFunction, SimilarityScore};

trait ToSlice {
type Target;
fn to_slice(&self) -> &[Self::Target];
}

impl<T> ToSlice for &[T] {
type Target = T;
fn to_slice(&self) -> &[T] {
self
}
}
impl<T, const N: usize> ToSlice for &[T; N] {
type Target = T;
fn to_slice(&self) -> &[T] {
&self[..]
}
}
impl<T, const N: usize> ToSlice for [T; N] {
type Target = T;
fn to_slice(&self) -> &[T] {
&self[..]
}
}
use crate::{
AsUnaligned, Half, MathematicalValue, PureDistanceFunction, SimilarityScore, UnalignedSlice,
};

macro_rules! architecture_hook {
($functor:ty, $impl:path) => {
impl<A, T, L, R> diskann_wide::arch::Target2<A, T, L, R> for $functor
where
A: Architecture,
L: ToSlice,
R: ToSlice,
$impl: simd::SIMDSchema<L::Target, R::Target, A>,
Self: PostOp<<$impl as simd::SIMDSchema<L::Target, R::Target, A>>::Return, T>,
L: AsUnaligned,
R: AsUnaligned,
$impl: simd::SIMDSchema<L::Element, R::Element, A>,
Comment thread
hildebrandmw marked this conversation as resolved.
Self: PostOp<<$impl as simd::SIMDSchema<L::Element, R::Element, A>>::Return, T>,
{
#[inline(always)]
fn run(self, arch: A, left: L, right: R) -> T {
Self::post_op(simd::simd_op(
&$impl,
arch,
left.to_slice(),
right.to_slice(),
left.as_unaligned(),
right.as_unaligned(),
))
}
}

impl<A, T, L, R> diskann_wide::arch::FTarget2<A, T, L, R> for $functor
where
A: Architecture,
L: ToSlice,
R: ToSlice,
L: AsUnaligned,
R: AsUnaligned,
Self: diskann_wide::arch::Target2<A, T, L, R>,
{
#[inline(always)]
Expand All @@ -73,35 +51,31 @@ macro_rules! architecture_hook {
#[derive(Debug, Clone, Copy)]
pub(crate) struct Specialize<const N: usize, F>(std::marker::PhantomData<F>);

impl<A, T, L, R, const N: usize, F> diskann_wide::arch::FTarget2<A, T, &[L], &[R]>
impl<A, T, L, R, const N: usize, F>
diskann_wide::arch::FTarget2<A, T, UnalignedSlice<'_, L>, UnalignedSlice<'_, R>>
for Specialize<N, F>
where
A: Architecture,
F: for<'a, 'b> diskann_wide::arch::Target2<A, T, &'a [L; N], &'b [R; N]> + Default,
F: for<'a, 'b> diskann_wide::arch::Target2<A, T, UnalignedSlice<'a, L>, UnalignedSlice<'b, R>>
+ Default,
Comment thread
hildebrandmw marked this conversation as resolved.
{
#[inline(always)]
fn run(arch: A, x: &[L], y: &[R]) -> T {
fn run(arch: A, x: UnalignedSlice<'_, L>, y: UnalignedSlice<'_, R>) -> T {
if (x.len() != N) | (y.len() != N) {
fail_length_check(x, y, N);
}

// SAFETY: We have checked that both arguments have the correct length.
//
// The alignment requirements of arrays are the alignment requirements of
// `Left` and `Right` respectively, which is provided by the corresponding slices.
arch.run2(
F::default(),
unsafe { &*(x.as_ptr() as *const [L; N]) },
unsafe { &*(y.as_ptr() as *const [R; N]) },
)
// The validation of `x.len()` and `y.len()` is sufficient (and indeed necessary)
// to trigger constant propagation and unrolling.
arch.run2(F::default(), x, y)
}
}

// Outline the panic formatting and keep the calling convention the same as
// the top function. This keeps code generation extremely lightweight.
#[inline(never)]
#[allow(clippy::panic)]
fn fail_length_check<L, R>(x: &[L], y: &[R], len: usize) -> ! {
fn fail_length_check<L, R>(x: UnalignedSlice<'_, L>, y: UnalignedSlice<'_, R>, len: usize) -> ! {
let message = if x.len() != len {
("first", x.len())
} else {
Expand Down Expand Up @@ -136,6 +110,7 @@ macro_rules! use_simd_implementation {
<$functor>::default().run(ARCH, x, y)
}
}

// Statically Sized
impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], SimilarityScore<f32>>
for $functor
Expand Down Expand Up @@ -566,8 +541,11 @@ mod tests {
let (x, y) = random_normal_arguments(DIM, -100.0, 100.0, 0x023457AA);

let reference: f32 = SquaredL2::evaluate(x.as_slice(), y.as_slice());
let evaluated: f32 =
Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
let evaluated: f32 = Specialize::<DIM, SquaredL2>::run(
diskann_wide::ARCH,
x.as_slice().as_unaligned(),
y.as_slice().as_unaligned(),
);

// Equality should be exact.
assert_eq!(reference, evaluated);
Expand All @@ -582,8 +560,11 @@ mod tests {
let x = vec![0.0f32; DIM + 1];
let y = vec![0.0f32; DIM];
// Since `x` does not have the correct dimensions, this should panic.
let _: f32 =
Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
let _: f32 = Specialize::<DIM, SquaredL2>::run(
diskann_wide::ARCH,
x.as_slice().as_unaligned(),
y.as_slice().as_unaligned(),
);
}

#[test]
Expand All @@ -595,8 +576,11 @@ mod tests {
let x = vec![0.0f32; DIM];
let y = vec![0.0f32; DIM + 1];
// Since `y` does not have the correct dimensions, this should panic.
let _: f32 =
Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
let _: f32 = Specialize::<DIM, SquaredL2>::run(
diskann_wide::ARCH,
x.as_slice().as_unaligned(),
y.as_slice().as_unaligned(),
);
}

////////////////////
Expand Down Expand Up @@ -650,7 +634,7 @@ mod tests {
To: GetInner + Copy,
Callback: FnMut(To, To),
{
let checker =
let mut checker =
test_util::Checker::<L, R, To>::new(under_test, reference, |got, expected| {
// Invoke the callback with the received numbers.
cb(got, expected);
Expand All @@ -663,7 +647,7 @@ mod tests {
});

test_util::test_distance_function(
checker,
&mut checker,
distribution.clone(),
distribution.clone(),
dim,
Expand Down
Loading
Loading