Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/crested/pl/design/_enhancer_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def step_contribution_scores(
sequence_labels: list | None = None,
class_labels: list | None = None,
zoom_n_bases: int | None = None,
x_shift: int = 0,
ylim: tuple[float, float] | None = None,
global_ylim: Literal["all", "per_design", "per_plot"] | None = "per_plot",
method: Literal["mutagenesis", "mutagenesis_letters"] | None = None,
Expand Down Expand Up @@ -90,6 +91,8 @@ def step_contribution_scores(
highlight_kws
Keywords to use for plotting changed basepairs with :meth:`~matplotlib.axes.Axes.axvspan`.
Default is {'edgecolor': "red", 'facecolor': "none", 'linewidth' :0.5}
x_shift
Number of base pairs to shift left or right for visualizing specific subsets of the region. Only use when combined with zooming in. Default is zero.
show
Whether to show all plots or return the (list of) figure and axes instead.
width
Expand Down Expand Up @@ -195,6 +198,7 @@ def step_contribution_scores(
sequence_labels=step_labels, # Sequence labels per step
class_labels=None,
zoom_n_bases=zoom_n_bases,
x_shift=x_shift,
method=method,
sharey=sharey,
ylim=ylim,
Expand Down
16 changes: 12 additions & 4 deletions src/crested/pl/locus/_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def track(
coordinates: str | tuple | None = None,
class_names: Sequence[str] | str | None = None,
highlight_positions: list[tuple[int, int]] | None = None,
x_shift: int = 0,
plot_kws: dict | None = None,
highlight_kws: dict | None = None,
ax: plt.Axes | None = None,
Expand Down Expand Up @@ -51,6 +52,8 @@ def track(
highlight_kws
Keywords to use for plotting highlights with :meth:`~matplotlib.axes.Axes.axvspan`.
Default is {'color': "green", 'alpha': 0.1}
x_shift
Number of base pairs or bins to shift left or right for visualizing specific subsets of the region. Only use when combined with `zoom_n_bases`.
ax
Axis to plot values on. If not supplied, creates a figure from scratch.
width
Expand Down Expand Up @@ -104,15 +107,19 @@ def track(
@log_and_raise(ValueError)
def _check_input_params():
if scores.ndim != 2:
raise ValueError("scores must be (length) or (length, classes)")
raise ValueError(f"scores must be (length) or (length, classes), so cannot be {scores.ndim} dimensions.")
if class_idxs is not None:
for cidx in class_idxs:
if cidx > scores.shape[0]:
raise ValueError(f"class_idxs {class_idxs} is beyond your input's number of classes ({n_classes}).")
if cidx > n_data_classes:
raise ValueError(f"class idx {cidx} from class_idxs is beyond your input's number of classes ({n_data_classes}).")
if class_names is not None and cidx >= len(class_names):
raise ValueError(f"class_idxs {cidx} is beyond the size of class_names ({len(class_names)}).")
if ax is not None and n_classes > 1:
raise ValueError("ax can only be set if plotting one class. Please pick one class in `class_idxs` or pass unidimensional data.")
if zoom_n_bases is not None:
temp_start_idx = n_bins//2 - zoom_n_bases//2 + x_shift
if temp_start_idx < 0 or (temp_start_idx + zoom_n_bases) > n_bins:
raise ValueError(f"x_shift {x_shift} with zoom_n_bases {zoom_n_bases} is shifting the zoom beyond the data limits ({n_bins} bins/bp).")

# Remove singleton dimensions like single-sequence batch dims
scores = scores.squeeze()
Expand All @@ -132,6 +139,7 @@ def _check_input_params():
highlight_positions = [highlight_positions]

n_bins = scores.shape[0]
n_data_classes = scores.shape[1]
n_classes = len(class_idxs)

_check_input_params()
Expand Down Expand Up @@ -193,7 +201,7 @@ def _check_input_params():
start_idx = n_bins//2 - zoom_n_bases//2
if coordinates is not None:
start_idx += start
ax.set_xlim(start_idx, start_idx+zoom_n_bases)
ax.set_xlim(start_idx+x_shift, start_idx+zoom_n_bases+x_shift)
# Reverse x axis if negative strand info
if coordinates is not None and strand == "-":
ax.xaxis.set_inverted(True)
Expand Down
1 change: 1 addition & 0 deletions tests/test_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def test_locus_track_single():
scores=scores,
class_idxs=None,
zoom_n_bases=90,
x_shift=5,
coordinates=range_values,
highlight_positions=(10, 20),
plot_kws={'alpha': 0.5},
Expand Down
Loading