diff --git a/src/dishka/code_tools/factory_compiler.py b/src/dishka/code_tools/factory_compiler.py index f10a4e2d5..d099307ad 100644 --- a/src/dishka/code_tools/factory_compiler.py +++ b/src/dishka/code_tools/factory_compiler.py @@ -1,10 +1,11 @@ +import asyncio import contextlib from collections.abc import Callable, Iterator from contextlib import AbstractContextManager from typing import Any, TypeAlias, cast from dishka.code_tools.code_builder import CodeBuilder -from dishka.container_objects import CompiledFactory +from dishka.container_objects import CompiledFactory, _Pending from dishka.dependency_source import Factory from dishka.entities.component import Component from dishka.entities.factory_type import FactoryType @@ -27,11 +28,11 @@ class FactoryBuilder(CodeBuilder): def __init__( - self, - *, - is_async: bool, - getter_prefix: str, - container_key: DependencyKey, + self, + *, + is_async: bool, + getter_prefix: str, + container_key: DependencyKey, ) -> None: super().__init__(is_async=is_async) self.provides_name = "" @@ -82,21 +83,50 @@ def getter( return self.await_( self.call( factory, - "getter", "exits", "cache", "context", "container", + "getter", + "exits", + "cache", + "context", + "container", ), ) - return self.await_(self.call( - "getter", self.global_(obj.as_compilation_key()), - )) + return self.await_( + self.call( + "getter", + self.global_(obj.as_compilation_key()), + ) + ) + + def _is_pending_aware(self, factory: Factory) -> bool: + return ( + self._is_async + and factory.cache + and factory.type is not FactoryType.CONTEXT + ) def cache(self, factory: Factory) -> None: if factory.cache and factory.type is not FactoryType.CONTEXT: self.assign_expr(f"cache[{self.cache_key}]", "solved") + if self._is_pending_aware(factory): + self.statement("_pending.set_result(solved)") def return_if_cached(self, factory: Factory) -> None: if factory.cache and factory.type is not FactoryType.CONTEXT: with self.if_(f"{self.cache_key} in cache"): - self.return_(f"cache[{self.cache_key}]") + if self._is_pending_aware(factory): + pending_cls = self.global_(_Pending, "_Pending") + self.assign_local("_cached", f"cache[{self.cache_key}]") + with self.if_(f"isinstance(_cached, {pending_cls})"): + self.return_(self.await_("_cached")) + self.return_("_cached") + else: + self.return_(f"cache[{self.cache_key}]") + + def place_pending(self, factory: Factory) -> None: + if self._is_pending_aware(factory): + pending_cls = self.global_(_Pending, "_Pending") + self.assign_local("_pending", self.call(pending_cls)) + self.assign_expr(f"cache[{self.cache_key}]", "_pending") def assign_solved(self, expr: str) -> None: self.assign_local("solved", expr) @@ -133,9 +163,30 @@ def when( f" with marker {marker}", ) return self.getter( - DependencyKey(marker, component), compiled_deps, + DependencyKey(marker, component), + compiled_deps, ) + def getter_coro( + self, + obj: DependencyKey, + compiled_deps: dict[DependencyKey, CompiledFactory], + ) -> str: + if obj in compiled_deps: + factory = self.global_(compiled_deps[obj]) + return self.call( + factory, + "getter", + "exits", + "cache", + "context", + "container", + ) + return self.call( + "getter", + self.global_(obj.as_compilation_key()), + ) + def build_getter(self) -> CompiledFactory: name = f"<{self.getter_name}{'_async' if self.async_str else ''}>" return cast(CompiledFactory, self.compile(name)[self.getter_name]) @@ -295,7 +346,9 @@ def _collection_factory_body( assigned = False for variant in factory.when_dependencies: condition = builder.when( - variant.when_override, variant.when_component, compiled_deps, + variant.when_override, + variant.when_component, + compiled_deps, ) if condition: if not assigned: @@ -363,7 +416,9 @@ def _select_when_dependency( first = True for variant in factory.when_dependencies: condition = builder.when( - variant.when_override, factory.when_component, compiled_deps, + variant.when_override, + factory.when_component, + compiled_deps, ) solved_value = builder.getter(variant.provides, compiled_deps) if first and not condition: @@ -383,46 +438,126 @@ def _select_when_dependency( return False +def _is_sync_dep( + dep: DependencyKey, + compiled_deps: dict[DependencyKey, CompiledFactory], + container_key: DependencyKey, +) -> bool: + return ( + dep.is_const() + or dep.type_hint is DependencyKey + or dep == container_key + ) + + +def _make_gathered_source_call( + builder: FactoryBuilder, + factory: Factory, + compiled_deps: dict[DependencyKey, CompiledFactory], +) -> str | None: + async_deps: list[tuple[int, int | str, DependencyKey]] = [] + pos_exprs: dict[int, str] = {} + kw_exprs: dict[str, str] = {} + + gather_idx = 0 + for i, dep in enumerate(factory.dependencies): + if _is_sync_dep(dep, compiled_deps, builder.container_key): + pos_exprs[i] = builder.getter(dep, compiled_deps) + else: + async_deps.append((gather_idx, i, dep)) + gather_idx += 1 + + for name, dep in factory.kw_dependencies.items(): + if _is_sync_dep(dep, compiled_deps, builder.container_key): + kw_exprs[name] = builder.getter(dep, compiled_deps) + else: + async_deps.append((gather_idx, name, dep)) + gather_idx += 1 + + if len(async_deps) < 2: # noqa: PLR2004 + return None + + coro_exprs = [ + builder.getter_coro(dep, compiled_deps) for _, _, dep in async_deps + ] + gather_func = builder.global_(asyncio.gather, "asyncio_gather") + builder.assign_local( + "_deps", + builder.await_(builder.call(gather_func, *coro_exprs)), + ) + + for gather_i, key_or_idx, _ in async_deps: + if isinstance(key_or_idx, int): + pos_exprs[key_or_idx] = f"_deps[{gather_i}]" + else: + kw_exprs[key_or_idx] = f"_deps[{gather_i}]" + + ordered_pos = [pos_exprs[i] for i in range(len(factory.dependencies))] + return builder.call( + builder.global_(factory.source), + *ordered_pos, + **kw_exprs, + ) + + def _make_body( - builder: FactoryBuilder, - factory: Factory, - compiled_deps: dict[DependencyKey, CompiledFactory], + builder: FactoryBuilder, + factory: Factory, + compiled_deps: dict[DependencyKey, CompiledFactory], + *, + can_gather: bool = False, ) -> None: if factory.type is FactoryType.COLLECTION: _collection_factory_body(builder, factory, compiled_deps) else: has_default = _select_when_dependency( - builder, factory, compiled_deps, + builder, + factory, + compiled_deps, ) if not has_default: - source_call = builder.call( - builder.global_(factory.source), - *( - builder.getter(dep, compiled_deps) - for dep in factory.dependencies - ), - **{ - name: builder.getter(dep, compiled_deps) - for name, dep in factory.kw_dependencies.items() - }, - ) + source_call = None + if can_gather and builder.async_str: + source_call = _make_gathered_source_call( + builder, + factory, + compiled_deps, + ) + if source_call is None: + source_call = builder.call( + builder.global_(factory.source), + *( + builder.getter(dep, compiled_deps) + for dep in factory.dependencies + ), + **{ + name: builder.getter(dep, compiled_deps) + for name, dep in factory.kw_dependencies.items() + }, + ) body_generator = BODY_GENERATORS[factory.type] if factory.when_dependencies: # conditions generated with builder.else_(): body_generator( - builder, source_call, factory, compiled_deps, + builder, + source_call, + factory, + compiled_deps, ) else: # no options at all body_generator( - builder, source_call, factory, compiled_deps, + builder, + source_call, + factory, + compiled_deps, ) def _has_deps(factory: Factory) -> bool: return bool( - factory.dependencies or - factory.kw_dependencies or - factory.when_dependencies, + factory.dependencies + or factory.kw_dependencies + or factory.when_dependencies, ) @@ -432,6 +567,7 @@ def compile_factory( is_async: bool, compiled_deps: dict[DependencyKey, CompiledFactory], container_key: DependencyKey, + can_gather: bool = False, ) -> CompiledFactory: if not is_async and factory.type in ASYNC_TYPES: raise UnsupportedFactoryError(factory) @@ -445,15 +581,56 @@ def compile_factory( ) builder.register_provides(factory.provides) + pending_aware = ( + is_async and factory.cache and factory.type is not FactoryType.CONTEXT + ) + with builder.make_getter(): builder.return_if_cached(factory) - if _has_deps(factory): - with builder.handle_no_dep(factory): - _make_body(builder, factory, compiled_deps) + builder.place_pending(factory) + if pending_aware: + with builder.try_(): + if _has_deps(factory): + with builder.handle_no_dep(factory): + _make_body( + builder, + factory, + compiled_deps, + can_gather=can_gather, + ) + else: + _make_body( + builder, + factory, + compiled_deps, + can_gather=can_gather, + ) + builder.cache(factory) + builder.return_("solved") + with builder.except_(BaseException, as_="_exc"): + builder.statement( + f"cache.pop({builder.cache_key}, None)", + ) + builder.statement("_pending.set_exception(_exc)") + builder.raise_() else: - _make_body(builder, factory, compiled_deps) - builder.cache(factory) - builder.return_("solved") + if _has_deps(factory): + with builder.handle_no_dep(factory): + _make_body( + builder, + factory, + compiled_deps, + can_gather=can_gather, + ) + else: + _make_body( + builder, + factory, + compiled_deps, + can_gather=can_gather, + ) + builder.cache(factory) + builder.return_("solved") return builder.build_getter() @@ -473,7 +650,9 @@ def compile_activation( builder.register_provides(factory.provides) with builder.make_getter(): condition = builder.when( - factory.when_active, factory.when_component, compiled_deps, + factory.when_active, + factory.when_component, + compiled_deps, ) if not condition: builder.return_(builder.global_(True)) diff --git a/src/dishka/container_objects.py b/src/dishka/container_objects.py index c2c55e55e..72ba38ef4 100644 --- a/src/dishka/container_objects.py +++ b/src/dishka/container_objects.py @@ -1,9 +1,35 @@ +import asyncio from abc import abstractmethod from collections.abc import AsyncGenerator, Callable, Generator from typing import Any, Protocol, TypeAlias from dishka.entities.key import CompilationKey + +class _Pending: + """Sentinel placed in the cache while an async factory is being resolved. + + If a concurrent coroutine (from asyncio.gather) tries to resolve the same + dependency, it finds this sentinel and awaits the embedded Future instead + of creating a duplicate. + """ + + __slots__ = ("_future",) + + def __init__(self) -> None: + loop = asyncio.get_running_loop() + self._future: asyncio.Future[Any] = loop.create_future() + + def set_result(self, value: Any) -> None: + self._future.set_result(value) + + def set_exception(self, exc: BaseException) -> None: + self._future.set_exception(exc) + + def __await__(self) -> Generator[Any, None, Any]: # type: ignore[override] + return self._future.__await__() + + Exit: TypeAlias = tuple[ Generator[Any, Any, Any] | None, AsyncGenerator[Any, Any] | None, @@ -13,11 +39,11 @@ class CompiledFactory(Protocol): @abstractmethod def __call__( - self, - getter: Callable[[CompilationKey], Any] | None, - exits: list[Exit], - cache: Any, - context: Any, - container: Any, + self, + getter: Callable[[CompilationKey], Any] | None, + exits: list[Exit], + cache: Any, + context: Any, + container: Any, ) -> Any: raise NotImplementedError diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 5572c2c30..5d103112b 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -63,12 +63,12 @@ class Registry: ) def __init__( - self, - scope: BaseScope, - *, - has_fallback: bool, - container_key: DependencyKey, - child_registry: "Registry | None" = None, + self, + scope: BaseScope, + *, + has_fallback: bool, + container_key: DependencyKey, + child_registry: "Registry | None" = None, ) -> None: self.scope = scope self.factories: dict[DependencyKey, Factory] = {} @@ -81,9 +81,9 @@ def __init__( self.child_registry = child_registry def add_factory( - self, - factory: Factory, - provides: DependencyKey | None = None, + self, + factory: Factory, + provides: DependencyKey | None = None, ) -> None: if provides is None: provides = factory.provides @@ -98,21 +98,23 @@ def add_factory( self.factories[origin_key] = factory def collect_deps(self, factory: Factory) -> list[DependencyKey]: - return list(itertools.chain( - factory.dependencies, - factory.kw_dependencies.values(), - (f.provides for f in factory.when_dependencies), - ( - DependencyKey(m, f.when_component) - for f in factory.when_dependencies - for m in unpack_marker(f.when_override) - ), - ( - DependencyKey(m, factory.when_component) - for marker in (factory.when_active, factory.when_override) - for m in unpack_marker(marker) - ), - )) + return list( + itertools.chain( + factory.dependencies, + factory.kw_dependencies.values(), + (f.provides for f in factory.when_dependencies), + ( + DependencyKey(m, f.when_component) + for f in factory.when_dependencies + for m in unpack_marker(f.when_override) + ), + ( + DependencyKey(m, factory.when_component) + for marker in (factory.when_active, factory.when_override) + for m in unpack_marker(marker) + ), + ) + ) def _compile_deps( self, @@ -137,7 +139,8 @@ def _compile_deps_async( return res def get_compiled( - self, dependency: CompilationKey, + self, + dependency: CompilationKey, ) -> CompiledFactory | None: try: return self.compiled[dependency] @@ -167,8 +170,42 @@ def get_compiled( self.compiled[dependency] = compiled return compiled + def _get_transitive_dep_keys( + self, + dep_key: DependencyKey, + visited: set[DependencyKey] | None = None, + ) -> set[DependencyKey]: + if visited is None: + visited = set() + if dep_key in visited: + return visited + factory = self.get_factory(dep_key) + if factory is None: + return visited + visited.add(dep_key) + for sub_dep in self.collect_deps(factory): + self._get_transitive_dep_keys(sub_dep, visited) + return visited + + def _can_gather_deps(self, factory: Factory) -> bool: + all_deps = list(factory.dependencies) + list( + factory.kw_dependencies.values(), + ) + resolvable = [ + d + for d in all_deps + if not d.is_const() + and d.type_hint is not DependencyKey + and d != self.container_key + ] + if len(resolvable) < 2: # noqa: PLR2004 + return False + + return True + def get_compiled_async( - self, dependency: CompilationKey, + self, + dependency: CompilationKey, ) -> CompiledFactory | None: try: return self.compiled_async[dependency] @@ -193,12 +230,14 @@ def get_compiled_async( is_async=True, compiled_deps=self._compile_deps_async(factory), container_key=self.container_key, + can_gather=self._can_gather_deps(factory), ) self.compiled_async[dependency] = compiled return compiled def get_compiled_activation( - self, dependency: CompilationKey, + self, + dependency: CompilationKey, ) -> CompiledFactory | None: try: return self.compiled_activation[dependency] @@ -229,7 +268,8 @@ def get_compiled_activation( return compiled def get_compiled_activation_async( - self, dependency: CompilationKey, + self, + dependency: CompilationKey, ) -> CompiledFactory | None: try: return self.compiled_activation_async[dependency] @@ -278,12 +318,9 @@ def get_factory(self, dependency: DependencyKey) -> Factory | None: ) factory = self.factories.get(origin_key) - if ( - not factory or - not is_broader_or_same_type( - factory.provides.type_hint, - dependency.type_hint, - ) + if not factory or not is_broader_or_same_type( + factory.provides.type_hint, + dependency.type_hint, ): return None factory = self._specialize_generic(factory, dependency) @@ -361,7 +398,9 @@ def _get_type_var_factory(self, dependency: DependencyKey) -> Factory: ) def _specialize_generic( - self, factory: Factory, dependency_key: DependencyKey, + self, + factory: Factory, + dependency_key: DependencyKey, ) -> Factory: params_replacement = get_typevar_replacement( factory.provides.type_hint, @@ -373,25 +412,29 @@ def _specialize_generic( if isinstance(hint, TypeVar): hint = params_replacement[hint] elif get_origin(hint) and (type_vars := get_type_vars(hint)): - hint = hint[tuple( - params_replacement[param] - for param in type_vars - )] - new_dependencies.append(DependencyKey( - hint, source_dependency.component, source_dependency.depth, - )) + hint = hint[ + tuple(params_replacement[param] for param in type_vars) + ] + new_dependencies.append( + DependencyKey( + hint, + source_dependency.component, + source_dependency.depth, + ) + ) new_kw_dependencies: dict[str, DependencyKey] = {} for name, source_dependency in factory.kw_dependencies.items(): hint = source_dependency.type_hint if isinstance(hint, TypeVar): hint = params_replacement[hint] elif get_origin(hint) and (type_vars := get_type_vars(hint)): - hint = hint[tuple( - params_replacement[param] - for param in type_vars - )] + hint = hint[ + tuple(params_replacement[param] for param in type_vars) + ] new_kw_dependencies[name] = DependencyKey( - hint, source_dependency.component, source_dependency.depth, + hint, + source_dependency.component, + source_dependency.depth, ) return Factory( source=factory.source, diff --git a/tests/unit/container/test_async_gather.py b/tests/unit/container/test_async_gather.py new file mode 100644 index 000000000..947e7a5ba --- /dev/null +++ b/tests/unit/container/test_async_gather.py @@ -0,0 +1,626 @@ +import asyncio +from collections.abc import AsyncIterator +from unittest.mock import Mock + +import pytest + +from dishka import ( + Provider, + Scope, + make_async_container, + provide, +) + +# --- Helpers --- + + +async def _timed_get(container, key): + """Get a dependency and return (result, elapsed_seconds).""" + loop = asyncio.get_running_loop() + start = loop.time() + result = await container.get(key) + return result, loop.time() - start + + +def _assert_fast(elapsed, threshold, label="deps should be gathered"): + assert elapsed < threshold, ( + f"Took {elapsed:.3f}s, expected < {threshold}s ({label})" + ) + + +# --- Diamond helper mixin --- +# Many tests need the same float→int, str→int, bytes→(float,str) diamond. + + +class DiamondConsumersMixin: + """Provides float(int), str(int), bytes(float, str) diamond consumers.""" + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +# --- Basic gathering tests --- + + +class IndependentDepsProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_independent_deps_gathered(): + container = make_async_container(IndependentDepsProvider()) + result, elapsed = await _timed_get(container, str) + assert result == "1-2.0" + # With gathering: ~0.1s. Without: ~0.2s. Use 0.15 as threshold. + _assert_fast(elapsed, 0.15) + await container.close() + + +class SharedTransitiveDepsProvider(DiamondConsumersMixin, Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return 42 + + +@pytest.mark.asyncio +async def test_shared_transitive_deps_gathered(): + """Shared cached transitive deps are gathered via pending sentinel.""" + container = make_async_container(SharedTransitiveDepsProvider()) + result = await container.get(bytes) + assert result == b"42.0-42" + await container.close() + + +class ThreeIndependentProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + async def get_bytes(self) -> bytes: + await asyncio.sleep(0.1) + return b"3" + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float, c: bytes) -> str: + return f"{a}-{b}-{c!r}" + + +@pytest.mark.asyncio +async def test_three_independent_deps_gathered(): + container = make_async_container(ThreeIndependentProvider()) + result, elapsed = await _timed_get(container, str) + assert result == "1-2.0-b'3'" + # With gathering: ~0.1s. Without: ~0.3s. + _assert_fast(elapsed, 0.15) + await container.close() + + +class MixedSyncAsyncProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + def get_bytes(self) -> bytes: + return b"sync" + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float, c: bytes) -> str: + return f"{a}-{b}-{c!r}" + + +@pytest.mark.asyncio +async def test_mixed_sync_async_deps(): + container = make_async_container(MixedSyncAsyncProvider()) + result, elapsed = await _timed_get(container, str) + assert result == "1-2.0-b'sync'" + _assert_fast(elapsed, 0.15) + await container.close() + + +class CachedGatherProvider(Provider): + def __init__(self, mock: Mock): + super().__init__() + self.mock = mock + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.mock() + await asyncio.sleep(0.05) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_cached_deps_not_recreated(): + mock = Mock() + container = make_async_container(CachedGatherProvider(mock)) + + result1 = await container.get(str) + result2 = await container.get(str) + assert result1 == result2 == "1-2.0" + mock.assert_called_once() + await container.close() + + +@pytest.mark.asyncio +async def test_gather_works_without_lock(): + container = make_async_container( + IndependentDepsProvider(), + lock_factory=None, + ) + result, elapsed = await _timed_get(container, str) + assert result == "1-2.0" + _assert_fast(elapsed, 0.15) + await container.close() + + +class KwOnlyDepsProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 10 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 1.5 + + @provide(scope=Scope.APP) + async def get_str(self, *, a: int, b: float) -> str: + return f"{a}+{b}" + + +@pytest.mark.asyncio +async def test_keyword_only_deps_gathered(): + container = make_async_container(KwOnlyDepsProvider()) + result, elapsed = await _timed_get(container, str) + assert result == "10+1.5" + _assert_fast(elapsed, 0.15) + await container.close() + + +class MixedScopeProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return 1 + + @provide(scope=Scope.REQUEST) + async def get_float(self) -> float: + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.REQUEST) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_mixed_scopes(): + container = make_async_container(MixedScopeProvider()) + async with container(scope=Scope.REQUEST) as request_container: + result = await request_container.get(str) + assert result == "1-2.0" + await container.close() + + +class ErrorProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + msg = "int factory failed" + raise ValueError(msg) + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_error_propagation(): + container = make_async_container(ErrorProvider()) + with pytest.raises(ValueError, match="int factory failed"): + await container.get(str) + await container.close() + + +class ConcurrentAccessProvider(Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.05) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + self.call_counts["float"] = self.call_counts.get("float", 0) + 1 + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_concurrent_container_access(): + call_counts: dict = {} + container = make_async_container(ConcurrentAccessProvider(call_counts)) + + results = await asyncio.gather( + container.get(str), + container.get(str), + container.get(str), + ) + for r in results: + assert r == "1-2.0" + await container.close() + + +class SingleAsyncDepProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return 42 + + @provide(scope=Scope.APP) + async def get_str(self, a: int) -> str: + return str(a) + + +@pytest.mark.asyncio +async def test_single_async_dep(): + container = make_async_container(SingleAsyncDepProvider()) + result = await container.get(str) + assert result == "42" + await container.close() + + +class MixedPosKwProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 5 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 1.5 + + @provide(scope=Scope.APP) + async def get_bytes(self) -> bytes: + await asyncio.sleep(0.1) + return b"data" + + @provide(scope=Scope.APP) + async def get_str(self, a: int, *, b: float, c: bytes) -> str: + return f"{a}-{b}-{c!r}" + + +@pytest.mark.asyncio +async def test_mixed_positional_and_keyword(): + container = make_async_container(MixedPosKwProvider()) + result, elapsed = await _timed_get(container, str) + assert result == "5-1.5-b'data'" + _assert_fast(elapsed, 0.15) + await container.close() + + +class AllSyncInAsyncProvider(Provider): + @provide(scope=Scope.APP) + def get_int(self) -> int: + return 1 + + @provide(scope=Scope.APP) + def get_float(self) -> float: + return 2.0 + + @provide(scope=Scope.APP) + def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_all_sync_in_async_container(): + container = make_async_container(AllSyncInAsyncProvider()) + result = await container.get(str) + assert result == "1-2.0" + await container.close() + + +# --- Diamond pattern (shared transitive dep) tests --- + + +class DiamondSingleCreationProvider(DiamondConsumersMixin, Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.1) + return 42 + + +@pytest.mark.asyncio +async def test_diamond_single_creation(): + """Shared dep (int) is created once via pending sentinel.""" + call_counts: dict = {} + container = make_async_container( + DiamondSingleCreationProvider(call_counts), + ) + result = await container.get(bytes) + assert result == b"42.0-42" + assert call_counts["int"] == 1 + await container.close() + + +class DiamondTimingProvider(DiamondConsumersMixin, Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + await asyncio.sleep(0.1) + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + await asyncio.sleep(0.1) + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_concurrent_timing(): + """Diamond deps (float, str) are gathered concurrently.""" + container = make_async_container(DiamondTimingProvider()) + result, elapsed = await _timed_get(container, bytes) + assert result == b"1.0-1" + # Sequential: int(0.1) + float(0.1) + str(0.1) = 0.3s + # Gathered: int(0.1) + max(float(0.1), str(0.1)) = 0.2s + _assert_fast(elapsed, 0.25, "diamond should be gathered") + await container.close() + + +class DiamondErrorProvider(DiamondConsumersMixin, Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + msg = "shared dep failed" + raise ValueError(msg) + + +@pytest.mark.asyncio +async def test_diamond_error_propagation(): + """Error in shared transitive dep propagates correctly through gather.""" + container = make_async_container(DiamondErrorProvider()) + with pytest.raises(ValueError, match="shared dep failed"): + await container.get(bytes) + await container.close() + + +class DiamondAsyncGeneratorProvider(DiamondConsumersMixin, Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + + @provide(scope=Scope.APP) + async def get_int(self) -> AsyncIterator[int]: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.05) + yield 42 + + +@pytest.mark.asyncio +async def test_diamond_async_generator_cached(): + """Shared async generator dep is created once, exit registered once.""" + call_counts: dict = {} + container = make_async_container( + DiamondAsyncGeneratorProvider(call_counts), + ) + result = await container.get(bytes) + assert result == b"42.0-42" + assert call_counts["int"] == 1 + await container.close() + + +# --- cache=False shared transitive dep tests --- + + +class DiamondUncachedFactoryProvider(DiamondConsumersMixin, Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + self._counter = 0 + + @provide(scope=Scope.APP, cache=False) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + self._counter += 1 + await asyncio.sleep(0.1) + return self._counter + + +@pytest.mark.asyncio +async def test_diamond_uncached_factory_creates_per_consumer(): + """cache=False shared dep is created independently per consumer branch.""" + call_counts: dict = {} + container = make_async_container( + DiamondUncachedFactoryProvider(call_counts), + ) + await container.get(bytes) + assert call_counts["int"] == 2 + await container.close() + + +@pytest.mark.asyncio +async def test_diamond_uncached_factory_concurrent_timing(): + """cache=False shared dep branches run concurrently when gathered.""" + call_counts: dict = {} + container = make_async_container( + DiamondUncachedFactoryProvider(call_counts), + ) + _result, elapsed = await _timed_get(container, bytes) + # Sequential: int(0.1) + int(0.1) = 0.2s + # Gathered: max(int(0.1), int(0.1)) = 0.1s + _assert_fast(elapsed, 0.15, "should gather uncached") + assert call_counts["int"] == 2 + await container.close() + + +class DiamondUncachedAsyncGenProvider(DiamondConsumersMixin, Provider): + def __init__(self, call_counts: dict, closed: list): + super().__init__() + self.call_counts = call_counts + self.closed = closed + self._counter = 0 + + @provide(scope=Scope.APP, cache=False) + async def get_int(self) -> AsyncIterator[int]: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + self._counter += 1 + val = self._counter + await asyncio.sleep(0.05) + yield val + self.closed.append(val) + + +@pytest.mark.asyncio +async def test_diamond_uncached_async_generator_per_consumer(): + """cache=False async generator: each branch gets own instance.""" + call_counts: dict = {} + closed: list = [] + container = make_async_container( + DiamondUncachedAsyncGenProvider(call_counts, closed), + ) + await container.get(bytes) + assert call_counts["int"] == 2 + await container.close() + assert len(closed) == 2 + + +class DiamondUncachedErrorProvider(DiamondConsumersMixin, Provider): + @provide(scope=Scope.APP, cache=False) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + msg = "uncached dep fails" + raise ValueError(msg) + + +@pytest.mark.asyncio +async def test_diamond_uncached_error_one_branch(): + """cache=False shared dep: error propagates correctly through gather.""" + container = make_async_container(DiamondUncachedErrorProvider()) + with pytest.raises(ValueError, match="uncached dep fails"): + await container.get(bytes) + await container.close() + + +class DiamondMixedCacheProvider(DiamondConsumersMixin, Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + self._uncached_counter = 0 + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.05) + return 42 + + @provide(scope=Scope.APP, cache=False) + async def get_complex(self) -> complex: + self.call_counts["complex"] = self.call_counts.get("complex", 0) + 1 + self._uncached_counter += 1 + await asyncio.sleep(0.1) + return complex(self._uncached_counter, 0) + + @provide(scope=Scope.APP) + async def get_float(self, v: int, c: complex) -> float: + return float(v) + c.real + + @provide(scope=Scope.APP) + async def get_str(self, v: int, c: complex) -> str: + return f"{v}+{c.real}" + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_mixed_cached_uncached(): + """Mixed diamond: float and str branches run concurrently. + + int is cached (created once via pending sentinel), complex is uncached + (created per branch). Timing proves concurrency: sequential would take + int(0.05) + complex(0.1) + complex(0.1) = 0.25s; gathered takes + int(0.05) + max(complex(0.1), complex(0.1)) = 0.15s. + """ + call_counts: dict = {} + container = make_async_container(DiamondMixedCacheProvider(call_counts)) + _result, elapsed = await _timed_get(container, bytes) + assert call_counts["int"] == 1 # cached — created once + assert call_counts["complex"] == 2 # uncached — created per branch + _assert_fast(elapsed, 0.2, "should gather mixed") + await container.close() diff --git a/tests/unit/test_async_gather.py b/tests/unit/test_async_gather.py new file mode 100644 index 000000000..d40ae66d7 --- /dev/null +++ b/tests/unit/test_async_gather.py @@ -0,0 +1,317 @@ +from typing import Literal + +import pytest + +from dishka.code_tools.factory_compiler import _is_sync_dep +from dishka.dependency_source.factory import Factory +from dishka.entities.component import DEFAULT_COMPONENT +from dishka.entities.factory_type import FactoryType +from dishka.entities.key import DependencyKey +from dishka.entities.scope import Scope +from dishka.registry import Registry + + +def _make_key(type_hint, component=DEFAULT_COMPONENT): + return DependencyKey(type_hint, component) + + +def _make_factory( + provides, + dependencies=None, + kw_dependencies=None, + *, + cache=True, +): + return Factory( + source=lambda: None, + provides=provides, + dependencies=dependencies or [], + kw_dependencies=kw_dependencies or {}, + scope=Scope.APP, + type_=FactoryType.ASYNC_FACTORY, + is_to_bind=False, + cache=cache, + when_override=None, + when_active=None, + when_component=None, + when_dependencies=[], + ) + + +@pytest.fixture +def container_key(): + return _make_key(object) + + +@pytest.fixture +def registry(container_key): + return Registry( + scope=Scope.APP, + has_fallback=False, + container_key=container_key, + ) + + +# --- _is_sync_dep --- + + +def test_is_sync_dep_const(container_key): + dep = DependencyKey(Literal[42], DEFAULT_COMPONENT) + assert _is_sync_dep(dep, {}, container_key) is True + + +def test_is_sync_dep_dependency_key_type(container_key): + dep = DependencyKey(DependencyKey, DEFAULT_COMPONENT) + assert _is_sync_dep(dep, {}, container_key) is True + + +def test_is_sync_dep_container_key(container_key): + assert _is_sync_dep(container_key, {}, container_key) is True + + +def test_is_sync_dep_regular(container_key): + dep = _make_key(int) + assert _is_sync_dep(dep, {}, container_key) is False + + +# --- _get_transitive_dep_keys --- + + +def test_transitive_no_deps(registry): + key_a = _make_key(int) + registry.add_factory(_make_factory(key_a)) + + assert registry._get_transitive_dep_keys(key_a) == {key_a} + + +def test_transitive_chain(registry): + key_a = _make_key(int) + key_b = _make_key(float) + key_c = _make_key(str) + + registry.add_factory(_make_factory(key_c)) + registry.add_factory(_make_factory(key_b, dependencies=[key_c])) + registry.add_factory(_make_factory(key_a, dependencies=[key_b])) + + assert registry._get_transitive_dep_keys(key_a) == {key_a, key_b, key_c} + + +def test_transitive_cycle(registry): + key_a = _make_key(int) + key_b = _make_key(float) + + registry.add_factory(_make_factory(key_a, dependencies=[key_b])) + registry.add_factory(_make_factory(key_b, dependencies=[key_a])) + + assert registry._get_transitive_dep_keys(key_a) == {key_a, key_b} + + +def test_transitive_missing_factory(registry): + assert registry._get_transitive_dep_keys(_make_key(int)) == set() + + +def test_transitive_diamond(registry): + key_a = _make_key(int) + key_b = _make_key(float) + key_c = _make_key(str) + key_d = _make_key(bytes) + + registry.add_factory(_make_factory(key_d)) + registry.add_factory(_make_factory(key_b, dependencies=[key_d])) + registry.add_factory(_make_factory(key_c, dependencies=[key_d])) + registry.add_factory(_make_factory(key_a, dependencies=[key_b, key_c])) + + assert registry._get_transitive_dep_keys(key_a) == { + key_a, + key_b, + key_c, + key_d, + } + + +def test_transitive_kw_dependencies(registry): + key_a = _make_key(int) + key_b = _make_key(float) + + registry.add_factory(_make_factory(key_b)) + registry.add_factory(_make_factory(key_a, kw_dependencies={"x": key_b})) + + assert registry._get_transitive_dep_keys(key_a) == {key_a, key_b} + + +# --- _can_gather_deps --- + + +def test_can_gather_independent(registry): + key_a = _make_key(int) + key_b = _make_key(float) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_a)) + registry.add_factory(_make_factory(key_b)) + factory = _make_factory(key_target, dependencies=[key_a, key_b]) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True + + +def test_can_gather_shared_transitive_cached(registry): + key_shared = _make_key(bytes) + key_a = _make_key(int) + key_b = _make_key(float) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_shared)) + registry.add_factory(_make_factory(key_a, dependencies=[key_shared])) + registry.add_factory(_make_factory(key_b, dependencies=[key_shared])) + factory = _make_factory(key_target, dependencies=[key_a, key_b]) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True + + +def test_can_gather_shared_transitive_uncached(registry): + key_shared = _make_key(bytes) + key_a = _make_key(int) + key_b = _make_key(float) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_shared, cache=False)) + registry.add_factory(_make_factory(key_a, dependencies=[key_shared])) + registry.add_factory(_make_factory(key_b, dependencies=[key_shared])) + factory = _make_factory(key_target, dependencies=[key_a, key_b]) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True + + +def test_can_gather_single_dep(registry): + key_a = _make_key(int) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_a)) + factory = _make_factory(key_target, dependencies=[key_a]) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is False + + +def test_can_gather_no_deps(registry): + key_target = _make_key(str) + factory = _make_factory(key_target) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is False + + +def test_can_gather_const_excluded(registry): + key_a = _make_key(int) + const_dep = DependencyKey(Literal[42], DEFAULT_COMPONENT) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_a)) + factory = _make_factory(key_target, dependencies=[key_a, const_dep]) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is False + + +def test_can_gather_container_key_excluded(container_key): + key_a = _make_key(int) + key_target = _make_key(str) + + reg = Registry( + scope=Scope.APP, + has_fallback=False, + container_key=container_key, + ) + reg.add_factory(_make_factory(key_a)) + factory = _make_factory(key_target, dependencies=[key_a, container_key]) + reg.add_factory(factory) + + assert reg._can_gather_deps(factory) is False + + +def test_can_gather_dependency_key_excluded(registry): + key_a = _make_key(int) + dk_dep = DependencyKey(DependencyKey, DEFAULT_COMPONENT) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_a)) + factory = _make_factory(key_target, dependencies=[key_a, dk_dep]) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is False + + +def test_can_gather_three_independent(registry): + key_a = _make_key(int) + key_b = _make_key(float) + key_c = _make_key(bytes) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_a)) + registry.add_factory(_make_factory(key_b)) + registry.add_factory(_make_factory(key_c)) + factory = _make_factory( + key_target, + dependencies=[key_a, key_b, key_c], + ) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True + + +def test_can_gather_kw_deps(registry): + key_a = _make_key(int) + key_b = _make_key(float) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_a)) + registry.add_factory(_make_factory(key_b)) + factory = _make_factory( + key_target, + dependencies=[key_a], + kw_dependencies={"b": key_b}, + ) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True + + +def test_can_gather_mixed_pos_kw_shared_transitive_cached(registry): + key_shared = _make_key(bytes) + key_a = _make_key(int) + key_b = _make_key(float) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_shared)) + registry.add_factory(_make_factory(key_a, dependencies=[key_shared])) + registry.add_factory(_make_factory(key_b, dependencies=[key_shared])) + factory = _make_factory( + key_target, + dependencies=[key_a], + kw_dependencies={"b": key_b}, + ) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True + + +def test_can_gather_mixed_pos_kw_shared_transitive_uncached(registry): + key_shared = _make_key(bytes) + key_a = _make_key(int) + key_b = _make_key(float) + key_target = _make_key(str) + + registry.add_factory(_make_factory(key_shared, cache=False)) + registry.add_factory(_make_factory(key_a, dependencies=[key_shared])) + registry.add_factory(_make_factory(key_b, dependencies=[key_shared])) + factory = _make_factory( + key_target, + dependencies=[key_a], + kw_dependencies={"b": key_b}, + ) + registry.add_factory(factory) + + assert registry._can_gather_deps(factory) is True