diff --git a/src/hirad/eval/bias_by_percentile_common.py b/src/hirad/eval/bias_by_percentile_common.py new file mode 100644 index 0000000..6f2c620 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_common.py @@ -0,0 +1,640 @@ +"""Shared machinery for the *bias / MAE / spread by percentile* plots.""" +import concurrent.futures +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + +import matplotlib.pyplot as plt +import matplotlib.ticker as mticker +import numpy as np +import torch +from scipy.ndimage import gaussian_filter + +from hirad.eval.eval_utils import ( + get_channel_indices, + load_generation_setup, + load_land_sea_mask, + resolve_ts_dir, +) + +# Deterministic modes plotted alongside the ensemble, and the ensemble styling. +DET_MODE_CFG = [ + ('baseline', 'Input', 'orange'), + ('regression-prediction', 'Regression Prediction', 'red'), +] +ENSEMBLE_LABEL = 'CorrDiff Ensemble (mean +/- 1 sigma)' +ENSEMBLE_COLOR = 'green' + +_RC_PARAMS = { + 'font.size': 16, + 'axes.titlesize': 18, + 'axes.labelsize': 16, + 'xtick.labelsize': 14, + 'ytick.labelsize': 14, + 'legend.fontsize': 14, +} + + +def to_flat(arr, conv: float, offset: float = 0.0) -> np.ndarray: + """Convert a (possibly xarray) field to a flat float numpy array, scaled and shifted.""" + return (np.asarray(getattr(arr, 'values', arr)) * conv + offset).ravel() + + +def _smooth2d(arr, sigma: float) -> np.ndarray: + """Apply an isotropic Gaussian low-pass to a 2D field (grid-point sigma).""" + a = np.asarray(getattr(arr, 'values', arr), dtype=np.float32) + return gaussian_filter(a, sigma=sigma, mode='nearest') + + +def build_all_histograms( + times: list, + ts_dirs: dict, + out_channels: tuple, + in_channels: tuple, + reduce_fn: Callable, + conv: float, + offset: float, + land_idx: np.ndarray, + hist_bins: np.ndarray, + log_interval: int, + logger: logging.Logger, + smoothing_sigma: float | None = None, +) -> tuple: + """Single serial pass over all timesteps, building histograms for every mode.""" + n_land = land_idx.size + n_bins = len(hist_bins) - 1 + det_modes = ('target', 'baseline', 'regression-prediction') + interior_edges = hist_bins[1:-1] + row_offsets = np.arange(n_land, dtype=np.intp) * n_bins + + det_counts: dict = {m: np.zeros((n_land, n_bins), dtype=np.int32) for m in det_modes} + member_counts: list | None = None + n_members: int | None = None + skip_det: set = set() + skip_preds = False + + def accumulate(counts, channel_arrs, smooth: bool = True): + if smoothing_sigma is not None and smooth: + channel_arrs = [_smooth2d(a, smoothing_sigma) for a in channel_arrs] + flats = [to_flat(a, conv, offset) for a in channel_arrs] + vals = reduce_fn(flats)[land_idx] + bin_idx = np.searchsorted(interior_edges, vals, side='right') + np.clip(bin_idx, 0, n_bins - 1, out=bin_idx) + counts.reshape(-1)[row_offsets + bin_idx] += 1 + + logger.info(f"Processing all modes in a single pass ({len(times)} timesteps)") + + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f" timestep {i + 1}/{len(times)}") + + ts_dir = ts_dirs[ts] + + for mode in det_modes: + if mode in skip_det: + continue + chans = in_channels if mode == 'baseline' else out_channels + try: + loaded = torch.load(ts_dir / f"{ts}-{mode}", weights_only=False) + except FileNotFoundError: + logger.warning(f" [{mode}] file not found at {ts}, skipping mode") + skip_det.add(mode) + continue + accumulate(det_counts[mode], [loaded[c] for c in chans], + smooth=mode != 'baseline') + + if not skip_preds: + try: + preds = torch.load(ts_dir / f"{ts}-predictions", weights_only=False) + except FileNotFoundError: + logger.warning(f" [predictions] file not found at {ts}, skipping ensemble") + skip_preds = True + continue + + if n_members is None: + n_members = int(preds.shape[0]) + member_counts = [ + np.zeros((n_land, n_bins), dtype=np.int32) + for _ in range(n_members) + ] + logger.info(f" Detected {n_members} ensemble members") + assert n_members is not None and member_counts is not None + for m in range(n_members): + accumulate(member_counts[m], [preds[m, c] for c in out_channels]) + + for mode in skip_det: + det_counts[mode] = None + + return det_counts, member_counts, n_members + + +def per_point_quantiles(pp_counts: np.ndarray, bin_edges: np.ndarray, + frac_percentiles: np.ndarray, + block_size: int = 8192) -> np.ndarray: + """Estimate per-row quantiles from per-grid-point histograms.""" + n_land, n_bins = pp_counts.shape + P = len(frac_percentiles) + result = np.empty((n_land, P), dtype=np.float32) + bin_centers = (0.5 * (bin_edges[:-1] + bin_edges[1:])).astype(np.float32) + frac_f64 = frac_percentiles.astype(np.float64) + + for start in range(0, n_land, block_size): + end = min(start + block_size, n_land) + blk = pp_counts[start:end] + B = end - start + + cdf = np.cumsum(blk, axis=1, dtype=np.float64) + totals = cdf[:, -1:] + cdf /= np.maximum(totals, 1.0) + + offset = (np.arange(B, dtype=np.float64) * 2.0)[:, None] + cdf += offset + queries = frac_f64[None, :] + offset + + idx = np.searchsorted(cdf.ravel(), queries.ravel(), side='left') + idx = idx.reshape(B, P) - (np.arange(B, dtype=np.intp)[:, None] * n_bins) + np.clip(idx, 0, n_bins - 1, out=idx) + result[start:end] = bin_centers[idx] + + return result + + +def per_point_exceedance(pp_counts: np.ndarray, bin_edges: np.ndarray, + thresholds: np.ndarray, + block_size: int = 8192) -> np.ndarray: + """Estimate per-row exceedance probabilities P(value > threshold). + + ``thresholds`` is ``(n_land, P)`` (one threshold per grid point and + percentile, e.g. the target's per-point quantiles). Returns an array of the + same shape with the fraction of this mode's mass strictly above each + threshold, read from the per-grid-point histogram CDF. This drives the + frequency bias index (FBI), a ratio of exceedance frequencies. + """ + n_land, P = thresholds.shape + n_bins = pp_counts.shape[1] + result = np.empty((n_land, P), dtype=np.float32) + edges_upper = bin_edges[1:].astype(np.float32) + + for start in range(0, n_land, block_size): + end = min(start + block_size, n_land) + blk = pp_counts[start:end] + B = end - start + + cdf = np.cumsum(blk, axis=1, dtype=np.float64) + totals = cdf[:, -1:] + cdf /= np.maximum(totals, 1.0) + + # Bin whose upper edge first reaches the threshold; the CDF up to and + # including that bin is the non-exceedance probability. + thr = thresholds[start:end] + bin_idx = np.empty((B, P), dtype=np.intp) + for j in range(P): + bin_idx[:, j] = np.searchsorted(edges_upper, thr[:, j], side='left') + np.clip(bin_idx, 0, n_bins - 1, out=bin_idx) + non_exc = np.take_along_axis(cdf, bin_idx, axis=1) + result[start:end] = np.clip(1.0 - non_exc, 0.0, 1.0) + + return result + + +def compute_quantiles( + det_counts: dict, + member_counts: list | None, + n_members: int | None, + active_det_modes: list, + has_ensemble: bool, + hist_bins: np.ndarray, + frac_percentiles: np.ndarray, + n_workers: int, +) -> tuple: + """Compute per-point quantiles for every mode in parallel, freeing counts as we go. + + Alongside the per-point quantiles, each prediction mode / member also gets + its per-point exceedance probability evaluated at the target's per-point + quantiles (the thresholds), which feeds the frequency bias index. + """ + members = member_counts if (has_ensemble and member_counts is not None and n_members is not None) else [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool: + def submit_q(counts): + return pool.submit(per_point_quantiles, counts, hist_bins, frac_percentiles) + + def submit_exc(counts, thresholds): + return pool.submit(per_point_exceedance, counts, hist_bins, thresholds) + + # Target quantiles first: they are the thresholds for every exceedance. + target_q = per_point_quantiles(det_counts['target'], hist_bins, frac_percentiles) + del det_counts['target'] + + fut_det_q = {mode: submit_q(det_counts[mode]) for mode in active_det_modes} + fut_det_exc = {mode: submit_exc(det_counts[mode], target_q) for mode in active_det_modes} + fut_member_q = [submit_q(c) for c in members] + fut_member_exc = [submit_exc(c, target_q) for c in members] + + det_results = {mode: fut_det_q[mode].result() for mode in active_det_modes} + det_exceedance = {mode: fut_det_exc[mode].result() for mode in active_det_modes} + for mode in active_det_modes: + del det_counts[mode] + member_qs = [f.result() for f in fut_member_q] + member_exc = [f.result() for f in fut_member_exc] + members.clear() + + return target_q, det_results, det_exceedance, member_qs, member_exc + + +def _ensemble_stats(member_qs, member_exc, target_q, frac_percentiles): + """Compute ensemble spread, and the MAE / bias / FBI plot entries. + + All quantities use the same local-then-averaged estimator: a statistic + is formed per grid point across ensemble members, then averaged over land. + FBI is the frequency bias index: per grid point and percentile ``p`` it is + ``P(pred > q_target(p)) / P(target > q_target(p)) = P(pred > q_target(p)) / (1 - p)``. + """ + n_members = len(member_qs) + sum_q = np.zeros(target_q.shape, dtype=np.float64) + sumsq_q = np.zeros(target_q.shape, dtype=np.float64) + sum_ae = np.zeros(target_q.shape, dtype=np.float64) + sumsq_ae = np.zeros(target_q.shape, dtype=np.float64) + sum_fbi = np.zeros(target_q.shape, dtype=np.float64) + sumsq_fbi = np.zeros(target_q.shape, dtype=np.float64) + target_q_f = target_q.astype(np.float64) + # Target exceedance probability P(target > q_target(p)) = 1 - p; guard p -> 1. + target_exc = np.maximum(1.0 - frac_percentiles.astype(np.float64), 1e-12)[None, :] + + for qm, exc in zip(member_qs, member_exc): + qm_f = qm.astype(np.float64) + sum_q += qm + sumsq_q += qm_f ** 2 + ae = np.abs(qm_f - target_q_f) + sum_ae += ae + sumsq_ae += ae ** 2 + fbi = exc.astype(np.float64) / target_exc + sum_fbi += fbi + sumsq_fbi += fbi ** 2 + + mean_q = sum_q / n_members + var_q = np.maximum(sumsq_q / n_members - mean_q ** 2, 0.0) + std_q = np.sqrt(var_q) + spread = std_q.mean(axis=0) + + mean_ae = sum_ae / n_members + var_ae = np.maximum(sumsq_ae / n_members - mean_ae ** 2, 0.0) + mae_entry = (mean_ae.mean(axis=0), np.sqrt(var_ae).mean(axis=0)) + + # Bias band, local-then-averaged: per grid point the member-mean bias is + # (mean_q - target_q) and the member-std of bias equals std_q (target is + # constant across members); both are then averaged over land. + bias_entry = ((mean_q - target_q_f).mean(axis=0), std_q.mean(axis=0)) + + mean_fbi = sum_fbi / n_members + var_fbi = np.maximum(sumsq_fbi / n_members - mean_fbi ** 2, 0.0) + fbi_entry = (np.nanmean(mean_fbi, axis=0), np.nanmean(np.sqrt(var_fbi), axis=0)) + + return spread, mae_entry, bias_entry, fbi_entry + + +def _round_sig(x: float, sig: int = 2) -> float: + """Round *x* to *sig* significant figures (clean axis labels).""" + if x == 0 or not np.isfinite(x): + return 0.0 + digits = sig - int(np.floor(np.log10(abs(x)))) - 1 + return round(x, digits) + + +def even_value_ticks(frac: np.ndarray, mean_q: np.ndarray, + target_ticks: int = 9) -> tuple: + """Pick secondary-axis ticks evenly spaced along the logit axis.""" + def _logit(p): + p = np.clip(p, 1e-9, 1 - 1e-9) + return np.log(p / (1.0 - p)) + + def _expit(z): + return 1.0 / (1.0 + np.exp(-z)) + + v_lo, v_hi = float(mean_q[0]), float(mean_q[-1]) + if v_hi <= v_lo: + return np.array([]), np.array([]) + + lp_lo, lp_hi = _logit(frac[0]), _logit(frac[-1]) + # Evenly spaced sample positions along the logit axis (corners included). + sample_lp = np.linspace(lp_lo, lp_hi, max(target_ticks, 2)) + sample_pos = _expit(sample_lp) + sample_val = np.interp(sample_pos, frac, mean_q) + + # Round labels to two significant figures, keeping the true axis position. + rounded = np.array([_round_sig(v) for v in sample_val]) + + # Drop consecutive duplicate labels (can happen where the value saturates), + # always keeping the first occurrence so both corners survive. + keep = [0] + for i in range(1, len(rounded)): + if rounded[i] != rounded[keep[-1]]: + keep.append(i) + keep = np.array(keep, dtype=np.intp) + return sample_pos[keep], rounded[keep] + + +# Percentile tick presets for a logit x-axis, as (fraction, label) pairs. +LOGIT_PERCENTILE_TICKS_FINE = ( + (0.0001, '0.01'), (0.001, '0.1'), (0.01, '1'), (0.1, '10'), + (0.50, '50'), (0.90, '90'), (0.99, '99'), (0.999, '99.9'), (0.9999, '99.99'), +) +LOGIT_PERCENTILE_TICKS_TAIL = ( + (0.50, '50'), (0.75, '75'), (0.90, '90'), + (0.99, '99'), (0.999, '99.9'), (0.9999, '99.99'), +) + + +def apply_logit_percentile_xaxis( + ax, + frac: np.ndarray, + mean_q: Optional[np.ndarray] = None, + *, + xlim_left: Optional[float] = None, + percentile_ticks=LOGIT_PERCENTILE_TICKS_FINE, + secondary_label: Optional[str] = None, + secondary_values: Optional[np.ndarray] = None, +) -> None: + """Configure a logit percentile x-axis with an optional physical-unit top axis. + + Shared by the bias / MAE / spread / FBI "by percentile" plots. The primary + axis carries labelled percentile ticks on a logit scale; when ``mean_q`` is + supplied, a secondary top axis labelled in the variable's physical units is + added. + + Parameters + ---------- + ax : matplotlib Axes to configure. + frac : fractional percentile positions (0-1) used as the x-coordinates. + mean_q : per-percentile mean target value; enables the secondary top axis. + xlim_left : left x-limit; defaults to ``max(frac[0], 1e-4)``. + percentile_ticks : (fraction, label) pairs for the primary percentile axis. + secondary_label : axis label for the top axis (e.g. ``'Mean target [m/s]'``). + secondary_values : fixed "nice" physical values to place on the top axis + (e.g. precipitation rates). When omitted, ticks are spread evenly along + the logit axis via :func:`even_value_ticks`. + """ + ax.set_xscale('logit') + if xlim_left is None: + xlim_left = max(float(frac[0]), 0.0001) + xlim_right = min(float(frac[-1]), 0.9999) + ax.set_xlim(xlim_left, xlim_right) + + visible = [(f, l) for f, l in percentile_ticks if xlim_left <= f <= xlim_right] + ax.set_xticks([f for f, _ in visible]) + ax.set_xticklabels([l for _, l in visible]) + ax.grid(True, alpha=0.3, which='both') + + if mean_q is None: + return + + ax2 = ax.twiny() + ax2.set_xscale('logit') + ax2.set_xlim(xlim_left, xlim_right) + + if secondary_values is not None: + secondary_values = np.asarray(secondary_values, dtype=float) + positions = np.interp(secondary_values, mean_q, frac) + labels = [f'{v:g}' for v in secondary_values] + else: + positions, values = even_value_ticks(frac, mean_q) + labels = [f'{v:g}' for v in values] + + positions = np.asarray(positions, dtype=float) + keep = (positions >= xlim_left) & (positions <= xlim_right) + positions = list(positions[keep]) + labels = [lab for lab, k in zip(labels, keep) if k] + + # Always label the left-hand edge of the visible range. + if not positions or positions[0] > xlim_left: + edge_val = _round_sig(float(np.interp(xlim_left, frac, mean_q))) + positions.insert(0, xlim_left) + labels.insert(0, f'{edge_val:g}') + + ax2.set_xticks(positions) + ax2.set_xticklabels(labels) + if secondary_label: + ax2.set_xlabel(secondary_label) + + +def new_percentile_axes(percentile_values: np.ndarray): + """Create a figure/axes pair and return it together with the fractional x-values.""" + fig, ax = plt.subplots(figsize=(10, 6)) + return fig, ax, percentile_values / 100.0 + + +def plot_dict_curves(ax, frac, data_dict, labels, colors, lower_clip=None) -> list: + """Plot per-mode curves and return the arrays spanning the plotted range.""" + all_vals = [] + for (_key, data), label, color in zip(data_dict.items(), labels, colors): + if isinstance(data, list): + arr = np.array(data) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + elif isinstance(data, tuple): + mean, std = (np.asarray(data[0]), np.asarray(data[1])) + else: + ax.plot(frac, data, color=color, label=label, linewidth=2, alpha=0.85) + all_vals.append(np.asarray(data)) + continue + + lower = mean - std if lower_clip is None else np.maximum(mean - std, lower_clip) + upper = mean + std + ax.plot(frac, mean, color=color, label=label, linewidth=2) + ax.fill_between(frac, lower, upper, color=color, alpha=0.2) + all_vals.extend([lower, upper]) + return all_vals + + +def finalize_percentile_plot(ax, frac, apply_xaxis, mean_q, xlabel, ylabel, + title, out_path, legend: bool = True) -> None: + """Apply shared axis styling (via *apply_xaxis*) and write the figure to disk.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + apply_xaxis(ax, frac, mean_q) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:g}')) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + if legend: + ax.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@dataclass +class BiasByPercentileSpec: + """Variable-specific configuration for :func:`run_bias_by_percentile`.""" + var_label: str # e.g. "2m temperature" — used in log messages + output_prefix: str # e.g. "temperature" — used in output file names + bias_title: str + mae_title: str + spread_title: str + bias_ylabel: str + mae_ylabel: str + spread_ylabel: str + percentile_values: np.ndarray + resolve_channels: Callable[[dict], tuple] # indices -> (ch_out, ch_in); raises ValueError + make_hist_bins: Callable[[dict], np.ndarray] + read_scaling: Callable[[dict], tuple] # cfg -> (conv, offset) + save_bias: Callable + save_mae: Callable + save_spread: Callable + # FBI is optional: leave ``save_fbi`` as ``None`` to skip the plot entirely. + save_fbi: Callable | None = None + fbi_title: str | None = None + fbi_ylabel: str | None = None + # Combines the (scaled) per-channel flat fields into the plotted scalar. + # Defaults to the single-channel identity; wind speed uses ``hypot``. + reduce_fn: Callable[[list], np.ndarray] = lambda flats: flats[0] + + +def run_bias_by_percentile(cfg: dict, spec: BiasByPercentileSpec) -> None: + """End-to-end driver shared by the temperature and precipitation scripts.""" + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(spec.output_prefix) + plt.rcParams.update(_RC_PARAMS) + + logger.info(f"Starting bias-by-percentile computation for {spec.var_label} over land") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Loaded {len(times)} timesteps to process") + + out_root = Path(generation_dir) + + indices = get_channel_indices(gen_cfg) + try: + ch_out, ch_in = spec.resolve_channels(indices) + except ValueError as exc: + logger.error(str(exc)) + return + # Channels may be a single index or a tuple (e.g. wind u/v); normalize to tuples. + out_channels = ch_out if isinstance(ch_out, tuple) else (ch_out,) + in_channels = ch_in if isinstance(ch_in, tuple) else (ch_in,) + logger.info(f"Channel indices - output: {out_channels}, input: {in_channels}") + + land_da = load_land_sea_mask( + cfg.get("land_sea_mask_path"), cfg.get("height") or 352, cfg.get("width") or 544 + ) + land_idx = np.flatnonzero(np.isfinite(land_da.values).ravel()) + n_land = land_idx.size + logger.info(f"{n_land} land grid points") + + hist_bins = spec.make_hist_bins(cfg) + log_interval = cfg.get("log_interval", 24) + conv, offset = spec.read_scaling(cfg) + + smoothing_km = cfg.get("smoothing_sigma_km") + grid_res_km = cfg.get("grid_res_km", 1.0) + smoothing_sigma = (smoothing_km / grid_res_km) if smoothing_km else None + if smoothing_sigma is not None: + logger.info( + f"Gaussian smoothing enabled: sigma = {smoothing_km} km " + f"/ {grid_res_km} km = {smoothing_sigma:.3g} grid points" + ) + suffix = f"_smoothed{int(smoothing_km)}km" if smoothing_km else "" + title_note = f" ({int(smoothing_km)} km smoothed)" if smoothing_km else "" + + percentile_values = spec.percentile_values + frac_percentiles = percentile_values / 100.0 + + logger.info(f"Resolving {len(times)} timestep directories ...") + ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} + + det_counts, member_counts, n_members = build_all_histograms( + times, ts_dirs, out_channels, in_channels, spec.reduce_fn, conv, offset, + land_idx, hist_bins, log_interval, logger, + smoothing_sigma=smoothing_sigma, + ) + + if det_counts.get('target') is None: + logger.error("No target data found; cannot compute bias.") + return + + active_det_modes = [m for m, _, _ in DET_MODE_CFG if det_counts.get(m) is not None] + has_ensemble = member_counts is not None and n_members is not None and n_members > 0 + + n_tasks = 1 + len(active_det_modes) + (n_members if has_ensemble else 0) + n_quant_workers = cfg.get("n_quant_workers", n_tasks) + + target_q, det_results, det_exceedance, member_qs, member_exc = compute_quantiles( + det_counts, member_counts, n_members, active_det_modes, has_ensemble, + hist_bins, frac_percentiles, n_quant_workers, + ) + + target_mean_q = target_q.mean(axis=0) + want_fbi = spec.save_fbi is not None + # FBI is the frequency bias index: P(pred > q_target(p)) / P(target > q_target(p)), + # where P(target > q_target(p)) = 1 - p by construction. Guard p -> 1. + target_exc = np.maximum(1.0 - frac_percentiles, 1e-12) + + bias_data: dict = {} + mae_data: dict = {} + fbi_data: dict = {} + labels: list = [] + colors: list = [] + + for mode, label, color in DET_MODE_CFG: + if mode not in det_results: + continue + pred_q = det_results.pop(mode) + pred_exc = det_exceedance.pop(mode) + bias_data[mode] = pred_q.mean(axis=0) - target_mean_q + mae_data[mode] = np.abs(pred_q - target_q).mean(axis=0) + if want_fbi: + fbi_data[mode] = np.nanmean(pred_exc / target_exc[None, :], axis=0) + labels.append(label) + colors.append(color) + + spread = None + if has_ensemble: + spread, mae_entry, bias_entry, fbi_entry = _ensemble_stats( + member_qs, member_exc, target_q, frac_percentiles, + ) + bias_data['predictions'] = bias_entry + mae_data['predictions'] = mae_entry + if want_fbi: + fbi_data['predictions'] = fbi_entry + labels.append(ENSEMBLE_LABEL) + colors.append(ENSEMBLE_COLOR) + + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "by_percentile" + output_path.mkdir(parents=True, exist_ok=True) + + fn = output_path / f'{spec.output_prefix}_bias_by_percentile{suffix}.png' + spec.save_bias( + bias_data, percentile_values, labels, colors, + title=spec.bias_title + title_note, xlabel='Percentile', ylabel=spec.bias_ylabel, + out_path=fn, mean_q=target_mean_q, + ) + logger.info(f"Bias-by-percentile plot saved: {fn}") + + fn_mae = output_path / f'{spec.output_prefix}_mae_by_percentile{suffix}.png' + spec.save_mae( + mae_data, percentile_values, labels, colors, + title=spec.mae_title + title_note, xlabel='Percentile', ylabel=spec.mae_ylabel, + out_path=fn_mae, mean_q=target_mean_q, + ) + logger.info(f"MAE-by-percentile plot saved: {fn_mae}") + + if want_fbi: + fn_fbi = output_path / f'{spec.output_prefix}_fbi_by_percentile{suffix}.png' + spec.save_fbi( + fbi_data, percentile_values, labels, colors, + title=spec.fbi_title + title_note, xlabel='Percentile', ylabel=spec.fbi_ylabel, + out_path=fn_fbi, mean_q=target_mean_q, + ) + logger.info(f"FBI-by-percentile plot saved: {fn_fbi}") + + if spread is not None: + fn_spread = output_path / f'{spec.output_prefix}_spread_by_percentile{suffix}.png' + spec.save_spread( + spread, percentile_values, + title=spec.spread_title + title_note, xlabel='Percentile', ylabel=spec.spread_ylabel, + out_path=fn_spread, mean_q=target_mean_q, + ) + logger.info(f"Spread-by-percentile plot saved: {fn_spread}") diff --git a/src/hirad/eval/bias_by_percentile_precip.py b/src/hirad/eval/bias_by_percentile_precip.py new file mode 100644 index 0000000..822b254 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_precip.py @@ -0,0 +1,119 @@ +""" +Plots bias / MAE / spread as a function of percentile for precipitation, using a +local-then-averaged estimator (see :mod:`hirad.eval.bias_by_percentile_common`). +""" +import numpy as np + +from hirad.eval.bias_by_percentile_common import ( + LOGIT_PERCENTILE_TICKS_TAIL, + BiasByPercentileSpec, + apply_logit_percentile_xaxis, + finalize_percentile_plot, + new_percentile_axes, + plot_dict_curves, + run_bias_by_percentile, +) +from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli + + +_PRECIP_SECONDARY_MMH = np.array([0.01, 0.1, 1.0, 10.0, 100.0]) + + +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: + """Apply logit percentile x-axis (upper tail) with a mm/h secondary axis.""" + apply_logit_percentile_xaxis( + ax, frac, mean_q, + xlim_left=0.5, + percentile_ticks=LOGIT_PERCENTILE_TICKS_TAIL, + secondary_label='Mean target [mm/h]', + secondary_values=_PRECIP_SECONDARY_MMH, + ) + + +def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a bias-by-percentile figure (symlog y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + plot_dict_curves(ax, frac, bias_data_dict, labels, colors) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + ax.set_yscale('symlog', linthresh=0.1, linscale=0.3) + ax.set_ylim(-10, 10) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_mae_by_percentile_plot(mae_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a MAE-by-percentile figure (log y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + plot_dict_curves(ax, frac, mae_data_dict, labels, colors, lower_clip=0) + ax.set_yscale('log') + ax.set_ylim(1e-5, 100) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_spread_by_percentile_plot(spread, percentile_values, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save an ensemble-spread-by-percentile figure (log y-axis, no legend).""" + _, ax, frac = new_percentile_axes(percentile_values) + ax.plot(frac, spread, color='green', linewidth=2) + ax.set_yscale('log') + ax.set_ylim(1e-5, 100) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path, legend=False) + + +def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a frequency-bias-index-by-percentile figure (log y-axis, ratio around 1).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=1e-3) + ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + ax.set_yscale('log') + if all_vals: + ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 + ax.set_ylim(max(ymin, 1e-3), 10.0) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, + mean_q, xlabel, ylabel, title, out_path) + + +def _resolve_channels(indices: dict) -> tuple: + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + return tp_out, tp_in + + +def _make_hist_bins(cfg: dict) -> np.ndarray: + n_bins = 500 + return np.concatenate([np.array([0.0]), np.logspace(-2, 3.2, n_bins)]) + + +SPEC = BiasByPercentileSpec( + var_label='precipitation', + output_prefix='precipitation', + bias_title='Precipitation Bias Over Land', + mae_title='Precipitation MAE Over Land', + spread_title='Precipitation Ensemble Spread Over Land', + fbi_title='Precipitation Frequency Bias Index Over Land', + bias_ylabel='Bias [mm/h]', + mae_ylabel='MAE [mm/h]', + spread_ylabel='Ensemble Spread [mm/h]', + fbi_ylabel='FBI (exceedance)', + percentile_values=make_percentile_values(), + resolve_channels=_resolve_channels, + make_hist_bins=_make_hist_bins, + read_scaling=lambda cfg: (cfg.get("conv_factor_hourly", 1.0), 0.0), + save_bias=save_bias_by_percentile_plot, + save_mae=save_mae_by_percentile_plot, + save_spread=save_spread_by_percentile_plot, + save_fbi=save_fbi_by_percentile_plot, +) + + +def main(cfg: dict) -> None: + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/bias_by_percentile_precip_smoothed.py b/src/hirad/eval/bias_by_percentile_precip_smoothed.py new file mode 100644 index 0000000..7cfe3b6 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_precip_smoothed.py @@ -0,0 +1,24 @@ +"""Smoothed precipitation bias / MAE / spread by percentile plots. + +Applies the same Gaussian low-pass filter to target and prediction fields before +the percentile histograms are accumulated, mitigating double-penalty effects in +the precipitation tails. +""" +from hirad.eval.bias_by_percentile_common import run_bias_by_percentile +from hirad.eval.bias_by_percentile_precip import SPEC +from hirad.eval.eval_utils import parse_eval_cli + + +SMOOTHING_SIGMA_KM = 20.0 +GRID_RES_KM = 1.0 + + +def main(cfg: dict) -> None: + cfg = dict(cfg) + cfg['smoothing_sigma_km'] = SMOOTHING_SIGMA_KM + cfg.setdefault('grid_res_km', GRID_RES_KM) + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) \ No newline at end of file diff --git a/src/hirad/eval/bias_by_percentile_temp.py b/src/hirad/eval/bias_by_percentile_temp.py new file mode 100644 index 0000000..f8f6333 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_temp.py @@ -0,0 +1,106 @@ +""" +Plots bias / MAE / spread as a function of percentile for 2m temperature, using a +local-then-averaged estimator (see :mod:`hirad.eval.bias_by_percentile_common`). +""" +import numpy as np + +from hirad.eval.bias_by_percentile_common import ( + BiasByPercentileSpec, + apply_logit_percentile_xaxis, + finalize_percentile_plot, + new_percentile_axes, + plot_dict_curves, + run_bias_by_percentile, +) +from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli + + +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None) -> None: + """Apply logit percentile x-axis with a °C secondary axis.""" + apply_logit_percentile_xaxis(ax, frac, mean_q, secondary_label='Mean target [°C]') + + +def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a bias-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, bias_data_dict, labels, colors) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + vcat = np.concatenate([np.asarray(v).ravel() for v in all_vals]) + vmin, vmax = float(np.nanmin(vcat)), float(np.nanmax(vcat)) + margin = max(abs(vmax - vmin) * 0.1, 0.05) + ax.set_ylim(vmin - margin, vmax + margin) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_mae_by_percentile_plot(mae_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a MAE-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, mae_data_dict, labels, colors, lower_clip=0) + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_spread_by_percentile_plot(spread, percentile_values, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save an ensemble-spread-by-percentile figure (linear y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + ax.plot(frac, spread, color='green', linewidth=2, label='Ensemble spread') + ymax_spread = float(np.nanmax(spread)) * 1.1 + if ymax_spread > 0: + ax.set_ylim(0, ymax_spread) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def _resolve_channels(indices: dict) -> tuple: + # Temperature channel: try '2t' first, then 't2m' + t2m_out = indices['output'].get('2t', indices['output'].get('t2m')) + t2m_in = indices['input'].get('2t', indices['input'].get('t2m', t2m_out)) + if t2m_out is None: + raise ValueError("Temperature channel (2t / t2m) not found in output channels.") + return t2m_out, t2m_in + + +def _make_hist_bins(cfg: dict) -> np.ndarray: + # Linear histogram bins in °C — fine enough to resolve sub-degree differences + n_bins = cfg.get("n_bins", 2000) + temp_min = cfg.get("temp_bin_min_celsius", -90.0) + temp_max = cfg.get("temp_bin_max_celsius", 65.0) + return np.linspace(temp_min, temp_max, n_bins + 1) + + +SPEC = BiasByPercentileSpec( + var_label='2m temperature', + output_prefix='temperature', + bias_title='2m Temperature Bias Over Land', + mae_title='2m Temperature MAE Over Land', + spread_title='2m Temperature Ensemble Spread Over Land', + bias_ylabel='Bias [°C]', + mae_ylabel='MAE [°C]', + spread_ylabel='Ensemble Spread [°C]', + percentile_values=make_percentile_values(), + resolve_channels=_resolve_channels, + make_hist_bins=_make_hist_bins, + # Default: convert Kelvin → °C (conv=1.0, offset=-273.15) + read_scaling=lambda cfg: (cfg.get("temp_conv_factor", 1.0), + cfg.get("temp_offset_celsius", -273.15)), + save_bias=save_bias_by_percentile_plot, + save_mae=save_mae_by_percentile_plot, + save_spread=save_spread_by_percentile_plot, +) + + +def main(cfg: dict) -> None: + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/bias_by_percentile_wind.py b/src/hirad/eval/bias_by_percentile_wind.py new file mode 100644 index 0000000..9db2d51 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_wind.py @@ -0,0 +1,130 @@ +""" +Plots bias / MAE / spread as a function of percentile for 10 m wind speed, using a +local-then-averaged estimator (see :mod:`hirad.eval.bias_by_percentile_common`). + +Wind speed is derived per grid point from the two surface wind components +(``10u``, ``10v``) as ``hypot(u, v)`` before histogramming. +""" +import numpy as np + +from hirad.eval.bias_by_percentile_common import ( + BiasByPercentileSpec, + apply_logit_percentile_xaxis, + finalize_percentile_plot, + new_percentile_axes, + plot_dict_curves, + run_bias_by_percentile, +) +from hirad.eval.eval_utils import make_percentile_values, parse_eval_cli + + +def _apply_logit_xaxis(ax, frac: np.ndarray, mean_q: np.ndarray | None = None, + xlim_left: float | None = None) -> None: + """Apply logit percentile x-axis with an m/s secondary axis.""" + apply_logit_percentile_xaxis(ax, frac, mean_q, xlim_left=xlim_left, + secondary_label='Mean target [m/s]') + + +def save_bias_by_percentile_plot(bias_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a bias-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, bias_data_dict, labels, colors) + ax.axhline(0.0, color='black', linewidth=0.8, linestyle='--') + if all_vals: + vcat = np.concatenate([np.asarray(v).ravel() for v in all_vals]) + vmin, vmax = float(np.nanmin(vcat)), float(np.nanmax(vcat)) + margin = max(abs(vmax - vmin) * 0.1, 0.05) + ax.set_ylim(vmin - margin, vmax + margin) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_mae_by_percentile_plot(mae_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a MAE-by-percentile figure (linear y-axis, data-driven limits).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, mae_data_dict, labels, colors, lower_clip=0) + if all_vals: + ymax = float(max(np.nanmax(v) for v in all_vals)) * 1.1 + if ymax > 0: + ax.set_ylim(0, ymax) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_spread_by_percentile_plot(spread, percentile_values, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save an ensemble-spread-by-percentile figure (linear y-axis).""" + _, ax, frac = new_percentile_axes(percentile_values) + ax.plot(frac, spread, color='green', linewidth=2, label='Ensemble spread') + ymax_spread = float(np.nanmax(spread)) * 1.1 + if ymax_spread > 0: + ax.set_ylim(0, ymax_spread) + finalize_percentile_plot(ax, frac, _apply_logit_xaxis, mean_q, + xlabel, ylabel, title, out_path) + + +def save_fbi_by_percentile_plot(fbi_data_dict, percentile_values, labels, colors, + title, xlabel, ylabel, out_path, mean_q=None) -> None: + """Save a frequency-bias-index-by-percentile figure (log y-axis, ratio around 1).""" + _, ax, frac = new_percentile_axes(percentile_values) + all_vals = plot_dict_curves(ax, frac, fbi_data_dict, labels, colors, lower_clip=1e-3) + ax.axhline(1.0, color='black', linewidth=0.8, linestyle='--') + ax.set_yscale('log') + if all_vals: + ymin = float(min(np.nanmin(v) for v in all_vals)) / 1.5 + ax.set_ylim(max(ymin, 1e-3), 10.0) + finalize_percentile_plot(ax, frac, + lambda ax_, frac_, mq: _apply_logit_xaxis(ax_, frac_, mq, xlim_left=0.10), + mean_q, xlabel, ylabel, title, out_path) + + +def _resolve_channels(indices: dict) -> tuple: + # Wind speed is derived from the two surface wind components. + u_out = indices['output'].get('10u') + v_out = indices['output'].get('10v') + if u_out is None or v_out is None: + raise ValueError("Wind components (10u / 10v) not found in output channels.") + u_in = indices['input'].get('10u', u_out) + v_in = indices['input'].get('10v', v_out) + return (u_out, v_out), (u_in, v_in) + + +def _make_hist_bins(cfg: dict) -> np.ndarray: + # Linear histogram bins in m/s — fine enough to resolve sub-m/s differences + n_bins = cfg.get("wind_n_bins", 1500) + speed_min = cfg.get("wind_bin_min_ms", 0.0) + speed_max = cfg.get("wind_bin_max_ms", 75.0) + return np.linspace(speed_min, speed_max, n_bins + 1) + + +SPEC = BiasByPercentileSpec( + var_label='10 m wind speed', + output_prefix='windspeed', + bias_title='10 m Wind Speed Bias Over Land', + mae_title='10 m Wind Speed MAE Over Land', + spread_title='10 m Wind Speed Ensemble Spread Over Land', + fbi_title='10 m Wind Speed Frequency Bias Index Over Land', + bias_ylabel='Bias [m/s]', + mae_ylabel='MAE [m/s]', + spread_ylabel='Ensemble Spread [m/s]', + fbi_ylabel='FBI (exceedance)', + percentile_values=make_percentile_values(), + resolve_channels=_resolve_channels, + make_hist_bins=_make_hist_bins, + read_scaling=lambda cfg: (cfg.get("wind_conv_factor", 1.0), 0.0), + reduce_fn=lambda flats: np.hypot(flats[0], flats[1]), + save_bias=save_bias_by_percentile_plot, + save_mae=save_mae_by_percentile_plot, + save_spread=save_spread_by_percentile_plot, + save_fbi=save_fbi_by_percentile_plot, +) + + +def main(cfg: dict) -> None: + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/bias_by_percentile_wind_smoothed.py b/src/hirad/eval/bias_by_percentile_wind_smoothed.py new file mode 100644 index 0000000..6502072 --- /dev/null +++ b/src/hirad/eval/bias_by_percentile_wind_smoothed.py @@ -0,0 +1,24 @@ +"""Smoothed 10 m wind speed bias / MAE / spread / FBI by percentile plots. + +Applies the same Gaussian low-pass filter to target and prediction fields before +the percentile histograms are accumulated, mitigating double-penalty effects in +the wind speed tails. +""" +from hirad.eval.bias_by_percentile_common import run_bias_by_percentile +from hirad.eval.bias_by_percentile_wind import SPEC +from hirad.eval.eval_utils import parse_eval_cli + + +SMOOTHING_SIGMA_KM = 20.0 +GRID_RES_KM = 1.0 + + +def main(cfg: dict) -> None: + cfg = dict(cfg) + cfg['smoothing_sigma_km'] = SMOOTHING_SIGMA_KM + cfg.setdefault('grid_res_km', GRID_RES_KM) + run_bias_by_percentile(cfg, SPEC) + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py b/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py new file mode 100644 index 0000000..5dc2ef6 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip_high_percentiles.py @@ -0,0 +1,243 @@ +""" +Plots the diurnal cycle of the all-hour 99th, 99.9th, and 99.99th percentiles of +precipitation, a somewhat reliable measure of the precipitation intensity. + +Each hour, member and type is treated separately, to conserve memory... but if the +period is long, this can still be a lot of data and thus an OOM error can occur. +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray as xr + +from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir + +def save_plot(hours, lines, labels, ylabel, title, out_path): + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(8,4)) + for data, label in zip(lines, labels): + if isinstance(data, tuple): # (mean, std) + mean, std = data + lower = np.maximum(np.array(mean) - std, 0) + upper = np.array(mean) + std + line, = plt.plot(hours, mean, label=label) + plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color()) + else: + plt.plot(hours, data, label=label) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig(out_path) + plt.close() + + +def main(cfg: dict): + # Setup logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for diurnal cycle of high percentiles of precipitation") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Loaded {len(times)} timesteps to process") + + # Output root + out_root = Path(generation_dir) + + # Find channel indices + indices = get_channel_indices(gen_cfg) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) + land_bool = land_mask.notnull().stack(space=('lat', 'lon')) + + percentile_configs = [ + (0.99, 'p99', '99th'), + (0.999, 'p999', '99.9th'), + (0.9999, 'p9999', '99.99th'), + ] + + # Storage for diurnal cycles: pct_mean[pct_key][mode], pct_std[pct_key]['prediction'] + pct_mean = {key: {} for _, key, _ in percentile_configs} + pct_std = {key: {} for _, key, _ in percentile_configs} + + conv_factor = cfg.get("conv_factor") + land_idx = land_bool.values # 1D boolean mask over flattened (lat, lon) + n_land = int(land_idx.sum()) + quantiles = np.array([q for q, _, _ in percentile_configs]) + logger.info(f"Land pixels: {n_land} / {land_idx.size} ({100 * n_land / land_idx.size:.1f}%)") + + # Group timesteps by hour-of-day so we never hold all timesteps in memory at once. + times_by_hour = {} + for ts in times: + hour = datetime.strptime(ts, "%Y%m%d-%H%M").hour + times_by_hour.setdefault(hour, []).append(ts) + sorted_hours = sorted(times_by_hour) + counts_per_hour = {h: len(times_by_hour[h]) for h in sorted_hours} + logger.info(f"Grouped {len(times)} timesteps into {len(sorted_hours)} hours; timesteps/hour: {counts_per_hour}") + + def land_values(arr): + """Flatten a (lat, lon) array and keep only land pixels as float32.""" + return np.asarray(arr, dtype=np.float32).reshape(-1)[land_idx] + + # -- Process target, baseline and regression-prediction -- + # For each hour we collect only land pixels across that hour's timesteps, then take + # the spatial-mean of the per-hour quantiles. Memory scales with one hour's data. + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + ch = tp_out if mode in ['target', 'regression-prediction'] else tp_in + + per_hour_pct = {key: [] for _, key, _ in percentile_configs} + failed = False + for hi, hour in enumerate(sorted_hours): + n_ts = counts_per_hour[hour] + logger.info( + f"[{mode}] hour {hour:02d} ({hi + 1}/{len(sorted_hours)}): " + f"loading {n_ts} timesteps" + ) + hour_vals = [] + try: + for ts in times_by_hour[hour]: + data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[ch] + hour_vals.append(land_values(data) * conv_factor) + except Exception as exc: + logger.error(f"Error loading data for mode {mode} at hour {hour:02d}: {exc!r}. Skipping mode.") + failed = True + break + + stacked = np.stack(hour_vals, axis=0) # [n_times_this_hour, n_land] + del hour_vals + logger.info( + f"[{mode}] hour {hour:02d}: stacked array {stacked.shape}" + ) + # quantile over time, then mean over space -> one value per quantile for this hour + q_vals = np.nanquantile(stacked, quantiles, axis=0) # [n_q, n_land] + del stacked + q_means = np.nanmean(q_vals, axis=1) # [n_q] + for (_, key, _), val in zip(percentile_configs, q_means): + per_hour_pct[key].append(val) + + if failed: + continue + + for _, key, _ in percentile_configs: + pct_mean[key][mode] = xr.DataArray( + np.array(per_hour_pct[key]), dims=['hour'], coords={'hour': sorted_hours} + ) + logger.info(f"Finished mode: {mode}") + + # -- Predictions: compute per hour per member, then mean+std across members -- + logger.info("Processing predictions") + + pred_hour_mean = {key: [] for _, key, _ in percentile_configs} + pred_hour_std = {key: [] for _, key, _ in percentile_configs} + + for hi, hour in enumerate(sorted_hours): + n_ts = counts_per_hour[hour] + logger.info( + f"[predictions] hour {hour:02d} ({hi + 1}/{len(sorted_hours)}): " + f"loading {n_ts} timesteps" + ) + # Accumulate land pixels per member: member -> list of [n_land] arrays over this hour's times + member_vals = None + for ts in times_by_hour[hour]: + preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon] + tp_data = np.asarray(preds[:, tp_out], dtype=np.float32) * conv_factor # [n_members, lat, lon] + n_members = tp_data.shape[0] + if member_vals is None: + member_vals = [[] for _ in range(n_members)] + flat = tp_data.reshape(n_members, -1)[:, land_idx] # [n_members, n_land] + for m in range(n_members): + member_vals[m].append(flat[m]) + del preds, tp_data, flat + + n_members = len(member_vals) + logger.info( + f"[predictions] hour {hour:02d}: {n_members} members x {n_ts} timesteps" + ) + + # Per-member: quantile over this hour's timesteps, then spatial mean -> [n_q] per member + per_member_q = [] + for vals in member_vals: + stacked = np.stack(vals, axis=0) # [n_times_this_hour, n_land] + q_vals = np.nanquantile(stacked, quantiles, axis=0) # [n_q, n_land] + del stacked + per_member_q.append(np.nanmean(q_vals, axis=1)) # [n_q] + per_member_arr = np.stack(per_member_q, axis=0) # [n_members, n_q] + del member_vals + + q_mean_over_members = per_member_arr.mean(axis=0) # [n_q] + q_std_over_members = per_member_arr.std(axis=0) # [n_q] + for i, (_, key, _) in enumerate(percentile_configs): + pred_hour_mean[key].append(q_mean_over_members[i]) + pred_hour_std[key].append(q_std_over_members[i]) + + logger.info("Finished predictions") + + for _, key, _ in percentile_configs: + pct_mean[key]['prediction'] = xr.DataArray( + np.array(pred_hour_mean[key]), dims=['hour'], coords={'hour': sorted_hours} + ) + pct_std[key]['prediction'] = xr.DataArray( + np.array(pred_hour_std[key]), dims=['hour'], coords={'hour': sorted_hours} + ) + + # Prepare cyclic lists for plotting + def cycle_fn(x): + vals = x.values.tolist() + return vals + [vals[0]] + + logger.info("Preparing data for plotting") + hrs_c = list(range(24)) + [24] + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + + for _, key, label in percentile_configs: + m = pct_mean[key] + s = pct_std[key] + + lines = [] + plot_labels = [] + + if 'target' in m: + lines.append(cycle_fn(m['target'])) + plot_labels.append('Target') + if 'baseline' in m: + lines.append(cycle_fn(m['baseline'])) + plot_labels.append('Input') + if 'prediction' in m: + lines.append((cycle_fn(m['prediction']), cycle_fn(s['prediction']))) + plot_labels.append(f'CorrDiff {label} Pct ± Std') + if 'regression-prediction' in m: + lines.append(cycle_fn(m['regression-prediction'])) + plot_labels.append('Regression Prediction') + + fn = output_path / f'diurnal_cycle_precip_{key}_percentile.png' + save_plot( + hrs_c, + lines, + plot_labels, + 'Precipitation (mm/day)', + f'Diurnal Cycle of {label}-Percentile Precipitation', + fn + ) + logger.info(f"Plot saved: {fn}") + +if __name__ == '__main__': + main(parse_eval_cli()) \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip_maps.py b/src/hirad/eval/diurnal_cycle_precip_maps.py new file mode 100644 index 0000000..501db8a --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip_maps.py @@ -0,0 +1,152 @@ +""" +Plots mean precipitation and wet-hour fraction maps for each hour of the diurnal cycle. + +For each hour (00-23 UTC), all timesteps with that hour are averaged into +a single spatial map, producing 24 maps per source per variable: + - mean precipitation (mm/h) + - wet-hour fraction (% of timesteps where precip > wet_threshold) +""" +import logging +from collections import defaultdict +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +from hirad.eval.eval_utils import ( + get_channel_indices, + grid_cfg_from_cfg, + load_generation_setup, + parse_eval_cli, + resolve_ts_dir, +) +from hirad.eval.plotting import plot_map, plot_map_precipitation + + +def main(cfg: dict) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting diurnal cycle precipitation map generation") + + grid_cfg = grid_cfg_from_cfg(cfg) + + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + + out_root = Path(generation_dir) + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") / "diurnal_cycle_precip_maps" + output_path.mkdir(parents=True, exist_ok=True) + + indices = get_channel_indices(gen_cfg) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + # conv_factor_hourly converts ERA5 accumulated precip (m) to mm/h + conv_factor = cfg.get("conv_factor_hourly", 1000) + wet_threshold = cfg.get("wet_threshold", 0.1) # mm/h + log_interval = cfg.get("log_interval", 24) + + logger.info(f"TP channel indices — output: {tp_out}, input: {tp_in}") + logger.info(f"Wet-hour threshold: {wet_threshold} mm/h") + logger.info(f"Processing {len(times)} timesteps") + + # Group timestep strings by UTC hour (no data loaded yet) + times_by_hour: dict[int, list[str]] = defaultdict(list) + for ts in times: + times_by_hour[datetime.strptime(ts, "%Y%m%d-%H%M").hour].append(ts) + + hours_present = sorted(times_by_hour) + logger.info(f"Hours present: {hours_present}") + + # Detect whether regression-prediction files exist (check first timestep) + first_ts = times[0] + first_dir = resolve_ts_dir(out_root, first_ts) / first_ts + has_regression = (first_dir / f"{first_ts}-regression-prediction").exists() + + sources = [ + ('target', 'Target', tp_out), + ('baseline', 'Input', tp_in), + ('predictions', 'CorrDiff Ensemble Mean', tp_out), + ] + if has_regression: + sources.append(('regression-prediction', 'Regression Prediction', tp_out)) + + for source_key, source_label, tp_idx in sources: + (output_path / source_key).mkdir(parents=True, exist_ok=True) + (output_path / f"{source_key}_wethour").mkdir(parents=True, exist_ok=True) + + H, W = cfg.get("height"), cfg.get("width") + + for hour in hours_present: + hour_times = times_by_hour[hour] + logger.info(f"Hour {hour:02d}:00 UTC — loading {len(hour_times)} timesteps") + + # Accumulators for this hour only + sums = {s[0]: np.zeros((H, W), dtype=np.float64) for s in sources} + wet_sums = {s[0]: np.zeros((H, W), dtype=np.float64) for s in sources} + count = 0 + + for idx, ts in enumerate(hour_times, 1): + ts_dir = resolve_ts_dir(out_root, ts) / ts + + target = np.asarray(torch.load(ts_dir / f"{ts}-target", weights_only=False)[tp_out]) * conv_factor + sums['target'] += target + wet_sums['target'] += (target > wet_threshold).astype(np.float64) + + baseline = np.asarray(torch.load(ts_dir / f"{ts}-baseline", weights_only=False)[tp_in]) * conv_factor + sums['baseline'] += baseline + wet_sums['baseline'] += (baseline > wet_threshold).astype(np.float64) + + preds = np.asarray(torch.load(ts_dir / f"{ts}-predictions", weights_only=False)[:, tp_out]) * conv_factor + pred_mean = preds.mean(axis=0) + sums['predictions'] += pred_mean + # wet-hour frequency: fraction of members that are wet, then average over timesteps + wet_sums['predictions'] += (preds > wet_threshold).mean(axis=0) + + if has_regression: + reg = np.asarray(torch.load(ts_dir / f"{ts}-regression-prediction", weights_only=False)[tp_out]) * conv_factor + sums['regression-prediction'] += reg + wet_sums['regression-prediction'] += (reg > wet_threshold).astype(np.float64) + + count += 1 + if idx % log_interval == 0 or idx == len(hour_times): + logger.info(f" Loaded {idx}/{len(hour_times)} ({ts})") + + # Plot and immediately discard the accumulators + for source_key, source_label, _ in sources: + mean_map = sums[source_key] / count + title = f"{source_label} — Mean Diurnal Precip {hour:02d}:00 UTC (n={count})" + out_file = str(output_path / source_key / f"diurnal_mean_precip_{source_key}_{hour:02d}h") + plot_map_precipitation( + mean_map, out_file, + title=title, + threshold=0.01, + rfac=1.0, + grid_cfg=grid_cfg, + ) + + wet_map = wet_sums[source_key] / count * 100.0 # percent + wh_title = f"{source_label} — Wet-Hour Fraction {hour:02d}:00 UTC (n={count})" + wh_out_file = str(output_path / f"{source_key}_wethour" / f"diurnal_wethour_{source_key}_{hour:02d}h") + plot_map( + wet_map, wh_out_file, + title=wh_title, + label="Wet-Hour Fraction [%]", + vmin=0, vmax=30, + cmap="PuBu", + extend="max", + grid_cfg=grid_cfg, + ) + + del sums, wet_sums + logger.info(f"Hour {hour:02d}:00 UTC — maps saved") + + logger.info("Diurnal cycle precipitation maps complete.") + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py index 1bc6209..a9e38e7 100644 --- a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -10,6 +10,9 @@ from hirad.eval.eval_utils import concat_and_group_diurnal, get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir +ALLHOUR_THRESHOLDS = [0.1, 1.0, 10.0, 100.0] # mm/h + + def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) plt.figure(figsize=(8,4)) @@ -31,6 +34,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): plt.savefig(out_path) plt.close() + def main(cfg: dict): # Setup logging logging.basicConfig(level=logging.INFO) @@ -61,7 +65,10 @@ def main(cfg: dict): # Prepare lists to collect DataArrays target_precip, baseline_precip, pred_precip, mean_pred_precip = [], [], [], [] - target_wet, baseline_wet, pred_wet, mean_pred_wet = [], [], [], [] + wet_target = {thr: [] for thr in ALLHOUR_THRESHOLDS} + wet_baseline = {thr: [] for thr in ALLHOUR_THRESHOLDS} + wet_pred = {thr: [] for thr in ALLHOUR_THRESHOLDS} + wet_regpred = {thr: [] for thr in ALLHOUR_THRESHOLDS} # Collect data for idx, ts in enumerate(times, 1): @@ -95,12 +102,13 @@ def main(cfg: dict): if mean_pred is not None: mean_pred_precip.append(da_mean_pred.mean(dim=("lat","lon")).assign_coords(time=dt)) - # Wet-hour fraction, i.e., freq(precip) > wet_threshold - target_wet.append(((da_target / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt))) - baseline_wet.append(((da_baseline / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt))) - pred_wet.append(((da_preds / 24> cfg.get("wet_threshold")).mean(dim=("lat","lon")).assign_coords(time=dt))) - if mean_pred is not None: - mean_pred_wet.append(((da_mean_pred / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt))) + # Wet-hour fraction per threshold + for thr in ALLHOUR_THRESHOLDS: + wet_target[thr].append((da_target / 24 > thr).mean().assign_coords(time=dt)) + wet_baseline[thr].append((da_baseline / 24 > thr).mean().assign_coords(time=dt)) + wet_pred[thr].append((da_preds / 24 > thr).mean(dim=('lat', 'lon')).assign_coords(time=dt)) + if mean_pred is not None: + wet_regpred[thr].append((da_mean_pred / 24 > thr).mean().assign_coords(time=dt)) if idx % cfg.get("log_interval") == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") @@ -112,12 +120,6 @@ def main(cfg: dict): if mean_pred_precip: amount_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_precip) - wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to obtain percentages - wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0) - wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0) - if mean_pred_wet: - wet_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wet, scale=100.0) - # Generate plots output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") output_path.mkdir(parents=True, exist_ok=True) @@ -130,15 +132,27 @@ def main(cfg: dict): 'Diurnal Cycle of Precip Amount', output_path / 'diurnal_cycle_precip_amount.png' ) - save_plot( - wet_target_mean.hour, - [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if mean_pred_wet else [wet_target_mean, wet_baseline_mean, wet_pred_mean], - [None, None, wet_pred_std, None] if mean_pred_wet else [None, None, wet_pred_std], - ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wet else ['Target','Input','CorrDiff ± Std(Members)'], - 'Wet-Hour Fraction [%]', - 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', - output_path / 'diurnal_cycle_precip_wethours.png' - ) + + # Diurnal cycle of wet-hours, one plot per threshold + for thr in ALLHOUR_THRESHOLDS: + wet_target_mean, _ = concat_and_group_diurnal(wet_target[thr], scale=100.0) + wet_baseline_mean, _ = concat_and_group_diurnal(wet_baseline[thr], scale=100.0) + wet_pred_mean, wet_pred_std = concat_and_group_diurnal(wet_pred[thr], is_member=True, scale=100.0) + has_regpred = bool(wet_regpred[thr]) + if has_regpred: + wet_mean_pred_mean, _ = concat_and_group_diurnal(wet_regpred[thr], scale=100.0) + + fn_wet = output_path / f'diurnal_cycle_precip_wethours_{thr:g}mmh.png' + save_plot( + wet_target_mean.hour, + [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if has_regpred else [wet_target_mean, wet_baseline_mean, wet_pred_mean], + [None, None, wet_pred_std, None] if has_regpred else [None, None, wet_pred_std], + ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if has_regpred else ['Target','Input','CorrDiff ± Std(Members)'], + 'Wet-Hour Fraction [%]', + f'Diurnal Cycle of Wet-Hours (>{thr:g} mm/h)', + fn_wet, + ) + logger.info(f"Diurnal wet-hour plot saved: {fn_wet}") logger.info("Plots saved.") diff --git a/src/hirad/eval/diurnal_cycle_precip_p99.py b/src/hirad/eval/diurnal_cycle_precip_p99.py deleted file mode 100644 index ecd257c..0000000 --- a/src/hirad/eval/diurnal_cycle_precip_p99.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -Plots the diurnal cycle of the all-hour 99th percentile of -precipitation, a somewhat reliable measure of the precipitation intensity. - -Each hour, member and type is treaded separately, to conserve memory... but if the -period is long, this can still be a lot of data and thus an OOM error can occur. -""" -import logging -from datetime import datetime -from pathlib import Path - -import hydra -import matplotlib.pyplot as plt -import numpy as np -import torch -import xarray as xr - -from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir - - -def save_plot(hours, lines, labels, ylabel, title, out_path): - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - plt.figure(figsize=(8,4)) - for data, label in zip(lines, labels): - if isinstance(data, tuple): # (mean, std) - mean, std = data - lower = np.maximum(np.array(mean) - std, 0) - upper = np.array(mean) + std - line, = plt.plot(hours, mean, label=label) - plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color()) - else: - plt.plot(hours, data, label=label) - plt.xlabel('Hour (UTC)') - plt.xticks(range(0,25,3)) - plt.xlim(0,24) - plt.ylabel(ylabel) - plt.title(title) - plt.grid(True) - plt.legend() - plt.tight_layout() - plt.savefig(out_path) - plt.close() - - -def main(cfg: dict): - # Setup logging - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - logger.info("Starting computation for diurnal cycle of 99th-percentile of precipitation") - try: - generation_dir, gen_cfg, times = load_generation_setup(cfg) - except ValueError as exc: - logger.error(str(exc)) - return - logger.info(f"Loaded {len(times)} timesteps to process") - - # Output root - out_root = Path(generation_dir) - - # Find channel indices - indices = get_channel_indices(gen_cfg) - tp_out = indices['output']['tp'] - tp_in = indices['input'].get('tp', tp_out) - logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - - # Land-sea mask - land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) - - # Storage for diurnal cycles - pct99_mean = {} - pct99_std = {} - - # -- Process target and baseline -- - for mode in ['target', 'baseline', 'regression-prediction']: - logger.info(f"Processing mode: {mode}") - - data_list = [] - try: - for ts in times: - data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor") - data_list.append(data) - except: - logger.error(f"Error loading data for mode {mode}. Skipping.") - continue - - da = xr.DataArray( - np.stack(data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times], - 'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']} - ) - - # Select only land pixels to avoid all-NaN slices in quantile - land_bool = land_mask.notnull().stack(space=('lat', 'lon')) - da_land = da.stack(space=('lat', 'lon')).isel(space=land_bool.values) - - # Group by hour and compute 99th percentile over time, then spatial mean - hourly_p99 = da_land.groupby('time.hour').quantile(0.99, dim='time') - pct99_mean[mode] = hourly_p99.mean(dim='space') - - # -- Predictions: compute per hour per member, then mean+std across members -- - logger.info("Processing predictions") - - # Load all prediction data at once into xarray - pred_data_list = [] - for ts in times: - preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor") # [n_members, n_channels, lat, lon] - tp_data = preds[:, tp_out] # [n_members, lat, lon] - tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'], - coords={'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']}) - pred_data_list.append(tp_da) - - pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] - pred_da = pred_da.assign_coords({ - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - }) - pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') - - # Select only land pixels to avoid all-NaN slices in quantile - land_bool = land_mask.notnull().stack(space=('lat', 'lon')) - pred_da_land = pred_da.stack(space=('lat', 'lon')).isel(space=land_bool.values) - - logger.info('Calculating 99th percentile for predictions') - # Group by hour, compute 99th percentile across time, then spatial mean over land - hourly_p99_by_member = pred_da_land.groupby('time.hour').quantile(0.99, dim='time').mean(dim='space') - - # Store ensemble statistics as xarray DataArrays - pct99_mean['prediction'] = hourly_p99_by_member.mean(dim='member') - pct99_std['prediction'] = hourly_p99_by_member.std(dim='member') - - # Prepare cyclic lists for plotting - def cycle_fn(x): - return x.values.tolist() + [x.values.tolist()[0]] - - logger.info("Preparing data for plotting") - hrs_c = list(range(24)) + [0 + 24] - pct99_lines = [ - cycle_fn(pct99_mean['target']), - cycle_fn(pct99_mean['baseline']), - ( - cycle_fn(pct99_mean['prediction']), - cycle_fn(pct99_std['prediction']) - ) - ] - if 'regression-prediction' in pct99_mean: - pct99_lines.append(cycle_fn(pct99_mean['regression-prediction'])) - - # Plot combined diurnal 99th-percentile cycle - labels = ['Target', 'Input', 'CorrDiff 99th Pct ± Std', 'Regression Prediction'] if 'regression-prediction' in pct99_mean else ['Target', 'Input', 'CorrDiff 99th Pct ± Std'] - output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") - output_path.mkdir(parents=True, exist_ok=True) - fn = output_path / 'diurnal_cycle_precip_99th_percentile.png' - save_plot( - hrs_c, - pct99_lines, - labels, - 'Precipitation (mm/day)', - 'Diurnal Cycle of 99th-Percentile Precipitation', - fn - ) - logger.info(f"Combined plot saved: {fn}") - -if __name__ == '__main__': - main(parse_eval_cli()) \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_temp.py b/src/hirad/eval/diurnal_cycle_temp.py new file mode 100644 index 0000000..940fb65 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_temp.py @@ -0,0 +1,125 @@ +import logging +from datetime import datetime +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray as xr + +from hirad.eval.eval_utils import concat_and_group_diurnal, get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir + +def main(cfg: dict): + # Initialize + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for diurnal cycle of 2m temperature") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + logger.info(f"Loaded {len(times)} timesteps to process") + + # Indices for channels + indices = get_channel_indices(gen_cfg) + out_ch = indices['output'] + in_ch = indices['input'] + + # Temperature channel (try '2t' first, fallback to 't2m') + t2m_out = out_ch.get('2t', out_ch.get('t2m')) + t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) + + # Output path + out_root = Path(generation_dir) + def load(ts, fn): + return torch.load(resolve_ts_dir(out_root, ts) / ts / fn, weights_only=False) + + # Land-sea mask + land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) + + # Prepare lists to collect DataArrays + target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], [] + + def mean_over_land(data, dims, coords, time_coord): + da = xr.DataArray(data, dims=dims, coords=coords) * land_mask + return da.mean(dim=("lat","lon")).assign_coords(time=time_coord) + + # Loop over timestamps + for idx, ts in enumerate(times, 1): + dt = datetimes[idx-1] + + # Load data + target = load(ts, f"{ts}-target") + baseline = load(ts, f"{ts}-baseline") + predictions = load(ts, f"{ts}-predictions") + try: + regression_pred = load(ts, f"{ts}-regression-prediction") + except: + regression_pred = None + + # Process temperature (convert to Celsius) + target_temp.append(mean_over_land( + target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + baseline_temp.append(mean_over_land( + baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt)) + pred_temp.append(mean_over_land( + predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), + {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + if regression_pred is not None: + mean_pred_temp.append(mean_over_land( + regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + + if idx % cfg.get("log_interval") == 0 or idx == len(times): + logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") + + # Compute diurnal means and stds + temp_target_mean, _ = concat_and_group_diurnal(target_temp) + temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) + temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) + if mean_pred_temp: + temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp) + + def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + + data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean] + labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['Target', 'Input', 'CorrDiff ± Std(Members)'] + stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std] + + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + save_plot( + temp_target_mean.hour, + data, + stds, + labels, + '2m Temperature [°C]', + 'Diurnal Cycle of 2m Temperature', + output_path / 'diurnal_cycle_2t.png' + ) + + logger.info("Plots saved.") + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_wind.py similarity index 70% rename from src/hirad/eval/diurnal_cycle_temp_wind.py rename to src/hirad/eval/diurnal_cycle_wind.py index c9b63f8..ea056c4 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_wind.py @@ -2,7 +2,6 @@ from datetime import datetime from pathlib import Path -import hydra import matplotlib.pyplot as plt import numpy as np import torch @@ -15,7 +14,7 @@ def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting computation for diurnal cycles of 2m temperature and windspeed") + logger.info("Starting computation for diurnal cycle of windspeed") try: generation_dir, gen_cfg, times = load_generation_setup(cfg) except ValueError as exc: @@ -28,15 +27,11 @@ def main(cfg: dict): indices = get_channel_indices(gen_cfg) out_ch = indices['output'] in_ch = indices['input'] - - # Temperature channel (try '2t' first, fallback to 't2m') - t2m_out = out_ch.get('2t', out_ch.get('t2m')) - t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) - + # Wind channels u_out = out_ch['10u'] u_in = in_ch.get('10u', u_out) - v_out = out_ch['10v'] + v_out = out_ch['10v'] v_in = in_ch.get('10v', v_out) # Output path @@ -48,7 +43,6 @@ def load(ts, fn): land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width")) # Prepare lists to collect DataArrays - target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], [] target_wind, baseline_wind, pred_wind, mean_pred_wind = [], [], [], [] def mean_over_land(data, dims, coords, time_coord): @@ -68,19 +62,6 @@ def mean_over_land(data, dims, coords, time_coord): except: regression_pred = None - # Process temperature (convert to Celsius) - target_temp.append(mean_over_land( - target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) - baseline_temp.append(mean_over_land( - baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt)) - pred_temp.append(mean_over_land( - predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), - {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) - if regression_pred is not None: - mean_pred_temp.append(mean_over_land( - regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) - - # Process wind speed target_wind.append(mean_over_land( np.hypot(target[u_out], target[v_out]), ("lat","lon"), land_mask.coords, dt)) @@ -97,12 +78,6 @@ def mean_over_land(data, dims, coords, time_coord): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") # Compute diurnal means and stds - temp_target_mean, _ = concat_and_group_diurnal(target_temp) - temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) - temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) - if mean_pred_temp: - temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp) - wind_target_mean, _ = concat_and_group_diurnal(target_wind) wind_baseline_mean, _ = concat_and_group_diurnal(baseline_wind) wind_pred_mean, wind_pred_std = concat_and_group_diurnal(pred_wind, is_member=True) @@ -130,27 +105,12 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): plt.savefig(out_path) plt.close() - data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean] - labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['Target', 'Input', 'CorrDiff ± Std(Members)'] - stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std] - - output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") - output_path.mkdir(parents=True, exist_ok=True) - # Generate plots - save_plot( - temp_target_mean.hour, - data, - stds, - labels, - '2m Temperature [°C]', - 'Diurnal Cycle of 2m Temperature', - output_path / 'diurnal_cycle_2t.png' - ) - data = [wind_target_mean, wind_baseline_mean, wind_pred_mean, wind_mean_pred_mean] if mean_pred_wind else [wind_target_mean, wind_baseline_mean, wind_pred_mean] labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wind else ['Target', 'Input', 'CorrDiff ± Std(Members)'] stds = [None, None, wind_pred_std, None] if mean_pred_wind else [None, None, wind_pred_std] + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) save_plot( wind_target_mean.hour, data, diff --git a/src/hirad/eval/eval_utils.py b/src/hirad/eval/eval_utils.py index 680f5d7..8ec0ad3 100644 --- a/src/hirad/eval/eval_utils.py +++ b/src/hirad/eval/eval_utils.py @@ -174,6 +174,15 @@ def get_channel_indices(gen_cfg: dict, channels=None) -> dict: } +def make_percentile_values(per_decade: int = 20) -> np.ndarray: + """Percentiles sampled equidistantly on the logit (log-exceedance) axis.""" + tail = np.logspace(-3, 1, 4 * per_decade + 1) # 0.001 ... 10 + lower = tail # low tail: 0.001 ... 10 + upper = 100.0 - tail[::-1] # high tail: 90 ... 99.999 + center = np.linspace(10.0, 90.0, 2 * per_decade + 1) + return np.unique(np.concatenate([lower, center, upper])) + + def resolve_ts_dir(out_root: Path, ts: str) -> Path: """Return the directory under *out_root* that contains the timestamp folder *ts*.""" if (out_root / ts).is_dir(): @@ -184,6 +193,15 @@ def resolve_ts_dir(out_root: Path, ts: str) -> Path: raise FileNotFoundError(f"Timestamp directory {ts} not found under {out_root}") +def signed_circular_difference(prediction: np.ndarray, target: np.ndarray, period: float = 360.0) -> np.ndarray: + """Return signed wrapped difference on a circular domain. + + For angles in degrees, this yields values in [-180, 180). + """ + half_period = period / 2.0 + return ((prediction - target + half_period) % period) - half_period + + def parse_eval_cli(allow_times: bool = False) -> dict: """Parse standard eval CLI args (``--config-name``) and return the loaded YAML config. diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index 30e2f60..685a888 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -9,7 +9,7 @@ from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir from hirad.eval.plotting import ( - plot_map_precipitation, plot_map + plot_difference_map, plot_map, plot_map_precipitation ) @@ -128,9 +128,12 @@ def plot_stat_map(data, filename, stat_config, label, grid_cfg): ) -def _load_predictions_all_members(filepath, conv_factor): - """Load prediction file once and return (n_members, C, H, W) tensor.""" - return torch.load(filepath, weights_only=False) * conv_factor +def _difference_label(stat_type): + if stat_type == 'weth_freq': + return 'Difference [%]' + if stat_type in ('cdd', 'cwd'): + return 'Difference [days]' + return 'Difference [mm/day]' def main(cfg: dict): @@ -183,6 +186,8 @@ def main(cfg: dict): 'regression-prediction': (tp_out, 'Regression Prediction') } + mode_results = {} + for mode, (tp_channel, label) in basic_modes.items(): logger.info(f"Processing mode: {mode}") data_list = [] @@ -196,54 +201,91 @@ def main(cfg: dict): logger.warning(f"{mode} not available, skipping") continue - # Stack into (T, H, W) numpy array - mode_data = np.stack(data_list, axis=0).astype(np.float64) + # Stack into (T, H, W) numpy array. float32 to save memory. + mode_data = np.stack(data_list, axis=0).astype(np.float32) del data_list + mode_results[mode] = {} for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for {mode}...") result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold) - map_output_dir = output_path / f"maps_{stat_config['stat_name']}" + mode_results[mode][stat_config['stat_name']] = result + map_output_dir = output_path / f"maps_precip_{stat_config['stat_name']}" map_output_dir.mkdir(parents=True, exist_ok=True) plot_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) - # --- Predictions: load each file ONCE, distribute to all members --- + del mode_data + + target_results = mode_results.get('target') + if target_results is None: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for basic modes") + else: + for mode, (_, label) in basic_modes.items(): + if mode == 'target' or mode not in mode_results: + continue + logger.info(f"Generating {mode} minus target difference maps") + for stat_config in stat_configs: + stat_name = stat_config['stat_name'] + if stat_name not in mode_results[mode] or stat_name not in target_results: + continue + diff = mode_results[mode][stat_name] - target_results[stat_name] + map_output_dir = output_path / f"maps_precip_{stat_name}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_difference_map( + diff, + str(map_output_dir / f'{mode}_minus_target_{stat_name}'), + title=f'{label} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) + + # --- Predictions: process ONE member at a time to bound memory usage --- logger.info("Processing predictions mode...") sample_data = torch.load(resolve_ts_dir(out_root, times[0]) / times[0] / f"{times[0]}-predictions", weights_only=False) n_members = sample_data.shape[0] del sample_data logger.info(f"Found {n_members} ensemble members") - # Pre-allocate arrays for ALL members at once: (n_members, T, H, W) - # If memory is tight, we can do this in chunks. For 16 members × 2200 × 704 × 1088 × 4 bytes ≈ 107 GB - # Instead we can do cummulative statistics on the fly without storing all members in memory (like in map_wind_stats), but this works for now. - H, W = cfg.get("height"), cfg.get("width") - member_arrays = [np.empty((len(times), H, W), dtype=np.float32) for _ in range(n_members)] - - logger.info("Loading all prediction timesteps (single pass over files)...") - for i, ts in enumerate(times): - if i % log_interval == 0: - logger.info(f"Loading predictions timestep {i+1}/{len(times)}: {ts}") - pred_data = torch.load(out_root / ts / f"{ts}-predictions", weights_only=False) * conv_factor - for m in range(n_members): - member_arrays[m][i] = (pred_data[m, tp_out].numpy() if isinstance(pred_data, torch.Tensor) - else pred_data[m, tp_out]) - del pred_data + H: int = cfg["height"] + W: int = cfg["width"] + member_data = np.empty((len(times), H, W), dtype=np.float32) + has_target_for_diff = target_results is not None + if not has_target_for_diff: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for members") for member_idx in range(n_members): - logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}") - member_data = member_arrays[member_idx].astype(np.float64) + logger.info(f"Loading prediction member {member_idx+1}/{n_members} (single pass over files)...") + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f"Loading predictions member {member_idx+1} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load(resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", weights_only=False) + member_slice = pred_data[member_idx, tp_out] + member_data[i] = (member_slice.numpy() if isinstance(member_slice, torch.Tensor) else member_slice) * conv_factor + del pred_data + logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}") for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") member_result = apply_statistic(member_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold) - map_output_dir = output_path / f"maps_{stat_config['stat_name']}" + map_output_dir = output_path / f"maps_precip_{stat_config['stat_name']}" map_output_dir.mkdir(parents=True, exist_ok=True) member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') member_label = f'CorrDiff Member {member_idx+1}' plot_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg) - - del member_arrays + if has_target_for_diff: + target_result = target_results.get(stat_config['stat_name']) + if target_result is not None: + diff = member_result - target_result + diff_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_minus_target_{stat_config["stat_name"]}') + plot_difference_map( + diff, + diff_filename, + title=f'CorrDiff Member {member_idx+1} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) + + del member_data logger.info("All precipitation statistics maps generated successfully") diff --git a/src/hirad/eval/map_temp_stats.py b/src/hirad/eval/map_temp_stats.py new file mode 100644 index 0000000..4dff7a1 --- /dev/null +++ b/src/hirad/eval/map_temp_stats.py @@ -0,0 +1,417 @@ +import logging +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch +import xarray as xr +import numba + +from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir +from hirad.eval.plotting import plot_difference_map, plot_map, plot_map_temperature + + +@numba.njit +def _longest_spell(x): + """Longest consecutive run of True values in a 1-D boolean array.""" + best = 0 + cur = 0 + for i in range(x.shape[0]): + if x[i]: + cur += 1 + if cur > best: + best = cur + else: + cur = 0 + return best + + +@numba.njit(parallel=True) +def _consecutive_spell_2d(condition_3d): + """ + condition_3d: bool array of shape (T, H, W). + Returns int array of shape (H, W) with longest spell per grid point. + """ + T, H, W = condition_3d.shape + out = np.empty((H, W), dtype=np.int64) + for i in numba.prange(H): + for j in range(W): + out[i, j] = _longest_spell(condition_3d[:, i, j]) + return out + + +def consecutive_spell(data_np, condition_fn): + """ + data_np: numpy array (T, H, W) + condition_fn: callable that takes the array and returns bool array of same shape + """ + cond = condition_fn(data_np) + return _consecutive_spell_2d(cond) + + +@numba.njit +def _count_spell_days(x, min_run): + """Count total days belonging to runs of at least min_run consecutive True values.""" + n = x.shape[0] + total = 0 + i = 0 + while i < n: + if x[i]: + run_start = i + while i < n and x[i]: + i += 1 + run_len = i - run_start + if run_len >= min_run: + total += run_len + else: + i += 1 + return total + + +@numba.njit(parallel=True) +def _spell_days_2d(condition_3d, min_run): + """ + condition_3d: bool array of shape (T, H, W). + Returns int array of shape (H, W) with total spell days per grid point. + """ + T, H, W = condition_3d.shape + out = np.empty((H, W), dtype=np.int64) + for i in numba.prange(H): + for j in range(W): + out[i, j] = _count_spell_days(condition_3d[:, i, j], min_run) + return out + + +def apply_statistic(data_np, times_dt, stat_type, stat_param=None): + """ + Apply temperature statistic on array containing time sequence of 2m temperature maps. + data_np: (T, H, W) float array in degrees Celsius + times_dt: list of datetime objects (length T) + Returns: (H, W) numpy array + """ + if stat_type == 'mean': + return np.mean(data_np, axis=0) + + if stat_type == 'std': + return np.std(data_np, axis=0) + + if stat_type == 'max': + # TXx: maximum temperature + return np.max(data_np, axis=0) + + if stat_type == 'min': + # TNn: minimum temperature + return np.min(data_np, axis=0) + + if stat_type == 'quantile': + return np.quantile(data_np, stat_param, axis=0) + + # For daily-based indices, build daily aggregations using xarray + if stat_type in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights', + 'dtr', 'warm_spell', 'hot_spell', 'cold_spell'): + da = xr.DataArray( + data_np, dims=['time', 'lat', 'lon'], + coords={'time': times_dt} + ) + daily_max = da.resample(time="1D").max("time").values # (D, H, W) + daily_min = da.resample(time="1D").min("time").values # (D, H, W) + D = daily_max.shape[0] + + if stat_type == 'warm_days': + # SU: fraction of days with daily max > 25 °C + return np.mean(daily_max > 25.0, axis=0) * 100.0 + + if stat_type == 'hot_days': + # HD: fraction of days with daily max > 35 °C + return np.mean(daily_max > 35.0, axis=0) * 100.0 + + if stat_type == 'frost_days': + # FD: fraction of days with daily min < 0 °C + return np.mean(daily_min < 0.0, axis=0) * 100.0 + + if stat_type == 'ice_days': + # ID: fraction of days with daily max < 0 °C + return np.mean(daily_max < 0.0, axis=0) * 100.0 + + if stat_type == 'tropical_nights': + # TR: fraction of days with daily min > 20 °C + return np.mean(daily_min > 20.0, axis=0) * 100.0 + + if stat_type == 'dtr': + # DTR: mean diurnal temperature range + return np.mean(daily_max - daily_min, axis=0) + + if stat_type == 'warm_spell': + # WSDI: total days in spells ≥6 consecutive days with TX > 90th percentile + p90 = np.percentile(daily_max, 90, axis=0) # (H, W) + cond = np.ascontiguousarray(daily_max > p90[np.newaxis, :, :]) + return _spell_days_2d(cond, 6).astype(np.float32) + + if stat_type == 'hot_spell': + # longest consecutive run of days with daily max > 35 °C (custom, non-ETCCDI) + return consecutive_spell(daily_max, lambda x: x > 35.0) + + if stat_type == 'cold_spell': + # CSDI: total days in spells ≥6 consecutive days with TN < 10th percentile + p10 = np.percentile(daily_min, 10, axis=0) # (H, W) + cond = np.ascontiguousarray(daily_min < p10[np.newaxis, :, :]) + return _spell_days_2d(cond, 6).astype(np.float32) + + raise ValueError(f"Unsupported temperature statistic type: {stat_type}") + + +def plot_temp_stat_map(data, filename, stat_config, label, grid_cfg): + """Plot a single temperature statistic map with appropriate styling.""" + stype = stat_config['type'] + title = f'{label}: {stat_config["title_stat"]}' + + if stype in ('mean', 'quantile', 'max', 'min'): + plot_map_temperature(data, filename, title=title, grid_cfg=grid_cfg) + elif stype == 'std': + plot_map( + data, filename, + title=title, + label='Std Dev [°C]', + vmin=0, vmax=10, cmap='plasma', extend='max', grid_cfg=grid_cfg + ) + elif stype == 'dtr': + plot_map( + data, filename, + title=title, + label='Diurnal Range [°C]', + vmin=0, vmax=20, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + ) + elif stype in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights'): + plot_map( + data, filename, + title=title, + label='Frequency [% of days]', + vmin=0, vmax=30, cmap='OrRd', extend='neither', grid_cfg=grid_cfg + ) + elif stype == 'warm_spell': + plot_map( + data, filename, + title=title, + label='Days', + vmin=0, vmax=92, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + ) + elif stype == 'hot_spell': + plot_map( + data, filename, + title=title, + label='Days', + vmin=0, vmax=30, cmap='YlOrRd', extend='max', grid_cfg=grid_cfg + ) + elif stype == 'cold_spell': + plot_map( + data, filename, + title=title, + label='Days', + vmin=0, vmax=30, cmap='YlGnBu', extend='max', grid_cfg=grid_cfg + ) + else: + plot_map( + data, filename, + title=title, + label='Temperature [°C]', + vmin=None, vmax=None, cmap='RdBu_r', extend='both', grid_cfg=grid_cfg + ) + + +def _difference_label(stat_type): + if stat_type in ('mean', 'quantile', 'max', 'min', 'std', 'dtr'): + return 'Difference [°C]' + if stat_type in ('warm_days', 'hot_days', 'frost_days', 'ice_days', 'tropical_nights'): + return 'Difference [% of days]' + if stat_type in ('warm_spell', 'hot_spell', 'cold_spell'): + return 'Difference [days]' + return 'Difference' + + +def main(cfg: dict): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + grid_cfg = grid_cfg_from_cfg(cfg) + + logger.info("Starting 2m temperature statistics generation") + try: + generation_dir, gen_cfg, times = load_generation_setup(cfg) + except ValueError as exc: + logger.error(str(exc)) + return + logger.info(f"Processing {len(times)} timesteps") + + times_dt = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + + out_root = Path(generation_dir) + output_path = out_root / cfg.get("results_dir_name", "evaluation_maps") + output_path.mkdir(parents=True, exist_ok=True) + + indices = get_channel_indices(gen_cfg) + out_ch = indices['output'] + in_ch = indices['input'] + + # Temperature channel: try '2t', fall back to 't2m' + t2m_out = out_ch.get('2t', out_ch.get('t2m')) + t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) + + if t2m_out is None: + logger.error("No temperature channel ('2t' or 't2m') found in output channels. Aborting.") + return + + # Conversion from Kelvin to Celsius: value * temp_conv_factor + temp_conv_offset + conv_factor = cfg.get("temp_conv_factor", 1.0) + conv_offset = cfg.get("temp_conv_offset", -273.15) + log_interval = cfg.get("log_interval", 100) + + STATISTICS_CONFIG = { + 'mean': {'type': 'mean', 'title': 'Mean Temperature'}, + 'std': {'type': 'std', 'title': 'Temperature Variability (Std Dev)'}, + 'txx': {'type': 'max', 'title': 'Maximum Temperature (TXx)'}, + 'tnn': {'type': 'min', 'title': 'Minimum Temperature (TNn)'}, + 'p99.99': {'type': 'quantile', 'param': 0.9999, 'title': '99.99th Percentile Temperature'}, + 'p99.9': {'type': 'quantile', 'param': 0.999, 'title': '99.9th Percentile Temperature'}, + 'p99': {'type': 'quantile', 'param': 0.99, 'title': '99th Percentile Temperature'}, + 'p01': {'type': 'quantile', 'param': 0.01, 'title': '1st Percentile Temperature'}, + 'p0.1': {'type': 'quantile', 'param': 0.001, 'title': '0.1th Percentile Temperature'}, + 'p0.01': {'type': 'quantile', 'param': 0.0001, 'title': '0.01th Percentile Temperature'}, + 'warm_days': {'type': 'warm_days', 'title': 'Summer Days (daily max > 25°C)'}, + 'hot_days': {'type': 'hot_days', 'title': 'Hot Days (daily max > 35°C)'}, + 'frost_days': {'type': 'frost_days', 'title': 'Frost Days (daily min < 0°C)'}, + 'ice_days': {'type': 'ice_days', 'title': 'Ice Days (daily max < 0°C)'}, + 'tropical_nights': {'type': 'tropical_nights', 'title': 'Tropical Nights (daily min > 20°C)'}, + 'dtr': {'type': 'dtr', 'title': 'Mean Diurnal Temperature Range (DTR)'}, + 'warm_spell': {'type': 'warm_spell', 'title': 'WSDI: Warm Spell Duration (TX > 90th pct, ≥6 days)'}, + 'hot_spell': {'type': 'hot_spell', 'title': 'Hot Spell Duration (daily max > 35°C)'}, + 'cold_spell': {'type': 'cold_spell', 'title': 'CSDI: Cold Spell Duration (TN < 10th pct, ≥6 days)'}, + } + stat_configs = [ + {'stat_name': name, 'title_stat': config['title'], 'param': config.get('param'), **config} + for name, config in STATISTICS_CONFIG.items() + ] + + # --- Basic modes: target, baseline, regression-prediction --- + basic_modes = { + 'target': (t2m_out, 'Target'), + 'baseline': (t2m_in, 'Input'), + 'regression-prediction': (t2m_out, 'Regression Prediction'), + } + + mode_results = {} + + for mode, (t2m_channel, label) in basic_modes.items(): + logger.info(f"Processing mode: {mode}") + data_list = [] + try: + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") + raw = torch.load(resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", weights_only=False) + val = raw[t2m_channel] + arr = val.numpy() if isinstance(val, torch.Tensor) else val + data_list.append(arr * conv_factor + conv_offset) + except Exception: + logger.warning(f"{mode} not available, skipping") + continue + + mode_data = np.stack(data_list, axis=0).astype(np.float32) + del data_list + + mode_results[mode] = {} + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for {mode}...") + result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param']) + mode_results[mode][stat_config['stat_name']] = result + map_output_dir = output_path / f"maps_temp_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_temp_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg) + + del mode_data + + target_results = mode_results.get('target') + if target_results is None: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for basic modes") + else: + for mode, (_, label) in basic_modes.items(): + if mode == 'target' or mode not in mode_results: + continue + logger.info(f"Generating {mode} minus target difference maps") + for stat_config in stat_configs: + stat_name = stat_config['stat_name'] + if stat_name not in mode_results[mode] or stat_name not in target_results: + continue + diff = mode_results[mode][stat_name] - target_results[stat_name] + map_output_dir = output_path / f"maps_temp_{stat_name}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_difference_map( + diff, + str(map_output_dir / f'{mode}_minus_target_{stat_name}'), + title=f'{label} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) + + # --- Predictions: process ONE member at a time to bound memory usage --- + logger.info("Processing predictions mode...") + try: + sample_data = torch.load( + resolve_ts_dir(out_root, times[0]) / times[0] / f"{times[0]}-predictions", + weights_only=False + ) + n_members = sample_data.shape[0] + del sample_data + except Exception as exc: + logger.error(f"Could not load predictions: {exc}") + return + logger.info(f"Found {n_members} ensemble members") + + H: int = cfg["height"] + W: int = cfg["width"] + member_data = np.empty((len(times), H, W), dtype=np.float32) + has_target_for_diff = target_results is not None + if not has_target_for_diff: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for members") + + for member_idx in range(n_members): + logger.info(f"Loading prediction member {member_idx+1}/{n_members} (single pass over files)...") + for i, ts in enumerate(times): + if i % log_interval == 0: + logger.info(f"Loading predictions member {member_idx+1} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load( + resolve_ts_dir(out_root, ts) / ts / f"{ts}-predictions", + weights_only=False + ) + member_slice = pred_data[member_idx, t2m_out] + arr = member_slice.numpy() if isinstance(member_slice, torch.Tensor) else member_slice + member_data[i] = arr * conv_factor + conv_offset + del pred_data + + logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}") + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") + member_result = apply_statistic(member_data, times_dt, stat_config['type'], stat_config['param']) + map_output_dir = output_path / f"maps_temp_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + member_label = f'CorrDiff Member {member_idx+1}' + plot_temp_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg) + if has_target_for_diff: + target_result = target_results.get(stat_config['stat_name']) + if target_result is not None: + diff = member_result - target_result + diff_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_minus_target_{stat_config["stat_name"]}') + plot_difference_map( + diff, + diff_filename, + title=f'CorrDiff Member {member_idx+1} - Target: {stat_config["title_stat"]} Difference', + label=_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + ) + + del member_data + logger.info("All 2m temperature statistics maps generated successfully") + + +if __name__ == '__main__': + main(parse_eval_cli()) diff --git a/src/hirad/eval/map_wind_stats.py b/src/hirad/eval/map_wind_stats.py index a6367b1..8428ca0 100644 --- a/src/hirad/eval/map_wind_stats.py +++ b/src/hirad/eval/map_wind_stats.py @@ -7,8 +7,8 @@ from hirad.datasets import get_channels_from_strings, get_strings_from_channels from hirad.utils.function_utils import get_time_from_range -from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir -from hirad.eval.plotting import plot_map +from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir, signed_circular_difference +from hirad.eval.plotting import plot_difference_map, plot_map def compute_wind_speed(u, v): @@ -199,6 +199,20 @@ def plot_wind_stat_map(data, filename, stat_config, label, grid_cfg): ) +def _wind_difference_label(stat_type): + if stat_type in ['mean_speed', 'max_speed']: + return 'Difference [m/s]' + if stat_type == 'wind_power': + return 'Difference [m^3/s^3]' + if stat_type in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', 'strong_breeze_freq', 'gale_freq']: + return 'Difference [%]' + if stat_type in ['prevailing_direction', 'direction_variability']: + return 'Difference [degrees]' + if stat_type in ['mean_u', 'mean_v']: + return 'Difference [m/s]' + return 'Difference' + + def main(cfg: dict): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -294,6 +308,7 @@ def main(cfg: dict): logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} modes + predictions") log_interval = cfg.get("log_interval", 100) + mode_results = {} for mode, (wind_channels, label) in basic_modes.items(): logger.info(f"Processing mode: {mode}") @@ -323,11 +338,42 @@ def main(cfg: dict): label, grid_cfg ) - del results + mode_results[mode] = results except Exception as e: logger.error(f"Failed computing statistics for {mode}: {e}") continue + target_results = mode_results.get('target') + if target_results is None: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for basic modes") + else: + for mode, (_, label) in basic_modes.items(): + if mode == 'target' or mode not in mode_results: + continue + logger.info(f"Generating {mode} minus target difference maps") + for stat_config in stat_configs: + stat_key = stat_config['stat_name'] + if stat_key not in mode_results[mode] or stat_key not in target_results: + continue + if stat_config['type'] == 'prevailing_direction': + diff = signed_circular_difference(mode_results[mode][stat_key], target_results[stat_key]) + else: + diff = mode_results[mode][stat_key] - target_results[stat_key] + map_output_dir = output_path / f"maps_wind_{stat_key}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_difference_map( + diff, + str(map_output_dir / f'{mode}_minus_target_{stat_key}'), + title=f'{label} - Target: {stat_config["title_stat"]} Difference', + label=_wind_difference_label(stat_config['type']), + grid_cfg=grid_cfg, + fixed_vmax=180.0 if stat_config['type'] == 'prevailing_direction' else None, + ) + + has_target_for_diff = target_results is not None + if not has_target_for_diff: + logger.warning("Target mode not available; skipping prediction-minus-target difference maps for members") + logger.info("Processing predictions mode...") try: @@ -469,6 +515,22 @@ def main(cfg: dict): map_output_dir.mkdir(parents=True, exist_ok=True) member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_key}') plot_wind_stat_map(member_result, member_filename, stat_config, f'CorrDiff Member {member_idx+1}', grid_cfg) + if has_target_for_diff: + target_result = target_results.get(stat_key) + if target_result is not None: + if stype == 'prevailing_direction': + diff = signed_circular_difference(member_result, target_result) + else: + diff = member_result - target_result + diff_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_minus_target_{stat_key}') + plot_difference_map( + diff, + diff_filename, + title=f'CorrDiff Member {member_idx+1} - Target: {stat_config["title_stat"]} Difference', + label=_wind_difference_label(stype), + grid_cfg=grid_cfg, + fixed_vmax=180.0 if stype == 'prevailing_direction' else None, + ) del member_result except Exception as e: logger.error(f"Failed {stat_config['title_stat']} for member {member_idx+1}: {e}") diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 61e0ff0..b1b0f69 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -79,6 +80,39 @@ def plot_map(values: np.array, fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") plt.close(fig) + +def compute_symmetric_vmax(values: np.ndarray, percentile: float = 99.0, fallback: float = 1.0) -> float: + """Return a robust symmetric colorbar half-range based on absolute values.""" + vmax = float(np.nanpercentile(np.abs(values), percentile)) + if vmax != vmax or vmax <= 0: + return fallback + return vmax + + +def plot_difference_map( + values: np.ndarray, + filename: str, + title: str = '', + label: str = 'Difference', + grid_cfg: GridConfig = DEFAULT_GRID_CONFIG, + cmap: str = 'RdBu_r', + percentile: float = 99.0, + fixed_vmax: Optional[float] = None, +): + """Plot a difference map with symmetric diverging bounds around zero.""" + vmax = fixed_vmax if fixed_vmax is not None else compute_symmetric_vmax(values, percentile=percentile) + plot_map( + values, + filename, + title=title, + label=label, + vmin=-vmax, + vmax=vmax, + cmap=cmap, + extend='both', + grid_cfg=grid_cfg, + ) + def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000.0, grid_cfg=DEFAULT_GRID_CONFIG): """Plot precipitation data with specific colormap and thresholds.""" # Scale and mask values below threshold @@ -106,6 +140,33 @@ def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000 grid_cfg=grid_cfg, ) +def plot_map_temperature(values, filename, title='', grid_cfg=DEFAULT_GRID_CONFIG): + """Plot 2m temperature data with Meteoswiss-style colormap.""" + colors = [ + "#3366FF", "#4C97FF", "#4CA8FF", "#00CCFF", + "#DEE699", "#A6D473", "#6BBF4D", "#33AB26", "#009900", + "#33B300", "#66CC00", "#99E600", "#CCFF00", "#FFFF00", + "#FFCC00", "#FF9900", "#FF6600", "#FF3300", "#FF0000", + "#EB00EB", "#FF40FF", "#FF80FF", "#FFBFFF", "#FFE0FF", "#FFF5FF", + "#D9D9D9", "#A6A6A6", "#737373", + ] + bounds = np.arange(-8, 49, 2) + + cmap = ListedColormap(colors) + norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) + + plot_map( + values, filename, + cmap=cmap, + norm=norm, + ticks=bounds, + title=title, + label='Temperature [°C]', + extend='both', + grid_cfg=grid_cfg, + ) + + def plot_map_wind_precip( u: np.ndarray, v: np.ndarray, diff --git a/src/hirad/eval/probability_of_exceedance_wind.py b/src/hirad/eval/probability_of_exceedance_wind.py index f280e2d..fa2d82d 100644 --- a/src/hirad/eval/probability_of_exceedance_wind.py +++ b/src/hirad/eval/probability_of_exceedance_wind.py @@ -1,15 +1,12 @@ """Probability of exceedance for wind speed and components.""" import logging +import time from pathlib import Path -import hydra import matplotlib.pyplot as plt import numpy as np import torch -import xarray as xr -from hirad.datasets import get_channels_from_strings, get_strings_from_channels -from hirad.utils.function_utils import get_time_from_range from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, parse_eval_cli, resolve_ts_dir from hirad.eval.eval_utils import percentiles_from_histogram @@ -28,18 +25,21 @@ def compute_exceedance_probs(values, thresholds, use_abs=False): def update_exceedance_counts(counts, total, values, thresholds, use_abs=False): - """Update exceedance counts incrementally.""" + """Update exceedance counts incrementally via searchsorted (O(n log k)).""" data = np.abs(values) if use_abs else values - counts += (data[:, None] > thresholds[None, :]).sum(axis=0) - total += len(values) + # idx[i] = number of thresholds strictly less than data[i] (side='left') + # data[i] > thresholds[j] iff idx[i] > j + idx = np.searchsorted(thresholds, data, side='left') + bin_counts = np.bincount(idx, minlength=len(thresholds) + 1) + counts += len(data) - np.cumsum(bin_counts)[: len(thresholds)] + total += len(data) return counts, total def compute_percentiles(values, percentile_dict, use_abs=False): """Compute percentiles.""" data = np.abs(values) if use_abs else values - data_array = xr.DataArray(data) - return {key: data_array.quantile(p).item() for key, p in percentile_dict.items()} + return {key: float(np.quantile(data, p)) for key, p in percentile_dict.items()} def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): @@ -119,6 +119,34 @@ def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title plt.close() +def _accumulate(exc_counts, h_counts, speed_vals, u_vals, v_vals, + thresholds, hist_interior, n_hist_bins): + """Update exceedance and histogram count arrays in-place.""" + abs_u = np.abs(u_vals) + abs_v = np.abs(v_vals) + n_thr = len(thresholds) + + # Exceedance counts: O(n log k) via searchsorted+bincount + for counts, data in [ + (exc_counts['speed'], speed_vals), + (exc_counts['u'], abs_u), + (exc_counts['v'], abs_v), + ]: + idx = np.searchsorted(thresholds, data, side='left') + bin_counts = np.bincount(idx, minlength=n_thr + 1) + counts += len(data) - np.cumsum(bin_counts)[:n_thr] + + # Histogram counts: searchsorted+bincount, no temporary bool array + for hcounts, data in [ + (h_counts['speed'], speed_vals), + (h_counts['u'], abs_u), + (h_counts['v'], abs_v), + ]: + hcounts += np.bincount( + np.searchsorted(hist_interior, data, side='right'), minlength=n_hist_bins + ) + + def main(cfg: dict): # Setup logging logging.basicConfig(level=logging.INFO) @@ -141,12 +169,13 @@ def main(cfg: dict): v10_out = indices['output'].get('10v') u10_in = indices['input'].get('10u', u10_out) v10_in = indices['input'].get('10v', v10_out) - + if u10_out is None or v10_out is None: logger.error("Wind components (10u, 10v) not found in dataset!") return - - logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, input: 10u={u10_in}, 10v={v10_in}") + + logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, " + f"input: 10u={u10_in}, 10v={v10_in}") # Define thresholds for exceedance calculation (same for all variables) thresholds = np.logspace(-1, 2, 200) # From 0.1 to ~100 m/s @@ -155,123 +184,123 @@ def main(cfg: dict): # Histogram bins for percentile estimation (fine-grained log-spaced) hist_bins = np.concatenate([ np.array([0.0]), - np.logspace(-1, 2.5, 5000) # From 0.1 to ~316 m/s + np.logspace(-1, 2.5, 5000) # From 0.1 to ~316 m/s ]) n_hist_bins = len(hist_bins) - 1 - - # Storage for exceedance counts (incremental computation) - exceedance_counts = { - 'speed': {}, 'u': {}, 'v': {} - } - totals = {'speed': {}, 'u': {}, 'v': {}} - hist_counts = { - 'speed': {}, 'u': {}, 'v': {} - } - - # -- Process target and baseline -- - for mode in ['target', 'baseline', 'regression-prediction']: - logger.info(f"Processing mode: {mode}") - - # Initialize counts - for var in ['speed', 'u', 'v']: + # Interior edges used by searchsorted+bincount (equivalent to np.histogram) + hist_interior = hist_bins[1:-1] + + det_modes = ('target', 'baseline', 'regression-prediction') + VARS = ('speed', 'u', 'v') + + # Storage for exceedance counts and histograms + exceedance_counts = {var: {} for var in VARS} + totals = {var: {} for var in VARS} + hist_counts = {var: {} for var in VARS} + for mode in det_modes: + for var in VARS: exceedance_counts[var][mode] = np.zeros(n_thresholds, dtype=np.int64) - totals[var][mode] = 0 - hist_counts[var][mode] = np.zeros(n_hist_bins, dtype=np.int64) - - try: - for i, ts in enumerate(times): - if i % cfg.get("log_interval") == 0: - logger.info(f"Processing timestep {i+1}/{len(times)}") - - data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False) - - # Extract wind components - if mode in ['target', 'regression-prediction']: - u = data[u10_out] - v = data[v10_out] - else: # baseline - u = data[u10_in] - v = data[v10_in] - - wind_speed = compute_wind_speed(u, v) - - # Get valid values - valid_mask = ~np.isnan(wind_speed) - speed_vals = wind_speed[valid_mask].flatten() - u_vals = u[valid_mask].flatten() - v_vals = v[valid_mask].flatten() - - # Update exceedance counts incrementally - exceedance_counts['speed'][mode], totals['speed'][mode] = update_exceedance_counts( - exceedance_counts['speed'][mode], totals['speed'][mode], speed_vals, thresholds, use_abs=False - ) - exceedance_counts['u'][mode], totals['u'][mode] = update_exceedance_counts( - exceedance_counts['u'][mode], totals['u'][mode], u_vals, thresholds, use_abs=True - ) - exceedance_counts['v'][mode], totals['v'][mode] = update_exceedance_counts( - exceedance_counts['v'][mode], totals['v'][mode], v_vals, thresholds, use_abs=True - ) - - # Collect samples for percentiles (subsample to save memory) - hist_counts['speed'][mode] += np.histogram(speed_vals, bins=hist_bins)[0] - hist_counts['u'][mode] += np.histogram(np.abs(u_vals), bins=hist_bins)[0] - hist_counts['v'][mode] += np.histogram(np.abs(v_vals), bins=hist_bins)[0] - - except Exception as e: - logger.warning(f"{mode} data not found or error occurred, skipping: {e}") - continue - - logger.info(f"Processed {totals['speed'][mode]} values for {mode}") - - # -- Process predictions: compute exceedance for each ensemble member -- - logger.info("Processing predictions") - - n_members = None - member_counts = {'speed': [], 'u': [], 'v': []} - member_totals = {'speed': [], 'u': [], 'v': []} - member_hist_counts = {'speed': [], 'u': [], 'v': []} - + totals[var][mode] = 0 + hist_counts[var][mode] = np.zeros(n_hist_bins, dtype=np.int64) + + n_members = None + member_counts = {var: [] for var in VARS} + member_totals = {var: [] for var in VARS} + member_hist_counts = {var: [] for var in VARS} + + skip_det = set() # modes whose files were missing + skip_preds = False + + log_interval = cfg.get("log_interval", 24) + + # Pre-resolve timestamp directories once (avoids 4× filesystem glob per timestep) + logger.info(f"Resolving {len(times)} timestep directories ...") + ts_dirs = {ts: resolve_ts_dir(out_root, ts) / ts for ts in times} + + # -- Single pass over all timesteps -- + t0 = time.perf_counter() for i, ts in enumerate(times): - if i % cfg.get("log_interval") == 0: - logger.info(f"Processing timestep {i+1}/{len(times)}") - - preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon] - - if n_members is None: - n_members = preds.shape[0] - for var in ['speed', 'u', 'v']: - member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)] - member_totals[var] = [0 for _ in range(n_members)] - member_hist_counts[var] = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)] - - for member_idx in range(n_members): - u = preds[member_idx, u10_out] - v = preds[member_idx, v10_out] + if i % log_interval == 0: + elapsed = time.perf_counter() - t0 + logger.info(f"Processing timestep {i+1}/{len(times)} ({elapsed:.1f}s elapsed)") + + ts_dir = ts_dirs[ts] + + # Deterministic modes: target, baseline, regression-prediction + for mode in det_modes: + if mode in skip_det: + continue + try: + data = torch.load(ts_dir / f"{ts}-{mode}", weights_only=False) + except FileNotFoundError: + logger.warning(f"[{mode}] file not found at {ts}, skipping mode entirely") + skip_det.add(mode) + continue + + ch_u = u10_in if mode == 'baseline' else u10_out + ch_v = v10_in if mode == 'baseline' else v10_out + u = data[ch_u] + v = data[ch_v] + wind_speed = compute_wind_speed(u, v) - valid_mask = ~np.isnan(wind_speed) - speed_vals = wind_speed[valid_mask].flatten() - u_vals = u[valid_mask].flatten() - v_vals = v[valid_mask].flatten() - - # Update counts - member_counts['speed'][member_idx], member_totals['speed'][member_idx] = update_exceedance_counts( - member_counts['speed'][member_idx], member_totals['speed'][member_idx], speed_vals, thresholds, use_abs=False - ) - member_counts['u'][member_idx], member_totals['u'][member_idx] = update_exceedance_counts( - member_counts['u'][member_idx], member_totals['u'][member_idx], u_vals, thresholds, use_abs=True - ) - member_counts['v'][member_idx], member_totals['v'][member_idx] = update_exceedance_counts( - member_counts['v'][member_idx], member_totals['v'][member_idx], v_vals, thresholds, use_abs=True + speed_vals = wind_speed[valid_mask].ravel() + u_vals = u[valid_mask].ravel() + v_vals = v[valid_mask].ravel() + + _accumulate( + {var: exceedance_counts[var][mode] for var in VARS}, + {var: hist_counts[var][mode] for var in VARS}, + speed_vals, u_vals, v_vals, + thresholds, hist_interior, n_hist_bins, ) - - # Collect samples for percentiles - member_hist_counts['speed'][member_idx] += np.histogram(speed_vals, bins=hist_bins)[0] - member_hist_counts['u'][member_idx] += np.histogram(np.abs(u_vals), bins=hist_bins)[0] - member_hist_counts['v'][member_idx] += np.histogram(np.abs(v_vals), bins=hist_bins)[0] - - logger.info(f"Collected {n_members} ensemble members for predictions") - + totals['speed'][mode] += len(speed_vals) + totals['u'][mode] += len(u_vals) + totals['v'][mode] += len(v_vals) + + # Predictions (ensemble) + if not skip_preds: + try: + preds = torch.load(ts_dir / f"{ts}-predictions", weights_only=False) # [M, C, H, W] + except FileNotFoundError: + logger.warning(f"[predictions] file not found at {ts}, skipping ensemble entirely") + skip_preds = True + continue + + if n_members is None: + n_members = preds.shape[0] + logger.info(f"Detected {n_members} ensemble members") + for var in VARS: + member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)] + member_totals[var] = [0] * n_members + member_hist_counts[var] = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)] + + for m in range(n_members): + u = preds[m, u10_out] + v = preds[m, v10_out] + wind_speed = compute_wind_speed(u, v) + valid_mask = ~np.isnan(wind_speed) + speed_vals = wind_speed[valid_mask].ravel() + u_vals = u[valid_mask].ravel() + v_vals = v[valid_mask].ravel() + + _accumulate( + {var: member_counts[var][m] for var in VARS}, + {var: member_hist_counts[var][m] for var in VARS}, + speed_vals, u_vals, v_vals, + thresholds, hist_interior, n_hist_bins, + ) + member_totals['speed'][m] += len(speed_vals) + member_totals['u'][m] += len(u_vals) + member_totals['v'][m] += len(v_vals) + + total_elapsed = time.perf_counter() - t0 + logger.info(f"Single-pass loop completed in {total_elapsed:.1f}s for {len(times)} timesteps") + if n_members is not None: + logger.info(f"Collected {n_members} ensemble members for predictions") + else: + n_members = 0 # no predictions found; guard range() calls below + # Convert counts to probabilities exceedance_data = {'speed': {}, 'u': {}, 'v': {}} diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index 4630d49..0f4241a 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -1,37 +1,40 @@ #!/bin/bash -#SBATCH --job-name="eval_precip" - -### HARDWARE ### -#SBATCH --partition=normal -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=12:00:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=./logs/plots_precip.log - -### ENVIRONMENT #### -#SBATCH -A a161 +set -euo pipefail ### CONFIG ### CONFIG_NAME="src/hirad/conf/eval_real.yaml" -srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . - +CMDS=( # Diurnal cycle - # python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=${CONFIG_NAME} - # python src/hirad/eval/diurnal_cycle_precip_p99.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py" + "python src/hirad/eval/diurnal_cycle_precip_high_percentiles.py" # Histograms - # python src/hirad/eval/hist.py --config-name=${CONFIG_NAME} - # python src/hirad/eval/probability_of_exceedance.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/hist.py" + "python src/hirad/eval/probability_of_exceedance.py" + # QQ + "python -m hirad.eval.bias_by_percentile_precip" + "python -m hirad.eval.bias_by_percentile_precip_smoothed" # Maps - # python src/hirad/eval/map_precip_stats.py --config-name=${CONFIG_NAME} -" \ No newline at end of file + "python src/hirad/eval/map_precip_stats.py" + "python -m hirad.eval.diurnal_cycle_precip_maps" +) + +for cmd in "${CMDS[@]}"; do + name=$(basename "${cmd##* }" .py | tr '.' '_') + job_id=$(sbatch \ + --job-name="eval_precip_${name}" \ + --partition=normal \ + --nodes=1 \ + --ntasks-per-node=1 \ + --gpus-per-node=1 \ + --cpus-per-task=72 \ + --time=24:00:00 \ + --no-requeue \ + --exclusive \ + -A c38 \ + --output="./logs/plots_precip_${name}_%j.log" \ + --parsable \ + --wrap="srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -lc 'pip install -e . && ${cmd} --config-name=${CONFIG_NAME}'") + echo "Submitted ${name}: ${job_id}" +done \ No newline at end of file diff --git a/src/hirad/eval_temp.sh b/src/hirad/eval_temp.sh new file mode 100644 index 0000000..a9773ee --- /dev/null +++ b/src/hirad/eval_temp.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -euo pipefail + +### CONFIG ### +CONFIG_NAME="src/hirad/conf/eval_real.yaml" + +CMDS=( + # Diurnal cycle of 2m temperature + "python src/hirad/eval/diurnal_cycle_temp.py" + # QQ + "python -m hirad.eval.bias_by_percentile_temp" + # Maps + "python src/hirad/eval/map_temp_stats.py" +) + +for cmd in "${CMDS[@]}"; do + name=$(basename "${cmd##* }" .py | tr '.' '_') + job_id=$(sbatch \ + --job-name="eval_temp_${name}" \ + --partition=normal \ + --nodes=1 \ + --ntasks-per-node=1 \ + --gpus-per-node=1 \ + --cpus-per-task=72 \ + --time=12:00:00 \ + --no-requeue \ + --exclusive \ + -A c38 \ + --output="./logs/plots_temp_${name}_%j.log" \ + --parsable \ + --wrap="srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -lc 'pip install -e . && ${cmd} --config-name=${CONFIG_NAME}'") + echo "Submitted ${name}: ${job_id}" +done diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index 88f3f22..8dc2624 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -1,35 +1,37 @@ #!/bin/bash -#SBATCH --job-name="eval_wind" - -### HARDWARE ### -#SBATCH --partition=normal -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=12:00:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=./logs/plots_wind.log - -### ENVIRONMENT #### -#SBATCH -A a161 +set -euo pipefail ### CONFIG ### CONFIG_NAME="src/hirad/conf/eval_real.yaml" -srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . - - # Diurnal cycle - # python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=${CONFIG_NAME} - +CMDS=( + # Diurnal cycle of windspeed + "python src/hirad/eval/diurnal_cycle_wind.py" # Probability of exceedance - # python src/hirad/eval/probability_of_exceedance_wind.py --config-name=${CONFIG_NAME} - + "python src/hirad/eval/probability_of_exceedance_wind.py" + # QQ + "python -m hirad.eval.bias_by_percentile_wind" + "python -m hirad.eval.bias_by_percentile_wind_smoothed" # Maps - # python src/hirad/eval/map_wind_stats.py --config-name=${CONFIG_NAME} -" \ No newline at end of file + "python src/hirad/eval/map_wind_stats.py" +) + +for cmd in "${CMDS[@]}"; do + name=$(basename "${cmd##* }" .py | tr '.' '_') + job_id=$(sbatch \ + --job-name="eval_wind_${name}" \ + --partition=normal \ + --nodes=1 \ + --ntasks-per-node=1 \ + --gpus-per-node=1 \ + --cpus-per-task=72 \ + --time=24:00:00 \ + --no-requeue \ + --exclusive \ + -A c38 \ + --output="./logs/plots_wind_${name}_%j.log" \ + --parsable \ + --wrap="srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -lc 'pip install -e . && ${cmd} --config-name=${CONFIG_NAME}'") + echo "Submitted ${name}: ${job_id}" +done \ No newline at end of file