Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Fixed
- Bug was that function marked the axis to be connected, but the trace_kwargs still had unique axes [[#5427](https://github.com/plotly/plotly.py/issues/5427)]
- Change: change the keyword argument for the trace, so that when the graph is initialized, it uses the correct axis instead of the autogenerated one
- Note: The program generates a unique axis label for each subgraph, and then overwrites the label (under this fix)

### Fixed
- Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)], with thanks to @antonymilne for the contribution!
- Update tests to be compatible with numpy 2.4 [[#5522](https://github.com/plotly/plotly.py/pull/5522)], with thanks to @thunze for the contribution!
Expand Down
297 changes: 194 additions & 103 deletions plotly/_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
# properties.
# Note that this set does not contain `xaxis`/`yaxis` because these behave a
# little differently.
from __future__ import annotations
import collections

from typing import Literal, Optional, Tuple, TypedDict, TYPE_CHECKING
if TYPE_CHECKING:
from plotly.graph_objects import Layout, XAxis

_single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"}
_subplot_types = set.union(_single_subplot_types, {"xy", "domain"})

Expand All @@ -31,6 +36,17 @@
"SubplotRef", ("subplot_type", "layout_keys", "trace_kwargs")
)

class SubplotSpec(TypedDict):
type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str
secondary_y : bool
colspan : int
rowspan : int
# NOTE: that this is the dictionary as defined by the documentation, so the ambiguous name 'l' can't be changed without changing the documentation
l : float # noqa: E741
r : float
t : float
b : float


def _get_initial_max_subplot_ids():
max_subplot_ids = {subplot_type: 0 for subplot_type in _single_subplot_types}
Expand Down Expand Up @@ -746,19 +762,10 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname):
)
grid_ref[r][c] = subplot_refs

_configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir, False)
_configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir, False)

any_secondary_y = any(
spec["secondary_y"]
for spec_row in specs
for spec in spec_row
if spec is not None
)
if any_secondary_y:
_configure_shared_axes(
layout, grid_ref, specs, "y", shared_yaxes, row_dir, True
)
_configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir)
_configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir)


# Build inset reference
# ---------------------
Expand Down Expand Up @@ -889,99 +896,183 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname):

return figure


def _configure_shared_axes(
layout, grid_ref, specs, x_or_y, shared, row_dir, secondary_y
):
rows = len(grid_ref)
cols = len(grid_ref[0])

layout_key_ind = ["x", "y"].index(x_or_y)

if row_dir < 0:
rows_iter = range(rows - 1, -1, -1)
else:
rows_iter = range(rows)

if secondary_y:
cols_iter = range(cols - 1, -1, -1)
axis_index = 1
else:
cols_iter = range(cols)
axis_index = 0

def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label):
if subplot_ref is None:
return first_axis_id

if x_or_y == "x":
span = spec["colspan"]
else:
span = spec["rowspan"]

if subplot_ref.subplot_type == "xy" and span == 1:
if first_axis_id is None:
first_axis_name = subplot_ref.layout_keys[layout_key_ind]
first_axis_id = first_axis_name.replace("axis", "")
else:
axis_name = subplot_ref.layout_keys[layout_key_ind]
axis_to_match = layout[axis_name]
axis_to_match.matches = first_axis_id
if remove_label:
axis_to_match.showticklabels = False

return first_axis_id

if shared == "columns" or (x_or_y == "x" and shared is True):
for c in cols_iter:
first_axis_id = None
ok_to_remove_label = x_or_y == "x"
for r in rows_iter:
if not grid_ref[r][c]:
continue
if axis_index >= len(grid_ref[r][c]):
continue
subplot_ref = grid_ref[r][c][axis_index]
spec = specs[r][c]
first_axis_id = update_axis_matches(
first_axis_id, subplot_ref, spec, ok_to_remove_label
)

elif shared == "rows" or (x_or_y == "y" and shared is True):
for r in rows_iter:
first_axis_id = None
ok_to_remove_label = x_or_y == "y"
for c in cols_iter:
if not grid_ref[r][c]:
continue
if axis_index >= len(grid_ref[r][c]):
continue
subplot_ref = grid_ref[r][c][axis_index]
spec = specs[r][c]
first_axis_id = update_axis_matches(
first_axis_id, subplot_ref, spec, ok_to_remove_label
)

elif shared == "all":
first_axis_id = None
for ri, r in enumerate(rows_iter):
for c in cols_iter:
if not grid_ref[r][c]:
continue
if axis_index >= len(grid_ref[r][c]):
continue
subplot_ref = grid_ref[r][c][axis_index]
spec = specs[r][c]

if x_or_y == "y":
ok_to_remove_label = c < cols - 1 if secondary_y else c > 0
else:
ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1

first_axis_id = update_axis_matches(
first_axis_id, subplot_ref, spec, ok_to_remove_label
)
layout : Layout,
grid_ref : Tuple[Tuple[SubplotRef]],
specs : Tuple[Tuple[SubplotSpec]],
x_or_y : Literal['x', 'y'],
shared : bool | Literal['rows', 'columns', 'all'],
row_direction : Literal[1, -1]
) -> None:
'''
Sets the axes to be shared, making them use the same axis

Parameters:
-----------
layout (go.Layout) : The layout of the figure to be updating
grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate
specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate
x_or_y ('x' | 'y') : The axis to configure
shared ('rows' | 'columns' | 'all' | bool) : The sharing mode, (True is 'columns' mode, False means no sharing) ie share the axis with all subplots in the corresponding row, column, or entire figure
row_direction (1 | -1) : The directional that the rows go
'''

row_count : int = len(grid_ref)
column_count : int = len(grid_ref[0])

axis_index : int = 0 if x_or_y == 'x' else 1

