From 453d76016bdc950805cd54fef8b940c8190e9db7 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 28 May 2026 15:23:08 +0200 Subject: [PATCH 01/51] plot more percentile in diurnal cycle --- .../diurnal_cycle_precip_high_percentiles.py | 177 ++++++++++++++++++ src/hirad/eval_precip.sh | 37 ++++ 2 files changed, 214 insertions(+) create mode 100644 src/hirad/eval/diurnal_cycle_precip_high_percentiles.py create mode 100644 src/hirad/eval_precip.sh diff --git a/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py b/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py new file mode 100644 index 00000000..8f0eb612 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py @@ -0,0 +1,177 @@ +""" +Plots the diurnal cycle of the all-hour 99th, 99.9th, and 99.99th percentiles of +precipitation, a somewhat reliable measure of the precipitation intensity. + +Each hour, member and type is treated separately, to conserve memory... but if the +period is long, this can still be a lot of data and thus an OOM error can occur. +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray as xr + +from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir + + +def save_plot(hours, lines, labels, ylabel, title, out_path): + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(8,4)) + for data, label in zip(lines, labels): + if isinstance(data, tuple): # (mean, std) + mean, std = data + lower = np.maximum(np.array(mean) - std, 0) + upper = np.array(mean) + std + line, = plt.plot(hours, mean, label=label) + plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color()) + else: + plt.plot(hours, data, label=label) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig(out_path) + plt.close() + + +def main(cfg: dict): + # Setup logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for diurnal cycle of high percentiles of precipitation") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Loaded {len(times)} timesteps to process") + + # Output root + out_root = Path(generation_dir) + + # Find channel indices + indices = get_channel_indices(gen_cfg) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) + land_bool = land_mask.notnull().stack(space=('lat', 'lon')) + + percentile_configs = [ + (0.99, 'p99', '99th'), + (0.999, 'p999', '99.9th'), + (0.9999, 'p9999', '99.99th'), + ] + + # Storage for diurnal cycles: pct_mean[pct_key][mode], pct_std[pct_key]['prediction'] + pct_mean = {key: {} for _, key, _ in percentile_configs} + pct_std = {key: {} for _, key, _ in percentile_configs} + + # -- Process target and baseline -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + + data_list = [] + try: + for ts in times: + data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor") + data_list.append(data) + except: + logger.error(f"Error loading data for mode {mode}. Skipping.") + continue + + da = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times], + 'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']} + ) + + # Select only land pixels to avoid all-NaN slices in quantile + da_land = da.stack(space=('lat', 'lon')).isel(space=land_bool.values) + + for q, key, _ in percentile_configs: + hourly_pct = da_land.groupby('time.hour').quantile(q, dim='time') + pct_mean[key][mode] = hourly_pct.mean(dim='space') + + # -- Predictions: compute per hour per member, then mean+std across members -- + logger.info("Processing predictions") + + # Load all prediction data at once into xarray + pred_data_list = [] + for ts in times: + preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor") # [n_members, n_channels, lat, lon] + tp_data = preds[:, tp_out] # [n_members, lat, lon] + tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'], + coords={'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']}) + pred_data_list.append(tp_da) + + pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] + pred_da = pred_da.assign_coords({ + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + }) + pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') + + # Select only land pixels to avoid all-NaN slices in quantile + pred_da_land = pred_da.stack(space=('lat', 'lon')).isel(space=land_bool.values) + + for q, key, label in percentile_configs: + logger.info(f'Calculating {label} percentile for predictions') + hourly_pct_by_member = pred_da_land.groupby('time.hour').quantile(q, dim='time').mean(dim='space') + pct_mean[key]['prediction'] = hourly_pct_by_member.mean(dim='member') + pct_std[key]['prediction'] = hourly_pct_by_member.std(dim='member') + + # Prepare cyclic lists for plotting + def cycle_fn(x): + vals = x.values.tolist() + return vals + [vals[0]] + + logger.info("Preparing data for plotting") + hrs_c = list(range(24)) + [24] + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + + for _, key, label in percentile_configs: + m = pct_mean[key] + s = pct_std[key] + + lines = [] + plot_labels = [] + + if 'target' in m: + lines.append(cycle_fn(m['target'])) + plot_labels.append('Target') + if 'baseline' in m: + lines.append(cycle_fn(m['baseline'])) + plot_labels.append('Input') + if 'prediction' in m: + lines.append((cycle_fn(m['prediction']), cycle_fn(s['prediction']))) + plot_labels.append(f'CorrDiff {label} Pct ± Std') + if 'regression-prediction' in m: + lines.append(cycle_fn(m['regression-prediction'])) + plot_labels.append('Regression Prediction') + + fn = output_path / f'diurnal_cycle_precip_{key}_percentile.png' + save_plot( + hrs_c, + lines, + plot_labels, + 'Precipitation (mm/day)', + f'Diurnal Cycle of {label}-Percentile Precipitation', + fn + ) + logger.info(f"Plot saved: {fn}") + +if __name__ == '__main__': + main(parse_eval_cli()) \ No newline at end of file diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh new file mode 100644 index 00000000..7d359c69 --- /dev/null +++ b/src/hirad/eval_precip.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +#SBATCH --job-name="eval_precip" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=24:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_precip_bias.log + +### ENVIRONMENT #### +#SBATCH -A c38 + +### CONFIG ### +CONFIG_NAME="src/hirad/conf/eval_real.yaml" + +srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . + + # Diurnal cycle + # python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=${CONFIG_NAME} + # python src/hirad/eval/diurnal_cycle_precip_high_percentiles.py --config-name=${CONFIG_NAME} + + # Histograms + # python src/hirad/eval/hist.py --config-name=${CONFIG_NAME} + # python src/hirad/eval/probability_of_exceedance.py --config-name=${CONFIG_NAME} + + # Maps + # python src/hirad/eval/map_precip_stats.py --config-name=${CONFIG_NAME} +" \ No newline at end of file From abdb136c166c173bfd673f8951ef6c17d4e256be Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 28 May 2026 15:29:18 +0200 Subject: [PATCH 02/51] plot maps of diurnal cycle --- src/hirad/eval/diurnal_cycle_precip_maps.py | 152 ++++++++++++++++++++ src/hirad/eval_precip.sh | 1 + 2 files changed, 153 insertions(+) create mode 100644 src/hirad/eval/diurnal_cycle_precip_maps.py diff --git a/src/hirad/eval/diurnal_cycle_precip_maps.py b/src/hirad/eval/diurnal_cycle_precip_maps.py new file mode 100644 index 00000000..501db8a6 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip_maps.py @@ -0,0 +1,152 @@ +""" +Plots mean precipitation and wet-hour fraction maps for each hour of the diurnal cycle. + +For each hour (00-23 UTC), all timesteps with that hour are averaged into +a single spatial map, producing 24 maps per source per variable: + - mean precipitation (mm/h) + - wet-hour fraction (% of timesteps where precip > wet_threshold) +""" +import logging +from collections import defaultdict +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +from hirad.eval.eval_utils import ( + get_channel_indices, + grid_cfg_from_cfg, + load_generation_setup, + parse_eval_cli, + resolve_ts_dir, +) +from hirad.eval.plotting import plot_map, plot_map_precipitation + + +def main(cfg: dict) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting diurnal cycle precipitation map generation") + + grid_cfg = grid_cfg_from_cfg(cfg) + + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + + out_root = Path(generation_dir) + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "diurnal_cycle_precip_maps" + output_path.mkdir(parents=True, exist_ok=True) + + indices = get_channel_indices(gen_cfg) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + # conv_factor_hourly converts ERA5 accumulated precip (m) to mm/h + conv_factor = cfg.get("conv_factor_hourly", 1000) + wet_threshold = cfg.get("wet_threshold", 0.1) # mm/h + log_interval = cfg.get("log_interval", 24) + + logger.info(f"TP channel indices — output: {tp_out}, input: {tp_in}") + logger.info(f"Wet-hour threshold: {wet_threshold} mm/h") + logger.info(f"Processing {len(times)} timesteps") + + # Group timestep strings by UTC hour (no data loaded yet) + times_by_hour: dict[int, list[str]] = defaultdict(list) + for ts in times: + times_by_hour[datetime.strptime(ts, "%Y%m%d-%H%M").hour].append(ts) + + hours_present = sorted(times_by_hour) + logger.info(f"Hours present: {hours_present}") + + # Detect whether regression-prediction files exist (check first timestep) + first_ts = times[0] + first_dir = resolve_ts_dir(out_root, first_ts) / first_ts + has_regression = (first_dir / f"{first_ts}-regression-prediction").exists() + + sources = [ + ('target', 'Target', tp_out), + ('baseline', 'Input', tp_in), + ('predictions', 'CorrDiff Ensemble Mean', tp_out), + ] + if has_regression: + sources.append(('regression-prediction', 'Regression Prediction', tp_out)) + + for source_key, source_label, tp_idx in sources: + (output_path / source_key).mkdir(parents=True, exist_ok=True) + (output_path / f"{source_key}_wethour").mkdir(parents=True, exist_ok=True) + + H, W = cfg.get("height"), cfg.get("width") + + for hour in hours_present: + hour_times = times_by_hour[hour] + logger.info(f"Hour {hour:02d}:00 UTC — loading {len(hour_times)} timesteps") + + # Accumulators for this hour only + sums = {s[0]: np.zeros((H, W), dtype=np.float64) for s in sources} + wet_sums = {s[0]: np.zeros((H, W), dtype=np.float64) for s in sources} + count = 0 + + for idx, ts in enumerate(hour_times, 1): + ts_dir = resolve_ts_dir(out_root, ts) / ts + + target = np.asarray(torch.load(ts_dir / f"{ts}-target", weights_only=False)[tp_out]) * conv_factor + sums['target'] += target + wet_sums['target'] += (target > wet_threshold).astype(np.float64) + + baseline = np.asarray(torch.load(ts_dir / f"{ts}-baseline", weights_only=False)[tp_in]) * conv_factor + sums['baseline'] += baseline + wet_sums['baseline'] += (baseline > wet_threshold).astype(np.float64) + + preds = np.asarray(torch.load(ts_dir / f"{ts}-predictions", weights_only=False)[:, tp_out]) * conv_factor + pred_mean = preds.mean(axis=0) + sums['predictions'] += pred_mean + # wet-hour frequency: fraction of members that are wet, then average over timesteps + wet_sums['predictions'] += (preds > wet_threshold).mean(axis=0) + + if has_regression: + reg = np.asarray(torch.load(ts_dir / f"{ts}-regression-prediction", weights_only=False)[tp_out]) * conv_factor + sums['regression-prediction'] += reg + wet_sums['regression-prediction'] += (reg > wet_threshold).astype(np.float64) + + count += 1 + if idx % log_interval == 0 or idx == len(hour_times): + logger.info(f" Loaded {idx}/{len(hour_times)} ({ts})") + + # Plot and immediately discard the accumulators + for source_key, source_label, _ in sources: + mean_map = sums[source_key] / count + title = f"{source_label} — Mean Diurnal Precip {hour:02d}:00 UTC (n={count})" + out_file = str(output_path / source_key / f"diurnal_mean_precip_{source_key}_{hour:02d}h") + plot_map_precipitation( + mean_map, out_file, + title=title, + threshold=0.01, + rfac=1.0, + grid_cfg=grid_cfg, + ) + + wet_map = wet_sums[source_key] / count * 100.0 # percent + wh_title = f"{source_label} — Wet-Hour Fraction {hour:02d}:00 UTC (n={count})" + wh_out_file = str(output_path / f"{source_key}_wethour" / f"diurnal_wethour_{source_key}_{hour:02d}h") + plot_map( + wet_map, wh_out_file, + title=wh_title, + label="Wet-Hour Fraction [%]", + vmin=0, vmax=30, + cmap="PuBu", + extend="max", + grid_cfg=grid_cfg, + ) + + del sums, wet_sums + logger.info(f"Hour {hour:02d}:00 UTC — maps saved") + + logger.info("Diurnal cycle precipitation maps complete.") + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index 7d359c69..ed040373 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -27,6 +27,7 @@ srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.to # Diurnal cycle # python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=${CONFIG_NAME} # python src/hirad/eval/diurnal_cycle_precip_high_percentiles.py --config-name=${CONFIG_NAME} + # python -m hirad.eval.diurnal_cycle_precip_maps --config-name=${CONFIG_NAME} # Histograms # python src/hirad/eval/hist.py --config-name=${CONFIG_NAME} From ade9bb2c57f92cbac98140c5596c3d84b5d32e61 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 28 May 2026 16:04:55 +0200 Subject: [PATCH 03/51] WIP: QQ variant --- src/hirad/eval/bias_by_percentile_precip.py | 401 ++++++++++++++++++++ src/hirad/eval_precip.sh | 5 +- 2 files changed, 405 insertions(+), 1 deletion(-) create mode 100644 src/hirad/eval/bias_by_percentile_precip.py diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py new file mode 100644 index 00000000..b4458822 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -0,0 +1,401 @@ +""" +Plots the bias (prediction - target) as a function of percentile for precipitation. + +For each percentile level p, the bias is: + + bias(p) = quantile_pred(p) - quantile_target(p) + +Positive bias means the model over-predicts at that quantile; negative means +under-prediction. For ensemble predictions the per-member biases are averaged +and the ±1 sigma spread is shaded. +""" +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from hirad.eval.eval_utils import ( + get_channel_indices, + load_generation_setup, + load_land_sea_mask, + parse_eval_cli, + resolve_ts_dir, +) + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _hist_quantiles(hist_counts, bin_edges, frac_percentiles): + """Vectorised quantile estimation from a histogram via linear CDF interpolation. + + Parameters + ---------- + hist_counts : (N,) int array + bin_edges : (N+1,) float array + frac_percentiles : (P,) float array - fractional values in [0, 1] + + Returns + ------- + (P,) float array of estimated quantile values. + """ + cdf = np.cumsum(hist_counts) / max(hist_counts.sum(), 1) + return np.interp(frac_percentiles, cdf, bin_edges[1:]) + + +# --------------------------------------------------------------------------- +# plotting +# --------------------------------------------------------------------------- + +def save_bias_by_percentile_plot( + bias_data_dict: dict, + percentile_values: np.ndarray, + labels: list, + colors: list, + title: str, + xlabel: str, + ylabel: str, + out_path, +) -> None: + """Save a bias-by-percentile figure. + + Parameters + ---------- + bias_data_dict : dict mapping key → bias array (n_percentiles,) for single + datasets, or list/tuple of (n_percentiles,) arrays for ensembles. + """ + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + fig, ax = plt.subplots(figsize=(10, 6)) + + # Convert to fractions in (0, 1) for logit scale + frac = percentile_values / 100.0 + + for (key, bias_data), label, color in zip(bias_data_dict.items(), labels, colors): + if isinstance(bias_data, (list, tuple)): + # Ensemble: plot member average ± 1 σ + arr = np.array(bias_data) # (n_members, n_percentiles) + mean_bias = arr.mean(axis=0) + std_bias = arr.std(axis=0) + ax.plot(frac, mean_bias, color=color, label=label, linewidth=2) + ax.fill_between( + frac, + mean_bias - std_bias, + mean_bias + std_bias, + color=color, + alpha=0.2, + ) + else: + ax.plot( + frac, bias_data, + color=color, label=label, linewidth=2, alpha=0.85, + ) + + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--', label='Zero bias') + + # Logit x-axis: compresses the centre and stretches both tails + _apply_logit_xaxis(ax, frac) + ax.set_ylim(-5, 5) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +def _apply_logit_xaxis(ax, frac: np.ndarray) -> None: + """Apply logit x-axis with labelled percentile ticks.""" + ax.set_xscale('logit') + ax.set_xlim(frac[0], frac[-1]) + tick_fracs = [0.01, 0.10, 0.25, 0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] + tick_labels = ['1', '10', '25', '50', '75', '90', '99', '99.9', '99.99'] + ax.set_xticks(tick_fracs) + ax.set_xticklabels(tick_labels) + ax.grid(True, alpha=0.3, which='both') + + +def save_mae_by_percentile_plot( + mae_data_dict: dict, + percentile_values: np.ndarray, + labels: list, + colors: list, + title: str, + xlabel: str, + ylabel: str, + out_path, +) -> None: + """Save a MAE-by-percentile figure. + + For single datasets the MAE curve is plotted directly. For the ensemble + the mean absolute error across members is shown with ±1 σ shading. + """ + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(10, 6)) + frac = percentile_values / 100.0 + + for (key, mae_data), label, color in zip(mae_data_dict.items(), labels, colors): + if isinstance(mae_data, (list, tuple)): + arr = np.array(mae_data) # (n_members, n_percentiles), already absolute + mean_mae = arr.mean(axis=0) + std_mae = arr.std(axis=0) + ax.plot(frac, mean_mae, color=color, label=label, linewidth=2) + ax.fill_between( + frac, + np.maximum(mean_mae - std_mae, 0), + mean_mae + std_mae, + color=color, alpha=0.2, + ) + else: + ax.plot(frac, mae_data, color=color, label=label, linewidth=2, alpha=0.85) + + _apply_logit_xaxis(ax, frac) + ax.set_ylim(0, 5) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +def save_spread_by_percentile_plot( + spread: np.ndarray, + percentile_values: np.ndarray, + title: str, + xlabel: str, + ylabel: str, + out_path, +) -> None: + """Save an ensemble-spread-by-percentile figure. + + Spread is the inter-member standard deviation of the p-th quantile, + i.e. how much the ensemble members disagree at each percentile level. + """ + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(10, 6)) + frac = percentile_values / 100.0 + + ax.plot(frac, spread, color='green', linewidth=2, label='CorrDiff Ensemble') + _apply_logit_xaxis(ax, frac) + ax.set_ylim(0, 5) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + +def main(cfg: dict) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting bias-by-percentile computation for precipitation over land") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Loaded {len(times)} timesteps to process") + + out_root = Path(generation_dir) + + # Channel indices + indices = get_channel_indices(gen_cfg) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices – output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask( + cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width") + ) + + # Fine-grained log-spaced histogram bins for accurate tail quantiles + hist_bins = np.concatenate([ + np.array([0.0]), + np.logspace(-2, 3.2, 5000), # 0.01 → ~1585 mm/h + ]) + n_hist_bins = len(hist_bins) - 1 + + hist_counts: dict[str, np.ndarray] = {} + totals: dict[str, int] = {} + + # -- Target and deterministic baselines -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + mode_hist = np.zeros(n_hist_bins, dtype=np.int64) + mode_total = 0 + try: + for i, ts in enumerate(times): + if i % cfg.get("log_interval") == 0: + logger.info(f" Timestep {i + 1}/{len(times)}") + ch_idx = tp_out if mode in ('target', 'regression-prediction') else tp_in + data = ( + torch.load( + resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", + weights_only=False, + )[ch_idx] + * cfg.get("conv_factor_hourly") + * land_mask + ) + land_values = data.values[~np.isnan(data.values)] + mode_hist += np.histogram(land_values, bins=hist_bins)[0] + mode_total += len(land_values) + except Exception: + logger.warning(f" {mode} data not found, skipping") + continue + + hist_counts[mode] = mode_hist + totals[mode] = mode_total + logger.info(f" Processed {mode_total} land values for {mode}") + + # -- Ensemble predictions -- + logger.info("Processing predictions") + n_members: int | None = None + member_hist: list[np.ndarray] | None = None + member_totals: list[int] | None = None + + for i, ts in enumerate(times): + if i % cfg.get("log_interval") == 0: + logger.info(f" Timestep {i + 1}/{len(times)}") + preds = ( + torch.load( + resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", + weights_only=False, + ) + * cfg.get("conv_factor_hourly") + ) # shape: (n_members, n_channels, lat, lon) + + if n_members is None: + n_members = preds.shape[0] + member_hist = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)] + member_totals = [0] * n_members + + for m in range(n_members): + land_values = (preds[m, tp_out] * land_mask).values + land_values = land_values[~np.isnan(land_values)] + member_hist[m] += np.histogram(land_values, bins=hist_bins)[0] + member_totals[m] += len(land_values) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # -- Build percentile grid -- + # Dense in the body, finer in the upper tail + percentile_values = np.unique(np.concatenate([ + np.linspace(1.0, 90.0, 90), + np.linspace(90.0, 99.0, 90), + np.linspace(99.0, 99.9, 45), + np.linspace(99.9, 99.99, 20), + ])) + frac_percentiles = percentile_values / 100.0 + + # -- Compute target quantiles -- + if 'target' not in hist_counts: + logger.error("No target data found; cannot compute bias.") + return + + target_quantiles = _hist_quantiles(hist_counts['target'], hist_bins, frac_percentiles) + + # -- Compute biases -- + bias_data: dict = {} + labels: list = [] + colors: list = [] + + for mode, label, color in [ + ('baseline', 'Input', 'orange'), + ('regression-prediction', 'Regression Prediction', 'red'), + ]: + if mode in hist_counts: + bias_data[mode] = _hist_quantiles(hist_counts[mode], hist_bins, frac_percentiles) - target_quantiles + labels.append(label) + colors.append(color) + + member_biases = [] + if member_hist is not None and n_members > 0: + member_quantiles = [ + _hist_quantiles(member_hist[m], hist_bins, frac_percentiles) + for m in range(n_members) + ] + member_biases = [q - target_quantiles for q in member_quantiles] + bias_data['predictions'] = member_biases + labels.append('CorrDiff Ensemble (mean ± 1σ)') + colors.append('green') + + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + + # -- Bias plot -- + fn = output_path / 'precipitation_bias_by_percentile.png' + save_bias_by_percentile_plot( + bias_data, + percentile_values, + labels, + colors, + title='Precipitation Bias by Percentile - Over Land (Pooled Data)', + xlabel='Percentile', + ylabel='Bias (Pred − Target) [mm/h]', + out_path=fn, + ) + logger.info(f"Bias-by-percentile plot saved: {fn}") + + # -- MAE plot -- + mae_data: dict = {} + mae_labels: list = [] + mae_colors: list = [] + for mode, label, color in [ + ('baseline', 'Input', 'orange'), + ('regression-prediction', 'Regression Prediction', 'red'), + ]: + if mode in hist_counts: + pred_q = _hist_quantiles(hist_counts[mode], hist_bins, frac_percentiles) + mae_data[mode] = np.abs(pred_q - target_quantiles) + mae_labels.append(label) + mae_colors.append(color) + if member_biases: + mae_data['predictions'] = [np.abs(b) for b in member_biases] + mae_labels.append('CorrDiff Ensemble (mean ± 1σ)') + mae_colors.append('green') + + fn_mae = output_path / 'precipitation_mae_by_percentile.png' + save_mae_by_percentile_plot( + mae_data, + percentile_values, + mae_labels, + mae_colors, + title='Precipitation MAE by Percentile - Over Land (Pooled Data)', + xlabel='Percentile', + ylabel='MAE [mm/h]', + out_path=fn_mae, + ) + logger.info(f"MAE-by-percentile plot saved: {fn_mae}") + + # -- Ensemble spread plot -- + if member_biases: + # std of member quantiles = std of member biases (target_quantiles is constant) + spread = np.std(member_biases, axis=0) + fn_spread = output_path / 'precipitation_spread_by_percentile.png' + save_spread_by_percentile_plot( + spread, + percentile_values, + title='Precipitation Ensemble Spread by Percentile - Over Land (Pooled Data)', + xlabel='Percentile', + ylabel='Spread (std of member quantiles) [mm/h]', + out_path=fn_spread, + ) + logger.info(f"Spread-by-percentile plot saved: {fn_spread}") + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index ed040373..c95d90be 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -27,12 +27,15 @@ srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.to # Diurnal cycle # python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=${CONFIG_NAME} # python src/hirad/eval/diurnal_cycle_precip_high_percentiles.py --config-name=${CONFIG_NAME} - # python -m hirad.eval.diurnal_cycle_precip_maps --config-name=${CONFIG_NAME} # Histograms # python src/hirad/eval/hist.py --config-name=${CONFIG_NAME} # python src/hirad/eval/probability_of_exceedance.py --config-name=${CONFIG_NAME} + # QQ + # python -m hirad.eval.bias_by_percentile_precip --config-name=${CONFIG_NAME} + # Maps # python src/hirad/eval/map_precip_stats.py --config-name=${CONFIG_NAME} + # python -m hirad.eval.diurnal_cycle_precip_maps --config-name=${CONFIG_NAME} " \ No newline at end of file From 028092035e3568c51cfdddd9d22dd7a18402d128 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 28 May 2026 17:03:52 +0200 Subject: [PATCH 04/51] split temp into eval_temp.sh --- src/hirad/eval/diurnal_cycle_precip_p99.py | 165 ------------------ src/hirad/eval/diurnal_cycle_temp.py | 125 +++++++++++++ ...cle_temp_wind.py => diurnal_cycle_wind.py} | 50 +----- src/hirad/eval_temp.sh | 29 +++ src/hirad/eval_wind.sh | 4 +- 5 files changed, 161 insertions(+), 212 deletions(-) delete mode 100644 src/hirad/eval/diurnal_cycle_precip_p99.py create mode 100644 src/hirad/eval/diurnal_cycle_temp.py rename src/hirad/eval/{diurnal_cycle_temp_wind.py => diurnal_cycle_wind.py} (70%) create mode 100644 src/hirad/eval_temp.sh diff --git a/src/hirad/eval/diurnal_cycle_precip_p99.py b/src/hirad/eval/diurnal_cycle_precip_p99.py deleted file mode 100644 index ecd257cb..00000000 --- a/src/hirad/eval/diurnal_cycle_precip_p99.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -Plots the diurnal cycle of the all-hour 99th percentile of -precipitation, a somewhat reliable measure of the precipitation intensity. - -Each hour, member and type is treaded separately, to conserve memory... but if the -period is long, this can still be a lot of data and thus an OOM error can occur. -""" -import logging -from datetime import datetime -from pathlib import Path - -import hydra -import matplotlib.pyplot as plt -import numpy as np -import torch -import xarray as xr - -from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir - - -def save_plot(hours, lines, labels, ylabel, title, out_path): - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - plt.figure(figsize=(8,4)) - for data, label in zip(lines, labels): - if isinstance(data, tuple): # (mean, std) - mean, std = data - lower = np.maximum(np.array(mean) - std, 0) - upper = np.array(mean) + std - line, = plt.plot(hours, mean, label=label) - plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color()) - else: - plt.plot(hours, data, label=label) - plt.xlabel('Hour (UTC)') - plt.xticks(range(0,25,3)) - plt.xlim(0,24) - plt.ylabel(ylabel) - plt.title(title) - plt.grid(True) - plt.legend() - plt.tight_layout() - plt.savefig(out_path) - plt.close() - - -def main(cfg: dict): - # Setup logging - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - logger.info("Starting computation for diurnal cycle of 99th-percentile of precipitation") - try: - generation_dir, gen_cfg, times = load_generation_setup(cfg) - except ValueError as exc: - logger.error(str(exc)) - return - logger.info(f"Loaded {len(times)} timesteps to process") - - # Output root - out_root = Path(generation_dir) - - # Find channel indices - indices = get_channel_indices(gen_cfg) - tp_out = indices['output']['tp'] - tp_in = indices['input'].get('tp', tp_out) - logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - - # Land-sea mask - land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) - - # Storage for diurnal cycles - pct99_mean = {} - pct99_std = {} - - # -- Process target and baseline -- - for mode in ['target', 'baseline', 'regression-prediction']: - logger.info(f"Processing mode: {mode}") - - data_list = [] - try: - for ts in times: - data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor") - data_list.append(data) - except: - logger.error(f"Error loading data for mode {mode}. Skipping.") - continue - - da = xr.DataArray( - np.stack(data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times], - 'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']} - ) - - # Select only land pixels to avoid all-NaN slices in quantile - land_bool = land_mask.notnull().stack(space=('lat', 'lon')) - da_land = da.stack(space=('lat', 'lon')).isel(space=land_bool.values) - - # Group by hour and compute 99th percentile over time, then spatial mean - hourly_p99 = da_land.groupby('time.hour').quantile(0.99, dim='time') - pct99_mean[mode] = hourly_p99.mean(dim='space') - - # -- Predictions: compute per hour per member, then mean+std across members -- - logger.info("Processing predictions") - - # Load all prediction data at once into xarray - pred_data_list = [] - for ts in times: - preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor") # [n_members, n_channels, lat, lon] - tp_data = preds[:, tp_out] # [n_members, lat, lon] - tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'], - coords={'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']}) - pred_data_list.append(tp_da) - - pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] - pred_da = pred_da.assign_coords({ - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - }) - pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') - - # Select only land pixels to avoid all-NaN slices in quantile - land_bool = land_mask.notnull().stack(space=('lat', 'lon')) - pred_da_land = pred_da.stack(space=('lat', 'lon')).isel(space=land_bool.values) - - logger.info('Calculating 99th percentile for predictions') - # Group by hour, compute 99th percentile across time, then spatial mean over land - hourly_p99_by_member = pred_da_land.groupby('time.hour').quantile(0.99, dim='time').mean(dim='space') - - # Store ensemble statistics as xarray DataArrays - pct99_mean['prediction'] = hourly_p99_by_member.mean(dim='member') - pct99_std['prediction'] = hourly_p99_by_member.std(dim='member') - - # Prepare cyclic lists for plotting - def cycle_fn(x): - return x.values.tolist() + [x.values.tolist()[0]] - - logger.info("Preparing data for plotting") - hrs_c = list(range(24)) + [0 + 24] - pct99_lines = [ - cycle_fn(pct99_mean['target']), - cycle_fn(pct99_mean['baseline']), - ( - cycle_fn(pct99_mean['prediction']), - cycle_fn(pct99_std['prediction']) - ) - ] - if 'regression-prediction' in pct99_mean: - pct99_lines.append(cycle_fn(pct99_mean['regression-prediction'])) - - # Plot combined diurnal 99th-percentile cycle - labels = ['Target', 'Input', 'CorrDiff 99th Pct ± Std', 'Regression Prediction'] if 'regression-prediction' in pct99_mean else ['Target', 'Input', 'CorrDiff 99th Pct ± Std'] - output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") - output_path.mkdir(parents=True, exist_ok=True) - fn = output_path / 'diurnal_cycle_precip_99th_percentile.png' - save_plot( - hrs_c, - pct99_lines, - labels, - 'Precipitation (mm/day)', - 'Diurnal Cycle of 99th-Percentile Precipitation', - fn - ) - logger.info(f"Combined plot saved: {fn}") - -if __name__ == '__main__': - main(parse_eval_cli()) \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_temp.py b/src/hirad/eval/diurnal_cycle_temp.py new file mode 100644 index 00000000..940fb65b --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_temp.py @@ -0,0 +1,125 @@ +import logging +from datetime import datetime +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray as xr + +from hirad.eval.eval_utils import concat_and_group_diurnal, get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir + +def main(cfg: dict): + # Initialize + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for diurnal cycle of 2m temperature") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + logger.info(f"Loaded {len(times)} timesteps to process") + + # Indices for channels + indices = get_channel_indices(gen_cfg) + out_ch = indices['output'] + in_ch = indices['input'] + + # Temperature channel (try '2t' first, fallback to 't2m') + t2m_out = out_ch.get('2t', out_ch.get('t2m')) + t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) + + # Output path + out_root = Path(generation_dir) + def load(ts, fn): + return torch.load(resolve_ts_dir(out_root, ts) / ts / fn, weights_only=False) + + # Land-sea mask + land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) + + # Prepare lists to collect DataArrays + target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], [] + + def mean_over_land(data, dims, coords, time_coord): + da = xr.DataArray(data, dims=dims, coords=coords) * land_mask + return da.mean(dim=("lat","lon")).assign_coords(time=time_coord) + + # Loop over timestamps + for idx, ts in enumerate(times, 1): + dt = datetimes[idx-1] + + # Load data + target = load(ts, f"{ts}-target") + baseline = load(ts, f"{ts}-baseline") + predictions = load(ts, f"{ts}-predictions") + try: + regression_pred = load(ts, f"{ts}-regression-prediction") + except: + regression_pred = None + + # Process temperature (convert to Celsius) + target_temp.append(mean_over_land( + target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + baseline_temp.append(mean_over_land( + baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt)) + pred_temp.append(mean_over_land( + predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), + {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + if regression_pred is not None: + mean_pred_temp.append(mean_over_land( + regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + + if idx % cfg.get("log_interval") == 0 or idx == len(times): + logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") + + # Compute diurnal means and stds + temp_target_mean, _ = concat_and_group_diurnal(target_temp) + temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) + temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) + if mean_pred_temp: + temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp) + + def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + + data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean] + labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['Target', 'Input', 'CorrDiff ± Std(Members)'] + stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std] + + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + save_plot( + temp_target_mean.hour, + data, + stds, + labels, + '2m Temperature [°C]', + 'Diurnal Cycle of 2m Temperature', + output_path / 'diurnal_cycle_2t.png' + ) + + logger.info("Plots saved.") + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_wind.py similarity index 70% rename from src/hirad/eval/diurnal_cycle_temp_wind.py rename to src/hirad/eval/diurnal_cycle_wind.py index c9b63f8d..ea056c4c 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_wind.py @@ -2,7 +2,6 @@ from datetime import datetime from pathlib import Path -import hydra import matplotlib.pyplot as plt import numpy as np import torch @@ -15,7 +14,7 @@ def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting computation for diurnal cycles of 2m temperature and windspeed") + logger.info("Starting computation for diurnal cycle of windspeed") try: generation_dir, gen_cfg, times = load_generation_setup(cfg) except ValueError as exc: @@ -28,15 +27,11 @@ def main(cfg: dict): indices = get_channel_indices(gen_cfg) out_ch = indices['output'] in_ch = indices['input'] - - # Temperature channel (try '2t' first, fallback to 't2m') - t2m_out = out_ch.get('2t', out_ch.get('t2m')) - t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) - + # Wind channels u_out = out_ch['10u'] u_in = in_ch.get('10u', u_out) - v_out = out_ch['10v'] + v_out = out_ch['10v'] v_in = in_ch.get('10v', v_out) # Output path @@ -48,7 +43,6 @@ def load(ts, fn): land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) # Prepare lists to collect DataArrays - target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], [] target_wind, baseline_wind, pred_wind, mean_pred_wind = [], [], [], [] def mean_over_land(data, dims, coords, time_coord): @@ -68,19 +62,6 @@ def mean_over_land(data, dims, coords, time_coord): except: regression_pred = None - # Process temperature (convert to Celsius) - target_temp.append(mean_over_land( - target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) - baseline_temp.append(mean_over_land( - baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt)) - pred_temp.append(mean_over_land( - predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), - {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) - if regression_pred is not None: - mean_pred_temp.append(mean_over_land( - regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) - - # Process wind speed target_wind.append(mean_over_land( np.hypot(target[u_out], target[v_out]), ("lat","lon"), land_mask.coords, dt)) @@ -97,12 +78,6 @@ def mean_over_land(data, dims, coords, time_coord): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") # Compute diurnal means and stds - temp_target_mean, _ = concat_and_group_diurnal(target_temp) - temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) - temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) - if mean_pred_temp: - temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp) - wind_target_mean, _ = concat_and_group_diurnal(target_wind) wind_baseline_mean, _ = concat_and_group_diurnal(baseline_wind) wind_pred_mean, wind_pred_std = concat_and_group_diurnal(pred_wind, is_member=True) @@ -130,27 +105,12 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): plt.savefig(out_path) plt.close() - data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean] - labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['Target', 'Input', 'CorrDiff ± Std(Members)'] - stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std] - - output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") - output_path.mkdir(parents=True, exist_ok=True) - # Generate plots - save_plot( - temp_target_mean.hour, - data, - stds, - labels, - '2m Temperature [°C]', - 'Diurnal Cycle of 2m Temperature', - output_path / 'diurnal_cycle_2t.png' - ) - data = [wind_target_mean, wind_baseline_mean, wind_pred_mean, wind_mean_pred_mean] if mean_pred_wind else [wind_target_mean, wind_baseline_mean, wind_pred_mean] labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wind else ['Target', 'Input', 'CorrDiff ± Std(Members)'] stds = [None, None, wind_pred_std, None] if mean_pred_wind else [None, None, wind_pred_std] + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) save_plot( wind_target_mean.hour, data, diff --git a/src/hirad/eval_temp.sh b/src/hirad/eval_temp.sh new file mode 100644 index 00000000..fe0108ea --- /dev/null +++ b/src/hirad/eval_temp.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +#SBATCH --job-name="eval_temp" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=12:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_temp.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +### CONFIG ### +CONFIG_NAME="src/hirad/conf/eval_real.yaml" + +srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . + + # Diurnal cycle of 2m temperature + # python src/hirad/eval/diurnal_cycle_temp.py --config-name=${CONFIG_NAME} +" diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index 88f3f221..7a9176c5 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -24,8 +24,8 @@ CONFIG_NAME="src/hirad/conf/eval_real.yaml" srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . - # Diurnal cycle - # python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=${CONFIG_NAME} + # Diurnal cycle of windspeed + # python src/hirad/eval/diurnal_cycle_wind.py --config-name=${CONFIG_NAME} # Probability of exceedance # python src/hirad/eval/probability_of_exceedance_wind.py --config-name=${CONFIG_NAME} From 0940f9a79618e65c17a055300edc6fa9ab945d55 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 28 May 2026 18:16:40 +0200 Subject: [PATCH 05/51] average at the end --- src/hirad/eval/bias_by_percentile_precip.py | 360 +++++++++++--------- 1 file changed, 207 insertions(+), 153 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index b4458822..cfb703c2 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -1,13 +1,10 @@ """ -Plots the bias (prediction - target) as a function of percentile for precipitation. +Plots bias / MAE / spread as a function of percentile for precipitation, using a +local-then-averaged estimator. -For each percentile level p, the bias is: - - bias(p) = quantile_pred(p) - quantile_target(p) - -Positive bias means the model over-predicts at that quantile; negative means -under-prediction. For ensemble predictions the per-member biases are averaged -and the ±1 sigma spread is shaded. +For each grid point g (and ensemble member m), a histogram of precipitation is +built over time and the per-percentile quantile q_{g,m}(p) is estimated. +Spatial / member averaging is then applied to produce the plotted curves: """ import logging from pathlib import Path @@ -29,21 +26,66 @@ # helpers # --------------------------------------------------------------------------- -def _hist_quantiles(hist_counts, bin_edges, frac_percentiles): - """Vectorised quantile estimation from a histogram via linear CDF interpolation. +def _accumulate_per_point_hist(pp_counts: np.ndarray, values_land: np.ndarray, + bin_edges: np.ndarray, n_bins: int) -> None: + """In-place add of one timestep of land values into per-grid-point histograms. - Parameters - ---------- - hist_counts : (N,) int array - bin_edges : (N+1,) float array - frac_percentiles : (P,) float array - fractional values in [0, 1] + pp_counts : (n_land, n_bins) int32 – modified in place. + values_land : (n_land,) float – one value per land grid point. + """ + # searchsorted over the interior edges: result in [0, n_bins - 1] + bin_idx = np.searchsorted(bin_edges[1:-1], values_land, side='right') + n_land = pp_counts.shape[0] + # Fancy indexing with unique indices is fully vectorised. + pp_counts.reshape(-1)[np.arange(n_land) * n_bins + bin_idx] += 1 + + +def _per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, + frac_percentiles: np.ndarray) -> np.ndarray: + """Estimate per-row quantiles from per-grid-point histograms. - Returns - ------- - (P,) float array of estimated quantile values. + Returns (n_land, P) float32 array. Uses upper-bin-edge values (no in-bin + interpolation) — adequate given fine log-spaced bins. """ - cdf = np.cumsum(hist_counts) / max(hist_counts.sum(), 1) - return np.interp(frac_percentiles, cdf, bin_edges[1:]) + cdf = pp_counts.astype(np.float32, copy=True) + np.cumsum(cdf, axis=1, out=cdf) + totals = cdf[:, -1:].copy() + cdf /= np.maximum(totals, 1.0) + + edges_upper = bin_edges[1:].astype(np.float32) + n_land, n_bins = pp_counts.shape + out = np.empty((n_land, len(frac_percentiles)), dtype=np.float32) + + # Reusable bool buffer to avoid repeated allocation. + buf = np.empty(cdf.shape, dtype=bool) + for j, p in enumerate(frac_percentiles): + np.less(cdf, p, out=buf) + idx = buf.sum(axis=1) # first bin where cdf >= p + np.clip(idx, 0, n_bins - 1, out=idx) + out[:, j] = edges_upper[idx] + return out + + +def _build_per_point_histogram(load_fn, times: list, land_idx: np.ndarray, + hist_bins: np.ndarray, log_interval: int, + logger: logging.Logger, mode_name: str + ) -> np.ndarray | None: + """Stream timesteps through `load_fn(ts) -> (H*W,) float array` and return + (n_land, n_bins) int32 per-grid-point histogram, or None on failure.""" + n_land = land_idx.size + n_bins = len(hist_bins) - 1 + pp_counts = np.zeros((n_land, n_bins), dtype=np.int32) + try: + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f" [{mode_name}] timestep {i + 1}/{len(times)}") + flat = load_fn(ts) # (H*W,) float, no NaN + _accumulate_per_point_hist(pp_counts, flat[land_idx], hist_bins, n_bins) + except FileNotFoundError: + logger.warning(f" {mode_name} data not found, skipping") + return None + return pp_counts + # --------------------------------------------------------------------------- @@ -98,7 +140,10 @@ def save_bias_by_percentile_plot( # Logit x-axis: compresses the centre and stretches both tails _apply_logit_xaxis(ax, frac) - ax.set_ylim(-5, 5) + # Symlog: linear within ±linthresh, logarithmic beyond → "log away from zero" + ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) + ax.set_ylim(-10, 10) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) @@ -154,7 +199,9 @@ def save_mae_by_percentile_plot( ax.plot(frac, mae_data, color=color, label=label, linewidth=2, alpha=0.85) _apply_logit_xaxis(ax, frac) - ax.set_ylim(0, 5) + ax.set_yscale('log') + ax.set_ylim(1e-3, 10) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) @@ -183,7 +230,9 @@ def save_spread_by_percentile_plot( ax.plot(frac, spread, color='green', linewidth=2, label='CorrDiff Ensemble') _apply_logit_xaxis(ax, frac) - ax.set_ylim(0, 5) + ax.set_yscale('log') + ax.set_ylim(1e-3, 10) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) @@ -201,7 +250,7 @@ def main(cfg: dict) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting bias-by-percentile computation for precipitation over land") + logger.info("Starting per-gridpoint bias-by-percentile computation for precipitation over land") try: generation_dir, gen_cfg, times = load_generation_setup(cfg) except ValueError as exc: @@ -215,84 +264,28 @@ def main(cfg: dict) -> None: indices = get_channel_indices(gen_cfg) tp_out = indices['output']['tp'] tp_in = indices['input'].get('tp', tp_out) - logger.info(f"TP channel indices – output: {tp_out}, input: {tp_in}") + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - # Land-sea mask - land_mask = load_land_sea_mask( + # Land-sea mask: build a boolean mask and a flat index list of land points + land_da = load_land_sea_mask( cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width") ) - - # Fine-grained log-spaced histogram bins for accurate tail quantiles + land_bool_2d = np.isfinite(land_da.values) # (H, W) + land_idx = np.flatnonzero(land_bool_2d.ravel()) # (n_land,) + n_land = land_idx.size + logger.info(f"{n_land} land grid points") + + # Log-spaced histogram bins. Coarser than the pooled version since each + # grid point only contributes ~ T samples (here T = {len(times)}). + n_bins = 500 hist_bins = np.concatenate([ np.array([0.0]), - np.logspace(-2, 3.2, 5000), # 0.01 → ~1585 mm/h + np.logspace(-2, 3.2, n_bins), # 0.01 -> ~1585 mm/h ]) - n_hist_bins = len(hist_bins) - 1 - - hist_counts: dict[str, np.ndarray] = {} - totals: dict[str, int] = {} - - # -- Target and deterministic baselines -- - for mode in ['target', 'baseline', 'regression-prediction']: - logger.info(f"Processing mode: {mode}") - mode_hist = np.zeros(n_hist_bins, dtype=np.int64) - mode_total = 0 - try: - for i, ts in enumerate(times): - if i % cfg.get("log_interval") == 0: - logger.info(f" Timestep {i + 1}/{len(times)}") - ch_idx = tp_out if mode in ('target', 'regression-prediction') else tp_in - data = ( - torch.load( - resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", - weights_only=False, - )[ch_idx] - * cfg.get("conv_factor_hourly") - * land_mask - ) - land_values = data.values[~np.isnan(data.values)] - mode_hist += np.histogram(land_values, bins=hist_bins)[0] - mode_total += len(land_values) - except Exception: - logger.warning(f" {mode} data not found, skipping") - continue - - hist_counts[mode] = mode_hist - totals[mode] = mode_total - logger.info(f" Processed {mode_total} land values for {mode}") + log_interval = cfg.get("log_interval", 24) + conv = cfg.get("conv_factor_hourly") - # -- Ensemble predictions -- - logger.info("Processing predictions") - n_members: int | None = None - member_hist: list[np.ndarray] | None = None - member_totals: list[int] | None = None - - for i, ts in enumerate(times): - if i % cfg.get("log_interval") == 0: - logger.info(f" Timestep {i + 1}/{len(times)}") - preds = ( - torch.load( - resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", - weights_only=False, - ) - * cfg.get("conv_factor_hourly") - ) # shape: (n_members, n_channels, lat, lon) - - if n_members is None: - n_members = preds.shape[0] - member_hist = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)] - member_totals = [0] * n_members - - for m in range(n_members): - land_values = (preds[m, tp_out] * land_mask).values - land_values = land_values[~np.isnan(land_values)] - member_hist[m] += np.histogram(land_values, bins=hist_bins)[0] - member_totals[m] += len(land_values) - - logger.info(f"Collected {n_members} ensemble members for predictions") - - # -- Build percentile grid -- - # Dense in the body, finer in the upper tail + # -- Percentile grid (denser in the upper tail) -- percentile_values = np.unique(np.concatenate([ np.linspace(1.0, 90.0, 90), np.linspace(90.0, 99.0, 90), @@ -300,98 +293,159 @@ def main(cfg: dict) -> None: np.linspace(99.9, 99.99, 20), ])) frac_percentiles = percentile_values / 100.0 - - # -- Compute target quantiles -- - if 'target' not in hist_counts: + P = len(frac_percentiles) + + # --------------------------------------------------------------------- + # Loader helpers + # --------------------------------------------------------------------- + def _load_det(ts, mode): + ch_idx = tp_out if mode in ('target', 'regression-prediction') else tp_in + arr = torch.load( + resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", + weights_only=False, + )[ch_idx] + arr = np.asarray(getattr(arr, 'values', arr)) * conv + return arr.ravel() + + def _load_member(ts, m, preds_cache): + arr = preds_cache[m, tp_out] + arr = np.asarray(getattr(arr, 'values', arr)) * conv + return arr.ravel() + + # --------------------------------------------------------------------- + # Target: per-grid-point quantiles (needed by everything else) + # --------------------------------------------------------------------- + logger.info("Processing target") + target_pp_counts = _build_per_point_histogram( + lambda ts: _load_det(ts, 'target'), + times, land_idx, hist_bins, log_interval, logger, 'target', + ) + if target_pp_counts is None: logger.error("No target data found; cannot compute bias.") return + target_q = _per_point_quantiles(target_pp_counts, hist_bins, frac_percentiles) + del target_pp_counts + target_mean_q = target_q.mean(axis=0) - target_quantiles = _hist_quantiles(hist_counts['target'], hist_bins, frac_percentiles) - - # -- Compute biases -- + # --------------------------------------------------------------------- + # Deterministic modes + # --------------------------------------------------------------------- bias_data: dict = {} + mae_data: dict = {} labels: list = [] colors: list = [] + mae_labels: list = [] + mae_colors: list = [] for mode, label, color in [ ('baseline', 'Input', 'orange'), ('regression-prediction', 'Regression Prediction', 'red'), ]: - if mode in hist_counts: - bias_data[mode] = _hist_quantiles(hist_counts[mode], hist_bins, frac_percentiles) - target_quantiles - labels.append(label) - colors.append(color) - - member_biases = [] - if member_hist is not None and n_members > 0: - member_quantiles = [ - _hist_quantiles(member_hist[m], hist_bins, frac_percentiles) - for m in range(n_members) - ] - member_biases = [q - target_quantiles for q in member_quantiles] + logger.info(f"Processing {mode}") + pp = _build_per_point_histogram( + lambda ts, m=mode: _load_det(ts, m), + times, land_idx, hist_bins, log_interval, logger, mode, + ) + if pp is None: + continue + pred_q = _per_point_quantiles(pp, hist_bins, frac_percentiles) + del pp + bias_data[mode] = pred_q.mean(axis=0) - target_mean_q + mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) + labels.append(label) + colors.append(color) + mae_labels.append(label) + mae_colors.append(color) + + # --------------------------------------------------------------------- + # Ensemble predictions: process one timestep at a time, accumulating + # per-member per-grid-point histograms. Then collapse per-member to + # spatial-mean curves and online-aggregate ensemble statistics. + # --------------------------------------------------------------------- + logger.info("Processing predictions (per-member, per-grid-point)") + n_members: int | None = None + member_pp: list[np.ndarray] | None = None + + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f" [predictions] timestep {i + 1}/{len(times)}") + preds = torch.load( + resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", + weights_only=False, + ) # (n_members, n_channels, H, W) + if n_members is None: + n_members = preds.shape[0] + member_pp = [ + np.zeros((n_land, n_bins), dtype=np.int32) for _ in range(n_members) + ] + logger.info(f" Detected {n_members} ensemble members") + for m in range(n_members): + arr = preds[m, tp_out] + flat = (np.asarray(getattr(arr, 'values', arr)) * conv).ravel() + _accumulate_per_point_hist(member_pp[m], flat[land_idx], hist_bins, n_bins) + + member_biases: list[np.ndarray] = [] + member_maes: list[np.ndarray] = [] + # Online aggregates for spread: E[std_m(q_{g,m})] over g. + # Need per-(g, p) std across members -> keep running sum and sum-of-squares + # of per-member per-point quantiles. + if member_pp is not None and n_members is not None and n_members > 0: + sum_q = np.zeros((n_land, P), dtype=np.float64) + sumsq_q = np.zeros((n_land, P), dtype=np.float64) + for m in range(n_members): + qm = _per_point_quantiles(member_pp[m], hist_bins, frac_percentiles) + member_pp[m] = None # free as we go + sum_q += qm + sumsq_q += qm.astype(np.float64) ** 2 + member_biases.append((qm.mean(axis=0) - target_mean_q).astype(np.float64)) + member_maes.append(np.abs(qm - target_q).mean(axis=0).astype(np.float64)) + mean_q = sum_q / n_members + var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) + # Per-gridpoint std across members, then spatial mean. + spread = np.sqrt(var_q).mean(axis=0) + bias_data['predictions'] = member_biases - labels.append('CorrDiff Ensemble (mean ± 1σ)') + mae_data['predictions'] = member_maes + labels.append('CorrDiff Ensemble (mean +/- 1 sigma)') colors.append('green') + mae_labels.append('CorrDiff Ensemble (mean +/- 1 sigma)') + mae_colors.append('green') + else: + spread = None - output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + # --------------------------------------------------------------------- + # Output + # --------------------------------------------------------------------- + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "by_percentile" output_path.mkdir(parents=True, exist_ok=True) - # -- Bias plot -- fn = output_path / 'precipitation_bias_by_percentile.png' save_bias_by_percentile_plot( - bias_data, - percentile_values, - labels, - colors, - title='Precipitation Bias by Percentile - Over Land (Pooled Data)', + bias_data, percentile_values, labels, colors, + title='Precipitation Bias by Percentile - Over Land (Per-Gridpoint)', xlabel='Percentile', - ylabel='Bias (Pred − Target) [mm/h]', + ylabel='Bias [mm/h]', out_path=fn, ) logger.info(f"Bias-by-percentile plot saved: {fn}") - # -- MAE plot -- - mae_data: dict = {} - mae_labels: list = [] - mae_colors: list = [] - for mode, label, color in [ - ('baseline', 'Input', 'orange'), - ('regression-prediction', 'Regression Prediction', 'red'), - ]: - if mode in hist_counts: - pred_q = _hist_quantiles(hist_counts[mode], hist_bins, frac_percentiles) - mae_data[mode] = np.abs(pred_q - target_quantiles) - mae_labels.append(label) - mae_colors.append(color) - if member_biases: - mae_data['predictions'] = [np.abs(b) for b in member_biases] - mae_labels.append('CorrDiff Ensemble (mean ± 1σ)') - mae_colors.append('green') - fn_mae = output_path / 'precipitation_mae_by_percentile.png' save_mae_by_percentile_plot( - mae_data, - percentile_values, - mae_labels, - mae_colors, - title='Precipitation MAE by Percentile - Over Land (Pooled Data)', + mae_data, percentile_values, mae_labels, mae_colors, + title='Precipitation MAE by Percentile - Over Land (Per-Gridpoint)', xlabel='Percentile', ylabel='MAE [mm/h]', out_path=fn_mae, ) logger.info(f"MAE-by-percentile plot saved: {fn_mae}") - # -- Ensemble spread plot -- - if member_biases: - # std of member quantiles = std of member biases (target_quantiles is constant) - spread = np.std(member_biases, axis=0) + if spread is not None: fn_spread = output_path / 'precipitation_spread_by_percentile.png' save_spread_by_percentile_plot( - spread, - percentile_values, - title='Precipitation Ensemble Spread by Percentile - Over Land (Pooled Data)', + spread, percentile_values, + title='Precipitation Ensemble Spread by Percentile - Over Land (Per-Gridpoint)', xlabel='Percentile', - ylabel='Spread (std of member quantiles) [mm/h]', + ylabel='Spread (mean over land of std across members) [mm/h]', out_path=fn_spread, ) logger.info(f"Spread-by-percentile plot saved: {fn_spread}") From ec8df12cb7311a5c20804eaabae8823ccdf7de61 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 2 Jun 2026 15:53:52 +0200 Subject: [PATCH 06/51] submit more jobs in parallel --- src/hirad/eval_precip.sh | 60 +++++++++++++++++++--------------------- src/hirad/eval_temp.sh | 45 +++++++++++++++--------------- src/hirad/eval_wind.sh | 51 +++++++++++++++++----------------- 3 files changed, 77 insertions(+), 79 deletions(-) diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index c12a11be..aacdae03 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -1,41 +1,39 @@ #!/bin/bash -#SBATCH --job-name="eval_precip" - -### HARDWARE ### -#SBATCH --partition=normal -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=12:00:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=./logs/plots_precip.log - -### ENVIRONMENT #### -#SBATCH -A a161 +set -euo pipefail ### CONFIG ### CONFIG_NAME="src/hirad/conf/eval_real.yaml" -srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . - +CMDS=( # Diurnal cycle - # python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=${CONFIG_NAME} - # python src/hirad/eval/diurnal_cycle_precip_high_percentiles.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py" + "python src/hirad/eval/diurnal_cycle_precip_high_percentiles.py" # Histograms - # python src/hirad/eval/hist.py --config-name=${CONFIG_NAME} - # python src/hirad/eval/probability_of_exceedance.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/hist.py" + "python src/hirad/eval/probability_of_exceedance.py" # QQ - # python -m hirad.eval.bias_by_percentile_precip --config-name=${CONFIG_NAME} - + "python -m hirad.eval.bias_by_percentile_precip" # Maps - # python src/hirad/eval/map_precip_stats.py --config-name=${CONFIG_NAME} - # python -m hirad.eval.diurnal_cycle_precip_maps --config-name=${CONFIG_NAME} -" \ No newline at end of file + "python src/hirad/eval/map_precip_stats.py" + "python -m hirad.eval.diurnal_cycle_precip_maps" +) + +for cmd in "${CMDS[@]}"; do + name=$(basename "${cmd##* }" .py | tr '.' '_') + job_id=$(sbatch \ + --job-name="eval_precip_${name}" \ + --partition=normal \ + --nodes=1 \ + --ntasks-per-node=1 \ + --gpus-per-node=1 \ + --cpus-per-task=72 \ + --time=24:00:00 \ + --no-requeue \ + --exclusive \ + -A c38 \ + --output="./logs/plots_precip_${name}_%j.log" \ + --parsable \ + --wrap="srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -lc 'pip install -e . && ${cmd} --config-name=${CONFIG_NAME}'") + echo "Submitted ${name}: ${job_id}" +done \ No newline at end of file diff --git a/src/hirad/eval_temp.sh b/src/hirad/eval_temp.sh index fe0108ea..1b56acf9 100644 --- a/src/hirad/eval_temp.sh +++ b/src/hirad/eval_temp.sh @@ -1,29 +1,30 @@ #!/bin/bash -#SBATCH --job-name="eval_temp" - -### HARDWARE ### -#SBATCH --partition=normal -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=12:00:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=./logs/plots_temp.log - -### ENVIRONMENT #### -#SBATCH -A a161 +set -euo pipefail ### CONFIG ### CONFIG_NAME="src/hirad/conf/eval_real.yaml" -srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . - +CMDS=( # Diurnal cycle of 2m temperature - # python src/hirad/eval/diurnal_cycle_temp.py --config-name=${CONFIG_NAME} -" + "python src/hirad/eval/diurnal_cycle_temp.py" +) + +for cmd in "${CMDS[@]}"; do + name=$(basename "${cmd##* }" .py | tr '.' '_') + job_id=$(sbatch \ + --job-name="eval_temp_${name}" \ + --partition=normal \ + --nodes=1 \ + --ntasks-per-node=1 \ + --gpus-per-node=1 \ + --cpus-per-task=72 \ + --time=12:00:00 \ + --no-requeue \ + --exclusive \ + -A c38 \ + --output="./logs/plots_temp_${name}_%j.log" \ + --parsable \ + --wrap="srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -lc 'pip install -e . && ${cmd} --config-name=${CONFIG_NAME}'") + echo "Submitted ${name}: ${job_id}" +done diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index 7a9176c5..6f1ce39d 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -1,35 +1,34 @@ #!/bin/bash -#SBATCH --job-name="eval_wind" - -### HARDWARE ### -#SBATCH --partition=normal -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=12:00:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=./logs/plots_wind.log - -### ENVIRONMENT #### -#SBATCH -A a161 +set -euo pipefail ### CONFIG ### CONFIG_NAME="src/hirad/conf/eval_real.yaml" -srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . - +CMDS=( # Diurnal cycle of windspeed - # python src/hirad/eval/diurnal_cycle_wind.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/diurnal_cycle_wind.py" # Probability of exceedance - # python src/hirad/eval/probability_of_exceedance_wind.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/probability_of_exceedance_wind.py" # Maps - # python src/hirad/eval/map_wind_stats.py --config-name=${CONFIG_NAME} -" \ No newline at end of file + "python src/hirad/eval/map_wind_stats.py" +) + +for cmd in "${CMDS[@]}"; do + name=$(basename "${cmd##* }" .py | tr '.' '_') + job_id=$(sbatch \ + --job-name="eval_wind_${name}" \ + --partition=normal \ + --nodes=1 \ + --ntasks-per-node=1 \ + --gpus-per-node=1 \ + --cpus-per-task=72 \ + --time=24:00:00 \ + --no-requeue \ + --exclusive \ + -A c38 \ + --output="./logs/plots_wind_${name}_%j.log" \ + --parsable \ + --wrap="srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -lc 'pip install -e . && ${cmd} --config-name=${CONFIG_NAME}'") + echo "Submitted ${name}: ${job_id}" +done \ No newline at end of file From 6476df596878743a71e49f39445fdf27a89c1830 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 2 Jun 2026 16:42:17 +0200 Subject: [PATCH 07/51] manicure --- src/hirad/eval/bias_by_percentile_precip.py | 328 ++++++++++---------- 1 file changed, 164 insertions(+), 164 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index cfb703c2..80a5a41f 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -4,8 +4,9 @@ For each grid point g (and ensemble member m), a histogram of precipitation is built over time and the per-percentile quantile q_{g,m}(p) is estimated. -Spatial / member averaging is then applied to produce the plotted curves: +Spatial / member averaging is then applied to produce the plotted curves. """ +import concurrent.futures import logging from pathlib import Path @@ -22,75 +23,129 @@ ) -# --------------------------------------------------------------------------- -# helpers -# --------------------------------------------------------------------------- +def _to_flat(arr, conv: float) -> np.ndarray: + """Convert a (possibly xarray) field to a flat float numpy array scaled by conv.""" + return (np.asarray(getattr(arr, 'values', arr)) * conv).ravel() -def _accumulate_per_point_hist(pp_counts: np.ndarray, values_land: np.ndarray, - bin_edges: np.ndarray, n_bins: int) -> None: - """In-place add of one timestep of land values into per-grid-point histograms. - pp_counts : (n_land, n_bins) int32 – modified in place. - values_land : (n_land,) float – one value per land grid point. +def _build_all_histograms( + times: list, + ts_dirs: dict, + tp_out: int, + tp_in: int, + conv: float, + land_idx: np.ndarray, + hist_bins: np.ndarray, + log_interval: int, + logger: logging.Logger, +) -> tuple: + """Single serial pass over all timesteps, building histograms for every mode. + + Returns ``(det_counts, member_counts, n_members)`` where *det_counts* maps + mode → array-or-None and *member_counts* is a list of arrays (one per member) + or None. """ - # searchsorted over the interior edges: result in [0, n_bins - 1] - bin_idx = np.searchsorted(bin_edges[1:-1], values_land, side='right') - n_land = pp_counts.shape[0] - # Fancy indexing with unique indices is fully vectorised. - pp_counts.reshape(-1)[np.arange(n_land) * n_bins + bin_idx] += 1 + n_land = land_idx.size + n_bins = len(hist_bins) - 1 + det_modes = ('target', 'baseline', 'regression-prediction') + interior_edges = hist_bins[1:-1] + row_offsets = np.arange(n_land, dtype=np.intp) * n_bins + + det_counts: dict = {m: np.zeros((n_land, n_bins), dtype=np.int32) for m in det_modes} + member_counts: list | None = None + n_members: int | None = None + skip_det: set = set() + skip_preds = False + + def accumulate(counts, arr): + vals = _to_flat(arr, conv)[land_idx] + bin_idx = np.searchsorted(interior_edges, vals, side='right') + counts.reshape(-1)[row_offsets + bin_idx] += 1 + + logger.info(f"Processing all modes in a single pass ({len(times)} timesteps)") + + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f" timestep {i + 1}/{len(times)}") + + ts_dir = ts_dirs[ts] + + for mode in det_modes: + if mode in skip_det: + continue + ch = tp_in if mode == 'baseline' else tp_out + try: + arr = torch.load(ts_dir / f"{ts}-{mode}", weights_only=False)[ch] + except FileNotFoundError: + logger.warning(f" [{mode}] file not found at {ts}, skipping mode") + skip_det.add(mode) + continue + accumulate(det_counts[mode], arr) + + if not skip_preds: + try: + preds = torch.load(ts_dir / f"{ts}-predictions", weights_only=False) + except FileNotFoundError: + logger.warning(f" [predictions] file not found at {ts}, skipping ensemble") + skip_preds = True + continue + + if n_members is None: + n_members = preds.shape[0] + member_counts = [ + np.zeros((n_land, n_bins), dtype=np.int32) + for _ in range(n_members) + ] + logger.info(f" Detected {n_members} ensemble members") + for m in range(n_members): + accumulate(member_counts[m], preds[m, tp_out]) + + for mode in skip_det: + det_counts[mode] = None + + return det_counts, member_counts, n_members def _per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, - frac_percentiles: np.ndarray) -> np.ndarray: + frac_percentiles: np.ndarray, + block_size: int = 8192) -> np.ndarray: """Estimate per-row quantiles from per-grid-point histograms. Returns (n_land, P) float32 array. Uses upper-bin-edge values (no in-bin interpolation) — adequate given fine log-spaced bins. - """ - cdf = pp_counts.astype(np.float32, copy=True) - np.cumsum(cdf, axis=1, out=cdf) - totals = cdf[:, -1:].copy() - cdf /= np.maximum(totals, 1.0) - edges_upper = bin_edges[1:].astype(np.float32) + Processes land points in blocks of *block_size* rows so the CDF working set + (~block_size × n_bins × 8 bytes) fits comfortably in L3 cache, making the + row-wise ``searchsorted`` cache-friendly and GIL-free (numpy releases the + GIL for large C-level operations, enabling true thread parallelism when + multiple calls run concurrently). + """ n_land, n_bins = pp_counts.shape - out = np.empty((n_land, len(frac_percentiles)), dtype=np.float32) + P = len(frac_percentiles) + result = np.empty((n_land, P), dtype=np.float32) + edges_upper = bin_edges[1:].astype(np.float32) + frac_f64 = frac_percentiles.astype(np.float64) - # Reusable bool buffer to avoid repeated allocation. - buf = np.empty(cdf.shape, dtype=bool) - for j, p in enumerate(frac_percentiles): - np.less(cdf, p, out=buf) - idx = buf.sum(axis=1) # first bin where cdf >= p - np.clip(idx, 0, n_bins - 1, out=idx) - out[:, j] = edges_upper[idx] - return out + for start in range(0, n_land, block_size): + end = min(start + block_size, n_land) + blk = pp_counts[start:end] + B = end - start + cdf = np.cumsum(blk, axis=1, dtype=np.float64) + totals = cdf[:, -1:] + cdf /= np.maximum(totals, 1.0) -def _build_per_point_histogram(load_fn, times: list, land_idx: np.ndarray, - hist_bins: np.ndarray, log_interval: int, - logger: logging.Logger, mode_name: str - ) -> np.ndarray | None: - """Stream timesteps through `load_fn(ts) -> (H*W,) float array` and return - (n_land, n_bins) int32 per-grid-point histogram, or None on failure.""" - n_land = land_idx.size - n_bins = len(hist_bins) - 1 - pp_counts = np.zeros((n_land, n_bins), dtype=np.int32) - try: - for i, ts in enumerate(times): - if i % log_interval == 0: - logger.info(f" [{mode_name}] timestep {i + 1}/{len(times)}") - flat = load_fn(ts) # (H*W,) float, no NaN - _accumulate_per_point_hist(pp_counts, flat[land_idx], hist_bins, n_bins) - except FileNotFoundError: - logger.warning(f" {mode_name} data not found, skipping") - return None - return pp_counts + offset = (np.arange(B, dtype=np.float64) * 2.0)[:, None] + cdf += offset + queries = frac_f64[None, :] + offset + idx = np.searchsorted(cdf.ravel(), queries.ravel(), side='left') + idx = idx.reshape(B, P) - (np.arange(B, dtype=np.intp)[:, None] * n_bins) + np.clip(idx, 0, n_bins - 1, out=idx) + result[start:end] = edges_upper[idx] + return result -# --------------------------------------------------------------------------- -# plotting -# --------------------------------------------------------------------------- def save_bias_by_percentile_plot( bias_data_dict: dict, @@ -112,14 +167,11 @@ def save_bias_by_percentile_plot( Path(out_path).parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(10, 6)) - - # Convert to fractions in (0, 1) for logit scale frac = percentile_values / 100.0 for (key, bias_data), label, color in zip(bias_data_dict.items(), labels, colors): if isinstance(bias_data, (list, tuple)): - # Ensemble: plot member average ± 1 σ - arr = np.array(bias_data) # (n_members, n_percentiles) + arr = np.array(bias_data) mean_bias = arr.mean(axis=0) std_bias = arr.std(axis=0) ax.plot(frac, mean_bias, color=color, label=label, linewidth=2) @@ -137,10 +189,7 @@ def save_bias_by_percentile_plot( ) ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--', label='Zero bias') - - # Logit x-axis: compresses the centre and stretches both tails _apply_logit_xaxis(ax, frac) - # Symlog: linear within ±linthresh, logarithmic beyond → "log away from zero" ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) ax.set_ylim(-10, 10) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) @@ -185,7 +234,7 @@ def save_mae_by_percentile_plot( for (key, mae_data), label, color in zip(mae_data_dict.items(), labels, colors): if isinstance(mae_data, (list, tuple)): - arr = np.array(mae_data) # (n_members, n_percentiles), already absolute + arr = np.array(mae_data) mean_mae = arr.mean(axis=0) std_mae = arr.std(axis=0) ax.plot(frac, mean_mae, color=color, label=label, linewidth=2) @@ -242,15 +291,11 @@ def save_spread_by_percentile_plot( plt.close() -# --------------------------------------------------------------------------- -# main -# --------------------------------------------------------------------------- - def main(cfg: dict) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting per-gridpoint bias-by-percentile computation for precipitation over land") + logger.info("Starting bias-by-percentile computation for precipitation over land") try: generation_dir, gen_cfg, times = load_generation_setup(cfg) except ValueError as exc: @@ -260,32 +305,27 @@ def main(cfg: dict) -> None: out_root = Path(generation_dir) - # Channel indices indices = get_channel_indices(gen_cfg) tp_out = indices['output']['tp'] tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - # Land-sea mask: build a boolean mask and a flat index list of land points land_da = load_land_sea_mask( cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width") ) - land_bool_2d = np.isfinite(land_da.values) # (H, W) - land_idx = np.flatnonzero(land_bool_2d.ravel()) # (n_land,) + land_bool_2d = np.isfinite(land_da.values) + land_idx = np.flatnonzero(land_bool_2d.ravel()) n_land = land_idx.size logger.info(f"{n_land} land grid points") - # Log-spaced histogram bins. Coarser than the pooled version since each - # grid point only contributes ~ T samples (here T = {len(times)}). n_bins = 500 hist_bins = np.concatenate([ np.array([0.0]), - np.logspace(-2, 3.2, n_bins), # 0.01 -> ~1585 mm/h + np.logspace(-2, 3.2, n_bins), ]) log_interval = cfg.get("log_interval", 24) - conv = cfg.get("conv_factor_hourly") + conv = cfg.get("conv_factor_hourly", 1.0) - # -- Percentile grid (denser in the upper tail) -- percentile_values = np.unique(np.concatenate([ np.linspace(1.0, 90.0, 90), np.linspace(90.0, 99.0, 90), @@ -295,127 +335,86 @@ def main(cfg: dict) -> None: frac_percentiles = percentile_values / 100.0 P = len(frac_percentiles) - # --------------------------------------------------------------------- - # Loader helpers - # --------------------------------------------------------------------- - def _load_det(ts, mode): - ch_idx = tp_out if mode in ('target', 'regression-prediction') else tp_in - arr = torch.load( - resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", - weights_only=False, - )[ch_idx] - arr = np.asarray(getattr(arr, 'values', arr)) * conv - return arr.ravel() - - def _load_member(ts, m, preds_cache): - arr = preds_cache[m, tp_out] - arr = np.asarray(getattr(arr, 'values', arr)) * conv - return arr.ravel() - - # --------------------------------------------------------------------- - # Target: per-grid-point quantiles (needed by everything else) - # --------------------------------------------------------------------- - logger.info("Processing target") - target_pp_counts = _build_per_point_histogram( - lambda ts: _load_det(ts, 'target'), - times, land_idx, hist_bins, log_interval, logger, 'target', + logger.info(f"Resolving {len(times)} timestep directories ...") + ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} + + det_counts, member_counts, n_members = _build_all_histograms( + times, ts_dirs, tp_out, tp_in, conv, land_idx, hist_bins, + log_interval, logger, ) - if target_pp_counts is None: + + if det_counts.get('target') is None: logger.error("No target data found; cannot compute bias.") return - target_q = _per_point_quantiles(target_pp_counts, hist_bins, frac_percentiles) - del target_pp_counts + + det_mode_cfg = [ + ('baseline', 'Input', 'orange'), + ('regression-prediction', 'Regression Prediction', 'red'), + ] + active_det_modes = [m for m, _, _ in det_mode_cfg if det_counts.get(m) is not None] + has_ensemble = member_counts is not None and n_members is not None and n_members > 0 + + n_tasks = 1 + len(active_det_modes) + (n_members if has_ensemble else 0) + n_quant_workers = cfg.get("n_quant_workers", n_tasks) + + with concurrent.futures.ThreadPoolExecutor(max_workers=n_quant_workers) as pool: + def submit(counts): + return pool.submit(_per_point_quantiles, counts, hist_bins, frac_percentiles) + + fut_target = submit(det_counts['target']) + fut_det = {mode: submit(det_counts[mode]) for mode in active_det_modes} + fut_members = ( + [submit(member_counts[m]) for m in range(n_members)] + if has_ensemble else [] + ) + + target_q = fut_target.result() + del det_counts['target'] + det_results = {mode: fut_det[mode].result() for mode in active_det_modes} + for mode in active_det_modes: + del det_counts[mode] + member_qs = [f.result() for f in fut_members] + if has_ensemble: + for m in range(n_members): + member_counts[m] = None + target_mean_q = target_q.mean(axis=0) - # --------------------------------------------------------------------- - # Deterministic modes - # --------------------------------------------------------------------- bias_data: dict = {} mae_data: dict = {} labels: list = [] colors: list = [] - mae_labels: list = [] - mae_colors: list = [] - for mode, label, color in [ - ('baseline', 'Input', 'orange'), - ('regression-prediction', 'Regression Prediction', 'red'), - ]: - logger.info(f"Processing {mode}") - pp = _build_per_point_histogram( - lambda ts, m=mode: _load_det(ts, m), - times, land_idx, hist_bins, log_interval, logger, mode, - ) - if pp is None: + for mode, label, color in det_mode_cfg: + if mode not in det_results: continue - pred_q = _per_point_quantiles(pp, hist_bins, frac_percentiles) - del pp + pred_q = det_results.pop(mode) bias_data[mode] = pred_q.mean(axis=0) - target_mean_q mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) labels.append(label) colors.append(color) - mae_labels.append(label) - mae_colors.append(color) - - # --------------------------------------------------------------------- - # Ensemble predictions: process one timestep at a time, accumulating - # per-member per-grid-point histograms. Then collapse per-member to - # spatial-mean curves and online-aggregate ensemble statistics. - # --------------------------------------------------------------------- - logger.info("Processing predictions (per-member, per-grid-point)") - n_members: int | None = None - member_pp: list[np.ndarray] | None = None - - for i, ts in enumerate(times): - if i % log_interval == 0: - logger.info(f" [predictions] timestep {i + 1}/{len(times)}") - preds = torch.load( - resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", - weights_only=False, - ) # (n_members, n_channels, H, W) - if n_members is None: - n_members = preds.shape[0] - member_pp = [ - np.zeros((n_land, n_bins), dtype=np.int32) for _ in range(n_members) - ] - logger.info(f" Detected {n_members} ensemble members") - for m in range(n_members): - arr = preds[m, tp_out] - flat = (np.asarray(getattr(arr, 'values', arr)) * conv).ravel() - _accumulate_per_point_hist(member_pp[m], flat[land_idx], hist_bins, n_bins) member_biases: list[np.ndarray] = [] member_maes: list[np.ndarray] = [] - # Online aggregates for spread: E[std_m(q_{g,m})] over g. - # Need per-(g, p) std across members -> keep running sum and sum-of-squares - # of per-member per-point quantiles. - if member_pp is not None and n_members is not None and n_members > 0: + spread = None + + if has_ensemble: sum_q = np.zeros((n_land, P), dtype=np.float64) sumsq_q = np.zeros((n_land, P), dtype=np.float64) - for m in range(n_members): - qm = _per_point_quantiles(member_pp[m], hist_bins, frac_percentiles) - member_pp[m] = None # free as we go + for qm in member_qs: sum_q += qm sumsq_q += qm.astype(np.float64) ** 2 member_biases.append((qm.mean(axis=0) - target_mean_q).astype(np.float64)) member_maes.append(np.abs(qm - target_q).mean(axis=0).astype(np.float64)) mean_q = sum_q / n_members var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) - # Per-gridpoint std across members, then spatial mean. spread = np.sqrt(var_q).mean(axis=0) bias_data['predictions'] = member_biases mae_data['predictions'] = member_maes labels.append('CorrDiff Ensemble (mean +/- 1 sigma)') colors.append('green') - mae_labels.append('CorrDiff Ensemble (mean +/- 1 sigma)') - mae_colors.append('green') - else: - spread = None - - # --------------------------------------------------------------------- - # Output - # --------------------------------------------------------------------- + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "by_percentile" output_path.mkdir(parents=True, exist_ok=True) @@ -431,7 +430,7 @@ def _load_member(ts, m, preds_cache): fn_mae = output_path / 'precipitation_mae_by_percentile.png' save_mae_by_percentile_plot( - mae_data, percentile_values, mae_labels, mae_colors, + mae_data, percentile_values, labels, colors, title='Precipitation MAE by Percentile - Over Land (Per-Gridpoint)', xlabel='Percentile', ylabel='MAE [mm/h]', @@ -451,5 +450,6 @@ def _load_member(ts, m, preds_cache): logger.info(f"Spread-by-percentile plot saved: {fn_spread}") + if __name__ == '__main__': main(parse_eval_cli()) From 6f7244713eb6fd84fdc54e93dbf97e4fcdd66b3b Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 09:41:53 +0200 Subject: [PATCH 08/51] reduce memory --- src/hirad/eval/map_precip_stats.py | 39 ++++++++++++++---------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index 30e2f601..a4bcde9c 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -196,8 +196,8 @@ def main(cfg: dict): logger.warning(f"{mode} not available, skipping") continue - # Stack into (T, H, W) numpy array - mode_data = np.stack(data_list, axis=0).astype(np.float64) + # Stack into (T, H, W) numpy array. float32 to save memory. + mode_data = np.stack(data_list, axis=0).astype(np.float32) del data_list for stat_config in stat_configs: @@ -207,33 +207,30 @@ def main(cfg: dict): map_output_dir.mkdir(parents=True, exist_ok=True) plot_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) - # --- Predictions: load each file ONCE, distribute to all members --- + del mode_data + + # --- Predictions: process ONE member at a time to bound memory usage --- logger.info("Processing predictions mode...") sample_data = torch.load(resolve_ts_dir(out_root, times[0]) / times[0] / f"{times[0]}-predictions", weights_only=False) n_members = sample_data.shape[0] del sample_data logger.info(f"Found {n_members} ensemble members") - # Pre-allocate arrays for ALL members at once: (n_members, T, H, W) - # If memory is tight, we can do this in chunks. For 16 members × 2200 × 704 × 1088 × 4 bytes ≈ 107 GB - # Instead we can do cummulative statistics on the fly without storing all members in memory (like in map_wind_stats), but this works for now. - H, W = cfg.get("height"), cfg.get("width") - member_arrays = [np.empty((len(times), H, W), dtype=np.float32) for _ in range(n_members)] - - logger.info("Loading all prediction timesteps (single pass over files)...") - for i, ts in enumerate(times): - if i % log_interval == 0: - logger.info(f"Loading predictions timestep {i+1}/{len(times)}: {ts}") - pred_data = torch.load(out_root / ts / f"{ts}-predictions", weights_only=False) * conv_factor - for m in range(n_members): - member_arrays[m][i] = (pred_data[m, tp_out].numpy() if isinstance(pred_data, torch.Tensor) - else pred_data[m, tp_out]) - del pred_data + H: int = cfg["height"] + W: int = cfg["width"] + member_data = np.empty((len(times), H, W), dtype=np.float32) for member_idx in range(n_members): - logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}") - member_data = member_arrays[member_idx].astype(np.float64) + logger.info(f"Loading prediction member {member_idx+1}/{n_members} (single pass over files)...") + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f"Loading predictions member {member_idx+1} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load(resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", weights_only=False) + member_slice = pred_data[member_idx, tp_out] + member_data[i] = (member_slice.numpy() if isinstance(member_slice, torch.Tensor) else member_slice) * conv_factor + del pred_data + logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}") for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") member_result = apply_statistic(member_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold) @@ -243,7 +240,7 @@ def main(cfg: dict): member_label = f'CorrDiff Member {member_idx+1}' plot_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg) - del member_arrays + del member_data logger.info("All precipitation statistics maps generated successfully") From 61508d9d799b5ead013b47fb2ad8c9ffaaa860e6 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 10:14:31 +0200 Subject: [PATCH 09/51] speedup --- .../eval/probability_of_exceedance_wind.py | 273 ++++++++++-------- 1 file changed, 151 insertions(+), 122 deletions(-) diff --git a/src/hirad/eval/probability_of_exceedance_wind.py b/src/hirad/eval/probability_of_exceedance_wind.py index f280e2db..fa2d82da 100644 --- a/src/hirad/eval/probability_of_exceedance_wind.py +++ b/src/hirad/eval/probability_of_exceedance_wind.py @@ -1,15 +1,12 @@ """Probability of exceedance for wind speed and components.""" import logging +import time from pathlib import Path -import hydra import matplotlib.pyplot as plt import numpy as np import torch -import xarray as xr -from hirad.datasets import get_channels_from_strings, get_strings_from_channels -from hirad.utils.function_utils import get_time_from_range from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, parse_eval_cli, resolve_ts_dir from hirad.eval.eval_utils import percentiles_from_histogram @@ -28,18 +25,21 @@ def compute_exceedance_probs(values, thresholds, use_abs=False): def update_exceedance_counts(counts, total, values, thresholds, use_abs=False): - """Update exceedance counts incrementally.""" + """Update exceedance counts incrementally via searchsorted (O(n log k)).""" data = np.abs(values) if use_abs else values - counts += (data[:, None] > thresholds[None, :]).sum(axis=0) - total += len(values) + # idx[i] = number of thresholds strictly less than data[i] (side='left') + # data[i] > thresholds[j] iff idx[i] > j + idx = np.searchsorted(thresholds, data, side='left') + bin_counts = np.bincount(idx, minlength=len(thresholds) + 1) + counts += len(data) - np.cumsum(bin_counts)[: len(thresholds)] + total += len(data) return counts, total def compute_percentiles(values, percentile_dict, use_abs=False): """Compute percentiles.""" data = np.abs(values) if use_abs else values - data_array = xr.DataArray(data) - return {key: data_array.quantile(p).item() for key, p in percentile_dict.items()} + return {key: float(np.quantile(data, p)) for key, p in percentile_dict.items()} def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): @@ -119,6 +119,34 @@ def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title plt.close() +def _accumulate(exc_counts, h_counts, speed_vals, u_vals, v_vals, + thresholds, hist_interior, n_hist_bins): + """Update exceedance and histogram count arrays in-place.""" + abs_u = np.abs(u_vals) + abs_v = np.abs(v_vals) + n_thr = len(thresholds) + + # Exceedance counts: O(n log k) via searchsorted+bincount + for counts, data in [ + (exc_counts['speed'], speed_vals), + (exc_counts['u'], abs_u), + (exc_counts['v'], abs_v), + ]: + idx = np.searchsorted(thresholds, data, side='left') + bin_counts = np.bincount(idx, minlength=n_thr + 1) + counts += len(data) - np.cumsum(bin_counts)[:n_thr] + + # Histogram counts: searchsorted+bincount, no temporary bool array + for hcounts, data in [ + (h_counts['speed'], speed_vals), + (h_counts['u'], abs_u), + (h_counts['v'], abs_v), + ]: + hcounts += np.bincount( + np.searchsorted(hist_interior, data, side='right'), minlength=n_hist_bins + ) + + def main(cfg: dict): # Setup logging logging.basicConfig(level=logging.INFO) @@ -141,12 +169,13 @@ def main(cfg: dict): v10_out = indices['output'].get('10v') u10_in = indices['input'].get('10u', u10_out) v10_in = indices['input'].get('10v', v10_out) - + if u10_out is None or v10_out is None: logger.error("Wind components (10u, 10v) not found in dataset!") return - - logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, input: 10u={u10_in}, 10v={v10_in}") + + logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, " + f"input: 10u={u10_in}, 10v={v10_in}") # Define thresholds for exceedance calculation (same for all variables) thresholds = np.logspace(-1, 2, 200) # From 0.1 to ~100 m/s @@ -155,123 +184,123 @@ def main(cfg: dict): # Histogram bins for percentile estimation (fine-grained log-spaced) hist_bins = np.concatenate([ np.array([0.0]), - np.logspace(-1, 2.5, 5000) # From 0.1 to ~316 m/s + np.logspace(-1, 2.5, 5000) # From 0.1 to ~316 m/s ]) n_hist_bins = len(hist_bins) - 1 - - # Storage for exceedance counts (incremental computation) - exceedance_counts = { - 'speed': {}, 'u': {}, 'v': {} - } - totals = {'speed': {}, 'u': {}, 'v': {}} - hist_counts = { - 'speed': {}, 'u': {}, 'v': {} - } - - # -- Process target and baseline -- - for mode in ['target', 'baseline', 'regression-prediction']: - logger.info(f"Processing mode: {mode}") - - # Initialize counts - for var in ['speed', 'u', 'v']: + # Interior edges used by searchsorted+bincount (equivalent to np.histogram) + hist_interior = hist_bins[1:-1] + + det_modes = ('target', 'baseline', 'regression-prediction') + VARS = ('speed', 'u', 'v') + + # Storage for exceedance counts and histograms + exceedance_counts = {var: {} for var in VARS} + totals = {var: {} for var in VARS} + hist_counts = {var: {} for var in VARS} + for mode in det_modes: + for var in VARS: exceedance_counts[var][mode] = np.zeros(n_thresholds, dtype=np.int64) - totals[var][mode] = 0 - hist_counts[var][mode] = np.zeros(n_hist_bins, dtype=np.int64) - - try: - for i, ts in enumerate(times): - if i % cfg.get("log_interval") == 0: - logger.info(f"Processing timestep {i+1}/{len(times)}") - - data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False) - - # Extract wind components - if mode in ['target', 'regression-prediction']: - u = data[u10_out] - v = data[v10_out] - else: # baseline - u = data[u10_in] - v = data[v10_in] - - wind_speed = compute_wind_speed(u, v) - - # Get valid values - valid_mask = ~np.isnan(wind_speed) - speed_vals = wind_speed[valid_mask].flatten() - u_vals = u[valid_mask].flatten() - v_vals = v[valid_mask].flatten() - - # Update exceedance counts incrementally - exceedance_counts['speed'][mode], totals['speed'][mode] = update_exceedance_counts( - exceedance_counts['speed'][mode], totals['speed'][mode], speed_vals, thresholds, use_abs=False - ) - exceedance_counts['u'][mode], totals['u'][mode] = update_exceedance_counts( - exceedance_counts['u'][mode], totals['u'][mode], u_vals, thresholds, use_abs=True - ) - exceedance_counts['v'][mode], totals['v'][mode] = update_exceedance_counts( - exceedance_counts['v'][mode], totals['v'][mode], v_vals, thresholds, use_abs=True - ) - - # Collect samples for percentiles (subsample to save memory) - hist_counts['speed'][mode] += np.histogram(speed_vals, bins=hist_bins)[0] - hist_counts['u'][mode] += np.histogram(np.abs(u_vals), bins=hist_bins)[0] - hist_counts['v'][mode] += np.histogram(np.abs(v_vals), bins=hist_bins)[0] - - except Exception as e: - logger.warning(f"{mode} data not found or error occurred, skipping: {e}") - continue - - logger.info(f"Processed {totals['speed'][mode]} values for {mode}") - - # -- Process predictions: compute exceedance for each ensemble member -- - logger.info("Processing predictions") - - n_members = None - member_counts = {'speed': [], 'u': [], 'v': []} - member_totals = {'speed': [], 'u': [], 'v': []} - member_hist_counts = {'speed': [], 'u': [], 'v': []} - + totals[var][mode] = 0 + hist_counts[var][mode] = np.zeros(n_hist_bins, dtype=np.int64) + + n_members = None + member_counts = {var: [] for var in VARS} + member_totals = {var: [] for var in VARS} + member_hist_counts = {var: [] for var in VARS} + + skip_det = set() # modes whose files were missing + skip_preds = False + + log_interval = cfg.get("log_interval", 24) + + # Pre-resolve timestamp directories once (avoids 4× filesystem glob per timestep) + logger.info(f"Resolving {len(times)} timestep directories ...") + ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} + + # -- Single pass over all timesteps -- + t0 = time.perf_counter() for i, ts in enumerate(times): - if i % cfg.get("log_interval") == 0: - logger.info(f"Processing timestep {i+1}/{len(times)}") - - preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon] - - if n_members is None: - n_members = preds.shape[0] - for var in ['speed', 'u', 'v']: - member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)] - member_totals[var] = [0 for _ in range(n_members)] - member_hist_counts[var] = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)] - - for member_idx in range(n_members): - u = preds[member_idx, u10_out] - v = preds[member_idx, v10_out] + if i % log_interval == 0: + elapsed = time.perf_counter() - t0 + logger.info(f"Processing timestep {i+1}/{len(times)} ({elapsed:.1f}s elapsed)") + + ts_dir = ts_dirs[ts] + + # Deterministic modes: target, baseline, regression-prediction + for mode in det_modes: + if mode in skip_det: + continue + try: + data = torch.load(ts_dir / f"{ts}-{mode}", weights_only=False) + except FileNotFoundError: + logger.warning(f"[{mode}] file not found at {ts}, skipping mode entirely") + skip_det.add(mode) + continue + + ch_u = u10_in if mode == 'baseline' else u10_out + ch_v = v10_in if mode == 'baseline' else v10_out + u = data[ch_u] + v = data[ch_v] + wind_speed = compute_wind_speed(u, v) - valid_mask = ~np.isnan(wind_speed) - speed_vals = wind_speed[valid_mask].flatten() - u_vals = u[valid_mask].flatten() - v_vals = v[valid_mask].flatten() - - # Update counts - member_counts['speed'][member_idx], member_totals['speed'][member_idx] = update_exceedance_counts( - member_counts['speed'][member_idx], member_totals['speed'][member_idx], speed_vals, thresholds, use_abs=False - ) - member_counts['u'][member_idx], member_totals['u'][member_idx] = update_exceedance_counts( - member_counts['u'][member_idx], member_totals['u'][member_idx], u_vals, thresholds, use_abs=True - ) - member_counts['v'][member_idx], member_totals['v'][member_idx] = update_exceedance_counts( - member_counts['v'][member_idx], member_totals['v'][member_idx], v_vals, thresholds, use_abs=True + speed_vals = wind_speed[valid_mask].ravel() + u_vals = u[valid_mask].ravel() + v_vals = v[valid_mask].ravel() + + _accumulate( + {var: exceedance_counts[var][mode] for var in VARS}, + {var: hist_counts[var][mode] for var in VARS}, + speed_vals, u_vals, v_vals, + thresholds, hist_interior, n_hist_bins, ) - - # Collect samples for percentiles - member_hist_counts['speed'][member_idx] += np.histogram(speed_vals, bins=hist_bins)[0] - member_hist_counts['u'][member_idx] += np.histogram(np.abs(u_vals), bins=hist_bins)[0] - member_hist_counts['v'][member_idx] += np.histogram(np.abs(v_vals), bins=hist_bins)[0] - - logger.info(f"Collected {n_members} ensemble members for predictions") - + totals['speed'][mode] += len(speed_vals) + totals['u'][mode] += len(u_vals) + totals['v'][mode] += len(v_vals) + + # Predictions (ensemble) + if not skip_preds: + try: + preds = torch.load(ts_dir / f"{ts}-predictions", weights_only=False) # [M, C, H, W] + except FileNotFoundError: + logger.warning(f"[predictions] file not found at {ts}, skipping ensemble entirely") + skip_preds = True + continue + + if n_members is None: + n_members = preds.shape[0] + logger.info(f"Detected {n_members} ensemble members") + for var in VARS: + member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)] + member_totals[var] = [0] * n_members + member_hist_counts[var] = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)] + + for m in range(n_members): + u = preds[m, u10_out] + v = preds[m, v10_out] + wind_speed = compute_wind_speed(u, v) + valid_mask = ~np.isnan(wind_speed) + speed_vals = wind_speed[valid_mask].ravel() + u_vals = u[valid_mask].ravel() + v_vals = v[valid_mask].ravel() + + _accumulate( + {var: member_counts[var][m] for var in VARS}, + {var: member_hist_counts[var][m] for var in VARS}, + speed_vals, u_vals, v_vals, + thresholds, hist_interior, n_hist_bins, + ) + member_totals['speed'][m] += len(speed_vals) + member_totals['u'][m] += len(u_vals) + member_totals['v'][m] += len(v_vals) + + total_elapsed = time.perf_counter() - t0 + logger.info(f"Single-pass loop completed in {total_elapsed:.1f}s for {len(times)} timesteps") + if n_members is not None: + logger.info(f"Collected {n_members} ensemble members for predictions") + else: + n_members = 0 # no predictions found; guard range() calls below + # Convert counts to probabilities exceedance_data = {'speed': {}, 'u': {}, 'v': {}} From c6084153e9a94a4c587d284a588d260b9fe4fd02 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 10:20:32 +0200 Subject: [PATCH 10/51] cleanups --- src/hirad/eval/bias_by_percentile_precip.py | 59 +++++++++++++++------ 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 80a5a41f..c578d513 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -156,6 +156,7 @@ def save_bias_by_percentile_plot( xlabel: str, ylabel: str, out_path, + mean_q: np.ndarray = None, ) -> None: """Save a bias-by-percentile figure. @@ -188,8 +189,8 @@ def save_bias_by_percentile_plot( color=color, label=label, linewidth=2, alpha=0.85, ) - ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--', label='Zero bias') - _apply_logit_xaxis(ax, frac) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) ax.set_ylim(-10, 10) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) @@ -202,16 +203,30 @@ def save_bias_by_percentile_plot( plt.close() -def _apply_logit_xaxis(ax, frac: np.ndarray) -> None: +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray = None) -> None: """Apply logit x-axis with labelled percentile ticks.""" ax.set_xscale('logit') - ax.set_xlim(frac[0], frac[-1]) - tick_fracs = [0.01, 0.10, 0.25, 0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] - tick_labels = ['1', '10', '25', '50', '75', '90', '99', '99.9', '99.99'] + ax.set_xlim(0.5, frac[-1]) + tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] + tick_labels = ['50', '75', '90', '99', '99.9', '99.99'] ax.set_xticks(tick_fracs) ax.set_xticklabels(tick_labels) ax.grid(True, alpha=0.3, which='both') + if mean_q is not None: + ax2 = ax.twiny() + ax2.set_xscale('logit') + ax2.set_xlim(0.5, frac[-1]) + # Place ticks at fixed "nice" mm/h values, positioning them by inverting mean_q + nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) + tick_positions = np.interp(nice_mmh, mean_q, frac) + valid = (tick_positions > 0.5) & (tick_positions < frac[-1]) + tick_positions = tick_positions[valid] + tick_mmh = nice_mmh[valid] + ax2.set_xticks(tick_positions) + ax2.set_xticklabels([f'{v:g}' for v in tick_mmh]) + ax2.set_xlabel('Mean target [mm/h]') + def save_mae_by_percentile_plot( mae_data_dict: dict, @@ -222,6 +237,7 @@ def save_mae_by_percentile_plot( xlabel: str, ylabel: str, out_path, + mean_q: np.ndarray = None, ) -> None: """Save a MAE-by-percentile figure. @@ -247,9 +263,9 @@ def save_mae_by_percentile_plot( else: ax.plot(frac, mae_data, color=color, label=label, linewidth=2, alpha=0.85) - _apply_logit_xaxis(ax, frac) + _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('log') - ax.set_ylim(1e-3, 10) + ax.set_ylim(1e-5, 100) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) @@ -267,6 +283,7 @@ def save_spread_by_percentile_plot( xlabel: str, ylabel: str, out_path, + mean_q: np.ndarray = None, ) -> None: """Save an ensemble-spread-by-percentile figure. @@ -277,10 +294,10 @@ def save_spread_by_percentile_plot( fig, ax = plt.subplots(figsize=(10, 6)) frac = percentile_values / 100.0 - ax.plot(frac, spread, color='green', linewidth=2, label='CorrDiff Ensemble') - _apply_logit_xaxis(ax, frac) + ax.plot(frac, spread, color='green', linewidth=2) + _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('log') - ax.set_ylim(1e-3, 10) + ax.set_ylim(1e-5, 100) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) @@ -295,6 +312,15 @@ def main(cfg: dict) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + plt.rcParams.update({ + 'font.size': 16, + 'axes.titlesize': 18, + 'axes.labelsize': 16, + 'xtick.labelsize': 14, + 'ytick.labelsize': 14, + 'legend.fontsize': 14, + }) + logger.info("Starting bias-by-percentile computation for precipitation over land") try: generation_dir, gen_cfg, times = load_generation_setup(cfg) @@ -421,20 +447,22 @@ def submit(counts): fn = output_path / 'precipitation_bias_by_percentile.png' save_bias_by_percentile_plot( bias_data, percentile_values, labels, colors, - title='Precipitation Bias by Percentile - Over Land (Per-Gridpoint)', + title='Precipitation Bias Over Land', xlabel='Percentile', ylabel='Bias [mm/h]', out_path=fn, + mean_q=target_mean_q, ) logger.info(f"Bias-by-percentile plot saved: {fn}") fn_mae = output_path / 'precipitation_mae_by_percentile.png' save_mae_by_percentile_plot( mae_data, percentile_values, labels, colors, - title='Precipitation MAE by Percentile - Over Land (Per-Gridpoint)', + title='Precipitation MAE Over Land', xlabel='Percentile', ylabel='MAE [mm/h]', out_path=fn_mae, + mean_q=target_mean_q, ) logger.info(f"MAE-by-percentile plot saved: {fn_mae}") @@ -442,10 +470,11 @@ def submit(counts): fn_spread = output_path / 'precipitation_spread_by_percentile.png' save_spread_by_percentile_plot( spread, percentile_values, - title='Precipitation Ensemble Spread by Percentile - Over Land (Per-Gridpoint)', + title='Precipitation Ensemble Spread Over Land', xlabel='Percentile', - ylabel='Spread (mean over land of std across members) [mm/h]', + ylabel='Ensemble Spread [mm/h]', out_path=fn_spread, + mean_q=target_mean_q, ) logger.info(f"Spread-by-percentile plot saved: {fn_spread}") From bff9615f4c5b21c65e1fcd9fb989f50faa342255 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 10:22:56 +0200 Subject: [PATCH 11/51] fix issues --- src/hirad/eval/bias_by_percentile_precip.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index c578d513..d624faf8 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -11,6 +11,7 @@ from pathlib import Path import matplotlib.pyplot as plt +import matplotlib.ticker as mticker import numpy as np import torch @@ -91,12 +92,13 @@ def accumulate(counts, arr): continue if n_members is None: - n_members = preds.shape[0] + n_members = int(preds.shape[0]) member_counts = [ np.zeros((n_land, n_bins), dtype=np.int32) for _ in range(n_members) ] logger.info(f" Detected {n_members} ensemble members") + assert n_members is not None and member_counts is not None for m in range(n_members): accumulate(member_counts[m], preds[m, tp_out]) @@ -156,7 +158,7 @@ def save_bias_by_percentile_plot( xlabel: str, ylabel: str, out_path, - mean_q: np.ndarray = None, + mean_q: np.ndarray | None = None, ) -> None: """Save a bias-by-percentile figure. @@ -193,7 +195,7 @@ def save_bias_by_percentile_plot( _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) ax.set_ylim(-10, 10) - ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) @@ -203,7 +205,7 @@ def save_bias_by_percentile_plot( plt.close() -def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray = None) -> None: +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: """Apply logit x-axis with labelled percentile ticks.""" ax.set_xscale('logit') ax.set_xlim(0.5, frac[-1]) @@ -237,7 +239,7 @@ def save_mae_by_percentile_plot( xlabel: str, ylabel: str, out_path, - mean_q: np.ndarray = None, + mean_q: np.ndarray | None = None, ) -> None: """Save a MAE-by-percentile figure. @@ -266,7 +268,7 @@ def save_mae_by_percentile_plot( _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('log') ax.set_ylim(1e-5, 100) - ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) @@ -283,7 +285,7 @@ def save_spread_by_percentile_plot( xlabel: str, ylabel: str, out_path, - mean_q: np.ndarray = None, + mean_q: np.ndarray | None = None, ) -> None: """Save an ensemble-spread-by-percentile figure. @@ -298,7 +300,7 @@ def save_spread_by_percentile_plot( _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('log') ax.set_ylim(1e-5, 100) - ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:g}')) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) @@ -337,7 +339,7 @@ def main(cfg: dict) -> None: logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") land_da = load_land_sea_mask( - cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width") + cfg.get("land_sea_mask_path"), cfg.get("height") or 352, cfg.get("width") or 544 ) land_bool_2d = np.isfinite(land_da.values) land_idx = np.flatnonzero(land_bool_2d.ravel()) From a14adf84fb7726ffaa372d0f4957690bf1e094ff Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 14:26:25 +0200 Subject: [PATCH 12/51] fix --- src/hirad/eval/bias_by_percentile_precip.py | 34 ++++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index d624faf8..d7ca0e38 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -251,7 +251,17 @@ def save_mae_by_percentile_plot( frac = percentile_values / 100.0 for (key, mae_data), label, color in zip(mae_data_dict.items(), labels, colors): - if isinstance(mae_data, (list, tuple)): + if isinstance(mae_data, tuple) and len(mae_data) == 2 and isinstance(mae_data[0], np.ndarray): + # (mean_mae, spread_mae) computed from per-land-point inter-member variance + mean_mae, std_mae = mae_data + ax.plot(frac, mean_mae, color=color, label=label, linewidth=2) + ax.fill_between( + frac, + np.maximum(mean_mae - std_mae, 0), + mean_mae + std_mae, + color=color, alpha=0.2, + ) + elif isinstance(mae_data, list): arr = np.array(mae_data) mean_mae = arr.mean(axis=0) std_mae = arr.std(axis=0) @@ -304,7 +314,6 @@ def save_spread_by_percentile_plot( ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) - ax.legend() plt.tight_layout() plt.savefig(out_path, dpi=300, bbox_inches='tight') plt.close() @@ -423,23 +432,32 @@ def submit(counts): colors.append(color) member_biases: list[np.ndarray] = [] - member_maes: list[np.ndarray] = [] spread = None + ensemble_mae: tuple[np.ndarray, np.ndarray] | None = None if has_ensemble: - sum_q = np.zeros((n_land, P), dtype=np.float64) + sum_q = np.zeros((n_land, P), dtype=np.float64) sumsq_q = np.zeros((n_land, P), dtype=np.float64) + sum_ae = np.zeros((n_land, P), dtype=np.float64) + sumsq_ae = np.zeros((n_land, P), dtype=np.float64) for qm in member_qs: - sum_q += qm + ae = np.abs(qm.astype(np.float64) - target_q.astype(np.float64)) + sum_q += qm sumsq_q += qm.astype(np.float64) ** 2 + sum_ae += ae + sumsq_ae += ae ** 2 member_biases.append((qm.mean(axis=0) - target_mean_q).astype(np.float64)) - member_maes.append(np.abs(qm - target_q).mean(axis=0).astype(np.float64)) mean_q = sum_q / n_members - var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) + var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) spread = np.sqrt(var_q).mean(axis=0) + mean_ae = sum_ae / n_members + var_ae = np.maximum(sumsq_ae / n_members - mean_ae ** 2, 0.0) + # spatial mean of per-land-point inter-member MAE spread + ensemble_mae = (mean_ae.mean(axis=0), np.sqrt(var_ae).mean(axis=0)) + bias_data['predictions'] = member_biases - mae_data['predictions'] = member_maes + mae_data['predictions'] = ensemble_mae labels.append('CorrDiff Ensemble (mean +/- 1 sigma)') colors.append('green') From e4cc942d5c7c83ea34613fd13da7e2c0e98b926a Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 15:32:20 +0200 Subject: [PATCH 13/51] add Q-bias plot for temperature --- src/hirad/eval/bias_by_percentile_common.py | 464 ++++++++++++++++++ src/hirad/eval/bias_by_percentile_precip.py | 511 +++----------------- src/hirad/eval/bias_by_percentile_temp.py | 169 +++++++ src/hirad/eval_temp.sh | 2 + 4 files changed, 696 insertions(+), 450 deletions(-) create mode 100644 src/hirad/eval/bias_by_percentile_common.py create mode 100644 src/hirad/eval/bias_by_percentile_temp.py diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py new file mode 100644 index 00000000..e86e71ea --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -0,0 +1,464 @@ +""" +Shared machinery for the *bias / MAE / spread by percentile* plots. + +Both the temperature and precipitation scripts build, for every grid point +``g`` (and ensemble member ``m``), a histogram of the field over time, estimate +the per-percentile quantile ``q_{g,m}(p)``, and then average spatially / across +members to produce the plotted curves. Everything that is identical between the +two variables lives here; each variable script only supplies a small +:class:`BiasByPercentileSpec` describing channels, binning, units and the +variable-specific plot styling. +""" +import concurrent.futures +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import matplotlib.pyplot as plt +import matplotlib.ticker as mticker +import numpy as np +import torch + +from hirad.eval.eval_utils import ( + get_channel_indices, + load_generation_setup, + load_land_sea_mask, + resolve_ts_dir, +) + +# Deterministic modes plotted alongside the ensemble, and the ensemble styling. +DET_MODE_CFG = [ + ('baseline', 'Input', 'orange'), + ('regression-prediction', 'Regression Prediction', 'red'), +] +ENSEMBLE_LABEL = 'CorrDiff Ensemble (mean +/- 1 sigma)' +ENSEMBLE_COLOR = 'green' + +_RC_PARAMS = { + 'font.size': 16, + 'axes.titlesize': 18, + 'axes.labelsize': 16, + 'xtick.labelsize': 14, + 'ytick.labelsize': 14, + 'legend.fontsize': 14, +} + + +# --------------------------------------------------------------------------- # +# Data machinery +# --------------------------------------------------------------------------- # +def to_flat(arr, conv: float, offset: float = 0.0) -> np.ndarray: + """Convert a (possibly xarray) field to a flat float numpy array, scaled and shifted.""" + return (np.asarray(getattr(arr, 'values', arr)) * conv + offset).ravel() + + +def build_all_histograms( + times: list, + ts_dirs: dict, + out_channels: tuple, + in_channels: tuple, + reduce_fn: Callable, + conv: float, + offset: float, + land_idx: np.ndarray, + hist_bins: np.ndarray, + log_interval: int, + logger: logging.Logger, +) -> tuple: + """Single serial pass over all timesteps, building histograms for every mode. + + *out_channels* / *in_channels* are tuples of channel indices; the per-channel + fields (after scaling) are combined into the plotted scalar by *reduce_fn* + (e.g. identity for single-channel fields, ``hypot`` for wind speed). + + Returns ``(det_counts, member_counts, n_members)`` where *det_counts* maps + mode → array-or-None and *member_counts* is a list of arrays (one per member) + or None. + """ + n_land = land_idx.size + n_bins = len(hist_bins) - 1 + det_modes = ('target', 'baseline', 'regression-prediction') + interior_edges = hist_bins[1:-1] + row_offsets = np.arange(n_land, dtype=np.intp) * n_bins + + det_counts: dict = {m: np.zeros((n_land, n_bins), dtype=np.int32) for m in det_modes} + member_counts: list | None = None + n_members: int | None = None + skip_det: set = set() + skip_preds = False + + def accumulate(counts, channel_arrs): + flats = [to_flat(a, conv, offset) for a in channel_arrs] + vals = reduce_fn(flats)[land_idx] + bin_idx = np.searchsorted(interior_edges, vals, side='right') + np.clip(bin_idx, 0, n_bins - 1, out=bin_idx) + counts.reshape(-1)[row_offsets + bin_idx] += 1 + + logger.info(f"Processing all modes in a single pass ({len(times)} timesteps)") + + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f" timestep {i + 1}/{len(times)}") + + ts_dir = ts_dirs[ts] + + for mode in det_modes: + if mode in skip_det: + continue + chans = in_channels if mode == 'baseline' else out_channels + try: + loaded = torch.load(ts_dir / f"{ts}-{mode}", weights_only=False) + except FileNotFoundError: + logger.warning(f" [{mode}] file not found at {ts}, skipping mode") + skip_det.add(mode) + continue + accumulate(det_counts[mode], [loaded[c] for c in chans]) + + if not skip_preds: + try: + preds = torch.load(ts_dir / f"{ts}-predictions", weights_only=False) + except FileNotFoundError: + logger.warning(f" [predictions] file not found at {ts}, skipping ensemble") + skip_preds = True + continue + + if n_members is None: + n_members = int(preds.shape[0]) + member_counts = [ + np.zeros((n_land, n_bins), dtype=np.int32) + for _ in range(n_members) + ] + logger.info(f" Detected {n_members} ensemble members") + assert n_members is not None and member_counts is not None + for m in range(n_members): + accumulate(member_counts[m], [preds[m, c] for c in out_channels]) + + for mode in skip_det: + det_counts[mode] = None + + return det_counts, member_counts, n_members + + +def per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, + frac_percentiles: np.ndarray, + block_size: int = 8192) -> np.ndarray: + """Estimate per-row quantiles from per-grid-point histograms. + + Returns ``(n_land, P)`` float32 array. Uses upper-bin-edge values (no in-bin + interpolation) — adequate given fine bins. + + Land points are processed in blocks of *block_size* rows so the CDF working + set fits comfortably in L3 cache; numpy releases the GIL for the large + C-level operations, enabling true thread parallelism across concurrent calls. + """ + n_land, n_bins = pp_counts.shape + P = len(frac_percentiles) + result = np.empty((n_land, P), dtype=np.float32) + edges_upper = bin_edges[1:].astype(np.float32) + frac_f64 = frac_percentiles.astype(np.float64) + + for start in range(0, n_land, block_size): + end = min(start + block_size, n_land) + blk = pp_counts[start:end] + B = end - start + + cdf = np.cumsum(blk, axis=1, dtype=np.float64) + totals = cdf[:, -1:] + cdf /= np.maximum(totals, 1.0) + + offset = (np.arange(B, dtype=np.float64) * 2.0)[:, None] + cdf += offset + queries = frac_f64[None, :] + offset + + idx = np.searchsorted(cdf.ravel(), queries.ravel(), side='left') + idx = idx.reshape(B, P) - (np.arange(B, dtype=np.intp)[:, None] * n_bins) + np.clip(idx, 0, n_bins - 1, out=idx) + result[start:end] = edges_upper[idx] + + return result + + +def compute_quantiles( + det_counts: dict, + member_counts: list | None, + n_members: int | None, + active_det_modes: list, + has_ensemble: bool, + hist_bins: np.ndarray, + frac_percentiles: np.ndarray, + n_workers: int, +) -> tuple: + """Compute per-point quantiles for every mode in parallel, freeing counts as we go. + + Returns ``(target_q, det_results, member_qs)``. + """ + use_ensemble = has_ensemble and member_counts is not None and n_members is not None + + with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool: + def submit(counts): + return pool.submit(per_point_quantiles, counts, hist_bins, frac_percentiles) + + fut_target = submit(det_counts['target']) + fut_det = {mode: submit(det_counts[mode]) for mode in active_det_modes} + if use_ensemble: + assert member_counts is not None and n_members is not None + fut_members = [submit(member_counts[m]) for m in range(n_members)] + else: + fut_members = [] + + target_q = fut_target.result() + del det_counts['target'] + det_results = {mode: fut_det[mode].result() for mode in active_det_modes} + for mode in active_det_modes: + del det_counts[mode] + member_qs = [f.result() for f in fut_members] + if use_ensemble: + assert member_counts is not None and n_members is not None + for m in range(n_members): + member_counts[m] = None + + return target_q, det_results, member_qs + + +def _ensemble_stats(member_qs, target_q, target_mean_q, n_members, P, n_land, mae_kind): + """Compute ensemble spread, the MAE plot entry, and per-member bias curves. + + *mae_kind* selects how the MAE band is built: + - ``'member_list'`` → list of per-member MAE curves (band = mean ±σ across members). + - ``'spatial_spread'`` → ``(mean, std)`` from per-land-point inter-member AE spread. + """ + is_spatial = mae_kind == 'spatial_spread' + sum_q = np.zeros((n_land, P), dtype=np.float64) + sumsq_q = np.zeros((n_land, P), dtype=np.float64) + sum_ae = np.zeros((n_land, P), dtype=np.float64) + sumsq_ae = np.zeros((n_land, P), dtype=np.float64) + member_biases: list = [] + member_maes: list = [] + target_q_f = target_q.astype(np.float64) + + for qm in member_qs: + qm_f = qm.astype(np.float64) + sum_q += qm + sumsq_q += qm_f ** 2 + member_biases.append((qm.mean(axis=0) - target_mean_q).astype(np.float64)) + if is_spatial: + ae = np.abs(qm_f - target_q_f) + sum_ae += ae + sumsq_ae += ae ** 2 + else: + member_maes.append(np.abs(qm - target_q).mean(axis=0).astype(np.float64)) + + mean_q = sum_q / n_members + var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) + spread = np.sqrt(var_q).mean(axis=0) + + if is_spatial: + mean_ae = sum_ae / n_members + var_ae = np.maximum(sumsq_ae / n_members - mean_ae ** 2, 0.0) + mae_entry: object = (mean_ae.mean(axis=0), np.sqrt(var_ae).mean(axis=0)) + else: + mae_entry = member_maes + + return spread, mae_entry, member_biases + + +# --------------------------------------------------------------------------- # +# Plotting helpers +# --------------------------------------------------------------------------- # +def new_percentile_axes(percentile_values: np.ndarray): + """Create a figure/axes pair and return it together with the fractional x-values.""" + fig, ax = plt.subplots(figsize=(10, 6)) + return fig, ax, percentile_values / 100.0 + + +def plot_dict_curves(ax, frac, data_dict, labels, colors, lower_clip=None) -> list: + """Plot per-mode curves and return the arrays spanning the plotted range. + + Entry types per dict value: + - ndarray → a single line. + - list of ndarrays → member curves, drawn as mean ±1 σ shading. + - ``(mean, std)`` tuple → an explicit band, drawn as mean ±1 σ shading. + """ + all_vals = [] + for (_key, data), label, color in zip(data_dict.items(), labels, colors): + if isinstance(data, list): + arr = np.array(data) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + elif isinstance(data, tuple): + mean, std = (np.asarray(data[0]), np.asarray(data[1])) + else: + ax.plot(frac, data, color=color, label=label, linewidth=2, alpha=0.85) + all_vals.append(np.asarray(data)) + continue + + lower = mean - std if lower_clip is None else np.maximum(mean - std, lower_clip) + upper = mean + std + ax.plot(frac, mean, color=color, label=label, linewidth=2) + ax.fill_between(frac, lower, upper, color=color, alpha=0.2) + all_vals.extend([lower, upper]) + return all_vals + + +def finalize_percentile_plot(ax, frac, apply_xaxis, mean_q, xlabel, ylabel, + title, out_path, legend: bool = True) -> None: + """Apply shared axis styling (via *apply_xaxis*) and write the figure to disk.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + apply_xaxis(ax, frac, mean_q) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + if legend: + ax.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +# --------------------------------------------------------------------------- # +# Orchestration +# --------------------------------------------------------------------------- # +@dataclass +class BiasByPercentileSpec: + """Variable-specific configuration for :func:`run_bias_by_percentile`.""" + var_label: str # e.g. "2m temperature" — used in log messages + output_prefix: str # e.g. "temperature" — used in output file names + bias_title: str + mae_title: str + spread_title: str + bias_ylabel: str + mae_ylabel: str + spread_ylabel: str + percentile_values: np.ndarray + mae_kind: str # 'member_list' | 'spatial_spread' + resolve_channels: Callable[[dict], tuple] # indices -> (ch_out, ch_in); raises ValueError + make_hist_bins: Callable[[dict], np.ndarray] + read_scaling: Callable[[dict], tuple] # cfg -> (conv, offset) + save_bias: Callable + save_mae: Callable + save_spread: Callable + # Combines the (scaled) per-channel flat fields into the plotted scalar. + # Defaults to the single-channel identity; wind speed uses ``hypot``. + reduce_fn: Callable[[list], np.ndarray] = lambda flats: flats[0] + + +def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: + """End-to-end driver shared by the temperature and precipitation scripts.""" + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(spec.output_prefix) + plt.rcParams.update(_RC_PARAMS) + + logger.info(f"Starting bias-by-percentile computation for {spec.var_label} over land") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Loaded {len(times)} timesteps to process") + + out_root = Path(generation_dir) + + indices = get_channel_indices(gen_cfg) + try: + ch_out, ch_in = spec.resolve_channels(indices) + except ValueError as exc: + logger.error(str(exc)) + return + # Channels may be a single index or a tuple (e.g. wind u/v); normalize to tuples. + out_channels = ch_out if isinstance(ch_out, tuple) else (ch_out,) + in_channels = ch_in if isinstance(ch_in, tuple) else (ch_in,) + logger.info(f"Channel indices - output: {out_channels}, input: {in_channels}") + + land_da = load_land_sea_mask( + cfg.get("land_sea_mask_path"), cfg.get("height") or 352, cfg.get("width") or 544 + ) + land_idx = np.flatnonzero(np.isfinite(land_da.values).ravel()) + n_land = land_idx.size + logger.info(f"{n_land} land grid points") + + hist_bins = spec.make_hist_bins(cfg) + log_interval = cfg.get("log_interval", 24) + conv, offset = spec.read_scaling(cfg) + + percentile_values = spec.percentile_values + frac_percentiles = percentile_values / 100.0 + P = len(frac_percentiles) + + logger.info(f"Resolving {len(times)} timestep directories ...") + ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} + + det_counts, member_counts, n_members = build_all_histograms( + times, ts_dirs, out_channels, in_channels, spec.reduce_fn, conv, offset, + land_idx, hist_bins, log_interval, logger, + ) + + if det_counts.get('target') is None: + logger.error("No target data found; cannot compute bias.") + return + + active_det_modes = [m for m, _, _ in DET_MODE_CFG if det_counts.get(m) is not None] + has_ensemble = member_counts is not None and n_members is not None and n_members > 0 + + n_tasks = 1 + len(active_det_modes) + (n_members if has_ensemble else 0) + n_quant_workers = cfg.get("n_quant_workers", n_tasks) + + target_q, det_results, member_qs = compute_quantiles( + det_counts, member_counts, n_members, active_det_modes, has_ensemble, + hist_bins, frac_percentiles, n_quant_workers, + ) + + target_mean_q = target_q.mean(axis=0) + + bias_data: dict = {} + mae_data: dict = {} + labels: list = [] + colors: list = [] + + for mode, label, color in DET_MODE_CFG: + if mode not in det_results: + continue + pred_q = det_results.pop(mode) + bias_data[mode] = pred_q.mean(axis=0) - target_mean_q + mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) + labels.append(label) + colors.append(color) + + spread = None + if has_ensemble: + spread, mae_entry, member_biases = _ensemble_stats( + member_qs, target_q, target_mean_q, n_members, P, n_land, spec.mae_kind, + ) + bias_data['predictions'] = member_biases + mae_data['predictions'] = mae_entry + labels.append(ENSEMBLE_LABEL) + colors.append(ENSEMBLE_COLOR) + + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "by_percentile" + output_path.mkdir(parents=True, exist_ok=True) + + fn = output_path / f'{spec.output_prefix}_bias_by_percentile.png' + spec.save_bias( + bias_data, percentile_values, labels, colors, + title=spec.bias_title, xlabel='Percentile', ylabel=spec.bias_ylabel, + out_path=fn, mean_q=target_mean_q, + ) + logger.info(f"Bias-by-percentile plot saved: {fn}") + + fn_mae = output_path / f'{spec.output_prefix}_mae_by_percentile.png' + spec.save_mae( + mae_data, percentile_values, labels, colors, + title=spec.mae_title, xlabel='Percentile', ylabel=spec.mae_ylabel, + out_path=fn_mae, mean_q=target_mean_q, + ) + logger.info(f"MAE-by-percentile plot saved: {fn_mae}") + + if spread is not None: + fn_spread = output_path / f'{spec.output_prefix}_spread_by_percentile.png' + spec.save_spread( + spread, percentile_values, + title=spec.spread_title, xlabel='Percentile', ylabel=spec.spread_ylabel, + out_path=fn_spread, mean_q=target_mean_q, + ) + logger.info(f"Spread-by-percentile plot saved: {fn_spread}") diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index d7ca0e38..7325fa65 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -1,212 +1,21 @@ """ Plots bias / MAE / spread as a function of percentile for precipitation, using a -local-then-averaged estimator. - -For each grid point g (and ensemble member m), a histogram of precipitation is -built over time and the per-percentile quantile q_{g,m}(p) is estimated. -Spatial / member averaging is then applied to produce the plotted curves. +local-then-averaged estimator (see :mod:`hirad.eval.bias_by_percentile_common`). """ -import concurrent.futures -import logging -from pathlib import Path - -import matplotlib.pyplot as plt -import matplotlib.ticker as mticker import numpy as np -import torch -from hirad.eval.eval_utils import ( - get_channel_indices, - load_generation_setup, - load_land_sea_mask, - parse_eval_cli, - resolve_ts_dir, +from hirad.eval.bias_by_percentile_common import ( + BiasByPercentileSpec, + finalize_percentile_plot, + new_percentile_axes, + plot_dict_curves, + run_bias_by_percentile, ) - - -def _to_flat(arr, conv: float) -> np.ndarray: - """Convert a (possibly xarray) field to a flat float numpy array scaled by conv.""" - return (np.asarray(getattr(arr, 'values', arr)) * conv).ravel() - - -def _build_all_histograms( - times: list, - ts_dirs: dict, - tp_out: int, - tp_in: int, - conv: float, - land_idx: np.ndarray, - hist_bins: np.ndarray, - log_interval: int, - logger: logging.Logger, -) -> tuple: - """Single serial pass over all timesteps, building histograms for every mode. - - Returns ``(det_counts, member_counts, n_members)`` where *det_counts* maps - mode → array-or-None and *member_counts* is a list of arrays (one per member) - or None. - """ - n_land = land_idx.size - n_bins = len(hist_bins) - 1 - det_modes = ('target', 'baseline', 'regression-prediction') - interior_edges = hist_bins[1:-1] - row_offsets = np.arange(n_land, dtype=np.intp) * n_bins - - det_counts: dict = {m: np.zeros((n_land, n_bins), dtype=np.int32) for m in det_modes} - member_counts: list | None = None - n_members: int | None = None - skip_det: set = set() - skip_preds = False - - def accumulate(counts, arr): - vals = _to_flat(arr, conv)[land_idx] - bin_idx = np.searchsorted(interior_edges, vals, side='right') - counts.reshape(-1)[row_offsets + bin_idx] += 1 - - logger.info(f"Processing all modes in a single pass ({len(times)} timesteps)") - - for i, ts in enumerate(times): - if i % log_interval == 0: - logger.info(f" timestep {i + 1}/{len(times)}") - - ts_dir = ts_dirs[ts] - - for mode in det_modes: - if mode in skip_det: - continue - ch = tp_in if mode == 'baseline' else tp_out - try: - arr = torch.load(ts_dir / f"{ts}-{mode}", weights_only=False)[ch] - except FileNotFoundError: - logger.warning(f" [{mode}] file not found at {ts}, skipping mode") - skip_det.add(mode) - continue - accumulate(det_counts[mode], arr) - - if not skip_preds: - try: - preds = torch.load(ts_dir / f"{ts}-predictions", weights_only=False) - except FileNotFoundError: - logger.warning(f" [predictions] file not found at {ts}, skipping ensemble") - skip_preds = True - continue - - if n_members is None: - n_members = int(preds.shape[0]) - member_counts = [ - np.zeros((n_land, n_bins), dtype=np.int32) - for _ in range(n_members) - ] - logger.info(f" Detected {n_members} ensemble members") - assert n_members is not None and member_counts is not None - for m in range(n_members): - accumulate(member_counts[m], preds[m, tp_out]) - - for mode in skip_det: - det_counts[mode] = None - - return det_counts, member_counts, n_members - - -def _per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, - frac_percentiles: np.ndarray, - block_size: int = 8192) -> np.ndarray: - """Estimate per-row quantiles from per-grid-point histograms. - - Returns (n_land, P) float32 array. Uses upper-bin-edge values (no in-bin - interpolation) — adequate given fine log-spaced bins. - - Processes land points in blocks of *block_size* rows so the CDF working set - (~block_size × n_bins × 8 bytes) fits comfortably in L3 cache, making the - row-wise ``searchsorted`` cache-friendly and GIL-free (numpy releases the - GIL for large C-level operations, enabling true thread parallelism when - multiple calls run concurrently). - """ - n_land, n_bins = pp_counts.shape - P = len(frac_percentiles) - result = np.empty((n_land, P), dtype=np.float32) - edges_upper = bin_edges[1:].astype(np.float32) - frac_f64 = frac_percentiles.astype(np.float64) - - for start in range(0, n_land, block_size): - end = min(start + block_size, n_land) - blk = pp_counts[start:end] - B = end - start - - cdf = np.cumsum(blk, axis=1, dtype=np.float64) - totals = cdf[:, -1:] - cdf /= np.maximum(totals, 1.0) - - offset = (np.arange(B, dtype=np.float64) * 2.0)[:, None] - cdf += offset - queries = frac_f64[None, :] + offset - - idx = np.searchsorted(cdf.ravel(), queries.ravel(), side='left') - idx = idx.reshape(B, P) - (np.arange(B, dtype=np.intp)[:, None] * n_bins) - np.clip(idx, 0, n_bins - 1, out=idx) - result[start:end] = edges_upper[idx] - - return result - - -def save_bias_by_percentile_plot( - bias_data_dict: dict, - percentile_values: np.ndarray, - labels: list, - colors: list, - title: str, - xlabel: str, - ylabel: str, - out_path, - mean_q: np.ndarray | None = None, -) -> None: - """Save a bias-by-percentile figure. - - Parameters - ---------- - bias_data_dict : dict mapping key → bias array (n_percentiles,) for single - datasets, or list/tuple of (n_percentiles,) arrays for ensembles. - """ - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - - fig, ax = plt.subplots(figsize=(10, 6)) - frac = percentile_values / 100.0 - - for (key, bias_data), label, color in zip(bias_data_dict.items(), labels, colors): - if isinstance(bias_data, (list, tuple)): - arr = np.array(bias_data) - mean_bias = arr.mean(axis=0) - std_bias = arr.std(axis=0) - ax.plot(frac, mean_bias, color=color, label=label, linewidth=2) - ax.fill_between( - frac, - mean_bias - std_bias, - mean_bias + std_bias, - color=color, - alpha=0.2, - ) - else: - ax.plot( - frac, bias_data, - color=color, label=label, linewidth=2, alpha=0.85, - ) - - ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') - _apply_logit_xaxis(ax, frac, mean_q) - ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) - ax.set_ylim(-10, 10) - ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) - ax.legend() - plt.tight_layout() - plt.savefig(out_path, dpi=300, bbox_inches='tight') - plt.close() +from hirad.eval.eval_utils import parse_eval_cli def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: - """Apply logit x-axis with labelled percentile ticks.""" + """Apply logit x-axis with labelled percentile ticks (and a mm/h secondary axis).""" ax.set_xscale('logit') ax.set_xlim(0.5, frac[-1]) tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] @@ -223,281 +32,83 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) tick_positions = np.interp(nice_mmh, mean_q, frac) valid = (tick_positions > 0.5) & (tick_positions < frac[-1]) - tick_positions = tick_positions[valid] - tick_mmh = nice_mmh[valid] - ax2.set_xticks(tick_positions) - ax2.set_xticklabels([f'{v:g}' for v in tick_mmh]) + ax2.set_xticks(tick_positions[valid]) + ax2.set_xticklabels([f'{v:g}' for v in nice_mmh[valid]]) ax2.set_xlabel('Mean target [mm/h]') -def save_mae_by_percentile_plot( - mae_data_dict: dict, - percentile_values: np.ndarray, - labels: list, - colors: list, - title: str, - xlabel: str, - ylabel: str, - out_path, - mean_q: np.ndarray | None = None, -) -> None: - """Save a MAE-by-percentile figure. - - For single datasets the MAE curve is plotted directly. For the ensemble - the mean absolute error across members is shown with ±1 σ shading. - """ - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - fig, ax = plt.subplots(figsize=(10, 6)) - frac = percentile_values / 100.0 +def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a bias-by-percentile figure (symlog y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + plot_dict_curves(ax, frac, bias_data_dict, labels, colors) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) + ax.set_ylim(-10, 10) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) - for (key, mae_data), label, color in zip(mae_data_dict.items(), labels, colors): - if isinstance(mae_data, tuple) and len(mae_data) == 2 and isinstance(mae_data[0], np.ndarray): - # (mean_mae, spread_mae) computed from per-land-point inter-member variance - mean_mae, std_mae = mae_data - ax.plot(frac, mean_mae, color=color, label=label, linewidth=2) - ax.fill_between( - frac, - np.maximum(mean_mae - std_mae, 0), - mean_mae + std_mae, - color=color, alpha=0.2, - ) - elif isinstance(mae_data, list): - arr = np.array(mae_data) - mean_mae = arr.mean(axis=0) - std_mae = arr.std(axis=0) - ax.plot(frac, mean_mae, color=color, label=label, linewidth=2) - ax.fill_between( - frac, - np.maximum(mean_mae - std_mae, 0), - mean_mae + std_mae, - color=color, alpha=0.2, - ) - else: - ax.plot(frac, mae_data, color=color, label=label, linewidth=2, alpha=0.85) - _apply_logit_xaxis(ax, frac, mean_q) +def save_mae_by_percentile_plot(mae_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a MAE-by-percentile figure (log y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + plot_dict_curves(ax, frac, mae_data_dict, labels, colors, lower_clip=0) ax.set_yscale('log') ax.set_ylim(1e-5, 100) - ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) - ax.legend() - plt.tight_layout() - plt.savefig(out_path, dpi=300, bbox_inches='tight') - plt.close() + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) -def save_spread_by_percentile_plot( - spread: np.ndarray, - percentile_values: np.ndarray, - title: str, - xlabel: str, - ylabel: str, - out_path, - mean_q: np.ndarray | None = None, -) -> None: - """Save an ensemble-spread-by-percentile figure. - - Spread is the inter-member standard deviation of the p-th quantile, - i.e. how much the ensemble members disagree at each percentile level. - """ - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - fig, ax = plt.subplots(figsize=(10, 6)) - frac = percentile_values / 100.0 - +def save_spread_by_percentile_plot(spread, percentile_values, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save an ensemble-spread-by-percentile figure (log y-axis, no legend).""" + _, ax, frac = new_percentile_axes(percentile_values) ax.plot(frac, spread, color='green', linewidth=2) - _apply_logit_xaxis(ax, frac, mean_q) ax.set_yscale('log') ax.set_ylim(1e-5, 100) - ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) - plt.tight_layout() - plt.savefig(out_path, dpi=300, bbox_inches='tight') - plt.close() + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path, legend=False) -def main(cfg: dict) -> None: - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - plt.rcParams.update({ - 'font.size': 16, - 'axes.titlesize': 18, - 'axes.labelsize': 16, - 'xtick.labelsize': 14, - 'ytick.labelsize': 14, - 'legend.fontsize': 14, - }) - - logger.info("Starting bias-by-percentile computation for precipitation over land") - try: - generation_dir, gen_cfg, times = load_generation_setup(cfg) - except ValueError as exc: - logger.error(str(exc)) - return - logger.info(f"Loaded {len(times)} timesteps to process") - - out_root = Path(generation_dir) - - indices = get_channel_indices(gen_cfg) +def _resolve_channels(indices: dict) -> tuple: tp_out = indices['output']['tp'] tp_in = indices['input'].get('tp', tp_out) - logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + return tp_out, tp_in - land_da = load_land_sea_mask( - cfg.get("land_sea_mask_path"), cfg.get("height") or 352, cfg.get("width") or 544 - ) - land_bool_2d = np.isfinite(land_da.values) - land_idx = np.flatnonzero(land_bool_2d.ravel()) - n_land = land_idx.size - logger.info(f"{n_land} land grid points") +def _make_hist_bins(cfg: dict) -> np.ndarray: n_bins = 500 - hist_bins = np.concatenate([ - np.array([0.0]), - np.logspace(-2, 3.2, n_bins), - ]) - log_interval = cfg.get("log_interval", 24) - conv = cfg.get("conv_factor_hourly", 1.0) - - percentile_values = np.unique(np.concatenate([ + return np.concatenate([np.array([0.0]), np.logspace(-2, 3.2, n_bins)]) + + +SPEC = BiasByPercentileSpec( + var_label='precipitation', + output_prefix='precipitation', + bias_title='Precipitation Bias Over Land', + mae_title='Precipitation MAE Over Land', + spread_title='Precipitation Ensemble Spread Over Land', + bias_ylabel='Bias [mm/h]', + mae_ylabel='MAE [mm/h]', + spread_ylabel='Ensemble Spread [mm/h]', + percentile_values=np.unique(np.concatenate([ np.linspace(1.0, 90.0, 90), np.linspace(90.0, 99.0, 90), np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), - ])) - frac_percentiles = percentile_values / 100.0 - P = len(frac_percentiles) - - logger.info(f"Resolving {len(times)} timestep directories ...") - ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} - - det_counts, member_counts, n_members = _build_all_histograms( - times, ts_dirs, tp_out, tp_in, conv, land_idx, hist_bins, - log_interval, logger, - ) - - if det_counts.get('target') is None: - logger.error("No target data found; cannot compute bias.") - return - - det_mode_cfg = [ - ('baseline', 'Input', 'orange'), - ('regression-prediction', 'Regression Prediction', 'red'), - ] - active_det_modes = [m for m, _, _ in det_mode_cfg if det_counts.get(m) is not None] - has_ensemble = member_counts is not None and n_members is not None and n_members > 0 - - n_tasks = 1 + len(active_det_modes) + (n_members if has_ensemble else 0) - n_quant_workers = cfg.get("n_quant_workers", n_tasks) - - with concurrent.futures.ThreadPoolExecutor(max_workers=n_quant_workers) as pool: - def submit(counts): - return pool.submit(_per_point_quantiles, counts, hist_bins, frac_percentiles) - - fut_target = submit(det_counts['target']) - fut_det = {mode: submit(det_counts[mode]) for mode in active_det_modes} - fut_members = ( - [submit(member_counts[m]) for m in range(n_members)] - if has_ensemble else [] - ) - - target_q = fut_target.result() - del det_counts['target'] - det_results = {mode: fut_det[mode].result() for mode in active_det_modes} - for mode in active_det_modes: - del det_counts[mode] - member_qs = [f.result() for f in fut_members] - if has_ensemble: - for m in range(n_members): - member_counts[m] = None - - target_mean_q = target_q.mean(axis=0) - - bias_data: dict = {} - mae_data: dict = {} - labels: list = [] - colors: list = [] - - for mode, label, color in det_mode_cfg: - if mode not in det_results: - continue - pred_q = det_results.pop(mode) - bias_data[mode] = pred_q.mean(axis=0) - target_mean_q - mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) - labels.append(label) - colors.append(color) - - member_biases: list[np.ndarray] = [] - spread = None - ensemble_mae: tuple[np.ndarray, np.ndarray] | None = None - - if has_ensemble: - sum_q = np.zeros((n_land, P), dtype=np.float64) - sumsq_q = np.zeros((n_land, P), dtype=np.float64) - sum_ae = np.zeros((n_land, P), dtype=np.float64) - sumsq_ae = np.zeros((n_land, P), dtype=np.float64) - for qm in member_qs: - ae = np.abs(qm.astype(np.float64) - target_q.astype(np.float64)) - sum_q += qm - sumsq_q += qm.astype(np.float64) ** 2 - sum_ae += ae - sumsq_ae += ae ** 2 - member_biases.append((qm.mean(axis=0) - target_mean_q).astype(np.float64)) - mean_q = sum_q / n_members - var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) - spread = np.sqrt(var_q).mean(axis=0) - - mean_ae = sum_ae / n_members - var_ae = np.maximum(sumsq_ae / n_members - mean_ae ** 2, 0.0) - # spatial mean of per-land-point inter-member MAE spread - ensemble_mae = (mean_ae.mean(axis=0), np.sqrt(var_ae).mean(axis=0)) - - bias_data['predictions'] = member_biases - mae_data['predictions'] = ensemble_mae - labels.append('CorrDiff Ensemble (mean +/- 1 sigma)') - colors.append('green') - - output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "by_percentile" - output_path.mkdir(parents=True, exist_ok=True) - - fn = output_path / 'precipitation_bias_by_percentile.png' - save_bias_by_percentile_plot( - bias_data, percentile_values, labels, colors, - title='Precipitation Bias Over Land', - xlabel='Percentile', - ylabel='Bias [mm/h]', - out_path=fn, - mean_q=target_mean_q, - ) - logger.info(f"Bias-by-percentile plot saved: {fn}") - - fn_mae = output_path / 'precipitation_mae_by_percentile.png' - save_mae_by_percentile_plot( - mae_data, percentile_values, labels, colors, - title='Precipitation MAE Over Land', - xlabel='Percentile', - ylabel='MAE [mm/h]', - out_path=fn_mae, - mean_q=target_mean_q, - ) - logger.info(f"MAE-by-percentile plot saved: {fn_mae}") + ])), + mae_kind='spatial_spread', + resolve_channels=_resolve_channels, + make_hist_bins=_make_hist_bins, + read_scaling=lambda cfg: (cfg.get("conv_factor_hourly", 1.0), 0.0), + save_bias=save_bias_by_percentile_plot, + save_mae=save_mae_by_percentile_plot, + save_spread=save_spread_by_percentile_plot, +) - if spread is not None: - fn_spread = output_path / 'precipitation_spread_by_percentile.png' - save_spread_by_percentile_plot( - spread, percentile_values, - title='Precipitation Ensemble Spread Over Land', - xlabel='Percentile', - ylabel='Ensemble Spread [mm/h]', - out_path=fn_spread, - mean_q=target_mean_q, - ) - logger.info(f"Spread-by-percentile plot saved: {fn_spread}") +def main(cfg: dict) -> None: + run_bias_by_percentile(cfg, SPEC) if __name__ == '__main__': diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py new file mode 100644 index 00000000..8fd75f77 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -0,0 +1,169 @@ +""" +Plots bias / MAE / spread as a function of percentile for 2m temperature, using a +local-then-averaged estimator (see :mod:`hirad.eval.bias_by_percentile_common`). +""" +import numpy as np + +from hirad.eval.bias_by_percentile_common import ( + BiasByPercentileSpec, + finalize_percentile_plot, + new_percentile_axes, + plot_dict_curves, + run_bias_by_percentile, +) +from hirad.eval.eval_utils import parse_eval_cli + + +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: + """Apply logit x-axis with labelled percentile ticks (and a °C secondary axis).""" + ax.set_xscale('logit') + ax.set_xlim(frac[0], frac[-1]) + tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] + tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] + # only show ticks within our data range + valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) + if frac[0] <= f <= frac[-1]] + ax.set_xticks([f for f, _ in valid_ticks]) + ax.set_xticklabels([l for _, l in valid_ticks]) + ax.grid(True, alpha=0.3, which='both') + + if mean_q is not None: + ax2 = ax.twiny() + ax2.set_xscale('logit') + ax2.set_xlim(frac[0], frac[-1]) + # Temperature is ~linear in percentile, but the axis is logit, so a fixed + # coarse list of temps bunches near the median while the tails get no + # labels. Use dense integer-degree candidates and greedily keep only those + # spaced far enough apart *on the axis* (in logit units): the compressed + # tails get fine 1° steps, the centre gets coarse steps, and the labels + # end up evenly spaced. + tick_positions, tick_temps = _even_temp_ticks(frac, mean_q) + ax2.set_xticks(tick_positions) + ax2.set_xticklabels([f'{v:g}' for v in tick_temps]) + ax2.set_xlabel('Mean target [°C]') + + +def _even_temp_ticks(frac: np.ndarray, mean_q: np.ndarray, + min_gap_frac: float = 0.06) -> tuple: + """Pick integer-degree temperature ticks evenly spaced along the logit axis. + + Candidates are every whole degree within the data range; we greedily keep a + tick only if it is at least *min_gap_frac* of the axis span (measured in + logit coordinates) from the previously kept one. Returns ``(positions, + temps)``. + """ + def _logit(p): + p = np.clip(p, 1e-9, 1 - 1e-9) + return np.log(p / (1.0 - p)) + + t_lo = int(np.ceil(mean_q[0])) + t_hi = int(np.floor(mean_q[-1])) + if t_hi <= t_lo: + return np.array([]), np.array([]) + + temps = np.arange(t_lo, t_hi + 1, dtype=float) + positions = np.interp(temps, mean_q, frac) + lp = _logit(positions) + min_gap = abs(_logit(frac[-1]) - _logit(frac[0])) * min_gap_frac + + keep = [0] + for i in range(1, len(lp)): + if lp[i] - lp[keep[-1]] >= min_gap: + keep.append(i) + keep = np.array(keep, dtype=np.intp) + return positions[keep], temps[keep] + + +def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a bias-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, bias_data_dict, labels, colors) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + vcat = np.concatenate([np.asarray(v).ravel() for v in all_vals]) + vmin, vmax = float(np.nanmin(vcat)), float(np.nanmax(vcat)) + margin = max(abs(vmax - vmin) * 0.1, 0.05) + ax.set_ylim(vmin - margin, vmax + margin) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_mae_by_percentile_plot(mae_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a MAE-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, mae_data_dict, labels, colors, lower_clip=0) + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_spread_by_percentile_plot(spread, percentile_values, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save an ensemble-spread-by-percentile figure (linear y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + ax.plot(frac, spread, color='green', linewidth=2, label='Ensemble spread') + ymax_spread = float(np.nanmax(spread)) * 1.1 + if ymax_spread > 0: + ax.set_ylim(0, ymax_spread) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def _resolve_channels(indices: dict) -> tuple: + # Temperature channel: try '2t' first, then 't2m' + t2m_out = indices['output'].get('2t', indices['output'].get('t2m')) + t2m_in = indices['input'].get('2t', indices['input'].get('t2m', t2m_out)) + if t2m_out is None: + raise ValueError("Temperature channel (2t / t2m) not found in output channels.") + return t2m_out, t2m_in + + +def _make_hist_bins(cfg: dict) -> np.ndarray: + # Linear histogram bins in °C — fine enough to resolve sub-degree differences + n_bins = cfg.get("n_bins", 2000) + temp_min = cfg.get("temp_bin_min_celsius", -90.0) + temp_max = cfg.get("temp_bin_max_celsius", 65.0) + return np.linspace(temp_min, temp_max, n_bins + 1) + + +SPEC = BiasByPercentileSpec( + var_label='2m temperature', + output_prefix='temperature', + bias_title='2m Temperature Bias Over Land', + mae_title='2m Temperature MAE Over Land', + spread_title='2m Temperature Ensemble Spread Over Land', + bias_ylabel='Bias [°C]', + mae_ylabel='MAE [°C]', + spread_ylabel='Ensemble Spread [°C]', + percentile_values=np.unique(np.concatenate([ + np.linspace(0.01, 0.1, 10), + np.linspace(0.1, 1.0, 10), + np.linspace(1.0, 10.0, 10), + np.linspace(10.0, 90.0, 80), + np.linspace(90.0, 99.0, 90), + np.linspace(99.0, 99.9, 45), + np.linspace(99.9, 99.99, 20), + ])), + mae_kind='member_list', + resolve_channels=_resolve_channels, + make_hist_bins=_make_hist_bins, + # Default: convert Kelvin → °C (conv=1.0, offset=-273.15) + read_scaling=lambda cfg: (cfg.get("temp_conv_factor", 1.0), + cfg.get("temp_offset_celsius", -273.15)), + save_bias=save_bias_by_percentile_plot, + save_mae=save_mae_by_percentile_plot, + save_spread=save_spread_by_percentile_plot, +) + + +def main(cfg: dict) -> None: + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval_temp.sh b/src/hirad/eval_temp.sh index 1b56acf9..41af460d 100644 --- a/src/hirad/eval_temp.sh +++ b/src/hirad/eval_temp.sh @@ -8,6 +8,8 @@ CONFIG_NAME="src/hirad/conf/eval_real.yaml" CMDS=( # Diurnal cycle of 2m temperature "python src/hirad/eval/diurnal_cycle_temp.py" + # QQ + "python -m hirad.eval.bias_by_percentile_temp" ) for cmd in "${CMDS[@]}"; do From be60d22eef201b93c07099362da6592c29d10169 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 16:40:52 +0200 Subject: [PATCH 14/51] add the same plot for windspeed --- src/hirad/eval/bias_by_percentile_common.py | 61 +++++++++ src/hirad/eval/bias_by_percentile_temp.py | 43 +----- src/hirad/eval/bias_by_percentile_wind.py | 140 ++++++++++++++++++++ src/hirad/eval_wind.sh | 4 +- 4 files changed, 209 insertions(+), 39 deletions(-) create mode 100644 src/hirad/eval/bias_by_percentile_wind.py diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index e86e71ea..e090f872 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -272,6 +272,67 @@ def new_percentile_axes(percentile_values: np.ndarray): return fig, ax, percentile_values / 100.0 +def _nice_step(rough: float) -> float: + """Round *rough* up to the nearest 1/2/2.5/5 ×10ⁿ 'nice' step.""" + if rough <= 0: + return 1.0 + exp = np.floor(np.log10(rough)) + base = 10.0 ** exp + for mult in (1.0, 2.0, 2.5, 5.0, 10.0): + if rough <= mult * base: + return mult * base + return 10.0 * base + + +def _round_sig(x: float, sig: int = 2) -> float: + """Round *x* to *sig* significant figures (clean axis labels).""" + if x == 0 or not np.isfinite(x): + return 0.0 + digits = sig - int(np.floor(np.log10(abs(x)))) - 1 + return round(x, digits) + + +def even_value_ticks(frac: np.ndarray, mean_q: np.ndarray, + target_ticks: int = 9) -> tuple: + """Pick secondary-axis ticks evenly spaced along the logit axis. + + The plotted x-axis is logit in *percentile*, while the secondary axis labels + the *mean target value*. We place *target_ticks* positions evenly spaced in + logit coordinates (so they look uniform on the axis), including both corners + — the first and last points that actually have data — then label each with + the interpolated value rounded to two significant figures for clean, + readable tick labels. Returns ``(positions, values)``. + """ + def _logit(p): + p = np.clip(p, 1e-9, 1 - 1e-9) + return np.log(p / (1.0 - p)) + + def _expit(z): + return 1.0 / (1.0 + np.exp(-z)) + + v_lo, v_hi = float(mean_q[0]), float(mean_q[-1]) + if v_hi <= v_lo: + return np.array([]), np.array([]) + + lp_lo, lp_hi = _logit(frac[0]), _logit(frac[-1]) + # Evenly spaced sample positions along the logit axis (corners included). + sample_lp = np.linspace(lp_lo, lp_hi, max(target_ticks, 2)) + sample_pos = _expit(sample_lp) + sample_val = np.interp(sample_pos, frac, mean_q) + + # Round labels to two significant figures, keeping the true axis position. + rounded = np.array([_round_sig(v) for v in sample_val]) + + # Drop consecutive duplicate labels (can happen where the value saturates), + # always keeping the first occurrence so both corners survive. + keep = [0] + for i in range(1, len(rounded)): + if rounded[i] != rounded[keep[-1]]: + keep.append(i) + keep = np.array(keep, dtype=np.intp) + return sample_pos[keep], rounded[keep] + + def plot_dict_curves(ax, frac, data_dict, labels, colors, lower_clip=None) -> list: """Plot per-mode curves and return the arrays spanning the plotted range. diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 8fd75f77..00dcf220 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -6,6 +6,7 @@ from hirad.eval.bias_by_percentile_common import ( BiasByPercentileSpec, + even_value_ticks, finalize_percentile_plot, new_percentile_axes, plot_dict_curves, @@ -31,49 +32,15 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - ax2 = ax.twiny() ax2.set_xscale('logit') ax2.set_xlim(frac[0], frac[-1]) - # Temperature is ~linear in percentile, but the axis is logit, so a fixed - # coarse list of temps bunches near the median while the tails get no - # labels. Use dense integer-degree candidates and greedily keep only those - # spaced far enough apart *on the axis* (in logit units): the compressed - # tails get fine 1° steps, the centre gets coarse steps, and the labels - # end up evenly spaced. - tick_positions, tick_temps = _even_temp_ticks(frac, mean_q) + # The axis is logit in percentile, so a fixed list of round temps bunches + # near the median while the tails get no labels; even_value_ticks picks a + # nice step and spreads labels evenly across the whole logit axis. + tick_positions, tick_temps = even_value_ticks(frac, mean_q) ax2.set_xticks(tick_positions) ax2.set_xticklabels([f'{v:g}' for v in tick_temps]) ax2.set_xlabel('Mean target [°C]') -def _even_temp_ticks(frac: np.ndarray, mean_q: np.ndarray, - min_gap_frac: float = 0.06) -> tuple: - """Pick integer-degree temperature ticks evenly spaced along the logit axis. - - Candidates are every whole degree within the data range; we greedily keep a - tick only if it is at least *min_gap_frac* of the axis span (measured in - logit coordinates) from the previously kept one. Returns ``(positions, - temps)``. - """ - def _logit(p): - p = np.clip(p, 1e-9, 1 - 1e-9) - return np.log(p / (1.0 - p)) - - t_lo = int(np.ceil(mean_q[0])) - t_hi = int(np.floor(mean_q[-1])) - if t_hi <= t_lo: - return np.array([]), np.array([]) - - temps = np.arange(t_lo, t_hi + 1, dtype=float) - positions = np.interp(temps, mean_q, frac) - lp = _logit(positions) - min_gap = abs(_logit(frac[-1]) - _logit(frac[0])) * min_gap_frac - - keep = [0] - for i in range(1, len(lp)): - if lp[i] - lp[keep[-1]] >= min_gap: - keep.append(i) - keep = np.array(keep, dtype=np.intp) - return positions[keep], temps[keep] - - def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, title, xlabel, ylabel, out_path, mean_q=None) -> None: """Save a bias-by-percentile figure (linear y-axis, data-driven limits).""" diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py new file mode 100644 index 00000000..46e537ea --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -0,0 +1,140 @@ +""" +Plots bias / MAE / spread as a function of percentile for 10 m wind speed, using a +local-then-averaged estimator (see :mod:`hirad.eval.bias_by_percentile_common`). + +Wind speed is derived per grid point from the two surface wind components +(``10u``, ``10v``) as ``hypot(u, v)`` before histogramming. +""" +import numpy as np + +from hirad.eval.bias_by_percentile_common import ( + BiasByPercentileSpec, + even_value_ticks, + finalize_percentile_plot, + new_percentile_axes, + plot_dict_curves, + run_bias_by_percentile, +) +from hirad.eval.eval_utils import parse_eval_cli + + +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: + """Apply logit x-axis with labelled percentile ticks (and an m/s secondary axis).""" + ax.set_xscale('logit') + ax.set_xlim(frac[0], frac[-1]) + tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] + tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] + # only show ticks within our data range + valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) + if frac[0] <= f <= frac[-1]] + ax.set_xticks([f for f, _ in valid_ticks]) + ax.set_xticklabels([l for _, l in valid_ticks]) + ax.grid(True, alpha=0.3, which='both') + + if mean_q is not None: + ax2 = ax.twiny() + ax2.set_xscale('logit') + ax2.set_xlim(frac[0], frac[-1]) + # Wind speed spans only a few m/s, so integer-only ticks leave the + # compressed tail unlabelled; even_value_ticks picks a nice sub-unit step + # and spreads labels evenly across the whole logit axis. + tick_positions, tick_speeds = even_value_ticks(frac, mean_q) + ax2.set_xticks(tick_positions) + ax2.set_xticklabels([f'{v:g}' for v in tick_speeds]) + ax2.set_xlabel('Mean target [m/s]') + + +def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a bias-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, bias_data_dict, labels, colors) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + vcat = np.concatenate([np.asarray(v).ravel() for v in all_vals]) + vmin, vmax = float(np.nanmin(vcat)), float(np.nanmax(vcat)) + margin = max(abs(vmax - vmin) * 0.1, 0.05) + ax.set_ylim(vmin - margin, vmax + margin) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_mae_by_percentile_plot(mae_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a MAE-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, mae_data_dict, labels, colors, lower_clip=0) + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_spread_by_percentile_plot(spread, percentile_values, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save an ensemble-spread-by-percentile figure (linear y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + ax.plot(frac, spread, color='green', linewidth=2, label='Ensemble spread') + ymax_spread = float(np.nanmax(spread)) * 1.1 + if ymax_spread > 0: + ax.set_ylim(0, ymax_spread) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def _resolve_channels(indices: dict) -> tuple: + # Wind speed is derived from the two surface wind components. + u_out = indices['output'].get('10u') + v_out = indices['output'].get('10v') + if u_out is None or v_out is None: + raise ValueError("Wind components (10u / 10v) not found in output channels.") + u_in = indices['input'].get('10u', u_out) + v_in = indices['input'].get('10v', v_out) + return (u_out, v_out), (u_in, v_in) + + +def _make_hist_bins(cfg: dict) -> np.ndarray: + # Linear histogram bins in m/s — fine enough to resolve sub-m/s differences + n_bins = cfg.get("wind_n_bins", 1500) + speed_min = cfg.get("wind_bin_min_ms", 0.0) + speed_max = cfg.get("wind_bin_max_ms", 75.0) + return np.linspace(speed_min, speed_max, n_bins + 1) + + +SPEC = BiasByPercentileSpec( + var_label='10 m wind speed', + output_prefix='windspeed', + bias_title='10 m Wind Speed Bias Over Land', + mae_title='10 m Wind Speed MAE Over Land', + spread_title='10 m Wind Speed Ensemble Spread Over Land', + bias_ylabel='Bias [m/s]', + mae_ylabel='MAE [m/s]', + spread_ylabel='Ensemble Spread [m/s]', + percentile_values=np.unique(np.concatenate([ + np.linspace(0.01, 0.1, 10), + np.linspace(0.1, 1.0, 10), + np.linspace(1.0, 10.0, 10), + np.linspace(10.0, 90.0, 80), + np.linspace(90.0, 99.0, 90), + np.linspace(99.0, 99.9, 45), + np.linspace(99.9, 99.99, 20), + ])), + mae_kind='member_list', + resolve_channels=_resolve_channels, + make_hist_bins=_make_hist_bins, + read_scaling=lambda cfg: (cfg.get("wind_conv_factor", 1.0), 0.0), + reduce_fn=lambda flats: np.hypot(flats[0], flats[1]), + save_bias=save_bias_by_percentile_plot, + save_mae=save_mae_by_percentile_plot, + save_spread=save_spread_by_percentile_plot, +) + + +def main(cfg: dict) -> None: + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index 6f1ce39d..a90f6c0a 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -3,13 +3,15 @@ set -euo pipefail ### CONFIG ### -CONFIG_NAME="src/hirad/conf/eval_real.yaml" +CONFIG_NAME="src/hirad/conf/eval_real_tst.yaml" CMDS=( # Diurnal cycle of windspeed "python src/hirad/eval/diurnal_cycle_wind.py" # Probability of exceedance "python src/hirad/eval/probability_of_exceedance_wind.py" + # QQ + "python -m hirad.eval.bias_by_percentile_wind" # Maps "python src/hirad/eval/map_wind_stats.py" ) From 56987f5fb7ef7e9b1522ae381c4d1331d8cd8bc9 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 16:43:42 +0200 Subject: [PATCH 15/51] cleanup --- src/hirad/eval/bias_by_percentile_common.py | 84 ++------------------- 1 file changed, 7 insertions(+), 77 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index e090f872..c805ed87 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -1,14 +1,4 @@ -""" -Shared machinery for the *bias / MAE / spread by percentile* plots. - -Both the temperature and precipitation scripts build, for every grid point -``g`` (and ensemble member ``m``), a histogram of the field over time, estimate -the per-percentile quantile ``q_{g,m}(p)``, and then average spatially / across -members to produce the plotted curves. Everything that is identical between the -two variables lives here; each variable script only supplies a small -:class:`BiasByPercentileSpec` describing channels, binning, units and the -variable-specific plot styling. -""" +"""Shared machinery for the *bias / MAE / spread by percentile* plots.""" import concurrent.futures import logging from dataclasses import dataclass @@ -45,9 +35,6 @@ } -# --------------------------------------------------------------------------- # -# Data machinery -# --------------------------------------------------------------------------- # def to_flat(arr, conv: float, offset: float = 0.0) -> np.ndarray: """Convert a (possibly xarray) field to a flat float numpy array, scaled and shifted.""" return (np.asarray(getattr(arr, 'values', arr)) * conv + offset).ravel() @@ -66,16 +53,7 @@ def build_all_histograms( log_interval: int, logger: logging.Logger, ) -> tuple: - """Single serial pass over all timesteps, building histograms for every mode. - - *out_channels* / *in_channels* are tuples of channel indices; the per-channel - fields (after scaling) are combined into the plotted scalar by *reduce_fn* - (e.g. identity for single-channel fields, ``hypot`` for wind speed). - - Returns ``(det_counts, member_counts, n_members)`` where *det_counts* maps - mode → array-or-None and *member_counts* is a list of arrays (one per member) - or None. - """ + """Single serial pass over all timesteps, building histograms for every mode.""" n_land = land_idx.size n_bins = len(hist_bins) - 1 det_modes = ('target', 'baseline', 'regression-prediction') @@ -143,15 +121,7 @@ def accumulate(counts, channel_arrs): def per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, frac_percentiles: np.ndarray, block_size: int = 8192) -> np.ndarray: - """Estimate per-row quantiles from per-grid-point histograms. - - Returns ``(n_land, P)`` float32 array. Uses upper-bin-edge values (no in-bin - interpolation) — adequate given fine bins. - - Land points are processed in blocks of *block_size* rows so the CDF working - set fits comfortably in L3 cache; numpy releases the GIL for the large - C-level operations, enabling true thread parallelism across concurrent calls. - """ + """Estimate per-row quantiles from per-grid-point histograms.""" n_land, n_bins = pp_counts.shape P = len(frac_percentiles) result = np.empty((n_land, P), dtype=np.float32) @@ -189,10 +159,7 @@ def compute_quantiles( frac_percentiles: np.ndarray, n_workers: int, ) -> tuple: - """Compute per-point quantiles for every mode in parallel, freeing counts as we go. - - Returns ``(target_q, det_results, member_qs)``. - """ + """Compute per-point quantiles for every mode in parallel, freeing counts as we go.""" use_ensemble = has_ensemble and member_counts is not None and n_members is not None with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool: @@ -222,12 +189,7 @@ def submit(counts): def _ensemble_stats(member_qs, target_q, target_mean_q, n_members, P, n_land, mae_kind): - """Compute ensemble spread, the MAE plot entry, and per-member bias curves. - - *mae_kind* selects how the MAE band is built: - - ``'member_list'`` → list of per-member MAE curves (band = mean ±σ across members). - - ``'spatial_spread'`` → ``(mean, std)`` from per-land-point inter-member AE spread. - """ + """Compute ensemble spread, the MAE plot entry, and per-member bias curves.""" is_spatial = mae_kind == 'spatial_spread' sum_q = np.zeros((n_land, P), dtype=np.float64) sumsq_q = np.zeros((n_land, P), dtype=np.float64) @@ -263,27 +225,12 @@ def _ensemble_stats(member_qs, target_q, target_mean_q, n_members, P, n_land, ma return spread, mae_entry, member_biases -# --------------------------------------------------------------------------- # -# Plotting helpers -# --------------------------------------------------------------------------- # def new_percentile_axes(percentile_values: np.ndarray): """Create a figure/axes pair and return it together with the fractional x-values.""" fig, ax = plt.subplots(figsize=(10, 6)) return fig, ax, percentile_values / 100.0 -def _nice_step(rough: float) -> float: - """Round *rough* up to the nearest 1/2/2.5/5 ×10ⁿ 'nice' step.""" - if rough <= 0: - return 1.0 - exp = np.floor(np.log10(rough)) - base = 10.0 ** exp - for mult in (1.0, 2.0, 2.5, 5.0, 10.0): - if rough <= mult * base: - return mult * base - return 10.0 * base - - def _round_sig(x: float, sig: int = 2) -> float: """Round *x* to *sig* significant figures (clean axis labels).""" if x == 0 or not np.isfinite(x): @@ -294,15 +241,7 @@ def _round_sig(x: float, sig: int = 2) -> float: def even_value_ticks(frac: np.ndarray, mean_q: np.ndarray, target_ticks: int = 9) -> tuple: - """Pick secondary-axis ticks evenly spaced along the logit axis. - - The plotted x-axis is logit in *percentile*, while the secondary axis labels - the *mean target value*. We place *target_ticks* positions evenly spaced in - logit coordinates (so they look uniform on the axis), including both corners - — the first and last points that actually have data — then label each with - the interpolated value rounded to two significant figures for clean, - readable tick labels. Returns ``(positions, values)``. - """ + """Pick secondary-axis ticks evenly spaced along the logit axis.""" def _logit(p): p = np.clip(p, 1e-9, 1 - 1e-9) return np.log(p / (1.0 - p)) @@ -334,13 +273,7 @@ def _expit(z): def plot_dict_curves(ax, frac, data_dict, labels, colors, lower_clip=None) -> list: - """Plot per-mode curves and return the arrays spanning the plotted range. - - Entry types per dict value: - - ndarray → a single line. - - list of ndarrays → member curves, drawn as mean ±1 σ shading. - - ``(mean, std)`` tuple → an explicit band, drawn as mean ±1 σ shading. - """ + """Plot per-mode curves and return the arrays spanning the plotted range.""" all_vals = [] for (_key, data), label, color in zip(data_dict.items(), labels, colors): if isinstance(data, list): @@ -378,9 +311,6 @@ def finalize_percentile_plot(ax, frac, apply_xaxis, mean_q, xlabel, ylabel, plt.close() -# --------------------------------------------------------------------------- # -# Orchestration -# --------------------------------------------------------------------------- # @dataclass class BiasByPercentileSpec: """Variable-specific configuration for :func:`run_bias_by_percentile`.""" From 075103d00b12e50b39f0707c01d156bda562d512 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 16:47:01 +0200 Subject: [PATCH 16/51] morecleanup --- src/hirad/eval/bias_by_percentile_common.py | 27 ++++++++------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index c805ed87..c7bd143e 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -160,7 +160,7 @@ def compute_quantiles( n_workers: int, ) -> tuple: """Compute per-point quantiles for every mode in parallel, freeing counts as we go.""" - use_ensemble = has_ensemble and member_counts is not None and n_members is not None + members = member_counts if (has_ensemble and member_counts is not None and n_members is not None) else [] with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool: def submit(counts): @@ -168,11 +168,7 @@ def submit(counts): fut_target = submit(det_counts['target']) fut_det = {mode: submit(det_counts[mode]) for mode in active_det_modes} - if use_ensemble: - assert member_counts is not None and n_members is not None - fut_members = [submit(member_counts[m]) for m in range(n_members)] - else: - fut_members = [] + fut_members = [submit(c) for c in members] target_q = fut_target.result() del det_counts['target'] @@ -180,21 +176,19 @@ def submit(counts): for mode in active_det_modes: del det_counts[mode] member_qs = [f.result() for f in fut_members] - if use_ensemble: - assert member_counts is not None and n_members is not None - for m in range(n_members): - member_counts[m] = None + members.clear() return target_q, det_results, member_qs -def _ensemble_stats(member_qs, target_q, target_mean_q, n_members, P, n_land, mae_kind): +def _ensemble_stats(member_qs, target_q, target_mean_q, mae_kind): """Compute ensemble spread, the MAE plot entry, and per-member bias curves.""" is_spatial = mae_kind == 'spatial_spread' - sum_q = np.zeros((n_land, P), dtype=np.float64) - sumsq_q = np.zeros((n_land, P), dtype=np.float64) - sum_ae = np.zeros((n_land, P), dtype=np.float64) - sumsq_ae = np.zeros((n_land, P), dtype=np.float64) + n_members = len(member_qs) + sum_q = np.zeros(target_q.shape, dtype=np.float64) + sumsq_q = np.zeros(target_q.shape, dtype=np.float64) + sum_ae = np.zeros(target_q.shape, dtype=np.float64) + sumsq_ae = np.zeros(target_q.shape, dtype=np.float64) member_biases: list = [] member_maes: list = [] target_q_f = target_q.astype(np.float64) @@ -375,7 +369,6 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: percentile_values = spec.percentile_values frac_percentiles = percentile_values / 100.0 - P = len(frac_percentiles) logger.info(f"Resolving {len(times)} timestep directories ...") ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} @@ -419,7 +412,7 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: spread = None if has_ensemble: spread, mae_entry, member_biases = _ensemble_stats( - member_qs, target_q, target_mean_q, n_members, P, n_land, spec.mae_kind, + member_qs, target_q, target_mean_q, spec.mae_kind, ) bias_data['predictions'] = member_biases mae_data['predictions'] = mae_entry From 32162aeb1cc930004a795b9f06747402b08806ab Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 16:52:40 +0200 Subject: [PATCH 17/51] fix config --- src/hirad/eval_wind.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index a90f6c0a..ec5b4966 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -3,7 +3,7 @@ set -euo pipefail ### CONFIG ### -CONFIG_NAME="src/hirad/conf/eval_real_tst.yaml" +CONFIG_NAME="src/hirad/conf/eval_real.yaml" CMDS=( # Diurnal cycle of windspeed From 4a0d79e56e210ba9d9d5984e4e85e19ba0d200aa Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 3 Jun 2026 19:57:42 +0200 Subject: [PATCH 18/51] extend x-axis --- src/hirad/eval/bias_by_percentile_temp.py | 7 ++++--- src/hirad/eval/bias_by_percentile_wind.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 00dcf220..67ad10ba 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -18,12 +18,13 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: """Apply logit x-axis with labelled percentile ticks (and a °C secondary axis).""" ax.set_xscale('logit') - ax.set_xlim(frac[0], frac[-1]) + xlim_right = frac[-1] + 1e-9 + ax.set_xlim(frac[0], xlim_right) tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] # only show ticks within our data range valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) - if frac[0] <= f <= frac[-1]] + if frac[0] <= f <= xlim_right] ax.set_xticks([f for f, _ in valid_ticks]) ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') @@ -31,7 +32,7 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(frac[0], frac[-1]) + ax2.set_xlim(frac[0], xlim_right) # The axis is logit in percentile, so a fixed list of round temps bunches # near the median while the tails get no labels; even_value_ticks picks a # nice step and spreads labels evenly across the whole logit axis. diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 46e537ea..74ce0725 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -21,12 +21,13 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: """Apply logit x-axis with labelled percentile ticks (and an m/s secondary axis).""" ax.set_xscale('logit') - ax.set_xlim(frac[0], frac[-1]) + xlim_right = frac[-1] + 1e-9 + ax.set_xlim(frac[0], xlim_right) tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] # only show ticks within our data range valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) - if frac[0] <= f <= frac[-1]] + if frac[0] <= f <= xlim_right] ax.set_xticks([f for f, _ in valid_ticks]) ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') @@ -34,7 +35,7 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(frac[0], frac[-1]) + ax2.set_xlim(frac[0], xlim_right) # Wind speed spans only a few m/s, so integer-only ticks leave the # compressed tail unlabelled; even_value_ticks picks a nice sub-unit step # and spreads labels evenly across the whole logit axis. From 3bb134fd24b0f7682100835a78b0a12c3a7f38e0 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 4 Jun 2026 09:18:07 +0200 Subject: [PATCH 19/51] fix spread --- src/hirad/eval/bias_by_percentile_common.py | 47 ++++++++++----------- src/hirad/eval/bias_by_percentile_precip.py | 1 - src/hirad/eval/bias_by_percentile_temp.py | 1 - src/hirad/eval/bias_by_percentile_wind.py | 1 - 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index c7bd143e..92bd61e7 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -181,42 +181,42 @@ def submit(counts): return target_q, det_results, member_qs -def _ensemble_stats(member_qs, target_q, target_mean_q, mae_kind): - """Compute ensemble spread, the MAE plot entry, and per-member bias curves.""" - is_spatial = mae_kind == 'spatial_spread' +def _ensemble_stats(member_qs, target_q): + """Compute ensemble spread, and the MAE / bias plot entries. + + All three quantities use the same local-then-averaged estimator: a statistic + is formed per grid point across ensemble members, then averaged over land. + """ n_members = len(member_qs) sum_q = np.zeros(target_q.shape, dtype=np.float64) sumsq_q = np.zeros(target_q.shape, dtype=np.float64) sum_ae = np.zeros(target_q.shape, dtype=np.float64) sumsq_ae = np.zeros(target_q.shape, dtype=np.float64) - member_biases: list = [] - member_maes: list = [] target_q_f = target_q.astype(np.float64) for qm in member_qs: qm_f = qm.astype(np.float64) sum_q += qm sumsq_q += qm_f ** 2 - member_biases.append((qm.mean(axis=0) - target_mean_q).astype(np.float64)) - if is_spatial: - ae = np.abs(qm_f - target_q_f) - sum_ae += ae - sumsq_ae += ae ** 2 - else: - member_maes.append(np.abs(qm - target_q).mean(axis=0).astype(np.float64)) + ae = np.abs(qm_f - target_q_f) + sum_ae += ae + sumsq_ae += ae ** 2 mean_q = sum_q / n_members var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) - spread = np.sqrt(var_q).mean(axis=0) + std_q = np.sqrt(var_q) + spread = std_q.mean(axis=0) + + mean_ae = sum_ae / n_members + var_ae = np.maximum(sumsq_ae / n_members - mean_ae ** 2, 0.0) + mae_entry = (mean_ae.mean(axis=0), np.sqrt(var_ae).mean(axis=0)) - if is_spatial: - mean_ae = sum_ae / n_members - var_ae = np.maximum(sumsq_ae / n_members - mean_ae ** 2, 0.0) - mae_entry: object = (mean_ae.mean(axis=0), np.sqrt(var_ae).mean(axis=0)) - else: - mae_entry = member_maes + # Bias band, local-then-averaged: per grid point the member-mean bias is + # (mean_q - target_q) and the member-std of bias equals std_q (target is + # constant across members); both are then averaged over land. + bias_entry = ((mean_q - target_q_f).mean(axis=0), std_q.mean(axis=0)) - return spread, mae_entry, member_biases + return spread, mae_entry, bias_entry def new_percentile_axes(percentile_values: np.ndarray): @@ -317,7 +317,6 @@ class BiasByPercentileSpec: mae_ylabel: str spread_ylabel: str percentile_values: np.ndarray - mae_kind: str # 'member_list' | 'spatial_spread' resolve_channels: Callable[[dict], tuple] # indices -> (ch_out, ch_in); raises ValueError make_hist_bins: Callable[[dict], np.ndarray] read_scaling: Callable[[dict], tuple] # cfg -> (conv, offset) @@ -411,10 +410,10 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: spread = None if has_ensemble: - spread, mae_entry, member_biases = _ensemble_stats( - member_qs, target_q, target_mean_q, spec.mae_kind, + spread, mae_entry, bias_entry = _ensemble_stats( + member_qs, target_q, ) - bias_data['predictions'] = member_biases + bias_data['predictions'] = bias_entry mae_data['predictions'] = mae_entry labels.append(ENSEMBLE_LABEL) colors.append(ENSEMBLE_COLOR) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 7325fa65..d3197aee 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -97,7 +97,6 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), ])), - mae_kind='spatial_spread', resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, read_scaling=lambda cfg: (cfg.get("conv_factor_hourly", 1.0), 0.0), diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 67ad10ba..345765cd 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -117,7 +117,6 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), ])), - mae_kind='member_list', resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, # Default: convert Kelvin → °C (conv=1.0, offset=-273.15) diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 74ce0725..0641c99e 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -122,7 +122,6 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), ])), - mae_kind='member_list', resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, read_scaling=lambda cfg: (cfg.get("wind_conv_factor", 1.0), 0.0), From 464ac34ec898ac9dae9473fb949c09e185eca030 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 4 Jun 2026 12:21:30 +0200 Subject: [PATCH 20/51] add smoothed version for precip and FBI --- src/hirad/eval/bias_by_percentile_common.py | 69 ++++++++++++++++--- src/hirad/eval/bias_by_percentile_precip.py | 17 +++++ .../bias_by_percentile_precip_smoothed.py | 24 +++++++ src/hirad/eval/bias_by_percentile_temp.py | 17 +++++ src/hirad/eval/bias_by_percentile_wind.py | 17 +++++ 5 files changed, 134 insertions(+), 10 deletions(-) create mode 100644 src/hirad/eval/bias_by_percentile_precip_smoothed.py diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index 92bd61e7..36ad9dff 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -9,6 +9,7 @@ import matplotlib.ticker as mticker import numpy as np import torch +from scipy.ndimage import gaussian_filter from hirad.eval.eval_utils import ( get_channel_indices, @@ -40,6 +41,12 @@ def to_flat(arr, conv: float, offset: float = 0.0) -> np.ndarray: return (np.asarray(getattr(arr, 'values', arr)) * conv + offset).ravel() +def _smooth2d(arr, sigma: float) -> np.ndarray: + """Apply an isotropic Gaussian low-pass to a 2D field (grid-point sigma).""" + a = np.asarray(getattr(arr, 'values', arr), dtype=np.float32) + return gaussian_filter(a, sigma=sigma, mode='nearest') + + def build_all_histograms( times: list, ts_dirs: dict, @@ -52,6 +59,7 @@ def build_all_histograms( hist_bins: np.ndarray, log_interval: int, logger: logging.Logger, + smoothing_sigma: float | None = None, ) -> tuple: """Single serial pass over all timesteps, building histograms for every mode.""" n_land = land_idx.size @@ -67,6 +75,8 @@ def build_all_histograms( skip_preds = False def accumulate(counts, channel_arrs): + if smoothing_sigma is not None: + channel_arrs = [_smooth2d(a, smoothing_sigma) for a in channel_arrs] flats = [to_flat(a, conv, offset) for a in channel_arrs] vals = reduce_fn(flats)[land_idx] bin_idx = np.searchsorted(interior_edges, vals, side='right') @@ -182,9 +192,9 @@ def submit(counts): def _ensemble_stats(member_qs, target_q): - """Compute ensemble spread, and the MAE / bias plot entries. + """Compute ensemble spread, and the MAE / bias / FBI plot entries. - All three quantities use the same local-then-averaged estimator: a statistic + All quantities use the same local-then-averaged estimator: a statistic is formed per grid point across ensemble members, then averaged over land. """ n_members = len(member_qs) @@ -192,7 +202,11 @@ def _ensemble_stats(member_qs, target_q): sumsq_q = np.zeros(target_q.shape, dtype=np.float64) sum_ae = np.zeros(target_q.shape, dtype=np.float64) sumsq_ae = np.zeros(target_q.shape, dtype=np.float64) + sum_fbi = np.zeros(target_q.shape, dtype=np.float64) + sumsq_fbi = np.zeros(target_q.shape, dtype=np.float64) target_q_f = target_q.astype(np.float64) + # FBI is a quantile ratio q_pred / q_target; guard against zero targets. + target_q_safe = np.where(target_q_f != 0.0, target_q_f, np.nan) for qm in member_qs: qm_f = qm.astype(np.float64) @@ -201,6 +215,9 @@ def _ensemble_stats(member_qs, target_q): ae = np.abs(qm_f - target_q_f) sum_ae += ae sumsq_ae += ae ** 2 + fbi = qm_f / target_q_safe + sum_fbi += fbi + sumsq_fbi += fbi ** 2 mean_q = sum_q / n_members var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) @@ -216,7 +233,11 @@ def _ensemble_stats(member_qs, target_q): # constant across members); both are then averaged over land. bias_entry = ((mean_q - target_q_f).mean(axis=0), std_q.mean(axis=0)) - return spread, mae_entry, bias_entry + mean_fbi = sum_fbi / n_members + var_fbi = np.maximum(sumsq_fbi / n_members - mean_fbi ** 2, 0.0) + fbi_entry = (np.nanmean(mean_fbi, axis=0), np.nanmean(np.sqrt(var_fbi), axis=0)) + + return spread, mae_entry, bias_entry, fbi_entry def new_percentile_axes(percentile_values: np.ndarray): @@ -313,9 +334,11 @@ class BiasByPercentileSpec: bias_title: str mae_title: str spread_title: str + fbi_title: str bias_ylabel: str mae_ylabel: str spread_ylabel: str + fbi_ylabel: str percentile_values: np.ndarray resolve_channels: Callable[[dict], tuple] # indices -> (ch_out, ch_in); raises ValueError make_hist_bins: Callable[[dict], np.ndarray] @@ -323,6 +346,7 @@ class BiasByPercentileSpec: save_bias: Callable save_mae: Callable save_spread: Callable + save_fbi: Callable # Combines the (scaled) per-channel flat fields into the plotted scalar. # Defaults to the single-channel identity; wind speed uses ``hypot``. reduce_fn: Callable[[list], np.ndarray] = lambda flats: flats[0] @@ -366,6 +390,17 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: log_interval = cfg.get("log_interval", 24) conv, offset = spec.read_scaling(cfg) + smoothing_km = cfg.get("smoothing_sigma_km") + grid_res_km = cfg.get("grid_res_km", 1.0) + smoothing_sigma = (smoothing_km / grid_res_km) if smoothing_km else None + if smoothing_sigma is not None: + logger.info( + f"Gaussian smoothing enabled: sigma = {smoothing_km} km " + f"/ {grid_res_km} km = {smoothing_sigma:.3g} grid points" + ) + suffix = f"_smoothed{int(smoothing_km)}km" if smoothing_km else "" + title_note = f" ({int(smoothing_km)} km smoothed)" if smoothing_km else "" + percentile_values = spec.percentile_values frac_percentiles = percentile_values / 100.0 @@ -375,6 +410,7 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: det_counts, member_counts, n_members = build_all_histograms( times, ts_dirs, out_channels, in_channels, spec.reduce_fn, conv, offset, land_idx, hist_bins, log_interval, logger, + smoothing_sigma=smoothing_sigma, ) if det_counts.get('target') is None: @@ -393,9 +429,12 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: ) target_mean_q = target_q.mean(axis=0) + # FBI is a per-point quantile ratio q_pred / q_target; guard zero targets. + target_q_safe = np.where(target_q != 0.0, target_q, np.nan) bias_data: dict = {} mae_data: dict = {} + fbi_data: dict = {} labels: list = [] colors: list = [] @@ -405,43 +444,53 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: pred_q = det_results.pop(mode) bias_data[mode] = pred_q.mean(axis=0) - target_mean_q mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) + fbi_data[mode] = np.nanmean(pred_q / target_q_safe, axis=0) labels.append(label) colors.append(color) spread = None if has_ensemble: - spread, mae_entry, bias_entry = _ensemble_stats( + spread, mae_entry, bias_entry, fbi_entry = _ensemble_stats( member_qs, target_q, ) bias_data['predictions'] = bias_entry mae_data['predictions'] = mae_entry + fbi_data['predictions'] = fbi_entry labels.append(ENSEMBLE_LABEL) colors.append(ENSEMBLE_COLOR) output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "by_percentile" output_path.mkdir(parents=True, exist_ok=True) - fn = output_path / f'{spec.output_prefix}_bias_by_percentile.png' + fn = output_path / f'{spec.output_prefix}_bias_by_percentile{suffix}.png' spec.save_bias( bias_data, percentile_values, labels, colors, - title=spec.bias_title, xlabel='Percentile', ylabel=spec.bias_ylabel, + title=spec.bias_title + title_note, xlabel='Percentile', ylabel=spec.bias_ylabel, out_path=fn, mean_q=target_mean_q, ) logger.info(f"Bias-by-percentile plot saved: {fn}") - fn_mae = output_path / f'{spec.output_prefix}_mae_by_percentile.png' + fn_mae = output_path / f'{spec.output_prefix}_mae_by_percentile{suffix}.png' spec.save_mae( mae_data, percentile_values, labels, colors, - title=spec.mae_title, xlabel='Percentile', ylabel=spec.mae_ylabel, + title=spec.mae_title + title_note, xlabel='Percentile', ylabel=spec.mae_ylabel, out_path=fn_mae, mean_q=target_mean_q, ) logger.info(f"MAE-by-percentile plot saved: {fn_mae}") + fn_fbi = output_path / f'{spec.output_prefix}_fbi_by_percentile{suffix}.png' + spec.save_fbi( + fbi_data, percentile_values, labels, colors, + title=spec.fbi_title + title_note, xlabel='Percentile', ylabel=spec.fbi_ylabel, + out_path=fn_fbi, mean_q=target_mean_q, + ) + logger.info(f"FBI-by-percentile plot saved: {fn_fbi}") + if spread is not None: - fn_spread = output_path / f'{spec.output_prefix}_spread_by_percentile.png' + fn_spread = output_path / f'{spec.output_prefix}_spread_by_percentile{suffix}.png' spec.save_spread( spread, percentile_values, - title=spec.spread_title, xlabel='Percentile', ylabel=spec.spread_ylabel, + title=spec.spread_title + title_note, xlabel='Percentile', ylabel=spec.spread_ylabel, out_path=fn_spread, mean_q=target_mean_q, ) logger.info(f"Spread-by-percentile plot saved: {fn_spread}") diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index d3197aee..a2b5f84a 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -71,6 +71,20 @@ def save_spread_by_percentile_plot(spread, percentile_values, xlabel, ylabel, title, out_path, legend=False) +def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a frequency-bias-index-by-percentile figure (linear y-axis, ratio around 1).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=0) + ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 1.0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + def _resolve_channels(indices: dict) -> tuple: tp_out = indices['output']['tp'] tp_in = indices['input'].get('tp', tp_out) @@ -88,9 +102,11 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_title='Precipitation Bias Over Land', mae_title='Precipitation MAE Over Land', spread_title='Precipitation Ensemble Spread Over Land', + fbi_title='Precipitation Frequency Bias Index Over Land', bias_ylabel='Bias [mm/h]', mae_ylabel='MAE [mm/h]', spread_ylabel='Ensemble Spread [mm/h]', + fbi_ylabel='FBI [-]', percentile_values=np.unique(np.concatenate([ np.linspace(1.0, 90.0, 90), np.linspace(90.0, 99.0, 90), @@ -103,6 +119,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: save_bias=save_bias_by_percentile_plot, save_mae=save_mae_by_percentile_plot, save_spread=save_spread_by_percentile_plot, + save_fbi=save_fbi_by_percentile_plot, ) diff --git a/src/hirad/eval/bias_by_percentile_precip_smoothed.py b/src/hirad/eval/bias_by_percentile_precip_smoothed.py new file mode 100644 index 00000000..7cfe3b60 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_precip_smoothed.py @@ -0,0 +1,24 @@ +"""Smoothed precipitation bias / MAE / spread by percentile plots. + +Applies the same Gaussian low-pass filter to target and prediction fields before +the percentile histograms are accumulated, mitigating double-penalty effects in +the precipitation tails. +""" +from hirad.eval.bias_by_percentile_common import run_bias_by_percentile +from hirad.eval.bias_by_percentile_precip import SPEC +from hirad.eval.eval_utils import parse_eval_cli + + +SMOOTHING_SIGMA_KM = 20.0 +GRID_RES_KM = 1.0 + + +def main(cfg: dict) -> None: + cfg = dict(cfg) + cfg['smoothing_sigma_km'] = SMOOTHING_SIGMA_KM + cfg.setdefault('grid_res_km', GRID_RES_KM) + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) \ No newline at end of file diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 345765cd..48dff6ce 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -82,6 +82,20 @@ def save_spread_by_percentile_plot(spread, percentile_values, xlabel, ylabel, title, out_path) +def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a frequency-bias-index-by-percentile figure (linear y-axis, ratio around 1).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=0) + ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 1.0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + def _resolve_channels(indices: dict) -> tuple: # Temperature channel: try '2t' first, then 't2m' t2m_out = indices['output'].get('2t', indices['output'].get('t2m')) @@ -105,9 +119,11 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_title='2m Temperature Bias Over Land', mae_title='2m Temperature MAE Over Land', spread_title='2m Temperature Ensemble Spread Over Land', + fbi_title='2m Temperature Frequency Bias Index Over Land', bias_ylabel='Bias [°C]', mae_ylabel='MAE [°C]', spread_ylabel='Ensemble Spread [°C]', + fbi_ylabel='FBI [-]', percentile_values=np.unique(np.concatenate([ np.linspace(0.01, 0.1, 10), np.linspace(0.1, 1.0, 10), @@ -125,6 +141,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: save_bias=save_bias_by_percentile_plot, save_mae=save_mae_by_percentile_plot, save_spread=save_spread_by_percentile_plot, + save_fbi=save_fbi_by_percentile_plot, ) diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 0641c99e..71a845c3 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -85,6 +85,20 @@ def save_spread_by_percentile_plot(spread, percentile_values, xlabel, ylabel, title, out_path) +def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a frequency-bias-index-by-percentile figure (linear y-axis, ratio around 1).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=0) + ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 1.0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + def _resolve_channels(indices: dict) -> tuple: # Wind speed is derived from the two surface wind components. u_out = indices['output'].get('10u') @@ -110,9 +124,11 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_title='10 m Wind Speed Bias Over Land', mae_title='10 m Wind Speed MAE Over Land', spread_title='10 m Wind Speed Ensemble Spread Over Land', + fbi_title='10 m Wind Speed Frequency Bias Index Over Land', bias_ylabel='Bias [m/s]', mae_ylabel='MAE [m/s]', spread_ylabel='Ensemble Spread [m/s]', + fbi_ylabel='FBI [-]', percentile_values=np.unique(np.concatenate([ np.linspace(0.01, 0.1, 10), np.linspace(0.1, 1.0, 10), @@ -129,6 +145,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: save_bias=save_bias_by_percentile_plot, save_mae=save_mae_by_percentile_plot, save_spread=save_spread_by_percentile_plot, + save_fbi=save_fbi_by_percentile_plot, ) From ac2d64a71a7e829f3ef4b60522a29976af6582ae Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 4 Jun 2026 16:16:13 +0200 Subject: [PATCH 21/51] update bins --- src/hirad/eval/bias_by_percentile_precip.py | 5 +++-- src/hirad/eval/bias_by_percentile_temp.py | 6 ++++-- src/hirad/eval/bias_by_percentile_wind.py | 6 ++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index a2b5f84a..5b16852f 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -18,8 +18,8 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - """Apply logit x-axis with labelled percentile ticks (and a mm/h secondary axis).""" ax.set_xscale('logit') ax.set_xlim(0.5, frac[-1]) - tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] - tick_labels = ['50', '75', '90', '99', '99.9', '99.99'] + tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999, 0.99999] + tick_labels = ['50', '75', '90', '99', '99.9', '99.99', '99.999'] ax.set_xticks(tick_fracs) ax.set_xticklabels(tick_labels) ax.grid(True, alpha=0.3, which='both') @@ -112,6 +112,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: np.linspace(90.0, 99.0, 90), np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), + np.linspace(99.99, 99.999, 10), ])), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 48dff6ce..5eb430e8 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -20,8 +20,8 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - ax.set_xscale('logit') xlim_right = frac[-1] + 1e-9 ax.set_xlim(frac[0], xlim_right) - tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] - tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] + tick_fracs = [0.00001, 0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999, 0.99999] + tick_labels = ['0.001', '0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99', '99.999'] # only show ticks within our data range valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) if frac[0] <= f <= xlim_right] @@ -125,6 +125,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: spread_ylabel='Ensemble Spread [°C]', fbi_ylabel='FBI [-]', percentile_values=np.unique(np.concatenate([ + np.linspace(0.001, 0.01, 10), np.linspace(0.01, 0.1, 10), np.linspace(0.1, 1.0, 10), np.linspace(1.0, 10.0, 10), @@ -132,6 +133,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: np.linspace(90.0, 99.0, 90), np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), + np.linspace(99.99, 99.999, 10), ])), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 71a845c3..841c9896 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -23,8 +23,8 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - ax.set_xscale('logit') xlim_right = frac[-1] + 1e-9 ax.set_xlim(frac[0], xlim_right) - tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] - tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] + tick_fracs = [0.00001, 0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999, 0.99999] + tick_labels = ['0.001', '0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99', '99.999'] # only show ticks within our data range valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) if frac[0] <= f <= xlim_right] @@ -130,6 +130,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: spread_ylabel='Ensemble Spread [m/s]', fbi_ylabel='FBI [-]', percentile_values=np.unique(np.concatenate([ + np.linspace(0.001, 0.01, 10), np.linspace(0.01, 0.1, 10), np.linspace(0.1, 1.0, 10), np.linspace(1.0, 10.0, 10), @@ -137,6 +138,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: np.linspace(90.0, 99.0, 90), np.linspace(99.0, 99.9, 45), np.linspace(99.9, 99.99, 20), + np.linspace(99.99, 99.999, 10), ])), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, From 5add4f2d90b7446e6899b882906811b2a85d829a Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 4 Jun 2026 16:21:29 +0200 Subject: [PATCH 22/51] change formulation for the FBI --- src/hirad/eval/bias_by_percentile_common.py | 120 +++++++++++++++----- src/hirad/eval/bias_by_percentile_temp.py | 17 --- 2 files changed, 90 insertions(+), 47 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index 36ad9dff..8514c649 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -159,6 +159,44 @@ def per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, return result +def per_point_exceedance(pp_counts: np.ndarray, bin_edges: np.ndarray, + thresholds: np.ndarray, + block_size: int = 8192) -> np.ndarray: + """Estimate per-row exceedance probabilities P(value > threshold). + + ``thresholds`` is ``(n_land, P)`` (one threshold per grid point and + percentile, e.g. the target's per-point quantiles). Returns an array of the + same shape with the fraction of this mode's mass strictly above each + threshold, read from the per-grid-point histogram CDF. This drives the + frequency bias index (FBI), a ratio of exceedance frequencies. + """ + n_land, P = thresholds.shape + n_bins = pp_counts.shape[1] + result = np.empty((n_land, P), dtype=np.float32) + edges_upper = bin_edges[1:].astype(np.float32) + + for start in range(0, n_land, block_size): + end = min(start + block_size, n_land) + blk = pp_counts[start:end] + B = end - start + + cdf = np.cumsum(blk, axis=1, dtype=np.float64) + totals = cdf[:, -1:] + cdf /= np.maximum(totals, 1.0) + + # Bin whose upper edge first reaches the threshold; the CDF up to and + # including that bin is the non-exceedance probability. + thr = thresholds[start:end] + bin_idx = np.empty((B, P), dtype=np.intp) + for j in range(P): + bin_idx[:, j] = np.searchsorted(edges_upper, thr[:, j], side='left') + np.clip(bin_idx, 0, n_bins - 1, out=bin_idx) + non_exc = np.take_along_axis(cdf, bin_idx, axis=1) + result[start:end] = np.clip(1.0 - non_exc, 0.0, 1.0) + + return result + + def compute_quantiles( det_counts: dict, member_counts: list | None, @@ -169,33 +207,48 @@ def compute_quantiles( frac_percentiles: np.ndarray, n_workers: int, ) -> tuple: - """Compute per-point quantiles for every mode in parallel, freeing counts as we go.""" + """Compute per-point quantiles for every mode in parallel, freeing counts as we go. + + Alongside the per-point quantiles, each prediction mode / member also gets + its per-point exceedance probability evaluated at the target's per-point + quantiles (the thresholds), which feeds the frequency bias index. + """ members = member_counts if (has_ensemble and member_counts is not None and n_members is not None) else [] with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool: - def submit(counts): + def submit_q(counts): return pool.submit(per_point_quantiles, counts, hist_bins, frac_percentiles) - fut_target = submit(det_counts['target']) - fut_det = {mode: submit(det_counts[mode]) for mode in active_det_modes} - fut_members = [submit(c) for c in members] + def submit_exc(counts, thresholds): + return pool.submit(per_point_exceedance, counts, hist_bins, thresholds) - target_q = fut_target.result() + # Target quantiles first: they are the thresholds for every exceedance. + target_q = per_point_quantiles(det_counts['target'], hist_bins, frac_percentiles) del det_counts['target'] - det_results = {mode: fut_det[mode].result() for mode in active_det_modes} + + fut_det_q = {mode: submit_q(det_counts[mode]) for mode in active_det_modes} + fut_det_exc = {mode: submit_exc(det_counts[mode], target_q) for mode in active_det_modes} + fut_member_q = [submit_q(c) for c in members] + fut_member_exc = [submit_exc(c, target_q) for c in members] + + det_results = {mode: fut_det_q[mode].result() for mode in active_det_modes} + det_exceedance = {mode: fut_det_exc[mode].result() for mode in active_det_modes} for mode in active_det_modes: del det_counts[mode] - member_qs = [f.result() for f in fut_members] + member_qs = [f.result() for f in fut_member_q] + member_exc = [f.result() for f in fut_member_exc] members.clear() - return target_q, det_results, member_qs + return target_q, det_results, det_exceedance, member_qs, member_exc -def _ensemble_stats(member_qs, target_q): +def _ensemble_stats(member_qs, member_exc, target_q, frac_percentiles): """Compute ensemble spread, and the MAE / bias / FBI plot entries. All quantities use the same local-then-averaged estimator: a statistic is formed per grid point across ensemble members, then averaged over land. + FBI is the frequency bias index: per grid point and percentile ``p`` it is + ``P(pred > q_target(p)) / P(target > q_target(p)) = P(pred > q_target(p)) / (1 - p)``. """ n_members = len(member_qs) sum_q = np.zeros(target_q.shape, dtype=np.float64) @@ -205,17 +258,17 @@ def _ensemble_stats(member_qs, target_q): sum_fbi = np.zeros(target_q.shape, dtype=np.float64) sumsq_fbi = np.zeros(target_q.shape, dtype=np.float64) target_q_f = target_q.astype(np.float64) - # FBI is a quantile ratio q_pred / q_target; guard against zero targets. - target_q_safe = np.where(target_q_f != 0.0, target_q_f, np.nan) + # Target exceedance probability P(target > q_target(p)) = 1 - p; guard p -> 1. + target_exc = np.maximum(1.0 - frac_percentiles.astype(np.float64), 1e-12)[None, :] - for qm in member_qs: + for qm, exc in zip(member_qs, member_exc): qm_f = qm.astype(np.float64) sum_q += qm sumsq_q += qm_f ** 2 ae = np.abs(qm_f - target_q_f) sum_ae += ae sumsq_ae += ae ** 2 - fbi = qm_f / target_q_safe + fbi = exc.astype(np.float64) / target_exc sum_fbi += fbi sumsq_fbi += fbi ** 2 @@ -334,11 +387,9 @@ class BiasByPercentileSpec: bias_title: str mae_title: str spread_title: str - fbi_title: str bias_ylabel: str mae_ylabel: str spread_ylabel: str - fbi_ylabel: str percentile_values: np.ndarray resolve_channels: Callable[[dict], tuple] # indices -> (ch_out, ch_in); raises ValueError make_hist_bins: Callable[[dict], np.ndarray] @@ -346,7 +397,10 @@ class BiasByPercentileSpec: save_bias: Callable save_mae: Callable save_spread: Callable - save_fbi: Callable + # FBI is optional: leave ``save_fbi`` as ``None`` to skip the plot entirely. + save_fbi: Callable | None = None + fbi_title: str | None = None + fbi_ylabel: str | None = None # Combines the (scaled) per-channel flat fields into the plotted scalar. # Defaults to the single-channel identity; wind speed uses ``hypot``. reduce_fn: Callable[[list], np.ndarray] = lambda flats: flats[0] @@ -423,14 +477,16 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: n_tasks = 1 + len(active_det_modes) + (n_members if has_ensemble else 0) n_quant_workers = cfg.get("n_quant_workers", n_tasks) - target_q, det_results, member_qs = compute_quantiles( + target_q, det_results, det_exceedance, member_qs, member_exc = compute_quantiles( det_counts, member_counts, n_members, active_det_modes, has_ensemble, hist_bins, frac_percentiles, n_quant_workers, ) target_mean_q = target_q.mean(axis=0) - # FBI is a per-point quantile ratio q_pred / q_target; guard zero targets. - target_q_safe = np.where(target_q != 0.0, target_q, np.nan) + want_fbi = spec.save_fbi is not None + # FBI is the frequency bias index: P(pred > q_target(p)) / P(target > q_target(p)), + # where P(target > q_target(p)) = 1 - p by construction. Guard p -> 1. + target_exc = np.maximum(1.0 - frac_percentiles, 1e-12) bias_data: dict = {} mae_data: dict = {} @@ -442,20 +498,23 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: if mode not in det_results: continue pred_q = det_results.pop(mode) + pred_exc = det_exceedance.pop(mode) bias_data[mode] = pred_q.mean(axis=0) - target_mean_q mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) - fbi_data[mode] = np.nanmean(pred_q / target_q_safe, axis=0) + if want_fbi: + fbi_data[mode] = np.nanmean(pred_exc / target_exc[None, :], axis=0) labels.append(label) colors.append(color) spread = None if has_ensemble: spread, mae_entry, bias_entry, fbi_entry = _ensemble_stats( - member_qs, target_q, + member_qs, member_exc, target_q, frac_percentiles, ) bias_data['predictions'] = bias_entry mae_data['predictions'] = mae_entry - fbi_data['predictions'] = fbi_entry + if want_fbi: + fbi_data['predictions'] = fbi_entry labels.append(ENSEMBLE_LABEL) colors.append(ENSEMBLE_COLOR) @@ -478,13 +537,14 @@ def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: ) logger.info(f"MAE-by-percentile plot saved: {fn_mae}") - fn_fbi = output_path / f'{spec.output_prefix}_fbi_by_percentile{suffix}.png' - spec.save_fbi( - fbi_data, percentile_values, labels, colors, - title=spec.fbi_title + title_note, xlabel='Percentile', ylabel=spec.fbi_ylabel, - out_path=fn_fbi, mean_q=target_mean_q, - ) - logger.info(f"FBI-by-percentile plot saved: {fn_fbi}") + if want_fbi: + fn_fbi = output_path / f'{spec.output_prefix}_fbi_by_percentile{suffix}.png' + spec.save_fbi( + fbi_data, percentile_values, labels, colors, + title=spec.fbi_title + title_note, xlabel='Percentile', ylabel=spec.fbi_ylabel, + out_path=fn_fbi, mean_q=target_mean_q, + ) + logger.info(f"FBI-by-percentile plot saved: {fn_fbi}") if spread is not None: fn_spread = output_path / f'{spec.output_prefix}_spread_by_percentile{suffix}.png' diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 5eb430e8..78935a33 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -82,20 +82,6 @@ def save_spread_by_percentile_plot(spread, percentile_values, xlabel, ylabel, title, out_path) -def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, - title, xlabel, ylabel, out_path, mean_q=None) -> None: - """Save a frequency-bias-index-by-percentile figure (linear y-axis, ratio around 1).""" - _, ax, frac = new_percentile_axes(percentile_values) - all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=0) - ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') - if all_vals: - ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 - if ymax > 1.0: - ax.set_ylim(0, ymax) - finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, - xlabel, ylabel, title, out_path) - - def _resolve_channels(indices: dict) -> tuple: # Temperature channel: try '2t' first, then 't2m' t2m_out = indices['output'].get('2t', indices['output'].get('t2m')) @@ -119,11 +105,9 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_title='2m Temperature Bias Over Land', mae_title='2m Temperature MAE Over Land', spread_title='2m Temperature Ensemble Spread Over Land', - fbi_title='2m Temperature Frequency Bias Index Over Land', bias_ylabel='Bias [°C]', mae_ylabel='MAE [°C]', spread_ylabel='Ensemble Spread [°C]', - fbi_ylabel='FBI [-]', percentile_values=np.unique(np.concatenate([ np.linspace(0.001, 0.01, 10), np.linspace(0.01, 0.1, 10), @@ -143,7 +127,6 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: save_bias=save_bias_by_percentile_plot, save_mae=save_mae_by_percentile_plot, save_spread=save_spread_by_percentile_plot, - save_fbi=save_fbi_by_percentile_plot, ) From a9e7b0fc39013423b840e3c0285fcca5c8283094 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 4 Jun 2026 16:40:35 +0200 Subject: [PATCH 23/51] add smoothed --- src/hirad/eval_precip.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index aacdae03..0f4241a4 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -14,6 +14,7 @@ CMDS=( "python src/hirad/eval/probability_of_exceedance.py" # QQ "python -m hirad.eval.bias_by_percentile_precip" + "python -m hirad.eval.bias_by_percentile_precip_smoothed" # Maps "python src/hirad/eval/map_precip_stats.py" "python -m hirad.eval.diurnal_cycle_precip_maps" From 3e6436c99f5502d0a8decd398f4d1fde32a5307c Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 9 Jun 2026 10:20:51 +0200 Subject: [PATCH 24/51] smoothed wind plot --- .../eval/bias_by_percentile_wind_smoothed.py | 24 +++++++++++++++++++ src/hirad/eval_wind.sh | 1 + 2 files changed, 25 insertions(+) create mode 100644 src/hirad/eval/bias_by_percentile_wind_smoothed.py diff --git a/src/hirad/eval/bias_by_percentile_wind_smoothed.py b/src/hirad/eval/bias_by_percentile_wind_smoothed.py new file mode 100644 index 00000000..65020728 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_wind_smoothed.py @@ -0,0 +1,24 @@ +"""Smoothed 10 m wind speed bias / MAE / spread / FBI by percentile plots. + +Applies the same Gaussian low-pass filter to target and prediction fields before +the percentile histograms are accumulated, mitigating double-penalty effects in +the wind speed tails. +""" +from hirad.eval.bias_by_percentile_common import run_bias_by_percentile +from hirad.eval.bias_by_percentile_wind import SPEC +from hirad.eval.eval_utils import parse_eval_cli + + +SMOOTHING_SIGMA_KM = 20.0 +GRID_RES_KM = 1.0 + + +def main(cfg: dict) -> None: + cfg = dict(cfg) + cfg['smoothing_sigma_km'] = SMOOTHING_SIGMA_KM + cfg.setdefault('grid_res_km', GRID_RES_KM) + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index ec5b4966..8dc2624c 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -12,6 +12,7 @@ CMDS=( "python src/hirad/eval/probability_of_exceedance_wind.py" # QQ "python -m hirad.eval.bias_by_percentile_wind" + "python -m hirad.eval.bias_by_percentile_wind_smoothed" # Maps "python src/hirad/eval/map_wind_stats.py" ) From 3930d6977188fe96dee267a0c183f9a963ba7b51 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 9 Jun 2026 10:21:16 +0200 Subject: [PATCH 25/51] change x-axis --- src/hirad/eval/bias_by_percentile_precip.py | 25 ++++++++++++--------- src/hirad/eval/bias_by_percentile_wind.py | 25 +++++++++++---------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 5b16852f..d13ab4f0 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -14,24 +14,26 @@ from hirad.eval.eval_utils import parse_eval_cli -def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, + xlim_left: float = 0.5) -> None: """Apply logit x-axis with labelled percentile ticks (and a mm/h secondary axis).""" ax.set_xscale('logit') - ax.set_xlim(0.5, frac[-1]) - tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999, 0.99999] - tick_labels = ['50', '75', '90', '99', '99.9', '99.99', '99.999'] - ax.set_xticks(tick_fracs) - ax.set_xticklabels(tick_labels) + ax.set_xlim(xlim_left, frac[-1]) + all_tick_fracs = [0.10, 0.25, 0.50, 0.75, 0.90, 0.99, 0.999, 0.9999, 0.99999] + all_tick_labels = ['10', '25', '50', '75', '90', '99', '99.9', '99.99', '99.999'] + valid_ticks = [(f, l) for f, l in zip(all_tick_fracs, all_tick_labels) + if xlim_left <= f <= frac[-1]] + ax.set_xticks([f for f, _ in valid_ticks]) + ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(0.5, frac[-1]) - # Place ticks at fixed "nice" mm/h values, positioning them by inverting mean_q + ax2.set_xlim(xlim_left, frac[-1]) nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) tick_positions = np.interp(nice_mmh, mean_q, frac) - valid = (tick_positions > 0.5) & (tick_positions < frac[-1]) + valid = (tick_positions > xlim_left) & (tick_positions < frac[-1]) ax2.set_xticks(tick_positions[valid]) ax2.set_xticklabels([f'{v:g}' for v in nice_mmh[valid]]) ax2.set_xlabel('Mean target [mm/h]') @@ -81,8 +83,9 @@ def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 if ymax > 1.0: ax.set_ylim(0, ymax) - finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, - xlabel, ylabel, title, out_path) + finalize_percentile_plot(ax, frac, + lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), + mean_q, xlabel, ylabel, title, out_path) def _resolve_channels(indices: dict) -> tuple: diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 841c9896..4f77298e 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -18,16 +18,18 @@ from hirad.eval.eval_utils import parse_eval_cli -def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, + xlim_left: float | None = None) -> None: """Apply logit x-axis with labelled percentile ticks (and an m/s secondary axis).""" ax.set_xscale('logit') + if xlim_left is None: + xlim_left = float(frac[0]) xlim_right = frac[-1] + 1e-9 - ax.set_xlim(frac[0], xlim_right) + ax.set_xlim(xlim_left, xlim_right) tick_fracs = [0.00001, 0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999, 0.99999] tick_labels = ['0.001', '0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99', '99.999'] - # only show ticks within our data range valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) - if frac[0] <= f <= xlim_right] + if xlim_left <= f <= xlim_right] ax.set_xticks([f for f, _ in valid_ticks]) ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') @@ -35,13 +37,11 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(frac[0], xlim_right) - # Wind speed spans only a few m/s, so integer-only ticks leave the - # compressed tail unlabelled; even_value_ticks picks a nice sub-unit step - # and spreads labels evenly across the whole logit axis. + ax2.set_xlim(xlim_left, xlim_right) tick_positions, tick_speeds = even_value_ticks(frac, mean_q) - ax2.set_xticks(tick_positions) - ax2.set_xticklabels([f'{v:g}' for v in tick_speeds]) + valid = (tick_positions >= xlim_left) & (tick_positions <= xlim_right) + ax2.set_xticks(tick_positions[valid]) + ax2.set_xticklabels([f'{v:g}' for v in tick_speeds[valid]]) ax2.set_xlabel('Mean target [m/s]') @@ -95,8 +95,9 @@ def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 if ymax > 1.0: ax.set_ylim(0, ymax) - finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, - xlabel, ylabel, title, out_path) + finalize_percentile_plot(ax, frac, + lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), + mean_q, xlabel, ylabel, title, out_path) def _resolve_channels(indices: dict) -> tuple: From ae7c300b195044e714800434294d897f3b3f307f Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 9 Jun 2026 10:21:43 +0200 Subject: [PATCH 26/51] reduce memory --- .../diurnal_cycle_precip_high_percentiles.py | 160 +++++++++++++----- 1 file changed, 113 insertions(+), 47 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py b/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py index 8f0eb612..5dc2ef62 100644 --- a/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py +++ b/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py @@ -17,7 +17,6 @@ from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir - def save_plot(hours, lines, labels, ylabel, title, out_path): Path(out_path).parent.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(8,4)) @@ -78,59 +77,126 @@ def main(cfg: dict): pct_mean = {key: {} for _, key, _ in percentile_configs} pct_std = {key: {} for _, key, _ in percentile_configs} - # -- Process target and baseline -- + conv_factor = cfg.get("conv_factor") + land_idx = land_bool.values # 1D boolean mask over flattened (lat, lon) + n_land = int(land_idx.sum()) + quantiles = np.array([q for q, _, _ in percentile_configs]) + logger.info(f"Land pixels: {n_land} / {land_idx.size} ({100 * n_land / land_idx.size:.1f}%)") + + # Group timesteps by hour-of-day so we never hold all timesteps in memory at once. + times_by_hour = {} + for ts in times: + hour = datetime.strptime(ts, "%Y%m%d-%H%M").hour + times_by_hour.setdefault(hour, []).append(ts) + sorted_hours = sorted(times_by_hour) + counts_per_hour = {h: len(times_by_hour[h]) for h in sorted_hours} + logger.info(f"Grouped {len(times)} timesteps into {len(sorted_hours)} hours; timesteps/hour: {counts_per_hour}") + + def land_values(arr): + """Flatten a (lat, lon) array and keep only land pixels as float32.""" + return np.asarray(arr, dtype=np.float32).reshape(-1)[land_idx] + + # -- Process target, baseline and regression-prediction -- + # For each hour we collect only land pixels across that hour's timesteps, then take + # the spatial-mean of the per-hour quantiles. Memory scales with one hour's data. for mode in ['target', 'baseline', 'regression-prediction']: logger.info(f"Processing mode: {mode}") - - data_list = [] - try: - for ts in times: - data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor") - data_list.append(data) - except: - logger.error(f"Error loading data for mode {mode}. Skipping.") + ch = tp_out if mode in ['target', 'regression-prediction'] else tp_in + + per_hour_pct = {key: [] for _, key, _ in percentile_configs} + failed = False + for hi, hour in enumerate(sorted_hours): + n_ts = counts_per_hour[hour] + logger.info( + f"[{mode}] hour {hour:02d} ({hi + 1}/{len(sorted_hours)}): " + f"loading {n_ts} timesteps" + ) + hour_vals = [] + try: + for ts in times_by_hour[hour]: + data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[ch] + hour_vals.append(land_values(data) * conv_factor) + except Exception as exc: + logger.error(f"Error loading data for mode {mode} at hour {hour:02d}: {exc!r}. Skipping mode.") + failed = True + break + + stacked = np.stack(hour_vals, axis=0) # [n_times_this_hour, n_land] + del hour_vals + logger.info( + f"[{mode}] hour {hour:02d}: stacked array {stacked.shape}" + ) + # quantile over time, then mean over space -> one value per quantile for this hour + q_vals = np.nanquantile(stacked, quantiles, axis=0) # [n_q, n_land] + del stacked + q_means = np.nanmean(q_vals, axis=1) # [n_q] + for (_, key, _), val in zip(percentile_configs, q_means): + per_hour_pct[key].append(val) + + if failed: continue - da = xr.DataArray( - np.stack(data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times], - 'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']} - ) - - # Select only land pixels to avoid all-NaN slices in quantile - da_land = da.stack(space=('lat', 'lon')).isel(space=land_bool.values) - - for q, key, _ in percentile_configs: - hourly_pct = da_land.groupby('time.hour').quantile(q, dim='time') - pct_mean[key][mode] = hourly_pct.mean(dim='space') + for _, key, _ in percentile_configs: + pct_mean[key][mode] = xr.DataArray( + np.array(per_hour_pct[key]), dims=['hour'], coords={'hour': sorted_hours} + ) + logger.info(f"Finished mode: {mode}") # -- Predictions: compute per hour per member, then mean+std across members -- logger.info("Processing predictions") - - # Load all prediction data at once into xarray - pred_data_list = [] - for ts in times: - preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor") # [n_members, n_channels, lat, lon] - tp_data = preds[:, tp_out] # [n_members, lat, lon] - tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'], - coords={'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']}) - pred_data_list.append(tp_da) - - pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] - pred_da = pred_da.assign_coords({ - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - }) - pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') - - # Select only land pixels to avoid all-NaN slices in quantile - pred_da_land = pred_da.stack(space=('lat', 'lon')).isel(space=land_bool.values) - - for q, key, label in percentile_configs: - logger.info(f'Calculating {label} percentile for predictions') - hourly_pct_by_member = pred_da_land.groupby('time.hour').quantile(q, dim='time').mean(dim='space') - pct_mean[key]['prediction'] = hourly_pct_by_member.mean(dim='member') - pct_std[key]['prediction'] = hourly_pct_by_member.std(dim='member') + + pred_hour_mean = {key: [] for _, key, _ in percentile_configs} + pred_hour_std = {key: [] for _, key, _ in percentile_configs} + + for hi, hour in enumerate(sorted_hours): + n_ts = counts_per_hour[hour] + logger.info( + f"[predictions] hour {hour:02d} ({hi + 1}/{len(sorted_hours)}): " + f"loading {n_ts} timesteps" + ) + # Accumulate land pixels per member: member -> list of [n_land] arrays over this hour's times + member_vals = None + for ts in times_by_hour[hour]: + preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon] + tp_data = np.asarray(preds[:, tp_out], dtype=np.float32) * conv_factor # [n_members, lat, lon] + n_members = tp_data.shape[0] + if member_vals is None: + member_vals = [[] for _ in range(n_members)] + flat = tp_data.reshape(n_members, -1)[:, land_idx] # [n_members, n_land] + for m in range(n_members): + member_vals[m].append(flat[m]) + del preds, tp_data, flat + + n_members = len(member_vals) + logger.info( + f"[predictions] hour {hour:02d}: {n_members} members x {n_ts} timesteps" + ) + + # Per-member: quantile over this hour's timesteps, then spatial mean -> [n_q] per member + per_member_q = [] + for vals in member_vals: + stacked = np.stack(vals, axis=0) # [n_times_this_hour, n_land] + q_vals = np.nanquantile(stacked, quantiles, axis=0) # [n_q, n_land] + del stacked + per_member_q.append(np.nanmean(q_vals, axis=1)) # [n_q] + per_member_arr = np.stack(per_member_q, axis=0) # [n_members, n_q] + del member_vals + + q_mean_over_members = per_member_arr.mean(axis=0) # [n_q] + q_std_over_members = per_member_arr.std(axis=0) # [n_q] + for i, (_, key, _) in enumerate(percentile_configs): + pred_hour_mean[key].append(q_mean_over_members[i]) + pred_hour_std[key].append(q_std_over_members[i]) + + logger.info("Finished predictions") + + for _, key, _ in percentile_configs: + pct_mean[key]['prediction'] = xr.DataArray( + np.array(pred_hour_mean[key]), dims=['hour'], coords={'hour': sorted_hours} + ) + pct_std[key]['prediction'] = xr.DataArray( + np.array(pred_hour_std[key]), dims=['hour'], coords={'hour': sorted_hours} + ) # Prepare cyclic lists for plotting def cycle_fn(x): From e12c4f704b51fbdf47ccd35afaf5005ac9783bed Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 9 Jun 2026 11:48:55 +0200 Subject: [PATCH 27/51] plot wethours for more tresholds --- .../diurnal_cycle_precip_mean_wet-hour.py | 88 +++++++++++++++---- 1 file changed, 73 insertions(+), 15 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py index 1bc62092..f8d2ce75 100644 --- a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -10,6 +10,9 @@ from hirad.eval.eval_utils import concat_and_group_diurnal, get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir +ALLHOUR_THRESHOLDS = [0.1, 1.0, 10.0, 100.0] # mm/h + + def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) plt.figure(figsize=(8,4)) @@ -31,6 +34,37 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): plt.savefig(out_path) plt.close() + +def save_allhour_wethours_plot(thresholds, mode_data, labels, title, out_path): + """Bar chart of all-hour wet-hour fraction (%) per threshold. + + mode_data: list of dicts {thr: (mean_pct, std_pct_or_None)} + """ + n_thr = len(thresholds) + n_modes = len(mode_data) + x = np.arange(n_thr) + width = 0.7 / n_modes + fig, ax = plt.subplots(figsize=(8, 5)) + for i, (data, label) in enumerate(zip(mode_data, labels)): + offset = (i - n_modes / 2 + 0.5) * width + means = [data[thr][0] for thr in thresholds] + stds = [data[thr][1] if data[thr][1] is not None else 0.0 for thr in thresholds] + ax.bar(x + offset, means, width, label=label, alpha=0.8) + if any(s > 0 for s in stds): + ax.errorbar(x + offset, means, yerr=stds, fmt='none', color='black', capsize=3) + ax.set_xticks(x) + ax.set_xticklabels([f'>{thr:g} mm/h' for thr in thresholds]) + ax.set_ylabel('Wet-Hour Fraction [%]') + ax.set_yscale('log') + ax.set_title(title) + ax.legend() + ax.grid(True, axis='y', alpha=0.3, which='both') + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + + def main(cfg: dict): # Setup logging logging.basicConfig(level=logging.INFO) @@ -61,7 +95,10 @@ def main(cfg: dict): # Prepare lists to collect DataArrays target_precip, baseline_precip, pred_precip, mean_pred_precip = [], [], [], [] - target_wet, baseline_wet, pred_wet, mean_pred_wet = [], [], [], [] + wet_target = {thr: [] for thr in ALLHOUR_THRESHOLDS} + wet_baseline = {thr: [] for thr in ALLHOUR_THRESHOLDS} + wet_pred = {thr: [] for thr in ALLHOUR_THRESHOLDS} + wet_regpred = {thr: [] for thr in ALLHOUR_THRESHOLDS} # Collect data for idx, ts in enumerate(times, 1): @@ -95,12 +132,13 @@ def main(cfg: dict): if mean_pred is not None: mean_pred_precip.append(da_mean_pred.mean(dim=("lat","lon")).assign_coords(time=dt)) - # Wet-hour fraction, i.e., freq(precip) > wet_threshold - target_wet.append(((da_target / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt))) - baseline_wet.append(((da_baseline / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt))) - pred_wet.append(((da_preds / 24> cfg.get("wet_threshold")).mean(dim=("lat","lon")).assign_coords(time=dt))) - if mean_pred is not None: - mean_pred_wet.append(((da_mean_pred / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt))) + # Wet-hour fraction per threshold + for thr in ALLHOUR_THRESHOLDS: + wet_target[thr].append((da_target / 24 > thr).mean().assign_coords(time=dt)) + wet_baseline[thr].append((da_baseline / 24 > thr).mean().assign_coords(time=dt)) + wet_pred[thr].append((da_preds / 24 > thr).mean(dim=('lat', 'lon')).assign_coords(time=dt)) + if mean_pred is not None: + wet_regpred[thr].append((da_mean_pred / 24 > thr).mean().assign_coords(time=dt)) if idx % cfg.get("log_interval") == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") @@ -112,11 +150,11 @@ def main(cfg: dict): if mean_pred_precip: amount_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_precip) - wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to obtain percentages - wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0) - wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0) - if mean_pred_wet: - wet_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wet, scale=100.0) + wet_target_mean, _ = concat_and_group_diurnal(wet_target[0.1], scale=100.0) + wet_baseline_mean, _ = concat_and_group_diurnal(wet_baseline[0.1], scale=100.0) + wet_pred_mean, wet_pred_std = concat_and_group_diurnal(wet_pred[0.1], is_member=True, scale=100.0) + if wet_regpred[0.1]: + wet_mean_pred_mean, _ = concat_and_group_diurnal(wet_regpred[0.1], scale=100.0) # Generate plots output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") @@ -132,14 +170,34 @@ def main(cfg: dict): ) save_plot( wet_target_mean.hour, - [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if mean_pred_wet else [wet_target_mean, wet_baseline_mean, wet_pred_mean], - [None, None, wet_pred_std, None] if mean_pred_wet else [None, None, wet_pred_std], - ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wet else ['Target','Input','CorrDiff ± Std(Members)'], + [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if wet_regpred[0.1] else [wet_target_mean, wet_baseline_mean, wet_pred_mean], + [None, None, wet_pred_std, None] if wet_regpred[0.1] else [None, None, wet_pred_std], + ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if wet_regpred[0.1] else ['Target','Input','CorrDiff ± Std(Members)'], 'Wet-Hour Fraction [%]', 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', output_path / 'diurnal_cycle_precip_wethours.png' ) + # All-hour wet-hour fraction bar chart + pred_all = {thr: xr.concat(wet_pred[thr], dim='time').values.ravel() for thr in ALLHOUR_THRESHOLDS} + allhour_mode_data = [ + {thr: (float(xr.concat(wet_target[thr], dim='time').mean()) * 100, None) for thr in ALLHOUR_THRESHOLDS}, + {thr: (float(xr.concat(wet_baseline[thr], dim='time').mean()) * 100, None) for thr in ALLHOUR_THRESHOLDS}, + {thr: (float(pred_all[thr].mean()) * 100, float(pred_all[thr].std()) * 100) for thr in ALLHOUR_THRESHOLDS}, + ] + allhour_labels = ['Target', 'Input', 'CorrDiff ± Std(Members)'] + if any(wet_regpred[thr] for thr in ALLHOUR_THRESHOLDS): + allhour_mode_data.append( + {thr: (float(xr.concat(wet_regpred[thr], dim='time').mean()) * 100, None) for thr in ALLHOUR_THRESHOLDS} + ) + allhour_labels.append('Regression Prediction') + fn_allhour = output_path / 'allhour_wethours.png' + save_allhour_wethours_plot( + ALLHOUR_THRESHOLDS, allhour_mode_data, allhour_labels, + 'All-Hour Wet-Hour Fraction', fn_allhour, + ) + logger.info(f"All-hour wet-hour plot saved: {fn_allhour}") + logger.info("Plots saved.") if __name__ == '__main__': From ed915d86fc79ec1b72e79fcd398c36779eafd57b Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 9 Jun 2026 17:49:10 +0200 Subject: [PATCH 28/51] use logscale --- src/hirad/eval/bias_by_percentile_precip.py | 11 ++++++----- src/hirad/eval/bias_by_percentile_wind.py | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index d13ab4f0..82e9e803 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -75,14 +75,15 @@ def save_spread_by_percentile_plot(spread, percentile_values, def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, title, xlabel, ylabel, out_path, mean_q=None) -> None: - """Save a frequency-bias-index-by-percentile figure (linear y-axis, ratio around 1).""" + """Save a frequency-bias-index-by-percentile figure (log y-axis, ratio around 1).""" _, ax, frac = new_percentile_axes(percentile_values) - all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=0) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=1e-3) ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + ax.set_yscale('log') if all_vals: - ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 - if ymax > 1.0: - ax.set_ylim(0, ymax) + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.5 + ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 + ax.set_ylim(max(ymin, 1e-3), max(ymax, 2.0)) finalize_percentile_plot(ax, frac, lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), mean_q, xlabel, ylabel, title, out_path) diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 4f77298e..628b70a0 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -87,14 +87,15 @@ def save_spread_by_percentile_plot(spread, percentile_values, def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, title, xlabel, ylabel, out_path, mean_q=None) -> None: - """Save a frequency-bias-index-by-percentile figure (linear y-axis, ratio around 1).""" + """Save a frequency-bias-index-by-percentile figure (log y-axis, ratio around 1).""" _, ax, frac = new_percentile_axes(percentile_values) - all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=0) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=1e-3) ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + ax.set_yscale('log') if all_vals: - ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 - if ymax > 1.0: - ax.set_ylim(0, ymax) + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.5 + ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 + ax.set_ylim(max(ymin, 1e-3), max(ymax, 2.0)) finalize_percentile_plot(ax, frac, lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), mean_q, xlabel, ylabel, title, out_path) From df5b880a7d4f97ac3bb905470428159a58a19df5 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 10 Jun 2026 09:01:10 +0200 Subject: [PATCH 29/51] plot more wethour tresholds --- .../diurnal_cycle_precip_mean_wet-hour.py | 82 +++++-------------- 1 file changed, 19 insertions(+), 63 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py index f8d2ce75..a9e38e70 100644 --- a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -35,36 +35,6 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): plt.close() -def save_allhour_wethours_plot(thresholds, mode_data, labels, title, out_path): - """Bar chart of all-hour wet-hour fraction (%) per threshold. - - mode_data: list of dicts {thr: (mean_pct, std_pct_or_None)} - """ - n_thr = len(thresholds) - n_modes = len(mode_data) - x = np.arange(n_thr) - width = 0.7 / n_modes - fig, ax = plt.subplots(figsize=(8, 5)) - for i, (data, label) in enumerate(zip(mode_data, labels)): - offset = (i - n_modes / 2 + 0.5) * width - means = [data[thr][0] for thr in thresholds] - stds = [data[thr][1] if data[thr][1] is not None else 0.0 for thr in thresholds] - ax.bar(x + offset, means, width, label=label, alpha=0.8) - if any(s > 0 for s in stds): - ax.errorbar(x + offset, means, yerr=stds, fmt='none', color='black', capsize=3) - ax.set_xticks(x) - ax.set_xticklabels([f'>{thr:g} mm/h' for thr in thresholds]) - ax.set_ylabel('Wet-Hour Fraction [%]') - ax.set_yscale('log') - ax.set_title(title) - ax.legend() - ax.grid(True, axis='y', alpha=0.3, which='both') - plt.tight_layout() - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(out_path) - plt.close() - - def main(cfg: dict): # Setup logging logging.basicConfig(level=logging.INFO) @@ -150,12 +120,6 @@ def main(cfg: dict): if mean_pred_precip: amount_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_precip) - wet_target_mean, _ = concat_and_group_diurnal(wet_target[0.1], scale=100.0) - wet_baseline_mean, _ = concat_and_group_diurnal(wet_baseline[0.1], scale=100.0) - wet_pred_mean, wet_pred_std = concat_and_group_diurnal(wet_pred[0.1], is_member=True, scale=100.0) - if wet_regpred[0.1]: - wet_mean_pred_mean, _ = concat_and_group_diurnal(wet_regpred[0.1], scale=100.0) - # Generate plots output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") output_path.mkdir(parents=True, exist_ok=True) @@ -168,35 +132,27 @@ def main(cfg: dict): 'Diurnal Cycle of Precip Amount', output_path / 'diurnal_cycle_precip_amount.png' ) - save_plot( - wet_target_mean.hour, - [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if wet_regpred[0.1] else [wet_target_mean, wet_baseline_mean, wet_pred_mean], - [None, None, wet_pred_std, None] if wet_regpred[0.1] else [None, None, wet_pred_std], - ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if wet_regpred[0.1] else ['Target','Input','CorrDiff ± Std(Members)'], - 'Wet-Hour Fraction [%]', - 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', - output_path / 'diurnal_cycle_precip_wethours.png' - ) - # All-hour wet-hour fraction bar chart - pred_all = {thr: xr.concat(wet_pred[thr], dim='time').values.ravel() for thr in ALLHOUR_THRESHOLDS} - allhour_mode_data = [ - {thr: (float(xr.concat(wet_target[thr], dim='time').mean()) * 100, None) for thr in ALLHOUR_THRESHOLDS}, - {thr: (float(xr.concat(wet_baseline[thr], dim='time').mean()) * 100, None) for thr in ALLHOUR_THRESHOLDS}, - {thr: (float(pred_all[thr].mean()) * 100, float(pred_all[thr].std()) * 100) for thr in ALLHOUR_THRESHOLDS}, - ] - allhour_labels = ['Target', 'Input', 'CorrDiff ± Std(Members)'] - if any(wet_regpred[thr] for thr in ALLHOUR_THRESHOLDS): - allhour_mode_data.append( - {thr: (float(xr.concat(wet_regpred[thr], dim='time').mean()) * 100, None) for thr in ALLHOUR_THRESHOLDS} + # Diurnal cycle of wet-hours, one plot per threshold + for thr in ALLHOUR_THRESHOLDS: + wet_target_mean, _ = concat_and_group_diurnal(wet_target[thr], scale=100.0) + wet_baseline_mean, _ = concat_and_group_diurnal(wet_baseline[thr], scale=100.0) + wet_pred_mean, wet_pred_std = concat_and_group_diurnal(wet_pred[thr], is_member=True, scale=100.0) + has_regpred = bool(wet_regpred[thr]) + if has_regpred: + wet_mean_pred_mean, _ = concat_and_group_diurnal(wet_regpred[thr], scale=100.0) + + fn_wet = output_path / f'diurnal_cycle_precip_wethours_{thr:g}mmh.png' + save_plot( + wet_target_mean.hour, + [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if has_regpred else [wet_target_mean, wet_baseline_mean, wet_pred_mean], + [None, None, wet_pred_std, None] if has_regpred else [None, None, wet_pred_std], + ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if has_regpred else ['Target','Input','CorrDiff ± Std(Members)'], + 'Wet-Hour Fraction [%]', + f'Diurnal Cycle of Wet-Hours (>{thr:g} mm/h)', + fn_wet, ) - allhour_labels.append('Regression Prediction') - fn_allhour = output_path / 'allhour_wethours.png' - save_allhour_wethours_plot( - ALLHOUR_THRESHOLDS, allhour_mode_data, allhour_labels, - 'All-Hour Wet-Hour Fraction', fn_allhour, - ) - logger.info(f"All-hour wet-hour plot saved: {fn_allhour}") + logger.info(f"Diurnal wet-hour plot saved: {fn_wet}") logger.info("Plots saved.") From d42ac924f3c06894e8c81f65859ea1fcd90b0026 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 10 Jun 2026 11:17:55 +0200 Subject: [PATCH 30/51] exclude input from smoothing as it is lower res --- src/hirad/eval/bias_by_percentile_common.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index 8514c649..fa6ca64d 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -74,8 +74,8 @@ def build_all_histograms( skip_det: set = set() skip_preds = False - def accumulate(counts, channel_arrs): - if smoothing_sigma is not None: + def accumulate(counts, channel_arrs, smooth: bool = True): + if smoothing_sigma is not None and smooth: channel_arrs = [_smooth2d(a, smoothing_sigma) for a in channel_arrs] flats = [to_flat(a, conv, offset) for a in channel_arrs] vals = reduce_fn(flats)[land_idx] @@ -101,7 +101,8 @@ def accumulate(counts, channel_arrs): logger.warning(f" [{mode}] file not found at {ts}, skipping mode") skip_det.add(mode) continue - accumulate(det_counts[mode], [loaded[c] for c in chans]) + accumulate(det_counts[mode], [loaded[c] for c in chans], + smooth=mode != 'baseline') if not skip_preds: try: From 75c0165f74c40b3594cbb8e6c6f8c4ae0508d593 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 10 Jun 2026 11:24:12 +0200 Subject: [PATCH 31/51] change directory name --- src/hirad/eval/map_precip_stats.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index a4bcde9c..66defe4d 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -128,11 +128,6 @@ def plot_stat_map(data, filename, stat_config, label, grid_cfg): ) -def _load_predictions_all_members(filepath, conv_factor): - """Load prediction file once and return (n_members, C, H, W) tensor.""" - return torch.load(filepath, weights_only=False) * conv_factor - - def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -203,7 +198,7 @@ def main(cfg: dict): for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for {mode}...") result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold) - map_output_dir = output_path / f"maps_{stat_config['stat_name']}" + map_output_dir = output_path / f"maps_precip_{stat_config['stat_name']}" map_output_dir.mkdir(parents=True, exist_ok=True) plot_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) @@ -234,7 +229,7 @@ def main(cfg: dict): for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") member_result = apply_statistic(member_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold) - map_output_dir = output_path / f"maps_{stat_config['stat_name']}" + map_output_dir = output_path / f"maps_precip_{stat_config['stat_name']}" map_output_dir.mkdir(parents=True, exist_ok=True) member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') member_label = f'CorrDiff Member {member_idx+1}' From 5fad8b38605afad520d4ad2ad4d59aa9db458f4e Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 10 Jun 2026 11:34:15 +0200 Subject: [PATCH 32/51] plot temp maps --- src/hirad/eval/map_temp_stats.py | 316 +++++++++++++++++++++++++++++++ src/hirad/eval_temp.sh | 2 + 2 files changed, 318 insertions(+) create mode 100644 src/hirad/eval/map_temp_stats.py diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py new file mode 100644 index 00000000..b0de9da8 --- /dev/null +++ b/src/hirad/eval/map_temp_stats.py @@ -0,0 +1,316 @@ +import logging +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch +import xarray as xr +import numba + +from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir +from hirad.eval.plotting import plot_map + + +@numba.njit +def _longest_spell(x): + """Longest consecutive run of True values in a 1-D boolean array.""" + best = 0 + cur = 0 + for i in range(x.shape[0]): + if x[i]: + cur += 1 + if cur > best: + best = cur + else: + cur = 0 + return best + + +@numba.njit(parallel=True) +def _consecutive_spell_2d(condition_3d): + """ + condition_3d: bool array of shape (T, H, W). + Returns int array of shape (H, W) with longest spell per grid point. + """ + T, H, W = condition_3d.shape + out = np.empty((H, W), dtype=np.int64) + for i in numba.prange(H): + for j in range(W): + out[i, j] = _longest_spell(condition_3d[:, i, j]) + return out + + +def consecutive_spell(data_np, condition_fn): + """ + data_np: numpy array (T, H, W) + condition_fn: callable that takes the array and returns bool array of same shape + """ + cond = condition_fn(data_np) + return _consecutive_spell_2d(cond) + + +def apply_statistic(data_np, times_dt, stat_type, stat_param=None): + """ + Apply temperature statistic on array containing time sequence of 2m temperature maps. + data_np: (T, H, W) float array in degrees Celsius + times_dt: list of datetime objects (length T) + Returns: (H, W) numpy array + """ + if stat_type == 'mean': + return np.mean(data_np, axis=0) + + if stat_type == 'std': + return np.std(data_np, axis=0) + + if stat_type == 'max': + # TXx: maximum temperature + return np.max(data_np, axis=0) + + if stat_type == 'min': + # TNn: minimum temperature + return np.min(data_np, axis=0) + + if stat_type == 'quantile': + return np.quantile(data_np, stat_param, axis=0) + + # For daily-based indices, build daily aggregations using xarray + if stat_type in ('warm_days', 'frost_days', 'ice_days', 'tropical_nights', + 'dtr', 'warm_spell', 'cold_spell'): + da = xr.DataArray( + data_np, dims=['time', 'lat', 'lon'], + coords={'time': times_dt} + ) + daily_max = da.resample(time="1D").max("time").values # (D, H, W) + daily_min = da.resample(time="1D").min("time").values # (D, H, W) + D = daily_max.shape[0] + + if stat_type == 'warm_days': + # SU: fraction of days with daily max > 25 °C + return np.mean(daily_max > 25.0, axis=0) * 100.0 + + if stat_type == 'frost_days': + # FD: fraction of days with daily min < 0 °C + return np.mean(daily_min < 0.0, axis=0) * 100.0 + + if stat_type == 'ice_days': + # ID: fraction of days with daily max < 0 °C + return np.mean(daily_max < 0.0, axis=0) * 100.0 + + if stat_type == 'tropical_nights': + # TR: fraction of days with daily min > 20 °C + return np.mean(daily_min > 20.0, axis=0) * 100.0 + + if stat_type == 'dtr': + # DTR: mean diurnal temperature range + return np.mean(daily_max - daily_min, axis=0) + + if stat_type == 'warm_spell': + # WSDI-like: longest consecutive run of days with daily max > 25 °C + return consecutive_spell(daily_max, lambda x: x > 25.0) + + if stat_type == 'cold_spell': + # CSDI-like: longest consecutive run of days with daily min < 0 °C + return consecutive_spell(daily_min, lambda x: x < 0.0) + + raise ValueError(f"Unsupported temperature statistic type: {stat_type}") + + +def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): + """Plot a single temperature statistic map with appropriate styling.""" + stype = stat_config['type'] + title = f'{label}: {stat_config["title_stat"]}' + + if stype in ('mean', 'quantile', 'max', 'min'): + plot_map( + data, filename, + title=title, + label='Temperature [°C]', + vmin=-10, vmax=40, cmap='RdBu_r', extend='both', grid_cfg=grid_cfg + ) + elif stype == 'std': + plot_map( + data, filename, + title=title, + label='Std Dev [°C]', + vmin=0, vmax=10, cmap='plasma', extend='max', grid_cfg=grid_cfg + ) + elif stype == 'dtr': + plot_map( + data, filename, + title=title, + label='Diurnal Range [°C]', + vmin=0, vmax=20, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + ) + elif stype in ('warm_days', 'frost_days', 'ice_days', 'tropical_nights'): + plot_map( + data, filename, + title=title, + label='Frequency [% of days]', + vmin=0, vmax=100, cmap='OrRd', extend='neither', grid_cfg=grid_cfg + ) + elif stype == 'warm_spell': + plot_map( + data, filename, + title=title, + label='Days', + vmin=0, vmax=30, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + ) + elif stype == 'cold_spell': + plot_map( + data, filename, + title=title, + label='Days', + vmin=0, vmax=30, cmap='YlGnBu', extend='max', grid_cfg=grid_cfg + ) + else: + plot_map( + data, filename, + title=title, + label='Temperature [°C]', + vmin=None, vmax=None, cmap='RdBu_r', extend='both', grid_cfg=grid_cfg + ) + + +def main(cfg: dict): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + grid_cfg = grid_cfg_from_cfg(cfg) + + logger.info("Starting 2m temperature statistics generation") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Processing {len(times)} timesteps") + + times_dt = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + + out_root = Path(generation_dir) + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + + indices = get_channel_indices(gen_cfg) + out_ch = indices['output'] + in_ch = indices['input'] + + # Temperature channel: try '2t', fall back to 't2m' + t2m_out = out_ch.get('2t', out_ch.get('t2m')) + t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) + + if t2m_out is None: + logger.error("No temperature channel ('2t' or 't2m') found in output channels. Aborting.") + return + + # Conversion from Kelvin to Celsius: value * conv_factor + conv_offset + conv_factor = cfg.get("conv_factor", 1.0) + conv_offset = cfg.get("conv_offset", -273.15) + log_interval = cfg.get("log_interval", 100) + + STATISTICS_CONFIG = { + 'mean': {'type': 'mean', 'title': 'Mean Temperature'}, + 'std': {'type': 'std', 'title': 'Temperature Variability (Std Dev)'}, + 'txx': {'type': 'max', 'title': 'Maximum Temperature (TXx)'}, + 'tnn': {'type': 'min', 'title': 'Minimum Temperature (TNn)'}, + 'p99.99': {'type': 'quantile', 'param': 0.9999, 'title': '99.99th Percentile Temperature'}, + 'p99.9': {'type': 'quantile', 'param': 0.999, 'title': '99.9th Percentile Temperature'}, + 'p99': {'type': 'quantile', 'param': 0.99, 'title': '99th Percentile Temperature'}, + 'p01': {'type': 'quantile', 'param': 0.01, 'title': '1st Percentile Temperature'}, + 'p0.1': {'type': 'quantile', 'param': 0.001, 'title': '0.1th Percentile Temperature'}, + 'p0.01': {'type': 'quantile', 'param': 0.0001, 'title': '0.01th Percentile Temperature'}, + 'warm_days': {'type': 'warm_days', 'title': 'Summer Days (daily max > 25°C)'}, + 'frost_days': {'type': 'frost_days', 'title': 'Frost Days (daily min < 0°C)'}, + 'ice_days': {'type': 'ice_days', 'title': 'Ice Days (daily max < 0°C)'}, + 'tropical_nights': {'type': 'tropical_nights', 'title': 'Tropical Nights (daily min > 20°C)'}, + 'dtr': {'type': 'dtr', 'title': 'Mean Diurnal Temperature Range (DTR)'}, + 'warm_spell': {'type': 'warm_spell', 'title': 'Warm Spell Duration (daily max > 25°C)'}, + 'cold_spell': {'type': 'cold_spell', 'title': 'Cold Spell Duration (daily min < 0°C)'}, + } + stat_configs = [ + {'stat_name': name, 'title_stat': config['title'], 'param': config.get('param'), **config} + for name, config in STATISTICS_CONFIG.items() + ] + + # --- Basic modes: target, baseline, regression-prediction --- + basic_modes = { + 'target': (t2m_out, 'Target'), + 'baseline': (t2m_in, 'Input'), + 'regression-prediction': (t2m_out, 'Regression Prediction'), + } + + for mode, (t2m_channel, label) in basic_modes.items(): + logger.info(f"Processing mode: {mode}") + data_list = [] + try: + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") + raw = torch.load(resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", weights_only=False) + val = raw[t2m_channel] + arr = val.numpy() if isinstance(val, torch.Tensor) else val + data_list.append(arr * conv_factor + conv_offset) + except Exception: + logger.warning(f"{mode} not available, skipping") + continue + + mode_data = np.stack(data_list, axis=0).astype(np.float32) + del data_list + + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for {mode}...") + result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param']) + map_output_dir = output_path / f"maps_temp_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_temp_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) + + del mode_data + + # --- Predictions: process ONE member at a time to bound memory usage --- + logger.info("Processing predictions mode...") + try: + sample_data = torch.load( + resolve_ts_dir(out_root, times[0]) / times[0] / f"{times[0]}-predictions", + weights_only=False + ) + n_members = sample_data.shape[0] + del sample_data + except Exception as exc: + logger.error(f"Could not load predictions: {exc}") + return + logger.info(f"Found {n_members} ensemble members") + + H: int = cfg["height"] + W: int = cfg["width"] + member_data = np.empty((len(times), H, W), dtype=np.float32) + + for member_idx in range(n_members): + logger.info(f"Loading prediction member {member_idx+1}/{n_members} (single pass over files)...") + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f"Loading predictions member {member_idx+1} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load( + resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", + weights_only=False + ) + member_slice = pred_data[member_idx, t2m_out] + arr = member_slice.numpy() if isinstance(member_slice, torch.Tensor) else member_slice + member_data[i] = arr * conv_factor + conv_offset + del pred_data + + logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}") + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") + member_result = apply_statistic(member_data, times_dt, stat_config['type'], stat_config['param']) + map_output_dir = output_path / f"maps_temp_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + member_label = f'CorrDiff Member {member_idx+1}' + plot_temp_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg) + + del member_data + logger.info("All 2m temperature statistics maps generated successfully") + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval_temp.sh b/src/hirad/eval_temp.sh index 41af460d..a9773eeb 100644 --- a/src/hirad/eval_temp.sh +++ b/src/hirad/eval_temp.sh @@ -10,6 +10,8 @@ CMDS=( "python src/hirad/eval/diurnal_cycle_temp.py" # QQ "python -m hirad.eval.bias_by_percentile_temp" + # Maps + "python src/hirad/eval/map_temp_stats.py" ) for cmd in "${CMDS[@]}"; do From 083d15a1b45b318c691b31e88e0b0f781922072e Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 11 Jun 2026 15:44:36 +0200 Subject: [PATCH 33/51] fix temp conversion --- src/hirad/eval/map_temp_stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index b0de9da8..b8f1a969 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -203,9 +203,9 @@ def main(cfg: dict): logger.error("No temperature channel ('2t' or 't2m') found in output channels. Aborting.") return - # Conversion from Kelvin to Celsius: value * conv_factor + conv_offset - conv_factor = cfg.get("conv_factor", 1.0) - conv_offset = cfg.get("conv_offset", -273.15) + # Conversion from Kelvin to Celsius: value * temp_conv_factor + temp_conv_offset + conv_factor = cfg.get("temp_conv_factor", 1.0) + conv_offset = cfg.get("temp_conv_offset", -273.15) log_interval = cfg.get("log_interval", 100) STATISTICS_CONFIG = { From fb78d41ada63a0d74f72cbdcdeaa6da27ece9ffb Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Mon, 15 Jun 2026 21:13:05 +0200 Subject: [PATCH 34/51] refine temperature plot --- src/hirad/eval/map_temp_stats.py | 11 +++-------- src/hirad/eval/plotting.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index b8f1a969..51332f32 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -8,7 +8,7 @@ import numba from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir -from hirad.eval.plotting import plot_map +from hirad.eval.plotting import plot_map, plot_map_temperature @numba.njit @@ -121,12 +121,7 @@ def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): title = f'{label}: {stat_config["title_stat"]}' if stype in ('mean', 'quantile', 'max', 'min'): - plot_map( - data, filename, - title=title, - label='Temperature [°C]', - vmin=-10, vmax=40, cmap='RdBu_r', extend='both', grid_cfg=grid_cfg - ) + plot_map_temperature(data, filename, title=title, grid_cfg=grid_cfg) elif stype == 'std': plot_map( data, filename, @@ -153,7 +148,7 @@ def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): data, filename, title=title, label='Days', - vmin=0, vmax=30, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + vmin=0, vmax=60, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg ) elif stype == 'cold_spell': plot_map( diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 61e0ff0b..a462ed27 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -106,6 +106,31 @@ def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000 grid_cfg=grid_cfg, ) +def plot_map_temperature(values, filename, title='', grid_cfg=DEFAULT_GRID_CONFIG): + """Plot 2m temperature data with Meteoswiss-style colormap.""" + colors = [ + "#4C97FF", "#4CA8FF", "#00CCFF", "#DEE699", "#A6D473", + "#6BBF4D", "#33AB26", "#009900", "#33B300", "#66CC00", + "#99E600", "#CCFF00", "#FFFF00", "#FFCC00", "#FF9900", + "#FF6600", "#FF3300", "#FF0000", "#EB00EB", "#FF40FF", + ] + bounds = [-5, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34] + + cmap = ListedColormap(colors) + norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) + + plot_map( + values, filename, + cmap=cmap, + norm=norm, + ticks=bounds, + title=title, + label='Temperature [°C]', + extend='both', + grid_cfg=grid_cfg, + ) + + def plot_map_wind_precip( u: np.ndarray, v: np.ndarray, From dc63a5555a406c7d888c8dac9bdf09ddfcb36378 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 10:35:35 +0200 Subject: [PATCH 35/51] more colors --- src/hirad/eval/plotting.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index a462ed27..11204f63 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -109,12 +109,13 @@ def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000 def plot_map_temperature(values, filename, title='', grid_cfg=DEFAULT_GRID_CONFIG): """Plot 2m temperature data with Meteoswiss-style colormap.""" colors = [ - "#4C97FF", "#4CA8FF", "#00CCFF", "#DEE699", "#A6D473", - "#6BBF4D", "#33AB26", "#009900", "#33B300", "#66CC00", - "#99E600", "#CCFF00", "#FFFF00", "#FFCC00", "#FF9900", - "#FF6600", "#FF3300", "#FF0000", "#EB00EB", "#FF40FF", + "#1A33CC", "#3366FF", "#4C97FF", "#4CA8FF", "#00CCFF", + "#DEE699", "#A6D473", "#6BBF4D", "#33AB26", "#009900", + "#33B300", "#66CC00", "#99E600", "#CCFF00", "#FFFF00", + "#FFCC00", "#FF9900", "#FF6600", "#FF3300", "#FF0000", + "#EB00EB", "#FF40FF", "#FF80FF", "#FFBFFF", ] - bounds = [-5, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34] + bounds = [-9, -7, -5, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38] cmap = ListedColormap(colors) norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) From 1fcf563fb7e000276c68bb4dc60ee09ba090ca01 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 10:35:44 +0200 Subject: [PATCH 36/51] also compute hot days --- src/hirad/eval/map_temp_stats.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index 51332f32..b3ae12a9 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -74,7 +74,7 @@ def apply_statistic(data_np, times_dt, stat_type, stat_param=None): return np.quantile(data_np, stat_param, axis=0) # For daily-based indices, build daily aggregations using xarray - if stat_type in ('warm_days', 'frost_days', 'ice_days', 'tropical_nights', + if stat_type in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights', 'dtr', 'warm_spell', 'cold_spell'): da = xr.DataArray( data_np, dims=['time', 'lat', 'lon'], @@ -88,6 +88,10 @@ def apply_statistic(data_np, times_dt, stat_type, stat_param=None): # SU: fraction of days with daily max > 25 °C return np.mean(daily_max > 25.0, axis=0) * 100.0 + if stat_type == 'hot_days': + # HD: fraction of days with daily max > 35 °C + return np.mean(daily_max > 35.0, axis=0) * 100.0 + if stat_type == 'frost_days': # FD: fraction of days with daily min < 0 °C return np.mean(daily_min < 0.0, axis=0) * 100.0 @@ -136,19 +140,19 @@ def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): label='Diurnal Range [°C]', vmin=0, vmax=20, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg ) - elif stype in ('warm_days', 'frost_days', 'ice_days', 'tropical_nights'): + elif stype in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights'): plot_map( data, filename, title=title, label='Frequency [% of days]', - vmin=0, vmax=100, cmap='OrRd', extend='neither', grid_cfg=grid_cfg + vmin=0, vmax=50, cmap='OrRd', extend='neither', grid_cfg=grid_cfg ) elif stype == 'warm_spell': plot_map( data, filename, title=title, label='Days', - vmin=0, vmax=60, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + vmin=0, vmax=92, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg ) elif stype == 'cold_spell': plot_map( @@ -215,6 +219,7 @@ def main(cfg: dict): 'p0.1': {'type': 'quantile', 'param': 0.001, 'title': '0.1th Percentile Temperature'}, 'p0.01': {'type': 'quantile', 'param': 0.0001, 'title': '0.01th Percentile Temperature'}, 'warm_days': {'type': 'warm_days', 'title': 'Summer Days (daily max > 25°C)'}, + 'hot_days': {'type': 'hot_days', 'title': 'Hot Days (daily max > 35°C)'}, 'frost_days': {'type': 'frost_days', 'title': 'Frost Days (daily min < 0°C)'}, 'ice_days': {'type': 'ice_days', 'title': 'Ice Days (daily max < 0°C)'}, 'tropical_nights': {'type': 'tropical_nights', 'title': 'Tropical Nights (daily min > 20°C)'}, From 92ab9e639167e9ee5b6035b2692d5daec3988906 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 12:25:03 +0200 Subject: [PATCH 37/51] change spacings of bins --- src/hirad/eval/bias_by_percentile_precip.py | 10 ++-------- src/hirad/eval/bias_by_percentile_temp.py | 14 ++------------ src/hirad/eval/bias_by_percentile_wind.py | 16 +++------------- src/hirad/eval/eval_utils.py | 9 +++++++++ 4 files changed, 16 insertions(+), 33 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 82e9e803..4097442b 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -11,7 +11,7 @@ plot_dict_curves, run_bias_by_percentile, ) -from hirad.eval.eval_utils import parse_eval_cli +from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, @@ -111,13 +111,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: mae_ylabel='MAE [mm/h]', spread_ylabel='Ensemble Spread [mm/h]', fbi_ylabel='FBI [-]', - percentile_values=np.unique(np.concatenate([ - np.linspace(1.0, 90.0, 90), - np.linspace(90.0, 99.0, 90), - np.linspace(99.0, 99.9, 45), - np.linspace(99.9, 99.99, 20), - np.linspace(99.99, 99.999, 10), - ])), + percentile_values=make_percentile_values(), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, read_scaling=lambda cfg: (cfg.get("conv_factor_hourly", 1.0), 0.0), diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 78935a33..4d207b33 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -12,7 +12,7 @@ plot_dict_curves, run_bias_by_percentile, ) -from hirad.eval.eval_utils import parse_eval_cli +from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: @@ -108,17 +108,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_ylabel='Bias [°C]', mae_ylabel='MAE [°C]', spread_ylabel='Ensemble Spread [°C]', - percentile_values=np.unique(np.concatenate([ - np.linspace(0.001, 0.01, 10), - np.linspace(0.01, 0.1, 10), - np.linspace(0.1, 1.0, 10), - np.linspace(1.0, 10.0, 10), - np.linspace(10.0, 90.0, 80), - np.linspace(90.0, 99.0, 90), - np.linspace(99.0, 99.9, 45), - np.linspace(99.9, 99.99, 20), - np.linspace(99.99, 99.999, 10), - ])), + percentile_values=make_percentile_values(), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, # Default: convert Kelvin → °C (conv=1.0, offset=-273.15) diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 628b70a0..58bf1961 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -15,7 +15,7 @@ plot_dict_curves, run_bias_by_percentile, ) -from hirad.eval.eval_utils import parse_eval_cli +from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, @@ -130,18 +130,8 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_ylabel='Bias [m/s]', mae_ylabel='MAE [m/s]', spread_ylabel='Ensemble Spread [m/s]', - fbi_ylabel='FBI [-]', - percentile_values=np.unique(np.concatenate([ - np.linspace(0.001, 0.01, 10), - np.linspace(0.01, 0.1, 10), - np.linspace(0.1, 1.0, 10), - np.linspace(1.0, 10.0, 10), - np.linspace(10.0, 90.0, 80), - np.linspace(90.0, 99.0, 90), - np.linspace(99.0, 99.9, 45), - np.linspace(99.9, 99.99, 20), - np.linspace(99.99, 99.999, 10), - ])), + fbi_ylabel='FBI (exceedance)', + percentile_values=make_percentile_values(), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, read_scaling=lambda cfg: (cfg.get("wind_conv_factor", 1.0), 0.0), diff --git a/src/hirad/eval/eval_utils.py b/src/hirad/eval/eval_utils.py index 680f5d71..6742bf46 100644 --- a/src/hirad/eval/eval_utils.py +++ b/src/hirad/eval/eval_utils.py @@ -174,6 +174,15 @@ def get_channel_indices(gen_cfg: dict, channels=None) -> dict: } +def make_percentile_values(per_decade: int = 20) -> np.ndarray: + """Percentiles sampled equidistantly on the logit (log-exceedance) axis.""" + tail = np.logspace(-3, 1, 4 * per_decade + 1) # 0.001 ... 10 + lower = tail # low tail: 0.001 ... 10 + upper = 100.0 - tail[::-1] # high tail: 90 ... 99.999 + center = np.linspace(10.0, 90.0, 2 * per_decade + 1) + return np.unique(np.concatenate([lower, center, upper])) + + def resolve_ts_dir(out_root: Path, ts: str) -> Path: """Return the directory under *out_root* that contains the timestamp folder *ts*.""" if (out_root / ts).is_dir(): From 888d686512b1459a0436e22e695d0695c2077112 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 12:27:46 +0200 Subject: [PATCH 38/51] also show 100 mm/h --- src/hirad/eval/bias_by_percentile_precip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 4097442b..5f1f68d6 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -33,7 +33,7 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, ax2.set_xlim(xlim_left, frac[-1]) nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) tick_positions = np.interp(nice_mmh, mean_q, frac) - valid = (tick_positions > xlim_left) & (tick_positions < frac[-1]) + valid = (tick_positions > xlim_left) & (tick_positions <= frac[-1]) ax2.set_xticks(tick_positions[valid]) ax2.set_xticklabels([f'{v:g}' for v in nice_mmh[valid]]) ax2.set_xlabel('Mean target [mm/h]') From 993a3d4eae3c91473e77c40c02e3592a147b11f3 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 17:39:06 +0200 Subject: [PATCH 39/51] fix label --- src/hirad/eval/bias_by_percentile_precip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 5f1f68d6..15170cc3 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -110,7 +110,7 @@ def _make_hist_bins(cfg: dict) -> np.ndarray: bias_ylabel='Bias [mm/h]', mae_ylabel='MAE [mm/h]', spread_ylabel='Ensemble Spread [mm/h]', - fbi_ylabel='FBI [-]', + fbi_ylabel='FBI (exceedance)', percentile_values=make_percentile_values(), resolve_channels=_resolve_channels, make_hist_bins=_make_hist_bins, From daf3d9cc1600070ce6050d8fb96eee8b07b8b716 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 17:43:16 +0200 Subject: [PATCH 40/51] add tick on the left --- src/hirad/eval/bias_by_percentile_precip.py | 11 +++++++++-- src/hirad/eval/bias_by_percentile_wind.py | 12 ++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 15170cc3..3f2ab6e7 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -34,8 +34,15 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) tick_positions = np.interp(nice_mmh, mean_q, frac) valid = (tick_positions > xlim_left) & (tick_positions <= frac[-1]) - ax2.set_xticks(tick_positions[valid]) - ax2.set_xticklabels([f'{v:g}' for v in nice_mmh[valid]]) + positions = list(tick_positions[valid]) + labels = [f'{v:g}' for v in nice_mmh[valid]] + # Ensure a labelled tick at the left-hand edge of the visible range. + if not positions or positions[0] > xlim_left: + left_mmh = np.interp(xlim_left, frac, mean_q) + positions.insert(0, xlim_left) + labels.insert(0, f'{left_mmh:.2g}') + ax2.set_xticks(positions) + ax2.set_xticklabels(labels) ax2.set_xlabel('Mean target [mm/h]') diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 58bf1961..2220303e 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -9,6 +9,7 @@ from hirad.eval.bias_by_percentile_common import ( BiasByPercentileSpec, + _round_sig, even_value_ticks, finalize_percentile_plot, new_percentile_axes, @@ -40,8 +41,15 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, ax2.set_xlim(xlim_left, xlim_right) tick_positions, tick_speeds = even_value_ticks(frac, mean_q) valid = (tick_positions >= xlim_left) & (tick_positions <= xlim_right) - ax2.set_xticks(tick_positions[valid]) - ax2.set_xticklabels([f'{v:g}' for v in tick_speeds[valid]]) + positions = list(tick_positions[valid]) + speeds = list(tick_speeds[valid]) + # Ensure a labelled tick at the left-hand edge of the visible range. + left_speed = _round_sig(float(np.interp(xlim_left, frac, mean_q))) + if not positions or positions[0] > xlim_left: + positions.insert(0, xlim_left) + speeds.insert(0, left_speed) + ax2.set_xticks(positions) + ax2.set_xticklabels([f'{v:g}' for v in speeds]) ax2.set_xlabel('Mean target [m/s]') From ce2f253c01b77bef5eb41286bbb9b2d595906fdb Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 16 Jun 2026 18:40:43 +0200 Subject: [PATCH 41/51] use bin centres --- src/hirad/eval/bias_by_percentile_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index fa6ca64d..2a71f86e 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -136,7 +136,7 @@ def per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, n_land, n_bins = pp_counts.shape P = len(frac_percentiles) result = np.empty((n_land, P), dtype=np.float32) - edges_upper = bin_edges[1:].astype(np.float32) + bin_centers = (0.5 * (bin_edges[:-1] + bin_edges[1:])).astype(np.float32) frac_f64 = frac_percentiles.astype(np.float64) for start in range(0, n_land, block_size): @@ -155,7 +155,7 @@ def per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, idx = np.searchsorted(cdf.ravel(), queries.ravel(), side='left') idx = idx.reshape(B, P) - (np.arange(B, dtype=np.intp)[:, None] * n_bins) np.clip(idx, 0, n_bins - 1, out=idx) - result[start:end] = edges_upper[idx] + result[start:end] = bin_centers[idx] return result From da16bc66ee20203b94246e0250f777beeb43b62a Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 14:52:56 +0200 Subject: [PATCH 42/51] plot difference maps --- src/hirad/eval/eval_utils.py | 9 ++++ src/hirad/eval/map_precip_stats.py | 52 ++++++++++++++++++++++- src/hirad/eval/map_temp_stats.py | 54 +++++++++++++++++++++++- src/hirad/eval/map_wind_stats.py | 68 ++++++++++++++++++++++++++++-- src/hirad/eval/plotting.py | 34 +++++++++++++++ 5 files changed, 212 insertions(+), 5 deletions(-) diff --git a/src/hirad/eval/eval_utils.py b/src/hirad/eval/eval_utils.py index 6742bf46..8ec0ad3e 100644 --- a/src/hirad/eval/eval_utils.py +++ b/src/hirad/eval/eval_utils.py @@ -193,6 +193,15 @@ def resolve_ts_dir(out_root: Path, ts: str) -> Path: raise FileNotFoundError(f"Timestamp directory {ts} not found under {out_root}") +def signed_circular_difference(prediction: np.ndarray, target: np.ndarray, period: float = 360.0) -> np.ndarray: + """Return signed wrapped difference on a circular domain. + + For angles in degrees, this yields values in [-180, 180). + """ + half_period = period / 2.0 + return ((prediction - target + half_period) % period) - half_period + + def parse_eval_cli(allow_times: bool = False) -> dict: """Parse standard eval CLI args (``--config-name``) and return the loaded YAML config. diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index 66defe4d..685a8889 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -9,7 +9,7 @@ from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir from hirad.eval.plotting import ( - plot_map_precipitation, plot_map + plot_difference_map, plot_map, plot_map_precipitation ) @@ -128,6 +128,14 @@ def plot_stat_map(data, filename, stat_config, label, grid_cfg): ) +def _difference_label(stat_type): + if stat_type == 'weth_freq': + return 'Difference [%]' + if stat_type in ('cdd', 'cwd'): + return 'Difference [days]' + return 'Difference [mm/day]' + + def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -178,6 +186,8 @@ def main(cfg: dict): 'regression-prediction': (tp_out, 'Regression Prediction') } + mode_results = {} + for mode, (tp_channel, label) in basic_modes.items(): logger.info(f"Processing mode: {mode}") data_list = [] @@ -195,15 +205,40 @@ def main(cfg: dict): mode_data = np.stack(data_list, axis=0).astype(np.float32) del data_list + mode_results[mode] = {} for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for {mode}...") result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold) + mode_results[mode][stat_config['stat_name']] = result map_output_dir = output_path / f"maps_precip_{stat_config['stat_name']}" map_output_dir.mkdir(parents=True, exist_ok=True) plot_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) del mode_data + target_results = mode_results.get('target') + if target_results is None: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for basic modes") + else: + for mode, (_, label) in basic_modes.items(): + if mode == 'target' or mode not in mode_results: + continue + logger.info(f"Generating {mode} minus target difference maps") + for stat_config in stat_configs: + stat_name = stat_config['stat_name'] + if stat_name not in mode_results[mode] or stat_name not in target_results: + continue + diff = mode_results[mode][stat_name] - target_results[stat_name] + map_output_dir = output_path / f"maps_precip_{stat_name}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_difference_map( + diff, + str(map_output_dir / f'{mode}_minus_target_{stat_name}'), + title=f'{label} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) + # --- Predictions: process ONE member at a time to bound memory usage --- logger.info("Processing predictions mode...") sample_data = torch.load(resolve_ts_dir(out_root, times[0]) / times[0] / f"{times[0]}-predictions", weights_only=False) @@ -214,6 +249,9 @@ def main(cfg: dict): H: int = cfg["height"] W: int = cfg["width"] member_data = np.empty((len(times), H, W), dtype=np.float32) + has_target_for_diff = target_results is not None + if not has_target_for_diff: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for members") for member_idx in range(n_members): logger.info(f"Loading prediction member {member_idx+1}/{n_members} (single pass over files)...") @@ -234,6 +272,18 @@ def main(cfg: dict): member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') member_label = f'CorrDiff Member {member_idx+1}' plot_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg) + if has_target_for_diff: + target_result = target_results.get(stat_config['stat_name']) + if target_result is not None: + diff = member_result - target_result + diff_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_minus_target_{stat_config["stat_name"]}') + plot_difference_map( + diff, + diff_filename, + title=f'CorrDiff Member {member_idx+1} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) del member_data logger.info("All precipitation statistics maps generated successfully") diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index b3ae12a9..a3efdd6a 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -8,7 +8,7 @@ import numba from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir -from hirad.eval.plotting import plot_map, plot_map_temperature +from hirad.eval.plotting import plot_difference_map, plot_map, plot_map_temperature @numba.njit @@ -170,6 +170,16 @@ def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): ) +def _difference_label(stat_type): + if stat_type in ('mean', 'quantile', 'max', 'min', 'std', 'dtr'): + return 'Difference [°C]' + if stat_type in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights'): + return 'Difference [% of days]' + if stat_type in ('warm_spell', 'cold_spell'): + return 'Difference [days]' + return 'Difference' + + def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -239,6 +249,8 @@ def main(cfg: dict): 'regression-prediction': (t2m_out, 'Regression Prediction'), } + mode_results = {} + for mode, (t2m_channel, label) in basic_modes.items(): logger.info(f"Processing mode: {mode}") data_list = [] @@ -257,15 +269,40 @@ def main(cfg: dict): mode_data = np.stack(data_list, axis=0).astype(np.float32) del data_list + mode_results[mode] = {} for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for {mode}...") result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param']) + mode_results[mode][stat_config['stat_name']] = result map_output_dir = output_path / f"maps_temp_{stat_config['stat_name']}" map_output_dir.mkdir(parents=True, exist_ok=True) plot_temp_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) del mode_data + target_results = mode_results.get('target') + if target_results is None: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for basic modes") + else: + for mode, (_, label) in basic_modes.items(): + if mode == 'target' or mode not in mode_results: + continue + logger.info(f"Generating {mode} minus target difference maps") + for stat_config in stat_configs: + stat_name = stat_config['stat_name'] + if stat_name not in mode_results[mode] or stat_name not in target_results: + continue + diff = mode_results[mode][stat_name] - target_results[stat_name] + map_output_dir = output_path / f"maps_temp_{stat_name}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_difference_map( + diff, + str(map_output_dir / f'{mode}_minus_target_{stat_name}'), + title=f'{label} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) + # --- Predictions: process ONE member at a time to bound memory usage --- logger.info("Processing predictions mode...") try: @@ -283,6 +320,9 @@ def main(cfg: dict): H: int = cfg["height"] W: int = cfg["width"] member_data = np.empty((len(times), H, W), dtype=np.float32) + has_target_for_diff = target_results is not None + if not has_target_for_diff: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for members") for member_idx in range(n_members): logger.info(f"Loading prediction member {member_idx+1}/{n_members} (single pass over files)...") @@ -307,6 +347,18 @@ def main(cfg: dict): member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') member_label = f'CorrDiff Member {member_idx+1}' plot_temp_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg) + if has_target_for_diff: + target_result = target_results.get(stat_config['stat_name']) + if target_result is not None: + diff = member_result - target_result + diff_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_minus_target_{stat_config["stat_name"]}') + plot_difference_map( + diff, + diff_filename, + title=f'CorrDiff Member {member_idx+1} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) del member_data logger.info("All 2m temperature statistics maps generated successfully") diff --git a/src/hirad/eval/map_wind_stats.py b/src/hirad/eval/map_wind_stats.py index a6367b10..8428ca0e 100644 --- a/src/hirad/eval/map_wind_stats.py +++ b/src/hirad/eval/map_wind_stats.py @@ -7,8 +7,8 @@ from hirad.datasets import get_channels_from_strings, get_strings_from_channels from hirad.utils.function_utils import get_time_from_range -from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir -from hirad.eval.plotting import plot_map +from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir, signed_circular_difference +from hirad.eval.plotting import plot_difference_map, plot_map def compute_wind_speed(u, v): @@ -199,6 +199,20 @@ def plot_wind_stat_map(data, filename, stat_config, label, grid_cfg): ) +def _wind_difference_label(stat_type): + if stat_type in ['mean_speed', 'max_speed']: + return 'Difference [m/s]' + if stat_type == 'wind_power': + return 'Difference [m^3/s^3]' + if stat_type in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', 'strong_breeze_freq', 'gale_freq']: + return 'Difference [%]' + if stat_type in ['prevailing_direction', 'direction_variability']: + return 'Difference [degrees]' + if stat_type in ['mean_u', 'mean_v']: + return 'Difference [m/s]' + return 'Difference' + + def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -294,6 +308,7 @@ def main(cfg: dict): logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} modes + predictions") log_interval = cfg.get("log_interval", 100) + mode_results = {} for mode, (wind_channels, label) in basic_modes.items(): logger.info(f"Processing mode: {mode}") @@ -323,11 +338,42 @@ def main(cfg: dict): label, grid_cfg ) - del results + mode_results[mode] = results except Exception as e: logger.error(f"Failed computing statistics for {mode}: {e}") continue + target_results = mode_results.get('target') + if target_results is None: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for basic modes") + else: + for mode, (_, label) in basic_modes.items(): + if mode == 'target' or mode not in mode_results: + continue + logger.info(f"Generating {mode} minus target difference maps") + for stat_config in stat_configs: + stat_key = stat_config['stat_name'] + if stat_key not in mode_results[mode] or stat_key not in target_results: + continue + if stat_config['type'] == 'prevailing_direction': + diff = signed_circular_difference(mode_results[mode][stat_key], target_results[stat_key]) + else: + diff = mode_results[mode][stat_key] - target_results[stat_key] + map_output_dir = output_path / f"maps_wind_{stat_key}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_difference_map( + diff, + str(map_output_dir / f'{mode}_minus_target_{stat_key}'), + title=f'{label} - Target: {stat_config["title_stat"]} Difference', + label=_wind_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + fixed_vmax=180.0 if stat_config['type'] == 'prevailing_direction' else None, + ) + + has_target_for_diff = target_results is not None + if not has_target_for_diff: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for members") + logger.info("Processing predictions mode...") try: @@ -469,6 +515,22 @@ def main(cfg: dict): map_output_dir.mkdir(parents=True, exist_ok=True) member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_key}') plot_wind_stat_map(member_result, member_filename, stat_config, f'CorrDiff Member {member_idx+1}', grid_cfg) + if has_target_for_diff: + target_result = target_results.get(stat_key) + if target_result is not None: + if stype == 'prevailing_direction': + diff = signed_circular_difference(member_result, target_result) + else: + diff = member_result - target_result + diff_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_minus_target_{stat_key}') + plot_difference_map( + diff, + diff_filename, + title=f'CorrDiff Member {member_idx+1} - Target: {stat_config["title_stat"]} Difference', + label=_wind_difference_label(stype), + grid_cfg=grid_cfg, + fixed_vmax=180.0 if stype == 'prevailing_direction' else None, + ) del member_result except Exception as e: logger.error(f"Failed {stat_config['title_stat']} for member {member_idx+1}: {e}") diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 11204f63..2d691093 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -79,6 +80,39 @@ def plot_map(values: np.array, fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") plt.close(fig) + +def compute_symmetric_vmax(values: np.ndarray, percentile: float = 99.0, fallback: float = 1.0) -> float: + """Return a robust symmetric colorbar half-range based on absolute values.""" + vmax = float(np.nanpercentile(np.abs(values), percentile)) + if vmax != vmax or vmax <= 0: + return fallback + return vmax + + +def plot_difference_map( + values: np.ndarray, + filename: str, + title: str = '', + label: str = 'Difference', + grid_cfg: GridConfig = DEFAULT_GRID_CONFIG, + cmap: str = 'RdBu_r', + percentile: float = 99.0, + fixed_vmax: Optional[float] = None, +): + """Plot a difference map with symmetric diverging bounds around zero.""" + vmax = fixed_vmax if fixed_vmax is not None else compute_symmetric_vmax(values, percentile=percentile) + plot_map( + values, + filename, + title=title, + label=label, + vmin=-vmax, + vmax=vmax, + cmap=cmap, + extend='both', + grid_cfg=grid_cfg, + ) + def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000.0, grid_cfg=DEFAULT_GRID_CONFIG): """Plot precipitation data with specific colormap and thresholds.""" # Scale and mask values below threshold From 82f70fa43f305eb3e201ab46f09d70cf9c7493fe Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:01:02 +0200 Subject: [PATCH 43/51] add more levels --- src/hirad/eval/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 2d691093..f96a247e 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -147,9 +147,9 @@ def plot_map_temperature(values, filename, title='', grid_cfg=DEFAULT_GRID_CONFI "#DEE699", "#A6D473", "#6BBF4D", "#33AB26", "#009900", "#33B300", "#66CC00", "#99E600", "#CCFF00", "#FFFF00", "#FFCC00", "#FF9900", "#FF6600", "#FF3300", "#FF0000", - "#EB00EB", "#FF40FF", "#FF80FF", "#FFBFFF", + "#EB00EB", "#FF40FF", "#FF80FF", "#FFBFFF", "#FFE0FF", "#FFF5FF", ] - bounds = [-9, -7, -5, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38] + bounds = [-9, -7, -5, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42] cmap = ListedColormap(colors) norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) From db938c4a35a455a8c57d742a87bed1c1886e137e Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:01:22 +0200 Subject: [PATCH 44/51] reduce range --- src/hirad/eval/map_temp_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index a3efdd6a..250c806f 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -145,7 +145,7 @@ def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): data, filename, title=title, label='Frequency [% of days]', - vmin=0, vmax=50, cmap='OrRd', extend='neither', grid_cfg=grid_cfg + vmin=0, vmax=30, cmap='OrRd', extend='neither', grid_cfg=grid_cfg ) elif stype == 'warm_spell': plot_map( From 6ffa8402222bdc8748c8d73a764ed37c4100108e Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:07:35 +0200 Subject: [PATCH 45/51] add hot spell --- src/hirad/eval/map_temp_stats.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index 250c806f..36ae74c9 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -75,7 +75,7 @@ def apply_statistic(data_np, times_dt, stat_type, stat_param=None): # For daily-based indices, build daily aggregations using xarray if stat_type in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights', - 'dtr', 'warm_spell', 'cold_spell'): + 'dtr', 'warm_spell', 'hot_spell', 'cold_spell'): da = xr.DataArray( data_np, dims=['time', 'lat', 'lon'], coords={'time': times_dt} @@ -112,6 +112,10 @@ def apply_statistic(data_np, times_dt, stat_type, stat_param=None): # WSDI-like: longest consecutive run of days with daily max > 25 °C return consecutive_spell(daily_max, lambda x: x > 25.0) + if stat_type == 'hot_spell': + # longest consecutive run of days with daily max > 35 °C + return consecutive_spell(daily_max, lambda x: x > 35.0) + if stat_type == 'cold_spell': # CSDI-like: longest consecutive run of days with daily min < 0 °C return consecutive_spell(daily_min, lambda x: x < 0.0) @@ -154,6 +158,13 @@ def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): label='Days', vmin=0, vmax=92, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg ) + elif stype == 'hot_spell': + plot_map( + data, filename, + title=title, + label='Days', + vmin=0, vmax=30, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + ) elif stype == 'cold_spell': plot_map( data, filename, @@ -175,7 +186,7 @@ def _difference_label(stat_type): return 'Difference [°C]' if stat_type in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights'): return 'Difference [% of days]' - if stat_type in ('warm_spell', 'cold_spell'): + if stat_type in ('warm_spell', 'hot_spell', 'cold_spell'): return 'Difference [days]' return 'Difference' @@ -235,6 +246,7 @@ def main(cfg: dict): 'tropical_nights': {'type': 'tropical_nights', 'title': 'Tropical Nights (daily min > 20°C)'}, 'dtr': {'type': 'dtr', 'title': 'Mean Diurnal Temperature Range (DTR)'}, 'warm_spell': {'type': 'warm_spell', 'title': 'Warm Spell Duration (daily max > 25°C)'}, + 'hot_spell': {'type': 'hot_spell', 'title': 'Hot Spell Duration (daily max > 35°C)'}, 'cold_spell': {'type': 'cold_spell', 'title': 'Cold Spell Duration (daily min < 0°C)'}, } stat_configs = [ From f30566d0607fe581adeb3d40055dd9f5f0213c51 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:10:08 +0200 Subject: [PATCH 46/51] makw compatible with ETCCDI --- src/hirad/eval/map_temp_stats.py | 51 +++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py index 36ae74c9..4dff7a10 100644 --- a/src/hirad/eval/map_temp_stats.py +++ b/src/hirad/eval/map_temp_stats.py @@ -49,6 +49,39 @@ def consecutive_spell(data_np, condition_fn): return _consecutive_spell_2d(cond) +@numba.njit +def _count_spell_days(x, min_run): + """Count total days belonging to runs of at least min_run consecutive True values.""" + n = x.shape[0] + total = 0 + i = 0 + while i < n: + if x[i]: + run_start = i + while i < n and x[i]: + i += 1 + run_len = i - run_start + if run_len >= min_run: + total += run_len + else: + i += 1 + return total + + +@numba.njit(parallel=True) +def _spell_days_2d(condition_3d, min_run): + """ + condition_3d: bool array of shape (T, H, W). + Returns int array of shape (H, W) with total spell days per grid point. + """ + T, H, W = condition_3d.shape + out = np.empty((H, W), dtype=np.int64) + for i in numba.prange(H): + for j in range(W): + out[i, j] = _count_spell_days(condition_3d[:, i, j], min_run) + return out + + def apply_statistic(data_np, times_dt, stat_type, stat_param=None): """ Apply temperature statistic on array containing time sequence of 2m temperature maps. @@ -109,16 +142,20 @@ def apply_statistic(data_np, times_dt, stat_type, stat_param=None): return np.mean(daily_max - daily_min, axis=0) if stat_type == 'warm_spell': - # WSDI-like: longest consecutive run of days with daily max > 25 °C - return consecutive_spell(daily_max, lambda x: x > 25.0) + # WSDI: total days in spells ≥6 consecutive days with TX > 90th percentile + p90 = np.percentile(daily_max, 90, axis=0) # (H, W) + cond = np.ascontiguousarray(daily_max > p90[np.newaxis, :, :]) + return _spell_days_2d(cond, 6).astype(np.float32) if stat_type == 'hot_spell': - # longest consecutive run of days with daily max > 35 °C + # longest consecutive run of days with daily max > 35 °C (custom, non-ETCCDI) return consecutive_spell(daily_max, lambda x: x > 35.0) if stat_type == 'cold_spell': - # CSDI-like: longest consecutive run of days with daily min < 0 °C - return consecutive_spell(daily_min, lambda x: x < 0.0) + # CSDI: total days in spells ≥6 consecutive days with TN < 10th percentile + p10 = np.percentile(daily_min, 10, axis=0) # (H, W) + cond = np.ascontiguousarray(daily_min < p10[np.newaxis, :, :]) + return _spell_days_2d(cond, 6).astype(np.float32) raise ValueError(f"Unsupported temperature statistic type: {stat_type}") @@ -245,9 +282,9 @@ def main(cfg: dict): 'ice_days': {'type': 'ice_days', 'title': 'Ice Days (daily max < 0°C)'}, 'tropical_nights': {'type': 'tropical_nights', 'title': 'Tropical Nights (daily min > 20°C)'}, 'dtr': {'type': 'dtr', 'title': 'Mean Diurnal Temperature Range (DTR)'}, - 'warm_spell': {'type': 'warm_spell', 'title': 'Warm Spell Duration (daily max > 25°C)'}, + 'warm_spell': {'type': 'warm_spell', 'title': 'WSDI: Warm Spell Duration (TX > 90th pct, ≥6 days)'}, 'hot_spell': {'type': 'hot_spell', 'title': 'Hot Spell Duration (daily max > 35°C)'}, - 'cold_spell': {'type': 'cold_spell', 'title': 'Cold Spell Duration (daily min < 0°C)'}, + 'cold_spell': {'type': 'cold_spell', 'title': 'CSDI: Cold Spell Duration (TN < 10th pct, ≥6 days)'}, } stat_configs = [ {'stat_name': name, 'title_stat': config['title'], 'param': config.get('param'), **config} From 8f93676ac8e37ca10e778f9f205366e2f9c8c120 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:21:17 +0200 Subject: [PATCH 47/51] use 50th percentile consistently for precip --- src/hirad/eval/bias_by_percentile_precip.py | 24 ++++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 3f2ab6e7..9af35e14 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -14,15 +14,14 @@ from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli -def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, - xlim_left: float = 0.5) -> None: +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: """Apply logit x-axis with labelled percentile ticks (and a mm/h secondary axis).""" ax.set_xscale('logit') - ax.set_xlim(xlim_left, frac[-1]) - all_tick_fracs = [0.10, 0.25, 0.50, 0.75, 0.90, 0.99, 0.999, 0.9999, 0.99999] - all_tick_labels = ['10', '25', '50', '75', '90', '99', '99.9', '99.99', '99.999'] + ax.set_xlim(0.5, frac[-1]) + all_tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999, 0.99999] + all_tick_labels = ['50', '75', '90', '99', '99.9', '99.99', '99.999'] valid_ticks = [(f, l) for f, l in zip(all_tick_fracs, all_tick_labels) - if xlim_left <= f <= frac[-1]] + if f <= frac[-1]] ax.set_xticks([f for f, _ in valid_ticks]) ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') @@ -30,16 +29,16 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(xlim_left, frac[-1]) + ax2.set_xlim(0.5, frac[-1]) nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) tick_positions = np.interp(nice_mmh, mean_q, frac) - valid = (tick_positions > xlim_left) & (tick_positions <= frac[-1]) + valid = (tick_positions > 0.5) & (tick_positions <= frac[-1]) positions = list(tick_positions[valid]) labels = [f'{v:g}' for v in nice_mmh[valid]] # Ensure a labelled tick at the left-hand edge of the visible range. - if not positions or positions[0] > xlim_left: - left_mmh = np.interp(xlim_left, frac, mean_q) - positions.insert(0, xlim_left) + if not positions or positions[0] > 0.5: + left_mmh = np.interp(0.5, frac, mean_q) + positions.insert(0, 0.5) labels.insert(0, f'{left_mmh:.2g}') ax2.set_xticks(positions) ax2.set_xticklabels(labels) @@ -91,8 +90,7 @@ def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.5 ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 ax.set_ylim(max(ymin, 1e-3), max(ymax, 2.0)) - finalize_percentile_plot(ax, frac, - lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, xlabel, ylabel, title, out_path) From 44f090b0716c5c028f2672060d9ab5017f5a6ef2 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:37:20 +0200 Subject: [PATCH 48/51] show less --- src/hirad/eval/bias_by_percentile_precip.py | 13 +++++++------ src/hirad/eval/bias_by_percentile_temp.py | 13 +++++++------ src/hirad/eval/bias_by_percentile_wind.py | 8 ++++---- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 9af35e14..0e37ae2a 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -17,11 +17,12 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: """Apply logit x-axis with labelled percentile ticks (and a mm/h secondary axis).""" ax.set_xscale('logit') - ax.set_xlim(0.5, frac[-1]) - all_tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999, 0.99999] - all_tick_labels = ['50', '75', '90', '99', '99.9', '99.99', '99.999'] + xlim_right = min(float(frac[-1]), 0.9999) + ax.set_xlim(0.5, xlim_right) + all_tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] + all_tick_labels = ['50', '75', '90', '99', '99.9', '99.99'] valid_ticks = [(f, l) for f, l in zip(all_tick_fracs, all_tick_labels) - if f <= frac[-1]] + if f <= xlim_right] ax.set_xticks([f for f, _ in valid_ticks]) ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') @@ -29,10 +30,10 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(0.5, frac[-1]) + ax2.set_xlim(0.5, xlim_right) nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) tick_positions = np.interp(nice_mmh, mean_q, frac) - valid = (tick_positions > 0.5) & (tick_positions <= frac[-1]) + valid = (tick_positions > 0.5) & (tick_positions <= xlim_right) positions = list(tick_positions[valid]) labels = [f'{v:g}' for v in nice_mmh[valid]] # Ensure a labelled tick at the left-hand edge of the visible range. diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index 4d207b33..e7c9446c 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -18,13 +18,14 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: """Apply logit x-axis with labelled percentile ticks (and a °C secondary axis).""" ax.set_xscale('logit') - xlim_right = frac[-1] + 1e-9 - ax.set_xlim(frac[0], xlim_right) - tick_fracs = [0.00001, 0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999, 0.99999] - tick_labels = ['0.001', '0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99', '99.999'] + xlim_left = max(float(frac[0]), 0.0001) + xlim_right = min(float(frac[-1]), 0.9999) + ax.set_xlim(xlim_left, xlim_right) + tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] + tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] # only show ticks within our data range valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) - if frac[0] <= f <= xlim_right] + if xlim_left <= f <= xlim_right] ax.set_xticks([f for f, _ in valid_ticks]) ax.set_xticklabels([l for _, l in valid_ticks]) ax.grid(True, alpha=0.3, which='both') @@ -32,7 +33,7 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) - if mean_q is not None: ax2 = ax.twiny() ax2.set_xscale('logit') - ax2.set_xlim(frac[0], xlim_right) + ax2.set_xlim(xlim_left, xlim_right) # The axis is logit in percentile, so a fixed list of round temps bunches # near the median while the tails get no labels; even_value_ticks picks a # nice step and spreads labels evenly across the whole logit axis. diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index 2220303e..b36b6a7e 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -24,11 +24,11 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, """Apply logit x-axis with labelled percentile ticks (and an m/s secondary axis).""" ax.set_xscale('logit') if xlim_left is None: - xlim_left = float(frac[0]) - xlim_right = frac[-1] + 1e-9 + xlim_left = max(float(frac[0]), 0.0001) + xlim_right = min(float(frac[-1]), 0.9999) ax.set_xlim(xlim_left, xlim_right) - tick_fracs = [0.00001, 0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999, 0.99999] - tick_labels = ['0.001', '0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99', '99.999'] + tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] + tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) if xlim_left <= f <= xlim_right] ax.set_xticks([f for f, _ in valid_ticks]) From e0d4a5d3fec331970679dc6f5d29dabb3c09acb2 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 15:57:54 +0200 Subject: [PATCH 49/51] consolidate logit axis --- src/hirad/eval/bias_by_percentile_common.py | 97 +++++++++++++++++++-- src/hirad/eval/bias_by_percentile_precip.py | 42 +++------ src/hirad/eval/bias_by_percentile_temp.py | 29 +----- src/hirad/eval/bias_by_percentile_wind.py | 36 +------- 4 files changed, 110 insertions(+), 94 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py index 2a71f86e..6f2c6208 100644 --- a/src/hirad/eval/bias_by_percentile_common.py +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Callable +from typing import Callable, Optional import matplotlib.pyplot as plt import matplotlib.ticker as mticker @@ -294,12 +294,6 @@ def _ensemble_stats(member_qs, member_exc, target_q, frac_percentiles): return spread, mae_entry, bias_entry, fbi_entry -def new_percentile_axes(percentile_values: np.ndarray): - """Create a figure/axes pair and return it together with the fractional x-values.""" - fig, ax = plt.subplots(figsize=(10, 6)) - return fig, ax, percentile_values / 100.0 - - def _round_sig(x: float, sig: int = 2) -> float: """Round *x* to *sig* significant figures (clean axis labels).""" if x == 0 or not np.isfinite(x): @@ -341,6 +335,95 @@ def _expit(z): return sample_pos[keep], rounded[keep] +# Percentile tick presets for a logit x-axis, as (fraction, label) pairs. +LOGIT_PERCENTILE_TICKS_FINE = ( + (0.0001, '0.01'), (0.001, '0.1'), (0.01, '1'), (0.1, '10'), + (0.50, '50'), (0.90, '90'), (0.99, '99'), (0.999, '99.9'), (0.9999, '99.99'), +) +LOGIT_PERCENTILE_TICKS_TAIL = ( + (0.50, '50'), (0.75, '75'), (0.90, '90'), + (0.99, '99'), (0.999, '99.9'), (0.9999, '99.99'), +) + + +def apply_logit_percentile_xaxis( + ax, + frac: np.ndarray, + mean_q: Optional[np.ndarray] = None, + *, + xlim_left: Optional[float] = None, + percentile_ticks=LOGIT_PERCENTILE_TICKS_FINE, + secondary_label: Optional[str] = None, + secondary_values: Optional[np.ndarray] = None, +) -> None: + """Configure a logit percentile x-axis with an optional physical-unit top axis. + + Shared by the bias / MAE / spread / FBI "by percentile" plots. The primary + axis carries labelled percentile ticks on a logit scale; when ``mean_q`` is + supplied, a secondary top axis labelled in the variable's physical units is + added. + + Parameters + ---------- + ax : matplotlib Axes to configure. + frac : fractional percentile positions (0-1) used as the x-coordinates. + mean_q : per-percentile mean target value; enables the secondary top axis. + xlim_left : left x-limit; defaults to ``max(frac[0], 1e-4)``. + percentile_ticks : (fraction, label) pairs for the primary percentile axis. + secondary_label : axis label for the top axis (e.g. ``'Mean target [m/s]'``). + secondary_values : fixed "nice" physical values to place on the top axis + (e.g. precipitation rates). When omitted, ticks are spread evenly along + the logit axis via :func:`even_value_ticks`. + """ + ax.set_xscale('logit') + if xlim_left is None: + xlim_left = max(float(frac[0]), 0.0001) + xlim_right = min(float(frac[-1]), 0.9999) + ax.set_xlim(xlim_left, xlim_right) + + visible = [(f, l) for f, l in percentile_ticks if xlim_left <= f <= xlim_right] + ax.set_xticks([f for f, _ in visible]) + ax.set_xticklabels([l for _, l in visible]) + ax.grid(True, alpha=0.3, which='both') + + if mean_q is None: + return + + ax2 = ax.twiny() + ax2.set_xscale('logit') + ax2.set_xlim(xlim_left, xlim_right) + + if secondary_values is not None: + secondary_values = np.asarray(secondary_values, dtype=float) + positions = np.interp(secondary_values, mean_q, frac) + labels = [f'{v:g}' for v in secondary_values] + else: + positions, values = even_value_ticks(frac, mean_q) + labels = [f'{v:g}' for v in values] + + positions = np.asarray(positions, dtype=float) + keep = (positions >= xlim_left) & (positions <= xlim_right) + positions = list(positions[keep]) + labels = [lab for lab, k in zip(labels, keep) if k] + + # Always label the left-hand edge of the visible range. + if not positions or positions[0] > xlim_left: + edge_val = _round_sig(float(np.interp(xlim_left, frac, mean_q))) + positions.insert(0, xlim_left) + labels.insert(0, f'{edge_val:g}') + + ax2.set_xticks(positions) + ax2.set_xticklabels(labels) + if secondary_label: + ax2.set_xlabel(secondary_label) + + +def new_percentile_axes(percentile_values: np.ndarray): + """Create a figure/axes pair and return it together with the fractional x-values.""" + fig, ax = plt.subplots(figsize=(10, 6)) + return fig, ax, percentile_values / 100.0 + + def plot_dict_curves(ax, frac, data_dict, labels, colors, lower_clip=None) -> list: """Plot per-mode curves and return the arrays spanning the plotted range.""" all_vals = [] diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 0e37ae2a..86610589 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -5,7 +5,9 @@ import numpy as np from hirad.eval.bias_by_percentile_common import ( + LOGIT_PERCENTILE_TICKS_TAIL, BiasByPercentileSpec, + apply_logit_percentile_xaxis, finalize_percentile_plot, new_percentile_axes, plot_dict_curves, @@ -14,36 +16,18 @@ from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli +_PRECIP_SECONDARY_MMH = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) + + def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: - """Apply logit x-axis with labelled percentile ticks (and a mm/h secondary axis).""" - ax.set_xscale('logit') - xlim_right = min(float(frac[-1]), 0.9999) - ax.set_xlim(0.5, xlim_right) - all_tick_fracs = [0.50, 0.75, 0.90, 0.99, 0.999, 0.9999] - all_tick_labels = ['50', '75', '90', '99', '99.9', '99.99'] - valid_ticks = [(f, l) for f, l in zip(all_tick_fracs, all_tick_labels) - if f <= xlim_right] - ax.set_xticks([f for f, _ in valid_ticks]) - ax.set_xticklabels([l for _, l in valid_ticks]) - ax.grid(True, alpha=0.3, which='both') - - if mean_q is not None: - ax2 = ax.twiny() - ax2.set_xscale('logit') - ax2.set_xlim(0.5, xlim_right) - nice_mmh = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) - tick_positions = np.interp(nice_mmh, mean_q, frac) - valid = (tick_positions > 0.5) & (tick_positions <= xlim_right) - positions = list(tick_positions[valid]) - labels = [f'{v:g}' for v in nice_mmh[valid]] - # Ensure a labelled tick at the left-hand edge of the visible range. - if not positions or positions[0] > 0.5: - left_mmh = np.interp(0.5, frac, mean_q) - positions.insert(0, 0.5) - labels.insert(0, f'{left_mmh:.2g}') - ax2.set_xticks(positions) - ax2.set_xticklabels(labels) - ax2.set_xlabel('Mean target [mm/h]') + """Apply logit percentile x-axis (upper tail) with a mm/h secondary axis.""" + apply_logit_percentile_xaxis( + ax, frac, mean_q, + xlim_left=0.5, + percentile_ticks=LOGIT_PERCENTILE_TICKS_TAIL, + secondary_label='Mean target [mm/h]', + secondary_values=_PRECIP_SECONDARY_MMH, + ) def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py index e7c9446c..f8f63338 100644 --- a/src/hirad/eval/bias_by_percentile_temp.py +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -6,7 +6,7 @@ from hirad.eval.bias_by_percentile_common import ( BiasByPercentileSpec, - even_value_ticks, + apply_logit_percentile_xaxis, finalize_percentile_plot, new_percentile_axes, plot_dict_curves, @@ -16,31 +16,8 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: - """Apply logit x-axis with labelled percentile ticks (and a °C secondary axis).""" - ax.set_xscale('logit') - xlim_left = max(float(frac[0]), 0.0001) - xlim_right = min(float(frac[-1]), 0.9999) - ax.set_xlim(xlim_left, xlim_right) - tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] - tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] - # only show ticks within our data range - valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) - if xlim_left <= f <= xlim_right] - ax.set_xticks([f for f, _ in valid_ticks]) - ax.set_xticklabels([l for _, l in valid_ticks]) - ax.grid(True, alpha=0.3, which='both') - - if mean_q is not None: - ax2 = ax.twiny() - ax2.set_xscale('logit') - ax2.set_xlim(xlim_left, xlim_right) - # The axis is logit in percentile, so a fixed list of round temps bunches - # near the median while the tails get no labels; even_value_ticks picks a - # nice step and spreads labels evenly across the whole logit axis. - tick_positions, tick_temps = even_value_ticks(frac, mean_q) - ax2.set_xticks(tick_positions) - ax2.set_xticklabels([f'{v:g}' for v in tick_temps]) - ax2.set_xlabel('Mean target [°C]') + """Apply logit percentile x-axis with a °C secondary axis.""" + apply_logit_percentile_xaxis(ax, frac, mean_q, secondary_label='Mean target [°C]') def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index b36b6a7e..bb5844cb 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -9,8 +9,7 @@ from hirad.eval.bias_by_percentile_common import ( BiasByPercentileSpec, - _round_sig, - even_value_ticks, + apply_logit_percentile_xaxis, finalize_percentile_plot, new_percentile_axes, plot_dict_curves, @@ -21,36 +20,9 @@ def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, xlim_left: float | None = None) -> None: - """Apply logit x-axis with labelled percentile ticks (and an m/s secondary axis).""" - ax.set_xscale('logit') - if xlim_left is None: - xlim_left = max(float(frac[0]), 0.0001) - xlim_right = min(float(frac[-1]), 0.9999) - ax.set_xlim(xlim_left, xlim_right) - tick_fracs = [0.0001, 0.001, 0.01, 0.1, 0.50, 0.90, 0.99, 0.999, 0.9999] - tick_labels = ['0.01', '0.1', '1', '10', '50', '90', '99', '99.9', '99.99'] - valid_ticks = [(f, l) for f, l in zip(tick_fracs, tick_labels) - if xlim_left <= f <= xlim_right] - ax.set_xticks([f for f, _ in valid_ticks]) - ax.set_xticklabels([l for _, l in valid_ticks]) - ax.grid(True, alpha=0.3, which='both') - - if mean_q is not None: - ax2 = ax.twiny() - ax2.set_xscale('logit') - ax2.set_xlim(xlim_left, xlim_right) - tick_positions, tick_speeds = even_value_ticks(frac, mean_q) - valid = (tick_positions >= xlim_left) & (tick_positions <= xlim_right) - positions = list(tick_positions[valid]) - speeds = list(tick_speeds[valid]) - # Ensure a labelled tick at the left-hand edge of the visible range. - left_speed = _round_sig(float(np.interp(xlim_left, frac, mean_q))) - if not positions or positions[0] > xlim_left: - positions.insert(0, xlim_left) - speeds.insert(0, left_speed) - ax2.set_xticks(positions) - ax2.set_xticklabels([f'{v:g}' for v in speeds]) - ax2.set_xlabel('Mean target [m/s]') + """Apply logit percentile x-axis with an m/s secondary axis.""" + apply_logit_percentile_xaxis(ax, frac, mean_q, xlim_left=xlim_left, + secondary_label='Mean target [m/s]') def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, From c41f6468865242b07c99e2904eb557dcd0a74615 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 22:11:18 +0200 Subject: [PATCH 50/51] limits --- src/hirad/eval/bias_by_percentile_precip.py | 3 +-- src/hirad/eval/bias_by_percentile_wind.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py index 86610589..822b254f 100644 --- a/src/hirad/eval/bias_by_percentile_precip.py +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -72,9 +72,8 @@ def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') ax.set_yscale('log') if all_vals: - ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.5 ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 - ax.set_ylim(max(ymin, 1e-3), max(ymax, 2.0)) + ax.set_ylim(max(ymin, 1e-3), 10.0) finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, xlabel, ylabel, title, out_path) diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py index bb5844cb..9db2d515 100644 --- a/src/hirad/eval/bias_by_percentile_wind.py +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -73,9 +73,8 @@ def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') ax.set_yscale('log') if all_vals: - ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.5 ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 - ax.set_ylim(max(ymin, 1e-3), max(ymax, 2.0)) + ax.set_ylim(max(ymin, 1e-3), 10.0) finalize_percentile_plot(ax, frac, lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), mean_q, xlabel, ylabel, title, out_path) From 89194e64016487bff4ba4b6760a29fb3ed99f626 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 17 Jun 2026 22:11:26 +0200 Subject: [PATCH 51/51] colorscale --- src/hirad/eval/plotting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index f96a247e..b1b0f699 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -143,13 +143,14 @@ def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000 def plot_map_temperature(values, filename, title='', grid_cfg=DEFAULT_GRID_CONFIG): """Plot 2m temperature data with Meteoswiss-style colormap.""" colors = [ - "#1A33CC", "#3366FF", "#4C97FF", "#4CA8FF", "#00CCFF", + "#3366FF", "#4C97FF", "#4CA8FF", "#00CCFF", "#DEE699", "#A6D473", "#6BBF4D", "#33AB26", "#009900", "#33B300", "#66CC00", "#99E600", "#CCFF00", "#FFFF00", "#FFCC00", "#FF9900", "#FF6600", "#FF3300", "#FF0000", "#EB00EB", "#FF40FF", "#FF80FF", "#FFBFFF", "#FFE0FF", "#FFF5FF", + "#D9D9D9", "#A6A6A6", "#737373", ] - bounds = [-9, -7, -5, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42] + bounds = np.arange(-8, 49, 2) cmap = ListedColormap(colors) norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False)