Replace sgemm() panics with Result-based error handling#997
Replace sgemm() panics with Result-based error handling#997
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #997 +/- ##
==========================================
+ Coverage 89.48% 89.51% +0.02%
==========================================
Files 448 448
Lines 84095 84260 +165
==========================================
+ Hits 75250 75422 +172
+ Misses 8845 8838 -7
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This PR refactors diskann_linalg::sgemm() to return Result<(), SgemmError> instead of panicking, adds overflow-safe dimension validation via checked_mul(), and updates downstream call sites/tests to handle the new error path.
Changes:
- Introduces
SgemmErrorand convertssgemm()’s dimension checks fromassert_eq!toResult-based validation with overflow protection. - Updates callers (notably
diskann-diskanddiskann-quantization’s linalg-backed transform) to handlesgemm()errors. - Adds unit tests in
diskann-linalgcovering invalid dimensions and overflow cases.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| diskann-linalg/src/lib.rs | Adds SgemmError, changes sgemm() to return Result, adds overflow checks and new unit tests. |
| diskann-disk/src/utils/math_util.rs | Converts 3 sgemm() call sites to propagate errors as ANNError. |
| diskann-quantization/src/algorithms/transforms/utils.rs | Adds TransformFailed::SgemmError (cfg’d on linalg) and removes unconditional Copy derive. |
| diskann-quantization/src/algorithms/transforms/random_rotation.rs | Propagates sgemm() errors via ? from transform_into(). |
| diskann-quantization/src/spherical/quantizer.rs | Adds match arms for TransformFailed::SgemmError, but currently handles it with panic!. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #[derive(Debug, Clone, Error, PartialEq)] | ||
| pub enum TransformFailed { | ||
| #[error("incorrect transform input vector - expected length {expected} but got {found}")] | ||
| SourceMismatch { expected: usize, found: usize }, | ||
| #[error("incorrect transform output vector - expected length {expected} but got {found}")] | ||
| DestinationMismatch { expected: usize, found: usize }, | ||
| #[error(transparent)] | ||
| AllocatorError(#[from] AllocatorError), | ||
| #[cfg(feature = "linalg")] | ||
| #[error("SGEMM operation failed: {0}")] | ||
| SgemmError(#[from] diskann_linalg::SgemmError), |
There was a problem hiding this comment.
TransformFailed no longer derives Copy. If the goal is only to support SgemmError when feature = "linalg" is enabled, consider conditionally deriving Copy when the feature is disabled (e.g., cfg_attr(not(feature = "linalg"), derive(Copy))) to avoid an unnecessary API regression for non-linalg builds.
| Err(TransformFailed::SgemmError(_)) => { | ||
| panic!("SGEMM should not fail with valid dimensions - this is a logic error"); |
There was a problem hiding this comment.
TransformFailed::SgemmError is handled via panic!, which reintroduces a production panic path for a recoverable error coming from diskann_linalg::sgemm (notably DimensionOverflow can be triggered by large/untrusted dimensions). Instead of panicking here, propagate the error (e.g., add a CompressionError variant for SGEMM failures, or map it into an existing error type) so callers can handle it without crashing.
| Err(TransformFailed::SgemmError(_)) => { | |
| panic!("SGEMM should not fail with valid dimensions - this is a logic error"); | |
| Err(TransformFailed::SgemmError(err)) => { | |
| return Err(CompressionError::SgemmError(err)); |
| Err(TransformFailed::SgemmError(_)) => { | ||
| panic!("SGEMM should not fail with valid dimensions - this is a logic error"); |
There was a problem hiding this comment.
This match arm panics on TransformFailed::SgemmError, which defeats the Result-based error handling in diskann_linalg::sgemm and can crash the process on DimensionOverflow / invalid dims. Please propagate or convert the error instead of panicking (similar to how AllocatorError is handled).
| Err(TransformFailed::SgemmError(_)) => { | |
| panic!("SGEMM should not fail with valid dimensions - this is a logic error"); | |
| Err(TransformFailed::SgemmError(err)) => { | |
| return Err(err.into()); |
| Err(TransformFailed::SgemmError(_)) => { | ||
| panic!("SGEMM should not fail with valid dimensions - this is a logic error"); |
There was a problem hiding this comment.
Panicking on TransformFailed::SgemmError reintroduces a runtime panic path for what is now an explicit error condition (diskann_linalg::SgemmError). To match the PR’s stated goal (no panics for recoverable input issues), this should be returned/propagated as an error rather than crashing.
| Err(TransformFailed::SgemmError(_)) => { | |
| panic!("SGEMM should not fail with valid dimensions - this is a logic error"); | |
| Err(TransformFailed::SgemmError(err)) => { | |
| return Err(err.into()); |
hildebrandmw
left a comment
There was a problem hiding this comment.
Thanks Alex - just a few general error handling comments from my end.
| expected_cols: usize, | ||
| expected_len: usize, | ||
| actual_len: usize, | ||
| }, |
There was a problem hiding this comment.
This is a large error type (~48 bytes: 32 bytes for the usize values and 16 bytes for &'static str). This is mainly a concern when it gets rolled into higher level errors, such as TransformFailed. We can do a little bit to bring it down in size.
First, expected_len can be computed from expected_rows and expected_cols, so we can save 8 bytes there. Next, an enum can be used for matrix_name:
#[derive(Debug, Clone, Copy)]
pub enum MatrixName {
A,
B,
C,
}
impl std::fmt::Display for MatrixName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::A => write!(f, "a"),
Self::B => write!(f, "b"),
Self::C => write!(f, "c"),
}
}
}This will drop the size of the error down to 32 bytes, which is still a little big, but better.
There was a problem hiding this comment.
Heh, that was my original plan, but I decided that I'm overengineering it.
Fixed.
| c: &mut [f32], | ||
| ) { | ||
| sgemm(atranspose, btranspose, m, n, k, alpha, a, b, beta, c).unwrap(); | ||
| } |
There was a problem hiding this comment.
I wonder if we should update the signature of TestProblem::check to be Result<(), SgemmError> for uniformity. If you don't want to repeat all the checking logic, a trick I've used in the past is something like this:
impl SgemmError {
fn check(m: usize, n: usize, ...) -> Result<(), Self>;
}Then everything preflight with
fn sgemm(...) -> Result<(), SgemmError> {
SgemmError::check(args...)?;
sgemm_impl(...);
Ok(())
}| dist_matrix, | ||
| ); | ||
| ) | ||
| .map_err(|e| ANNError::log_index_error(format_args!("{}", e)))?; |
There was a problem hiding this comment.
For future reference, you may want to prefer ANNError::new(ANNErrorKind::IndexError, e) instead of the old log* style constructors. The former defers any kind of string formatting until the entire error is formatter higher in the call stack, where-as the latter forces eager string evaluation lower in the callstack. Not that it matters that much though.
There was a problem hiding this comment.
Done + created a task to remove the old log-style constructors: #1003
Summary
Refactors
diskann_linalg::sgemm()to returnResult<(), SgemmError>instead of panicking on invalid inputs. This change improves error handling and adds overflow protection for matrix dimension calculations.Principle: DiskANN must not panic in production for recoverable errors; panics are reserved for catastrophic failures indicating undefined behavior.
Changes
1. New Error Type (
SgemmError)SgemmErrorenum with two variants:InvalidMatrixDimensions: Matrix has incorrect dimensionsDimensionOverflow: Multiplication would overflowusizeDisplayandstd::error::Errortraits for proper error reporting2. Updated
sgemm()Function SignatureBefore:
After:
3. Overflow Protection
Replaced unsafe multiplication with
checked_mul()for all dimension calculations:m * kvalidation (matrix a)k * nvalidation (matrix b)m * nvalidation (matrix c)This prevents integer overflow vulnerabilities that could lead to undefined behavior.
4. Updated Call Sites
Updated all callers of
sgemm()to handle theResult:TransformFailed::SgemmErrorvariantANNErrorwith descriptive messages5. Comprehensive Test Coverage
Added 6 new unit tests:
test_sgemm_invalid_matrix_a_dimensions: Validates dimension check for matrix atest_sgemm_invalid_matrix_b_dimensions: Validates dimension check for matrix btest_sgemm_invalid_matrix_c_dimensions: Validates dimension check for matrix ctest_sgemm_m_times_k_overflow: Tests overflow detection form * ktest_sgemm_k_times_n_overflow: Tests overflow detection fork * ntest_sgemm_m_times_n_overflow: Tests overflow detection form * nReview Order
📁 diskann-linalg/src/lib.rs
SgemmErrortypechecked_mul()📁 diskann-linalg/src/lib.rs (lines 318-461)
📁 diskann-disk/src/utils/math_util.rs (lines 164-204)