diff --git a/test.py b/test.py new file mode 100644 index 00000000..c546269f --- /dev/null +++ b/test.py @@ -0,0 +1,38 @@ +# %% +import numpy as np + +import ultraplot as uplt + +rng = np.random.default_rng(21) +x = np.linspace(0, 5, 300) + +layout = [[1, 2, 5], [3, 4, 5]] +# layout = [[1, 2], [4, 4]] +fig, axs = uplt.subplots(layout, journal="nat1") +for i, ax in enumerate(axs): + trend = (i + 1) * 0.2 + y = np.exp(-0.4 * x) * np.sin(2 * x + i * 0.6) + trend + y += 0.05 * rng.standard_normal(x.size) + ax.plot(x, y, lw=2) + ax.fill_between(x, y - 0.15, y + 0.15, alpha=0.2) + ax.set_title(f"Condition {i + 1}") +# Share first 2 plots top left +axs[:2].format( + xlabel="Time (days)", +) +axs[1, :2].format(xlabel="Time 2 (days)") +axs[[-1]].format(xlabel="Time 3 (days)") +axs.format( + xlabel="Time (days)", + ylabel="Normalized response", + abc=True, + abcloc="ul", + suptitle="Spanning labels with shared axes", + grid=False, +) +axs.format(abc=1, abcloc="ol") +axs.format(xlabel="test") +fig.save("test.png") + +fig.show() +uplt.show(block=1) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index e7887088..b2612d6a 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1877,6 +1877,10 @@ def _align_axis_label(self, x): if ax in seen or pos not in ("bottom", "left"): continue # already aligned or cannot align axs = ax._get_span_axes(pos, panels=False) # returns panel or main axes + if self._has_share_label_groups(x) and any( + self._is_share_label_group_member(axi, x) for axi in axs + ): + continue # explicit label groups override default spanning if any(getattr(ax, "_share" + x) for ax in axs): continue # nothing to align or axes have parents seen.update(axs) @@ -2523,6 +2527,18 @@ def format( for cls, sig in paxes.Axes._format_signatures.items() } classes = set() # track used dictionaries + + def _axis_has_share_label_text(ax, axis): + groups = self._share_label_groups.get(axis, {}) + for group in groups.values(): + if ax in group["axes"] and str(group.get("text", "")).strip(): + return True + return False + + def _axis_has_label_text(ax, axis): + text = ax.get_xlabel() if axis == "x" else ax.get_ylabel() + return bool(text and text.strip()) + for number, ax in enumerate(axs): number = number + 1 # number from 1 store_old_number = ax.number @@ -2534,6 +2550,12 @@ def format( for key, value in kw.items() if isinstance(ax, cls) and not classes.add(cls) } + if kw.get("xlabel") is not None and self._has_share_label_groups("x"): + if _axis_has_share_label_text(ax, "x") or _axis_has_label_text(ax, "x"): + kw.pop("xlabel", None) + if kw.get("ylabel") is not None and self._has_share_label_groups("y"): + if _axis_has_share_label_text(ax, "y") or _axis_has_label_text(ax, "y"): + kw.pop("ylabel", None) ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs) ax.number = store_old_number # Warn unused keyword argument(s) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 93a6343a..6f4c2d22 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1767,6 +1767,10 @@ def format(self, **kwargs): all_axes = set(self.figure._subplot_dict.values()) is_subset = bool(axes) and all_axes and set(axes) != all_axes if len(self) > 1: + if not is_subset and share_xlabels is None and xlabel is not None: + self.figure._clear_share_label_groups(target="x") + if not is_subset and share_ylabels is None and ylabel is not None: + self.figure._clear_share_label_groups(target="y") if share_xlabels is False: self.figure._clear_share_label_groups(self, target="x") if share_ylabels is False: diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index eb42c79f..39eb61c3 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -335,6 +335,41 @@ def test_subset_share_xlabels_implicit_column(): uplt.close(fig) +def test_subset_share_xlabels_overridden_by_global_format(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + bottom = ax[1, :] + bottom.format(xlabel="Bottom-row X") + ax[0, 0].format(xlabel="Top-left X") + ax.format(xlabel="Global X") + + fig.canvas.draw() + + assert ax[0, 0].get_xlabel() == "Global X" + assert ax[0, 1].get_xlabel() == "Global X" + assert not any( + lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values() + ) + + uplt.close(fig) + + +def test_full_grid_clears_share_label_groups(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + bottom = ax[1, :] + bottom.format(xlabel="Bottom-row X") + ax.format(xlabel="Global X") + + fig.canvas.draw() + + assert not fig._has_share_label_groups("x") + assert not any( + lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values() + ) + assert all(axi.get_xlabel() == "Global X" for axi in ax) + + uplt.close(fig) + + def test_subset_share_ylabels_implicit_row(): fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) top = ax[0, :]