Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/squidpy/experimental/im/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
detect_tissue,
)
from ._make_tiles import make_tiles, make_tiles_from_spots
from ._qc_image import qc_image
from ._qc_metrics import QCMetric

__all__ = [
"BackgroundDetectionParams",
"FelzenszwalbParams",
"QCMetric",
"WekaParams",
"detect_tissue",
"make_tiles",
"make_tiles_from_spots",
"qc_image",
]
169 changes: 17 additions & 152 deletions src/squidpy/experimental/im/_make_tiles.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from __future__ import annotations

import itertools
from typing import Literal

import dask.array as da
import geopandas as gpd
import numpy as np
import pandas as pd
Expand All @@ -17,101 +13,16 @@
from spatialdata.transformations import get_transformation, set_transformation

from squidpy._utils import _yx_from_shape

from ._utils import _get_element_data
from squidpy.experimental.im._utils import (
TileGrid,
_get_element_data,
_get_mask_materialized,
_save_tile_grid_to_shapes,
)

__all__ = ["make_tiles", "make_tiles_from_spots"]


class _TileGrid:
"""Immutable tile grid definition with cached bounds and centroids."""

def __init__(
self,
H: int,
W: int,
tile_size: Literal["auto"] | tuple[int, int] = "auto",
target_tiles: int = 100,
offset_y: int = 0,
offset_x: int = 0,
):
self.H = H
self.W = W
if tile_size == "auto":
size = max(min(self.H // target_tiles, self.W // target_tiles), 100)
self.ty = int(size)
self.tx = int(size)
else:
self.ty = int(tile_size[0])
self.tx = int(tile_size[1])
self.offset_y = offset_y
self.offset_x = offset_x
# Calculate number of tiles needed to cover entire image, accounting for offset
# The grid starts at offset_y, offset_x (can be negative)
# We need tiles from min(0, offset_y) to at least H
# So total coverage needed is from min(0, offset_y) to H
grid_start_y = min(0, self.offset_y)
grid_start_x = min(0, self.offset_x)
total_h_needed = self.H - grid_start_y
total_w_needed = self.W - grid_start_x
self.tiles_y = (total_h_needed + self.ty - 1) // self.ty
self.tiles_x = (total_w_needed + self.tx - 1) // self.tx
# Cache immutable derived values
self._indices = np.array([[iy, ix] for iy in range(self.tiles_y) for ix in range(self.tiles_x)], dtype=int)
self._names = [f"tile_x{ix}_y{iy}" for iy in range(self.tiles_y) for ix in range(self.tiles_x)]
self._bounds = self._compute_bounds()
self._centroids_polys = self._compute_centroids_and_polygons()

def indices(self) -> np.ndarray:
return self._indices

def names(self) -> list[str]:
return self._names

def bounds(self) -> np.ndarray:
return self._bounds

def _compute_bounds(self) -> np.ndarray:
b: list[list[int]] = []
for iy, ix in itertools.product(range(self.tiles_y), range(self.tiles_x)):
y0 = iy * self.ty + self.offset_y
x0 = ix * self.tx + self.offset_x
y1 = ((iy + 1) * self.ty + self.offset_y) if iy < self.tiles_y - 1 else self.H
x1 = ((ix + 1) * self.tx + self.offset_x) if ix < self.tiles_x - 1 else self.W
# Clamp bounds to image dimensions
y0 = max(0, min(y0, self.H))
x0 = max(0, min(x0, self.W))
y1 = max(0, min(y1, self.H))
x1 = max(0, min(x1, self.W))
b.append([y0, x0, y1, x1])
return np.array(b, dtype=int)

def centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]:
return self._centroids_polys

def _compute_centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]:
cents: list[list[float]] = []
polys: list[Polygon] = []
for y0, x0, y1, x1 in self._bounds:
cy = (y0 + y1) / 2
cx = (x0 + x1) / 2
cents.append([cy, cx])
polys.append(Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1), (x0, y0)]))
return np.array(cents, dtype=float), polys

def rechunk_and_pad(self, arr_yx: da.Array) -> da.Array:
if arr_yx.ndim != 2:
raise ValueError("Expected a 2D array shaped (y, x).")
pad_y = self.tiles_y * self.ty - int(arr_yx.shape[0])
pad_x = self.tiles_x * self.tx - int(arr_yx.shape[1])
a = arr_yx.rechunk((self.ty, self.tx))
return da.pad(a, ((0, pad_y), (0, pad_x)), mode="edge") if (pad_y > 0 or pad_x > 0) else a

def coarsen(self, arr_yx: da.Array, reduce: Literal["mean", "sum"] = "mean") -> da.Array:
reducer = np.mean if reduce == "mean" else np.sum
return da.coarsen(reducer, arr_yx, {0: self.ty, 1: self.tx}, trim_excess=False)


class _SpotTileGrid:
"""Tile container for Visium spots, used with ``_filter_tiles``."""

