From 785569486102c549dcfa047bdfa2286c98254066 Mon Sep 17 00:00:00 2001 From: vhertel Date: Fri, 29 May 2026 18:55:40 +0200 Subject: [PATCH] Add data reader for offgrid inference --- config/config_offgrid_forecasting.yml | 261 ++++++++++++++++++ config/evaluate/eval_config_offgrid.yml | 60 ++++ config/streams/era5_1deg_forecasting/era5.yml | 22 +- .../datasets/data_reader_offgrid.py | 217 +++++++++++++++ .../datasets/multi_stream_data_sampler.py | 22 +- 5 files changed, 573 insertions(+), 9 deletions(-) create mode 100644 config/config_offgrid_forecasting.yml create mode 100644 config/evaluate/eval_config_offgrid.yml create mode 100644 src/weathergen/datasets/data_reader_offgrid.py diff --git a/config/config_offgrid_forecasting.yml b/config/config_offgrid_forecasting.yml new file mode 100644 index 000000000..a8db8edd6 --- /dev/null +++ b/config/config_offgrid_forecasting.yml @@ -0,0 +1,261 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +test_config: + + offgrid_eval: + # absolute path to .npy file with shape (N, 2) [lat, lon] used for offgrid inference + grid: /e/scratch/weatherai/shared_work/offgrid-test/offgrid_regular.npy + # temporal spacing between offgrid samples, e.g. 6h + frequency: 6h + # TODO add support for geoinfos + geoinfos: null + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/evaluate/eval_config_offgrid.yml b/config/evaluate/eval_config_offgrid.yml new file mode 100644 index 000000000..997970383 --- /dev/null +++ b/config/evaluate/eval_config_offgrid.yml @@ -0,0 +1,60 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +global_plotting_options: + regions: ["global"] +# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +# animation_format: "gif" #options: "mp4", "gif" +# dpi_val : 300 +# fps: 2 +# ERA5: +# marker_size: 2 +# scale_marker_size: 1 +# marker: "o" +# alpha: 0.5 +# add_healpix_grid: false +# healpix_nside: 4 +# 2t: +# vmin: 250 +# vmax: 300 +# 10u: +# vmin: -40 +# vmax: 40 + +# max_workers: 36 # hard cap on parallel workers (I/O, plotting, scoring) + +evaluation: + metrics: ["rmse", "mae"] + regions: ["global", "nhem"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + + +default_streams: + ERA5: + channels: ["2t", "10u", "10v", "10ff", "q_1000"] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [0] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + plot_maps: true + plot_histograms: false + plot_animations: true + +run_ids : + toqb9dv6: + label: "test offgrid regular" + + thb1rpuj: + label: "test offgrid synop" diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 0bd70ae01..1d6135208 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -9,26 +9,26 @@ ERA5 : type : anemoi - filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] stream_id : 0 - source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] - target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + # geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight : 1. location_weight : cosine_latitude - masking_rate : 0.6 - masking_rate_none : 0.05 token_size : 8 tokenize_spacetime : True max_num_targets: 20000 + frequency : 06:00:00 embed : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : num_layers : 2 num_heads : 4 @@ -37,6 +37,7 @@ ERA5 : ens_size : 1 num_layers : 1 channel_weights : + q_10: 0.2 q_50: 0.2 q_100: 0.23 q_150: 0.26 @@ -50,6 +51,7 @@ ERA5 : q_850: 0.71 q_925: 0.75 q_1000: 0.8 + t_10: 0.2 t_50: 0.2 t_100: 0.23 t_150: 0.26 @@ -63,6 +65,7 @@ ERA5 : t_850: 0.71 t_925: 0.75 t_1000: 0.8 + u_10: 0.2 u_50: 0.2 u_100: 0.23 u_150: 0.26 @@ -76,6 +79,7 @@ ERA5 : u_850: 0.71 u_925: 0.75 u_1000: 0.8 + v_10: 0.2 v_50: 0.2 v_100: 0.23 v_150: 0.26 @@ -89,6 +93,7 @@ ERA5 : v_850: 0.71 v_925: 0.75 v_1000: 0.8 + z_10: 0.2 z_50: 0.2 z_100: 0.23 z_150: 0.26 @@ -101,4 +106,5 @@ ERA5 : z_700: 0.61 z_850: 0.71 z_925: 0.75 - z_1000: 0.8 \ No newline at end of file + z_1000: 0.8 + \ No newline at end of file diff --git a/src/weathergen/datasets/data_reader_offgrid.py b/src/weathergen/datasets/data_reader_offgrid.py new file mode 100644 index 000000000..539b51d97 --- /dev/null +++ b/src/weathergen/datasets/data_reader_offgrid.py @@ -0,0 +1,217 @@ +# (C) Copyright 2026 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import numpy as np +from numpy.typing import NDArray + +from weathergen.common.config import parse_timedelta +from weathergen.datasets.data_reader_base import ( + DataReaderBase, + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + + +class DataReaderOffgrid(DataReaderTimestep): + """Offgrid datareader. + + 1) loads a template grid from a numpy file, + 2) uses a configured frequency, + 3) generates coords + datetimes for inference + + Expected template file format: + - .npy file with shape (N, 2), columns [lat, lon] in degrees + + Required stream_info entries: + - name + - frequency (for example "6h") + """ + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + frequency: str | int | float | np.timedelta64, + stream_info: dict, + ref_reader: DataReaderBase | None = None, + ) -> None: + """ + Construct data reader for offgrid inference + + Parameters + ---------- + tw_handler : + TimeWindowHandler defining the temporal window for inference + filename : + filename (and path) of dataset + frequency : + temporal spacing of offgrid samples + stream_info : + information about stream + ref_reader : + optional reference reader to inherit metadata from (e.g. source/target channels, normalization) + + Returns + ------- + None + """ + + # parse frequency into numpy timedelta64 + period = parse_timedelta(frequency) + + # initialize base class with time window and frequency info + super().__init__( + tw_handler, + stream_info, + data_start_time=tw_handler.t_start, + data_end_time=tw_handler.t_end, + period=period, + ) + + # load grid template + grid = np.load(filename) + if grid.ndim != 2 or grid.shape[1] != 2: + raise ValueError( + f"Template must be .npy with shape (N, 2) [lat, lon], got {grid.shape}" + ) + + # caches lats and lons + self.latitudes = _clip_lat(grid[:, 0].astype(np.float32)) + self.longitudes = _clip_lon(grid[:, 1].astype(np.float32)) + self.n_points = len(self.latitudes) + + # number of time steps that fit in the requested window + self.len = max(0, int((tw_handler.t_end - tw_handler.t_start) / period)) + + # Optionally inherit stream/channel metadata from a reference reader + if ref_reader is not None: + # select/filter requested source channels + self.source_idx = np.asarray(ref_reader.source_idx, dtype=np.int64) + self.source_channels = list(ref_reader.source_channels) + + # select/filter requested target channels + self.target_idx = np.asarray(ref_reader.target_idx, dtype=np.int64) + self.target_channels = list(ref_reader.target_channels) + + # set target channel weights + self.target_channel_weights = list(ref_reader.target_channel_weights) + + # set normalization parameters + self.mean = np.array(ref_reader.mean, copy=True) + self.stdev = np.array(ref_reader.stdev, copy=True) + + # if not provided, initialize with empty metadata and neutral normalization + else: + # empty source channels (needed from base class) + self.source_idx: NDArray[np.int64] = np.array([], dtype=np.int64) + self.source_channels: list[str] = [] + + # empty target channels (needed from base class) + self.target_idx: NDArray[np.int64] = np.array([], dtype=np.int64) + self.target_channels: list[str] = [] + + # empty target channel weights + self.target_channel_weights: list[float] = [] + + # neutral normalization + self.mean = np.zeros(0, dtype=np.float32) + self.stdev = np.ones(0, dtype=np.float32) + + # TODO add support for geoinfos + self.geoinfo_channels: list[str] = [] + self.geoinfo_idx = np.array([], dtype=np.int64) + self.mean_geoinfo = np.zeros(0, dtype=np.float32) + self.stdev_geoinfo = np.ones(0, dtype=np.float32) + + ds_name = stream_info["name"] + _logger.info( + f"{ds_name}: offgrid reader active (source={len(self.source_channels)}, " + f"target={len(self.target_channels)}, geoinfo={len(self.geoinfo_channels)})." + ) + + @override + def init_empty(self) -> None: + super().init_empty() + self.len = 0 + self.latitudes = np.zeros(0, dtype=np.float32) + self.longitudes = np.zeros(0, dtype=np.float32) + self.n_points = 0 + + @override + def length(self) -> int: + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for window (for either source or target, through public interface) + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + + Returns + ------- + ReaderData providing coords, geoinfos, data, datetimes + """ + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=0 + ) + assert t_idxs[0] >= 0, "index must be non-negative" + + n_steps = len(t_idxs) + n_total = self.n_points * n_steps + + # tile lat/lon pair for each time step + latlon = np.stack([self.latitudes, self.longitudes], axis=1) + coords = np.tile(latlon, (n_steps, 1)) + + # no atmospheric data fields, using zeros as placeholder + data = np.zeros((n_total, len(channels_idx)), dtype=np.float32) + + # TODO add support for geoinfos + geoinfos = np.zeros((n_total, 0), dtype=np.float32) + + # compute absolute times for each step, then repeat per grid point + step_times = self.data_start_time + self.period * t_idxs + datetimes = np.repeat(step_times, self.n_points) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes + ) + check_reader_data(rd, dtr) + + return rd + +def _clip_lat(lats: NDArray) -> NDArray[np.float32]: + """Clip latitudes to the range [-90, 90] and ensure periodicity.""" + return (2 * np.clip(lats, -90.0, 90.0) - lats).astype(np.float32) + +def _clip_lon(lons: NDArray) -> NDArray[np.float32]: + """Clip longitudes to the range [-180, 180] and ensure periodicity.""" + return ((lons + 180.0) % 360.0 - 180.0).astype(np.float32) \ No newline at end of file diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 49be6f8d5..3d2844877 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -23,6 +23,7 @@ TIndex, ) from weathergen.datasets.data_reader_fesom import DataReaderFesom +from weathergen.datasets.data_reader_offgrid import DataReaderOffgrid from weathergen.datasets.data_reader_obs import DataReaderObs from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData, spoof @@ -34,7 +35,7 @@ from weathergen.train.utils import Stage, get_batch_size_from_config from weathergen.utils.distributed import is_root -type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs +type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs | DataReaderOffgrid type StreamName = str logger = logging.getLogger(__name__) @@ -139,6 +140,14 @@ def __init__( self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False) self.streams_datasets: dict[StreamName, list[AnyDataReader]] = {} + + # Setup for offgrid inference and evaluation + offgrid_eval = mode_cfg.get("offgrid_eval", {}) + # Path to .npy file containing offgrid coordinates + self.offgrid_template = offgrid_eval.get("grid", None) + # Frequency defined by config (offgrid_eval.frequency) with fallback to time_window_step + self.offgrid_frequency = offgrid_eval.get("frequency", mode_cfg.time_window_step) + for _, stream_info in enumerate(cf.streams): # list of sources for current stream self.streams_datasets[stream_info["name"]] = [] @@ -190,6 +199,17 @@ def __init__( ) ds = dataset(filename=filename, **kwargs) + # load offgrid dataset if specified + if self.offgrid_template is not None: + filename = pathlib.Path(str(self.offgrid_template)) + ds = DataReaderOffgrid( + tw_handler=self.time_window_handler, + filename=filename, + frequency=self.offgrid_frequency, + stream_info=stream_info, + ref_reader=ds, + ) + stream_info[str(self._stage) + "_source_channels"] = ds.source_channels stream_info[str(self._stage) + "_target_channels"] = ds.target_channels stream_info["target_channel_weights"] = (