Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions docs/advanced/when.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,20 @@ In case you want to activate some features when specific objects are available y
* it is activated
* if it actually presents in context while being registered as ``from_context``

``Has(T)`` implicitly registers ``T`` for graph validation.

Use ``from_context(T, ...)`` when ``Has(T)`` should become true only after a real
context value is passed.

The implicit registration only helps validation. ``Has(T)`` still stays false
until some real provider or real context value is available.


For example:

.. code-block:: python

from dishka import Provider, provide, Scope
from dishka import Provider, from_context, provide, Scope
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. _when:

Conditional activation
============================================

There are some cases when you want to declare a factory or decorator in a provider but use only when a certain condition is met. For example:

* Apply decorators in debug mode
* Use cache if redis config provided in context
* Implement A/B testing with different implementations based on HTTP header
* Provide different identity provider classes based on available context objects: web request or queued messages.

This can be achieved with "activation" approach. Key concepts here:

* **Marker** - special object to distinguish which implementations should be used.
* **Activator** or **activation function** - special function registered in provider and taking decision if marker is active or not.
* **activation condition** - expression with marker objects set in dependency source dynamically associated with activators to select between multiple implementations or enable decorators


.. note::

    The activation feature makes the application harder to analyze and can also affect performance, so use it wisely.

Basic usage
---------------------------------

To set conditional activation you create special ``Marker`` objects and use them in ``when=`` condition inside ``provide``, ``decorate`` or ``alias``.

.. code-block:: python

-    from dishka import Provider, provide, Scope
+    from dishka import Marker, Provider, Scope, provide

-    class MyProvider(Provider)
+    class MyProvider(Provider):
        @provide(scope=Scope.APP)
        def base_impl(self) -> Cache:
            return NormalCacheImpl()

        @provide(when=Marker("debug"), scope=Scope.APP)
        def debug_impl(self) -> Cache:
            return DebugCacheImpl()

In this code you can see 2 factories providing same type ``Cache``.
The second one is used whenever ``Marker("debug")`` is treated as as active.
The base implementation will be used in all other cases as it has no condition set.
The overall rule is "last wins" like it worked with overriding.

Second step is to provide logic of marker activation. You write a function returning ``bool`` and register it in provider using ``@activate`` decorator.
It can be the same or another provider while you pass when creating a container.

.. code-block:: python

-    from dishka import activate, Provider
+    from dishka import Marker, Provider, activate

-    class MyProvider(Provider)
+    class MyProvider(Provider):
        @activate(Marker("debug"))
        def is_debug(self) -> bool:
            return False

This function can use other objects as well. For example, we can pass config using context

.. code-block:: python

-    class MyProvider(Provider)
+    class MyProvider(Provider):
        config = from_context(Config, scope=Scope.APP)

        @activate(Marker("debug"))
        def is_debug(self, config: Config) -> bool:
            return config.debug

Activation on marker type
--------------------------------

More general pattern is to create own marker type and register a single activator on all instances. You can request marker as an activator parameter.

.. code-block::

    class EnvMarker(Marker):
        pass

-    class MyProvider(Provider)
+    class MyProvider(Provider):
        config = from_context(Config, scope=Scope.APP)

        @activate(EnvMarker)
        def is_debug(self, marker: EnvMarker, config: Config) -> bool:
            return config.environment == marker.value


Combining markers
------------------------------------------

Markers support simple combination logic when used in ``when=`` using ``|`` (or), ``&`` (and) and ``~`` (not) operators

.. code-block:: python


        @provide(when=Marker("debug") | EnvMarker("preprod"))
        def debug_impl(self) -> Cache:
            return DebugCacheImpl()

        @provide(when=~Marker("debug") & EnvMarker("preprod"))
        def test_impl(self) -> Cache:
            return TestCacheImpl()


Provider-level activation
-------------------------

You can set ``when=`` on the entire provider to apply a condition to all factories, aliases, and decorators within it. This reduces boilerplate when all dependencies in a provider share the same activation condition.

.. code-block:: python

    from dishka import Marker, Provider, Scope, provide

    class DebugProvider(Provider):
        when = Marker("debug")
        scope = Scope.APP

        @provide
        def debug_cache(self) -> Cache:
            return DebugCacheImpl()

        @provide
        def debug_logger(self) -> Logger:
            return VerboseLogger()

The provider's ``when`` can also be set via constructor:

.. code-block:: python

    provider = DebugProvider(when=Marker("debug"))

When both provider and individual source have ``when=``, conditions are combined with AND logic:

