Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a6e1438
remove parallelize from sepal
selmanozleyen Feb 20, 2026
e011ddc
preallocated buffers
selmanozleyen Feb 20, 2026
4aad2c0
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Feb 20, 2026
d9e3438
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2026
357966a
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Feb 24, 2026
8c6df25
add sparse batch support
selmanozleyen Feb 24, 2026
023461d
Merge branch 'feat/remove-parallelize-minimal' of https://github.com/…
selmanozleyen Feb 24, 2026
b0d0792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
b9605dd
better default
selmanozleyen Feb 24, 2026
2df55b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
c90f80b
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Feb 26, 2026
e856524
lazy densify
selmanozleyen Feb 27, 2026
d81b098
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2026
542df63
ignore comment for fau
selmanozleyen Feb 27, 2026
2f48ba2
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 1, 2026
e7e5579
update
selmanozleyen Mar 1, 2026
f24c1ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2026
2d8a143
update docstrings
selmanozleyen Mar 2, 2026
b283de7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2026
46558fd
init
selmanozleyen Mar 2, 2026
b64ebc9
rng can't be none inside test
selmanozleyen Mar 2, 2026
deb31f3
Merge branch 'feat/ligrec-rng-threading' into feat/remove-parallelize…
selmanozleyen Mar 2, 2026
5bc7caa
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 8, 2026
ecf7147
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 11, 2026
1c34afa
checkout main
selmanozleyen Mar 11, 2026
9ee9462
remove refrence file
selmanozleyen Mar 11, 2026
b7ca301
add progress bar option and deprecate warning for backend
selmanozleyen Mar 11, 2026
9c1d97a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
c384f9b
add deprecated params import
selmanozleyen Mar 11, 2026
3adf5ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
103e0b5
revert log message
selmanozleyen Mar 11, 2026
10a6913
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 11, 2026
65160ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
417d3b7
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 12, 2026
7f8bbac
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 16, 2026
94e9804
update conf
selmanozleyen Mar 16, 2026
f4ffe34
now use threadpool util
selmanozleyen Mar 16, 2026
6749135
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
a7a5d95
Merge branch 'main' into feat/remove-parallelize-minimal
selmanozleyen Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
numpy=("https://numpy.org/doc/stable", None),
statsmodels=("https://www.statsmodels.org/stable", None),
scipy=("https://docs.scipy.org/doc/scipy", None),
pandas=("https://pandas.pydata.org/pandas-docs/stable", None),
pandas=("https://pandas.pydata.org/docs", None),
anndata=("https://anndata.readthedocs.io/en/stable", None),
scanpy=("https://scanpy.readthedocs.io/en/stable", None),
matplotlib=("https://matplotlib.org/stable", None),
Expand Down
41 changes: 41 additions & 0 deletions src/squidpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import xarray as xr
from spatialdata.models import Image2DModel, Labels2DModel
from tqdm.auto import tqdm

__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]

Expand Down Expand Up @@ -228,6 +229,46 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def thread_map(
fn: Callable[..., Any],
items: Iterable[Any],
*,
n_jobs: int = 1,
show_progress_bar: bool = False,
unit: str = "item",
total: int | None = None,
) -> list[Any]:
"""Map *fn* over *items* using a thread pool with an optional progress bar.

Parameters
----------
fn
Callable applied to each element of *items*.
items
Iterable of inputs passed one-by-one to *fn*.
n_jobs
Number of worker threads. ``1`` runs sequentially (no pool overhead).
show_progress_bar
Whether to display a ``tqdm`` progress bar.
unit
Label shown next to the ``tqdm`` counter.
total
Length hint passed to ``tqdm`` when *items* has no ``__len__``.

Returns
-------
list
Results in the same order as *items*.
"""
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(max_workers=n_jobs) as pool:
it = pool.map(fn, items)
if show_progress_bar and tqdm is not None:
it = tqdm(it, total=len(items), unit=unit)
return list(it)


