Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/integrations/litestar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,22 @@ or with class-based handler:
async with container() as request_container:
b = await request_container.get(B) # object with Scope.REQUEST
return {"key": "value"}


Guards
******

For ``Guard`` functions (which receive ``ASGIConnection`` instead of ``Request``), use the ``@inject_asgi`` decorator:

.. code-block:: python

from litestar import ASGIConnection, BaseRouteHandler
from dishka.integrations.litestar import FromDishka, inject_asgi

@inject_asgi
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be used somewhere else? inject_asgi sounds like something more generic than guards only

async def my_guard(
connection: ASGIConnection,
_: BaseRouteHandler,
config: FromDishka[Config],
) -> None:
...
35 changes: 29 additions & 6 deletions src/dishka/integrations/litestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
"FromDishka",
"LitestarProvider",
"inject",
"inject_asgi",
"inject_websocket",
"setup_dishka",
]

from collections.abc import Callable
from functools import wraps
from inspect import Parameter
from typing import ParamSpec, TypeVar, get_type_hints
from inspect import Parameter, Signature, signature
from typing import Any, ParamSpec, TypeVar, get_type_hints, overload

from litestar import Controller, Litestar, Request, Router, WebSocket
from litestar.enums import ScopeType
Expand All @@ -19,6 +20,7 @@
HTTPRouteHandler,
WebsocketListener,
)
from litestar.connection import ASGIConnection
from litestar.handlers.websocket_handlers import WebsocketListenerRouteHandler
from litestar.handlers.websocket_handlers._utils import ListenerHandler
from litestar.routes import BaseRoute
Expand All @@ -34,22 +36,38 @@
from dishka import Scope as DIScope
from dishka.integrations.base import wrap_injection

GuardDependencyT = TypeVar("GuardDependencyT")
P = ParamSpec("P")
T = TypeVar("T")


@overload
def inject(
func: Callable[[ASGIConnection, BaseRouteHandler, GuardDependencyT], T],
/,
) -> Callable[[ASGIConnection, BaseRouteHandler], T]: ...


@overload
def inject(func: Callable[P, T], /) -> Callable[P, T]: ...


def inject(func: Callable[P, T]) -> Callable[P, T]:
return _inject_wrapper(func, "request", Request)


def inject_asgi(func: Callable[P, T]) -> Callable[P, T]:
return _inject_wrapper(func, "connection", ASGIConnection)


def inject_websocket(func: Callable[P, T]) -> Callable[P, T]:
return _inject_wrapper(func, "socket", WebSocket)


def _inject_wrapper(
func: Callable[P, T],
param_name: str,
param_annotation: type[Request | WebSocket],
func: Callable[P, T],
param_name: str,
param_annotation: type[Request | WebSocket | ASGIConnection],
) -> Callable[P, T]:
hints = get_type_hints(func)

Expand All @@ -67,11 +85,16 @@ def _inject_wrapper(
kind=Parameter.KEYWORD_ONLY,
)]

if param_name == "connection":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see why user won't put param called connection at any other position in normal handler

container_getter = lambda args, _: args[0].state.dishka_container
else:
container_getter = lambda _, kwargs: kwargs[param_name].state.dishka_container

return wrap_injection(
func=func,
is_async=True,
additional_params=additional_params,
container_getter=lambda _, r: r[param_name].state.dishka_container,
container_getter=container_getter,
)


Expand Down
81 changes: 81 additions & 0 deletions tests/integrations/litestar/test_litestar_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from unittest.mock import Mock

import pytest
from asgi_lifespan import LifespanManager
from litestar import Litestar, get
from litestar.connection import ASGIConnection
from litestar.handlers import BaseRouteHandler
from litestar.types import Guard
from litestar.testing import TestClient

from dishka import make_async_container
from dishka.integrations.litestar import (
DishkaRouter,
FromDishka,
inject_asgi,
setup_dishka,
)
from ..common import (
APP_DEP_VALUE,
AppDep,
AppProvider,
)


@asynccontextmanager
async def dishka_app(guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]:
@get("/", guards=[guard])
async def endpoint() -> dict:
return {"status": "ok"}

app = Litestar([endpoint], debug=True)
container = make_async_container(provider)
setup_dishka(container, app)
async with LifespanManager(app):
yield TestClient(app)
await container.close()


@asynccontextmanager
async def dishka_auto_app(guard: Guard, provider: AppProvider) -> AsyncGenerator[TestClient, None]:
@get("/")
async def endpoint() -> dict:
return {"status": "ok"}

router = DishkaRouter("", route_handlers=[endpoint], guards=[guard])
app = Litestar([router], debug=True)
container = make_async_container(provider)
setup_dishka(container, app)
async with LifespanManager(app):
yield TestClient(app)
await container.close()


@inject_asgi
async def guard_with_app(
connection: ASGIConnection,
_: BaseRouteHandler,
a: FromDishka[AppDep],
mock: FromDishka[Mock],
) -> None:
mock(a)


@pytest.mark.asyncio
@pytest.mark.parametrize(
("app_factory", "guard"),
[
(dishka_app, guard_with_app),
],
)
async def test_guard_injects_app_dependency(
app_factory,
guard,
app_provider: AppProvider,
):
async with app_factory(guard, app_provider) as client:
response = client.get("/")
assert response.status_code == 200
app_provider.mock.assert_called_with(APP_DEP_VALUE)