Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/integrations/litestar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,22 @@ or with class-based handler:
async with container() as request_container:
b = await request_container.get(B) # object with Scope.REQUEST
return {"key": "value"}


Guards
******

For ``Guard`` functions (which receive ``ASGIConnection`` instead of ``Request``), use the ``@inject_asgi`` decorator:

.. code-block:: python

from litestar import ASGIConnection, BaseRouteHandler
from dishka.integrations.litestar import FromDishka, inject_asgi

@inject_asgi
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be used somewhere else? inject_asgi sounds like something more generic than guards only

async def my_guard(
connection: ASGIConnection,
_: BaseRouteHandler,
config: FromDishka[Config],
) -> None:
...
46 changes: 35 additions & 11 deletions src/dishka/integrations/litestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
"FromDishka",
"LitestarProvider",
"inject",
"inject_asgi",
"inject_websocket",
"setup_dishka",
]

from collections.abc import Callable
from functools import wraps
from inspect import Parameter
from typing import ParamSpec, TypeVar, get_type_hints
from typing import Any, ParamSpec, TypeVar, get_type_hints

from litestar import Controller, Litestar, Request, Router, WebSocket
from litestar.connection import ASGIConnection
from litestar.enums import ScopeType
from litestar.handlers import (
BaseRouteHandler,
Expand All @@ -34,6 +36,7 @@
from dishka import Scope as DIScope
from dishka.integrations.base import wrap_injection

GuardDependencyT = TypeVar("GuardDependencyT")
P = ParamSpec("P")
T = TypeVar("T")

Expand All @@ -42,14 +45,20 @@ def inject(func: Callable[P, T]) -> Callable[P, T]:
return _inject_wrapper(func, "request", Request)


def inject_asgi(
func: Callable[[ASGIConnection, BaseRouteHandler, GuardDependencyT], T],
) -> Callable[[ASGIConnection, BaseRouteHandler], T]:
return _inject_wrapper(func, "connection", ASGIConnection) # type: ignore[invalid-return-type]


def inject_websocket(func: Callable[P, T]) -> Callable[P, T]:
return _inject_wrapper(func, "socket", WebSocket)


def _inject_wrapper(
func: Callable[P, T],
param_name: str,
param_annotation: type[Request | WebSocket],
func: Callable[P, T],
param_name: str,
param_annotation: type[Request | WebSocket | ASGIConnection],
) -> Callable[P, T]:
hints = get_type_hints(func)

Expand All @@ -61,17 +70,31 @@ def _inject_wrapper(
if request_param:
additional_params = []
else:
additional_params = [Parameter(
name=param_name,
annotation=param_annotation,
kind=Parameter.KEYWORD_ONLY,
)]
additional_params = [
Parameter(
name=param_name,
annotation=param_annotation,
kind=Parameter.KEYWORD_ONLY,
),
]

if param_name == "connection":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see why user won't put param called connection at any other position in normal handler

def container_getter(
args: tuple[ASGIConnection], _: Any,
) -> AsyncContainer:
return args[0].state.dishka_container
else:
def container_getter(
_: Any,
kwargs: dict[str, Request | WebSocket],
) -> AsyncContainer:
return kwargs[param_name].state.dishka_container

return wrap_injection(
func=func,
is_async=True,
additional_params=additional_params,
container_getter=lambda _, r: r[param_name].state.dishka_container,
container_getter=container_getter,
)


Expand Down Expand Up @@ -153,7 +176,8 @@ async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
di_scope = DIScope.SESSION

async with request.app.state.dishka_container(
context, scope=di_scope,
context,
scope=di_scope,
) as request_container:
request.state.dishka_container = request_container
await app(scope, receive, send)
Expand Down
83 changes: 83 additions & 0 deletions tests/integrations/litestar/test_litestar_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from unittest.mock import Mock

import pytest
from asgi_lifespan import LifespanManager
from litestar import Litestar, get
from litestar.connection import ASGIConnection
from litestar.handlers import BaseRouteHandler
from litestar.testing import TestClient
from litestar.types import Guard

from dishka import make_async_container
from dishka.integrations.litestar import (
DishkaRouter,
FromDishka,
inject_asgi,
setup_dishka,
)
from ..common import (
APP_DEP_VALUE,
AppDep,
AppProvider,
)


@asynccontextmanager
async def dishka_app(
guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]:
@get("/", guards=[guard])
async def endpoint() -> dict:
return {"status": "ok"}

app = Litestar([endpoint], debug=True)
container = make_async_container(provider)
setup_dishka(container, app)
async with LifespanManager(app):
yield TestClient(app)
await container.close()


@asynccontextmanager
async def dishka_auto_app(
guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]:
@get("/")
async def endpoint() -> dict:
return {"status": "ok"}

router = DishkaRouter("", route_handlers=[endpoint], guards=[guard])
app = Litestar([router], debug=True)
container = make_async_container(provider)
setup_dishka(container, app)
async with LifespanManager(app):
yield TestClient(app)
await container.close()


@inject_asgi
async def guard_with_app(
connection: ASGIConnection,
_: BaseRouteHandler,
a: FromDishka[AppDep],
mock: FromDishka[Mock],
) -> None:
mock(a)


@pytest.mark.asyncio
@pytest.mark.parametrize(
("app_factory", "guard"),
[
(dishka_app, guard_with_app),
],
)
async def test_guard_injects_app_dependency(
app_factory,
guard,
app_provider: AppProvider,
):
async with app_factory(guard, app_provider) as client:
response = client.get("/")
assert response.status_code == 200
app_provider.mock.assert_called_with(APP_DEP_VALUE)