diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index cf1f4a7c..2dd2bfc0 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -164,6 +164,118 @@ struct select_real_t> using type = U; }; +template +class SimpleArrayMixinSum +{ + +private: + + using internal_types = detail::SimpleArrayInternalTypes; + +public: + + using value_type = typename internal_types::value_type; + + value_type sum() const + { + auto athis = static_cast(this); + const size_t n = athis->size(); + if (n == 0) + { + return zero(); + } + // Either C- or F-contiguous arrays occupy a single dense block in + // memory, so summation order is irrelevant and we can sweep the + // buffer linearly regardless of axis order. + if (athis->is_c_contiguous() || athis->is_f_contiguous()) + { + return sum_contiguous(athis->data(), n); + } + return sum_strided(athis->data(), athis->shape(), athis->stride()); + } + +private: + + static constexpr value_type zero() + { + if constexpr (is_complex_v) + { + return value_type{}; + } + else + { + return value_type{0}; + } + } + + static void accumulate(value_type & acc, value_type v) + { + if constexpr (std::is_same_v>) + { + acc |= v; + } + else + { + acc += v; + } + } + + static value_type sum_contiguous(value_type const * data, size_t n) + { + value_type acc = zero(); + for (size_t i = 0; i < n; ++i) + { + accumulate(acc, data[i]); + } + return acc; + } + + // Walk a strided array by its innermost dimension: compute the row base + // offset once per outer iteration, then accumulate along the last axis. + // This avoids the per-element multi-dimensional index arithmetic that + // at(sidx) performs. + static value_type sum_strided(value_type const * data, + small_vector const & shape, + small_vector const & stride) + { + const size_t ndim = shape.size(); + const size_t last_dim = shape[ndim - 1]; + const size_t last_stride = stride[ndim - 1]; + + value_type acc = zero(); + small_vector prefix(ndim - 1, 0); + do + { + size_t offset = 0; + for (size_t i = 0; i + 1 < ndim; ++i) + { + offset += prefix[i] * stride[i]; + } + value_type const * row = data + offset; + for (size_t j = 0; j < last_dim; ++j) + { + accumulate(acc, row[j * last_stride]); + } + } while (next_prefix(prefix, shape)); + return acc; + } + + static bool next_prefix(small_vector & idx, + small_vector const & shape) + { + for (size_t i = idx.size(); i > 0; --i) + { + if (++idx[i - 1] < shape[i - 1]) + { + return true; + } + idx[i - 1] = 0; + } + return false; + } + +}; /* end class SimpleArrayMixinSum */ + template class SimpleArrayMixinCalculators { @@ -365,15 +477,12 @@ class SimpleArrayMixinCalculators value_type mean() const { auto athis = static_cast(this); - auto sidx = athis->first_sidx(); - value_type sum = 0; - int64_t total = 0; - do + const size_t n = athis->size(); + if (n == 0) { - sum += athis->at(sidx); - ++total; - } while (athis->next_sidx(sidx)); - return sum / static_cast(total); + throw std::runtime_error("SimpleArray::mean(): empty array"); + } + return athis->sum() / static_cast(n); } real_type var_op(small_vector & sv, size_t ddof) const @@ -488,36 +597,6 @@ class SimpleArrayMixinCalculators return initial; } - value_type sum() const - { - value_type initial; - if constexpr (is_complex_v) - { - initial = value_type(); - } - else - { - initial = 0; - } - - auto athis = static_cast(this); - if constexpr (!std::is_same_v>) - { - for (size_t i = 0; i < athis->size(); ++i) - { - initial += athis->data(i); - } - } - else - { - for (size_t i = 0; i < athis->size(); ++i) - { - initial |= athis->data(i); - } - } - return initial; - } - A abs() const { auto athis = static_cast(this); @@ -1382,6 +1461,7 @@ struct with_alignment_t template class SimpleArray : public detail::SimpleArrayMixinModifiers, T> + , public detail::SimpleArrayMixinSum, T> , public detail::SimpleArrayMixinCalculators, T> , public detail::SimpleArrayMixinSort, T> , public detail::SimpleArrayMixinSearch, T> @@ -1877,37 +1957,60 @@ class SimpleArray value_type const * body() const { return m_body; } value_type * body() { return m_body; } + bool is_c_contiguous() const { return is_c_contiguous(m_shape, m_stride); } + bool is_f_contiguous() const { return is_f_contiguous(m_shape, m_stride); } + private: - void check_c_contiguous(small_vector const & shape, - small_vector const & stride) const + static bool is_c_contiguous(small_vector const & shape, + small_vector const & stride) { if (stride[stride.size() - 1] != 1) { - throw std::runtime_error("SimpleArray: C contiguous stride must end with 1"); + return false; } for (size_t it = 0; it < shape.size() - 1; ++it) { if (stride[it] != shape[it + 1] * stride[it + 1]) { - throw std::runtime_error("SimpleArray: C contiguous stride must match shape"); + return false; } } + return true; } - void check_f_contiguous(small_vector const & shape, - small_vector const & stride) const + static bool is_f_contiguous(small_vector const & shape, + small_vector const & stride) { if (stride[0] != 1) { - throw std::runtime_error("SimpleArray: Fortran contiguous stride must start with 1"); + return false; } for (size_t it = 0; it < shape.size() - 1; ++it) { if (stride[it + 1] != shape[it] * stride[it]) { - throw std::runtime_error("SimpleArray: Fortran contiguous stride must match shape"); + return false; } } + return true; + } + + static void check_c_contiguous(small_vector const & shape, + small_vector const & stride) + { + if (!is_c_contiguous(shape, stride)) + { + throw std::runtime_error("SimpleArray: C contiguous stride must match shape and end with 1"); + } + } + + void check_f_contiguous(small_vector const & shape, + small_vector const & stride) const + { + if (!is_f_contiguous(shape, stride)) + { + throw std::runtime_error("SimpleArray: F contiguous stride must match shape and start with 1"); + } } void validate_range(ssize_t it) const diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 2b3592f2..273d5bab 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -1443,6 +1443,29 @@ def test_minmaxsum(self): self.assertEqual(sarr.min(), -2.3) self.assertEqual(sarr.max(), 9.2) + def test_sum_non_contiguous(self): + # Strided slice that fails both C- and F-contiguity checks, so + # sum() must take the sum_strided path. Distinct integer values + # mean any indexing bug shifts the result. + nparr = np.arange(625, dtype='float64').reshape((5, 5, 5, 5)) + nparr = nparr[:3:2, 1:4:2, :3:3, 3:4:2] + sarr = modmesh.SimpleArrayFloat64(array=nparr) + self.assertEqual(sarr.sum(), np.sum(nparr)) + + def test_sum_empty_non_contiguous(self): + # Empty and non-contiguous: shape (3, 0, 4) with strides + # (0, 1, 0) is neither C- nor F-contiguous, so sum() must + # short-circuit on n == 0 instead of falling into the strided + # path and reading from an empty buffer. + sarr = modmesh.SimpleArrayFloat64(shape=(3, 4, 0), value=0.0) + sarr = sarr.transpose(axis=[0, 2, 1]) + self.assertEqual(sarr.sum(), 0.0) + + def test_mean_empty_raises(self): + sarr = modmesh.SimpleArrayFloat64(shape=(0, 3), value=0.0) + with self.assertRaisesRegex(RuntimeError, "empty array"): + sarr.mean() + def test_abs(self): sarr = modmesh.SimpleArrayInt64(shape=(3, 2), value=-2) self.assertEqual(sarr.sum(), -2 * 3 * 2)