Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"pylibraft",
"dask",
"cuvs",
"spatialdata",
]
default_role = "literal"
napoleon_google_docstring = False
Expand Down Expand Up @@ -126,6 +127,7 @@
"statsmodels": ("https://www.statsmodels.org/stable/", None),
"omnipath": ("https://omnipath.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/stable/", None),
"spatialdata": ("https://spatialdata.scverse.org/en/stable/", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ dev = [
"pre-commit",
]

[project.entry-points."squidpy.backends"]
rapids_singlecell = "rapids_singlecell.squidpy_backend:RscSquidpyBackend"

[project.urls]
Documentation = "https://rapids-singlecell.readthedocs.io"
Source = "https://github.com/scverse/rapids_singlecell"
Expand Down
5 changes: 5 additions & 0 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from scipy.sparse import csc_matrix as csc_matrix_cpu
from scipy.sparse import csr_matrix as csr_matrix_cpu

try:
from spatialdata import SpatialData
except ImportError:
SpatialData = None


def _meta_dense(dtype):
return cp.zeros([0], dtype=dtype)
Expand Down
30 changes: 30 additions & 0 deletions src/rapids_singlecell/squidpy_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Squidpy backend adapter for rapids_singlecell.

The dispatch decorator introspects the real RSC function signatures
(lazily imported on first access), so no need to duplicate them here.
"""

from __future__ import annotations

import importlib


class RscSquidpyBackend:
"""Backend adapter exposing rapids_singlecell GPU implementations to squidpy."""

name = "rapids_singlecell"
aliases = ["rapids-singlecell", "rsc", "cuda", "gpu"]

# squidpy function name -> module that implements it
_functions = {
"spatial_autocorr": "rapids_singlecell.squidpy_gpu",
"co_occurrence": "rapids_singlecell.squidpy_gpu",
"ligrec": "rapids_singlecell.squidpy_gpu",
}

def __getattr__(self, name: str):
if name in self._functions:
func = getattr(importlib.import_module(self._functions[name]), name)
setattr(self, name, func) # cache on instance
return func
raise AttributeError(f"{type(self).__name__!r} has no attribute {name!r}")
5 changes: 4 additions & 1 deletion src/rapids_singlecell/squidpy_gpu/_autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scipy import sparse
from statsmodels.stats.multitest import multipletests

from rapids_singlecell._compat import SpatialData
from rapids_singlecell.preprocessing._utils import _sparse_to_dense

from ._gearysc import _gearys_C_cupy
Expand Down Expand Up @@ -49,7 +50,7 @@ def _to_cupy(vals, *, use_sparse: bool, dtype):


def spatial_autocorr(
adata: AnnData,
adata: AnnData | SpatialData,
*,
connectivity_key: str = "spatial_connectivities",
genes: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -118,6 +119,8 @@ def spatial_autocorr(
DataFrame containing the autocorrelation scores, p-values, and corrected p-values for each gene. \
If `copy` is False, the results are stored in `adata.uns` and None is returned.
"""
if SpatialData is not None and isinstance(adata, SpatialData):
adata = adata.table
if genes is None:
if "highly_variable" in adata.var:
genes = adata[:, adata.var["highly_variable"]].var_names.values
Expand Down
5 changes: 4 additions & 1 deletion src/rapids_singlecell/squidpy_gpu/_co_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from cuml.metrics import pairwise_distances

from rapids_singlecell._compat import SpatialData
from rapids_singlecell._cuda import _cooc_cuda as _co
from rapids_singlecell._utils import (
_calculate_blocks_per_pair,
Expand All @@ -21,7 +22,7 @@


def co_occurrence(
adata: AnnData,
adata: AnnData | SpatialData,
cluster_key: str,
*,
spatial_key: str = "spatial",
Expand Down Expand Up @@ -65,6 +66,8 @@ def co_occurrence(
computed at ``interval``.
"""

if SpatialData is not None and isinstance(adata, SpatialData):
adata = adata.table
_assert_categorical_obs(adata, key=cluster_key)
_assert_spatial_basis(adata, key=spatial_key)
spatial = cp.array(adata.obsm[spatial_key]).astype(np.float32)
Expand Down
6 changes: 5 additions & 1 deletion src/rapids_singlecell/squidpy_gpu/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from cupyx.scipy.sparse import issparse as cpissparse
from scipy.sparse import csc_matrix, issparse

from rapids_singlecell._compat import SpatialData

from ._utils import _assert_categorical_obs, _create_sparse_df

SOURCE = "source"
Expand Down Expand Up @@ -118,7 +120,7 @@ def _check_tuple_needles(needles, haystack, *, msg: str, reraise: bool = True):


def ligrec(
adata: AnnData,
adata: AnnData | SpatialData,
cluster_key: str,
*,
clusters: list | None = None,
Expand Down Expand Up @@ -233,6 +235,8 @@ def ligrec(
interacting components was 0 or it didn't pass the threshold percentage of \
cells being expressed within a given cluster.
"""
if SpatialData is not None and isinstance(adata, SpatialData):
adata = adata.table
# Get and Check interactions
if interactions is None:
interactions = _get_interactions(
Expand Down
11 changes: 11 additions & 0 deletions tests/test_backend_conformance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Run squidpy's backend conformance suite against the RSC backend."""

from __future__ import annotations

from squidpy.testing.backend_conformance import validate_backend


def test_conformance():
results = validate_backend("rapids_singlecell")
for name, status in results.items():
assert status == "PASSED", f"{name}: {status}"
Loading