Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d69685a
Add bindings and tests for FixedShapeTensorType and Array
AlenkaF Apr 4, 2023
a6292f8
Fix linter error
AlenkaF Apr 5, 2023
1bdba1d
Add pa.fixedshapetensor factory function and update docstring examples
AlenkaF Apr 5, 2023
7c395b0
Apply suggestions from code review - Joris
AlenkaF Apr 5, 2023
d27d48f
Use pa.FixedSizeListArray.from_arrays(..) in from_numpy_ndarray()
AlenkaF Apr 5, 2023
8e790b4
Change fixedshapetensor to fixed_shape_tensor
AlenkaF Apr 5, 2023
64e0cd0
Add tests for all the custom attributes
AlenkaF Apr 5, 2023
48cbeb3
Add test for numpy F-contiguous
AlenkaF Apr 5, 2023
d9ca165
Correct dim_names() to return list of strings, not bytes
AlenkaF Apr 5, 2023
d3530af
Correct dim_names and permutation methods to return None and not empt…
AlenkaF Apr 5, 2023
e2ce8ba
Replace FixedShapeTensorType with fixed_shape_tensor in FixedShapeTen…
AlenkaF Apr 5, 2023
ee5d25c
Update from_numpy_ndarray docstrings
AlenkaF Apr 5, 2023
f5a5c0c
Update public-api.pxi
AlenkaF Apr 5, 2023
52f9e7e
Update python/pyarrow/types.pxi
AlenkaF Apr 5, 2023
f9dee9e
Merge branch 'main' into python-binding-tensor-extension-type
AlenkaF Apr 5, 2023
b171d00
Use ravel insted of flatten and raise ValueError if numpy array is no…
AlenkaF Apr 5, 2023
c0ec94c
Remove CFixedShapeTensorType binding in libarrow
AlenkaF Apr 5, 2023
f2d9fe7
Fix doctest failure
AlenkaF Apr 6, 2023
8b5dc93
Add explanation of permutation from the spec to the docstring of fixe…
AlenkaF Apr 6, 2023
570f086
from_numpy_ndarray should be a static method
AlenkaF Apr 6, 2023
3dbbe20
Apply suggestions from code review
AlenkaF Apr 6, 2023
223968a
Apply suggestions from code review
AlenkaF Apr 6, 2023
dd8fd31
Update to_numpy_ndarraydocstring
AlenkaF Apr 7, 2023
1ebb829
Add a check for non-trivial permutation in to_numpy_ndarray
AlenkaF Apr 11, 2023
b2d0453
Update python/pyarrow/array.pxi
AlenkaF Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def print_entry(label, value):
union, sparse_union, dense_union,
dictionary,
run_end_encoded,
fixedshapetensor,
fixed_shape_tensor,
field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand Down
5 changes: 2 additions & 3 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3148,9 +3148,8 @@ class FixedShapeTensorArray(ExtensionArray):
size = obj.size / obj.shape[0]

return ExtensionArray.from_storage(
FixedShapeTensorType(arrow_type, shape),
array([t.flatten() for t in obj],
list_(arrow_type, size))
fixed_shape_tensor(arrow_type, shape),
FixedSizeListArray.from_arrays(obj.flatten(), size)
)


Expand Down
56 changes: 28 additions & 28 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import pyarrow as pa
from pyarrow.lib import tobytes

import pytest

Expand Down Expand Up @@ -1130,13 +1131,32 @@ def test_cpp_extension_in_python(tmpdir):


def test_tensor_type():
tensor_type = pa.FixedShapeTensorType(pa.int8(), (2, 3))
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type.storage_type == pa.list_(pa.int8(), 6)
assert tensor_type.shape == [2, 3]
assert not tensor_type.dim_names
assert not tensor_type.permutation

tensor_type = pa.fixed_shape_tensor(pa.float64(), [2, 2, 3],
permutation=[0, 2, 1])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type.storage_type == pa.list_(pa.float64(), 12)
assert tensor_type.shape == [2, 2, 3]
assert not tensor_type.dim_names
assert tensor_type.permutation == [0, 2, 1]

tensor_type = pa.fixed_shape_tensor(pa.bool_(), [2, 2, 3],
dim_names=['C', 'H', 'W'])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type.storage_type == pa.list_(pa.bool_(), 12)
assert tensor_type.shape == [2, 2, 3]
assert tensor_type.dim_names == [tobytes(x) for x in ['C', 'H', 'W']]
assert not tensor_type.permutation


