Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 149 additions & 46 deletions cpp/modmesh/buffer/SimpleArray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,118 @@ struct select_real_t<Complex<U>>
using type = U;
};

template <typename A, typename T>
class SimpleArrayMixinSum
{

private:

using internal_types = detail::SimpleArrayInternalTypes<T>;

public:

using value_type = typename internal_types::value_type;

value_type sum() const
{
auto athis = static_cast<A const *>(this);
const size_t n = athis->size();
if (n == 0)
{
return zero();
}
// Either C- or F-contiguous arrays occupy a single dense block in
// memory, so summation order is irrelevant and we can sweep the
// buffer linearly regardless of axis order.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is wrong, please correct it.

sweep the buffer linearly regardless of axis order

Sweeping the buffer linearly is exactly following the axis order for C and F contiguity, respectively.

It took me a while to realize the comment is wrong.

if (athis->is_c_contiguous() || athis->is_f_contiguous())
{
return sum_contiguous(athis->data(), n);
}
return sum_strided(athis->data(), athis->shape(), athis->stride());
Comment on lines +190 to +194
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Properly handle the sum.

}

private:

static constexpr value_type zero()
{
if constexpr (is_complex_v<value_type>)
{
return value_type{};
}
else
{
return value_type{0};
}
}

static void accumulate(value_type & acc, value_type v)
{
if constexpr (std::is_same_v<bool, std::remove_const_t<value_type>>)
{
acc |= v;
}
else
{
acc += v;
}
}

static value_type sum_contiguous(value_type const * data, size_t n)
{
value_type acc = zero();
for (size_t i = 0; i < n; ++i)
{
accumulate(acc, data[i]);
}
return acc;
}

// Walk a strided array by its innermost dimension: compute the row base
// offset once per outer iteration, then accumulate along the last axis.
// This avoids the per-element multi-dimensional index arithmetic that
// at(sidx) performs.
static value_type sum_strided(value_type const * data,
small_vector<size_t> const & shape,
small_vector<size_t> const & stride)
{
const size_t ndim = shape.size();
const size_t last_dim = shape[ndim - 1];
const size_t last_stride = stride[ndim - 1];

value_type acc = zero();
small_vector<size_t> prefix(ndim - 1, 0);
do
{
size_t offset = 0;
for (size_t i = 0; i + 1 < ndim; ++i)
{
offset += prefix[i] * stride[i];
}
value_type const * row = data + offset;
for (size_t j = 0; j < last_dim; ++j)
{
accumulate(acc, row[j * last_stride]);
}
} while (next_prefix(prefix, shape));
return acc;
}

static bool next_prefix(small_vector<size_t> & idx,
small_vector<size_t> const & shape)
{
for (size_t i = idx.size(); i > 0; --i)
{
if (++idx[i - 1] < shape[i - 1])
{
return true;
}
idx[i - 1] = 0;
}
return false;
}

}; /* end class SimpleArrayMixinSum */

template <typename A, typename T>
class SimpleArrayMixinCalculators
{
Expand Down Expand Up @@ -365,15 +477,12 @@ class SimpleArrayMixinCalculators
value_type mean() const
{
auto athis = static_cast<A const *>(this);
auto sidx = athis->first_sidx();
value_type sum = 0;
int64_t total = 0;
do
const size_t n = athis->size();
if (n == 0)
{
sum += athis->at(sidx);
++total;
} while (athis->next_sidx(sidx));
return sum / static_cast<value_type>(total);
throw std::runtime_error("SimpleArray::mean(): empty array");
}
return athis->sum() / static_cast<value_type>(n);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meat.

}

real_type var_op(small_vector<value_type> & sv, size_t ddof) const
Expand Down Expand Up @@ -488,36 +597,6 @@ class SimpleArrayMixinCalculators
return initial;
}

value_type sum() const
{
value_type initial;
if constexpr (is_complex_v<value_type>)
{
initial = value_type();
}
else
{
initial = 0;
}

auto athis = static_cast<A const *>(this);
if constexpr (!std::is_same_v<bool, std::remove_const_t<value_type>>)
{
for (size_t i = 0; i < athis->size(); ++i)
{
initial += athis->data(i);
}
}
else
{
for (size_t i = 0; i < athis->size(); ++i)
{
initial |= athis->data(i);
}
}
return initial;
}

