diff --git a/docs/api.md b/docs/api.md index fc9922a31..f6328e328 100644 --- a/docs/api.md +++ b/docs/api.md @@ -86,6 +86,19 @@ import squidpy as sq tl.var_by_distance ``` +## Settings + +```{eval-rst} +.. currentmodule:: squidpy._backends._settings + +.. autoclass:: _Settings + :members: backend, use_backend, available_backends, get_backend + +.. data:: squidpy.settings + + The global settings instance. See :class:`~squidpy._backends._settings._Settings`. +``` + ## Datasets ```{eval-rst} .. module:: squidpy.datasets diff --git a/pyproject.toml b/pyproject.toml index 1e264c809..83a0ec3d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,12 @@ dependencies = [ "xarray>=2024.10", "zarr>=3", ] +optional-dependencies.cu12 = [ + "rapids-singlecell-cu12[rapids]", +] +optional-dependencies.cu13 = [ + "rapids-singlecell-cu13[rapids]", +] optional-dependencies.dev = [ "hatch>=1.9", "ipykernel", diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 85b250d82..e0fb8b502 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -4,6 +4,7 @@ from importlib.metadata import PackageMetadata from squidpy import datasets, experimental, gr, im, pl, read, tl +from squidpy._backends import settings try: md: PackageMetadata = metadata.metadata(__name__) @@ -15,4 +16,4 @@ del metadata, md -__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "tl"] +__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "settings", "tl"] diff --git a/src/squidpy/_backends/__init__.py b/src/squidpy/_backends/__init__.py new file mode 100644 index 000000000..ccb817aaf --- /dev/null +++ b/src/squidpy/_backends/__init__.py @@ -0,0 +1,9 @@ +"""Pluggable backend dispatch system for squidpy.""" + +from __future__ import annotations + +from squidpy._backends._dispatch import dispatch +from squidpy._backends._registry import available_backend_names, get_backend +from squidpy._backends._settings import settings + +__all__ = ["dispatch", "get_backend", "available_backend_names", "settings"] diff --git a/src/squidpy/_backends/_dispatch.py b/src/squidpy/_backends/_dispatch.py new file mode 100644 index 000000000..8a9b5e89d --- /dev/null +++ b/src/squidpy/_backends/_dispatch.py @@ -0,0 +1,432 @@ +"""Dispatch decorator with introspection-based argument routing.""" + +from __future__ import annotations + +import functools +import inspect +import warnings +from collections.abc import Callable +from typing import Any, TypeVar + +from squidpy._backends._registry import get_backend +from squidpy._backends._settings import settings + +F = TypeVar("F", bound=Callable[..., Any]) + +# Cache: (func_qualname, backend_canonical_name) -> (shared, cpu_only, gpu_only, host_defaults) +_sig_cache: dict[tuple[str, str], tuple[set, set, set, dict]] = {} + + +# All functions decorated with @dispatch, so we can update their signatures later +_dispatched_functions: list[Callable] = [] + + +def _get_param_sets( + func: Callable, + adapter_method: Callable, + func_name: str, + backend_name: str, +) -> tuple[set, set, set, dict]: + """Compute shared/cpu_only/gpu_only param sets. Cached per function+backend.""" + key = (func.__qualname__, backend_name) + if key in _sig_cache: + return _sig_cache[key] + + host_sig = inspect.signature(func) + adapter_sig = inspect.signature(adapter_method) + + host_params = set(host_sig.parameters.keys()) - {"self", "args", "kwargs"} + adapter_params = set(adapter_sig.parameters.keys()) - {"self", "args", "kwargs"} + + # Remove "backend" — it's the dispatch kwarg, not forwarded + host_params.discard("backend") + adapter_params.discard("backend") + + shared = host_params & adapter_params + cpu_only = host_params - adapter_params + gpu_only = adapter_params - host_params + + # Cache host defaults to detect non-default cpu_only args + host_defaults = {} + for name, param in host_sig.parameters.items(): + if name in cpu_only and param.default is not inspect.Parameter.empty: + host_defaults[name] = param.default + + result = (shared, cpu_only, gpu_only, host_defaults) + _sig_cache[key] = result + return result + + +# numpydoc section headers that end a Parameters block +_NUMPYDOC_SECTIONS = frozenset( + ("Returns", "Raises", "See Also", "Notes", "Examples", "Yields", "Warns", "References", "Attributes", "Methods") +) + + +def _find_section(lines: list[str], section: str) -> tuple[int, int] | None: + """Find a numpydoc section, returning (header_line, first_content_line). + + Returns None if the section is not found. + """ + for i, line in enumerate(lines): + if line.strip() == section and i + 1 < len(lines) and lines[i + 1].strip().startswith("---"): + return i, i + 2 + return None + + +def _detect_indent(lines: list[str], start: int, end: int) -> str: + """Detect the parameter-name indentation used in a numpydoc Parameters block. + + Looks for the first non-empty, non-section-header line between start and end. + """ + for line in lines[start:end]: + stripped = line.lstrip() + if stripped and stripped.split()[0].replace("*", "").replace(",", "").isidentifier(): + return line[: len(line) - len(stripped)] + return " " + + +def _extract_param_docs(docstring: str | None, param_names: set[str]) -> dict[str, str]: + """Extract numpydoc parameter entries for the given names. + + Uses indentation-based parsing: a parameter entry starts with a line + whose indentation matches the section's base indent, and continues + with all subsequent lines that are blank or more deeply indented. + + Returns a dict mapping param name to its dedented doc block. + On any parse ambiguity, the parameter is skipped rather than producing + garbled output. + """ + if not docstring or not param_names: + return {} + + lines = docstring.split("\n") + section = _find_section(lines, "Parameters") + if section is None: + return {} + + _, content_start = section + + # Find where the Parameters section ends + content_end = len(lines) + for i in range(content_start, len(lines)): + stripped = lines[i].strip() + if stripped in _NUMPYDOC_SECTIONS: + if i + 1 < len(lines) and lines[i + 1].strip().startswith("---"): + content_end = i + break + + base_indent = _detect_indent(lines, content_start, content_end) + base_indent_len = len(base_indent) + + # Parse parameter entries by indentation + result: dict[str, str] = {} + i = content_start + while i < content_end: + line = lines[i] + + # Skip blank lines between parameters + if not line.strip(): + i += 1 + continue + + # A parameter name line: at base indentation, starts with an identifier + line_indent_len = len(line) - len(line.lstrip()) + if line_indent_len != base_indent_len: + i += 1 + continue + + # Extract the parameter name (first word, before optional " : type") + first_token = line.strip().split()[0].rstrip(",") + # Handle *args, **kwargs style names + name = first_token.lstrip("*") + + # Collect the body (lines indented deeper than base) + block_lines = [line] + j = i + 1 + while j < content_end: + body_line = lines[j] + if not body_line.strip(): + block_lines.append(body_line) + j += 1 + continue + body_indent_len = len(body_line) - len(body_line.lstrip()) + if body_indent_len > base_indent_len: + block_lines.append(body_line) + j += 1 + else: + break + + # Strip trailing blank lines + while block_lines and not block_lines[-1].strip(): + block_lines.pop() + + if name in param_names: + # Dedent the block to remove the base indentation + dedented = [] + for bl in block_lines: + if bl.strip(): + dedented.append(bl[base_indent_len:] if len(bl) >= base_indent_len else bl) + else: + dedented.append("") + result[name] = "\n".join(dedented) + + i = j + + return result + + +def _inject_param_docs(docstring: str | None, extra_docs: dict[str, str]) -> str: + """Inject extra parameter docs and ``backend`` doc into a numpydoc docstring. + + Inserts before the first non-Parameters section (Returns, Raises, etc.). + If the Parameters section can't be found, the docstring is returned unchanged + rather than producing garbled output. + """ + if not docstring: + return docstring or "" + + lines = docstring.split("\n") + section = _find_section(lines, "Parameters") + if section is None: + return docstring + + _, content_start = section + + # Find insertion point: just before the next section header + insert_idx = len(lines) + for i in range(content_start, len(lines)): + stripped = lines[i].strip() + if stripped in _NUMPYDOC_SECTIONS: + if i + 1 < len(lines) and lines[i + 1].strip().startswith("---"): + insert_idx = i + break + + indent = _detect_indent(lines, content_start, insert_idx) + body_indent = indent + " " + + # Build the extra parameter lines + extra_lines: list[str] = [] + for doc_block in extra_docs.values(): + for doc_line in doc_block.split("\n"): + if doc_line.strip(): + extra_lines.append(indent + doc_line) + else: + extra_lines.append("") + + # Always add the backend parameter doc + extra_lines.append(f"{indent}backend") + extra_lines.append(f"{body_indent}Backend to use. Use ``'cpu'`` for the default implementation or a") + extra_lines.append(f"{body_indent}registered backend name (e.g. ``'gpu'``). See ``squidpy.settings.backend``.") + + lines = lines[:insert_idx] + extra_lines + [""] + lines[insert_idx:] + return "\n".join(lines) + + +def _build_signature(func: Callable) -> None: + """Build the wrapper's ``__signature__`` from the host function, adding ``backend``.""" + sig = inspect.signature(func) + params = list(sig.parameters.values()) + + if "backend" not in sig.parameters: + backend_param = inspect.Parameter("backend", inspect.Parameter.KEYWORD_ONLY, default="cpu", annotation=str) + kwargs_idx = next((i for i, p in enumerate(params) if p.kind == inspect.Parameter.VAR_KEYWORD), None) + if kwargs_idx is not None: + params.insert(kwargs_idx, backend_param) + else: + params.append(backend_param) + + func.__signature__ = sig.replace(parameters=params) + + +def _find_public_func(wrapper: Callable) -> Callable: + """Find the outermost public function that wraps a dispatch wrapper. + + Walks from the module-level attribute through ``__wrapped__`` to verify + it actually chains back to our wrapper. Returns the outermost function + (which may be ``wrapper`` itself if no outer decorator exists). + """ + import sys + + func = wrapper.__wrapped__ + mod = sys.modules.get(func.__module__) + if mod is None: + return wrapper + + candidate = getattr(mod, func.__name__, None) + if candidate is None or candidate is wrapper: + return wrapper + + # Walk __wrapped__ chain to verify candidate actually wraps our wrapper + obj = candidate + while obj is not None: + if obj is wrapper: + return candidate + obj = getattr(obj, "__wrapped__", None) + + return wrapper + + +def update_signatures() -> None: + """Merge GPU-only params from discovered backends into dispatched function signatures. + + Called once automatically after backend discovery so that ``help()`` / + IDE tooltips show the full parameter list (CPU + GPU + backend) with + documentation. + """ + from squidpy._backends._registry import _backends + + for wrapper in _dispatched_functions: + func = wrapper.__wrapped__ + func_name = func.__name__ + host_sig = inspect.signature(func) + host_param_names = set(host_sig.parameters.keys()) + host_param_names.add("backend") + + # Collect GPU-only params and their docs from all backends + gpu_params: list[inspect.Parameter] = [] + gpu_param_names: set[str] = set() + adapter_docs: dict[str, str] = {} + for backend in _backends.values(): + try: + method = getattr(backend, func_name, None) + except Exception: # noqa: BLE001 + continue + if method is None: + continue + try: + adapter_sig = inspect.signature(method) + except (ValueError, TypeError): + continue + + new_names: set[str] = set() + for name, param in adapter_sig.parameters.items(): + if name in host_param_names or name in gpu_param_names or name in {"self", "args", "kwargs"}: + continue + gpu_params.append(param.replace(kind=inspect.Parameter.KEYWORD_ONLY)) + gpu_param_names.add(name) + new_names.add(name) + + # Extract docstrings for these new params from the adapter's docstring + if new_names: + param_docs = _extract_param_docs(method.__doc__, new_names) + adapter_docs.update(param_docs) + + # --- Update signature --- + params: list[inspect.Parameter] = [] + var_kw = None + for p in host_sig.parameters.values(): + if p.kind == inspect.Parameter.VAR_KEYWORD: + var_kw = p + else: + params.append(p) + + params.extend(gpu_params) + params.append(inspect.Parameter("backend", inspect.Parameter.KEYWORD_ONLY, default="cpu", annotation=str)) + if var_kw is not None: + params.append(var_kw) + + merged_sig = host_sig.replace(parameters=params) + merged_doc = _inject_param_docs(wrapper.__doc__, adapter_docs) + + # Update the dispatch wrapper + wrapper.__signature__ = merged_sig + wrapper.__doc__ = merged_doc + + # If an outer decorator (e.g. deprecated_params) copied __signature__ + # and __doc__ via functools.wraps, update it too. + public_func = _find_public_func(wrapper) + if public_func is not wrapper: + public_func.__signature__ = merged_sig + public_func.__doc__ = merged_doc + + +def dispatch(func: F) -> F: + """Route a function call to the active backend or fall back to CPU. + + Apply this decorator to any squidpy public function that a backend may + accelerate. The decorator: + + * Injects a ``backend`` keyword argument (default ``"cpu"``). + * Injects backend-specific parameters and their docstrings from + discovered backends into the function signature and docstring. + * At call time, resolves the effective backend + (``backend`` kwarg > ``squidpy.settings.backend``) and forwards + arguments via signature introspection. + + Argument routing (GPU path): + + * **shared** (present in both host and backend) — forwarded. + * **backend-only** (e.g. ``use_sparse``, ``multi_gpu``) — forwarded. + * **cpu-only at default value** — silently dropped. + * **cpu-only at non-default value** — dropped with a warning. + + Argument routing (CPU path): + + * All host arguments are forwarded normally. + * Backend-only arguments raise ``TypeError`` (Python's own check). + + If the active backend does not implement the decorated function, the + call falls back to the CPU implementation transparently. + + Parameters + ---------- + func + The CPU implementation to wrap. + + Returns + ------- + The wrapped function with backend dispatch. + + Examples + -------- + >>> from squidpy._backends import dispatch + >>> @dispatch + ... def my_function(adata, n_jobs=None): ... # CPU implementation + """ + func_name = func.__name__ + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + local_backend = kwargs.pop("backend", None) + effective = local_backend or settings.backend + + if effective == "cpu": + return func(*args, **kwargs) + + backend = get_backend(effective) + if backend is None: + raise RuntimeError( + f"Backend {effective!r} is not installed. Install it or set squidpy.settings.backend = 'cpu'." + ) + + method = getattr(backend, func_name, None) + if method is None: + # Backend doesn't implement this function — fall back to CPU + return func(*args, **kwargs) + + shared, cpu_only, gpu_only, host_defaults = _get_param_sets(func, method, func_name, backend.name) + + # Route kwargs + adapter_kwargs: dict[str, Any] = {} + for key, value in kwargs.items(): + if key in shared or key in gpu_only: + adapter_kwargs[key] = value + elif key in cpu_only: + # Warn if non-default + if key not in host_defaults or value != host_defaults[key]: + warnings.warn( + f"{key!r} has no effect on backend {effective!r}.", + stacklevel=2, + ) + else: + # Unknown kwarg — let the adapter deal with it + adapter_kwargs[key] = value + + return method(*args, **adapter_kwargs) + + # Initial signature with just `backend` (before backends are discovered) + _build_signature(wrapper) + _dispatched_functions.append(wrapper) + + return wrapper # type: ignore[return-value] diff --git a/src/squidpy/_backends/_registry.py b/src/squidpy/_backends/_registry.py new file mode 100644 index 000000000..af95117d7 --- /dev/null +++ b/src/squidpy/_backends/_registry.py @@ -0,0 +1,130 @@ +"""Backend discovery via Python entrypoints.""" + +from __future__ import annotations + +import importlib.metadata +import logging +import warnings +from difflib import get_close_matches +from typing import Any + +logger = logging.getLogger(__name__) + +_backends: dict[str, Any] = {} # canonical_name -> instance +_alias_map: dict[str, str] = {} # alias -> canonical_name +_discovered = False + +# Trusted (verified) backends and their known aliases. +# Backends not in this list still work but emit a one-time warning on first use. +# To become trusted, submit a PR adding your backend here and pass the +# conformance test suite (squidpy.testing.backend_conformance). +TRUSTED_BACKENDS: dict[str, dict[str, Any]] = { + "rapids_singlecell": { + "aliases": ["rapids-singlecell", "rsc", "cuda", "gpu"], + "package": "rapids-singlecell", + }, +} + +# Build reverse lookup: alias -> canonical_name +_TRUSTED_ALIASES: dict[str, str] = {} +for _canonical, _info in TRUSTED_BACKENDS.items(): + _TRUSTED_ALIASES[_canonical] = _canonical + for _alias in _info["aliases"]: + _TRUSTED_ALIASES[_alias] = _canonical + + +def _ensure_discovered() -> None: + """Discover and register backends via entrypoints (lazy, runs once). + + All backends are loaded on first call. Untrusted backends (not in + :data:`TRUSTED_BACKENDS`) emit a warning on first use. + """ + global _discovered + if _discovered: + return + _discovered = True + + for ep in importlib.metadata.entry_points(group="squidpy.backends"): + try: + cls = ep.load() + instance = cls() + canonical = instance.name + _backends[canonical] = instance + + # register aliases + _alias_map[canonical] = canonical + for alias in getattr(instance, "aliases", []): + if alias in _alias_map and _alias_map[alias] != canonical: + warnings.warn( + f"Backend alias {alias!r} claimed by both " + f"{_alias_map[alias]!r} and {canonical!r}. " + f"Using {_alias_map[alias]!r}.", + stacklevel=2, + ) + else: + _alias_map[alias] = canonical + except Exception: # noqa: BLE001 + logger.debug("Failed to load backend entrypoint %r", ep.name, exc_info=True) + + # Merge backend-specific params into dispatched function signatures + if _backends: + from squidpy._backends._dispatch import update_signatures + + update_signatures() + + +def check_trusted(name: str) -> None: + """Emit a one-time warning if the backend is not in the trusted list.""" + canonical = _alias_map.get(name, name) + if canonical not in TRUSTED_BACKENDS and canonical in _backends: + warnings.warn( + f"Backend {canonical!r} is not in squidpy's trusted backends list. " + f"It may not have passed the conformance test suite. " + f"Trusted backends: {sorted(TRUSTED_BACKENDS)}.", + stacklevel=3, + ) + + +def _suggest_backend(name: str) -> str: + """Build an error message with 'did you mean' suggestions.""" + _ensure_discovered() + all_names = sorted(set(list(_alias_map.keys()) + list(_TRUSTED_ALIASES.keys()))) + matches = get_close_matches(name, all_names, n=1, cutoff=0.4) + msg = f"Unknown backend {name!r}." + if matches: + msg += f" Did you mean {matches[0]!r}?" + available = available_backend_names() + if available: + msg += f" Available: {available}." + else: + msg += " No backends are currently installed." + return msg + + +def resolve_backend_name(name: str) -> str | None: + """Resolve alias to canonical backend name. + + Recognises both loaded backends and trusted (but not yet installed) + backend aliases. Returns ``None`` only for completely unknown names. + """ + _ensure_discovered() + if name == "cpu": + return "cpu" + return _alias_map.get(name) or _TRUSTED_ALIASES.get(name) + + +def get_backend(name: str) -> Any | None: + """Get backend instance by name or alias. Returns None for 'cpu'.""" + _ensure_discovered() + if name == "cpu": + return None + canonical = _alias_map.get(name) or _TRUSTED_ALIASES.get(name) + if canonical is None: + return None + return _backends.get(canonical) + + +def available_backend_names() -> list[str]: + """Return all registered backend names and aliases.""" + _ensure_discovered() + return sorted(_alias_map.keys()) diff --git a/src/squidpy/_backends/_settings.py b/src/squidpy/_backends/_settings.py new file mode 100644 index 000000000..767f1ff2e --- /dev/null +++ b/src/squidpy/_backends/_settings.py @@ -0,0 +1,141 @@ +"""Backend settings with thread-safe context variable.""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Generator + +_backend_var: ContextVar[str] = ContextVar("backend", default="cpu") + + +class _Settings: + """Global settings for squidpy's backend dispatch system. + + Access via the singleton ``squidpy.settings``. + + Examples + -------- + >>> import squidpy as sq + >>> sq.settings.backend = "gpu" + >>> sq.settings.available_backends() + ['rapids_singlecell'] + """ + + @property + def backend(self) -> str: + """The active backend name (default ``'cpu'``). + + Set to a registered backend name or alias (e.g. ``'gpu'``, ``'cuda'``) + to dispatch supported functions to that backend. + Aliases are resolved to the canonical name. + + Examples + -------- + >>> sq.settings.backend = "gpu" + >>> sq.settings.backend + 'rapids_singlecell' + """ + return _backend_var.get() + + @backend.setter + def backend(self, value: str) -> None: + from squidpy._backends._registry import ( + TRUSTED_BACKENDS, + _suggest_backend, + check_trusted, + get_backend, + resolve_backend_name, + ) + + if value == "cpu": + _backend_var.set(value) + return + + canonical = resolve_backend_name(value) + + # Completely unknown name — suggest alternatives + if canonical is None: + raise ValueError(_suggest_backend(value)) + + # Trusted but not installed + if canonical in TRUSTED_BACKENDS and get_backend(canonical) is None: + package = TRUSTED_BACKENDS[canonical]["package"] + raise ImportError( + f"Backend {value!r} ({canonical}) is not installed. Install it with: pip install {package}" + ) + + # Known alias but backend not loaded + if get_backend(canonical) is None: + raise ImportError(f"Backend {value!r} is not installed.") + + # Warn if untrusted + check_trusted(canonical) + + # Always store the canonical name + _backend_var.set(canonical) + + @contextmanager + def use_backend(self, backend: str) -> Generator[None, None, None]: + """Temporarily set the backend within a context. + + Parameters + ---------- + backend + The backend to use inside the context. + + Examples + -------- + >>> with sq.settings.use_backend("gpu"): + ... sq.gr.spatial_autocorr(adata) + """ + token = _backend_var.set(self.backend) + try: + self.backend = backend + yield + finally: + _backend_var.reset(token) + + @staticmethod + def available_backends() -> list[str]: + """Return canonical names of all discovered backends. + + Examples + -------- + >>> sq.settings.available_backends() + ['rapids_singlecell'] + """ + from squidpy._backends._registry import _backends, _ensure_discovered + + _ensure_discovered() + return sorted(_backends.keys()) + + @staticmethod + def get_backend(name: str) -> Any | None: + """Look up a backend by name or alias. + + Parameters + ---------- + name + Canonical name or alias (e.g. ``'gpu'``, ``'cuda'``, + ``'rapids_singlecell'``). + + Returns + ------- + The backend instance, or ``None`` if not found. + + Examples + -------- + >>> backend = sq.settings.get_backend("gpu") + >>> backend.name + 'rapids_singlecell' + """ + from squidpy._backends._registry import get_backend + + return get_backend(name) + + +settings = _Settings() diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index a4596df75..1a1b7e8ac 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -108,15 +108,12 @@ def decorator2(obj: Any) -> Any: _parallelize = """\ n_jobs Number of parallel jobs to use. - For ``backend="loky"``, the number of cores used by numba for - each job spawned by the backend will be set to 1 in order to - overcome the oversubscription issue in case you run + The number of cores used by numba for each job will be set to 1 + in order to overcome the oversubscription issue in case you run numba in your function to parallelize. To set the absolute maximum number of threads in numba for your python program, set the environment variable: ``NUMBA_NUM_THREADS`` before running the program. -backend - Parallelization backend to use. See :class:`joblib.Parallel` for available options. show_progress_bar Whether to show the progress bar or not.""" _channels = """\ diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 242ec16f2..661494312 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -17,6 +17,7 @@ from scipy.sparse import csc_matrix from spatialdata import SpatialData +from squidpy._backends import dispatch from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs @@ -629,6 +630,7 @@ def prepare( @d.dedent +@dispatch def ligrec( adata: AnnData | SpatialData, cluster_key: str, diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 3163035db..17b437cdf 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -20,6 +20,7 @@ from spatialdata import SpatialData from statsmodels.stats.multitest import multipletests +from squidpy._backends import dispatch from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs @@ -45,6 +46,7 @@ @d.dedent @inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr) +@dispatch def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), @@ -60,7 +62,6 @@ def spatial_autocorr( use_raw: bool = False, copy: bool = False, n_jobs: int | None = None, - backend: str = "loky", show_progress_bar: bool = True, ) -> pd.DataFrame | None: """ @@ -208,7 +209,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr extractor=np.concatenate, use_ixs=True, n_jobs=n_jobs, - backend=backend, + backend="loky", show_progress_bar=show_progress_bar, )(mode=mode, g=g, vals=vals, seed=seed) else: @@ -341,7 +342,8 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent -@deprecated_params({"n_splits": "1.10.0", "n_jobs": "1.10.0", "backend": "1.10.0", "show_progress_bar": "1.10.0"}) +@deprecated_params({"n_splits": "1.10.0", "n_jobs": "1.10.0"}) +@dispatch def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, diff --git a/src/squidpy/testing/__init__.py b/src/squidpy/testing/__init__.py new file mode 100644 index 000000000..8eb92309b --- /dev/null +++ b/src/squidpy/testing/__init__.py @@ -0,0 +1,7 @@ +"""Testing utilities for squidpy backends.""" + +from __future__ import annotations + +from squidpy.testing.backend_conformance import validate_backend + +__all__ = ["validate_backend"] diff --git a/src/squidpy/testing/backend_conformance.py b/src/squidpy/testing/backend_conformance.py new file mode 100644 index 000000000..1f8015128 --- /dev/null +++ b/src/squidpy/testing/backend_conformance.py @@ -0,0 +1,125 @@ +"""Conformance test suite for squidpy backends. + +Usage in backend CI:: + + from squidpy.testing.backend_conformance import validate_backend + + validate_backend("rapids_singlecell") +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from anndata import AnnData + +if TYPE_CHECKING: + from collections.abc import Sequence + +# Tolerance registry per function +TOLERANCES: dict[str, dict[str, float]] = { + "spatial_autocorr": {"atol": 1e-5, "rtol": 1e-3}, + "co_occurrence": {"atol": 1e-5, "rtol": 1e-2}, +} + + +def _make_test_adata(n_obs: int = 500, n_vars: int = 100) -> AnnData: + """Create a minimal AnnData for testing.""" + rng = np.random.default_rng(42) + adata = AnnData(X=rng.random((n_obs, n_vars)).astype(np.float32)) + adata.obsm["spatial"] = rng.random((n_obs, 2)).astype(np.float32) * 1000 + adata.obs["cell_type"] = rng.choice(["A", "B", "C"], size=n_obs) + adata.obs["cell_type"] = adata.obs["cell_type"].astype("category") + return adata + + +def validate_backend( + backend_name: str, + functions: Sequence[str] | None = None, +) -> dict[str, str]: + """Run conformance tests against a backend. + + Parameters + ---------- + backend_name + Name or alias of the backend to test. + functions + Specific functions to test. None = test all known functions. + + Returns + ------- + Dict mapping function name to result string. + + Raises + ------ + AssertionError + If any test fails. + """ + import squidpy as sq + from squidpy._backends._registry import get_backend + + backend = get_backend(backend_name) + assert backend is not None, f"Backend {backend_name!r} not found" + + adata = _make_test_adata() + + # Build spatial graph (required for spatial_autocorr) + sq.gr.spatial_neighbors(adata) + + all_tests = { + "spatial_autocorr": _test_spatial_autocorr, + "co_occurrence": _test_co_occurrence, + } + + to_test = {k: v for k, v in all_tests.items() if functions is None or k in functions} + + results: dict[str, str] = {} + for name, test_fn in to_test.items(): + method = getattr(backend, name, None) + if method is None: + results[name] = "SKIPPED (not implemented)" + continue + try: + test_fn(adata, backend_name) + results[name] = "PASSED" + except Exception as e: + results[name] = f"FAILED: {e}" + raise + + return results + + +def _test_spatial_autocorr(adata: AnnData, backend_name: str) -> None: + import squidpy as sq + + # CPU reference + cpu_result = sq.gr.spatial_autocorr(adata.copy(), mode="moran", copy=True) + + # Backend result + with sq.settings.use_backend(backend_name): + backend_result = sq.gr.spatial_autocorr(adata.copy(), mode="moran", copy=True) + + tol = TOLERANCES["spatial_autocorr"] + np.testing.assert_allclose( + cpu_result["I"].values, + backend_result["I"].values, + **tol, + err_msg="spatial_autocorr Moran's I mismatch", + ) + + +def _test_co_occurrence(adata: AnnData, backend_name: str) -> None: + import squidpy as sq + + cpu_result = sq.gr.co_occurrence(adata.copy(), cluster_key="cell_type", copy=True) + with sq.settings.use_backend(backend_name): + backend_result = sq.gr.co_occurrence(adata.copy(), cluster_key="cell_type", copy=True) + + tol = TOLERANCES["co_occurrence"] + np.testing.assert_allclose( + cpu_result[0], + backend_result[0], + **tol, + err_msg="co_occurrence probability mismatch", + ) diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 000000000..81d52cd08 --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,381 @@ +"""Tests for the backend dispatch system.""" + +from __future__ import annotations + +import warnings + +import pytest + +from squidpy._backends import _registry, dispatch, settings +from squidpy._backends._dispatch import _sig_cache +from squidpy._backends._registry import ( + _TRUSTED_ALIASES, + TRUSTED_BACKENDS, + _alias_map, + _backends, +) + + +class FakeBackend: + name = "fake_gpu" + aliases = ["fake", "test-gpu"] + + def my_func(self, x, gpu_param=None): + """Run my_func on GPU. + + Parameters + ---------- + x + Input value. + gpu_param + GPU-specific parameter. + """ + return f"gpu:{x}:{gpu_param}" + + +@pytest.fixture(autouse=True) +def _reset_state(): + """Reset all backend state between tests.""" + _backends.clear() + _alias_map.clear() + _sig_cache.clear() + old_discovered = _registry._discovered + # Mark as discovered to prevent entrypoint loading during tests — + # tests register their own fake backends explicitly. + _registry._discovered = True + settings.backend = "cpu" + # Temporarily add fake backend to trusted list + old_trusted = TRUSTED_BACKENDS.copy() + old_aliases = _TRUSTED_ALIASES.copy() + TRUSTED_BACKENDS["fake_gpu"] = { + "aliases": ["fake", "test-gpu"], + "package": "fake-gpu-pkg", + } + _TRUSTED_ALIASES["fake_gpu"] = "fake_gpu" + _TRUSTED_ALIASES["fake"] = "fake_gpu" + _TRUSTED_ALIASES["test-gpu"] = "fake_gpu" + yield + _backends.clear() + _alias_map.clear() + _sig_cache.clear() + _registry._discovered = old_discovered + settings.backend = "cpu" + TRUSTED_BACKENDS.clear() + TRUSTED_BACKENDS.update(old_trusted) + _TRUSTED_ALIASES.clear() + _TRUSTED_ALIASES.update(old_aliases) + + +def _register_fake(): + backend = FakeBackend() + _backends["fake_gpu"] = backend + _alias_map["fake_gpu"] = "fake_gpu" + _alias_map["fake"] = "fake_gpu" + _alias_map["test-gpu"] = "fake_gpu" + return backend + + +class TestSettings: + def test_default_is_cpu(self): + assert settings.backend == "cpu" + + def test_set_unknown_backend_raises_with_suggestion(self): + _register_fake() + with pytest.raises(ValueError, match="Unknown backend.*Did you mean"): + settings.backend = "fak" # close to "fake" + + def test_set_unknown_backend_raises(self): + with pytest.raises(ValueError, match="Unknown backend"): + settings.backend = "nonexistent" + + def test_trusted_but_not_installed_raises(self): + # fake_gpu is trusted but not registered (not installed) + with pytest.raises(ImportError, match="not installed"): + settings.backend = "fake" + + def test_context_manager_restores(self): + _register_fake() + + settings.backend = "fake" + with settings.use_backend("cpu"): + assert settings.backend == "cpu" + assert settings.backend == "fake_gpu" + + def test_set_via_alias(self): + _register_fake() + settings.backend = "fake" + # Aliases resolve to canonical name + assert settings.backend == "fake_gpu" + + def test_set_via_canonical(self): + _register_fake() + settings.backend = "fake_gpu" + assert settings.backend == "fake_gpu" + + def test_untrusted_backend_warns(self): + # Remove fake_gpu from trusted list + TRUSTED_BACKENDS.pop("fake_gpu", None) + _TRUSTED_ALIASES.pop("fake_gpu", None) + _TRUSTED_ALIASES.pop("fake", None) + _TRUSTED_ALIASES.pop("test-gpu", None) + + _register_fake() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + settings.backend = "fake" + assert len(w) == 1 + assert "not in squidpy's trusted backends list" in str(w[0].message) + + def test_available_backends_empty(self): + assert settings.available_backends() == [] + + def test_available_backends_with_registered(self): + _register_fake() + assert "fake_gpu" in settings.available_backends() + + def test_get_backend_returns_instance(self): + backend = _register_fake() + result = settings.get_backend("fake_gpu") + assert result is backend + + def test_get_backend_by_alias(self): + backend = _register_fake() + assert settings.get_backend("fake") is backend + assert settings.get_backend("test-gpu") is backend + + def test_get_backend_unknown_returns_none(self): + assert settings.get_backend("nonexistent") is None + + def test_get_backend_cpu_returns_none(self): + assert settings.get_backend("cpu") is None + + +class TestDispatch: + def test_cpu_path(self): + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}:{n_jobs}" + + assert my_func(42) == "cpu:42:None" + assert my_func(42, n_jobs=4) == "cpu:42:4" + + def test_gpu_dispatch(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + with settings.use_backend("fake"): + assert my_func(42) == "gpu:42:None" + + def test_gpu_specific_kwarg(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + with settings.use_backend("fake"): + assert my_func(42, gpu_param="hello") == "gpu:42:hello" + + def test_cpu_only_kwarg_warns(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + with settings.use_backend("fake"): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + my_func(42, n_jobs=4) + assert len(w) == 1 + assert "n_jobs" in str(w[0].message) + + def test_cpu_only_kwarg_default_silent(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + with settings.use_backend("fake"): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + my_func(42, n_jobs=None) + assert len(w) == 0 + + def test_gpu_kwarg_on_cpu_raises(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + with pytest.raises(TypeError, match="gpu_param"): + my_func(42, gpu_param="hello") + + def test_fallback_when_not_implemented(self): + _register_fake() + + @dispatch + def other_func(x): + return f"cpu:{x}" + + with settings.use_backend("fake"): + # FakeBackend doesn't have other_func -> CPU fallback + assert other_func(42) == "cpu:42" + + def test_per_function_override(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + settings.backend = "fake" + assert my_func(42, backend="cpu") == "cpu:42" + + def test_alias_resolution(self): + _register_fake() + + @dispatch + def my_func(x, n_jobs=None): + return f"cpu:{x}" + + with settings.use_backend("fake_gpu"): + assert my_func(42) == "gpu:42:None" + with settings.use_backend("fake"): + assert my_func(42) == "gpu:42:None" + with settings.use_backend("test-gpu"): + assert my_func(42) == "gpu:42:None" + + def test_backend_not_installed_raises(self): + _register_fake() + + @dispatch + def my_func(x): + return f"cpu:{x}" + + with pytest.raises(RuntimeError, match="not installed"): + my_func(42, backend="nonexistent_backend") + + +class TestDocstringMerging: + """Test numpydoc parameter extraction and injection.""" + + def test_extract_param_docs(self): + from squidpy._backends._dispatch import _extract_param_docs + + docstring = """\ +Run something. + +Parameters +---------- +x + Input value. +gpu_param + GPU-specific parameter. + +Returns +------- +Result. +""" + result = _extract_param_docs(docstring, {"gpu_param"}) + assert "gpu_param" in result + assert "GPU-specific" in result["gpu_param"] + + def test_extract_skips_missing_params(self): + from squidpy._backends._dispatch import _extract_param_docs + + docstring = """\ +Parameters +---------- +x + Input. +""" + result = _extract_param_docs(docstring, {"nonexistent"}) + assert result == {} + + def test_extract_no_params_section(self): + from squidpy._backends._dispatch import _extract_param_docs + + result = _extract_param_docs("Just a docstring.", {"x"}) + assert result == {} + + def test_inject_param_docs(self): + from squidpy._backends._dispatch import _inject_param_docs + + docstring = """\ +Do something. + +Parameters +---------- +x + Input. + +Returns +------- +Result. +""" + result = _inject_param_docs(docstring, {"gpu_param": "gpu_param\n A GPU param."}) + assert "gpu_param" in result + assert "backend" in result + # backend doc should appear before Returns + lines = result.split("\n") + backend_idx = next(i for i, l in enumerate(lines) if "backend" in l.lower() and "Backend to use" not in l) + returns_idx = next(i for i, l in enumerate(lines) if l.strip() == "Returns") + assert backend_idx < returns_idx + + def test_inject_no_params_section_unchanged(self): + from squidpy._backends._dispatch import _inject_param_docs + + docstring = "Just a plain docstring." + assert _inject_param_docs(docstring, {"x": "x\n Param."}) == docstring + + def test_extract_handles_multiline_descriptions(self): + from squidpy._backends._dispatch import _extract_param_docs + + docstring = """\ +Parameters +---------- +multi + First line of description. + Second line continues here + with more detail. +other + Another param. +""" + result = _extract_param_docs(docstring, {"multi"}) + assert "multi" in result + assert "Second line" in result["multi"] + assert "with more detail" in result["multi"] + + +class TestLazyDiscovery: + """Test that backend discovery is lazy.""" + + def test_discovery_not_triggered_on_import(self): + """Verify _discovered stays False until a backend function is used.""" + _registry._discovered = False + _backends.clear() + _alias_map.clear() + + assert not _registry._discovered + + def test_discovery_triggered_by_get_backend(self): + """get_backend triggers lazy discovery.""" + _registry._discovered = False + from squidpy._backends._registry import get_backend + + get_backend("cpu") + assert _registry._discovered + + def test_discovery_triggered_by_settings_setter(self): + """Setting backend triggers lazy discovery.""" + _registry._discovered = False + _register_fake() + _registry._discovered = False + settings.backend = "fake" + assert _registry._discovered