Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 8 additions & 52 deletions narwhals/dtypes/_supertyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from operator import attrgetter
from typing import TYPE_CHECKING, Any

from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
from narwhals._dispatch import just_dispatch
from narwhals._typing_compat import TypeVar
from narwhals.dtypes._classes import (
Expand Down Expand Up @@ -58,7 +57,6 @@
from typing_extensions import TypeAlias, TypeIs

from narwhals.dtypes._classes import _Bits
from narwhals.typing import TimeUnit

_Fn = TypeVar("_Fn", bound=Callable[..., Any])

Expand Down Expand Up @@ -95,6 +93,7 @@ def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes:
- Pairwise comparisons, but order (of classes) is not important
"""

DEC128_MAX_PREC = 38

SIGNED_INTEGER: DTypeGroup = frozenset((Int8, Int16, Int32, Int64, Int128))
UNSIGNED_INTEGER: DTypeGroup = frozenset((UInt8, UInt16, UInt32, UInt64, UInt128))
Expand All @@ -111,18 +110,6 @@ def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes:
}


_TIME_UNIT_PER_SECOND: Mapping[TimeUnit, int] = {
"s": 1,
"ms": MS_PER_SECOND,
"us": US_PER_SECOND,
"ns": NS_PER_SECOND,
}


def _key_fn_time_unit(obj: Datetime | Duration, /) -> int:
return _TIME_UNIT_PER_SECOND[obj.time_unit]


@lru_cache(maxsize=_CACHE_SIZE // 2)
def dtype_eq(left: DType, right: DType, /) -> bool:
return left == right
Expand Down Expand Up @@ -191,13 +178,6 @@ def same_supertype(left: DType, right: DType, /) -> DType | None:
return left if dtype_eq(left, right) else None


@same_supertype.register(Duration, DurationV1)
@lru_cache(maxsize=_CACHE_SIZE * 2)
def downcast_time_unit(left: SameTemporalT, right: SameTemporalT, /) -> SameTemporalT:
"""Return the operand with the lowest precision time unit."""
return min(left, right, key=_key_fn_time_unit)


def _struct_fields_union(
left: Collection[Field], right: Collection[Field], /
) -> Struct | None:
Expand Down Expand Up @@ -251,15 +231,6 @@ def list_supertype(left: List, right: List, /) -> List | None:
return None


@same_supertype.register(Datetime, DatetimeV1)
def datetime_supertype(
left: SameDatetimeT, right: SameDatetimeT, /
) -> SameDatetimeT | None:
if left.time_zone != right.time_zone:
return None
return downcast_time_unit(left, right)


@same_supertype.register(Enum)
def enum_supertype(left: Enum, right: Enum, /) -> Enum | None:
return left if left.categories == right.categories else None
Expand All @@ -273,39 +244,24 @@ def decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal:
return Decimal(precision=precision, scale=scale)


DEC128_MAX_PREC = 38
# Precomputing powers of 10 up to 10^38
POW10_LIST = tuple(10**i for i in range(DEC128_MAX_PREC + 1))
INT_MAX_MAP: Mapping[Int, int] = {
UInt8(): (2**8) - 1,
UInt16(): (2**16) - 1,
UInt32(): (2**32) - 1,
UInt64(): (2**64) - 1,
Int8(): (2**7) - 1,
Int16(): (2**15) - 1,
Int32(): (2**31) - 1,
Int64(): (2**63) - 1,
}


def _integer_fits_in_decimal(value: int, precision: int, scale: int) -> bool:
"""Scales an integer and checks if it fits the target precision."""
# !NOTE: Indexing is safe since `scale <= precision <= 38`
Copy link
Copy Markdown
Member

@dangotbanned dangotbanned Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# !NOTE: Indexing is safe since scale <= precision <= 38

I suppose this comment may be relevant again depending on (#3526 (comment))

return (precision == DEC128_MAX_PREC) or (
value * POW10_LIST[scale] < POW10_LIST[precision]
)
return (precision == DEC128_MAX_PREC) or (value * (10**scale) < (10**precision))
Comment on lines 247 to +250
Copy link
Copy Markdown
Member

@dangotbanned dangotbanned Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3396 (comment)

(1) Could we be lazy-er?

I would prefer if we defer generating this until it is needed.

E.g. I'd expect _integer_supertyping and _primitive_numeric_supertyping to be more commonly used - but even they don't exist at module-import-time

Sorry for not being clear here!
I think you are right to pre-compute these numbers.

Edit: ffs, reading this back really looks like AI.
Gonna leave it as-is, but I promise I wrote that
😭

How big are they?

AFAICT, these are all the possible inputs each parameter can have.

value: Literal[127, ..., 340282366920938463463374607431768211455]
#                   ^^^ 8 hidden values
precision: Literal[1, ..., 38]
scale: Literal[1, ..., 38]  # TIL: polars seems to allow `0`

So this makes the worst-case 😨:

(
    34028236692093846346337460743176821145500000000000000000000000000000000000000
    < 100000000000000000000000000000000000000
)

Suggestion

Show _{integer,primitive_numeric}_supertyping

@cache
def _integer_supertyping() -> Mapping[FrozenDTypes, type[Int | Float64]]:
"""Generate the supertype conversion table for all integer data type pairs."""
tps_int = SignedIntegerType.__subclasses__()
tps_uint = UnsignedIntegerType.__subclasses__()
get_bits: attrgetter[_Bits] = attrgetter("_bits")
ints = (
(frozen_dtypes(lhs, rhs), max(lhs, rhs, key=get_bits))
for lhs, rhs in product(tps_int, repeat=2)
)
uints = (
(frozen_dtypes(lhs, rhs), max(lhs, rhs, key=get_bits))
for lhs, rhs in product(tps_uint, repeat=2)
)
# NOTE: `Float64` is here because `mypy` refuses to respect the last overload 😭
# https://github.com/python/typeshed/blob/a564787bf23386e57338b750bf4733f3c978b701/stdlib/typing.pyi#L776-L781
ubits_to_int: Mapping[_Bits, type[Int | Float64]] = {8: Int16, 16: Int32, 32: Int64}
mixed = (
(
frozen_dtypes(int_, uint),
int_ if int_._bits > uint._bits else ubits_to_int.get(uint._bits, Float64),
)
for int_, uint in product(tps_int, tps_uint)
)
return dict(chain(ints, uints, mixed))
@cache
def _primitive_numeric_supertyping() -> Mapping[FrozenDTypes, type[Float]]:
"""Generate the supertype conversion table for all (integer, float) data type pairs."""
F32, F64 = Float32, Float64 # noqa: N806
small_int = (Int8, Int16, UInt8, UInt16)
small_int_f32 = ((frozen_dtypes(tp, F32), F32) for tp in small_int)
big_int_f32 = ((frozen_dtypes(tp, F32), F64) for tp in INTEGER.difference(small_int))
int_f64 = ((frozen_dtypes(tp, F64), F64) for tp in INTEGER)
return dict(chain(small_int_f32, big_int_f32, int_f64))

I referenced these guys because they do work to generate a dictionary - which may never end up being used.

Iff we need it, the result is cached and then we reuse from there 🙂

So the suggestion was just to move the global dict/list into a function



def _decimal_integer_supertyping(decimal: Decimal, integer: Int) -> DType | None:
precision, scale = decimal.precision, decimal.scale

if integer in {UInt128(), Int128()}:
fits_orig_prec_scale = False
elif value := INT_MAX_MAP.get(integer, None):
else:
bits: int = integer._bits
if isinstance(integer, SignedIntegerType):
bits = bits - 1

value = (1 << bits) - 1
Comment on lines +258 to +263
Copy link
Copy Markdown
Member

@dangotbanned dangotbanned Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Literal math is so cool 😄

Really didn't expect it to work for numbers this large (or bit-shifting tbh)

image

fits_orig_prec_scale = _integer_fits_in_decimal(value, precision, scale)
else: # pragma: no cover
msg = "Unreachable integer type"
raise ValueError(msg)

precision = precision if fits_orig_prec_scale else DEC128_MAX_PREC
return Decimal(precision, scale)
Expand Down
16 changes: 8 additions & 8 deletions tests/dtypes/get_supertype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_identical_dtype(dtype: DType) -> None:
@pytest.mark.parametrize(
("left", "right", "expected"),
[
(nw.Datetime("ns"), nw.Datetime("us"), nw.Datetime("us")),
(nw.Datetime("s"), nw.Datetime("us"), nw.Datetime("s")),
(nw.Datetime("ns"), nw.Datetime("us"), None),
(nw.Datetime("s"), nw.Datetime("us"), None),
(nw.Datetime("s"), nw.Datetime("s", "Africa/Accra"), None),
(nw.Datetime(time_zone="Asia/Kathmandu"), nw.Datetime(), None),
(
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_same_class(left: DType, right: DType, expected: DType | None) -> None:
[
(
{"f0": nw.Duration("ms"), "f1": nw.Int64, "f2": nw.Int64},
{"f0": nw.Duration("us"), "f1": nw.Int64()},
{"f0": nw.Duration("ms"), "f1": nw.Int64()},
{"f0": nw.Duration("ms"), "f1": nw.Int64(), "f2": nw.Int64()},
),
(
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_numeric_and_bool_promotion(numeric_dtype: NumericType) -> None:
("left", "right", "expected"),
[
(nw_v1.Datetime(), nw_v1.Datetime(), nw_v1.Datetime()),
(nw_v1.Datetime("ns"), nw_v1.Datetime("s"), nw_v1.Datetime("s")),
(nw_v1.Datetime("ns"), nw_v1.Datetime("s"), None),
(
nw_v1.Datetime(time_zone="Europe/Berlin"),
nw_v1.Datetime(time_zone="Europe/Berlin"),
Expand All @@ -339,13 +339,13 @@ def test_numeric_and_bool_promotion(numeric_dtype: NumericType) -> None:
(
nw_v1.Datetime(time_zone="Europe/Berlin"),
nw_v1.Datetime("ms", "Europe/Berlin"),
nw_v1.Datetime("ms", "Europe/Berlin"),
None,
),
(nw_v1.Datetime(time_zone="Europe/Berlin"), nw_v1.Datetime(), None),
(nw_v1.Datetime("s"), nw_v1.Datetime("s", "Africa/Accra"), None),
(nw_v1.Duration("ns"), nw_v1.Duration("ms"), nw_v1.Duration("ms")),
(nw_v1.Duration("ns"), nw_v1.Duration("ms"), None),
(nw_v1.Duration(), nw_v1.Duration(), nw_v1.Duration()),
(nw_v1.Duration("s"), nw_v1.Duration(), nw_v1.Duration("s")),
(nw_v1.Duration("s"), nw_v1.Duration(), None),
(nw_v1.Duration(), nw_v1.Datetime(), None),
(nw_v1.Enum(), nw_v1.Enum(), nw_v1.Enum()),
(nw_v1.Enum(), nw_v1.String(), nw_v1.String()),
Expand All @@ -356,7 +356,7 @@ def test_numeric_and_bool_promotion(numeric_dtype: NumericType) -> None:
),
(
nw.Struct({"f0": nw_v1.Duration("ms"), "f1": nw.Int64, "f2": nw.Int64}),
nw.Struct({"f0": nw_v1.Duration("us"), "f1": nw.Int64()}),
nw.Struct({"f0": nw_v1.Duration("ms"), "f1": nw.Int64()}),
nw.Struct({"f0": nw_v1.Duration("ms"), "f1": nw.Int64(), "f2": nw.Int64()}),
),
(
Expand Down
33 changes: 18 additions & 15 deletions utils/promotion-rules.md.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -142,39 +142,42 @@ https://github.com/narwhals-dev/narwhals/pull/3377

### Duration

Two `Duration` types always have a supertype, namely the type with the **less precise** (coarser) time unit.
For example:
Two `Duration` types have a supertype only if share the **same time unit** (hence are the same):

```python exec="1" session="promotion-rules" result="python"
st(nw.Duration('us'), nw.Duration('us'))
st(nw.Duration('us'), nw.Duration('ms'))
st(nw.Duration('s'), nw.Duration('ms'))
```

Time unit precision order (from coarsest to finest): `s` < `ms` < `us` < `ns`
!!! warning "Difference with Polars"

Polars promotes two `Duration` types with different time units to the **less precise** (coarser) one,
while other backends, such as pandas, promote to the **most precise** (finest) one.

Since these two behaviors are contradictory, Narwhals does not attempt to reconcile them and instead
returns no supertype when the time units differ.

### Datetime

Two `Datetime` types have a supertype only if they share the **same time zone**:
Two `Datetime` types have a supertype only if they are the same, hence
if they share both the **same time zone** and the **same time unit**:

```python exec="1" session="promotion-rules" result="python"
st(nw.Datetime('us'), nw.Datetime('ns'))
st(nw.Datetime('us'), nw.Datetime('us'))

tz = "Europe/Berlin"
print(f"{tz = !r}")
st(nw.Datetime(time_zone=tz), nw.Datetime(time_zone=tz))
```

The resulting time unit is the **less precise** (coarser) of the two as defined in the previous section on `Duration`.
!!! warning "Difference with Polars"

If they do not share the same time zone, no supertype exists:
Polars promotes two `Datetime` types with different time units to the **less precise** (coarser) one,
while other backends, such as pandas, promote to the **most precise** (finest) one.

Since these two behaviors are contradictory, Narwhals does not attempt to reconcile them and instead
returns no supertype when the time units differ.

```python exec="1" session="promotion-rules" result="python"
tz1 = "Europe/Berlin"
tz2 = "Europe/Paris"
print(f"{tz1 = !r}")
print(f"{tz2 = !r}")
st(nw.Datetime(time_zone=tz1), nw.Datetime(time_zone=tz2))
```

### Datetime and Date

Expand Down
Loading