Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,20 @@ def func(expr: Expression) -> Expression:
return self._with_callable(func)

def len(self) -> Self:
return self._with_callable(lambda _expr: F("count"))
def func(df: DuckDBLazyFrame) -> list[Expression]:
if not self._metadata.preserves_length:
msg = "`len` is not supported after a length-changing expression"
raise NotImplementedError(msg)

return [F("count") for _ in self(df)]

return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)

def null_count(self) -> Self:
return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int")))
Expand Down
8 changes: 6 additions & 2 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,12 @@ def n_unique(self) -> Self:
)

def len(self) -> Self:
def func(df: IbisLazyFrame) -> Sequence[ir.IntegerScalar]:
return [df.native.count() for _ in self._evaluate_output_names(df)]
def func(df: IbisLazyFrame) -> list[ir.Value]:
if not self._metadata.preserves_length:
msg = "`len` is not supported after a length-changing expression"
raise NotImplementedError(msg)

return [df.native.count() for _ in self(df)]

return self.__class__(
func,
Expand Down
17 changes: 13 additions & 4 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,20 @@ def _is_in(expr: Column) -> Column:
return self._with_elementwise(_is_in)

def len(self) -> Self:
def _len(_expr: Column) -> Column:
# Use count(*) to count all rows including nulls
return self._F.count("*")
def func(df: SparkLikeLazyFrame) -> list[Column]:
if not self._metadata.preserves_length:
msg = "`len` is not supported after a length-changing expression"
raise NotImplementedError(msg)

return [self._F.count("*") for _ in self(df)]

return self._with_callable(_len)
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)

def skew(self) -> Self:
return self._with_callable(self._F.skewness)
Expand Down
17 changes: 16 additions & 1 deletion narwhals/_sql/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,22 @@ def mode(self, *, keep: ModeKeepStrategy) -> Self:

return self._with_callable(lambda expr: self._function("mode", expr))

def filter(self, *predicates: Self) -> Self:
plx = self.__narwhals_namespace__()
predicate = plx.all_horizontal(*predicates, ignore_nulls=False)

def func(df: SQLLazyFrameT) -> list[NativeExprT]:
mask = df._evaluate_single_output_expr(predicate)
return [self._when(mask, value=expr) for expr in self(df)]

return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)

# Namespaces
@property
def str(self) -> SQLExprStringNamespace[Self]: ...
Expand All @@ -919,5 +935,4 @@ def str(self) -> SQLExprStringNamespace[Self]: ...
def dt(self) -> SQLExprDateTimeNamesSpace[Self]: ...

drop_nulls = not_implemented() # type: ignore[misc]
filter = not_implemented() # type: ignore[misc]
unique = not_implemented() # type: ignore[misc]
32 changes: 32 additions & 0 deletions tests/expr_and_series/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,35 @@ def test_filter_windows_over(
df = nw.from_native(constructor(data))
result = df.filter(nw.col("i") == nw.col("i").min().over("b")).sort("i")
assert_equal_data(result, expected_over)


def test_expr_filter(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor):
reason = "not implemented"
request.applymarker(pytest.mark.xfail(reason=reason))

mask_multi = (nw.col("i") > 0, nw.col("a") < 4)
df = nw.from_native(constructor(data)).with_columns(
d=nw.when(nw.col("i") > 2).then(nw.col("i")).otherwise(nw.lit(None))
# column d is [None, None, None, 3, 4]
)
result = df.select(
nw.col("a").filter(nw.col("i") < 3).mean(),
nw.col("b").filter(*mask_multi).sum(),
nw.col("b", "c").filter(*mask_multi).min().name.suffix("_min"),
count_c=nw.col("c").filter(*mask_multi).count(),
count_d=nw.col("d").filter(*mask_multi).count(),
# len=nw.col("d").filter(*mask_multi).len() # noqa: ERA001
# !NOTE: Result should be {"len": [3]}, but we can get:
# 5: Without changing the current implementation for SQL backends
# 1: By accounting for the expression, (e.g. for spark F("count", expr)) but ignoring nulls
)
Comment on lines +79 to +94
Copy link
Copy Markdown
Member

@dangotbanned dangotbanned Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FBruzzesi are we using Scalar Subqueries much in the sql backends?

This would be like:

import duckdb
import polars as pl

df = pl.DataFrame(data).with_columns(d=pl.when(pl.col("i") > 2).then("i"))


def _(**named_agg: str) -> str:
    alias, expr = next(iter(named_agg.items()))
    return f"SELECT {expr} AS {alias}"


from_df = "FROM df"
where = "WHERE i > 0 AND a < 4"

query = (
    f"FROM ({_(a='mean(a)')} {from_df} WHERE i < 3),"
    f"({_(b='sum(b)')} {from_df} {where}),"
    f"({_(b_min='min(b)')} {from_df} {where}),"
    f"({_(count_c='count(c)')} {from_df} {where}),"
    f"({_(count_d='count(d)')} {from_df} {where}),"
    f"({_(len='count()')} {from_df} {where})"
)
duckdb.sql(query)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚   a    β”‚   b    β”‚ b_min β”‚ count_c β”‚ count_d β”‚  len  β”‚
β”‚ double β”‚ int128 β”‚ int64 β”‚  int64  β”‚  int64  β”‚ int64 β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚    1.0 β”‚     10 β”‚     2 β”‚       3 β”‚       1 β”‚     3 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

Note

If we're already using them - at least the len issue can be resolved πŸ˜„

expected = {
"a": [1.0],
"b": [10],
"b_min": [2],
"c_min": [2],
"count_c": [3],
"count_d": [1],
}
assert_equal_data(result, expected)
13 changes: 10 additions & 3 deletions tests/expr_and_series/len_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pytest

import narwhals as nw
from tests.utils import Constructor, ConstructorEager, assert_equal_data

Expand All @@ -16,10 +18,15 @@ def test_len_no_filter(constructor: Constructor) -> None:
assert_equal_data(df, expected)


def test_len_chaining(constructor_eager: ConstructorEager) -> None:
data = {"a": list("xyz"), "b": [1, 2, 1]}
def test_len_chaining(constructor: Constructor, request: pytest.FixtureRequest) -> None:
lazy_non_polars = ("dask", "duckdb", "ibis", "spark")
if any(x in str(constructor) for x in lazy_non_polars):
reason = "not implemented"
request.applymarker(pytest.mark.xfail(reason=reason))

data = {"a": [0, 1, None], "b": [1, 2, 1]}
expected = {"a1": [2], "a2": [1]}
df = nw.from_native(constructor_eager(data)).select(
df = nw.from_native(constructor(data)).select(
nw.col("a").filter(nw.col("b") == 1).len().alias("a1"),
nw.col("a").filter(nw.col("b") == 2).len().alias("a2"),
)
Expand Down
Loading