diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 98282a575e..83bb1c2bce 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -25,7 +25,7 @@ from narwhals._arrow.typing import ChunkedArrayAny, Incomplete, ScalarAny from narwhals._utils import Version - from narwhals.typing import IntoDType, NonNestedLiteral + from narwhals.typing import IntoDType, PythonLiteral class ArrowNamespace( @@ -64,7 +64,7 @@ def len(self) -> ArrowExpr: version=self._version, ) - def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr: + def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> ArrowExpr: def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries.from_iterable( data=[value], name="literal", context=self diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index c4791a7e0d..af3a9bc2f9 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -22,7 +22,7 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._utils import Implementation, zip_strict +from narwhals._utils import Implementation, is_nested_literal, zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -55,6 +55,10 @@ def __init__(self, *, version: Version) -> None: self._version = version def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr: + if is_nested_literal(value): + msg = f"Nested structures are not supported for Dask backend, found {type(value).__name__}" + raise NotImplementedError(msg) + def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: native_dtype = narwhals_to_native_dtype(dtype, self._version) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index ebc5041e68..71a94e4a95 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -35,7 +35,7 @@ from narwhals._compliant.window import WindowInputs from narwhals._utils import Version - from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral + from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral VARCHAR = duckdb_dtypes.VARCHAR @@ -130,8 +130,12 @@ def func(cols: Iterable[Expression]) -> Expression: return self._expr._from_elementwise_horizontal_op(func, *exprs) - def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr: + def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[Expression]: + if isinstance(value, dict) and not value: + msg = "Cannot create an empty struct type for DuckDB backend" + raise NotImplementedError(msg) + tz = DeferredTimeZone(df.native) if dtype is not None: target = narwhals_to_native_dtype(dtype, self._version, tz) diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index ddda316742..801a04dbd4 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -114,10 +114,15 @@ def func(cols: Iterable[ir.Value]) -> ir.Value: return self._expr._from_elementwise_horizontal_op(func, *exprs) - def lit(self, value: Any, dtype: IntoDType | None) -> IbisExpr: + def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> IbisExpr: def func(_df: IbisLazyFrame) -> Sequence[ir.Value]: ibis_dtype = narwhals_to_native_dtype(dtype, self._version) if dtype else None - return [lit(value, ibis_dtype)] + if not isinstance(value, dict): + return [lit(value, ibis_dtype)] + if value: + return [ibis.struct(value, type=ibis_dtype)] + msg = "Cannot create an empty struct type for Ibis backend" + raise NotImplementedError(msg) return self._expr( func, diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index c822ecec04..257f52dcaa 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -10,6 +10,7 @@ from narwhals._pandas_like.series import PANDAS_TO_NUMPY_DTYPE_MISSING, PandasLikeSeries from narwhals._pandas_like.utils import ( align_and_extract_native, + broadcast_series_to_index, get_dtype_backend, import_array_module, iter_dtype_backends, @@ -307,8 +308,12 @@ def _with_native(self, df: Any, *, validate_column_names: bool = True) -> Self: def _extract_comparand(self, other: PandasLikeSeries) -> pd.Series[Any]: index = self.native.index if other._broadcast: - s = other.native - return type(s)(s.iloc[0], index=index, dtype=s.dtype, name=s.name) + native = other.native + is_nested = other.dtype.is_nested() + return broadcast_series_to_index( + native, index, is_nested=is_nested, series_class=type(native) + ) + if (len_other := len(other)) != (len_idx := len(index)): msg = f"Expected object of length {len_idx}, got: {len_other}." raise ShapeError(msg) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 96b7a22290..604d06bd25 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -25,7 +25,7 @@ from typing_extensions import TypeAlias from narwhals._utils import Implementation, Version - from narwhals.typing import IntoDType, NonNestedLiteral + from narwhals.typing import IntoDType, PythonLiteral Incomplete: TypeAlias = Any @@ -83,17 +83,46 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: context=self, ) - def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> PandasLikeExpr: + def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> PandasLikeExpr: def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: - pandas_series = self._series.from_iterable( + if isinstance(value, (list, tuple, dict)): + try: + import pandas as pd # ignore-banned-import + import pyarrow as pa # ignore-banned-import + except ImportError as exc: # pragma: no cover + msg = ( + "Nested structures require pyarrow to be installed for pandas backend. " + "Please install pyarrow: pip install pyarrow" + ) + raise ImportError(msg) from exc + + from narwhals._arrow.utils import ( + narwhals_to_native_dtype as _to_arrow_dtype, + ) + + array_value = list(value) if isinstance(value, tuple) else value + pa_dtype = _to_arrow_dtype(dtype, self._version) if dtype else None + pa_array = pa.array([array_value], type=pa_dtype) # type: ignore[arg-type, list-item] + + # Use ArrowExtensionArray to avoid pandas unpacking the nested structure + ns = self._implementation.to_native_namespace() + pandas_series_native = ns.Series( + pd.arrays.ArrowExtensionArray(pa_array), # type: ignore[attr-defined] + name="literal", + index=df._native_frame.index[0:1], + ) + + return self._series.from_native(pandas_series_native, context=self) + + pandas_like_series = self._series.from_iterable( data=[value], name="literal", index=df._native_frame.index[0:1], context=self, ) if dtype: - return pandas_series.cast(dtype) - return pandas_series + return pandas_like_series.cast(dtype) + return pandas_like_series return PandasLikeExpr( lambda df: [_lit_pandas_series(df)], diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 09cae7120e..187a5c6c42 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -14,6 +14,7 @@ from narwhals._pandas_like.series_struct import PandasLikeSeriesStructNamespace from narwhals._pandas_like.utils import ( align_and_extract_native, + broadcast_series_to_index, get_dtype_backend, import_array_module, narwhals_to_native_dtype, @@ -211,8 +212,8 @@ def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]: reindexed = [] for s in series: if s._broadcast: - native = Series( - s.native.iloc[0], index=idx, name=s.name, dtype=s.native.dtype + native = broadcast_series_to_index( + s.native, idx, is_nested=s.dtype.is_nested(), series_class=Series ) compliant = s._with_native(native) elif s.native.index is not idx: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2de5813b81..1b5b60e71b 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -663,3 +663,37 @@ class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]): def make_group_by_kwargs(*, drop_null_keys: bool) -> dict[str, bool]: return {"sort": False, "as_index": True, "dropna": drop_null_keys, "observed": True} + + +def broadcast_series_to_index( + native: pd.Series[Any], + index: Any, + *, + is_nested: bool, + series_class: type[pd.Series[Any]], +) -> pd.Series[Any]: + """Broadcast a scalar value from a (one element) Series to match a target index. + + For nested (arrow-backed) types, we rely on + [`pandas.array`](https://pandas.pydata.org/docs/reference/api/pandas.array.html). + + Arguments: + native: The native pandas-like Series containing the scalar value to broadcast. + index: The target index to broadcast to. + is_nested: Whether the Series has a nested (arrow-backed) dtype. + series_class: Series class to use for constructing the result. + + Returns: + A new Series with the scalar value broadcast to match the target index. + """ + value = native.iloc[0] + if is_nested: + from narwhals._arrow.utils import repeat + + # NOTE: Ignore typing because `pandas-stubs` are wrong + # TODO(FBruzzesi): Should we pass the `copy=False` flag? + pa_array = pd.array(repeat(value, len(index)), dtype=native.dtype) # type: ignore[arg-type] + + return series_class(pa_array, index=index, name=native.name) + + return series_class(value, index=index, dtype=native.dtype, name=native.name) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index c660b67298..b96b807575 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -28,7 +28,7 @@ from narwhals._compliant.window import WindowInputs from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401 from narwhals._utils import Implementation, Version - from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral, PythonLiteral + from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral # Adjust slight SQL vs PySpark differences FUNCTION_REMAPPINGS = { @@ -91,9 +91,22 @@ def _when( def _coalesce(self, *exprs: Column) -> Column: return self._F.coalesce(*exprs) - def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr: + def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: - column = df._F.lit(value) + F = df._F + + if isinstance(value, (list, tuple)): + lit_values = [F.lit(v) for v in value] + column = F.lit(F.array(lit_values)) + elif isinstance(value, dict): + if (not self._implementation.is_pyspark()) and (len(value) == 0): + msg = f"Cannot create an empty struct type for {self._implementation} backend" + raise NotImplementedError(msg) + lit_values = [F.lit(v).alias(k) for k, v in value.items()] + column = F.struct(*lit_values) + else: + column = F.lit(value) + if dtype: native_dtype = narwhals_to_native_dtype( dtype, self._version, df._native_dtypes, df.native.sparkSession diff --git a/narwhals/_utils.py b/narwhals/_utils.py index a02552e90a..bc1795fb59 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -122,6 +122,7 @@ FileSource, IntoSeriesT, MultiIndexSelector, + NestedLiteral, SingleIndexSelector, SizedMultiBoolSelector, SizedMultiIndexSelector, @@ -1371,6 +1372,10 @@ def is_sequence_of(obj: Any, tp: type[_T]) -> TypeIs[Sequence[_T]]: ) +def is_nested_literal(obj: Any) -> TypeIs[NestedLiteral]: + return isinstance(obj, (list, tuple, dict)) + + def validate_strict_and_pass_though( strict: bool | None, # noqa: FBT001 pass_through: bool | None, # noqa: FBT001 diff --git a/narwhals/functions.py b/narwhals/functions.py index 322da7019e..a19478eb21 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -12,6 +12,7 @@ deprecate_native_namespace, flatten, is_eager_allowed, + is_nested_literal, is_sequence_but_not_str, normalize_path, supports_arrow_c_stream, @@ -46,6 +47,7 @@ IntoExpr, IntoSchema, NonNestedLiteral, + PythonLiteral, _2DArray, ) @@ -1422,27 +1424,70 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> ) -def lit(value: NonNestedLiteral, dtype: IntoDType | None = None) -> Expr: +def lit(value: PythonLiteral, dtype: IntoDType | None = None) -> Expr: """Return an expression representing a literal value. Arguments: - value: The value to use as literal. + value: The value to use as literal. Can be a scalar value, list, tuple, or dict. + Lists and tuples are converted to `List` dtype, dicts to `Struct` dtype. dtype: The data type of the literal value. If not provided, the data type will - be inferred by the native library. + be inferred by the native library. For empty lists/dicts, dtype must be + specified explicitly. Examples: - >>> import pandas as pd + Scalar literals: + + >>> import pyarrow as pa >>> import narwhals as nw >>> - >>> df_native = pd.DataFrame({"a": [1, 2]}) - >>> nw.from_native(df_native).with_columns(nw.lit(3)) + >>> df_nw = nw.from_native(pa.table({"a": [1, 2]})) + >>> df_nw.with_columns(nw.lit(3)) ┌──────────────────┐ |Narwhals DataFrame| |------------------| - | a literal | - | 0 1 3 | - | 1 2 3 | + | pyarrow.Table | + | a: int64 | + | literal: int64 | + | ---- | + | a: [[1,2]] | + | literal: [[3,3]] | └──────────────────┘ + + List literals (creates a List column): + + >>> df_nw.with_columns(nw.lit([1, 2, 3]).alias("list_col")) + ┌─────────────────────────────┐ + | Narwhals DataFrame | + |-----------------------------| + |pyarrow.Table | + |a: int64 | + |list_col: list | + | child 0, item: int64 | + |---- | + |a: [[1,2]] | + |list_col: [[[1,2,3],[1,2,3]]]| + └─────────────────────────────┘ + + Dict literals (creates a Struct column): + + >>> df_nw.with_columns(nw.lit({"x": 1, "y": 2}).alias("struct_col")) + ┌──────────────────────────────────────┐ + | Narwhals DataFrame | + |--------------------------------------| + |pyarrow.Table | + |a: int64 | + |struct_col: struct| + | child 0, x: int64 | + | child 1, y: int64 | + |---- | + |a: [[1,2]] | + |struct_col: [ | + | -- is_valid: all not null | + | -- child 0 type: int64 | + |[1,1] | + | -- child 1 type: int64 | + |[2,2]] | + └──────────────────────────────────────┘ """ if is_numpy_array(value): msg = ( @@ -1450,11 +1495,18 @@ def lit(value: NonNestedLiteral, dtype: IntoDType | None = None) -> Expr: "Consider using `with_columns` to create a new column from the array." ) raise ValueError(msg) - - if isinstance(value, (list, tuple)): - msg = f"Nested datatypes are not supported yet. Got {value}" - raise NotImplementedError(msg) - + if is_nested_literal(value): + if not value: + if not dtype: + msg = "Cannot infer dtype for empty nested structure. Please provide an explicit dtype parameter." + raise ValueError(msg) + elif isinstance(value, dict): + if any(is_nested_literal(v) for v in value.values()): + msg = "Nested structures with nested values are not supported." + raise NotImplementedError(msg) + elif is_nested_literal(value[0]): + msg = "Nested structures with nested values are not supported." + raise NotImplementedError(msg) return Expr(ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype)) diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index 1627384601..864bd0a0e2 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -661,9 +661,11 @@ def lit(value: NonNestedLiteral, dtype: IntoDType | None = None) -> Expr: """Return an expression representing a literal value. Arguments: - value: The value to use as literal. + value: The value to use as literal. Can be a scalar value, list, tuple, or dict. + Lists and tuples are converted to `List` dtype, dicts to `Struct` dtype. dtype: The data type of the literal value. If not provided, the data type will - be inferred by the native library. + be inferred by the native library. For empty lists/dicts, dtype must be + specified explicitly. """ return _stableify(nw.lit(value, dtype)) diff --git a/narwhals/typing.py b/narwhals/typing.py index cf2852fc83..9454186722 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -270,7 +270,8 @@ def Binary(self) -> type[dtypes.Binary]: ... NonNestedLiteral: TypeAlias = ( "NumericLiteral | TemporalLiteral | str | bool | bytes | None" ) -PythonLiteral: TypeAlias = "NonNestedLiteral | list[Any] | tuple[Any, ...]" +NestedLiteral: TypeAlias = "list[Any] | tuple[Any, ...] | dict[str, Any]" +PythonLiteral: TypeAlias = "NonNestedLiteral | NestedLiteral" NonNestedDType: TypeAlias = "dtypes.NumericType | dtypes.TemporalType | dtypes.String | dtypes.Boolean | dtypes.Binary | dtypes.Categorical | dtypes.Unknown | dtypes.Object" """Any Narwhals DType that does not have required arguments.""" diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index 4a23ab9629..f0106f76e2 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from datetime import date from typing import TYPE_CHECKING, Any @@ -11,12 +10,16 @@ CUDF_VERSION, DASK_VERSION, PANDAS_VERSION, + POLARS_VERSION, + PYARROW_VERSION, Constructor, assert_equal_data, + is_pyspark_connect, ) if TYPE_CHECKING: from narwhals.dtypes import DType + from narwhals.typing import IntoDType, PythonLiteral @pytest.mark.parametrize( @@ -45,14 +48,6 @@ def test_lit_error(constructor: Constructor) -> None: ValueError, match="numpy arrays are not supported as literal values" ): _ = df.with_columns(nw.lit(np.array([1, 2])).alias("lit")) # pyright: ignore[reportArgumentType] - with pytest.raises( - NotImplementedError, match=re.escape("Nested datatypes are not supported yet.") - ): - _ = df.with_columns(nw.lit((1, 2)).alias("lit")) # type: ignore[arg-type] - with pytest.raises( - NotImplementedError, match=re.escape("Nested datatypes are not supported yet.") - ): - _ = df.with_columns(nw.lit([1, 2]).alias("lit")) # type: ignore[arg-type] def test_lit_out_name(constructor: Constructor) -> None: @@ -143,3 +138,103 @@ def test_pyarrow_lit_string() -> None: assert pa.types.is_string(result.type) result = df.select(nw.lit("foo", dtype=nw.String)).to_native().schema.field("literal") assert pa.types.is_string(result.type) + + +@pytest.mark.parametrize( + ("value", "dtype"), + [ + # Empty nested structures + ((), nw.List(nw.Int32())), + ([], nw.List(nw.Int32())), + ({}, nw.Struct({})), + # Nested structures with different size from dataframe + (("foo", "bar"), None), + (["orca", "narwhal"], None), + ({"field_1": 42}, None), + # Nested structures with same size as dataframe + (("foo", "bar", "baz"), None), + (["orca", "narwhal", "penguin"], None), + ( + {"field_1": 42, "field_2": 1.2, "field_3": True}, + nw.Struct( + {"field_1": nw.Int32(), "field_2": nw.Float64(), "field_3": nw.Boolean()} + ), + ), + ], +) +def test_nested_structures( + request: pytest.FixtureRequest, + constructor: Constructor, + value: PythonLiteral, + dtype: IntoDType | None, +) -> None: + is_empty_dict = isinstance(value, dict) and len(value) == 0 + non_pyspark_sql_like = ("duckdb", "sqlframe", "ibis") + is_non_pyspark_sql_like = any(x in str(constructor) for x in non_pyspark_sql_like) + if (is_non_pyspark_sql_like or is_pyspark_connect(constructor)) and is_empty_dict: + reason = "Cannot create an empty struct type for backend" + request.applymarker(pytest.mark.xfail(reason=reason, raises=NotImplementedError)) + + # TODO(FBruzzesi): Check cudf + if any(x in str(constructor) for x in ("cudf", "dask")): + reason = "Nested structures are not support for backend" + request.applymarker(pytest.mark.xfail(reason=reason, raises=NotImplementedError)) + + if any(x in str(constructor) for x in ("pandas", "modin")) and ( + PYARROW_VERSION == (0, 0, 0) or PANDAS_VERSION < (2, 0) + ): # pragma: no cover + reason = "Requires pyarrow and pandas 2.0+" + pytest.skip(reason=reason) + + if ( + "polars" in str(constructor) + and isinstance(value, dict) + and POLARS_VERSION < (1, 0, 0) + ): # pragma: no cover + reason = "polars<1.0 does not support dict to struct in lit" + pytest.skip(reason=reason) + + size = 3 + data = {"a": list(range(size))} + expr = nw.lit(value, dtype=dtype).alias("nested") + + value_ = list(value) if isinstance(value, tuple) else value + expected_nested = {"nested": [value_] * size} + + frame = nw.from_native(constructor(data)) + + result_with_cols = frame.with_columns(expr) + assert_equal_data(result_with_cols, {**data, **expected_nested}) + + result_select = frame.select(expr, nw.col("a")) + assert_equal_data(result_select, {**expected_nested, **data}) + + +@pytest.mark.parametrize("value", [[], (), {}]) +def test_raise_empty_nested_structures(value: PythonLiteral) -> None: + msg = "Cannot infer dtype for empty nested structure. Please provide an explicit dtype parameter." + with pytest.raises(ValueError, match=msg): + nw.lit(value=value) + + +@pytest.mark.parametrize( + "value", + [ + # List containing nested structures + [[1, 2], [3, 4]], + [(1, 2), (3, 4)], + [{"a": 1}, {"a": 2}], + # Tuple containing nested structures + ([1, 2], [3, 4]), + ((1, 2), (3, 4)), + ({"a": 1}, {"a": 2}), + # Dict containing nested structures + {"a": [1, 2], "b": [3, 4]}, + {"a": (1, 2), "b": (3, 4)}, + {"a": {"x": 1}, "b": {"y": 2}}, + ], +) +def test_raise_nested_structures_with_nested_values(value: Any) -> None: + msg = "Nested structures with nested values are not supported." + with pytest.raises(NotImplementedError, match=msg): + nw.lit(value=value) diff --git a/tests/expr_and_series/replace_strict_test.py b/tests/expr_and_series/replace_strict_test.py index 7d4b5f4bf6..14bce078e1 100644 --- a/tests/expr_and_series/replace_strict_test.py +++ b/tests/expr_and_series/replace_strict_test.py @@ -1,13 +1,18 @@ from __future__ import annotations -import os from typing import TYPE_CHECKING, Any import pytest import narwhals as nw from narwhals.exceptions import InvalidOperationError -from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data +from tests.utils import ( + POLARS_VERSION, + Constructor, + ConstructorEager, + assert_equal_data, + xfail_if_pyspark_connect, +) if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -30,15 +35,6 @@ def xfail_if_no_default(constructor: Constructor, request: pytest.FixtureRequest request.applymarker(pytest.mark.xfail(reason=reason)) -def xfail_if_pyspark_connect( # pragma: no cover - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - is_spark_connect = os.environ.get("SPARK_CONNECT", None) - if is_spark_connect and "pyspark" in str(constructor): - reason = "`mapping_expr[expr]` raises: pyspark.errors.exceptions.base.PySparkTypeError: [UNSUPPORTED_DATA_TYPE] Unsupported DataType `Column`." - request.applymarker(pytest.mark.xfail(reason=reason)) - - @pytest.mark.parametrize( ("old", "new", "return_dtype"), [ @@ -139,7 +135,11 @@ def test_replace_strict_pandas_unnamed_series() -> None: def test_replace_strict_expr_with_default( constructor: Constructor, request: pytest.FixtureRequest, return_dtype: DType | None ) -> None: - xfail_if_pyspark_connect(constructor, request) + spark_connect_reason = ( + "`mapping_expr[expr]` raises: pyspark.errors.exceptions.base.PySparkTypeError: " + "[UNSUPPORTED_DATA_TYPE] Unsupported DataType `Column`." + ) + xfail_if_pyspark_connect(constructor, request, reason=spark_connect_reason) if "polars" in str(constructor) and polars_lt_v1: pytest.skip(reason=pl_skip_reason) diff --git a/tests/utils.py b/tests/utils.py index 693569feb6..47bcd34884 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, cast +import pytest + import narwhals as nw from narwhals._utils import Implementation, parse_version, zip_strict from narwhals.dependencies import get_pandas @@ -17,7 +19,6 @@ from collections.abc import Mapping, Sequence import pandas as pd - import pytest from pyspark.sql import SparkSession from sqlframe.duckdb import DuckDBSession from typing_extensions import TypeAlias @@ -252,3 +253,15 @@ def time_unit_compat(time_unit: TimeUnit, request: pytest.FixtureRequest, /) -> if PANDAS_VERSION < (2,) and any(name in request_id for name in pandas_like): return "ns" return time_unit + + +def is_pyspark_connect(constructor: Constructor) -> bool: + is_spark_connect = bool(os.environ.get("SPARK_CONNECT", None)) + return is_spark_connect and ("pyspark" in str(constructor)) + + +def xfail_if_pyspark_connect( # pragma: no cover + constructor: Constructor, request: pytest.FixtureRequest, reason: str = "" +) -> None: + if is_pyspark_connect(constructor): + request.applymarker(pytest.mark.xfail(reason=reason))