From d58ec9521bf80063bccd83fe29cfad2b26b80419 Mon Sep 17 00:00:00 2001 From: "Aryn Y." Date: Wed, 25 Mar 2026 23:23:07 +0500 Subject: [PATCH 1/6] feat(litestar): add ASGIConnection and guard dependency support Add support for injecting ASGIConnection as a parameter type, enabling more flexible handler signatures including guards. Add proper overload signatures to support guard dependencies with correct type inference. Refactor internal injection logic to dynamically detect the connection parameter type instead of hardcoding 'request' or 'socket'. Introduce _build_container_getter to properly resolve container from args/kwargs using function signature binding. --- src/dishka/integrations/litestar.py | 94 +++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 24 deletions(-) diff --git a/src/dishka/integrations/litestar.py b/src/dishka/integrations/litestar.py index 573c78611..350f2ef9d 100644 --- a/src/dishka/integrations/litestar.py +++ b/src/dishka/integrations/litestar.py @@ -9,8 +9,8 @@ from collections.abc import Callable from functools import wraps -from inspect import Parameter -from typing import ParamSpec, TypeVar, get_type_hints +from inspect import Parameter, Signature, signature +from typing import Any, ParamSpec, TypeVar, get_type_hints, overload from litestar import Controller, Litestar, Request, Router, WebSocket from litestar.enums import ScopeType @@ -19,6 +19,8 @@ HTTPRouteHandler, WebsocketListener, ) +from litestar.connection import ASGIConnection + from litestar.handlers.websocket_handlers import WebsocketListenerRouteHandler from litestar.handlers.websocket_handlers._utils import ListenerHandler from litestar.routes import BaseRoute @@ -34,47 +36,91 @@ from dishka import Scope as DIScope from dishka.integrations.base import wrap_injection +GuardDependencyT = TypeVar("GuardDependencyT") P = ParamSpec("P") T = TypeVar("T") +@overload +def inject( + func: Callable[[ASGIConnection, BaseRouteHandler, GuardDependencyT], T], + /, + ) -> Callable[[ASGIConnection, BaseRouteHandler], T]: ... + + +@overload +def inject(func: Callable[P, T], /) -> Callable[P, T]: ... + + def inject(func: Callable[P, T]) -> Callable[P, T]: - return _inject_wrapper(func, "request", Request) + return _inject_wrapper(func) def inject_websocket(func: Callable[P, T]) -> Callable[P, T]: - return _inject_wrapper(func, "socket", WebSocket) + return inject(func) -def _inject_wrapper( - func: Callable[P, T], - param_name: str, - param_annotation: type[Request | WebSocket], -) -> Callable[P, T]: +def _inject_wrapper(func: Callable[P, T]) -> Callable[P, T]: hints = get_type_hints(func) - - request_param = next( - (name for name in hints if name == param_name), - None, - ) - - if request_param: - additional_params = [] - else: - additional_params = [Parameter( - name=param_name, - annotation=param_annotation, - kind=Parameter.KEYWORD_ONLY, - )] + func_signature = signature(func) + + param_name, param_annotation = _find_connection_parameter(hints) + + additional_params = [] + if param_name not in hints: + additional_params = [ + Parameter( + name=param_name, + annotation=param_annotation, + kind=Parameter.KEYWORD_ONLY, + ), + ] return wrap_injection( func=func, is_async=True, additional_params=additional_params, - container_getter=lambda _, r: r[param_name].state.dishka_container, + container_getter=_build_container_getter( + func_signature=func_signature, + param_name=param_name, + ), ) +def _find_connection_parameter( + hints: dict[str, Any], +) -> tuple[str, type[Request | WebSocket | ASGIConnection]]: + for param_name, param_annotation in ( + ("request", Request), + ("socket", WebSocket), + ("connection", ASGIConnection), + ): + if param_name in hints: + return param_name, param_annotation + + return "connection", ASGIConnection + + +def _build_container_getter( + *, + func_signature: Signature, + param_name: str, + ) -> Callable[[tuple[Any, ...], dict[str, Any]], AsyncContainer]: + def container_getter( + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> AsyncContainer: + if param_name in kwargs: + connection = kwargs[param_name] + else: + bound = func_signature.bind_partial(*args, **kwargs) + connection = bound.arguments[param_name] + + return connection.state.dishka_container + + return container_getter + + def _inject_based_on_handler_type( value: BaseRouteHandler, ) -> BaseRouteHandler: From 6245b47d85358fe920ce32dc854bedcb2ab6f8e6 Mon Sep 17 00:00:00 2001 From: "Aryn Y." Date: Thu, 26 Mar 2026 01:14:29 +0500 Subject: [PATCH 2/6] feat(litestar): add inject_asgi for guards (#699) Revert the dynamic parameter detection approach in favor of explicit inject functions for different handler types. Add `inject_asgi` function specifically for Litestar guards that receive `connection: ASGIConnection` as their first parameter. This provides a clean API for guard dependency injection: @inject_asgi async def my_guard( connection: ASGIConnection, _: BaseRouteHandler, config: FromDishka[Config], ) -> None: ... Fixes the KeyError when using @inject on guard functions that don't have 'request' in their kwargs. --- src/dishka/integrations/litestar.py | 87 +++++++++++------------------ 1 file changed, 32 insertions(+), 55 deletions(-) diff --git a/src/dishka/integrations/litestar.py b/src/dishka/integrations/litestar.py index 350f2ef9d..c80530113 100644 --- a/src/dishka/integrations/litestar.py +++ b/src/dishka/integrations/litestar.py @@ -3,6 +3,7 @@ "FromDishka", "LitestarProvider", "inject", + "inject_asgi", "inject_websocket", "setup_dishka", ] @@ -20,7 +21,6 @@ WebsocketListener, ) from litestar.connection import ASGIConnection - from litestar.handlers.websocket_handlers import WebsocketListenerRouteHandler from litestar.handlers.websocket_handlers._utils import ListenerHandler from litestar.routes import BaseRoute @@ -53,74 +53,51 @@ def inject(func: Callable[P, T], /) -> Callable[P, T]: ... def inject(func: Callable[P, T]) -> Callable[P, T]: - return _inject_wrapper(func) + return _inject_wrapper(func, "request", Request) + + +def inject_asgi(func: Callable[P, T]) -> Callable[P, T]: + return _inject_wrapper(func, "connection", ASGIConnection) def inject_websocket(func: Callable[P, T]) -> Callable[P, T]: - return inject(func) + return _inject_wrapper(func, "socket", WebSocket) -def _inject_wrapper(func: Callable[P, T]) -> Callable[P, T]: +def _inject_wrapper( + func: Callable[P, T], + param_name: str, + param_annotation: type[Request | WebSocket | ASGIConnection], +) -> Callable[P, T]: hints = get_type_hints(func) - func_signature = signature(func) - - param_name, param_annotation = _find_connection_parameter(hints) - - additional_params = [] - if param_name not in hints: - additional_params = [ - Parameter( - name=param_name, - annotation=param_annotation, - kind=Parameter.KEYWORD_ONLY, - ), - ] + + request_param = next( + (name for name in hints if name == param_name), + None, + ) + + if request_param: + additional_params = [] + else: + additional_params = [Parameter( + name=param_name, + annotation=param_annotation, + kind=Parameter.KEYWORD_ONLY, + )] + + if param_name == "connection": + container_getter = lambda args, _: args[0].state.dishka_container + else: + container_getter = lambda _, kwargs: kwargs[param_name].state.dishka_container return wrap_injection( func=func, is_async=True, additional_params=additional_params, - container_getter=_build_container_getter( - func_signature=func_signature, - param_name=param_name, - ), + container_getter=container_getter, ) -def _find_connection_parameter( - hints: dict[str, Any], -) -> tuple[str, type[Request | WebSocket | ASGIConnection]]: - for param_name, param_annotation in ( - ("request", Request), - ("socket", WebSocket), - ("connection", ASGIConnection), - ): - if param_name in hints: - return param_name, param_annotation - - return "connection", ASGIConnection - - -def _build_container_getter( - *, - func_signature: Signature, - param_name: str, - ) -> Callable[[tuple[Any, ...], dict[str, Any]], AsyncContainer]: - def container_getter( - args: tuple[Any, ...], - kwargs: dict[str, Any], - ) -> AsyncContainer: - if param_name in kwargs: - connection = kwargs[param_name] - else: - bound = func_signature.bind_partial(*args, **kwargs) - connection = bound.arguments[param_name] - - return connection.state.dishka_container - - return container_getter - - def _inject_based_on_handler_type( value: BaseRouteHandler, ) -> BaseRouteHandler: From c7fb3ffc3d47a14f99e40e7d06f83990ce2f8649 Mon Sep 17 00:00:00 2001 From: "Aryn Y." Date: Thu, 26 Mar 2026 01:15:43 +0500 Subject: [PATCH 3/6] docs(litestar): add guard usage with inject_asgi --- docs/integrations/litestar.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/integrations/litestar.rst b/docs/integrations/litestar.rst index 31f858a6f..9ec5b8a13 100644 --- a/docs/integrations/litestar.rst +++ b/docs/integrations/litestar.rst @@ -126,3 +126,22 @@ or with class-based handler: async with container() as request_container: b = await request_container.get(B) # object with Scope.REQUEST return {"key": "value"} + + +Guards +****** + +For ``Guard`` functions (which receive ``ASGIConnection`` instead of ``Request``), use the ``@inject_asgi`` decorator: + +>> code-block:: python + +>> from litestar import ASGIConnection, BaseRouteHandler +>> from dishka.integrations.litestar import FromDishka, inject_asgi +>> +>> @inject_asgi +>> async def my_guard( +>> connection: ASGIConnection, +>> _: BaseRouteHandler, +>> config: FromDishka[Config], +>> ) -> None: + >>> >>> ... \ No newline at end of file From a95868d976586e34b8b1edb1fd0c83681d744c80 Mon Sep 17 00:00:00 2001 From: "Aryn Y." Date: Thu, 26 Mar 2026 01:18:15 +0500 Subject: [PATCH 4/6] feat(litestar): add documentation for @inject_asgi decorator in Guards section --- docs/integrations/litestar.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/integrations/litestar.rst b/docs/integrations/litestar.rst index 31f858a6f..0a5880754 100644 --- a/docs/integrations/litestar.rst +++ b/docs/integrations/litestar.rst @@ -126,3 +126,22 @@ or with class-based handler: async with container() as request_container: b = await request_container.get(B) # object with Scope.REQUEST return {"key": "value"} + + +Guards +****** + +For ``Guard`` functions (which receive ``ASGIConnection`` instead of ``Request``), use the ``@inject_asgi`` decorator: + +.. code-block:: python + + from litestar import ASGIConnection, BaseRouteHandler + from dishka.integrations.litestar import FromDishka, inject_asgi + + @inject_asgi + async def my_guard( + connection: ASGIConnection, + _: BaseRouteHandler, + config: FromDishka[Config], + ) -> None: + ... From 545baa2522ef3a8bd239e1a25e4cfe0cd3ec0cfe Mon Sep 17 00:00:00 2001 From: "Aryn Y." Date: Thu, 26 Mar 2026 01:26:40 +0500 Subject: [PATCH 5/6] test(litestar): fix Guard import and simplify guard tests Fix Guard import path from litestar.guards to litestar.types. Remove auto-injection test case as guards require explicit @inject_asgi decorator. --- tests/integrations/litestar/test_litestar_guards.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/integrations/litestar/test_litestar_guards.py b/tests/integrations/litestar/test_litestar_guards.py index 53100a9ae..964640a3c 100644 --- a/tests/integrations/litestar/test_litestar_guards.py +++ b/tests/integrations/litestar/test_litestar_guards.py @@ -7,7 +7,7 @@ from litestar import Litestar, get from litestar.connection import ASGIConnection from litestar.handlers import BaseRouteHandler -from litestar.guards import Guard +from litestar.types import Guard from litestar.testing import TestClient from dishka import make_async_container @@ -63,21 +63,11 @@ async def guard_with_app( mock(a) -async def auto_guard_with_app( - connection: ASGIConnection, - _: BaseRouteHandler, - a: FromDishka[AppDep], - mock: FromDishka[Mock], -) -> None: - mock(a) - - @pytest.mark.asyncio @pytest.mark.parametrize( ("app_factory", "guard"), [ (dishka_app, guard_with_app), - (dishka_auto_app, auto_guard_with_app), ], ) async def test_guard_injects_app_dependency( From 06ed97729ed7b0501f0c8ca8a819cfd83a59b50b Mon Sep 17 00:00:00 2001 From: "Aryn Y." Date: Thu, 26 Mar 2026 01:40:49 +0500 Subject: [PATCH 6/6] style: fix typing in litestar integration --- src/dishka/integrations/litestar.py | 49 ++++++++++--------- .../litestar/test_litestar_guards.py | 8 +-- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/dishka/integrations/litestar.py b/src/dishka/integrations/litestar.py index c80530113..eddea3306 100644 --- a/src/dishka/integrations/litestar.py +++ b/src/dishka/integrations/litestar.py @@ -10,17 +10,17 @@ from collections.abc import Callable from functools import wraps -from inspect import Parameter, Signature, signature -from typing import Any, ParamSpec, TypeVar, get_type_hints, overload +from inspect import Parameter +from typing import Any, ParamSpec, TypeVar, get_type_hints from litestar import Controller, Litestar, Request, Router, WebSocket +from litestar.connection import ASGIConnection from litestar.enums import ScopeType from litestar.handlers import ( BaseRouteHandler, HTTPRouteHandler, WebsocketListener, ) -from litestar.connection import ASGIConnection from litestar.handlers.websocket_handlers import WebsocketListenerRouteHandler from litestar.handlers.websocket_handlers._utils import ListenerHandler from litestar.routes import BaseRoute @@ -41,23 +41,14 @@ T = TypeVar("T") -@overload -def inject( - func: Callable[[ASGIConnection, BaseRouteHandler, GuardDependencyT], T], - /, - ) -> Callable[[ASGIConnection, BaseRouteHandler], T]: ... - - -@overload -def inject(func: Callable[P, T], /) -> Callable[P, T]: ... - - def inject(func: Callable[P, T]) -> Callable[P, T]: return _inject_wrapper(func, "request", Request) -def inject_asgi(func: Callable[P, T]) -> Callable[P, T]: - return _inject_wrapper(func, "connection", ASGIConnection) +def inject_asgi( + func: Callable[[ASGIConnection, BaseRouteHandler, GuardDependencyT], T], +) -> Callable[[ASGIConnection, BaseRouteHandler], T]: + return _inject_wrapper(func, "connection", ASGIConnection) # type: ignore[invalid-return-type] def inject_websocket(func: Callable[P, T]) -> Callable[P, T]: @@ -79,16 +70,25 @@ def _inject_wrapper( if request_param: additional_params = [] else: - additional_params = [Parameter( - name=param_name, - annotation=param_annotation, - kind=Parameter.KEYWORD_ONLY, - )] + additional_params = [ + Parameter( + name=param_name, + annotation=param_annotation, + kind=Parameter.KEYWORD_ONLY, + ), + ] if param_name == "connection": - container_getter = lambda args, _: args[0].state.dishka_container + def container_getter( + args: tuple[ASGIConnection], _: Any, + ) -> AsyncContainer: + return args[0].state.dishka_container else: - container_getter = lambda _, kwargs: kwargs[param_name].state.dishka_container + def container_getter( + _: Any, + kwargs: dict[str, Request | WebSocket], + ) -> AsyncContainer: + return kwargs[param_name].state.dishka_container return wrap_injection( func=func, @@ -176,7 +176,8 @@ async def middleware(scope: Scope, receive: Receive, send: Send) -> None: di_scope = DIScope.SESSION async with request.app.state.dishka_container( - context, scope=di_scope, + context, + scope=di_scope, ) as request_container: request.state.dishka_container = request_container await app(scope, receive, send) diff --git a/tests/integrations/litestar/test_litestar_guards.py b/tests/integrations/litestar/test_litestar_guards.py index 964640a3c..42724e115 100644 --- a/tests/integrations/litestar/test_litestar_guards.py +++ b/tests/integrations/litestar/test_litestar_guards.py @@ -7,8 +7,8 @@ from litestar import Litestar, get from litestar.connection import ASGIConnection from litestar.handlers import BaseRouteHandler -from litestar.types import Guard from litestar.testing import TestClient +from litestar.types import Guard from dishka import make_async_container from dishka.integrations.litestar import ( @@ -25,7 +25,8 @@ @asynccontextmanager -async def dishka_app(guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]: +async def dishka_app( + guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]: @get("/", guards=[guard]) async def endpoint() -> dict: return {"status": "ok"} @@ -39,7 +40,8 @@ async def endpoint() -> dict: @asynccontextmanager -async def dishka_auto_app(guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]: +async def dishka_auto_app( + guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]: @get("/") async def endpoint() -> dict: return {"status": "ok"}