diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 7a0977a54c..05314360eb 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -214,7 +214,20 @@ def func(expr: Expression) -> Expression: return self._with_callable(func) def len(self) -> Self: - return self._with_callable(lambda _expr: F("count")) + def func(df: DuckDBLazyFrame) -> list[Expression]: + if not self._metadata.preserves_length: + msg = "`len` is not supported after a length-changing expression" + raise NotImplementedError(msg) + + return [F("count") for _ in self(df)] + + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, + implementation=self._implementation, + ) def null_count(self) -> Self: return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int"))) diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 9de16232bb..69ddee1b64 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -220,8 +220,12 @@ def n_unique(self) -> Self: ) def len(self) -> Self: - def func(df: IbisLazyFrame) -> Sequence[ir.IntegerScalar]: - return [df.native.count() for _ in self._evaluate_output_names(df)] + def func(df: IbisLazyFrame) -> list[ir.Value]: + if not self._metadata.preserves_length: + msg = "`len` is not supported after a length-changing expression" + raise NotImplementedError(msg) + + return [df.native.count() for _ in self(df)] return self.__class__( func, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 404f31a818..c6f37f3391 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -310,11 +310,20 @@ def _is_in(expr: Column) -> Column: return self._with_elementwise(_is_in) def len(self) -> Self: - def _len(_expr: Column) -> Column: - # Use count(*) to count all rows including nulls - return self._F.count("*") + def func(df: SparkLikeLazyFrame) -> list[Column]: + if not self._metadata.preserves_length: + msg = "`len` is not supported after a length-changing expression" + raise NotImplementedError(msg) + + return [self._F.count("*") for _ in self(df)] - return self._with_callable(_len) + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, + implementation=self._implementation, + ) def skew(self) -> Self: return self._with_callable(self._F.skewness) diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index 63fdec5562..03bd4eacb8 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -911,6 +911,22 @@ def mode(self, *, keep: ModeKeepStrategy) -> Self: return self._with_callable(lambda expr: self._function("mode", expr)) + def filter(self, *predicates: Self) -> Self: + plx = self.__narwhals_namespace__() + predicate = plx.all_horizontal(*predicates, ignore_nulls=False) + + def func(df: SQLLazyFrameT) -> list[NativeExprT]: + mask = df._evaluate_single_output_expr(predicate) + return [self._when(mask, value=expr) for expr in self(df)] + + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, + implementation=self._implementation, + ) + # Namespaces @property def str(self) -> SQLExprStringNamespace[Self]: ... @@ -919,5 +935,4 @@ def str(self) -> SQLExprStringNamespace[Self]: ... def dt(self) -> SQLExprDateTimeNamesSpace[Self]: ... drop_nulls = not_implemented() # type: ignore[misc] - filter = not_implemented() # type: ignore[misc] unique = not_implemented() # type: ignore[misc] diff --git a/tests/expr_and_series/filter_test.py b/tests/expr_and_series/filter_test.py index 5f300ed4b8..e7a61eb64c 100644 --- a/tests/expr_and_series/filter_test.py +++ b/tests/expr_and_series/filter_test.py @@ -69,3 +69,35 @@ def test_filter_windows_over( df = nw.from_native(constructor(data)) result = df.filter(nw.col("i") == nw.col("i").min().over("b")).sort("i") assert_equal_data(result, expected_over) + + +def test_expr_filter(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "dask" in str(constructor): + reason = "not implemented" + request.applymarker(pytest.mark.xfail(reason=reason)) + + mask_multi = (nw.col("i") > 0, nw.col("a") < 4) + df = nw.from_native(constructor(data)).with_columns( + d=nw.when(nw.col("i") > 2).then(nw.col("i")).otherwise(nw.lit(None)) + # column d is [None, None, None, 3, 4] + ) + result = df.select( + nw.col("a").filter(nw.col("i") < 3).mean(), + nw.col("b").filter(*mask_multi).sum(), + nw.col("b", "c").filter(*mask_multi).min().name.suffix("_min"), + count_c=nw.col("c").filter(*mask_multi).count(), + count_d=nw.col("d").filter(*mask_multi).count(), + # len=nw.col("d").filter(*mask_multi).len() # noqa: ERA001 + # !NOTE: Result should be {"len": [3]}, but we can get: + # 5: Without changing the current implementation for SQL backends + # 1: By accounting for the expression, (e.g. for spark F("count", expr)) but ignoring nulls + ) + expected = { + "a": [1.0], + "b": [10], + "b_min": [2], + "c_min": [2], + "count_c": [3], + "count_d": [1], + } + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/len_test.py b/tests/expr_and_series/len_test.py index 755c389a28..fb00e4fbeb 100644 --- a/tests/expr_and_series/len_test.py +++ b/tests/expr_and_series/len_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals as nw from tests.utils import Constructor, ConstructorEager, assert_equal_data @@ -16,10 +18,15 @@ def test_len_no_filter(constructor: Constructor) -> None: assert_equal_data(df, expected) -def test_len_chaining(constructor_eager: ConstructorEager) -> None: - data = {"a": list("xyz"), "b": [1, 2, 1]} +def test_len_chaining(constructor: Constructor, request: pytest.FixtureRequest) -> None: + lazy_non_polars = ("dask", "duckdb", "ibis", "spark") + if any(x in str(constructor) for x in lazy_non_polars): + reason = "not implemented" + request.applymarker(pytest.mark.xfail(reason=reason)) + + data = {"a": [0, 1, None], "b": [1, 2, 1]} expected = {"a1": [2], "a2": [1]} - df = nw.from_native(constructor_eager(data)).select( + df = nw.from_native(constructor(data)).select( nw.col("a").filter(nw.col("b") == 1).len().alias("a1"), nw.col("a").filter(nw.col("b") == 2).len().alias("a2"), )