diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 6f7062be35..98b1e16b50 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -5,6 +5,7 @@ from narwhals._compliant.typing import ( CompliantExprT, + CompliantExprT_co, CompliantFrameT, CompliantLazyFrameT, DepthTrackingExprT, @@ -23,7 +24,7 @@ from narwhals.dependencies import is_numpy_array_2d if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence from typing_extensions import TypeAlias, TypeIs @@ -98,6 +99,58 @@ def is_native(self, obj: Any, /) -> TypeIs[Any]: ... +class AlignDiagonal(Protocol[CompliantFrameT, CompliantExprT_co]): + """Mixin to help support `"diagonal*"` concatenation.""" + + def lit( + self, value: NonNestedLiteral, dtype: IntoDType | None + ) -> CompliantExprT_co: ... + def align_diagonal( + self, frames: Collection[CompliantFrameT], / + ) -> Sequence[CompliantFrameT]: + """Prepare frames with differing schemas for vertical concatenation. + + Adapted from [`convert_diagonal_concat`]. + + [`convert_diagonal_concat`]: https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-plan/src/plans/conversion/dsl_to_ir/concat.rs#L10-L68 + """ + schemas = [frame.collect_schema() for frame in frames] + # 1 - Take the first schema, collecting any fields from the remaining that introduce new names + # Subtle difference from `dict |= dict |= ...` as we preserve the first `DType`, rather than the last. + it_schemas = iter(schemas) + union = dict(next(it_schemas)) + seen = union.keys() + for schema in it_schemas: + union.update((nm, dtype) for nm, dtype in schema.items() if nm not in seen) + return self._align_diagonal(frames, schemas, union) + + def _align_diagonal( + self, + frames: Iterable[CompliantFrameT], + schemas: Iterable[IntoSchema], + union_schema: IntoSchema, + ) -> Sequence[CompliantFrameT]: + union_names = union_schema.keys() + # Lazily populate null expressions as needed, shared across frames + null_exprs: dict[str, CompliantExprT_co] = {} + + def iter_missing_exprs(missing: Iterable[str]) -> Iterator[CompliantExprT_co]: + nonlocal null_exprs + for name in missing: + if (expr := null_exprs.get(name)) is None: + dtype = union_schema[name] + expr = null_exprs[name] = self.lit(None, dtype).alias(name) + yield expr + + def align(df: CompliantFrameT, columns: KeysView[str]) -> CompliantFrameT: + if missing := union_names - columns: + df = df.with_columns(*iter_missing_exprs(missing)) + # Even if all fields are present, we always reorder the columns to match between frames. + return df.simple_select(*union_names) + + return [align(frame, schema.keys()) for frame, schema in zip(frames, schemas)] + + class DepthTrackingNamespace( CompliantNamespace[CompliantFrameT, DepthTrackingExprT], Protocol[CompliantFrameT, DepthTrackingExprT], diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index d3116edd4f..ddda316742 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -8,6 +8,7 @@ import ibis import ibis.expr.types as ir +from narwhals._compliant.namespace import AlignDiagonal from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, @@ -26,7 +27,10 @@ from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral -class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]): +class IbisNamespace( + SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"], + AlignDiagonal[IbisLazyFrame, IbisExpr], +): _implementation: Implementation = Implementation.IBIS def __init__(self, *, version: Version) -> None: @@ -63,17 +67,18 @@ def _coalesce(self, *exprs: ir.Value) -> ir.Value: def concat( self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod ) -> IbisLazyFrame: + frames: Sequence[IbisLazyFrame] = tuple(items) if how == "diagonal": - msg = "diagonal concat not supported for Ibis. Please join instead." - raise NotImplementedError(msg) - - items = list(items) - native_items = [item.native for item in items] - schema = items[0].schema - if not all(x.schema == schema for x in items[1:]): - msg = "inputs should all have the same schema" - raise TypeError(msg) - return self._lazyframe.from_native(ibis.union(*native_items), context=self) + frames = self.align_diagonal(frames) + try: + result = ibis.union(*(lf.native for lf in frames)) + except ibis.IbisError: + first = frames[0].schema + if not all(x.schema == first for x in frames[1:]): + msg = "inputs should all have the same schema" + raise TypeError(msg) from None + raise + return frames[0]._with_native(result) def concat_str( self, *exprs: IbisExpr, separator: str, ignore_nulls: bool diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 35ebd54d39..8d8b4b2ce6 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -1,12 +1,18 @@ from __future__ import annotations +import datetime as dt import re +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw -from narwhals.exceptions import InvalidOperationError -from tests.utils import Constructor, ConstructorEager, assert_equal_data +from narwhals._utils import Implementation +from narwhals.exceptions import InvalidOperationError, NarwhalsError +from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data + +if TYPE_CHECKING: + from collections.abc import Iterator def test_concat_horizontal(constructor_eager: ConstructorEager) -> None: @@ -61,11 +67,7 @@ def test_concat_vertical(constructor: Constructor) -> None: nw.concat([df_left, df_left.select("d")], how="vertical").collect() -def test_concat_diagonal( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "ibis" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_concat_diagonal(constructor: Constructor) -> None: data_1 = {"a": [1, 3], "b": [4, 6]} data_2 = {"a": [100, 200], "z": ["x", "y"]} expected = { @@ -83,3 +85,56 @@ def test_concat_diagonal( with pytest.raises(ValueError, match="No items"): nw.concat([], how="diagonal") + + +def _from_natives( + constructor: Constructor, *sources: dict[str, list[Any]] +) -> Iterator[nw.LazyFrame[Any]]: + yield from (nw.from_native(constructor(data)).lazy() for data in sources) + + +def test_concat_diagonal_bigger(constructor: Constructor) -> None: + # NOTE: `ibis.union` doesn't guarantee the order of outputs + # https://github.com/narwhals-dev/narwhals/pull/3404#discussion_r2694556781 + data_1 = {"idx": [1, 2], "a": [1, 2], "b": [3, 4]} + data_2 = {"a": [5, 6], "c": [7, 8], "idx": [3, 4]} + data_3 = {"b": [9, 10], "idx": [5, 6], "c": [11, 12]} + expected = { + "idx": [1, 2, 3, 4, 5, 6], + "a": [1, 2, 5, 6, None, None], + "b": [3, 4, None, None, 9, 10], + "c": [None, None, 7, 8, 11, 12], + } + dfs = _from_natives(constructor, data_1, data_2, data_3) + result = nw.concat(dfs, how="diagonal").sort("idx") + assert_equal_data(result, expected) + + +def test_concat_diagonal_invalid( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + data_1 = {"a": [1, 3], "b": [4, 6]} + data_2 = { + "a": [dt.datetime(2000, 1, 1), dt.datetime(2000, 1, 2)], + "b": [4, 6], + "z": ["x", "y"], + } + df_1 = nw.from_native(constructor(data_1)).lazy() + bad_schema = nw.from_native(constructor(data_2)).lazy() + impl = df_1.implementation + request.applymarker( + pytest.mark.xfail( + impl not in {Implementation.IBIS, Implementation.POLARS}, + reason=f"{impl!r} does not validate schemas for `concat(how='diagonal')", + ) + ) + context: Any + if impl.is_polars() and POLARS_VERSION < (1,): # pragma: no cover + context = pytest.raises( + NarwhalsError, + match=re.compile(r"(int.+datetime)|(datetime.+int)", re.IGNORECASE), + ) + else: + context = pytest.raises((InvalidOperationError, TypeError), match=r"same schema") + with context: + nw.concat([df_1, bad_schema], how="diagonal").collect().to_dict(as_series=False)