From 6230aedd20651a2a570fe84f68746e14c5d4b2a5 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 11 Mar 2026 15:00:07 +0100 Subject: [PATCH 01/13] refactor ligrec: replace parallelize with threading + numba nogil --- src/squidpy/_utils.py | 31 ++ src/squidpy/gr/_ligrec.py | 387 +++++++------------- src/squidpy/gr/_ppatterns.py | 15 +- tests/_data/ligrec_pvalues_reference.pickle | Bin 0 -> 10103 bytes tests/conftest.py | 9 +- tests/graph/test_ligrec.py | 52 +-- tests/graph/test_ppatterns.py | 8 +- 7 files changed, 196 insertions(+), 306 deletions(-) create mode 100644 tests/_data/ligrec_pvalues_reference.pickle diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 736c88172..3c5c86392 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -278,6 +278,37 @@ def verbosity(level: int) -> Generator[None, None, None]: sc.settings.verbosity = verbosity +def deprecated_params( + params: dict[str, str], +) -> Callable[..., Any]: + """Decorator that warns when deprecated keyword arguments are passed. + + Parameters + ---------- + params + Mapping of deprecated parameter names to the version in which + they will be removed, e.g. ``{"n_jobs": "1.10.0"}``. + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + for k in list(kwargs): + if k in params: + warnings.warn( + f"Parameter `{k}` of `{func.__name__}()` is deprecated " + f"and has no effect. It will be removed in squidpy v{params[k]}.", + FutureWarning, + stacklevel=2, + ) + kwargs.pop(k) + return func(*args, **kwargs) + + return wrapper + + return decorator + + string_types = (bytes, str) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index a4beecd8f..4446ac999 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -5,28 +5,32 @@ from abc import ABC from collections import namedtuple from collections.abc import Iterable, Mapping, Sequence -from functools import partial +from concurrent.futures import ThreadPoolExecutor from itertools import product from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, TypeAlias +import numba import numpy as np import pandas as pd from anndata import AnnData +from numba import njit from scanpy import logging as logg from scipy.sparse import csc_matrix from spatialdata import SpatialData +from tqdm.auto import tqdm from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, deprecated_params from squidpy.gr._utils import ( _assert_categorical_obs, _assert_positive, _check_tuple_needles, _genesymbols, _save_data, + extract_adata, ) __all__ = ["ligrec", "PermutationTest"] @@ -42,102 +46,6 @@ TempResult = namedtuple("TempResult", ["means", "pvalues"]) -_template = """ -from __future__ import annotations - -from numba import njit, prange -import numpy as np - -@njit(parallel={parallel}, cache=False, fastmath=False) -def _test_{n_cls}_{ret_means}_{parallel}( - interactions: NDArrayA[np.uint32], - interaction_clusters: NDArrayA[np.uint32], - data: NDArrayA[np.float64], - clustering: NDArrayA[np.uint32], - mean: NDArrayA[np.float64], - mask: NDArrayA[np.bool_], - res: NDArrayA[np.float64], - {args} -) -> None: - - {init} - {loop} - {finalize} - - for i in prange(len(interactions)): - rec, lig = interactions[i] - for j in prange(len(interaction_clusters)): - c1, c2 = interaction_clusters[j] - m1, m2 = mean[rec, c1], mean[lig, c2] - - if np.isnan(res[i, j]): - continue - - if m1 > 0 and m2 > 0: - {set_means} - if mask[rec, c1] and mask[lig, c2]: - # both rec, lig are sufficiently expressed in c1, c2 - res[i, j] += (groups[c1, rec] + groups[c2, lig]) > (m1 + m2) - else: - res[i, j] = np.nan - else: - # res_means is initialized with 0s - res[i, j] = np.nan -""" - - -def _create_template(n_cls: int, return_means: bool = False, parallel: bool = True) -> str: - if n_cls <= 0: - raise ValueError(f"Expected number of clusters to be positive, found `{n_cls}`.") - - rng = range(n_cls) - init = "".join( - f""" - g{i} = np.zeros((data.shape[1],), dtype=np.float64); s{i} = 0""" - for i in rng - ) - init += """ - error = False - """ - - loop_body = """ - if cl == 0: - g0 += data[row] - s0 += 1""" - loop_body = loop_body + "".join( - f""" - elif cl == {i}: - g{i} += data[row] - s{i} += 1""" - for i in range(1, n_cls) - ) - loop = f""" - for row in prange(data.shape[0]): - cl = clustering[row] - {loop_body} - else: - error = True - """ - finalize = ", ".join(f"g{i} / s{i}" for i in rng) - finalize = f"groups = np.stack(({finalize}))" - - if return_means: - args = "res_means: NDArrayA, # [np.float64]" - set_means = "res_means[i, j] = (m1 + m2) / 2.0" - else: - args = set_means = "" - - return _template.format( - n_cls=n_cls, - parallel=bool(parallel), - ret_means=int(return_means), - args=args, - init=init, - loop=loop, - finalize=finalize, - set_means=set_means, - ) - def _fdr_correct( pvals: pd.DataFrame, @@ -326,8 +234,8 @@ def test( alpha: float = 0.05, copy: bool = False, key_added: str | None = None, - numba_parallel: bool | None = None, - **kwargs: Any, + n_jobs: int | None = None, + show_progress_bar: bool = True, ) -> Mapping[str, pd.DataFrame] | None: """ Perform the permutation test as described in :cite:`cellphonedb`. @@ -355,8 +263,6 @@ def test( key_added Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``. If `None`, ``'{{cluster_key}}_ligrec'`` will be used. - %(numba_parallel)s - %(parallelize)s Returns ------- @@ -409,7 +315,9 @@ def test( # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values) - n_jobs = _get_n_cores(kwargs.pop("n_jobs", None)) + if n_jobs is None: + n_jobs = numba.get_num_threads() + n_jobs = max(1, min(n_jobs, n_perms)) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" @@ -422,8 +330,7 @@ def test( n_perms=n_perms, seed=seed, n_jobs=n_jobs, - numba_parallel=numba_parallel, - **kwargs, + show_progress_bar=show_progress_bar, ) index = pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]) columns = pd.MultiIndex.from_tuples(clusters, names=["cluster_1", "cluster_2"]) @@ -454,6 +361,7 @@ def test( return res _save_data(self._adata, attr="uns", key=Key.uns.ligrec(cluster_key, key_added), data=res, time=start) + return None def _trim_data(self) -> None: """Subset genes :attr:`_data` to those present in interactions.""" @@ -630,6 +538,7 @@ def prepare( @d.dedent +@deprecated_params({"numba_parallel": "1.10.0"}) def ligrec( adata: AnnData | SpatialData, cluster_key: str, @@ -642,7 +551,17 @@ def ligrec( copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, - **kwargs: Any, + n_perms: int = 1000, + seed: int | None = None, + clusters: Cluster_t | None = None, + alpha: float = 0.05, + n_jobs: int | None = None, + show_progress_bar: bool = True, + interactions_params: Mapping[str, Any] = MappingProxyType({}), + transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), + receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), + *, + table_key: str = "table", ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -659,24 +578,75 @@ def ligrec( ------- %(ligrec_test_returns)s """ # noqa: D400 - if isinstance(adata, SpatialData): - adata = adata.table + adata = extract_adata(adata, table_key=table_key) with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) - .prepare(interactions, complex_policy=complex_policy, **kwargs) + .prepare( + interactions, + complex_policy=complex_policy, + interactions_params=interactions_params, + transmitter_params=transmitter_params, + receiver_params=receiver_params, + ) .test( cluster_key=cluster_key, + clusters=clusters, + n_perms=n_perms, threshold=threshold, + seed=seed, corr_method=corr_method, corr_axis=corr_axis, + alpha=alpha, copy=copy, key_added=key_added, - **kwargs, + n_jobs=n_jobs, + show_progress_bar=show_progress_bar, ) ) +@njit(nogil=True, cache=True) +def _score_permutation( + data: NDArrayA, + perm: NDArrayA, + inv_counts: NDArrayA, + mean_obs: NDArrayA, + interactions: NDArrayA, + interaction_clusters: NDArrayA, + valid: NDArrayA, + local_counts: NDArrayA, +) -> None: + """Score a single permutation: compute group means and accumulate p-value counts.""" + n_cells = data.shape[0] + n_genes = data.shape[1] + n_cls = mean_obs.shape[0] + + groups = np.zeros((n_cls, n_genes), dtype=np.float64) + for cell in range(n_cells): + cl = perm[cell] + for g in range(n_genes): + groups[cl, g] += data[cell, g] + for k in range(n_cls): + inv_c = inv_counts[k] + for g in range(n_genes): + groups[k, g] *= inv_c + + n_inter = interactions.shape[0] + n_cpairs = interaction_clusters.shape[0] + for i in range(n_inter): + r = interactions[i, 0] + l = interactions[i, 1] + for j in range(n_cpairs): + if valid[i, j]: + a = interaction_clusters[j, 0] + b = interaction_clusters[j, 1] + shuf = groups[a, r] + groups[b, l] + obs = mean_obs[a, r] + mean_obs[b, l] + if shuf > obs: + local_counts[i, j] += 1 + + @d.dedent def _analysis( data: pd.DataFrame, @@ -686,14 +656,11 @@ def _analysis( n_perms: int = 1000, seed: int | None = None, n_jobs: int = 1, - numba_parallel: bool | None = None, - **kwargs: Any, + show_progress_bar: bool = True, ) -> TempResult: """ Run the analysis as described in :cite:`cellphonedb`. - This function runs the mean, percent and shuffled analysis. - Parameters ---------- data @@ -707,11 +674,9 @@ def _analysis( %(n_perms)s %(seed)s n_jobs - Number of parallel jobs to launch. - numba_parallel - Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - kwargs - Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``. + Number of threads to use. + show_progress_bar + Whether to show the progress bar. Returns ------- @@ -720,145 +685,67 @@ def _analysis( - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the means. - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing the p-values. """ - - def extractor(res: Sequence[TempResult]) -> TempResult: - assert len(res) == n_jobs, f"Expected to find `{n_jobs}` results, found `{len(res)}`." - - meanss: list[NDArrayA] = [r.means for r in res if r.means is not None] - assert len(meanss) == 1, f"Only `1` job should've calculated the means, but found `{len(meanss)}`." - means = meanss[0] - if TYPE_CHECKING: - assert isinstance(means, np.ndarray) - - pvalues = np.sum([r.pvalues for r in res if r.pvalues is not None], axis=0) / float(n_perms) - assert means.shape == pvalues.shape, f"Means and p-values differ in shape: `{means.shape}`, `{pvalues.shape}`." - - return TempResult(means=means, pvalues=pvalues) - clustering = np.array(data["clusters"].values, dtype=np.int32) - # densify the data earlier to avoid concatenating sparse arrays - # with multiple fill values: '[0.0, nan]' (which leads to PerformanceWarning) data = data.astype({c: np.float64 for c in data.columns if c != "clusters"}) groups = data.groupby("clusters", observed=True) - mean = groups.mean().values.T # (n_genes, n_clusters) + mean_obs = groups.mean().values # (n_clusters, n_genes) # see https://github.com/scverse/squidpy/pull/991#issuecomment-2888506296 # for why we need to cast to int64 here mask = groups.apply( lambda c: ((c > 0).astype(np.int64).sum() / len(c)) >= threshold - ).values.T # (n_genes, n_clusters) - - # (n_cells, n_genes) - data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") - # all 3 should be C contiguous - return parallelize( # type: ignore[no-any-return] - _analysis_helper, - np.arange(n_perms, dtype=np.int32).tolist(), - n_jobs=n_jobs, - unit="permutation", - extractor=extractor, - **kwargs, - )( - data, - mean, - mask, - interactions, - interaction_clusters=interaction_clusters, - clustering=clustering, - seed=seed, - numba_parallel=numba_parallel, - ) - - -def _analysis_helper( - perms: NDArrayA, - data: NDArrayA, - mean: NDArrayA, - mask: NDArrayA, - interactions: NDArrayA, - interaction_clusters: NDArrayA, - clustering: NDArrayA, - seed: int | None = None, - numba_parallel: bool | None = None, - queue: SigQueue | None = None, -) -> TempResult: - """ - Run the results of mean, percent and shuffled analysis. - - Parameters - ---------- - perms - Permutation indices. Only used to set the ``seed``. - data - Array of shape `(n_cells, n_genes)`. - mean - Array of shape `(n_genes, n_clusters)` representing mean expression per cluster. - mask - Array of shape `(n_genes, n_clusters)` containing `True` if the a gene within a cluster is - expressed at least in ``threshold`` percentage of cells. - interactions - Array of shape `(n_interactions, 2)`. - interaction_clusters - Array of shape `(n_interaction_clusters, 2)`. - clustering - Array of shape `(n_cells,)` containing the original clustering. - seed - Random seed for :class:`numpy.random.RandomState`. - numba_parallel - Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - queue - Signalling queue to update progress bar. + ).values # (n_clusters, n_genes) + + counts = groups.size().values.astype(np.float64) + inv_counts = 1.0 / np.maximum(counts, 1) + + data_arr = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") + + interactions = np.array(interactions, dtype=np.int32) + interaction_clusters = np.array(interaction_clusters, dtype=np.int32) + rec = interactions[:, 0] + lig = interactions[:, 1] + c1 = interaction_clusters[:, 0] + c2 = interaction_clusters[:, 1] + + obs_score = mean_obs[c1, :][:, rec].T + mean_obs[c2, :][:, lig].T + nonzero = (mean_obs[c1, :][:, rec].T > 0) & (mean_obs[c2, :][:, lig].T > 0) + valid = nonzero & mask[c1, :][:, rec].T & mask[c2, :][:, lig].T + res_means = np.where(nonzero, obs_score / 2.0, 0.0) + + n_inter = len(rec) + n_cpairs = len(c1) + + base_chunk, remainder = divmod(n_perms, n_jobs) + chunk_sizes = np.full(n_jobs, base_chunk, dtype=np.int64) + chunk_sizes[:remainder] += 1 + + thread_counts = np.zeros((n_jobs, n_inter, n_cpairs), dtype=np.int64) + pbar = tqdm(total=n_perms, unit="permutation", disable=not show_progress_bar) + + def _worker(t: int) -> None: + rs = np.random.RandomState(None if seed is None else t + seed) + perm = clustering.copy() + for _ in range(chunk_sizes[t]): + rs.shuffle(perm) + _score_permutation( + data_arr, + perm, + inv_counts, + mean_obs, + interactions, + interaction_clusters, + valid, + thread_counts[t], + ) + pbar.update(1) - Returns - ------- - Tuple of the following format: + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + list(pool.map(_worker, range(n_jobs))) + pbar.close() - - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the true test - statistic. It is `None` if ``min(perms)!=0`` so that only 1 worker calculates it. - - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing `np.sum(T0 > T)` - where `T0` is the test statistic under null hypothesis and `T` is the true test statistic. - """ - rs = np.random.RandomState(None if seed is None else perms[0] + seed) - - clustering = clustering.copy() - n_cls = mean.shape[1] - return_means = np.min(perms) == 0 - - # ideally, these would be both sparse array, but there is no numba impl. (sparse.COO is read-only and very limited) - # keep it f64, because we're setting NaN - res = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) - numba_parallel = ( - (np.prod(res.shape) >= 2**20 or clustering.shape[0] >= 2**15) if numba_parallel is None else numba_parallel # type: ignore[assignment] - ) - - fn_key = f"_test_{n_cls}_{int(return_means)}_{bool(numba_parallel)}" - if fn_key not in globals(): - exec( - compile(_create_template(n_cls, return_means=return_means, parallel=numba_parallel), "", "exec"), # type: ignore[arg-type] - globals(), - ) - _test = globals()[fn_key] + pval_counts = thread_counts.sum(axis=0) + pvalues = pval_counts.astype(np.float64) / n_perms + pvalues[~valid] = np.nan - if return_means: - res_means: NDArrayA | None = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) - test = partial(_test, res_means=res_means) - else: - res_means = None - test = _test - - for _ in perms: - rs.shuffle(clustering) - error = test(interactions, interaction_clusters, data, clustering, mean, mask, res=res) - if error: - raise ValueError("In the execution of the numba function, an unhandled case was encountered. ") - # This is mainly to avoid a numba warning - # Otherwise, the numba function wouldn't be - # executed in parallel - # See: https://github.com/scverse/squidpy/issues/994 - if queue is not None: - queue.put(Signal.UPDATE) - - if queue is not None: - queue.put(Signal.FINISH) - - return TempResult(means=res_means, pvalues=res) + return TempResult(means=res_means, pvalues=pvalues) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 292c75994..2080d73fa 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,7 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, deprecated_params, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -342,16 +342,13 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent +@deprecated_params({"n_splits": "1.10.0", "n_jobs": "1.10.0", "backend": "1.10.0", "show_progress_bar": "1.10.0"}) def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, spatial_key: str = Key.obsm.spatial, interval: int | NDArrayA = 50, copy: bool = False, - n_splits: int | None = None, - n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -365,10 +362,6 @@ def co_occurrence( Distances interval at which co-occurrence is computed. If :class:`int`, uniformly spaced interval of the given size will be used. %(copy)s - n_splits - Number of splits in which to divide the spatial coordinates in - :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. - %(parallelize)s Returns ------- @@ -406,9 +399,7 @@ def co_occurrence( # Compute co-occurrence probabilities using the fast numba routine. out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs) - start = logg.info( - f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits" - ) + start = logg.info(f"Calculating co-occurrence probabilities for `{len(interval)}` intervals") if copy: logg.info("Finish", time=start) diff --git a/tests/_data/ligrec_pvalues_reference.pickle b/tests/_data/ligrec_pvalues_reference.pickle new file mode 100644 index 0000000000000000000000000000000000000000..ab94890388599e476d17cbd3d9fd8c6a3c602ae6 GIT binary patch literal 10103 zcmds7du$ZP9o~D{=f{EZ3p=8aDuo!T96oRwr=eMc2?BXV*;HlzJ1Lx~O= z@~}d!s`YYO-e7;;OP4=Nv3*b)9x|x4*x5XhOe;oG*3Ia!oRo(Y1IDXer>FY2Z)SD6 z7kV~lyGVK>Zbox`s5M{8Bu5hc+jT{XV<%{BEpd)78-_e;M$Hk~FqNoXqgG9yU0sXc zL60oFh{Re+iR6ItAhjCuCS~7IE4(q8-e@=2OG&HMoXiZ5j7BAU?ZX*8oyg8nD;Y|E zDru&TOn(|-h3uBpssygG6BA*KcKhP=33@keBR*0!M%C-A&_H^0M8PipLfLG>wmUf2INwQVDk$?i#pj;f}yv2loQF>*1!*bT`am&Et#QA?q2% zq-v|W2w|v;ogi+o{DYm7Q>!a$I zi6ifBx+ofpj=cEeWgq9h=k;Y1`+8UY{iulGtO|)O+iG>D&;FYkWkoJ=>$3#0#HpQX zhgBmDCUjl0S4D}CiD%+gIG#yBLQtY1Q&03Gr$hjreYjTQ$EZfsy6i9wYBOmq=A#10 zSvKC?!YStT7E-I;RJw1+cz64~o3=0C`p0A9?z1gN!%KfI{_N>LoVxi_XT|59|Ixem zU;IG)Z{XsUp5GrAE6z7H)-Cfd~9&%R4oe)q4jGa}@^>5_C>tcD~vet0Y&v7ZFoAr3H(YmsB*7rEtDT!$s~7a|qz5C>T{>)bpk*awhafA+dpKU;t7BT;XAL7w;d zw{sxk>Bq@HR>t#`;^^X-bhcUgpT z>*&MhoaH-vVjtvZxt2uSA)AfT5z#Mh^>sn{)XR}e`3NFpo+k)?wPN8F#NA>hAv&OQAT`@Ne_>fE7mW&~dxuZCnsdruu%P<{jebZSj3AocDL_}Km%#Qpbb z>&JGm_u{+u{o8!OE^M27Q>!NpWuNtt&OU$stPP~m-K^K3zUkj@8cCBoG_zFS)W17T zq`CYEApcYE2hC)mI|Q>(U$G7rl9sa1y1(I&kgvWL((0TwO}D{$y~TNkvyan?8I8*( z7z5Q(Qh7k)C+Pgs9&c%%qoL_$JWunc=pa<+$__<0aSm;TUl6o52v*41Bsc)ZLy2Q; z2p(KG)_%!xthIb!D>zY$c%sGy_}Z!ETLt&$;JyX!Td9Jl)L@di8~u951mi-A(F`7| z8JJZ722g4pP!^vB0f2TJe1qd_B7D%gFs#L}4nsEtoJ6|=K3G+JidAe1#R~PO2H?aF z-HQirWvffXCJ*h}8W^V3xxE#%D+oTk{CEZM3gTsNW<`bojIG@ZV~uyW3O;`zSSnyo z#i}w>Mml8-CmG z-bknw>9}R}s*aBK_3f-4>hA09jZrJu_vKh8H~@Y3d~rj}&LShY!c1ih@B|!?yqrBP z8$*z7dv-jN`yh`CT@f*AO0R{OO&?cOK-g zgttw)QtIt@e*GdF84W4$a^}Q+VhB!)$9E#q6Zj4b*J+UfpUBm}`i1J(4vd`4E<`&F z8kjo+>N`*4^+paPLIZqRu!4~PJ!gLsv}Ya^WAT+7u)~YTKH75lmqjdkWM^15P0HBaS)uDLMG7SI$Mr%R(kRajv)+`Hor$!V~jqx_zST{?$~PD7%Hl zwNOZVDHn@564z)`Oz-q)=#I?FgqXqDXu7zt80?pjGB`s87DGAwAkBcKY9@v>h75)s z5OS89hR(+9h<+3Y`v!)`Fg%W748r8tXqdPldJerD?HNS>Jq+K+@B<9{A?%{s4_R># z3frOy=PVx8;n6IgpU1oZfo{9_@DO4(ozTw-O#J0PJ1}l>i;qcndtujX=!}i literal 0 HcmV?d00001 diff --git a/tests/conftest.py b/tests/conftest.py index 83d405d8d..738fd8853 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -256,11 +256,12 @@ def complexes(adata: AnnData) -> Sequence[tuple[str, str]]: ] + + @pytest.fixture(scope="session") -def ligrec_no_numba() -> Mapping[str, pd.DataFrame]: - with open("tests/_data/ligrec_no_numba.pickle", "rb") as fin: - data = pickle.load(fin) - return {"means": data[0], "pvalues": data[1], "metadata": data[2]} +def ligrec_pvalues_reference() -> Mapping[str, pd.DataFrame]: + with open("tests/_data/ligrec_pvalues_reference.pickle", "rb") as fin: + return pickle.load(fin) @pytest.fixture(scope="session") diff --git a/tests/graph/test_ligrec.py b/tests/graph/test_ligrec.py index 099ecb1b6..4900f6fc6 100644 --- a/tests/graph/test_ligrec.py +++ b/tests/graph/test_ligrec.py @@ -3,7 +3,6 @@ import sys from collections.abc import Mapping, Sequence from itertools import product -from time import time from typing import TYPE_CHECKING import numpy as np @@ -313,39 +312,6 @@ def test_reproducibility_cores(self, adata: AnnData, interactions: Interactions_ assert not np.allclose(r3["pvalues"], r1["pvalues"]) assert not np.allclose(r3["pvalues"], r2["pvalues"]) - def test_reproducibility_numba_parallel_off(self, adata: AnnData, interactions: Interactions_t): - t1 = time() - r1 = ligrec( - adata, - _CK, - interactions=interactions, - n_perms=25, - copy=True, - show_progress_bar=False, - seed=42, - numba_parallel=False, - ) - t1 = time() - t1 - - t2 = time() - r2 = ligrec( - adata, - _CK, - interactions=interactions, - n_perms=25, - copy=True, - show_progress_bar=False, - seed=42, - numba_parallel=True, - ) - t2 = time() - t2 - - assert r1 is not r2 - # for such a small data, overhead from parallelization is too high - assert t1 <= t2, (t1, t2) - np.testing.assert_allclose(r1["means"], r2["means"]) - np.testing.assert_allclose(r1["pvalues"], r2["pvalues"]) - def test_paul15_correct_means(self, paul15: AnnData, paul15_means: pd.DataFrame): res = ligrec( paul15, @@ -379,6 +345,23 @@ def test_reproducibility_numba_off( np.testing.assert_allclose(r["pvalues"], ligrec_no_numba["pvalues"]) np.testing.assert_array_equal(np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_no_numba["pvalues"]))) + def test_pvalues_reference( + self, adata: AnnData, interactions: Interactions_t, ligrec_pvalues_reference: Mapping[str, pd.DataFrame] + ): + r = ligrec( + adata, _CK, interactions=interactions, n_perms=25, copy=True, show_progress_bar=False, seed=42, n_jobs=1 + ) + np.testing.assert_array_equal(r["means"].index, ligrec_pvalues_reference["means"].index) + np.testing.assert_array_equal(r["means"].columns, ligrec_pvalues_reference["means"].columns) + np.testing.assert_array_equal(r["pvalues"].index, ligrec_pvalues_reference["pvalues"].index) + np.testing.assert_array_equal(r["pvalues"].columns, ligrec_pvalues_reference["pvalues"].columns) + + np.testing.assert_allclose(r["means"], ligrec_pvalues_reference["means"]) + np.testing.assert_allclose(r["pvalues"], ligrec_pvalues_reference["pvalues"]) + np.testing.assert_array_equal( + np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_pvalues_reference["pvalues"])) + ) + def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys): s.logfile = sys.stderr s.verbosity = 4 @@ -420,7 +403,6 @@ def test_non_uniqueness(self, adata: AnnData, interactions: Interactions_t): copy=True, show_progress_bar=False, seed=42, - numba_parallel=False, ) assert len(res["pvalues"]) == len(expected) diff --git a/tests/graph/test_ppatterns.py b/tests/graph/test_ppatterns.py index 226fb2830..01c2e3033 100644 --- a/tests/graph/test_ppatterns.py +++ b/tests/graph/test_ppatterns.py @@ -137,12 +137,10 @@ def test_co_occurrence(adata: AnnData): assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].unique().shape[0] -# @pytest.mark.parametrize(("ys", "xs"), [(10, 10), (None, None), (10, 20)]) -@pytest.mark.parametrize(("n_jobs", "n_splits"), [(1, 2), (2, 2)]) -def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int, n_splits: int): +def test_co_occurrence_reproducibility(adata: AnnData): """Check co_occurrence reproducibility results.""" - arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) - arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) + arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True) + arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True) np.testing.assert_array_equal(sorted(interval_1), sorted(interval_2)) np.testing.assert_allclose(arr_1, arr_2) From 2de8fd591438d19de0dfc247e59eded0c19e52f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 14:01:10 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 738fd8853..45cb276c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -256,8 +256,6 @@ def complexes(adata: AnnData) -> Sequence[tuple[str, str]]: ] - - @pytest.fixture(scope="session") def ligrec_pvalues_reference() -> Mapping[str, pd.DataFrame]: with open("tests/_data/ligrec_pvalues_reference.pickle", "rb") as fin: From 4cdf709c0c78cfe1622575b55f069f01ef3e448c Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 11 Mar 2026 15:16:11 +0100 Subject: [PATCH 03/13] undo extract_adata --- src/squidpy/gr/_ligrec.py | 4 ---- tests/graph/test_ppatterns.py | 8 +++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 4446ac999..e294fd9a2 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -30,7 +30,6 @@ _check_tuple_needles, _genesymbols, _save_data, - extract_adata, ) __all__ = ["ligrec", "PermutationTest"] @@ -560,8 +559,6 @@ def ligrec( interactions_params: Mapping[str, Any] = MappingProxyType({}), transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), - *, - table_key: str = "table", ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -578,7 +575,6 @@ def ligrec( ------- %(ligrec_test_returns)s """ # noqa: D400 - adata = extract_adata(adata, table_key=table_key) with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) diff --git a/tests/graph/test_ppatterns.py b/tests/graph/test_ppatterns.py index 01c2e3033..226fb2830 100644 --- a/tests/graph/test_ppatterns.py +++ b/tests/graph/test_ppatterns.py @@ -137,10 +137,12 @@ def test_co_occurrence(adata: AnnData): assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].unique().shape[0] -def test_co_occurrence_reproducibility(adata: AnnData): +# @pytest.mark.parametrize(("ys", "xs"), [(10, 10), (None, None), (10, 20)]) +@pytest.mark.parametrize(("n_jobs", "n_splits"), [(1, 2), (2, 2)]) +def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int, n_splits: int): """Check co_occurrence reproducibility results.""" - arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True) - arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True) + arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) + arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) np.testing.assert_array_equal(sorted(interval_1), sorted(interval_2)) np.testing.assert_allclose(arr_1, arr_2) From 6ec9571b0f0c68953856fb2a99424aecb1d9c542 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 11 Mar 2026 15:16:54 +0100 Subject: [PATCH 04/13] undo changes --- src/squidpy/_utils.py | 31 ------------------------------- src/squidpy/gr/_ppatterns.py | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 34 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 3c5c86392..736c88172 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -278,37 +278,6 @@ def verbosity(level: int) -> Generator[None, None, None]: sc.settings.verbosity = verbosity -def deprecated_params( - params: dict[str, str], -) -> Callable[..., Any]: - """Decorator that warns when deprecated keyword arguments are passed. - - Parameters - ---------- - params - Mapping of deprecated parameter names to the version in which - they will be removed, e.g. ``{"n_jobs": "1.10.0"}``. - """ - - def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - for k in list(kwargs): - if k in params: - warnings.warn( - f"Parameter `{k}` of `{func.__name__}()` is deprecated " - f"and has no effect. It will be removed in squidpy v{params[k]}.", - FutureWarning, - stacklevel=2, - ) - kwargs.pop(k) - return func(*args, **kwargs) - - return wrapper - - return decorator - - string_types = (bytes, str) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 2080d73fa..292c75994 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,7 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, deprecated_params, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -342,13 +342,16 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent -@deprecated_params({"n_splits": "1.10.0", "n_jobs": "1.10.0", "backend": "1.10.0", "show_progress_bar": "1.10.0"}) def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, spatial_key: str = Key.obsm.spatial, interval: int | NDArrayA = 50, copy: bool = False, + n_splits: int | None = None, + n_jobs: int | None = None, + backend: str = "loky", + show_progress_bar: bool = True, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -362,6 +365,10 @@ def co_occurrence( Distances interval at which co-occurrence is computed. If :class:`int`, uniformly spaced interval of the given size will be used. %(copy)s + n_splits + Number of splits in which to divide the spatial coordinates in + :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. + %(parallelize)s Returns ------- @@ -399,7 +406,9 @@ def co_occurrence( # Compute co-occurrence probabilities using the fast numba routine. out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs) - start = logg.info(f"Calculating co-occurrence probabilities for `{len(interval)}` intervals") + start = logg.info( + f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits" + ) if copy: logg.info("Finish", time=start) From d6324d5e1ef98ef8b2e0e3a6ab4212541d933aec Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 11 Mar 2026 15:19:09 +0100 Subject: [PATCH 05/13] checkout main --- src/squidpy/_utils.py | 31 +++++++++++++++++++++++++++++++ src/squidpy/gr/_ppatterns.py | 15 +++------------ tests/graph/test_ppatterns.py | 8 +++----- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 736c88172..3c5c86392 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -278,6 +278,37 @@ def verbosity(level: int) -> Generator[None, None, None]: sc.settings.verbosity = verbosity +def deprecated_params( + params: dict[str, str], +) -> Callable[..., Any]: + """Decorator that warns when deprecated keyword arguments are passed. + + Parameters + ---------- + params + Mapping of deprecated parameter names to the version in which + they will be removed, e.g. ``{"n_jobs": "1.10.0"}``. + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + for k in list(kwargs): + if k in params: + warnings.warn( + f"Parameter `{k}` of `{func.__name__}()` is deprecated " + f"and has no effect. It will be removed in squidpy v{params[k]}.", + FutureWarning, + stacklevel=2, + ) + kwargs.pop(k) + return func(*args, **kwargs) + + return wrapper + + return decorator + + string_types = (bytes, str) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 292c75994..2080d73fa 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,7 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, deprecated_params, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -342,16 +342,13 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent +@deprecated_params({"n_splits": "1.10.0", "n_jobs": "1.10.0", "backend": "1.10.0", "show_progress_bar": "1.10.0"}) def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, spatial_key: str = Key.obsm.spatial, interval: int | NDArrayA = 50, copy: bool = False, - n_splits: int | None = None, - n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -365,10 +362,6 @@ def co_occurrence( Distances interval at which co-occurrence is computed. If :class:`int`, uniformly spaced interval of the given size will be used. %(copy)s - n_splits - Number of splits in which to divide the spatial coordinates in - :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. - %(parallelize)s Returns ------- @@ -406,9 +399,7 @@ def co_occurrence( # Compute co-occurrence probabilities using the fast numba routine. out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs) - start = logg.info( - f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits" - ) + start = logg.info(f"Calculating co-occurrence probabilities for `{len(interval)}` intervals") if copy: logg.info("Finish", time=start) diff --git a/tests/graph/test_ppatterns.py b/tests/graph/test_ppatterns.py index 226fb2830..01c2e3033 100644 --- a/tests/graph/test_ppatterns.py +++ b/tests/graph/test_ppatterns.py @@ -137,12 +137,10 @@ def test_co_occurrence(adata: AnnData): assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].unique().shape[0] -# @pytest.mark.parametrize(("ys", "xs"), [(10, 10), (None, None), (10, 20)]) -@pytest.mark.parametrize(("n_jobs", "n_splits"), [(1, 2), (2, 2)]) -def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int, n_splits: int): +def test_co_occurrence_reproducibility(adata: AnnData): """Check co_occurrence reproducibility results.""" - arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) - arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) + arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True) + arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True) np.testing.assert_array_equal(sorted(interval_1), sorted(interval_2)) np.testing.assert_allclose(arr_1, arr_2) From df46993409dd51f7bd41b458fb33d31740f1de5e Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 11 Mar 2026 15:27:57 +0100 Subject: [PATCH 06/13] remove old test --- tests/graph/test_ligrec.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/graph/test_ligrec.py b/tests/graph/test_ligrec.py index 4900f6fc6..a8bbc4ccc 100644 --- a/tests/graph/test_ligrec.py +++ b/tests/graph/test_ligrec.py @@ -330,21 +330,6 @@ def test_paul15_correct_means(self, paul15: AnnData, paul15_means: pd.DataFrame) np.testing.assert_array_equal(res["means"].columns, paul15_means.columns) np.testing.assert_allclose(res["means"].values, paul15_means.values) - def test_reproducibility_numba_off( - self, adata: AnnData, interactions: Interactions_t, ligrec_no_numba: Mapping[str, pd.DataFrame] - ): - r = ligrec( - adata, _CK, interactions=interactions, n_perms=5, copy=True, show_progress_bar=False, seed=42, n_jobs=1 - ) - np.testing.assert_array_equal(r["means"].index, ligrec_no_numba["means"].index) - np.testing.assert_array_equal(r["means"].columns, ligrec_no_numba["means"].columns) - np.testing.assert_array_equal(r["pvalues"].index, ligrec_no_numba["pvalues"].index) - np.testing.assert_array_equal(r["pvalues"].columns, ligrec_no_numba["pvalues"].columns) - - np.testing.assert_allclose(r["means"], ligrec_no_numba["means"]) - np.testing.assert_allclose(r["pvalues"], ligrec_no_numba["pvalues"]) - np.testing.assert_array_equal(np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_no_numba["pvalues"]))) - def test_pvalues_reference( self, adata: AnnData, interactions: Interactions_t, ligrec_pvalues_reference: Mapping[str, pd.DataFrame] ): From 2f999dc036a77e236175b49aa78c439fc24c5d32 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 11 Mar 2026 15:30:07 +0100 Subject: [PATCH 07/13] also deprecate backend --- src/squidpy/gr/_ligrec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index e294fd9a2..f364dd6a1 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -537,7 +537,7 @@ def prepare( @d.dedent -@deprecated_params({"numba_parallel": "1.10.0"}) +@deprecated_params({"numba_parallel": "1.10.0", "backend": "1.10.0"}) def ligrec( adata: AnnData | SpatialData, cluster_key: str, From 3a62d7d9f85fa9ffbb71944bf3b06ee3074f8105 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:45:01 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_ligrec.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 40fe2d82a..4d5c10022 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -23,10 +23,8 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs - from squidpy._utils import NDArrayA, deprecated_params from squidpy._validators import assert_positive, check_tuple_needles - from squidpy.gr._utils import ( _assert_categorical_obs, _genesymbols, From 8f8a192f50e46cd146cdf2897a4976bd280e40a7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 16 Mar 2026 21:57:36 +0300 Subject: [PATCH 09/13] update conf --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 0f413dfac..33ac09e1a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,7 +49,7 @@ numpy=("https://numpy.org/doc/stable", None), statsmodels=("https://www.statsmodels.org/stable", None), scipy=("https://docs.scipy.org/doc/scipy", None), - pandas=("https://pandas.pydata.org/pandas-docs/stable", None), + pandas=("https://pandas.pydata.org/docs", None), anndata=("https://anndata.readthedocs.io/en/stable", None), scanpy=("https://scanpy.readthedocs.io/en/stable", None), matplotlib=("https://matplotlib.org/stable", None), From 0c1add069906755115a8721b4396c3c0c8593eb7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 17 Mar 2026 00:17:51 +0300 Subject: [PATCH 10/13] use _get_n_cores --- src/squidpy/gr/_ligrec.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 4d5c10022..f6e506714 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -10,7 +10,6 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, TypeAlias -import numba import numpy as np import pandas as pd from anndata import AnnData @@ -23,7 +22,7 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, deprecated_params +from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params from squidpy._validators import assert_positive, check_tuple_needles from squidpy.gr._utils import ( _assert_categorical_obs, @@ -313,9 +312,7 @@ def test( # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values) - if n_jobs is None: - n_jobs = numba.get_num_threads() - n_jobs = max(1, min(n_jobs, n_perms)) + n_jobs = _get_n_cores(n_jobs) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" From d8dc655082b8584933fffc468d3878aba57b9a5d Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 17 Mar 2026 00:35:54 +0300 Subject: [PATCH 11/13] use thread_map --- src/squidpy/_utils.py | 51 +++++++++++++++++++++++++++++++++++++++ src/squidpy/gr/_ligrec.py | 23 ++++++------------ 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 3c5c86392..162062f8d 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -228,6 +228,57 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +def thread_map( + fn: Callable[..., Any], + items: Iterable[Any], + *, + n_jobs: int = 1, + show_progress_bar: bool = False, + unit: str = "item", + total: int | None = None, +) -> list[Any]: + """Map *fn* over *items* using a thread pool with an optional progress bar. + + Parameters + ---------- + fn + Callable applied to each element of *items*. + items + Iterable of inputs passed one-by-one to *fn*. + n_jobs + Number of worker threads. ``1`` runs sequentially (no pool overhead). + show_progress_bar + Whether to display a ``tqdm`` progress bar. + unit + Label shown next to the ``tqdm`` counter. + total + Length hint passed to ``tqdm`` when *items* has no ``__len__``. + + Returns + ------- + list + Results in the same order as *items*. + """ + from concurrent.futures import ThreadPoolExecutor + + try: + import ipywidgets # noqa: F401 + from tqdm.auto import tqdm + except ImportError: + try: + from tqdm.std import tqdm + except ImportError: + tqdm = None # type: ignore[assignment] + + _total = total if total is not None else (len(items) if hasattr(items, "__len__") else None) + + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + it = pool.map(fn, items) + if show_progress_bar and tqdm is not None: + it = tqdm(it, total=_total, unit=unit) + return list(it) + + def _get_n_cores(n_cores: int | None) -> int: """ Make number of cores a positive integer. diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index f6e506714..2f6926864 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -5,7 +5,6 @@ from abc import ABC from collections import namedtuple from collections.abc import Iterable, Mapping, Sequence -from concurrent.futures import ThreadPoolExecutor from itertools import product from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, TypeAlias @@ -22,7 +21,7 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params +from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params, thread_map from squidpy._validators import assert_positive, check_tuple_needles from squidpy.gr._utils import ( _assert_categorical_obs, @@ -712,31 +711,25 @@ def _analysis( chunk_sizes = np.full(n_jobs, base_chunk, dtype=np.int64) chunk_sizes[:remainder] += 1 - thread_counts = np.zeros((n_jobs, n_inter, n_cpairs), dtype=np.int64) pbar = tqdm(total=n_perms, unit="permutation", disable=not show_progress_bar) - def _worker(t: int) -> None: + def _worker(t: int) -> NDArrayA: + local_counts = np.zeros((n_inter, n_cpairs), dtype=np.int64) rs = np.random.RandomState(None if seed is None else t + seed) perm = clustering.copy() for _ in range(chunk_sizes[t]): rs.shuffle(perm) _score_permutation( - data_arr, - perm, - inv_counts, - mean_obs, - interactions, - interaction_clusters, - valid, - thread_counts[t], + data_arr, perm, inv_counts, mean_obs, + interactions, interaction_clusters, valid, local_counts, ) pbar.update(1) + return local_counts - with ThreadPoolExecutor(max_workers=n_jobs) as pool: - list(pool.map(_worker, range(n_jobs))) + thread_counts = thread_map(_worker, range(n_jobs), n_jobs=n_jobs) pbar.close() - pval_counts = thread_counts.sum(axis=0) + pval_counts = np.sum(thread_counts, axis=0) pvalues = pval_counts.astype(np.float64) / n_perms pvalues[~valid] = np.nan From fb5c2e7d3f83f0c9eb7b0a37ae1e2be3d08f9866 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Mar 2026 21:37:38 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_ligrec.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 2f6926864..cf53511d3 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -720,8 +720,14 @@ def _worker(t: int) -> NDArrayA: for _ in range(chunk_sizes[t]): rs.shuffle(perm) _score_permutation( - data_arr, perm, inv_counts, mean_obs, - interactions, interaction_clusters, valid, local_counts, + data_arr, + perm, + inv_counts, + mean_obs, + interactions, + interaction_clusters, + valid, + local_counts, ) pbar.update(1) return local_counts From 165544dc13c97ea16d268927f6ee7d85186b06b3 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 17 Mar 2026 00:41:35 +0300 Subject: [PATCH 13/13] update threadmap --- src/squidpy/_utils.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 162062f8d..e3e7e1457 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -18,6 +18,7 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel +from tqdm.auto import tqdm __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -261,21 +262,10 @@ def thread_map( """ from concurrent.futures import ThreadPoolExecutor - try: - import ipywidgets # noqa: F401 - from tqdm.auto import tqdm - except ImportError: - try: - from tqdm.std import tqdm - except ImportError: - tqdm = None # type: ignore[assignment] - - _total = total if total is not None else (len(items) if hasattr(items, "__len__") else None) - with ThreadPoolExecutor(max_workers=n_jobs) as pool: it = pool.map(fn, items) if show_progress_bar and tqdm is not None: - it = tqdm(it, total=_total, unit=unit) + it = tqdm(it, total=len(items), unit=unit) return list(it)