From 04f3d9a084f4e142155e09ee2fb0b84c04560eb2 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 18 Apr 2026 17:51:38 +0200 Subject: [PATCH] fix: Allow float('nan') as value in join for duckdb --- narwhals/_duckdb/dataframe.py | 7 ++++++- narwhals/_pandas_like/series.py | 2 +- tests/frame/join_test.py | 33 +++++++++++++++++++++++++++++++++ tests/hypothesis/join_test.py | 4 ++++ 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 23ef4daa65..14c01d4d4a 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -307,8 +307,13 @@ def join( # help mypy assert left_on is not None # noqa: S101 assert right_on is not None # noqa: S101 + # Use `==` to preserve the polars semantic of `null != null` in joins, + # and additionally match `NaN = NaN` since `==` in DuckDB treats `NaN` + # like SQL `NULL` (see https://github.com/narwhals-dev/narwhals/issues/3554). + # `isnan` is safe to apply as it returns `false` for non-float columns. it = ( - col(f'lhs."{left}"') == col(f'rhs."{right}"') + (col(f'lhs."{left}"') == col(f'rhs."{right}"')) + | (F("isnan", col(f'lhs."{left}"')) & F("isnan", col(f'rhs."{right}"'))) for left, right in zip_strict(left_on, right_on) ) condition: Expression = reduce(and_, it) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index e0c4e23b02..c955828e86 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -567,7 +567,7 @@ def is_nan(self) -> Self: if not self.dtype.is_numeric(): msg = f"`.is_nan` only supported for numeric dtype and not {self.dtype}, did you mean `.is_null`?" raise InvalidOperationError(msg) - # If/when pandas exposes an API which distinguishes NaN vs null, use that. + # TODO(Unassigned): If/when pandas exposes an API which distinguishes NaN vs null, use that. return self._with_native(ser != ser, preserve_broadcast=True) # noqa: PLR0124 def fill_null( diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index 42d52adafc..c1c342ed70 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -919,3 +919,36 @@ def test_full_join_with_overlapping_non_key_columns_and_nulls( } assert_equal_data(result, expected) + + +def test_join_with_float_nan( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if any(x in str(constructor) for x in ("cudf", "dask", "modin", "pandas")): + request.applymarker(pytest.mark.xfail) + + data = {"a": [0, 0, 0], "b": [0, 0, 0], "c": [0.0, 0.0, float("nan")]} + join_cols = ["a", "c"] + frame = from_native_lazy(constructor(data)) + + result = ( + frame.join(frame, on=join_cols, how="inner").sort("c", nulls_last=True).collect() + ) + + zero_cols = ("a", "b", "b_right") + for col in zero_cols: + assert (result.get_column(col) == 0).all() + + assert (result.get_column("c").is_nan().sum()) == 1 + """ + NOTE: polars result is the following: + expected = { + "a": [0, 0, 0, 0, 0], + "b": [0, 0, 0, 0, 0], + "c": [0., 0., 0., 0., float("nan")], + "b_right": [0, 0, 0, 0, 0], + } + + How can we sort the data to use: + assert_equal_data(result, expected) + """ diff --git a/tests/hypothesis/join_test.py b/tests/hypothesis/join_test.py index 037854a861..3432e4d489 100644 --- a/tests/hypothesis/join_test.py +++ b/tests/hypothesis/join_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING, Any, cast import pytest @@ -44,6 +45,9 @@ def test_join( # pragma: no cover floats: st.SearchStrategy[list[float]], cols: st.SearchStrategy[list[str]], ) -> None: + # See https://github.com/narwhals-dev/narwhals/issues/3554 + # for why we need to assume that all float values are finite + assume(all(math.isfinite(f) for f in cast("list[float]", floats))) data: Mapping[str, Any] = {"a": integers, "b": other_integers, "c": floats} join_cols = cast("list[str]", cols)