Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/dishka/code_tools/code_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, *, is_async: bool) -> None:
self.async_str = ""

self._is_async = is_async
self.is_async = is_async

def _make_global_name(self, obj: Any, name: str | None = None) -> str:
if name is None:
Expand Down Expand Up @@ -175,7 +176,7 @@ def try_(self) -> AbstractContextManager[None]:

def except_(
self,
exception: type[Exception],
exception: type[BaseException],
as_: str = "",
) -> AbstractContextManager[None]:
name = self.global_(exception)
Expand Down
168 changes: 157 additions & 11 deletions src/dishka/code_tools/factory_compiler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import contextlib
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager
from typing import Any, TypeAlias, cast

from dishka.code_tools.code_builder import CodeBuilder
from dishka.container_objects import CompiledFactory
from dishka.container_objects import CompiledFactory, _Pending
from dishka.dependency_source import Factory
from dishka.entities.component import Component
from dishka.entities.factory_type import FactoryType
Expand Down Expand Up @@ -89,14 +90,82 @@
"getter", self.global_(obj.as_compilation_key()),
))

def is_awaitable_dep(
self,
obj: DependencyKey,
compiled_deps: dict[DependencyKey, CompiledFactory],

Check warning on line 96 in src/dishka/code_tools/factory_compiler.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the unused function parameter "compiled_deps".

See more on https://sonarcloud.io/project/issues?id=reagento_dishka&issues=AZ1SZRDewO6MK7c7lFIm&open=AZ1SZRDewO6MK7c7lFIm&pullRequest=704
) -> bool:
if obj.is_const():
return False
if obj.type_hint is DependencyKey:
return False
if obj == self.container_key:
return False
return True

def getter_coro(
self,
obj: DependencyKey,
compiled_deps: dict[DependencyKey, CompiledFactory],
) -> str:
if obj in compiled_deps:
factory = self.global_(compiled_deps[obj])
return self.call(
factory,
"getter", "exits", "cache", "context", "container",
)
return self.call(
"getter", self.global_(obj.as_compilation_key()),
)

def cache(self, factory: Factory) -> None:
if factory.cache and factory.type is not FactoryType.CONTEXT:
self.assign_expr(f"cache[{self.cache_key}]", "solved")

def return_if_cached(self, factory: Factory) -> None:
if factory.cache and factory.type is not FactoryType.CONTEXT:
with self.if_(f"{self.cache_key} in cache"):
self.return_(f"cache[{self.cache_key}]")
if self.is_async:
pending_cls = self.global_(_Pending, "_Pending")
with self.if_(f"{self.cache_key} in cache"):
self.assign_local(
"_cached", f"cache[{self.cache_key}]",
)
with self.if_(
f"isinstance(_cached, {pending_cls})",
):
self.return_(self.await_("_cached"))
self.return_("_cached")
else:
with self.if_(f"{self.cache_key} in cache"):
self.return_(f"cache[{self.cache_key}]")

@property
def _uses_pending(self) -> bool:
return self.is_async

def place_pending(self, factory: Factory) -> None:
if not self._uses_pending:
return
if not factory.cache or factory.type is FactoryType.CONTEXT:
return
pending_cls = self.global_(_Pending, "_Pending")
self.assign_local("_pending", self.call(pending_cls))
self.assign_expr(f"cache[{self.cache_key}]", "_pending")

def pending_resolve(self, factory: Factory) -> None:
if not self._uses_pending:
return
if not factory.cache or factory.type is FactoryType.CONTEXT:
return
self.statement(self.call("_pending.set_result", "solved"))

def pending_reject(self, factory: Factory) -> None:
if not self._uses_pending:
return
if not factory.cache or factory.type is FactoryType.CONTEXT:
return
self.statement(f"cache.pop({self.cache_key}, None)")
self.statement(self.call("_pending.set_exception", "_exc"))

def assign_solved(self, expr: str) -> None:
self.assign_local("solved", expr)
Expand Down Expand Up @@ -383,6 +452,50 @@
return False


def _gather_async_deps(
builder: FactoryBuilder,
factory: Factory,
compiled_deps: dict[DependencyKey, CompiledFactory],
) -> dict[DependencyKey, str] | None:
"""Generate asyncio.gather for 2+ awaitable deps. Returns dep->var map."""
if not builder.is_async:
return None
if factory.when_dependencies:
return None

