Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from collections import deque
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, overload

from narwhals._compliant.typing import (
CompliantExprT,
CompliantExprT_co,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
Expand All @@ -23,7 +25,7 @@
from narwhals.dependencies import is_numpy_array_2d

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Collection, Iterable, Sequence

from typing_extensions import TypeAlias, TypeIs

Expand Down Expand Up @@ -98,6 +100,61 @@ def is_native(self, obj: Any, /) -> TypeIs[Any]:
...


class AlignDiagonal(Protocol[CompliantFrameT, CompliantExprT_co]):
"""Mixin to help support `"diagonal*"` concatenation."""

def lit(
self, value: NonNestedLiteral, dtype: IntoDType | None
) -> CompliantExprT_co: ...
def align_diagonal(
self, frames: Collection[CompliantFrameT], /
) -> Sequence[CompliantFrameT]:
"""Convert the inputs to `concat(..., how="diagonal")` into `concat(..., how="vertical")`.
Comment thread
dangotbanned marked this conversation as resolved.
Outdated
Comment thread
dangotbanned marked this conversation as resolved.
Outdated

Adapted from [`convert_diagonal_concat`].

[`convert_diagonal_concat`]: https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-plan/src/plans/conversion/dsl_to_ir/concat.rs#L10-L68
"""
schemas = [frame.collect_schema() for frame in frames]
# 1 - Take the first schema, collecting any fields from the remaining that introduce new names
# Subtle difference from `dict |= dict |= ...` as we preserve the first `DType`, rather than the last.
it_schemas = iter(schemas)
union = dict(next(it_schemas))
seen = union.keys()
for schema in it_schemas:
union.update((nm, dtype) for nm, dtype in schema.items() if nm not in seen)
return self._align_diagonal(frames, schemas, union)

def _align_diagonal(
self,
frames: Iterable[CompliantFrameT],
schemas: Iterable[IntoSchema],
union_schema: IntoSchema,
) -> Sequence[CompliantFrameT]:
# 2 - Align every frame, by adding null column(s) for each missing field in each schema.
# Even if all fields are present, we always reorder the columns to match between frames.
union_names = tuple(union_schema)
# Likely we'll have repeats between frames, so we can share exprs between inner loops
null_exprs: dict[str, CompliantExprT_co] = {}
missing_from_frame = deque[CompliantExprT_co]()
aligned = deque[CompliantFrameT]()
for frame, schema in zip(frames, schemas):
for name, dtype in union_schema.items():
if name not in schema:
if cached := null_exprs.get(name):
missing_from_frame.append(cached)
else:
null_expr = self.lit(None, dtype).alias(name)
missing_from_frame.append(null_expr)
null_exprs[name] = null_expr
result = frame
if missing_from_frame:
result = result.with_columns(*missing_from_frame)
missing_from_frame.clear()
aligned.append(result.simple_select(*union_names))
return aligned
Comment thread
FBruzzesi marked this conversation as resolved.
Outdated


class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
Expand Down
28 changes: 20 additions & 8 deletions narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ibis
import ibis.expr.types as ir

from narwhals._compliant.namespace import AlignDiagonal
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
Expand All @@ -26,7 +27,10 @@
from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral


class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]):
class IbisNamespace(
SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"],
AlignDiagonal[IbisLazyFrame, IbisExpr],
):
_implementation: Implementation = Implementation.IBIS

def __init__(self, *, version: Version) -> None:
Expand Down Expand Up @@ -63,14 +67,22 @@ def _coalesce(self, *exprs: ir.Value) -> ir.Value:
def concat(
self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod
) -> IbisLazyFrame:
frames: Sequence[IbisLazyFrame] = list(items)
Comment thread
dangotbanned marked this conversation as resolved.
Outdated
if how == "diagonal":
msg = "diagonal concat not supported for Ibis. Please join instead."
raise NotImplementedError(msg)

