Skip to content
Open
58 changes: 58 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, extend_bool, no_default, zip_strict
from narwhals.dependencies import get_pandas

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
Expand Down Expand Up @@ -407,6 +408,63 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
implementation=self._implementation,
)

def map_batches(
self,
function: Callable[[Any], Any],
return_dtype: IntoDType | None,
*,
returns_scalar: bool,
) -> Self: # pragma: no cover
if self._implementation.is_sqlframe():
msg = "`Expr.map_batches` is not supported for the sqlframe backend."
raise NotImplementedError(msg)

if returns_scalar:
msg = "`returns_scalar=True` is not supported for the pyspark backend."
raise NotImplementedError(msg)

def func(df: SparkLikeLazyFrame) -> list[Column]:
pd = get_pandas()
F = self._F

native_cols = self(df)
result_columns: list[Column] = []

for col_expr in native_cols:
if return_dtype is not None:
spark_type = narwhals_to_native_dtype(
return_dtype,
self._version,
self._native_dtypes,
df.native.sparkSession,
)
else:
spark_type = "float"

@F.pandas_udf(spark_type) # type: ignore[call-overload]
def udf_wrapper(s): # noqa: ANN001, ANN202
result = function(s)
if isinstance(result, pd.Series):
return result
# if the function returns a scalar, broadcast it to the length of s
series = pd.Series(result)
if len(series) == 1:
return pd.Series([result] * len(s))
return series

result_columns.append(udf_wrapper(col_expr))

return result_columns

return self.__class__(
func,
None,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)

@property
def str(self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)
Expand Down
6 changes: 1 addition & 5 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,7 @@ def map_batches(
|2 3 6 4.0 7.0|
└───────────────────────────┘
"""
kind = (
ExprKind.ORDERABLE_AGGREGATION
if returns_scalar
else ExprKind.ORDERABLE_FILTRATION
)
kind = ExprKind.AGGREGATION if returns_scalar else ExprKind.ELEMENTWISE
return self._append_node(
ExprNode(
kind,
Expand Down
70 changes: 56 additions & 14 deletions tests/expr_and_series/map_batches_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,39 @@
from tests.utils import (
PANDAS_VERSION,
POLARS_VERSION,
Constructor,
ConstructorEager,
assert_equal_data,
)

if TYPE_CHECKING:
from narwhals.dtypes import DType


def xfail_not_implemented(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
"""XFAIL if the constructor doesn't support map_batches."""
constructor_id = str(request.node.callspec.id)
if any(x in str(constructor) for x in ("dask", "duckdb", "ibis", "sqlframe")) or (
constructor_id == "pyspark[connect]"
):
pytest.xfail("constructor doesn't support map_batches")


data = {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]}


def test_map_batches_expr_compliant(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data))
def test_map_batches_expr_compliant(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
xfail_not_implemented(constructor, request)
df = nw.from_native(constructor(data))
expected = df.select(nw.col("a", "b").map_batches(lambda s: s + 1).name.suffix("1"))
assert_equal_data(expected, {"a1": [2, 3, 4], "b1": [5, 6, 7]})


# pyspark doesn't support returns_scalar=True
@pytest.mark.parametrize(
("value", "dtype"),
[(1, nw.Int64()), ("foo", nw.String()), ([1, 2], nw.List(nw.Int64()))],
Expand All @@ -48,16 +65,6 @@ def test_map_batches_expr_scalar(
assert_equal_data(expected, {"a": [value], "b": [value]})


def test_map_batches_expr_numpy_array(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data))
expected = df.select(
nw.col("a")
.map_batches(lambda s: s.to_numpy() + 1, return_dtype=nw.Float64())
.sum()
)
assert_equal_data(expected, {"a": [9.0]})


def test_map_batches_expr_numpy_scalar(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data))

Expand All @@ -67,8 +74,24 @@ def test_map_batches_expr_numpy_scalar(constructor_eager: ConstructorEager) -> N
assert_equal_data(expected, {"a": [2], "b": [2], "z": [2]})


def test_map_batches_expr_names(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data))
def test_map_batches_expr_numpy_array(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
xfail_not_implemented(constructor, request)
df = nw.from_native(constructor(data))
expected = df.select(
nw.col("a")
.map_batches(lambda s: s.to_numpy() + 1, return_dtype=nw.Float64())
.sum()
)
assert_equal_data(expected, {"a": [9.0]})


def test_map_batches_expr_names(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
xfail_not_implemented(constructor, request)
df = nw.from_native(constructor(data))
expected = nw.from_native(df.select(nw.all().map_batches(lambda x: x.to_numpy())))
assert_equal_data(expected, {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]})

Expand All @@ -89,3 +112,22 @@ def test_map_batches_exception(

with pytest.raises(TypeError, match=msg):
df.select(nw.all().map_batches(lambda s: s.to_numpy().argmax()))


@pytest.mark.parametrize(
("value", "dtype", "expected"),
[(1, None, [1.0] * 3), ("asd", nw.String(), ["asd"] * 3)],
)
def test_map_batches_pyspark_scalar(
constructor: Constructor,
request: pytest.FixtureRequest,
value: Any,
dtype: DType,
expected: Any,
) -> None: # pragma: no cover
constructor_id = str(request.node.callspec.id)
if constructor_id != "pyspark":
pytest.xfail("Test only valid for pyspark")
df = nw.from_native(constructor(data))
expected = df.select(nw.col("a").map_batches(lambda _: value, dtype))
assert_equal_data(expected, {"a": expected})
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def pyspark_session() -> SparkSession: # pragma: no cover
from pyspark.sql.connect.session import SparkSession
else:
from pyspark.sql import SparkSession
os.environ.setdefault("PYSPARK_PYTHON", sys.executable)
builder = cast("SparkSession.Builder", SparkSession.builder).appName("unit-tests")
builder = (
builder.remote(f"sc://localhost:{os.environ.get('SPARK_PORT', '15002')}")
Expand Down
Loading