diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 805e00c8e4..3e9dbcca82 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -16,6 +16,7 @@ ) from narwhals._sql.expr import SQLExpr from narwhals._utils import Implementation, Version, extend_bool, no_default, zip_strict +from narwhals.dependencies import get_pandas if TYPE_CHECKING: from collections.abc import Iterator, Mapping, Sequence @@ -407,6 +408,63 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: implementation=self._implementation, ) + def map_batches( + self, + function: Callable[[Any], Any], + return_dtype: IntoDType | None, + *, + returns_scalar: bool, + ) -> Self: # pragma: no cover + if self._implementation.is_sqlframe(): + msg = "`Expr.map_batches` is not supported for the sqlframe backend." + raise NotImplementedError(msg) + + if returns_scalar: + msg = "`returns_scalar=True` is not supported for the pyspark backend." + raise NotImplementedError(msg) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + pd = get_pandas() + F = self._F + + native_cols = self(df) + result_columns: list[Column] = [] + + for col_expr in native_cols: + if return_dtype is not None: + spark_type = narwhals_to_native_dtype( + return_dtype, + self._version, + self._native_dtypes, + df.native.sparkSession, + ) + else: + spark_type = "float" + + @F.pandas_udf(spark_type) # type: ignore[call-overload] + def udf_wrapper(s): # noqa: ANN001, ANN202 + result = function(s) + if isinstance(result, pd.Series): + return result + # if the function returns a scalar, broadcast it to the length of s + series = pd.Series(result) + if len(series) == 1: + return pd.Series([result] * len(s)) + return series + + result_columns.append(udf_wrapper(col_expr)) + + return result_columns + + return self.__class__( + func, + None, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, + implementation=self._implementation, + ) + @property def str(self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) diff --git a/narwhals/expr.py b/narwhals/expr.py index 9b6f616f71..c81f302421 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -532,11 +532,7 @@ def map_batches( |2 3 6 4.0 7.0| └───────────────────────────┘ """ - kind = ( - ExprKind.ORDERABLE_AGGREGATION - if returns_scalar - else ExprKind.ORDERABLE_FILTRATION - ) + kind = ExprKind.AGGREGATION if returns_scalar else ExprKind.ELEMENTWISE return self._append_node( ExprNode( kind, diff --git a/tests/expr_and_series/map_batches_test.py b/tests/expr_and_series/map_batches_test.py index d6f8cc8b32..391a759846 100644 --- a/tests/expr_and_series/map_batches_test.py +++ b/tests/expr_and_series/map_batches_test.py @@ -10,6 +10,7 @@ from tests.utils import ( PANDAS_VERSION, POLARS_VERSION, + Constructor, ConstructorEager, assert_equal_data, ) @@ -17,15 +18,31 @@ if TYPE_CHECKING: from narwhals.dtypes import DType + +def xfail_not_implemented( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + """XFAIL if the constructor doesn't support map_batches.""" + constructor_id = str(request.node.callspec.id) + if any(x in str(constructor) for x in ("dask", "duckdb", "ibis", "sqlframe")) or ( + constructor_id == "pyspark[connect]" + ): + pytest.xfail("constructor doesn't support map_batches") + + data = {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]} -def test_map_batches_expr_compliant(constructor_eager: ConstructorEager) -> None: - df = nw.from_native(constructor_eager(data)) +def test_map_batches_expr_compliant( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + xfail_not_implemented(constructor, request) + df = nw.from_native(constructor(data)) expected = df.select(nw.col("a", "b").map_batches(lambda s: s + 1).name.suffix("1")) assert_equal_data(expected, {"a1": [2, 3, 4], "b1": [5, 6, 7]}) +# pyspark doesn't support returns_scalar=True @pytest.mark.parametrize( ("value", "dtype"), [(1, nw.Int64()), ("foo", nw.String()), ([1, 2], nw.List(nw.Int64()))], @@ -48,16 +65,6 @@ def test_map_batches_expr_scalar( assert_equal_data(expected, {"a": [value], "b": [value]}) -def test_map_batches_expr_numpy_array(constructor_eager: ConstructorEager) -> None: - df = nw.from_native(constructor_eager(data)) - expected = df.select( - nw.col("a") - .map_batches(lambda s: s.to_numpy() + 1, return_dtype=nw.Float64()) - .sum() - ) - assert_equal_data(expected, {"a": [9.0]}) - - def test_map_batches_expr_numpy_scalar(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager(data)) @@ -67,8 +74,24 @@ def test_map_batches_expr_numpy_scalar(constructor_eager: ConstructorEager) -> N assert_equal_data(expected, {"a": [2], "b": [2], "z": [2]}) -def test_map_batches_expr_names(constructor_eager: ConstructorEager) -> None: - df = nw.from_native(constructor_eager(data)) +def test_map_batches_expr_numpy_array( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + xfail_not_implemented(constructor, request) + df = nw.from_native(constructor(data)) + expected = df.select( + nw.col("a") + .map_batches(lambda s: s.to_numpy() + 1, return_dtype=nw.Float64()) + .sum() + ) + assert_equal_data(expected, {"a": [9.0]}) + + +def test_map_batches_expr_names( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + xfail_not_implemented(constructor, request) + df = nw.from_native(constructor(data)) expected = nw.from_native(df.select(nw.all().map_batches(lambda x: x.to_numpy()))) assert_equal_data(expected, {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]}) @@ -89,3 +112,22 @@ def test_map_batches_exception( with pytest.raises(TypeError, match=msg): df.select(nw.all().map_batches(lambda s: s.to_numpy().argmax())) + + +@pytest.mark.parametrize( + ("value", "dtype", "expected"), + [(1, None, [1.0] * 3), ("asd", nw.String(), ["asd"] * 3)], +) +def test_map_batches_pyspark_scalar( + constructor: Constructor, + request: pytest.FixtureRequest, + value: Any, + dtype: DType, + expected: Any, +) -> None: # pragma: no cover + constructor_id = str(request.node.callspec.id) + if constructor_id != "pyspark": + pytest.xfail("Test only valid for pyspark") + df = nw.from_native(constructor(data)) + expected = df.select(nw.col("a").map_batches(lambda _: value, dtype)) + assert_equal_data(expected, {"a": expected}) diff --git a/tests/utils.py b/tests/utils.py index 4d01223b2a..9d6c2c543c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -187,6 +187,7 @@ def pyspark_session() -> SparkSession: # pragma: no cover from pyspark.sql.connect.session import SparkSession else: from pyspark.sql import SparkSession + os.environ.setdefault("PYSPARK_PYTHON", sys.executable) builder = cast("SparkSession.Builder", SparkSession.builder).appName("unit-tests") builder = ( builder.remote(f"sc://localhost:{os.environ.get('SPARK_PORT', '15002')}")