diff --git a/docs/advanced/when.rst b/docs/advanced/when.rst index cb0b85873..8e3e6c46e 100644 --- a/docs/advanced/when.rst +++ b/docs/advanced/when.rst @@ -16,6 +16,7 @@ This can be achieved with "activation" approach. Key concepts here: * **Activator** or **activation function** - special function registered in provider and taking decision if marker is active or not. * **activation condition** - expression with marker objects set in dependency source dynamically associated with activators to select between multiple implementations or enable decorators +Activators can be called preliminary or multiple times, so avoid acquiring resources or doing heavy calculations, if necessary, move such things into factories or context data. .. note:: @@ -28,9 +29,9 @@ To set conditional activation you create special ``Marker`` objects and use them .. code-block:: python - from dishka import Provider, provide, Scope + from dishka import Marker, Provider, provide, Scope - class MyProvider(Provider) + class MyProvider(Provider): @provide(scope=Scope.APP) def base_impl(self) -> Cache: return NormalCacheImpl() @@ -49,9 +50,9 @@ It can be the same or another provider while you pass when creating a container. .. code-block:: python - from dishka import activate, Provider + from dishka import activate, Marker, Provider - class MyProvider(Provider) + class MyProvider(Provider): @activate(Marker("debug")) def is_debug(self) -> bool: return False @@ -60,7 +61,7 @@ This function can use other objects as well. For example, we can pass config usi .. code-block:: python - class MyProvider(Provider) + class MyProvider(Provider): config = from_context(Config, scope=Scope.APP) @activate(Marker("debug")) @@ -78,7 +79,7 @@ More general pattern is to create own marker type and register a single activato pass - class MyProvider(Provider) + class MyProvider(Provider): config = from_context(Config, scope=Scope.APP) @activate(EnvMarker) @@ -161,7 +162,7 @@ For example: from dishka import Provider, provide, Scope - class MyProvider(Provider) + class MyProvider(Provider): config = from_context(RedisConfig, scope=Scope.APP) @provide(scope=Scope.APP) @@ -184,4 +185,28 @@ In this case, * ``memcached_impl`` is not used because no factory for ``MemcachedConfig`` is provided * ``redis_impl`` is not used while it is registered as ``from_context`` but no real value is provided. -* ``base_impl`` is used as a default one, because none of later is active \ No newline at end of file +* ``base_impl`` is used as a default one, because none of later is active + + +Preliminary (static) evaluation and graph validation +------------------------------------------------------------ + +In certain cases activator can be called during graph building step, this allows avoid unnecessary calls in runtime and ignore errors on factories which are never called. + +Static evaluation is enabled only if activator a sync non-generator function with dependencies retrieved from root context or without dependencies at all. +For example, in the following code ``redis_impl`` is never called because ``RedisConfig`` is not passed, so it won't be validated at all. + + +.. code-block:: python + + from dishka import Provider, provide, Scope + + class MyProvider(Provider): + config = from_context(RedisConfig, scope=Scope.APP) + + @provide(when=Has(RedisConfig), scope=Scope.APP) + def redis_impl(self, config: RedisConfig) -> Cache: + return RedisCache(config) + + container = make_container(MyProvider, context={}) + diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index b4b6b7cde..ca66fae15 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -251,6 +251,7 @@ def _get_sync(self, key: CompilationKey) -> Any: self._cache, self._context, self, + self._has_sync, ) async def _get(self, key: CompilationKey) -> Any: @@ -296,6 +297,7 @@ async def _get_unlocked(self, key: CompilationKey) -> Any: self._cache, self._context, self, + self._has, ) async def close(self, exception: BaseException | None = None) -> None: @@ -350,6 +352,23 @@ async def _has(self, marker: CompilationKey) -> bool: self._cache, self._context, self, + self._has, + )) + + def _has_sync(self, marker: CompilationKey) -> bool: + compiled = self.registry.get_compiled_activation(marker) + if not compiled: + if not self.parent_container: + return False + return self.parent_container._has_sync(marker) # noqa: SLF001 + + return bool(compiled( + self._get_sync, + self._exits, + self._cache, + self._context, + self, + self._has_sync, )) def _has_context(self, marker: Any) -> bool: @@ -357,6 +376,10 @@ def _has_context(self, marker: Any) -> bool: class HasProvider(Provider): + """ + This provider is used only for direct access on Has/HasContext. + Basic implementation is inlined in code builder. + """ @activate(Has) async def has( self, @@ -392,9 +415,11 @@ def make_async_container( has_provider = HasProvider() builder = GraphBuilder( scopes=scopes, + start_scope=start_scope, container_key=CONTAINER_KEY, skip_validation=skip_validation, validation_settings=validation_settings, + root_context=context or {}, ) builder.add_multicomponent_providers(has_provider) builder.add_providers(*providers) diff --git a/src/dishka/code_tools/code_builder.py b/src/dishka/code_tools/code_builder.py index fa51f84d0..80bc88e0e 100644 --- a/src/dishka/code_tools/code_builder.py +++ b/src/dishka/code_tools/code_builder.py @@ -133,9 +133,10 @@ def call(self, func: str, *args: str, **kwargs: str) -> str: args_list.extend(f"{name}={value}" for name, value in kwargs.items()) if len(args_list) > MAX_ITEMS_PER_LINE: - args_str = ",\n".join(args_list) + sep = ",\n " + self.indent_str + " "*8 else: - args_str = ", ".join(args_list) + sep = ", " + args_str = sep.join(args_list) return f"{func}({args_str})" def await_(self, expr: str) -> str: @@ -239,16 +240,18 @@ def for_(self, name: str, expr: str) -> Iterator[None]: def list_literal(self, *items: str) -> str: if len(items) > MAX_ITEMS_PER_LINE: - items_str = "\n, ".join(items) + sep = ",\n " + self.indent_str + " "*8 else: - items_str = ", ".join(items) + sep = ", " + items_str = sep.join(items) return f"[{items_str}]" def tuple_literal(self, *items: str) -> str: if len(items) > MAX_ITEMS_PER_LINE: - items_str = "\n, ".join(items) + sep = ",\n " + self.indent_str + " "*8 else: - items_str = ", ".join(items) + sep = ", " + items_str = sep.join(items) return f"({items_str})" def compile(self, source_file_name: str) -> dict[str, Any]: diff --git a/src/dishka/code_tools/factory_compiler.py b/src/dishka/code_tools/factory_compiler.py index f10a4e2d5..a553f1f23 100644 --- a/src/dishka/code_tools/factory_compiler.py +++ b/src/dishka/code_tools/factory_compiler.py @@ -13,6 +13,8 @@ AndMarker, BaseMarker, BoolMarker, + Has, + HasContext, NotMarker, OrMarker, ) @@ -63,7 +65,7 @@ def make_getter(self) -> AbstractContextManager[None]: self.getter_name = self.getter_prefix + raw_provides_name return self.def_( self.getter_name, - ["getter", "exits", "cache", "context", "container"], + ["getter", "exits", "cache", "context", "container", "has"], ) def getter( @@ -82,7 +84,7 @@ def getter( return self.await_( self.call( factory, - "getter", "exits", "cache", "context", "container", + "getter", "exits", "cache", "context", "container", "has", ), ) return self.await_(self.call( @@ -101,7 +103,10 @@ def return_if_cached(self, factory: Factory) -> None: def assign_solved(self, expr: str) -> None: self.assign_local("solved", expr) - def when( + def _has_context(self, type_: str) -> str: + return f"(context is not None and {type_} in context)" + + def when( # noqa: PLR0911 self, marker: BaseMarker | None, component: Component | None, @@ -126,6 +131,14 @@ def when( ) case BoolMarker(False): return self.global_(marker.value) + case Has(): + key = DependencyKey(marker.value, component) + return self.await_(self.call( + "has", + self.global_(key.as_compilation_key()), + )) + case HasContext(): + return self._has_context(self.global_(marker.value)) case _: if component is None: raise TypeError( # noqa: TRY003 diff --git a/src/dishka/container.py b/src/dishka/container.py index 5381eab24..2ec0e4dcb 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -226,6 +226,7 @@ def _get_unlocked(self, key: CompilationKey) -> Any: self._cache, self._context, self, + self._has, ) def close(self, exception: BaseException | None = None) -> None: @@ -278,6 +279,7 @@ def _has(self, marker: CompilationKey) -> bool: self._cache, self._context, self, + self._has, )) def _has_context(self, marker: Any) -> bool: @@ -285,6 +287,10 @@ def _has_context(self, marker: Any) -> bool: class HasProvider(Provider): + """ + This provider is used only for direct access on Has/HasContext. + Basic implementation is inlined in code builder. + """ @activate(Has) def has( self, @@ -317,7 +323,9 @@ def make_container( context_provider = make_root_context_provider(providers, context, scopes) has_provider = HasProvider() builder = GraphBuilder( + root_context=context or {}, scopes=scopes, + start_scope=start_scope, container_key=CONTAINER_KEY, skip_validation=skip_validation, validation_settings=validation_settings, diff --git a/src/dishka/container_objects.py b/src/dishka/container_objects.py index c2c55e55e..ca527abad 100644 --- a/src/dishka/container_objects.py +++ b/src/dishka/container_objects.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from collections.abc import AsyncGenerator, Callable, Generator +from collections.abc import AsyncGenerator, Awaitable, Callable, Generator from typing import Any, Protocol, TypeAlias from dishka.entities.key import CompilationKey @@ -19,5 +19,6 @@ def __call__( cache: Any, context: Any, container: Any, + has: Callable[[CompilationKey], bool | Awaitable[bool]], ) -> Any: raise NotImplementedError diff --git a/src/dishka/dependency_source/activator.py b/src/dishka/dependency_source/activator.py index 33ac21842..78b243c8e 100644 --- a/src/dishka/dependency_source/activator.py +++ b/src/dishka/dependency_source/activator.py @@ -1,12 +1,22 @@ from typing import Any from dishka.entities.component import Component +from dishka.entities.factory_type import FactoryData from dishka.entities.key import DependencyKey, const_dependency_key from dishka.entities.marker import Marker from dishka.entities.scope import BaseScope from .factory import Factory +class StaticEvaluationUnavailable(Exception): # noqa: N818 + def __init__(self, factory: FactoryData) -> None: + self.factory = factory + + def __str__(self) -> str: + return (f"StaticEvaluationUnavailable({self.factory.provides}," + f" type={self.factory.type})") + + class Activator: __slots__ = ("factory", "marker", "marker_type") diff --git a/src/dishka/exceptions.py b/src/dishka/exceptions.py index a0aa608e6..6d4befa06 100644 --- a/src/dishka/exceptions.py +++ b/src/dishka/exceptions.py @@ -37,11 +37,14 @@ class NoContextValueError(DishkaError): class UnsupportedFactoryError(DishkaError): - def __init__(self, factory_type: FactoryData) -> None: - self.factory_type = factory_type + def __init__(self, factory_data: FactoryData) -> None: + self.factory_data = factory_data def __str__(self) -> str: - return f"Unsupported factory type {self.factory_type}." + name = get_source_name(self.factory_data) + return ( + f"Unsupported factory type {self.factory_data.type} at {name}" + ) class InvalidGraphError(DishkaError): diff --git a/src/dishka/graph_builder/activation.py b/src/dishka/graph_builder/activation.py new file mode 100644 index 000000000..911c44a8f --- /dev/null +++ b/src/dishka/graph_builder/activation.py @@ -0,0 +1,191 @@ +from collections.abc import Sequence +from functools import partial +from logging import getLogger +from typing import Any + +from dishka.container_objects import CompiledFactory +from dishka.dependency_source import Factory +from dishka.dependency_source.activator import StaticEvaluationUnavailable +from dishka.entities.factory_type import FactoryType +from dishka.entities.key import ( + CompilationKey, + DependencyKey, +) +from dishka.entities.marker import BoolMarker, Marker +from dishka.entities.scope import BaseScope +from dishka.exception_base import DishkaError +from dishka.registry import Registry +from dishka.text_rendering.name import get_source_name + +logger = getLogger(__name__) + + +class StaticRegistry(Registry): + def __init__( + self, + scope: BaseScope, + *, + has_fallback: bool, + container_key: DependencyKey, + is_root: bool, + ) -> None: + super().__init__( + scope, + has_fallback=has_fallback, + container_key=container_key, + ) + self.is_root = is_root + + def _is_static_allowed(self, factory: Factory) -> bool: + if factory.type in ( + FactoryType.VALUE, + FactoryType.ALIAS, + FactoryType.SELECTOR, + ): + return True + if self.is_root and factory.type == FactoryType.CONTEXT: + return True + if ( + isinstance(factory.provides.type_hint, Marker) + and factory.type is FactoryType.FACTORY + ): + return True + return False + + def _compile_factory(self, factory: Factory) -> CompiledFactory: + if not self._is_static_allowed(factory): + raise StaticEvaluationUnavailable(factory) + return super()._compile_factory(factory) + + def _compile_factory_async(self, factory: Factory) -> CompiledFactory: + if not self._is_static_allowed(factory): + raise StaticEvaluationUnavailable(factory) + return super()._compile_factory_async(factory) + + +def static_registry( + registry: Registry, + start_scope: BaseScope, +) -> StaticRegistry: + new = StaticRegistry( + registry.scope, + has_fallback=False, + container_key=registry.container_key, + is_root=registry.scope <= start_scope, + ) + new.factories = registry.factories + return new + + +class ActivationContainer: + def __init__( + self, + context: dict[Any, Any], + registries: dict[BaseScope, Registry], + container_key: DependencyKey, + ) -> None: + self._context = context + self._registries = registries + self._container_key = container_key.as_compilation_key() + + self._parent_scopes: dict[BaseScope, BaseScope | None] = {} + prev_scope = None + for scope in registries: + self._parent_scopes[scope] = prev_scope + prev_scope = scope + + def _get(self, dep: CompilationKey, scope: BaseScope) -> Any: + registry = self._registries[scope] + compiled = registry.get_compiled(dep) + if not compiled: + parent_scope = self._parent_scopes[scope] + if parent_scope is None: + return False + return self._get(dep, parent_scope) + return bool(compiled( + partial(self._get, scope=scope), + [], + {}, + self._context, + self, + partial(self._has, scope=scope), + )) + + def is_active(self, factory: Factory) -> bool: + marker = factory.provides.as_compilation_key() + if factory.scope is None: + error = f"{get_source_name(factory)} as not scope" + raise DishkaError(error) + registry = self._registries[factory.scope] + compiled = registry.get_compiled_activation(marker) + if not compiled: + raise RuntimeError + return bool(compiled( + partial(self._get, scope=factory.scope), + [], + {}, + self._context, + self, + partial(self._has, scope=factory.scope), + )) + + def _has(self, marker: CompilationKey, scope: BaseScope) -> bool: + if marker == self._container_key: + return True + registry = self._registries[scope] + compiled = registry.get_compiled_activation(marker) + if not compiled: + parent_scope = self._parent_scopes[scope] + if parent_scope is None: + return False + return self._has(marker, parent_scope) + return bool(compiled( + partial(self._get, scope=scope), + [], + {}, + self._context, + self, + partial(self._has, scope=scope), + )) + + +class StaticEvaluator: + def __init__( + self, + registries: Sequence[Registry], + context: dict[Any, Any], + container_key: DependencyKey, + scopes: type[BaseScope], + start_scope: BaseScope | None, + ) -> None: + if start_scope is None: + start_scope = next(s for s in scopes if not s.skip) + self.registries: dict[BaseScope, Registry] = { + registry.scope: static_registry(registry, start_scope) + for registry in registries + } + activation_container = ActivationContainer( + registries=self.registries, + container_key=container_key, + context=context, + ) + self.activation_container = activation_container + + def _eval_activation(self, factory: Factory) -> None: + try: + active = self.activation_container.is_active(factory) + except StaticEvaluationUnavailable as e: + logger.debug( + "Static evaluation for %s is not available: %s", + factory.provides, + e, + ) + return + if factory.when_override == factory.when_active: + factory.when_override = BoolMarker(active) + factory.when_active = BoolMarker(active) + + def evaluate_static(self) -> None: + for registry in self.registries.values(): + for factory in list(registry.factories.values()): + self._eval_activation(factory) diff --git a/src/dishka/graph_builder/builder.py b/src/dishka/graph_builder/builder.py index 9c4ca8ffb..6ebf97d08 100644 --- a/src/dishka/graph_builder/builder.py +++ b/src/dishka/graph_builder/builder.py @@ -1,7 +1,7 @@ import itertools from collections import defaultdict from collections.abc import Collection, Sequence -from typing import cast +from typing import Any, cast from dishka.dependency_source import ( Activator, @@ -29,6 +29,7 @@ from dishka.provider import BaseProvider, ProviderWrapper from dishka.registry import Registry from dishka.text_rendering.name import get_source_name +from .activation import StaticEvaluator from .moved_objects_tracker import MovedObjectsTracker from .uniter import ( CollectionGroupProcessor, @@ -42,11 +43,15 @@ def __init__( self, *, scopes: type[BaseScope], + start_scope: BaseScope | None, container_key: DependencyKey, skip_validation: bool, validation_settings: ValidationSettings, + root_context: dict[Any, Any], ) -> None: + self.root_context = root_context self.scopes = scopes + self.start_scope = start_scope self.container_key = container_key self.skip_validation = skip_validation self.validation_settings = validation_settings @@ -550,6 +555,13 @@ def build(self) -> Sequence[Registry]: self._get_activator_factories(fixed_factories, found_markers), ) registries = self._make_registries(fixed_factories) + StaticEvaluator( + registries, + self.root_context, + self.container_key, + self.scopes, + self.start_scope, + ).evaluate_static() if not self.skip_validation: GraphValidator(registries).validate() return registries diff --git a/src/dishka/graph_builder/validator.py b/src/dishka/graph_builder/validator.py index 50e6cdce4..1dc03f190 100644 --- a/src/dishka/graph_builder/validator.py +++ b/src/dishka/graph_builder/validator.py @@ -3,6 +3,7 @@ from dishka.dependency_source import Factory from dishka.entities.key import DependencyKey +from dishka.entities.marker import BoolMarker from dishka.exceptions import ( CycleDependenciesError, GraphMissingFactoryError, @@ -57,6 +58,12 @@ def _validate_key( def _validate_factory( self, factory: Factory, registry_index: int, ) -> None: + if ( + factory.when_active == BoolMarker(False) and + factory.when_override == BoolMarker(False) + ): + return # do not validate disabled factories + self.path[factory.provides] = factory if ( factory.provides in factory.kw_dependencies.values() or diff --git a/src/dishka/provider/make_alias.py b/src/dishka/provider/make_alias.py index 5e9b93c43..decd77b4f 100644 --- a/src/dishka/provider/make_alias.py +++ b/src/dishka/provider/make_alias.py @@ -7,7 +7,7 @@ ) from dishka.entities.component import Component from dishka.entities.key import hint_to_dependency_key -from dishka.entities.marker import Marker +from dishka.entities.marker import BaseMarker, Marker from dishka.exception_base import DishkaError from dishka.exceptions import WhenOverrideConflictError from .make_factory import calc_override @@ -34,7 +34,7 @@ def alias( cache: bool = True, component: Component | None = None, override: bool = False, - when: Marker | None = None, + when: BaseMarker | None = None, ) -> CompositeDependencySource: if component is provides is None: raise ValueError( # noqa: TRY003 diff --git a/src/dishka/provider/make_decorator.py b/src/dishka/provider/make_decorator.py index 93afdb645..0fb53485d 100644 --- a/src/dishka/provider/make_decorator.py +++ b/src/dishka/provider/make_decorator.py @@ -101,5 +101,12 @@ def decorate_on_instance( source: Callable[..., Any] | type, provides: Any, scope: BaseScope | None, + when: BaseMarker | None = None, ) -> CompositeDependencySource: - return _decorate(source, provides, scope=scope, is_in_class=False) + return _decorate( + source, + provides, + scope=scope, + is_in_class=False, + when=when, + ) diff --git a/src/dishka/provider/provider.py b/src/dishka/provider/provider.py index 38f40fe9a..81276b23c 100644 --- a/src/dishka/provider/provider.py +++ b/src/dishka/provider/provider.py @@ -218,6 +218,7 @@ def alias( cache: bool = True, component: Component | None = None, override: bool = False, + when: BaseMarker | None = None, ) -> CompositeDependencySource: composite = alias( source=source, @@ -225,6 +226,7 @@ def alias( cache=cache, component=component, override=override, + when=when, ) self._add_dependency_sources(composite.dependency_sources) return composite @@ -235,11 +237,13 @@ def decorate( *, provides: Any = None, scope: BaseScope | None = None, + when: BaseMarker | None = None, ) -> CompositeDependencySource: composite = decorate_on_instance( source=source, provides=provides, scope=scope, + when=when, ) self._add_dependency_sources(composite.dependency_sources) return composite diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 5572c2c30..d0a1003f7 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -30,7 +30,7 @@ DependencyKey, compilation_to_dependency_key, ) -from .entities.marker import Marker, unpack_marker +from .entities.marker import Has, HasContext, Marker, unpack_marker from .entities.scope import BaseScope IGNORE_TYPES: Final = ( @@ -81,9 +81,9 @@ def __init__( self.child_registry = child_registry def add_factory( - self, - factory: Factory, - provides: DependencyKey | None = None, + self, + factory: Factory, + provides: DependencyKey | None = None, ) -> None: if provides is None: provides = factory.provides @@ -97,29 +97,40 @@ def add_factory( ) self.factories[origin_key] = factory - def collect_deps(self, factory: Factory) -> list[DependencyKey]: + def collect_deps( + self, + factory: Factory, + activation_only: bool, # noqa: FBT001 + ) -> list[DependencyKey]: + activation_deps = ( + DependencyKey(m, f.when_component) + for f in factory.when_dependencies + for m in unpack_marker(f.when_override) + if not isinstance(m, (Has, HasContext)) + ) + if activation_only: + return list(activation_deps) return list(itertools.chain( factory.dependencies, factory.kw_dependencies.values(), (f.provides for f in factory.when_dependencies), - ( - DependencyKey(m, f.when_component) - for f in factory.when_dependencies - for m in unpack_marker(f.when_override) - ), + activation_deps, ( DependencyKey(m, factory.when_component) for marker in (factory.when_active, factory.when_override) for m in unpack_marker(marker) - ), + if not isinstance(m, (Has, HasContext)) + ) + , )) def _compile_deps( self, factory: Factory, + activation_only: bool, # noqa: FBT001 ) -> dict[DependencyKey, CompiledFactory]: res = {} - for dep in self.collect_deps(factory): + for dep in self.collect_deps(factory, activation_only): compiled = self.get_compiled(dep.as_compilation_key()) if compiled is not None: res[dep] = compiled @@ -128,16 +139,18 @@ def _compile_deps( def _compile_deps_async( self, factory: Factory, + activation_only: bool, # noqa: FBT001 ) -> dict[DependencyKey, CompiledFactory]: res = {} - for dep in self.collect_deps(factory): + for dep in self.collect_deps(factory, activation_only): compiled = self.get_compiled_async(dep.as_compilation_key()) if compiled is not None: res[dep] = compiled return res def get_compiled( - self, dependency: CompilationKey, + self, + dependency: CompilationKey, ) -> CompiledFactory | None: try: return self.compiled[dependency] @@ -158,17 +171,21 @@ def get_compiled( self.compiled[dependency] = None return None - compiled = compile_factory( - factory=factory, - is_async=False, - compiled_deps=self._compile_deps(factory), - container_key=self.container_key, - ) + compiled = self._compile_factory(factory) self.compiled[dependency] = compiled return compiled + def _compile_factory(self, factory: Factory) -> CompiledFactory: + return compile_factory( + factory=factory, + is_async=False, + compiled_deps=self._compile_deps(factory, False), + container_key=self.container_key, + ) + def get_compiled_async( - self, dependency: CompilationKey, + self, + dependency: CompilationKey, ) -> CompiledFactory | None: try: return self.compiled_async[dependency] @@ -191,12 +208,20 @@ def get_compiled_async( compiled = compile_factory( factory=factory, is_async=True, - compiled_deps=self._compile_deps_async(factory), + compiled_deps=self._compile_deps_async(factory, False), container_key=self.container_key, ) self.compiled_async[dependency] = compiled return compiled + def _compile_factory_async(self, factory: Factory) -> CompiledFactory: + return compile_factory( + factory=factory, + is_async=True, + compiled_deps=self._compile_deps_async(factory, False), + container_key=self.container_key, + ) + def get_compiled_activation( self, dependency: CompilationKey, ) -> CompiledFactory | None: @@ -222,7 +247,7 @@ def get_compiled_activation( compiled = compile_activation( factory=factory, is_async=False, - compiled_deps=self._compile_deps(factory), + compiled_deps=self._compile_deps(factory, True), container_key=self.container_key, ) self.compiled_activation[dependency] = compiled @@ -252,7 +277,7 @@ def get_compiled_activation_async( compiled = compile_activation( factory=factory, is_async=True, - compiled_deps=self._compile_deps_async(factory), + compiled_deps=self._compile_deps_async(factory, True), container_key=self.container_key, ) self.compiled_activation_async[dependency] = compiled @@ -350,10 +375,10 @@ def _get_type_var_factory(self, dependency: DependencyKey) -> Factory: dependencies=[], kw_dependencies={}, provides=DependencyKey(type[typevar], dependency.component), - type_=FactoryType.FACTORY, + type_=FactoryType.VALUE, is_to_bind=False, cache=False, - source=lambda: typevar, + source=typevar, when_override=None, when_active=None, when_component=None, diff --git a/tests/unit/container/when/test_factory.py b/tests/unit/container/when/test_factory.py index ed0d81610..271220291 100644 --- a/tests/unit/container/when/test_factory.py +++ b/tests/unit/container/when/test_factory.py @@ -1,3 +1,5 @@ +from typing import NewType + import pytest from dishka import Marker, Provider, Scope, make_container @@ -5,6 +7,7 @@ from dishka.exceptions import ( ActivatorOverrideError, NoActivatorError, + NoFactoryError, WhenOverrideConflictError, ) @@ -96,3 +99,78 @@ def test_activator_override(): provider.activate(lambda: True, Marker("B")) with pytest.raises(ActivatorOverrideError): make_container(provider) + + +def provide_with_dep(a: float) -> str: + return str(a) + + +def test_has_no_dep_inactive(): + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.activate(lambda: False, Marker("B")) + provider.provide(provide_with_dep, provides=str, when=Marker("B")) + + c = make_container(provider) + assert c.get(str) == "a" + + +def test_has_no_dep_active(): + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.activate(lambda: True, Marker("B")) + provider.provide(provide_with_dep, provides=str, when=Marker("B")) + + with pytest.raises(NoFactoryError): + make_container(provider) + + +def activate_zero(value: int): + return value == 0 + + +def test_activation_with_param_static_inactive(): + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.activate(activate_zero, Marker("ZERO")) + provider.provide(provide_with_dep, provides=str, when=Marker("ZERO")) + c = make_container(provider, context={int: 1}) + assert c.get(str) == "a" + + +def test_activation_with_param_static_active_no_dep(): + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.activate(activate_zero, Marker("ZERO")) + provider.provide(provide_with_dep, provides=str, when=Marker("ZERO")) + with pytest.raises(NoFactoryError): + make_container(provider, context={int: 0}) + + +@pytest.mark.parametrize(("number", "string"), [ + (0, "b"), + (1, "a"), +]) +def test_activation_with_param_dynamic(number, string): + provider = Provider(scope=Scope.REQUEST) + provider.provide(lambda: "a", provides=str) + provider.provide(lambda: "b", provides=str, when=Marker("ZERO")) + provider.from_context(int) + provider.activate(activate_zero, Marker("ZERO")) + c = make_container(provider) + with c({int: number}) as request_c: + assert request_c.get(str) == string + + +def test_activation_with_selector_alias_inactive(): + int1 = NewType("int1", int) + int2 = NewType("int2", int) + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.provide(provide_with_dep, provides=str, when=Marker("ZERO")) + provider.activate(lambda: True, Marker("another")) + provider.alias(int1, provides=int) + provider.alias(int2, provides=int, when=Marker("another")) + provider.activate(activate_zero, Marker("ZERO")) + c = make_container(provider, context={int1: 1, int2: 2}) + assert c.get(str) == "a" diff --git a/tests/unit/container/when/test_has.py b/tests/unit/container/when/test_has.py index 8baab0ec8..12a525e71 100644 --- a/tests/unit/container/when/test_has.py +++ b/tests/unit/container/when/test_has.py @@ -34,6 +34,21 @@ async def test_has_async(*, register: bool, value: str): assert await c.get(str) == value +@pytest.mark.parametrize(("register", "value"), [ + (True, "b"), + (False, "a"), +]) +def test_has_async_sync(*, register: bool, value: str): + provider = Provider(scope=Scope.APP) + if register: + provider.provide(lambda: 42, provides=int) + provider.provide(lambda: "a", provides=str) + provider.provide(lambda: "b", provides=str, when=Has(int)) + + c = make_async_container(provider) + assert c.get_sync(str) == value + + @pytest.mark.parametrize(("register", "value"), [ (True, "b"), (False, "a"), @@ -48,3 +63,38 @@ def test_has_chained(*, register: bool, value: str): c = make_container(provider) assert c.get(str) == value + + +def provide_with_dep(a: int) -> str: + return str(a) + + +@pytest.mark.parametrize("scope", [Scope.RUNTIME, Scope.APP]) +def test_has_no_dep(scope): + provider = Provider(scope=scope) + provider.provide(lambda: "a", provides=str) + provider.provide(provide_with_dep, provides=str, when=Has(float)) + + c = make_container(provider) + assert c.get(str) == "a" + + +@pytest.mark.parametrize("scope", [Scope.RUNTIME, Scope.APP]) +def test_has_no_dep_nested_scope(scope): + provider = Provider(scope=scope) + provider.provide(lambda: "a", provides=str) + provider.provide(provide_with_dep, provides=str, when=Has(float)) + + c = make_container(provider) + with c() as request_c: + assert request_c.get(str) == "a" + + +def test_has_wrong_scope(): + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.provide(lambda: 1.0, provides=float, scope=Scope.STEP) + provider.provide(provide_with_dep, provides=str, when=Has(float)) + + c = make_container(provider) + assert c.get(str) == "a"