diff --git a/docs/integrations/litestar.rst b/docs/integrations/litestar.rst index 31f858a6f..0a5880754 100644 --- a/docs/integrations/litestar.rst +++ b/docs/integrations/litestar.rst @@ -126,3 +126,22 @@ or with class-based handler: async with container() as request_container: b = await request_container.get(B) # object with Scope.REQUEST return {"key": "value"} + + +Guards +****** + +For ``Guard`` functions (which receive ``ASGIConnection`` instead of ``Request``), use the ``@inject_asgi`` decorator: + +.. code-block:: python + + from litestar import ASGIConnection, BaseRouteHandler + from dishka.integrations.litestar import FromDishka, inject_asgi + + @inject_asgi + async def my_guard( + connection: ASGIConnection, + _: BaseRouteHandler, + config: FromDishka[Config], + ) -> None: + ... diff --git a/src/dishka/integrations/litestar.py b/src/dishka/integrations/litestar.py index 573c78611..eddea3306 100644 --- a/src/dishka/integrations/litestar.py +++ b/src/dishka/integrations/litestar.py @@ -3,6 +3,7 @@ "FromDishka", "LitestarProvider", "inject", + "inject_asgi", "inject_websocket", "setup_dishka", ] @@ -10,9 +11,10 @@ from collections.abc import Callable from functools import wraps from inspect import Parameter -from typing import ParamSpec, TypeVar, get_type_hints +from typing import Any, ParamSpec, TypeVar, get_type_hints from litestar import Controller, Litestar, Request, Router, WebSocket +from litestar.connection import ASGIConnection from litestar.enums import ScopeType from litestar.handlers import ( BaseRouteHandler, @@ -34,6 +36,7 @@ from dishka import Scope as DIScope from dishka.integrations.base import wrap_injection +GuardDependencyT = TypeVar("GuardDependencyT") P = ParamSpec("P") T = TypeVar("T") @@ -42,14 +45,20 @@ def inject(func: Callable[P, T]) -> Callable[P, T]: return _inject_wrapper(func, "request", Request) +def inject_asgi( + func: Callable[[ASGIConnection, BaseRouteHandler, GuardDependencyT], T], +) -> Callable[[ASGIConnection, BaseRouteHandler], T]: + return _inject_wrapper(func, "connection", ASGIConnection) # type: ignore[invalid-return-type] + + def inject_websocket(func: Callable[P, T]) -> Callable[P, T]: return _inject_wrapper(func, "socket", WebSocket) def _inject_wrapper( - func: Callable[P, T], - param_name: str, - param_annotation: type[Request | WebSocket], + func: Callable[P, T], + param_name: str, + param_annotation: type[Request | WebSocket | ASGIConnection], ) -> Callable[P, T]: hints = get_type_hints(func) @@ -61,17 +70,31 @@ def _inject_wrapper( if request_param: additional_params = [] else: - additional_params = [Parameter( - name=param_name, - annotation=param_annotation, - kind=Parameter.KEYWORD_ONLY, - )] + additional_params = [ + Parameter( + name=param_name, + annotation=param_annotation, + kind=Parameter.KEYWORD_ONLY, + ), + ] + + if param_name == "connection": + def container_getter( + args: tuple[ASGIConnection], _: Any, + ) -> AsyncContainer: + return args[0].state.dishka_container + else: + def container_getter( + _: Any, + kwargs: dict[str, Request | WebSocket], + ) -> AsyncContainer: + return kwargs[param_name].state.dishka_container return wrap_injection( func=func, is_async=True, additional_params=additional_params, - container_getter=lambda _, r: r[param_name].state.dishka_container, + container_getter=container_getter, ) @@ -153,7 +176,8 @@ async def middleware(scope: Scope, receive: Receive, send: Send) -> None: di_scope = DIScope.SESSION async with request.app.state.dishka_container( - context, scope=di_scope, + context, + scope=di_scope, ) as request_container: request.state.dishka_container = request_container await app(scope, receive, send) diff --git a/tests/integrations/litestar/test_litestar_guards.py b/tests/integrations/litestar/test_litestar_guards.py new file mode 100644 index 000000000..42724e115 --- /dev/null +++ b/tests/integrations/litestar/test_litestar_guards.py @@ -0,0 +1,83 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from unittest.mock import Mock + +import pytest +from asgi_lifespan import LifespanManager +from litestar import Litestar, get +from litestar.connection import ASGIConnection +from litestar.handlers import BaseRouteHandler +from litestar.testing import TestClient +from litestar.types import Guard + +from dishka import make_async_container +from dishka.integrations.litestar import ( + DishkaRouter, + FromDishka, + inject_asgi, + setup_dishka, +) +from ..common import ( + APP_DEP_VALUE, + AppDep, + AppProvider, +) + + +@asynccontextmanager +async def dishka_app( + guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]: + @get("/", guards=[guard]) + async def endpoint() -> dict: + return {"status": "ok"} + + app = Litestar([endpoint], debug=True) + container = make_async_container(provider) + setup_dishka(container, app) + async with LifespanManager(app): + yield TestClient(app) + await container.close() + + +@asynccontextmanager +async def dishka_auto_app( + guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]: + @get("/") + async def endpoint() -> dict: + return {"status": "ok"} + + router = DishkaRouter("", route_handlers=[endpoint], guards=[guard]) + app = Litestar([router], debug=True) + container = make_async_container(provider) + setup_dishka(container, app) + async with LifespanManager(app): + yield TestClient(app) + await container.close() + + +@inject_asgi +async def guard_with_app( + connection: ASGIConnection, + _: BaseRouteHandler, + a: FromDishka[AppDep], + mock: FromDishka[Mock], +) -> None: + mock(a) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("app_factory", "guard"), + [ + (dishka_app, guard_with_app), + ], +) +async def test_guard_injects_app_dependency( + app_factory, + guard, + app_provider: AppProvider, +): + async with app_factory(guard, app_provider) as client: + response = client.get("/") + assert response.status_code == 200 + app_provider.mock.assert_called_with(APP_DEP_VALUE)