Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions docs/advanced/when.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ This can be achieved with "activation" approach. Key concepts here:
* **Activator** or **activation function** - special function registered in provider and taking decision if marker is active or not.
* **activation condition** - expression with marker objects set in dependency source dynamically associated with activators to select between multiple implementations or enable decorators

Activators can be called preliminary or multiple times, so avoid acquiring resources or doing heavy calculations, if necessary, move such things into factories or context data.

.. note::

Expand All @@ -28,9 +29,9 @@ To set conditional activation you create special ``Marker`` objects and use them

.. code-block:: python

from dishka import Provider, provide, Scope
from dishka import Marker, Provider, provide, Scope

class MyProvider(Provider)
class MyProvider(Provider):
@provide(scope=Scope.APP)
def base_impl(self) -> Cache:
return NormalCacheImpl()
Expand All @@ -49,9 +50,9 @@ It can be the same or another provider while you pass when creating a container.

.. code-block:: python

from dishka import activate, Provider
from dishka import activate, Marker, Provider

class MyProvider(Provider)
class MyProvider(Provider):
@activate(Marker("debug"))
def is_debug(self) -> bool:
return False
Expand All @@ -60,7 +61,7 @@ This function can use other objects as well. For example, we can pass config usi

.. code-block:: python

class MyProvider(Provider)
class MyProvider(Provider):
config = from_context(Config, scope=Scope.APP)

@activate(Marker("debug"))
Expand All @@ -78,7 +79,7 @@ More general pattern is to create own marker type and register a single activato
pass


class MyProvider(Provider)
class MyProvider(Provider):
config = from_context(Config, scope=Scope.APP)

@activate(EnvMarker)
Expand Down Expand Up @@ -161,7 +162,7 @@ For example:

from dishka import Provider, provide, Scope

class MyProvider(Provider)
class MyProvider(Provider):
config = from_context(RedisConfig, scope=Scope.APP)

@provide(scope=Scope.APP)
Expand All @@ -184,4 +185,28 @@ In this case,

* ``memcached_impl`` is not used because no factory for ``MemcachedConfig`` is provided
* ``redis_impl`` is not used while it is registered as ``from_context`` but no real value is provided.
* ``base_impl`` is used as a default one, because none of later is active
* ``base_impl`` is used as a default one, because none of later is active


Preliminary (static) evaluation and graph validation
------------------------------------------------------------

In certain cases activator can be called during graph building step, this allows avoid unnecessary calls in runtime and ignore errors on factories which are never called.

Static evaluation is enabled only if activator a sync non-generator function with dependencies retrieved from root context or without dependencies at all.
For example, in the following code ``redis_impl`` is never called because ``RedisConfig`` is not passed, so it won't be validated at all.


.. code-block:: python

from dishka import Provider, provide, Scope

class MyProvider(Provider):
config = from_context(RedisConfig, scope=Scope.APP)

@provide(when=Has(RedisConfig), scope=Scope.APP)
def redis_impl(self, config: RedisConfig) -> Cache:
return RedisCache(config)

container = make_container(MyProvider, context={})

25 changes: 25 additions & 0 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def _get_sync(self, key: CompilationKey) -> Any:
self._cache,
self._context,
self,
self._has_sync,
)

async def _get(self, key: CompilationKey) -> Any:
Expand Down Expand Up @@ -296,6 +297,7 @@ async def _get_unlocked(self, key: CompilationKey) -> Any:
self._cache,
self._context,
self,
self._has,
)

async def close(self, exception: BaseException | None = None) -> None:
Expand Down Expand Up @@ -350,13 +352,34 @@ async def _has(self, marker: CompilationKey) -> bool:
self._cache,
self._context,
self,
self._has,
))

def _has_sync(self, marker: CompilationKey) -> bool:
compiled = self.registry.get_compiled_activation(marker)
if not compiled:
if not self.parent_container:
return False
return self.parent_container._has_sync(marker) # noqa: SLF001

return bool(compiled(
self._get_sync,
self._exits,
self._cache,
self._context,
self,
self._has_sync,
))

def _has_context(self, marker: Any) -> bool:
return self._context is not None and marker in self._context


