Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [Unreleased]

### Fixes

- Filter unused categories from clonotype network legend ([#680](https://github.com/scverse/scirpy/pull/680/)).

## v0.23.0

### Changes
Expand Down
13 changes: 10 additions & 3 deletions src/scirpy/pl/_clonotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,17 +617,24 @@ def _aggregate_per_dot_continuous(values):
**text_kwds,
)

# add legend for categorical colors
# add legend for categorical colors, showing only categories present in the plot
if cat_colors is not None and show_legend:
for cat, color in cat_colors.items():
used_colors = set()
if pie_colors is not None:
for pc in pie_colors:
used_colors.update(pc.keys())
visible_cat_colors = (
{cat: c for cat, c in cat_colors.items() if c in used_colors} if used_colors else cat_colors
)
for cat, color in visible_cat_colors.items():
# use empty scatter to set labels
legend_ax.scatter([], [], c=color, label=cat)
legend_ax.legend(
frameon=False,
loc="center left",
# bbox_to_anchor=(1, 0.5),
fontsize=legend_fontsize,
ncol=(1 if len(cat_colors) <= 14 else 2 if len(cat_colors) <= 30 else 3),
ncol=(1 if len(visible_cat_colors) <= 14 else 2 if len(visible_cat_colors) <= 30 else 3),
)
legend_ax.axis("off")

Expand Down
30 changes: 30 additions & 0 deletions src/scirpy/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,36 @@ def test_clonotype_network_pie(
assert isinstance(p, plt.Axes)


@pytest.mark.extra
def test_clonotype_network_pie_legend_filters_unused(adata_clonotype_network):
"""Legend should only contain categories present in the plotted network.

Regression test for https://github.com/scverse/scirpy/issues/679
"""
adata = adata_clonotype_network
# Add an extra category that no cell in the network has, to ensure
# the legend filters it out.
obs_col = "receptor_type"
tmp_ad = adata.mod["airr"] if isinstance(adata, MuData) else adata
tmp_ad.obs[obs_col] = tmp_ad.obs[obs_col].cat.add_categories("extra_unused")

fig = plt.gcf()
fig.clear()
p = pl.clonotype_network(adata, color=obs_col, show_legend=True)
assert isinstance(p, plt.Axes)

# Collect legend labels from all axes in the figure
legend_labels = []
for ax in fig.get_axes():
legend = ax.get_legend()
if legend is not None:
legend_labels.extend([t.get_text() for t in legend.get_texts()])

assert "extra_unused" not in legend_labels, (
f"Legend should not contain unused category 'extra_unused', got: {legend_labels}"
)


@pytest.mark.extra
def test_logoplot(adata_cdr3):
p = pl.logoplot_cdr3_motif(adata_cdr3, chains="VJ_1")
Expand Down
Loading