Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
numpy=("https://numpy.org/doc/stable", None),
statsmodels=("https://www.statsmodels.org/stable", None),
scipy=("https://docs.scipy.org/doc/scipy", None),
pandas=("https://pandas.pydata.org/pandas-docs/stable", None),
pandas=("https://pandas.pydata.org/docs", None),
anndata=("https://anndata.readthedocs.io/en/stable", None),
scanpy=("https://scanpy.readthedocs.io/en/stable", None),
matplotlib=("https://matplotlib.org/stable", None),
Expand Down
41 changes: 41 additions & 0 deletions src/squidpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import xarray as xr
from spatialdata.models import Image2DModel, Labels2DModel
from tqdm.auto import tqdm

__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]

Expand Down Expand Up @@ -228,6 +229,46 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def thread_map(
fn: Callable[..., Any],
items: Iterable[Any],
*,
n_jobs: int = 1,
show_progress_bar: bool = False,
unit: str = "item",
total: int | None = None,
) -> list[Any]:
"""Map *fn* over *items* using a thread pool with an optional progress bar.

Parameters
----------
fn
Callable applied to each element of *items*.
items
Iterable of inputs passed one-by-one to *fn*.
n_jobs
Number of worker threads. ``1`` runs sequentially (no pool overhead).
show_progress_bar
Whether to display a ``tqdm`` progress bar.
unit
Label shown next to the ``tqdm`` counter.
total
Length hint passed to ``tqdm`` when *items* has no ``__len__``.

Returns
-------
list
Results in the same order as *items*.
"""
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(max_workers=n_jobs) as pool:
it = pool.map(fn, items)
if show_progress_bar and tqdm is not None:
it = tqdm(it, total=len(items), unit=unit)
return list(it)


def _get_n_cores(n_cores: int | None) -> int:
"""
Make number of cores a positive integer.
Expand Down
Loading
Loading