From c2cc97810fb2ec77e82e63a0c6803dc2a9da8c64 Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:26:45 +0200 Subject: [PATCH 01/21] feat(data_input): add jretrievedwh subprocess wrapper + station catalog --- src/data_input/jretrieve.py | 245 +++++++++++++++++++++++++++++++++++ tests/unit/test_jretrieve.py | 37 ++++++ 2 files changed, 282 insertions(+) create mode 100644 src/data_input/jretrieve.py create mode 100644 tests/unit/test_jretrieve.py diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py new file mode 100644 index 00000000..d5df7720 --- /dev/null +++ b/src/data_input/jretrieve.py @@ -0,0 +1,245 @@ +"""Subprocess wrapper around ``jretrievedwh.py`` for retrieving SwissMetNet +(SNM) surface observations from the MeteoSwiss data warehouse (DWH). + +Ported/adapted from MeteoSwiss/anemoi-plugins-meteoswiss (add-synop-dwh-source). +Requires ``jretrievedwh.py`` on $PATH and $OPR_HOME set with a readable +``.jretrievedwh-conf..py`` conf file. +""" + +from __future__ import annotations + +import logging +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from datetime import datetime +from io import StringIO +from typing import Any, Sequence + +import numpy as np +import pandas as pd + +LOG = logging.getLogger(__name__) + +BINARY_NAME = "jretrievedwh.py" +VALID_STAGES = {"prod", "depl", "devt"} +DEFAULT_META_FIELDS: tuple[str, ...] = ("lat", "lon", "elev", "name", "nat_abbr") +DEFAULT_GROUP = "SwissMetNet" +CATALOG_TIME_RANGE_START = datetime(1900, 1, 1) +CATALOG_TIME_RANGE_END = datetime(2100, 12, 31, 23, 59) + + +class JretrieveError(RuntimeError): + """Raised when jretrievedwh.py fails or returns malformed output.""" + + +def _resolve_binary() -> str: + path = shutil.which(BINARY_NAME) + if path is None: + raise JretrieveError( + f"{BINARY_NAME} not found in $PATH. " + "Make sure /oprusers/osm/opr.inn/bin (or equivalent) is on your PATH." + ) + return path + + +def _build_env(stage: str) -> dict[str, str]: + if stage not in VALID_STAGES: + raise ValueError(f"Invalid stage {stage!r}. Must be one of {sorted(VALID_STAGES)}.") + opr_home = os.environ.get("OPR_HOME") + if not opr_home: + raise JretrieveError("OPR_HOME is not set; cannot locate jretrieve conf file.") + conf_name = f".jretrievedwh-conf.{stage}.py" + conf_path = os.path.join(opr_home, conf_name) + if not os.path.isfile(conf_path): + raise JretrieveError(f"jretrieve conf file not found: {conf_path}") + if not os.access(conf_path, os.R_OK): + raise JretrieveError(f"jretrieve conf file not readable: {conf_path}") + env = os.environ.copy() + env["JRETRIEVE_CONF_DIR"] = opr_home + env["JRETRIEVE_CONF_NAME"] = conf_name + return env + + +def _fmt_time(dt: datetime) -> str: + return dt.strftime("%Y%m%d%H%M") + + +def _stations_to_argv(stations: dict[str, Any]) -> list[str]: + """Translate a station selection dict to jretrieve CLI args. + + Exactly one of {group, locations, bbox} must be set. + """ + keys = [k for k in ("group", "locations", "bbox") if stations.get(k) is not None] + if len(keys) != 1: + raise ValueError( + f"stations must specify exactly one of group/locations/bbox, got {keys}" + ) + key = keys[0] + val = stations[key] + if key == "group": + return ["-a", f"stn_group,{val}"] + if key == "locations": + if isinstance(val, str): + val = [v for v in val.split(",") if v] + if not isinstance(val, Sequence): + raise ValueError("stations.locations must be a list of nat_abbr strings.") + return ["-i", "nat_abbr," + ",".join(str(v) for v in val)] + if key == "bbox": + if isinstance(val, str): + val = [v for v in val.split(",") if v] + if len(val) != 4: + raise ValueError("stations.bbox must be [minlat, maxlat, minlon, maxlon].") + return ["-l", ",".join(str(v) for v in val)] + raise AssertionError("unreachable") + + +def parse_selection(root: Any) -> tuple[dict[str, Any], str, str]: + """Parse a truth-root marker into (stations, stage, seq_type). + + Examples (slash-free so they survive ``Path()`` normalisation): + ``jretrievedwh:SwissMetNet`` -> group + ``jretrievedwh:group=SwissMetNet;stage=devt`` + ``jretrievedwh:locations=ARO,KLO`` + ``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` + """ + _, _, rest = str(root).partition(":") + rest = rest.strip() + stations: dict[str, Any] = {} + stage = "prod" + seq_type = "surface" + for i, part in enumerate([p for p in rest.split(";") if p]): + if "=" not in part: + if i == 0: + stations["group"] = part + continue + raise ValueError(f"Invalid jretrieve selector fragment: {part!r}") + key, _, value = part.partition("=") + key, value = key.strip(), value.strip() + if key in ("group", "locations", "bbox"): + stations[key] = value + elif key == "stage": + stage = value + elif key == "seq_type": + seq_type = value + else: + raise ValueError(f"Unknown jretrieve selector key: {key!r}") + if not stations: + stations = {"group": DEFAULT_GROUP} + return stations, stage, seq_type + + +def _run(argv: list[str], env: dict[str, str], timeout_s: int) -> str: + try: + proc = subprocess.run( + argv, env=env, capture_output=True, text=True, timeout=timeout_s, check=False + ) + except subprocess.TimeoutExpired as e: + raise JretrieveError(f"jretrieve timed out after {timeout_s}s: {' '.join(argv)}") from e + if proc.returncode != 0: + raise JretrieveError( + f"jretrieve exited with {proc.returncode}\nargv: {argv}\n" + f"stderr: {proc.stderr.strip()}\nstdout (head): {proc.stdout[:500]}" + ) + if proc.stdout.lstrip().startswith("ERROR"): + raise JretrieveError(f"jretrieve returned error: {proc.stdout.strip()[:500]}") + return proc.stdout + + +def _run_with_retry(argv, env, timeout_s, attempts=3) -> str: + last_err: Exception | None = None + for attempt in range(1, attempts + 1): + try: + return _run(argv, env=env, timeout_s=timeout_s) + except JretrieveError as e: + last_err = e + if attempt == attempts: + break + backoff = 2**attempt + LOG.warning( + "jretrieve attempt %d/%d failed (%s); retrying in %ds", + attempt, attempts, e, backoff, + ) + time.sleep(backoff) + assert last_err is not None + raise last_err + + +def _parse_csv(csv_text: str) -> pd.DataFrame: + csv_text = csv_text.strip() + if not csv_text: + return pd.DataFrame() + return pd.read_csv(StringIO(csv_text), sep=";") + + +def fetch_meta( + *, stations, params, seq_type="surface", stage="prod", + meta_fields=DEFAULT_META_FIELDS, timeout_s=300, +) -> pd.DataFrame: + """Fetch the station catalog (rows per station x parameter x period) over a + fixed wide time range so the response is deterministic.""" + if not params: + raise ValueError("params must be non-empty.") + argv = [ + _resolve_binary(), "-s", seq_type, "-n", ",".join(params), + "-t", f"{_fmt_time(CATALOG_TIME_RANGE_START)},{_fmt_time(CATALOG_TIME_RANGE_END)}", + "--meta-info", ",".join(meta_fields), "--format", "csv", + *_stations_to_argv(stations), + ] + LOG.info("jretrieve meta: %s", " ".join(argv)) + df = _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) + if df.empty: + raise JretrieveError("jretrieve meta-info returned no rows.") + return df + + +def fetch_data( + *, stations, params, start, end, increment_minutes=60, + seq_type="surface", stage="prod", timeout_s=600, +) -> pd.DataFrame: + """Fetch observation data; columns: station (int), termin (YYYYMMDDhhmmss), + one column per requested short name.""" + if not params: + raise ValueError("params must be non-empty.") + argv = [ + _resolve_binary(), "-s", seq_type, "-n", ",".join(params), + "-t", f"{_fmt_time(start)},{_fmt_time(end)},{int(increment_minutes)}", + "--format", "csv", *_stations_to_argv(stations), + ] + LOG.info("jretrieve data: %s", " ".join(argv)) + return _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) + + +@dataclass(frozen=True) +class StationCatalog: + """Stable, nat_abbr-sorted station catalog used as the cell axis.""" + + nat_abbr: np.ndarray + station_id: np.ndarray + latitude: np.ndarray + longitude: np.ndarray + elevation: np.ndarray + name: np.ndarray + + @property + def n(self) -> int: + return len(self.nat_abbr) + + @classmethod + def from_meta(cls, meta: pd.DataFrame) -> "StationCatalog": + per_station = ( + meta.sort_values(["nat_abbr", "parameter", "op_since"], kind="stable") + .drop_duplicates(subset=["station"], keep="first") + .sort_values("nat_abbr", kind="stable") + .reset_index(drop=True) + ) + return cls( + nat_abbr=per_station["nat_abbr"].to_numpy(dtype=object), + station_id=per_station["station"].to_numpy(dtype=np.int64), + latitude=per_station["latitude"].to_numpy(dtype=np.float64), + longitude=per_station["longitude"].to_numpy(dtype=np.float64), + elevation=per_station["elev"].to_numpy(dtype=np.float64), + name=per_station["stn_name"].to_numpy(dtype=object), + ) diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py new file mode 100644 index 00000000..2dda5a52 --- /dev/null +++ b/tests/unit/test_jretrieve.py @@ -0,0 +1,37 @@ +import numpy as np +import pandas as pd +import pytest + +from data_input import jretrieve as jr + + +def test_stations_to_argv_group(): + assert jr._stations_to_argv({"group": "SwissMetNet"}) == ["-a", "stn_group,SwissMetNet"] + + +def test_stations_to_argv_locations_from_string(): + assert jr._stations_to_argv({"locations": "ARO,KLO"}) == ["-i", "nat_abbr,ARO,KLO"] + + +def test_stations_to_argv_bbox_from_string(): + assert jr._stations_to_argv({"bbox": "45.8,47.8,5.9,10.5"}) == [ + "-l", "45.8,47.8,5.9,10.5", + ] + + +def test_stations_to_argv_rejects_ambiguous(): + with pytest.raises(ValueError, match="exactly one"): + jr._stations_to_argv({"group": "x", "bbox": "1,2,3,4"}) + + +def test_parse_selection_default_group(): + assert jr.parse_selection("jretrievedwh:") == ({"group": "SwissMetNet"}, "prod", "surface") + assert jr.parse_selection("jretrievedwh:SwissMetNet") == ( + {"group": "SwissMetNet"}, "prod", "surface", + ) + + +def test_parse_selection_keyvalue_and_stage(): + assert jr.parse_selection("jretrievedwh:locations=ARO,KLO;stage=devt") == ( + {"locations": "ARO,KLO"}, "devt", "surface", + ) From fdf18c4a7a86c1ab75ed165ca4914e0d0b6f9c54 Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:27:34 +0200 Subject: [PATCH 02/21] test(data_input): cover StationCatalog.from_meta --- tests/unit/test_jretrieve.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index 2dda5a52..f54b6057 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -35,3 +35,25 @@ def test_parse_selection_keyvalue_and_stage(): assert jr.parse_selection("jretrievedwh:locations=ARO,KLO;stage=devt") == ( {"locations": "ARO,KLO"}, "devt", "surface", ) + + +def _sample_meta(): + return pd.DataFrame({ + "station": [2, 1, 1], + "op_since": [19900101000000, 19800101000000, 19800101000000], + "op_till": ["", "", ""], + "parameter": ["tre200s0", "fkl010z0", "tre200s0"], + "latitude": [47.48, 46.79, 46.79], + "longitude": [8.54, 9.68, 9.68], + "elev": [426.0, 1878.0, 1878.0], + "stn_name": ["Zurich", "Arosa", "Arosa"], + "nat_abbr": ["KLO", "ARO", "ARO"], + }) + + +def test_station_catalog_from_meta_collapses_and_sorts(): + cat = jr.StationCatalog.from_meta(_sample_meta()) + assert cat.n == 2 + assert list(cat.nat_abbr) == ["ARO", "KLO"] # sorted by nat_abbr + assert list(cat.station_id) == [1, 2] + np.testing.assert_allclose(cat.latitude, [46.79, 47.48]) From 49479e9f483a9b11e82295f897a08b532a6a7cba Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:31:02 +0200 Subject: [PATCH 03/21] feat(data_input): implement load_obs_data_from_jretrieve --- src/data_input/__init__.py | 122 +++++++++++++++++++++++++++++++++++ tests/unit/test_jretrieve.py | 42 ++++++++++++ 2 files changed, 164 insertions(+) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 6c14005a..63dbc75e 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -5,6 +5,7 @@ import earthkit.data as ekd import numpy as np +import pandas as pd import xarray as xr from pyproj import Transformer @@ -325,6 +326,120 @@ def load_obs_data_from_peakweather( return _select_valid_times(ds, times) +DWH_PARAM_MAP = { + "T_2M": "tre200s0", + "TD_2M": "tde200s0", + "PS": "prestas0", + "PMSL": "pp0qffs0", + "TOT_PREC": "rre150h0", + "FF_10M": "fkl010z0", + "DD_10M": "dkl010z0", + "VMAX_10M": "fkl010z1", +} +DWH_WIND_SPEED = "fkl010z0" +DWH_WIND_DIR = "dkl010z0" +DWH_CELSIUS_TO_KELVIN = {"tre200s0", "tde200s0"} +DWH_HPA_TO_PA = {"prestas0", "pp0qffs0"} + + +def _jretrieve_df_to_xarray(df, short_names, catalog) -> xr.Dataset: + """Pivot long-form jretrieve obs into a (time, values) cube aligned to the + catalog, NaN-filled for missing cells.""" + station_to_idx = {sid: i for i, sid in enumerate(catalog.station_id)} + if df.empty: + time_index = pd.DatetimeIndex([]) + else: + df = df.copy() + df["time"] = pd.to_datetime(df["termin"].astype(str), format="%Y%m%d%H%M%S") + time_index = pd.DatetimeIndex(sorted(df["time"].unique())) + n_t, n_s = len(time_index), catalog.n + coords = { + "time": ("time", time_index.values.astype("datetime64[ns]")), + "values": ("values", catalog.nat_abbr), + "latitude": ("values", catalog.latitude), + "longitude": ("values", catalog.longitude), + } + data_vars: dict[str, tuple] = {} + if df.empty: + for p in short_names: + data_vars[p] = (("time", "values"), np.full((n_t, n_s), np.nan, np.float32)) + else: + time_to_idx = {t: i for i, t in enumerate(time_index)} + df["_si"] = df["station"].map(station_to_idx) + df["_ti"] = df["time"].map(time_to_idx) + df = df.dropna(subset=["_si", "_ti"]) + df["_si"] = df["_si"].astype(int) + df["_ti"] = df["_ti"].astype(int) + for p in short_names: + arr = np.full((n_t, n_s), np.nan, dtype=np.float32) + if p in df.columns: + arr[df["_ti"].to_numpy(), df["_si"].to_numpy()] = df[p].to_numpy( + dtype=np.float32 + ) + data_vars[p] = (("time", "values"), arr) + return xr.Dataset(data_vars=data_vars, coords=coords) + + +def load_obs_data_from_jretrieve( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load SwissMetNet (SNM) surface observations from the DWH via jretrievedwh. + + ``root`` is a marker string selecting stations, e.g. ``jretrievedwh:SwissMetNet`` + (default group), ``jretrievedwh:locations=ARO,KLO``, or + ``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` (optionally ``;stage=devt``). Returns + a Dataset with dims (time, values), values=nat_abbr, latitude/longitude coords, + variables renamed to ICON names in SI units (T/TD in Kelvin, pressure in Pa). + Only the requested hourly valid times are kept. + """ + from data_input import jretrieve as jr + + stations, stage, seq_type = jr.parse_selection(root) + + want_uv = "U_10M" in params or "V_10M" in params + short_names: list[str] = [DWH_PARAM_MAP[p] for p in params if p in DWH_PARAM_MAP] + if want_uv: + short_names += [DWH_WIND_SPEED, DWH_WIND_DIR] + short_names = list(dict.fromkeys(short_names)) + if not short_names: + raise ValueError(f"No DWH parameter mapping for requested params: {params}") + + start = reftime + end = start + timedelta(hours=max(steps)) + if len(steps) > 1: + end += timedelta(hours=steps[-1] - steps[-2]) + + catalog = jr.StationCatalog.from_meta( + jr.fetch_meta(stations=stations, params=short_names, seq_type=seq_type, stage=stage) + ) + df = jr.fetch_data( + stations=stations, params=short_names, start=start, end=end, + increment_minutes=60, seq_type=seq_type, stage=stage, + ) + raw = _jretrieve_df_to_xarray(df, short_names, catalog) + + out = xr.Dataset(coords=raw.coords) + for icon, short in DWH_PARAM_MAP.items(): + if icon in params and short in raw: + var = raw[short] + if short in DWH_CELSIUS_TO_KELVIN: + var = var - ZERO_KELVIN + elif short in DWH_HPA_TO_PA: + var = var * 100.0 + out[icon] = var + if want_uv and DWH_WIND_SPEED in raw and DWH_WIND_DIR in raw: + ff = raw[DWH_WIND_SPEED] + dd_rad = np.deg2rad(raw[DWH_WIND_DIR]) + if "U_10M" in params: + out["U_10M"] = -ff * np.sin(dd_rad) + if "V_10M" in params: + out["V_10M"] = -ff * np.cos(dd_rad) + + out = out.dropna("values", how="all") + times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") + return _select_valid_times(out, times) + + def load_truth_data( root, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: @@ -350,6 +465,13 @@ def load_truth_data( steps=steps, params=params, ) + elif "jretrieve" in str(root): + LOG.info("Loading ground truth from JRetrieve...") + truth = load_obs_data_from_jretrieve( + reftime=reftime, + steps=steps, + params=params, + ) else: raise ValueError(f"Unsupported truth root: {root}") return truth diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index f54b6057..23dbe057 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -1,7 +1,10 @@ +from datetime import datetime + import numpy as np import pandas as pd import pytest +import data_input from data_input import jretrieve as jr @@ -57,3 +60,42 @@ def test_station_catalog_from_meta_collapses_and_sorts(): assert list(cat.nat_abbr) == ["ARO", "KLO"] # sorted by nat_abbr assert list(cat.station_id) == [1, 2] np.testing.assert_allclose(cat.latitude, [46.79, 47.48]) + + +def test_load_obs_data_from_jretrieve(monkeypatch): + meta = pd.DataFrame({ + "station": [1, 2], + "op_since": [19800101000000, 19900101000000], + "op_till": ["", ""], + "parameter": ["tre200s0", "tre200s0"], + "latitude": [46.79, 47.48], + "longitude": [9.68, 8.54], + "elev": [1878.0, 426.0], + "stn_name": ["Arosa", "Zurich"], + "nat_abbr": ["ARO", "KLO"], + }) + data = pd.DataFrame({ + "station": [1, 1], + "termin": [20250115000000, 20250115010000], + "tre200s0": [10.0, 11.0], # degC + "fkl010z0": [3.0, 4.0], # m/s + "dkl010z0": [0.0, 90.0], # deg + }) + monkeypatch.setattr(jr, "fetch_meta", lambda **kw: meta) + monkeypatch.setattr(jr, "fetch_data", lambda **kw: data) + + ds = data_input.load_obs_data_from_jretrieve( + "jretrievedwh:locations=ARO,KLO", + datetime(2025, 1, 15, 0, 0), + [0, 1], + ["T_2M", "U_10M", "V_10M"], + ) + + assert set(ds.dims) == {"time", "values"} + assert list(ds["values"].values) == ["ARO"] # KLO all-NaN -> dropped + assert set(ds.data_vars) == {"T_2M", "U_10M", "V_10M"} + np.testing.assert_allclose(ds["T_2M"].sel(values="ARO").values, [283.15, 284.15]) + # DD=0 -> U=0, V=-FF ; DD=90 -> U=-FF, V=0 + np.testing.assert_allclose(ds["U_10M"].sel(values="ARO").values, [0.0, -4.0], atol=1e-5) + np.testing.assert_allclose(ds["V_10M"].sel(values="ARO").values, [-3.0, 0.0], atol=1e-5) + np.testing.assert_allclose(ds["latitude"].values, [46.79]) From e903490a269099618b2edac69d76cb48fa4b3a16 Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:32:24 +0200 Subject: [PATCH 04/21] feat(data_input): forward truth root marker to jretrieve loader --- src/data_input/__init__.py | 1 + tests/unit/test_jretrieve.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 63dbc75e..7f13e3f9 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -468,6 +468,7 @@ def load_truth_data( elif "jretrieve" in str(root): LOG.info("Loading ground truth from JRetrieve...") truth = load_obs_data_from_jretrieve( + root=root, reftime=reftime, steps=steps, params=params, diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index 23dbe057..349b971e 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -1,4 +1,5 @@ from datetime import datetime +from pathlib import Path import numpy as np import pandas as pd @@ -99,3 +100,18 @@ def test_load_obs_data_from_jretrieve(monkeypatch): np.testing.assert_allclose(ds["U_10M"].sel(values="ARO").values, [0.0, -4.0], atol=1e-5) np.testing.assert_allclose(ds["V_10M"].sel(values="ARO").values, [-3.0, 0.0], atol=1e-5) np.testing.assert_allclose(ds["latitude"].values, [46.79]) + + +def test_load_truth_data_forwards_root(monkeypatch): + seen = {} + + def fake_loader(root, reftime, steps, params): + seen["root"] = root + return "SENTINEL" + + monkeypatch.setattr(data_input, "load_obs_data_from_jretrieve", fake_loader) + out = data_input.load_truth_data( + Path("jretrievedwh:SwissMetNet"), datetime(2025, 1, 15), [0], ["T_2M"] + ) + assert out == "SENTINEL" + assert "jretrievedwh:SwissMetNet" in str(seen["root"]) From 9babcc4e82a347d73a1b420484e453819124b09c Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:35:35 +0200 Subject: [PATCH 05/21] feat(workflow): make truth input conditional for live jretrieve source --- workflow/rules/common.smk | 7 +++++++ workflow/rules/verification.smk | 10 ++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 020ad956..d76e9c53 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -349,6 +349,13 @@ def truth_hash(truth_config: dict) -> str: return generate_json_hash(cfg) +def truth_file_dep(_): + """Truth file dependency: a real path for zarr/peakweather, but a live-query + marker (no input file) for jretrieve.""" + root = config["truth"]["root"] + return [] if "jretrieve" in str(root) else [root] + + TRUTH_HASH = truth_hash(config["truth"]) REGIONS = parse_regions() SHOWCASE_REGIONS = parse_showcase_regions() diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index ffadf40a..b10b37b5 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -15,7 +15,7 @@ rule verification_metrics_baseline: "src/data_input/__init__.py", script="workflow/scripts/verification_metrics.py", forecast=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["root"], - truth=config["truth"]["root"], + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: OUT_ROOT / f"data/baselines/{{baseline_id}}/{{init_time}}/verif_{TRUTH_HASH}.nc", @@ -29,6 +29,7 @@ rule verification_metrics_baseline: baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"), baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], member=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("member", "000"), + truth=config["truth"]["root"], truth_label=config["truth"]["label"], regions=REGIONS, experiment_params=",".join(EXPERIMENT_PARAMS), @@ -38,7 +39,7 @@ rule verification_metrics_baseline: export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) uv run {input.script} \ --forecast {input.forecast} \ - --truth {input.truth} \ + --truth {params.truth} \ --reftime {wildcards.init_time} \ --steps "{params.baseline_steps}" \ --label "{params.baseline_label}" \ @@ -64,7 +65,7 @@ rule verification_metrics: "src/data_input/__init__.py", script="workflow/scripts/verification_metrics.py", inference_okfile=rules.inference_execute.output.okfile, - truth=config["truth"]["root"], + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: OUT_ROOT / f"data/runs/{{run_id}}/{{init_time}}/verif_{TRUTH_HASH}.nc", @@ -80,6 +81,7 @@ rule verification_metrics: params: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + truth=config["truth"]["root"], truth_label=config["truth"]["label"], regions=REGIONS, grib_out_dir=lambda wc: ( @@ -92,7 +94,7 @@ rule verification_metrics: export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) uv run {input.script} \ --forecast {params.grib_out_dir} \ - --truth {input.truth} \ + --truth {params.truth} \ --reftime {wildcards.init_time} \ --steps "{params.fcst_steps}" \ --label "{params.fcst_label}" \ From d172fb5df347a218798fd03296d89e9edf642d85 Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:36:53 +0200 Subject: [PATCH 06/21] docs: document jretrievedwh truth source config + prerequisites --- README.md | 27 +++++++++++++++++++++++++++ config/temporal-downscalers-ich1.yaml | 7 +++++++ 2 files changed, 34 insertions(+) diff --git a/README.md b/README.md index 0d9c9a06..1c48a03a 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,33 @@ You can then run it with: evalml experiment path/to/experiment/config.yaml --report ``` +### Truth sources + +The `truth.root` value selects how the ground truth is loaded: + +- **Analysis Zarr** — a path ending in `.zarr` (anemoi analysis dataset). +- **PeakWeather** — a path containing `peakweather` (SwissMetNet station obs from Hugging Face). +- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching SwissMetNet (SNM) + surface observations live from the MeteoSwiss data warehouse. Variables are mapped to + ICON names in SI units (temperatures in K, pressure in Pa, precipitation as the hourly + sum); wind `U_10M`/`V_10M` are derived from speed + direction. + + Marker syntax (station selection is required; pick one of group/locations/bbox): + + ```yaml + truth: + label: SwissMetNet (DWH) + root: jretrievedwh:SwissMetNet # stn_group (default group) + # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list + # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon + # append ;stage=devt to target a non-prod DWH stage (prod|depl|devt) + ``` + + **Prerequisites:** `jretrievedwh.py` must be on `$PATH` (e.g. + `/oprusers/osm/opr.inn/bin`) and `$OPR_HOME` set with a readable + `.jretrievedwh-conf..py` conf file. No data is pre-downloaded — the obs are + queried at verification time. + ## Installation diff --git a/config/temporal-downscalers-ich1.yaml b/config/temporal-downscalers-ich1.yaml index 18d83704..84bbef5f 100644 --- a/config/temporal-downscalers-ich1.yaml +++ b/config/temporal-downscalers-ich1.yaml @@ -37,6 +37,13 @@ runs: truth: label: SwissMetNet root: output/data/observations/peakweather + # To verify against SwissMetNet observations from the DWH via jretrievedwh, + # set instead (requires jretrievedwh.py on $PATH and $OPR_HOME set): + # label: SwissMetNet (DWH) + # root: jretrievedwh:SwissMetNet + # Other selectors: root: jretrievedwh:locations=ARO,KLO,LUG + # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 + # append ;stage=devt to target a non-prod DWH stage experiment: params: From cea206c2745ffd786df141288768e8d13adceaaa Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:39:40 +0200 Subject: [PATCH 07/21] style: apply ruff format to jretrieve source and tests --- src/data_input/__init__.py | 13 +++-- src/data_input/jretrieve.py | 67 ++++++++++++++++++------ tests/unit/test_jretrieve.py | 98 ++++++++++++++++++++++-------------- 3 files changed, 123 insertions(+), 55 deletions(-) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 7f13e3f9..5c733033 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -410,11 +410,18 @@ def load_obs_data_from_jretrieve( end += timedelta(hours=steps[-1] - steps[-2]) catalog = jr.StationCatalog.from_meta( - jr.fetch_meta(stations=stations, params=short_names, seq_type=seq_type, stage=stage) + jr.fetch_meta( + stations=stations, params=short_names, seq_type=seq_type, stage=stage + ) ) df = jr.fetch_data( - stations=stations, params=short_names, start=start, end=end, - increment_minutes=60, seq_type=seq_type, stage=stage, + stations=stations, + params=short_names, + start=start, + end=end, + increment_minutes=60, + seq_type=seq_type, + stage=stage, ) raw = _jretrieve_df_to_xarray(df, short_names, catalog) diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py index d5df7720..b95dafbb 100644 --- a/src/data_input/jretrieve.py +++ b/src/data_input/jretrieve.py @@ -47,7 +47,9 @@ def _resolve_binary() -> str: def _build_env(stage: str) -> dict[str, str]: if stage not in VALID_STAGES: - raise ValueError(f"Invalid stage {stage!r}. Must be one of {sorted(VALID_STAGES)}.") + raise ValueError( + f"Invalid stage {stage!r}. Must be one of {sorted(VALID_STAGES)}." + ) opr_home = os.environ.get("OPR_HOME") if not opr_home: raise JretrieveError("OPR_HOME is not set; cannot locate jretrieve conf file.") @@ -134,10 +136,17 @@ def parse_selection(root: Any) -> tuple[dict[str, Any], str, str]: def _run(argv: list[str], env: dict[str, str], timeout_s: int) -> str: try: proc = subprocess.run( - argv, env=env, capture_output=True, text=True, timeout=timeout_s, check=False + argv, + env=env, + capture_output=True, + text=True, + timeout=timeout_s, + check=False, ) except subprocess.TimeoutExpired as e: - raise JretrieveError(f"jretrieve timed out after {timeout_s}s: {' '.join(argv)}") from e + raise JretrieveError( + f"jretrieve timed out after {timeout_s}s: {' '.join(argv)}" + ) from e if proc.returncode != 0: raise JretrieveError( f"jretrieve exited with {proc.returncode}\nargv: {argv}\n" @@ -160,7 +169,10 @@ def _run_with_retry(argv, env, timeout_s, attempts=3) -> str: backoff = 2**attempt LOG.warning( "jretrieve attempt %d/%d failed (%s); retrying in %ds", - attempt, attempts, e, backoff, + attempt, + attempts, + e, + backoff, ) time.sleep(backoff) assert last_err is not None @@ -175,17 +187,30 @@ def _parse_csv(csv_text: str) -> pd.DataFrame: def fetch_meta( - *, stations, params, seq_type="surface", stage="prod", - meta_fields=DEFAULT_META_FIELDS, timeout_s=300, + *, + stations, + params, + seq_type="surface", + stage="prod", + meta_fields=DEFAULT_META_FIELDS, + timeout_s=300, ) -> pd.DataFrame: """Fetch the station catalog (rows per station x parameter x period) over a fixed wide time range so the response is deterministic.""" if not params: raise ValueError("params must be non-empty.") argv = [ - _resolve_binary(), "-s", seq_type, "-n", ",".join(params), - "-t", f"{_fmt_time(CATALOG_TIME_RANGE_START)},{_fmt_time(CATALOG_TIME_RANGE_END)}", - "--meta-info", ",".join(meta_fields), "--format", "csv", + _resolve_binary(), + "-s", + seq_type, + "-n", + ",".join(params), + "-t", + f"{_fmt_time(CATALOG_TIME_RANGE_START)},{_fmt_time(CATALOG_TIME_RANGE_END)}", + "--meta-info", + ",".join(meta_fields), + "--format", + "csv", *_stations_to_argv(stations), ] LOG.info("jretrieve meta: %s", " ".join(argv)) @@ -196,17 +221,31 @@ def fetch_meta( def fetch_data( - *, stations, params, start, end, increment_minutes=60, - seq_type="surface", stage="prod", timeout_s=600, + *, + stations, + params, + start, + end, + increment_minutes=60, + seq_type="surface", + stage="prod", + timeout_s=600, ) -> pd.DataFrame: """Fetch observation data; columns: station (int), termin (YYYYMMDDhhmmss), one column per requested short name.""" if not params: raise ValueError("params must be non-empty.") argv = [ - _resolve_binary(), "-s", seq_type, "-n", ",".join(params), - "-t", f"{_fmt_time(start)},{_fmt_time(end)},{int(increment_minutes)}", - "--format", "csv", *_stations_to_argv(stations), + _resolve_binary(), + "-s", + seq_type, + "-n", + ",".join(params), + "-t", + f"{_fmt_time(start)},{_fmt_time(end)},{int(increment_minutes)}", + "--format", + "csv", + *_stations_to_argv(stations), ] LOG.info("jretrieve data: %s", " ".join(argv)) return _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index 349b971e..adfd4116 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -10,7 +10,10 @@ def test_stations_to_argv_group(): - assert jr._stations_to_argv({"group": "SwissMetNet"}) == ["-a", "stn_group,SwissMetNet"] + assert jr._stations_to_argv({"group": "SwissMetNet"}) == [ + "-a", + "stn_group,SwissMetNet", + ] def test_stations_to_argv_locations_from_string(): @@ -19,7 +22,8 @@ def test_stations_to_argv_locations_from_string(): def test_stations_to_argv_bbox_from_string(): assert jr._stations_to_argv({"bbox": "45.8,47.8,5.9,10.5"}) == [ - "-l", "45.8,47.8,5.9,10.5", + "-l", + "45.8,47.8,5.9,10.5", ] @@ -29,59 +33,73 @@ def test_stations_to_argv_rejects_ambiguous(): def test_parse_selection_default_group(): - assert jr.parse_selection("jretrievedwh:") == ({"group": "SwissMetNet"}, "prod", "surface") + assert jr.parse_selection("jretrievedwh:") == ( + {"group": "SwissMetNet"}, + "prod", + "surface", + ) assert jr.parse_selection("jretrievedwh:SwissMetNet") == ( - {"group": "SwissMetNet"}, "prod", "surface", + {"group": "SwissMetNet"}, + "prod", + "surface", ) def test_parse_selection_keyvalue_and_stage(): assert jr.parse_selection("jretrievedwh:locations=ARO,KLO;stage=devt") == ( - {"locations": "ARO,KLO"}, "devt", "surface", + {"locations": "ARO,KLO"}, + "devt", + "surface", ) def _sample_meta(): - return pd.DataFrame({ - "station": [2, 1, 1], - "op_since": [19900101000000, 19800101000000, 19800101000000], - "op_till": ["", "", ""], - "parameter": ["tre200s0", "fkl010z0", "tre200s0"], - "latitude": [47.48, 46.79, 46.79], - "longitude": [8.54, 9.68, 9.68], - "elev": [426.0, 1878.0, 1878.0], - "stn_name": ["Zurich", "Arosa", "Arosa"], - "nat_abbr": ["KLO", "ARO", "ARO"], - }) + return pd.DataFrame( + { + "station": [2, 1, 1], + "op_since": [19900101000000, 19800101000000, 19800101000000], + "op_till": ["", "", ""], + "parameter": ["tre200s0", "fkl010z0", "tre200s0"], + "latitude": [47.48, 46.79, 46.79], + "longitude": [8.54, 9.68, 9.68], + "elev": [426.0, 1878.0, 1878.0], + "stn_name": ["Zurich", "Arosa", "Arosa"], + "nat_abbr": ["KLO", "ARO", "ARO"], + } + ) def test_station_catalog_from_meta_collapses_and_sorts(): cat = jr.StationCatalog.from_meta(_sample_meta()) assert cat.n == 2 - assert list(cat.nat_abbr) == ["ARO", "KLO"] # sorted by nat_abbr + assert list(cat.nat_abbr) == ["ARO", "KLO"] # sorted by nat_abbr assert list(cat.station_id) == [1, 2] np.testing.assert_allclose(cat.latitude, [46.79, 47.48]) def test_load_obs_data_from_jretrieve(monkeypatch): - meta = pd.DataFrame({ - "station": [1, 2], - "op_since": [19800101000000, 19900101000000], - "op_till": ["", ""], - "parameter": ["tre200s0", "tre200s0"], - "latitude": [46.79, 47.48], - "longitude": [9.68, 8.54], - "elev": [1878.0, 426.0], - "stn_name": ["Arosa", "Zurich"], - "nat_abbr": ["ARO", "KLO"], - }) - data = pd.DataFrame({ - "station": [1, 1], - "termin": [20250115000000, 20250115010000], - "tre200s0": [10.0, 11.0], # degC - "fkl010z0": [3.0, 4.0], # m/s - "dkl010z0": [0.0, 90.0], # deg - }) + meta = pd.DataFrame( + { + "station": [1, 2], + "op_since": [19800101000000, 19900101000000], + "op_till": ["", ""], + "parameter": ["tre200s0", "tre200s0"], + "latitude": [46.79, 47.48], + "longitude": [9.68, 8.54], + "elev": [1878.0, 426.0], + "stn_name": ["Arosa", "Zurich"], + "nat_abbr": ["ARO", "KLO"], + } + ) + data = pd.DataFrame( + { + "station": [1, 1], + "termin": [20250115000000, 20250115010000], + "tre200s0": [10.0, 11.0], # degC + "fkl010z0": [3.0, 4.0], # m/s + "dkl010z0": [0.0, 90.0], # deg + } + ) monkeypatch.setattr(jr, "fetch_meta", lambda **kw: meta) monkeypatch.setattr(jr, "fetch_data", lambda **kw: data) @@ -93,12 +111,16 @@ def test_load_obs_data_from_jretrieve(monkeypatch): ) assert set(ds.dims) == {"time", "values"} - assert list(ds["values"].values) == ["ARO"] # KLO all-NaN -> dropped + assert list(ds["values"].values) == ["ARO"] # KLO all-NaN -> dropped assert set(ds.data_vars) == {"T_2M", "U_10M", "V_10M"} np.testing.assert_allclose(ds["T_2M"].sel(values="ARO").values, [283.15, 284.15]) # DD=0 -> U=0, V=-FF ; DD=90 -> U=-FF, V=0 - np.testing.assert_allclose(ds["U_10M"].sel(values="ARO").values, [0.0, -4.0], atol=1e-5) - np.testing.assert_allclose(ds["V_10M"].sel(values="ARO").values, [-3.0, 0.0], atol=1e-5) + np.testing.assert_allclose( + ds["U_10M"].sel(values="ARO").values, [0.0, -4.0], atol=1e-5 + ) + np.testing.assert_allclose( + ds["V_10M"].sel(values="ARO").values, [-3.0, 0.0], atol=1e-5 + ) np.testing.assert_allclose(ds["latitude"].values, [46.79]) From 7aa6ab91ce1aafe65beaf9ee9c6bd77cb39d6b3e Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:57:27 +0200 Subject: [PATCH 08/21] docs: correct SwissMetNet abbreviation SNM -> SMN --- README.md | 2 +- src/data_input/__init__.py | 2 +- src/data_input/jretrieve.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1c48a03a..655efd6a 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ The `truth.root` value selects how the ground truth is loaded: - **Analysis Zarr** — a path ending in `.zarr` (anemoi analysis dataset). - **PeakWeather** — a path containing `peakweather` (SwissMetNet station obs from Hugging Face). -- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching SwissMetNet (SNM) +- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching SwissMetNet (SMN) surface observations live from the MeteoSwiss data warehouse. Variables are mapped to ICON names in SI units (temperatures in K, pressure in Pa, precipitation as the hourly sum); wind `U_10M`/`V_10M` are derived from speed + direction. diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 5c733033..7a9280c3 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -383,7 +383,7 @@ def _jretrieve_df_to_xarray(df, short_names, catalog) -> xr.Dataset: def load_obs_data_from_jretrieve( root, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: - """Load SwissMetNet (SNM) surface observations from the DWH via jretrievedwh. + """Load SwissMetNet (SMN) surface observations from the DWH via jretrievedwh. ``root`` is a marker string selecting stations, e.g. ``jretrievedwh:SwissMetNet`` (default group), ``jretrievedwh:locations=ARO,KLO``, or diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py index b95dafbb..a7b073b9 100644 --- a/src/data_input/jretrieve.py +++ b/src/data_input/jretrieve.py @@ -1,5 +1,5 @@ """Subprocess wrapper around ``jretrievedwh.py`` for retrieving SwissMetNet -(SNM) surface observations from the MeteoSwiss data warehouse (DWH). +(SMN) surface observations from the MeteoSwiss data warehouse (DWH). Ported/adapted from MeteoSwiss/anemoi-plugins-meteoswiss (add-synop-dwh-source). Requires ``jretrievedwh.py`` on $PATH and $OPR_HOME set with a readable From 5cd0fa8fe0595558ce0acf40bda999fa169fb684 Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:20:03 +0200 Subject: [PATCH 09/21] feat(jretrieve): fail-fast check for jretrievedwh.py on PATH + OPR_HOME Validate the DWH prerequisites (binary on PATH, OPR_HOME set, conf file readable) at workflow-build time so a misconfigured environment aborts at launch instead of hours into the run, and again at loader entry for the authoritative job environment. Errors aggregate all problems at once. --- src/data_input/__init__.py | 1 + src/data_input/jretrieve.py | 33 +++++++++++++++++++++++++++++++++ tests/unit/test_jretrieve.py | 26 ++++++++++++++++++++++++++ workflow/rules/common.smk | 10 ++++++++++ 4 files changed, 70 insertions(+) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 7a9280c3..ab088c49 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -395,6 +395,7 @@ def load_obs_data_from_jretrieve( from data_input import jretrieve as jr stations, stage, seq_type = jr.parse_selection(root) + jr.check_prerequisites(stage) want_uv = "U_10M" in params or "V_10M" in params short_names: list[str] = [DWH_PARAM_MAP[p] for p in params if p in DWH_PARAM_MAP] diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py index a7b073b9..c2ccbc24 100644 --- a/src/data_input/jretrieve.py +++ b/src/data_input/jretrieve.py @@ -65,6 +65,39 @@ def _build_env(stage: str) -> dict[str, str]: return env +def check_prerequisites(stage: str = "prod") -> None: + """Fail-fast validation that the jretrievedwh environment is usable. + + Verifies the CLI is on $PATH, $OPR_HOME is set, and the conf file for + ``stage`` exists and is readable. Raises a single ``JretrieveError`` listing + *all* problems found, so a misconfigured environment is reported up front + (e.g. at workflow launch) instead of hours later inside the verification job. + """ + problems: list[str] = [] + if shutil.which(BINARY_NAME) is None: + problems.append( + f"{BINARY_NAME} not found in $PATH " + "(e.g. add /oprusers/osm/opr.inn/bin to $PATH)." + ) + opr_home = os.environ.get("OPR_HOME") + if not opr_home: + problems.append("$OPR_HOME is not set.") + elif stage not in VALID_STAGES: + problems.append( + f"Invalid stage {stage!r}; must be one of {sorted(VALID_STAGES)}." + ) + else: + conf_path = os.path.join(opr_home, f".jretrievedwh-conf.{stage}.py") + if not os.path.isfile(conf_path): + problems.append(f"jretrieve conf file not found: {conf_path}") + elif not os.access(conf_path, os.R_OK): + problems.append(f"jretrieve conf file not readable: {conf_path}") + if problems: + raise JretrieveError( + "jretrievedwh prerequisites not met:\n - " + "\n - ".join(problems) + ) + + def _fmt_time(dt: datetime) -> str: return dt.strftime("%Y%m%d%H%M") diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index adfd4116..7a1d16a2 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -53,6 +53,32 @@ def test_parse_selection_keyvalue_and_stage(): ) +def test_check_prerequisites_ok(monkeypatch, tmp_path): + conf = tmp_path / ".jretrievedwh-conf.prod.py" + conf.write_text("# conf\n") + monkeypatch.setattr(jr.shutil, "which", lambda name: "/opt/bin/jretrievedwh.py") + monkeypatch.setenv("OPR_HOME", str(tmp_path)) + jr.check_prerequisites("prod") # should not raise + + +def test_check_prerequisites_missing_binary(monkeypatch, tmp_path): + conf = tmp_path / ".jretrievedwh-conf.prod.py" + conf.write_text("# conf\n") + monkeypatch.setattr(jr.shutil, "which", lambda name: None) + monkeypatch.setenv("OPR_HOME", str(tmp_path)) + with pytest.raises(jr.JretrieveError, match=r"\$PATH"): + jr.check_prerequisites("prod") + + +def test_check_prerequisites_aggregates_all_problems(monkeypatch): + monkeypatch.setattr(jr.shutil, "which", lambda name: None) + monkeypatch.delenv("OPR_HOME", raising=False) + with pytest.raises(jr.JretrieveError) as exc: + jr.check_prerequisites("prod") + msg = str(exc.value) + assert "$PATH" in msg and "$OPR_HOME" in msg # both reported, not just the first + + def _sample_meta(): return pd.DataFrame( { diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index d76e9c53..8b2f8968 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -356,6 +356,16 @@ def truth_file_dep(_): return [] if "jretrieve" in str(root) else [root] +# Fail fast: when the truth source is the live DWH (jretrievedwh), verify its +# prerequisites at workflow-build time so a misconfigured environment is caught +# at launch, before any (expensive) inference job runs. +if "jretrieve" in str(config["truth"]["root"]): + from data_input.jretrieve import check_prerequisites, parse_selection + + _, _jretrieve_stage, _ = parse_selection(config["truth"]["root"]) + check_prerequisites(_jretrieve_stage) + + TRUTH_HASH = truth_hash(config["truth"]) REGIONS = parse_regions() SHOWCASE_REGIONS = parse_showcase_regions() From ba8858727b9170cf275e7d1789f84c5ff16acb2f Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Fri, 12 Jun 2026 15:48:18 +0200 Subject: [PATCH 10/21] fix issue with meas_group/stn_group --- README.md | 2 +- config/temporal-downscalers-ich1.yaml | 28 ++++++--------------------- src/data_input/jretrieve.py | 2 +- tests/unit/test_jretrieve.py | 4 ++-- 4 files changed, 10 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 655efd6a..a8072472 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ The `truth.root` value selects how the ground truth is loaded: ```yaml truth: label: SwissMetNet (DWH) - root: jretrievedwh:SwissMetNet # stn_group (default group) + root: jretrievedwh:SMN # meas_group (default group) # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon # append ;stage=devt to target a non-prod DWH stage (prod|depl|devt) diff --git a/config/temporal-downscalers-ich1.yaml b/config/temporal-downscalers-ich1.yaml index 84bbef5f..3199ef1c 100644 --- a/config/temporal-downscalers-ich1.yaml +++ b/config/temporal-downscalers-ich1.yaml @@ -6,25 +6,9 @@ description: | dates: start: 2025-03-01T00:00 end: 2025-03-03T00:00 - frequency: 24h + frequency: 48h runs: - - temporal_downscaler: - checkpoint: /scratch/mch/miccatta/ICON_interpolator_checkpoints/checkpoint_stage-C-interpolator-n320-6hto1h-reduced-variables/f9279244ed6f4c458597bdcf335ab36f/inference-last.ckpt - label: Varda-Single - steps: 0/120/1 - config: resources/inference/configs/sgm-temporal-downscaler-global_trimedge_multi.yaml - extra_requirements: - - anemoi-datasets==0.5.35 - # - anemoi-inference==0.11.0 - forecaster: - checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad - config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml - steps: 0/120/6 - - baseline: - label: INCA - root: /store_new/mch/msclim/INCA - steps: 0/6/1 - baseline: label: ICON-CH2-CTRL root: /store_new/mch/msopr/osm/ICON-CH2-EPS @@ -36,11 +20,11 @@ runs: truth: label: SwissMetNet - root: output/data/observations/peakweather + root: jretrievedwh:SMN # To verify against SwissMetNet observations from the DWH via jretrievedwh, # set instead (requires jretrievedwh.py on $PATH and $OPR_HOME set): - # label: SwissMetNet (DWH) - # root: jretrievedwh:SwissMetNet + # label: SwissMetNet + # root: jretrievedwh:SMN # Other selectors: root: jretrievedwh:locations=ARO,KLO,LUG # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # append ;stage=devt to target a non-prod DWH stage @@ -76,7 +60,7 @@ experiment: # - init_hour - season scorecards: - enabled: true + enabled: false sections: nowcasting: baseline: INCA @@ -127,7 +111,7 @@ locations: output_root: output/ profile: - executor: slurm + executor: local global_resources: gpus: 16 default_resources: diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py index c2ccbc24..9aa91202 100644 --- a/src/data_input/jretrieve.py +++ b/src/data_input/jretrieve.py @@ -115,7 +115,7 @@ def _stations_to_argv(stations: dict[str, Any]) -> list[str]: key = keys[0] val = stations[key] if key == "group": - return ["-a", f"stn_group,{val}"] + return ["-a", f"stn_group_id,{val}"] if key == "locations": if isinstance(val, str): val = [v for v in val.split(",") if v] diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index 7a1d16a2..dbdaed80 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -10,9 +10,9 @@ def test_stations_to_argv_group(): - assert jr._stations_to_argv({"group": "SwissMetNet"}) == [ + assert jr._stations_to_argv({"group": "1,2"}) == [ "-a", - "stn_group,SwissMetNet", + "stn_group_id,1,2", ] From 2182e3b047a22663c69d60aa315fb0666c6591c7 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Fri, 12 Jun 2026 16:15:40 +0200 Subject: [PATCH 11/21] fix inadvertent change to config and update readme --- README.md | 2 +- config/temporal-downscalers-ich1.yaml | 26 ++++++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a8072472..816857f5 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ The `truth.root` value selects how the ground truth is loaded: ```yaml truth: label: SwissMetNet (DWH) - root: jretrievedwh:SMN # meas_group (default group) + root: jretrievedwh:1,2 # stn_group_id (default) # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon # append ;stage=devt to target a non-prod DWH stage (prod|depl|devt) diff --git a/config/temporal-downscalers-ich1.yaml b/config/temporal-downscalers-ich1.yaml index 3199ef1c..08aee647 100644 --- a/config/temporal-downscalers-ich1.yaml +++ b/config/temporal-downscalers-ich1.yaml @@ -6,9 +6,25 @@ description: | dates: start: 2025-03-01T00:00 end: 2025-03-03T00:00 - frequency: 48h + frequency: 24h runs: + - temporal_downscaler: + checkpoint: /scratch/mch/miccatta/ICON_interpolator_checkpoints/checkpoint_stage-C-interpolator-n320-6hto1h-reduced-variables/f9279244ed6f4c458597bdcf335ab36f/inference-last.ckpt + label: Varda-Single + steps: 0/120/1 + config: resources/inference/configs/sgm-temporal-downscaler-global_trimedge_multi.yaml + extra_requirements: + - anemoi-datasets==0.5.35 + # - anemoi-inference==0.11.0 + forecaster: + checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad + config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + steps: 0/120/6 + - baseline: + label: INCA + root: /store_new/mch/msclim/INCA + steps: 0/6/1 - baseline: label: ICON-CH2-CTRL root: /store_new/mch/msopr/osm/ICON-CH2-EPS @@ -20,11 +36,9 @@ runs: truth: label: SwissMetNet - root: jretrievedwh:SMN + root: jretrievedwh:1,2 # To verify against SwissMetNet observations from the DWH via jretrievedwh, # set instead (requires jretrievedwh.py on $PATH and $OPR_HOME set): - # label: SwissMetNet - # root: jretrievedwh:SMN # Other selectors: root: jretrievedwh:locations=ARO,KLO,LUG # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # append ;stage=devt to target a non-prod DWH stage @@ -60,7 +74,7 @@ experiment: # - init_hour - season scorecards: - enabled: false + enabled: true sections: nowcasting: baseline: INCA @@ -111,7 +125,7 @@ locations: output_root: output/ profile: - executor: local + executor: slurm global_resources: gpus: 16 default_resources: From 54da55e713fa5c6e937ad4617baef507f1674f00 Mon Sep 17 00:00:00 2001 From: clairemerker <34312518+clairemerker@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:24:49 +0200 Subject: [PATCH 12/21] test(jretrieve): mock check_prerequisites in loader test for CI The loader calls check_prerequisites(), which probes for jretrievedwh.py on $PATH and $OPR_HOME. GitHub CI has neither, so the test failed there while passing locally. Mock it like the other DWH calls so the test is environment independent. --- tests/unit/test_jretrieve.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index dbdaed80..6625084a 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -126,6 +126,9 @@ def test_load_obs_data_from_jretrieve(monkeypatch): "dkl010z0": [0.0, 90.0], # deg } ) + # Skip the live environment check so the test runs anywhere (incl. CI, + # which has no jretrievedwh.py / $OPR_HOME). + monkeypatch.setattr(jr, "check_prerequisites", lambda *a, **k: None) monkeypatch.setattr(jr, "fetch_meta", lambda **kw: meta) monkeypatch.setattr(jr, "fetch_data", lambda **kw: data) From 5e0caec0e25acd4beff238f833caa1a531f97367 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Fri, 12 Jun 2026 18:09:10 +0200 Subject: [PATCH 13/21] remove peakweather --- README.md | 7 ++- pyproject.toml | 2 - src/data_input/__init__.py | 71 +----------------------------- workflow/rules/common.smk | 2 +- workflow/rules/data.smk | 17 ------- workflow/rules/plot.smk | 10 ++--- workflow/scripts/plot_meteogram.py | 51 +++++++++++++-------- 7 files changed, 42 insertions(+), 118 deletions(-) diff --git a/README.md b/README.md index 816857f5..2a26b0a5 100644 --- a/README.md +++ b/README.md @@ -150,9 +150,8 @@ evalml experiment path/to/experiment/config.yaml --report The `truth.root` value selects how the ground truth is loaded: - **Analysis Zarr** — a path ending in `.zarr` (anemoi analysis dataset). -- **PeakWeather** — a path containing `peakweather` (SwissMetNet station obs from Hugging Face). -- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching SwissMetNet (SMN) - surface observations live from the MeteoSwiss data warehouse. Variables are mapped to +- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching surface observations + (e.g. SMN) live from the MeteoSwiss data warehouse. Variables are mapped to ICON names in SI units (temperatures in K, pressure in Pa, precipitation as the hourly sum); wind `U_10M`/`V_10M` are derived from speed + direction. @@ -160,7 +159,7 @@ The `truth.root` value selects how the ground truth is loaded: ```yaml truth: - label: SwissMetNet (DWH) + label: SwissMetNet root: jretrievedwh:1,2 # stn_group_id (default) # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon diff --git a/pyproject.toml b/pyproject.toml index d3279e91..d8552f6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ dependencies = [ "pyproj>=3.7.2", "marimo>=0.23.3", "geopandas>=0.14.0", - "peakweather", "pyzmq>=27.1.0", "scores>=2.0.0", "eccodes>=2.44,<2.48", @@ -60,5 +59,4 @@ packages = [ ] [tool.uv.sources] -peakweather = { git = "https://github.com/MeteoSwiss/PeakWeather.git" } eccodes-cosmo-resources-python = { git = "https://github.com/MeteoSwiss/eccodes-cosmo-resources-python" } diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index ab088c49..ac486264 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -265,67 +265,6 @@ def load_forecast_data_from_grib(files: list[Path], params: list[str]) -> xr.Dat return ds -def load_obs_data_from_peakweather( - root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h" -) -> xr.Dataset: - """Load PeakWeather station observations into an xarray Dataset. - - Returns a Dataset with dimensions `time` and `values`, values coordinates - (`lat`, `lon`), and variables renamed to ICON parameter names. - Temperatures are converted to Kelvin when present. - """ - from peakweather.dataset import PeakWeatherDataset - - param_names = { - "temperature": "T_2M", - "wind_u": "U_10M", - "wind_v": "V_10M", - "precipitation": "TOT_PREC", - "pressure": "PS", - "wind_gust": "VMAX_10M", - } - param_names = {k: v for k, v in param_names.items() if v in params} - start = reftime - end = start + timedelta(hours=max(steps)) - if len(steps) > 1: - end += timedelta(hours=steps[-1] - steps[-2]) # extend by 1 extra step - years = list(set([start.year, end.year])) - if "wind_u" in param_names or "wind_v" in param_names: - compute_uv = True - else: - compute_uv = False - pw = PeakWeatherDataset(root=root, years=years, freq=freq, compute_uv=compute_uv) - ds, mask = pw.get_observations( - parameters=[k for k in param_names.keys()], - first_date=f"{start:%Y-%m-%d %H:%M}", - last_date=f"{end:%Y-%m-%d %H:%M}", - return_mask=True, - ) - ds = ( - ds.stack(["nat_abbr", "name"], future_stack=True) - .to_xarray() - .to_dataset(dim="name") - ) - mask = ( - mask.stack(["nat_abbr", "name"], future_stack=True) - .to_xarray() - .to_dataset(dim="name") - ) - ds = ds.where(mask) - ds = ds.rename({"datetime": "time", "nat_abbr": "values"}) - ds = ds.rename(param_names) - ds = ds.assign_coords(time=ds.indexes["time"].tz_convert("UTC").tz_localize(None)) - ds = ds.assign_coords(values=ds.indexes["values"]) - ds = ds.assign_coords(longitude=("values", pw.stations_table["longitude"])) - ds = ds.assign_coords(latitude=("values", pw.stations_table["latitude"])) - if "T_2M" in ds: - ds["T_2M"] = ds["T_2M"] - ZERO_KELVIN # convert to Kelvin - ds = ds.dropna("values", how="all") - - times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") - return _select_valid_times(ds, times) - - DWH_PARAM_MAP = { "T_2M": "tre200s0", "TD_2M": "tde200s0", @@ -451,7 +390,7 @@ def load_obs_data_from_jretrieve( def load_truth_data( root, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: - """Load truth data from analysis Zarr or PeakWeather observations.""" + """Load truth data from an analysis Zarr dataset or DWH observations via jretrieve.""" if root.suffix == ".zarr": LOG.info("Loading ground truth from an analysis zarr dataset...") truth = load_analysis_data_from_zarr( @@ -465,14 +404,6 @@ def load_truth_data( if "y" in truth.dims and "x" in truth.dims else {"values": -1} ) - elif "peakweather" in str(root): - LOG.info("Loading ground truth from PeakWeather observations...") - truth = load_obs_data_from_peakweather( - root=root, - reftime=reftime, - steps=steps, - params=params, - ) elif "jretrieve" in str(root): LOG.info("Loading ground truth from JRetrieve...") truth = load_obs_data_from_jretrieve( diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 8b2f8968..f5395378 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -350,7 +350,7 @@ def truth_hash(truth_config: dict) -> str: def truth_file_dep(_): - """Truth file dependency: a real path for zarr/peakweather, but a live-query + """Truth file dependency: a real path for zarr, but a live-query marker (no input file) for jretrieve.""" root = config["truth"]["root"] return [] if "jretrieve" in str(root) else [root] diff --git a/workflow/rules/data.smk b/workflow/rules/data.smk index 67914a9f..eb41eff9 100644 --- a/workflow/rules/data.smk +++ b/workflow/rules/data.smk @@ -4,23 +4,6 @@ from pathlib import Path include: "common.smk" -if config["truth"]["root"].endswith("peakweather"): - output_peakweather_root = config["truth"]["root"] -else: - output_peakweather_root = OUT_ROOT / "data/observations/peakweather" - - -rule data_download_obs_from_peakweather: - output: - root=directory(output_peakweather_root), - localrule: True - run: - from peakweather.dataset import PeakWeatherDataset - - # Download the data from Huggingface - ds = PeakWeatherDataset(root=output.root) - - # Grid-definition files required by earthkit/eckit to decode the ICON-CH grids. # Automatic download/caching of these is currently broken (see README), so we # fetch them into eckit's default geo grid cache under $HOME, where it finds diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 91ff1080..5b5a7c97 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -27,14 +27,14 @@ rule plot_meteogram: input: script="workflow/scripts/plot_meteogram.py", inference_okfile=rules.inference_execute.output.okfile, - truth=config["truth"]["root"], - peakweather_dir=rules.data_download_obs_from_peakweather.output.root, + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: expand( OUT_ROOT - / "results/{{showcase}}/{{run_id}}/{{init_time}}/{{init_time}}_{{param}}_{sta}.png", + / "results/{{showcase}}/{{run_id}}/{{init_time}}/{{init_time}}_{{param}}_{sta}_{truth_hash}.png", sta=config["showcase"]["meteograms"]["stations"], + truth_hash=TRUTH_HASH, ), log: OUT_ROOT / "logs/{showcase}/{run_id}/{init_time}/plot_meteogram_{param}.log", @@ -44,6 +44,7 @@ rule plot_meteogram: runtime="60m", params: ana_label=lambda wc: config["truth"]["label"], + truth_root=config["truth"]["root"], fcst_grib=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" ).resolve(), @@ -71,9 +72,8 @@ rule plot_meteogram: --forecast {params.fcst_grib:q} --forecast_steps {params.fcst_steps:q} --forecast_label {params.fcst_label:q} - --analysis {input.truth:q} + --analysis {params.truth_root:q} --analysis_label {params.ana_label:q} - --peakweather {input.peakweather_dir:q} --date {wildcards.init_time:q} --outdir {params.outdir:q} --param {wildcards.param:q} diff --git a/workflow/scripts/plot_meteogram.py b/workflow/scripts/plot_meteogram.py index 8e7ac30e..43248d73 100644 --- a/workflow/scripts/plot_meteogram.py +++ b/workflow/scripts/plot_meteogram.py @@ -4,13 +4,14 @@ from pathlib import Path import matplotlib.pyplot as plt -from peakweather import PeakWeatherDataset +import xarray as xr from data_input import ( parse_steps, load_forecast_data, load_truth_data, ) +from data_input import jretrieve as jr from verification.spatial import map_forecast_to_truth LOG = logging.getLogger(__name__) @@ -96,9 +97,6 @@ def main(): default="truth", help="Label for analysis line in plot legend.", ) - parser.add_argument( - "--peakweather", type=str, default=None, help="Path to PeakWeather dataset" - ) parser.add_argument("--date", type=str, default=None, help="reference datetime") parser.add_argument("--outdir", type=str, help="output directory") parser.add_argument("--param", type=str, help="parameter") @@ -124,7 +122,6 @@ def main(): "Mismatched baseline arguments: --baseline and --baseline_label " "must be provided the same number of times." ) - peakweather_dir = Path(args.peakweather) init_time = datetime.strptime(args.date, "%Y%m%d%H%M") outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) @@ -145,6 +142,28 @@ def main(): else: paramlist = [param] + # Load station metadata from DWH + LOG.info("Fetching station metadata from jretrieve (SwissMetNet catalog)") + _jr_stations, _jr_stage, _jr_seq_type = jr.parse_selection("jretrievedwh:1,2") + _catalog = jr.StationCatalog.from_meta( + jr.fetch_meta( + stations=_jr_stations, + params=["rre150h0"], + seq_type=_jr_seq_type, + stage=_jr_stage, + ) + ) + catalog_lookup = { + abbr: (lat, lon) + for abbr, lat, lon in zip( + _catalog.nat_abbr, _catalog.latitude, _catalog.longitude + ) + } + + LOG.info("Loading analysis data from %s", analysis_root) + analysis_ds = load_truth_data(analysis_root, init_time, forecast_steps, paramlist) + analysis_ds = preprocess_ds(analysis_ds, param) + # Load gridded data once — shared across all station plots LOG.info("Loading forecast data from %s", forecast_grib_dir) forecast_ds = load_forecast_data( @@ -152,11 +171,6 @@ def main(): ) forecast_ds = preprocess_ds(forecast_ds, param) - steps = [int(s) for s in forecast_ds["step"].dt.total_seconds().values / 3600] - LOG.info("Loading analysis data from %s", analysis_root) - analysis_ds = load_truth_data(analysis_root, init_time, steps, paramlist) - analysis_ds = preprocess_ds(analysis_ds, param) - baseline_ds_list = [] for root, step, label in zip(baseline_roots, baseline_steps, baseline_labels): LOG.info("Loading baseline '%s' from %s", label, root) @@ -164,12 +178,6 @@ def main(): preprocess_ds(load_forecast_data(root, init_time, step, paramlist), param) ) - # Load station metadata once - LOG.info("Loading station metadata from %s", peakweather_dir) - peakweather = PeakWeatherDataset(root=peakweather_dir) - stations_table = peakweather.stations_table - stations_table.index.names = ["values"] - param2plot = forecast_ds[param].attrs.get("parameter", {}) short = param2plot.get("shortName", "") units = param2plot.get("units", "") @@ -183,9 +191,14 @@ def main(): stations.index(station) + 1, len(stations), ) - station_ds = stations_table.to_xarray().sel(values=[station]) - station_ds = station_ds.set_coords(("latitude", "longitude", "station_name")) - station_ds = station_ds.drop_vars(list(station_ds.data_vars)) + lat, lon = catalog_lookup[station] + station_ds = xr.Dataset( + coords={ + "values": [station], + "latitude": ("values", [lat]), + "longitude": ("values", [lon]), + } + ) forecast_station_ds = map_forecast_to_truth(forecast_ds, station_ds) analysis_station_ds = map_forecast_to_truth(analysis_ds, station_ds) From a2e4e9910b8e332c494a85e706a2f1e308b7e469 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Fri, 12 Jun 2026 19:30:12 +0200 Subject: [PATCH 14/21] remove truth hash, as this is redundant (and not used) --- workflow/rules/plot.smk | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 5b5a7c97..b855662e 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -32,9 +32,8 @@ rule plot_meteogram: output: expand( OUT_ROOT - / "results/{{showcase}}/{{run_id}}/{{init_time}}/{{init_time}}_{{param}}_{sta}_{truth_hash}.png", + / "results/{{showcase}}/{{run_id}}/{{init_time}}/{{init_time}}_{{param}}_{sta}.png", sta=config["showcase"]["meteograms"]["stations"], - truth_hash=TRUTH_HASH, ), log: OUT_ROOT / "logs/{showcase}/{run_id}/{init_time}/plot_meteogram_{param}.log", From 62c6cf8f514baacba088c825e97bb9c2609cdaa3 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 15 Jun 2026 07:05:06 +0200 Subject: [PATCH 15/21] only retrieve necessary timesteps --- src/data_input/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index ac486264..5e431533 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -354,12 +354,13 @@ def load_obs_data_from_jretrieve( stations=stations, params=short_names, seq_type=seq_type, stage=stage ) ) + step_hours = (steps[1] - steps[0]) if len(steps) > 1 else 1 df = jr.fetch_data( stations=stations, params=short_names, start=start, end=end, - increment_minutes=60, + increment_minutes=step_hours * 60, seq_type=seq_type, stage=stage, ) From 4c92452c81be471ace91212869c0755ae0b8e388 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 15 Jun 2026 07:08:57 +0200 Subject: [PATCH 16/21] fail with error if not all time steps are available --- src/data_input/__init__.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 5e431533..0317acfd 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -20,16 +20,21 @@ ZERO_KELVIN = -273.15 # °C -def _select_valid_times(ds, times: np.datetime64): +def _select_valid_times(ds, times: np.datetime64, strict: bool = False): # (handle special case where some valid times are not in the dataset, e.g. at the end) times_np = np.asarray(times, dtype="datetime64[ns]") times_included = np.isin(times_np, ds.time.values) if times_included.all(): return ds.sel(time=times_np) elif times_included.any(): + missing = times_np[~times_included] + if strict: + raise ValueError( + f"Some valid times are not included in the dataset:\n{missing}" + ) LOG.warning( "Some valid times are not included in the dataset: \n%s", - times_np[~times_included], + missing, ) return ds.sel(time=times_np[times_included]) else: @@ -385,7 +390,7 @@ def load_obs_data_from_jretrieve( out = out.dropna("values", how="all") times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") - return _select_valid_times(out, times) + return _select_valid_times(out, times, strict=True) def load_truth_data( From 0add96573e2cd6e361cd0277329ea5e10eb1f085 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 15 Jun 2026 08:02:28 +0200 Subject: [PATCH 17/21] use dedicated varda credentials for jretrieve --- .jretrievedwh-conf.prod.py | 48 +++++++++++++++++++++++ README.md | 27 +++++++++++-- src/data_input/jretrieve.py | 76 +++++++++++++++---------------------- 3 files changed, 101 insertions(+), 50 deletions(-) create mode 100644 .jretrievedwh-conf.prod.py diff --git a/.jretrievedwh-conf.prod.py b/.jretrievedwh-conf.prod.py new file mode 100644 index 00000000..abafe594 --- /dev/null +++ b/.jretrievedwh-conf.prod.py @@ -0,0 +1,48 @@ +import base64 +import json +import os +import urllib.request +from pathlib import Path + + +def _read_dotenv(path): + result = {} + try: + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + result[key.strip()] = value.strip().strip('"').strip("'") + except FileNotFoundError: + pass + return result + + +_client_id = os.environ.get("JRETRIEVE_CLIENT_ID") +_client_secret = os.environ.get("JRETRIEVE_CLIENT_SECRET") +if not _client_id or not _client_secret: + _dotenv = _read_dotenv(Path(os.environ.get("JRETRIEVE_CONF_DIR", ".")) / ".env") + _client_id = _client_id or _dotenv.get("JRETRIEVE_CLIENT_ID") + _client_secret = _client_secret or _dotenv.get("JRETRIEVE_CLIENT_SECRET") +if not _client_id or not _client_secret: + raise RuntimeError( + "jretrieve credentials not found. Set JRETRIEVE_CLIENT_ID and " + "JRETRIEVE_CLIENT_SECRET in the environment or in a .env file next " + "to this script." + ) + +jretrieve_url = "https://service.meteoswiss.ch/jretrieve/api/v1" +with urllib.request.urlopen( + urllib.request.Request( + method="POST", + url="https://service.meteoswiss.ch/auth/realms/meteoswiss.ch/protocol/openid-connect/token", + data=b"grant_type=client_credentials", + headers={ + b"Authorization": b"Basic " + + base64.b64encode(f"{_client_id}:{_client_secret}".encode()) + }, + ) +) as f: + auth_header = "Bearer " + json.loads(f.read().decode())["access_token"] diff --git a/README.md b/README.md index 2a26b0a5..a9ba3440 100644 --- a/README.md +++ b/README.md @@ -166,10 +166,10 @@ The `truth.root` value selects how the ground truth is loaded: # append ;stage=devt to target a non-prod DWH stage (prod|depl|devt) ``` - **Prerequisites:** `jretrievedwh.py` must be on `$PATH` (e.g. - `/oprusers/osm/opr.inn/bin`) and `$OPR_HOME` set with a readable - `.jretrievedwh-conf..py` conf file. No data is pre-downloaded — the obs are - queried at verification time. + **Prerequisites:** `jretrievedwh.py` must be on `$PATH` (falls back to + `/oprusers/osm/opr.inn/bin/jretrievedwh.py`) and DWH credentials must be + available — see [Credentials setup](#credentials-setup). No data is + pre-downloaded — the obs are queried at verification time. ## Installation @@ -201,6 +201,25 @@ valid for 30 days. Every training or evaluation run within this period automatic extends the token by another 30 days. It’s good practice to run the login command before executing the workflow to ensure your token is still valid. +### DWH / jretrieve credentials + +To use `jretrievedwh:` as a truth source, provide a client ID and secret for +the MeteoSwiss service account. Set them as environment variables: + +```bash +export JRETRIEVE_CLIENT_ID=your-client-id +export JRETRIEVE_CLIENT_SECRET=your-client-secret +``` + +or place them in a `.env` file at the project root (next to `.jretrievedwh-conf.prod.py`): + +``` +JRETRIEVE_CLIENT_ID=your-client-id +JRETRIEVE_CLIENT_SECRET=your-client-secret +``` + +The `.env` file is listed in `.gitignore` and is never committed. + ## Workspace setup By default, data produced by the workflow will be stored under `output/` in your working directory. diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py index 9aa91202..cbf00ad9 100644 --- a/src/data_input/jretrieve.py +++ b/src/data_input/jretrieve.py @@ -1,9 +1,10 @@ """Subprocess wrapper around ``jretrievedwh.py`` for retrieving SwissMetNet (SMN) surface observations from the MeteoSwiss data warehouse (DWH). -Ported/adapted from MeteoSwiss/anemoi-plugins-meteoswiss (add-synop-dwh-source). -Requires ``jretrievedwh.py`` on $PATH and $OPR_HOME set with a readable -``.jretrievedwh-conf..py`` conf file. +Requires ``jretrievedwh.py`` (resolved via $PATH, $OPR_HOME, or the hardcoded +fallback path) and credentials set as ``JRETRIEVE_CLIENT_ID`` / +``JRETRIEVE_CLIENT_SECRET`` in the environment or in a ``.env`` file in the +working directory. """ from __future__ import annotations @@ -16,6 +17,7 @@ from dataclasses import dataclass from datetime import datetime from io import StringIO +from pathlib import Path from typing import Any, Sequence import numpy as np @@ -24,7 +26,7 @@ LOG = logging.getLogger(__name__) BINARY_NAME = "jretrievedwh.py" -VALID_STAGES = {"prod", "depl", "devt"} +HARDCODED_BINARY_PATH = "/oprusers/osm/opr.inn/bin/jretrievedwh.py" DEFAULT_META_FIELDS: tuple[str, ...] = ("lat", "lon", "elev", "name", "nat_abbr") DEFAULT_GROUP = "SwissMetNet" CATALOG_TIME_RANGE_START = datetime(1900, 1, 1) @@ -37,30 +39,22 @@ class JretrieveError(RuntimeError): def _resolve_binary() -> str: path = shutil.which(BINARY_NAME) - if path is None: - raise JretrieveError( - f"{BINARY_NAME} not found in $PATH. " - "Make sure /oprusers/osm/opr.inn/bin (or equivalent) is on your PATH." - ) - return path + if path is not None: + return path + if os.path.isfile(HARDCODED_BINARY_PATH): + return HARDCODED_BINARY_PATH + raise JretrieveError( + f"{BINARY_NAME} not found on $PATH or at {HARDCODED_BINARY_PATH}." + ) def _build_env(stage: str) -> dict[str, str]: - if stage not in VALID_STAGES: - raise ValueError( - f"Invalid stage {stage!r}. Must be one of {sorted(VALID_STAGES)}." - ) - opr_home = os.environ.get("OPR_HOME") - if not opr_home: - raise JretrieveError("OPR_HOME is not set; cannot locate jretrieve conf file.") - conf_name = f".jretrievedwh-conf.{stage}.py" - conf_path = os.path.join(opr_home, conf_name) - if not os.path.isfile(conf_path): - raise JretrieveError(f"jretrieve conf file not found: {conf_path}") - if not os.access(conf_path, os.R_OK): - raise JretrieveError(f"jretrieve conf file not readable: {conf_path}") + if stage != "prod": + raise ValueError(f"Only 'prod' stage is supported, got {stage!r}.") + conf_dir = str(Path(__file__).parents[2]) # project root + conf_name = ".jretrievedwh-conf.prod.py" env = os.environ.copy() - env["JRETRIEVE_CONF_DIR"] = opr_home + env["JRETRIEVE_CONF_DIR"] = conf_dir env["JRETRIEVE_CONF_NAME"] = conf_name return env @@ -68,30 +62,20 @@ def _build_env(stage: str) -> dict[str, str]: def check_prerequisites(stage: str = "prod") -> None: """Fail-fast validation that the jretrievedwh environment is usable. - Verifies the CLI is on $PATH, $OPR_HOME is set, and the conf file for - ``stage`` exists and is readable. Raises a single ``JretrieveError`` listing - *all* problems found, so a misconfigured environment is reported up front - (e.g. at workflow launch) instead of hours later inside the verification job. + Checks the binary is reachable and credentials are available. Raises a + single ``JretrieveError`` listing *all* problems found, so a misconfigured + environment is reported up front rather than hours into a verification job. """ problems: list[str] = [] - if shutil.which(BINARY_NAME) is None: - problems.append( - f"{BINARY_NAME} not found in $PATH " - "(e.g. add /oprusers/osm/opr.inn/bin to $PATH)." - ) - opr_home = os.environ.get("OPR_HOME") - if not opr_home: - problems.append("$OPR_HOME is not set.") - elif stage not in VALID_STAGES: - problems.append( - f"Invalid stage {stage!r}; must be one of {sorted(VALID_STAGES)}." - ) - else: - conf_path = os.path.join(opr_home, f".jretrievedwh-conf.{stage}.py") - if not os.path.isfile(conf_path): - problems.append(f"jretrieve conf file not found: {conf_path}") - elif not os.access(conf_path, os.R_OK): - problems.append(f"jretrieve conf file not readable: {conf_path}") + if stage != "prod": + problems.append(f"Only 'prod' stage is supported, got {stage!r}.") + try: + _resolve_binary() + except JretrieveError as e: + problems.append(str(e)) + conf_path = Path(__file__).parents[2] / ".jretrievedwh-conf.prod.py" + if not conf_path.is_file(): + problems.append(f"jretrieve conf file not found: {conf_path}") if problems: raise JretrieveError( "jretrievedwh prerequisites not met:\n - " + "\n - ".join(problems) From ae0a8c767c87883540663db5391aff29150cbdd1 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 15 Jun 2026 08:44:25 +0200 Subject: [PATCH 18/21] update dependencies --- uv.lock | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/uv.lock b/uv.lock index 8a509618..1bb467be 100644 --- a/uv.lock +++ b/uv.lock @@ -20,6 +20,7 @@ dependencies = [ sdist = { url = "https://files.pythonhosted.org/packages/4c/d4/6585f3b6fdb75648bca294664af4becc8aa2fb3fb08f4e4e9fd27e10d773/adjusttext-1.3.0.tar.gz", hash = "sha256:4ab75cd4453af4828876ac3e964f2c49be642ea834f0c1f7449558d5f12cbca1", size = 15724, upload-time = "2024-10-31T16:45:36.101Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/53/1c/8feedd607cc14c5df9aef74fe3af9a99bf660743b842a9b5b1865326b4aa/adjustText-1.3.0-py3-none-any.whl", hash = "sha256:da23d7b24b6db5ffa039bb136bfa556207365e32f48ac74b07ad26dd485bc691", size = 13154, upload-time = "2024-10-31T16:45:35.227Z" }, + { url = "https://files.pythonhosted.org/packages/2d/80/7ad35ee5321a86b842f9e8516c8ae4c86f58db7b40e82ce9759f94517a50/adjusttext-1.3.0-py3-none-any.whl", hash = "sha256:bc6c118cd9d7caf6ae37f9355e51d840a2d7f64b4fb2956b8401de27c5af803b", size = 13264, upload-time = "2026-06-08T16:40:05.041Z" }, ] [[package]] @@ -1269,7 +1270,6 @@ dependencies = [ { name = "marimo" }, { name = "mlflow" }, { name = "netcdf4" }, - { name = "peakweather" }, { name = "pydantic" }, { name = "pyproj" }, { name = "pyzmq" }, @@ -1302,7 +1302,6 @@ requires-dist = [ { name = "marimo", specifier = ">=0.23.3" }, { name = "mlflow", specifier = ">=3.1.1" }, { name = "netcdf4", specifier = ">=1.7.2" }, - { name = "peakweather", git = "https://github.com/MeteoSwiss/PeakWeather.git" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "pyproj", specifier = ">=3.7.2" }, { name = "pyzmq", specifier = ">=27.1.0" }, @@ -3114,17 +3113,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/20/f2b98b18200c304f04f7839732298a786d121dba6b7cd79aa406c8c9000d/pdbufr-0.14.2-py3-none-any.whl", hash = "sha256:8d9eb74e65fe1b4b89ffe5e7ad8ee0e854ac2efbed5b64ac17cff287ea86a838", size = 52633, upload-time = "2026-02-26T17:03:45.16Z" }, ] -[[package]] -name = "peakweather" -version = "0.2.2" -source = { git = "https://github.com/MeteoSwiss/PeakWeather.git#65db3aab3dd8ca399de451577f4740a5f2f5c7ac" } -dependencies = [ - { name = "numpy" }, - { name = "pandas" }, - { name = "pyarrow" }, - { name = "tqdm" }, -] - [[package]] name = "pillow" version = "12.2.0" From 065a5d90d16c35e3d732bfddae005e7d379459b9 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 15 Jun 2026 09:47:07 +0200 Subject: [PATCH 19/21] fix failing test --- tests/unit/test_jretrieve.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py index 6625084a..2509ebda 100644 --- a/tests/unit/test_jretrieve.py +++ b/tests/unit/test_jretrieve.py @@ -61,22 +61,23 @@ def test_check_prerequisites_ok(monkeypatch, tmp_path): jr.check_prerequisites("prod") # should not raise -def test_check_prerequisites_missing_binary(monkeypatch, tmp_path): - conf = tmp_path / ".jretrievedwh-conf.prod.py" - conf.write_text("# conf\n") +def test_check_prerequisites_missing_binary(monkeypatch): monkeypatch.setattr(jr.shutil, "which", lambda name: None) - monkeypatch.setenv("OPR_HOME", str(tmp_path)) + monkeypatch.setattr(jr.os.path, "isfile", lambda p: False) with pytest.raises(jr.JretrieveError, match=r"\$PATH"): jr.check_prerequisites("prod") def test_check_prerequisites_aggregates_all_problems(monkeypatch): monkeypatch.setattr(jr.shutil, "which", lambda name: None) - monkeypatch.delenv("OPR_HOME", raising=False) + monkeypatch.setattr(jr.os.path, "isfile", lambda p: False) + monkeypatch.setattr(Path, "is_file", lambda self: False) with pytest.raises(jr.JretrieveError) as exc: jr.check_prerequisites("prod") msg = str(exc.value) - assert "$PATH" in msg and "$OPR_HOME" in msg # both reported, not just the first + assert ( + "$PATH" in msg and "conf file not found" in msg + ) # both reported, not just the first def _sample_meta(): From ca4a2230f40755e98817c686303431a7f518c9f8 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Thu, 18 Jun 2026 14:15:05 +0200 Subject: [PATCH 20/21] Add SP_10M to list of DWH parameters --- src/data_input/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 0317acfd..04d96cfc 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -277,6 +277,7 @@ def load_forecast_data_from_grib(files: list[Path], params: list[str]) -> xr.Dat "PMSL": "pp0qffs0", "TOT_PREC": "rre150h0", "FF_10M": "fkl010z0", + "SP_10M": "fkl010z0", "DD_10M": "dkl010z0", "VMAX_10M": "fkl010z1", } From 86d3380b633b6f342f5b64882879df25a90015bb Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Thu, 18 Jun 2026 15:10:33 +0200 Subject: [PATCH 21/21] fix README --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index a9ba3440..a764eba3 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,6 @@ The `truth.root` value selects how the ground truth is loaded: root: jretrievedwh:1,2 # stn_group_id (default) # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon - # append ;stage=devt to target a non-prod DWH stage (prod|depl|devt) ``` **Prerequisites:** `jretrievedwh.py` must be on `$PATH` (falls back to