-
Notifications
You must be signed in to change notification settings - Fork 191
fix: Add fallback for incompatible string concenation in pandas #3548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
9a5a726
b94ba8e
4187c38
4d6368c
b1f312d
ddf5615
b2e4cec
c9f28b8
1938a27
b8b66d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| from narwhals._pandas_like.utils import ( | ||
| NUMPY_VERSION, | ||
| align_and_extract_native, | ||
| binary_string_sum_fallback, | ||
| broadcast_series_to_index, | ||
| get_dtype_backend, | ||
| import_array_module, | ||
|
|
@@ -399,23 +400,22 @@ def first(self) -> PythonLiteral: | |
| def last(self) -> PythonLiteral: | ||
| return self.native.iloc[-1] if len(self.native) else None | ||
|
|
||
| def _with_binary(self, op: Callable[..., PandasLikeSeries], other: Any) -> Self: | ||
| def _with_binary(self, op: Callable[..., pd.Series], other: Any) -> Self: | ||
| ser, other_native = align_and_extract_native(self, other) | ||
| preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True) | ||
| if ( | ||
| str(self.native.dtype) == "large_string[pyarrow]" | ||
| and isinstance(other_native, str) | ||
| and op.__name__ == "add" | ||
| ): | ||
| # https://github.com/pandas-dev/pandas/issues/64393 | ||
| import pyarrow as pa # ignore-banned-import | ||
|
|
||
| other_native = pa.scalar(other_native, type=pa.large_string()) | ||
| return self._with_native( | ||
| op(ser, other_native), preserve_broadcast=preserve_broadcast | ||
| ).alias(self.name) | ||
| try: | ||
| res = op(ser, other_native) | ||
| except Exception: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
very long tracebacksimport pandas as pd
import narwhals as nw
df = pd.DataFrame({"count": [1, 2, 3], "fruit": ["apple", "banana", "orange"]})
nw.from_native(df).with_columns(concat=nw.col("count") + nw.col("fruit"))produces
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your review! Regarding making the error more specific, I'd prefer to avoid that in this PR. Currently, this keeps existing behaviour the same, but just changes some previously-failing paths to now pass. If we change the error, then we're changing what may previously have been Agree with checking the dtypes before the fallback though ๐ Have added that |
||
| if op.__name__ == "add": | ||
| pdx = self.__native_namespace__() | ||
| res = binary_string_sum_fallback(ser, other_native, pdx) | ||
| else: | ||
| raise | ||
|
FBruzzesi marked this conversation as resolved.
|
||
| return self._with_native(res, preserve_broadcast=preserve_broadcast).alias( | ||
| self.name | ||
| ) | ||
|
|
||
| def _with_binary_right(self, op: Callable[..., PandasLikeSeries], other: Any) -> Self: | ||
| def _with_binary_right(self, op: Callable[..., pd.Series], other: Any) -> Self: | ||
| return self._with_binary(lambda x, y: op(y, x), other).alias(self.name) | ||
|
|
||
| def __eq__(self, other: object) -> Self: # type: ignore[override] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any | ||
|
|
||
| import pytest | ||
|
|
||
| import narwhals as nw | ||
| from tests.utils import assert_equal_data | ||
|
|
||
| pytest.importorskip("pandas", minversion="3.0.0") | ||
| pytest.importorskip("pyarrow") | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import pyarrow as pa | ||
|
|
||
| STRING_DTYPE_NAN = pd.StringDtype("pyarrow", na_value=np.nan) # type: ignore[call-arg] | ||
| STRING_DTYPE_NA = pd.StringDtype("pyarrow", na_value=pd.NA) # type: ignore[call-arg] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would there be any purpose to add non-pyarrow backed StringDtypes into this test?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, done, thanks! |
||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("left_dtype", "right_dtype", "result_dtype"), | ||
| [ | ||
| (STRING_DTYPE_NAN, STRING_DTYPE_NAN, STRING_DTYPE_NAN), | ||
| (STRING_DTYPE_NAN, STRING_DTYPE_NA, STRING_DTYPE_NAN), | ||
| (STRING_DTYPE_NAN, pd.ArrowDtype(pa.string()), STRING_DTYPE_NAN), | ||
| (STRING_DTYPE_NAN, pd.ArrowDtype(pa.large_string()), STRING_DTYPE_NAN), | ||
| (STRING_DTYPE_NA, STRING_DTYPE_NAN, STRING_DTYPE_NA), | ||
| (STRING_DTYPE_NA, STRING_DTYPE_NA, STRING_DTYPE_NA), | ||
| (STRING_DTYPE_NA, pd.ArrowDtype(pa.string()), STRING_DTYPE_NA), | ||
| (STRING_DTYPE_NA, pd.ArrowDtype(pa.large_string()), STRING_DTYPE_NA), | ||
| (pd.ArrowDtype(pa.string()), STRING_DTYPE_NAN, pd.ArrowDtype(pa.large_string())), | ||
| (pd.ArrowDtype(pa.string()), STRING_DTYPE_NA, pd.ArrowDtype(pa.large_string())), | ||
| ( | ||
| pd.ArrowDtype(pa.string()), | ||
| pd.ArrowDtype(pa.string()), | ||
| pd.ArrowDtype(pa.string()), | ||
| ), | ||
| ( | ||
| pd.ArrowDtype(pa.string()), | ||
| pd.ArrowDtype(pa.large_string()), | ||
| pd.ArrowDtype(pa.large_string()), | ||
| ), | ||
| ( | ||
| pd.ArrowDtype(pa.large_string()), | ||
| STRING_DTYPE_NAN, | ||
| pd.ArrowDtype(pa.large_string()), | ||
| ), | ||
| ( | ||
| pd.ArrowDtype(pa.large_string()), | ||
| STRING_DTYPE_NA, | ||
| pd.ArrowDtype(pa.large_string()), | ||
| ), | ||
| ( | ||
| pd.ArrowDtype(pa.large_string()), | ||
| pd.ArrowDtype(pa.string()), | ||
| pd.ArrowDtype(pa.large_string()), | ||
| ), | ||
| ( | ||
| pd.ArrowDtype(pa.large_string()), | ||
| pd.ArrowDtype(pa.large_string()), | ||
| pd.ArrowDtype(pa.large_string()), | ||
| ), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cases are serve effectively as a handwritten truth table. Can this be refactored to
This reduces the length of the parameters spec and codifies the expected rules for returning a datatype. DTYPES = [
pd.StringDtype("pyarrow", na_value=np.nan), # type: ignore[call-arg]
pd.StringDtype("pyarrow", na_value=pd.NA), # type: ignore[call-arg]
pd.ArrowDtype(pa.string()),
pd.ArrowDtype(pa.large_string()),
]
@pytest.mark.parametrize(
("left_dtype", "right_dtype"),
[*product(DTYPES, repeat=2)],
)
def test_pandas_str_types(...):
...
assert_equal_data(res, expected)
result_dtype = res.to_native()["concat_col"].dtype
if isinstance(left_dtype, pd.StringDtype):
expected_dtype = left_dtype
elif isinstance(left_dtype, pd.ArrowDtype) and isinstance(right_dtype, pd.StringDtype):
expected_dtype = pd.ArrowDtype(pa.large_string())
elif isinstance(left_dtype, pd.ArrowDtype) and isinstance(right_dtype, pd.ArrowDtype):
left_type = left_dtype.pyarrow_dtype
right_type = right_dtype.pyarrow_dtype
if pa.types.is_large_string(left_type) or pa.types.is_large_string(right_type):
expected_dtype = pd.ArrowDtype(pa.large_string())
else:
expected_dtype = pd.ArrowDtype(pa.string())
assert result_dtype == expected_dtype
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for taking a look! i appreciate the suggestion ๐ however, i generally really dislike having logic in tests and try to minimise it where possible, I much prefer writing out the test cases by hand. there was a blog post on this a few* years ago which i really liked https://testing.googleblog.com/2014/07/testing-on-toilet-dont-put-logic-in.html *has it really been 12 years already? ๐ฎ |
||
| ], | ||
| ) | ||
| def test_pandas_str_types(left_dtype: Any, right_dtype: Any, result_dtype: Any) -> None: | ||
| import pandas as pd | ||
|
|
||
| df = pd.DataFrame({"fruit": ["apple", "banana"]}, dtype=left_dtype) | ||
| df["new_str_col"] = "!" | ||
| df["new_str_col"] = df["new_str_col"].astype(right_dtype) # pyrefly: ignore[missing-attribute] https://github.com/facebook/pyrefly/issues/3299 | ||
| res = nw.from_native(df).with_columns( | ||
| concat_col=nw.concat_str([nw.col("fruit"), nw.col("new_str_col")]) | ||
| ) | ||
| expected = { | ||
| "fruit": ["apple", "banana"], | ||
| "new_str_col": ["!", "!"], | ||
| "concat_col": ["apple!", "banana!"], | ||
| } | ||
| assert_equal_data(res, expected) | ||
| assert res.to_native()["concat_col"].dtype == result_dtype | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch here :)