From 3b49effb1569ef7ffe5e67c88b18347babc85e79 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 16:13:50 +0100 Subject: [PATCH 01/20] WIP: Add relaxed versions for all but Dask --- narwhals/_arrow/namespace.py | 10 +++++ narwhals/_compliant/dataframe.py | 4 ++ narwhals/_compliant/namespace.py | 28 +++++++++++++- narwhals/_duckdb/namespace.py | 60 +++++++++++++++++++++++++----- narwhals/_ibis/namespace.py | 36 +++++++++++++----- narwhals/_pandas_like/namespace.py | 16 ++++++++ narwhals/_polars/namespace.py | 16 +++++--- narwhals/_spark_like/namespace.py | 50 ++++++++++++++++++++++--- narwhals/exceptions.py | 4 ++ narwhals/functions.py | 4 ++ narwhals/schema.py | 42 +++++++++++++++++++++ narwhals/typing.py | 8 +++- 12 files changed, 246 insertions(+), 32 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 98282a575e..6cf16857d7 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -194,6 +194,16 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: raise TypeError(msg) return pa.concat_tables(dfs) + def _concat_vertical_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: + from narwhals.schema import Schema, to_supertype + + out_schema = reduce( + lambda x, y: to_supertype(x, y), + (Schema.from_arrow(table.schema) for table in dfs), + ).to_arrow() + + return pa.concat_tables([table.cast(out_schema) for table in dfs]) + @property def selectors(self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace.from_namespace(self) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 3e7810616c..f91e6eb1c1 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -323,6 +323,10 @@ class EagerDataFrame( def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() + @property + def native(self) -> NativeDataFrameT: + return self._native_frame + def __narwhals_namespace__( self, ) -> EagerNamespace[ diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 6f7062be35..0650defe88 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial +from functools import partial, reduce from typing import TYPE_CHECKING, Any, Protocol, overload from narwhals._compliant.typing import ( @@ -224,16 +224,42 @@ def _concat_horizontal( self, dfs: Sequence[NativeFrameT | Any], / ) -> NativeFrameT: ... def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def _concat_vertical_relaxed( + self, dfs: Sequence[NativeFrameT], / + ) -> NativeFrameT: ... def concat( self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod ) -> EagerDataFrameT: + items = tuple(items) dfs = [item.native for item in items] if how == "horizontal": native = self._concat_horizontal(dfs) elif how == "vertical": native = self._concat_vertical(dfs) + elif how == "vertical_relaxed": + native = self._concat_vertical_relaxed(dfs) elif how == "diagonal": native = self._concat_diagonal(dfs) + elif how == "diagonal_relaxed": + from narwhals.schema import Schema, combine_schemas, to_supertype + + schemas = tuple(Schema(item.collect_schema()) for item in items) + out_schema = reduce( + lambda x, y: to_supertype(*combine_schemas(x, y)), schemas + ) + aligned_items = tuple( + item.select( + *( + self.col(name).cast(dtype) + if name in schema + else self.lit(None, dtype=dtype) + for name, dtype in out_schema.items() + ) + ).native + for item, schema in zip(items, schemas) + ) + native = self._concat_vertical(aligned_items) + else: # pragma: no cover raise NotImplementedError return self._dataframe.from_native(native, context=self) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index ebc5041e68..58f4f162a6 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -27,6 +27,7 @@ ) from narwhals._sql.namespace import SQLNamespace from narwhals._utils import Implementation +from narwhals.schema import Schema, combine_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable @@ -82,23 +83,62 @@ def _coalesce(self, *exprs: Expression) -> Expression: def concat( self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod ) -> DuckDBLazyFrame: - native_items = [item._native_frame for item in items] - items = list(items) + items = tuple(items) first = items[0] - schema = first.schema - if how == "vertical" and not all(x.schema == schema for x in items[1:]): - msg = "inputs should all have the same schema" - raise TypeError(msg) + + if how == "vertical": + schema = first.schema + if not all(x.schema == schema for x in items[1:]): + msg = "inputs should all have the same schema" + raise TypeError(msg) + + res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items)) + return first._with_native(res) + + if how == "vertical_relaxed": + schemas = (Schema(df.collect_schema()) for df in items) + out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) + native_items = ( + item.select( + *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) + )._native_frame + for item in items + ) + res = reduce(lambda x, y: x.union(y), native_items) + return first._with_native(res) + if how == "diagonal": - res = first.native - for _item in native_items[1:]: + res, *others = (item._native_frame for item in items) + for _item in others: + # TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996 + res = duckdb.sql(""" + from res select * union all by name from _item select * + """) + return first._with_native(res) + + if how == "diagonal_relaxed": + schemas = [Schema(df.collect_schema()) for df in items] + out_schema = reduce( + lambda x, y: to_supertype(*combine_schemas(x, y)), schemas + ) + res, *others = ( + item.select( + *( + self.col(name).cast(dtype) + if name in schema + else self.lit(None, dtype=dtype) + for name, dtype in out_schema.items() + ) + )._native_frame + for item, schema in zip(items, schemas) + ) + for _item in others: # TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996 res = duckdb.sql(""" from res select * union all by name from _item select * """) return first._with_native(res) - res = reduce(lambda x, y: x.union(y), native_items) - return first._with_native(res) + raise NotImplementedError def concat_str( self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index d3116edd4f..87f2555b8f 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -18,6 +18,7 @@ from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype from narwhals._sql.namespace import SQLNamespace from narwhals._utils import Implementation +from narwhals.schema import Schema, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -63,17 +64,34 @@ def _coalesce(self, *exprs: ir.Value) -> ir.Value: def concat( self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod ) -> IbisLazyFrame: - if how == "diagonal": - msg = "diagonal concat not supported for Ibis. Please join instead." + if how in {"diagonal", "diagonal_relaxed"}: + msg = f"{how} concat not supported for Ibis. Please join instead." raise NotImplementedError(msg) - items = list(items) - native_items = [item.native for item in items] - schema = items[0].schema - if not all(x.schema == schema for x in items[1:]): - msg = "inputs should all have the same schema" - raise TypeError(msg) - return self._lazyframe.from_native(ibis.union(*native_items), context=self) + items = tuple(items) + + if how == "vertical": + schema = items[0].schema + if not all(x.schema == schema for x in items[1:]): + msg = "inputs should all have the same schema" + raise TypeError(msg) + + native_items = (item.native for item in items) + return self._lazyframe.from_native(ibis.union(*native_items), context=self) + + if how == "vertical_relaxed": + schemas = (Schema(df.collect_schema()) for df in items) + out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) + + native_items = ( + item.select( + *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) + )._native_frame + for item in items + ) + return self._lazyframe.from_native(ibis.union(*native_items), context=self) + + raise NotImplementedError def concat_str( self, *exprs: IbisExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 96b7a22290..2124eb38e7 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -289,6 +289,22 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) + def _concat_vertical_relaxed( + self, dfs: Sequence[NativeDataFrameT], / + ) -> NativeDataFrameT: + from narwhals.schema import Schema, to_supertype + + out_schema = reduce( + lambda x, y: to_supertype(x, y), + (Schema.from_pandas_like(frame.dtypes.to_dict()) for frame in dfs), + ).to_pandas( + # dtype_backend= # TODO(FBruzzesi): what should this be? + ) + to_concat = [frame.astype(out_schema) for frame in dfs] + if self._implementation.is_pandas() and self._backend_version < (3,): + return self._concat(to_concat, axis=VERTICAL, copy=False) + return self._concat(to_concat, axis=VERTICAL) + def concat_str( self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool ) -> PandasLikeExpr: diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index ac8da364be..a8ddae6fe0 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, cast, overload import polars as pl @@ -22,7 +22,14 @@ from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext - from narwhals.typing import Into1DArray, IntoDType, IntoSchema, TimeUnit, _2DArray + from narwhals.typing import ( + ConcatMethod, + Into1DArray, + IntoDType, + IntoSchema, + TimeUnit, + _2DArray, + ) class PolarsNamespace: @@ -130,10 +137,7 @@ def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr: return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version) def concat( - self, - items: Iterable[FrameT], - *, - how: Literal["vertical", "horizontal", "diagonal"], + self, items: Iterable[FrameT], *, how: ConcatMethod ) -> PolarsDataFrame | PolarsLazyFrame: result = pl.concat((item.native for item in items), how=how) if isinstance(result, pl.DataFrame): diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index c660b67298..6d9d5aeab3 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,6 +19,7 @@ true_divide, ) from narwhals._sql.namespace import SQLNamespace +from narwhals.schema import Schema, combine_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable @@ -146,10 +147,11 @@ def func(cols: Iterable[Column]) -> Column: def concat( self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod ) -> SparkLikeLazyFrame: - dfs = [item._native_frame for item in items] + items = tuple(items) if how == "vertical": - cols_0 = dfs[0].columns - for i, df in enumerate(dfs[1:], start=1): + native_items = [item._native_frame for item in items] + cols_0 = native_items[0].columns + for i, df in enumerate(native_items[1:], start=1): cols_current = df.columns if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): msg = ( @@ -160,7 +162,7 @@ def concat( raise TypeError(msg) return SparkLikeLazyFrame( - native_dataframe=reduce(lambda x, y: x.union(y), dfs), + native_dataframe=reduce(lambda x, y: x.union(y), native_items), version=self._version, implementation=self._implementation, ) @@ -168,11 +170,49 @@ def concat( if how == "diagonal": return SparkLikeLazyFrame( native_dataframe=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs + lambda x, y: x.unionByName(y, allowMissingColumns=True), + (item._native_frame for item in items), ), version=self._version, implementation=self._implementation, ) + + if how == "vertical_relaxed": + schemas = (Schema(df.collect_schema()) for df in items) + out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) + native_items = ( + item.select( + *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) + )._native_frame + for item in items + ) + return SparkLikeLazyFrame( + native_dataframe=reduce(lambda x, y: x.union(y), native_items), + version=self._version, + implementation=self._implementation, + ) + + if how == "diagonal_relaxed": + schemas = [Schema(df.collect_schema()) for df in items] + out_schema = reduce( + lambda x, y: to_supertype(*combine_schemas(x, y)), schemas + ) + native_items = ( + item.select( + *( + self.col(name).cast(dtype) + if name in schema + else self.lit(None, dtype=dtype) + for name, dtype in out_schema.items() + ) + )._native_frame + for item, schema in zip(items, schemas) + ) + return SparkLikeLazyFrame( + native_dataframe=reduce(lambda x, y: x.union(y), native_items), + version=self._version, + implementation=self._implementation, + ) raise NotImplementedError def concat_str( diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index b6a445cb7e..27f88e5747 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -101,6 +101,10 @@ class UnsupportedDTypeError(NarwhalsError): """Exception raised when trying to convert to a DType which is not supported by the given backend.""" +class SchemaMismatchError(NarwhalsError): + """Exception raised when an unexpected schema mismatch causes an error.""" + + class NarwhalsUnstableWarning(UserWarning): """Warning issued when a method or function is considered unstable in the stable api.""" diff --git a/narwhals/functions.py b/narwhals/functions.py index 322da7019e..23e887ccfa 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -60,11 +60,15 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT how: concatenating strategy - vertical: Concatenate vertically. Column names must match. + - vertical_relaxed: Same as vertical, but additionally coerces columns to + their common supertype if they are mismatched (eg: Int32 → Int64). - horizontal: Concatenate horizontally. If lengths don't match, then missing rows are filled with null values. This is only supported when all inputs are (eager) DataFrames. - diagonal: Finds a union between the column schemas and fills missing column values with null. + - diagonal_relaxed: Same as diagonal, but additionally coerces columns to + their common supertype if they are mismatched (eg: Int32 → Int64). Raises: TypeError: The items to concatenate should either all be eager, or all lazy diff --git a/narwhals/schema.py b/narwhals/schema.py index b459be68fa..d6d602c0a8 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -21,6 +21,8 @@ is_pyarrow_data_type, is_pyarrow_schema, ) +from narwhals.dtypes._supertyping import get_supertype +from narwhals.exceptions import ComputeError, SchemaMismatchError if TYPE_CHECKING: from collections.abc import Iterable @@ -363,3 +365,43 @@ def _from_pandas_like( (name, native_to_narwhals_dtype(dtype, cls._version, impl, allow_object=True)) for name, dtype in schema.items() ) + + +def to_supertype(left: Schema, right: Schema) -> Schema: + # Adapted from polars https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/schema/mod.rs#L83-L96 + if len(left) != len(right): + msg = "schema lengths differ" + raise ComputeError(msg) + + into_out_schema: Mapping[str, DType] = {} + for (lname, ltype), (rname, rtype) in zip(left.items(), right.items()): + if lname != rname: + msg = f"schema names differ: got {rname}, expected {lname}" + raise ComputeError(msg) + if promoted_dtype := get_supertype(left=ltype, right=rtype): + into_out_schema[lname] = promoted_dtype + else: + msg = f"failed to determine supertype of {ltype} and {rtype}" + raise SchemaMismatchError(msg) + + return Schema(into_out_schema) + + +def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]: + """Extend both schemas with names and dtypes missing from the other. + + Returns a tuple of two schemas where each original schema is extended + with the columns that exist in the other schema but not in itself. + + The final order for both schemas is: left schema keys first (in order), + followed by keys missing from left (in the order they appear in right). + """ + left_names = set(left.keys()) + missing_in_left = [(name, right[name]) for name in right if name not in left_names] + + extended_left = Schema([*left.items(), *missing_in_left]) + # Reorder right to match: left keys first, then right-only keys + extended_right = Schema( + [(name, right.get(name, left[name])) for name in extended_left] + ) + return extended_left, extended_right diff --git a/narwhals/typing.py b/narwhals/typing.py index cf2852fc83..6c617da85d 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -192,14 +192,20 @@ def Binary(self) -> type[dtypes.Binary]: ... ClosedInterval: TypeAlias = Literal["left", "right", "none", "both"] """Define which sides of the interval are closed (inclusive).""" -ConcatMethod: TypeAlias = Literal["horizontal", "vertical", "diagonal"] +ConcatMethod: TypeAlias = Literal[ + "horizontal", "vertical", "vertical_relaxed", "diagonal", "diagonal_relaxed" +] """Concatenating strategy. - *"vertical"*: Concatenate vertically. Column names must match. +- *"vertical_relaxed"*: Same as vertical, but additionally coerces columns to their common + supertype if they are mismatched (eg: Int32 → Int64). - *"horizontal"*: Concatenate horizontally. If lengths don't match, then missing rows are filled with null values. - *"diagonal"*: Finds a union between the column schemas and fills missing column values with null. +- *"diagonal_relaxed"*: Same as diagonal, but additionally coerces columns to their common + supertype if they are mismatched (eg: Int32 → Int64). """ FillNullStrategy: TypeAlias = Literal["forward", "backward"] From 336f6528935b33990772f96fbc76ff00e0b27f3a Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 16:40:18 +0100 Subject: [PATCH 02/20] WIP: Add unit tests --- narwhals/functions.py | 10 ++++- tests/frame/concat_test.py | 78 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/narwhals/functions.py b/narwhals/functions.py index 23e887ccfa..b4b887bc82 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -142,8 +142,14 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT raise ValueError(msg) items = tuple(items) validate_laziness(items) - if how not in {"horizontal", "vertical", "diagonal"}: # pragma: no cover - msg = "Only vertical, horizontal and diagonal concatenations are supported." + if how not in { + "horizontal", + "vertical", + "diagonal", + "vertical_relaxed", + "diagonal_relaxed", + }: # pragma: no cover + msg = "Only vertical, vertical_relaxed, horizontal, diagonal and diagonal_relaxed concatenations are supported." raise NotImplementedError(msg) first_item = items[0] if is_narwhals_lazyframe(first_item) and how == "horizontal": diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 35ebd54d39..539d58b16f 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -1,13 +1,22 @@ from __future__ import annotations import re +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw from narwhals.exceptions import InvalidOperationError +from narwhals.schema import Schema from tests.utils import Constructor, ConstructorEager, assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import LazyFrameT + + +def _cast(frame: LazyFrameT, schema: Schema) -> LazyFrameT: + return frame.select(nw.col(name).cast(dtype) for name, dtype in schema.items()) + def test_concat_horizontal(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} @@ -83,3 +92,72 @@ def test_concat_diagonal( with pytest.raises(ValueError, match="No items"): nw.concat([], how="diagonal") + + +@pytest.mark.parametrize( + ("ldata", "lschema", "rdata", "rschema", "expected_data", "expected_schema"), + [ + ( + {"a": [1, 2, 3], "b": [True, False, None]}, + Schema({"a": nw.Int8(), "b": nw.Boolean()}), + {"a": [43, 2, 3], "b": [32, 1, None]}, + Schema({"a": nw.Int16(), "b": nw.Int64()}), + {"a": [1, 2, 3, 43, 2, 3], "b": [1, 0, None, 32, 1, None]}, + Schema({"a": nw.Int16(), "b": nw.Int64()}), + ), + ( + {"a": [1, 2], "b": [2, 1]}, + Schema({"a": nw.Int32(), "b": nw.Int32()}), + {"a": [1.0, 0.2], "b": [None, 0.1]}, + Schema({"a": nw.Float32(), "b": nw.Float32()}), + {"a": [1.0, 2.0, 1.0, 0.2], "b": [2.0, 1.0, None, 0.1]}, + Schema({"a": nw.Float64(), "b": nw.Float64()}), + ), + ], +) +def test_concat_vertically_relaxed( + constructor: Constructor, + ldata: dict[str, Any], + lschema: Schema, + rdata: dict[str, Any], + rschema: Schema, + expected_data: dict[str, Any], + expected_schema: Schema, +) -> None: + # Adapted from https://github.com/pola-rs/polars/blob/b0fdbd34d430d934bda9a4ca3f75e136223bd95b/py-polars/tests/unit/functions/test_concat.py#L64 + left = nw.from_native(constructor(ldata)).lazy().pipe(_cast, lschema) + right = nw.from_native(constructor(rdata)).lazy().pipe(_cast, rschema) + result = nw.concat([left, right], how="vertical_relaxed") + + assert result.collect_schema() == expected_schema + assert_equal_data(result.collect(), expected_data) + + result = nw.concat([right, left], how="vertical_relaxed") + assert result.collect_schema() == expected_schema + + +def test_concat_diagonal_relaxed(constructor: Constructor) -> None: + # Adapted from https://github.com/pola-rs/polars/blob/b0fdbd34d430d934bda9a4ca3f75e136223bd95b/py-polars/tests/unit/functions/test_concat.py#L265C1-L288C41 + schema1 = Schema({"a": nw.Int32(), "c": nw.Int64()}) + data1 = {"a": [1, 2], "c": [10, 20]} + df1 = nw.from_native(constructor(data1)).lazy().pipe(_cast, schema1) + + schema2 = Schema({"a": nw.Float64(), "b": nw.Float32()}) + data2 = {"a": [3.5, 4.5], "b": [30.1, 40.2]} + df2 = nw.from_native(constructor(data2)).lazy().pipe(_cast, schema2) + + schema3 = Schema({"a": nw.Int32(), "c": nw.Int32()}) + data3 = {"b": [5, 6], "c": [50, 60]} + df3 = nw.from_native(constructor(data3)).lazy().pipe(_cast, schema3) + + result = nw.concat([df1, df2, df3], how="diagonal_relaxed") + out_schema = result.collect_schema() + expected_schema = Schema({"a": nw.Float64(), "b": nw.Float64(), "c": nw.Int64()}) + assert out_schema == expected_schema + + expected_data = { + "a": [1.0, 2.0, 3.5, 4.5, None, None], + "c": [10, 20, None, None, 50, 60], + "b": [None, None, 30.1, 40.2, 5.0, 6.0], + } + assert_equal_data(result.collect(), expected_data) From 7d4cb37d8f245f1a7f05dd555ef4af26f66d4861 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 17:38:04 +0100 Subject: [PATCH 03/20] fixup col name and pyarrow --- narwhals/_arrow/namespace.py | 18 ++++++++++++++++-- narwhals/_compliant/namespace.py | 25 +++++-------------------- narwhals/_duckdb/namespace.py | 4 ++-- narwhals/_pandas_like/namespace.py | 19 +++++++++++++++++-- narwhals/_spark_like/namespace.py | 4 ++-- narwhals/schema.py | 8 +++----- tests/frame/concat_test.py | 4 ++-- 7 files changed, 47 insertions(+), 35 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 6cf16857d7..c42121e91c 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -19,6 +19,7 @@ combine_evaluate_output_names, ) from narwhals._utils import Implementation +from narwhals.schema import Schema, combine_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -176,6 +177,21 @@ def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: return pa.concat_tables(dfs, promote_options="default") return pa.concat_tables(dfs, promote=True) # pragma: no cover + def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: + native_schemas = tuple(table.schema for table in dfs) + out_schema = reduce( + lambda x, y: to_supertype(*combine_schemas(x, y)), + (Schema.from_arrow(pa_schema) for pa_schema in native_schemas), + ).to_arrow() + to_schemas = ( + pa.schema([out_schema.field(name) for name in native_schema.names]) + for native_schema in native_schemas + ) + to_concat = tuple( + table.cast(to_schema) for table, to_schema in zip(dfs, to_schemas) + ) + return self._concat_diagonal(to_concat) + def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: names = list(chain.from_iterable(df.column_names for df in dfs)) arrays = tuple(chain.from_iterable(df.itercolumns() for df in dfs)) @@ -195,8 +211,6 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: return pa.concat_tables(dfs) def _concat_vertical_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: - from narwhals.schema import Schema, to_supertype - out_schema = reduce( lambda x, y: to_supertype(x, y), (Schema.from_arrow(table.schema) for table in dfs), diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 0650defe88..5c3058085d 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial, reduce +from functools import partial from typing import TYPE_CHECKING, Any, Protocol, overload from narwhals._compliant.typing import ( @@ -220,6 +220,9 @@ def from_numpy( return self._series.from_numpy(data, context=self) def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def _concat_diagonal_relaxed( + self, dfs: Sequence[NativeFrameT], / + ) -> NativeFrameT: ... def _concat_horizontal( self, dfs: Sequence[NativeFrameT | Any], / ) -> NativeFrameT: ... @@ -241,25 +244,7 @@ def concat( elif how == "diagonal": native = self._concat_diagonal(dfs) elif how == "diagonal_relaxed": - from narwhals.schema import Schema, combine_schemas, to_supertype - - schemas = tuple(Schema(item.collect_schema()) for item in items) - out_schema = reduce( - lambda x, y: to_supertype(*combine_schemas(x, y)), schemas - ) - aligned_items = tuple( - item.select( - *( - self.col(name).cast(dtype) - if name in schema - else self.lit(None, dtype=dtype) - for name, dtype in out_schema.items() - ) - ).native - for item, schema in zip(items, schemas) - ) - native = self._concat_vertical(aligned_items) - + native = self._concat_diagonal_relaxed(dfs) else: # pragma: no cover raise NotImplementedError return self._dataframe.from_native(native, context=self) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 58f4f162a6..e0bc3275cc 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -126,10 +126,10 @@ def concat( *( self.col(name).cast(dtype) if name in schema - else self.lit(None, dtype=dtype) + else self.lit(None, dtype=dtype).alias(name) for name, dtype in out_schema.items() ) - )._native_frame + ).native for item, schema in zip(items, schemas) ) for _item in others: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 2124eb38e7..ac34903cb4 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -18,6 +18,7 @@ from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT from narwhals._pandas_like.utils import is_non_nullable_boolean from narwhals._utils import zip_strict +from narwhals.schema import Schema, combine_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -257,6 +258,22 @@ def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) + def _concat_diagonal_relaxed( + self, dfs: Sequence[NativeDataFrameT], / + ) -> NativeDataFrameT: + out_schema = reduce( + lambda x, y: to_supertype(*combine_schemas(x, y)), + (Schema.from_pandas_like(frame.dtypes.to_dict()) for frame in dfs), + ).to_pandas( + # dtype_backend= # TODO(FBruzzesi): what should this be? + ) + if self._implementation.is_pandas() and self._backend_version < (3,): + native_res = self._concat(dfs, axis=VERTICAL, copy=False) + else: + native_res = self._concat(dfs, axis=VERTICAL) + + return native_res.astype(out_schema) + def _concat_horizontal( self, dfs: Sequence[NativeDataFrameT | NativeSeriesT], / ) -> NativeDataFrameT: @@ -292,8 +309,6 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_vertical_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: - from narwhals.schema import Schema, to_supertype - out_schema = reduce( lambda x, y: to_supertype(x, y), (Schema.from_pandas_like(frame.dtypes.to_dict()) for frame in dfs), diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 6d9d5aeab3..07e0d48eb1 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -202,10 +202,10 @@ def concat( *( self.col(name).cast(dtype) if name in schema - else self.lit(None, dtype=dtype) + else self.lit(None, dtype=dtype).alias(name) for name, dtype in out_schema.items() ) - )._native_frame + ).native for item, schema in zip(items, schemas) ) return SparkLikeLazyFrame( diff --git a/narwhals/schema.py b/narwhals/schema.py index d6d602c0a8..37efb67847 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -397,11 +397,9 @@ def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]: followed by keys missing from left (in the order they appear in right). """ left_names = set(left.keys()) - missing_in_left = [(name, right[name]) for name in right if name not in left_names] + missing_in_left = (kv for kv in right.items() if kv[0] not in left_names) - extended_left = Schema([*left.items(), *missing_in_left]) + extended_left = Schema((*left.items(), *missing_in_left)) # Reorder right to match: left keys first, then right-only keys - extended_right = Schema( - [(name, right.get(name, left[name])) for name in extended_left] - ) + extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items()) return extended_left, extended_right diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 539d58b16f..e45ec88366 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -146,13 +146,13 @@ def test_concat_diagonal_relaxed(constructor: Constructor) -> None: data2 = {"a": [3.5, 4.5], "b": [30.1, 40.2]} df2 = nw.from_native(constructor(data2)).lazy().pipe(_cast, schema2) - schema3 = Schema({"a": nw.Int32(), "c": nw.Int32()}) + schema3 = Schema({"b": nw.Int32(), "c": nw.Int32()}) data3 = {"b": [5, 6], "c": [50, 60]} df3 = nw.from_native(constructor(data3)).lazy().pipe(_cast, schema3) result = nw.concat([df1, df2, df3], how="diagonal_relaxed") out_schema = result.collect_schema() - expected_schema = Schema({"a": nw.Float64(), "b": nw.Float64(), "c": nw.Int64()}) + expected_schema = Schema({"a": nw.Float64(), "c": nw.Int64(), "b": nw.Float64()}) assert out_schema == expected_schema expected_data = { From 0f178496e2d1433472c4745ac3f9a25e01dda48a Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 17:50:39 +0100 Subject: [PATCH 04/20] minor standardization --- narwhals/_compliant/dataframe.py | 4 ---- narwhals/_compliant/namespace.py | 1 - narwhals/_dask/namespace.py | 17 +++++++++++++---- narwhals/_duckdb/namespace.py | 6 +++--- narwhals/_ibis/namespace.py | 4 ++-- narwhals/_pandas_like/namespace.py | 4 ++-- narwhals/_spark_like/namespace.py | 14 +++++++------- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index f91e6eb1c1..3e7810616c 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -323,10 +323,6 @@ class EagerDataFrame( def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() - @property - def native(self) -> NativeDataFrameT: - return self._native_frame - def __narwhals_namespace__( self, ) -> EagerNamespace[ diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 5c3058085d..17c98261bf 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -233,7 +233,6 @@ def _concat_vertical_relaxed( def concat( self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod ) -> EagerDataFrameT: - items = tuple(items) dfs = [item.native for item in items] if how == "horizontal": native = self._concat_horizontal(dfs) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index c4791a7e0d..7766c58c27 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -139,10 +139,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def concat( self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod ) -> DaskLazyFrame: - if not items: - msg = "No items to concatenate" # pragma: no cover - raise AssertionError(msg) - dfs = [i._native_frame for i in items] + dfs = tuple(item.native for item in items) cols_0 = dfs[0].columns if how == "vertical": for i, df in enumerate(dfs[1:], start=1): @@ -163,6 +160,18 @@ def concat( return DaskLazyFrame( dd.concat(dfs, axis=0, join="outer"), version=self._version ) + if how == "vertical_relaxed": + msg = "TODO" + raise NotImplementedError(msg) + return DaskLazyFrame( + dd.concat(dfs, axis=0, join="inner"), version=self._version + ) + if how == "diagonal_relaxed": + msg = "TODO" + raise NotImplementedError(msg) + return DaskLazyFrame( + dd.concat(dfs, axis=0, join="outer"), version=self._version + ) raise NotImplementedError diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index e0bc3275cc..3452c72a8b 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -92,7 +92,7 @@ def concat( msg = "inputs should all have the same schema" raise TypeError(msg) - res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items)) + res = reduce(lambda x, y: x.union(y), (item.native for item in items)) return first._with_native(res) if how == "vertical_relaxed": @@ -101,14 +101,14 @@ def concat( native_items = ( item.select( *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) - )._native_frame + ).native for item in items ) res = reduce(lambda x, y: x.union(y), native_items) return first._with_native(res) if how == "diagonal": - res, *others = (item._native_frame for item in items) + res, *others = (item.native for item in items) for _item in others: # TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996 res = duckdb.sql(""" diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 87f2555b8f..41995a0323 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -80,13 +80,13 @@ def concat( return self._lazyframe.from_native(ibis.union(*native_items), context=self) if how == "vertical_relaxed": - schemas = (Schema(df.collect_schema()) for df in items) + schemas = (Schema(item.collect_schema()) for item in items) out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) native_items = ( item.select( *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) - )._native_frame + ).native for item in items ) return self._lazyframe.from_native(ibis.union(*native_items), context=self) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index ac34903cb4..3e08f36ce6 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -311,11 +311,11 @@ def _concat_vertical_relaxed( ) -> NativeDataFrameT: out_schema = reduce( lambda x, y: to_supertype(x, y), - (Schema.from_pandas_like(frame.dtypes.to_dict()) for frame in dfs), + (Schema.from_pandas_like(df.dtypes.to_dict()) for df in dfs), ).to_pandas( # dtype_backend= # TODO(FBruzzesi): what should this be? ) - to_concat = [frame.astype(out_schema) for frame in dfs] + to_concat = [df.astype(out_schema) for df in dfs] if self._implementation.is_pandas() and self._backend_version < (3,): return self._concat(to_concat, axis=VERTICAL, copy=False) return self._concat(to_concat, axis=VERTICAL) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 07e0d48eb1..b95523e06d 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -149,9 +149,9 @@ def concat( ) -> SparkLikeLazyFrame: items = tuple(items) if how == "vertical": - native_items = [item._native_frame for item in items] - cols_0 = native_items[0].columns - for i, df in enumerate(native_items[1:], start=1): + first, *others = (item.native for item in items) + cols_0 = first.columns + for i, df in enumerate(others, start=1): cols_current = df.columns if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): msg = ( @@ -162,7 +162,7 @@ def concat( raise TypeError(msg) return SparkLikeLazyFrame( - native_dataframe=reduce(lambda x, y: x.union(y), native_items), + native_dataframe=reduce(lambda x, y: x.union(y), others, first), version=self._version, implementation=self._implementation, ) @@ -171,7 +171,7 @@ def concat( return SparkLikeLazyFrame( native_dataframe=reduce( lambda x, y: x.unionByName(y, allowMissingColumns=True), - (item._native_frame for item in items), + (item.native for item in items), ), version=self._version, implementation=self._implementation, @@ -183,7 +183,7 @@ def concat( native_items = ( item.select( *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) - )._native_frame + ).native for item in items ) return SparkLikeLazyFrame( @@ -193,7 +193,7 @@ def concat( ) if how == "diagonal_relaxed": - schemas = [Schema(df.collect_schema()) for df in items] + schemas = tuple(Schema(item.collect_schema()) for item in items) out_schema = reduce( lambda x, y: to_supertype(*combine_schemas(x, y)), schemas ) From cebeeda5e6fbf101f95b8b19599801a4d53a1a58 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 18:55:59 +0100 Subject: [PATCH 05/20] pandas-like promote_dtype_backend --- narwhals/_native.py | 6 ++++- narwhals/_pandas_like/namespace.py | 30 +++++++++++----------- narwhals/_pandas_like/utils.py | 40 +++++++++++++++++++++++++++--- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/narwhals/_native.py b/narwhals/_native.py index 2e06ac7041..46a0260dfc 100644 --- a/narwhals/_native.py +++ b/narwhals/_native.py @@ -224,7 +224,10 @@ def rename(self, *args: Any, **kwds: Any) -> Self | Incomplete: """ -class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): ... +class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): + @property + def dtypes(self) -> _BasePandasLikeSeries: ... + def astype(self, dtype: Any, *args: Any, **kwargs: Any) -> Self: ... class _BasePandasLikeSeries(NativeSeries, _BasePandasLike, Protocol): @@ -240,6 +243,7 @@ def __init__( **kwargs: Any, ) -> None: ... def where(self, cond: Any, other: Any = ..., /) -> Self | Incomplete: ... + def to_dict(self) -> dict[str, Any]: ... class NativeDask(NativeLazyFrame, Protocol): diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 3e08f36ce6..3d33b189b4 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -16,7 +16,7 @@ from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT -from narwhals._pandas_like.utils import is_non_nullable_boolean +from narwhals._pandas_like.utils import is_non_nullable_boolean, promote_dtype_backend from narwhals._utils import zip_strict from narwhals.schema import Schema, combine_schemas, to_supertype @@ -261,17 +261,18 @@ def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_diagonal_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: + dtypes = tuple(df.dtypes.to_dict() for df in dfs) + dtype_backend = promote_dtype_backend(dfs, self._implementation) out_schema = reduce( lambda x, y: to_supertype(*combine_schemas(x, y)), - (Schema.from_pandas_like(frame.dtypes.to_dict()) for frame in dfs), - ).to_pandas( - # dtype_backend= # TODO(FBruzzesi): what should this be? - ) - if self._implementation.is_pandas() and self._backend_version < (3,): - native_res = self._concat(dfs, axis=VERTICAL, copy=False) - else: - native_res = self._concat(dfs, axis=VERTICAL) + (Schema.from_pandas_like(dtype) for dtype in dtypes), + ).to_pandas(dtype_backend=dtype_backend.values()) + native_res = ( + self._concat(dfs, axis=VERTICAL, copy=False) + if self._implementation.is_pandas() and self._backend_version < (3,) + else self._concat(dfs, axis=VERTICAL) + ) return native_res.astype(out_schema) def _concat_horizontal( @@ -309,13 +310,14 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_vertical_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: + dtypes = tuple(df.dtypes.to_dict() for df in dfs) + dtype_backend = promote_dtype_backend(dfs, self._implementation) out_schema = reduce( lambda x, y: to_supertype(x, y), - (Schema.from_pandas_like(df.dtypes.to_dict()) for df in dfs), - ).to_pandas( - # dtype_backend= # TODO(FBruzzesi): what should this be? - ) - to_concat = [df.astype(out_schema) for df in dfs] + (Schema.from_pandas_like(dtype) for dtype in dtypes), + ).to_pandas(dtype_backend=dtype_backend.values()) + + to_concat = (df.astype(out_schema) for df in dfs) if self._implementation.is_pandas() and self._backend_version < (3,): return self._concat(to_concat, axis=VERTICAL, copy=False) return self._concat(to_concat, axis=VERTICAL) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2777edb6f7..082279754c 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,8 +1,8 @@ from __future__ import annotations -import functools import operator import re +from functools import lru_cache, partial from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast import pandas as pd @@ -26,6 +26,7 @@ requires, ) from narwhals.exceptions import ShapeError +from narwhals.typing import DTypeBackend if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -36,6 +37,7 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._duration import IntervalUnit + from narwhals._native import NativePandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import ( @@ -211,7 +213,7 @@ def rename( return cast("NativeNDFrameT", result) # type: ignore[redundant-cast] -@functools.lru_cache(maxsize=16) +@lru_cache(maxsize=16) def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912 dtype = str(native_dtype) @@ -426,7 +428,7 @@ def iter_dtype_backends( return (get_dtype_backend(dtype, implementation) for dtype in dtypes) -@functools.lru_cache(maxsize=16) +@lru_cache(maxsize=16) def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]: return hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype) @@ -704,3 +706,35 @@ class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]): def make_group_by_kwargs(*, drop_null_keys: bool) -> dict[str, bool]: return {"sort": False, "as_index": True, "dropna": drop_null_keys, "observed": True} + + +_DTYPE_BACKEND_PRIORITY: dict[DTypeBackend, Literal[0, 1, 2]] = { + "pyarrow": 2, + "numpy_nullable": 1, + None: 0, +} + + +def promote_dtype_backend( + dataframes: Iterable[NativePandasLikeDataFrame], implementation: Implementation +) -> dict[str, DTypeBackend]: + """Promote dtype backends for each column based on priority rules. + + Priority: pyarrow > numpy_nullable > None + + Returns: + Dictionary mapping column names to the promoted dtype backend + """ + column_backends: dict[str, DTypeBackend] = {} + _get_dtype_backend_impl = partial(get_dtype_backend, implementation=implementation) + for df in dataframes: + for col in df.columns: + backend = _get_dtype_backend_impl(df[col].dtype) + current = column_backends.get(col) + if ( + current is None + or _DTYPE_BACKEND_PRIORITY[backend] > _DTYPE_BACKEND_PRIORITY[current] + ): + column_backends[col] = backend + + return column_backends From 942af32bc100fa59923252e66f3f6cabfa9964da Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 19:11:12 +0100 Subject: [PATCH 06/20] dask and tests --- narwhals/_dask/namespace.py | 30 ++++++++++++++------- narwhals/_pandas_like/utils.py | 4 ++- tests/frame/concat_test.py | 48 ++++++++++++++++++++++++++++++---- 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 7766c58c27..3df63f8be0 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -22,7 +22,9 @@ combine_alias_output_names, combine_evaluate_output_names, ) +from narwhals._pandas_like.utils import promote_dtype_backend from narwhals._utils import Implementation, zip_strict +from narwhals.schema import Schema, combine_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -139,7 +141,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def concat( self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod ) -> DaskLazyFrame: - dfs = tuple(item.native for item in items) + dfs = [item.native for item in items] cols_0 = dfs[0].columns if how == "vertical": for i, df in enumerate(dfs[1:], start=1): @@ -161,17 +163,27 @@ def concat( dd.concat(dfs, axis=0, join="outer"), version=self._version ) if how == "vertical_relaxed": - msg = "TODO" - raise NotImplementedError(msg) + dtypes = tuple(df.dtypes.to_dict() for df in dfs) + dtype_backend = promote_dtype_backend(dfs, self._implementation) + out_schema = reduce( + lambda x, y: to_supertype(x, y), + (Schema.from_pandas_like(dtype) for dtype in dtypes), + ).to_pandas(dtype_backend=dtype_backend.values()) + + to_concat = [df.astype(out_schema) for df in dfs] return DaskLazyFrame( - dd.concat(dfs, axis=0, join="inner"), version=self._version + dd.concat(to_concat, axis=0, join="inner"), version=self._version ) if how == "diagonal_relaxed": - msg = "TODO" - raise NotImplementedError(msg) - return DaskLazyFrame( - dd.concat(dfs, axis=0, join="outer"), version=self._version - ) + dtypes = tuple(df.dtypes.to_dict() for df in dfs) + dtype_backend = promote_dtype_backend(dfs, self._implementation) + out_schema = reduce( + lambda x, y: to_supertype(*combine_schemas(x, y)), + (Schema.from_pandas_like(dtype) for dtype in dtypes), + ).to_pandas(dtype_backend=dtype_backend.values()) + + native_res = dd.concat(dfs, axis=0, join="outer").astype(out_schema) + return DaskLazyFrame(native_res, version=self._version) raise NotImplementedError diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 082279754c..4f1c126654 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -32,6 +32,7 @@ from collections.abc import Iterable, Iterator, Mapping from types import ModuleType + from dask.dataframe import DataFrame as NativeDaskDataFrame from pandas._typing import Dtype as PandasDtype from pandas.core.dtypes.dtypes import BaseMaskedDtype from typing_extensions import TypeAlias, TypeIs @@ -716,7 +717,8 @@ def make_group_by_kwargs(*, drop_null_keys: bool) -> dict[str, bool]: def promote_dtype_backend( - dataframes: Iterable[NativePandasLikeDataFrame], implementation: Implementation + dataframes: Iterable[NativePandasLikeDataFrame] | Iterable[NativeDaskDataFrame], + implementation: Implementation, ) -> dict[str, DTypeBackend]: """Promote dtype backends for each column based on priority rules. diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index e45ec88366..8470817e76 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -114,6 +114,7 @@ def test_concat_diagonal( Schema({"a": nw.Float64(), "b": nw.Float64()}), ), ], + ids=["nullable-integer", "nullable-float"], ) def test_concat_vertically_relaxed( constructor: Constructor, @@ -123,8 +124,16 @@ def test_concat_vertically_relaxed( rschema: Schema, expected_data: dict[str, Any], expected_schema: Schema, + request: pytest.FixtureRequest, ) -> None: # Adapted from https://github.com/pola-rs/polars/blob/b0fdbd34d430d934bda9a4ca3f75e136223bd95b/py-polars/tests/unit/functions/test_concat.py#L64 + is_nullable_int = request.node.callspec.id.endswith("nullable-integer") + if is_nullable_int and any( + x in str(constructor) + for x in ("dask", "pandas_constructor", "modin_constructor", "cudf") + ): + reason = "Cannot convert non-finite values (NA or inf)" + request.applymarker(pytest.mark.xfail(reason=reason)) left = nw.from_native(constructor(ldata)).lazy().pipe(_cast, lschema) right = nw.from_native(constructor(rdata)).lazy().pipe(_cast, rschema) result = nw.concat([left, right], how="vertical_relaxed") @@ -136,23 +145,52 @@ def test_concat_vertically_relaxed( assert result.collect_schema() == expected_schema -def test_concat_diagonal_relaxed(constructor: Constructor) -> None: +@pytest.mark.parametrize( + ("schema1", "schema2", "schema3", "expected_schema"), + [ + ( + Schema({"a": nw.Int32(), "c": nw.Int64()}), + Schema({"a": nw.Float64(), "b": nw.Float32()}), + Schema({"b": nw.Int32(), "c": nw.Int32()}), + Schema({"a": nw.Float64(), "c": nw.Int64(), "b": nw.Float64()}), + ), + ( + Schema({"a": nw.Float32(), "c": nw.Float32()}), + Schema({"a": nw.Float64(), "b": nw.Float32()}), + Schema({"b": nw.Float32(), "c": nw.Float32()}), + Schema({"a": nw.Float64(), "c": nw.Float32(), "b": nw.Float32()}), + ), + ], + ids=["nullable-integer", "nullable-float"], +) +def test_concat_diagonal_relaxed( + constructor: Constructor, + schema1: Schema, + schema2: Schema, + schema3: Schema, + expected_schema: Schema, + request: pytest.FixtureRequest, +) -> None: # Adapted from https://github.com/pola-rs/polars/blob/b0fdbd34d430d934bda9a4ca3f75e136223bd95b/py-polars/tests/unit/functions/test_concat.py#L265C1-L288C41 - schema1 = Schema({"a": nw.Int32(), "c": nw.Int64()}) + is_nullable_int = request.node.callspec.id.endswith("nullable-integer") + if is_nullable_int and any( + x in str(constructor) + for x in ("dask", "pandas_constructor", "modin_constructor", "cudf") + ): + reason = "Cannot convert non-finite values (NA or inf)" + request.applymarker(pytest.mark.xfail(reason=reason)) + data1 = {"a": [1, 2], "c": [10, 20]} df1 = nw.from_native(constructor(data1)).lazy().pipe(_cast, schema1) - schema2 = Schema({"a": nw.Float64(), "b": nw.Float32()}) data2 = {"a": [3.5, 4.5], "b": [30.1, 40.2]} df2 = nw.from_native(constructor(data2)).lazy().pipe(_cast, schema2) - schema3 = Schema({"b": nw.Int32(), "c": nw.Int32()}) data3 = {"b": [5, 6], "c": [50, 60]} df3 = nw.from_native(constructor(data3)).lazy().pipe(_cast, schema3) result = nw.concat([df1, df2, df3], how="diagonal_relaxed") out_schema = result.collect_schema() - expected_schema = Schema({"a": nw.Float64(), "c": nw.Int64(), "b": nw.Float64()}) assert out_schema == expected_schema expected_data = { From 9743d24f3fbba7b7ea78c2f170bf39a14ee15296 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 19:26:18 +0100 Subject: [PATCH 07/20] add to_supertype coverage --- tests/frame/schema_test.py | 56 +++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index d90916b029..3cfd5c38ee 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -8,7 +8,8 @@ import pytest import narwhals as nw -from narwhals.exceptions import PerformanceWarning +from narwhals.exceptions import ComputeError, PerformanceWarning, SchemaMismatchError +from narwhals.schema import to_supertype from tests.utils import PANDAS_VERSION, POLARS_VERSION, ConstructorPandasLike if TYPE_CHECKING: @@ -22,6 +23,7 @@ IntoArrowSchema, IntoPandasSchema, IntoPolarsSchema, + IntoSchema, ) from tests.utils import Constructor, ConstructorEager @@ -704,3 +706,55 @@ def test_schema_from_to_roundtrip() -> None: assert nw_schema_1 == nw_schema_2 assert nw_schema_2 == nw_schema_3 assert py_schema_1 == py_schema_2 + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + ( + {"a": nw.Int64(), "b": nw.String()}, + {"a": nw.Int64(), "b": nw.String()}, + {"a": nw.Int64(), "b": nw.String()}, + ), + ( + {"a": nw.Int32(), "b": nw.Float32()}, + {"a": nw.Int64(), "b": nw.Float64()}, + {"a": nw.Int64(), "b": nw.Float64()}, + ), + ({"a": nw.Int32()}, {"a": nw.Float64()}, {"a": nw.Float64()}), + ({"a": nw.Datetime("ns")}, {"a": nw.Datetime("us")}, {"a": nw.Datetime("us")}), + ], +) +def test_to_supertype(left: IntoSchema, right: IntoSchema, expected: IntoSchema) -> None: + result = to_supertype(nw.Schema(left), nw.Schema(right)) + assert result == expected + + +@pytest.mark.parametrize( + ("left", "right", "context"), + [ + ( + {"a": nw.Int64()}, + {"a": nw.Int64(), "b": nw.String()}, + pytest.raises(ComputeError, match="schema lengths differ"), + ), + ( + {"a": nw.Int64()}, + {"b": nw.Int64()}, + pytest.raises(ComputeError, match="schema names differ: got b, expected a"), + ), + ( + {"a": nw.String()}, + {"a": nw.Int64()}, + pytest.raises( + SchemaMismatchError, + match="failed to determine supertype of String and Int64", + ), + ), + ], +) +def test_to_supertype_exceptions( + left: IntoSchema, right: IntoSchema, context: pytest.RaisesExc +) -> None: + with context: + to_supertype(nw.Schema(left), nw.Schema(right)) From 3328f2ffe0391a1007c4f65a682675c8f35dd6c5 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 11 Jan 2026 19:52:03 +0100 Subject: [PATCH 08/20] skip ibis diagonal --- tests/frame/concat_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 8470817e76..eb10b6a14a 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -180,6 +180,9 @@ def test_concat_diagonal_relaxed( reason = "Cannot convert non-finite values (NA or inf)" request.applymarker(pytest.mark.xfail(reason=reason)) + if "ibis" in str(constructor): + pytest.skip(reason="NotImplementedError") + data1 = {"a": [1, 2], "c": [10, 20]} df1 = nw.from_native(constructor(data1)).lazy().pipe(_cast, schema1) From 23d0cd4abd28e5318ca7d0dc645bd1658dec7d69 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 11 Jan 2026 21:20:45 +0000 Subject: [PATCH 09/20] fix(typing): Make `pyright` happier --- narwhals/_native.py | 3 --- narwhals/_pandas_like/namespace.py | 10 +++++++--- narwhals/_pandas_like/utils.py | 13 ++++++++++++- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/narwhals/_native.py b/narwhals/_native.py index 46a0260dfc..c722b61fe4 100644 --- a/narwhals/_native.py +++ b/narwhals/_native.py @@ -225,8 +225,6 @@ def rename(self, *args: Any, **kwds: Any) -> Self | Incomplete: class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): - @property - def dtypes(self) -> _BasePandasLikeSeries: ... def astype(self, dtype: Any, *args: Any, **kwargs: Any) -> Self: ... @@ -243,7 +241,6 @@ def __init__( **kwargs: Any, ) -> None: ... def where(self, cond: Any, other: Any = ..., /) -> Self | Incomplete: ... - def to_dict(self) -> dict[str, Any]: ... class NativeDask(NativeLazyFrame, Protocol): diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 3d33b189b4..92c6ebaa68 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -16,7 +16,11 @@ from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT -from narwhals._pandas_like.utils import is_non_nullable_boolean, promote_dtype_backend +from narwhals._pandas_like.utils import ( + is_non_nullable_boolean, + native_schema, + promote_dtype_backend, +) from narwhals._utils import zip_strict from narwhals.schema import Schema, combine_schemas, to_supertype @@ -261,7 +265,7 @@ def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_diagonal_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: - dtypes = tuple(df.dtypes.to_dict() for df in dfs) + dtypes = tuple(native_schema(df) for df in dfs) dtype_backend = promote_dtype_backend(dfs, self._implementation) out_schema = reduce( lambda x, y: to_supertype(*combine_schemas(x, y)), @@ -310,7 +314,7 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_vertical_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: - dtypes = tuple(df.dtypes.to_dict() for df in dfs) + dtypes = tuple(native_schema(df) for df in dfs) dtype_backend = promote_dtype_backend(dfs, self._implementation) out_schema = reduce( lambda x, y: to_supertype(x, y), diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 4f1c126654..3088bcc2b6 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -47,7 +47,13 @@ NativeSeriesT, ) from narwhals.dtypes import DType - from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray + from narwhals.typing import ( + DTypeBackend, + IntoDType, + IntoPandasSchema, + TimeUnit, + _1DArray, + ) ExprT = TypeVar("ExprT", bound=PandasLikeExpr) UnitCurrent: TypeAlias = TimeUnit @@ -55,6 +61,7 @@ BinOpBroadcast: TypeAlias = Callable[[Any, int], Any] IntoRhs: TypeAlias = int +Incomplete: TypeAlias = Any PANDAS_LIKE_IMPLEMENTATION = { Implementation.PANDAS, @@ -740,3 +747,7 @@ def promote_dtype_backend( column_backends[col] = backend return column_backends + + +def native_schema(df: Incomplete) -> IntoPandasSchema: + return df.dtypes.to_dict() From e0ce9eb4244769c89e1ce5826e3e27a0fb68b4cc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 11 Jan 2026 21:36:43 +0000 Subject: [PATCH 10/20] fix(typing): Pacify `mypy` for `pandas_like` --- narwhals/_native.py | 3 +-- narwhals/_pandas_like/namespace.py | 11 +++++++---- narwhals/_pandas_like/utils.py | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/narwhals/_native.py b/narwhals/_native.py index c722b61fe4..2e06ac7041 100644 --- a/narwhals/_native.py +++ b/narwhals/_native.py @@ -224,8 +224,7 @@ def rename(self, *args: Any, **kwds: Any) -> Self | Incomplete: """ -class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): - def astype(self, dtype: Any, *args: Any, **kwargs: Any) -> Self: ... +class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): ... class _BasePandasLikeSeries(NativeSeries, _BasePandasLike, Protocol): diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 92c6ebaa68..0b88355047 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -17,7 +17,9 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT from narwhals._pandas_like.utils import ( + cast_native, is_non_nullable_boolean, + iter_cast_native, native_schema, promote_dtype_backend, ) @@ -277,7 +279,7 @@ def _concat_diagonal_relaxed( if self._implementation.is_pandas() and self._backend_version < (3,) else self._concat(dfs, axis=VERTICAL) ) - return native_res.astype(out_schema) + return cast_native(native_res, out_schema) def _concat_horizontal( self, dfs: Sequence[NativeDataFrameT | NativeSeriesT], / @@ -321,10 +323,11 @@ def _concat_vertical_relaxed( (Schema.from_pandas_like(dtype) for dtype in dtypes), ).to_pandas(dtype_backend=dtype_backend.values()) - to_concat = (df.astype(out_schema) for df in dfs) if self._implementation.is_pandas() and self._backend_version < (3,): - return self._concat(to_concat, axis=VERTICAL, copy=False) - return self._concat(to_concat, axis=VERTICAL) + return self._concat( + iter_cast_native(dfs, out_schema), axis=VERTICAL, copy=False + ) + return self._concat(iter_cast_native(dfs, out_schema), axis=VERTICAL) def concat_str( self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 3088bcc2b6..3a3869164b 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -751,3 +751,18 @@ def promote_dtype_backend( def native_schema(df: Incomplete) -> IntoPandasSchema: return df.dtypes.to_dict() + + +def cast_native(df: NativeDataFrameT, schema: IntoPandasSchema) -> NativeDataFrameT: + df_: Incomplete = df + return cast("NativeDataFrameT", df_.astype(schema)) + + +def iter_cast_native( + dfs: Iterable[NativeDataFrameT], schema: IntoPandasSchema +) -> Iterator[NativeDataFrameT]: + if TYPE_CHECKING: + for df in dfs: + yield cast_native(df, schema) + else: + yield from (df.astype(schema) for df in dfs) From 4b484f477352878d8065e3cc1155ab530a879a4a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 11 Jan 2026 21:42:02 +0000 Subject: [PATCH 11/20] wow that was a useless error message! ```py error: Unsupported target for indexed assignment ("Mapping[str, DType]") ``` --- narwhals/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/schema.py b/narwhals/schema.py index 37efb67847..8554cf646b 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -373,7 +373,7 @@ def to_supertype(left: Schema, right: Schema) -> Schema: msg = "schema lengths differ" raise ComputeError(msg) - into_out_schema: Mapping[str, DType] = {} + into_out_schema: dict[str, DType] = {} for (lname, ltype), (rname, rtype) in zip(left.items(), right.items()): if lname != rname: msg = f"schema names differ: got {rname}, expected {lname}" From 683c83587e03b60558123fdc8378d8b445899f94 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 11 Jan 2026 21:44:21 +0000 Subject: [PATCH 12/20] fix(typing): Tell `mypy` we have a wider type than the first assignment --- narwhals/_duckdb/namespace.py | 2 +- narwhals/_spark_like/namespace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 3452c72a8b..4d8a95f2b4 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -96,7 +96,7 @@ def concat( return first._with_native(res) if how == "vertical_relaxed": - schemas = (Schema(df.collect_schema()) for df in items) + schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) native_items = ( item.select( diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index b95523e06d..8d4e63fcb0 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -178,7 +178,7 @@ def concat( ) if how == "vertical_relaxed": - schemas = (Schema(df.collect_schema()) for df in items) + schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) native_items = ( item.select( From 5bffccd79a9054764df765d31078d8b3ae0b4173 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 11 Jan 2026 22:29:13 +0000 Subject: [PATCH 13/20] perf: Avoid unnecessary `lambda`s MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If the function already exists, we can use it instead 😄 --- narwhals/_arrow/namespace.py | 3 +-- narwhals/_dask/namespace.py | 3 +-- narwhals/_duckdb/namespace.py | 10 ++++------ narwhals/_ibis/namespace.py | 2 +- narwhals/_pandas_like/namespace.py | 3 +-- narwhals/_spark_like/namespace.py | 8 +++++--- 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index c42121e91c..bc450265ad 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -212,8 +212,7 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: def _concat_vertical_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: out_schema = reduce( - lambda x, y: to_supertype(x, y), - (Schema.from_arrow(table.schema) for table in dfs), + to_supertype, (Schema.from_arrow(table.schema) for table in dfs) ).to_arrow() return pa.concat_tables([table.cast(out_schema) for table in dfs]) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 3df63f8be0..5254744ea0 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -166,8 +166,7 @@ def concat( dtypes = tuple(df.dtypes.to_dict() for df in dfs) dtype_backend = promote_dtype_backend(dfs, self._implementation) out_schema = reduce( - lambda x, y: to_supertype(x, y), - (Schema.from_pandas_like(dtype) for dtype in dtypes), + to_supertype, (Schema.from_pandas_like(dtype) for dtype in dtypes) ).to_pandas(dtype_backend=dtype_backend.values()) to_concat = [df.astype(out_schema) for df in dfs] diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 4d8a95f2b4..41c2e8d299 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any import duckdb -from duckdb import CoalesceOperator, Expression +from duckdb import CoalesceOperator, DuckDBPyRelation, Expression from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr @@ -32,8 +32,6 @@ if TYPE_CHECKING: from collections.abc import Iterable - from duckdb import DuckDBPyRelation # noqa: F401 - from narwhals._compliant.window import WindowInputs from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral @@ -92,19 +90,19 @@ def concat( msg = "inputs should all have the same schema" raise TypeError(msg) - res = reduce(lambda x, y: x.union(y), (item.native for item in items)) + res = reduce(DuckDBPyRelation.union, (item.native for item in items)) return first._with_native(res) if how == "vertical_relaxed": schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) - out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) + out_schema = reduce(to_supertype, schemas) native_items = ( item.select( *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) ).native for item in items ) - res = reduce(lambda x, y: x.union(y), native_items) + res = reduce(DuckDBPyRelation.union, native_items) return first._with_native(res) if how == "diagonal": diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 41995a0323..fd5425e13a 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -81,7 +81,7 @@ def concat( if how == "vertical_relaxed": schemas = (Schema(item.collect_schema()) for item in items) - out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) + out_schema = reduce(to_supertype, schemas) native_items = ( item.select( diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 0b88355047..293589471f 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -319,8 +319,7 @@ def _concat_vertical_relaxed( dtypes = tuple(native_schema(df) for df in dfs) dtype_backend = promote_dtype_backend(dfs, self._implementation) out_schema = reduce( - lambda x, y: to_supertype(x, y), - (Schema.from_pandas_like(dtype) for dtype in dtypes), + to_supertype, (Schema.from_pandas_like(dtype) for dtype in dtypes) ).to_pandas(dtype_backend=dtype_backend.values()) if self._implementation.is_pandas() and self._backend_version < (3,): diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 8d4e63fcb0..275f82589a 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -179,15 +179,16 @@ def concat( if how == "vertical_relaxed": schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) - out_schema = reduce(lambda x, y: to_supertype(x, y), schemas) + out_schema = reduce(to_supertype, schemas) native_items = ( item.select( *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) ).native for item in items ) + union = items[0].native.__class__.union return SparkLikeLazyFrame( - native_dataframe=reduce(lambda x, y: x.union(y), native_items), + native_dataframe=reduce(union, native_items), version=self._version, implementation=self._implementation, ) @@ -208,8 +209,9 @@ def concat( ).native for item, schema in zip(items, schemas) ) + union = items[0].native.__class__.union return SparkLikeLazyFrame( - native_dataframe=reduce(lambda x, y: x.union(y), native_items), + native_dataframe=reduce(union, native_items), version=self._version, implementation=self._implementation, ) From 8fabb13e4f74be9236504b139334d6dbbb697684 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:19:28 +0000 Subject: [PATCH 14/20] perf: Use a generator instead of intermediate `dict` https://github.com/narwhals-dev/narwhals/pull/3398#discussion_r2680392032 --- narwhals/schema.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/narwhals/schema.py b/narwhals/schema.py index 8554cf646b..3f5ea8650a 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -367,24 +367,30 @@ def _from_pandas_like( ) -def to_supertype(left: Schema, right: Schema) -> Schema: - # Adapted from polars https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/schema/mod.rs#L83-L96 +def _supertype(left: DType, right: DType) -> DType: + if promoted_dtype := get_supertype(left, right): + return promoted_dtype + msg = f"failed to determine supertype of {left} and {right}" + raise SchemaMismatchError(msg) + + +def _ensure_names_match(left: Schema, right: Schema) -> tuple[Schema, Schema]: if len(left) != len(right): msg = "schema lengths differ" raise ComputeError(msg) + if left.names() != right.names(): + it = ((lname, rname) for (lname, rname) in zip(left, right) if lname != rname) + lname, rname = next(it) + msg = f"schema names differ: got {rname}, expected {lname}" + raise ComputeError(msg) + return left, right + - into_out_schema: dict[str, DType] = {} - for (lname, ltype), (rname, rtype) in zip(left.items(), right.items()): - if lname != rname: - msg = f"schema names differ: got {rname}, expected {lname}" - raise ComputeError(msg) - if promoted_dtype := get_supertype(left=ltype, right=rtype): - into_out_schema[lname] = promoted_dtype - else: - msg = f"failed to determine supertype of {ltype} and {rtype}" - raise SchemaMismatchError(msg) - - return Schema(into_out_schema) +def to_supertype(left: Schema, right: Schema) -> Schema: + # Adapted from polars https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/schema/mod.rs#L83-L96 + left, right = _ensure_names_match(left, right) + it = zip(left.keys(), left.values(), right.values()) + return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it) def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]: From c658320b7754e7ca0c66b2396247779895379ac8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 Jan 2026 21:57:36 +0000 Subject: [PATCH 15/20] perf: Optimize, rename `promote_dtype_backends` - Reuse the schemas we've already collected - And rename the variables so it is obvious - Skip checking if columns exist for the first schema - Use a generator inside `dict.update` (instead of many `__setitem__`s) - Use `iter_dtype_backends` instead of creating a new `partial` Unrelated to performance: Return `.values()` instead of a `dict` (no usage of the keys anywhere?) --- narwhals/_dask/namespace.py | 16 +++++------ narwhals/_pandas_like/namespace.py | 16 +++++------ narwhals/_pandas_like/utils.py | 46 +++++++++++++++--------------- 3 files changed, 37 insertions(+), 41 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 5254744ea0..4a98511a9a 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -22,7 +22,7 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._pandas_like.utils import promote_dtype_backend +from narwhals._pandas_like.utils import promote_dtype_backends from narwhals._utils import Implementation, zip_strict from narwhals.schema import Schema, combine_schemas, to_supertype @@ -163,23 +163,21 @@ def concat( dd.concat(dfs, axis=0, join="outer"), version=self._version ) if how == "vertical_relaxed": - dtypes = tuple(df.dtypes.to_dict() for df in dfs) - dtype_backend = promote_dtype_backend(dfs, self._implementation) + schemas = tuple(df.dtypes.to_dict() for df in dfs) out_schema = reduce( - to_supertype, (Schema.from_pandas_like(dtype) for dtype in dtypes) - ).to_pandas(dtype_backend=dtype_backend.values()) + to_supertype, (Schema.from_pandas_like(schema) for schema in schemas) + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) to_concat = [df.astype(out_schema) for df in dfs] return DaskLazyFrame( dd.concat(to_concat, axis=0, join="inner"), version=self._version ) if how == "diagonal_relaxed": - dtypes = tuple(df.dtypes.to_dict() for df in dfs) - dtype_backend = promote_dtype_backend(dfs, self._implementation) + schemas = tuple(df.dtypes.to_dict() for df in dfs) out_schema = reduce( lambda x, y: to_supertype(*combine_schemas(x, y)), - (Schema.from_pandas_like(dtype) for dtype in dtypes), - ).to_pandas(dtype_backend=dtype_backend.values()) + (Schema.from_pandas_like(schema) for schema in schemas), + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) native_res = dd.concat(dfs, axis=0, join="outer").astype(out_schema) return DaskLazyFrame(native_res, version=self._version) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 293589471f..ce0f7dd538 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -21,7 +21,7 @@ is_non_nullable_boolean, iter_cast_native, native_schema, - promote_dtype_backend, + promote_dtype_backends, ) from narwhals._utils import zip_strict from narwhals.schema import Schema, combine_schemas, to_supertype @@ -267,12 +267,11 @@ def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_diagonal_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: - dtypes = tuple(native_schema(df) for df in dfs) - dtype_backend = promote_dtype_backend(dfs, self._implementation) + schemas = tuple(native_schema(df) for df in dfs) out_schema = reduce( lambda x, y: to_supertype(*combine_schemas(x, y)), - (Schema.from_pandas_like(dtype) for dtype in dtypes), - ).to_pandas(dtype_backend=dtype_backend.values()) + (Schema.from_pandas_like(schema) for schema in schemas), + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) native_res = ( self._concat(dfs, axis=VERTICAL, copy=False) @@ -316,11 +315,10 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram def _concat_vertical_relaxed( self, dfs: Sequence[NativeDataFrameT], / ) -> NativeDataFrameT: - dtypes = tuple(native_schema(df) for df in dfs) - dtype_backend = promote_dtype_backend(dfs, self._implementation) + schemas = tuple(native_schema(df) for df in dfs) out_schema = reduce( - to_supertype, (Schema.from_pandas_like(dtype) for dtype in dtypes) - ).to_pandas(dtype_backend=dtype_backend.values()) + to_supertype, (Schema.from_pandas_like(schema) for schema in schemas) + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) if self._implementation.is_pandas() and self._backend_version < (3,): return self._concat( diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 3a3869164b..7870446658 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -2,7 +2,7 @@ import operator import re -from functools import lru_cache, partial +from functools import lru_cache from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast import pandas as pd @@ -32,13 +32,11 @@ from collections.abc import Iterable, Iterator, Mapping from types import ModuleType - from dask.dataframe import DataFrame as NativeDaskDataFrame from pandas._typing import Dtype as PandasDtype from pandas.core.dtypes.dtypes import BaseMaskedDtype from typing_extensions import TypeAlias, TypeIs from narwhals._duration import IntervalUnit - from narwhals._native import NativePandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import ( @@ -723,30 +721,32 @@ def make_group_by_kwargs(*, drop_null_keys: bool) -> dict[str, bool]: } -def promote_dtype_backend( - dataframes: Iterable[NativePandasLikeDataFrame] | Iterable[NativeDaskDataFrame], - implementation: Implementation, -) -> dict[str, DTypeBackend]: +def iter_names_dtype_backends( + schema: IntoPandasSchema, impl: Implementation, / +) -> Iterator[tuple[str, DTypeBackend]]: + yield from zip(schema, iter_dtype_backends(schema.values(), impl)) + + +def promote_dtype_backends( + schemas: Iterable[IntoPandasSchema], implementation: Implementation +) -> Iterable[DTypeBackend]: """Promote dtype backends for each column based on priority rules. Priority: pyarrow > numpy_nullable > None - - Returns: - Dictionary mapping column names to the promoted dtype backend """ - column_backends: dict[str, DTypeBackend] = {} - _get_dtype_backend_impl = partial(get_dtype_backend, implementation=implementation) - for df in dataframes: - for col in df.columns: - backend = _get_dtype_backend_impl(df[col].dtype) - current = column_backends.get(col) - if ( - current is None - or _DTYPE_BACKEND_PRIORITY[backend] > _DTYPE_BACKEND_PRIORITY[current] - ): - column_backends[col] = backend - - return column_backends + impl = implementation + it_schemas = iter(schemas) + col_backends = dict(iter_names_dtype_backends(next(it_schemas), impl)) + priority = _DTYPE_BACKEND_PRIORITY.__getitem__ + current = col_backends.__getitem__ + seen = col_backends.keys() + for schema in it_schemas: + col_backends.update( + (name, backend) + for name, backend in iter_names_dtype_backends(schema, impl) + if name not in seen or priority(backend) > priority(current(name)) + ) + return col_backends.values() def native_schema(df: Incomplete) -> IntoPandasSchema: From c832b72e5007d981434b7ca254fe8f98557d50ca Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 20 Jan 2026 14:45:23 +0100 Subject: [PATCH 16/20] combine_schemas -> merge_schemas --- narwhals/_arrow/namespace.py | 5 ++--- narwhals/_dask/namespace.py | 5 ++--- narwhals/_duckdb/namespace.py | 6 ++---- narwhals/_pandas_like/namespace.py | 5 ++--- narwhals/_spark_like/namespace.py | 6 ++---- narwhals/schema.py | 30 +++++++++++++++++++++++------- 6 files changed, 33 insertions(+), 24 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index bc450265ad..42f98ec5e0 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -19,7 +19,7 @@ combine_evaluate_output_names, ) from narwhals._utils import Implementation -from narwhals.schema import Schema, combine_schemas, to_supertype +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -180,8 +180,7 @@ def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: native_schemas = tuple(table.schema for table in dfs) out_schema = reduce( - lambda x, y: to_supertype(*combine_schemas(x, y)), - (Schema.from_arrow(pa_schema) for pa_schema in native_schemas), + merge_schemas, (Schema.from_arrow(pa_schema) for pa_schema in native_schemas) ).to_arrow() to_schemas = ( pa.schema([out_schema.field(name) for name in native_schema.names]) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 4a98511a9a..a49c1ae73a 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -24,7 +24,7 @@ ) from narwhals._pandas_like.utils import promote_dtype_backends from narwhals._utils import Implementation, zip_strict -from narwhals.schema import Schema, combine_schemas, to_supertype +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -175,8 +175,7 @@ def concat( if how == "diagonal_relaxed": schemas = tuple(df.dtypes.to_dict() for df in dfs) out_schema = reduce( - lambda x, y: to_supertype(*combine_schemas(x, y)), - (Schema.from_pandas_like(schema) for schema in schemas), + merge_schemas, (Schema.from_pandas_like(schema) for schema in schemas) ).to_pandas(promote_dtype_backends(schemas, self._implementation)) native_res = dd.concat(dfs, axis=0, join="outer").astype(out_schema) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 41c2e8d299..fb69bd9d46 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -27,7 +27,7 @@ ) from narwhals._sql.namespace import SQLNamespace from narwhals._utils import Implementation -from narwhals.schema import Schema, combine_schemas, to_supertype +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable @@ -116,9 +116,7 @@ def concat( if how == "diagonal_relaxed": schemas = [Schema(df.collect_schema()) for df in items] - out_schema = reduce( - lambda x, y: to_supertype(*combine_schemas(x, y)), schemas - ) + out_schema = reduce(merge_schemas, schemas) res, *others = ( item.select( *( diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index ce0f7dd538..47267af8c6 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -24,7 +24,7 @@ promote_dtype_backends, ) from narwhals._utils import zip_strict -from narwhals.schema import Schema, combine_schemas, to_supertype +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -269,8 +269,7 @@ def _concat_diagonal_relaxed( ) -> NativeDataFrameT: schemas = tuple(native_schema(df) for df in dfs) out_schema = reduce( - lambda x, y: to_supertype(*combine_schemas(x, y)), - (Schema.from_pandas_like(schema) for schema in schemas), + merge_schemas, (Schema.from_pandas_like(schema) for schema in schemas) ).to_pandas(promote_dtype_backends(schemas, self._implementation)) native_res = ( diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 275f82589a..923271a9d8 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,7 +19,7 @@ true_divide, ) from narwhals._sql.namespace import SQLNamespace -from narwhals.schema import Schema, combine_schemas, to_supertype +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable @@ -195,9 +195,7 @@ def concat( if how == "diagonal_relaxed": schemas = tuple(Schema(item.collect_schema()) for item in items) - out_schema = reduce( - lambda x, y: to_supertype(*combine_schemas(x, y)), schemas - ) + out_schema = reduce(merge_schemas, schemas) native_items = ( item.select( *( diff --git a/narwhals/schema.py b/narwhals/schema.py index 3f5ea8650a..aae3f5c370 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -393,14 +393,28 @@ def to_supertype(left: Schema, right: Schema) -> Schema: return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it) -def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]: - """Extend both schemas with names and dtypes missing from the other. +def merge_schemas(left: Schema, right: Schema) -> Schema: + """Merge two schemas, combining columns and resolving types to their supertype. - Returns a tuple of two schemas where each original schema is extended - with the columns that exist in the other schema but not in itself. + This function merges two schemas by: - The final order for both schemas is: left schema keys first (in order), - followed by keys missing from left (in the order they appear in right). + 1. Taking all columns from the left schema (preserving order). + 2. Appending any columns from the right schema that are missing in the left. + 3. For columns present in both schemas, promoting to the common supertype. + + Arguments: + left: The primary schema whose column order takes precedence. + right: The secondary schema to merge. Columns missing in left are appended. + + Returns: + A new Schema with merged columns and supertypes resolved. + + Raises: + SchemaMismatchError: If no common supertype exists for a shared column. + + Note: + For columns present only in one schema, the type from that schema is used + (via `right.get(name, left_type)` fallback for right-only columns). """ left_names = set(left.keys()) missing_in_left = (kv for kv in right.items() if kv[0] not in left_names) @@ -408,4 +422,6 @@ def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]: extended_left = Schema((*left.items(), *missing_in_left)) # Reorder right to match: left keys first, then right-only keys extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items()) - return extended_left, extended_right + + it = zip(extended_left.keys(), extended_left.values(), extended_right.values()) + return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it) From 4a1b946eeb3627812a0d44fd270f77837629da3c Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 30 Jan 2026 22:07:01 +0100 Subject: [PATCH 17/20] preserve unknown for lazy backends --- narwhals/_duckdb/namespace.py | 21 ++++++++------------- narwhals/_ibis/namespace.py | 9 +++------ narwhals/_spark_like/namespace.py | 12 ++++++------ narwhals/_utils.py | 20 +++++++++++++++++++- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index fb69bd9d46..adf3f558f1 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -26,7 +26,7 @@ combine_evaluate_output_names, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._utils import Implementation +from narwhals._utils import Implementation, safe_cast from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: @@ -97,10 +97,7 @@ def concat( schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) out_schema = reduce(to_supertype, schemas) native_items = ( - item.select( - *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) - ).native - for item in items + item.select(*safe_cast(self, out_schema)).native for item in items ) res = reduce(DuckDBPyRelation.union, native_items) return first._with_native(res) @@ -117,22 +114,20 @@ def concat( if how == "diagonal_relaxed": schemas = [Schema(df.collect_schema()) for df in items] out_schema = reduce(merge_schemas, schemas) - res, *others = ( + native_items = ( item.select( *( - self.col(name).cast(dtype) + self.col(name) if name in schema else self.lit(None, dtype=dtype).alias(name) for name, dtype in out_schema.items() ) - ).native + ) + .select(*safe_cast(self, out_schema)) + .native for item, schema in zip(items, schemas) ) - for _item in others: - # TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996 - res = duckdb.sql(""" - from res select * union all by name from _item select * - """) + res = reduce(DuckDBPyRelation.union, native_items) return first._with_native(res) raise NotImplementedError diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index a043058474..e267050e8c 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -18,7 +18,7 @@ from narwhals._ibis.selectors import IbisSelectorNamespace from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype from narwhals._sql.namespace import SQLNamespace -from narwhals._utils import Implementation +from narwhals._utils import Implementation, safe_cast from narwhals.schema import Schema, to_supertype if TYPE_CHECKING: @@ -74,11 +74,8 @@ def concat( if how.endswith("relaxed"): schemas = (Schema(frame.collect_schema()) for frame in frames) - out_schema = reduce(to_supertype, schemas).items() - frames = [ - frame.select(*(self.col(name).cast(dtype) for name, dtype in out_schema)) - for frame in frames - ] + out_schema = reduce(to_supertype, schemas) + frames = [frame.select(*safe_cast(self, out_schema)) for frame in frames] try: result = ibis.union(*(lf.native for lf in frames)) except ibis.IbisError: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 923271a9d8..5c9abcb866 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,6 +19,7 @@ true_divide, ) from narwhals._sql.namespace import SQLNamespace +from narwhals._utils import safe_cast from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: @@ -181,10 +182,7 @@ def concat( schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) out_schema = reduce(to_supertype, schemas) native_items = ( - item.select( - *(self.col(name).cast(dtype) for name, dtype in out_schema.items()) - ).native - for item in items + item.select(*safe_cast(self, out_schema)).native for item in items ) union = items[0].native.__class__.union return SparkLikeLazyFrame( @@ -199,12 +197,14 @@ def concat( native_items = ( item.select( *( - self.col(name).cast(dtype) + self.col(name) if name in schema else self.lit(None, dtype=dtype).alias(name) for name, dtype in out_schema.items() ) - ).native + ) + .select(*safe_cast(self, out_schema)) + .native for item, schema in zip(items, schemas) ) union = items[0].native.__class__.union diff --git a/narwhals/_utils.py b/narwhals/_utils.py index a02552e90a..a78c6adfac 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -68,7 +68,14 @@ TypeIs, ) - from narwhals._compliant import CompliantExprT, CompliantSeriesT, NativeSeriesT_co + from narwhals import dtypes + from narwhals._compliant import ( + CompliantExprT, + CompliantFrameT, + CompliantNamespace, + CompliantSeriesT, + NativeSeriesT_co, + ) from narwhals._compliant.any_namespace import NamespaceAccessor from narwhals._compliant.typing import ( Accessor, @@ -2148,3 +2155,14 @@ def __repr__(self) -> str: # pragma: no cover # Can be imported from types in Python 3.10 NoneType = type(None) + + +def safe_cast( + ns: CompliantNamespace[CompliantFrameT, CompliantExprT], + mapping: Mapping[str, dtypes.DType], +) -> Iterable[CompliantExprT]: + Unknown = ns._version.dtypes.Unknown() # noqa: N806 + return ( + ns.col(name) if dtype == Unknown else ns.col(name).cast(dtype) + for name, dtype in mapping.items() + ) From b205ddb50c2d391efd7191bfaaf24389f8eeeb2e Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 31 Jan 2026 13:20:56 +0100 Subject: [PATCH 18/20] preserve original dtype if not supported by narwhals --- narwhals/_arrow/namespace.py | 42 ++++++++++++----- tests/frame/concat_test.py | 91 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 11 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 519b5b27fd..85e11ee2c0 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from functools import reduce +from functools import partial, reduce from itertools import chain from typing import TYPE_CHECKING, Literal @@ -18,7 +18,7 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._utils import Implementation +from narwhals._utils import Implementation, safe_cast from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: @@ -181,15 +181,27 @@ def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: native_schemas = tuple(table.schema for table in dfs) out_schema = reduce( merge_schemas, (Schema.from_arrow(pa_schema) for pa_schema in native_schemas) - ).to_arrow() + ) to_schemas = ( - pa.schema([out_schema.field(name) for name in native_schema.names]) + { + name: dtype + for name, dtype in out_schema.items() + if name in native_schema.names + } for native_schema in native_schemas ) - to_concat = tuple( - table.cast(to_schema) for table, to_schema in zip(dfs, to_schemas) + version = self._version + to_compliant = partial( + self._dataframe, + version=version, + validate_backend_version=False, + validate_column_names=False, + ) + tables = tuple( + to_compliant(tbl).select(*safe_cast(self, to_schema)).native + for tbl, to_schema in zip(dfs, to_schemas) ) - return self._concat_diagonal(to_concat) + return self._concat_diagonal(tables) def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: names = list(chain.from_iterable(df.column_names for df in dfs)) @@ -210,11 +222,19 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: return pa.concat_tables(dfs) def _concat_vertical_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: - out_schema = reduce( - to_supertype, (Schema.from_arrow(table.schema) for table in dfs) - ).to_arrow() + out_schema = reduce(to_supertype, (Schema.from_arrow(tbl.schema) for tbl in dfs)) + version = self._version + to_compliant = partial( + self._dataframe, + version=version, + validate_backend_version=False, + validate_column_names=False, + ) + tables = ( + to_compliant(tbl).select(*safe_cast(self, out_schema)).native for tbl in dfs + ) - return pa.concat_tables([table.cast(out_schema) for table in dfs]) + return pa.concat_tables(tables) @property def selectors(self) -> ArrowSelectorNamespace: diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 7b91fe26c1..9367f55092 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -3,6 +3,7 @@ import datetime as dt import re from typing import TYPE_CHECKING, Any +from uuid import uuid4 as make_uuid import pytest @@ -259,3 +260,93 @@ def test_concat_diagonal_relaxed( "b": [None, None, 30.1, 40.2, 5.0, 6.0], } assert_equal_data(result.collect(), expected_data) + + +def test_pyarrow_concat_vertical_uuid() -> None: + # Test that concat vertical and vertical_relaxed preserves unsupported types like UUID + pa = pytest.importorskip("pyarrow") + + id1, id2, id3, id4 = [make_uuid() for _ in range(4)] + + data1 = {"id": pa.array([id1.bytes, id2.bytes], type=pa.uuid()), "a": [1, 2]} + data2 = {"id": pa.array([id3.bytes, id4.bytes], type=pa.uuid()), "a": [3.14, 4.2]} + frame1, frame2 = nw.from_native(pa.table(data1)), nw.from_native(pa.table(data2)) + + # Vertical + res_v = nw.concat([frame1, frame1], how="vertical").to_native() + + assert res_v.schema.field("id").type == pa.uuid() + assert res_v["id"].to_pylist() == [id1, id2, id1, id2] + assert res_v["a"].to_pylist() == [1, 2, 1, 2] + + # Vertical relaxed + res_vr = nw.concat([frame1, frame2], how="vertical_relaxed").to_native() + + assert res_vr.schema.field("id").type == pa.uuid() + assert res_vr["id"].to_pylist() == [id1, id2, id3, id4] + assert res_vr["a"].to_pylist() == [1.0, 2.0, 3.14, 4.2] + + +def test_concat_vertical_relaxed_duckdb_uuid() -> None: + # Test that concat vertical_relaxed preserves UUID type for DuckDB + duckdb = pytest.importorskip("duckdb") + + id1, id2, id3, id4 = [make_uuid() for _ in range(4)] + conn = duckdb.connect() + + rel1 = conn.sql( + f"SELECT '{id1}'::UUID as id, 1 as a UNION ALL SELECT '{id2}'::UUID, 2" + ) + rel2 = conn.sql( + f"SELECT '{id3}'::UUID as id, 3.14 as a UNION ALL SELECT '{id4}'::UUID, 4.2" + ) + + frame1, frame2 = nw.from_native(rel1), nw.from_native(rel2) + + # Vertical + res_v = nw.concat([frame1, frame1], how="vertical") + + assert res_v.to_native().types[0] == "UUID" + res_v_pa = res_v.collect().to_native() + assert res_v_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id1, id2)] + assert res_v_pa["a"].to_pylist() == [1, 2, 1, 2] + + # Vertical relaxed + res_vr = nw.concat([frame1, frame2], how="vertical_relaxed") + + assert res_vr.to_native().types[0] == "UUID" + res_vr_pa = res_vr.collect().to_native() + assert res_vr_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id3, id4)] + assert [float(v) for v in res_vr_pa["a"].to_pylist()] == [1.0, 2.0, 3.14, 4.2] + + +def test_concat_vertical_relaxed_ibis_uuid() -> None: + """Test that concat vertical_relaxed preserves UUID type for Ibis.""" + ibis = pytest.importorskip("ibis") + + id1, id2, id3, id4 = [make_uuid() for _ in range(4)] + + t1 = ibis.memtable( + {"id": [str(id1), str(id2)], "a": [1, 2]}, schema={"id": "uuid", "a": "int"} + ) + t2 = ibis.memtable( + {"id": [str(id3), str(id4)], "a": [3.14, 4.2]}, + schema={"id": "uuid", "a": "float"}, + ) + frame1, frame2 = nw.from_native(t1), nw.from_native(t2) + + # Vertical + res_v = nw.concat([frame1, frame1], how="vertical") + + assert res_v.to_native().schema()["id"] == ibis.dtype("uuid") + res_v_pa = res_v.collect().to_native() + assert res_v_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id1, id2)] + assert res_v_pa["a"].to_pylist() == [1, 2, 1, 2] + + # Vertical relaxed + res_vr = nw.concat([frame1, frame2], how="vertical_relaxed") + + assert res_vr.to_native().schema()["id"] == ibis.dtype("uuid") + res_vr_pa = res_vr.collect().to_native() + assert res_vr_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id3, id4)] + assert [float(v) for v in res_vr_pa["a"].to_pylist()] == [1.0, 2.0, 3.14, 4.2] From 61205dd06f840cece28645d91b9001df9ecda99d Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 31 Jan 2026 13:28:28 +0100 Subject: [PATCH 19/20] require pyarrow 19 --- tests/frame/concat_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 9367f55092..e6fd18e554 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -11,7 +11,13 @@ from narwhals._utils import Implementation from narwhals.exceptions import InvalidOperationError, NarwhalsError from narwhals.schema import Schema -from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data +from tests.utils import ( + POLARS_VERSION, + PYARROW_VERSION, + Constructor, + ConstructorEager, + assert_equal_data, +) if TYPE_CHECKING: from collections.abc import Iterator @@ -262,6 +268,7 @@ def test_concat_diagonal_relaxed( assert_equal_data(result.collect(), expected_data) +@pytest.mark.skipif(PYARROW_VERSION < (19, 0, 0)) def test_pyarrow_concat_vertical_uuid() -> None: # Test that concat vertical and vertical_relaxed preserves unsupported types like UUID pa = pytest.importorskip("pyarrow") From b3576dcc69b88664dd3e61f11119a9db4301359f Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 31 Jan 2026 13:33:32 +0100 Subject: [PATCH 20/20] add reason --- tests/frame/concat_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index e6fd18e554..83d34d9c7c 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -268,7 +268,7 @@ def test_concat_diagonal_relaxed( assert_equal_data(result.collect(), expected_data) -@pytest.mark.skipif(PYARROW_VERSION < (19, 0, 0)) +@pytest.mark.skipif(PYARROW_VERSION < (19, 0, 0), reason="Too old for pyarrow.uuid type") def test_pyarrow_concat_vertical_uuid() -> None: # Test that concat vertical and vertical_relaxed preserves unsupported types like UUID pa = pytest.importorskip("pyarrow")