A abs() const
{
auto athis = static_cast<A const *>(this);
Expand Down Expand Up @@ -1382,6 +1461,7 @@ struct with_alignment_t
template <typename T>
class SimpleArray
: public detail::SimpleArrayMixinModifiers<SimpleArray<T>, T>
, public detail::SimpleArrayMixinSum<SimpleArray<T>, T>
, public detail::SimpleArrayMixinCalculators<SimpleArray<T>, T>
, public detail::SimpleArrayMixinSort<SimpleArray<T>, T>
, public detail::SimpleArrayMixinSearch<SimpleArray<T>, T>
Expand Down Expand Up @@ -1877,37 +1957,60 @@ class SimpleArray
value_type const * body() const { return m_body; }
value_type * body() { return m_body; }

bool is_c_contiguous() const { return is_c_contiguous(m_shape, m_stride); }
bool is_f_contiguous() const { return is_f_contiguous(m_shape, m_stride); }

private:
void check_c_contiguous(small_vector<size_t> const & shape,
small_vector<size_t> const & stride) const
static bool is_c_contiguous(small_vector<size_t> const & shape,
small_vector<size_t> const & stride)
{
if (stride[stride.size() - 1] != 1)
{
throw std::runtime_error("SimpleArray: C contiguous stride must end with 1");
return false;
}
for (size_t it = 0; it < shape.size() - 1; ++it)
{
if (stride[it] != shape[it + 1] * stride[it + 1])
{
throw std::runtime_error("SimpleArray: C contiguous stride must match shape");
return false;
}
}
return true;
}

void check_f_contiguous(small_vector<size_t> const & shape,
small_vector<size_t> const & stride) const
static bool is_f_contiguous(small_vector<size_t> const & shape,
small_vector<size_t> const & stride)
{
if (stride[0] != 1)
{
throw std::runtime_error("SimpleArray: Fortran contiguous stride must start with 1");
return false;
}
for (size_t it = 0; it < shape.size() - 1; ++it)
{
if (stride[it + 1] != shape[it] * stride[it])
{
throw std::runtime_error("SimpleArray: Fortran contiguous stride must match shape");
return false;
}
}
return true;
}

static void check_c_contiguous(small_vector<size_t> const & shape,
small_vector<size_t> const & stride)
{
if (!is_c_contiguous(shape, stride))
{
throw std::runtime_error("SimpleArray: C contiguous stride must match shape and end with 1");
}
}

void check_f_contiguous(small_vector<size_t> const & shape,
small_vector<size_t> const & stride) const
{
if (!is_f_contiguous(shape, stride))
{
throw std::runtime_error("SimpleArray: F contiguous stride must match shape and start with 1");
}
}

void validate_range(ssize_t it) const
Expand Down
23 changes: 23 additions & 0 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,29 @@ def test_minmaxsum(self):
self.assertEqual(sarr.min(), -2.3)
self.assertEqual(sarr.max(), 9.2)

def test_sum_non_contiguous(self):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a few test cases to increase the coverage.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for contiguous memory too. I did not find test_sum() in this file.

# Strided slice that fails both C- and F-contiguity checks, so
# sum() must take the sum_strided path. Distinct integer values
# mean any indexing bug shifts the result.
nparr = np.arange(625, dtype='float64').reshape((5, 5, 5, 5))
nparr = nparr[:3:2, 1:4:2, :3:3, 3:4:2]
sarr = modmesh.SimpleArrayFloat64(array=nparr)
self.assertEqual(sarr.sum(), np.sum(nparr))

def test_sum_empty_non_contiguous(self):
# Empty and non-contiguous: shape (3, 0, 4) with strides
# (0, 1, 0) is neither C- nor F-contiguous, so sum() must
# short-circuit on n == 0 instead of falling into the strided
# path and reading from an empty buffer.
sarr = modmesh.SimpleArrayFloat64(shape=(3, 4, 0), value=0.0)
sarr = sarr.transpose(axis=[0, 2, 1])
self.assertEqual(sarr.sum(), 0.0)

def test_mean_empty_raises(self):
sarr = modmesh.SimpleArrayFloat64(shape=(0, 3), value=0.0)
with self.assertRaisesRegex(RuntimeError, "empty array"):
sarr.mean()

def test_abs(self):
sarr = modmesh.SimpleArrayInt64(shape=(3, 2), value=-2)
self.assertEqual(sarr.sum(), -2 * 3 * 2)
Expand Down
Loading