.. code-block:: python

    class FeatureProvider(Provider):
        when = Marker("prod")  # prerequisite
        scope = Scope.APP

        @provide(when=Has(RedisConfig))  # additional condition
        def redis_cache(self, config: RedisConfig) -> Cache:
            return RedisCache(config)
        # Effective: Marker("prod") & Has(RedisConfig)

The provider's ``when`` acts as a prerequisite; individual sources add further constraints. If a factory shouldn't inherit the provider's condition, move it to a different provider.

Checking graph elements
---------------------------------------

In case you want to activate some features when specific objects are available you can use ``Has`` marker. It checks whether

* requested class is registered in container with appropriate scope
* it is activated
* if it actually presents in context while being registered as ``from_context``

``Has(T)`` implicitly registers ``T`` for graph validation.

Use ``from_context(T, ...)`` when ``Has(T)`` should become true only after a real
context value is passed.

The implicit registration only helps validation. ``Has(T)`` still stays false
until some real provider or real context value is available.


For example:

.. code-block:: python

-    from dishka import Provider, from_context, provide, Scope
+    from dishka import Has, Provider, Scope, from_context, make_container, provide

-    class MyProvider(Provider)
+    class MyProvider(Provider):
        config = from_context(RedisConfig, scope=Scope.APP)

        @provide(scope=Scope.APP)
        def base_impl(self) -> Cache:
            return NormalCacheImpl()

        @provide(when=Has(RedisConfig), scope=Scope.APP)
        def redis_impl(self, config: RedisConfig) -> Cache:
            return RedisCache(config)

        @provide(when=Has(MemcachedConfig), scope=Scope.APP)
        def memcached_impl(self, config: MemcachedConfig) -> Cache:
            return MemcachedCache(config)


    container = make_container(MyProvider, context={})


In this case,

* ``memcached_impl`` is not used because no real factory for ``MemcachedConfig`` is provided
* ``redis_impl`` is not used while it is registered as ``from_context`` but no real value is provided.
* ``base_impl`` is used as a default one, because none of later is active


class MyProvider(Provider)
config = from_context(RedisConfig, scope=Scope.APP)
Expand All @@ -182,6 +190,6 @@ For example:

In this case,

* ``memcached_impl`` is not used because no factory for ``MemcachedConfig`` is provided
* ``memcached_impl`` is not used because no real factory for ``MemcachedConfig`` is provided
* ``redis_impl`` is not used while it is registered as ``from_context`` but no real value is provided.
* ``base_impl`` is used as a default one, because none of later is active
* ``base_impl`` is used as a default one, because none of later is active
45 changes: 43 additions & 2 deletions src/dishka/graph_builder/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from collections import defaultdict
from collections.abc import Collection, Sequence
from collections.abc import Collection, Iterator, Sequence
from typing import cast