class HasProvider(Provider):
"""
This provider is used only for direct access on Has/HasContext.
Basic implementation is inlined in code builder.
"""
@activate(Has)
async def has(
self,
Expand Down Expand Up @@ -392,9 +415,11 @@ def make_async_container(
has_provider = HasProvider()
builder = GraphBuilder(
scopes=scopes,
start_scope=start_scope,
container_key=CONTAINER_KEY,
skip_validation=skip_validation,
validation_settings=validation_settings,
root_context=context or {},
)
builder.add_multicomponent_providers(has_provider)
builder.add_providers(*providers)
Expand Down
15 changes: 9 additions & 6 deletions src/dishka/code_tools/code_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ def call(self, func: str, *args: str, **kwargs: str) -> str:
args_list.extend(f"{name}={value}" for name, value in kwargs.items())

if len(args_list) > MAX_ITEMS_PER_LINE:
args_str = ",\n".join(args_list)
sep = ",\n " + self.indent_str + " "*8
else:
args_str = ", ".join(args_list)
sep = ", "
args_str = sep.join(args_list)
return f"{func}({args_str})"

def await_(self, expr: str) -> str:
Expand Down Expand Up @@ -239,16 +240,18 @@ def for_(self, name: str, expr: str) -> Iterator[None]:

def list_literal(self, *items: str) -> str:
if len(items) > MAX_ITEMS_PER_LINE:
items_str = "\n, ".join(items)
sep = ",\n " + self.indent_str + " "*8
else:
items_str = ", ".join(items)
sep = ", "
items_str = sep.join(items)
return f"[{items_str}]"

def tuple_literal(self, *items: str) -> str:
if len(items) > MAX_ITEMS_PER_LINE:
items_str = "\n, ".join(items)
sep = ",\n " + self.indent_str + " "*8
else:
items_str = ", ".join(items)
sep = ", "
items_str = sep.join(items)
return f"({items_str})"

def compile(self, source_file_name: str) -> dict[str, Any]:
Expand Down
19 changes: 16 additions & 3 deletions src/dishka/code_tools/factory_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
AndMarker,
BaseMarker,
BoolMarker,
Has,
HasContext,
NotMarker,
OrMarker,
)
Expand Down Expand Up @@ -63,7 +65,7 @@ def make_getter(self) -> AbstractContextManager[None]:
self.getter_name = self.getter_prefix + raw_provides_name
return self.def_(
self.getter_name,
["getter", "exits", "cache", "context", "container"],
["getter", "exits", "cache", "context", "container", "has"],
)

def getter(
Expand All @@ -82,7 +84,7 @@ def getter(
return self.await_(
self.call(
factory,
"getter", "exits", "cache", "context", "container",
"getter", "exits", "cache", "context", "container", "has",
),
)
return self.await_(self.call(
Expand All @@ -101,7 +103,10 @@ def return_if_cached(self, factory: Factory) -> None:
def assign_solved(self, expr: str) -> None:
self.assign_local("solved", expr)

def when(
def _has_context(self, type_: str) -> str:
return f"(context is not None and {type_} in context)"

def when( # noqa: PLR0911
self,
marker: BaseMarker | None,
component: Component | None,
Expand All @@ -126,6 +131,14 @@ def when(
)
case BoolMarker(False):
return self.global_(marker.value)
case Has():
key = DependencyKey(marker.value, component)
return self.await_(self.call(
"has",
self.global_(key.as_compilation_key()),
))
case HasContext():
return self._has_context(self.global_(marker.value))
case _:
if component is None:
raise TypeError( # noqa: TRY003
Expand Down
8 changes: 8 additions & 0 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _get_unlocked(self, key: CompilationKey) -> Any:
self._cache,
self._context,
self,
self._has,
)

def close(self, exception: BaseException | None = None) -> None:
Expand Down Expand Up @@ -278,13 +279,18 @@ def _has(self, marker: CompilationKey) -> bool:
self._cache,
self._context,
self,
self._has,
))

def _has_context(self, marker: Any) -> bool:
return self._context is not None and marker in self._context


class HasProvider(Provider):
"""
This provider is used only for direct access on Has/HasContext.
Basic implementation is inlined in code builder.
"""
@activate(Has)
def has(
self,
Expand Down Expand Up @@ -317,7 +323,9 @@ def make_container(
context_provider = make_root_context_provider(providers, context, scopes)
has_provider = HasProvider()
builder = GraphBuilder(
root_context=context or {},
scopes=scopes,
start_scope=start_scope,
container_key=CONTAINER_KEY,
skip_validation=skip_validation,
validation_settings=validation_settings,
Expand Down
3 changes: 2 additions & 1 deletion src/dishka/container_objects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from collections.abc import AsyncGenerator, Callable, Generator
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from typing import Any, Protocol, TypeAlias

from dishka.entities.key import CompilationKey
Expand All @@ -19,5 +19,6 @@ def __call__(
cache: Any,
context: Any,
container: Any,
has: Callable[[CompilationKey], bool | Awaitable[bool]],
) -> Any:
raise NotImplementedError
10 changes: 10 additions & 0 deletions src/dishka/dependency_source/activator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from typing import Any

from dishka.entities.component import Component
from dishka.entities.factory_type import FactoryData
from dishka.entities.key import DependencyKey, const_dependency_key
from dishka.entities.marker import Marker
from dishka.entities.scope import BaseScope
from .factory import Factory


class StaticEvaluationUnavailable(Exception): # noqa: N818
def __init__(self, factory: FactoryData) -> None:
self.factory = factory

def __str__(self) -> str:
return (f"StaticEvaluationUnavailable({self.factory.provides},"
f" type={self.factory.type})")


class Activator:
__slots__ = ("factory", "marker", "marker_type")

Expand Down
9 changes: 6 additions & 3 deletions src/dishka/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ class NoContextValueError(DishkaError):


class UnsupportedFactoryError(DishkaError):
def __init__(self, factory_type: FactoryData) -> None:
self.factory_type = factory_type
def __init__(self, factory_data: FactoryData) -> None:
self.factory_data = factory_data

def __str__(self) -> str:
return f"Unsupported factory type {self.factory_type}."
name = get_source_name(self.factory_data)
return (
f"Unsupported factory type {self.factory_data.type} at {name}"
)


class InvalidGraphError(DishkaError):
Expand Down
Loading
Loading