diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 12b6ba5edb..84463f494f 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -126,8 +126,8 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: expr_results = tuple(chain.from_iterable(expr(df) for expr in exprs)) - series = [s.fill_null(0, strategy=None, limit=None) for s in expr_results] - non_na = [1 - s.is_null().cast(int_64) for s in expr_results] + series = (s.fill_null(0, strategy=None, limit=None) for s in expr_results) + non_na = (1 - s.is_null().cast(int_64) for s in expr_results) return [reduce(operator.add, series) / reduce(operator.add, non_na)] return self._expr._from_callable( @@ -137,15 +137,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: context=self, ) - def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def _min_max_horizontal( + self, exprs: Sequence[ArrowExpr], /, op: Literal["min", "max"] + ) -> ArrowExpr: + agg = pc.min_element_wise if op == "min" else pc.max_element_wise + def func(df: ArrowDataFrame) -> list[ArrowSeries]: - init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs)) - native_series = reduce( - pc.min_element_wise, [s.native for s in series], init_series.native - ) - return [ - ArrowSeries(native_series, name=init_series.name, version=self._version) - ] + series = tuple(chain.from_iterable(expr(df) for expr in exprs)) + result = agg(*(s.native[0] if s._broadcast else s.native for s in series)) + return [ArrowSeries(result, name=series[0].name, version=self._version)] return self._expr._from_callable( func=func, @@ -154,22 +154,11 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: context=self, ) - def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs)) - native_series = reduce( - pc.max_element_wise, [s.native for s in series], init_series.native - ) - return [ - ArrowSeries(native_series, name=init_series.name, version=self._version) - ] + def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + return self._min_max_horizontal(exprs, "min") - return self._expr._from_callable( - func=func, - evaluate_output_names=combine_evaluate_output_names(*exprs), - alias_output_names=combine_alias_output_names(*exprs), - context=self, - ) + def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + return self._min_max_horizontal(exprs, "max") def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: if self._backend_version >= (14,): diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index af3a9bc2f9..83d2ddcf1c 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -172,13 +172,13 @@ def concat( def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: - expr_results = [s for _expr in exprs for s in _expr(df)] - series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results)) - non_na = align_series_full_broadcast( - df, *(1 - s.isna() for s in expr_results) + expr_results = align_series_full_broadcast( + df, *[s for _expr in exprs for s in _expr(df)] ) - num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue] - den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue] + series = (s.fillna(0) for s in expr_results) + non_na = (1 - s.isna() for s in expr_results) + num = reduce(operator.add, series) # pyright: ignore[reportOperatorIssue] + den = reduce(operator.add, non_na) # pyright: ignore[reportOperatorIssue] return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue] return self._expr( diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 604d06bd25..94aca63464 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -235,7 +235,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def min_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = list(chain.from_iterable(expr(df) for expr in exprs)) + series = self._series._align_full_broadcast( + *chain.from_iterable(expr(df) for expr in exprs) + ) return [ PandasLikeSeries( self.concat( @@ -255,7 +257,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def max_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = list(chain.from_iterable(expr(df) for expr in exprs)) + series = self._series._align_full_broadcast( + *chain.from_iterable(expr(df) for expr in exprs) + ) return [ PandasLikeSeries( self.concat( diff --git a/narwhals/functions.py b/narwhals/functions.py index de200bae11..88de309598 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -1266,7 +1266,9 @@ def _expr_with_horizontal_op(name: str, *exprs: IntoExpr, **kwargs: Any) -> Expr ) -def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: +def sum_horizontal( + *exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr], +) -> Expr: """Sum all values horizontally across columns. Warning: @@ -1300,7 +1302,9 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return _expr_with_horizontal_op("sum_horizontal", *flatten(exprs)) -def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: +def min_horizontal( + *exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr], +) -> Expr: """Get the minimum value horizontally across columns. Notes: @@ -1332,7 +1336,9 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return _expr_with_horizontal_op("min_horizontal", *flatten(exprs)) -def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: +def max_horizontal( + *exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr], +) -> Expr: """Get the maximum value horizontally across columns. Notes: @@ -1431,7 +1437,10 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: return When(*predicates) -def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr: +def all_horizontal( + *exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr], + ignore_nulls: bool, +) -> Expr: r"""Compute the bitwise AND horizontally across columns. Arguments: @@ -1560,7 +1569,10 @@ def lit(value: PythonLiteral, dtype: IntoDType | None = None) -> Expr: return Expr(ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype)) -def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr: +def any_horizontal( + *exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr], + ignore_nulls: bool, +) -> Expr: r"""Compute the bitwise OR horizontally across columns. Arguments: @@ -1608,7 +1620,9 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> ) -def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: +def mean_horizontal( + *exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr], +) -> Expr: """Compute the mean of all values horizontally across columns. Arguments: diff --git a/tests/expr_and_series/all_horizontal_test.py b/tests/expr_and_series/all_horizontal_test.py index d980b3def3..1d70917739 100644 --- a/tests/expr_and_series/all_horizontal_test.py +++ b/tests/expr_and_series/all_horizontal_test.py @@ -1,13 +1,16 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import PythonLiteral + def test_allh(constructor: Constructor) -> None: data = {"a": [False, False, True], "b": [False, True, True]} @@ -157,3 +160,24 @@ def test_horizontal_expressions_empty(constructor: Constructor) -> None: ValueError, match=r"At least one expression must be passed.*min_horizontal" ): df.select(nw.min_horizontal()) + + +@pytest.mark.parametrize( + ("exprs", "name"), + [ + ((nw.col("a"), True), "a"), + ((nw.col("a"), nw.lit(True)), "a"), + ((True, nw.col("a")), "literal"), + ((nw.lit(True), nw.col("a")), "literal"), + ], +) +def test_allh_with_scalars( + constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str +) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 19): + name = "a" + + data = {"a": [False, True]} + df = nw.from_native(constructor(data)) + result = df.select(nw.all_horizontal(*exprs, ignore_nulls=True)) + assert_equal_data(result, {name: [False, True]}) diff --git a/tests/expr_and_series/any_horizontal_test.py b/tests/expr_and_series/any_horizontal_test.py index 04f0cba76c..44da3b55a2 100644 --- a/tests/expr_and_series/any_horizontal_test.py +++ b/tests/expr_and_series/any_horizontal_test.py @@ -1,12 +1,16 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise +from typing import TYPE_CHECKING import pytest import narwhals as nw from tests.utils import Constructor, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import PythonLiteral + def test_anyh(constructor: Constructor) -> None: data = {"a": [False, False, True], "b": [False, True, True]} @@ -85,3 +89,21 @@ def test_anyh_all(constructor: Constructor) -> None: result = df.select(nw.any_horizontal(nw.all(), ignore_nulls=False)) expected = {"a": [False, True, True]} assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("exprs", "name"), + [ + ((nw.col("a"), False), "a"), + ((nw.col("a"), nw.lit(False)), "a"), + ((False, nw.col("a")), "literal"), + ((nw.lit(False), nw.col("a")), "literal"), + ], +) +def test_anyh_with_scalars( + constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str +) -> None: + data = {"a": [False, True]} + df = nw.from_native(constructor(data)) + result = df.select(nw.any_horizontal(*exprs, ignore_nulls=True)) + assert_equal_data(result, {name: [False, True]}) diff --git a/tests/expr_and_series/max_horizontal_test.py b/tests/expr_and_series/max_horizontal_test.py index cc0bddfb1a..8e8a6da9b3 100644 --- a/tests/expr_and_series/max_horizontal_test.py +++ b/tests/expr_and_series/max_horizontal_test.py @@ -1,10 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest import narwhals as nw from tests.utils import Constructor, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import PythonLiteral + data = {"a": [1, 3, None, None], "b": [4, None, 6, None], "z": [3, 1, None, None]} expected_values = [4, 3, 6, None] @@ -23,3 +28,20 @@ def test_maxh_all(constructor: Constructor) -> None: result = df.select(nw.max_horizontal(nw.all()), c=nw.max_horizontal(nw.all())) expected = {"a": expected_values, "c": expected_values} assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("exprs", "name"), + [ + ((nw.col("a"), 2), "a"), + ((nw.col("a"), nw.lit(2)), "a"), + ((2, nw.col("a")), "literal"), + ((nw.lit(2), nw.col("a")), "literal"), + ], +) +def test_maxh_with_scalars( + constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str +) -> None: + df = nw.from_native(constructor({"a": [1, 2, 3]})) + result = df.select(nw.max_horizontal(*exprs)) + assert_equal_data(result, {name: [2, 2, 3]}) diff --git a/tests/expr_and_series/mean_horizontal_test.py b/tests/expr_and_series/mean_horizontal_test.py index bc5bc12fa6..5888918651 100644 --- a/tests/expr_and_series/mean_horizontal_test.py +++ b/tests/expr_and_series/mean_horizontal_test.py @@ -1,10 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest import narwhals as nw from tests.utils import Constructor, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import PythonLiteral + def test_meanh(constructor: Constructor) -> None: data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} @@ -14,11 +19,7 @@ def test_meanh(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_meanh_with_literal( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_meanh_with_literal(constructor: Constructor) -> None: data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} df = nw.from_native(constructor(data)) result = df.select(horizontal_mean=nw.mean_horizontal(nw.lit(1), "a", nw.col("b"))) @@ -35,3 +36,21 @@ def test_meanh_all(constructor: Constructor) -> None: result = df.select(c=nw.mean_horizontal(nw.all())) expected = {"c": [6, 12, 18]} assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("exprs", "name"), + [ + ((nw.col("a"), 1), "a"), + ((nw.col("a"), nw.lit(1)), "a"), + ((1, nw.col("a")), "literal"), + ((nw.lit(1), nw.col("a")), "literal"), + ], +) +def test_meanh_with_scalars( + constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str +) -> None: + data = {"a": [1, 2, 3]} + df = nw.from_native(constructor(data)) + result = df.select(nw.mean_horizontal(*exprs)) + assert_equal_data(result, {name: [1.0, 1.5, 2.0]}) diff --git a/tests/expr_and_series/min_horizontal_test.py b/tests/expr_and_series/min_horizontal_test.py index df9ff31feb..4342eeb075 100644 --- a/tests/expr_and_series/min_horizontal_test.py +++ b/tests/expr_and_series/min_horizontal_test.py @@ -1,10 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest import narwhals as nw from tests.utils import Constructor, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import PythonLiteral + data = {"a": [1, 3, None, None], "b": [4, None, 6, None], "z": [3, 1, None, None]} expected_values = [1, 1, 6, None] @@ -23,3 +28,20 @@ def test_minh_all(constructor: Constructor) -> None: result = df.select(nw.min_horizontal(nw.all()), c=nw.min_horizontal(nw.all())) expected = {"a": expected_values, "c": expected_values} assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("exprs", "name"), + [ + ((nw.col("a"), 2), "a"), + ((nw.col("a"), nw.lit(2)), "a"), + ((2, nw.col("a")), "literal"), + ((nw.lit(2), nw.col("a")), "literal"), + ], +) +def test_minh_with_scalars( + constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str +) -> None: + df = nw.from_native(constructor({"a": [1, 2, 3]})) + result = df.select(nw.min_horizontal(*exprs)) + assert_equal_data(result, {name: [1, 2, 2]}) diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index 94cc32d4b0..16580e05a3 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -1,12 +1,15 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw from tests.utils import DUCKDB_VERSION, Constructor, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import PythonLiteral + def test_sumh(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} @@ -60,3 +63,21 @@ def test_sumh_transformations(constructor: Constructor) -> None: result = df.select(d=nw.sum_horizontal("a", nw.lit(None, dtype=nw.Float64), "c")) expected = {"d": [8.0, 10.0, 12.0]} assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("exprs", "name"), + [ + ((nw.col("a"), 1), "a"), + ((nw.col("a"), nw.lit(1)), "a"), + ((1, nw.col("a")), "literal"), + ((nw.lit(1), nw.col("a")), "literal"), + ], +) +def test_sumh_with_scalars( + constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str +) -> None: + data = {"a": [1, 2, 3]} + df = nw.from_native(constructor(data)) + result = df.select(nw.sum_horizontal(*exprs)) + assert_equal_data(result, {name: [2, 3, 4]})