diff --git a/src/dishka/dependency_source/activator.py b/src/dishka/dependency_source/activator.py index 78b243c8e..3e0d516f8 100644 --- a/src/dishka/dependency_source/activator.py +++ b/src/dishka/dependency_source/activator.py @@ -82,6 +82,7 @@ def as_factory( }, type_=factory.type, cache=factory.cache, + validate_unconditional_when=factory.validate_unconditional_when, when_override=factory.when_override, when_active=factory.when_active, when_component=factory.when_component, diff --git a/src/dishka/dependency_source/alias.py b/src/dishka/dependency_source/alias.py index 01b5f7659..928ef8249 100644 --- a/src/dishka/dependency_source/alias.py +++ b/src/dishka/dependency_source/alias.py @@ -52,6 +52,7 @@ def as_factory( kw_dependencies={}, type_=FactoryType.ALIAS, cache=self.cache, + validate_unconditional_when=None, when_override=self.when_override, when_active=self.when_active, when_component=( diff --git a/src/dishka/dependency_source/context_var.py b/src/dishka/dependency_source/context_var.py index 3be794771..950e08e5d 100644 --- a/src/dishka/dependency_source/context_var.py +++ b/src/dishka/dependency_source/context_var.py @@ -31,7 +31,7 @@ def __init__( def as_factory( self, component: Component, ) -> Factory: - override = (BoolMarker(True) if self.override else None) + override = BoolMarker(True) if self.override else None if component == DEFAULT_COMPONENT: return Factory( @@ -43,6 +43,7 @@ def as_factory( kw_dependencies={}, type_=FactoryType.CONTEXT, cache=False, + validate_unconditional_when=None, when_override=override, when_active=HasContext(self.provides.type_hint), when_component=component, diff --git a/src/dishka/dependency_source/decorator.py b/src/dishka/dependency_source/decorator.py index 7b3750f62..98e90f5ed 100644 --- a/src/dishka/dependency_source/decorator.py +++ b/src/dishka/dependency_source/decorator.py @@ -69,6 +69,7 @@ def as_factory( }, type_=self.factory.type, cache=cache, + validate_unconditional_when=None, when_override=self.when, when_active=self.when, when_component=self.factory.when_component or component, diff --git a/src/dishka/dependency_source/factory.py b/src/dishka/dependency_source/factory.py index 638bb82d4..6e31191a7 100644 --- a/src/dishka/dependency_source/factory.py +++ b/src/dishka/dependency_source/factory.py @@ -35,6 +35,7 @@ class Factory(FactoryData): "dependencies", "is_to_bind", "kw_dependencies", + "validate_unconditional_when", "when_active", "when_component", "when_dependencies", @@ -51,6 +52,7 @@ def __init__( type_: FactoryType, is_to_bind: bool, cache: bool, + validate_unconditional_when: bool | None, when_override: BaseMarker | None, when_active: BaseMarker | None, when_component: Component | None, @@ -82,6 +84,7 @@ def __init__( self.kw_dependencies = kw_dependencies self.is_to_bind = is_to_bind self.cache = cache + self.validate_unconditional_when = validate_unconditional_when self.when_active = when_active self.when_component = when_component self.when_dependencies = when_dependencies @@ -106,6 +109,7 @@ def __get__(self, instance: Any, owner: Any) -> Factory: type_=self.type, is_to_bind=False, cache=self.cache, + validate_unconditional_when=self.validate_unconditional_when, when_override=when_override, when_active=when_active, when_component=self.when_component, @@ -127,6 +131,7 @@ def with_component(self, component: Component) -> Factory: is_to_bind=self.is_to_bind, cache=self.cache, type_=self.type, + validate_unconditional_when=self.validate_unconditional_when, when_override=self.when_override, when_active=self.when_active, when_component=( @@ -147,6 +152,7 @@ def with_scope(self, scope: BaseScope) -> Factory: is_to_bind=self.is_to_bind, cache=self.cache, type_=self.type, + validate_unconditional_when=self.validate_unconditional_when, when_override=self.when_override, when_active=self.when_active, when_component=self.when_component, @@ -171,6 +177,7 @@ def replace( is_to_bind=self.is_to_bind, cache=self.cache, type_=self.type, + validate_unconditional_when=self.validate_unconditional_when, when_override=coalesce(when_override, self.when_override), when_active=coalesce(when_active, self.when_active), when_component=coalesce(when_component, self.when_component), diff --git a/src/dishka/dependency_source/factory_union_mode.py b/src/dishka/dependency_source/factory_union_mode.py index 8ed0f9c25..5a56dc5cc 100644 --- a/src/dishka/dependency_source/factory_union_mode.py +++ b/src/dishka/dependency_source/factory_union_mode.py @@ -53,6 +53,7 @@ def as_factory(self) -> Factory | None: when_active=None, when_override=None, cache=self.cache, + validate_unconditional_when=None, when_component=self.provides.component, is_to_bind=False, type_=FactoryType.COLLECTION, diff --git a/src/dishka/entities/validation_settings.py b/src/dishka/entities/validation_settings.py index 506d64210..20ebfde4a 100644 --- a/src/dishka/entities/validation_settings.py +++ b/src/dishka/entities/validation_settings.py @@ -9,6 +9,8 @@ class ValidationSettings: implicit_override: bool = False # check if decorator was not applied to any factory nothing_decorated: bool = True + # validate factories without explicit `when` conditions at build time + validate_unconditional_when: bool = False DEFAULT_VALIDATION = ValidationSettings() @@ -16,4 +18,5 @@ class ValidationSettings: nothing_overridden=True, implicit_override=True, nothing_decorated=True, + validate_unconditional_when = True, ) diff --git a/src/dishka/graph_builder/activation.py b/src/dishka/graph_builder/activation.py index 4b6ad776f..e0fb6f257 100644 --- a/src/dishka/graph_builder/activation.py +++ b/src/dishka/graph_builder/activation.py @@ -166,6 +166,8 @@ def __init__( self.activation_container = activation_container def _eval_activation(self, factory: Factory) -> None: + if factory.when_active is None and factory.when_override is None: + return try: active = self.activation_container.is_active(factory) except StaticEvaluationUnavailable as e: diff --git a/src/dishka/graph_builder/builder.py b/src/dishka/graph_builder/builder.py index ee534e427..c95ffa063 100644 --- a/src/dishka/graph_builder/builder.py +++ b/src/dishka/graph_builder/builder.py @@ -563,5 +563,5 @@ def build(self) -> Sequence[Registry]: self.start_scope, ).evaluate_static() if not self.skip_validation: - GraphValidator(registries).validate() + GraphValidator(registries, self.validation_settings).validate() return registries diff --git a/src/dishka/graph_builder/uniter.py b/src/dishka/graph_builder/uniter.py index a6b056a2f..67fe390a9 100644 --- a/src/dishka/graph_builder/uniter.py +++ b/src/dishka/graph_builder/uniter.py @@ -83,6 +83,7 @@ def unite( type_=FactoryType.SELECTOR, kw_dependencies={}, source=None, + validate_unconditional_when=None, when_override=None, when_active=or_markers(*( factory.when_active diff --git a/src/dishka/graph_builder/validator.py b/src/dishka/graph_builder/validator.py index 1dc03f190..2b0d0e93f 100644 --- a/src/dishka/graph_builder/validator.py +++ b/src/dishka/graph_builder/validator.py @@ -4,6 +4,7 @@ from dishka.dependency_source import Factory from dishka.entities.key import DependencyKey from dishka.entities.marker import BoolMarker +from dishka.entities.validation_settings import ValidationSettings from dishka.exceptions import ( CycleDependenciesError, GraphMissingFactoryError, @@ -14,10 +15,33 @@ class GraphValidator: - def __init__(self, registries: Sequence[Registry]) -> None: + def __init__( + self, + registries: Sequence[Registry], + validation_settings: ValidationSettings, + ) -> None: self.registries = registries + self.validation_settings = validation_settings self.path: dict[DependencyKey, Factory] = {} self.valid_keys: dict[DependencyKey, bool] = {} + self.current_factory: Factory | None = None + + def _can_validate_now(self, factory: Factory) -> bool: + return self._is_resolved(factory.when_active) and self._is_resolved( + factory.when_override, + ) + + def _is_resolved(self, when: object) -> bool: + if when == BoolMarker(True): + return True + if when is None: + if self.current_factory is None: + return self.validation_settings.validate_unconditional_when + factory_override = self.current_factory.validate_unconditional_when + if factory_override is not None: + return factory_override + return self.validation_settings.validate_unconditional_when + return False def _validate_key( self, @@ -56,13 +80,13 @@ def _validate_key( ) def _validate_factory( - self, factory: Factory, registry_index: int, + self, + factory: Factory, + registry_index: int, ) -> None: - if ( - factory.when_active == BoolMarker(False) and - factory.when_override == BoolMarker(False) - ): - return # do not validate disabled factories + self.current_factory = factory + if not self._can_validate_now(factory): + return self.path[factory.provides] = factory if ( diff --git a/src/dishka/provider/make_factory.py b/src/dishka/provider/make_factory.py index f624d4d10..57d35c471 100644 --- a/src/dishka/provider/make_factory.py +++ b/src/dishka/provider/make_factory.py @@ -266,6 +266,7 @@ def _make_factory_by_class( cache: bool, override: bool, when: BaseMarker | None, + validate_unconditional_when: bool | None, ) -> Factory: if not provides: provides = source @@ -293,6 +294,7 @@ def _make_factory_by_class( provides=hint_to_dependency_key(provides), is_to_bind=False, cache=cache, + validate_unconditional_when=validate_unconditional_when, when_override=calc_override(when=when, override=override), when_active=when, when_component=None, @@ -328,6 +330,7 @@ def _make_factory_by_function( override: bool, check_self_name: bool, when: BaseMarker | None, + validate_unconditional_when: bool | None, ) -> Factory: # typing.cast is applied as unwrap takes a Callable object raw_source = unwrap(cast(Callable[..., Any], source)) @@ -369,6 +372,7 @@ def _make_factory_by_function( provides=hint_to_dependency_key(provides), is_to_bind=is_in_class, cache=cache, + validate_unconditional_when=validate_unconditional_when, when_override=calc_override(when=when, override=override), when_active=when, when_component=None, @@ -384,6 +388,7 @@ def _make_factory_by_static_method( cache: bool, override: bool, when: BaseMarker | None, + validate_unconditional_when: bool | None, ) -> Factory: if missing_hints := _params_without_hints(source, skip_self=False): raise MissingHintsError(source, missing_hints) @@ -413,6 +418,7 @@ def _make_factory_by_static_method( provides=hint_to_dependency_key(provides), is_to_bind=False, cache=cache, + validate_unconditional_when=validate_unconditional_when, when_override=calc_override(when=when, override=override), when_active=when, when_component=None, @@ -440,12 +446,13 @@ def _make_factory_by_other_callable( cache: bool, override: bool, when: BaseMarker | None, + validate_unconditional_when: bool | None, ) -> Factory: if _is_bound_method(source): to_check = source.__func__ # type: ignore[attr-defined] is_in_class = True else: - call_method = source.__call__ # type: ignore[operator] + call_method = source.__call__ # type: ignore[operator] if _is_bound_method(call_method): to_check = call_method.__func__ is_in_class = True @@ -461,6 +468,7 @@ def _make_factory_by_other_callable( override=override, check_self_name=False, when=when, + validate_unconditional_when=validate_unconditional_when, ) if factory.is_to_bind: dependencies = factory.dependencies[1:] # remove `self` @@ -476,6 +484,7 @@ def _make_factory_by_other_callable( provides=factory.provides, is_to_bind=False, cache=cache, + validate_unconditional_when=validate_unconditional_when, when_override=calc_override(when=when, override=override), when_active=when, when_component=None, @@ -509,6 +518,7 @@ def make_factory( is_in_class: bool, override: bool, when: BaseMarker | None = None, + validate_unconditional_when: bool | None, ) -> Factory: provides, source = _extract_source(provides, source) @@ -528,6 +538,7 @@ def make_factory( cache=cache, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) elif isfunction(source) or isinstance(source, classmethod): return _make_factory_by_function( @@ -539,6 +550,7 @@ def make_factory( override=override, check_self_name=True, when=when, + validate_unconditional_when=validate_unconditional_when, ) elif isbuiltin(source): return _make_factory_by_function( @@ -550,6 +562,7 @@ def make_factory( override=override, check_self_name=False, when=when, + validate_unconditional_when=validate_unconditional_when, ) elif isinstance(source, staticmethod): return _make_factory_by_static_method( @@ -559,6 +572,7 @@ def make_factory( cache=cache, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) elif callable(source) and not source_origin: return _make_factory_by_other_callable( @@ -568,6 +582,7 @@ def make_factory( cache=cache, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) else: raise NotAFactoryError(source) @@ -583,16 +598,20 @@ def _provide( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None, ) -> CompositeDependencySource: if when and override: raise WhenOverrideConflictError composite = ensure_composite(source) factory = make_factory( - provides=provides, scope=scope, - source=composite.origin, cache=cache, + provides=provides, + scope=scope, + source=composite.origin, + cache=cache, is_in_class=is_in_class, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) composite.dependency_sources.extend(unpack_factory(factory)) if not recursive: @@ -612,6 +631,7 @@ def _provide( is_in_class=is_in_class, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) additional_sources.extend(additional.dependency_sources) composite.dependency_sources.extend(additional_sources) @@ -627,12 +647,18 @@ def provide_on_instance( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None, ) -> CompositeDependencySource: return _provide( - provides=provides, scope=scope, source=source, cache=cache, + provides=provides, + scope=scope, + source=source, + cache=cache, is_in_class=False, - recursive=recursive, override=override, + recursive=recursive, + override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) @@ -645,6 +671,7 @@ def provide( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None, ) -> Callable[[Callable[..., Any]], CompositeDependencySource]: ... @@ -659,6 +686,7 @@ def provide( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None, ) -> CompositeDependencySource: ... @@ -672,6 +700,7 @@ def provide( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None, ) -> CompositeDependencySource | Callable[ [Callable[..., Any]], CompositeDependencySource, ]: @@ -701,14 +730,14 @@ def provide( return _provide( provides=provides, scope=scope, source=source, cache=cache, is_in_class=True, recursive=recursive, override=override, - when=when, + when=when, validate_unconditional_when=validate_unconditional_when, ) def scoped(func: Callable[..., Any]) -> CompositeDependencySource: return _provide( provides=provides, scope=scope, source=func, cache=cache, is_in_class=True, recursive=recursive, override=override, - when=when, + when=when, validate_unconditional_when=validate_unconditional_when, ) return scoped @@ -723,6 +752,7 @@ def _provide_all( recursive: bool, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None = None, ) -> CompositeDependencySource: composite = CompositeDependencySource(None) for single_provides in provides: @@ -735,6 +765,7 @@ def _provide_all( recursive=recursive, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) composite.dependency_sources.extend(source.dependency_sources) return composite @@ -747,12 +778,13 @@ def provide_all( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None = None, ) -> CompositeDependencySource: return _provide_all( provides=provides, scope=scope, cache=cache, is_in_class=True, recursive=recursive, override=override, - when=when, + when=when, validate_unconditional_when=validate_unconditional_when, ) @@ -763,10 +795,11 @@ def provide_all_on_instance( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None = None, ) -> CompositeDependencySource: return _provide_all( provides=provides, scope=scope, cache=cache, is_in_class=False, recursive=recursive, override=override, - when=when, + when=when, validate_unconditional_when=validate_unconditional_when, ) diff --git a/src/dishka/provider/provider.py b/src/dishka/provider/provider.py index 81276b23c..9909a6546 100644 --- a/src/dishka/provider/provider.py +++ b/src/dishka/provider/provider.py @@ -173,6 +173,7 @@ def provide( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None = None, ) -> CompositeDependencySource: if scope is None: scope = self.scope @@ -184,6 +185,7 @@ def provide( recursive=recursive, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) self._add_dependency_sources(composite.dependency_sources) return composite @@ -196,6 +198,7 @@ def provide_all( recursive: bool = False, override: bool = False, when: BaseMarker | None = None, + validate_unconditional_when: bool | None = None, ) -> CompositeDependencySource: if scope is None: scope = self.scope @@ -206,6 +209,7 @@ def provide_all( recursive=recursive, override=override, when=when, + validate_unconditional_when=validate_unconditional_when, ) self._add_dependency_sources(composite.dependency_sources) return composite diff --git a/src/dishka/provider/unpack_provides.py b/src/dishka/provider/unpack_provides.py index ce8d865b1..38ad77bf7 100644 --- a/src/dishka/provider/unpack_provides.py +++ b/src/dishka/provider/unpack_provides.py @@ -41,6 +41,7 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]: scope=factory.scope, is_to_bind=factory.is_to_bind, cache=factory.cache, + validate_unconditional_when=factory.validate_unconditional_when, provides=hint_to_dependency_key( provides_first, ).with_component(factory.provides.component), diff --git a/src/dishka/registry.py b/src/dishka/registry.py index d0a1003f7..b9292ded5 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -378,6 +378,7 @@ def _get_type_var_factory(self, dependency: DependencyKey) -> Factory: type_=FactoryType.VALUE, is_to_bind=False, cache=False, + validate_unconditional_when=None, source=typevar, when_override=None, when_active=None, @@ -427,6 +428,7 @@ def _specialize_generic( type_=factory.type, scope=factory.scope, cache=factory.cache, + validate_unconditional_when=factory.validate_unconditional_when, when_override=factory.when_override, when_active=factory.when_active, when_component=factory.when_component, diff --git a/tests/unit/container/when/test_factory.py b/tests/unit/container/when/test_factory.py index 271220291..1d6e52d1f 100644 --- a/tests/unit/container/when/test_factory.py +++ b/tests/unit/container/when/test_factory.py @@ -2,7 +2,7 @@ import pytest -from dishka import Marker, Provider, Scope, make_container +from dishka import Marker, Provider, Scope, ValidationSettings, make_container from dishka.exception_base import InvalidMarkerError from dishka.exceptions import ( ActivatorOverrideError, @@ -12,6 +12,108 @@ ) +def is_zero(value: int) -> bool: + return value == 0 + + +def fallback() -> str: + return "a" + + +def needs_float(value: float) -> str: + return str(value) + + +@pytest.mark.parametrize( ("number", "expected", "raises"), [ + (1, "a", False), + (0, None, True), +]) +def test_unresolved_conditional_branch_is_validated_at_runtime( + *, + number: int, + expected: str | None, + raises: bool, +): + provider = Provider(scope=Scope.APP) + provider.activate(is_zero, Marker("ZERO")) + provider.provide(lambda: number, provides=int) + provider.provide(fallback, provides=str) + provider.provide(needs_float, provides=str, when=Marker("ZERO")) + container = make_container(provider) + if raises: + with pytest.raises(NoFactoryError): + container.get(str) + else: + assert container.get(str) == expected + + +@pytest.mark.parametrize("validate_unconditional_when", [False, True]) +def test_validate_unconditional_when_setting(*, validate_unconditional_when: bool,): + provider = Provider(scope=Scope.APP) + provider.provide(needs_float, provides=str) + + if validate_unconditional_when: + with pytest.raises(NoFactoryError): + make_container( + provider, + validation_settings=ValidationSettings( + validate_unconditional_when=True, + ), + ) + else: + container = make_container( + provider, + validation_settings=ValidationSettings( + validate_unconditional_when=False, + ), + ) + + with pytest.raises(NoFactoryError): + container.get(str) + + +@pytest.mark.parametrize( + ("global_setting", "factory_override", "build_fails"), + [ + (False, None, False), + (True, None, True), + (False, True, True), + (True, False, False), + ], +) +def test_validate_unconditional_when_factory_override( + *, + global_setting: bool, + factory_override: bool | None, + build_fails: bool, +): + provider = Provider(scope=Scope.APP) + provider.provide( + needs_float, + provides=str, + validate_unconditional_when=factory_override, + ) + + if build_fails: + with pytest.raises(NoFactoryError): + make_container( + provider, + validation_settings=ValidationSettings( + validate_unconditional_when=global_setting, + ), + ) + else: + container = make_container( + provider, + validation_settings=ValidationSettings( + validate_unconditional_when=global_setting, + ), + ) + + with pytest.raises(NoFactoryError): + container.get(str) + + @pytest.mark.parametrize(("value", "b_is_active"), [ ("a", False), ("b", True),