diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index 65d63cc1d6..8929532c0d 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from collections import OrderedDict from collections.abc import Collection, Mapping, Sequence from itertools import pairwise, product @@ -51,7 +52,6 @@ from matplotlib.axes import Axes from matplotlib.colors import Colormap, ListedColormap, Normalize from numpy.typing import NDArray - from seaborn import FacetGrid from seaborn.matrix import ClusterGrid from .._utils import Empty @@ -753,13 +753,14 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 xlabel: str = "", ylabel: str | Sequence[str] | None = None, rotation: float | None = None, + ncols: int | None = None, show: bool | None = None, ax: Axes | None = None, # deprecated save: bool | str | None = None, scale: DensityNorm | Empty = _empty, **kwds, -) -> Axes | FacetGrid | None: +) -> Axes | Sequence[Axes] | None: """Violin plot. Wraps :func:`seaborn.violinplot` for :class:`~anndata.AnnData`. @@ -802,17 +803,21 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown. ylabel - Label of the y axis. If `None` and `groupby` is `None`, defaults - to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`. + Label of the y axis. rotation Rotation of xtick labels. + ncols + Number of columns for arranging multiple plots. + If `None`, all panels are placed in a single row. {show_save_ax} **kwds Are passed to :func:`~seaborn.violinplot`. Returns ------- - A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`. + Axes or list of Axes + If `show=False`, returns the `matplotlib` Axes object(s) used for + plotting. If `show=True`, returns `None`. Examples -------- @@ -873,13 +878,19 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 del scale if isinstance(ylabel, str | NoneType): - ylabel = [ylabel] * (1 if groupby is None else len(keys)) + ylabel = "" if ylabel is None else ylabel + if groupby is None and multi_panel: + ylabel = [ylabel] * len(keys) + else: + ylabel = [ylabel] * (1 if groupby is None else len(keys)) + if groupby is None: - if len(ylabel) != 1: - msg = f"Expected number of y-labels to be `1`, found `{len(ylabel)}`." + expected = len(keys) if multi_panel else 1 + if len(ylabel) != expected: + msg = f"Expected {expected} y-labels, got {len(ylabel)}." raise ValueError(msg) elif len(ylabel) != len(keys): - msg = f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`." + msg = f"Expected {len(keys)} y-labels, got {len(ylabel)}." raise ValueError(msg) if groupby is not None: @@ -911,56 +922,62 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 x = groupby ys = keys - if multi_panel and groupby is None and len(ys) == 1: - # This is a quick and dirty way for adapting scales across several - # keys if groupby is None. - y = ys[0] - - g: sns.axisgrid.FacetGrid = sns.catplot( - y=y, - data=obs_tidy, - kind="violin", - density_norm=density_norm, - col=x, - col_order=keys, - sharey=False, - cut=0, - inner=None, - **kwds, - ) + # set default violin parameters + kwds.setdefault("cut", 0) + kwds.setdefault("inner") - if stripplot: - grouped_df = obs_tidy.groupby(x, observed=True) - for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=True): - sns.stripplot( - y=y, - data=grouped_df.get_group(key), - jitter=jitter, - size=size, - color="black", - ax=g.axes[0, ax_id], - ) - if log: - g.set(yscale="log") - g.set_titles(col_template="{col_name}").set_xlabels("") - if rotation is not None: - for ax_base in g.axes[0]: - ax_base.tick_params(axis="x", labelrotation=rotation) - else: - # set by default the violin plot cut=0 to limit the extend - # of the violin plot (see stacked_violin code) for more info. - kwds.setdefault("cut", 0) - kwds.setdefault("inner") + if ax is None: + panels = keys if multi_panel else ["x"] if groupby is None else keys - if ax is None: + if ncols is not None and len(panels) > 1: + n_panels = len(panels) + n_rows = math.ceil(n_panels / ncols) + _fig, axs = plt.subplots(n_rows, ncols) + axs = axs.flatten()[:n_panels] + else: axs, _, _, _ = setup_axes( ax, - panels=["x"] if groupby is None else keys, + panels=panels, show_ticks=True, right_margin=0.3, ) - else: - axs = [ax] + else: + axs = [ax] + + if len(axs) > 1: + axs[0].figure.subplots_adjust(hspace=0.5, wspace=0.4) + + if groupby is None and multi_panel: + for ax_base, key, ylab in zip(axs, keys, ylabel, strict=True): + sns.violinplot( + y=key, + data=obs_df, + orient="vertical", + density_norm=density_norm, + ax=ax_base, + **kwds, + ) + + if stripplot: + sns.stripplot( + y=key, + data=obs_df, + jitter=jitter, + color="black", + size=size, + ax=ax_base, + ) + + ax_base.set_xlabel("") + ax_base.set_title(str(key).replace("_", " ")) + if ylab is not None: + ax_base.set_ylabel(ylab) + if log: + ax_base.set_yscale("log") + if rotation is not None: + ax_base.tick_params(axis="x", labelrotation=rotation) + + else: for ax_base, y, ylab in zip(axs, ys, ylabel, strict=True): sns.violinplot( x=x, @@ -972,6 +989,7 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 ax=ax_base, **kwds, ) + if stripplot: sns.stripplot( x=x, @@ -983,6 +1001,11 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 size=size, ax=ax_base, ) + + if multi_panel or groupby is not None: + ax_base.set_title(str(y).replace("_", " ")) + else: + ax_base.set_title("") if xlabel == "" and groupby is not None and rotation is None: xlabel = groupby.replace("_", " ") ax_base.set_xlabel(xlabel) @@ -996,8 +1019,6 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 _utils.savefig_or_show("violin", show=show, save=save) if show: return None - if multi_panel and groupby is None and len(ys) == 1: - return g if len(axs) == 1: return axs[0] return axs