From bee6f83782d7f01fad51a987362885bd431f738d Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 17:54:42 +0200 Subject: [PATCH 1/5] add project to query and grouped query --- ablate/queries/grouped_query.py | 33 +++++++++++++++++++++++++++++++-- ablate/queries/query.py | 25 +++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ablate/queries/grouped_query.py b/ablate/queries/grouped_query.py index a49d954..ac17d00 100644 --- a/ablate/queries/grouped_query.py +++ b/ablate/queries/grouped_query.py @@ -2,14 +2,14 @@ from collections import defaultdict from copy import deepcopy -from typing import TYPE_CHECKING, Callable, Dict, List, Literal +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Union from ablate.core.types import GroupedRun, Run if TYPE_CHECKING: # pragma: no cover from .query import Query # noqa: TC004 - from .selectors import AbstractMetric + from .selectors import AbstractMetric, AbstractParam class GroupedQuery: @@ -75,6 +75,35 @@ def sort(self, key: AbstractMetric, ascending: bool = False) -> GroupedQuery: ] ) + def project( + self, selectors: Union[AbstractParam, List[AbstractParam]] + ) -> GroupedQuery: + """Project the parameter space of the grouped runs in the grouped query to a + subset of parameters only including the specified selectors. + + This function is intended to be used for reducing the dimensionality of the + parameter space and therefore operates on a deep copy of the grouped runs in the + grouped query. + + Args: + selectors: Selector or list of selectors to project the grouped runs by. + + Returns: + A new grouped query with the projected grouped runs. + """ + if not isinstance(selectors, list): + selectors = [selectors] + + names = {s.name for s in selectors} + projected: List[GroupedRun] = [] + + for group in deepcopy(self._grouped): + for run in group.runs: + run.params = {k: v for k, v in run.params.items() if k in names} + projected.append(group) + + return GroupedQuery(projected) + def head(self, n: int) -> Query: """Get the first n runs inside each grouped run. diff --git a/ablate/queries/query.py b/ablate/queries/query.py index 87306c5..af6c45f 100644 --- a/ablate/queries/query.py +++ b/ablate/queries/query.py @@ -64,6 +64,31 @@ def sort(self, key: AbstractMetric, ascending: bool = False) -> Query: """ return Query(sorted(self._runs[:], key=key, reverse=not ascending)) + def project(self, selectors: Union[AbstractParam, List[AbstractParam]]) -> Query: + """Project the parameter space of the runs in the query to a subset of + parameters only including the specified selectors. + + This function is intended to be used for reducing the dimensionality of the + parameter space and therefore operates on a deep copy of the runs in the query. + + Args: + selectors: Selector or list of selectors to project the runs by. + + Returns: + A new query with the projected runs. + """ + if not isinstance(selectors, list): + selectors = [selectors] + + names = {s.name for s in selectors} + projected: List[Run] = [] + + for run in deepcopy(self._runs): + run.params = {k: v for k, v in run.params.items() if k in names} + projected.append(run) + + return Query(projected) + def groupby( self, selectors: Union[AbstractParam, List[AbstractParam]], From 7dd394bb135c2142b203f7568ee21e763ee64a7e Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 17:55:34 +0200 Subject: [PATCH 2/5] add tests --- tests/queries/test_grouped_query.py | 8 ++++++++ tests/queries/test_query.py | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/tests/queries/test_grouped_query.py b/tests/queries/test_grouped_query.py index 5093415..3f31fe5 100644 --- a/tests/queries/test_grouped_query.py +++ b/tests/queries/test_grouped_query.py @@ -111,3 +111,11 @@ def test_copy_and_deepcopy(grouped: GroupedQuery) -> None: assert all( dr is not gr for dr, gr in zip(deep._grouped, grouped._grouped, strict=False) ) + + +def test_grouped_query_project_reduces_param_space(grouped: GroupedQuery) -> None: + grouped = grouped.project(Param("model")) + for group in grouped._grouped: + for run in group.runs: + assert set(run.params.keys()) == {"model"} + assert set(run.metrics.keys()) == {"accuracy"} diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index 9ba3597..86ec73b 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -93,3 +93,10 @@ def test_query_deepcopy(runs: List[Run]) -> None: def test_query_len(runs: List[Run]) -> None: assert len(Query(runs)) == 3 + + +def test_project_reduces_parameter_space(runs: List[Run]) -> None: + q = Query(runs).project([Param("model")]) + for run in q.all(): + assert set(run.params.keys()) == {"model"} + assert set(run.metrics.keys()) == {"accuracy"} From 98fe382162677eb5ed1d1a943db36a6d2b2c32a5 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 17:55:47 +0200 Subject: [PATCH 3/5] add assertion to over if needed --- ablate/queries/grouped_query.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ablate/queries/grouped_query.py b/ablate/queries/grouped_query.py index ac17d00..056b9bc 100644 --- a/ablate/queries/grouped_query.py +++ b/ablate/queries/grouped_query.py @@ -224,8 +224,10 @@ def aggregate( case "last": return self.tail(1) case "best": + assert over is not None return self.topk(over, 1) case "worst": + assert over is not None return self.bottomk(over, 1) case "mean": return Query([self._mean_run(g) for g in self._grouped]) From b65b8ba7d59b7cd26b3ceeb33ba8003b4cb23cfe Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 17:57:05 +0200 Subject: [PATCH 4/5] simplify render metric plot utility function interface --- ablate/exporters/markdown_exporter.py | 6 +----- ablate/exporters/utils.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/ablate/exporters/markdown_exporter.py b/ablate/exporters/markdown_exporter.py index 1cbf12a..4db50e9 100644 --- a/ablate/exporters/markdown_exporter.py +++ b/ablate/exporters/markdown_exporter.py @@ -57,11 +57,7 @@ def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> str: if not isinstance(block, MetricPlot): raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.") - filename = render_metric_plot( - block.build(runs), - self.assets_dir, - type(block).__name__, - ) + filename = render_metric_plot(block, runs, self.assets_dir) if filename is None: return ( f"*No data available for {', '.join(m.label for m in block.metrics)}*" diff --git a/ablate/exporters/utils.py b/ablate/exporters/utils.py index e710e68..08cd170 100644 --- a/ablate/exporters/utils.py +++ b/ablate/exporters/utils.py @@ -1,11 +1,12 @@ import hashlib from pathlib import Path +from typing import List import matplotlib.pyplot as plt -import pandas as pd import seaborn as sns -from ablate.blocks import H1, H2, H3, H4, H5, H6 +from ablate.blocks import H1, H2, H3, H4, H5, H6, MetricPlot +from ablate.core.types.runs import Run HEADING_LEVELS = {H1: 1, H2: 2, H3: 3, H4: 4, H5: 5, H6: 6} @@ -19,11 +20,12 @@ def apply_default_plot_style() -> None: def render_metric_plot( - df: pd.DataFrame, + block: MetricPlot, + runs: List[Run], output_dir: Path, - name_prefix: str, ) -> str | None: apply_default_plot_style() + df = block.build(runs) if df.empty: return None @@ -38,11 +40,11 @@ def render_metric_plot( ) ax.set_xlabel("Step") ax.set_ylabel("Value") - ax.legend(title="Run", loc="best", frameon=False) + ax.legend(title=block.identifier.label, loc="best", frameon=False) plt.tight_layout() h = hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12] - filename = f"{name_prefix}_{h}.png" + filename = f"{type(block).__name__}_{h}.png" fig.savefig(output_dir / filename) plt.close(fig) return filename From f51115188363c52e2e2008f050de49642ba96624 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 18:00:08 +0200 Subject: [PATCH 5/5] update tests --- tests/queries/test_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index 86ec73b..7ed459f 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -96,7 +96,7 @@ def test_query_len(runs: List[Run]) -> None: def test_project_reduces_parameter_space(runs: List[Run]) -> None: - q = Query(runs).project([Param("model")]) + q = Query(runs).project(Param("model")) for run in q.all(): assert set(run.params.keys()) == {"model"} assert set(run.metrics.keys()) == {"accuracy"}