Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 220 additions & 41 deletions src/dishka/code_tools/factory_compiler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import contextlib
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager
from typing import Any, TypeAlias, cast

from dishka.code_tools.code_builder import CodeBuilder
from dishka.container_objects import CompiledFactory
from dishka.container_objects import CompiledFactory, _Pending
from dishka.dependency_source import Factory
from dishka.entities.component import Component
from dishka.entities.factory_type import FactoryType
Expand All @@ -27,11 +28,11 @@

class FactoryBuilder(CodeBuilder):
def __init__(
self,
*,
is_async: bool,
getter_prefix: str,
container_key: DependencyKey,
self,
*,
is_async: bool,
getter_prefix: str,
container_key: DependencyKey,
) -> None:
super().__init__(is_async=is_async)
self.provides_name = ""
Expand Down Expand Up @@ -82,21 +83,50 @@
return self.await_(
self.call(
factory,
"getter", "exits", "cache", "context", "container",
"getter",
"exits",
"cache",
"context",
"container",
),
)
return self.await_(self.call(
"getter", self.global_(obj.as_compilation_key()),
))
return self.await_(
self.call(
"getter",
self.global_(obj.as_compilation_key()),
)
)

def _is_pending_aware(self, factory: Factory) -> bool:
return (
self._is_async
and factory.cache
and factory.type is not FactoryType.CONTEXT
)

def cache(self, factory: Factory) -> None:
if factory.cache and factory.type is not FactoryType.CONTEXT:
self.assign_expr(f"cache[{self.cache_key}]", "solved")
if self._is_pending_aware(factory):
self.statement("_pending.set_result(solved)")

def return_if_cached(self, factory: Factory) -> None:
if factory.cache and factory.type is not FactoryType.CONTEXT:
with self.if_(f"{self.cache_key} in cache"):
self.return_(f"cache[{self.cache_key}]")
if self._is_pending_aware(factory):
pending_cls = self.global_(_Pending, "_Pending")
self.assign_local("_cached", f"cache[{self.cache_key}]")
with self.if_(f"isinstance(_cached, {pending_cls})"):
self.return_(self.await_("_cached"))
self.return_("_cached")
else:
self.return_(f"cache[{self.cache_key}]")

def place_pending(self, factory: Factory) -> None:
if self._is_pending_aware(factory):
pending_cls = self.global_(_Pending, "_Pending")
self.assign_local("_pending", self.call(pending_cls))
self.assign_expr(f"cache[{self.cache_key}]", "_pending")

def assign_solved(self, expr: str) -> None:
self.assign_local("solved", expr)
Expand Down Expand Up @@ -133,9 +163,30 @@
f" with marker {marker}",
)
return self.getter(
DependencyKey(marker, component), compiled_deps,
DependencyKey(marker, component),
compiled_deps,
)

def getter_coro(
self,
obj: DependencyKey,
compiled_deps: dict[DependencyKey, CompiledFactory],
) -> str:
if obj in compiled_deps:
factory = self.global_(compiled_deps[obj])
return self.call(
factory,
"getter",
"exits",
"cache",
"context",
"container",
)
return self.call(
"getter",
self.global_(obj.as_compilation_key()),
)

def build_getter(self) -> CompiledFactory:
name = f"<{self.getter_name}{'_async' if self.async_str else ''}>"
return cast(CompiledFactory, self.compile(name)[self.getter_name])
Expand Down Expand Up @@ -295,7 +346,9 @@
assigned = False
for variant in factory.when_dependencies:
condition = builder.when(
variant.when_override, variant.when_component, compiled_deps,
variant.when_override,
variant.when_component,
compiled_deps,
)
if condition:
if not assigned:
Expand Down Expand Up @@ -363,7 +416,9 @@
first = True
for variant in factory.when_dependencies:
condition = builder.when(
variant.when_override, factory.when_component, compiled_deps,
variant.when_override,
factory.when_component,
compiled_deps,
)
solved_value = builder.getter(variant.provides, compiled_deps)
if first and not condition:
Expand All @@ -383,46 +438,126 @@
return False


def _is_sync_dep(
dep: DependencyKey,
compiled_deps: dict[DependencyKey, CompiledFactory],

Check warning on line 443 in src/dishka/code_tools/factory_compiler.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the unused function parameter "compiled_deps".

See more on https://sonarcloud.io/project/issues?id=reagento_dishka&issues=AZ1SZUMFTIwSMyARwclZ&open=AZ1SZUMFTIwSMyARwclZ&pullRequest=705
container_key: DependencyKey,
) -> bool:
return (
dep.is_const()
or dep.type_hint is DependencyKey
or dep == container_key
)


def _make_gathered_source_call(
builder: FactoryBuilder,
factory: Factory,
compiled_deps: dict[DependencyKey, CompiledFactory],
) -> str | None:
async_deps: list[tuple[int, int | str, DependencyKey]] = []
pos_exprs: dict[int, str] = {}
kw_exprs: dict[str, str] = {}

