diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 30356d8b04..840c13a3eb 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -294,9 +294,9 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 elif sort_order and color_type == "cat": # Null points go on bottom order = np.argsort(~pd.isnull(color_source_vector), kind="stable") - # Set orders - if isinstance(size, np.ndarray): - size = np.array(size)[order] + # Set orders — use a local to avoid cumulative reordering across + # subplots when multiple color keys are given. + _size = np.array(size)[order] if isinstance(size, np.ndarray) else size color_source_vector = color_source_vector[order] color_vector = color_vector[order] coords = basis_values[:, dims][order, :] @@ -348,10 +348,10 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 ) else: scatter = ( - partial(ax.scatter, s=size, plotnonfinite=True) + partial(ax.scatter, s=_size, plotnonfinite=True) if scale_factor is None else partial( - circles, s=size, ax=ax, scale_factor=scale_factor + circles, s=_size, ax=ax, scale_factor=scale_factor ) # size in circles is radius ) @@ -366,7 +366,7 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 # with some transparency. bg_width, gap_width = outline_width - point = np.sqrt(size) + point = np.sqrt(_size) gap_size = (point + (point * gap_width) * 2) ** 2 bg_size = (np.sqrt(gap_size) + (point * bg_width) * 2) ** 2 # the default black and white colors can be changes using