-
Notifications
You must be signed in to change notification settings - Fork 190
refactor: Simplify _integer_fits_in_decimal; disallow supercasting for Datetime and Duration with different time_unit's
#3526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dtypes/supertyping
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,7 +15,6 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from operator import attrgetter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from narwhals._dispatch import just_dispatch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from narwhals._typing_compat import TypeVar | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from narwhals.dtypes._classes import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -58,7 +57,6 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing_extensions import TypeAlias, TypeIs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from narwhals.dtypes._classes import _Bits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from narwhals.typing import TimeUnit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _Fn = TypeVar("_Fn", bound=Callable[..., Any]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -95,6 +93,7 @@ def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Pairwise comparisons, but order (of classes) is not important | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEC128_MAX_PREC = 38 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SIGNED_INTEGER: DTypeGroup = frozenset((Int8, Int16, Int32, Int64, Int128)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UNSIGNED_INTEGER: DTypeGroup = frozenset((UInt8, UInt16, UInt32, UInt64, UInt128)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -111,18 +110,6 @@ def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _TIME_UNIT_PER_SECOND: Mapping[TimeUnit, int] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "s": 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "ms": MS_PER_SECOND, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "us": US_PER_SECOND, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "ns": NS_PER_SECOND, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _key_fn_time_unit(obj: Datetime | Duration, /) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _TIME_UNIT_PER_SECOND[obj.time_unit] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @lru_cache(maxsize=_CACHE_SIZE // 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def dtype_eq(left: DType, right: DType, /) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return left == right | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -191,13 +178,6 @@ def same_supertype(left: DType, right: DType, /) -> DType | None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return left if dtype_eq(left, right) else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @same_supertype.register(Duration, DurationV1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @lru_cache(maxsize=_CACHE_SIZE * 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def downcast_time_unit(left: SameTemporalT, right: SameTemporalT, /) -> SameTemporalT: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return the operand with the lowest precision time unit.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return min(left, right, key=_key_fn_time_unit) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _struct_fields_union( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| left: Collection[Field], right: Collection[Field], / | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Struct | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -251,15 +231,6 @@ def list_supertype(left: List, right: List, /) -> List | None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @same_supertype.register(Datetime, DatetimeV1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def datetime_supertype( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| left: SameDatetimeT, right: SameDatetimeT, / | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> SameDatetimeT | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if left.time_zone != right.time_zone: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return downcast_time_unit(left, right) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @same_supertype.register(Enum) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def enum_supertype(left: Enum, right: Enum, /) -> Enum | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return left if left.categories == right.categories else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -273,39 +244,24 @@ def decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return Decimal(precision=precision, scale=scale) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEC128_MAX_PREC = 38 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Precomputing powers of 10 up to 10^38 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| POW10_LIST = tuple(10**i for i in range(DEC128_MAX_PREC + 1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| INT_MAX_MAP: Mapping[Int, int] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UInt8(): (2**8) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UInt16(): (2**16) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UInt32(): (2**32) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UInt64(): (2**64) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Int8(): (2**7) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Int16(): (2**15) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Int32(): (2**31) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Int64(): (2**63) - 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _integer_fits_in_decimal(value: int, precision: int, scale: int) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Scales an integer and checks if it fits the target precision.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # !NOTE: Indexing is safe since `scale <= precision <= 38` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return (precision == DEC128_MAX_PREC) or ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| value * POW10_LIST[scale] < POW10_LIST[precision] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return (precision == DEC128_MAX_PREC) or (value * (10**scale) < (10**precision)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
247
to
+250
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sorry for not being clear here! Edit: ffs, reading this back really looks like AI. How big are they?AFAICT, these are all the possible inputs each parameter can have. value: Literal[127, ..., 340282366920938463463374607431768211455]
# ^^^ 8 hidden values
precision: Literal[1, ..., 38]
scale: Literal[1, ..., 38] # TIL: polars seems to allow `0`So this makes the worst-case 😨: (
34028236692093846346337460743176821145500000000000000000000000000000000000000
< 100000000000000000000000000000000000000
)SuggestionShow
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @cache | |
| def _integer_supertyping() -> Mapping[FrozenDTypes, type[Int | Float64]]: | |
| """Generate the supertype conversion table for all integer data type pairs.""" | |
| tps_int = SignedIntegerType.__subclasses__() | |
| tps_uint = UnsignedIntegerType.__subclasses__() | |
| get_bits: attrgetter[_Bits] = attrgetter("_bits") | |
| ints = ( | |
| (frozen_dtypes(lhs, rhs), max(lhs, rhs, key=get_bits)) | |
| for lhs, rhs in product(tps_int, repeat=2) | |
| ) | |
| uints = ( | |
| (frozen_dtypes(lhs, rhs), max(lhs, rhs, key=get_bits)) | |
| for lhs, rhs in product(tps_uint, repeat=2) | |
| ) | |
| # NOTE: `Float64` is here because `mypy` refuses to respect the last overload 😭 | |
| # https://github.com/python/typeshed/blob/a564787bf23386e57338b750bf4733f3c978b701/stdlib/typing.pyi#L776-L781 | |
| ubits_to_int: Mapping[_Bits, type[Int | Float64]] = {8: Int16, 16: Int32, 32: Int64} | |
| mixed = ( | |
| ( | |
| frozen_dtypes(int_, uint), | |
| int_ if int_._bits > uint._bits else ubits_to_int.get(uint._bits, Float64), | |
| ) | |
| for int_, uint in product(tps_int, tps_uint) | |
| ) | |
| return dict(chain(ints, uints, mixed)) | |
| @cache | |
| def _primitive_numeric_supertyping() -> Mapping[FrozenDTypes, type[Float]]: | |
| """Generate the supertype conversion table for all (integer, float) data type pairs.""" | |
| F32, F64 = Float32, Float64 # noqa: N806 | |
| small_int = (Int8, Int16, UInt8, UInt16) | |
| small_int_f32 = ((frozen_dtypes(tp, F32), F32) for tp in small_int) | |
| big_int_f32 = ((frozen_dtypes(tp, F32), F64) for tp in INTEGER.difference(small_int)) | |
| int_f64 = ((frozen_dtypes(tp, F64), F64) for tp in INTEGER) | |
| return dict(chain(small_int_f32, big_int_f32, int_f64)) |
Iff we need it, the result is cached and then we reuse from there 🙂
So the suggestion was just to move the global dict/list into a function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.

Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this comment may be relevant again depending on (#3526 (comment))