diff --git a/.jretrievedwh-conf.prod.py b/.jretrievedwh-conf.prod.py new file mode 100644 index 00000000..abafe594 --- /dev/null +++ b/.jretrievedwh-conf.prod.py @@ -0,0 +1,48 @@ +import base64 +import json +import os +import urllib.request +from pathlib import Path + + +def _read_dotenv(path): + result = {} + try: + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + result[key.strip()] = value.strip().strip('"').strip("'") + except FileNotFoundError: + pass + return result + + +_client_id = os.environ.get("JRETRIEVE_CLIENT_ID") +_client_secret = os.environ.get("JRETRIEVE_CLIENT_SECRET") +if not _client_id or not _client_secret: + _dotenv = _read_dotenv(Path(os.environ.get("JRETRIEVE_CONF_DIR", ".")) / ".env") + _client_id = _client_id or _dotenv.get("JRETRIEVE_CLIENT_ID") + _client_secret = _client_secret or _dotenv.get("JRETRIEVE_CLIENT_SECRET") +if not _client_id or not _client_secret: + raise RuntimeError( + "jretrieve credentials not found. Set JRETRIEVE_CLIENT_ID and " + "JRETRIEVE_CLIENT_SECRET in the environment or in a .env file next " + "to this script." + ) + +jretrieve_url = "https://service.meteoswiss.ch/jretrieve/api/v1" +with urllib.request.urlopen( + urllib.request.Request( + method="POST", + url="https://service.meteoswiss.ch/auth/realms/meteoswiss.ch/protocol/openid-connect/token", + data=b"grant_type=client_credentials", + headers={ + b"Authorization": b"Basic " + + base64.b64encode(f"{_client_id}:{_client_secret}".encode()) + }, + ) +) as f: + auth_header = "Bearer " + json.loads(f.read().decode())["access_token"] diff --git a/README.md b/README.md index 0d9c9a06..a764eba3 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,31 @@ You can then run it with: evalml experiment path/to/experiment/config.yaml --report ``` +### Truth sources + +The `truth.root` value selects how the ground truth is loaded: + +- **Analysis Zarr** — a path ending in `.zarr` (anemoi analysis dataset). +- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching surface observations + (e.g. SMN) live from the MeteoSwiss data warehouse. Variables are mapped to + ICON names in SI units (temperatures in K, pressure in Pa, precipitation as the hourly + sum); wind `U_10M`/`V_10M` are derived from speed + direction. + + Marker syntax (station selection is required; pick one of group/locations/bbox): + + ```yaml + truth: + label: SwissMetNet + root: jretrievedwh:1,2 # stn_group_id (default) + # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list + # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon + ``` + + **Prerequisites:** `jretrievedwh.py` must be on `$PATH` (falls back to + `/oprusers/osm/opr.inn/bin/jretrievedwh.py`) and DWH credentials must be + available — see [Credentials setup](#credentials-setup). No data is + pre-downloaded — the obs are queried at verification time. + ## Installation @@ -175,6 +200,25 @@ valid for 30 days. Every training or evaluation run within this period automatic extends the token by another 30 days. It’s good practice to run the login command before executing the workflow to ensure your token is still valid. +### DWH / jretrieve credentials + +To use `jretrievedwh:` as a truth source, provide a client ID and secret for +the MeteoSwiss service account. Set them as environment variables: + +```bash +export JRETRIEVE_CLIENT_ID=your-client-id +export JRETRIEVE_CLIENT_SECRET=your-client-secret +``` + +or place them in a `.env` file at the project root (next to `.jretrievedwh-conf.prod.py`): + +``` +JRETRIEVE_CLIENT_ID=your-client-id +JRETRIEVE_CLIENT_SECRET=your-client-secret +``` + +The `.env` file is listed in `.gitignore` and is never committed. + ## Workspace setup By default, data produced by the workflow will be stored under `output/` in your working directory. diff --git a/config/varda-single-1.0.yaml b/config/varda-single-1.0.yaml index b9515033..9bbacab7 100644 --- a/config/varda-single-1.0.yaml +++ b/config/varda-single-1.0.yaml @@ -37,7 +37,12 @@ runs: truth: label: SwissMetNet - root: output/data/observations/peakweather + root: jretrievedwh:1,2 + # To verify against SwissMetNet observations from the DWH via jretrievedwh, + # set instead (requires jretrievedwh.py on $PATH and $OPR_HOME set): + # Other selectors: root: jretrievedwh:locations=ARO,KLO,LUG + # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 + # append ;stage=devt to target a non-prod DWH stage experiment: params: diff --git a/pyproject.toml b/pyproject.toml index d3279e91..d8552f6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ dependencies = [ "pyproj>=3.7.2", "marimo>=0.23.3", "geopandas>=0.14.0", - "peakweather", "pyzmq>=27.1.0", "scores>=2.0.0", "eccodes>=2.44,<2.48", @@ -60,5 +59,4 @@ packages = [ ] [tool.uv.sources] -peakweather = { git = "https://github.com/MeteoSwiss/PeakWeather.git" } eccodes-cosmo-resources-python = { git = "https://github.com/MeteoSwiss/eccodes-cosmo-resources-python" } diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 6c14005a..04d96cfc 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -5,6 +5,7 @@ import earthkit.data as ekd import numpy as np +import pandas as pd import xarray as xr from pyproj import Transformer @@ -19,16 +20,21 @@ ZERO_KELVIN = -273.15 # °C -def _select_valid_times(ds, times: np.datetime64): +def _select_valid_times(ds, times: np.datetime64, strict: bool = False): # (handle special case where some valid times are not in the dataset, e.g. at the end) times_np = np.asarray(times, dtype="datetime64[ns]") times_included = np.isin(times_np, ds.time.values) if times_included.all(): return ds.sel(time=times_np) elif times_included.any(): + missing = times_np[~times_included] + if strict: + raise ValueError( + f"Some valid times are not included in the dataset:\n{missing}" + ) LOG.warning( "Some valid times are not included in the dataset: \n%s", - times_np[~times_included], + missing, ) return ds.sel(time=times_np[times_included]) else: @@ -264,71 +270,134 @@ def load_forecast_data_from_grib(files: list[Path], params: list[str]) -> xr.Dat return ds -def load_obs_data_from_peakweather( - root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h" -) -> xr.Dataset: - """Load PeakWeather station observations into an xarray Dataset. +DWH_PARAM_MAP = { + "T_2M": "tre200s0", + "TD_2M": "tde200s0", + "PS": "prestas0", + "PMSL": "pp0qffs0", + "TOT_PREC": "rre150h0", + "FF_10M": "fkl010z0", + "SP_10M": "fkl010z0", + "DD_10M": "dkl010z0", + "VMAX_10M": "fkl010z1", +} +DWH_WIND_SPEED = "fkl010z0" +DWH_WIND_DIR = "dkl010z0" +DWH_CELSIUS_TO_KELVIN = {"tre200s0", "tde200s0"} +DWH_HPA_TO_PA = {"prestas0", "pp0qffs0"} + + +def _jretrieve_df_to_xarray(df, short_names, catalog) -> xr.Dataset: + """Pivot long-form jretrieve obs into a (time, values) cube aligned to the + catalog, NaN-filled for missing cells.""" + station_to_idx = {sid: i for i, sid in enumerate(catalog.station_id)} + if df.empty: + time_index = pd.DatetimeIndex([]) + else: + df = df.copy() + df["time"] = pd.to_datetime(df["termin"].astype(str), format="%Y%m%d%H%M%S") + time_index = pd.DatetimeIndex(sorted(df["time"].unique())) + n_t, n_s = len(time_index), catalog.n + coords = { + "time": ("time", time_index.values.astype("datetime64[ns]")), + "values": ("values", catalog.nat_abbr), + "latitude": ("values", catalog.latitude), + "longitude": ("values", catalog.longitude), + } + data_vars: dict[str, tuple] = {} + if df.empty: + for p in short_names: + data_vars[p] = (("time", "values"), np.full((n_t, n_s), np.nan, np.float32)) + else: + time_to_idx = {t: i for i, t in enumerate(time_index)} + df["_si"] = df["station"].map(station_to_idx) + df["_ti"] = df["time"].map(time_to_idx) + df = df.dropna(subset=["_si", "_ti"]) + df["_si"] = df["_si"].astype(int) + df["_ti"] = df["_ti"].astype(int) + for p in short_names: + arr = np.full((n_t, n_s), np.nan, dtype=np.float32) + if p in df.columns: + arr[df["_ti"].to_numpy(), df["_si"].to_numpy()] = df[p].to_numpy( + dtype=np.float32 + ) + data_vars[p] = (("time", "values"), arr) + return xr.Dataset(data_vars=data_vars, coords=coords) + - Returns a Dataset with dimensions `time` and `values`, values coordinates - (`lat`, `lon`), and variables renamed to ICON parameter names. - Temperatures are converted to Kelvin when present. +def load_obs_data_from_jretrieve( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load SwissMetNet (SMN) surface observations from the DWH via jretrievedwh. + + ``root`` is a marker string selecting stations, e.g. ``jretrievedwh:SwissMetNet`` + (default group), ``jretrievedwh:locations=ARO,KLO``, or + ``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` (optionally ``;stage=devt``). Returns + a Dataset with dims (time, values), values=nat_abbr, latitude/longitude coords, + variables renamed to ICON names in SI units (T/TD in Kelvin, pressure in Pa). + Only the requested hourly valid times are kept. """ - from peakweather.dataset import PeakWeatherDataset - - param_names = { - "temperature": "T_2M", - "wind_u": "U_10M", - "wind_v": "V_10M", - "precipitation": "TOT_PREC", - "pressure": "PS", - "wind_gust": "VMAX_10M", - } - param_names = {k: v for k, v in param_names.items() if v in params} + from data_input import jretrieve as jr + + stations, stage, seq_type = jr.parse_selection(root) + jr.check_prerequisites(stage) + + want_uv = "U_10M" in params or "V_10M" in params + short_names: list[str] = [DWH_PARAM_MAP[p] for p in params if p in DWH_PARAM_MAP] + if want_uv: + short_names += [DWH_WIND_SPEED, DWH_WIND_DIR] + short_names = list(dict.fromkeys(short_names)) + if not short_names: + raise ValueError(f"No DWH parameter mapping for requested params: {params}") + start = reftime end = start + timedelta(hours=max(steps)) if len(steps) > 1: - end += timedelta(hours=steps[-1] - steps[-2]) # extend by 1 extra step - years = list(set([start.year, end.year])) - if "wind_u" in param_names or "wind_v" in param_names: - compute_uv = True - else: - compute_uv = False - pw = PeakWeatherDataset(root=root, years=years, freq=freq, compute_uv=compute_uv) - ds, mask = pw.get_observations( - parameters=[k for k in param_names.keys()], - first_date=f"{start:%Y-%m-%d %H:%M}", - last_date=f"{end:%Y-%m-%d %H:%M}", - return_mask=True, - ) - ds = ( - ds.stack(["nat_abbr", "name"], future_stack=True) - .to_xarray() - .to_dataset(dim="name") + end += timedelta(hours=steps[-1] - steps[-2]) + + catalog = jr.StationCatalog.from_meta( + jr.fetch_meta( + stations=stations, params=short_names, seq_type=seq_type, stage=stage + ) ) - mask = ( - mask.stack(["nat_abbr", "name"], future_stack=True) - .to_xarray() - .to_dataset(dim="name") + step_hours = (steps[1] - steps[0]) if len(steps) > 1 else 1 + df = jr.fetch_data( + stations=stations, + params=short_names, + start=start, + end=end, + increment_minutes=step_hours * 60, + seq_type=seq_type, + stage=stage, ) - ds = ds.where(mask) - ds = ds.rename({"datetime": "time", "nat_abbr": "values"}) - ds = ds.rename(param_names) - ds = ds.assign_coords(time=ds.indexes["time"].tz_convert("UTC").tz_localize(None)) - ds = ds.assign_coords(values=ds.indexes["values"]) - ds = ds.assign_coords(longitude=("values", pw.stations_table["longitude"])) - ds = ds.assign_coords(latitude=("values", pw.stations_table["latitude"])) - if "T_2M" in ds: - ds["T_2M"] = ds["T_2M"] - ZERO_KELVIN # convert to Kelvin - ds = ds.dropna("values", how="all") + raw = _jretrieve_df_to_xarray(df, short_names, catalog) + + out = xr.Dataset(coords=raw.coords) + for icon, short in DWH_PARAM_MAP.items(): + if icon in params and short in raw: + var = raw[short] + if short in DWH_CELSIUS_TO_KELVIN: + var = var - ZERO_KELVIN + elif short in DWH_HPA_TO_PA: + var = var * 100.0 + out[icon] = var + if want_uv and DWH_WIND_SPEED in raw and DWH_WIND_DIR in raw: + ff = raw[DWH_WIND_SPEED] + dd_rad = np.deg2rad(raw[DWH_WIND_DIR]) + if "U_10M" in params: + out["U_10M"] = -ff * np.sin(dd_rad) + if "V_10M" in params: + out["V_10M"] = -ff * np.cos(dd_rad) + out = out.dropna("values", how="all") times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") - return _select_valid_times(ds, times) + return _select_valid_times(out, times, strict=True) def load_truth_data( root, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: - """Load truth data from analysis Zarr or PeakWeather observations.""" + """Load truth data from an analysis Zarr dataset or DWH observations via jretrieve.""" if root.suffix == ".zarr": LOG.info("Loading ground truth from an analysis zarr dataset...") truth = load_analysis_data_from_zarr( @@ -342,9 +411,9 @@ def load_truth_data( if "y" in truth.dims and "x" in truth.dims else {"values": -1} ) - elif "peakweather" in str(root): - LOG.info("Loading ground truth from PeakWeather observations...") - truth = load_obs_data_from_peakweather( + elif "jretrieve" in str(root): + LOG.info("Loading ground truth from JRetrieve...") + truth = load_obs_data_from_jretrieve( root=root, reftime=reftime, steps=steps, diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py new file mode 100644 index 00000000..cbf00ad9 --- /dev/null +++ b/src/data_input/jretrieve.py @@ -0,0 +1,301 @@ +"""Subprocess wrapper around ``jretrievedwh.py`` for retrieving SwissMetNet +(SMN) surface observations from the MeteoSwiss data warehouse (DWH). + +Requires ``jretrievedwh.py`` (resolved via $PATH, $OPR_HOME, or the hardcoded +fallback path) and credentials set as ``JRETRIEVE_CLIENT_ID`` / +``JRETRIEVE_CLIENT_SECRET`` in the environment or in a ``.env`` file in the +working directory. +""" + +from __future__ import annotations + +import logging +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from datetime import datetime +from io import StringIO +from pathlib import Path +from typing import Any, Sequence + +import numpy as np +import pandas as pd + +LOG = logging.getLogger(__name__) + +BINARY_NAME = "jretrievedwh.py" +HARDCODED_BINARY_PATH = "/oprusers/osm/opr.inn/bin/jretrievedwh.py" +DEFAULT_META_FIELDS: tuple[str, ...] = ("lat", "lon", "elev", "name", "nat_abbr") +DEFAULT_GROUP = "SwissMetNet" +CATALOG_TIME_RANGE_START = datetime(1900, 1, 1) +CATALOG_TIME_RANGE_END = datetime(2100, 12, 31, 23, 59) + + +class JretrieveError(RuntimeError): + """Raised when jretrievedwh.py fails or returns malformed output.""" + + +def _resolve_binary() -> str: + path = shutil.which(BINARY_NAME) + if path is not None: + return path + if os.path.isfile(HARDCODED_BINARY_PATH): + return HARDCODED_BINARY_PATH + raise JretrieveError( + f"{BINARY_NAME} not found on $PATH or at {HARDCODED_BINARY_PATH}." + ) + + +def _build_env(stage: str) -> dict[str, str]: + if stage != "prod": + raise ValueError(f"Only 'prod' stage is supported, got {stage!r}.") + conf_dir = str(Path(__file__).parents[2]) # project root + conf_name = ".jretrievedwh-conf.prod.py" + env = os.environ.copy() + env["JRETRIEVE_CONF_DIR"] = conf_dir + env["JRETRIEVE_CONF_NAME"] = conf_name + return env + + +def check_prerequisites(stage: str = "prod") -> None: + """Fail-fast validation that the jretrievedwh environment is usable. + + Checks the binary is reachable and credentials are available. Raises a + single ``JretrieveError`` listing *all* problems found, so a misconfigured + environment is reported up front rather than hours into a verification job. + """ + problems: list[str] = [] + if stage != "prod": + problems.append(f"Only 'prod' stage is supported, got {stage!r}.") + try: + _resolve_binary() + except JretrieveError as e: + problems.append(str(e)) + conf_path = Path(__file__).parents[2] / ".jretrievedwh-conf.prod.py" + if not conf_path.is_file(): + problems.append(f"jretrieve conf file not found: {conf_path}") + if problems: + raise JretrieveError( + "jretrievedwh prerequisites not met:\n - " + "\n - ".join(problems) + ) + + +def _fmt_time(dt: datetime) -> str: + return dt.strftime("%Y%m%d%H%M") + + +def _stations_to_argv(stations: dict[str, Any]) -> list[str]: + """Translate a station selection dict to jretrieve CLI args. + + Exactly one of {group, locations, bbox} must be set. + """ + keys = [k for k in ("group", "locations", "bbox") if stations.get(k) is not None] + if len(keys) != 1: + raise ValueError( + f"stations must specify exactly one of group/locations/bbox, got {keys}" + ) + key = keys[0] + val = stations[key] + if key == "group": + return ["-a", f"stn_group_id,{val}"] + if key == "locations": + if isinstance(val, str): + val = [v for v in val.split(",") if v] + if not isinstance(val, Sequence): + raise ValueError("stations.locations must be a list of nat_abbr strings.") + return ["-i", "nat_abbr," + ",".join(str(v) for v in val)] + if key == "bbox": + if isinstance(val, str): + val = [v for v in val.split(",") if v] + if len(val) != 4: + raise ValueError("stations.bbox must be [minlat, maxlat, minlon, maxlon].") + return ["-l", ",".join(str(v) for v in val)] + raise AssertionError("unreachable") + + +def parse_selection(root: Any) -> tuple[dict[str, Any], str, str]: + """Parse a truth-root marker into (stations, stage, seq_type). + + Examples (slash-free so they survive ``Path()`` normalisation): + ``jretrievedwh:SwissMetNet`` -> group + ``jretrievedwh:group=SwissMetNet;stage=devt`` + ``jretrievedwh:locations=ARO,KLO`` + ``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` + """ + _, _, rest = str(root).partition(":") + rest = rest.strip() + stations: dict[str, Any] = {} + stage = "prod" + seq_type = "surface" + for i, part in enumerate([p for p in rest.split(";") if p]): + if "=" not in part: + if i == 0: + stations["group"] = part + continue + raise ValueError(f"Invalid jretrieve selector fragment: {part!r}") + key, _, value = part.partition("=") + key, value = key.strip(), value.strip() + if key in ("group", "locations", "bbox"): + stations[key] = value + elif key == "stage": + stage = value + elif key == "seq_type": + seq_type = value + else: + raise ValueError(f"Unknown jretrieve selector key: {key!r}") + if not stations: + stations = {"group": DEFAULT_GROUP} + return stations, stage, seq_type + + +def _run(argv: list[str], env: dict[str, str], timeout_s: int) -> str: + try: + proc = subprocess.run( + argv, + env=env, + capture_output=True, + text=True, + timeout=timeout_s, + check=False, + ) + except subprocess.TimeoutExpired as e: + raise JretrieveError( + f"jretrieve timed out after {timeout_s}s: {' '.join(argv)}" + ) from e + if proc.returncode != 0: + raise JretrieveError( + f"jretrieve exited with {proc.returncode}\nargv: {argv}\n" + f"stderr: {proc.stderr.strip()}\nstdout (head): {proc.stdout[:500]}" + ) + if proc.stdout.lstrip().startswith("ERROR"): + raise JretrieveError(f"jretrieve returned error: {proc.stdout.strip()[:500]}") + return proc.stdout + + +def _run_with_retry(argv, env, timeout_s, attempts=3) -> str: + last_err: Exception | None = None + for attempt in range(1, attempts + 1): + try: + return _run(argv, env=env, timeout_s=timeout_s) + except JretrieveError as e: + last_err = e + if attempt == attempts: + break + backoff = 2**attempt + LOG.warning( + "jretrieve attempt %d/%d failed (%s); retrying in %ds", + attempt, + attempts, + e, + backoff, + ) + time.sleep(backoff) + assert last_err is not None + raise last_err + + +def _parse_csv(csv_text: str) -> pd.DataFrame: + csv_text = csv_text.strip() + if not csv_text: + return pd.DataFrame() + return pd.read_csv(StringIO(csv_text), sep=";") + + +def fetch_meta( + *, + stations, + params, + seq_type="surface", + stage="prod", + meta_fields=DEFAULT_META_FIELDS, + timeout_s=300, +) -> pd.DataFrame: + """Fetch the station catalog (rows per station x parameter x period) over a + fixed wide time range so the response is deterministic.""" + if not params: + raise ValueError("params must be non-empty.") + argv = [ + _resolve_binary(), + "-s", + seq_type, + "-n", + ",".join(params), + "-t", + f"{_fmt_time(CATALOG_TIME_RANGE_START)},{_fmt_time(CATALOG_TIME_RANGE_END)}", + "--meta-info", + ",".join(meta_fields), + "--format", + "csv", + *_stations_to_argv(stations), + ] + LOG.info("jretrieve meta: %s", " ".join(argv)) + df = _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) + if df.empty: + raise JretrieveError("jretrieve meta-info returned no rows.") + return df + + +def fetch_data( + *, + stations, + params, + start, + end, + increment_minutes=60, + seq_type="surface", + stage="prod", + timeout_s=600, +) -> pd.DataFrame: + """Fetch observation data; columns: station (int), termin (YYYYMMDDhhmmss), + one column per requested short name.""" + if not params: + raise ValueError("params must be non-empty.") + argv = [ + _resolve_binary(), + "-s", + seq_type, + "-n", + ",".join(params), + "-t", + f"{_fmt_time(start)},{_fmt_time(end)},{int(increment_minutes)}", + "--format", + "csv", + *_stations_to_argv(stations), + ] + LOG.info("jretrieve data: %s", " ".join(argv)) + return _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) + + +@dataclass(frozen=True) +class StationCatalog: + """Stable, nat_abbr-sorted station catalog used as the cell axis.""" + + nat_abbr: np.ndarray + station_id: np.ndarray + latitude: np.ndarray + longitude: np.ndarray + elevation: np.ndarray + name: np.ndarray + + @property + def n(self) -> int: + return len(self.nat_abbr) + + @classmethod + def from_meta(cls, meta: pd.DataFrame) -> "StationCatalog": + per_station = ( + meta.sort_values(["nat_abbr", "parameter", "op_since"], kind="stable") + .drop_duplicates(subset=["station"], keep="first") + .sort_values("nat_abbr", kind="stable") + .reset_index(drop=True) + ) + return cls( + nat_abbr=per_station["nat_abbr"].to_numpy(dtype=object), + station_id=per_station["station"].to_numpy(dtype=np.int64), + latitude=per_station["latitude"].to_numpy(dtype=np.float64), + longitude=per_station["longitude"].to_numpy(dtype=np.float64), + elevation=per_station["elev"].to_numpy(dtype=np.float64), + name=per_station["stn_name"].to_numpy(dtype=object), + ) diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py new file mode 100644 index 00000000..2509ebda --- /dev/null +++ b/tests/unit/test_jretrieve.py @@ -0,0 +1,169 @@ +from datetime import datetime +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +import data_input +from data_input import jretrieve as jr + + +def test_stations_to_argv_group(): + assert jr._stations_to_argv({"group": "1,2"}) == [ + "-a", + "stn_group_id,1,2", + ] + + +def test_stations_to_argv_locations_from_string(): + assert jr._stations_to_argv({"locations": "ARO,KLO"}) == ["-i", "nat_abbr,ARO,KLO"] + + +def test_stations_to_argv_bbox_from_string(): + assert jr._stations_to_argv({"bbox": "45.8,47.8,5.9,10.5"}) == [ + "-l", + "45.8,47.8,5.9,10.5", + ] + + +def test_stations_to_argv_rejects_ambiguous(): + with pytest.raises(ValueError, match="exactly one"): + jr._stations_to_argv({"group": "x", "bbox": "1,2,3,4"}) + + +def test_parse_selection_default_group(): + assert jr.parse_selection("jretrievedwh:") == ( + {"group": "SwissMetNet"}, + "prod", + "surface", + ) + assert jr.parse_selection("jretrievedwh:SwissMetNet") == ( + {"group": "SwissMetNet"}, + "prod", + "surface", + ) + + +def test_parse_selection_keyvalue_and_stage(): + assert jr.parse_selection("jretrievedwh:locations=ARO,KLO;stage=devt") == ( + {"locations": "ARO,KLO"}, + "devt", + "surface", + ) + + +def test_check_prerequisites_ok(monkeypatch, tmp_path): + conf = tmp_path / ".jretrievedwh-conf.prod.py" + conf.write_text("# conf\n") + monkeypatch.setattr(jr.shutil, "which", lambda name: "/opt/bin/jretrievedwh.py") + monkeypatch.setenv("OPR_HOME", str(tmp_path)) + jr.check_prerequisites("prod") # should not raise + + +def test_check_prerequisites_missing_binary(monkeypatch): + monkeypatch.setattr(jr.shutil, "which", lambda name: None) + monkeypatch.setattr(jr.os.path, "isfile", lambda p: False) + with pytest.raises(jr.JretrieveError, match=r"\$PATH"): + jr.check_prerequisites("prod") + + +def test_check_prerequisites_aggregates_all_problems(monkeypatch): + monkeypatch.setattr(jr.shutil, "which", lambda name: None) + monkeypatch.setattr(jr.os.path, "isfile", lambda p: False) + monkeypatch.setattr(Path, "is_file", lambda self: False) + with pytest.raises(jr.JretrieveError) as exc: + jr.check_prerequisites("prod") + msg = str(exc.value) + assert ( + "$PATH" in msg and "conf file not found" in msg + ) # both reported, not just the first + + +def _sample_meta(): + return pd.DataFrame( + { + "station": [2, 1, 1], + "op_since": [19900101000000, 19800101000000, 19800101000000], + "op_till": ["", "", ""], + "parameter": ["tre200s0", "fkl010z0", "tre200s0"], + "latitude": [47.48, 46.79, 46.79], + "longitude": [8.54, 9.68, 9.68], + "elev": [426.0, 1878.0, 1878.0], + "stn_name": ["Zurich", "Arosa", "Arosa"], + "nat_abbr": ["KLO", "ARO", "ARO"], + } + ) + + +def test_station_catalog_from_meta_collapses_and_sorts(): + cat = jr.StationCatalog.from_meta(_sample_meta()) + assert cat.n == 2 + assert list(cat.nat_abbr) == ["ARO", "KLO"] # sorted by nat_abbr + assert list(cat.station_id) == [1, 2] + np.testing.assert_allclose(cat.latitude, [46.79, 47.48]) + + +def test_load_obs_data_from_jretrieve(monkeypatch): + meta = pd.DataFrame( + { + "station": [1, 2], + "op_since": [19800101000000, 19900101000000], + "op_till": ["", ""], + "parameter": ["tre200s0", "tre200s0"], + "latitude": [46.79, 47.48], + "longitude": [9.68, 8.54], + "elev": [1878.0, 426.0], + "stn_name": ["Arosa", "Zurich"], + "nat_abbr": ["ARO", "KLO"], + } + ) + data = pd.DataFrame( + { + "station": [1, 1], + "termin": [20250115000000, 20250115010000], + "tre200s0": [10.0, 11.0], # degC + "fkl010z0": [3.0, 4.0], # m/s + "dkl010z0": [0.0, 90.0], # deg + } + ) + # Skip the live environment check so the test runs anywhere (incl. CI, + # which has no jretrievedwh.py / $OPR_HOME). + monkeypatch.setattr(jr, "check_prerequisites", lambda *a, **k: None) + monkeypatch.setattr(jr, "fetch_meta", lambda **kw: meta) + monkeypatch.setattr(jr, "fetch_data", lambda **kw: data) + + ds = data_input.load_obs_data_from_jretrieve( + "jretrievedwh:locations=ARO,KLO", + datetime(2025, 1, 15, 0, 0), + [0, 1], + ["T_2M", "U_10M", "V_10M"], + ) + + assert set(ds.dims) == {"time", "values"} + assert list(ds["values"].values) == ["ARO"] # KLO all-NaN -> dropped + assert set(ds.data_vars) == {"T_2M", "U_10M", "V_10M"} + np.testing.assert_allclose(ds["T_2M"].sel(values="ARO").values, [283.15, 284.15]) + # DD=0 -> U=0, V=-FF ; DD=90 -> U=-FF, V=0 + np.testing.assert_allclose( + ds["U_10M"].sel(values="ARO").values, [0.0, -4.0], atol=1e-5 + ) + np.testing.assert_allclose( + ds["V_10M"].sel(values="ARO").values, [-3.0, 0.0], atol=1e-5 + ) + np.testing.assert_allclose(ds["latitude"].values, [46.79]) + + +def test_load_truth_data_forwards_root(monkeypatch): + seen = {} + + def fake_loader(root, reftime, steps, params): + seen["root"] = root + return "SENTINEL" + + monkeypatch.setattr(data_input, "load_obs_data_from_jretrieve", fake_loader) + out = data_input.load_truth_data( + Path("jretrievedwh:SwissMetNet"), datetime(2025, 1, 15), [0], ["T_2M"] + ) + assert out == "SENTINEL" + assert "jretrievedwh:SwissMetNet" in str(seen["root"]) diff --git a/uv.lock b/uv.lock index 8a509618..1bb467be 100644 --- a/uv.lock +++ b/uv.lock @@ -20,6 +20,7 @@ dependencies = [ sdist = { url = "https://files.pythonhosted.org/packages/4c/d4/6585f3b6fdb75648bca294664af4becc8aa2fb3fb08f4e4e9fd27e10d773/adjusttext-1.3.0.tar.gz", hash = "sha256:4ab75cd4453af4828876ac3e964f2c49be642ea834f0c1f7449558d5f12cbca1", size = 15724, upload-time = "2024-10-31T16:45:36.101Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/53/1c/8feedd607cc14c5df9aef74fe3af9a99bf660743b842a9b5b1865326b4aa/adjustText-1.3.0-py3-none-any.whl", hash = "sha256:da23d7b24b6db5ffa039bb136bfa556207365e32f48ac74b07ad26dd485bc691", size = 13154, upload-time = "2024-10-31T16:45:35.227Z" }, + { url = "https://files.pythonhosted.org/packages/2d/80/7ad35ee5321a86b842f9e8516c8ae4c86f58db7b40e82ce9759f94517a50/adjusttext-1.3.0-py3-none-any.whl", hash = "sha256:bc6c118cd9d7caf6ae37f9355e51d840a2d7f64b4fb2956b8401de27c5af803b", size = 13264, upload-time = "2026-06-08T16:40:05.041Z" }, ] [[package]] @@ -1269,7 +1270,6 @@ dependencies = [ { name = "marimo" }, { name = "mlflow" }, { name = "netcdf4" }, - { name = "peakweather" }, { name = "pydantic" }, { name = "pyproj" }, { name = "pyzmq" }, @@ -1302,7 +1302,6 @@ requires-dist = [ { name = "marimo", specifier = ">=0.23.3" }, { name = "mlflow", specifier = ">=3.1.1" }, { name = "netcdf4", specifier = ">=1.7.2" }, - { name = "peakweather", git = "https://github.com/MeteoSwiss/PeakWeather.git" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "pyproj", specifier = ">=3.7.2" }, { name = "pyzmq", specifier = ">=27.1.0" }, @@ -3114,17 +3113,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/20/f2b98b18200c304f04f7839732298a786d121dba6b7cd79aa406c8c9000d/pdbufr-0.14.2-py3-none-any.whl", hash = "sha256:8d9eb74e65fe1b4b89ffe5e7ad8ee0e854ac2efbed5b64ac17cff287ea86a838", size = 52633, upload-time = "2026-02-26T17:03:45.16Z" }, ] -[[package]] -name = "peakweather" -version = "0.2.2" -source = { git = "https://github.com/MeteoSwiss/PeakWeather.git#65db3aab3dd8ca399de451577f4740a5f2f5c7ac" } -dependencies = [ - { name = "numpy" }, - { name = "pandas" }, - { name = "pyarrow" }, - { name = "tqdm" }, -] - [[package]] name = "pillow" version = "12.2.0" diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index c75da889..88263076 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -355,6 +355,23 @@ def truth_hash(truth_config: dict) -> str: return generate_json_hash(cfg) +def truth_file_dep(_): + """Truth file dependency: a real path for zarr, but a live-query + marker (no input file) for jretrieve.""" + root = config["truth"]["root"] + return [] if "jretrieve" in str(root) else [root] + + +# Fail fast: when the truth source is the live DWH (jretrievedwh), verify its +# prerequisites at workflow-build time so a misconfigured environment is caught +# at launch, before any (expensive) inference job runs. +if "jretrieve" in str(config["truth"]["root"]): + from data_input.jretrieve import check_prerequisites, parse_selection + + _, _jretrieve_stage, _ = parse_selection(config["truth"]["root"]) + check_prerequisites(_jretrieve_stage) + + TRUTH_HASH = truth_hash(config["truth"]) REGIONS = parse_regions() SHOWCASE_REGIONS = parse_showcase_regions() diff --git a/workflow/rules/data.smk b/workflow/rules/data.smk index 67914a9f..eb41eff9 100644 --- a/workflow/rules/data.smk +++ b/workflow/rules/data.smk @@ -4,23 +4,6 @@ from pathlib import Path include: "common.smk" -if config["truth"]["root"].endswith("peakweather"): - output_peakweather_root = config["truth"]["root"] -else: - output_peakweather_root = OUT_ROOT / "data/observations/peakweather" - - -rule data_download_obs_from_peakweather: - output: - root=directory(output_peakweather_root), - localrule: True - run: - from peakweather.dataset import PeakWeatherDataset - - # Download the data from Huggingface - ds = PeakWeatherDataset(root=output.root) - - # Grid-definition files required by earthkit/eckit to decode the ICON-CH grids. # Automatic download/caching of these is currently broken (see README), so we # fetch them into eckit's default geo grid cache under $HOME, where it finds diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 91ff1080..b855662e 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -27,8 +27,7 @@ rule plot_meteogram: input: script="workflow/scripts/plot_meteogram.py", inference_okfile=rules.inference_execute.output.okfile, - truth=config["truth"]["root"], - peakweather_dir=rules.data_download_obs_from_peakweather.output.root, + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: expand( @@ -44,6 +43,7 @@ rule plot_meteogram: runtime="60m", params: ana_label=lambda wc: config["truth"]["label"], + truth_root=config["truth"]["root"], fcst_grib=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" ).resolve(), @@ -71,9 +71,8 @@ rule plot_meteogram: --forecast {params.fcst_grib:q} --forecast_steps {params.fcst_steps:q} --forecast_label {params.fcst_label:q} - --analysis {input.truth:q} + --analysis {params.truth_root:q} --analysis_label {params.ana_label:q} - --peakweather {input.peakweather_dir:q} --date {wildcards.init_time:q} --outdir {params.outdir:q} --param {wildcards.param:q} diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 4bae88e2..07dff0cf 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -15,7 +15,7 @@ rule verification_metrics_baseline: "src/data_input/__init__.py", script="workflow/scripts/verification_metrics.py", forecast=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["root"], - truth=config["truth"]["root"], + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: OUT_ROOT / f"data/baselines/{{baseline_id}}/{{init_time}}/verif_{TRUTH_HASH}.nc", @@ -29,6 +29,7 @@ rule verification_metrics_baseline: baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"), baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], member=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("member", "000"), + truth=config["truth"]["root"], truth_label=config["truth"]["label"], regions=REGIONS, experiment_params=",".join(EXPERIMENT_PARAMS), @@ -38,7 +39,7 @@ rule verification_metrics_baseline: export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) uv run {input.script} \ --forecast {input.forecast} \ - --truth {input.truth} \ + --truth {params.truth} \ --reftime {wildcards.init_time} \ --steps "{params.baseline_steps}" \ --label "{params.baseline_label}" \ @@ -64,7 +65,7 @@ rule verification_metrics: "src/data_input/__init__.py", script="workflow/scripts/verification_metrics.py", inference_okfile=rules.inference_execute.output.okfile, - truth=config["truth"]["root"], + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: OUT_ROOT / f"data/runs/{{run_id}}/{{init_time}}/verif_{TRUTH_HASH}.nc", @@ -80,6 +81,7 @@ rule verification_metrics: params: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + truth=config["truth"]["root"], truth_label=config["truth"]["label"], regions=REGIONS, grib_out_dir=lambda wc: ( @@ -92,7 +94,7 @@ rule verification_metrics: export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) uv run {input.script} \ --forecast {params.grib_out_dir} \ - --truth {input.truth} \ + --truth {params.truth} \ --reftime {wildcards.init_time} \ --steps "{params.fcst_steps}" \ --label "{params.fcst_label}" \ diff --git a/workflow/scripts/plot_meteogram.py b/workflow/scripts/plot_meteogram.py index 8e7ac30e..43248d73 100644 --- a/workflow/scripts/plot_meteogram.py +++ b/workflow/scripts/plot_meteogram.py @@ -4,13 +4,14 @@ from pathlib import Path import matplotlib.pyplot as plt -from peakweather import PeakWeatherDataset +import xarray as xr from data_input import ( parse_steps, load_forecast_data, load_truth_data, ) +from data_input import jretrieve as jr from verification.spatial import map_forecast_to_truth LOG = logging.getLogger(__name__) @@ -96,9 +97,6 @@ def main(): default="truth", help="Label for analysis line in plot legend.", ) - parser.add_argument( - "--peakweather", type=str, default=None, help="Path to PeakWeather dataset" - ) parser.add_argument("--date", type=str, default=None, help="reference datetime") parser.add_argument("--outdir", type=str, help="output directory") parser.add_argument("--param", type=str, help="parameter") @@ -124,7 +122,6 @@ def main(): "Mismatched baseline arguments: --baseline and --baseline_label " "must be provided the same number of times." ) - peakweather_dir = Path(args.peakweather) init_time = datetime.strptime(args.date, "%Y%m%d%H%M") outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) @@ -145,6 +142,28 @@ def main(): else: paramlist = [param] + # Load station metadata from DWH + LOG.info("Fetching station metadata from jretrieve (SwissMetNet catalog)") + _jr_stations, _jr_stage, _jr_seq_type = jr.parse_selection("jretrievedwh:1,2") + _catalog = jr.StationCatalog.from_meta( + jr.fetch_meta( + stations=_jr_stations, + params=["rre150h0"], + seq_type=_jr_seq_type, + stage=_jr_stage, + ) + ) + catalog_lookup = { + abbr: (lat, lon) + for abbr, lat, lon in zip( + _catalog.nat_abbr, _catalog.latitude, _catalog.longitude + ) + } + + LOG.info("Loading analysis data from %s", analysis_root) + analysis_ds = load_truth_data(analysis_root, init_time, forecast_steps, paramlist) + analysis_ds = preprocess_ds(analysis_ds, param) + # Load gridded data once — shared across all station plots LOG.info("Loading forecast data from %s", forecast_grib_dir) forecast_ds = load_forecast_data( @@ -152,11 +171,6 @@ def main(): ) forecast_ds = preprocess_ds(forecast_ds, param) - steps = [int(s) for s in forecast_ds["step"].dt.total_seconds().values / 3600] - LOG.info("Loading analysis data from %s", analysis_root) - analysis_ds = load_truth_data(analysis_root, init_time, steps, paramlist) - analysis_ds = preprocess_ds(analysis_ds, param) - baseline_ds_list = [] for root, step, label in zip(baseline_roots, baseline_steps, baseline_labels): LOG.info("Loading baseline '%s' from %s", label, root) @@ -164,12 +178,6 @@ def main(): preprocess_ds(load_forecast_data(root, init_time, step, paramlist), param) ) - # Load station metadata once - LOG.info("Loading station metadata from %s", peakweather_dir) - peakweather = PeakWeatherDataset(root=peakweather_dir) - stations_table = peakweather.stations_table - stations_table.index.names = ["values"] - param2plot = forecast_ds[param].attrs.get("parameter", {}) short = param2plot.get("shortName", "") units = param2plot.get("units", "") @@ -183,9 +191,14 @@ def main(): stations.index(station) + 1, len(stations), ) - station_ds = stations_table.to_xarray().sel(values=[station]) - station_ds = station_ds.set_coords(("latitude", "longitude", "station_name")) - station_ds = station_ds.drop_vars(list(station_ds.data_vars)) + lat, lon = catalog_lookup[station] + station_ds = xr.Dataset( + coords={ + "values": [station], + "latitude": ("values", [lat]), + "longitude": ("values", [lon]), + } + ) forecast_station_ds = map_forecast_to_truth(forecast_ds, station_ds) analysis_station_ds = map_forecast_to_truth(analysis_ds, station_ds)