Accelerate mean() with a dedicated fast sum() path#741
Open
tigercosmos wants to merge 1 commit intosolvcon:masterfrom
Open
Accelerate mean() with a dedicated fast sum() path#741tigercosmos wants to merge 1 commit intosolvcon:masterfrom
tigercosmos wants to merge 1 commit intosolvcon:masterfrom
Conversation
07ed4a9 to
1ffa8e4
Compare
Add SimpleArrayMixinSum with two code paths for sum(): a contiguous path that walks the buffer directly, and a strided path that walks by the innermost dimension and computes the row base offset once per outer iteration. Rewrite mean() as sum() / size() so both share the same fast path. This also fixes a latent bug where the old sum() walked data(i) by logical index on non-contiguous arrays and read the wrong elements.
1ffa8e4 to
75ff6b2
Compare
tigercosmos
commented
Apr 28, 2026
Collaborator
Author
tigercosmos
left a comment
There was a problem hiding this comment.
@yungyuc Please take a look, thanks!
| return sum / static_cast<value_type>(total); | ||
| throw std::runtime_error("SimpleArray::mean(): empty array"); | ||
| } | ||
| return athis->sum() / static_cast<value_type>(n); |
Comment on lines
+190
to
+194
| if (athis->is_c_contiguous() || athis->is_f_contiguous()) | ||
| { | ||
| return sum_contiguous(athis->data(), n); | ||
| } | ||
| return sum_strided(athis->data(), athis->shape(), athis->stride()); |
Collaborator
Author
There was a problem hiding this comment.
Properly handle the sum.
| self.assertEqual(sarr.min(), -2.3) | ||
| self.assertEqual(sarr.max(), 9.2) | ||
|
|
||
| def test_sum_non_contiguous(self): |
Collaborator
Author
There was a problem hiding this comment.
Add a few test cases to increase the coverage.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The existing
SimpleArray::mean()walks every element through a multi-dimensional index iterator (first_sidx/next_sidx/at(sidx)). That per-element index arithmetic dominates the runtime and makesmean()several times slower than it needs to be, especially on large arrays. This PR continues the work started in #589 and takes another pass at the same goal: makemean()fast without introducing threading or guessed micro-optimizations.What changes
A new mixin
SimpleArrayMixinSumprovidessum()with two code paths. The contiguous path walks the buffer with a straight pointer loop so the compiler can auto-vectorize; the strided path walks by the innermost dimension, computing each row's base offset once per outer iteration instead of for every element.mean()now delegates to thissum()and simply divides bysize().The refactor also fixes a latent correctness bug. The previous
sum()iterateddata(i)by a linear buffer index, so for a non-contiguous array (any slice or transpose) it happily summed the wrong elements. The new strided path respects the stride, and there is a regression test covering it.What is deliberately not included
This PR drops the threading and manual loop unrolling from the earlier iteration of #589, following the review feedback on that PR. Threads belong to a larger scheduling system that does not exist yet, and a hand-rolled unroll that funnels every partial into a single accumulator does not actually help the compiler vectorize. Private helpers that do not touch
thisare markedstatic, matching the reviewer's request.Edge cases covered
Empty arrays no longer read past the end of the buffer:
sum()returns zero on any shape whose product is zero, andmean()throws on an empty input rather than producing0/0(NaN for floats, UB for integers). Ghost-cell behavior is unchanged — both old and new paths include ghost cells, which matches howdata()and the oldfirst_sidxiterator behaved.Benchmarks
Times are microseconds per call, best-of-5 runs of 10 iterations each, then averaged across two independent full runs. Machine: Apple Silicon, Python 3.14, release build. "before" =
upstream/masterat 3d73a5d; "after" = this branch HEAD.mean()— 3× to 6× fastersum()— flat on contiguous, modest cost on stridedNote: the
sum()"before" numbers on strided cases reflect the old buggy behavior (wrong data, read linearly along the buffer), so the comparison is about cost, not correctness — the new strided sum is correct, and on the 3D strided 100³ case it is somewhat slower than the old buggy linear walk (0.83×) because it now follows the stride properly. On contiguous shapes,sum()is unchanged or slightly faster.How to reproduce the benchmark
The benchmark script and raw results are not part of this PR; they live locally at
profiling/bench_mean_compare.py. A minimal reproducer:Procedure to get before/after numbers:
upstream/master, build withmake buildext, run the script, record numbers.make buildext, run the same script, record numbers.sa.meanshould drop by 3–6× on every shape;sa.sumstays about the same on contiguous shapes.Relation to #589
This is a rework of #589 based on the review there. It keeps the core idea (dedicated fast
sum()path used bymean()), drops the parts the reviewer rejected (threads, unclear unroll), and adds the empty-array guards plus strided-sum correctness fix.