From 9db7dc79b9def17b1e7784aa4f127d3c14e9a784 Mon Sep 17 00:00:00 2001 From: Tuukka Mustonen Date: Sat, 1 Mar 2025 13:20:51 +0200 Subject: [PATCH] WIP: Sketching --- src/dishka/integrations/base.py | 23 ++++++++++-- tests/integrations/fastapi/test_fastapi.py | 41 +++++++++++++++++++++- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 24424b481..ba3a44dd4 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -1,3 +1,4 @@ +import inspect from collections.abc import Awaitable, Callable, Sequence from inspect import ( Parameter, @@ -96,9 +97,15 @@ def wrap_injection( func_signature = signature(func) dependencies = {} - for name, param in func_signature.parameters.items(): - hint = hints.get(name, Any) - dep = parse_dependency(param, hint) + for index, (name, param) in enumerate(func_signature.parameters.items()): + if name == "self" and index == 0: + # If it's a method in a class, by the time this is run the class + # hasn't been created yet, and inspection would fail. So, + # postpone it. + dep = DependencyKey(func, DEFAULT_COMPONENT) + else: + hint = hints.get(name, Any) + dep = parse_dependency(param, hint) if dep is None: continue dependencies[name] = dep @@ -185,6 +192,11 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) + + if (dep := dependencies.get("self")) and dep.type_hint == func: + klass = inspect._findclass(dep.type_hint) # noqa: SLF001 + dependencies["self"] = DependencyKey(klass, dep.component) + solved = { name: await container.get( dep.type_hint, component=dep.component, @@ -198,6 +210,11 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) + + if (dep := dependencies.get("self")) and dep.type_hint == func: + klass = inspect._findclass(dep.type_hint) # noqa: SLF001 + dependencies["self"] = DependencyKey(klass, dep.component) + solved = { name: await container.get( dep.type_hint, component=dep.component, diff --git a/tests/integrations/fastapi/test_fastapi.py b/tests/integrations/fastapi/test_fastapi.py index 162084f4b..a3a6212ac 100644 --- a/tests/integrations/fastapi/test_fastapi.py +++ b/tests/integrations/fastapi/test_fastapi.py @@ -7,7 +7,7 @@ from asgi_lifespan import LifespanManager from fastapi.testclient import TestClient -from dishka import make_async_container +from dishka import Scope, make_async_container, provide from dishka.integrations.fastapi import ( DishkaRoute, FromDishka, @@ -18,6 +18,7 @@ APP_DEP_VALUE, REQUEST_DEP_VALUE, AppDep, + AppMock, AppProvider, RequestDep, ) @@ -68,6 +69,44 @@ async def test_app_dependency(app_provider: AppProvider, app_factory): app_provider.app_released.assert_called() +class Wrapper: + def __init__( + self, + a: AppDep, + app_mock: AppMock, + ): + self.a = a + self.app_mock = app_mock + + async def get_with_app( + self, + a: FromDishka[AppDep], + app_mock: FromDishka[AppMock], + ) -> None: + assert self.a == a + assert self.app_mock == app_mock + app_mock(a) + + +class LocalProvider(AppProvider): + scope = Scope.APP + + wrapper = provide(Wrapper) + + +@pytest.mark.parametrize("app_factory", [ + dishka_app, dishka_auto_app, +]) +@pytest.mark.asyncio +async def test_app_dependency_class(app_factory): + app_provider = LocalProvider() + async with app_factory(Wrapper.get_with_app, app_provider) as client: + client.get("/") + app_provider.app_mock.assert_called_with(APP_DEP_VALUE) + app_provider.app_released.assert_not_called() + app_provider.app_released.assert_called() + + async def get_with_request( a: FromDishka[RequestDep], mock: FromDishka[Mock],