diff --git a/.jretrievedwh-conf.prod.py b/.jretrievedwh-conf.prod.py new file mode 100644 index 00000000..abafe594 --- /dev/null +++ b/.jretrievedwh-conf.prod.py @@ -0,0 +1,48 @@ +import base64 +import json +import os +import urllib.request +from pathlib import Path + + +def _read_dotenv(path): + result = {} + try: + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + result[key.strip()] = value.strip().strip('"').strip("'") + except FileNotFoundError: + pass + return result + + +_client_id = os.environ.get("JRETRIEVE_CLIENT_ID") +_client_secret = os.environ.get("JRETRIEVE_CLIENT_SECRET") +if not _client_id or not _client_secret: + _dotenv = _read_dotenv(Path(os.environ.get("JRETRIEVE_CONF_DIR", ".")) / ".env") + _client_id = _client_id or _dotenv.get("JRETRIEVE_CLIENT_ID") + _client_secret = _client_secret or _dotenv.get("JRETRIEVE_CLIENT_SECRET") +if not _client_id or not _client_secret: + raise RuntimeError( + "jretrieve credentials not found. Set JRETRIEVE_CLIENT_ID and " + "JRETRIEVE_CLIENT_SECRET in the environment or in a .env file next " + "to this script." + ) + +jretrieve_url = "https://service.meteoswiss.ch/jretrieve/api/v1" +with urllib.request.urlopen( + urllib.request.Request( + method="POST", + url="https://service.meteoswiss.ch/auth/realms/meteoswiss.ch/protocol/openid-connect/token", + data=b"grant_type=client_credentials", + headers={ + b"Authorization": b"Basic " + + base64.b64encode(f"{_client_id}:{_client_secret}".encode()) + }, + ) +) as f: + auth_header = "Bearer " + json.loads(f.read().decode())["access_token"] diff --git a/README.md b/README.md index 0d9c9a06..a764eba3 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,31 @@ You can then run it with: evalml experiment path/to/experiment/config.yaml --report ``` +### Truth sources + +The `truth.root` value selects how the ground truth is loaded: + +- **Analysis Zarr** — a path ending in `.zarr` (anemoi analysis dataset). +- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching surface observations + (e.g. SMN) live from the MeteoSwiss data warehouse. Variables are mapped to + ICON names in SI units (temperatures in K, pressure in Pa, precipitation as the hourly + sum); wind `U_10M`/`V_10M` are derived from speed + direction. + + Marker syntax (station selection is required; pick one of group/locations/bbox): + + ```yaml + truth: + label: SwissMetNet + root: jretrievedwh:1,2 # stn_group_id (default) + # root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list + # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon + ``` + + **Prerequisites:** `jretrievedwh.py` must be on `$PATH` (falls back to + `/oprusers/osm/opr.inn/bin/jretrievedwh.py`) and DWH credentials must be + available — see [Credentials setup](#credentials-setup). No data is + pre-downloaded — the obs are queried at verification time. + ## Installation @@ -175,6 +200,25 @@ valid for 30 days. Every training or evaluation run within this period automatic extends the token by another 30 days. It’s good practice to run the login command before executing the workflow to ensure your token is still valid. +### DWH / jretrieve credentials + +To use `jretrievedwh:` as a truth source, provide a client ID and secret for +the MeteoSwiss service account. Set them as environment variables: + +```bash +export JRETRIEVE_CLIENT_ID=your-client-id +export JRETRIEVE_CLIENT_SECRET=your-client-secret +``` + +or place them in a `.env` file at the project root (next to `.jretrievedwh-conf.prod.py`): + +``` +JRETRIEVE_CLIENT_ID=your-client-id +JRETRIEVE_CLIENT_SECRET=your-client-secret +``` + +The `.env` file is listed in `.gitignore` and is never committed. + ## Workspace setup By default, data produced by the workflow will be stored under `output/` in your working directory. diff --git a/config/varda-single-1.0.yaml b/config/varda-single-1.0.yaml index b9515033..9bbacab7 100644 --- a/config/varda-single-1.0.yaml +++ b/config/varda-single-1.0.yaml @@ -37,7 +37,12 @@ runs: truth: label: SwissMetNet - root: output/data/observations/peakweather + root: jretrievedwh:1,2 + # To verify against SwissMetNet observations from the DWH via jretrievedwh, + # set instead (requires jretrievedwh.py on $PATH and $OPR_HOME set): + # Other selectors: root: jretrievedwh:locations=ARO,KLO,LUG + # root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 + # append ;stage=devt to target a non-prod DWH stage experiment: params: diff --git a/pyproject.toml b/pyproject.toml index d3279e91..d8552f6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ dependencies = [ "pyproj>=3.7.2", "marimo>=0.23.3", "geopandas>=0.14.0", - "peakweather", "pyzmq>=27.1.0", "scores>=2.0.0", "eccodes>=2.44,<2.48", @@ -60,5 +59,4 @@ packages = [ ] [tool.uv.sources] -peakweather = { git = "https://github.com/MeteoSwiss/PeakWeather.git" } eccodes-cosmo-resources-python = { git = "https://github.com/MeteoSwiss/eccodes-cosmo-resources-python" } diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 6c14005a..8b6b9790 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -5,6 +5,7 @@ import earthkit.data as ekd import numpy as np +import pandas as pd import xarray as xr from pyproj import Transformer @@ -18,17 +19,33 @@ ZERO_KELVIN = -273.15 # °C +GRAVITY = 9.80665 # m/s² — standard gravity used to convert FIS geopotential to metres -def _select_valid_times(ds, times: np.datetime64): +ICON_CH1_GRID_NC = Path( + "/scratch/mch/jenkins/icon/pool/data/ICON/mch/grids/icon-1" + "/external_parameter_icon_grid_0001_R19B08_mch.nc" +) +ICON_CH2_GRID_NC = Path( + "/scratch/mch/jenkins/icon/pool/data/ICON/mch/grids/icon-2" + "/external_parameter_icon_grid_0002_R19B07_mch.nc" +) + + +def _select_valid_times(ds, times: np.datetime64, strict: bool = False): # (handle special case where some valid times are not in the dataset, e.g. at the end) times_np = np.asarray(times, dtype="datetime64[ns]") times_included = np.isin(times_np, ds.time.values) if times_included.all(): return ds.sel(time=times_np) elif times_included.any(): + missing = times_np[~times_included] + if strict: + raise ValueError( + f"Some valid times are not included in the dataset:\n{missing}" + ) LOG.warning( "Some valid times are not included in the dataset: \n%s", - times_np[~times_included], + missing, ) return ds.sel(time=times_np[times_included]) else: @@ -48,6 +65,35 @@ def parse_steps(steps: str) -> list[int]: return list(range(start, end + 1, step)) +def _load_icon_topography(grid_nc: Path) -> np.ndarray: + """Return topography_c [m] from an ICON external parameter file.""" + with xr.open_dataset(grid_nc) as ds: + return ds["topography_c"].values.astype(np.float32) + + +def _try_assign_model_elevation(ds: xr.Dataset) -> xr.Dataset: + """Attempt to attach model_elevation from a known ICON grid NC file. + + Matches by comparing the size of the ``values`` dimension to the number of + cells in each candidate grid file. Silently returns the dataset unchanged + when no match is found (e.g. non-ICON or custom grids). + """ + if "values" not in ds.dims: + return ds + n = ds.sizes["values"] + for grid_nc in (ICON_CH1_GRID_NC, ICON_CH2_GRID_NC): + if not grid_nc.exists(): + continue + topo = _load_icon_topography(grid_nc) + if len(topo) == n: + LOG.info("Assigned model_elevation from %s (%d cells)", grid_nc.name, n) + return ds.assign_coords(model_elevation=("values", topo)) + LOG.warning( + "Could not assign model_elevation: no ICON grid NC file matches values=%d", n + ) + return ds + + def load_analysis_data_from_zarr( root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: @@ -64,6 +110,7 @@ def load_analysis_data_from_zarr( "PS": "sp", "PMSL": "msl", "TOT_PREC": "tp", + "FIS": "FIS", } tot_prec_string = "TOT_PREC_6H" if min(np.diff(steps)) == 6 else "TOT_PREC_1H" PARAMS_MAP_COSMO1 = { @@ -79,8 +126,12 @@ def load_analysis_data_from_zarr( # set 'variables' attr as dimension coordinate ds = ds.assign_coords({"variable": ds.attrs["variables"]}) + # Always include FIS so we can derive an elevation coordinate below + load_params = list(dict.fromkeys(params + ["FIS"])) # select variables and valid time, squeeze ensemble dimension - ds = ds.sel(variable=[PARAMS_MAP[p] for p in params]).squeeze("ensemble", drop=True) + ds = ds.sel(variable=[PARAMS_MAP[p] for p in load_params]).squeeze( + "ensemble", drop=True + ) # recover original 2D shape if len(ds.attrs["field_shape"]) == 2: @@ -111,6 +162,14 @@ def load_analysis_data_from_zarr( if "cell" in ds.dims: ds = ds.rename({"cell": "values"}) + # Derive elevation from FIS (surface geopotential, m²/s²) and assign as coordinate. + # FIS is constant in time, so drop the time dimension to get a purely spatial coord. + if "FIS" in ds: + elevation = ds["FIS"].isel(time=0, drop=True) / GRAVITY + ds = ds.assign_coords(elevation=elevation) + if "FIS" not in params: + ds = ds.drop_vars("FIS") + times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") return _select_valid_times(ds, times) @@ -264,71 +323,135 @@ def load_forecast_data_from_grib(files: list[Path], params: list[str]) -> xr.Dat return ds -def load_obs_data_from_peakweather( - root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h" -) -> xr.Dataset: - """Load PeakWeather station observations into an xarray Dataset. +DWH_PARAM_MAP = { + "T_2M": "tre200s0", + "TD_2M": "tde200s0", + "PS": "prestas0", + "PMSL": "pp0qffs0", + "TOT_PREC": "rre150h0", + "FF_10M": "fkl010z0", + "SP_10M": "fkl010z0", + "DD_10M": "dkl010z0", + "VMAX_10M": "fkl010z1", +} +DWH_WIND_SPEED = "fkl010z0" +DWH_WIND_DIR = "dkl010z0" +DWH_CELSIUS_TO_KELVIN = {"tre200s0", "tde200s0"} +DWH_HPA_TO_PA = {"prestas0", "pp0qffs0"} + + +def _jretrieve_df_to_xarray(df, short_names, catalog) -> xr.Dataset: + """Pivot long-form jretrieve obs into a (time, values) cube aligned to the + catalog, NaN-filled for missing cells.""" + station_to_idx = {sid: i for i, sid in enumerate(catalog.station_id)} + if df.empty: + time_index = pd.DatetimeIndex([]) + else: + df = df.copy() + df["time"] = pd.to_datetime(df["termin"].astype(str), format="%Y%m%d%H%M%S") + time_index = pd.DatetimeIndex(sorted(df["time"].unique())) + n_t, n_s = len(time_index), catalog.n + coords = { + "time": ("time", time_index.values.astype("datetime64[ns]")), + "values": ("values", catalog.nat_abbr), + "latitude": ("values", catalog.latitude), + "longitude": ("values", catalog.longitude), + "elevation": ("values", catalog.elevation), + } + data_vars: dict[str, tuple] = {} + if df.empty: + for p in short_names: + data_vars[p] = (("time", "values"), np.full((n_t, n_s), np.nan, np.float32)) + else: + time_to_idx = {t: i for i, t in enumerate(time_index)} + df["_si"] = df["station"].map(station_to_idx) + df["_ti"] = df["time"].map(time_to_idx) + df = df.dropna(subset=["_si", "_ti"]) + df["_si"] = df["_si"].astype(int) + df["_ti"] = df["_ti"].astype(int) + for p in short_names: + arr = np.full((n_t, n_s), np.nan, dtype=np.float32) + if p in df.columns: + arr[df["_ti"].to_numpy(), df["_si"].to_numpy()] = df[p].to_numpy( + dtype=np.float32 + ) + data_vars[p] = (("time", "values"), arr) + return xr.Dataset(data_vars=data_vars, coords=coords) - Returns a Dataset with dimensions `time` and `values`, values coordinates - (`lat`, `lon`), and variables renamed to ICON parameter names. - Temperatures are converted to Kelvin when present. + +def load_obs_data_from_jretrieve( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load SwissMetNet (SMN) surface observations from the DWH via jretrievedwh. + + ``root`` is a marker string selecting stations, e.g. ``jretrievedwh:SwissMetNet`` + (default group), ``jretrievedwh:locations=ARO,KLO``, or + ``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` (optionally ``;stage=devt``). Returns + a Dataset with dims (time, values), values=nat_abbr, latitude/longitude coords, + variables renamed to ICON names in SI units (T/TD in Kelvin, pressure in Pa). + Only the requested hourly valid times are kept. """ - from peakweather.dataset import PeakWeatherDataset - - param_names = { - "temperature": "T_2M", - "wind_u": "U_10M", - "wind_v": "V_10M", - "precipitation": "TOT_PREC", - "pressure": "PS", - "wind_gust": "VMAX_10M", - } - param_names = {k: v for k, v in param_names.items() if v in params} + from data_input import jretrieve as jr + + stations, stage, seq_type = jr.parse_selection(root) + jr.check_prerequisites(stage) + + want_uv = "U_10M" in params or "V_10M" in params + short_names: list[str] = [DWH_PARAM_MAP[p] for p in params if p in DWH_PARAM_MAP] + if want_uv: + short_names += [DWH_WIND_SPEED, DWH_WIND_DIR] + short_names = list(dict.fromkeys(short_names)) + if not short_names: + raise ValueError(f"No DWH parameter mapping for requested params: {params}") + start = reftime end = start + timedelta(hours=max(steps)) if len(steps) > 1: - end += timedelta(hours=steps[-1] - steps[-2]) # extend by 1 extra step - years = list(set([start.year, end.year])) - if "wind_u" in param_names or "wind_v" in param_names: - compute_uv = True - else: - compute_uv = False - pw = PeakWeatherDataset(root=root, years=years, freq=freq, compute_uv=compute_uv) - ds, mask = pw.get_observations( - parameters=[k for k in param_names.keys()], - first_date=f"{start:%Y-%m-%d %H:%M}", - last_date=f"{end:%Y-%m-%d %H:%M}", - return_mask=True, - ) - ds = ( - ds.stack(["nat_abbr", "name"], future_stack=True) - .to_xarray() - .to_dataset(dim="name") + end += timedelta(hours=steps[-1] - steps[-2]) + + catalog = jr.StationCatalog.from_meta( + jr.fetch_meta( + stations=stations, params=short_names, seq_type=seq_type, stage=stage + ) ) - mask = ( - mask.stack(["nat_abbr", "name"], future_stack=True) - .to_xarray() - .to_dataset(dim="name") + step_hours = (steps[1] - steps[0]) if len(steps) > 1 else 1 + df = jr.fetch_data( + stations=stations, + params=short_names, + start=start, + end=end, + increment_minutes=step_hours * 60, + seq_type=seq_type, + stage=stage, ) - ds = ds.where(mask) - ds = ds.rename({"datetime": "time", "nat_abbr": "values"}) - ds = ds.rename(param_names) - ds = ds.assign_coords(time=ds.indexes["time"].tz_convert("UTC").tz_localize(None)) - ds = ds.assign_coords(values=ds.indexes["values"]) - ds = ds.assign_coords(longitude=("values", pw.stations_table["longitude"])) - ds = ds.assign_coords(latitude=("values", pw.stations_table["latitude"])) - if "T_2M" in ds: - ds["T_2M"] = ds["T_2M"] - ZERO_KELVIN # convert to Kelvin - ds = ds.dropna("values", how="all") + raw = _jretrieve_df_to_xarray(df, short_names, catalog) + + out = xr.Dataset(coords=raw.coords) + for icon, short in DWH_PARAM_MAP.items(): + if icon in params and short in raw: + var = raw[short] + if short in DWH_CELSIUS_TO_KELVIN: + var = var - ZERO_KELVIN + elif short in DWH_HPA_TO_PA: + var = var * 100.0 + out[icon] = var + if want_uv and DWH_WIND_SPEED in raw and DWH_WIND_DIR in raw: + ff = raw[DWH_WIND_SPEED] + dd_rad = np.deg2rad(raw[DWH_WIND_DIR]) + if "U_10M" in params: + out["U_10M"] = -ff * np.sin(dd_rad) + if "V_10M" in params: + out["V_10M"] = -ff * np.cos(dd_rad) + out = out.dropna("values", how="all") times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") - return _select_valid_times(ds, times) + return _select_valid_times(out, times, strict=True) def load_truth_data( root, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: - """Load truth data from analysis Zarr or PeakWeather observations.""" + """Load truth data from an analysis Zarr dataset or DWH observations via jretrieve.""" if root.suffix == ".zarr": LOG.info("Loading ground truth from an analysis zarr dataset...") truth = load_analysis_data_from_zarr( @@ -342,9 +465,9 @@ def load_truth_data( if "y" in truth.dims and "x" in truth.dims else {"values": -1} ) - elif "peakweather" in str(root): - LOG.info("Loading ground truth from PeakWeather observations...") - truth = load_obs_data_from_peakweather( + elif "jretrieve" in str(root): + LOG.info("Loading ground truth from JRetrieve...") + truth = load_obs_data_from_jretrieve( root=root, reftime=reftime, steps=steps, @@ -725,13 +848,33 @@ def load_icon_baseline_from_grib( f"No ensemble members could be loaded for {reftime} from {root}" ) LOG.info("Ensemble mean computed over %d members.", n_loaded) - return acc / n_loaded + result = acc / n_loaded else: - return load_forecast_data_from_grib( + result = load_forecast_data_from_grib( files=_collect_icon_archive_files(root, reftime, steps, member_id=member), params=params, ) + # Attach model orography as model_elevation coordinate + if "ICON-CH1-EPS" in root.parts: + grid_nc = ICON_CH1_GRID_NC + elif "ICON-CH2-EPS" in root.parts: + grid_nc = ICON_CH2_GRID_NC + else: + grid_nc = None + if grid_nc is not None and grid_nc.exists() and "values" in result.dims: + topo = _load_icon_topography(grid_nc) + if result.sizes["values"] == len(topo): + result = result.assign_coords(model_elevation=("values", topo)) + else: + LOG.warning( + "model_elevation not assigned: values=%d but %s has %d cells", + result.sizes["values"], + grid_nc.name, + len(topo), + ) + return result + def load_forecast_data( root, reftime: datetime, steps: list[int], params: list[str], member: str = "000" @@ -747,11 +890,12 @@ def load_forecast_data( root = Path(root) if any(root.glob("*.grib")): LOG.info("Loading forecasts from GRIB files...") - return load_forecast_data_from_grib( + ds = load_forecast_data_from_grib( # NOTE: root is already for a specific reftime files=_collect_ml_grib_files(root, steps), params=params, ) + return _try_assign_model_elevation(ds) if "INCA" in root.parts: LOG.info("Loading INCA baseline from NetCDF files...") return load_INCA_baseline_from_netcdf(root, reftime, steps, params) diff --git a/src/data_input/jretrieve.py b/src/data_input/jretrieve.py new file mode 100644 index 00000000..cbf00ad9 --- /dev/null +++ b/src/data_input/jretrieve.py @@ -0,0 +1,301 @@ +"""Subprocess wrapper around ``jretrievedwh.py`` for retrieving SwissMetNet +(SMN) surface observations from the MeteoSwiss data warehouse (DWH). + +Requires ``jretrievedwh.py`` (resolved via $PATH, $OPR_HOME, or the hardcoded +fallback path) and credentials set as ``JRETRIEVE_CLIENT_ID`` / +``JRETRIEVE_CLIENT_SECRET`` in the environment or in a ``.env`` file in the +working directory. +""" + +from __future__ import annotations + +import logging +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from datetime import datetime +from io import StringIO +from pathlib import Path +from typing import Any, Sequence + +import numpy as np +import pandas as pd + +LOG = logging.getLogger(__name__) + +BINARY_NAME = "jretrievedwh.py" +HARDCODED_BINARY_PATH = "/oprusers/osm/opr.inn/bin/jretrievedwh.py" +DEFAULT_META_FIELDS: tuple[str, ...] = ("lat", "lon", "elev", "name", "nat_abbr") +DEFAULT_GROUP = "SwissMetNet" +CATALOG_TIME_RANGE_START = datetime(1900, 1, 1) +CATALOG_TIME_RANGE_END = datetime(2100, 12, 31, 23, 59) + + +class JretrieveError(RuntimeError): + """Raised when jretrievedwh.py fails or returns malformed output.""" + + +def _resolve_binary() -> str: + path = shutil.which(BINARY_NAME) + if path is not None: + return path + if os.path.isfile(HARDCODED_BINARY_PATH): + return HARDCODED_BINARY_PATH + raise JretrieveError( + f"{BINARY_NAME} not found on $PATH or at {HARDCODED_BINARY_PATH}." + ) + + +def _build_env(stage: str) -> dict[str, str]: + if stage != "prod": + raise ValueError(f"Only 'prod' stage is supported, got {stage!r}.") + conf_dir = str(Path(__file__).parents[2]) # project root + conf_name = ".jretrievedwh-conf.prod.py" + env = os.environ.copy() + env["JRETRIEVE_CONF_DIR"] = conf_dir + env["JRETRIEVE_CONF_NAME"] = conf_name + return env + + +def check_prerequisites(stage: str = "prod") -> None: + """Fail-fast validation that the jretrievedwh environment is usable. + + Checks the binary is reachable and credentials are available. Raises a + single ``JretrieveError`` listing *all* problems found, so a misconfigured + environment is reported up front rather than hours into a verification job. + """ + problems: list[str] = [] + if stage != "prod": + problems.append(f"Only 'prod' stage is supported, got {stage!r}.") + try: + _resolve_binary() + except JretrieveError as e: + problems.append(str(e)) + conf_path = Path(__file__).parents[2] / ".jretrievedwh-conf.prod.py" + if not conf_path.is_file(): + problems.append(f"jretrieve conf file not found: {conf_path}") + if problems: + raise JretrieveError( + "jretrievedwh prerequisites not met:\n - " + "\n - ".join(problems) + ) + + +def _fmt_time(dt: datetime) -> str: + return dt.strftime("%Y%m%d%H%M") + + +def _stations_to_argv(stations: dict[str, Any]) -> list[str]: + """Translate a station selection dict to jretrieve CLI args. + + Exactly one of {group, locations, bbox} must be set. + """ + keys = [k for k in ("group", "locations", "bbox") if stations.get(k) is not None] + if len(keys) != 1: + raise ValueError( + f"stations must specify exactly one of group/locations/bbox, got {keys}" + ) + key = keys[0] + val = stations[key] + if key == "group": + return ["-a", f"stn_group_id,{val}"] + if key == "locations": + if isinstance(val, str): + val = [v for v in val.split(",") if v] + if not isinstance(val, Sequence): + raise ValueError("stations.locations must be a list of nat_abbr strings.") + return ["-i", "nat_abbr," + ",".join(str(v) for v in val)] + if key == "bbox": + if isinstance(val, str): + val = [v for v in val.split(",") if v] + if len(val) != 4: + raise ValueError("stations.bbox must be [minlat, maxlat, minlon, maxlon].") + return ["-l", ",".join(str(v) for v in val)] + raise AssertionError("unreachable") + + +def parse_selection(root: Any) -> tuple[dict[str, Any], str, str]: + """Parse a truth-root marker into (stations, stage, seq_type). + + Examples (slash-free so they survive ``Path()`` normalisation): + ``jretrievedwh:SwissMetNet`` -> group + ``jretrievedwh:group=SwissMetNet;stage=devt`` + ``jretrievedwh:locations=ARO,KLO`` + ``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` + """ + _, _, rest = str(root).partition(":") + rest = rest.strip() + stations: dict[str, Any] = {} + stage = "prod" + seq_type = "surface" + for i, part in enumerate([p for p in rest.split(";") if p]): + if "=" not in part: + if i == 0: + stations["group"] = part + continue + raise ValueError(f"Invalid jretrieve selector fragment: {part!r}") + key, _, value = part.partition("=") + key, value = key.strip(), value.strip() + if key in ("group", "locations", "bbox"): + stations[key] = value + elif key == "stage": + stage = value + elif key == "seq_type": + seq_type = value + else: + raise ValueError(f"Unknown jretrieve selector key: {key!r}") + if not stations: + stations = {"group": DEFAULT_GROUP} + return stations, stage, seq_type + + +def _run(argv: list[str], env: dict[str, str], timeout_s: int) -> str: + try: + proc = subprocess.run( + argv, + env=env, + capture_output=True, + text=True, + timeout=timeout_s, + check=False, + ) + except subprocess.TimeoutExpired as e: + raise JretrieveError( + f"jretrieve timed out after {timeout_s}s: {' '.join(argv)}" + ) from e + if proc.returncode != 0: + raise JretrieveError( + f"jretrieve exited with {proc.returncode}\nargv: {argv}\n" + f"stderr: {proc.stderr.strip()}\nstdout (head): {proc.stdout[:500]}" + ) + if proc.stdout.lstrip().startswith("ERROR"): + raise JretrieveError(f"jretrieve returned error: {proc.stdout.strip()[:500]}") + return proc.stdout + + +def _run_with_retry(argv, env, timeout_s, attempts=3) -> str: + last_err: Exception | None = None + for attempt in range(1, attempts + 1): + try: + return _run(argv, env=env, timeout_s=timeout_s) + except JretrieveError as e: + last_err = e + if attempt == attempts: + break + backoff = 2**attempt + LOG.warning( + "jretrieve attempt %d/%d failed (%s); retrying in %ds", + attempt, + attempts, + e, + backoff, + ) + time.sleep(backoff) + assert last_err is not None + raise last_err + + +def _parse_csv(csv_text: str) -> pd.DataFrame: + csv_text = csv_text.strip() + if not csv_text: + return pd.DataFrame() + return pd.read_csv(StringIO(csv_text), sep=";") + + +def fetch_meta( + *, + stations, + params, + seq_type="surface", + stage="prod", + meta_fields=DEFAULT_META_FIELDS, + timeout_s=300, +) -> pd.DataFrame: + """Fetch the station catalog (rows per station x parameter x period) over a + fixed wide time range so the response is deterministic.""" + if not params: + raise ValueError("params must be non-empty.") + argv = [ + _resolve_binary(), + "-s", + seq_type, + "-n", + ",".join(params), + "-t", + f"{_fmt_time(CATALOG_TIME_RANGE_START)},{_fmt_time(CATALOG_TIME_RANGE_END)}", + "--meta-info", + ",".join(meta_fields), + "--format", + "csv", + *_stations_to_argv(stations), + ] + LOG.info("jretrieve meta: %s", " ".join(argv)) + df = _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) + if df.empty: + raise JretrieveError("jretrieve meta-info returned no rows.") + return df + + +def fetch_data( + *, + stations, + params, + start, + end, + increment_minutes=60, + seq_type="surface", + stage="prod", + timeout_s=600, +) -> pd.DataFrame: + """Fetch observation data; columns: station (int), termin (YYYYMMDDhhmmss), + one column per requested short name.""" + if not params: + raise ValueError("params must be non-empty.") + argv = [ + _resolve_binary(), + "-s", + seq_type, + "-n", + ",".join(params), + "-t", + f"{_fmt_time(start)},{_fmt_time(end)},{int(increment_minutes)}", + "--format", + "csv", + *_stations_to_argv(stations), + ] + LOG.info("jretrieve data: %s", " ".join(argv)) + return _parse_csv(_run_with_retry(argv, env=_build_env(stage), timeout_s=timeout_s)) + + +@dataclass(frozen=True) +class StationCatalog: + """Stable, nat_abbr-sorted station catalog used as the cell axis.""" + + nat_abbr: np.ndarray + station_id: np.ndarray + latitude: np.ndarray + longitude: np.ndarray + elevation: np.ndarray + name: np.ndarray + + @property + def n(self) -> int: + return len(self.nat_abbr) + + @classmethod + def from_meta(cls, meta: pd.DataFrame) -> "StationCatalog": + per_station = ( + meta.sort_values(["nat_abbr", "parameter", "op_since"], kind="stable") + .drop_duplicates(subset=["station"], keep="first") + .sort_values("nat_abbr", kind="stable") + .reset_index(drop=True) + ) + return cls( + nat_abbr=per_station["nat_abbr"].to_numpy(dtype=object), + station_id=per_station["station"].to_numpy(dtype=np.int64), + latitude=per_station["latitude"].to_numpy(dtype=np.float64), + longitude=per_station["longitude"].to_numpy(dtype=np.float64), + elevation=per_station["elev"].to_numpy(dtype=np.float64), + name=per_station["stn_name"].to_numpy(dtype=object), + ) diff --git a/src/evalml/config.py b/src/evalml/config.py index bbe94543..bf0db283 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -372,6 +372,10 @@ class ExperimentConfig(BaseModel): "Each dict maps operator keys (gt, ge, lt, le, eq, ne) to lists of threshold values." ), ) + lapse_rate_correction: bool = Field( + default=True, + description=("Apply standard-atmosphere lapse-rate correction to T_2M."), + ) dashboard: Dashboard = Field( ..., description="Settings for the experiment dashboard.", diff --git a/src/verification/__init__.py b/src/verification/__init__.py index 941e3433..4fa482de 100644 --- a/src/verification/__init__.py +++ b/src/verification/__init__.py @@ -21,6 +21,39 @@ LOG = logging.getLogger(__name__) +_T_LAPSE_RATE = 0.0065 # K/m — ICAO standard atmosphere +_LAPSE_RATE_PARAMS: dict[str, float] = {"T_2M": _T_LAPSE_RATE} + + +def apply_lapse_rate_correction( + fcst: xr.Dataset, + obs: xr.Dataset, + params: list[str], +) -> xr.Dataset: + """Correct T_2M and TD_2M in *fcst* to the elevation of *obs*. + + Requires *fcst* to carry a ``model_elevation`` coordinate (metres, from the + ICON external parameter file) and *obs* to carry an ``elevation`` coordinate + (metres, from station metadata or FIS geopotential). The function silently + returns *fcst* unchanged when either coordinate is absent so that pipelines + without elevation data are not broken. + + Formula applied per parameter: + T_corrected = T_forecast − Γ × (elevation_obs − model_elevation_fcst) + + A positive height difference (obs higher than forecast grid cell) lowers the + corrected value, consistent with the standard atmospheric lapse rate. + """ + if "model_elevation" not in fcst.coords or "elevation" not in obs.coords: + LOG.debug("Skipping lapse-rate correction: elevation coordinates missing.") + return fcst + dz = obs["elevation"] - fcst["model_elevation"] + fcst = fcst.copy() + for param, rate in _LAPSE_RATE_PARAMS.items(): + if param in params and param in fcst.data_vars: + fcst[param] = fcst[param] - rate * dz + return fcst + class AggregationMasks(abc.ABC): @abc.abstractmethod diff --git a/tests/unit/test_jretrieve.py b/tests/unit/test_jretrieve.py new file mode 100644 index 00000000..b63766d9 --- /dev/null +++ b/tests/unit/test_jretrieve.py @@ -0,0 +1,171 @@ +from datetime import datetime +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +import data_input +from data_input import jretrieve as jr + + +def test_stations_to_argv_group(): + assert jr._stations_to_argv({"group": "1,2"}) == [ + "-a", + "stn_group_id,1,2", + ] + + +def test_stations_to_argv_locations_from_string(): + assert jr._stations_to_argv({"locations": "ARO,KLO"}) == ["-i", "nat_abbr,ARO,KLO"] + + +def test_stations_to_argv_bbox_from_string(): + assert jr._stations_to_argv({"bbox": "45.8,47.8,5.9,10.5"}) == [ + "-l", + "45.8,47.8,5.9,10.5", + ] + + +def test_stations_to_argv_rejects_ambiguous(): + with pytest.raises(ValueError, match="exactly one"): + jr._stations_to_argv({"group": "x", "bbox": "1,2,3,4"}) + + +def test_parse_selection_default_group(): + assert jr.parse_selection("jretrievedwh:") == ( + {"group": "SwissMetNet"}, + "prod", + "surface", + ) + assert jr.parse_selection("jretrievedwh:SwissMetNet") == ( + {"group": "SwissMetNet"}, + "prod", + "surface", + ) + + +def test_parse_selection_keyvalue_and_stage(): + assert jr.parse_selection("jretrievedwh:locations=ARO,KLO;stage=devt") == ( + {"locations": "ARO,KLO"}, + "devt", + "surface", + ) + + +def test_check_prerequisites_ok(monkeypatch, tmp_path): + conf = tmp_path / ".jretrievedwh-conf.prod.py" + conf.write_text("# conf\n") + monkeypatch.setattr(jr.shutil, "which", lambda name: "/opt/bin/jretrievedwh.py") + monkeypatch.setenv("OPR_HOME", str(tmp_path)) + jr.check_prerequisites("prod") # should not raise + + +def test_check_prerequisites_missing_binary(monkeypatch): + monkeypatch.setattr(jr.shutil, "which", lambda name: None) + monkeypatch.setattr(jr.os.path, "isfile", lambda p: False) + with pytest.raises(jr.JretrieveError, match=r"\$PATH"): + jr.check_prerequisites("prod") + + +def test_check_prerequisites_aggregates_all_problems(monkeypatch): + monkeypatch.setattr(jr.shutil, "which", lambda name: None) + monkeypatch.setattr(jr.os.path, "isfile", lambda p: False) + monkeypatch.setattr(Path, "is_file", lambda self: False) + with pytest.raises(jr.JretrieveError) as exc: + jr.check_prerequisites("prod") + msg = str(exc.value) + assert ( + "$PATH" in msg and "conf file not found" in msg + ) # both reported, not just the first + + +def _sample_meta(): + return pd.DataFrame( + { + "station": [2, 1, 1], + "op_since": [19900101000000, 19800101000000, 19800101000000], + "op_till": ["", "", ""], + "parameter": ["tre200s0", "fkl010z0", "tre200s0"], + "latitude": [47.48, 46.79, 46.79], + "longitude": [8.54, 9.68, 9.68], + "elev": [426.0, 1878.0, 1878.0], + "stn_name": ["Zurich", "Arosa", "Arosa"], + "nat_abbr": ["KLO", "ARO", "ARO"], + } + ) + + +def test_station_catalog_from_meta_collapses_and_sorts(): + cat = jr.StationCatalog.from_meta(_sample_meta()) + assert cat.n == 2 + assert list(cat.nat_abbr) == ["ARO", "KLO"] # sorted by nat_abbr + assert list(cat.station_id) == [1, 2] + np.testing.assert_allclose(cat.latitude, [46.79, 47.48]) + + +def test_load_obs_data_from_jretrieve(monkeypatch): + meta = pd.DataFrame( + { + "station": [1, 2], + "op_since": [19800101000000, 19900101000000], + "op_till": ["", ""], + "parameter": ["tre200s0", "tre200s0"], + "latitude": [46.79, 47.48], + "longitude": [9.68, 8.54], + "elev": [1878.0, 426.0], + "stn_name": ["Arosa", "Zurich"], + "nat_abbr": ["ARO", "KLO"], + } + ) + data = pd.DataFrame( + { + "station": [1, 1], + "termin": [20250115000000, 20250115010000], + "tre200s0": [10.0, 11.0], # degC + "fkl010z0": [3.0, 4.0], # m/s + "dkl010z0": [0.0, 90.0], # deg + } + ) + # Skip the live environment check so the test runs anywhere (incl. CI, + # which has no jretrievedwh.py / $OPR_HOME). + monkeypatch.setattr(jr, "check_prerequisites", lambda *a, **k: None) + monkeypatch.setattr(jr, "fetch_meta", lambda **kw: meta) + monkeypatch.setattr(jr, "fetch_data", lambda **kw: data) + + ds = data_input.load_obs_data_from_jretrieve( + "jretrievedwh:locations=ARO,KLO", + datetime(2025, 1, 15, 0, 0), + [0, 1], + ["T_2M", "U_10M", "V_10M"], + ) + + assert set(ds.dims) == {"time", "values"} + assert list(ds["values"].values) == ["ARO"] # KLO all-NaN -> dropped + assert set(ds.data_vars) == {"T_2M", "U_10M", "V_10M"} + np.testing.assert_allclose(ds["T_2M"].sel(values="ARO").values, [283.15, 284.15]) + # DD=0 -> U=0, V=-FF ; DD=90 -> U=-FF, V=0 + np.testing.assert_allclose( + ds["U_10M"].sel(values="ARO").values, [0.0, -4.0], atol=1e-5 + ) + np.testing.assert_allclose( + ds["V_10M"].sel(values="ARO").values, [-3.0, 0.0], atol=1e-5 + ) + np.testing.assert_allclose(ds["latitude"].values, [46.79]) + assert "elevation" in ds.coords + np.testing.assert_allclose(ds["elevation"].values, [1878.0]) + + +def test_load_truth_data_forwards_root(monkeypatch): + seen = {} + + def fake_loader(root, reftime, steps, params): + seen["root"] = root + return "SENTINEL" + + monkeypatch.setattr(data_input, "load_obs_data_from_jretrieve", fake_loader) + out = data_input.load_truth_data( + Path("jretrievedwh:SwissMetNet"), datetime(2025, 1, 15), [0], ["T_2M"] + ) + assert out == "SENTINEL" + assert "jretrievedwh:SwissMetNet" in str(seen["root"]) diff --git a/tests/unit/test_verification.py b/tests/unit/test_verification.py index 400babe9..b549b640 100644 --- a/tests/unit/test_verification.py +++ b/tests/unit/test_verification.py @@ -8,7 +8,7 @@ sys.path.insert(0, str(Path(__file__).parents[2] / "workflow" / "scripts")) from verification_aggregation import aggregate_results -from verification import decode_metric +from verification import decode_metric, apply_lapse_rate_correction @pytest.mark.parametrize( @@ -97,3 +97,67 @@ def test_aggregate_results_n_samples_by_season_and_init_hour(): assert int(out["n_samples"].sel(season="DJF", init_hour=0)) == 2 assert int(out["n_samples"].sel(season="DJF", init_hour=12)) == 1 assert int(out["n_samples"].sel(season="JJA", init_hour=0)) == 1 + + +# --------------------------------------------------------------------------- +# apply_lapse_rate_correction +# --------------------------------------------------------------------------- + + +def _make_lapse_rate_datasets(fcst_elev, obs_elev, t2m=280.0, td2m=270.0): + """Return (fcst, obs) pair with elevation coordinates and T/TD data vars.""" + n = len(fcst_elev) + fcst = xr.Dataset( + { + "T_2M": (["step", "values"], np.full((3, n), t2m, dtype=np.float32)), + "TD_2M": (["step", "values"], np.full((3, n), td2m, dtype=np.float32)), + }, + coords={"model_elevation": ("values", np.array(fcst_elev, dtype=np.float32))}, + ) + obs = xr.Dataset( + coords={"elevation": ("values", np.array(obs_elev, dtype=np.float32))} + ) + return fcst, obs + + +def test_lapse_rate_correction_temperature(): + # Station 500 m above forecast grid cell → T should decrease by 0.0065 * 500 = 3.25 K + fcst, obs = _make_lapse_rate_datasets(fcst_elev=[500.0], obs_elev=[1000.0]) + result = apply_lapse_rate_correction(fcst, obs, ["T_2M", "TD_2M"]) + np.testing.assert_allclose(result["T_2M"].values, 280.0 - 0.0065 * 500.0, atol=1e-4) + + +def test_lapse_rate_correction_dewpoint_unchanged(): + # TD_2M is not corrected — only T_2M gets the lapse-rate adjustment + fcst, obs = _make_lapse_rate_datasets(fcst_elev=[500.0], obs_elev=[1000.0]) + result = apply_lapse_rate_correction(fcst, obs, ["T_2M", "TD_2M"]) + np.testing.assert_array_equal(result["TD_2M"].values, fcst["TD_2M"].values) + + +def test_lapse_rate_correction_station_below_grid(): + # Station 300 m below forecast grid → T should increase by 0.0065 * 300 = 1.95 K + fcst, obs = _make_lapse_rate_datasets(fcst_elev=[800.0], obs_elev=[500.0]) + result = apply_lapse_rate_correction(fcst, obs, ["T_2M"]) + np.testing.assert_allclose(result["T_2M"].values, 280.0 + 0.0065 * 300.0, atol=1e-4) + + +def test_lapse_rate_correction_skipped_without_model_elevation(): + fcst, obs = _make_lapse_rate_datasets(fcst_elev=[500.0], obs_elev=[1000.0]) + fcst_no_elev = fcst.drop_vars("model_elevation") + result = apply_lapse_rate_correction(fcst_no_elev, obs, ["T_2M", "TD_2M"]) + np.testing.assert_array_equal(result["T_2M"].values, fcst_no_elev["T_2M"].values) + + +def test_lapse_rate_correction_skipped_without_obs_elevation(): + fcst, obs = _make_lapse_rate_datasets(fcst_elev=[500.0], obs_elev=[1000.0]) + obs_no_elev = obs.drop_vars("elevation") + result = apply_lapse_rate_correction(fcst, obs_no_elev, ["T_2M", "TD_2M"]) + np.testing.assert_array_equal(result["T_2M"].values, fcst["T_2M"].values) + + +def test_lapse_rate_correction_only_requested_params(): + # Pass only T_2M in params — TD_2M should not be corrected + fcst, obs = _make_lapse_rate_datasets(fcst_elev=[500.0], obs_elev=[1000.0]) + result = apply_lapse_rate_correction(fcst, obs, ["T_2M"]) + np.testing.assert_allclose(result["T_2M"].values, 280.0 - 0.0065 * 500.0, atol=1e-4) + np.testing.assert_array_equal(result["TD_2M"].values, fcst["TD_2M"].values) diff --git a/uv.lock b/uv.lock index 8a509618..1bb467be 100644 --- a/uv.lock +++ b/uv.lock @@ -20,6 +20,7 @@ dependencies = [ sdist = { url = "https://files.pythonhosted.org/packages/4c/d4/6585f3b6fdb75648bca294664af4becc8aa2fb3fb08f4e4e9fd27e10d773/adjusttext-1.3.0.tar.gz", hash = "sha256:4ab75cd4453af4828876ac3e964f2c49be642ea834f0c1f7449558d5f12cbca1", size = 15724, upload-time = "2024-10-31T16:45:36.101Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/53/1c/8feedd607cc14c5df9aef74fe3af9a99bf660743b842a9b5b1865326b4aa/adjustText-1.3.0-py3-none-any.whl", hash = "sha256:da23d7b24b6db5ffa039bb136bfa556207365e32f48ac74b07ad26dd485bc691", size = 13154, upload-time = "2024-10-31T16:45:35.227Z" }, + { url = "https://files.pythonhosted.org/packages/2d/80/7ad35ee5321a86b842f9e8516c8ae4c86f58db7b40e82ce9759f94517a50/adjusttext-1.3.0-py3-none-any.whl", hash = "sha256:bc6c118cd9d7caf6ae37f9355e51d840a2d7f64b4fb2956b8401de27c5af803b", size = 13264, upload-time = "2026-06-08T16:40:05.041Z" }, ] [[package]] @@ -1269,7 +1270,6 @@ dependencies = [ { name = "marimo" }, { name = "mlflow" }, { name = "netcdf4" }, - { name = "peakweather" }, { name = "pydantic" }, { name = "pyproj" }, { name = "pyzmq" }, @@ -1302,7 +1302,6 @@ requires-dist = [ { name = "marimo", specifier = ">=0.23.3" }, { name = "mlflow", specifier = ">=3.1.1" }, { name = "netcdf4", specifier = ">=1.7.2" }, - { name = "peakweather", git = "https://github.com/MeteoSwiss/PeakWeather.git" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "pyproj", specifier = ">=3.7.2" }, { name = "pyzmq", specifier = ">=27.1.0" }, @@ -3114,17 +3113,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/20/f2b98b18200c304f04f7839732298a786d121dba6b7cd79aa406c8c9000d/pdbufr-0.14.2-py3-none-any.whl", hash = "sha256:8d9eb74e65fe1b4b89ffe5e7ad8ee0e854ac2efbed5b64ac17cff287ea86a838", size = 52633, upload-time = "2026-02-26T17:03:45.16Z" }, ] -[[package]] -name = "peakweather" -version = "0.2.2" -source = { git = "https://github.com/MeteoSwiss/PeakWeather.git#65db3aab3dd8ca399de451577f4740a5f2f5c7ac" } -dependencies = [ - { name = "numpy" }, - { name = "pandas" }, - { name = "pyarrow" }, - { name = "tqdm" }, -] - [[package]] name = "pillow" version = "12.2.0" diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index c75da889..61a7ff24 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -281,12 +281,12 @@ def collect_experiment_participants(): participants = {} for base in BASELINE_CONFIGS.keys(): participants[base] = ( - OUT_ROOT / f"data/baselines/{base}/verif_aggregated_{TRUTH_HASH}.nc" + OUT_ROOT / f"data/baselines/{base}/verif_aggregated_{VERIF_HASH}.nc" ) for exp in RUN_CONFIGS.keys(): if RUN_CONFIGS[exp].get("_is_candidate", False): participants[exp] = ( - OUT_ROOT / f"data/runs/{exp}/verif_aggregated_{TRUTH_HASH}.nc" + OUT_ROOT / f"data/runs/{exp}/verif_aggregated_{VERIF_HASH}.nc" ) return participants @@ -355,7 +355,43 @@ def truth_hash(truth_config: dict) -> str: return generate_json_hash(cfg) +def verif_hash(full_config: dict) -> str: + """Hash of all settings that affect verification outputs. + + Combines the truth source with verification-method settings so that + changing either (e.g. switching lapse_rate_correction on/off) produces + new output paths and unconditionally triggers a rerun. + """ + truth_cfg = { + k: v for k, v in full_config["truth"].items() if k not in TRUTH_HASH_EXCLUDE + } + experiment_verif_cfg = { + "lapse_rate_correction": full_config.get("experiment", {}).get( + "lapse_rate_correction", True + ), + } + return generate_json_hash({"truth": truth_cfg, "verif": experiment_verif_cfg}) + + +def truth_file_dep(_): + """Truth file dependency: a real path for zarr, but a live-query + marker (no input file) for jretrieve.""" + root = config["truth"]["root"] + return [] if "jretrieve" in str(root) else [root] + + +# Fail fast: when the truth source is the live DWH (jretrievedwh), verify its +# prerequisites at workflow-build time so a misconfigured environment is caught +# at launch, before any (expensive) inference job runs. +if "jretrieve" in str(config["truth"]["root"]): + from data_input.jretrieve import check_prerequisites, parse_selection + + _, _jretrieve_stage, _ = parse_selection(config["truth"]["root"]) + check_prerequisites(_jretrieve_stage) + + TRUTH_HASH = truth_hash(config["truth"]) +VERIF_HASH = verif_hash(config) REGIONS = parse_regions() SHOWCASE_REGIONS = parse_showcase_regions() SHOWCASE_PARAMS = config.get("showcase", {}).get("params", ["T_2M", "SP_10M"]) diff --git a/workflow/rules/data.smk b/workflow/rules/data.smk index 67914a9f..eb41eff9 100644 --- a/workflow/rules/data.smk +++ b/workflow/rules/data.smk @@ -4,23 +4,6 @@ from pathlib import Path include: "common.smk" -if config["truth"]["root"].endswith("peakweather"): - output_peakweather_root = config["truth"]["root"] -else: - output_peakweather_root = OUT_ROOT / "data/observations/peakweather" - - -rule data_download_obs_from_peakweather: - output: - root=directory(output_peakweather_root), - localrule: True - run: - from peakweather.dataset import PeakWeatherDataset - - # Download the data from Huggingface - ds = PeakWeatherDataset(root=output.root) - - # Grid-definition files required by earthkit/eckit to decode the ICON-CH grids. # Automatic download/caching of these is currently broken (see README), so we # fetch them into eckit's default geo grid cache under $HOME, where it finds diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 91ff1080..b855662e 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -27,8 +27,7 @@ rule plot_meteogram: input: script="workflow/scripts/plot_meteogram.py", inference_okfile=rules.inference_execute.output.okfile, - truth=config["truth"]["root"], - peakweather_dir=rules.data_download_obs_from_peakweather.output.root, + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: expand( @@ -44,6 +43,7 @@ rule plot_meteogram: runtime="60m", params: ana_label=lambda wc: config["truth"]["label"], + truth_root=config["truth"]["root"], fcst_grib=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" ).resolve(), @@ -71,9 +71,8 @@ rule plot_meteogram: --forecast {params.fcst_grib:q} --forecast_steps {params.fcst_steps:q} --forecast_label {params.fcst_label:q} - --analysis {input.truth:q} + --analysis {params.truth_root:q} --analysis_label {params.ana_label:q} - --peakweather {input.peakweather_dir:q} --date {wildcards.init_time:q} --outdir {params.outdir:q} --param {wildcards.param:q} diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 4bae88e2..b062983c 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -15,10 +15,10 @@ rule verification_metrics_baseline: "src/data_input/__init__.py", script="workflow/scripts/verification_metrics.py", forecast=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["root"], - truth=config["truth"]["root"], + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: - OUT_ROOT / f"data/baselines/{{baseline_id}}/{{init_time}}/verif_{TRUTH_HASH}.nc", + OUT_ROOT / f"data/baselines/{{baseline_id}}/{{init_time}}/verif_{VERIF_HASH}.nc", log: OUT_ROOT / "logs/verification_metrics_baseline/{baseline_id}-{init_time}.log", resources: @@ -29,16 +29,22 @@ rule verification_metrics_baseline: baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"), baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], member=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("member", "000"), + truth=config["truth"]["root"], truth_label=config["truth"]["label"], regions=REGIONS, experiment_params=",".join(EXPERIMENT_PARAMS), threshold_dict=config["experiment"]["thresholds"], + lapse_rate_flag=( + "--lapse_rate_correction" + if config["experiment"].get("lapse_rate_correction", True) + else "" + ), shell: """ export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) uv run {input.script} \ --forecast {input.forecast} \ - --truth {input.truth} \ + --truth {params.truth} \ --reftime {wildcards.init_time} \ --steps "{params.baseline_steps}" \ --label "{params.baseline_label}" \ @@ -47,6 +53,7 @@ rule verification_metrics_baseline: --params "{params.experiment_params}" \ --threshold_dict "{params.threshold_dict}" \ --member "{params.member}" \ + {params.lapse_rate_flag} \ --output {output} >{log} 2>&1 """ @@ -64,10 +71,10 @@ rule verification_metrics: "src/data_input/__init__.py", script="workflow/scripts/verification_metrics.py", inference_okfile=rules.inference_execute.output.okfile, - truth=config["truth"]["root"], + truth_dep=truth_file_dep, eckit_grids=rules.data_download_eckit_geo_grids.output, output: - OUT_ROOT / f"data/runs/{{run_id}}/{{init_time}}/verif_{TRUTH_HASH}.nc", + OUT_ROOT / f"data/runs/{{run_id}}/{{init_time}}/verif_{VERIF_HASH}.nc", log: OUT_ROOT / "logs/verification_metrics/{run_id}-{init_time}.log", resources: @@ -80,6 +87,7 @@ rule verification_metrics: params: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + truth=config["truth"]["root"], truth_label=config["truth"]["label"], regions=REGIONS, grib_out_dir=lambda wc: ( @@ -87,12 +95,17 @@ rule verification_metrics: ).resolve(), experiment_params=",".join(EXPERIMENT_PARAMS), threshold_dict=config["experiment"]["thresholds"], + lapse_rate_flag=( + "--lapse_rate_correction" + if config["experiment"].get("lapse_rate_correction", True) + else "" + ), shell: """ export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) uv run {input.script} \ --forecast {params.grib_out_dir} \ - --truth {input.truth} \ + --truth {params.truth} \ --reftime {wildcards.init_time} \ --steps "{params.fcst_steps}" \ --label "{params.fcst_label}" \ @@ -100,6 +113,7 @@ rule verification_metrics: --regions "{params.regions}" \ --params "{params.experiment_params}" \ --threshold_dict "{params.threshold_dict}" \ + {params.lapse_rate_flag} \ --output {output} >{log} 2>&1 """ @@ -120,7 +134,7 @@ rule verification_metrics_aggregation: allow_missing=True, ), output: - OUT_ROOT / f"data/runs/{{run_id}}/verif_aggregated_{TRUTH_HASH}.nc", + OUT_ROOT / f"data/runs/{{run_id}}/verif_aggregated_{VERIF_HASH}.nc", log: OUT_ROOT / "logs/verification_metrics_aggregation/{run_id}.log", resources: @@ -142,7 +156,7 @@ use rule verification_metrics_aggregation as verification_metrics_aggregation_ba allow_missing=True, ), output: - OUT_ROOT / f"data/baselines/{{baseline_id}}/verif_aggregated_{TRUTH_HASH}.nc", + OUT_ROOT / f"data/baselines/{{baseline_id}}/verif_aggregated_{VERIF_HASH}.nc", log: OUT_ROOT / "logs/verification_metrics_aggregation_baseline/{baseline_id}.log", diff --git a/workflow/scripts/plot_meteogram.py b/workflow/scripts/plot_meteogram.py index 8e7ac30e..43248d73 100644 --- a/workflow/scripts/plot_meteogram.py +++ b/workflow/scripts/plot_meteogram.py @@ -4,13 +4,14 @@ from pathlib import Path import matplotlib.pyplot as plt -from peakweather import PeakWeatherDataset +import xarray as xr from data_input import ( parse_steps, load_forecast_data, load_truth_data, ) +from data_input import jretrieve as jr from verification.spatial import map_forecast_to_truth LOG = logging.getLogger(__name__) @@ -96,9 +97,6 @@ def main(): default="truth", help="Label for analysis line in plot legend.", ) - parser.add_argument( - "--peakweather", type=str, default=None, help="Path to PeakWeather dataset" - ) parser.add_argument("--date", type=str, default=None, help="reference datetime") parser.add_argument("--outdir", type=str, help="output directory") parser.add_argument("--param", type=str, help="parameter") @@ -124,7 +122,6 @@ def main(): "Mismatched baseline arguments: --baseline and --baseline_label " "must be provided the same number of times." ) - peakweather_dir = Path(args.peakweather) init_time = datetime.strptime(args.date, "%Y%m%d%H%M") outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) @@ -145,6 +142,28 @@ def main(): else: paramlist = [param] + # Load station metadata from DWH + LOG.info("Fetching station metadata from jretrieve (SwissMetNet catalog)") + _jr_stations, _jr_stage, _jr_seq_type = jr.parse_selection("jretrievedwh:1,2") + _catalog = jr.StationCatalog.from_meta( + jr.fetch_meta( + stations=_jr_stations, + params=["rre150h0"], + seq_type=_jr_seq_type, + stage=_jr_stage, + ) + ) + catalog_lookup = { + abbr: (lat, lon) + for abbr, lat, lon in zip( + _catalog.nat_abbr, _catalog.latitude, _catalog.longitude + ) + } + + LOG.info("Loading analysis data from %s", analysis_root) + analysis_ds = load_truth_data(analysis_root, init_time, forecast_steps, paramlist) + analysis_ds = preprocess_ds(analysis_ds, param) + # Load gridded data once — shared across all station plots LOG.info("Loading forecast data from %s", forecast_grib_dir) forecast_ds = load_forecast_data( @@ -152,11 +171,6 @@ def main(): ) forecast_ds = preprocess_ds(forecast_ds, param) - steps = [int(s) for s in forecast_ds["step"].dt.total_seconds().values / 3600] - LOG.info("Loading analysis data from %s", analysis_root) - analysis_ds = load_truth_data(analysis_root, init_time, steps, paramlist) - analysis_ds = preprocess_ds(analysis_ds, param) - baseline_ds_list = [] for root, step, label in zip(baseline_roots, baseline_steps, baseline_labels): LOG.info("Loading baseline '%s' from %s", label, root) @@ -164,12 +178,6 @@ def main(): preprocess_ds(load_forecast_data(root, init_time, step, paramlist), param) ) - # Load station metadata once - LOG.info("Loading station metadata from %s", peakweather_dir) - peakweather = PeakWeatherDataset(root=peakweather_dir) - stations_table = peakweather.stations_table - stations_table.index.names = ["values"] - param2plot = forecast_ds[param].attrs.get("parameter", {}) short = param2plot.get("shortName", "") units = param2plot.get("units", "") @@ -183,9 +191,14 @@ def main(): stations.index(station) + 1, len(stations), ) - station_ds = stations_table.to_xarray().sel(values=[station]) - station_ds = station_ds.set_coords(("latitude", "longitude", "station_name")) - station_ds = station_ds.drop_vars(list(station_ds.data_vars)) + lat, lon = catalog_lookup[station] + station_ds = xr.Dataset( + coords={ + "values": [station], + "latitude": ("values", [lat]), + "longitude": ("values", [lon]), + } + ) forecast_station_ds = map_forecast_to_truth(forecast_ds, station_ds) analysis_station_ds = map_forecast_to_truth(analysis_ds, station_ds) diff --git a/workflow/scripts/verification_metrics.py b/workflow/scripts/verification_metrics.py index 0281d1c0..d9eeb335 100644 --- a/workflow/scripts/verification_metrics.py +++ b/workflow/scripts/verification_metrics.py @@ -5,7 +5,7 @@ from pathlib import Path -from verification import verify # noqa: E402 +from verification import verify, apply_lapse_rate_correction # noqa: E402 from verification.spatial import map_forecast_to_truth # noqa: E402 from data_input import ( parse_steps, @@ -74,6 +74,9 @@ def main(args: ScriptConfig): fcst = map_forecast_to_truth(fcst, truth) truth = truth.sel(time=fcst["valid_time"]) + if args.lapse_rate_correction: + fcst = apply_lapse_rate_correction(fcst, truth, args.params) + # compute metrics and statistics results = verify( fcst, @@ -161,6 +164,12 @@ def main(args: ScriptConfig): default="verif.nc", help="Output file to save the verification results (default: verif.nc).", ) + parser.add_argument( + "--lapse_rate_correction", + action="store_true", + default=True, + help="Apply standard-atmosphere lapse-rate correction to T_2M and TD_2M.", + ) args = parser.parse_args() main(args) diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index 03671aa1..208ecb38 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -274,6 +274,12 @@ "title": "Thresholds", "type": "object" }, + "lapse_rate_correction": { + "default": true, + "description": "Apply standard-atmosphere lapse-rate correction to T_2M.", + "title": "Lapse Rate Correction", + "type": "boolean" + }, "dashboard": { "$ref": "#/$defs/Dashboard", "description": "Settings for the experiment dashboard."