diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1cea5df..fd6f3aadf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [Unreleased] + +### Fixes + + - Filter unused categories from clonotype network legend ([#680](https://github.com/scverse/scirpy/pull/680/)). + ## v0.23.0 ### Changes diff --git a/src/scirpy/pl/_clonotypes.py b/src/scirpy/pl/_clonotypes.py index 2976fa55a..e388d396f 100644 --- a/src/scirpy/pl/_clonotypes.py +++ b/src/scirpy/pl/_clonotypes.py @@ -617,9 +617,16 @@ def _aggregate_per_dot_continuous(values): **text_kwds, ) - # add legend for categorical colors + # add legend for categorical colors, showing only categories present in the plot if cat_colors is not None and show_legend: - for cat, color in cat_colors.items(): + used_colors = set() + if pie_colors is not None: + for pc in pie_colors: + used_colors.update(pc.keys()) + visible_cat_colors = ( + {cat: c for cat, c in cat_colors.items() if c in used_colors} if used_colors else cat_colors + ) + for cat, color in visible_cat_colors.items(): # use empty scatter to set labels legend_ax.scatter([], [], c=color, label=cat) legend_ax.legend( @@ -627,7 +634,7 @@ def _aggregate_per_dot_continuous(values): loc="center left", # bbox_to_anchor=(1, 0.5), fontsize=legend_fontsize, - ncol=(1 if len(cat_colors) <= 14 else 2 if len(cat_colors) <= 30 else 3), + ncol=(1 if len(visible_cat_colors) <= 14 else 2 if len(visible_cat_colors) <= 30 else 3), ) legend_ax.axis("off") diff --git a/src/scirpy/tests/test_plotting.py b/src/scirpy/tests/test_plotting.py index d6113076d..93330a05c 100644 --- a/src/scirpy/tests/test_plotting.py +++ b/src/scirpy/tests/test_plotting.py @@ -163,6 +163,36 @@ def test_clonotype_network_pie( assert isinstance(p, plt.Axes) +@pytest.mark.extra +def test_clonotype_network_pie_legend_filters_unused(adata_clonotype_network): + """Legend should only contain categories present in the plotted network. + + Regression test for https://github.com/scverse/scirpy/issues/679 + """ + adata = adata_clonotype_network + # Add an extra category that no cell in the network has, to ensure + # the legend filters it out. + obs_col = "receptor_type" + tmp_ad = adata.mod["airr"] if isinstance(adata, MuData) else adata + tmp_ad.obs[obs_col] = tmp_ad.obs[obs_col].cat.add_categories("extra_unused") + + fig = plt.gcf() + fig.clear() + p = pl.clonotype_network(adata, color=obs_col, show_legend=True) + assert isinstance(p, plt.Axes) + + # Collect legend labels from all axes in the figure + legend_labels = [] + for ax in fig.get_axes(): + legend = ax.get_legend() + if legend is not None: + legend_labels.extend([t.get_text() for t in legend.get_texts()]) + + assert "extra_unused" not in legend_labels, ( + f"Legend should not contain unused category 'extra_unused', got: {legend_labels}" + ) + + @pytest.mark.extra def test_logoplot(adata_cdr3): p = pl.logoplot_cdr3_motif(adata_cdr3, chains="VJ_1")