items = list(items)
native_items = [item.native for item in items]
schema = items[0].schema
if not all(x.schema == schema for x in items[1:]):
frames = self.align_diagonal(frames)
natives = (lf.native for lf in frames)
try:
result = ibis.union(*natives)
except ibis.IbisError:
first = frames[0].schema
if not all(x.schema == first for x in frames):
msg = "inputs should all have the same schema"
raise TypeError(msg) from None
raise
return frames[0]._with_native(result)
Comment thread
dangotbanned marked this conversation as resolved.
Outdated
native_items = [item.native for item in frames]
schema = frames[0].schema
if not all(x.schema == schema for x in frames[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
return self._lazyframe.from_native(ibis.union(*native_items), context=self)
Expand Down
30 changes: 25 additions & 5 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import datetime as dt
import re

import pytest

import narwhals as nw
from narwhals._utils import Implementation
from narwhals.exceptions import InvalidOperationError
from tests.utils import Constructor, ConstructorEager, assert_equal_data

Expand Down Expand Up @@ -61,11 +63,7 @@ def test_concat_vertical(constructor: Constructor) -> None:
nw.concat([df_left, df_left.select("d")], how="vertical").collect()


def test_concat_diagonal(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "ibis" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_concat_diagonal(constructor: Constructor) -> None:
Copy link
Copy Markdown
Member Author

@dangotbanned dangotbanned Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(#3404 (comment))

I do think there's something here, I just wanna experiment a bit (and add more tests) ❤️

@FBruzzesi okay this led to discovering a bug (but not in any of the new code or your suggestion)

ibis is the only backend that doesn't guarantee the order of union.

I found that out by using the 3-tabled version of the test:

That fails for ibis, but it turns out "vertical" fails too:

Show test

def test_concat_vertical_bigger(constructor: Constructor) -> None:
    data_1 = {"a": [1, 2], "b": [3, 4], "c": [0, None]}
    data_2 = {"a": [5, 6], "b": [0, None], "c": [7, 8]}
    data_3 = {"a": [0, None], "b": [9, 10], "c": [11, 12]}
    expected = {
        "a": [1, 2, 5, 6, 0, None],
        "b": [3, 4, 0, None, 9, 10],
        "c": [0, None, 7, 8, 11, 12],
    }
    df_1 = nw.from_native(constructor(data_1)).lazy()
    df_2 = nw.from_native(constructor(data_2)).lazy()
    df_3 = nw.from_native(constructor(data_3)).lazy()
    result = nw.concat([df_1, df_2, df_3], how="vertical")
    assert_equal_data(result, expected)

Show error

E               AssertionError: Mismatch at index 0, key a: 0 != 1
E               Expected: {'a': [1, 2, 5, 6, 0, None], 'b': [3, 4, 0, None, 9, 10], 'c': [0, None, 7, 8, 11, 12]}
E               Got: {'a': [0, None, 5, 6, 1, 2], 'b': [9, 10, 0, None, 3, 4], 'c': [11, 12, 7, 8, 0, None]}

I suppose I'm stuck with testing two tables for now then 😂
I've added (a8e8388), but should probably follow this up with another issue.
(We (and polars) don't document that it is ordered, but we do test for it and polars.union was recently introduced for unordered)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ibis is the only backend that doesn't guarantee the order of union.

For a moment I thought order of columns, and I panicked sooo much 🤯

No guarantee in row order makes sense. You can add an index column and sort by such

data_1 = {"a": [1, 3], "b": [4, 6]}
data_2 = {"a": [100, 200], "z": ["x", "y"]}
expected = {
Expand All @@ -83,3 +81,25 @@ def test_concat_diagonal(

with pytest.raises(ValueError, match="No items"):
nw.concat([], how="diagonal")


def test_concat_diagonal_invalid(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
data_1 = {"a": [1, 3], "b": [4, 6]}
data_2 = {
"a": [dt.datetime(2000, 1, 1), dt.datetime(2000, 1, 2)],
"b": [4, 6],
"z": ["x", "y"],
}
df_1 = nw.from_native(constructor(data_1)).lazy()
bad_schema = nw.from_native(constructor(data_2)).lazy()
impl = df_1.implementation
request.applymarker(
pytest.mark.xfail(
impl not in {Implementation.IBIS, Implementation.POLARS},
reason=f"{impl!r} does not validate schemas for `concat(how='diagonal')",
)
)
Comment thread
dangotbanned marked this conversation as resolved.
with pytest.raises((InvalidOperationError, TypeError), match=r"same schema"):
nw.concat([df_1, bad_schema], how="diagonal").collect().to_dict(as_series=False)
Loading