diff --git a/docs/conf.py b/docs/conf.py index 0f413dfac..33ac09e1a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,7 +49,7 @@ numpy=("https://numpy.org/doc/stable", None), statsmodels=("https://www.statsmodels.org/stable", None), scipy=("https://docs.scipy.org/doc/scipy", None), - pandas=("https://pandas.pydata.org/pandas-docs/stable", None), + pandas=("https://pandas.pydata.org/docs", None), anndata=("https://anndata.readthedocs.io/en/stable", None), scanpy=("https://scanpy.readthedocs.io/en/stable", None), matplotlib=("https://matplotlib.org/stable", None), diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 3c5c86392..e3e7e1457 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -18,6 +18,7 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel +from tqdm.auto import tqdm __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -228,6 +229,46 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +def thread_map( + fn: Callable[..., Any], + items: Iterable[Any], + *, + n_jobs: int = 1, + show_progress_bar: bool = False, + unit: str = "item", + total: int | None = None, +) -> list[Any]: + """Map *fn* over *items* using a thread pool with an optional progress bar. + + Parameters + ---------- + fn + Callable applied to each element of *items*. + items + Iterable of inputs passed one-by-one to *fn*. + n_jobs + Number of worker threads. ``1`` runs sequentially (no pool overhead). + show_progress_bar + Whether to display a ``tqdm`` progress bar. + unit + Label shown next to the ``tqdm`` counter. + total + Length hint passed to ``tqdm`` when *items* has no ``__len__``. + + Returns + ------- + list + Results in the same order as *items*. + """ + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + it = pool.map(fn, items) + if show_progress_bar and tqdm is not None: + it = tqdm(it, total=len(items), unit=unit) + return list(it) + + def _get_n_cores(n_cores: int | None) -> int: """ Make number of cores a positive integer. diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 242ec16f2..cf53511d3 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -5,7 +5,6 @@ from abc import ABC from collections import namedtuple from collections.abc import Iterable, Mapping, Sequence -from functools import partial from itertools import product from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, TypeAlias @@ -13,14 +12,16 @@ import numpy as np import pandas as pd from anndata import AnnData +from numba import njit from scanpy import logging as logg from scipy.sparse import csc_matrix from spatialdata import SpatialData +from tqdm.auto import tqdm from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params, thread_map from squidpy._validators import assert_positive, check_tuple_needles from squidpy.gr._utils import ( _assert_categorical_obs, @@ -41,102 +42,6 @@ TempResult = namedtuple("TempResult", ["means", "pvalues"]) -_template = """ -from __future__ import annotations - -from numba import njit, prange -import numpy as np - -@njit(parallel={parallel}, cache=False, fastmath=False) -def _test_{n_cls}_{ret_means}_{parallel}( - interactions: NDArrayA[np.uint32], - interaction_clusters: NDArrayA[np.uint32], - data: NDArrayA[np.float64], - clustering: NDArrayA[np.uint32], - mean: NDArrayA[np.float64], - mask: NDArrayA[np.bool_], - res: NDArrayA[np.float64], - {args} -) -> None: - - {init} - {loop} - {finalize} - - for i in prange(len(interactions)): - rec, lig = interactions[i] - for j in prange(len(interaction_clusters)): - c1, c2 = interaction_clusters[j] - m1, m2 = mean[rec, c1], mean[lig, c2] - - if np.isnan(res[i, j]): - continue - - if m1 > 0 and m2 > 0: - {set_means} - if mask[rec, c1] and mask[lig, c2]: - # both rec, lig are sufficiently expressed in c1, c2 - res[i, j] += (groups[c1, rec] + groups[c2, lig]) > (m1 + m2) - else: - res[i, j] = np.nan - else: - # res_means is initialized with 0s - res[i, j] = np.nan -""" - - -def _create_template(n_cls: int, return_means: bool = False, parallel: bool = True) -> str: - if n_cls <= 0: - raise ValueError(f"Expected number of clusters to be positive, found `{n_cls}`.") - - rng = range(n_cls) - init = "".join( - f""" - g{i} = np.zeros((data.shape[1],), dtype=np.float64); s{i} = 0""" - for i in rng - ) - init += """ - error = False - """ - - loop_body = """ - if cl == 0: - g0 += data[row] - s0 += 1""" - loop_body = loop_body + "".join( - f""" - elif cl == {i}: - g{i} += data[row] - s{i} += 1""" - for i in range(1, n_cls) - ) - loop = f""" - for row in prange(data.shape[0]): - cl = clustering[row] - {loop_body} - else: - error = True - """ - finalize = ", ".join(f"g{i} / s{i}" for i in rng) - finalize = f"groups = np.stack(({finalize}))" - - if return_means: - args = "res_means: NDArrayA, # [np.float64]" - set_means = "res_means[i, j] = (m1 + m2) / 2.0" - else: - args = set_means = "" - - return _template.format( - n_cls=n_cls, - parallel=bool(parallel), - ret_means=int(return_means), - args=args, - init=init, - loop=loop, - finalize=finalize, - set_means=set_means, - ) - def _fdr_correct( pvals: pd.DataFrame, @@ -325,8 +230,8 @@ def test( alpha: float = 0.05, copy: bool = False, key_added: str | None = None, - numba_parallel: bool | None = None, - **kwargs: Any, + n_jobs: int | None = None, + show_progress_bar: bool = True, ) -> Mapping[str, pd.DataFrame] | None: """ Perform the permutation test as described in :cite:`cellphonedb`. @@ -354,8 +259,6 @@ def test( key_added Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``. If `None`, ``'{{cluster_key}}_ligrec'`` will be used. - %(numba_parallel)s - %(parallelize)s Returns ------- @@ -408,7 +311,7 @@ def test( # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values) - n_jobs = _get_n_cores(kwargs.pop("n_jobs", None)) + n_jobs = _get_n_cores(n_jobs) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" @@ -421,8 +324,7 @@ def test( n_perms=n_perms, seed=seed, n_jobs=n_jobs, - numba_parallel=numba_parallel, - **kwargs, + show_progress_bar=show_progress_bar, ) index = pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]) columns = pd.MultiIndex.from_tuples(clusters, names=["cluster_1", "cluster_2"]) @@ -453,6 +355,7 @@ def test( return res _save_data(self._adata, attr="uns", key=Key.uns.ligrec(cluster_key, key_added), data=res, time=start) + return None def _trim_data(self) -> None: """Subset genes :attr:`_data` to those present in interactions.""" @@ -629,6 +532,7 @@ def prepare( @d.dedent +@deprecated_params({"numba_parallel": "1.10.0", "backend": "1.10.0"}) def ligrec( adata: AnnData | SpatialData, cluster_key: str, @@ -641,7 +545,15 @@ def ligrec( copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, - **kwargs: Any, + n_perms: int = 1000, + seed: int | None = None, + clusters: Cluster_t | None = None, + alpha: float = 0.05, + n_jobs: int | None = None, + show_progress_bar: bool = True, + interactions_params: Mapping[str, Any] = MappingProxyType({}), + transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), + receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -658,24 +570,74 @@ def ligrec( ------- %(ligrec_test_returns)s """ # noqa: D400 - if isinstance(adata, SpatialData): - adata = adata.table with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) - .prepare(interactions, complex_policy=complex_policy, **kwargs) + .prepare( + interactions, + complex_policy=complex_policy, + interactions_params=interactions_params, + transmitter_params=transmitter_params, + receiver_params=receiver_params, + ) .test( cluster_key=cluster_key, + clusters=clusters, + n_perms=n_perms, threshold=threshold, + seed=seed, corr_method=corr_method, corr_axis=corr_axis, + alpha=alpha, copy=copy, key_added=key_added, - **kwargs, + n_jobs=n_jobs, + show_progress_bar=show_progress_bar, ) ) +@njit(nogil=True, cache=True) +def _score_permutation( + data: NDArrayA, + perm: NDArrayA, + inv_counts: NDArrayA, + mean_obs: NDArrayA, + interactions: NDArrayA, + interaction_clusters: NDArrayA, + valid: NDArrayA, + local_counts: NDArrayA, +) -> None: + """Score a single permutation: compute group means and accumulate p-value counts.""" + n_cells = data.shape[0] + n_genes = data.shape[1] + n_cls = mean_obs.shape[0] + + groups = np.zeros((n_cls, n_genes), dtype=np.float64) + for cell in range(n_cells): + cl = perm[cell] + for g in range(n_genes): + groups[cl, g] += data[cell, g] + for k in range(n_cls): + inv_c = inv_counts[k] + for g in range(n_genes): + groups[k, g] *= inv_c + + n_inter = interactions.shape[0] + n_cpairs = interaction_clusters.shape[0] + for i in range(n_inter): + r = interactions[i, 0] + l = interactions[i, 1] + for j in range(n_cpairs): + if valid[i, j]: + a = interaction_clusters[j, 0] + b = interaction_clusters[j, 1] + shuf = groups[a, r] + groups[b, l] + obs = mean_obs[a, r] + mean_obs[b, l] + if shuf > obs: + local_counts[i, j] += 1 + + @d.dedent def _analysis( data: pd.DataFrame, @@ -685,14 +647,11 @@ def _analysis( n_perms: int = 1000, seed: int | None = None, n_jobs: int = 1, - numba_parallel: bool | None = None, - **kwargs: Any, + show_progress_bar: bool = True, ) -> TempResult: """ Run the analysis as described in :cite:`cellphonedb`. - This function runs the mean, percent and shuffled analysis. - Parameters ---------- data @@ -706,11 +665,9 @@ def _analysis( %(n_perms)s %(seed)s n_jobs - Number of parallel jobs to launch. - numba_parallel - Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - kwargs - Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``. + Number of threads to use. + show_progress_bar + Whether to show the progress bar. Returns ------- @@ -719,145 +676,67 @@ def _analysis( - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the means. - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing the p-values. """ - - def extractor(res: Sequence[TempResult]) -> TempResult: - assert len(res) == n_jobs, f"Expected to find `{n_jobs}` results, found `{len(res)}`." - - meanss: list[NDArrayA] = [r.means for r in res if r.means is not None] - assert len(meanss) == 1, f"Only `1` job should've calculated the means, but found `{len(meanss)}`." - means = meanss[0] - if TYPE_CHECKING: - assert isinstance(means, np.ndarray) - - pvalues = np.sum([r.pvalues for r in res if r.pvalues is not None], axis=0) / float(n_perms) - assert means.shape == pvalues.shape, f"Means and p-values differ in shape: `{means.shape}`, `{pvalues.shape}`." - - return TempResult(means=means, pvalues=pvalues) - clustering = np.array(data["clusters"].values, dtype=np.int32) - # densify the data earlier to avoid concatenating sparse arrays - # with multiple fill values: '[0.0, nan]' (which leads to PerformanceWarning) data = data.astype({c: np.float64 for c in data.columns if c != "clusters"}) groups = data.groupby("clusters", observed=True) - mean = groups.mean().values.T # (n_genes, n_clusters) + mean_obs = groups.mean().values # (n_clusters, n_genes) # see https://github.com/scverse/squidpy/pull/991#issuecomment-2888506296 # for why we need to cast to int64 here mask = groups.apply( lambda c: ((c > 0).astype(np.int64).sum() / len(c)) >= threshold - ).values.T # (n_genes, n_clusters) - - # (n_cells, n_genes) - data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") - # all 3 should be C contiguous - return parallelize( # type: ignore[no-any-return] - _analysis_helper, - np.arange(n_perms, dtype=np.int32).tolist(), - n_jobs=n_jobs, - unit="permutation", - extractor=extractor, - **kwargs, - )( - data, - mean, - mask, - interactions, - interaction_clusters=interaction_clusters, - clustering=clustering, - seed=seed, - numba_parallel=numba_parallel, - ) - - -def _analysis_helper( - perms: NDArrayA, - data: NDArrayA, - mean: NDArrayA, - mask: NDArrayA, - interactions: NDArrayA, - interaction_clusters: NDArrayA, - clustering: NDArrayA, - seed: int | None = None, - numba_parallel: bool | None = None, - queue: SigQueue | None = None, -) -> TempResult: - """ - Run the results of mean, percent and shuffled analysis. - - Parameters - ---------- - perms - Permutation indices. Only used to set the ``seed``. - data - Array of shape `(n_cells, n_genes)`. - mean - Array of shape `(n_genes, n_clusters)` representing mean expression per cluster. - mask - Array of shape `(n_genes, n_clusters)` containing `True` if the a gene within a cluster is - expressed at least in ``threshold`` percentage of cells. - interactions - Array of shape `(n_interactions, 2)`. - interaction_clusters - Array of shape `(n_interaction_clusters, 2)`. - clustering - Array of shape `(n_cells,)` containing the original clustering. - seed - Random seed for :class:`numpy.random.RandomState`. - numba_parallel - Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - queue - Signalling queue to update progress bar. + ).values # (n_clusters, n_genes) + + counts = groups.size().values.astype(np.float64) + inv_counts = 1.0 / np.maximum(counts, 1) + + data_arr = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") + + interactions = np.array(interactions, dtype=np.int32) + interaction_clusters = np.array(interaction_clusters, dtype=np.int32) + rec = interactions[:, 0] + lig = interactions[:, 1] + c1 = interaction_clusters[:, 0] + c2 = interaction_clusters[:, 1] + + obs_score = mean_obs[c1, :][:, rec].T + mean_obs[c2, :][:, lig].T + nonzero = (mean_obs[c1, :][:, rec].T > 0) & (mean_obs[c2, :][:, lig].T > 0) + valid = nonzero & mask[c1, :][:, rec].T & mask[c2, :][:, lig].T + res_means = np.where(nonzero, obs_score / 2.0, 0.0) + + n_inter = len(rec) + n_cpairs = len(c1) + + base_chunk, remainder = divmod(n_perms, n_jobs) + chunk_sizes = np.full(n_jobs, base_chunk, dtype=np.int64) + chunk_sizes[:remainder] += 1 + + pbar = tqdm(total=n_perms, unit="permutation", disable=not show_progress_bar) + + def _worker(t: int) -> NDArrayA: + local_counts = np.zeros((n_inter, n_cpairs), dtype=np.int64) + rs = np.random.RandomState(None if seed is None else t + seed) + perm = clustering.copy() + for _ in range(chunk_sizes[t]): + rs.shuffle(perm) + _score_permutation( + data_arr, + perm, + inv_counts, + mean_obs, + interactions, + interaction_clusters, + valid, + local_counts, + ) + pbar.update(1) + return local_counts - Returns - ------- - Tuple of the following format: + thread_counts = thread_map(_worker, range(n_jobs), n_jobs=n_jobs) + pbar.close() - - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the true test - statistic. It is `None` if ``min(perms)!=0`` so that only 1 worker calculates it. - - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing `np.sum(T0 > T)` - where `T0` is the test statistic under null hypothesis and `T` is the true test statistic. - """ - rs = np.random.RandomState(None if seed is None else perms[0] + seed) - - clustering = clustering.copy() - n_cls = mean.shape[1] - return_means = np.min(perms) == 0 - - # ideally, these would be both sparse array, but there is no numba impl. (sparse.COO is read-only and very limited) - # keep it f64, because we're setting NaN - res = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) - numba_parallel = ( - (np.prod(res.shape) >= 2**20 or clustering.shape[0] >= 2**15) if numba_parallel is None else numba_parallel # type: ignore[assignment] - ) - - fn_key = f"_test_{n_cls}_{int(return_means)}_{bool(numba_parallel)}" - if fn_key not in globals(): - exec( - compile(_create_template(n_cls, return_means=return_means, parallel=numba_parallel), "", "exec"), # type: ignore[arg-type] - globals(), - ) - _test = globals()[fn_key] + pval_counts = np.sum(thread_counts, axis=0) + pvalues = pval_counts.astype(np.float64) / n_perms + pvalues[~valid] = np.nan - if return_means: - res_means: NDArrayA | None = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) - test = partial(_test, res_means=res_means) - else: - res_means = None - test = _test - - for _ in perms: - rs.shuffle(clustering) - error = test(interactions, interaction_clusters, data, clustering, mean, mask, res=res) - if error: - raise ValueError("In the execution of the numba function, an unhandled case was encountered. ") - # This is mainly to avoid a numba warning - # Otherwise, the numba function wouldn't be - # executed in parallel - # See: https://github.com/scverse/squidpy/issues/994 - if queue is not None: - queue.put(Signal.UPDATE) - - if queue is not None: - queue.put(Signal.FINISH) - - return TempResult(means=res_means, pvalues=res) + return TempResult(means=res_means, pvalues=pvalues) diff --git a/tests/_data/ligrec_pvalues_reference.pickle b/tests/_data/ligrec_pvalues_reference.pickle new file mode 100644 index 000000000..ab9489038 Binary files /dev/null and b/tests/_data/ligrec_pvalues_reference.pickle differ diff --git a/tests/conftest.py b/tests/conftest.py index 83d405d8d..45cb276c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -257,10 +257,9 @@ def complexes(adata: AnnData) -> Sequence[tuple[str, str]]: @pytest.fixture(scope="session") -def ligrec_no_numba() -> Mapping[str, pd.DataFrame]: - with open("tests/_data/ligrec_no_numba.pickle", "rb") as fin: - data = pickle.load(fin) - return {"means": data[0], "pvalues": data[1], "metadata": data[2]} +def ligrec_pvalues_reference() -> Mapping[str, pd.DataFrame]: + with open("tests/_data/ligrec_pvalues_reference.pickle", "rb") as fin: + return pickle.load(fin) @pytest.fixture(scope="session") diff --git a/tests/graph/test_ligrec.py b/tests/graph/test_ligrec.py index 099ecb1b6..a8bbc4ccc 100644 --- a/tests/graph/test_ligrec.py +++ b/tests/graph/test_ligrec.py @@ -3,7 +3,6 @@ import sys from collections.abc import Mapping, Sequence from itertools import product -from time import time from typing import TYPE_CHECKING import numpy as np @@ -313,39 +312,6 @@ def test_reproducibility_cores(self, adata: AnnData, interactions: Interactions_ assert not np.allclose(r3["pvalues"], r1["pvalues"]) assert not np.allclose(r3["pvalues"], r2["pvalues"]) - def test_reproducibility_numba_parallel_off(self, adata: AnnData, interactions: Interactions_t): - t1 = time() - r1 = ligrec( - adata, - _CK, - interactions=interactions, - n_perms=25, - copy=True, - show_progress_bar=False, - seed=42, - numba_parallel=False, - ) - t1 = time() - t1 - - t2 = time() - r2 = ligrec( - adata, - _CK, - interactions=interactions, - n_perms=25, - copy=True, - show_progress_bar=False, - seed=42, - numba_parallel=True, - ) - t2 = time() - t2 - - assert r1 is not r2 - # for such a small data, overhead from parallelization is too high - assert t1 <= t2, (t1, t2) - np.testing.assert_allclose(r1["means"], r2["means"]) - np.testing.assert_allclose(r1["pvalues"], r2["pvalues"]) - def test_paul15_correct_means(self, paul15: AnnData, paul15_means: pd.DataFrame): res = ligrec( paul15, @@ -364,20 +330,22 @@ def test_paul15_correct_means(self, paul15: AnnData, paul15_means: pd.DataFrame) np.testing.assert_array_equal(res["means"].columns, paul15_means.columns) np.testing.assert_allclose(res["means"].values, paul15_means.values) - def test_reproducibility_numba_off( - self, adata: AnnData, interactions: Interactions_t, ligrec_no_numba: Mapping[str, pd.DataFrame] + def test_pvalues_reference( + self, adata: AnnData, interactions: Interactions_t, ligrec_pvalues_reference: Mapping[str, pd.DataFrame] ): r = ligrec( - adata, _CK, interactions=interactions, n_perms=5, copy=True, show_progress_bar=False, seed=42, n_jobs=1 + adata, _CK, interactions=interactions, n_perms=25, copy=True, show_progress_bar=False, seed=42, n_jobs=1 ) - np.testing.assert_array_equal(r["means"].index, ligrec_no_numba["means"].index) - np.testing.assert_array_equal(r["means"].columns, ligrec_no_numba["means"].columns) - np.testing.assert_array_equal(r["pvalues"].index, ligrec_no_numba["pvalues"].index) - np.testing.assert_array_equal(r["pvalues"].columns, ligrec_no_numba["pvalues"].columns) + np.testing.assert_array_equal(r["means"].index, ligrec_pvalues_reference["means"].index) + np.testing.assert_array_equal(r["means"].columns, ligrec_pvalues_reference["means"].columns) + np.testing.assert_array_equal(r["pvalues"].index, ligrec_pvalues_reference["pvalues"].index) + np.testing.assert_array_equal(r["pvalues"].columns, ligrec_pvalues_reference["pvalues"].columns) - np.testing.assert_allclose(r["means"], ligrec_no_numba["means"]) - np.testing.assert_allclose(r["pvalues"], ligrec_no_numba["pvalues"]) - np.testing.assert_array_equal(np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_no_numba["pvalues"]))) + np.testing.assert_allclose(r["means"], ligrec_pvalues_reference["means"]) + np.testing.assert_allclose(r["pvalues"], ligrec_pvalues_reference["pvalues"]) + np.testing.assert_array_equal( + np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_pvalues_reference["pvalues"])) + ) def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys): s.logfile = sys.stderr @@ -420,7 +388,6 @@ def test_non_uniqueness(self, adata: AnnData, interactions: Interactions_t): copy=True, show_progress_bar=False, seed=42, - numba_parallel=False, ) assert len(res["pvalues"]) == len(expected)