gather_idx = 0
for i, dep in enumerate(factory.dependencies):
if _is_sync_dep(dep, compiled_deps, builder.container_key):
pos_exprs[i] = builder.getter(dep, compiled_deps)
else:
async_deps.append((gather_idx, i, dep))
gather_idx += 1

for name, dep in factory.kw_dependencies.items():
if _is_sync_dep(dep, compiled_deps, builder.container_key):
kw_exprs[name] = builder.getter(dep, compiled_deps)
else:
async_deps.append((gather_idx, name, dep))
gather_idx += 1

if len(async_deps) < 2: # noqa: PLR2004
return None

coro_exprs = [
builder.getter_coro(dep, compiled_deps) for _, _, dep in async_deps
]
gather_func = builder.global_(asyncio.gather, "asyncio_gather")
builder.assign_local(
"_deps",
builder.await_(builder.call(gather_func, *coro_exprs)),
)

for gather_i, key_or_idx, _ in async_deps:
if isinstance(key_or_idx, int):
pos_exprs[key_or_idx] = f"_deps[{gather_i}]"
else:
kw_exprs[key_or_idx] = f"_deps[{gather_i}]"

ordered_pos = [pos_exprs[i] for i in range(len(factory.dependencies))]
return builder.call(
builder.global_(factory.source),
*ordered_pos,
**kw_exprs,
)


def _make_body(
builder: FactoryBuilder,
factory: Factory,
compiled_deps: dict[DependencyKey, CompiledFactory],
builder: FactoryBuilder,
factory: Factory,
compiled_deps: dict[DependencyKey, CompiledFactory],
*,
can_gather: bool = False,
) -> None:
if factory.type is FactoryType.COLLECTION:
_collection_factory_body(builder, factory, compiled_deps)
else:
has_default = _select_when_dependency(
builder, factory, compiled_deps,
builder,
factory,
compiled_deps,
)
if not has_default:
source_call = builder.call(
builder.global_(factory.source),
*(
builder.getter(dep, compiled_deps)
for dep in factory.dependencies
),
**{
name: builder.getter(dep, compiled_deps)
for name, dep in factory.kw_dependencies.items()
},
)
source_call = None
if can_gather and builder.async_str:
source_call = _make_gathered_source_call(
builder,
factory,
compiled_deps,
)
if source_call is None:
source_call = builder.call(
builder.global_(factory.source),
*(
builder.getter(dep, compiled_deps)
for dep in factory.dependencies
),
**{
name: builder.getter(dep, compiled_deps)
for name, dep in factory.kw_dependencies.items()
},
)
body_generator = BODY_GENERATORS[factory.type]
if factory.when_dependencies: # conditions generated
with builder.else_():
body_generator(
builder, source_call, factory, compiled_deps,
builder,
source_call,
factory,
compiled_deps,
)
else: # no options at all
body_generator(
builder, source_call, factory, compiled_deps,
builder,
source_call,
factory,
compiled_deps,
)


def _has_deps(factory: Factory) -> bool:
return bool(
factory.dependencies or
factory.kw_dependencies or
factory.when_dependencies,
factory.dependencies
or factory.kw_dependencies
or factory.when_dependencies,
)


Expand All @@ -432,6 +567,7 @@
is_async: bool,
compiled_deps: dict[DependencyKey, CompiledFactory],
container_key: DependencyKey,
can_gather: bool = False,
) -> CompiledFactory:
if not is_async and factory.type in ASYNC_TYPES:
raise UnsupportedFactoryError(factory)
Expand All @@ -445,15 +581,56 @@
)
builder.register_provides(factory.provides)

pending_aware = (
is_async and factory.cache and factory.type is not FactoryType.CONTEXT
)

with builder.make_getter():
builder.return_if_cached(factory)
if _has_deps(factory):
with builder.handle_no_dep(factory):
_make_body(builder, factory, compiled_deps)
builder.place_pending(factory)
if pending_aware:
with builder.try_():
if _has_deps(factory):
with builder.handle_no_dep(factory):
_make_body(
builder,
factory,
compiled_deps,
can_gather=can_gather,
)
else:
_make_body(
builder,
factory,
compiled_deps,
can_gather=can_gather,
)
builder.cache(factory)
builder.return_("solved")
with builder.except_(BaseException, as_="_exc"):
builder.statement(
f"cache.pop({builder.cache_key}, None)",
)
builder.statement("_pending.set_exception(_exc)")
builder.raise_()
else:
_make_body(builder, factory, compiled_deps)
builder.cache(factory)
builder.return_("solved")
if _has_deps(factory):
with builder.handle_no_dep(factory):
_make_body(
builder,
factory,
compiled_deps,
can_gather=can_gather,
)
else:
_make_body(
builder,
factory,
compiled_deps,
can_gather=can_gather,
)
builder.cache(factory)
builder.return_("solved")

return builder.build_getter()

Expand All @@ -473,7 +650,9 @@
builder.register_provides(factory.provides)
with builder.make_getter():
condition = builder.when(
factory.when_active, factory.when_component, compiled_deps,
factory.when_active,
factory.when_component,
compiled_deps,
)
if not condition:
builder.return_(builder.global_(True))
Expand Down
Loading