diff --git a/config/forecasters-ich1-oper-fixed.yaml b/config/forecasters-ich1-oper-fixed.yaml index 167c92d4..6a822367 100644 --- a/config/forecasters-ich1-oper-fixed.yaml +++ b/config/forecasters-ich1-oper-fixed.yaml @@ -70,6 +70,8 @@ experiment: # - init_hour # - region - season + scoremaps: + enabled: false locations: output_root: output/ diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml index bc726218..65d96ae2 100644 --- a/config/forecasters-ich1-oper.yaml +++ b/config/forecasters-ich1-oper.yaml @@ -67,6 +67,8 @@ experiment: # - init_hour # - region - season + scoremaps: + enabled: false locations: output_root: output/ diff --git a/config/forecasters-ich1.yaml b/config/forecasters-ich1.yaml index f290dd79..3ac7184a 100644 --- a/config/forecasters-ich1.yaml +++ b/config/forecasters-ich1.yaml @@ -79,6 +79,8 @@ experiment: # - init_hour # - region - season + scoremaps: + enabled: false locations: output_root: output/ diff --git a/config/varda-single-1.0.yaml b/config/varda-single-1.0.yaml index b9515033..c97e8d25 100644 --- a/config/varda-single-1.0.yaml +++ b/config/varda-single-1.0.yaml @@ -99,6 +99,8 @@ experiment: - "V_10M:RMSE,R2,ETS" - "T_2M:RMSE,R2,ETS" - "TOT_PREC:RMSE,R2,ETS" + scoremaps: + enabled: false showcase: params: diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 6c14005a..16a4cf2b 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -180,6 +180,12 @@ def _discover_icon_member_ids( def load_from_grib_file(file: str | list[str], sel_kwargs): + # Coerce Path objects to str: earthkit-data unwraps a single-element list + # into one File source without converting, and then fails on non-str paths. + if isinstance(file, (list, tuple)): + file = [str(f) for f in file] + else: + file = str(file) fieldlist = ekd.from_source("file", file, lazily=True).to_fieldlist() return fieldlist_to_xarray(fieldlist.sel(**sel_kwargs)) @@ -212,13 +218,40 @@ def fieldlist_to_xarray(fieldlist) -> xr.Dataset: return ds -def _tot_prec_handling(tp: xr.DataArray) -> xr.DataArray: +def _tot_prec_handling( + tp: xr.DataArray, requested_steps: list[int] | None = None +) -> xr.DataArray: _full_step_coord = tp["step"] # step coordinate before .diff() # anemoi-inference sometimes omits step 0 from the GRIB even with - # accumulate_from_start_of_forecast enabled. If missing, earthkit-data - # will fill it with NaNs following the `allow_holes=True` flag. - if tp[{"step": 0}].isnull().all(): + # accumulate_from_start_of_forecast enabled: the field may be absent from + # the step coordinate entirely, or present but NaN-filled by earthkit-data + # (allow_holes=True). With cumulative-from-start data the accumulation at + # the initial condition is identically zero, so synthesise it — but only + # when step 0 was actually requested (`requested_steps`); for window loads + # like [18, 24] the first step is real data and must not be treated as an + # initial condition. + if requested_steps is not None: + if 0 in requested_steps: + step0_idx = np.where(tp["step"].values == np.timedelta64(0, "ns"))[0] + if step0_idx.size == 0: + LOG.warning( + "Step 0 of TOT_PREC is missing from the GRIB, prepending " + "zeroes assuming accumulate_from_start_of_forecast is " + "enabled." + ) + zero = xr.zeros_like(tp.isel(step=[0])) + zero = zero.assign_coords(step=[np.timedelta64(0, "ns")]) + tp = xr.concat([zero, tp], dim="step") + elif tp[{"step": int(step0_idx[0])}].isnull().all(): + LOG.warning( + "Step 0 of TOT_PREC is all-NaN, filling with zeroes " + "assuming accumulate_from_start_of_forecast is enabled." + ) + tp[{"step": int(step0_idx[0])}] = 0.0 + elif tp[{"step": 0}].isnull().all(): + # Legacy path for callers that do not pass the requested steps: treat + # the first loaded step positionally as the initial condition. LOG.warning( "Step 0 of TOT_PREC is missing, filling with zeroes " "assuming accumulate_from_start_of_forecast is enabled." @@ -228,6 +261,12 @@ def _tot_prec_handling(tp: xr.DataArray) -> xr.DataArray: # Disaggregate TOT_PREC from cumulative-from-start (expected when the # accumulate_from_start_of_forecast post-processor is enabled in # anemoi-inference) to per-step accumulations. + if tp.sizes["step"] < 2: + raise ValueError( + "Cannot de-accumulate TOT_PREC: only a single step was loaded and " + "step 0 was not requested/synthesised, so no accumulation window " + "can be formed. Request the preceding step as well." + ) LOG.info( "Disaggregating TOT_PREC from cumulative-from-start to per-step accumulations." ) @@ -254,12 +293,27 @@ def _tot_prec_handling(tp: xr.DataArray) -> xr.DataArray: return tp -def load_forecast_data_from_grib(files: list[Path], params: list[str]) -> xr.Dataset: - """Load forecast data from a list of GRIB files.""" +def load_forecast_data_from_grib( + files: list[Path], params: list[str], steps: list[int] | None = None +) -> xr.Dataset: + """Load forecast data from a list of GRIB files (internal helper). + + External callers should use :func:`load_forecast_data`, which derives + `files` from `steps` and routes by source. This helper is the shared + low-level loader for the ML-grib and ICON-archive paths. + + `files` and `steps` are complementary, not redundant: + - `files` are the GRIB files that exist on disk (one per lead time). + - `steps` are the *requested* lead times, forwarded to the TOT_PREC + de-accumulation. They cannot be inferred from `files` alone: when step 0 + is requested, anemoi-inference omits the TOT_PREC step-0 field entirely + (no file exists), so it is synthesised as zero to form the first + accumulation window. `steps` carries that intent. + """ ds = load_from_grib_file(files, {"parameter.variable": params}) if "TOT_PREC" in ds.data_vars: - ds["TOT_PREC"] = _tot_prec_handling(ds["TOT_PREC"]) + ds["TOT_PREC"] = _tot_prec_handling(ds["TOT_PREC"], requested_steps=steps) return ds @@ -713,6 +767,7 @@ def load_icon_baseline_from_grib( root, reftime, steps, member_id=mid ), params=params, + steps=steps, ) if "number" in ds.dims: ds = ds.isel(number=0, drop=True) @@ -730,6 +785,7 @@ def load_icon_baseline_from_grib( return load_forecast_data_from_grib( files=_collect_icon_archive_files(root, reftime, steps, member_id=member), params=params, + steps=steps, ) @@ -751,6 +807,7 @@ def load_forecast_data( # NOTE: root is already for a specific reftime files=_collect_ml_grib_files(root, steps), params=params, + steps=steps, ) if "INCA" in root.parts: LOG.info("Loading INCA baseline from NetCDF files...") diff --git a/src/evalml/cli.py b/src/evalml/cli.py index 51a9ed45..f3df3bf4 100644 --- a/src/evalml/cli.py +++ b/src/evalml/cli.py @@ -146,7 +146,7 @@ def execute_workflow( if report and not dry_run: command += ["--report-after-run", "--report", str(report)] - command.append(target) + command += [target] command += list(extra_smk_args) if not verbose: command += ["--quiet", "rules"] # reduce verobosity of snakemake output @@ -165,7 +165,15 @@ def cli(): ) @workflow_options def experiment( - configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args + configfile, + cores, + verbose, + dry_run, + unlock, + report, + dag, + rulegraph, + extra_smk_args, ): execute_workflow( configfile, diff --git a/src/evalml/config.py b/src/evalml/config.py index bbe94543..ddfd0de7 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Dict, List, Any, ClassVar, FrozenSet, Optional -from pydantic import BaseModel, Field, RootModel, field_validator +from pydantic import BaseModel, Field, RootModel, field_validator, model_validator PROJECT_ROOT = Path(__file__).parents[2] @@ -219,6 +219,47 @@ class BaselineItem(BaseModel): baseline: BaselineConfig +class ScoreMapsConfig(BaseModel): + """Parameters controlling which score map plots are produced.""" + + enabled: bool = Field( + default=False, + description="Whether to produce score maps (computationally intensive).", + ) + params: List[str] = Field( + default=["T_2M"], + description=( + "List of parameters to plot. Supported values: T_2M, TD_2M, U_10M, V_10M, " + "PS, PMSL, TOT_PREC (native), and SP_10M (derived wind speed from U_10M/V_10M)." + ), + ) + leadtimes: List[int] = Field( + default=[6, 24], + description="List of lead times (hours) to plot.", + ) + scores: List[str] = Field( + default=["BIAS"], + description="List of verification scores to plot. Supported: BIAS, RMSE, MAE.", + ) + regions: List[str] = Field( + default=["switzerland"], + description="List of regions to plot (e.g. switzerland, centraleurope).", + ) + seasons: List[str] = Field( + default=["all"], + description="List of seasons to plot ('all', 'DJF', 'MAM', 'JJA', 'SON').", + ) + init_hours: List[str] = Field( + default=["all"], + description=( + "List of initialization hours to plot. Use 'all' for the unstratified " + "view, or zero-padded hour strings like '00', '06', '12', '18'." + ), + ) + + model_config = {"extra": "forbid"} + + class DomainConfig(BaseModel): """A custom map domain defined by name, extent, and projection.""" @@ -380,6 +421,10 @@ class ExperimentConfig(BaseModel): default=None, description="Scorecard generation configuration. Omit or set enabled: false to disable.", ) + scoremaps: Optional[ScoreMapsConfig] = Field( + default=None, + description="Score map plot configuration. Omit or set enabled: false to disable.", + ) @field_validator("thresholds") @classmethod @@ -502,6 +547,24 @@ class ConfigModel(BaseModel): description="Settings for the showcase workflow.", ) + @model_validator(mode="after") + def validate_scoremap_leadtimes(self) -> "ConfigModel": + sm = self.experiment.scoremaps + if sm is None or not sm.enabled: + return self + requested = set(sm.leadtimes) + for item in self.runs: + steps = getattr(item, next(iter(item.model_fields))).steps + start, end, step = map(int, steps.split("/")) + producible = set(range(start, end + 1, step)) + unsupported = requested - producible + if unsupported: + raise ValueError( + f"scoremaps.leadtimes contains {sorted(unsupported)} h which are not " + f"produced by participant with steps '{steps}'." + ) + return self + model_config = { "extra": "forbid", # fail on misspelled keys "populate_by_name": True, diff --git a/src/plotting/__init__.py b/src/plotting/__init__.py index 163bf74b..524834c6 100644 --- a/src/plotting/__init__.py +++ b/src/plotting/__init__.py @@ -39,7 +39,7 @@ def get_projection(name: str) -> "ccrs.Projection": "projection": _PROJECTIONS["orthographic"], }, "centraleurope": { - "extent": [-2.6, 19.5, 40.2, 52.3], + "extent": [-1.5, 18, 41.5, 51], "projection": _PROJECTIONS["orthographic"], }, "icon-ch": { diff --git a/src/plotting/colormap_defaults.py b/src/plotting/colormap_defaults.py index 88c065e6..1f1a4013 100644 --- a/src/plotting/colormap_defaults.py +++ b/src/plotting/colormap_defaults.py @@ -4,6 +4,7 @@ from matplotlib import pyplot as plt import warnings from .colormap_loader import load_ncl_colormap +import numpy as np def _fallback(): @@ -11,6 +12,19 @@ def _fallback(): return {"cmap": plt.get_cmap("viridis"), "norm": None, "units": ""} +# Sequential Reds shared by the error-magnitude scores (RMSE, MAE, STDE), which +# are all non-negative. Defined once per parameter under the generic +# "{param}.score.map" key; the score-map lookup falls back to it for any score +# without a dedicated "{param}.{score}.map" entry. To give a single score its +# own colours, add that explicit key — it takes precedence over this fallback. +_SCORE_REDS = {"cmap": plt.get_cmap("Reds", 6), "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3]} +_SCORE_REDS_PA = { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], +} +_SCORE_REDS_PRECIP = {"cmap": plt.get_cmap("Reds", 6), "levels": [0, 1, 1.5, 2, 3, 4]} + + _CMAP_DEFAULTS = { "SP": { "cmap": plt.get_cmap("coolwarm", 11), @@ -129,6 +143,63 @@ def _fallback(): 120.0, ], }, + # Sequential Reds for error-magnitude scores (RMSE, MAE, STDE): error is + # non-negative, larger ⇒ darker. Levels start at 0 so saturation maps + # directly to error magnitude; discrete levels make absolute values readable + # from the colour bar. Defined once per parameter under "{param}.score.map"; + # the lookup uses these for any score lacking a dedicated entry. The precip + # levels are a bit on the bright side, but kept consistent with the rest. + "U_10M.score.map": _SCORE_REDS | {"units": "m/s"}, + "V_10M.score.map": _SCORE_REDS | {"units": "m/s"}, + "SP_10M.score.map": _SCORE_REDS | {"units": "m/s"}, + "TD_2M.score.map": _SCORE_REDS | {"units": "°C"}, + "T_2M.score.map": _SCORE_REDS | {"units": "°C"}, + "PMSL.score.map": _SCORE_REDS_PA | {"units": "Pa"}, + "PS.score.map": _SCORE_REDS_PA | {"units": "Pa"}, + "TOT_PREC.score.map": _SCORE_REDS_PRECIP | {"units": "mm"}, + # Bias: + # diverging colour scheme for the Bias to reflect the nature of the data (can be positive or negative, symmetric). + # Red-Blue colour scheme for all variables except precipitation, where a Brown-Green scheme is more suggestive. + "U_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "V_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "SP_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "TD_2M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-2.75, stop=2.76, step=0.5), + } + | {"units": "°C"}, + "T_2M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-2.75, stop=2.76, step=0.5), + } + | {"units": "°C"}, + "PMSL.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-110, stop=111, step=20), + } + | {"units": "Pa"}, + "PS.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-110, stop=111, step=20), + } + | {"units": "Pa"}, + "TOT_PREC.BIAS.map": { + "cmap": plt.get_cmap("BrBG", 9), + "levels": [-1, -0.5, -0.25, -0.1, 0.1, 0.25, 0.5, 1], + } + | {"units": "mm"}, } CMAP_DEFAULTS = defaultdict(_fallback, _CMAP_DEFAULTS) diff --git a/tests/unit/test_data_input.py b/tests/unit/test_data_input.py new file mode 100644 index 00000000..912bad68 --- /dev/null +++ b/tests/unit/test_data_input.py @@ -0,0 +1,55 @@ +"""Unit tests for data_input TOT_PREC de-accumulation.""" + +import numpy as np +import pytest +import xarray as xr + +from data_input import _tot_prec_handling + + +def _cumulative_tp(steps_h, values): + """Build a cumulative-from-start TOT_PREC DataArray over `steps_h` (hours).""" + step = np.array([np.timedelta64(h, "h") for h in steps_h]).astype("timedelta64[ns]") + data = np.array(values, dtype=np.float64)[:, np.newaxis] * np.ones((1, 4)) + return xr.DataArray( + data, dims=("step", "values"), coords={"step": step}, name="TOT_PREC" + ) + + +def test_tot_prec_missing_step0_synthesised_when_requested(): + """Step 0 requested but absent from the GRIB -> zero IC is synthesised.""" + tp = _cumulative_tp([6], [3.0]) + out = _tot_prec_handling(tp, requested_steps=[0, 6]) + np.testing.assert_allclose(out.sel(step=np.timedelta64(6, "h")).values, 3.0) + + +def test_tot_prec_full_range_missing_step0(): + """Full-range load with missing step 0 keeps the first lead time.""" + tp = _cumulative_tp([6, 12, 18], [3.0, 5.0, 5.5]) + out = _tot_prec_handling(tp, requested_steps=[0, 6, 12, 18]) + np.testing.assert_allclose(out.sel(step=np.timedelta64(6, "h")).values, 3.0) + np.testing.assert_allclose(out.sel(step=np.timedelta64(12, "h")).values, 2.0) + np.testing.assert_allclose(out.sel(step=np.timedelta64(18, "h")).values, 0.5) + + +def test_tot_prec_window_without_step0_untouched(): + """A [18, 24] window must not be treated as starting at an IC.""" + tp = _cumulative_tp([18, 24], [5.0, 7.0]) + out = _tot_prec_handling(tp, requested_steps=[18, 24]) + np.testing.assert_allclose(out.sel(step=np.timedelta64(24, "h")).values, 2.0) + # First window step has no preceding accumulation -> NaN after reindex. + assert np.isnan(out.sel(step=np.timedelta64(18, "h")).values).all() + + +def test_tot_prec_step0_present_but_nan_is_zero_filled(): + """Step 0 present as an all-NaN hole (earthkit allow_holes) -> zero-filled.""" + tp = _cumulative_tp([0, 6], [np.nan, 3.0]) + out = _tot_prec_handling(tp, requested_steps=[0, 6]) + np.testing.assert_allclose(out.sel(step=np.timedelta64(6, "h")).values, 3.0) + + +def test_tot_prec_single_step_without_step0_raises(): + """A single loaded step with step 0 not requested cannot be de-accumulated.""" + tp = _cumulative_tp([6], [3.0]) + with pytest.raises(ValueError, match="Cannot de-accumulate TOT_PREC"): + _tot_prec_handling(tp, requested_steps=[6]) diff --git a/workflow/Snakefile b/workflow/Snakefile index 1e9ea971..550f88ce 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -130,6 +130,9 @@ onerror: # ----------------------------------------------------- +SCOREMAPS_CONFIGS = config.get("experiment", {}).get("scoremaps") or {} + + rule experiment_all: """Target rule for experiment workflow.""" input: @@ -158,6 +161,36 @@ rule experiment_all: ) else [] ), + ( + expand( + rules.plot_scoremaps.output, + run_id=list(CANDIDATES), + param=SCOREMAPS_CONFIGS["params"], + leadtime=SCOREMAPS_CONFIGS["leadtimes"], + score=SCOREMAPS_CONFIGS["scores"], + region=SCOREMAPS_CONFIGS["regions"], + season=SCOREMAPS_CONFIGS["seasons"], + init_hour=SCOREMAPS_CONFIGS["init_hours"], + experiment=EXPERIMENT_NAME, + ) + if SCOREMAPS_CONFIGS.get("enabled", False) + else [] + ), + ( + expand( + rules.plot_scoremaps_baseline.output, + baseline_id=list(BASELINES), + param=SCOREMAPS_CONFIGS["params"], + leadtime=SCOREMAPS_CONFIGS["leadtimes"], + score=SCOREMAPS_CONFIGS["scores"], + region=SCOREMAPS_CONFIGS["regions"], + season=SCOREMAPS_CONFIGS["seasons"], + init_hour=SCOREMAPS_CONFIGS["init_hours"], + experiment=EXPERIMENT_NAME, + ) + if SCOREMAPS_CONFIGS.get("enabled", False) + else [] + ), rule showcase_all: diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index c75da889..c8c05787 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -371,3 +371,39 @@ _scorecard = config.get("experiment", {}).get("scorecards") or {} SCORECARD_CONFIGS = ( _scorecard.get("sections", {}) if _scorecard.get("enabled", True) else {} ) + + +# Period-accumulated params verify a [lead - period, lead] window, so they have +# no value at lead times shorter than one step spacing (e.g. no 0h precip map). +# Short and canonical names both appear across the workflow (showcases vs maps). +ACCUMULATED_PARAMS = {"TOT_PREC", "tp"} + + +def resolve_leadtimes(steps_spec, requested="all", param=None): + """Lead times to compute for a single participant. + + A run or baseline produces only the lead times in its own ``steps`` spec + (``start/stop/step``, hours). This returns those of the ``requested`` + selection that the participant actually produces — the literal ``"all"`` + (every produced lead time) or an explicit list of ints — so a 36h lead is + never requested of an ICON-CH1 baseline (steps ``0/33/6``), nor a >120h + lead of ICON-CH2. Explicitly requested lead times the participant cannot + produce are skipped with a warning. For accumulated ``param``s, lead times + shorter than one step spacing are dropped (no accumulation window). + """ + start, end, step = map(int, steps_spec.split("/")) + supported = set(range(start, end + 1, step)) + wanted = supported if requested == "all" else set(requested) + + unsupported = sorted(wanted - supported) + if unsupported: + logging.getLogger("snakemake").warning( + "Skipping lead time(s) %sh: not produced by forecast steps '%s'.", + unsupported, + steps_spec, + ) + + valid = wanted & supported + if param in ACCUMULATED_PARAMS: + valid = {lt for lt in valid if lt >= step} + return sorted(valid) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 91ff1080..eb5bde22 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -132,12 +132,9 @@ rule plot_forecast_frame: def get_leadtimes(wc): - """Get all lead times from the run config.""" - start, end, step = map(int, RUN_CONFIGS[wc.run_id]["steps"].split("/")) - # skip lead time 0 for diagnostic variables - if wc.param in ["tp", "TOT_PREC"] and start == 0: - start += step - return [f"{i}" for i in range(start, end + 1, step)] + """Get all lead times the run produces (accumulated params skip lead 0).""" + leadtimes = resolve_leadtimes(RUN_CONFIGS[wc.run_id]["steps"], param=wc.param) + return [str(lt) for lt in leadtimes] rule make_forecast_animation: @@ -164,3 +161,50 @@ rule make_forecast_animation: """ convert -delay {params.delay} -loop 0 {input} {output} """ + + +rule plot_scoremaps: + # localrule: True + input: + script="workflow/scripts/plot_scoremaps.mo.py", + verif_file=OUT_ROOT + / f"data/runs/{{run_id}}/scoremaps/{{param}}_{{leadtime}}_{TRUTH_HASH}.nc", + output: + OUT_ROOT + / "results/{experiment}/scoremaps/runs/{run_id}/{param}_{score}_{region}_{season}_{init_hour}_{leadtime}.png", + log: + OUT_ROOT + / "logs/plot_scoremaps/{experiment}/{run_id}-{param}-{score}-{region}-{season}-{init_hour}-{leadtime}.log", + wildcard_constraints: + leadtime=r"\d+", # only digits + init_hour=r"all|\d{1,2}", + resources: + slurm_partition="postproc", + cpus_per_task=1, + runtime="10m", + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + uv run python {input.script} \ + --input {input.verif_file} --outfn {output[0]} --region {wildcards.region} \ + --param {wildcards.param} --leadtime {wildcards.leadtime} --score {wildcards.score} \ + --season {wildcards.season} --init_hour {wildcards.init_hour} >{log} 2>&1 + # interactive editing (needs to set localrule: True and use only one core) + # marimo edit {input.script} -- \ + # --input {input.verif_file} --outfn {output[0]} --region {wildcards.region} \ + # --param {wildcards.param} --leadtime {wildcards.leadtime} --score {wildcards.score} \ + # --season {wildcards.season} --init_hour {wildcards.init_hour} + """ + + +use rule plot_scoremaps as plot_scoremaps_baseline with: + input: + script="workflow/scripts/plot_scoremaps.mo.py", + verif_file=OUT_ROOT + / f"data/baselines/{{baseline_id}}/scoremaps/{{param}}_{{leadtime}}_{TRUTH_HASH}.nc", + output: + OUT_ROOT + / "results/{experiment}/scoremaps/baselines/{baseline_id}/{param}_{score}_{region}_{season}_{init_hour}_{leadtime}.png", + log: + OUT_ROOT + / "logs/plot_scoremaps/{experiment}/{baseline_id}-{param}-{score}-{region}-{season}-{init_hour}-{leadtime}.log", diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 4bae88e2..876afe63 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -169,3 +169,84 @@ rule verification_metrics_plot: """ uv run {input.script} {input.verif} --output_dir {output} >{log} 2>&1 """ + + +rule verification_scoremaps: + input: + "src/verification/__init__.py", + "src/data_input/__init__.py", + script="workflow/scripts/verification_scoremaps.py", + inference_okfiles=lambda wc: expand( + rules.inference_execute.output.okfile, + init_time=_restrict_reftimes_to_hours(REFTIMES), + allow_missing=True, + ), + truth=config["truth"]["root"], + output: + OUT_ROOT + / f"data/runs/{{run_id}}/scoremaps/{{param}}_{{leadtime}}_{TRUTH_HASH}.nc", + log: + OUT_ROOT + / f"logs/verification_scoremaps/{{run_id}}-{TRUTH_HASH}-{{param}}-{{leadtime}}.log", + resources: + cpus_per_task=2, + mem_mb=50_000, + runtime="60m", + # wildcard_constraints: + # run_id="^" # to avoid ambiguitiy with run_baseline_verif + # TODO: implement logic to use experiment name instead of run_id as wildcard + params: + fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), + fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + truth_label=config["truth"]["label"], + reftimes=" ".join(t.strftime("%Y%m%d%H%M") for t in REFTIMES), + run_root=lambda wc: (Path(OUT_ROOT) / f"data/runs/{wc.run_id}").resolve(), + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + uv run {input.script} \ + --run_root {params.run_root} \ + --reftimes {params.reftimes} \ + --truth {input.truth} \ + --step {wildcards.leadtime} \ + --steps "{params.fcst_steps}" \ + --param {wildcards.param} \ + --output {output} >{log} 2>&1 + """ + + +rule verification_scoremaps_baseline: + input: + "src/verification/__init__.py", + "src/data_input/__init__.py", + script="workflow/scripts/verification_scoremaps.py", + forecast=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["root"], + truth=config["truth"]["root"], + eckit_grids=rules.data_download_eckit_geo_grids.output, + output: + OUT_ROOT + / f"data/baselines/{{baseline_id}}/scoremaps/{{param}}_{{leadtime}}_{TRUTH_HASH}.nc", + log: + OUT_ROOT + / f"logs/verification_scoremaps_baseline/{{baseline_id}}-{TRUTH_HASH}-{{param}}-{{leadtime}}.log", + resources: + cpus_per_task=24, + mem_mb=50_000, + runtime="60m", + params: + baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], + member=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("member", "000"), + reftimes=" ".join(t.strftime("%Y%m%d%H%M") for t in REFTIMES), + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + uv run {input.script} \ + --baseline_root {input.forecast} \ + --reftimes {params.reftimes} \ + --truth {input.truth} \ + --step {wildcards.leadtime} \ + --steps "{params.baseline_steps}" \ + --param {wildcards.param} \ + --member "{params.member}" \ + --output {output} >{log} 2>&1 + """ diff --git a/workflow/scripts/plot_scoremaps.mo.py b/workflow/scripts/plot_scoremaps.mo.py new file mode 100644 index 00000000..5a19f9f3 --- /dev/null +++ b/workflow/scripts/plot_scoremaps.mo.py @@ -0,0 +1,231 @@ +import marimo + +__generated_with = "0.19.4" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import logging + from argparse import ArgumentParser + from pathlib import Path + + import earthkit.plots as ekp + import numpy as np + import xarray as xr + + from plotting import DOMAINS + from plotting import StatePlotter + from plotting.colormap_defaults import CMAP_DEFAULTS + + return ( + ArgumentParser, + CMAP_DEFAULTS, + DOMAINS, + Path, + StatePlotter, + ekp, + logging, + np, + xr, + ) + + +@app.cell +def _(logging): + LOG = logging.getLogger(__name__) + LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + logging.basicConfig(level=logging.INFO, format=LOG_FMT) + return (LOG,) + + +@app.cell +def _(ArgumentParser, Path, np): + parser = ArgumentParser() + + parser.add_argument( + "--input", + type=str, + default=None, + help="Directory to .nc data containing the error fields", + ) + parser.add_argument("--outfn", type=str, help="output filename") + parser.add_argument("--leadtime", type=str, help="leadtime") + parser.add_argument("--param", type=str, help="parameter") + parser.add_argument("--region", type=str, help="name of region") + parser.add_argument( + "--score", + type=str, + help="Evaluation Score. So far Bias, RMSE, MAE or STDE are implemented.", + ) + parser.add_argument("--season", type=str, default="all", help="season filter") + parser.add_argument( + "--init_hour", type=str, default="all", help="initialization hour filter" + ) + + args = parser.parse_args() + verif_file = Path(args.input) + outfn = Path(args.outfn) + lead_time = args.leadtime + param = args.param + region = args.region + season = args.season + init_hour = args.init_hour + score = args.score + + if isinstance(init_hour, str): + if init_hour == "all": + init_hour = -999 + else: + try: + init_hour = int(init_hour) + except ValueError as exc: + raise ValueError("init_hour must be 'all' or an integer hour") from exc + + lead_time = np.timedelta64(lead_time, "h") + return ( + init_hour, + lead_time, + outfn, + param, + region, + score, + season, + verif_file, + ) + + +@app.cell +def _(LOG, init_hour, param, score, season, verif_file, xr): + ds = xr.open_dataset(verif_file) + LOG.info("Opened dataset: %s", ds) + var = f"{param}.{score}" + LOG.info( + "Selecting variable '%s' for season '%s', init_hour=%s", var, season, init_hour + ) + ds = ds[var].sel(season=season, init_hour=init_hour) + LOG.info( + "Selected DataArray: dims=%s, shape=%s, dtype=%s", ds.dims, ds.shape, ds.dtype + ) + LOG.info( + "Value range: min=%.4g, max=%.4g, n_nan=%d", + float(ds.min()), + float(ds.max()), + int(ds.isnull().sum()), + ) + return (ds,) + + +@app.cell +def _(CMAP_DEFAULTS, ekp): + def get_style(param, score, units_override=None): + """Get style and colormap settings for the plot. + + earthkit-plots >= 1.0 expects ``Style.colors`` to be a list of + colours; Matplotlib ``Colormap`` objects from CMAP_DEFAULTS are + sampled into one colour per level interval. + """ + from matplotlib import colors as mcolors + + # Prefer a score-specific colormap; otherwise fall back to the generic + # per-parameter score colormap (shared by RMSE/MAE/STDE), and finally to + # the parameter's field colormap. + score_key = f"{param}.{score}.map" + generic_score_key = f"{param}.score.map" + if score_key in CMAP_DEFAULTS: + cfg = CMAP_DEFAULTS[score_key] + elif generic_score_key in CMAP_DEFAULTS: + cfg = CMAP_DEFAULTS[generic_score_key] + else: + cfg = CMAP_DEFAULTS.get(param, {}) + units = units_override if units_override is not None else cfg.get("units", "") + levels = cfg.get("bounds", cfg.get("levels", None)) + colors = cfg.get("colors", None) + cmap = cfg.get("cmap", None) + if colors is None and cmap is not None: + n = len(levels) - 1 if levels is not None else getattr(cmap, "N", 256) + colors = [mcolors.to_hex(cmap(i / max(n - 1, 1))) for i in range(n)] + return { + "style": ekp.styles.Style( + levels=levels, + extend="both", + units=units, + colors=colors, + ), + } + + return (get_style,) + + +@app.cell +def _( + DOMAINS, + LOG, + StatePlotter, + ds, + get_style, + init_hour, + lead_time, + np, + outfn, + param, + region, + score, + season, +): + # plot individual fields + + plotter = StatePlotter( + ds["longitude"].values.ravel(), + ds["latitude"].values.ravel(), + outfn.parent, + ) + fig = plotter.init_geoaxes( + nrows=1, + ncols=1, + projection=DOMAINS[region]["projection"], + bbox=DOMAINS[region]["extent"], + name=region, + size=(6, 6), + ) + subplot = fig.add_map(row=0, column=0) + + plot_vals = ds.values.ravel() + + style_kwargs = get_style(param, score) + LOG.info("style_kwargs: %s", style_kwargs) + + if np.all(np.isnan(plot_vals)): + LOG.warning( + "All values are NaN for %s %s season=%s — plotting empty map.", + param, + score, + season, + ) + import matplotlib.patches as mpatches + + subplot.ax.set_facecolor("#cccccc") + subplot.standard_layers() + grey_patch = mpatches.Patch(color="#cccccc", label="No data") + subplot.ax.legend(handles=[grey_patch], loc="lower left", fontsize=8) + else: + plotter.plot_field(subplot, plot_vals, **style_kwargs) + + # black coast lines and country borders for better visibility + # grey is hardly visible, especially when the shading colours are intense. + subplot.coastlines(edgecolor="black", linewidth=1.0, zorder=5) + subplot.borders(edgecolor="black", linewidth=0.5, zorder=5) + + init_hour_lbl = "all" if init_hour == -999 else f"{init_hour:02d}" + fig.title( + f"{score} of {param}, Season: {season}, " + f"Init hour: {init_hour_lbl}, Lead Time: {lead_time}" + ) + + fig.save(outfn, bbox_inches="tight", dpi=200) + LOG.info(f"saved: {outfn}") + return + + +if __name__ == "__main__": + app.run() diff --git a/workflow/scripts/verification_scoremaps.py b/workflow/scripts/verification_scoremaps.py new file mode 100644 index 00000000..e1a8dd56 --- /dev/null +++ b/workflow/scripts/verification_scoremaps.py @@ -0,0 +1,711 @@ +"""Compute spatial maps of temporally-aggregated forecast errors. + +For a fixed lead time and variable, iterates over all initialisation times +(discovered under a run directory, or taken from --reftimes for baselines), +loads the corresponding forecast field and the matching truth slice from a +reference zarr, maps the forecast onto the truth grid, and accumulates running +error statistics without ever holding the full time series in memory. The +final BIAS / RMSE / MAE / STDE maps are written to a NetCDF file. + +Forecasts load through data_input.load_forecast_data, which routes by source: +ML run directories (GRIB files), INCA (NetCDF archive), or otherwise the ICON +operational GRIB archive. Baselines (--baseline_root) use the latter two paths; +init times are not discovered from the archive but taken from --reftimes. Every +configured initialisation must be available across forecast and truth — a missing +one is a hard error, never a silent skip — so that run and baseline maps are +always computed over an identical sample. + +Design note: one Snakemake job per (run, param, lead time), each loading only the +step(s) it needs. We deliberately do not load all lead times at once: per-job +memory and output disk scale with N_leadtimes x grid size, which is infeasible at +interpolator (1 h) and nowcasting (10 min) resolutions; that cost is independent +of GRIB read speed, so it does not improve as loading gets faster. For TOT_PREC +the loader (data_input._tot_prec_handling) de-accumulates over the requested +[step - period, step] window, so we just select the target step. + +Usage +----- + uv run workflow/scripts/verification_scoremaps.py \\ + output/data/runs/ \\ + --truth /path/to/truth.zarr \\ + --step 24 \\ + --param T_2M +""" + +import logging +from argparse import ArgumentParser, Namespace +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import xarray as xr + +from data_input import load_forecast_data, parse_steps +from verification.spatial import map_forecast_to_truth + +LOG = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +DATETIME_FMT = "%Y%m%d%H%M" + +SEASONS = ["DJF", "MAM", "JJA", "SON", "all"] +# Init hour buckets. -999 is the "all" sentinel (matches verification_aggregation.py). +INIT_HOURS = [0, 6, 12, 18, -999] + + +def _season_of(dt: datetime) -> str: + """Return the meteorological season string for a given datetime.""" + month = dt.month + if month in (12, 1, 2): + return "DJF" + if month in (3, 4, 5): + return "MAM" + if month in (6, 7, 8): + return "JJA" + return "SON" + + +# Maps from standard parameter names to zarr variable names. +# COSMO-2e zarrs use short CF names; COSMO-1e zarrs keep the COSMO names. +_PARAMS_MAP_CO2 = { + "T_2M": "2t", + "TD_2M": "2d", + "U_10M": "10u", + "V_10M": "10v", + "PS": "sp", + "PMSL": "msl", + "TOT_PREC": "tp", +} +# Derived variables and the components they require. +_DERIVED = { + "SP_10M": ("U_10M", "V_10M"), +} + +# Params whose GRIB/zarr values are cumulative-from-start accumulations and must +# be de-accumulated over a [step - period, step] window before verification. +_ACCUMULATED_PARAMS = {"TOT_PREC"} + + +def _params_map(truth_root: Path, accum_h: int | None = None) -> dict[str, str]: + """Map canonical parameter names to truth-zarr variable names. + + COSMO-2e zarrs use short CF names. COSMO-1e / ICON zarrs store precip as + period accumulations named ``TOT_PREC_H``, where N is the accumulation + length in hours (matching the verification step spacing); pass it via + ``accum_h``. + """ + if "co2" in truth_root.name: + return _PARAMS_MAP_CO2 + suffix = f"TOT_PREC_{accum_h}H" if accum_h else "TOT_PREC_6H" + return {k: k.replace("TOT_PREC", suffix) for k in _PARAMS_MAP_CO2} + + +def _compute_derived(ds: xr.Dataset, param: str) -> xr.DataArray: + """Compute a derived variable from its components already present in *ds*.""" + if param == "SP_10M": + return (ds["U_10M"] ** 2 + ds["V_10M"] ** 2) ** 0.5 + raise ValueError(f"No recipe for derived variable '{param}'") + + +# --------------------------------------------------------------------------- +# Truth loading +# --------------------------------------------------------------------------- +# TODO: consolidate with src/data_input/__init__.py as part of the +# refactor/data-io branch. _open_zarr_component below duplicates +# ~80% of load_analysis_data_from_zarr but returns a lazy DataArray +# rather than a time-sliced Dataset, which is what our streaming +# aggregation needs. The right end-state is a shared lazy-open primitive +# in data_input that both consumers use; not introduced here to avoid +# conflicting with the data-io refactor. Until then this opener must +# mirror the loader's conventions (notably the m -> mm precip conversion +# from MRB-820). + + +def _open_zarr_component( + root: Path, param: str, accum_h: int | None = None +) -> xr.DataArray: + """Open a single native zarr variable lazily as a DataArray.""" + zarr_param = _params_map(root, accum_h)[param] + + ds = xr.open_zarr(root, consolidated=False) + ds = ds.set_index(time="dates") + + # Extract lat/lon before selecting on variable (they live on cell only). + spatial_dim = "cell" + lat = ds["latitudes"] if "latitudes" in ds else None + lon = ds["longitudes"] if "longitudes" in ds else None + + ds = ds.assign_coords(variable=ds.attrs["variables"]) + ds = ds.sel(variable=zarr_param).squeeze("ensemble", drop=True) + + # Recover 2-D spatial shape when stored as a flat cell dimension. + if len(ds.attrs["field_shape"]) == 2: + ny, nx = ds.attrs["field_shape"] + y_idx, x_idx = np.unravel_index(np.arange(ny * nx), (ny, nx)) + ds = ds.assign_coords(y=(spatial_dim, y_idx), x=(spatial_dim, x_idx)) + ds = ds.set_index(**{spatial_dim: ("y", "x")}).unstack(spatial_dim) + spatial_dim = None # now (y, x) + + da = ds["data"].rename(param).drop_vars("variable", errors="ignore") + + # Truth zarrs store precip in m (anemoi convention); all forecast loaders + # deliver canonical mm (kg m-2) since MRB-820, which put this conversion in + # load_analysis_data_from_zarr. Mirror it here until this opener is + # consolidated into data_input (refactor/data-io). Stays lazy (dask). + if param in _ACCUMULATED_PARAMS: + da = da * 1000 + + # Attach latitude/longitude as coordinates on the spatial dimension(s). + # Use the full names to match the forecast loader (load_forecast_data) and + # map_forecast_to_truth, which key on `latitude`/`longitude`. + if lat is not None and lon is not None: + if spatial_dim is not None: + # flat 1-D case: cell/values dim + da = da.assign_coords( + latitude=(spatial_dim, lat.values), + longitude=(spatial_dim, lon.values), + ) + else: + # 2-D case: lat/lon still on original flat index — attach via unstack + da = da.assign_coords( + latitude=(["y", "x"], lat.values.reshape(ny, nx)), + longitude=(["y", "x"], lon.values.reshape(ny, nx)), + ) + + return da + + +def open_truth_zarr(root: Path, param: str, accum_h: int | None = None) -> xr.DataArray: + """Open the truth zarr lazily and return a DataArray for *param*. + + For derived variables (e.g. SP_10M) the required components are loaded and + the derivation is applied on the fly. The returned DataArray has dimensions + ``(time, y, x)`` or ``(time, values)`` and always exposes ``latitude``/``longitude``. + ``accum_h`` selects the precip accumulation length (TOT_PREC_H). + """ + if param in _DERIVED: + components = { + c: _open_zarr_component(root, c, accum_h).drop_vars( + "variable", errors="ignore" + ) + for c in _DERIVED[param] + } + ds = xr.Dataset(components) + return _compute_derived(ds, param) + return _open_zarr_component(root, param, accum_h) + + +# --------------------------------------------------------------------------- +# Init-time discovery +# --------------------------------------------------------------------------- + + +def iter_init_dirs(run_root: Path) -> list[tuple[datetime, Path]]: + """Return ``(reftime, grib_dir)`` pairs for every complete init time. + + Expects subdirectories named ``YYYYMMDDHHMI`` directly under *run_root*. + GRIB files may live either directly in the init-time directory or inside a + ``grib/`` subdirectory. + """ + result = [] + for d in sorted(run_root.iterdir()): + if not d.is_dir(): + continue + try: + reftime = datetime.strptime(d.name, DATETIME_FMT) + except ValueError: + continue + grib_dir = d / "grib" if (d / "grib").is_dir() else d + if not any(grib_dir.glob("*.grib")): + LOG.debug("No GRIB files in %s, skipping", grib_dir) + continue + result.append((reftime, grib_dir)) + return result + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(args: Namespace) -> None: + LOG.info("=" * 60) + LOG.info("Spatial verification param=%s step=%dh", args.param, args.step) + LOG.info("Run root : %s", args.run_root) + LOG.info("Truth : %s", args.truth) + LOG.info("Output : %s", args.output) + LOG.info("=" * 60) + + # Accumulated params (TOT_PREC) are stored cumulative-from-start, while the + # truth is a period accumulation whose length equals the verification step + # spacing (e.g. 6h for steps "0/120/6"). Derive that period so we can (a) + # request the matching [step - period, step] window from the forecast loader + # and (b) read the matching TOT_PREC_H truth variable. We do not + # assume a fixed period; it follows the configured --steps. + accum_h: int | None = None + if args.param in _ACCUMULATED_PARAMS: + if not args.steps: + raise ValueError( + f"--steps is required for accumulated param '{args.param}' " + "(used to derive the accumulation period)." + ) + spacing = np.diff(parse_steps(args.steps)) + if spacing.size == 0: + raise ValueError( + f"Cannot derive an accumulation period from --steps '{args.steps}'." + ) + accum_h = int(spacing.min()) + if args.step < accum_h: + raise ValueError( + f"Lead time {args.step}h is smaller than the {accum_h}h " + f"accumulation period; cannot form a [step - period, step] " + f"window for '{args.param}'." + ) + req_steps = [args.step - accum_h, args.step] + LOG.info("Accumulation period: %dh (forecast window %s)", accum_h, req_steps) + + # INCA delivers native 1h precip sums and (unlike the GRIB paths, where + # the cumulative-from-start diff adapts to the requested window) cannot + # re-aggregate to a coarser period: the value at the target step would + # stay a 1h sum while the truth read is TOT_PREC_H — a silent + # mismatch. Re-aggregation in the loader is a planned follow-up. + if args.baseline_root and "INCA" in args.baseline_root.parts and accum_h != 1: + raise ValueError( + f"INCA provides native 1h accumulations only, but the step " + f"spacing of --steps '{args.steps}' implies a {accum_h}h " + f"accumulation period for '{args.param}'. Use 1h-spaced steps " + f"for INCA score maps." + ) + else: + req_steps = [args.step] + + # Open the truth zarr once; individual time slices are loaded on demand. + truth_da = open_truth_zarr(args.truth, args.param, accum_h) + # Normalise to datetime64[ns] so membership checks work regardless of zarr precision. + truth_da = truth_da.assign_coords( + time=truth_da.time.values.astype("datetime64[ns]") + ) + # Rename flat spatial dim to 'values' if the zarr uses 'cell'. + if "cell" in truth_da.dims: + truth_da = truth_da.rename({"cell": "values"}) + truth_times = set( + truth_da.time.values + ) # keep as datetime64, tolist() yields ints for ns precision + LOG.info("Truth opened lazily: %s", truth_da) + + if args.baseline_root: + # The operational archive is too large to enumerate up front, so the + # experiment's configured init times define the work list. Every one of + # them must be available: a baseline init missing from the archive is a + # hard error at load time (in the loop below), not a silent skip, so the + # baseline map covers the same sample as the run maps. + init_items = [ + (rt, None) + for rt in sorted(datetime.strptime(s, DATETIME_FMT) for s in args.reftimes) + ] + LOG.info("Using %d baseline init times from --reftimes", len(init_items)) + else: + init_items = iter_init_dirs(args.run_root) + LOG.info("Found %d init time directories", len(init_items)) + + # Restrict to the experiment's configured init times, and require that + # every configured init was actually discovered: a missing run output + # directory must fail rather than silently shrink the sample. + if args.reftimes: + wanted = {datetime.strptime(s, DATETIME_FMT) for s in args.reftimes} + discovered = {rt for rt, _ in init_items} + missing = sorted(wanted - discovered) + if missing: + raise ValueError( + f"{len(missing)} configured initialisation(s) have no GRIB " + f"output under {args.run_root}: " + f"{[m.strftime(DATETIME_FMT) for m in missing]}. All configured " + "initialisations must be available so that run and baseline " + "score maps are computed over an identical sample; blacklist " + "genuinely-absent dates in the experiment config." + ) + init_items = [(rt, d) for rt, d in init_items if rt in wanted] + LOG.info("Matched all %d configured init times", len(init_items)) + + step_td = timedelta(hours=args.step) + + # Every configured init must have a matching truth slice; a gap here would + # otherwise silently drop the init from the map. Check up front so the full + # set of missing valid times is reported at once rather than one per run. + required_valid_times = { + np.datetime64(rt + step_td).astype("datetime64[ns]") for rt, _ in init_items + } + missing_truth = sorted(required_valid_times - truth_times) + if missing_truth: + raise ValueError( + f"Truth is missing {len(missing_truth)} required valid time(s) for " + f"param={args.param}, step={args.step}h (e.g. " + f"{[str(t) for t in missing_truth[:5]]}). All configured " + "initialisations must be available so that run and baseline score " + "maps are computed over an identical sample; blacklist genuinely-" + "absent dates in the experiment config." + ) + + # Running accumulators keyed by (season, init_hour) – initialised on the + # first successfully processed sample so that we can infer the spatial + # shape from the data. Each entry is a numpy array over the spatial + # dimension(s). + bucket_keys = [(s, h) for s in SEASONS for h in INIT_HOURS] + accum_n: dict[tuple[str, int], np.ndarray | None] = {k: None for k in bucket_keys} + accum_sum_e: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + accum_sum_se: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + accum_sum_ae: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + ref_truth_slice: xr.DataArray | None = None # kept for output coordinates + + n_ok = 0 + + for reftime, grib_dir in init_items: + valid_time = np.datetime64(reftime + step_td).astype("datetime64[ns]") + + LOG.info( + "Processing reftime=%s valid=%s", + reftime.strftime(DATETIME_FMT), + valid_time, + ) + + first_iter = n_ok == 0 + + # --- load forecast --- + fct_params = ( + list(_DERIVED[args.param]) if args.param in _DERIVED else [args.param] + ) + + try: + # For accumulated params (TOT_PREC) req_steps is the [step - period, + # step] window; for GRIB sources (runs and the ICON archive) the + # loader de-accumulates the cumulative-from-start field over the + # requested steps (diff over `step`), so the target step holds the + # period accumulation; INCA returns native 1h sums, matching the + # period because accum_h == 1 is enforced above. Instantaneous + # params request a single step. The target step is selected just + # below. + # + # data_input._tot_prec_handling receives the requested steps and + # synthesises a zero initial condition when step 0 is requested but + # absent from the GRIB (anemoi-inference omits TOT_PREC at step 0), + # which makes the first-lead-time window [0, period] work for ML + # runs. Windows not containing step 0 are never zero-filled. + src_root = args.baseline_root if args.baseline_root else grib_dir + fcst = load_forecast_data( + src_root, reftime, req_steps, fct_params, member=args.member + ) + except Exception as exc: + raise RuntimeError( + f"Could not load forecast for initialisation " + f"{reftime.strftime(DATETIME_FMT)} (lead time {args.step}h) from " + f"{src_root}: {exc}. All configured initialisations must be " + "available so that run and baseline score maps are computed over " + "an identical sample; blacklist genuinely-absent dates in the " + "experiment config." + ) from exc + + # Select the target step. The earthkit loader returns forecasts over the + # requested steps with a `step` (timedelta64) dimension; for TOT_PREC the + # loader has already de-accumulated over the window, so the target step + # holds the period accumulation, and for instantaneous params only the + # single requested step is present. + if "step" in fcst.dims: + fcst = fcst.sel(step=np.timedelta64(args.step, "h")) + + # Compute derived variable if needed. + if args.param in _DERIVED: + fcst = fcst.assign({args.param: _compute_derived(fcst, args.param)}) + + if first_iter: + LOG.info("fcst (after step selection): %s", fcst) + fcst_raw = fcst[args.param].values if args.param in fcst else None + if fcst_raw is not None: + n_nan_fcst = int(np.isnan(fcst_raw).sum()) + LOG.info( + "fcst[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + fcst_raw.shape, + float(np.nanmin(fcst_raw)) + if n_nan_fcst < fcst_raw.size + else float("nan"), + float(np.nanmax(fcst_raw)) + if n_nan_fcst < fcst_raw.size + else float("nan"), + n_nan_fcst, + ) + + # --- load truth slice --- + truth_slice = truth_da.sel(time=valid_time).compute() + # For derived variables truth_da is already the derived DataArray, + # so wrap it in a Dataset for map_forecast_to_truth compatibility. + truth_ds = ( + truth_slice.to_dataset(name=args.param) + if isinstance(truth_slice, xr.DataArray) + else truth_slice + ) + + if first_iter: + truth_raw = truth_slice.values + n_nan_truth = int(np.isnan(truth_raw).sum()) + LOG.info( + "truth_slice[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + truth_raw.shape, + float(np.nanmin(truth_raw)) + if n_nan_truth < truth_raw.size + else float("nan"), + float(np.nanmax(truth_raw)) + if n_nan_truth < truth_raw.size + else float("nan"), + n_nan_truth, + ) + + # --- map forecast onto truth grid --- + try: + fcst_mapped = map_forecast_to_truth(fcst, truth_ds) + except Exception as exc: + raise RuntimeError( + f"Spatial mapping failed for initialisation " + f"{reftime.strftime(DATETIME_FMT)} (lead time {args.step}h): {exc}." + ) from exc + + fcst_param = fcst_mapped[args.param] + # Squeeze size-1 non-spatial dims so the error array is purely spatial. + # The earthkit loader keeps `number` (ensemble), `z` (vertical) and + # `forecast_reference_time` as size-1 dims for a deterministic surface run. + for dim in ["eps", "ensemble", "number", "z", "forecast_reference_time"]: + if dim in fcst_param.dims and fcst_param.sizes[dim] == 1: + fcst_param = fcst_param.squeeze(dim, drop=True) + fcst_vals = fcst_param.values + truth_vals = truth_slice.values + error = fcst_vals - truth_vals # shape: spatial dims of truth + + if first_iter: + n_nan_mapped = int(np.isnan(fcst_vals).sum()) + LOG.info( + "fcst_mapped[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + fcst_vals.shape, + float(np.nanmin(fcst_vals)) + if n_nan_mapped < fcst_vals.size + else float("nan"), + float(np.nanmax(fcst_vals)) + if n_nan_mapped < fcst_vals.size + else float("nan"), + n_nan_mapped, + ) + n_nan_err = int(np.isnan(error).sum()) + LOG.info( + "error: shape=%s, min=%.4g, max=%.4g, n_nan=%d / %d", + error.shape, + float(np.nanmin(error)) if n_nan_err < error.size else float("nan"), + float(np.nanmax(error)) if n_nan_err < error.size else float("nan"), + n_nan_err, + error.size, + ) + + n_nan_error = int(np.isnan(error).sum()) + if n_nan_error == error.size: + LOG.warning( + "reftime=%s: error is all-NaN (%d points) — nothing accumulated.", + reftime.strftime(DATETIME_FMT), + error.size, + ) + + # --- initialise accumulators on first valid sample --- + if accum_n[("all", -999)] is None: + for k in bucket_keys: + accum_n[k] = np.zeros(error.shape, dtype=np.int64) + accum_sum_e[k] = np.zeros(error.shape, dtype=np.float64) + accum_sum_se[k] = np.zeros(error.shape, dtype=np.float64) + accum_sum_ae[k] = np.zeros(error.shape, dtype=np.float64) + ref_truth_slice = truth_slice + + # --- accumulate into matching (season, init_hour) buckets, plus the + # "all" rows/cols on each axis (NaN-safe) --- + season = _season_of(reftime) + ih = reftime.hour + valid = ~np.isnan(error) + for s in (season, "all"): + for h in (ih, -999): + accum_n[(s, h)][valid] += 1 + accum_sum_e[(s, h)][valid] += error[valid] + accum_sum_se[(s, h)][valid] += error[valid] ** 2 + accum_sum_ae[(s, h)][valid] += np.abs(error[valid]) + n_ok += 1 + + LOG.info("Finished: %d init times processed", n_ok) + + if n_ok == 0: + raise ValueError( + "No initialisations were processed — nothing to write. Check that " + "--reftimes is non-empty." + ) + + # --- compute aggregate maps per (season, init_hour), then stack --- + spatial_coords = { + c: ref_truth_slice[c] + for c in ref_truth_slice.coords + if set(ref_truth_slice[c].dims).issubset(set(ref_truth_slice.dims)) + and c != "time" + } + spatial_dims = list(ref_truth_slice.dims) + out_dims = ["season", "init_hour"] + spatial_dims + out_coords = {"season": SEASONS, "init_hour": INIT_HOURS, **spatial_coords} + + def _strat_da(compute_fn) -> xr.DataArray: + """Stack per-(season, init_hour) arrays into a (season, init_hour, *spatial) DataArray.""" + out_shape = (len(SEASONS), len(INIT_HOURS)) + ref_truth_slice.shape + arr = np.empty(out_shape, dtype=np.float32) + for i, s in enumerate(SEASONS): + for j, h in enumerate(INIT_HOURS): + n = accum_n[(s, h)] + with np.errstate(invalid="ignore", divide="ignore"): + arr[i, j] = compute_fn(n, s, h).astype(np.float32) + return xr.DataArray(arr, dims=out_dims, coords=out_coords) + + out = xr.Dataset( + { + f"{args.param}.BIAS": _strat_da( + lambda n, s, h: np.where(n > 0, accum_sum_e[(s, h)] / n, np.nan) + ), + f"{args.param}.RMSE": _strat_da( + lambda n, s, h: np.where( + n > 0, np.sqrt(accum_sum_se[(s, h)] / n), np.nan + ) + ), + f"{args.param}.MAE": _strat_da( + lambda n, s, h: np.where(n > 0, accum_sum_ae[(s, h)] / n, np.nan) + ), + f"{args.param}.STDE": _strat_da( + lambda n, s, h: np.where( + n > 0, + np.sqrt( + np.maximum( + accum_sum_se[(s, h)] / n - (accum_sum_e[(s, h)] / n) ** 2, + 0.0, + ) + ), + np.nan, + ) + ), + f"{args.param}.N": _strat_da(lambda n, s, h: np.where(n > 0, n, np.nan)), + }, + attrs={ + "param": args.param, + "step_h": args.step, + # Accumulation period of the verified quantity (accumulated params + # only) — lets consumers tell a 1h INCA map from a 6h ICON map. + "accum_h": accum_h if accum_h is not None else "n/a", + "member": args.member, + "source": str(args.baseline_root if args.baseline_root else args.run_root), + "n_processed": n_ok, + }, + ) + + LOG.info("Output dataset:\n%s", out) + args.output.parent.mkdir(parents=True, exist_ok=True) + out.to_netcdf(args.output) + LOG.info("Saved to %s", args.output) + + +if __name__ == "__main__": + parser = ArgumentParser( + description=( + "Compute spatial maps of temporally-aggregated forecast errors. " + "Supports model runs (GRIB) and baselines (ICON GRIB archive or " + "INCA NetCDF archive). " + "Exactly one of --run_root or --baseline_root must be provided." + ) + ) + parser.add_argument( + "--run_root", + type=Path, + default=None, + help="Root directory of a model run (e.g. output/data/runs/).", + ) + parser.add_argument( + "--baseline_root", + type=Path, + default=None, + help=( + "Root directory of a baseline archive (e.g. the ICON-CH1/CH2-EPS " + "operational GRIB archive, or an INCA NetCDF archive). Requires " + "--reftimes." + ), + ) + parser.add_argument( + "--member", + type=str, + default="000", + help=( + "Ensemble member to load for ICON baselines: '000' for control, " + "'median' for the pre-computed median, 'mean' to average all " + "members, or any 3-digit member ID. Ignored for runs and INCA." + ), + ) + parser.add_argument( + "--truth", + type=Path, + required=True, + help="Path to the reference zarr dataset.", + ) + parser.add_argument( + "--step", + type=int, + required=True, + help="Forecast lead time in hours (e.g. 24).", + ) + parser.add_argument( + "--param", + type=str, + required=True, + help="Variable to verify (e.g. T_2M, TD_2M, U_10M).", + ) + parser.add_argument( + "--steps", + type=str, + default=None, + help=( + "Forecast step spec 'start/stop/step' (e.g. '0/120/6'). Required for " + "accumulated params (TOT_PREC): the accumulation period is the step " + "spacing, the forecast is accumulated over [step - period, step], and " + "the matching TOT_PREC_H truth variable is read. Ignored for " + "instantaneous params." + ), + ) + parser.add_argument( + "--reftimes", + nargs="+", + default=None, + help=( + "List of init times (YYYYMMDDHHMM). For runs: optional restriction of " + "the discovered init-time directories. For baselines: required; " + "defines the init times to load from the archive." + ), + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output NetCDF file.", + ) + args = parser.parse_args() + + if bool(args.run_root) == bool(args.baseline_root): + parser.error("Exactly one of --run_root or --baseline_root must be provided.") + if args.baseline_root and not args.reftimes: + parser.error( + "--reftimes is required with --baseline_root: init times cannot be " + "discovered from the operational archive." + ) + + main(args) diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index 03671aa1..37c95a59 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -289,6 +289,18 @@ ], "default": null, "description": "Scorecard generation configuration. Omit or set enabled: false to disable." + }, + "scoremaps": { + "anyOf": [ + { + "$ref": "#/$defs/ScoreMapsConfig" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Score map plot configuration. Omit or set enabled: false to disable." } }, "required": [ @@ -594,6 +606,87 @@ "title": "Profile", "type": "object" }, + "ScoreMapsConfig": { + "additionalProperties": false, + "description": "Parameters controlling which score map plots are produced.", + "properties": { + "enabled": { + "default": false, + "description": "Whether to produce score maps (computationally intensive).", + "title": "Enabled", + "type": "boolean" + }, + "params": { + "default": [ + "T_2M" + ], + "description": "List of parameters to plot. Supported values: T_2M, TD_2M, U_10M, V_10M, PS, PMSL, TOT_PREC (native), and SP_10M (derived wind speed from U_10M/V_10M).", + "items": { + "type": "string" + }, + "title": "Params", + "type": "array" + }, + "leadtimes": { + "default": [ + 6, + 24 + ], + "description": "List of lead times (hours) to plot.", + "items": { + "type": "integer" + }, + "title": "Leadtimes", + "type": "array" + }, + "scores": { + "default": [ + "BIAS" + ], + "description": "List of verification scores to plot. Supported: BIAS, RMSE, MAE.", + "items": { + "type": "string" + }, + "title": "Scores", + "type": "array" + }, + "regions": { + "default": [ + "switzerland" + ], + "description": "List of regions to plot (e.g. switzerland, centraleurope).", + "items": { + "type": "string" + }, + "title": "Regions", + "type": "array" + }, + "seasons": { + "default": [ + "all" + ], + "description": "List of seasons to plot ('all', 'DJF', 'MAM', 'JJA', 'SON').", + "items": { + "type": "string" + }, + "title": "Seasons", + "type": "array" + }, + "init_hours": { + "default": [ + "all" + ], + "description": "List of initialization hours to plot. Use 'all' for the unstratified view, or zero-padded hour strings like '00', '06', '12', '18'.", + "items": { + "type": "string" + }, + "title": "Init Hours", + "type": "array" + } + }, + "title": "ScoreMapsConfig", + "type": "object" + }, "ScorecardConfig": { "additionalProperties": false, "description": "Configuration for a single named scorecard.",