diff --git a/src/squidpy/pl/_spatial_utils.py b/src/squidpy/pl/_spatial_utils.py index 7088c8d36..ce8612fb4 100644 --- a/src/squidpy/pl/_spatial_utils.py +++ b/src/squidpy/pl/_spatial_utils.py @@ -466,7 +466,6 @@ def _set_color_source_vec( color_source_vector = adata.raw.obs_vector(value_to_plot) else: color_source_vector = adata.obs_vector(value_to_plot, layer=layer) - if not isinstance(color_source_vector.dtype, CategoricalDtype): return None, color_source_vector, False @@ -474,7 +473,6 @@ def _set_color_source_vec( categories = color_source_vector.categories if groups is not None: color_source_vector = color_source_vector.remove_categories(categories.difference(groups)) - color_map = _get_palette( adata, cluster_key=value_to_plot, diff --git a/src/squidpy/tl/_sliding_window.py b/src/squidpy/tl/_sliding_window.py index 431b3103e..58ee2e575 100644 --- a/src/squidpy/tl/_sliding_window.py +++ b/src/squidpy/tl/_sliding_window.py @@ -1,5 +1,7 @@ from __future__ import annotations +import math +import time from collections import defaultdict from itertools import product @@ -19,12 +21,17 @@ def sliding_window( adata: AnnData | SpatialData, library_key: str | None = None, - window_size: int | None = None, - overlap: int = 0, coord_columns: tuple[str, str] = ("globalX", "globalY"), - sliding_window_key: str = "sliding_window_assignment", + window_size: int | tuple[int, int] | None = None, spatial_key: str = "spatial", + sliding_window_key: str = "sliding_window_assignment", + overlap: int = 0, + max_n_cells: int = None, + split_line: str = "h", + n_splits: int = None, drop_partial_windows: bool = False, + square: bool = False, + window_size_per_library_key: str = "equal", copy: bool = False, ) -> pd.DataFrame | None: """ @@ -33,36 +40,53 @@ def sliding_window( Parameters ---------- %(adata)s - window_size: int - Size of the sliding window. %(library_key)s coord_columns: Tuple[str, str] Tuple of column names in `adata.obs` that specify the coordinates (x, y), e.i. ('globalX', 'globalY') + window_size: int | Tuple[str, str] + Size of the sliding window. + %(spatial_key)s sliding_window_key: str Base name for sliding window columns. overlap: int Overlap size between consecutive windows. (0 = no overlap) - %(spatial_key)s + max_n_cells: int + If window_size is None, either 'n_split' or 'max_n_cells' can be set. + max_n_cells sets an upper limit for the number of cells within each region. + split_line: str + If 'square' is False, this set's the orientation for rectanglular regions. `h` : Horizontal, `v`: Vertical + n_splits: int + This can be used to split the entire region to some splits. drop_partial_windows: bool If True, drop windows that are smaller than the window size at the borders. + square: bool + If True, the windows will be square. + window_size_per_library_key: str + If 'equal', the window size will be the same for all libraries. If 'different', the window size will be optimized + for each library based on the number of cells in the library. copy: bool If True, return the result, otherwise save it to the adata object. - Returns ------- If ``copy = True``, returns the sliding window annotation(s) as pandas dataframe Otherwise, stores the sliding window annotation(s) in .obs. """ + if overlap < 0: raise ValueError("Overlap must be non-negative.") - if isinstance(adata, SpatialData): adata = adata.table + assert max_n_cells is None or n_splits is None, ( + "You can specify only one from the parameters 'n_split' and 'max_n_cells' " + ) # we don't want to modify the original adata in case of copy=True if copy: adata = adata.copy() + if "sliding_window_assignment_colors" in adata.uns: + del adata.uns["sliding_window_assignment_colors"] + # extract coordinates of observations x_col, y_col = coord_columns if x_col in adata.obs and y_col in adata.obs: @@ -78,51 +102,102 @@ def sliding_window( f"Coordinates not found. Provide `{coord_columns}` in `adata.obs` or specify a suitable `spatial_key` in `adata.obsm`." ) - # infer window size if not provided - if window_size is None: - coord_range = max( - coords[x_col].max() - coords[x_col].min(), - coords[y_col].max() - coords[y_col].min(), - ) - # mostly arbitrary choice, except that full integers usually generate windows with 1-2 cells at the borders - window_size = max(int(np.floor(coord_range // 3.95)), 1) - - if window_size <= 0: - raise ValueError("Window size must be larger than 0.") - if library_key is not None and library_key not in adata.obs: raise ValueError(f"Library key '{library_key}' not found in adata.obs") - libraries = [None] if library_key is None else adata.obs[library_key].unique() + if library_key is None: + library_key = "temp_fov" + adata.obs[library_key] = "fov1" + + libraries = adata.obs[library_key].unique() + + fovs_x_range = [ + (coords[adata.obs[library_key] == key][x_col].max(), coords[adata.obs[library_key] == key][x_col].min()) + for key in libraries + ] + fovs_y_range = [ + (coords[adata.obs[library_key] == key][y_col].max(), coords[adata.obs[library_key] == key][y_col].min()) + for key in libraries + ] + fovs_width = [i - j for (i, j) in fovs_x_range] + fovs_height = [i - j for (i, j) in fovs_y_range] + fovs_n_cell = [adata[adata.obs[library_key] == key].shape[0] for key in libraries] + fovs_area = [i * j for i, j in zip(fovs_width, fovs_height)] + fovs_density = [i / j for i, j in zip(fovs_n_cell, fovs_area)] + window_sizes = [] - # Create a DataFrame to store the sliding window assignments - sliding_window_df = pd.DataFrame(index=adata.obs.index) + if window_size is None: + if window_size_per_library_key == "equal": + if max_n_cells: + n_splits = max(2, int(min(fovs_n_cell) / max_n_cells)) + min_n_cells = max(int(0.2 * max_n_cells), 1) + elif n_splits is None: + n_splits = 2 + max_n_cells = int(min(fovs_n_cell) / n_splits) + min_n_cells = max(int(0.2 * max_n_cells), 1) + else: + max_n_cells = int(min(fovs_n_cell) / n_splits) + min_n_cells = max_n_cells - 1 - if sliding_window_key in adata.obs: - logg.warning(f"Overwriting existing column '{sliding_window_key}' in adata.obs.") + maximum_region_area = max_n_cells / max(fovs_density) + minimum_region_area = min_n_cells / max(fovs_density) - for lib in libraries: - if lib is not None: - lib_mask = adata.obs[library_key] == lib - lib_coords = coords.loc[lib_mask] + window_size = _optimize_tile_size( + min(fovs_width), min(fovs_height), minimum_region_area, maximum_region_area, square, split_line + ) + window_sizes = [window_size] * len(libraries) else: - lib_mask = np.ones(len(adata), dtype=bool) - lib_coords = coords + for i, lib in enumerate(libraries): + if max_n_cells: + n_splits = max(2, int(fovs_n_cell[i] / max_n_cells)) + min_n_cells = max(int(0.2 * max_n_cells), 1) + elif n_splits is None: + n_splits = 2 + max_n_cells = int(fovs_n_cell[i] / n_splits) + min_n_cells = max(int(0.2 * max_n_cells), 1) + else: + max_n_cells = int(fovs_n_cell[i] / n_splits) + min_n_cells = max_n_cells - 1 + + min_n_cells = int(fovs_n_cell[i] / n_splits) + minimum_region_area = min_n_cells / max(fovs_density) + maximum_region_area = fovs_area[i] / fovs_density[i] + window_sizes.append( + _optimize_tile_size( + fovs_width[i], fovs_height[i], minimum_region_area, maximum_region_area, square, split_line + ) + ) + else: + # assert split_line is None, logg.warning("'split' ignored as window_size is specified for square regions") + assert n_splits is None, logg.warning("'n_split' ignored as window_size is specified for square regions") + assert max_n_cells is None, logg.warning("'max_n_cells' ignored as window_size is specified") + if isinstance(window_size, (int, float)): + if window_size <= 0: + raise ValueError("Window size must be larger than 0.") + else: + window_size = (window_size, window_size) + elif isinstance(window_size, tuple): + for i in window_size: + if i <= 0: + raise ValueError("Window size must be larger than 0.") - min_x, max_x = lib_coords[x_col].min(), lib_coords[x_col].max() - min_y, max_y = lib_coords[y_col].min(), lib_coords[y_col].max() + window_sizes = [window_size] * len(libraries) + # Create a DataFrame to store the sliding window assignments + sliding_window_df = pd.DataFrame(index=adata.obs.index) + if sliding_window_key in adata.obs: + logg.warning(f"Overwriting existing column '{sliding_window_key}' in adata.obs.") + for i, lib in enumerate(libraries): + lib_mask = adata.obs[library_key] == lib + lib_coords = coords.loc[lib_mask] # precalculate windows windows = _calculate_window_corners( - min_x=min_x, - max_x=max_x, - min_y=min_y, - max_y=max_y, - window_size=window_size, + fovs_x_range[i], + fovs_y_range[i], + window_size=window_sizes[i], overlap=overlap, drop_partial_windows=drop_partial_windows, ) - lib_key = f"{lib}_" if lib is not None else "" # assign observations to windows @@ -132,6 +207,11 @@ def sliding_window( y_start = window["y_start"] y_end = window["y_end"] + if drop_partial_windows: + # Check if the window is within the bounds + if x_end > fovs_x_range[i][0] or y_end > fovs_y_range[i][0]: + continue # Skip windows that extend beyond the region + mask = ( (lib_coords[x_col] >= x_start) & (lib_coords[x_col] <= x_end) @@ -149,7 +229,7 @@ def sliding_window( ) obs_indices = lib_coords.index[mask] sliding_window_df.loc[obs_indices, sliding_window_key] = f"{lib_key}window_{idx}" - + sliding_window_df.loc[:, sliding_window_key].fillna("out_of_window_0", inplace=True) else: col_name = f"{sliding_window_key}_{lib_key}window_{idx}" sliding_window_df.loc[obs_indices, col_name] = True @@ -157,30 +237,28 @@ def sliding_window( if overlap == 0: # create categorical variable for ordered windows + # Ensure the column is a string type sliding_window_df[sliding_window_key] = pd.Categorical( sliding_window_df[sliding_window_key], ordered=True, categories=sorted( sliding_window_df[sliding_window_key].unique(), - key=lambda x: int(x.split("_")[-1]), + key=lambda x: int(str(x).split("_")[-1]), ), ) - sliding_window_df[x_col] = coords[x_col] - sliding_window_df[y_col] = coords[y_col] - if copy: return sliding_window_df - for col_name, col_data in sliding_window_df.items(): - _save_data(adata, attr="obs", key=col_name, data=col_data) + sliding_window_df = sliding_window_df.loc[adata.obs.index] + if "temp_fov" in adata.obs.columns: + del adata.obs["temp_fov"] + _save_data(adata, attr="obs", key=sliding_window_key, data=sliding_window_df[sliding_window_key]) def _calculate_window_corners( - min_x: int, - max_x: int, - min_y: int, - max_y: int, - window_size: int, + x_range: int, + y_range: int, + window_size: int = None, overlap: int = 0, drop_partial_windows: bool = False, ) -> pd.DataFrame: @@ -210,31 +288,104 @@ def _calculate_window_corners( ------- windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end'] """ + x_window_size, y_window_size = window_size + if overlap < 0: raise ValueError("Overlap must be non-negative.") - if overlap >= window_size: + if overlap >= x_window_size or overlap >= y_window_size: raise ValueError("Overlap must be less than the window size.") - x_step = window_size - overlap - y_step = window_size - overlap + max_x, min_x = x_range + max_y, min_y = y_range + + x_step = x_window_size - overlap + y_step = y_window_size - overlap - # Generate starting points - x_starts = np.arange(min_x, max_x, x_step) - y_starts = np.arange(min_y, max_y, y_step) + # Align min_x and min_y to ensure that the first window starts properly + aligned_min_x = min_x - (min_x % x_window_size) if min_x % x_window_size != 0 else min_x + aligned_min_y = min_y - (min_y % y_window_size) if min_y % y_window_size != 0 else min_y + + # Generate starting points starting from the aligned minimum values + x_starts = np.arange(aligned_min_x, max_x, x_step) + y_starts = np.arange(aligned_min_y, max_y, y_step) # Create all combinations of x and y starting points starts = list(product(x_starts, y_starts)) windows = pd.DataFrame(starts, columns=["x_start", "y_start"]) - windows["x_end"] = windows["x_start"] + window_size - windows["y_end"] = windows["y_start"] + window_size + windows["x_end"] = windows["x_start"] + x_window_size + windows["y_end"] = windows["y_start"] + y_window_size - # Adjust windows that extend beyond the bounds if not drop_partial_windows: windows["x_end"] = windows["x_end"].clip(upper=max_x) windows["y_end"] = windows["y_end"].clip(upper=max_y) else: valid_windows = (windows["x_end"] <= max_x) & (windows["y_end"] <= max_y) windows = windows[valid_windows] - windows = windows.reset_index(drop=True) return windows[["x_start", "x_end", "y_start", "y_end"]] + + +def _optimize_tile_size( + L: int, W: int, A_min: float | None = None, A_max: float | None = None, square: bool = False, split_line: str = "v" +) -> tuple: + """ + This function optimizes the tile size for covering a rectangle of dimensions LxW. + It returns a tuple (x, y) where x and y are the dimensions of the optimal tile. + + Parameters: + - L (int): Length of the rectangle. + - W (int): Width of the rectangle. + - A_min (int, optional): Minimum allowed area of each tile. If None, no minimum area limit is applied. + - A_max (int, optional): Maximum allowed area of each tile. If None, no maximum area limit is applied. + - square (bool, optional): If True, tiles will be square (x = y). + + Returns: + - tuple: (x, y) representing the optimal tile dimensions. + """ + best_tile_size = None + min_uncovered_area = float("inf") + area = L * W + if square: + # Calculate square tiles + max_side = min(int(math.sqrt(A_max)), int(min(L, W))) if A_max else int(min(L, W)) + min_side = int(math.sqrt(A_min)) if A_min else 1 + # Try all square tile sizes from min_side to max_side + for side in range(min_side, max_side + 1): + if (A_min and side * side < A_min) or (A_max and side * side > A_max): + continue # Skip sizes that are out of the area limits + + # Calculate number of tiles that fit in the rectangle + num_tiles_x = max(L // side, 1) + num_tiles_y = max(W // side, 1) + uncovered_area = area - (num_tiles_x * num_tiles_y * side * side) + + # Track the best tile size + if uncovered_area < min_uncovered_area: + min_uncovered_area = uncovered_area + best_tile_size = (side, side) + else: + # For non-square tiles, optimize both dimensions independently + if split_line == "v": + max_tile_length = A_max / W if A_max else int(L) + max_tile_width = W + min_tile_length = A_min / W + min_tile_width = W + if split_line == "h": + max_tile_length = L + max_tile_width = A_max / L if A_max else 0 + min_tile_width = A_min / L + min_tile_length = L + # Try all combinations of width and height within the bounds + for width in range(int(min_tile_width), int(max_tile_width) + 1): + for height in range(int(min_tile_length), int(max_tile_length) + 1): + if (A_min and width * height < A_min) or (A_max and width * height > A_max): + continue # Skip sizes that are out of the area limits + # Calculate number of tiles that fit in the rectangle + num_tiles_x = max(L // width, 1) + num_tiles_y = max(W // height, 1) + uncovered_area = area - (num_tiles_x * num_tiles_y * width * height) + # Track the best tile size (minimizing uncovered area) + if uncovered_area < min_uncovered_area: + min_uncovered_area = uncovered_area + best_tile_size = (height, width) + return best_tile_size