from dishka.dependency_source import (
Expand All @@ -14,7 +14,7 @@
from dishka.entities.component import Component
from dishka.entities.factory_type import FactoryData, FactoryType
from dishka.entities.key import DependencyKey
from dishka.entities.marker import unpack_marker
from dishka.entities.marker import BoolMarker, Has, Marker, unpack_marker
from dishka.entities.scope import BaseScope, InvalidScopes
from dishka.entities.validation_settings import ValidationSettings
from dishka.exception_base import InvalidMarkerError
Expand Down Expand Up @@ -537,11 +537,52 @@ def _fix_missing_scopes(
requester_scope=root_scope,
)

def _iter_factory_markers(self, factory: Factory) -> Iterator[Marker]:
yield from unpack_marker(factory.when_active)
yield from unpack_marker(factory.when_override)
for subfactory in factory.when_dependencies:
yield from self._iter_factory_markers(subfactory)

def _make_implicit_has_factory(
self,
key: DependencyKey,
) -> Factory:
return Factory(
scope=next(iter(self.scopes)),
source=key.type_hint,
provides=key,
is_to_bind=False,
dependencies=[],
kw_dependencies={},
type_=FactoryType.CONTEXT,
cache=False,
when_override=None,
when_active=BoolMarker(False),
when_component=key.component,
when_dependencies=[],
)

def _add_implicit_has_factories(
self,
factories: dict[DependencyKey, Factory],
) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like:

    def _make_implicit_has_factory(
        self,
        key: DependencyKey,
        root_scope: BaseScope,
    ) -> Factory:
        return Factory(
            scope=root_scope,
            source=key.type_hint,
            provides=key,
            is_to_bind=False,
            dependencies=[],
            kw_dependencies={},
            type_=FactoryType.CONTEXT,
            cache=False,
            when_override=None,
            when_active=BoolMarker(False),
            when_component=key.component,
            when_dependencies=[],
        )

    def _add_implicit_has_factories(
        self,
        factories: dict[DependencyKey, Factory],
    ) -> None:
        root_scope = next(iter(self.scopes))
        seen = set(factories)
        missing_keys: list[DependencyKey] = []
        stack = list(factories.values())

        while stack:
            factory = stack.pop()
            stack.extend(factory.when_dependencies)

            for marker in unpack_marker(factory.when_active):
                if isinstance(marker, Has):
                    key = DependencyKey(marker.value, factory.when_component)
                    if key not in seen:
                        seen.add(key)
                        missing_keys.append(key)

            for marker in unpack_marker(factory.when_override):
                if isinstance(marker, Has):
                    key = DependencyKey(marker.value, factory.when_component)
                    if key not in seen:
                        seen.add(key)
                        missing_keys.append(key)

        factories.update({
            key: self._make_implicit_has_factory(key, root_scope)
            for key in missing_keys
        })

_iter_factory_markers create extra function calls and yield from frames for each node in tree (while and stack.pop() cheaper)

Copy link
Copy Markdown
Author

@Stefanqn Stefanqn Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the review. Please take a look at my latest commit.

I updated src/dishka/graph_builder/builder.py to use the iterative stack traversal and pass root_scope explicitly into _make_implicit_has_factory(...), following your suggestion.

missing: dict[DependencyKey, Factory] = {}
for factory in tuple(factories.values()):
for marker in self._iter_factory_markers(factory):
if not isinstance(marker, Has):
continue
key = DependencyKey(marker.value, factory.when_component)
if key in factories or key in missing:
continue
missing[key] = self._make_implicit_has_factory(key)
factories.update(missing)

def build(self) -> Sequence[Registry]:
self._check_markers()
factories: dict[DependencyKey, Factory] = {
f.provides: f for f in self._collect_prepared_factories()
}
self._add_implicit_has_factories(factories)
self._fix_missing_scopes(factories)
fixed_factories = list(factories.values())

Expand Down
103 changes: 102 additions & 1 deletion tests/unit/container/when/test_has.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from typing import Any

import pytest

from dishka import Has, Provider, Scope, make_async_container, make_container
from dishka import (
Has,
Provider,
Scope,
make_async_container,
make_container,
provide,
)
from dishka.exceptions import NoActiveFactoryError


@pytest.mark.parametrize(("register", "value"), [
Expand Down Expand Up @@ -48,3 +58,94 @@ def test_has_chained(*, register: bool, value: str):

c = make_container(provider)
assert c.get(str) == value


@pytest.mark.parametrize(
("enable_conditional_provider", "successful"), [
(False, False),
(True, True),
],
)
def test_has_with_declared_context_dependency(
*, enable_conditional_provider: bool, successful: bool,
):
class StringProvider(Provider):
@provide(when=Has(int), scope=Scope.APP)
def setup(self, cfg: int) -> str:
return "ok"

class IntProvider(Provider):
int_config_instance = provide(
source=lambda self: 42, provides=int, scope=Scope.APP,
)

providers: list[Any] = [StringProvider()]
if enable_conditional_provider:
providers.append(IntProvider())

container = make_container(*providers, context={})

if successful:
assert isinstance(container.get(str), str)
else:
with pytest.raises(NoActiveFactoryError):
container.get(str)


@pytest.mark.asyncio
@pytest.mark.parametrize(
("is_async", "register_int", "value"),
[
(False, True, "b"),
(False, False, "a"),
(True, True, "b"),
(True, False, "a"),
],
)
async def test_provider_declare_method_does_not_make_has_active(
*, is_async: bool, register_int: bool, value: str,
):
provider = Provider(scope=Scope.APP)
provider.provide(lambda: "a", provides=str)
provider.provide(lambda: "b", provides=str, when=Has(int))

provider2 = Provider(scope=Scope.APP)
if register_int:
provider2.provide(lambda: 42, provides=int)

if is_async:
container = make_async_container(provider, provider2, context={})
assert await container.get(str) == value
else:
container = make_container(provider, provider2, context={})
assert container.get(str) == value


@pytest.mark.asyncio
@pytest.mark.parametrize(
("is_async", "register_ctx", "value"),
[
(False, True, "b"),
(False, False, "a"),
(True, True, "b"),
(True, False, "a"),
],
)
async def test_from_context_requires_real_context_value_for_has(
*, is_async: bool, register_ctx: bool, value: str,
):
provider = Provider(scope=Scope.APP)
provider.from_context(int)
provider.provide(lambda: "a", provides=str)
provider.provide(lambda: "b", provides=str, when=Has(int))
if register_ctx:
ctx = {int: 42}
else:
ctx = {}

if is_async:
container = make_async_container(provider, context=ctx)
assert await container.get(str) == value
else:
container = make_container(provider, context=ctx)
assert container.get(str) == value