Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 27 additions & 38 deletions diskann-benchmark-runner/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will
/// first be validated with the benchmark using [`try_match`](Self::try_match). Only
/// successful matches will be passed to [`run`](Self::run).
pub trait Benchmark {
pub trait Benchmark: 'static {
Comment thread
hildebrandmw marked this conversation as resolved.
/// The [`Input`] type this benchmark matches against.
type Input: Input + 'static;

Expand All @@ -32,14 +32,15 @@ pub trait Benchmark {
/// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`]
/// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
/// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
fn try_match(input: &Self::Input) -> Result<MatchScore, FailureScore>;
fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;

/// Return descriptive information about the benchmark.
///
/// If `input` is `None`, then high level information about the benchmark should be relayed.
/// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what
/// was expected should be generated to help users.
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&Self::Input>,
) -> std::fmt::Result;
Expand All @@ -52,6 +53,7 @@ pub trait Benchmark {
///
/// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`.
fn run(
&self,
input: &Self::Input,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
Expand Down Expand Up @@ -88,6 +90,7 @@ pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
/// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type
/// for reporting upstream.
fn check(
&self,
tolerances: &Self::Tolerances,
input: &Self::Input,
before: &Self::Output,
Expand All @@ -109,8 +112,6 @@ pub enum PassFail<P, F> {
pub(crate) mod internal {
use super::*;

use std::marker::PhantomData;

use anyhow::Context;
use thiserror::Error;

Expand Down Expand Up @@ -176,38 +177,32 @@ pub(crate) mod internal {
}
}

pub(crate) trait AsRegression {
fn as_regression(&self) -> Option<&dyn Regression>;
pub(crate) trait AsRegression<T> {
fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub(crate) struct NoRegression;

impl AsRegression for NoRegression {
fn as_regression(&self) -> Option<&dyn Regression> {
impl<T> AsRegression<T> for NoRegression {
fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
None
}
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct WithRegression<T>(PhantomData<T>);
pub(crate) struct WithRegression;

impl<T> WithRegression<T> {
pub(crate) const fn new() -> Self {
Self(PhantomData)
}
}

impl<T> AsRegression for WithRegression<T>
impl<T> AsRegression<T> for WithRegression
where
T: super::Regression,
{
fn as_regression(&self) -> Option<&dyn Regression> {
Some(self)
fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
Some(benchmark)
}
}

impl<T> Regression for WithRegression<T>
impl<T> Regression for T
where
T: super::Regression,
{
Comment thread
hildebrandmw marked this conversation as resolved.
Expand Down Expand Up @@ -242,7 +237,7 @@ pub(crate) mod internal {
let after = T::Output::deserialize(after)
.map_err(|err| DeserializationError::new(Kind::After, err))?;

let passfail = match T::check(tolerance, input, &before, &after)? {
let passfail = match self.check(tolerance, input, &before, &after)? {
Comment thread
hildebrandmw marked this conversation as resolved.
PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
};
Expand All @@ -253,21 +248,15 @@ pub(crate) mod internal {

#[derive(Debug, Clone, Copy)]
pub(crate) struct Wrapper<T, R = NoRegression> {
regression: R,
_type: PhantomData<T>,
}

impl<T> Wrapper<T, NoRegression> {
pub(crate) const fn new() -> Self {
Self::new_with(NoRegression)
}
benchmark: T,
_regression: R,
}

impl<T, R> Wrapper<T, R> {
pub(crate) const fn new_with(regression: R) -> Self {
pub(crate) const fn new(benchmark: T, regression: R) -> Self {
Self {
regression,
_type: PhantomData,
benchmark,
_regression: regression,
}
}
}
Expand All @@ -278,11 +267,11 @@ pub(crate) mod internal {
impl<T, R> Benchmark for Wrapper<T, R>
where
T: super::Benchmark,
R: AsRegression,
R: AsRegression<T>,
{
fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
if let Some(cast) = input.downcast_ref::<T::Input>() {
T::try_match(cast)
self.benchmark.try_match(cast)
} else {
Err(MATCH_FAIL)
}
Expand All @@ -295,7 +284,7 @@ pub(crate) mod internal {
) -> std::fmt::Result {
match input {
Some(input) => match input.downcast_ref::<T::Input>() {
Some(cast) => T::description(f, Some(cast)),
Some(cast) => self.benchmark.description(f, Some(cast)),
None => write!(
f,
"expected tag \"{}\" - instead got \"{}\"",
Expand All @@ -305,7 +294,7 @@ pub(crate) mod internal {
},
None => {
writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
T::description(f, None)
self.benchmark.description(f, None)
}
}
}
Expand All @@ -318,7 +307,7 @@ pub(crate) mod internal {
) -> anyhow::Result<serde_json::Value> {
match input.downcast_ref::<T::Input>() {
Some(input) => {
let result = T::run(input, checkpoint, output)?;
let result = self.benchmark.run(input, checkpoint, output)?;
Ok(serde_json::to_value(result)?)
}
None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
Expand All @@ -327,7 +316,7 @@ pub(crate) mod internal {

// Extensions
fn as_regression(&self) -> Option<&dyn Regression> {
self.regression.as_regression()
R::as_regression(&self.benchmark)
}
}

Expand Down
18 changes: 11 additions & 7 deletions diskann-benchmark-runner/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,16 @@ impl Benchmarks {
}

/// Register a new benchmark with the given name.
pub fn register<T>(&mut self, name: impl Into<String>)
pub fn register<T>(&mut self, name: impl Into<String>, benchmark: T)
where
T: Benchmark + 'static,
T: Benchmark,
{
self.benchmarks.push(RegisteredBenchmark {
name: name.into(),
benchmark: Box::new(benchmark::internal::Wrapper::<T>::new()),
benchmark: Box::new(benchmark::internal::Wrapper::<T, _>::new(
benchmark,
benchmark::internal::NoRegression,
)),
});
}

Expand Down Expand Up @@ -212,12 +215,13 @@ impl Benchmarks {
///
/// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark
/// itself will be reachable via [`Check`](crate::app::Check).
pub fn register_regression<T>(&mut self, name: impl Into<String>)
pub fn register_regression<T>(&mut self, name: impl Into<String>, benchmark: T)
where
T: Regression + 'static,
T: Regression,
{
let registered = benchmark::internal::Wrapper::<T, _>::new_with(
benchmark::internal::WithRegression::<T>::new(),
let registered = benchmark::internal::Wrapper::<T, _>::new(
benchmark,
benchmark::internal::WithRegression,
);
self.benchmarks.push(RegisteredBenchmark {
name: name.into(),
Expand Down
19 changes: 15 additions & 4 deletions diskann-benchmark-runner/src/test/dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,19 @@ impl Benchmark for SimpleBench {
type Input = DimInput;
type Output = usize;

fn try_match(input: &DimInput) -> Result<MatchScore, FailureScore> {
fn try_match(&self, input: &DimInput) -> Result<MatchScore, FailureScore> {
if input.dim.is_none() {
Ok(MatchScore(0))
} else {
Err(FailureScore(1000))
}
}

fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result {
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&DimInput>,
) -> std::fmt::Result {
match input {
Some(input) if input.dim.is_none() => write!(f, "successful match"),
Some(_) => write!(f, "expected dim=None"),
Expand All @@ -116,6 +120,7 @@ impl Benchmark for SimpleBench {
}

fn run(
&self,
input: &DimInput,
_checkpoint: Checkpoint<'_>,
mut output: &mut dyn Output,
Expand All @@ -133,11 +138,15 @@ impl Benchmark for DimBench {
type Input = DimInput;
type Output = usize;

fn try_match(_input: &DimInput) -> Result<MatchScore, FailureScore> {
fn try_match(&self, _input: &DimInput) -> Result<MatchScore, FailureScore> {
Ok(MatchScore(0))
}

fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result {
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&DimInput>,
) -> std::fmt::Result {
if input.is_some() {
write!(f, "perfect match")
} else {
Expand All @@ -146,6 +155,7 @@ impl Benchmark for DimBench {
}

fn run(
&self,
input: &DimInput,
_checkpoint: Checkpoint<'_>,
mut output: &mut dyn Output,
Expand All @@ -161,6 +171,7 @@ impl Regression for DimBench {
type Fail = &'static str;

fn check(
&self,
tolerance: &Tolerance,
input: &DimInput,
before: &usize,
Expand Down
13 changes: 8 additions & 5 deletions diskann-benchmark-runner/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ pub fn register_inputs(inputs: &mut registry::Inputs) -> anyhow::Result<()> {
}

pub fn register_benchmarks(benchmarks: &mut registry::Benchmarks) {
benchmarks.register_regression::<typed::TypeBench<f32>>("type-bench-f32");
benchmarks.register_regression::<typed::TypeBench<i8>>("type-bench-i8");
benchmarks.register_regression::<typed::ExactTypeBench<f32, 1000>>("exact-type-bench-f32-1000");
benchmarks.register_regression("type-bench-f32", typed::TypeBench::<f32>::new());
benchmarks.register_regression("type-bench-i8", typed::TypeBench::<i8>::new());
benchmarks.register_regression(
"exact-type-bench-f32-1000",
typed::ExactTypeBench::<f32, 1000>::new(),
);

benchmarks.register::<dim::SimpleBench>("simple-bench");
benchmarks.register_regression::<dim::DimBench>("dim-bench");
benchmarks.register("simple-bench", dim::SimpleBench);
benchmarks.register_regression("dim-bench", dim::DimBench);
}
Loading
Loading