Skip to content

Replace sgemm() panics with Result-based error handling#997

Open
arrayka wants to merge 9 commits intomainfrom
u/arrayka/sgemm_panic
Open

Replace sgemm() panics with Result-based error handling#997
arrayka wants to merge 9 commits intomainfrom
u/arrayka/sgemm_panic

Conversation

@arrayka
Copy link
Copy Markdown
Contributor

@arrayka arrayka commented Apr 29, 2026

Summary

Refactors diskann_linalg::sgemm() to return Result<(), SgemmError> instead of panicking on invalid inputs. This change improves error handling and adds overflow protection for matrix dimension calculations.

Principle: DiskANN must not panic in production for recoverable errors; panics are reserved for catastrophic failures indicating undefined behavior.

Changes

1. New Error Type (SgemmError)

  • Added SgemmError enum with two variants:
    • InvalidMatrixDimensions: Matrix has incorrect dimensions
    • DimensionOverflow: Multiplication would overflow usize
  • Implemented Display and std::error::Error traits for proper error reporting

2. Updated sgemm() Function Signature

Before:

pub fn sgemm(...) { ... }

After:

pub fn sgemm(...) -> Result<(), SgemmError> { ... }

3. Overflow Protection

Replaced unsafe multiplication with checked_mul() for all dimension calculations:

  • m * k validation (matrix a)
  • k * n validation (matrix b)
  • m * n validation (matrix c)

This prevents integer overflow vulnerabilities that could lead to undefined behavior.

4. Updated Call Sites

Updated all callers of sgemm() to handle the Result:

  • diskann-quantization: Propagates errors through TransformFailed::SgemmError variant
  • diskann-disk: Converts errors to ANNError with descriptive messages
  • diskann-linalg tests: Added wrapper function for compatibility

5. Comprehensive Test Coverage

Added 6 new unit tests:

  • test_sgemm_invalid_matrix_a_dimensions: Validates dimension check for matrix a
  • test_sgemm_invalid_matrix_b_dimensions: Validates dimension check for matrix b
  • test_sgemm_invalid_matrix_c_dimensions: Validates dimension check for matrix c
  • test_sgemm_m_times_k_overflow: Tests overflow detection for m * k
  • test_sgemm_k_times_n_overflow: Tests overflow detection for k * n
  • test_sgemm_m_times_n_overflow: Tests overflow detection for m * n

Review Order

  1. Core Changes (5 min)

📁 diskann-linalg/src/lib.rs

  • Lines 15-50: New SgemmError type
  • Lines 155-202: Overflow protection with checked_mul()
  1. Tests (3 min)

📁 diskann-linalg/src/lib.rs (lines 318-461)

  • 6 new tests covering all error variants
  • Verify error messages are tested exactly
  1. Call Sites (2 min)

📁 diskann-disk/src/utils/math_util.rs (lines 164-204)

  • Three .map_err() conversions to ANNError

@arrayka arrayka linked an issue Apr 29, 2026 that may be closed by this pull request
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 29, 2026

Codecov Report

❌ Patch coverage is 97.31183% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 89.51%. Comparing base (767879f) to head (37d4543).

Files with missing lines Patch % Lines
diskann-quantization/src/spherical/quantizer.rs 0.00% 3 Missing ⚠️
diskann-linalg/src/lib.rs 99.44% 1 Missing ⚠️
...ation/src/algorithms/transforms/random_rotation.rs 0.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #997      +/-   ##
==========================================
+ Coverage   89.48%   89.51%   +0.02%     
==========================================
  Files         448      448              
  Lines       84095    84260     +165     
