From 4509531c14beefa4cf8ebee57fff317b2723c4fb Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 15 Jan 2026 13:38:40 +1000 Subject: [PATCH 1/5] Fix legend span inference with panels Legend span inference used panel-inflated indices after prior legends added panel rows/cols, yielding invalid gridspec indices for list refs. Decode subplot indices to non-panel grid before computing span and add regression tests for multi-legend ordering. --- ultraplot/figure.py | 265 +++++++++++++-- ultraplot/tests/test_legend.py | 566 +++------------------------------ 2 files changed, 285 insertions(+), 546 deletions(-) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 5d302f318..88a248828 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -313,7 +313,7 @@ The axes number used for a-b-c labeling. See `~ultraplot.axes.Axes.format` for details. By default this is incremented automatically based on the other subplots in the figure. Use e.g. ``number=None`` or ``number=False`` to ensure the subplot - has no a-b-c label. Note the number corresponding to ``a`` is ``1``, not ``0``. + has no a-b-c label. Note the number corresponding to `a` is ``1``, not ``0``. autoshare : bool, default: True Whether to automatically share the *x* and *y* axes with subplots spanning the same rows and columns based on the figure-wide `sharex` and `sharey` settings. @@ -2594,6 +2594,8 @@ def colorbar( """ # Backwards compatibility ax = kwargs.pop("ax", None) + ref = kwargs.pop("ref", None) + loc_ax = ref if ref is not None else ax cax = kwargs.pop("cax", None) if isinstance(values, maxes.Axes): cax = _not_none(cax_positional=values, cax=cax) @@ -2613,20 +2615,116 @@ def colorbar( with context._state_context(cax, _internal_call=True): # do not wrap pcolor cb = super().colorbar(mappable, cax=cax, **kwargs) # Axes panel colorbar - elif ax is not None: + elif loc_ax is not None: # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None + # Infer span from loc_ax if it is a list and no span provided + if ( + not has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + if side: + r_min, r_max = float("inf"), float("-inf") + c_min, c_max = float("inf"), float("-inf") + valid_ax = False + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + pass + r_min = min(r_min, r1) + r_max = max(r_max, r2) + c_min = min(c_min, c1) + c_max = max(c_max, c2) + valid_ax = True + + if valid_ax: + if side in ("left", "right"): + rows = (r_min + 1, r_max + 1) + else: + cols = (c_min + 1, c_max + 1) + has_span = True + # Extract a single axes from array if span is provided # Otherwise, pass the array as-is for normal colorbar behavior - if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): - try: - ax_single = next(iter(ax)) + if ( + has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + # Pick the best axis to anchor to based on the colorbar side + loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) - except (TypeError, StopIteration): - ax_single = ax + best_ax = None + best_coord = float("-inf") + + # If side is determined, search for the edge axis + if side: + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + pass + + if side == "right": + val = c2 # Maximize column index + elif side == "left": + val = -c1 # Minimize column index + elif side == "bottom": + val = r2 # Maximize row index + elif side == "top": + val = -r1 # Minimize row index + else: + val = 0 + + if val > best_coord: + best_coord = val + best_ax = axi + + # Fallback to first axis + if best_ax is None: + try: + ax_single = next(iter(loc_ax)) + except (TypeError, StopIteration): + ax_single = loc_ax + else: + ax_single = best_ax else: - ax_single = ax + ax_single = loc_ax # Pass span parameters through to axes colorbar cb = ax_single.colorbar( @@ -2700,27 +2798,150 @@ def legend( matplotlib.axes.Axes.legend """ ax = kwargs.pop("ax", None) + ref = kwargs.pop("ref", None) + loc_ax = ref if ref is not None else ax + # Axes panel legend - if ax is not None: + if loc_ax is not None: + content_ax = ax if ax is not None else loc_ax # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None - # Extract a single axes from array if span is provided - # Otherwise, pass the array as-is for normal legend behavior - # Automatically collect handles and labels from spanned axes if not provided - if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): - # Auto-collect handles and labels if not explicitly provided - if handles is None and labels is None: - handles, labels = [], [] - for axi in ax: + + # Automatically collect handles and labels from content axes if not provided + # Case 1: content_ax is a list (we must auto-collect) + # Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles) + must_collect = ( + np.iterable(content_ax) + and not isinstance(content_ax, (str, maxes.Axes)) + ) or (content_ax is not loc_ax) + + if must_collect and handles is None and labels is None: + handles, labels = [], [] + # Handle list of axes + if np.iterable(content_ax) and not isinstance( + content_ax, (str, maxes.Axes) + ): + for axi in content_ax: h, l = axi.get_legend_handles_labels() handles.extend(h) labels.extend(l) - try: - ax_single = next(iter(ax)) - except (TypeError, StopIteration): - ax_single = ax + # Handle single axis + else: + handles, labels = content_ax.get_legend_handles_labels() + + # Infer span from loc_ax if it is a list and no span provided + if ( + not has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + if side: + r_min, r_max = float("inf"), float("-inf") + c_min, c_max = float("inf"), float("-inf") + valid_ax = False + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + pass + r_min = min(r_min, r1) + r_max = max(r_max, r2) + c_min = min(c_min, c1) + c_max = max(c_max, c2) + valid_ax = True + + if valid_ax: + if side in ("left", "right"): + rows = (r_min + 1, r_max + 1) + else: + cols = (c_min + 1, c_max + 1) + has_span = True + + # Extract a single axes from array if span is provided (or if ref is a list) + # Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list) + if ( + has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + # Pick the best axis to anchor to based on the legend side + loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + best_ax = None + best_coord = float("-inf") + + # If side is determined, search for the edge axis + if side: + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + pass + + if side == "right": + val = c2 # Maximize column index + elif side == "left": + val = -c1 # Minimize column index + elif side == "bottom": + val = r2 # Maximize row index + elif side == "top": + val = -r1 # Minimize row index + else: + val = 0 + + if val > best_coord: + best_coord = val + best_ax = axi + + # Fallback to first axis if no best axis found (or side is None) + if best_ax is None: + try: + ax_single = next(iter(loc_ax)) + except (TypeError, StopIteration): + ax_single = loc_ax + else: + ax_single = best_ax + else: - ax_single = ax + ax_single = loc_ax + if isinstance(ax_single, list): + try: + ax_single = pgridspec.SubplotGrid(ax_single) + except ValueError: + ax_single = ax_single[0] + leg = ax_single.legend( handles, labels, diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 6b984a55e..949195e0f 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1,531 +1,49 @@ -#!/usr/bin/env python3 -""" -Test legends. -""" import numpy as np -import pandas as pd import pytest import ultraplot as uplt - - -@pytest.mark.mpl_image_compare -def test_auto_legend(rng): - """ - Test retrieval of legends from panels, insets, etc. - """ - fig, ax = uplt.subplots() - ax.line(rng.random((5, 3)), labels=list("abc")) - px = ax.panel_axes("right", share=False) - px.linex(rng.random((5, 3)), labels=list("xyz")) - # px.legend(loc='r') - ix = ax.inset_axes((-0.2, 0.8, 0.5, 0.5), zoom=False) - ix.line(rng.random((5, 2)), labels=list("pq")) - ax.legend(loc="b", order="F", edgecolor="red9", edgewidth=3) - return fig - - -@pytest.mark.mpl_image_compare -def test_singleton_legend(): - """ - Test behavior when singleton lists are passed. - Ensure this does not trigger centered-row legends. - """ - fig, ax = uplt.subplots() - h1 = ax.plot([0, 1, 2], label="a") - h2 = ax.plot([0, 1, 1], label="b") - ax.legend(loc="best") - ax.legend([h1, h2], loc="bottom") - return fig - - -@pytest.mark.mpl_image_compare -def test_centered_legends(rng): - """ - Test success of algorithm. - """ - # Basic centered legends - fig, axs = uplt.subplots(ncols=2, nrows=2, axwidth=2) - hs = axs[0].plot(rng.random((10, 6))) - locs = ["l", "t", "r", "uc", "ul", "ll"] - locs = ["l", "t", "uc", "ll"] - labels = ["a", "bb", "ccc", "ddddd", "eeeeeeee", "fffffffff"] - for ax, loc in zip(axs, locs): - ax.legend(hs, loc=loc, ncol=3, labels=labels, center=True) - - # Pass centered legends with keywords or list-of-list input. - fig, ax = uplt.subplots() - hs = ax.plot(rng.random((10, 5)), labels=list("abcde")) - ax.legend(hs, center=True, loc="b") - ax.legend(hs + hs[:1], loc="r", ncol=1) - ax.legend([hs[:2], hs[2:], hs[0]], loc="t") - return fig - - -@pytest.mark.mpl_image_compare -def test_manual_labels(): - """ - Test mixed auto and manual labels. Passing labels but no handles does nothing - This is breaking change but probably best. We should not be "guessing" the - order objects were drawn in then assigning labels to them. Similar to using - OO interface and rejecting pyplot "current axes" and "current figure". - """ - fig, ax = uplt.subplots() - (h1,) = ax.plot([0, 1, 2], label="label1") - (h2,) = ax.plot([0, 1, 1], label="label2") - for loc in ("best", "bottom"): - ax.legend([h1, h2], loc=loc, labels=[None, "override"]) - fig, ax = uplt.subplots() - ax.plot([0, 1, 2]) - ax.plot([0, 1, 1]) - for loc in ("best", "bottom"): - # ax.legend(loc=loc, labels=['a', 'b']) - ax.legend(["a", "b"], loc=loc) # same as above - return fig - - -@pytest.mark.mpl_image_compare -def test_contour_legend_with_label(rng): - """ - Support contour element labels. If has no label should trigger warning. - """ - figs = [] - label = "label" - - fig, axs = uplt.subplots(ncols=2) - ax = axs[0] - ax.contour(rng.random((5, 5)), color="k", label=label, legend="b") - ax = axs[1] - ax.pcolor(rng.random((5, 5)), label=label, legend="b") - return fig - - -@pytest.mark.mpl_image_compare -def test_contour_legend_without_label(rng): - """ - Support contour element labels. If has no label should trigger warning. - """ - label = None - fig, axs = uplt.subplots(ncols=2) - ax = axs[0] - ax.contour(rng.random((5, 5)), color="k", label=label, legend="b") - ax = axs[1] - ax.pcolor(rng.random((5, 5)), label=label, legend="b") - return fig - - -@pytest.mark.mpl_image_compare -def test_histogram_legend(rng): - """ - Support complex histogram legends. - """ - with uplt.rc.context({"inlineformat": "svg"}): - fig, ax = uplt.subplots() - res = ax.hist( - rng.random((500, 2)), 4, labels=("label", "other"), edgefix=True, legend="b" - ) - ax.legend( - res, loc="r", ncol=1 - ) # should issue warning after ignoring numpy arrays - df = pd.DataFrame( - {"length": [1.5, 0.5, 1.2, 0.9, 3], "width": [0.7, 0.2, 0.15, 0.2, 1.1]}, - index=["pig", "rabbit", "duck", "chicken", "horse"], - ) - fig, axs = uplt.subplots(ncols=3) - ax = axs[0] - res = ax.hist(df, bins=3, legend=True, lw=3) - ax.legend(loc="b") - for ax, meth in zip(axs[1:], ("bar", "area")): - hs = getattr(ax, meth)(df, legend="ul", lw=3) - ax.legend(hs, loc="b") - return fig - - -@pytest.mark.mpl_image_compare -def test_multiple_calls(rng): - """ - Test successive plotting additions to guides. - """ - fig, ax = uplt.subplots() - ax.pcolor(rng.random((10, 10)), colorbar="b") - ax.pcolor(rng.random((10, 5)), cmap="grays", colorbar="b") - ax.pcolor(rng.random((10, 5)), cmap="grays", colorbar="b") - - fig, ax = uplt.subplots() - data = rng.random((10, 5)) - for i in range(data.shape[1]): - ax.plot(data[:, i], colorbar="b", label=f"x{i}", colorbar_kw={"label": "hello"}) - return fig - - -@pytest.mark.mpl_image_compare -def test_tuple_handles(rng): - """ - Test tuple legend handles. - """ - from matplotlib import legend_handler - - fig, ax = uplt.subplots(refwidth=3, abc="A.", abcloc="ul", span=False) - patches = ax.fill_between(rng.random((10, 3)), stack=True) - lines = ax.line(1 + 0.5 * (rng.random((10, 3)) - 0.5).cumsum(axis=0)) - # ax.legend([(handles[0], lines[1])], ['joint label'], loc='bottom', queue=True) - for hs in (lines, patches): - ax.legend( - [tuple(hs[:3]) if len(hs) == 3 else hs], - ["joint label"], - loc="bottom", - queue=True, - ncol=1, - handlelength=4.5, - handleheight=1.5, - handler_map={tuple: legend_handler.HandlerTuple(pad=0, ndivide=3)}, - ) - return fig - - -@pytest.mark.mpl_image_compare -def test_legend_col_spacing(rng): - """ - Test legend column spacing. - """ - fig, ax = uplt.subplots() - ax.plot(rng.random(10), label="short") - ax.plot(rng.random(10), label="longer label") - ax.plot(rng.random(10), label="even longer label") - for idx in range(3): - spacing = f"{idx}em" - if idx == 2: - spacing = 3 - ax.legend(loc="bottom", ncol=3, columnspacing=spacing) - - with pytest.raises(ValueError): - ax.legend(loc="bottom", ncol=3, columnspacing="15x") - return fig - - -def test_sync_label_dict(rng): - """ - Legends are held within _legend_dict for which the key is a tuple of location and alignment. - - We need to ensure that the legend is updated in the dictionary when its location is changed. - """ - data = rng.random((2, 100)) - fig, ax = uplt.subplots() - ax.plot(*data, label="test") - leg = ax.legend(loc="lower right") - assert ("lower right", "center") in ax[0]._legend_dict, "Legend not found in dict" - leg.set_loc("upper left") - assert ("upper left", "center") in ax[ - 0 - ]._legend_dict, "Legend not found in dict after update" - assert leg is ax[0]._legend_dict[("upper left", "center")] - assert ("lower right", "center") not in ax[ - 0 - ]._legend_dict, "Old legend not removed from dict" - uplt.close(fig) - - -def test_external_mode_defers_on_the_fly_legend(): - """ - External mode should defer on-the-fly legend creation until explicitly requested. - """ - fig, ax = uplt.subplots() - ax.set_external(True) - (h,) = ax.plot([0, 1], label="a", legend="b") - - # No legend should have been created yet - assert getattr(ax[0], "legend_", None) is None - - # Explicit legend creation should include the plotted label - leg = ax.legend(h, loc="b") - labels = [t.get_text() for t in leg.get_texts()] - assert "a" in labels - uplt.close(fig) - - -def test_external_mode_mixing_context_manager(): - """ - Mixing external and internal plotting on the same axes: - - Inside ax.external(): on-the-fly legend is deferred - - Outside: UltraPlot-native plotting resumes as normal - - Final explicit ax.legend() consolidates both kinds of artists - """ - fig, ax = uplt.subplots() - - with ax.external(): - (ext,) = ax.plot([0, 1], label="ext", legend="b") # deferred - - (intr,) = ax.line([0, 1], label="int") # normal UL behavior - - leg = ax.legend([ext, intr], loc="b") - labels = {t.get_text() for t in leg.get_texts()} - assert {"ext", "int"}.issubset(labels) - uplt.close(fig) - - -def test_external_mode_toggle_enables_auto(): - """ - Toggling external mode back off should resume on-the-fly guide creation. - """ - fig, ax = uplt.subplots() - - ax.set_external(True) - (ha,) = ax.plot([0, 1], label="a", legend="b") - assert getattr(ax[0], "legend_", None) is None # deferred - - ax.set_external(False) - (hb,) = ax.plot([0, 1], label="b", legend="b") - # Now legend is queued for creation; verify it is registered in the outer legend dict - assert ("bottom", "center") in ax[0]._legend_dict - - # Ensure final legend contains both entries - leg = ax.legend([ha, hb], loc="b") - labels = {t.get_text() for t in leg.get_texts()} - assert {"a", "b"}.issubset(labels) - uplt.close(fig) - - -def test_synthetic_handles_filtered(): - """ - Synthetic-tagged helper artists must be ignored by legend parsing even when - explicitly passed as handles. - """ - fig, ax = uplt.subplots() - (h1,) = ax.plot([0, 1], label="visible") - (h2,) = ax.plot([1, 0], label="helper") - # Mark helper as synthetic; it should be filtered out from legend entries - setattr(h2, "_ultraplot_synthetic", True) - - leg = ax.legend([h1, h2], loc="best") - labels = [t.get_text() for t in leg.get_texts()] - assert "visible" in labels - assert "helper" not in labels - uplt.close(fig) - - -def test_fill_between_included_in_legend(): - """ - Legitimate fill_between/area handles must appear in legends (regression for - previously skipped FillBetweenPolyCollection). - """ - fig, ax = uplt.subplots() - x = np.arange(5) - y1 = np.zeros(5) - y2 = np.ones(5) - ax.fill_between(x, y1, y2, label="band") - - leg = ax.legend(loc="best") - labels = [t.get_text() for t in leg.get_texts()] - assert "band" in labels - uplt.close(fig) - - -def test_legend_span_bottom(): - """Test bottom legend with span parameter.""" - - fig, axs = uplt.subplots(nrows=2, ncols=3) - axs[0, 0].plot([], [], label="test") - - # Legend below row 1, spanning columns 1-2 - leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") - - # Verify legend was created - assert leg is not None - - -def test_legend_span_top(): - """Test top legend with span parameter.""" - - fig, axs = uplt.subplots(nrows=2, ncols=3) - axs[0, 0].plot([], [], label="test") - - # Legend above row 2, spanning columns 2-3 - leg = fig.legend(ax=axs[1, :], cols=(2, 3), loc="top") - - assert leg is not None - - -def test_legend_span_right(): - """Test right legend with rows parameter.""" - - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - # Legend right of column 1, spanning rows 1-2 - leg = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") - - assert leg is not None - - -def test_legend_span_left(): - """Test left legend with rows parameter.""" - - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - # Legend left of column 2, spanning rows 2-3 - leg = fig.legend(ax=axs[:, 1], rows=(2, 3), loc="left") - - assert leg is not None - - -def test_legend_span_validation_left_with_cols_error(): - """Test that LEFT legend raises error with cols parameter.""" - - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - with pytest.raises(ValueError, match="left.*vertical.*use 'rows='.*not 'cols='"): - fig.legend(ax=axs[0, 0], cols=(1, 2), loc="left") - - -def test_legend_span_validation_right_with_cols_error(): - """Test that RIGHT legend raises error with cols parameter.""" - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - with pytest.raises(ValueError, match="right.*vertical.*use 'rows='.*not 'cols='"): - fig.legend(ax=axs[0, 0], cols=(1, 2), loc="right") - - -def test_legend_span_validation_top_with_rows_error(): - """Test that TOP legend raises error with rows parameter.""" - fig, axs = uplt.subplots(nrows=2, ncols=3) - axs[0, 0].plot([], [], label="test") - - with pytest.raises(ValueError, match="top.*horizontal.*use 'cols='.*not 'rows='"): - fig.legend(ax=axs[0, 0], rows=(1, 2), loc="top") - - -def test_legend_span_validation_bottom_with_rows_error(): - """Test that BOTTOM legend raises error with rows parameter.""" - fig, axs = uplt.subplots(nrows=2, ncols=3) - axs[0, 0].plot([], [], label="test") - - with pytest.raises( - ValueError, match="bottom.*horizontal.*use 'cols='.*not 'rows='" - ): - fig.legend(ax=axs[0, 0], rows=(1, 2), loc="bottom") - - -def test_legend_span_validation_left_with_span_warns(): - """Test that LEFT legend with span parameter issues warning.""" - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - with pytest.warns(match="left.*vertical.*prefer 'rows='"): - leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="left") - assert leg is not None - - -def test_legend_span_validation_right_with_span_warns(): - """Test that RIGHT legend with span parameter issues warning.""" - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - with pytest.warns(match="right.*vertical.*prefer 'rows='"): - leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="right") - assert leg is not None - - -def test_legend_array_without_span(): - """Test that legend on array without span preserves original behavior.""" - fig, axs = uplt.subplots(nrows=2, ncols=2) - axs[0, 0].plot([], [], label="test") - - # Should create legend for all axes in the array - leg = fig.legend(ax=axs[:], loc="right") - assert leg is not None - - -def test_legend_array_with_span(): - """Test that legend on array with span uses first axis + span extent.""" - fig, axs = uplt.subplots(nrows=2, ncols=3) - axs[0, 0].plot([], [], label="test") - - # Should use first axis position with span extent - leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") - assert leg is not None - - -def test_legend_row_without_span(): - """Test that legend on row without span spans entire row.""" - fig, axs = uplt.subplots(nrows=2, ncols=3) - axs[0, 0].plot([], [], label="test") - - # Should span all 3 columns - leg = fig.legend(ax=axs[0, :], loc="bottom") - assert leg is not None - - -def test_legend_column_without_span(): - """Test that legend on column without span spans entire column.""" - fig, axs = uplt.subplots(nrows=3, ncols=2) - axs[0, 0].plot([], [], label="test") - - # Should span all 3 rows - leg = fig.legend(ax=axs[:, 0], loc="right") - assert leg is not None - - -def test_legend_multiple_sides_with_span(): - """Test multiple legends on different sides with span control.""" +from ultraplot.axes import Axes as UAxes + + +def _decode_panel_span(panel_ax, axis): + ss = panel_ax.get_subplotspec().get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if axis == "rows": + r1, r2 = gs._decode_indices(r1, r2, which="h") + return int(r1), int(r2) + if axis == "cols": + c1, c2 = gs._decode_indices(c1, c2, which="w") + return int(c1), int(c2) + raise ValueError(f"Unknown axis {axis!r}.") + + +def _anchor_axis(ref): + if np.iterable(ref) and not isinstance(ref, (str, UAxes)): + return next(iter(ref)) + return ref + + +@pytest.mark.parametrize( + "first_loc, first_ref, second_loc, second_ref, span_axis", + [ + ("b", lambda axs: axs[0], "r", lambda axs: axs[:, 1], "rows"), + ("r", lambda axs: axs[:, 2], "b", lambda axs: axs[1, :], "cols"), + ("t", lambda axs: axs[2], "l", lambda axs: axs[:, 0], "rows"), + ("l", lambda axs: axs[:, 0], "t", lambda axs: axs[1, :], "cols"), + ], +) +def test_legend_span_inference_with_multi_panels( + first_loc, first_ref, second_loc, second_ref, span_axis +): fig, axs = uplt.subplots(nrows=3, ncols=3) - axs[0, 0].plot([], [], label="test") - - # Create legends on all 4 sides with different spans - leg_bottom = fig.legend(ax=axs[0, 0], span=(1, 2), loc="bottom") - leg_top = fig.legend(ax=axs[1, 0], span=(2, 3), loc="top") - leg_right = fig.legend(ax=axs[0, 0], rows=(1, 2), loc="right") - leg_left = fig.legend(ax=axs[0, 1], rows=(2, 3), loc="left") - - assert leg_bottom is not None - assert leg_top is not None - assert leg_right is not None - assert leg_left is not None - - -def test_legend_auto_collect_handles_labels_with_span(): - """Test automatic collection of handles and labels from multiple axes with span parameters.""" - - fig, axs = uplt.subplots(nrows=2, ncols=2) - - # Create different plots in each subplot with labels - axs[0, 0].plot([0, 1], [0, 1], label="line1") - axs[0, 1].plot([0, 1], [1, 0], label="line2") - axs[1, 0].scatter([0.5], [0.5], label="point1") - axs[1, 1].scatter([0.5], [0.5], label="point2") - - # Test automatic collection with span parameter (no explicit handles/labels) - leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") - - # Verify legend was created and contains all handles/labels from both axes - assert leg is not None - assert len(leg.get_texts()) == 2 # Should have 2 labels (line1, line2) - - # Test with rows parameter - leg2 = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") - assert leg2 is not None - assert len(leg2.get_texts()) == 2 # Should have 2 labels (line1, point1) - - -def test_legend_explicit_handles_labels_override_auto_collection(): - """Test that explicit handles/labels override auto-collection.""" - - fig, axs = uplt.subplots(nrows=1, ncols=2) - - # Create plots with labels - (h1,) = axs[0].plot([0, 1], [0, 1], label="auto_label1") - (h2,) = axs[1].plot([0, 1], [1, 0], label="auto_label2") + axs.plot([0, 1], [0, 1], label="line") - # Test with explicit handles/labels (should override auto-collection) - custom_handles = [h1] - custom_labels = ["custom_label"] - leg = fig.legend( - ax=axs, span=(1, 2), loc="bottom", handles=custom_handles, labels=custom_labels - ) + fig.legend(ref=first_ref(axs), loc=first_loc) + fig.legend(ref=second_ref(axs), loc=second_loc) - # Verify legend uses explicit handles/labels, not auto-collected ones - assert leg is not None - assert len(leg.get_texts()) == 1 - assert leg.get_texts()[0].get_text() == "custom_label" + side_map = {"l": "left", "r": "right", "t": "top", "b": "bottom"} + anchor = _anchor_axis(second_ref(axs)) + panel_ax = anchor._panel_dict[side_map[second_loc]][-1] + span = _decode_panel_span(panel_ax, span_axis) + assert span == (0, 2) From 2db8df74f3dacf240643c3da6c4ee9ec66a0c74f Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 15 Jan 2026 13:48:42 +1000 Subject: [PATCH 2/5] Restore tests --- ultraplot/tests/test_legend.py | 614 +++++++++++++++++++++++++++++++++ 1 file changed, 614 insertions(+) diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 949195e0f..cbc67d352 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1,10 +1,624 @@ +#!/usr/bin/env python3 +""" +Test legends. +""" import numpy as np +import pandas as pd import pytest import ultraplot as uplt from ultraplot.axes import Axes as UAxes +@pytest.mark.mpl_image_compare +def test_auto_legend(rng): + """ + Test retrieval of legends from panels, insets, etc. + """ + fig, ax = uplt.subplots() + ax.line(rng.random((5, 3)), labels=list("abc")) + px = ax.panel_axes("right", share=False) + px.linex(rng.random((5, 3)), labels=list("xyz")) + # px.legend(loc='r') + ix = ax.inset_axes((-0.2, 0.8, 0.5, 0.5), zoom=False) + ix.line(rng.random((5, 2)), labels=list("pq")) + ax.legend(loc="b", order="F", edgecolor="red9", edgewidth=3) + return fig + + +@pytest.mark.mpl_image_compare +def test_singleton_legend(): + """ + Test behavior when singleton lists are passed. + Ensure this does not trigger centered-row legends. + """ + fig, ax = uplt.subplots() + h1 = ax.plot([0, 1, 2], label="a") + h2 = ax.plot([0, 1, 1], label="b") + ax.legend(loc="best") + ax.legend([h1, h2], loc="bottom") + return fig + + +@pytest.mark.mpl_image_compare +def test_centered_legends(rng): + """ + Test success of algorithm. + """ + # Basic centered legends + fig, axs = uplt.subplots(ncols=2, nrows=2, axwidth=2) + hs = axs[0].plot(rng.random((10, 6))) + locs = ["l", "t", "r", "uc", "ul", "ll"] + locs = ["l", "t", "uc", "ll"] + labels = ["a", "bb", "ccc", "ddddd", "eeeeeeee", "fffffffff"] + for ax, loc in zip(axs, locs): + ax.legend(hs, loc=loc, ncol=3, labels=labels, center=True) + + # Pass centered legends with keywords or list-of-list input. + fig, ax = uplt.subplots() + hs = ax.plot(rng.random((10, 5)), labels=list("abcde")) + ax.legend(hs, center=True, loc="b") + ax.legend(hs + hs[:1], loc="r", ncol=1) + ax.legend([hs[:2], hs[2:], hs[0]], loc="t") + return fig + + +@pytest.mark.mpl_image_compare +def test_manual_labels(): + """ + Test mixed auto and manual labels. Passing labels but no handles does nothing + This is breaking change but probably best. We should not be "guessing" the + order objects were drawn in then assigning labels to them. Similar to using + OO interface and rejecting pyplot "current axes" and "current figure". + """ + fig, ax = uplt.subplots() + (h1,) = ax.plot([0, 1, 2], label="label1") + (h2,) = ax.plot([0, 1, 1], label="label2") + for loc in ("best", "bottom"): + ax.legend([h1, h2], loc=loc, labels=[None, "override"]) + fig, ax = uplt.subplots() + ax.plot([0, 1, 2]) + ax.plot([0, 1, 1]) + for loc in ("best", "bottom"): + # ax.legend(loc=loc, labels=['a', 'b']) + ax.legend(["a", "b"], loc=loc) # same as above + return fig + + +@pytest.mark.mpl_image_compare +def test_contour_legend_with_label(rng): + """ + Support contour element labels. If has no label should trigger warning. + """ + figs = [] + label = "label" + + fig, axs = uplt.subplots(ncols=2) + ax = axs[0] + ax.contour(rng.random((5, 5)), color="k", label=label, legend="b") + ax = axs[1] + ax.pcolor(rng.random((5, 5)), label=label, legend="b") + return fig + + +@pytest.mark.mpl_image_compare +def test_contour_legend_without_label(rng): + """ + Support contour element labels. If has no label should trigger warning. + """ + label = None + fig, axs = uplt.subplots(ncols=2) + ax = axs[0] + ax.contour(rng.random((5, 5)), color="k", label=label, legend="b") + ax = axs[1] + ax.pcolor(rng.random((5, 5)), label=label, legend="b") + return fig + + +@pytest.mark.mpl_image_compare +def test_histogram_legend(rng): + """ + Support complex histogram legends. + """ + with uplt.rc.context({"inlineformat": "svg"}): + fig, ax = uplt.subplots() + res = ax.hist( + rng.random((500, 2)), 4, labels=("label", "other"), edgefix=True, legend="b" + ) + ax.legend( + res, loc="r", ncol=1 + ) # should issue warning after ignoring numpy arrays + df = pd.DataFrame( + {"length": [1.5, 0.5, 1.2, 0.9, 3], "width": [0.7, 0.2, 0.15, 0.2, 1.1]}, + index=["pig", "rabbit", "duck", "chicken", "horse"], + ) + fig, axs = uplt.subplots(ncols=3) + ax = axs[0] + res = ax.hist(df, bins=3, legend=True, lw=3) + ax.legend(loc="b") + for ax, meth in zip(axs[1:], ("bar", "area")): + hs = getattr(ax, meth)(df, legend="ul", lw=3) + ax.legend(hs, loc="b") + return fig + + +@pytest.mark.mpl_image_compare +def test_multiple_calls(rng): + """ + Test successive plotting additions to guides. + """ + fig, ax = uplt.subplots() + ax.pcolor(rng.random((10, 10)), colorbar="b") + ax.pcolor(rng.random((10, 5)), cmap="grays", colorbar="b") + ax.pcolor(rng.random((10, 5)), cmap="grays", colorbar="b") + + fig, ax = uplt.subplots() + data = rng.random((10, 5)) + for i in range(data.shape[1]): + ax.plot(data[:, i], colorbar="b", label=f"x{i}", colorbar_kw={"label": "hello"}) + return fig + + +@pytest.mark.mpl_image_compare +def test_tuple_handles(rng): + """ + Test tuple legend handles. + """ + from matplotlib import legend_handler + + fig, ax = uplt.subplots(refwidth=3, abc="A.", abcloc="ul", span=False) + patches = ax.fill_between(rng.random((10, 3)), stack=True) + lines = ax.line(1 + 0.5 * (rng.random((10, 3)) - 0.5).cumsum(axis=0)) + # ax.legend([(handles[0], lines[1])], ['joint label'], loc='bottom', queue=True) + for hs in (lines, patches): + ax.legend( + [tuple(hs[:3]) if len(hs) == 3 else hs], + ["joint label"], + loc="bottom", + queue=True, + ncol=1, + handlelength=4.5, + handleheight=1.5, + handler_map={tuple: legend_handler.HandlerTuple(pad=0, ndivide=3)}, + ) + return fig + + +@pytest.mark.mpl_image_compare +def test_legend_col_spacing(rng): + """ + Test legend column spacing. + """ + fig, ax = uplt.subplots() + ax.plot(rng.random(10), label="short") + ax.plot(rng.random(10), label="longer label") + ax.plot(rng.random(10), label="even longer label") + for idx in range(3): + spacing = f"{idx}em" + if idx == 2: + spacing = 3 + ax.legend(loc="bottom", ncol=3, columnspacing=spacing) + + with pytest.raises(ValueError): + ax.legend(loc="bottom", ncol=3, columnspacing="15x") + return fig + + +def test_sync_label_dict(rng): + """ + Legends are held within _legend_dict for which the key is a tuple of location and alignment. + + We need to ensure that the legend is updated in the dictionary when its location is changed. + """ + data = rng.random((2, 100)) + fig, ax = uplt.subplots() + ax.plot(*data, label="test") + leg = ax.legend(loc="lower right") + assert ("lower right", "center") in ax[0]._legend_dict, "Legend not found in dict" + leg.set_loc("upper left") + assert ("upper left", "center") in ax[ + 0 + ]._legend_dict, "Legend not found in dict after update" + assert leg is ax[0]._legend_dict[("upper left", "center")] + assert ("lower right", "center") not in ax[ + 0 + ]._legend_dict, "Old legend not removed from dict" + uplt.close(fig) + + +def test_external_mode_defers_on_the_fly_legend(): + """ + External mode should defer on-the-fly legend creation until explicitly requested. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1], label="a", legend="b") + + # No legend should have been created yet + assert getattr(ax[0], "legend_", None) is None + + # Explicit legend creation should include the plotted label + leg = ax.legend(h, loc="b") + labels = [t.get_text() for t in leg.get_texts()] + assert "a" in labels + uplt.close(fig) + + +def test_external_mode_mixing_context_manager(): + """ + Mixing external and internal plotting on the same axes: + - Inside ax.external(): on-the-fly legend is deferred + - Outside: UltraPlot-native plotting resumes as normal + - Final explicit ax.legend() consolidates both kinds of artists + """ + fig, ax = uplt.subplots() + + with ax.external(): + (ext,) = ax.plot([0, 1], label="ext", legend="b") # deferred + + (intr,) = ax.line([0, 1], label="int") # normal UL behavior + + leg = ax.legend([ext, intr], loc="b") + labels = {t.get_text() for t in leg.get_texts()} + assert {"ext", "int"}.issubset(labels) + uplt.close(fig) + + +def test_external_mode_toggle_enables_auto(): + """ + Toggling external mode back off should resume on-the-fly guide creation. + """ + fig, ax = uplt.subplots() + + ax.set_external(True) + (ha,) = ax.plot([0, 1], label="a", legend="b") + assert getattr(ax[0], "legend_", None) is None # deferred + + ax.set_external(False) + (hb,) = ax.plot([0, 1], label="b", legend="b") + # Now legend is queued for creation; verify it is registered in the outer legend dict + assert ("bottom", "center") in ax[0]._legend_dict + + # Ensure final legend contains both entries + leg = ax.legend([ha, hb], loc="b") + labels = {t.get_text() for t in leg.get_texts()} + assert {"a", "b"}.issubset(labels) + uplt.close(fig) + + +def test_synthetic_handles_filtered(): + """ + Synthetic-tagged helper artists must be ignored by legend parsing even when + explicitly passed as handles. + """ + fig, ax = uplt.subplots() + (h1,) = ax.plot([0, 1], label="visible") + (h2,) = ax.plot([1, 0], label="helper") + # Mark helper as synthetic; it should be filtered out from legend entries + setattr(h2, "_ultraplot_synthetic", True) + + leg = ax.legend([h1, h2], loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "visible" in labels + assert "helper" not in labels + uplt.close(fig) + + +def test_fill_between_included_in_legend(): + """ + Legitimate fill_between/area handles must appear in legends (regression for + previously skipped FillBetweenPolyCollection). + """ + fig, ax = uplt.subplots() + x = np.arange(5) + y1 = np.zeros(5) + y2 = np.ones(5) + ax.fill_between(x, y1, y2, label="band") + + leg = ax.legend(loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "band" in labels + uplt.close(fig) + + +def test_legend_span_bottom(): + """Test bottom legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend below row 1, spanning columns 1-2 + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created + assert leg is not None + + +def test_legend_span_top(): + """Test top legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend above row 2, spanning columns 2-3 + leg = fig.legend(ax=axs[1, :], cols=(2, 3), loc="top") + + assert leg is not None + + +def test_legend_span_right(): + """Test right legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend right of column 1, spanning rows 1-2 + leg = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + + assert leg is not None + + +def test_legend_span_left(): + """Test left legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend left of column 2, spanning rows 2-3 + leg = fig.legend(ax=axs[:, 1], rows=(2, 3), loc="left") + + assert leg is not None + + +def test_legend_span_validation_left_with_cols_error(): + """Test that LEFT legend raises error with cols parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="left.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="left") + + +def test_legend_span_validation_right_with_cols_error(): + """Test that RIGHT legend raises error with cols parameter.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="right.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="right") + + +def test_legend_span_validation_top_with_rows_error(): + """Test that TOP legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="top.*horizontal.*use 'cols='.*not 'rows='"): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="top") + + +def test_legend_span_validation_bottom_with_rows_error(): + """Test that BOTTOM legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises( + ValueError, match="bottom.*horizontal.*use 'cols='.*not 'rows='" + ): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="bottom") + + +def test_legend_span_validation_left_with_span_warns(): + """Test that LEFT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="left.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="left") + assert leg is not None + + +def test_legend_span_validation_right_with_span_warns(): + """Test that RIGHT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="right.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="right") + assert leg is not None + + +def test_legend_array_without_span(): + """Test that legend on array without span preserves original behavior.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should create legend for all axes in the array + leg = fig.legend(ax=axs[:], loc="right") + assert leg is not None + + +def test_legend_array_with_span(): + """Test that legend on array with span uses first axis + span extent.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should use first axis position with span extent + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + assert leg is not None + + +def test_legend_row_without_span(): + """Test that legend on row without span spans entire row.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 columns + leg = fig.legend(ax=axs[0, :], loc="bottom") + assert leg is not None + + +def test_legend_column_without_span(): + """Test that legend on column without span spans entire column.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 rows + leg = fig.legend(ax=axs[:, 0], loc="right") + assert leg is not None + + +def test_legend_multiple_sides_with_span(): + """Test multiple legends on different sides with span control.""" + fig, axs = uplt.subplots(nrows=3, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Create legends on all 4 sides with different spans + leg_bottom = fig.legend(ax=axs[0, 0], span=(1, 2), loc="bottom") + leg_top = fig.legend(ax=axs[1, 0], span=(2, 3), loc="top") + leg_right = fig.legend(ax=axs[0, 0], rows=(1, 2), loc="right") + leg_left = fig.legend(ax=axs[0, 1], rows=(2, 3), loc="left") + + assert leg_bottom is not None + assert leg_top is not None + assert leg_right is not None + assert leg_left is not None + + +def test_legend_auto_collect_handles_labels_with_span(): + """Test automatic collection of handles and labels from multiple axes with span parameters.""" + + fig, axs = uplt.subplots(nrows=2, ncols=2) + + # Create different plots in each subplot with labels + axs[0, 0].plot([0, 1], [0, 1], label="line1") + axs[0, 1].plot([0, 1], [1, 0], label="line2") + axs[1, 0].scatter([0.5], [0.5], label="point1") + axs[1, 1].scatter([0.5], [0.5], label="point2") + + # Test automatic collection with span parameter (no explicit handles/labels) + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created and contains all handles/labels from both axes + assert leg is not None + assert len(leg.get_texts()) == 2 # Should have 2 labels (line1, line2) + + # Test with rows parameter + leg2 = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + assert leg2 is not None + assert len(leg2.get_texts()) == 2 # Should have 2 labels (line1, point1) + + +def test_legend_explicit_handles_labels_override_auto_collection(): + """Test that explicit handles/labels override auto-collection.""" + + fig, axs = uplt.subplots(nrows=1, ncols=2) + + # Create plots with labels + (h1,) = axs[0].plot([0, 1], [0, 1], label="auto_label1") + (h2,) = axs[1].plot([0, 1], [1, 0], label="auto_label2") + + # Test with explicit handles/labels (should override auto-collection) + custom_handles = [h1] + custom_labels = ["custom_label"] + leg = fig.legend( + ax=axs, span=(1, 2), loc="bottom", handles=custom_handles, labels=custom_labels + ) + + # Verify legend uses explicit handles/labels, not auto-collected ones + assert leg is not None + assert len(leg.get_texts()) == 1 + assert leg.get_texts()[0].get_text() == "custom_label" + + +def test_legend_ref_argument(): + """Test using 'ref' to decouple legend location from content axes.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="line1") # Row 0 + axs[1, 0].plot([], [], label="line2") # Row 1 + + # Place legend below Row 0 (axs[0, :]) using content from Row 1 (axs[1, :]) + leg = fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom") + + assert leg is not None + + # Should be a single legend because span is inferred from ref + assert not isinstance(leg, tuple) + + texts = [t.get_text() for t in leg.get_texts()] + assert "line2" in texts + assert "line1" not in texts + + +def test_legend_ref_argument_no_ax(): + """Test using 'ref' where 'ax' is implied to be 'ref'.""" + fig, axs = uplt.subplots(nrows=1, ncols=1) + axs[0].plot([], [], label="line1") + + # ref provided, ax=None. Should behave like ax=ref. + leg = fig.legend(ref=axs[0], loc="bottom") + assert leg is not None + + # Should be a single legend + assert not isinstance(leg, tuple) + + texts = [t.get_text() for t in leg.get_texts()] + assert "line1" in texts + + +def test_ref_with_explicit_handles(): + """Test using ref with explicit handles and labels.""" + fig, axs = uplt.subplots(ncols=2) + h = axs[0].plot([0, 1], [0, 1], label="line") + + # Place legend below both axes (ref=axs) using explicit handle + leg = fig.legend(handles=h, labels=["explicit"], ref=axs, loc="bottom") + + assert leg is not None + texts = [t.get_text() for t in leg.get_texts()] + assert texts == ["explicit"] + + +def test_ref_with_non_edge_location(): + """Test using ref with an inset location (should not infer span).""" + fig, axs = uplt.subplots(ncols=2) + axs[0].plot([0, 1], label="test") + + # ref=axs (list of 2). + # 'upper left' is inset. Should fallback to first axis. + leg = fig.legend(ref=axs, loc="upper left") + + assert leg is not None + if isinstance(leg, tuple): + leg = leg[0] + # Should be associated with axs[0] (or a panel of it? Inset is child of axes) + # leg.axes is the axes containing the legend. For inset, it's the parent axes? + # No, legend itself is an artist. leg.axes should be axs[0]. + assert leg.axes is axs[0] + + +def test_ref_with_single_axis(): + """Test using ref with a single axis object.""" + fig, axs = uplt.subplots(ncols=2) + axs[0].plot([0, 1], label="line") + + # ref=axs[1]. loc='bottom'. + leg = fig.legend(ref=axs[1], ax=axs[0], loc="bottom") + assert leg is not None + + +def test_ref_with_manual_axes_no_subplotspec(): + """Test using ref with axes that don't have subplotspec.""" + fig = uplt.figure() + ax1 = fig.add_axes([0.1, 0.1, 0.4, 0.4]) + ax2 = fig.add_axes([0.5, 0.1, 0.4, 0.4]) + ax1.plot([0, 1], [0, 1], label="line") + # ref=[ax1, ax2]. loc='upper right' (inset). + leg = fig.legend(ref=[ax1, ax2], loc="upper right") + assert leg is not None + + def _decode_panel_span(panel_ax, axis): ss = panel_ax.get_subplotspec().get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() From f12e120ffc853c08d2c21ff96d650b0413d580e2 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 15 Jan 2026 14:00:19 +1000 Subject: [PATCH 3/5] Document legend span decode fallback Add a brief note that decoding panel indices can fail for panel or nested subplot specs, so we fall back to raw indices. --- ultraplot/figure.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 88a248828..e78870889 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -2650,6 +2650,7 @@ def colorbar( r1, r2 = gs._decode_indices(r1, r2, which="h") c1, c2 = gs._decode_indices(c1, c2, which="w") except ValueError: + # Non-panel decode can fail for panel or nested specs. pass r_min = min(r_min, r1) r_max = max(r_max, r2) @@ -2698,6 +2699,7 @@ def colorbar( r1, r2 = gs._decode_indices(r1, r2, which="h") c1, c2 = gs._decode_indices(c1, c2, which="w") except ValueError: + # Non-panel decode can fail for panel or nested specs. pass if side == "right": From 054399c9b51d9df80a7963d45f80d491dd448042 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 15 Jan 2026 14:10:11 +1000 Subject: [PATCH 4/5] Add legend span/selection regression tests Cover best-axis selection for left/right/top/bottom and the decode-index fallback path to raise coverage around Figure.legend panel inference. --- ultraplot/tests/test_legend.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index cbc67d352..4c7c943e8 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -661,3 +661,46 @@ def test_legend_span_inference_with_multi_panels( panel_ax = anchor._panel_dict[side_map[second_loc]][-1] span = _decode_panel_span(panel_ax, span_axis) assert span == (0, 2) + + +def test_legend_best_axis_selection_right_left(): + fig, axs = uplt.subplots(nrows=1, ncols=3) + axs.plot([0, 1], [0, 1], label="line") + ref = [axs[0, 0], axs[0, 2]] + + fig.legend(ref=ref, loc="r", rows=1) + assert len(axs[0, 2]._panel_dict["right"]) == 1 + assert len(axs[0, 0]._panel_dict["right"]) == 0 + + fig.legend(ref=ref, loc="l", rows=1) + assert len(axs[0, 0]._panel_dict["left"]) == 1 + assert len(axs[0, 2]._panel_dict["left"]) == 0 + + +def test_legend_best_axis_selection_top_bottom(): + fig, axs = uplt.subplots(nrows=2, ncols=1) + axs.plot([0, 1], [0, 1], label="line") + ref = [axs[0, 0], axs[1, 0]] + + fig.legend(ref=ref, loc="t", cols=1) + assert len(axs[0, 0]._panel_dict["top"]) == 1 + assert len(axs[1, 0]._panel_dict["top"]) == 0 + + fig.legend(ref=ref, loc="b", cols=1) + assert len(axs[1, 0]._panel_dict["bottom"]) == 1 + assert len(axs[0, 0]._panel_dict["bottom"]) == 0 + + +def test_legend_span_decode_fallback(monkeypatch): + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs.plot([0, 1], [0, 1], label="line") + ref = axs[:, 0] + + gs = axs[0, 0].get_subplotspec().get_topmost_subplotspec().get_gridspec() + + def _raise_decode(*args, **kwargs): + raise ValueError("forced") + + monkeypatch.setattr(gs, "_decode_indices", _raise_decode) + leg = fig.legend(ref=ref, loc="r") + assert leg is not None From b4087a8023b28e954a5df9d51ac71d02de898d2f Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 15 Jan 2026 14:17:31 +1000 Subject: [PATCH 5/5] Extend legend coverage for edge ref handling Add tests that cover span inference with invalid ref entries, best-axis fallback on inset locations, and the empty-iterable ref fallback path. --- ultraplot/tests/test_legend.py | 44 ++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 4c7c943e8..f9157ddd0 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -704,3 +704,47 @@ def _raise_decode(*args, **kwargs): monkeypatch.setattr(gs, "_decode_indices", _raise_decode) leg = fig.legend(ref=ref, loc="r") assert leg is not None + + +def test_legend_span_inference_skips_invalid_ref_axes(): + class DummyNoSpec: + pass + + class DummyNullSpec: + def get_subplotspec(self): + return None + + fig, axs = uplt.subplots(nrows=1, ncols=2) + axs[0].plot([0, 1], [0, 1], label="line") + ref = [DummyNoSpec(), DummyNullSpec(), axs[0]] + + leg = fig.legend(ax=axs[0], ref=ref, loc="r") + assert leg is not None + assert len(axs[0]._panel_dict["right"]) == 1 + + +def test_legend_best_axis_fallback_with_inset_loc(): + fig, axs = uplt.subplots(nrows=1, ncols=2) + axs.plot([0, 1], [0, 1], label="line") + + leg = fig.legend(ref=axs, loc="upper left", rows=1) + assert leg is not None + + +def test_legend_best_axis_fallback_empty_iterable_ref(): + class LegendProxy: + def __init__(self, ax): + self._ax = ax + + def __iter__(self): + return iter(()) + + def legend(self, *args, **kwargs): + return self._ax.legend(*args, **kwargs) + + fig, ax = uplt.subplots() + ax.plot([0, 1], [0, 1], label="line") + proxy = LegendProxy(ax) + + leg = fig.legend(ref=proxy, loc="upper left", rows=1) + assert leg is not None