Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c2cc978
feat(data_input): add jretrievedwh subprocess wrapper + station catalog
clairemerker Jun 12, 2026
fdf18c4
test(data_input): cover StationCatalog.from_meta
clairemerker Jun 12, 2026
49479e9
feat(data_input): implement load_obs_data_from_jretrieve
clairemerker Jun 12, 2026
e903490
feat(data_input): forward truth root marker to jretrieve loader
clairemerker Jun 12, 2026
9babcc4
feat(workflow): make truth input conditional for live jretrieve source
clairemerker Jun 12, 2026
d172fb5
docs: document jretrievedwh truth source config + prerequisites
clairemerker Jun 12, 2026
cea206c
style: apply ruff format to jretrieve source and tests
clairemerker Jun 12, 2026
7aa6ab9
docs: correct SwissMetNet abbreviation SNM -> SMN
clairemerker Jun 12, 2026
5cd0fa8
feat(jretrieve): fail-fast check for jretrievedwh.py on PATH + OPR_HOME
clairemerker Jun 12, 2026
ba88587
fix issue with meas_group/stn_group
jonasbhend Jun 12, 2026
2182e3b
fix inadvertent change to config and update readme
jonasbhend Jun 12, 2026
54da55e
test(jretrieve): mock check_prerequisites in loader test for CI
clairemerker Jun 12, 2026
5e0caec
remove peakweather
jonasbhend Jun 12, 2026
a2e4e99
remove truth hash, as this is redundant (and not used)
jonasbhend Jun 12, 2026
62c6cf8
only retrieve necessary timesteps
jonasbhend Jun 15, 2026
4c92452
fail with error if not all time steps are available
jonasbhend Jun 15, 2026
0add965
use dedicated varda credentials for jretrieve
jonasbhend Jun 15, 2026
ae0a8c7
update dependencies
jonasbhend Jun 15, 2026
065a5d9
fix failing test
jonasbhend Jun 15, 2026
ca4a223
Add SP_10M to list of DWH parameters
jonasbhend Jun 18, 2026
7d52c90
Merge branch 'main' into feat/jretrieve
dnerini Jun 18, 2026
86d3380
fix README
jonasbhend Jun 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .jretrievedwh-conf.prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import base64
import json
import os
import urllib.request
from pathlib import Path


def _read_dotenv(path):
result = {}
try:
with open(path) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
result[key.strip()] = value.strip().strip('"').strip("'")
except FileNotFoundError:
pass
return result


_client_id = os.environ.get("JRETRIEVE_CLIENT_ID")
_client_secret = os.environ.get("JRETRIEVE_CLIENT_SECRET")
if not _client_id or not _client_secret:
_dotenv = _read_dotenv(Path(os.environ.get("JRETRIEVE_CONF_DIR", ".")) / ".env")
_client_id = _client_id or _dotenv.get("JRETRIEVE_CLIENT_ID")
_client_secret = _client_secret or _dotenv.get("JRETRIEVE_CLIENT_SECRET")
if not _client_id or not _client_secret:
raise RuntimeError(
"jretrieve credentials not found. Set JRETRIEVE_CLIENT_ID and "
"JRETRIEVE_CLIENT_SECRET in the environment or in a .env file next "
"to this script."
)

