From a4868b23d31986033ffdeba5957d096c57a69c62 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Mon, 11 May 2026 18:10:31 +0200 Subject: [PATCH 1/9] Extract verification loading and lead-time panel plotter into shared modules Move the verif-netCDF loading helpers (_ensure_unique_lead_time, _select_best_sources, the long-form DataFrame builder, subset_df) into src/verification/loading.py, and the per-axes metric-vs-lead-time plotter into src/plotting/metric_lead_time_panel.py. Update verification_plot_metrics.py and report_experiment_dashboard.py to import from the new locations; the sys.path hack in the dashboard script is no longer needed. No behavior change. --- src/plotting/metric_lead_time_panel.py | 38 ++++++ src/verification/loading.py | 78 +++++++++++ .../scripts/report_experiment_dashboard.py | 4 +- workflow/scripts/verification_plot_metrics.py | 122 +++--------------- 4 files changed, 137 insertions(+), 105 deletions(-) create mode 100644 src/plotting/metric_lead_time_panel.py create mode 100644 src/verification/loading.py diff --git a/src/plotting/metric_lead_time_panel.py b/src/plotting/metric_lead_time_panel.py new file mode 100644 index 00000000..fbdc503f --- /dev/null +++ b/src/plotting/metric_lead_time_panel.py @@ -0,0 +1,38 @@ +"""Per-axes plotting helper for verification metrics vs. lead time.""" +import pandas as pd +from matplotlib.axes import Axes + + +def plot_panel( + ax: Axes, + sub_df: pd.DataFrame, + *, + metric: str, + title: str | None = None, + xlabel: str | None = "Lead Time [h]", + ylabel: str | None = None, + show_legend: bool = True, +) -> None: + """Plot one metric-vs-lead-time panel onto `ax`. + + `sub_df` must already be filtered to a single (metric, param, region, season, + init_hour) combo and contain at least the columns: source, lead_time, value. + One line per source is drawn; sources whose name contains "analysis" are + forced to black. + """ + if ylabel is None: + ylabel = metric + for source, df in sub_df.groupby("source"): + df.plot( + x="lead_time", + y="value", + kind="line", + marker="o", + title=title, + xlabel=xlabel or "", + ylabel=ylabel or "", + label=source, + color="black" if "analysis" in source else None, + ax=ax, + legend=show_legend, + ) diff --git a/src/verification/loading.py b/src/verification/loading.py new file mode 100644 index 00000000..153e3ec3 --- /dev/null +++ b/src/verification/loading.py @@ -0,0 +1,78 @@ +"""Helpers for loading aggregated verification netCDFs into long-form DataFrames.""" +from pathlib import Path + +import pandas as pd +import xarray as xr + + +def _ensure_unique_lead_time(ds: xr.Dataset) -> xr.Dataset: + """Drop duplicate lead_time entries within a Dataset (keep first occurrence).""" + try: + idx = ds.get_index("lead_time") + except Exception: + idx = pd.Index(ds["lead_time"].values) + if getattr(idx, "has_duplicates", False): + keep = ~idx.duplicated(keep="first") + ds = ds.isel(lead_time=keep) + return ds + + +def _select_best_sources(dfs: list[xr.Dataset]) -> list[xr.Dataset]: + """For sources present in multiple datasets, keep the one with the most lead_times.""" + src_sets = [set(d.source.values.tolist()) for d in dfs] + all_sources = set().union(*src_sets) + + best: dict[str, int] = {} + for s in all_sources: + candidates = [] + for i, d in enumerate(dfs): + if s in d.source.values: + di = d.sel(source=s) + try: + n = pd.Index(di["lead_time"].values).unique().size + except Exception: + n = len(pd.unique(di["lead_time"].values)) + candidates.append((i, n)) + if candidates: + best_idx, _ = max(candidates, key=lambda t: t[1]) + best[s] = best_idx + + out = [] + for i, d in enumerate(dfs): + drop_src = [s for s, b in best.items() if b != i and s in d.source.values] + if drop_src: + d = d.drop_sel(source=drop_src) + out.append(d) + return out + + +def load_long_df(verif_files: list[Path]) -> pd.DataFrame: + """Open verification netCDFs and return a long-form DataFrame. + + Columns: source, lead_time (hours, float), region, season, init_hour, + param, metric, value. + """ + dfs = [xr.open_dataset(f) for f in verif_files] + dfs = [_ensure_unique_lead_time(d) for d in dfs] + dfs = _select_best_sources(dfs) + ds = xr.concat(dfs, dim="source", join="outer") + + nonspatial_vars = [d for d in ds.data_vars if "spatial" not in d] + df = ( + ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() + ) + df[["param", "metric"]] = df["stack"].str.split(".", n=1, expand=True) + df.drop(columns=["stack"], inplace=True) + df["lead_time"] = df["lead_time"].dt.total_seconds() / 3600 + return df + + +def subset_df(df: pd.DataFrame, **kwargs) -> pd.DataFrame: + """Return rows of `df` matching every column=value (or column in [values]) constraint.""" + mask = pd.Series([True] * len(df)) + for key, value in kwargs.items(): + if isinstance(value, (list, tuple, set)): + mask &= df[key].isin(value) + else: + mask &= df[key] == value + return df[mask] diff --git a/workflow/scripts/report_experiment_dashboard.py b/workflow/scripts/report_experiment_dashboard.py index 866d18a5..bf333537 100644 --- a/workflow/scripts/report_experiment_dashboard.py +++ b/workflow/scripts/report_experiment_dashboard.py @@ -1,14 +1,12 @@ import argparse import logging -import sys as _sys from pathlib import Path import jinja2 import xarray as xr -_sys.path.append(str(Path(__file__).parent)) -from verification_plot_metrics import _ensure_unique_lead_time, _select_best_sources from verification import decode_metric +from verification.loading import _ensure_unique_lead_time, _select_best_sources LOG = logging.getLogger(__name__) logging.basicConfig( diff --git a/workflow/scripts/verification_plot_metrics.py b/workflow/scripts/verification_plot_metrics.py index 92149d82..673654ba 100644 --- a/workflow/scripts/verification_plot_metrics.py +++ b/workflow/scripts/verification_plot_metrics.py @@ -5,9 +5,10 @@ from pathlib import Path import matplotlib.pyplot as plt -import pandas as pd -import xarray as xr + +from plotting.metric_lead_time_panel import plot_panel from verification import decode_metric +from verification.loading import load_long_df, subset_df LOG = logging.getLogger(__name__) logging.basicConfig( @@ -15,83 +16,10 @@ ) -def _ensure_unique_lead_time(ds: xr.Dataset) -> xr.Dataset: - """Drop duplicate lead_time entries within a Dataset (keep first occurrence).""" - try: - idx = ds.get_index("lead_time") - except Exception: - idx = pd.Index(ds["lead_time"].values) - if getattr(idx, "has_duplicates", False): - keep = ~idx.duplicated(keep="first") - ds = ds.isel(lead_time=keep) - return ds - - -def _select_best_sources(dfs: list[xr.Dataset]) -> list[xr.Dataset]: - """ - If the same 'source' exists in multiple datasets, keep it only from the dataset - that has the largest number of unique lead_time entries. Drop it from others. - """ - # Compute unique sources per dataset - src_sets = [set(d.source.values.tolist()) for d in dfs] - all_sources = set().union(*src_sets) - - # Decide best provider (dataset index) for each source - best = {} - for s in all_sources: - candidates = [] - for i, d in enumerate(dfs): - if s in d.source.values: - di = d.sel(source=s) - try: - n = pd.Index(di["lead_time"].values).unique().size - except Exception: - n = len(pd.unique(di["lead_time"].values)) - candidates.append((i, n)) - if candidates: - best_idx, _ = max(candidates, key=lambda t: t[1]) - best[s] = best_idx - - # Drop non-best occurrences - out = [] - for i, d in enumerate(dfs): - drop_src = [s for s, b in best.items() if b != i and s in d.source.values] - if drop_src: - d = d.drop_sel(source=drop_src) - out.append(d) - return out - - -def subset_df(df, **kwargs): - mask = pd.Series([True] * len(df)) - for key, value in kwargs.items(): - if isinstance(value, (list, tuple, set)): - mask &= df[key].isin(value) - else: - mask &= df[key] == value - return df[mask] - - def main(args: Namespace) -> None: """Main function to verify results from KENDA-1 data.""" - # remove duplicated but not identical values from analyses (rounding errors) - dfs = [xr.open_dataset(f) for f in args.verif_files] - # 1) Ensure each dataset has unique lead_time values - dfs = [_ensure_unique_lead_time(d) for d in dfs] - # 2) For sources present in multiple datasets, keep the one with most lead_times - dfs = _select_best_sources(dfs) - # 3) Concatenate by source; outer join to keep the union of lead_times - ds = xr.concat(dfs, dim="source", join="outer") - - # extract only non-spatial variables to pd.DataFrame - nonspatial_vars = [d for d in ds.data_vars if "spatial" not in d] - all_df = ( - ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() - ) - all_df[["param", "metric"]] = all_df["stack"].str.split(".", n=1, expand=True) - all_df.drop(columns=["stack"], inplace=True) - all_df["lead_time"] = all_df["lead_time"].dt.total_seconds() / 3600 + all_df = load_long_df(args.verif_files) metrics = all_df["metric"].unique() params = all_df["param"].unique() @@ -108,17 +36,14 @@ def main(args: Namespace) -> None: f"Processing region: {region}, metric: {metric}, param: {param}, season: {season}, init_hour: {init_hour}" ) - def _subset_df(df): - return subset_df( - df, - region=region, - metric=metric, - param=param, - season=season, - init_hour=init_hour, - ) - - sub_df = _subset_df(all_df).dropna() + sub_df = subset_df( + all_df, + region=region, + metric=metric, + param=param, + season=season, + init_hour=init_hour, + ).dropna() if sub_df.empty: continue @@ -126,19 +51,14 @@ def _subset_df(df): title = f"{metric} - {param} - {region}" title += f"- {season} - {init_hour}" if args.stratify else "" - for source, df in sub_df.groupby("source"): - df.plot( - x="lead_time", - y="value", - kind="line", - marker="o", - title=title, - xlabel="Lead Time [h]", - ylabel=decode_metric(metric), - label=source, - color="black" if "analysis" in source else None, - ax=ax, - ) + plot_panel( + ax, + sub_df, + metric=metric, + title=title, + ylabel=decode_metric(metric), + ) + args.output_dir.mkdir(parents=True, exist_ok=True) fn = f"{metric}_{param}" fn += f"_{season}_{init_hour}.png" if args.stratify else ".png" @@ -153,8 +73,6 @@ def _subset_df(df): type=Path, nargs="+", help="Paths to verification files.", - # "--verif_files", type=Path, nargs="+", help="Paths to verification files.", - # default = list(Path("output/data").glob("*/*/verif_aggregated.nc")), required=False ) parser.add_argument( "--stratify", From d2646f43ccc6da70e8faadbd965eea8022f5a95a Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Mon, 11 May 2026 18:12:13 +0200 Subject: [PATCH 2/9] Add multi-panel metric-vs-lead-time plots (MRB-860) New rule verification_metrics_multipanel_plot builds one PNG per named layout under results//multipanel/.png, driven by a JSON spec passed inline. Aggregator rule verification_metrics_multipanel_plot_all expands over every entry of the new optional config field multipanel_plots (no-op when absent, so existing configs are unaffected). Each layout specifies rows, cols, optional figsize and figure title, and a row-major list of panels (metric, param, optional region/season/init_hour/title/ylim). The script reuses load_long_df and plot_panel, draws all panels with sharex=True and independent y-axes, and emits a single deduped legend at the bottom of the figure. --- src/evalml/config.py | 70 ++++++++- workflow/rules/verification.smk | 47 ++++++ .../verification_plot_metrics_multipanel.py | 139 ++++++++++++++++++ 3 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 workflow/scripts/verification_plot_metrics_multipanel.py diff --git a/src/evalml/config.py b/src/evalml/config.py index f1343d07..b65365a1 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Dict, List, Any, ClassVar, FrozenSet -from pydantic import BaseModel, Field, RootModel, field_validator +from pydantic import BaseModel, Field, RootModel, field_validator, model_validator PROJECT_ROOT = Path(__file__).parents[2] @@ -227,6 +227,66 @@ class Stratification(BaseModel): ) +class MultipanelPanelSpec(BaseModel): + """One panel inside a multi-panel metric-vs-lead-time figure.""" + + metric: str = Field(..., description="Metric name (e.g. 'rmse').") + param: str = Field(..., description="Parameter name (e.g. 'T_2M').") + region: str = Field( + "all", + description="Region to subset to. 'all' uses the unstratified aggregate.", + ) + season: str = Field( + "all", + description="Season to subset to. 'all' uses the unstratified aggregate.", + ) + init_hour: int = Field( + -999, + description="Init hour to subset to. -999 (sentinel) uses the unstratified aggregate.", + ) + title: str | None = Field( + None, + description="Panel title. Defaults to ' - '.", + ) + ylim: List[float] | None = Field( + None, + description="Optional [ymin, ymax] for this panel's y-axis.", + min_length=2, + max_length=2, + ) + + model_config = {"extra": "forbid"} + + +class MultipanelPlotSpec(BaseModel): + """Layout for a single multi-panel metric-vs-lead-time figure.""" + + rows: int = Field(..., ge=1, description="Number of subplot rows.") + cols: int = Field(..., ge=1, description="Number of subplot columns.") + figsize: List[float] | None = Field( + None, + description="Optional [width, height] in inches. Defaults to (4.5*cols, 3.5*rows).", + min_length=2, + max_length=2, + ) + title: str | None = Field(None, description="Optional figure-level title.") + panels: List[MultipanelPanelSpec] = Field( + ..., + description="Per-panel specs in row-major order. Length must equal rows*cols.", + ) + + model_config = {"extra": "forbid"} + + @model_validator(mode="after") + def _check_panel_count(self) -> "MultipanelPlotSpec": + expected = self.rows * self.cols + if len(self.panels) != expected: + raise ValueError( + f"panels has length {len(self.panels)}, expected rows*cols = {expected}" + ) + return self + + class Dashboard(BaseModel): """Settings for the dashboard""" @@ -351,6 +411,14 @@ def validate_threshold_operators( dashboard: Dashboard locations: Locations profile: Profile + multipanel_plots: Dict[str, MultipanelPlotSpec] = Field( + default_factory=dict, + description=( + "Optional named multi-panel metric-vs-lead-time figures. " + "Each entry produces one PNG under results//multipanel/.png " + "when the verification_metrics_multipanel_plot_all target is built." + ), + ) model_config = { "extra": "forbid", # fail on misspelled keys diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 58b215e5..335e4cbc 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -1,6 +1,8 @@ # ----------------------------------------------------- # # VERIFICATION WORKFLOW # # ----------------------------------------------------- # +import json +import shlex from datetime import datetime import pandas as pd @@ -164,3 +166,48 @@ rule verification_metrics_plot: """ uv run {input.script} {input.verif} --output_dir {output} > {log} 2>&1 """ + + +def _multipanel_plots_cfg() -> dict: + return config.get("multipanel_plots") or {} + + +rule verification_metrics_multipanel_plot: + input: + "src/verification/__init__.py", + script="workflow/scripts/verification_plot_metrics_multipanel.py", + verif=list(EXPERIMENT_PARTICIPANTS.values()), + output: + OUT_ROOT / "results/{experiment}/multipanel/{plot_name}.png", + params: + spec_json=lambda wc: shlex.quote( + json.dumps(_multipanel_plots_cfg()[wc.plot_name]) + ), + log: + OUT_ROOT + / "logs/verification_metrics_multipanel_plot/{experiment}-{plot_name}.log", + resources: + cpus_per_task=4, + mem_mb=20_000, + runtime="20m", + shell: + """ + uv run {input.script} {input.verif} \ + --spec_json {params.spec_json} \ + --output {output} > {log} 2>&1 + """ + + +rule verification_metrics_multipanel_plot_all: + """Build every multipanel layout declared in `multipanel_plots` in the config. + + Invoke by rule name (no wildcards). No-op when the config has no + `multipanel_plots` section. + """ + localrule: True + input: + lambda wc: expand( + OUT_ROOT / "results/{experiment}/multipanel/{plot_name}.png", + experiment=[EXPERIMENT_NAME], + plot_name=list(_multipanel_plots_cfg().keys()), + ), diff --git a/workflow/scripts/verification_plot_metrics_multipanel.py b/workflow/scripts/verification_plot_metrics_multipanel.py new file mode 100644 index 00000000..258b2cc2 --- /dev/null +++ b/workflow/scripts/verification_plot_metrics_multipanel.py @@ -0,0 +1,139 @@ +"""Build a multi-panel metric-vs-lead-time figure from aggregated verification files. + +The panel layout (rows, cols, per-panel selectors) is supplied as a JSON spec +either inline (``--spec_json ''``) or as a path to a JSON file +(``--spec_path /path/to/spec.json``). The spec schema mirrors +``MultipanelPlotSpec`` in ``src/evalml/config.py``. +""" +import json +import logging +from argparse import ArgumentParser +from argparse import Namespace +from pathlib import Path + +import matplotlib.pyplot as plt + +from plotting.metric_lead_time_panel import plot_panel +from verification import decode_metric +from verification.loading import load_long_df, subset_df + +LOG = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + + +def _load_spec(args: Namespace) -> dict: + if args.spec_json: + return json.loads(args.spec_json) + return json.loads(args.spec_path.read_text()) + + +def main(args: Namespace) -> None: + spec = _load_spec(args) + rows = int(spec["rows"]) + cols = int(spec["cols"]) + panels = spec["panels"] + if len(panels) != rows * cols: + raise ValueError( + f"panels has length {len(panels)}, expected rows*cols = {rows * cols}" + ) + + all_df = load_long_df(args.verif_files) + + figsize = tuple(spec.get("figsize") or (4.5 * cols, 3.5 * rows)) + fig, axes = plt.subplots( + rows, cols, sharex=True, figsize=figsize, squeeze=False + ) + + legend_entries: dict[str, object] = {} + for idx, panel in enumerate(panels): + r, c = divmod(idx, cols) + ax = axes[r][c] + metric = panel["metric"] + param = panel["param"] + sub = subset_df( + all_df, + metric=metric, + param=param, + region=panel.get("region", "all"), + season=panel.get("season", "all"), + init_hour=panel.get("init_hour", -999), + ).dropna() + if sub.empty: + LOG.warning( + "No data for panel %d (metric=%s, param=%s, region=%s, season=%s, init_hour=%s)", + idx, metric, param, + panel.get("region", "all"), + panel.get("season", "all"), + panel.get("init_hour", -999), + ) + + is_bottom = r == rows - 1 + is_left = c == 0 + title = panel.get("title", f"{metric} - {param}") + plot_panel( + ax, + sub, + metric=metric, + title=title, + xlabel="Lead Time [h]" if is_bottom else None, + ylabel=decode_metric(metric) if is_left else None, + show_legend=False, + ) + if panel.get("ylim"): + ax.set_ylim(panel["ylim"]) + + handles, labels = ax.get_legend_handles_labels() + for handle, label in zip(handles, labels): + legend_entries.setdefault(label, handle) + + if spec.get("title"): + fig.suptitle(spec["title"]) + + if legend_entries: + fig.legend( + list(legend_entries.values()), + list(legend_entries.keys()), + loc="lower center", + ncol=min(len(legend_entries), 4), + bbox_to_anchor=(0.5, 0.0), + ) + + top = 0.95 if spec.get("title") else 0.98 + fig.tight_layout(rect=[0, 0.08, 1, top]) + + args.output.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(args.output, dpi=150) + plt.close(fig) + + +if __name__ == "__main__": + parser = ArgumentParser(description=__doc__) + parser.add_argument( + "verif_files", + type=Path, + nargs="+", + help="Paths to aggregated verification netCDFs.", + ) + spec_group = parser.add_mutually_exclusive_group(required=True) + spec_group.add_argument( + "--spec_json", + type=str, + default=None, + help="Inline JSON string describing the panel layout.", + ) + spec_group.add_argument( + "--spec_path", + type=Path, + default=None, + help="Path to a JSON file describing the panel layout.", + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output PNG path.", + ) + args = parser.parse_args() + main(args) From 72cf176aa859012cefef48abe6768b999b5981dd Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Mon, 11 May 2026 19:53:42 +0200 Subject: [PATCH 3/9] Build multipanel plots as part of experiment_all Expand experiment_all's inputs over the multipanel_plots config entries so `evalml experiment` produces the layouts declared in the YAML. No-op when the section is absent. --- workflow/Snakefile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/workflow/Snakefile b/workflow/Snakefile index a2586cb4..34a2344b 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -133,6 +133,11 @@ rule experiment_all: rules.verification_metrics_plot.output, experiment=EXPERIMENT_NAME, ), + expand( + OUT_ROOT / "results/{experiment}/multipanel/{plot_name}.png", + experiment=[EXPERIMENT_NAME], + plot_name=list((config.get("multipanel_plots") or {}).keys()), + ), rule showcase_all: From 8929db9b4417afba882cabb089732ae0803bf66d Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Mon, 11 May 2026 19:54:50 +0200 Subject: [PATCH 4/9] Improve lead-time metric plots: units, panel labels, spacing * New src/plotting/units.py: PARAM_UNITS mapping + metric_units(metric, param) helper that yields '' for CORR/R2 and the param's canonical units otherwise. * plot_panel now accepts `param` and auto-builds the y-axis label as " []" when no explicit ylabel is given. It also accepts `panel_label` (e.g. "a)") rendered as a bold, left-aligned title at the same height as the centred title. * verification_plot_metrics.py passes `param` so single-panel plots pick up units automatically. * verification_plot_metrics_multipanel.py: numbers panels a), b), ..., in row-major order; replaces tight_layout with explicit subplots_adjust margins so inter-panel hspace/wspace are honoured and the bottom legend has guaranteed room. --- src/plotting/metric_lead_time_panel.py | 22 ++++++++++++- src/plotting/units.py | 26 +++++++++++++++ workflow/scripts/verification_plot_metrics.py | 3 +- .../verification_plot_metrics_multipanel.py | 32 +++++++++++++++++-- 4 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 src/plotting/units.py diff --git a/src/plotting/metric_lead_time_panel.py b/src/plotting/metric_lead_time_panel.py index fbdc503f..3a144209 100644 --- a/src/plotting/metric_lead_time_panel.py +++ b/src/plotting/metric_lead_time_panel.py @@ -2,13 +2,25 @@ import pandas as pd from matplotlib.axes import Axes +from verification import decode_metric + +from .units import metric_units + + +def _default_ylabel(metric: str, param: str | None) -> str: + label = decode_metric(metric) + units = metric_units(metric, param) if param is not None else "" + return f"{label} [{units}]" if units else label + def plot_panel( ax: Axes, sub_df: pd.DataFrame, *, metric: str, + param: str | None = None, title: str | None = None, + panel_label: str | None = None, xlabel: str | None = "Lead Time [h]", ylabel: str | None = None, show_legend: bool = True, @@ -19,9 +31,15 @@ def plot_panel( init_hour) combo and contain at least the columns: source, lead_time, value. One line per source is drawn; sources whose name contains "analysis" are forced to black. + + If `ylabel` is None and `param` is provided, the y-axis label is built as + " []" via plotting.units.metric_units. + + `panel_label` (e.g. "a)") is rendered left-aligned at the same height as + the centred title. """ if ylabel is None: - ylabel = metric + ylabel = _default_ylabel(metric, param) for source, df in sub_df.groupby("source"): df.plot( x="lead_time", @@ -36,3 +54,5 @@ def plot_panel( ax=ax, legend=show_legend, ) + if panel_label: + ax.set_title(panel_label, loc="left", fontweight="bold") diff --git a/src/plotting/units.py b/src/plotting/units.py new file mode 100644 index 00000000..22d219d4 --- /dev/null +++ b/src/plotting/units.py @@ -0,0 +1,26 @@ +"""Canonical units for verification parameters and metrics. + +Storage units in the verification netCDFs (BIAS, RMSE, MAE, STDE, ... all +inherit these). Update the dict if a parameter's internal representation +changes. +""" + +PARAM_UNITS: dict[str, str] = { + "T_2M": "K", + "TD_2M": "K", + "PMSL": "Pa", + "PS": "Pa", + "TOT_PREC": "mm", + "U_10M": "m/s", + "V_10M": "m/s", + "SP_10M": "m/s", +} + +UNITLESS_METRICS: set[str] = {"CORR", "R2"} + + +def metric_units(metric: str, param: str) -> str: + """Return the canonical units of (metric, param), or '' if unitless/unknown.""" + if metric.upper() in UNITLESS_METRICS: + return "" + return PARAM_UNITS.get(param, "") diff --git a/workflow/scripts/verification_plot_metrics.py b/workflow/scripts/verification_plot_metrics.py index 673654ba..b287241c 100644 --- a/workflow/scripts/verification_plot_metrics.py +++ b/workflow/scripts/verification_plot_metrics.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt from plotting.metric_lead_time_panel import plot_panel -from verification import decode_metric from verification.loading import load_long_df, subset_df LOG = logging.getLogger(__name__) @@ -55,8 +54,8 @@ def main(args: Namespace) -> None: ax, sub_df, metric=metric, + param=param, title=title, - ylabel=decode_metric(metric), ) args.output_dir.mkdir(parents=True, exist_ok=True) diff --git a/workflow/scripts/verification_plot_metrics_multipanel.py b/workflow/scripts/verification_plot_metrics_multipanel.py index 258b2cc2..e26876b4 100644 --- a/workflow/scripts/verification_plot_metrics_multipanel.py +++ b/workflow/scripts/verification_plot_metrics_multipanel.py @@ -7,6 +7,7 @@ """ import json import logging +import string from argparse import ArgumentParser from argparse import Namespace from pathlib import Path @@ -16,6 +17,16 @@ from plotting.metric_lead_time_panel import plot_panel from verification import decode_metric from verification.loading import load_long_df, subset_df +from plotting.units import metric_units + + +def _panel_label(idx: int) -> str: + """Return 'a)', 'b)', ..., 'z)', 'aa)', ... for the given 0-based index.""" + letters = string.ascii_lowercase + if idx < len(letters): + return f"{letters[idx]})" + a, b = divmod(idx, len(letters)) + return f"{letters[a - 1]}{letters[b]})" LOG = logging.getLogger(__name__) logging.basicConfig( @@ -72,13 +83,21 @@ def main(args: Namespace) -> None: is_bottom = r == rows - 1 is_left = c == 0 title = panel.get("title", f"{metric} - {param}") + units = metric_units(metric, param) + ylabel = ( + (f"{decode_metric(metric)} [{units}]" if units else decode_metric(metric)) + if is_left + else None + ) plot_panel( ax, sub, metric=metric, + param=param, title=title, + panel_label=_panel_label(idx), xlabel="Lead Time [h]" if is_bottom else None, - ylabel=decode_metric(metric) if is_left else None, + ylabel=ylabel, show_legend=False, ) if panel.get("ylim"): @@ -100,8 +119,15 @@ def main(args: Namespace) -> None: bbox_to_anchor=(0.5, 0.0), ) - top = 0.95 if spec.get("title") else 0.98 - fig.tight_layout(rect=[0, 0.08, 1, top]) + top = 0.92 if spec.get("title") else 0.96 + fig.subplots_adjust( + left=0.09, + right=0.97, + top=top, + bottom=0.13, + hspace=0.45, + wspace=0.28, + ) args.output.parent.mkdir(parents=True, exist_ok=True) fig.savefig(args.output, dpi=150) From e3d019ec4dd67e3ac2b055b9d0b9e0b7f6569ab2 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Wed, 13 May 2026 11:46:05 +0200 Subject: [PATCH 5/9] Regenerate config JSON schema for multipanel_plots Pick up MultipanelPanelSpec, MultipanelPlotSpec, and the new multipanel_plots field on ConfigModel. Generated via `python src/evalml/config.py workflow/tools/config.schema.json`. --- workflow/tools/config.schema.json | 143 ++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index d6153e87..a41ab8a7 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -445,6 +445,141 @@ "title": "Locations", "type": "object" }, + "MultipanelPanelSpec": { + "additionalProperties": false, + "description": "One panel inside a multi-panel metric-vs-lead-time figure.", + "properties": { + "metric": { + "description": "Metric name (e.g. 'rmse').", + "title": "Metric", + "type": "string" + }, + "param": { + "description": "Parameter name (e.g. 'T_2M').", + "title": "Param", + "type": "string" + }, + "region": { + "default": "all", + "description": "Region to subset to. 'all' uses the unstratified aggregate.", + "title": "Region", + "type": "string" + }, + "season": { + "default": "all", + "description": "Season to subset to. 'all' uses the unstratified aggregate.", + "title": "Season", + "type": "string" + }, + "init_hour": { + "default": -999, + "description": "Init hour to subset to. -999 (sentinel) uses the unstratified aggregate.", + "title": "Init Hour", + "type": "integer" + }, + "title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Panel title. Defaults to ' - '.", + "title": "Title" + }, + "ylim": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional [ymin, ymax] for this panel's y-axis.", + "title": "Ylim" + } + }, + "required": [ + "metric", + "param" + ], + "title": "MultipanelPanelSpec", + "type": "object" + }, + "MultipanelPlotSpec": { + "additionalProperties": false, + "description": "Layout for a single multi-panel metric-vs-lead-time figure.", + "properties": { + "rows": { + "description": "Number of subplot rows.", + "minimum": 1, + "title": "Rows", + "type": "integer" + }, + "cols": { + "description": "Number of subplot columns.", + "minimum": 1, + "title": "Cols", + "type": "integer" + }, + "figsize": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional [width, height] in inches. Defaults to (4.5*cols, 3.5*rows).", + "title": "Figsize" + }, + "title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional figure-level title.", + "title": "Title" + }, + "panels": { + "description": "Per-panel specs in row-major order. Length must equal rows*cols.", + "items": { + "$ref": "#/$defs/MultipanelPanelSpec" + }, + "title": "Panels", + "type": "array" + } + }, + "required": [ + "rows", + "cols", + "panels" + ], + "title": "MultipanelPlotSpec", + "type": "object" + }, "Profile": { "description": "Workflow execution profile.", "properties": { @@ -624,6 +759,14 @@ }, "profile": { "$ref": "#/$defs/Profile" + }, + "multipanel_plots": { + "additionalProperties": { + "$ref": "#/$defs/MultipanelPlotSpec" + }, + "description": "Optional named multi-panel metric-vs-lead-time figures. Each entry produces one PNG under results//multipanel/.png when the verification_metrics_multipanel_plot_all target is built.", + "title": "Multipanel Plots", + "type": "object" } }, "required": [ From 3406ae55863600e4bee40742ac464af46bd62fde Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Wed, 13 May 2026 11:46:51 +0200 Subject: [PATCH 6/9] Pin source-color mapping across dashboard and matplotlib plots Source -> color is now bijective and stable across the dashboard, the single-panel plots, and the multipanel plots. * src/plotting/source_colors.py: TABLEAU10 palette and a source_color_map() helper that assigns colors over the alphabetically-sorted full source list. Wraps past 10 sources; switch palettes if that becomes an issue in practice. * plot_panel grows an optional color_map arg. Both single-panel and multipanel scripts build one map from all_df["source"].unique() and pass it down. The matplotlib-only "analysis = black" override is dropped to match the dashboard. * resources/report/dashboard/script.js pins the color scale's domain to the full source list and uses Vega-Lite's "tableau10" scheme, so toggling sources in the UI no longer reshuffles colors. --- resources/report/dashboard/script.js | 7 ++++ src/plotting/metric_lead_time_panel.py | 11 ++++-- src/plotting/source_colors.py | 38 +++++++++++++++++++ workflow/scripts/verification_plot_metrics.py | 3 ++ .../verification_plot_metrics_multipanel.py | 5 ++- 5 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 src/plotting/source_colors.py diff --git a/resources/report/dashboard/script.js b/resources/report/dashboard/script.js index 170ee050..f2936021 100644 --- a/resources/report/dashboard/script.js +++ b/resources/report/dashboard/script.js @@ -62,6 +62,12 @@ document.getElementById("param-select").addEventListener("change", updateChart); data = JSON.parse(document.getElementById("verif-data").textContent) header = document.getElementById("header-text").textContent.trim() +// Pin the source -> color mapping to the full, alphabetically-sorted source +// list so it stays bijective even when sources are toggled in the UI. Must +// match src/plotting/source_colors.py to keep the dashboard and the static +// matplotlib figures consistent. +const allSources = [...new Set(data.map(d => d.source))].sort(); + // Define base spec var spec = { "data": { "values": data }, @@ -106,6 +112,7 @@ var spec = { "color": { "field": "source", "type": "nominal", + "scale": { "scheme": "tableau10", "domain": allSources }, "legend": { "orient": "top", "title": "Data Source", "offset": 0, "padding": 10 } }, "shape": { diff --git a/src/plotting/metric_lead_time_panel.py b/src/plotting/metric_lead_time_panel.py index 3a144209..96137a8d 100644 --- a/src/plotting/metric_lead_time_panel.py +++ b/src/plotting/metric_lead_time_panel.py @@ -24,19 +24,24 @@ def plot_panel( xlabel: str | None = "Lead Time [h]", ylabel: str | None = None, show_legend: bool = True, + color_map: dict[str, str] | None = None, ) -> None: """Plot one metric-vs-lead-time panel onto `ax`. `sub_df` must already be filtered to a single (metric, param, region, season, init_hour) combo and contain at least the columns: source, lead_time, value. - One line per source is drawn; sources whose name contains "analysis" are - forced to black. + One line per source is drawn. If `ylabel` is None and `param` is provided, the y-axis label is built as " []" via plotting.units.metric_units. `panel_label` (e.g. "a)") is rendered left-aligned at the same height as the centred title. + + If `color_map` is given, each source's line is drawn in + ``color_map[source]``; sources missing from the map fall back to + matplotlib's default color cycle. Use ``plotting.source_colors.source_color_map`` + to build a map that matches the dashboard. """ if ylabel is None: ylabel = _default_ylabel(metric, param) @@ -50,7 +55,7 @@ def plot_panel( xlabel=xlabel or "", ylabel=ylabel or "", label=source, - color="black" if "analysis" in source else None, + color=(color_map or {}).get(source), ax=ax, legend=show_legend, ) diff --git a/src/plotting/source_colors.py b/src/plotting/source_colors.py new file mode 100644 index 00000000..8f6a48c5 --- /dev/null +++ b/src/plotting/source_colors.py @@ -0,0 +1,38 @@ +"""Stable source -> color mapping shared with the dashboard. + +The dashboard uses Vega-Lite's ``tableau10`` categorical scheme and pins its +``color.scale.domain`` to the alphabetically-sorted full source list so the +mapping stays bijective regardless of dashboard filters. The matplotlib plots +use the same palette and ordering so a given source has the same color in +every figure produced from a verification run. + +Both the dashboard and the matplotlib side wrap around when there are more +than ``len(TABLEAU10)`` sources, at which point two sources will share a +color. Switch palettes (e.g. to ``tableau20`` or a deterministic HSV ramp) +if that becomes a problem. +""" + +# Vega-Lite "tableau10" scheme: +# https://vega.github.io/vega/docs/schemes/#tableau10 +TABLEAU10: list[str] = [ + "#4c78a8", + "#f58518", + "#e45756", + "#72b7b2", + "#54a24b", + "#eeca3b", + "#b279a2", + "#ff9da6", + "#9d755d", + "#bab0ac", +] + + +def source_color_map(sources) -> dict[str, str]: + """Return ``{source: color}`` over unique sources, ordered alphabetically. + + Wraps around for more than ``len(TABLEAU10)`` sources, matching Vega-Lite's + behaviour for a categorical scale whose domain exceeds the scheme. + """ + ordered = sorted(set(sources)) + return {s: TABLEAU10[i % len(TABLEAU10)] for i, s in enumerate(ordered)} diff --git a/workflow/scripts/verification_plot_metrics.py b/workflow/scripts/verification_plot_metrics.py index b287241c..77e255bb 100644 --- a/workflow/scripts/verification_plot_metrics.py +++ b/workflow/scripts/verification_plot_metrics.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt from plotting.metric_lead_time_panel import plot_panel +from plotting.source_colors import source_color_map from verification.loading import load_long_df, subset_df LOG = logging.getLogger(__name__) @@ -19,6 +20,7 @@ def main(args: Namespace) -> None: """Main function to verify results from KENDA-1 data.""" all_df = load_long_df(args.verif_files) + color_map = source_color_map(all_df["source"].unique()) metrics = all_df["metric"].unique() params = all_df["param"].unique() @@ -56,6 +58,7 @@ def main(args: Namespace) -> None: metric=metric, param=param, title=title, + color_map=color_map, ) args.output_dir.mkdir(parents=True, exist_ok=True) diff --git a/workflow/scripts/verification_plot_metrics_multipanel.py b/workflow/scripts/verification_plot_metrics_multipanel.py index e26876b4..4ef81bcb 100644 --- a/workflow/scripts/verification_plot_metrics_multipanel.py +++ b/workflow/scripts/verification_plot_metrics_multipanel.py @@ -15,9 +15,10 @@ import matplotlib.pyplot as plt from plotting.metric_lead_time_panel import plot_panel +from plotting.source_colors import source_color_map +from plotting.units import metric_units from verification import decode_metric from verification.loading import load_long_df, subset_df -from plotting.units import metric_units def _panel_label(idx: int) -> str: @@ -51,6 +52,7 @@ def main(args: Namespace) -> None: ) all_df = load_long_df(args.verif_files) + color_map = source_color_map(all_df["source"].unique()) figsize = tuple(spec.get("figsize") or (4.5 * cols, 3.5 * rows)) fig, axes = plt.subplots( @@ -99,6 +101,7 @@ def main(args: Namespace) -> None: xlabel="Lead Time [h]" if is_bottom else None, ylabel=ylabel, show_legend=False, + color_map=color_map, ) if panel.get("ylim"): ax.set_ylim(panel["ylim"]) From bc105abe81198dc3aa80065b816605c2f54e3031 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Wed, 13 May 2026 11:48:03 +0200 Subject: [PATCH 7/9] Add unit tests for MultipanelPanelSpec / MultipanelPlotSpec Cover panel-level defaults (region/season/init_hour, title, ylim), the row*cols == len(panels) validator, the extra="forbid" guard on both models, and that ConfigModel.multipanel_plots defaults to an empty dict and round-trips a named layout. --- tests/unit/test_config.py | 67 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 40281a6e..f508094b 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -2,7 +2,7 @@ import pytest -from evalml.config import ConfigModel +from evalml.config import ConfigModel, MultipanelPanelSpec, MultipanelPlotSpec def test_example_forecasters_config(example_forecasters_config): @@ -91,3 +91,68 @@ def test_workflow_derives_baseline_id_from_root_stem(example_interpolators_confi "root": "/store_new/mch/msopr/ml/COSMO-E_hourly", "steps": "0/120/1", } + + +def _spec(rows, cols, panel_count=None): + n = panel_count if panel_count is not None else rows * cols + return { + "rows": rows, + "cols": cols, + "panels": [{"metric": "BIAS", "param": "T_2M"} for _ in range(n)], + } + + +def test_multipanel_panel_defaults(): + panel = MultipanelPanelSpec.model_validate({"metric": "BIAS", "param": "T_2M"}) + assert panel.region == "all" + assert panel.season == "all" + assert panel.init_hour == -999 + assert panel.title is None + assert panel.ylim is None + + +def test_multipanel_panel_forbids_extras(): + with pytest.raises(ValueError, match="Extra"): + MultipanelPanelSpec.model_validate( + {"metric": "BIAS", "param": "T_2M", "unknown": True} + ) + + +def test_multipanel_plot_accepts_matching_panel_count(): + spec = MultipanelPlotSpec.model_validate(_spec(2, 3)) + assert spec.rows == 2 + assert spec.cols == 3 + assert len(spec.panels) == 6 + + +def test_multipanel_plot_rejects_mismatched_panel_count(): + with pytest.raises(ValueError, match=r"rows\*cols"): + MultipanelPlotSpec.model_validate(_spec(2, 2, panel_count=3)) + + +def test_multipanel_plot_forbids_extras(): + bad = _spec(1, 1) + bad["unexpected"] = True + with pytest.raises(ValueError, match="Extra"): + MultipanelPlotSpec.model_validate(bad) + + +def test_multipanel_plot_rejects_zero_dim(): + with pytest.raises(ValueError): + MultipanelPlotSpec.model_validate(_spec(0, 1, panel_count=0)) + + +def test_configmodel_multipanel_plots_default(example_forecasters_config): + """`multipanel_plots` is optional and defaults to an empty dict.""" + cfg = ConfigModel.model_validate(example_forecasters_config) + assert cfg.multipanel_plots == {} + + +def test_configmodel_multipanel_plots_roundtrip(example_forecasters_config): + example_forecasters_config["multipanel_plots"] = { + "bias_overview": _spec(1, 2), + } + cfg = ConfigModel.model_validate(example_forecasters_config) + assert "bias_overview" in cfg.multipanel_plots + assert cfg.multipanel_plots["bias_overview"].rows == 1 + assert cfg.multipanel_plots["bias_overview"].cols == 2 From aaf66fdd2b019b41e84ddfd8c281c6050f4947ad Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Wed, 13 May 2026 11:52:23 +0200 Subject: [PATCH 8/9] Add example multipanel_plots config for MRB-860 Working config exercising the new multipanel_plots feature: stage_E_realch1 vs stage_E_icon_1km_cutoff_edges_subgrid_horography against ICON-CH1/CH2 baselines, with a BIAS-by-season and an RMSE-by-init-hour 2x2 layout. Serves as a copy-paste starting point for new layouts. --- config/multipanel_example.yaml | 117 +++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 config/multipanel_example.yaml diff --git a/config/multipanel_example.yaml b/config/multipanel_example.yaml new file mode 100644 index 00000000..08060ec3 --- /dev/null +++ b/config/multipanel_example.yaml @@ -0,0 +1,117 @@ +# yaml-language-server: $schema=../workflow/tools/config.schema.json +description: | + Evaluate skill of Stage E with/without cutoff edges trained with and without subgrid orography. + +dates: + start: 2025-01-01T06:00 + end: 2025-12-26T00:00 + frequency: 30h + +runs: + + - forecaster: + inference_resources: + slurm_partition: normal-shared + checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/fd63e17043014af59170c7beca516b95 + label: stage_E_realch1 + steps: 0/120/6 + config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + extra_requirements: + - git+https://github.com/ecmwf/anemoi-inference.git@0.10.0 + + - forecaster: + inference_resources: + slurm_partition: normal-shared + checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad + label: stage_E_icon_1km_cutoff_edges_subgrid_horography + steps: 0/120/6 + config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + extra_requirements: + - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db + + # - forecaster: + # inference_resources: + # slurm_partition: normal-shared + # checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/57684b20f64f414b937cce10e5ceeb68 + # label: stage_E_realch1_new + # steps: 0/120/6 + # config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + # extra_requirements: + # - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db + + # - forecaster: + # inference_resources: + # slurm_partition: normal-shared + # checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/2265ae18b04e4470ab89314a85a822ae + # label: stage_E_icon_1km_cutoff_edges_KNN_5_dec + # steps: 0/120/6 + # config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + # extra_requirements: + # - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db + + +baselines: + - baseline: + baseline_id: ICON-CH1-EPS + label: ICON-CH1-ctrl + root: /scratch/mch/cmerker/ICON-CH1-EPS + steps: 0/33/6 + + - baseline: + baseline_id: ICON-CH2-EPS + label: ICON-CH2-ctrl + root: /scratch/mch/cmerker/ICON-CH2-EPS + steps: 0/120/6 + + +truth: + label: KENDA-CH1 + root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr + +stratification: + regions: + - jura + root: /scratch/mch/bhendj/regions/Prognoseregionen_LV95_20220517 + +dashboard: + stratification: + - season + +locations: + output_root: ./output + +profile: + executor: slurm + global_resources: + gpus: 16 + default_resources: + slurm_partition: "postproc" + cpus_per_task: 1 + mem_mb_per_cpu: 1800 + runtime: "1h" + gpus: 0 + jobs: 50 + batch_rules: + plot_frame: 32 + +multipanel_plots: + bias_overview: + rows: 2 + cols: 2 + figsize: [12, 8] + title: "BIAS vs lead time" + panels: + - {metric: BIAS, param: T_2M, season: all, title: "T_2M — all"} + - {metric: BIAS, param: T_2M, season: JJA, title: "T_2M — JJA"} + - {metric: BIAS, param: PMSL, season: all, title: "PMSL — all"} + - {metric: BIAS, param: PMSL, season: JJA, title: "PMSL — JJA"} + rmse_overview: + rows: 2 + cols: 2 + figsize: [12, 8] + title: "RMSE vs lead time" + panels: + - {metric: RMSE, param: T_2M, init_hour: -999, title: "T_2M — 00 UTC"} + - {metric: RMSE, param: T_2M, init_hour: 12, title: "T_2M — 12 UTC"} + - {metric: RMSE, param: PMSL, init_hour: -999, title: "PMSL — 00 UTC"} + - {metric: RMSE, param: PMSL, init_hour: 12, title: "PMSL — 12 UTC"} From eb9f3d750fcac9650006cdebe5e90a35f5074930 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Wed, 13 May 2026 13:24:48 +0200 Subject: [PATCH 9/9] Apply pre-commit fixes (trailing whitespace, ruff format) --- config/multipanel_example.yaml | 12 ++++++------ src/plotting/metric_lead_time_panel.py | 1 + src/verification/loading.py | 5 ++--- .../scripts/verification_plot_metrics_multipanel.py | 10 ++++++---- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/config/multipanel_example.yaml b/config/multipanel_example.yaml index 08060ec3..1b961786 100644 --- a/config/multipanel_example.yaml +++ b/config/multipanel_example.yaml @@ -10,7 +10,7 @@ dates: runs: - forecaster: - inference_resources: + inference_resources: slurm_partition: normal-shared checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/fd63e17043014af59170c7beca516b95 label: stage_E_realch1 @@ -20,7 +20,7 @@ runs: - git+https://github.com/ecmwf/anemoi-inference.git@0.10.0 - forecaster: - inference_resources: + inference_resources: slurm_partition: normal-shared checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad label: stage_E_icon_1km_cutoff_edges_subgrid_horography @@ -30,7 +30,7 @@ runs: - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db # - forecaster: - # inference_resources: + # inference_resources: # slurm_partition: normal-shared # checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/57684b20f64f414b937cce10e5ceeb68 # label: stage_E_realch1_new @@ -40,7 +40,7 @@ runs: # - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db # - forecaster: - # inference_resources: + # inference_resources: # slurm_partition: normal-shared # checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/2265ae18b04e4470ab89314a85a822ae # label: stage_E_icon_1km_cutoff_edges_KNN_5_dec @@ -48,7 +48,7 @@ runs: # config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml # extra_requirements: # - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db - + baselines: - baseline: @@ -56,7 +56,7 @@ baselines: label: ICON-CH1-ctrl root: /scratch/mch/cmerker/ICON-CH1-EPS steps: 0/33/6 - + - baseline: baseline_id: ICON-CH2-EPS label: ICON-CH2-ctrl diff --git a/src/plotting/metric_lead_time_panel.py b/src/plotting/metric_lead_time_panel.py index 96137a8d..8749e6b1 100644 --- a/src/plotting/metric_lead_time_panel.py +++ b/src/plotting/metric_lead_time_panel.py @@ -1,4 +1,5 @@ """Per-axes plotting helper for verification metrics vs. lead time.""" + import pandas as pd from matplotlib.axes import Axes diff --git a/src/verification/loading.py b/src/verification/loading.py index 153e3ec3..8ef07780 100644 --- a/src/verification/loading.py +++ b/src/verification/loading.py @@ -1,4 +1,5 @@ """Helpers for loading aggregated verification netCDFs into long-form DataFrames.""" + from pathlib import Path import pandas as pd @@ -58,9 +59,7 @@ def load_long_df(verif_files: list[Path]) -> pd.DataFrame: ds = xr.concat(dfs, dim="source", join="outer") nonspatial_vars = [d for d in ds.data_vars if "spatial" not in d] - df = ( - ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() - ) + df = ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() df[["param", "metric"]] = df["stack"].str.split(".", n=1, expand=True) df.drop(columns=["stack"], inplace=True) df["lead_time"] = df["lead_time"].dt.total_seconds() / 3600 diff --git a/workflow/scripts/verification_plot_metrics_multipanel.py b/workflow/scripts/verification_plot_metrics_multipanel.py index 4ef81bcb..ba5c495e 100644 --- a/workflow/scripts/verification_plot_metrics_multipanel.py +++ b/workflow/scripts/verification_plot_metrics_multipanel.py @@ -5,6 +5,7 @@ (``--spec_path /path/to/spec.json``). The spec schema mirrors ``MultipanelPlotSpec`` in ``src/evalml/config.py``. """ + import json import logging import string @@ -29,6 +30,7 @@ def _panel_label(idx: int) -> str: a, b = divmod(idx, len(letters)) return f"{letters[a - 1]}{letters[b]})" + LOG = logging.getLogger(__name__) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -55,9 +57,7 @@ def main(args: Namespace) -> None: color_map = source_color_map(all_df["source"].unique()) figsize = tuple(spec.get("figsize") or (4.5 * cols, 3.5 * rows)) - fig, axes = plt.subplots( - rows, cols, sharex=True, figsize=figsize, squeeze=False - ) + fig, axes = plt.subplots(rows, cols, sharex=True, figsize=figsize, squeeze=False) legend_entries: dict[str, object] = {} for idx, panel in enumerate(panels): @@ -76,7 +76,9 @@ def main(args: Namespace) -> None: if sub.empty: LOG.warning( "No data for panel %d (metric=%s, param=%s, region=%s, season=%s, init_hour=%s)", - idx, metric, param, + idx, + metric, + param, panel.get("region", "all"), panel.get("season", "all"), panel.get("init_hour", -999),