Expand Down Expand Up @@ -204,34 +115,12 @@ def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[in

def _save_tiles_to_shapes(
sdata: sd.SpatialData,
tg: _TileGrid,
tg: TileGrid,
image_key: str,
shapes_key: str,
) -> None:
"""Save a TileGrid to sdata.shapes as a GeoDataFrame."""
tile_indices = tg.indices()
pixel_bounds = tg.bounds()
_, polys = tg.centroids_and_polygons()

tile_gdf = gpd.GeoDataFrame(
{
"tile_id": tg.names(),
"tile_y": tile_indices[:, 0],
"tile_x": tile_indices[:, 1],
"pixel_y0": pixel_bounds[:, 0],
"pixel_x0": pixel_bounds[:, 1],
"pixel_y1": pixel_bounds[:, 2],
"pixel_x1": pixel_bounds[:, 3],
"geometry": polys,
},
geometry="geometry",
)

sdata.shapes[shapes_key] = ShapesModel.parse(tile_gdf)
# we know that a) the element exists and b) it has at least an Identity transformation
transformations = get_transformation(sdata.images[image_key], get_all=True)
set_transformation(sdata.shapes[shapes_key], transformations, set_all=True)
logger.info(f"Saved tile grid as 'sdata.shapes[\"{shapes_key}\"]'")
_save_tile_grid_to_shapes(sdata, tg, shapes_key, copy_transforms_from_key=image_key)


def _save_spot_tiles_to_shapes(
Expand Down Expand Up @@ -366,7 +255,7 @@ def make_tiles(
mask_key_for_grid = default_mask_key
else:
try:
from ._detect_tissue import detect_tissue
from squidpy.experimental.im._detect_tissue import detect_tissue

detect_tissue(
sdata,
Expand Down Expand Up @@ -411,7 +300,7 @@ def make_tiles(
classification_mask_key,
)
try:
from ._detect_tissue import detect_tissue
from squidpy.experimental.im._detect_tissue import detect_tissue

detect_tissue(
sdata,
Expand Down Expand Up @@ -558,7 +447,7 @@ def make_tiles_from_spots(
classification_mask_key,
)
try:
from ._detect_tissue import detect_tissue
from squidpy.experimental.im._detect_tissue import detect_tissue

detect_tissue(
sdata,
Expand Down Expand Up @@ -633,7 +522,7 @@ def make_tiles_from_spots(

def _filter_tiles(
sdata: sd.SpatialData,
tg: _TileGrid,
tg: TileGrid,
image_key: str | None,
*,
tissue_mask_key: str | None = None,
Expand Down Expand Up @@ -686,7 +575,7 @@ def _filter_tiles(
raise ValueError("tissue_mask_key must be provided when image_key is None.")
if mask_key not in sdata.labels:
raise KeyError(f"Tissue mask '{mask_key}' not found in sdata.labels.")
mask = _get_mask_from_labels(sdata, mask_key, scale)
mask = _get_mask_materialized(sdata, mask_key, scale)
H_mask, W_mask = mask.shape

# Check tissue coverage for each tile
Expand Down Expand Up @@ -751,7 +640,7 @@ def _make_tiles(
tile_size: tuple[int, int] = (224, 224),
center_grid_on_tissue: bool = False,
scale: str = "auto",
) -> _TileGrid:
) -> TileGrid:
"""Construct a tile grid for an image, optionally centered on a tissue mask."""
# Validate image key
if image_key not in sdata.images:
Expand All @@ -764,7 +653,7 @@ def _make_tiles(

# Path 1: Regular grid starting from top-left
if not center_grid_on_tissue or image_mask_key is None:
return _TileGrid(H, W, tile_size=tile_size)
return TileGrid(H, W, tile_size=tile_size)

# Path 2: Center grid on tissue mask centroid
if image_mask_key not in sdata.labels:
Expand Down Expand Up @@ -806,7 +695,7 @@ def _make_tiles(
mask_bool = mask > 0
if not mask_bool.any():
logger.warning("Mask is empty. Using regular grid starting from top-left.")
return _TileGrid(H, W, tile_size=tile_size)
return TileGrid(H, W, tile_size=tile_size)

# Calculate centroid using center of mass
y_coords, x_coords = np.where(mask_bool)
Expand All @@ -821,7 +710,7 @@ def _make_tiles(
offset_y = int(round(centroid_y - tile_center_y_standard))
offset_x = int(round(centroid_x - tile_center_x_standard))

return _TileGrid(H, W, tile_size=tile_size, offset_y=offset_y, offset_x=offset_x)
return TileGrid(H, W, tile_size=tile_size, offset_y=offset_y, offset_x=offset_x)


def _get_spot_coordinates(
Expand Down Expand Up @@ -877,27 +766,3 @@ def _derive_tile_size_from_spots(coords: np.ndarray) -> tuple[int, int]:
)
side = max(1, int(np.floor(row_spacing)))
return side, side


def _get_mask_from_labels(sdata: sd.SpatialData, mask_key: str, scale: str) -> np.ndarray:
"""Extract a 2D mask array from ``sdata.labels`` at the requested scale."""
if mask_key not in sdata.labels:
raise KeyError(f"Mask key '{mask_key}' not found in sdata.labels")

label_node = sdata.labels[mask_key]
mask_da = _get_element_data(label_node, scale, "label", mask_key)

if is_dask_collection(mask_da):
mask_da = mask_da.compute()

if isinstance(mask_da, xr.DataArray):
mask = np.asarray(mask_da.data)
else:
mask = np.asarray(mask_da)

if mask.ndim > 2:
mask = mask.squeeze()
if mask.ndim != 2:
raise ValueError(f"Expected 2D mask with shape (y, x), got shape {mask.shape}")

return mask
Loading