jretrieve_url = "https://service.meteoswiss.ch/jretrieve/api/v1"
with urllib.request.urlopen(
urllib.request.Request(
method="POST",
url="https://service.meteoswiss.ch/auth/realms/meteoswiss.ch/protocol/openid-connect/token",
data=b"grant_type=client_credentials",
headers={
b"Authorization": b"Basic "
+ base64.b64encode(f"{_client_id}:{_client_secret}".encode())
},
)
) as f:
auth_header = "Bearer " + json.loads(f.read().decode())["access_token"]
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,31 @@ You can then run it with:
evalml experiment path/to/experiment/config.yaml --report
```

### Truth sources

The `truth.root` value selects how the ground truth is loaded:

- **Analysis Zarr** — a path ending in `.zarr` (anemoi analysis dataset).
- **DWH / jretrievedwh** — a `jretrievedwh:` marker string fetching surface observations
(e.g. SMN) live from the MeteoSwiss data warehouse. Variables are mapped to
ICON names in SI units (temperatures in K, pressure in Pa, precipitation as the hourly
sum); wind `U_10M`/`V_10M` are derived from speed + direction.

Marker syntax (station selection is required; pick one of group/locations/bbox):

```yaml
truth:
label: SwissMetNet
root: jretrievedwh:1,2 # stn_group_id (default)
# root: jretrievedwh:locations=ARO,KLO,LUG # explicit nat_abbr list
# root: jretrievedwh:bbox=45.8,47.8,5.9,10.5 # minlat,maxlat,minlon,maxlon
```

**Prerequisites:** `jretrievedwh.py` must be on `$PATH` (falls back to

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to package this as a python package that we add as project dependency?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clairemerker I think I lack a bit of insight on what this does and how much effort this would be. I agree though that the current solution is 'hacky'

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm... It's part of this repo: https://service.meteoswiss.ch/git/databrokerandcustomerinterfaces/jretrieve
So if anything we would need to get if from there? I wouldn't create our own package for it...

`/oprusers/osm/opr.inn/bin/jretrievedwh.py`) and DWH credentials must be
available — see [Credentials setup](#credentials-setup). No data is
pre-downloaded — the obs are queried at verification time.


## Installation

Expand Down Expand Up @@ -175,6 +200,25 @@ valid for 30 days. Every training or evaluation run within this period automatic
extends the token by another 30 days. It’s good practice to run the login command before
executing the workflow to ensure your token is still valid.

### DWH / jretrieve credentials

To use `jretrievedwh:` as a truth source, provide a client ID and secret for
the MeteoSwiss service account. Set them as environment variables:

```bash
export JRETRIEVE_CLIENT_ID=your-client-id
export JRETRIEVE_CLIENT_SECRET=your-client-secret
```

or place them in a `.env` file at the project root (next to `.jretrievedwh-conf.prod.py`):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it help to include a .env.template ?


```
JRETRIEVE_CLIENT_ID=your-client-id
JRETRIEVE_CLIENT_SECRET=your-client-secret
```

The `.env` file is listed in `.gitignore` and is never committed.

## Workspace setup

By default, data produced by the workflow will be stored under `output/` in your working directory.
Expand Down
7 changes: 6 additions & 1 deletion config/varda-single-1.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ runs:

truth:
label: SwissMetNet
root: output/data/observations/peakweather
root: jretrievedwh:1,2
# To verify against SwissMetNet observations from the DWH via jretrievedwh,
# set instead (requires jretrievedwh.py on $PATH and $OPR_HOME set):
# Other selectors: root: jretrievedwh:locations=ARO,KLO,LUG
# root: jretrievedwh:bbox=45.8,47.8,5.9,10.5
# append ;stage=devt to target a non-prod DWH stage

experiment:
params:
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies = [
"pyproj>=3.7.2",
"marimo>=0.23.3",
"geopandas>=0.14.0",
"peakweather",
"pyzmq>=27.1.0",
"scores>=2.0.0",
"eccodes>=2.44,<2.48",
Expand Down Expand Up @@ -60,5 +59,4 @@ packages = [
]

[tool.uv.sources]
peakweather = { git = "https://github.com/MeteoSwiss/PeakWeather.git" }
eccodes-cosmo-resources-python = { git = "https://github.com/MeteoSwiss/eccodes-cosmo-resources-python" }
181 changes: 125 additions & 56 deletions src/data_input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import earthkit.data as ekd
import numpy as np
import pandas as pd
import xarray as xr
from pyproj import Transformer

Expand All @@ -19,16 +20,21 @@
ZERO_KELVIN = -273.15 # °C


def _select_valid_times(ds, times: np.datetime64):
def _select_valid_times(ds, times: np.datetime64, strict: bool = False):
# (handle special case where some valid times are not in the dataset, e.g. at the end)
times_np = np.asarray(times, dtype="datetime64[ns]")
times_included = np.isin(times_np, ds.time.values)
if times_included.all():
return ds.sel(time=times_np)
elif times_included.any():
missing = times_np[~times_included]
if strict:
raise ValueError(
f"Some valid times are not included in the dataset:\n{missing}"
)
LOG.warning(
"Some valid times are not included in the dataset: \n%s",
times_np[~times_included],
missing,
)
return ds.sel(time=times_np[times_included])
else:
Expand Down Expand Up @@ -264,71 +270,134 @@ def load_forecast_data_from_grib(files: list[Path], params: list[str]) -> xr.Dat
return ds


def load_obs_data_from_peakweather(
root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h"
) -> xr.Dataset:
"""Load PeakWeather station observations into an xarray Dataset.
DWH_PARAM_MAP = {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the global DWH_ variables are used only by load_obs_data_from_jretrieve, I suggest moving them at the top of that function

"T_2M": "tre200s0",
"TD_2M": "tde200s0",
"PS": "prestas0",
"PMSL": "pp0qffs0",
"TOT_PREC": "rre150h0",
"FF_10M": "fkl010z0",
"SP_10M": "fkl010z0",
"DD_10M": "dkl010z0",
"VMAX_10M": "fkl010z1",
}
DWH_WIND_SPEED = "fkl010z0"
DWH_WIND_DIR = "dkl010z0"
DWH_CELSIUS_TO_KELVIN = {"tre200s0", "tde200s0"}
DWH_HPA_TO_PA = {"prestas0", "pp0qffs0"}


def _jretrieve_df_to_xarray(df, short_names, catalog) -> xr.Dataset:
"""Pivot long-form jretrieve obs into a (time, values) cube aligned to the
catalog, NaN-filled for missing cells."""
station_to_idx = {sid: i for i, sid in enumerate(catalog.station_id)}
if df.empty:
time_index = pd.DatetimeIndex([])
else:
df = df.copy()
df["time"] = pd.to_datetime(df["termin"].astype(str), format="%Y%m%d%H%M%S")
time_index = pd.DatetimeIndex(sorted(df["time"].unique()))
n_t, n_s = len(time_index), catalog.n
coords = {
"time": ("time", time_index.values.astype("datetime64[ns]")),
"values": ("values", catalog.nat_abbr),
"latitude": ("values", catalog.latitude),
"longitude": ("values", catalog.longitude),
}
data_vars: dict[str, tuple] = {}
if df.empty:
for p in short_names:
data_vars[p] = (("time", "values"), np.full((n_t, n_s), np.nan, np.float32))
else:
time_to_idx = {t: i for i, t in enumerate(time_index)}
df["_si"] = df["station"].map(station_to_idx)
df["_ti"] = df["time"].map(time_to_idx)
df = df.dropna(subset=["_si", "_ti"])
df["_si"] = df["_si"].astype(int)
df["_ti"] = df["_ti"].astype(int)
for p in short_names:
arr = np.full((n_t, n_s), np.nan, dtype=np.float32)
if p in df.columns:
arr[df["_ti"].to_numpy(), df["_si"].to_numpy()] = df[p].to_numpy(
dtype=np.float32
)
data_vars[p] = (("time", "values"), arr)
return xr.Dataset(data_vars=data_vars, coords=coords)


Returns a Dataset with dimensions `time` and `values`, values coordinates
(`lat`, `lon`), and variables renamed to ICON parameter names.
Temperatures are converted to Kelvin when present.
def load_obs_data_from_jretrieve(
root, reftime: datetime, steps: list[int], params: list[str]
) -> xr.Dataset:
"""Load SwissMetNet (SMN) surface observations from the DWH via jretrievedwh.

``root`` is a marker string selecting stations, e.g. ``jretrievedwh:SwissMetNet``
(default group), ``jretrievedwh:locations=ARO,KLO``, or
``jretrievedwh:bbox=45.8,47.8,5.9,10.5`` (optionally ``;stage=devt``). Returns
a Dataset with dims (time, values), values=nat_abbr, latitude/longitude coords,
variables renamed to ICON names in SI units (T/TD in Kelvin, pressure in Pa).
Only the requested hourly valid times are kept.
"""
from peakweather.dataset import PeakWeatherDataset

param_names = {
"temperature": "T_2M",
"wind_u": "U_10M",
"wind_v": "V_10M",
"precipitation": "TOT_PREC",
"pressure": "PS",
"wind_gust": "VMAX_10M",
}
param_names = {k: v for k, v in param_names.items() if v in params}
from data_input import jretrieve as jr

stations, stage, seq_type = jr.parse_selection(root)
jr.check_prerequisites(stage)

want_uv = "U_10M" in params or "V_10M" in params
short_names: list[str] = [DWH_PARAM_MAP[p] for p in params if p in DWH_PARAM_MAP]
if want_uv:
short_names += [DWH_WIND_SPEED, DWH_WIND_DIR]
short_names = list(dict.fromkeys(short_names))
if not short_names:
raise ValueError(f"No DWH parameter mapping for requested params: {params}")

start = reftime
end = start + timedelta(hours=max(steps))
if len(steps) > 1:
end += timedelta(hours=steps[-1] - steps[-2]) # extend by 1 extra step
years = list(set([start.year, end.year]))
if "wind_u" in param_names or "wind_v" in param_names:
compute_uv = True
else:
compute_uv = False
pw = PeakWeatherDataset(root=root, years=years, freq=freq, compute_uv=compute_uv)
ds, mask = pw.get_observations(
parameters=[k for k in param_names.keys()],
first_date=f"{start:%Y-%m-%d %H:%M}",
last_date=f"{end:%Y-%m-%d %H:%M}",
return_mask=True,
)
ds = (
ds.stack(["nat_abbr", "name"], future_stack=True)
.to_xarray()
.to_dataset(dim="name")
end += timedelta(hours=steps[-1] - steps[-2])

catalog = jr.StationCatalog.from_meta(
jr.fetch_meta(
stations=stations, params=short_names, seq_type=seq_type, stage=stage
)
)
mask = (
mask.stack(["nat_abbr", "name"], future_stack=True)
.to_xarray()
.to_dataset(dim="name")
step_hours = (steps[1] - steps[0]) if len(steps) > 1 else 1
df = jr.fetch_data(
stations=stations,
params=short_names,
start=start,
end=end,
increment_minutes=step_hours * 60,
seq_type=seq_type,
stage=stage,
)
ds = ds.where(mask)
ds = ds.rename({"datetime": "time", "nat_abbr": "values"})
ds = ds.rename(param_names)
ds = ds.assign_coords(time=ds.indexes["time"].tz_convert("UTC").tz_localize(None))
ds = ds.assign_coords(values=ds.indexes["values"])
ds = ds.assign_coords(longitude=("values", pw.stations_table["longitude"]))
ds = ds.assign_coords(latitude=("values", pw.stations_table["latitude"]))
if "T_2M" in ds:
ds["T_2M"] = ds["T_2M"] - ZERO_KELVIN # convert to Kelvin
ds = ds.dropna("values", how="all")
raw = _jretrieve_df_to_xarray(df, short_names, catalog)

out = xr.Dataset(coords=raw.coords)
for icon, short in DWH_PARAM_MAP.items():
if icon in params and short in raw:
var = raw[short]
if short in DWH_CELSIUS_TO_KELVIN:
var = var - ZERO_KELVIN
elif short in DWH_HPA_TO_PA:
var = var * 100.0
out[icon] = var
if want_uv and DWH_WIND_SPEED in raw and DWH_WIND_DIR in raw:
ff = raw[DWH_WIND_SPEED]
dd_rad = np.deg2rad(raw[DWH_WIND_DIR])
if "U_10M" in params:
out["U_10M"] = -ff * np.sin(dd_rad)
if "V_10M" in params:
out["V_10M"] = -ff * np.cos(dd_rad)

out = out.dropna("values", how="all")
times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]")
return _select_valid_times(ds, times)
return _select_valid_times(out, times, strict=True)


def load_truth_data(
root, reftime: datetime, steps: list[int], params: list[str]
) -> xr.Dataset:
"""Load truth data from analysis Zarr or PeakWeather observations."""
"""Load truth data from an analysis Zarr dataset or DWH observations via jretrieve."""
if root.suffix == ".zarr":
LOG.info("Loading ground truth from an analysis zarr dataset...")
truth = load_analysis_data_from_zarr(
Expand All @@ -342,9 +411,9 @@ def load_truth_data(
if "y" in truth.dims and "x" in truth.dims
else {"values": -1}
)
elif "peakweather" in str(root):
LOG.info("Loading ground truth from PeakWeather observations...")
truth = load_obs_data_from_peakweather(
elif "jretrieve" in str(root):
LOG.info("Loading ground truth from JRetrieve...")
truth = load_obs_data_from_jretrieve(
root=root,
reftime=reftime,
steps=steps,
Expand Down
Loading
Loading