diff --git a/src/crested/pl/design/_enhancer_design.py b/src/crested/pl/design/_enhancer_design.py index 1780e142..b6ea6438 100644 --- a/src/crested/pl/design/_enhancer_design.py +++ b/src/crested/pl/design/_enhancer_design.py @@ -48,6 +48,7 @@ def step_contribution_scores( sequence_labels: list | None = None, class_labels: list | None = None, zoom_n_bases: int | None = None, + x_shift: int = 0, ylim: tuple[float, float] | None = None, global_ylim: Literal["all", "per_design", "per_plot"] | None = "per_plot", method: Literal["mutagenesis", "mutagenesis_letters"] | None = None, @@ -90,6 +91,8 @@ def step_contribution_scores( highlight_kws Keywords to use for plotting changed basepairs with :meth:`~matplotlib.axes.Axes.axvspan`. Default is {'edgecolor': "red", 'facecolor': "none", 'linewidth' :0.5} + x_shift + Number of base pairs to shift left or right for visualizing specific subsets of the region. Only use when combined with zooming in. Default is zero. show Whether to show all plots or return the (list of) figure and axes instead. width @@ -195,6 +198,7 @@ def step_contribution_scores( sequence_labels=step_labels, # Sequence labels per step class_labels=None, zoom_n_bases=zoom_n_bases, + x_shift=x_shift, method=method, sharey=sharey, ylim=ylim, diff --git a/src/crested/pl/locus/_track.py b/src/crested/pl/locus/_track.py index 98f04d17..057fdae7 100644 --- a/src/crested/pl/locus/_track.py +++ b/src/crested/pl/locus/_track.py @@ -20,6 +20,7 @@ def track( coordinates: str | tuple | None = None, class_names: Sequence[str] | str | None = None, highlight_positions: list[tuple[int, int]] | None = None, + x_shift: int = 0, plot_kws: dict | None = None, highlight_kws: dict | None = None, ax: plt.Axes | None = None, @@ -51,6 +52,8 @@ def track( highlight_kws Keywords to use for plotting highlights with :meth:`~matplotlib.axes.Axes.axvspan`. Default is {'color': "green", 'alpha': 0.1} + x_shift + Number of base pairs or bins to shift left or right for visualizing specific subsets of the region. Only use when combined with `zoom_n_bases`. ax Axis to plot values on. If not supplied, creates a figure from scratch. width @@ -104,15 +107,19 @@ def track( @log_and_raise(ValueError) def _check_input_params(): if scores.ndim != 2: - raise ValueError("scores must be (length) or (length, classes)") + raise ValueError(f"scores must be (length) or (length, classes), so cannot be {scores.ndim} dimensions.") if class_idxs is not None: for cidx in class_idxs: - if cidx > scores.shape[0]: - raise ValueError(f"class_idxs {class_idxs} is beyond your input's number of classes ({n_classes}).") + if cidx > n_data_classes: + raise ValueError(f"class idx {cidx} from class_idxs is beyond your input's number of classes ({n_data_classes}).") if class_names is not None and cidx >= len(class_names): raise ValueError(f"class_idxs {cidx} is beyond the size of class_names ({len(class_names)}).") if ax is not None and n_classes > 1: raise ValueError("ax can only be set if plotting one class. Please pick one class in `class_idxs` or pass unidimensional data.") + if zoom_n_bases is not None: + temp_start_idx = n_bins//2 - zoom_n_bases//2 + x_shift + if temp_start_idx < 0 or (temp_start_idx + zoom_n_bases) > n_bins: + raise ValueError(f"x_shift {x_shift} with zoom_n_bases {zoom_n_bases} is shifting the zoom beyond the data limits ({n_bins} bins/bp).") # Remove singleton dimensions like single-sequence batch dims scores = scores.squeeze() @@ -132,6 +139,7 @@ def _check_input_params(): highlight_positions = [highlight_positions] n_bins = scores.shape[0] + n_data_classes = scores.shape[1] n_classes = len(class_idxs) _check_input_params() @@ -193,7 +201,7 @@ def _check_input_params(): start_idx = n_bins//2 - zoom_n_bases//2 if coordinates is not None: start_idx += start - ax.set_xlim(start_idx, start_idx+zoom_n_bases) + ax.set_xlim(start_idx+x_shift, start_idx+zoom_n_bases+x_shift) # Reverse x axis if negative strand info if coordinates is not None and strand == "-": ax.xaxis.set_inverted(True) diff --git a/tests/test_pl.py b/tests/test_pl.py index cd7a0d44..6a64f752 100644 --- a/tests/test_pl.py +++ b/tests/test_pl.py @@ -587,6 +587,7 @@ def test_locus_track_single(): scores=scores, class_idxs=None, zoom_n_bases=90, + x_shift=5, coordinates=range_values, highlight_positions=(10, 20), plot_kws={'alpha': 0.5},