==========================================
+ Hits        75250    75422     +172     
+ Misses       8845     8838       -7     
Flag Coverage Δ
miri 89.51% <97.31%> (+0.02%) ⬆️
unittests 89.35% <97.31%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
diskann-disk/src/utils/math_util.rs 98.82% <100.00%> (+<0.01%) ⬆️
...nn-quantization/src/algorithms/transforms/utils.rs 90.24% <ø> (ø)
diskann-linalg/src/lib.rs 98.55% <99.44%> (+5.43%) ⬆️
...ation/src/algorithms/transforms/random_rotation.rs 99.25% <0.00%> (-0.75%) ⬇️
diskann-quantization/src/spherical/quantizer.rs 95.06% <0.00%> (-0.22%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@arrayka arrayka marked this pull request as ready for review April 29, 2026 03:03
@arrayka arrayka requested review from a team and Copilot April 29, 2026 03:03
@arrayka arrayka enabled auto-merge (squash) April 29, 2026 03:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors diskann_linalg::sgemm() to return Result<(), SgemmError> instead of panicking, adds overflow-safe dimension validation via checked_mul(), and updates downstream call sites/tests to handle the new error path.

Changes:

  • Introduces SgemmError and converts sgemm()’s dimension checks from assert_eq! to Result-based validation with overflow protection.
  • Updates callers (notably diskann-disk and diskann-quantization’s linalg-backed transform) to handle sgemm() errors.
  • Adds unit tests in diskann-linalg covering invalid dimensions and overflow cases.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
diskann-linalg/src/lib.rs Adds SgemmError, changes sgemm() to return Result, adds overflow checks and new unit tests.
diskann-disk/src/utils/math_util.rs Converts 3 sgemm() call sites to propagate errors as ANNError.
diskann-quantization/src/algorithms/transforms/utils.rs Adds TransformFailed::SgemmError (cfg’d on linalg) and removes unconditional Copy derive.
diskann-quantization/src/algorithms/transforms/random_rotation.rs Propagates sgemm() errors via ? from transform_into().
diskann-quantization/src/spherical/quantizer.rs Adds match arms for TransformFailed::SgemmError, but currently handles it with panic!.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +11 to +21
#[derive(Debug, Clone, Error, PartialEq)]
pub enum TransformFailed {
#[error("incorrect transform input vector - expected length {expected} but got {found}")]
SourceMismatch { expected: usize, found: usize },
#[error("incorrect transform output vector - expected length {expected} but got {found}")]
DestinationMismatch { expected: usize, found: usize },
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
#[cfg(feature = "linalg")]
#[error("SGEMM operation failed: {0}")]
SgemmError(#[from] diskann_linalg::SgemmError),
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformFailed no longer derives Copy. If the goal is only to support SgemmError when feature = "linalg" is enabled, consider conditionally deriving Copy when the feature is disabled (e.g., cfg_attr(not(feature = "linalg"), derive(Copy))) to avoid an unnecessary API regression for non-linalg builds.

Copilot uses AI. Check for mistakes.
Comment on lines +765 to +766
Err(TransformFailed::SgemmError(_)) => {
panic!("SGEMM should not fail with valid dimensions - this is a logic error");
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformFailed::SgemmError is handled via panic!, which reintroduces a production panic path for a recoverable error coming from diskann_linalg::sgemm (notably DimensionOverflow can be triggered by large/untrusted dimensions). Instead of panicking here, propagate the error (e.g., add a CompressionError variant for SGEMM failures, or map it into an existing error type) so callers can handle it without crashing.

Suggested change
Err(TransformFailed::SgemmError(_)) => {
panic!("SGEMM should not fail with valid dimensions - this is a logic error");
Err(TransformFailed::SgemmError(err)) => {
return Err(CompressionError::SgemmError(err));

Copilot uses AI. Check for mistakes.
Comment on lines +843 to +844
Err(TransformFailed::SgemmError(_)) => {
panic!("SGEMM should not fail with valid dimensions - this is a logic error");
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This match arm panics on TransformFailed::SgemmError, which defeats the Result-based error handling in diskann_linalg::sgemm and can crash the process on DimensionOverflow / invalid dims. Please propagate or convert the error instead of panicking (similar to how AllocatorError is handled).

Suggested change
Err(TransformFailed::SgemmError(_)) => {
panic!("SGEMM should not fail with valid dimensions - this is a logic error");
Err(TransformFailed::SgemmError(err)) => {
return Err(err.into());

Copilot uses AI. Check for mistakes.
Comment on lines +1167 to +1168
Err(TransformFailed::SgemmError(_)) => {
panic!("SGEMM should not fail with valid dimensions - this is a logic error");
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Panicking on TransformFailed::SgemmError reintroduces a runtime panic path for what is now an explicit error condition (diskann_linalg::SgemmError). To match the PR’s stated goal (no panics for recoverable input issues), this should be returned/propagated as an error rather than crashing.

Suggested change
Err(TransformFailed::SgemmError(_)) => {
panic!("SGEMM should not fail with valid dimensions - this is a logic error");
Err(TransformFailed::SgemmError(err)) => {
return Err(err.into());

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

@hildebrandmw hildebrandmw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Alex - just a few general error handling comments from my end.

Comment thread diskann-linalg/src/lib.rs
expected_cols: usize,
expected_len: usize,
actual_len: usize,
},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a large error type (~48 bytes: 32 bytes for the usize values and 16 bytes for &'static str). This is mainly a concern when it gets rolled into higher level errors, such as TransformFailed. We can do a little bit to bring it down in size.

First, expected_len can be computed from expected_rows and expected_cols, so we can save 8 bytes there. Next, an enum can be used for matrix_name:

#[derive(Debug, Clone, Copy)]
pub enum MatrixName {
    A,
    B,
    C,
}

impl std::fmt::Display for MatrixName {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::A => write!(f, "a"),
            Self::B => write!(f, "b"),
            Self::C => write!(f, "c"), 
        }
    }
}

This will drop the size of the error down to 32 bytes, which is still a little big, but better.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, that was my original plan, but I decided that I'm overengineering it.
Fixed.

Comment thread diskann-linalg/src/lib.rs
c: &mut [f32],
) {
sgemm(atranspose, btranspose, m, n, k, alpha, a, b, beta, c).unwrap();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should update the signature of TestProblem::check to be Result<(), SgemmError> for uniformity. If you don't want to repeat all the checking logic, a trick I've used in the past is something like this:

impl SgemmError {
    fn check(m: usize, n: usize, ...) -> Result<(), Self>;
}

Then everything preflight with

fn sgemm(...) -> Result<(), SgemmError> {
    SgemmError::check(args...)?;
    sgemm_impl(...);
    Ok(())
}

Comment thread diskann-quantization/src/algorithms/transforms/utils.rs Outdated
Comment thread diskann-disk/src/utils/math_util.rs Outdated
dist_matrix,
);
)
.map_err(|e| ANNError::log_index_error(format_args!("{}", e)))?;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, you may want to prefer ANNError::new(ANNErrorKind::IndexError, e) instead of the old log* style constructors. The former defers any kind of string formatting until the entire error is formatter higher in the call stack, where-as the latter forces eager string evaluation lower in the callstack. Not that it matters that much though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done + created a task to remove the old log-style constructors: #1003

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

diskann_linalg::sgemm() should not panic, but should return Result()

5 participants