-
Notifications
You must be signed in to change notification settings - Fork 58
Accelerate mean() with a dedicated fast sum() path #741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -164,6 +164,118 @@ struct select_real_t<Complex<U>> | |
| using type = U; | ||
| }; | ||
|
|
||
| template <typename A, typename T> | ||
| class SimpleArrayMixinSum | ||
| { | ||
|
|
||
| private: | ||
|
|
||
| using internal_types = detail::SimpleArrayInternalTypes<T>; | ||
|
|
||
| public: | ||
|
|
||
| using value_type = typename internal_types::value_type; | ||
|
|
||
| value_type sum() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
| const size_t n = athis->size(); | ||
| if (n == 0) | ||
| { | ||
| return zero(); | ||
| } | ||
| // Either C- or F-contiguous arrays occupy a single dense block in | ||
| // memory, so summation order is irrelevant and we can sweep the | ||
| // buffer linearly regardless of axis order. | ||
| if (athis->is_c_contiguous() || athis->is_f_contiguous()) | ||
| { | ||
| return sum_contiguous(athis->data(), n); | ||
| } | ||
| return sum_strided(athis->data(), athis->shape(), athis->stride()); | ||
|
Comment on lines
+190
to
+194
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Properly handle the sum. |
||
| } | ||
|
|
||
| private: | ||
|
|
||
| static constexpr value_type zero() | ||
| { | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| return value_type{}; | ||
| } | ||
| else | ||
| { | ||
| return value_type{0}; | ||
| } | ||
| } | ||
|
|
||
| static void accumulate(value_type & acc, value_type v) | ||
| { | ||
| if constexpr (std::is_same_v<bool, std::remove_const_t<value_type>>) | ||
| { | ||
| acc |= v; | ||
| } | ||
| else | ||
| { | ||
| acc += v; | ||
| } | ||
| } | ||
|
|
||
| static value_type sum_contiguous(value_type const * data, size_t n) | ||
| { | ||
| value_type acc = zero(); | ||
| for (size_t i = 0; i < n; ++i) | ||
| { | ||
| accumulate(acc, data[i]); | ||
| } | ||
| return acc; | ||
| } | ||
|
|
||
| // Walk a strided array by its innermost dimension: compute the row base | ||
| // offset once per outer iteration, then accumulate along the last axis. | ||
| // This avoids the per-element multi-dimensional index arithmetic that | ||
| // at(sidx) performs. | ||
| static value_type sum_strided(value_type const * data, | ||
| small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) | ||
| { | ||
| const size_t ndim = shape.size(); | ||
| const size_t last_dim = shape[ndim - 1]; | ||
| const size_t last_stride = stride[ndim - 1]; | ||
|
|
||
| value_type acc = zero(); | ||
| small_vector<size_t> prefix(ndim - 1, 0); | ||
| do | ||
| { | ||
| size_t offset = 0; | ||
| for (size_t i = 0; i + 1 < ndim; ++i) | ||
| { | ||
| offset += prefix[i] * stride[i]; | ||
| } | ||
| value_type const * row = data + offset; | ||
| for (size_t j = 0; j < last_dim; ++j) | ||
| { | ||
| accumulate(acc, row[j * last_stride]); | ||
| } | ||
| } while (next_prefix(prefix, shape)); | ||
| return acc; | ||
| } | ||
|
|
||
| static bool next_prefix(small_vector<size_t> & idx, | ||
| small_vector<size_t> const & shape) | ||
| { | ||
| for (size_t i = idx.size(); i > 0; --i) | ||
| { | ||
| if (++idx[i - 1] < shape[i - 1]) | ||
| { | ||
| return true; | ||
| } | ||
| idx[i - 1] = 0; | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| }; /* end class SimpleArrayMixinSum */ | ||
|
|
||
| template <typename A, typename T> | ||
| class SimpleArrayMixinCalculators | ||
| { | ||
|
|
@@ -365,15 +477,12 @@ class SimpleArrayMixinCalculators | |
| value_type mean() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
| auto sidx = athis->first_sidx(); | ||
| value_type sum = 0; | ||
| int64_t total = 0; | ||
| do | ||
| const size_t n = athis->size(); | ||
| if (n == 0) | ||
| { | ||
| sum += athis->at(sidx); | ||
| ++total; | ||
| } while (athis->next_sidx(sidx)); | ||
| return sum / static_cast<value_type>(total); | ||
| throw std::runtime_error("SimpleArray::mean(): empty array"); | ||
| } | ||
| return athis->sum() / static_cast<value_type>(n); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meat. |
||
| } | ||
|
|
||
| real_type var_op(small_vector<value_type> & sv, size_t ddof) const | ||
|
|
@@ -488,36 +597,6 @@ class SimpleArrayMixinCalculators | |
| return initial; | ||
| } | ||
|
|
||
| value_type sum() const | ||
| { | ||
| value_type initial; | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| initial = value_type(); | ||
| } | ||
| else | ||
| { | ||
| initial = 0; | ||
| } | ||
|
|
||
| auto athis = static_cast<A const *>(this); | ||
| if constexpr (!std::is_same_v<bool, std::remove_const_t<value_type>>) | ||
| { | ||
| for (size_t i = 0; i < athis->size(); ++i) | ||
| { | ||
| initial += athis->data(i); | ||
| } | ||
| } | ||
| else | ||
| { | ||
| for (size_t i = 0; i < athis->size(); ++i) | ||
| { | ||
| initial |= athis->data(i); | ||
| } | ||
| } | ||
| return initial; | ||
| } | ||
|
|
||
| A abs() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
|
|
@@ -1382,6 +1461,7 @@ struct with_alignment_t | |
| template <typename T> | ||
| class SimpleArray | ||
| : public detail::SimpleArrayMixinModifiers<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinSum<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinCalculators<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinSort<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinSearch<SimpleArray<T>, T> | ||
|
|
@@ -1877,37 +1957,60 @@ class SimpleArray | |
| value_type const * body() const { return m_body; } | ||
| value_type * body() { return m_body; } | ||
|
|
||
| bool is_c_contiguous() const { return is_c_contiguous(m_shape, m_stride); } | ||
| bool is_f_contiguous() const { return is_f_contiguous(m_shape, m_stride); } | ||
|
|
||
| private: | ||
| void check_c_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) const | ||
| static bool is_c_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) | ||
| { | ||
| if (stride[stride.size() - 1] != 1) | ||
| { | ||
| throw std::runtime_error("SimpleArray: C contiguous stride must end with 1"); | ||
| return false; | ||
| } | ||
| for (size_t it = 0; it < shape.size() - 1; ++it) | ||
| { | ||
| if (stride[it] != shape[it + 1] * stride[it + 1]) | ||
| { | ||
| throw std::runtime_error("SimpleArray: C contiguous stride must match shape"); | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| void check_f_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) const | ||
| static bool is_f_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) | ||
| { | ||
| if (stride[0] != 1) | ||
| { | ||
| throw std::runtime_error("SimpleArray: Fortran contiguous stride must start with 1"); | ||
| return false; | ||
| } | ||
| for (size_t it = 0; it < shape.size() - 1; ++it) | ||
| { | ||
| if (stride[it + 1] != shape[it] * stride[it]) | ||
| { | ||
| throw std::runtime_error("SimpleArray: Fortran contiguous stride must match shape"); | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| static void check_c_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) | ||
| { | ||
| if (!is_c_contiguous(shape, stride)) | ||
| { | ||
| throw std::runtime_error("SimpleArray: C contiguous stride must match shape and end with 1"); | ||
| } | ||
| } | ||
|
|
||
| void check_f_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) const | ||
| { | ||
| if (!is_f_contiguous(shape, stride)) | ||
| { | ||
| throw std::runtime_error("SimpleArray: F contiguous stride must match shape and start with 1"); | ||
| } | ||
| } | ||
|
|
||
| void validate_range(ssize_t it) const | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1443,6 +1443,29 @@ def test_minmaxsum(self): | |
| self.assertEqual(sarr.min(), -2.3) | ||
| self.assertEqual(sarr.max(), 9.2) | ||
|
|
||
| def test_sum_non_contiguous(self): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a few test cases to increase the coverage.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add tests for contiguous memory too. I did not find |
||
| # Strided slice that fails both C- and F-contiguity checks, so | ||
| # sum() must take the sum_strided path. Distinct integer values | ||
| # mean any indexing bug shifts the result. | ||
| nparr = np.arange(625, dtype='float64').reshape((5, 5, 5, 5)) | ||
| nparr = nparr[:3:2, 1:4:2, :3:3, 3:4:2] | ||
| sarr = modmesh.SimpleArrayFloat64(array=nparr) | ||
| self.assertEqual(sarr.sum(), np.sum(nparr)) | ||
|
|
||
| def test_sum_empty_non_contiguous(self): | ||
| # Empty and non-contiguous: shape (3, 0, 4) with strides | ||
| # (0, 1, 0) is neither C- nor F-contiguous, so sum() must | ||
| # short-circuit on n == 0 instead of falling into the strided | ||
| # path and reading from an empty buffer. | ||
| sarr = modmesh.SimpleArrayFloat64(shape=(3, 4, 0), value=0.0) | ||
| sarr = sarr.transpose(axis=[0, 2, 1]) | ||
| self.assertEqual(sarr.sum(), 0.0) | ||
|
|
||
| def test_mean_empty_raises(self): | ||
| sarr = modmesh.SimpleArrayFloat64(shape=(0, 3), value=0.0) | ||
| with self.assertRaisesRegex(RuntimeError, "empty array"): | ||
| sarr.mean() | ||
|
|
||
| def test_abs(self): | ||
| sarr = modmesh.SimpleArrayInt64(shape=(3, 2), value=-2) | ||
| self.assertEqual(sarr.sum(), -2 * 3 * 2) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is wrong, please correct it.
Sweeping the buffer linearly is exactly following the axis order for C and F contiguity, respectively.
It took me a while to realize the comment is wrong.