diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index e0c4e23b02..76694511ea 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -13,6 +13,7 @@ from narwhals._pandas_like.utils import ( NUMPY_VERSION, align_and_extract_native, + binary_string_sum_fallback, broadcast_series_to_index, get_dtype_backend, import_array_module, @@ -399,23 +400,22 @@ def first(self) -> PythonLiteral: def last(self) -> PythonLiteral: return self.native.iloc[-1] if len(self.native) else None - def _with_binary(self, op: Callable[..., PandasLikeSeries], other: Any) -> Self: + def _with_binary(self, op: Callable[..., pd.Series], other: Any) -> Self: ser, other_native = align_and_extract_native(self, other) preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True) - if ( - str(self.native.dtype) == "large_string[pyarrow]" - and isinstance(other_native, str) - and op.__name__ == "add" - ): - # https://github.com/pandas-dev/pandas/issues/64393 - import pyarrow as pa # ignore-banned-import - - other_native = pa.scalar(other_native, type=pa.large_string()) - return self._with_native( - op(ser, other_native), preserve_broadcast=preserve_broadcast - ).alias(self.name) + try: + res = op(ser, other_native) + except Exception: + if op.__name__ == "add": + pdx = self.__native_namespace__() + res = binary_string_sum_fallback(ser, other_native, pdx) + else: + raise + return self._with_native(res, preserve_broadcast=preserve_broadcast).alias( + self.name + ) - def _with_binary_right(self, op: Callable[..., PandasLikeSeries], other: Any) -> Self: + def _with_binary_right(self, op: Callable[..., pd.Series], other: Any) -> Self: return self._with_binary(lambda x, y: op(y, x), other).alias(self.name) def __eq__(self, other: object) -> Self: # type: ignore[override] diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 136a5cac23..bf0aa27212 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -708,3 +708,35 @@ def broadcast_series_to_index( return series_class(pa_array, index=index, name=native.name) return series_class(value, index=index, dtype=native.dtype, name=native.name) + + +def binary_string_sum_fallback(left: pd.Series, right: Any, pdx: Any) -> pd.Series: + # Workaround some upstream issues: + # - https://github.com/pandas-dev/pandas/issues/64393 + # - https://github.com/pandas-dev/pandas/issues/65220 + left_dtype = left.dtype + left_dtype_str = str(left_dtype) + if left_dtype_str == "large_string[pyarrow]" and isinstance(right, str): + import pyarrow as pa # ignore-banned-import + + return left + pa.scalar(right, type=pa.large_string()) + if isinstance(right, pdx.Series): + right_dtype = right.dtype + if left_dtype_str == "object": # pragma: no cover + # Only for pandas pre 3.0. Anything is better than `object`, so take RHS. + return left.astype(right_dtype) + right + if hasattr(left.values, "__arrow_array__") and hasattr( + right.values, "__arrow_array__" + ): + import pyarrow as pa # ignore-banned-import + + left_arrow = left.values.__arrow_array__().type # noqa: PD011 # type: ignore[attr-defined] + right_arrow = right.values.__arrow_array__().type # noqa: PD011 # type: ignore[attr-defined] + if pa.types.is_string(left_arrow) and pa.types.is_large_string(right_arrow): + # LHS is smaller than RHS, so use the latter + return left.astype(right_dtype) + right + else: # pragma: no cover + pass + # Give precedence to the left-hand-side dtype. + return left + right.astype(left_dtype) + return left + right # pragma: no cover diff --git a/tests/expr_and_series/pandas_str_dtypes_test.py b/tests/expr_and_series/pandas_str_dtypes_test.py new file mode 100644 index 0000000000..802ea29297 --- /dev/null +++ b/tests/expr_and_series/pandas_str_dtypes_test.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import narwhals as nw +from tests.utils import assert_equal_data + +pytest.importorskip("pandas", minversion="3.0.0") +pytest.importorskip("pyarrow") + +import numpy as np +import pandas as pd +import pyarrow as pa + +STRING_DTYPE_NAN = pd.StringDtype("pyarrow", na_value=np.nan) # type: ignore[call-arg] +STRING_DTYPE_NA = pd.StringDtype("pyarrow", na_value=pd.NA) # type: ignore[call-arg] + + +@pytest.mark.parametrize( + ("left_dtype", "right_dtype", "result_dtype"), + [ + (STRING_DTYPE_NAN, STRING_DTYPE_NAN, STRING_DTYPE_NAN), + (STRING_DTYPE_NAN, STRING_DTYPE_NA, STRING_DTYPE_NAN), + (STRING_DTYPE_NAN, pd.ArrowDtype(pa.string()), STRING_DTYPE_NAN), + (STRING_DTYPE_NAN, pd.ArrowDtype(pa.large_string()), STRING_DTYPE_NAN), + (STRING_DTYPE_NA, STRING_DTYPE_NAN, STRING_DTYPE_NA), + (STRING_DTYPE_NA, STRING_DTYPE_NA, STRING_DTYPE_NA), + (STRING_DTYPE_NA, pd.ArrowDtype(pa.string()), STRING_DTYPE_NA), + (STRING_DTYPE_NA, pd.ArrowDtype(pa.large_string()), STRING_DTYPE_NA), + (pd.ArrowDtype(pa.string()), STRING_DTYPE_NAN, STRING_DTYPE_NAN), + (pd.ArrowDtype(pa.string()), STRING_DTYPE_NA, STRING_DTYPE_NA), + ( + pd.ArrowDtype(pa.string()), + pd.ArrowDtype(pa.string()), + pd.ArrowDtype(pa.string()), + ), + ( + pd.ArrowDtype(pa.string()), + pd.ArrowDtype(pa.large_string()), + pd.ArrowDtype(pa.large_string()), + ), + ( + pd.ArrowDtype(pa.large_string()), + STRING_DTYPE_NAN, + pd.ArrowDtype(pa.large_string()), + ), + ( + pd.ArrowDtype(pa.large_string()), + STRING_DTYPE_NA, + pd.ArrowDtype(pa.large_string()), + ), + ( + pd.ArrowDtype(pa.large_string()), + pd.ArrowDtype(pa.string()), + pd.ArrowDtype(pa.large_string()), + ), + ( + pd.ArrowDtype(pa.large_string()), + pd.ArrowDtype(pa.large_string()), + pd.ArrowDtype(pa.large_string()), + ), + ], +) +def test_pandas_str_types(left_dtype: Any, right_dtype: Any, result_dtype: Any) -> None: + import pandas as pd + + df = pd.DataFrame({"fruit": ["apple", "banana"]}, dtype=left_dtype) + df["new_str_col"] = "!" + df["new_str_col"] = df["new_str_col"].astype(right_dtype) + res = nw.from_native(df).with_columns( + concat_col=nw.concat_str([nw.col("fruit"), nw.col("new_str_col")]) + ) + expected = { + "fruit": ["apple", "banana"], + "new_str_col": ["!", "!"], + "concat_col": ["apple!", "banana!"], + } + assert_equal_data(res, expected) + assert res.to_native()["concat_col"].dtype == result_dtype