Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ repos:
language: unsupported
types: [python]

- id: local-mypy
name: mypy check
entry: uv run mypy sqlmodel tests/test_select_typing.py
- id: local-ty
name: ty check
entry: uv run ty check sqlmodel
require_serial: true
language: unsupported
pass_filenames: false
Expand Down
12 changes: 1 addition & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ tests = [
"fastapi>=0.128.0",
"httpx==0.28.1",
"jinja2==3.1.6",
"mypy==1.19.1",
"pre-commit>=2.17.0,<5.0.0",
"pytest>=7.0.1,<10.0.0",
"ruff==0.15.5",
"ty>=0.0.9",
"typing-extensions==4.15.0",
]

Expand Down Expand Up @@ -125,16 +125,6 @@ exclude_lines = [
[tool.coverage.html]
show_contexts = true

[tool.mypy]
strict = true
exclude = "sqlmodel.sql._expression_select_gen"

[[tool.mypy.overrides]]
module = "docs_src.*"
disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_untyped_calls = false

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Arg(BaseModel):
else:
t_type = f"_T{i}"
t_var = f"_TCCA[{t_type}]"
arg = Arg(name=f"__ent{i}", annotation=t_var)
arg = Arg(name=f"_ent{i}", annotation=t_var)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty complaints with:

warning[invalid-legacy-positional-parameter]: Invalid use of the legacy convention for positional-only parameters
--> sqlmodel\sql_expression_select_gen.py:132:5
|
130 | @overload
131 | def select(
132 | entity_0: _TScalar_0,
| -------- Prior parameter here was positional-or-keyword
133 | __ent1: _TCCA[_T1],
| ^^^^^^ Parameter name begins with __ but will not be treated as positional-only
134 | ) -> Select[tuple[_TScalar_0, _T1]]: ...
|
info: A parameter can only be positional-only if it precedes all positional-or-keyword parameters
info: rule invalid-legacy-positional-parameter is enabled by default

So basically we can't have __var if there is a scalar parameter in front of this one.

As a quick fix, I changed all double underscores to singles, but it feels a bit like a hack. We could also suppress the ty warning, but that also feels wrong...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just renaming it to single-underscored would be incorrect.
As I understand, double-underscored parameter names was a convention for position-only parameters before , / syntax was added in 3.8.

I suggest we update this to use names without leading underscores, and update template to add , / to each signature of select:

@overload
-def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ...
+def select(__ent0: _TCCA[_T0], /) -> SelectOfScalar[_T0]: ...


@overload
-def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: ...
+def select(__ent0: _TScalar_0, /) -> SelectOfScalar[_TScalar_0]: ...

# Generated overloads start

{% for signature in signatures %}

@overload
def select(
-    {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}
+    {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} /,
    ) -> Select[tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ...

{% endfor %}

# Generated overloads end

Alternatively we can leave __ent and rename entity_.. parameters to __entity_..

ret_type = t_type
args.append(arg)
return_types.append(ret_type)
Expand Down
4 changes: 2 additions & 2 deletions scripts/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -e
set -x

mypy sqlmodel
mypy tests/test_select_typing.py
ty check sqlmodel
ty check tests/test_select_typing.py
ruff check sqlmodel tests docs_src scripts
ruff format sqlmodel tests docs_src scripts --check
28 changes: 14 additions & 14 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import deprecated

from ._compat import ( # type: ignore[attr-defined]
from ._compat import (
PYDANTIC_MINOR_VERSION,
BaseConfig,
ModelMetaclass,
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
cascade_delete: bool | None = False,
passive_deletes: bool | Literal["all"] | None = False,
link_model: Any | None = None,
sa_relationship: RelationshipProperty | None = None, # type: ignore
sa_relationship: RelationshipProperty | None = None,
sa_relationship_args: Sequence[Any] | None = None,
sa_relationship_kwargs: Mapping[str, Any] | None = None,
) -> None:
Expand Down Expand Up @@ -398,7 +398,7 @@ def Field(
nullable: bool | UndefinedType = Undefined,
index: bool | UndefinedType = Undefined,
sa_type: type[Any] | UndefinedType = Undefined,
sa_column: Column | UndefinedType = Undefined, # type: ignore
sa_column: Column | UndefinedType = Undefined,
sa_column_args: Sequence[Any] | UndefinedType = Undefined,
sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined,
schema_extra: dict[str, Any] | None = None,
Expand Down Expand Up @@ -525,17 +525,17 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
model_fields: ClassVar[dict[str, FieldInfo]]

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
def __setattr__(cls, key: str, value: Any) -> None:
if is_table_model_class(cls):
DeclarativeMeta.__setattr__(cls, name, value)
DeclarativeMeta.__setattr__(cls, key, value)
else:
super().__setattr__(name, value)
super().__setattr__(key, value)

def __delattr__(cls, name: str) -> None:
def __delattr__(cls, key: str) -> None:
if is_table_model_class(cls):
DeclarativeMeta.__delattr__(cls, name)
DeclarativeMeta.__delattr__(cls, key)
else:
super().__delattr__(name)
super().__delattr__(key)
Comment on lines -528 to +538
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty said the original code violates the Liskov Substitution Principle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty is right about Liskov Substitution Principle violation, but I don't think we need to fix it this way.
object declares these methods with name parameter. So, using key here will not be correct.
This is actually minor issue as this is unlikely that these methods will be called with parameters passed by name.
So, I would revert changes and just ignore ty warning


# From Pydantic
def __new__(
Expand Down Expand Up @@ -649,7 +649,7 @@ def __init__(
# Plain forward references, for models not yet defined, are not
# handled well by SQLAlchemy without Mapped, so, wrap the
# annotations in Mapped here
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
cls.__annotations__[rel_name] = Mapped[ann]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
Expand Down Expand Up @@ -738,7 +738,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Column:
field_info = field
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
Expand Down Expand Up @@ -773,7 +773,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
assert isinstance(foreign_key, str)
assert isinstance(ondelete_value, (str, type(None))) # for typing
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
kwargs = {
kwargs: dict[str, Any] = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
Expand All @@ -797,7 +797,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
return Column(sa_type, *args, **kwargs)


class_registry = weakref.WeakValueDictionary() # type: ignore
class_registry = weakref.WeakValueDictionary()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need this class_registry?
It's not used and after removing it all tests pass


default_registry = registry()

Expand Down Expand Up @@ -850,7 +850,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
if is_table_model_class(self.__class__) and is_instrumented(self, name):
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
Expand Down
4 changes: 2 additions & 2 deletions sqlmodel/sql/_expression_select_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def where(self, *whereclause: _ColumnExpressionArgument[bool] | bool) -> Self:
"""Return a new `Select` construct with the given expression added to
its `WHERE` clause, joined to the existing clause via `AND`, if any.
"""
return super().where(*whereclause) # type: ignore[arg-type]
return super().where(*whereclause)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ty's mistake that it doesn't notice type mismatch here. But I checked - passing bool argument works in runtime


def having(self, *having: _ColumnExpressionArgument[bool] | bool) -> Self:
"""Return a new `Select` construct with the given expression added to
its `HAVING` clause, joined to the existing clause via `AND`, if any.
"""
return super().having(*having) # type: ignore[arg-type]
return super().having(*having)


class Select(SelectBase[_T]):
Expand Down
Loading