def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tuple[int], trace_layer : int) -> Optional[Tuple[str, Tuple[int, int]]]:
'''
Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS

Parameters:
-----------
row_order (int | Tuple[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable
column_order (int | Tuple[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable
trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1]
Return:
-------
Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull)
Return (None): No label was found
'''

# Turn them into lists with one element, so that both row_order and column_order are iterables
row_order : Tuple[int] = [row_order] if isinstance(row_order, int) else row_order
column_order : Tuple[int] = [column_order] if isinstance(column_order, int) else column_order


# Iterate through the rows and columns
for row in row_order:
for column in column_order:
if not grid_ref[row][column]:
continue

subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column]
subplot_spec : SubplotSpec = specs[row][column]

span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan']
if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces):
continue

trace = subplot_traces[trace_layer]
if trace is None or trace.subplot_type != 'xy':
continue

label_name : str = trace.layout_keys[axis_index]
label : str = label_name.replace("axis", "")
return label, (row, column)
return None


def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None:
'''
Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS

Parameters:
-----------
axis_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location
row (int) : The row of the subplot within grid_ref to update
column (int) : The column of the subplot within grid_ref to update
trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1]
can_reassign_axis (bool): If True, can change the unique axis for the shared axis in the trace keywords, otherwise, will keep using the axis name it already has
can_hide_ticks (bool): If the function is allowed to hide the ticks (if True, it will hide the ticks, if False, it will leave the ticks as their current state)
can_match_axis (bool): If the axis should be marked as a match to the axis label
'''

if not grid_ref[row][column] or specs[row][column] is None:
return

subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column]
subplot_spec : SubplotSpec = specs[row][column]

span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan']
if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces):
return

trace : Optional[SubplotRef] = subplot_traces[trace_layer]

if trace is None or trace.subplot_type != 'xy' or span != 1:
return

axis_name : str = trace.layout_keys[axis_index]
# axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis'
axis : XAxis = layout[axis_name]

if can_match_axis:
axis.matches = axis_label

if can_hide_ticks:
axis.showticklabels = False

if can_reassign_axis:
# trace.trace_kwargs[axis_dimension] = axis_label
pass

def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int):
for column in columns:
# Get the label used by all the rows in the column
label_data = find_label_and_index(rows, column, trace_layer)
if label_data is None:
continue
axis_label, (label_row, _) = label_data

# Set all of the values in the column
for row in rows:
subplot_spec : SubplotSpec = specs[row][column]
can_reassign_axis : bool = (x_or_y != 'y' or not subplot_spec["secondary_y"]) # Every subplot in the same column should share the same axis if in columns mode
can_match_axis : bool = (row != label_row)
can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns

update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis)


def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int):
for row in rows:
label_data = find_label_and_index(row, columns, trace_layer)
if label_data is None:
continue
axis_label, (_, label_column) = label_data

for column in columns:
spec : SubplotSpec = specs[row][column]
can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y'])
can_match_axis : bool = (column != label_column)
can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row

update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis)

def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int):
label_data = find_label_and_index(rows, columns, trace_layer)
if label_data is None:
return
axis_label, (label_row, label_column) = label_data

for row in rows:
for column in columns:
spec : SubplotSpec = specs[row][column]
can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y'])
can_match_axis : bool = (row != label_row or column != label_column)
can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column
update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis)


rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count))
columns : Tuple[int] = tuple(range(column_count))
BASE_TRACE_LAYER = 0
SECOND_Y_LAYER = 1
match(shared, x_or_y):
case ('columns', _) | (True, 'x'): # If columns mode, or shared and x
columns_mode(rows, columns, BASE_TRACE_LAYER)
columns_mode(tuple(reversed(rows)), columns, SECOND_Y_LAYER)
case ('rows', _) | (True, 'y'): # If rows mode, or shared and y
rows_mode(rows, columns, BASE_TRACE_LAYER)
rows_mode(rows, tuple(reversed(columns)), SECOND_Y_LAYER)
case ('all', _): # If all mode
all_mode(rows, columns, BASE_TRACE_LAYER)
all_mode(tuple(reversed(rows)), tuple(reversed(columns)), SECOND_Y_LAYER)
case _: # If reached the other case
return

def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None):
if max_subplot_ids is None:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_core/test_subplots/test_make_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,8 @@ def test_subplot_titles_shared_axes_rows_columns(self):
shared_xaxes="rows",
shared_yaxes="columns",
)
print(f'Expected {expected}')
print(f'Actual: {fig}')
self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json())

def test_subplot_titles_irregular_layout(self):
Expand Down Expand Up @@ -1848,8 +1850,8 @@ def test_secondary_y_subplots(self):
fig.add_scatter(y=[0, 2, 4], name="Fifth", row=2, col=1)
fig.add_scatter(y=[2, 1, 3], name="Sixth", row=2, col=1, secondary_y=True)

fig.add_scatter(y=[2, 4, 0], name="Fifth", row=2, col=2)
fig.add_scatter(y=[2, 3, 6], name="Sixth", row=2, col=2, secondary_y=True)
fig.add_scatter(y=[2, 4, 0], name="Seventh", row=2, col=2)
fig.add_scatter(y=[2, 3, 6], name="Eighth", row=2, col=2, secondary_y=True)

fig.update_traces(uid=None)

Expand Down Expand Up @@ -1899,14 +1901,14 @@ def test_secondary_y_subplots(self):
"yaxis": "y6",
},
{
"name": "Fifth",
"name": "Seventh",
"type": "scatter",
"xaxis": "x4",
"y": [2, 4, 0],
"yaxis": "y7",
},
{
"name": "Sixth",
"name": "Eighth",
"type": "scatter",
"xaxis": "x4",
"y": [2, 3, 6],
Expand Down
Loading