diff --git a/.gitignore b/.gitignore index 4818e5c2e..b60d6aa4f 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ data pixi.lock _version.py +uv.lock diff --git a/docs/conf.py b/docs/conf.py index 0f413dfac..85080bd23 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,6 +66,7 @@ napari=("https://napari.org", None), spatialdata=("https://spatialdata.scverse.org/en/latest", None), shapely=("https://shapely.readthedocs.io/en/stable", None), + rapids_singlecell=("https://rapids-singlecell.readthedocs.io/en/latest", None), ) # Add any paths that contain templates here, relative to this directory. diff --git a/pyproject.toml b/pyproject.toml index 29c5461e2..2a4443ba7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,12 @@ optional-dependencies.docs = [ "sphinxcontrib-bibtex>=2.3", "sphinxcontrib-spelling>=7.6.2", ] +optional-dependencies.gpu-cuda11 = [ + "rapids-singlecell[rapids11]>=0.13.5", +] +optional-dependencies.gpu-cuda12 = [ + "rapids-singlecell[rapids12]>=0.13.5", +] optional-dependencies.leiden = [ "leidenalg", "spatialleiden>=0.4", diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 85b250d82..1aaa80053 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -4,6 +4,7 @@ from importlib.metadata import PackageMetadata from squidpy import datasets, experimental, gr, im, pl, read, tl +from squidpy._settings import settings try: md: PackageMetadata = metadata.metadata(__name__) @@ -15,4 +16,4 @@ del metadata, md -__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "tl"] +__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "tl", "settings"] diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index a4596df75..ad629e6f6 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -105,20 +105,36 @@ def decorator2(obj: Any) -> Any: _plotting_returns = """\ Nothing, just plots the figure and optionally saves the plot. """ -_parallelize = """\ +_CPU_ONLY = " Only available when ``device='cpu'``." + +_n_jobs = """\ n_jobs - Number of parallel jobs to use. + Number of parallel jobs to use. If ``None``, use all available cores. For ``backend="loky"``, the number of cores used by numba for each job spawned by the backend will be set to 1 in order to overcome the oversubscription issue in case you run numba in your function to parallelize. To set the absolute maximum number of threads in numba for your python program, set the environment variable: - ``NUMBA_NUM_THREADS`` before running the program. + ``NUMBA_NUM_THREADS`` before running the program.""" +_backend = """\ backend - Parallelization backend to use. See :class:`joblib.Parallel` for available options. + Parallelization backend to use. If ``None``, defaults to ``'loky'``. + See :class:`joblib.Parallel` for available options.""" +_show_progress_bar = """\ show_progress_bar - Whether to show the progress bar or not.""" + Whether to show the progress bar. If ``None``, uses ``scanpy.settings.verbosity``.""" + +_parallelize = f"{_n_jobs}\n{_backend}\n{_show_progress_bar}" +_parallelize_device = f"{_n_jobs}{_CPU_ONLY}\n{_backend}{_CPU_ONLY}\n{_show_progress_bar}{_CPU_ONLY}" +_seed_device = f"""\ +seed + Random seed for reproducibility.{_CPU_ONLY} +""" +_device_kwargs = """\ +device_kwargs + Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` + is set to ``'gpu'``. Must be ``None`` or empty when device is ``'cpu'``.""" _channels = """\ channels Channels for this feature is computed. If `None`, use all channels.""" @@ -379,6 +395,9 @@ def decorator2(obj: Any) -> Any: cat_plotting=_cat_plotting, plotting_returns=_plotting_returns, parallelize=_parallelize, + parallelize_device=_parallelize_device, + seed_device=_seed_device, + device_kwargs=_device_kwargs, channels=_channels, segment_kwargs=_segment_kwargs, ligrec_test_returns=_ligrec_test_returns, diff --git a/src/squidpy/_settings/__init__.py b/src/squidpy/_settings/__init__.py new file mode 100644 index 000000000..5c0dbd920 --- /dev/null +++ b/src/squidpy/_settings/__init__.py @@ -0,0 +1,8 @@ +"""Squidpy settings.""" + +from __future__ import annotations + +from squidpy._settings._dispatch import gpu_dispatch +from squidpy._settings._settings import DeviceType, settings + +__all__ = ["settings", "DeviceType", "gpu_dispatch"] diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py new file mode 100644 index 000000000..0a824a23b --- /dev/null +++ b/src/squidpy/_settings/_dispatch.py @@ -0,0 +1,118 @@ +"""GPU dispatch decorator for squidpy.""" + +from __future__ import annotations + +import functools +import importlib +import re +from collections.abc import Callable, Mapping +from typing import Any, TypeVar + +from squidpy._settings._settings import settings + +__all__ = ["gpu_dispatch"] + +F = TypeVar("F", bound=Callable[..., Any]) + + +def _make_gpu_note(func_name: str, gpu_module: str, indent: str = "") -> str: + lines = [ + ".. note::", + " This function supports GPU acceleration via :doc:`rapids_singlecell `.", + f" See :func:`{gpu_module}.{func_name}` for the GPU implementation.", + ] + return "\n".join(indent + line for line in lines) + + +def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | None: + """Inject GPU note into docstring before the Parameters section.""" + if doc is None: + return None + + # Find "Parameters\n ----------" and capture the indentation (spaces only, not newline) + match = re.search(r"\n([ \t]*)Parameters\s*\n\s*-+", doc) + if match: + indent = match.group(1) + gpu_note = _make_gpu_note(func_name, gpu_module, indent) + insert_pos = match.start() + return doc[:insert_pos] + "\n\n" + gpu_note + "\n" + doc[insert_pos:] + + # Fallback: append at the end + return doc + "\n\n" + _make_gpu_note(func_name, gpu_module) + + +@functools.cache +def _get_gpu_func(gpu_module: str, func_name: str) -> Callable[..., Any]: + """Get GPU function from module, with caching. + + Raises + ------ + ImportError + If the GPU module cannot be imported. + AttributeError + If the function does not exist in the GPU module. + """ + module = importlib.import_module(gpu_module) + return getattr(module, func_name) + + +def gpu_dispatch( + gpu_module: str = "rapids_singlecell.gr", + validate_args: Mapping[str, Callable[[Any], None]] | None = None, +) -> Callable[[F], F]: + """Decorator to dispatch to GPU implementation based on settings.device. + + When device is 'gpu', calls the GPU implementation from the specified module. + The ``device_kwargs`` parameter from the decorated function is merged into the + call for GPU-specific options. Arguments with ``None`` values are filtered out + to let the GPU function use its defaults. + + Parameters + ---------- + gpu_module + Module path containing the GPU implementation. + validate_args + Mapping of parameter names to validation functions. Each validator is called + with the parameter value before GPU dispatch and should raise ValueError + if the value is not supported on GPU. Validated arguments are removed from + kwargs before calling the GPU function. Only called when dispatching to GPU. + """ + _validate_args = validate_args or {} + + def decorator(func: F) -> F: + func_name = func.__name__ + + func.__doc__ = _inject_gpu_note(func.__doc__, func_name, gpu_module) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if settings.device == "cpu": + device_kwargs = kwargs.pop("device_kwargs", None) + if device_kwargs is not None and len(device_kwargs) > 0: + raise ValueError( + "device_kwargs should not be provided when squidpy.settings.device='cpu'. " + "Set squidpy.settings.device='gpu' or use settings.use_device('gpu') context manager." + ) + return func(*args, **kwargs) + + # GPU path + # run validators and remove validated args + for param_name, validator in _validate_args.items(): + if param_name in kwargs: + validator(kwargs[param_name]) + kwargs.pop(param_name) + + # get GPU function + gpu_func = _get_gpu_func(gpu_module, func_name) + + # filter out None values to let GPU function use its defaults + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # merge device_kwargs and call GPU function + device_kwargs = kwargs.pop("device_kwargs", None) or {} + kwargs.update(device_kwargs) + + return gpu_func(*args, **kwargs) + + return wrapper # type: ignore[return-value] + + return decorator diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py new file mode 100644 index 000000000..0dd33cc03 --- /dev/null +++ b/src/squidpy/_settings/_settings.py @@ -0,0 +1,88 @@ +"""Squidpy global settings.""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Literal, get_args + +if TYPE_CHECKING: + from collections.abc import Generator + +__all__ = ["settings", "DeviceType"] + +DeviceType = Literal["cpu", "gpu"] +GPU_UNAVAILABLE_MSG = ( + "GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support." +) +_device_var: ContextVar[DeviceType | None] = ContextVar("device", default=None) + + +def _check_gpu_available() -> bool: + """Check if GPU acceleration is available.""" + try: + import rapids_singlecell # noqa: F401 + + return True + except ImportError: + return False + + +class _SqSettings: + """Global configuration for squidpy. + + Attributes + ---------- + gpu_available + Whether GPU acceleration via rapids-singlecell is available. + device + Compute device. + Defaults to ``'gpu'`` if available, otherwise ``'cpu'``. + """ + + def __init__(self) -> None: + self.gpu_available: bool = _check_gpu_available() + + @property + def device(self) -> DeviceType: + """Compute device: ``'cpu'`` or ``'gpu'``. + + Defaults to ``'gpu'`` if rapids-singlecell is installed, otherwise ``'cpu'``. + Setting to ``'gpu'`` when GPU is unavailable raises a RuntimeError. + """ + value = _device_var.get() + if value is None: + return "gpu" if self.gpu_available else "cpu" + return value + + @device.setter + def device(self, value: DeviceType) -> None: + if value not in get_args(DeviceType): + raise ValueError(f"device must be one of {get_args(DeviceType)}, got {value!r}") + if value == "gpu" and not self.gpu_available: + raise RuntimeError(GPU_UNAVAILABLE_MSG) + _device_var.set(value) + + @contextmanager + def use_device(self, device: DeviceType) -> Generator[None, None, None]: + """Temporarily set the compute device within a context. + + Parameters + ---------- + device + The device to use. + + Examples + -------- + >>> with sq.settings.use_device("cpu"): + ... sq.gr.spatial_neighbors(adata) + """ + token: Token[DeviceType | None] = _device_var.set(_device_var.get()) + try: + self.device = device + yield + finally: + _device_var.reset(token) + + +settings = _SqSettings() diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 736c88172..ab34531c4 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -79,9 +79,9 @@ def parallelize( n_split: int | None = None, unit: str = "", use_ixs: bool = False, - backend: str = "loky", + backend: str | None = "loky", extractor: Callable[[Sequence[Any]], Any] | None = None, - show_progress_bar: bool = True, + show_progress_bar: bool | None = True, use_runner: bool = False, **_: Any, ) -> Any: @@ -119,6 +119,12 @@ def parallelize( ------- The result depending on ``callable``, ``extractor``. """ + # Apply defaults for None values (allows dispatch to pass through None) + if backend is None: + backend = "loky" + if show_progress_bar is None: + show_progress_bar = True + if show_progress_bar: try: import ipywidgets # noqa: F401 diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index f369759b1..b32494941 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -20,6 +20,7 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs +from squidpy._settings import gpu_dispatch from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, @@ -328,7 +329,9 @@ def test( copy: bool = False, key_added: str | None = None, numba_parallel: bool | None = None, - **kwargs: Any, + n_jobs: int | None = None, + backend: str | None = None, + show_progress_bar: bool | None = None, ) -> Mapping[str, pd.DataFrame] | None: """ Perform the permutation test as described in :cite:`cellphonedb`. @@ -410,10 +413,10 @@ def test( # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values) - n_jobs = _get_n_cores(kwargs.pop("n_jobs", None)) + n_jobs_ = _get_n_cores(n_jobs) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " - f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" + f"and `{len(clusters)}` cluster combinations using `{n_jobs_}` core(s)" ) res = _analysis( data, @@ -422,9 +425,10 @@ def test( threshold=threshold, n_perms=n_perms, seed=seed, - n_jobs=n_jobs, + n_jobs=n_jobs_, numba_parallel=numba_parallel, - **kwargs, + backend=backend, + show_progress_bar=show_progress_bar, ) res = { "means": _create_sparse_df( @@ -579,7 +583,6 @@ def prepare( interactions_params: Mapping[str, Any] = MappingProxyType({}), transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), - **_: Any, ) -> PermutationTest: """ %(PT_prepare.full_desc)s @@ -633,6 +636,7 @@ def prepare( @d.dedent +@gpu_dispatch() def ligrec( adata: AnnData | SpatialData, cluster_key: str, @@ -645,7 +649,20 @@ def ligrec( copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, - **kwargs: Any, + # prepare params + interactions_params: Mapping[str, Any] = MappingProxyType({}), + transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), + receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), + # test params + clusters: Cluster_t | None = None, + n_perms: int = 1000, + seed: int | None = None, + alpha: float = 0.05, + numba_parallel: bool | None = None, + n_jobs: int | None = None, + backend: str | None = None, + show_progress_bar: bool | None = None, + device_kwargs: dict[str, Any] | None = None, ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -657,25 +674,41 @@ def ligrec( %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. + %(device_kwargs)s Returns ------- %(ligrec_test_returns)s """ # noqa: D400 + del device_kwargs # handled by gpu_dispatch decorator if isinstance(adata, SpatialData): adata = adata.table + with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) - .prepare(interactions, complex_policy=complex_policy, **kwargs) + .prepare( + interactions, + complex_policy=complex_policy, + interactions_params=interactions_params, + transmitter_params=transmitter_params, + receiver_params=receiver_params, + ) .test( cluster_key=cluster_key, + clusters=clusters, + n_perms=n_perms, threshold=threshold, + seed=seed, corr_method=corr_method, corr_axis=corr_axis, + alpha=alpha, copy=copy, key_added=key_added, - **kwargs, + numba_parallel=numba_parallel, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, ) ) @@ -690,7 +723,8 @@ def _analysis( seed: int | None = None, n_jobs: int = 1, numba_parallel: bool | None = None, - **kwargs: Any, + backend: str | None = None, + show_progress_bar: bool | None = None, ) -> TempResult: """ Run the analysis as described in :cite:`cellphonedb`. @@ -709,13 +743,9 @@ def _analysis( Percentage threshold for removing lowly expressed genes in clusters. %(n_perms)s %(seed)s - n_jobs - Number of parallel jobs to launch. numba_parallel Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - kwargs - Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``. - + %(parallelize)s Returns ------- Tuple of the following format: @@ -757,7 +787,8 @@ def extractor(res: Sequence[TempResult]) -> TempResult: n_jobs=n_jobs, unit="permutation", extractor=extractor, - **kwargs, + backend=backend, + show_progress_bar=show_progress_bar, )( data, mean, diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 292c75994..4a5cc23f0 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,6 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs +from squidpy._settings import gpu_dispatch from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, @@ -35,6 +36,14 @@ __all__ = ["spatial_autocorr", "co_occurrence"] +def _validate_attr_for_gpu(value: Any) -> None: + """Validate attr parameter for GPU dispatch.""" + if value != "X": + raise ValueError( + f"attr={value!r} is not supported on GPU. Use `squidpy.settings.device = 'cpu'` to use other attributes." + ) + + it = nt.int32 ft = nt.float32 tt = nt.UniTuple @@ -45,6 +54,7 @@ @d.dedent @inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr) +@gpu_dispatch(gpu_module="rapids_singlecell.gr", validate_args={"attr": _validate_attr_for_gpu}) def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), @@ -60,8 +70,9 @@ def spatial_autocorr( use_raw: bool = False, copy: bool = False, n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, + backend: str | None = None, + show_progress_bar: bool | None = None, + device_kwargs: dict[str, Any] | None = None, ) -> pd.DataFrame | None: """ Calculate Global Autocorrelation Statistic (Moran’s I or Geary's C). @@ -104,9 +115,11 @@ def spatial_autocorr( Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. attr Which attribute of :class:`~anndata.AnnData` to access. See ``genes`` parameter for more information. - %(seed)s + Can be only 'X' when effective device is 'gpu'. + %(seed_device)s %(copy)s - %(parallelize)s + %(parallelize_device)s + %(device_kwargs)s Returns ------- @@ -128,6 +141,7 @@ def spatial_autocorr( - :attr:`anndata.AnnData.uns` ``['moranI']`` - the above mentioned dataframe, if ``mode = {sp.MORAN.s!r}``. - :attr:`anndata.AnnData.uns` ``['gearyC']`` - the above mentioned dataframe, if ``mode = {sp.GEARY.s!r}``. """ + del device_kwargs # device and use_sparse are handled by the gpu_dispatch decorator if isinstance(adata, SpatialData): adata = adata.table _assert_connectivity_key(adata, connectivity_key) @@ -342,16 +356,17 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent +@gpu_dispatch() def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, spatial_key: str = Key.obsm.spatial, interval: int | NDArrayA = 50, copy: bool = False, - n_splits: int | None = None, n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, + backend: str | None = None, + show_progress_bar: bool | None = None, + device_kwargs: dict[str, Any] | None = None, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -365,10 +380,8 @@ def co_occurrence( Distances interval at which co-occurrence is computed. If :class:`int`, uniformly spaced interval of the given size will be used. %(copy)s - n_splits - Number of splits in which to divide the spatial coordinates in - :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. - %(parallelize)s + %(parallelize_device)s + %(device_kwargs)s Returns ------- @@ -381,7 +394,7 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ - + del n_jobs, backend, show_progress_bar, device_kwargs # handled by gpu_dispatch decorator or unused on CPU if isinstance(adata, SpatialData): adata = adata.table _assert_categorical_obs(adata, key=cluster_key) @@ -405,10 +418,8 @@ def co_occurrence( spatial_y = spatial[:, 1] # Compute co-occurrence probabilities using the fast numba routine. + start = logg.info(f"Calculating co-occurrence probabilities for `{len(interval)}` intervals") out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs) - start = logg.info( - f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits" - ) if copy: logg.info("Finish", time=start) diff --git a/tests/graph/test_ppatterns.py b/tests/graph/test_ppatterns.py index 226fb2830..eb23e37f8 100644 --- a/tests/graph/test_ppatterns.py +++ b/tests/graph/test_ppatterns.py @@ -138,11 +138,11 @@ def test_co_occurrence(adata: AnnData): # @pytest.mark.parametrize(("ys", "xs"), [(10, 10), (None, None), (10, 20)]) -@pytest.mark.parametrize(("n_jobs", "n_splits"), [(1, 2), (2, 2)]) -def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int, n_splits: int): +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int): """Check co_occurrence reproducibility results.""" - arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) - arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) + arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs) + arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs) np.testing.assert_array_equal(sorted(interval_1), sorted(interval_2)) np.testing.assert_allclose(arr_1, arr_2) diff --git a/tests/test_gpu.py b/tests/test_gpu.py new file mode 100644 index 000000000..2ddbe4d42 --- /dev/null +++ b/tests/test_gpu.py @@ -0,0 +1,61 @@ +"""Tests for GPU functionality (skipped in CI without GPU). + +These tests verify GPU results match CPU results. Structure/correctness +of CPU outputs is tested elsewhere, so we only test equivalence here. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +import squidpy as sq +from squidpy._settings import settings + +pytestmark = pytest.mark.skipif( + not settings.gpu_available, + reason="GPU tests require rapids-singlecell to be installed", +) + + +@pytest.fixture +def adata_filtered(adata): + """Filter adata to genes with non-zero variance (avoids NaN in GPU spatial_autocorr).""" + X = adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X + gene_var = np.var(X, axis=0) + return adata[:, gene_var > 1e-6].copy() + + +class TestGPUvsCPU: + """Test that GPU and CPU produce equivalent results.""" + + def test_co_occurrence(self, adata_filtered): + """Test co_occurrence GPU vs CPU equivalence.""" + with settings.use_device("cpu"): + cpu_arr, cpu_interval = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) + with settings.use_device("gpu"): + gpu_arr, gpu_interval = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) + + np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) + np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) + + def test_spatial_autocorr(self, adata_filtered): + """Test spatial_autocorr GPU vs CPU equivalence.""" + sq.gr.spatial_neighbors(adata_filtered) + + with settings.use_device("cpu"): + cpu_result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) + with settings.use_device("gpu"): + gpu_result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) + + np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True) + + def test_ligrec(self, adata_filtered): + """Test ligrec GPU vs CPU equivalence.""" + with settings.use_device("cpu"): + cpu_result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True, n_perms=5) + with settings.use_device("gpu"): + gpu_result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True, n_perms=5) + + # Compare means (deterministic) + np.testing.assert_allclose(cpu_result["means"].values, gpu_result["means"].values, rtol=1e-5, equal_nan=True) diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 000000000..90f25e20e --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,134 @@ +"""Tests for squidpy._settings module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from squidpy._settings import gpu_dispatch, settings +from squidpy._settings._dispatch import _get_gpu_func +from squidpy._settings._settings import _device_var + + +@pytest.fixture(autouse=True) +def reset_device(): + """Reset device state before and after each test.""" + _device_var.set(None) + _get_gpu_func.cache_clear() + yield + _device_var.set(None) + _get_gpu_func.cache_clear() + + +class TestDeviceSettings: + """Test device property and use_device context manager.""" + + def test_invalid_device_raises(self): + """Test invalid device raises ValueError.""" + with pytest.raises(ValueError, match="device must be one of"): + settings.device = "invalid" + with pytest.raises(ValueError, match="device must be one of"): + with settings.use_device("invalid"): + pass + + @pytest.mark.skipif(settings.gpu_available, reason="GPU is available") + def test_gpu_without_rsc_raises(self): + """Test setting GPU without rapids-singlecell raises RuntimeError.""" + with pytest.raises(RuntimeError, match="GPU unavailable"): + settings.device = "gpu" + with pytest.raises(RuntimeError, match="GPU unavailable"): + with settings.use_device("gpu"): + pass + + +class TestGpuDispatch: + """Test the gpu_dispatch decorator.""" + + def test_cpu_path(self): + """Test CPU device calls original function.""" + calls = [] + + @gpu_dispatch() + def my_func(x, y): + calls.append((x, y)) + return x + y + + with settings.use_device("cpu"): + assert my_func(1, 2) == 3 + assert calls == [(1, 2)] + + def test_gpu_dispatch_and_device_kwargs(self): + """Test GPU dispatch with device_kwargs.""" + mock_module = MagicMock() + received = {} + + def gpu_my_func(x, use_sparse=False): + received.update({"x": x, "use_sparse": use_sparse}) + return "gpu_result" + + mock_module.my_func = gpu_my_func + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, device_kwargs=None): + return "cpu_result" + + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with patch("importlib.import_module", return_value=mock_module): + # Basic dispatch + assert my_func(42) == "gpu_result" + # With device_kwargs + assert my_func(42, device_kwargs={"use_sparse": True}) == "gpu_result" + assert received["use_sparse"] is True + + def test_device_kwargs_error_on_cpu(self): + """Test device_kwargs raises error on CPU path.""" + + @gpu_dispatch() + def my_func(x, device_kwargs=None): + return x * 2 + + with settings.use_device("cpu"): + with pytest.raises(ValueError, match="device_kwargs should not be provided"): + my_func(5, device_kwargs={"use_sparse": True}) + + def test_validate_args(self): + """Test validate_args runs validators before GPU dispatch.""" + mock_module = MagicMock() + mock_module.my_func = MagicMock(return_value="gpu_result") + + @gpu_dispatch( + gpu_module="test_module", + validate_args={ + "attr": lambda v: (_ for _ in ()).throw(ValueError(f"attr={v!r} invalid")) if v != "X" else None + }, + ) + def my_func(x, attr="X"): + return "cpu_result" + + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with patch("importlib.import_module", return_value=mock_module): + assert my_func(42, attr="X") == "gpu_result" + with pytest.raises(ValueError, match="attr='obs' invalid"): + my_func(42, attr="obs") + + def test_preserves_metadata_and_docstring(self): + """Test decorator preserves function name and injects GPU note.""" + + @gpu_dispatch(gpu_module="custom.module") + def documented_func(x): + """Original docstring. + + Parameters + ---------- + x + Input value. + """ + return x + + assert documented_func.__name__ == "documented_func" + assert "Original docstring." in documented_func.__doc__ + assert "GPU acceleration" in documented_func.__doc__ + assert "custom.module.documented_func" in documented_func.__doc__