diff --git a/docs/advanced/when.rst b/docs/advanced/when.rst index cb0b85873..f321447ae 100644 --- a/docs/advanced/when.rst +++ b/docs/advanced/when.rst @@ -28,9 +28,9 @@ To set conditional activation you create special ``Marker`` objects and use them .. code-block:: python - from dishka import Provider, provide, Scope + from dishka import Marker, Provider, Scope, provide - class MyProvider(Provider) + class MyProvider(Provider): @provide(scope=Scope.APP) def base_impl(self) -> Cache: return NormalCacheImpl() @@ -49,9 +49,9 @@ It can be the same or another provider while you pass when creating a container. .. code-block:: python - from dishka import activate, Provider + from dishka import Marker, Provider, activate - class MyProvider(Provider) + class MyProvider(Provider): @activate(Marker("debug")) def is_debug(self) -> bool: return False @@ -60,7 +60,7 @@ This function can use other objects as well. For example, we can pass config usi .. code-block:: python - class MyProvider(Provider) + class MyProvider(Provider): config = from_context(Config, scope=Scope.APP) @activate(Marker("debug")) @@ -78,7 +78,7 @@ More general pattern is to create own marker type and register a single activato pass - class MyProvider(Provider) + class MyProvider(Provider): config = from_context(Config, scope=Scope.APP) @activate(EnvMarker) @@ -154,14 +154,21 @@ In case you want to activate some features when specific objects are available y * it is activated * if it actually presents in context while being registered as ``from_context`` +``Has(T)`` implicitly registers ``T`` for graph validation. Use +``from_context(T, ...)`` when ``Has(T)`` should become true only after a real +context value is passed. + +The implicit registration only helps validation. ``Has(T)`` still stays false +until some real provider or real context value is available. + For example: .. code-block:: python - from dishka import Provider, provide, Scope + from dishka import Has, Provider, Scope, from_context, make_container, provide - class MyProvider(Provider) + class MyProvider(Provider): config = from_context(RedisConfig, scope=Scope.APP) @provide(scope=Scope.APP) @@ -182,6 +189,6 @@ For example: In this case, -* ``memcached_impl`` is not used because no factory for ``MemcachedConfig`` is provided +* ``memcached_impl`` is not used because no real factory for ``MemcachedConfig`` is provided * ``redis_impl`` is not used while it is registered as ``from_context`` but no real value is provided. -* ``base_impl`` is used as a default one, because none of later is active \ No newline at end of file +* ``base_impl`` is used as a default one, because none of later is active diff --git a/src/dishka/graph_builder/validator.py b/src/dishka/graph_builder/validator.py index 50e6cdce4..841e497af 100644 --- a/src/dishka/graph_builder/validator.py +++ b/src/dishka/graph_builder/validator.py @@ -1,8 +1,19 @@ import itertools from collections.abc import Sequence +from enum import Enum from dishka.dependency_source import Factory +from dishka.entities.component import Component from dishka.entities.key import DependencyKey +from dishka.entities.marker import ( + AndMarker, + BaseMarker, + BoolMarker, + Has, + Marker, + NotMarker, + OrMarker, +) from dishka.exceptions import ( CycleDependenciesError, GraphMissingFactoryError, @@ -12,12 +23,115 @@ from dishka.registry import Registry +class MarkerValue(Enum): + TRUE = True + FALSE = False + UNKNOWN = None + + class GraphValidator: def __init__(self, registries: Sequence[Registry]) -> None: self.registries = registries self.path: dict[DependencyKey, Factory] = {} self.valid_keys: dict[DependencyKey, bool] = {} + def _get_factory( + self, + key: DependencyKey, + registry_index: int, + ) -> Factory | None: + for index in range(registry_index + 1): + factory = self.registries[index].get_factory(key) + if factory is not None: + return factory + return None + + def _has_reachable_factory( + self, + key: DependencyKey, + registry_index: int, + ) -> bool: + return self._get_factory(key, registry_index) is not None + + def _marker_component(self, component: Component | None) -> Component: + if component is None: + raise TypeError + return component + + def _invert_marker_value(self, value: MarkerValue) -> MarkerValue: + if value is MarkerValue.TRUE: + return MarkerValue.FALSE + if value is MarkerValue.FALSE: + return MarkerValue.TRUE + return MarkerValue.UNKNOWN + + def _and_marker_values( + self, + left: MarkerValue, + right: MarkerValue, + ) -> MarkerValue: + if MarkerValue.FALSE in (left, right): + return MarkerValue.FALSE + if MarkerValue.TRUE == left == right: + return MarkerValue.TRUE + return MarkerValue.UNKNOWN + + def _or_marker_values( + self, + left: MarkerValue, + right: MarkerValue, + ) -> MarkerValue: + if MarkerValue.TRUE in (left, right): + return MarkerValue.TRUE + if MarkerValue.FALSE == left == right: + return MarkerValue.FALSE + return MarkerValue.UNKNOWN + + def _eval_marker( + self, + marker: BaseMarker | None, + component: Component | None, + registry_index: int, + ) -> MarkerValue: + result = MarkerValue.UNKNOWN + match marker: + case None | BoolMarker(True): + result = MarkerValue.TRUE + case BoolMarker(False): + result = MarkerValue.FALSE + case AndMarker(): + result = self._and_marker_values( + self._eval_marker(marker.left, component, registry_index), + self._eval_marker(marker.right, component, registry_index), + ) + case OrMarker(): + result = self._or_marker_values( + self._eval_marker(marker.left, component, registry_index), + self._eval_marker(marker.right, component, registry_index), + ) + case NotMarker(): + result = self._invert_marker_value( + self._eval_marker( + marker.marker, + component, + registry_index, + ), + ) + case Has(): + key = DependencyKey( + marker.value, + self._marker_component(component), + ) + if self._has_reachable_factory(key, registry_index): + result = MarkerValue.UNKNOWN + else: + result = MarkerValue.FALSE + case Marker(): + result = MarkerValue.UNKNOWN + case _: + result = MarkerValue.UNKNOWN + return result + def _validate_key( self, key: DependencyKey, @@ -65,13 +179,19 @@ def _validate_factory( raise CycleDependenciesError([factory]) try: - for dep in itertools.chain( - factory.dependencies, - factory.kw_dependencies.values(), - ): - # ignore TypeVar and const parameters - if not dep.is_type_var() and not dep.is_const(): - self._validate_key(dep, registry_index) + when_active = self._eval_marker( + factory.when_active, + factory.when_component, + registry_index, + ) + if when_active is not MarkerValue.FALSE: + for dep in itertools.chain( + factory.dependencies, + factory.kw_dependencies.values(), + ): + # ignore TypeVar and const parameters + if not dep.is_type_var() and not dep.is_const(): + self._validate_key(dep, registry_index) except NoFactoryError as e: e.add_path(factory) raise diff --git a/tests/unit/container/when/test_has.py b/tests/unit/container/when/test_has.py index 8baab0ec8..7c311d1ad 100644 --- a/tests/unit/container/when/test_has.py +++ b/tests/unit/container/when/test_has.py @@ -1,6 +1,17 @@ +from typing import Any + import pytest -from dishka import Has, Provider, Scope, make_async_container, make_container +from dishka import ( + Has, + Provider, + Scope, + from_context, + make_async_container, + make_container, + provide, +) +from dishka.exceptions import GraphMissingFactoryError, NoActiveFactoryError @pytest.mark.parametrize(("register", "value"), [ @@ -48,3 +59,149 @@ def test_has_chained(*, register: bool, value: str): c = make_container(provider) assert c.get(str) == value + + +@pytest.mark.parametrize( + ("enable_conditional_provider", "successful"), [ + (False, False), + (True, True), + ], +) +def test_has_with_declared_context_dependency( + *, enable_conditional_provider: bool, successful: bool, +): + class StringProvider(Provider): + @provide(when=Has(int), scope=Scope.APP) + def setup(self, cfg: int) -> str: + return "ok" + + class IntProvider(Provider): + int_config_instance = provide( + source=lambda self: 42, provides=int, scope=Scope.APP, + ) + + providers: list[Any] = [StringProvider()] + if enable_conditional_provider: + providers.append(IntProvider()) + + container = make_container(*providers, context={}) + + if successful: + assert isinstance(container.get(str), str) + else: + with pytest.raises(NoActiveFactoryError): + container.get(str) + + +def test_not_has_with_missing_dependency_fails_validation(): + class StringProvider(Provider): + scope = Scope.APP + + @provide(when=~Has(int)) + def setup(self, cfg: int) -> str: + return "ok" + + with pytest.raises(GraphMissingFactoryError): + make_container(StringProvider(), context={}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("is_async", "register_int", "value"), + [ + (False, True, "b"), + (False, False, "a"), + (True, True, "b"), + (True, False, "a"), + ], +) +async def test_provider_declare_method_does_not_make_has_active( + *, is_async: bool, register_int: bool, value: str, +): + provider = Provider(scope=Scope.APP) + provider.provide(lambda: "a", provides=str) + provider.provide(lambda: "b", provides=str, when=Has(int)) + + provider2 = Provider(scope=Scope.APP) + if register_int: + provider2.provide(lambda: 42, provides=int) + + if is_async: + container = make_async_container(provider, provider2, context={}) + assert await container.get(str) == value + else: + container = make_container(provider, provider2, context={}) + assert container.get(str) == value + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("is_async", "register_ctx", "value"), + [ + (False, True, "b"), + (False, False, "a"), + (True, True, "b"), + (True, False, "a"), + ], +) +async def test_from_context_requires_real_context_value_for_has( + *, is_async: bool, register_ctx: bool, value: str, +): + provider = Provider(scope=Scope.APP) + provider.from_context(int) + provider.provide(lambda: "a", provides=str) + provider.provide(lambda: "b", provides=str, when=Has(int)) + if register_ctx: + ctx = {int: 42} + else: + ctx = {} + + if is_async: + container = make_async_container(provider, context=ctx) + assert await container.get(str) == value + else: + container = make_container(provider, context=ctx) + assert container.get(str) == value + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("is_async", "register_ctx", "successful"), + [ + (False, True, True), + (False, False, False), + (True, True, True), + (True, False, False), + ], +) +async def test_from_context_keeps_has_runtime_dependent( + *, + is_async: bool, + register_ctx: bool, + successful: bool, +): + class StringProvider(Provider): + scope = Scope.APP + + cfg = from_context(int) + + @provide(when=Has(int)) + def setup(self, cfg: int) -> str: + return "ok" + + ctx = {int: 42} if register_ctx else {} + + if is_async: + container = make_async_container(StringProvider(), context=ctx) + if successful: + assert await container.get(str) == "ok" + else: + with pytest.raises(NoActiveFactoryError): + await container.get(str) + else: + container = make_container(StringProvider(), context=ctx) + if successful: + assert container.get(str) == "ok" + else: + with pytest.raises(NoActiveFactoryError): + container.get(str)