Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# %%
import numpy as np

import ultraplot as uplt

rng = np.random.default_rng(21)
x = np.linspace(0, 5, 300)

layout = [[1, 2, 5], [3, 4, 5]]
# layout = [[1, 2], [4, 4]]
fig, axs = uplt.subplots(layout, journal="nat1")
for i, ax in enumerate(axs):
trend = (i + 1) * 0.2
y = np.exp(-0.4 * x) * np.sin(2 * x + i * 0.6) + trend
y += 0.05 * rng.standard_normal(x.size)
ax.plot(x, y, lw=2)
ax.fill_between(x, y - 0.15, y + 0.15, alpha=0.2)
ax.set_title(f"Condition {i + 1}")
# Share first 2 plots top left
axs[:2].format(
xlabel="Time (days)",
)
axs[1, :2].format(xlabel="Time 2 (days)")
axs[[-1]].format(xlabel="Time 3 (days)")
axs.format(
xlabel="Time (days)",
ylabel="Normalized response",
abc=True,
abcloc="ul",
suptitle="Spanning labels with shared axes",
grid=False,
)
axs.format(abc=1, abcloc="ol")
axs.format(xlabel="test")
fig.save("test.png")

fig.show()
uplt.show(block=1)
22 changes: 22 additions & 0 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,10 @@ def _align_axis_label(self, x):
if ax in seen or pos not in ("bottom", "left"):
continue # already aligned or cannot align
axs = ax._get_span_axes(pos, panels=False) # returns panel or main axes
if self._has_share_label_groups(x) and any(
self._is_share_label_group_member(axi, x) for axi in axs
):
continue # explicit label groups override default spanning
if any(getattr(ax, "_share" + x) for ax in axs):
continue # nothing to align or axes have parents
seen.update(axs)
Expand Down Expand Up @@ -2523,6 +2527,18 @@ def format(
for cls, sig in paxes.Axes._format_signatures.items()
}
classes = set() # track used dictionaries

def _axis_has_share_label_text(ax, axis):
groups = self._share_label_groups.get(axis, {})
for group in groups.values():
if ax in group["axes"] and str(group.get("text", "")).strip():
return True
return False

def _axis_has_label_text(ax, axis):
text = ax.get_xlabel() if axis == "x" else ax.get_ylabel()
return bool(text and text.strip())

for number, ax in enumerate(axs):
number = number + 1 # number from 1
store_old_number = ax.number
Expand All @@ -2534,6 +2550,12 @@ def format(
for key, value in kw.items()
if isinstance(ax, cls) and not classes.add(cls)
}
if kw.get("xlabel") is not None and self._has_share_label_groups("x"):
if _axis_has_share_label_text(ax, "x") or _axis_has_label_text(ax, "x"):
kw.pop("xlabel", None)
if kw.get("ylabel") is not None and self._has_share_label_groups("y"):
if _axis_has_share_label_text(ax, "y") or _axis_has_label_text(ax, "y"):
kw.pop("ylabel", None)
ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs)
ax.number = store_old_number
# Warn unused keyword argument(s)
Expand Down
4 changes: 4 additions & 0 deletions ultraplot/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,10 @@ def format(self, **kwargs):
all_axes = set(self.figure._subplot_dict.values())
is_subset = bool(axes) and all_axes and set(axes) != all_axes
if len(self) > 1:
if not is_subset and share_xlabels is None and xlabel is not None:
self.figure._clear_share_label_groups(target="x")
if not is_subset and share_ylabels is None and ylabel is not None:
self.figure._clear_share_label_groups(target="y")
if share_xlabels is False:
self.figure._clear_share_label_groups(self, target="x")
if share_ylabels is False:
Expand Down
35 changes: 35 additions & 0 deletions ultraplot/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,41 @@ def test_subset_share_xlabels_implicit_column():
uplt.close(fig)


def test_subset_share_xlabels_overridden_by_global_format():
fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False)
bottom = ax[1, :]
bottom.format(xlabel="Bottom-row X")
ax[0, 0].format(xlabel="Top-left X")
ax.format(xlabel="Global X")

fig.canvas.draw()

assert ax[0, 0].get_xlabel() == "Global X"
assert ax[0, 1].get_xlabel() == "Global X"
assert not any(
lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()
)

uplt.close(fig)


def test_full_grid_clears_share_label_groups():
fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False)
bottom = ax[1, :]
bottom.format(xlabel="Bottom-row X")
ax.format(xlabel="Global X")

fig.canvas.draw()

assert not fig._has_share_label_groups("x")
assert not any(
lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()
)
assert all(axi.get_xlabel() == "Global X" for axi in ax)

uplt.close(fig)


def test_subset_share_ylabels_implicit_row():
fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False)
top = ax[0, :]
Expand Down