diff --git a/src/dishka/code_tools/code_builder.py b/src/dishka/code_tools/code_builder.py index fa51f84d0..7b3bc2cea 100644 --- a/src/dishka/code_tools/code_builder.py +++ b/src/dishka/code_tools/code_builder.py @@ -28,6 +28,7 @@ def __init__(self, *, is_async: bool) -> None: self.async_str = "" self._is_async = is_async + self.is_async = is_async def _make_global_name(self, obj: Any, name: str | None = None) -> str: if name is None: @@ -175,7 +176,7 @@ def try_(self) -> AbstractContextManager[None]: def except_( self, - exception: type[Exception], + exception: type[BaseException], as_: str = "", ) -> AbstractContextManager[None]: name = self.global_(exception) diff --git a/src/dishka/code_tools/factory_compiler.py b/src/dishka/code_tools/factory_compiler.py index f10a4e2d5..dd0ee7b31 100644 --- a/src/dishka/code_tools/factory_compiler.py +++ b/src/dishka/code_tools/factory_compiler.py @@ -1,10 +1,11 @@ +import asyncio import contextlib from collections.abc import Callable, Iterator from contextlib import AbstractContextManager from typing import Any, TypeAlias, cast from dishka.code_tools.code_builder import CodeBuilder -from dishka.container_objects import CompiledFactory +from dishka.container_objects import CompiledFactory, _Pending from dishka.dependency_source import Factory from dishka.entities.component import Component from dishka.entities.factory_type import FactoryType @@ -89,14 +90,82 @@ def getter( "getter", self.global_(obj.as_compilation_key()), )) + def is_awaitable_dep( + self, + obj: DependencyKey, + compiled_deps: dict[DependencyKey, CompiledFactory], + ) -> bool: + if obj.is_const(): + return False + if obj.type_hint is DependencyKey: + return False + if obj == self.container_key: + return False + return True + + def getter_coro( + self, + obj: DependencyKey, + compiled_deps: dict[DependencyKey, CompiledFactory], + ) -> str: + if obj in compiled_deps: + factory = self.global_(compiled_deps[obj]) + return self.call( + factory, + "getter", "exits", "cache", "context", "container", + ) + return self.call( + "getter", self.global_(obj.as_compilation_key()), + ) + def cache(self, factory: Factory) -> None: if factory.cache and factory.type is not FactoryType.CONTEXT: self.assign_expr(f"cache[{self.cache_key}]", "solved") def return_if_cached(self, factory: Factory) -> None: if factory.cache and factory.type is not FactoryType.CONTEXT: - with self.if_(f"{self.cache_key} in cache"): - self.return_(f"cache[{self.cache_key}]") + if self.is_async: + pending_cls = self.global_(_Pending, "_Pending") + with self.if_(f"{self.cache_key} in cache"): + self.assign_local( + "_cached", f"cache[{self.cache_key}]", + ) + with self.if_( + f"isinstance(_cached, {pending_cls})", + ): + self.return_(self.await_("_cached")) + self.return_("_cached") + else: + with self.if_(f"{self.cache_key} in cache"): + self.return_(f"cache[{self.cache_key}]") + + @property + def _uses_pending(self) -> bool: + return self.is_async + + def place_pending(self, factory: Factory) -> None: + if not self._uses_pending: + return + if not factory.cache or factory.type is FactoryType.CONTEXT: + return + pending_cls = self.global_(_Pending, "_Pending") + self.assign_local("_pending", self.call(pending_cls)) + self.assign_expr(f"cache[{self.cache_key}]", "_pending") + + def pending_resolve(self, factory: Factory) -> None: + if not self._uses_pending: + return + if not factory.cache or factory.type is FactoryType.CONTEXT: + return + self.statement(self.call("_pending.set_result", "solved")) + + def pending_reject(self, factory: Factory) -> None: + if not self._uses_pending: + return + if not factory.cache or factory.type is FactoryType.CONTEXT: + return + self.statement(f"cache.pop({self.cache_key}, None)") + self.statement(self.call("_pending.set_exception", "_exc")) def assign_solved(self, expr: str) -> None: self.assign_local("solved", expr) @@ -383,6 +452,50 @@ def _select_when_dependency( return False +def _gather_async_deps( + builder: FactoryBuilder, + factory: Factory, + compiled_deps: dict[DependencyKey, CompiledFactory], +) -> dict[DependencyKey, str] | None: + """Generate asyncio.gather for 2+ awaitable deps. Returns dep->var map.""" + if not builder.is_async: + return None + if factory.when_dependencies: + return None + + all_deps: list[tuple[DependencyKey, str | None]] = [] + for dep in factory.dependencies: + all_deps.append((dep, None)) + for name, dep in factory.kw_dependencies.items(): + all_deps.append((dep, name)) + + awaitable_deps = [ + (dep, kw_name) + for dep, kw_name in all_deps + if builder.is_awaitable_dep(dep, compiled_deps) + ] + + if len(awaitable_deps) < 2: # noqa: PLR2004 + return None + + gather_fn = builder.global_(asyncio.gather, "asyncio_gather") + coros = [ + builder.getter_coro(dep, compiled_deps) + for dep, _ in awaitable_deps + ] + var_names = [f"_g{i}" for i in range(len(awaitable_deps))] + for var_name in var_names: + builder.locals.add(var_name) + targets = ", ".join(var_names) + gather_call = builder.call(gather_fn, *coros) + builder.statement(f"{targets} = await {gather_call}") + + result: dict[DependencyKey, str] = {} + for (dep, _), var_name in zip(awaitable_deps, var_names, strict=True): + result[dep] = var_name + return result + + def _make_body( builder: FactoryBuilder, factory: Factory, @@ -395,14 +508,23 @@ def _make_body( builder, factory, compiled_deps, ) if not has_default: + gathered = _gather_async_deps( + builder, factory, compiled_deps, + ) + + def _dep_expr(dep: DependencyKey) -> str: + if gathered and dep in gathered: + return gathered[dep] + return builder.getter(dep, compiled_deps) + source_call = builder.call( builder.global_(factory.source), *( - builder.getter(dep, compiled_deps) + _dep_expr(dep) for dep in factory.dependencies ), **{ - name: builder.getter(dep, compiled_deps) + name: _dep_expr(dep) for name, dep in factory.kw_dependencies.items() }, ) @@ -426,6 +548,14 @@ def _has_deps(factory: Factory) -> bool: ) +def _use_pending(factory: Factory, *, is_async: bool) -> bool: + return ( + is_async + and factory.cache + and factory.type is not FactoryType.CONTEXT + ) + + def compile_factory( *, factory: Factory, @@ -444,16 +574,32 @@ def compile_factory( container_key=container_key, ) builder.register_provides(factory.provides) + use_pending = _use_pending(factory, is_async=is_async) with builder.make_getter(): builder.return_if_cached(factory) - if _has_deps(factory): - with builder.handle_no_dep(factory): - _make_body(builder, factory, compiled_deps) + if use_pending: + builder.place_pending(factory) + with builder.try_(): + if _has_deps(factory): + with builder.handle_no_dep(factory): + _make_body(builder, factory, compiled_deps) + else: + _make_body(builder, factory, compiled_deps) + builder.cache(factory) + builder.pending_resolve(factory) + builder.return_("solved") + with builder.except_(BaseException, as_="_exc"): + builder.pending_reject(factory) + builder.raise_() else: - _make_body(builder, factory, compiled_deps) - builder.cache(factory) - builder.return_("solved") + if _has_deps(factory): + with builder.handle_no_dep(factory): + _make_body(builder, factory, compiled_deps) + else: + _make_body(builder, factory, compiled_deps) + builder.cache(factory) + builder.return_("solved") return builder.build_getter() diff --git a/src/dishka/container_objects.py b/src/dishka/container_objects.py index c2c55e55e..08d9f8211 100644 --- a/src/dishka/container_objects.py +++ b/src/dishka/container_objects.py @@ -1,9 +1,36 @@ +import asyncio from abc import abstractmethod from collections.abc import AsyncGenerator, Callable, Generator from typing import Any, Protocol, TypeAlias from dishka.entities.key import CompilationKey + +class _Pending: + """Sentinel placed in the cache while an async factory is being resolved. + + If a concurrent coroutine (from asyncio.gather) tries to resolve the same + dependency, it finds this sentinel and awaits the embedded Future instead + of creating a duplicate. + """ + + __slots__ = ("_future",) + + def __init__(self) -> None: + self._future = asyncio.get_running_loop().create_future() + + def set_result(self, result: Any) -> None: + if not self._future.done(): + self._future.set_result(result) + + def set_exception(self, exc: BaseException) -> None: + if not self._future.done(): + self._future.set_exception(exc) + + def __await__(self) -> Any: + return self._future.__await__() + + Exit: TypeAlias = tuple[ Generator[Any, Any, Any] | None, AsyncGenerator[Any, Any] | None, diff --git a/tests/unit/container/test_async_gather.py b/tests/unit/container/test_async_gather.py new file mode 100644 index 000000000..6adbabfba --- /dev/null +++ b/tests/unit/container/test_async_gather.py @@ -0,0 +1,730 @@ +import asyncio +from collections.abc import AsyncIterator +from unittest.mock import Mock + +import pytest + +from dishka import ( + Provider, + Scope, + make_async_container, + provide, +) + + +class IndependentDepsProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_independent_deps_gathered(): + container = make_async_container(IndependentDepsProvider()) + start = asyncio.get_running_loop().time() + result = await container.get(str) + elapsed = asyncio.get_running_loop().time() - start + + assert result == "1-2.0" + # With gathering: ~0.1s. Without: ~0.2s. Use 0.15 as threshold. + assert elapsed < 0.15, ( + f"Took {elapsed:.3f}s, expected < 0.15s (deps should be gathered)" + ) + await container.close() + + +class SharedTransitiveDepsProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return 42 + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_shared_transitive_deps_gathered(): + """Shared cached transitive deps are gathered via pending sentinel.""" + container = make_async_container(SharedTransitiveDepsProvider()) + result = await container.get(bytes) + assert result == b"42.0-42" + await container.close() + + +class ThreeIndependentProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + async def get_bytes(self) -> bytes: + await asyncio.sleep(0.1) + return b"3" + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float, c: bytes) -> str: + return f"{a}-{b}-{c!r}" + + +@pytest.mark.asyncio +async def test_three_independent_deps_gathered(): + container = make_async_container(ThreeIndependentProvider()) + start = asyncio.get_running_loop().time() + result = await container.get(str) + elapsed = asyncio.get_running_loop().time() - start + + assert result == "1-2.0-b'3'" + # With gathering: ~0.1s. Without: ~0.3s. + assert elapsed < 0.15, f"Took {elapsed:.3f}s, expected < 0.15s" + await container.close() + + +class MixedSyncAsyncProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + def get_bytes(self) -> bytes: + return b"sync" + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float, c: bytes) -> str: + return f"{a}-{b}-{c!r}" + + +@pytest.mark.asyncio +async def test_mixed_sync_async_deps(): + container = make_async_container(MixedSyncAsyncProvider()) + start = asyncio.get_running_loop().time() + result = await container.get(str) + elapsed = asyncio.get_running_loop().time() - start + + assert result == "1-2.0-b'sync'" + assert elapsed < 0.15, f"Took {elapsed:.3f}s, expected < 0.15s" + await container.close() + + +class CachedGatherProvider(Provider): + def __init__(self, mock: Mock): + super().__init__() + self.mock = mock + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.mock() + await asyncio.sleep(0.05) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_cached_deps_not_recreated(): + mock = Mock() + container = make_async_container(CachedGatherProvider(mock)) + + result1 = await container.get(str) + result2 = await container.get(str) + assert result1 == result2 == "1-2.0" + mock.assert_called_once() + await container.close() + + +class NoLockGatherProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_gather_works_without_lock(): + container = make_async_container( + NoLockGatherProvider(), + lock_factory=None, + ) + start = asyncio.get_running_loop().time() + result = await container.get(str) + elapsed = asyncio.get_running_loop().time() - start + + assert result == "1-2.0" + assert elapsed < 0.15, f"Took {elapsed:.3f}s, expected < 0.15s" + await container.close() + + +class KwOnlyDepsProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 10 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 1.5 + + @provide(scope=Scope.APP) + async def get_str(self, *, a: int, b: float) -> str: + return f"{a}+{b}" + + +@pytest.mark.asyncio +async def test_keyword_only_deps_gathered(): + container = make_async_container(KwOnlyDepsProvider()) + start = asyncio.get_running_loop().time() + result = await container.get(str) + elapsed = asyncio.get_running_loop().time() - start + + assert result == "10+1.5" + assert elapsed < 0.15, f"Took {elapsed:.3f}s, expected < 0.15s" + await container.close() + + +class MixedScopeProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return 1 + + @provide(scope=Scope.REQUEST) + async def get_float(self) -> float: + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.REQUEST) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_mixed_scopes(): + container = make_async_container(MixedScopeProvider()) + async with container(scope=Scope.REQUEST) as request_container: + result = await request_container.get(str) + assert result == "1-2.0" + await container.close() + + +class ErrorProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + msg = "int factory failed" + raise ValueError(msg) + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_error_propagation(): + container = make_async_container(ErrorProvider()) + with pytest.raises(ValueError, match="int factory failed"): + await container.get(str) + await container.close() + + +class ConcurrentAccessProvider(Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.05) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + self.call_counts["float"] = self.call_counts.get("float", 0) + 1 + await asyncio.sleep(0.05) + return 2.0 + + @provide(scope=Scope.APP) + async def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_concurrent_container_access(): + call_counts: dict = {} + container = make_async_container(ConcurrentAccessProvider(call_counts)) + + results = await asyncio.gather( + container.get(str), + container.get(str), + container.get(str), + ) + for r in results: + assert r == "1-2.0" + await container.close() + + +class SingleAsyncDepProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return 42 + + @provide(scope=Scope.APP) + async def get_str(self, a: int) -> str: + return str(a) + + +@pytest.mark.asyncio +async def test_single_async_dep(): + container = make_async_container(SingleAsyncDepProvider()) + result = await container.get(str) + assert result == "42" + await container.close() + + +class MixedPosKwProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 5 + + @provide(scope=Scope.APP) + async def get_float(self) -> float: + await asyncio.sleep(0.1) + return 1.5 + + @provide(scope=Scope.APP) + async def get_bytes(self) -> bytes: + await asyncio.sleep(0.1) + return b"data" + + @provide(scope=Scope.APP) + async def get_str(self, a: int, *, b: float, c: bytes) -> str: + return f"{a}-{b}-{c!r}" + + +@pytest.mark.asyncio +async def test_mixed_positional_and_keyword(): + container = make_async_container(MixedPosKwProvider()) + start = asyncio.get_running_loop().time() + result = await container.get(str) + elapsed = asyncio.get_running_loop().time() - start + + assert result == "5-1.5-b'data'" + assert elapsed < 0.15, f"Took {elapsed:.3f}s, expected < 0.15s" + await container.close() + + +class AllSyncInAsyncProvider(Provider): + @provide(scope=Scope.APP) + def get_int(self) -> int: + return 1 + + @provide(scope=Scope.APP) + def get_float(self) -> float: + return 2.0 + + @provide(scope=Scope.APP) + def get_str(self, a: int, b: float) -> str: + return f"{a}-{b}" + + +@pytest.mark.asyncio +async def test_all_sync_in_async_container(): + container = make_async_container(AllSyncInAsyncProvider()) + result = await container.get(str) + assert result == "1-2.0" + await container.close() + + +# --- Diamond pattern (shared transitive dep) tests --- + + +class DiamondSingleCreationProvider(Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.1) + return 42 + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + self.call_counts["float"] = self.call_counts.get("float", 0) + 1 + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + self.call_counts["str"] = self.call_counts.get("str", 0) + 1 + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_single_creation(): + """Shared dep (int) is created once via pending sentinel.""" + call_counts: dict = {} + container = make_async_container( + DiamondSingleCreationProvider(call_counts), + ) + result = await container.get(bytes) + assert result == b"42.0-42" + assert call_counts["int"] == 1 + assert call_counts["float"] == 1 + assert call_counts["str"] == 1 + await container.close() + + +class DiamondTimingProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.1) + return 1 + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + await asyncio.sleep(0.1) + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + await asyncio.sleep(0.1) + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_concurrent_timing(): + """Diamond deps (float, str) are gathered concurrently.""" + container = make_async_container(DiamondTimingProvider()) + start = asyncio.get_running_loop().time() + result = await container.get(bytes) + elapsed = asyncio.get_running_loop().time() - start + + assert result == b"1.0-1" + # Sequential: int(0.1) + float(0.1) + str(0.1) = 0.3s + # Gathered: int(0.1) + max(float(0.1), str(0.1)) = 0.2s + assert elapsed < 0.25, ( + f"Took {elapsed:.3f}s, expected < 0.25s (diamond should be gathered)" + ) + await container.close() + + +class DiamondErrorProvider(Provider): + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + msg = "shared dep failed" + raise ValueError(msg) + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_error_propagation(): + """Error in shared transitive dep propagates correctly through gather.""" + container = make_async_container(DiamondErrorProvider()) + with pytest.raises(ValueError, match="shared dep failed"): + await container.get(bytes) + await container.close() + + +class DiamondAsyncGeneratorProvider(Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + + @provide(scope=Scope.APP) + async def get_int(self) -> AsyncIterator[int]: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.05) + yield 42 + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_async_generator_cached(): + """Shared async generator dep is created once, exit registered once.""" + call_counts: dict = {} + container = make_async_container( + DiamondAsyncGeneratorProvider(call_counts), + ) + result = await container.get(bytes) + assert result == b"42.0-42" + assert call_counts["int"] == 1 + await container.close() + + +# --- cache=False shared transitive dep tests --- +# These test that gathering works correctly when a shared transitive dep +# is uncached. Each branch should create its own independent instance. + + +class DiamondUncachedFactoryProvider(Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + self._counter = 0 + + @provide(scope=Scope.APP, cache=False) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + self._counter += 1 + await asyncio.sleep(0.1) + return self._counter + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_uncached_factory_creates_per_consumer(): + """cache=False shared dep is created independently per consumer branch.""" + call_counts: dict = {} + container = make_async_container( + DiamondUncachedFactoryProvider(call_counts), + ) + await container.get(bytes) + # Each branch (float, str) creates its own int independently + assert call_counts["int"] == 2 + await container.close() + + +@pytest.mark.asyncio +async def test_diamond_uncached_factory_concurrent_timing(): + """cache=False shared dep branches run concurrently when gathered.""" + call_counts: dict = {} + container = make_async_container( + DiamondUncachedFactoryProvider(call_counts), + ) + start = asyncio.get_running_loop().time() + await container.get(bytes) + elapsed = asyncio.get_running_loop().time() - start + # Sequential: int(0.1) + int(0.1) = 0.2s + # Gathered: max(int(0.1), int(0.1)) = 0.1s + assert elapsed < 0.15, ( + f"Took {elapsed:.3f}s, expected < 0.15s (should gather uncached)" + ) + assert call_counts["int"] == 2 + await container.close() + + +class DiamondUncachedAsyncGenProvider(Provider): + def __init__(self, call_counts: dict, closed: list): + super().__init__() + self.call_counts = call_counts + self.closed = closed + self._counter = 0 + + @provide(scope=Scope.APP, cache=False) + async def get_int(self) -> AsyncIterator[int]: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + self._counter += 1 + val = self._counter + await asyncio.sleep(0.05) + yield val + self.closed.append(val) + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_uncached_async_generator_per_consumer(): + """cache=False async generator: each branch gets own instance.""" + call_counts: dict = {} + closed: list = [] + container = make_async_container( + DiamondUncachedAsyncGenProvider(call_counts, closed), + ) + await container.get(bytes) + # Each branch creates its own generator + assert call_counts["int"] == 2 + await container.close() + # Both generators should have been closed + assert len(closed) == 2 + + +class DiamondUncachedErrorProvider(Provider): + @provide(scope=Scope.APP, cache=False) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + msg = "uncached dep fails" + raise ValueError(msg) + + @provide(scope=Scope.APP) + async def get_float(self, v: int) -> float: + return float(v) + + @provide(scope=Scope.APP) + async def get_str(self, v: int) -> str: + return str(v) + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_uncached_error_one_branch(): + """cache=False shared dep: error propagates correctly through gather.""" + container = make_async_container(DiamondUncachedErrorProvider()) + with pytest.raises(ValueError, match="uncached dep fails"): + await container.get(bytes) + await container.close() + + +class DiamondMixedCacheProvider(Provider): + def __init__(self, call_counts: dict): + super().__init__() + self.call_counts = call_counts + self._uncached_counter = 0 + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + self.call_counts["int"] = self.call_counts.get("int", 0) + 1 + await asyncio.sleep(0.05) + return 42 + + @provide(scope=Scope.APP, cache=False) + async def get_complex(self) -> complex: + self.call_counts["complex"] = self.call_counts.get("complex", 0) + 1 + self._uncached_counter += 1 + await asyncio.sleep(0.1) + return complex(self._uncached_counter, 0) + + @provide(scope=Scope.APP) + async def get_float(self, v: int, c: complex) -> float: + return float(v) + c.real + + @provide(scope=Scope.APP) + async def get_str(self, v: int, c: complex) -> str: + return f"{v}+{c.real}" + + @provide(scope=Scope.APP) + async def get_bytes(self, a: float, b: str) -> bytes: + return f"{a}-{b}".encode() + + +@pytest.mark.asyncio +async def test_diamond_mixed_cached_uncached(): + """Mixed diamond: float and str branches run concurrently. + + int is cached (created once via pending sentinel), complex is uncached + (created per branch). Timing proves concurrency: sequential would take + int(0.05) + complex(0.1) + complex(0.1) = 0.25s; gathered takes + int(0.05) + max(complex(0.1), complex(0.1)) = 0.15s. + """ + call_counts: dict = {} + container = make_async_container(DiamondMixedCacheProvider(call_counts)) + start = asyncio.get_running_loop().time() + await container.get(bytes) + elapsed = asyncio.get_running_loop().time() - start + + assert call_counts["int"] == 1 # cached — created once + assert call_counts["complex"] == 2 # uncached — created per branch + assert elapsed < 0.2, ( + f"Took {elapsed:.3f}s, expected < 0.2s (should gather mixed)" + ) + await container.close() diff --git a/tests/unit/container/test_concurrency.py b/tests/unit/container/test_concurrency.py index 203e0efc2..60a2ebae9 100644 --- a/tests/unit/container/test_concurrency.py +++ b/tests/unit/container/test_concurrency.py @@ -87,3 +87,82 @@ async def test_cache_async(): await t2 int_getter.assert_called_once_with() + + +# --- Approach B: Pending dedup & gather tests --- + + +class PendingDedupProvider(Provider): + """Provider where the int factory is slow — used to test dedup.""" + + def __init__(self, mock: Mock): + super().__init__() + self.mock = mock + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return self.mock() + + @provide(scope=Scope.APP) + def get_str(self, value: int) -> str: + return str(value) + + +@pytest.mark.repeat(10) +@pytest.mark.asyncio +async def test_pending_dedup_no_lock(): + """Concurrent gets without lock — pending dedup works.""" + int_getter = Mock(return_value=42) + provider = PendingDedupProvider(int_getter) + + container = make_async_container(provider, lock_factory=None) + t1 = asyncio.create_task(container.get(str)) + t2 = asyncio.create_task(container.get(str)) + r1, r2 = await asyncio.gather(t1, t2) + + int_getter.assert_called_once_with() + assert r1 == "42" + assert r2 == "42" + await container.close() + + +class FailingProvider(Provider): + def __init__(self, mock: Mock): + super().__init__() + self.mock = mock + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await asyncio.sleep(0.05) + return self.mock() + + +@pytest.mark.asyncio +async def test_pending_exception_propagation(): + """If a pending dep fails, waiters get the exception; retry works.""" + call_count = 0 + + def failing_then_ok(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("boom") + return 42 + + mock = Mock(side_effect=failing_then_ok) + provider = FailingProvider(mock) + container = make_async_container(provider, lock_factory=None) + + # First concurrent calls — both should fail + t1 = asyncio.create_task(container.get(int)) + t2 = asyncio.create_task(container.get(int)) + + results = await asyncio.gather(t1, t2, return_exceptions=True) + assert all(isinstance(r, ValueError) for r in results) + assert mock.call_count == 1 # only one actual call + + # Retry should work (pending removed from cache) + result = await container.get(int) + assert result == 42 + await container.close()