Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/array/array_nested.h"
#include "arrow/array/array_primitive.h"
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
#include "arrow/tensor.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging.h"
#include "arrow/util/sort.h"
Expand Down Expand Up @@ -140,6 +141,164 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}

Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor(
const std::shared_ptr<Tensor>& tensor) {
auto cell_shape = tensor->shape();
cell_shape.erase(cell_shape.begin());

std::vector<std::string> dim_names;
for (size_t i = 1; i < tensor->dim_names().size(); ++i) {
dim_names.emplace_back(tensor->dim_names()[i]);
}

auto permutation = internal::ArgSort(tensor->strides(), std::greater<>());
if (permutation[0] != 0) {
return Status::Invalid(
"Only first-major tensors can be zero-copy converted to arrays");
}
permutation.erase(permutation.begin());
for (size_t i = 0; i < permutation.size(); ++i) {
--permutation[i];
}

auto ext_type = internal::checked_pointer_cast<ExtensionType>(
fixed_shape_tensor(tensor->type(), cell_shape, permutation, dim_names));

std::shared_ptr<Array> value_array;
switch (tensor->type_id()) {
case Type::UINT8: {
value_array = std::make_shared<UInt8Array>(tensor->size(), tensor->data());
break;
}
case Type::INT8: {
value_array = std::make_shared<Int8Array>(tensor->size(), tensor->data());
break;
}
case Type::UINT16: {
value_array = std::make_shared<UInt16Array>(tensor->size(), tensor->data());
break;
}
case Type::INT16: {
value_array = std::make_shared<Int16Array>(tensor->size(), tensor->data());
break;
}
case Type::UINT32: {
value_array = std::make_shared<UInt32Array>(tensor->size(), tensor->data());
break;
}
case Type::INT32: {
value_array = std::make_shared<Int32Array>(tensor->size(), tensor->data());
break;
}
case Type::UINT64: {
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
break;
}
case Type::INT64: {
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
break;
}
case Type::HALF_FLOAT: {
value_array = std::make_shared<HalfFloatArray>(tensor->size(), tensor->data());
break;
}
case Type::FLOAT: {
value_array = std::make_shared<FloatArray>(tensor->size(), tensor->data());
break;
}
case Type::DOUBLE: {
value_array = std::make_shared<DoubleArray>(tensor->size(), tensor->data());
break;
}
default: {
return Status::NotImplemented("Unsupported tensor type: ",
tensor->type()->ToString());
}
}
auto cell_size = static_cast<int32_t>(tensor->size() / tensor->shape()[0]);
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> arr,
FixedSizeListArray::FromArrays(value_array, cell_size));
std::shared_ptr<Array> ext_arr = ExtensionType::WrapArray(ext_type, arr);
return std::reinterpret_pointer_cast<FixedShapeTensorArray>(ext_arr);
}

Status FixedShapeTensorType::ComputeStrides(const FixedWidthType& type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(type, shape, strides);
}

const int byte_width = type.byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
// To convert an array of n dimensional tensors to a n+1 dimensional tensor we
// interpret the array's length as the first dimension the new tensor.

auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>(this->storage());
auto ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
Status::Invalid(ext_arr->value_type()->ToString(),
" is not valid data type for a tensor"));
std::vector<int64_t> shape = ext_type->shape();
shape.insert(shape.begin(), 1, this->length());

std::vector<int64_t> tensor_strides;
auto value_type = internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type());
auto permutation = ext_type->permutation();
for (size_t i = 0; i < permutation.size(); ++i) {
++permutation[i];
}
permutation.insert(permutation.begin(), 1, 0);
ARROW_RETURN_NOT_OK(FixedShapeTensorType::ComputeStrides(*value_type.get(), shape,
permutation, &tensor_strides));

std::vector<std::string> dim_names;
if (!ext_type->dim_names().empty()) {
dim_names = ext_type->dim_names();
dim_names.insert(dim_names.begin(), 1, "");
} else {
dim_names = {};
}

ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten());
ARROW_ASSIGN_OR_RAISE(
auto tensor, Tensor::Make(ext_arr->value_type(), buffers->data()->buffers[1], shape,
tensor_strides, dim_names));
return tensor;
}

Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
Expand All @@ -157,6 +316,17 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
shape, permutation, dim_names);
}

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_pointer_cast<FixedWidthType>(this->value_type_);
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), this->permutation(),
&tensor_strides));
strides_ = tensor_strides;
}
return strides_;
}

std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
Expand Down
31 changes: 31 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ namespace extension {
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;

/// \brief Create a FixedShapeTensorArray from a Tensor
///
/// This method will create a FixedShapeTensorArray from a Tensor, taking its first
/// dimension as the number of elements in the resulting array and the remaining
/// dimensions as the shape of the individual tensors. If Tensor provides strides,
/// they will be used to determine dimension permutation. Otherwise, row-major layout
/// (i.e. no permutation) will be assumed.
///
/// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray
static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
const std::shared_ptr<Tensor>& tensor);

/// \brief Create a Tensor from FixedShapeTensorArray
///
/// This method will create a Tensor from a FixedShapeTensorArray, setting its
/// first dimension as length equal to the FixedShapeTensorArray's length and the
/// remaining dimensions as the FixedShapeTensorType's shape.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add here to the docstring that this will automatically reshape the resulting Tensor according to the permutation metadata of the FixedShapeTensorArray?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const Result<std::shared_ptr<Tensor>> ToTensor() const;
};

/// \brief Concrete type class for constant-size Tensor data.
Expand Down Expand Up @@ -51,6 +70,11 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
/// Value type of tensor elements
const std::shared_ptr<DataType> value_type() const { return value_type_; }

/// Strides of tensor elements. Strides state offset in bytes between adjacent
/// elements along each dimension. In case permutation is non-empty strides are
/// computed from permuted tensor element's shape.
const std::vector<int64_t>& strides();

/// Permutation mapping from logical to physical memory layout of tensor elements
const std::vector<int64_t>& permutation() const { return permutation_; }

Expand All @@ -74,10 +98,17 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

/// \brief Compute strides of FixedShapeTensorType
static Status ComputeStrides(const FixedWidthType& type,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need for having this public in addition to strides() ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason it's like this now is because both FixedShapeTensorType and FixedShapeTensorArray are using ComputeStrides, but ComputeStrides is a FixedShapeTensorType method. I'll see if I can put it into another namespace.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to anonymous namespace.

const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides);

private:
std::shared_ptr<DataType> storage_type_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> shape_;
std::vector<int64_t> strides_;
std::vector<int64_t> permutation_;
std::vector<std::string> dim_names_;
};
Expand Down
Loading