Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
0a70420
Add Cache aware multi-vector distance functions
suri-kumkaran Mar 25, 2026
08459cf
Merge branch 'main' into users/suryangupta/multi-vector-distance-impl
suri-kumkaran Mar 31, 2026
0bf3fc0
Merge branch 'main' into users/suryangupta/multi-vector-distance-impl
suri-kumkaran Apr 1, 2026
359da08
Improve design - make it more extensible and generic
suri-kumkaran Apr 1, 2026
888de9f
cfg flag fix in tests
Apr 1, 2026
cfa8b76
Make design more solid and powerful and add f16 kernels
suri-kumkaran Apr 1, 2026
816057a
Merge branch 'main' into users/suryangupta/multi-vector-distance-impl
suri-kumkaran Apr 7, 2026
3222c93
Enable dyanmic dispatch of multi-vector distance function based on arch
suri-kumkaran Apr 8, 2026
3f7b544
Fix miri tests and increase code coverage
Apr 8, 2026
02a6acc
Use Target traits for runtime dispatch, add QueryComputer type that h…
Apr 10, 2026
b5c8895
Improve testing and code quality
suri-kumkaran Apr 10, 2026
aa1990b
Address Copilot review comments
suri-kumkaran Apr 13, 2026
2913912
Fix clippy
Apr 13, 2026
3258e67
Move preparation step to tiles rather than pannels
suri-kumkaran Apr 16, 2026
8a42b0f
Merge branch 'main' into users/suryangupta/multi-vector-distance-impl
suri-kumkaran Apr 17, 2026
2198497
Address review comments
suri-kumkaran Apr 21, 2026
793fb9b
Merge branch 'main' into users/suryangupta/multi-vector-distance-impl
suri-kumkaran Apr 21, 2026
e06c48a
Improve testing and address review comments
suri-kumkaran Apr 24, 2026
ff199a4
Address review comments
suri-kumkaran May 1, 2026
dffa75f
Merge branch 'main' into users/suryangupta/multi-vector-distance-impl
May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion diskann-quantization/src/multi_vector/block_transposed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};

use diskann_utils::{
ReborrowMut,
Reborrow, ReborrowMut,
strided::StridedView,
views::{MatrixView, MutMatrixView},
};
Expand Down Expand Up @@ -231,6 +231,15 @@ impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROU
self.nrows % GROUP
}

/// Total number of logical rows rounded up to the next multiple of `GROUP`.
///
/// This is the number of "available" row slots in the backing allocation,
/// including zero-padded rows in the last (possibly partial) block.
#[inline]
pub fn available_rows(&self) -> usize {
self.num_blocks() * GROUP
}

/// The stride (in elements) between the start of consecutive blocks.
#[inline]
fn block_stride(&self) -> usize {
Expand Down Expand Up @@ -743,6 +752,15 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a,
self.data.repr().remainder()
}

/// Total number of logical rows rounded up to the next multiple of `GROUP`.
///
/// This is the number of "available" row slots in the backing allocation,
/// including zero-padded rows in the last (possibly partial) block.
#[inline]
pub fn available_rows(&self) -> usize {
self.data.repr().available_rows()
}
Comment thread
suri-kumkaran marked this conversation as resolved.

/// Return a raw typed pointer to the start of the backing data.
#[inline]
pub fn as_ptr(&self) -> *const T {
Expand Down Expand Up @@ -870,6 +888,7 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a,
delegate_to_ref!(pub fn full_blocks(&self) -> usize);
delegate_to_ref!(pub fn num_blocks(&self) -> usize);
delegate_to_ref!(pub fn remainder(&self) -> usize);
delegate_to_ref!(pub fn available_rows(&self) -> usize);
delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
Expand Down Expand Up @@ -1017,6 +1036,7 @@ impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, P
delegate_to_ref!(pub fn full_blocks(&self) -> usize);
delegate_to_ref!(pub fn num_blocks(&self) -> usize);
delegate_to_ref!(pub fn remainder(&self) -> usize);
delegate_to_ref!(pub fn available_rows(&self) -> usize);
delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
Expand Down Expand Up @@ -1072,6 +1092,19 @@ impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, P
}
}

// ── Reborrow ─────────────────────────────────────────────────────

impl<'this, T: Copy, const GROUP: usize, const PACK: usize> Reborrow<'this>
for BlockTransposed<T, GROUP, PACK>
{
type Target = BlockTransposedRef<'this, T, GROUP, PACK>;

#[inline]
fn reborrow(&'this self) -> Self::Target {
self.as_view()
}
}

// ── Factory methods ──────────────────────────────────────────────

impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
Expand Down Expand Up @@ -1676,6 +1709,15 @@ mod tests {
}
}

// ── available_rows() returns padded row count ────────────

assert_eq!(
transpose.as_view().available_rows(),
padded_nrows,
"available_rows() mismatch -- {}",
context,
);

// ── from_matrix_view produces identical results ──────────

if nrows > 0 && ncols > 0 {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

//! Simple kernel implementation of multi-vector distance computation.
//! Fallback kernel implementation of multi-vector distance computation.

use std::ops::Deref;

Expand Down Expand Up @@ -49,17 +49,17 @@ impl<'a, T: Repr> Deref for QueryMatRef<'a, T> {
}
}

//////////////////
// SimpleKernel //
//////////////////
////////////////////
// FallbackKernel //
////////////////////

/// Simple double-loop kernel to compute max-sim distances over multi-vectors.
/// Fallback double-loop kernel to compute max-sim distances over multi-vectors.
///
/// This kernel performs a simple double-loop over the rows of `query`
/// and the `doc` and dispatches to [`InnerProduct`] to compute the similarity.
pub struct SimpleKernel;
pub struct FallbackKernel;

impl SimpleKernel {
impl FallbackKernel {
/// Core kernel for computing per-query-vector max similarities (min negated inner-product).
///
/// For each `query` vector, computes the maximum similarity (negated inner product)
Expand Down Expand Up @@ -128,7 +128,7 @@ where
return Err(MaxSimError::InvalidBufferLength(size, n_queries));
}

SimpleKernel::max_sim_kernel(query, doc, |i, score| {
FallbackKernel::max_sim_kernel(query, doc, |i, score| {
// SAFETY: We asserted that self.size() == query.num_vectors(),
// and i < query.num_vectors() due to the kernel loop bound.
unsafe { *self.scores.get_unchecked_mut(i) = score };
Expand All @@ -151,7 +151,7 @@ where
fn evaluate(query: QueryMatRef<'_, Standard<T>>, doc: MatRef<'_, Standard<T>>) -> f32 {
let mut sum = 0.0f32;

SimpleKernel::max_sim_kernel(query, doc, |_i, score| {
FallbackKernel::max_sim_kernel(query, doc, |_i, score| {
sum += score;
});

Expand Down Expand Up @@ -270,8 +270,8 @@ mod tests {
);
}

// Check that SimpleKernel is also correct.
SimpleKernel::max_sim_kernel(query, doc, |i, score| {
// Check that FallbackKernel is also correct.
FallbackKernel::max_sim_kernel(query, doc, |i, score| {
assert!((scores[i] - score).abs() <= 1e-6)
});

Expand Down Expand Up @@ -299,7 +299,7 @@ mod tests {
// No query vectors means sum is 0
assert_eq!(result, 0.0);

let result = Chamfer::evaluate(doc.into(), query.deref().reborrow());
let result = Chamfer::evaluate(QueryMatRef::from(doc), query.deref().reborrow());

assert_eq!(result, 0.0);
}
Expand Down
Loading
Loading