diff --git a/docs/conf.py b/docs/conf.py index 0f413dfac..33ac09e1a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,7 +49,7 @@ numpy=("https://numpy.org/doc/stable", None), statsmodels=("https://www.statsmodels.org/stable", None), scipy=("https://docs.scipy.org/doc/scipy", None), - pandas=("https://pandas.pydata.org/pandas-docs/stable", None), + pandas=("https://pandas.pydata.org/docs", None), anndata=("https://anndata.readthedocs.io/en/stable", None), scanpy=("https://scanpy.readthedocs.io/en/stable", None), matplotlib=("https://matplotlib.org/stable", None), diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 3c5c86392..e3e7e1457 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -18,6 +18,7 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel +from tqdm.auto import tqdm __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -228,6 +229,46 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +def thread_map( + fn: Callable[..., Any], + items: Iterable[Any], + *, + n_jobs: int = 1, + show_progress_bar: bool = False, + unit: str = "item", + total: int | None = None, +) -> list[Any]: + """Map *fn* over *items* using a thread pool with an optional progress bar. + + Parameters + ---------- + fn + Callable applied to each element of *items*. + items + Iterable of inputs passed one-by-one to *fn*. + n_jobs + Number of worker threads. ``1`` runs sequentially (no pool overhead). + show_progress_bar + Whether to display a ``tqdm`` progress bar. + unit + Label shown next to the ``tqdm`` counter. + total + Length hint passed to ``tqdm`` when *items* has no ``__len__``. + + Returns + ------- + list + Results in the same order as *items*. + """ + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + it = pool.map(fn, items) + if show_progress_bar and tqdm is not None: + it = tqdm(it, total=len(items), unit=unit) + return list(it) + + def _get_n_cores(n_cores: int | None) -> int: """ Make number of cores a positive integer. diff --git a/src/squidpy/gr/_sepal.py b/src/squidpy/gr/_sepal.py index cce8ebca6..753057af3 100644 --- a/src/squidpy/gr/_sepal.py +++ b/src/squidpy/gr/_sepal.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence from typing import Literal import numpy as np @@ -8,13 +8,13 @@ from anndata import AnnData from numba import njit from scanpy import logging as logg -from scipy.sparse import csr_matrix, isspmatrix_csr, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix from sklearn.metrics import pairwise_distances from spatialdata import SpatialData from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params, thread_map from squidpy._validators import assert_non_empty_sequence from squidpy.gr._utils import ( _assert_connectivity_key, @@ -28,6 +28,7 @@ @d.dedent @inject_docs(key=Key.obsp.spatial_conn()) +@deprecated_params({"backend": "1.10.0"}) def sepal( adata: AnnData | SpatialData, max_neighs: Literal[4, 6], @@ -41,7 +42,6 @@ def sepal( use_raw: bool = False, copy: bool = False, n_jobs: int | None = None, - backend: str = "loky", show_progress_bar: bool = True, ) -> pd.DataFrame | None: """ @@ -78,8 +78,8 @@ def sepal( use_raw Whether to access :attr:`anndata.AnnData.raw`. %(copy)s - %(parallelize)s - + %(n_jobs)s + %(show_progress_bar)s Returns ------- If ``copy = True``, returns a :class:`pandas.DataFrame` with the sepal scores. @@ -126,24 +126,22 @@ def sepal( vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer) start = logg.info(f"Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)") - score = parallelize( - _score_helper, - collection=np.arange(len(genes)).tolist(), - extractor=np.hstack, - use_ixs=False, + use_hex = max_neighs == 6 + + if issparse(vals): + vals = csc_matrix(vals) + score = _diffusion_genes( + vals, + use_hex, + n_iter, + sat, + sat_idx, + unsat, + unsat_idx, + dt, + thresh, n_jobs=n_jobs, - backend=backend, show_progress_bar=show_progress_bar, - )( - vals=vals, - max_neighs=max_neighs, - n_iter=n_iter, - sat=sat, - sat_idx=sat_idx, - unsat=unsat, - unsat_idx=unsat_idx, - dt=dt, - thresh=thresh, ) key_added = "sepal_score" @@ -160,10 +158,9 @@ def sepal( _save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start) -def _score_helper( - ixs: Sequence[int], - vals: spmatrix | NDArrayA, - max_neighs: int, +def _diffusion_genes( + vals: NDArrayA | spmatrix, + use_hex: bool, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, @@ -171,58 +168,70 @@ def _score_helper( unsat_idx: NDArrayA, dt: float, thresh: float, - queue: SigQueue | None = None, + n_jobs: int, + show_progress_bar: bool = True, ) -> NDArrayA: - if max_neighs == 4: - fun = _laplacian_rect - elif max_neighs == 6: - fun = _laplacian_hex - else: - raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.") - - score = [] - for i in ixs: - if isinstance(vals, spmatrix): - conc = vals[:, i].toarray().flatten() # Safe to call toarray() - else: - conc = vals[:, i].copy() # vals is assumed to be a NumPy array here - - time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh) - score.append(dt * time_iter) - - if queue is not None: - queue.put(Signal.UPDATE) + """Run diffusion for each gene column, parallelised across threads.""" - if queue is not None: - queue.put(Signal.FINISH) + sparse = issparse(vals) - return np.array(score) + def _process_gene(i: int) -> float: + if sparse: + conc = np.ascontiguousarray(vals[:, i].toarray().ravel(), dtype=np.float64) + else: + conc = vals[:, i] + time_iter = _diffusion( + conc, + use_hex, + n_iter, + sat, + sat_idx, + unsat, + unsat_idx, + dt, + thresh, + ) + return dt * time_iter + + scores = thread_map( + _process_gene, + range(vals.shape[1]), + n_jobs=n_jobs, + show_progress_bar=show_progress_bar, + unit="gene", + ) + return np.array(scores) -@njit(fastmath=True) +@njit(fastmath=True, nogil=True) def _diffusion( conc: NDArrayA, - laplacian: Callable[[NDArrayA, NDArrayA], float], + use_hex: bool, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, - dt: float = 0.001, - thresh: float = 1e-8, + dt: float, + thresh: float, ) -> float: """Simulate diffusion process on a regular graph.""" - sat_shape, conc_shape = sat.shape[0], conc.shape[0] + sat_shape = sat.shape[0] + n_cells = conc.shape[0] entropy_arr = np.zeros(n_iter) - prev_ent = 1.0 nhood = np.zeros(sat_shape) + dcdt = np.zeros(n_cells) + prev_ent = 1.0 for i in range(n_iter): for j in range(sat_shape): nhood[j] = np.sum(conc[sat_idx[j]]) - d2 = laplacian(conc[sat], nhood) + if use_hex: + d2 = _laplacian_hex(conc[sat], nhood) + else: + d2 = _laplacian_rect(conc[sat], nhood) - dcdt = np.zeros(conc_shape) + dcdt[:] = 0.0 dcdt[sat] = d2 conc[sat] += dcdt[sat] * dt conc[unsat] += dcdt[unsat_idx] * dt diff --git a/tests/graph/test_sepal.py b/tests/graph/test_sepal.py index 8fb711f5e..a54f8d7ac 100644 --- a/tests/graph/test_sepal.py +++ b/tests/graph/test_sepal.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numba import numpy as np from anndata import AnnData from pandas.testing import assert_frame_equal @@ -16,8 +17,14 @@ def test_sepal_seq_par(adata: AnnData): adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape, p=[0.005, 0.995]) sepal(adata, max_neighs=6) - df = sepal(adata, max_neighs=6, copy=True, n_jobs=1) - df_parallel = sepal(adata, max_neighs=6, copy=True, n_jobs=2) + + prev_threads = numba.get_num_threads() + try: + numba.set_num_threads(1) + df = sepal(adata, max_neighs=6, copy=True) + finally: + numba.set_num_threads(prev_threads) + df_parallel = sepal(adata, max_neighs=6, copy=True) idx_df = df.index.values idx_adata = adata[:, adata.var.highly_variable.values].var_names.values @@ -40,8 +47,13 @@ def test_sepal_square_seq_par(adata_squaregrid: AnnData): rng = np.random.default_rng(42) adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape) - sepal(adata, max_neighs=4) - df_parallel = sepal(adata, copy=True, n_jobs=2, max_neighs=4) + prev_threads = numba.get_num_threads() + try: + numba.set_num_threads(1) + sepal(adata, max_neighs=4) + finally: + numba.set_num_threads(prev_threads) + df_parallel = sepal(adata, copy=True, max_neighs=4) idx_df = df_parallel.index.values idx_adata = adata[:, adata.var.highly_variable.values].var_names.values