all_deps: list[tuple[DependencyKey, str | None]] = []
for dep in factory.dependencies:
all_deps.append((dep, None))
for name, dep in factory.kw_dependencies.items():
all_deps.append((dep, name))

awaitable_deps = [
(dep, kw_name)
for dep, kw_name in all_deps
if builder.is_awaitable_dep(dep, compiled_deps)
]

if len(awaitable_deps) < 2: # noqa: PLR2004
return None

gather_fn = builder.global_(asyncio.gather, "asyncio_gather")
coros = [
builder.getter_coro(dep, compiled_deps)
for dep, _ in awaitable_deps
]
var_names = [f"_g{i}" for i in range(len(awaitable_deps))]
for var_name in var_names:
builder.locals.add(var_name)
targets = ", ".join(var_names)
gather_call = builder.call(gather_fn, *coros)
builder.statement(f"{targets} = await {gather_call}")

result: dict[DependencyKey, str] = {}
for (dep, _), var_name in zip(awaitable_deps, var_names, strict=True):
result[dep] = var_name
return result


def _make_body(
builder: FactoryBuilder,
factory: Factory,
Expand All @@ -395,14 +508,23 @@
builder, factory, compiled_deps,
)
if not has_default:
gathered = _gather_async_deps(
builder, factory, compiled_deps,
)

def _dep_expr(dep: DependencyKey) -> str:
if gathered and dep in gathered:
return gathered[dep]
return builder.getter(dep, compiled_deps)

source_call = builder.call(
builder.global_(factory.source),
*(
builder.getter(dep, compiled_deps)
_dep_expr(dep)
for dep in factory.dependencies
),
**{
name: builder.getter(dep, compiled_deps)
name: _dep_expr(dep)
for name, dep in factory.kw_dependencies.items()
},
)
Expand All @@ -426,6 +548,14 @@
)


def _use_pending(factory: Factory, *, is_async: bool) -> bool:
return (
is_async
and factory.cache
and factory.type is not FactoryType.CONTEXT
)


def compile_factory(
*,
factory: Factory,
Expand All @@ -444,16 +574,32 @@
container_key=container_key,
)
builder.register_provides(factory.provides)
use_pending = _use_pending(factory, is_async=is_async)

with builder.make_getter():
builder.return_if_cached(factory)
if _has_deps(factory):
with builder.handle_no_dep(factory):
_make_body(builder, factory, compiled_deps)
if use_pending:
builder.place_pending(factory)
with builder.try_():
if _has_deps(factory):
with builder.handle_no_dep(factory):
_make_body(builder, factory, compiled_deps)
else:
_make_body(builder, factory, compiled_deps)
builder.cache(factory)
builder.pending_resolve(factory)
builder.return_("solved")
with builder.except_(BaseException, as_="_exc"):
builder.pending_reject(factory)
builder.raise_()
else:
_make_body(builder, factory, compiled_deps)
builder.cache(factory)
builder.return_("solved")
if _has_deps(factory):
with builder.handle_no_dep(factory):
_make_body(builder, factory, compiled_deps)
else:
_make_body(builder, factory, compiled_deps)
builder.cache(factory)
builder.return_("solved")

return builder.build_getter()

Expand Down
27 changes: 27 additions & 0 deletions src/dishka/container_objects.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import asyncio
from abc import abstractmethod
from collections.abc import AsyncGenerator, Callable, Generator
from typing import Any, Protocol, TypeAlias

from dishka.entities.key import CompilationKey


class _Pending:
"""Sentinel placed in the cache while an async factory is being resolved.

If a concurrent coroutine (from asyncio.gather) tries to resolve the same
dependency, it finds this sentinel and awaits the embedded Future instead
of creating a duplicate.
"""

__slots__ = ("_future",)

def __init__(self) -> None:
self._future = asyncio.get_running_loop().create_future()

def set_result(self, result: Any) -> None:
if not self._future.done():
self._future.set_result(result)

def set_exception(self, exc: BaseException) -> None:
if not self._future.done():
self._future.set_exception(exc)

def __await__(self) -> Any:
return self._future.__await__()


Exit: TypeAlias = tuple[
Generator[Any, Any, Any] | None,
AsyncGenerator[Any, Any] | None,
Expand Down
Loading