def test_tensor_class_methods():
tensor_type = pa.FixedShapeTensorType(pa.float32(), (2, 3))
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
storage = pa.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]],
pa.list_(pa.float32(), 6))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
Expand All @@ -1153,9 +1173,9 @@ def test_tensor_class_methods():


@pytest.mark.parametrize("tensor_type", (
pa.FixedShapeTensorType(pa.int8(), (2, 2, 3)),
pa.FixedShapeTensorType(pa.int8(), (2, 2, 3), permutation=[0, 2, 1]),
pa.FixedShapeTensorType(pa.int8(), (2, 2, 3), dim_names=['C', 'H', 'W'])
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1]),
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], dim_names=['C', 'H', 'W'])
))
def test_tensor_type_ipc(tensor_type):
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
Expand All @@ -1181,32 +1201,12 @@ def test_tensor_type_ipc(tensor_type):
assert result.type.value_type == pa.int8()
assert result.type.shape == [2, 2, 3]

# using different parametrization as how it was registered
tensor_type_uint = tensor_type.__class__(pa.uint8(), (2, 3))
assert tensor_type_uint.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type_uint.value_type == pa.uint8()
assert tensor_type_uint.shape == [2, 3]

storage = pa.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]],
pa.list_(pa.uint8(), 6))
arr = pa.ExtensionArray.from_storage(tensor_type_uint, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])

buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
result = batch.column(0)
assert isinstance(result.type, pa.FixedShapeTensorType)
assert result.type.value_type == pa.uint8()
assert result.type.shape == [2, 3]
assert type(result) == tensor_class


def test_tensor_type_equality():
tensor_type = pa.FixedShapeTensorType(pa.int8(), (2, 2, 3))
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"

tensor_type2 = pa.FixedShapeTensorType(pa.int8(), (2, 2, 3))
tensor_type3 = pa.FixedShapeTensorType(pa.uint8(), (2, 2, 3))
tensor_type2 = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
tensor_type3 = pa.fixed_shape_tensor(pa.uint8(), [2, 2, 3])
assert tensor_type == tensor_type2
assert not tensor_type == tensor_type3
18 changes: 9 additions & 9 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1504,14 +1504,14 @@ cdef class FixedShapeTensorType(BaseExtensionType):
Create an instance of fixed shape tensor extension type:

>>> import pyarrow as pa
>>> pa.fixedshapetensor(pa.int32(), [2, 2])
>>> pa.fixed_shape_tensor(pa.int32(), [2, 2])
FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)

Create an instance of fixed shape tensor extension type with
permutation:

>>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
... permutation=[0, 2, 1])
>>> tensor_type = pa.fixed_shape_tensor(pa.int8(), (2, 2, 3),
... permutation=[0, 2, 1])
>>> tensor_type.permutation
[0, 2, 1]
"""
Expand Down Expand Up @@ -4615,7 +4615,7 @@ def run_end_encoded(run_end_type, value_type):
return pyarrow_wrap_data_type(ree_type)


def fixedshapetensor(DataType value_type, shape, dim_names=None, permutation=None):
def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=None):
"""
Create instance of fixed shape tensor extension type with shape and optional
names of tensor dimensions and indices of the desired ordering.
Expand All @@ -4636,7 +4636,7 @@ def fixedshapetensor(DataType value_type, shape, dim_names=None, permutation=Non
Create an instance of fixed shape tensor extension type:

>>> import pyarrow as pa
>>> tensor_type = pa.fixedshapetensor(pa.int32(), [2, 2])
>>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2])
>>> tensor_type
FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)

Expand All @@ -4661,16 +4661,16 @@ def fixedshapetensor(DataType value_type, shape, dim_names=None, permutation=Non
Create an instance of fixed shape tensor extension type with names
of tensor dimensions:

>>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
... dim_names=['C', 'H', 'W'])
>>> tensor_type = pa.fixed_shape_tensor(pa.int8(), (2, 2, 3),
... dim_names=['C', 'H', 'W'])
>>> tensor_type.dim_names
[b'C', b'H', b'W']

Create an instance of fixed shape tensor extension type with
permutation:

>>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
... permutation=[0, 2, 1])
>>> tensor_type = pa.fixed_shape_tensor(pa.int8(), (2, 2, 3),
... permutation=[0, 2, 1])
>>> tensor_type.permutation
[0, 2, 1]

Expand Down