def _get_n_cores(n_cores: int | None) -> int:
"""
Make number of cores a positive integer.
Expand Down
123 changes: 66 additions & 57 deletions src/squidpy/gr/_sepal.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit
from scanpy import logging as logg
from scipy.sparse import csr_matrix, isspmatrix_csr, spmatrix
from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix
from sklearn.metrics import pairwise_distances
from spatialdata import SpatialData

from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params, thread_map
from squidpy._validators import assert_non_empty_sequence
from squidpy.gr._utils import (
_assert_connectivity_key,
Expand All @@ -28,6 +28,7 @@

@d.dedent
@inject_docs(key=Key.obsp.spatial_conn())
@deprecated_params({"backend": "1.10.0"})
def sepal(
adata: AnnData | SpatialData,
max_neighs: Literal[4, 6],
Expand All @@ -41,7 +42,6 @@ def sepal(
use_raw: bool = False,
copy: bool = False,
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = True,
) -> pd.DataFrame | None:
"""
Expand Down Expand Up @@ -78,8 +78,8 @@ def sepal(
use_raw
Whether to access :attr:`anndata.AnnData.raw`.
%(copy)s
%(parallelize)s

%(n_jobs)s
%(show_progress_bar)s
Returns
-------
If ``copy = True``, returns a :class:`pandas.DataFrame` with the sepal scores.
Expand Down Expand Up @@ -126,24 +126,22 @@ def sepal(
vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer)
start = logg.info(f"Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)")

score = parallelize(
_score_helper,
collection=np.arange(len(genes)).tolist(),
extractor=np.hstack,
use_ixs=False,
use_hex = max_neighs == 6

if issparse(vals):
vals = csc_matrix(vals)
score = _diffusion_genes(
vals,
use_hex,
n_iter,
sat,
sat_idx,
unsat,
unsat_idx,
dt,
thresh,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)(
vals=vals,
max_neighs=max_neighs,
n_iter=n_iter,
sat=sat,
sat_idx=sat_idx,
unsat=unsat,
unsat_idx=unsat_idx,
dt=dt,
thresh=thresh,
)

key_added = "sepal_score"
Expand All @@ -160,69 +158,80 @@ def sepal(
_save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start)


def _score_helper(
ixs: Sequence[int],
vals: spmatrix | NDArrayA,
max_neighs: int,
def _diffusion_genes(
vals: NDArrayA | spmatrix,
use_hex: bool,
n_iter: int,
sat: NDArrayA,
sat_idx: NDArrayA,
unsat: NDArrayA,
unsat_idx: NDArrayA,
dt: float,
thresh: float,
queue: SigQueue | None = None,
n_jobs: int,
show_progress_bar: bool = True,
) -> NDArrayA:
if max_neighs == 4:
fun = _laplacian_rect
elif max_neighs == 6:
fun = _laplacian_hex
else:
raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.")

score = []
for i in ixs:
if isinstance(vals, spmatrix):
conc = vals[:, i].toarray().flatten() # Safe to call toarray()
else:
conc = vals[:, i].copy() # vals is assumed to be a NumPy array here

time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh)
score.append(dt * time_iter)

if queue is not None:
queue.put(Signal.UPDATE)
"""Run diffusion for each gene column, parallelised across threads."""

if queue is not None:
queue.put(Signal.FINISH)
sparse = issparse(vals)

return np.array(score)
def _process_gene(i: int) -> float:
if sparse:
conc = np.ascontiguousarray(vals[:, i].toarray().ravel(), dtype=np.float64)
else:
conc = vals[:, i]
time_iter = _diffusion(
conc,
use_hex,
n_iter,
sat,
sat_idx,
unsat,
unsat_idx,
dt,
thresh,
)
return dt * time_iter

scores = thread_map(
_process_gene,
range(vals.shape[1]),
n_jobs=n_jobs,
show_progress_bar=show_progress_bar,
unit="gene",
)
return np.array(scores)


@njit(fastmath=True)
@njit(fastmath=True, nogil=True)
def _diffusion(
conc: NDArrayA,
laplacian: Callable[[NDArrayA, NDArrayA], float],
use_hex: bool,
n_iter: int,
sat: NDArrayA,
sat_idx: NDArrayA,
unsat: NDArrayA,
unsat_idx: NDArrayA,
dt: float = 0.001,
thresh: float = 1e-8,
dt: float,
thresh: float,
) -> float:
"""Simulate diffusion process on a regular graph."""
sat_shape, conc_shape = sat.shape[0], conc.shape[0]
sat_shape = sat.shape[0]
n_cells = conc.shape[0]
entropy_arr = np.zeros(n_iter)
prev_ent = 1.0
nhood = np.zeros(sat_shape)
dcdt = np.zeros(n_cells)
prev_ent = 1.0

for i in range(n_iter):
for j in range(sat_shape):
nhood[j] = np.sum(conc[sat_idx[j]])
d2 = laplacian(conc[sat], nhood)
if use_hex:
d2 = _laplacian_hex(conc[sat], nhood)
else:
d2 = _laplacian_rect(conc[sat], nhood)

dcdt = np.zeros(conc_shape)
dcdt[:] = 0.0
dcdt[sat] = d2
conc[sat] += dcdt[sat] * dt
conc[unsat] += dcdt[unsat_idx] * dt
Expand Down
20 changes: 16 additions & 4 deletions tests/graph/test_sepal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numba
import numpy as np
from anndata import AnnData
from pandas.testing import assert_frame_equal
Expand All @@ -16,8 +17,14 @@ def test_sepal_seq_par(adata: AnnData):
adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape, p=[0.005, 0.995])

sepal(adata, max_neighs=6)
df = sepal(adata, max_neighs=6, copy=True, n_jobs=1)
df_parallel = sepal(adata, max_neighs=6, copy=True, n_jobs=2)

prev_threads = numba.get_num_threads()
try:
numba.set_num_threads(1)
df = sepal(adata, max_neighs=6, copy=True)
finally:
numba.set_num_threads(prev_threads)
df_parallel = sepal(adata, max_neighs=6, copy=True)

idx_df = df.index.values
idx_adata = adata[:, adata.var.highly_variable.values].var_names.values
Expand All @@ -40,8 +47,13 @@ def test_sepal_square_seq_par(adata_squaregrid: AnnData):
rng = np.random.default_rng(42)
adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape)

sepal(adata, max_neighs=4)
df_parallel = sepal(adata, copy=True, n_jobs=2, max_neighs=4)
prev_threads = numba.get_num_threads()
try:
numba.set_num_threads(1)
sepal(adata, max_neighs=4)
finally:
numba.set_num_threads(prev_threads)
df_parallel = sepal(adata, copy=True, max_neighs=4)

idx_df = df_parallel.index.values
idx_adata = adata[:, adata.var.highly_variable.values].var_names.values
Expand Down
Loading