From 3f1bb7d3dc4d24602a58e51315e39ad782b06c0b Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 30 Oct 2025 17:27:56 +0000 Subject: [PATCH 001/344] Abstract class for target/aux computation Implemented Identity class TODO: implement EMATeacher --- .../train/target_and_aux_module_base.py | 30 +++++++++++++++++++ .../train/target_and_aux_ssl_teacher.py | 17 +++++++++++ src/weathergen/train/trainer.py | 13 ++++++-- src/weathergen/train/trainer_base.py | 14 +++++++++ 4 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 src/weathergen/train/target_and_aux_module_base.py create mode 100644 src/weathergen/train/target_and_aux_ssl_teacher.py diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py new file mode 100644 index 000000000..37812e8b6 --- /dev/null +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -0,0 +1,30 @@ +from typing import Any + +class TargetAndAuxModuleBase: + def __init__(self, model, rng, **kwargs): + pass + + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + pass + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + pass + + def compute(self, *args, **kwargs) -> tuple[Any, Any]: + pass + + +class IdentityTargetAndAux(TargetAndAuxModuleBase): + def __init__(self, model, rng, config): + return + + def update_state_pre_backward(self, istep, batch, model, **kwargs): + return + + def update_state_post_opt_step(self, istep, batch, model, **kwargs): + return + + def compute(self, istep, batch, model): + return batch[0], None + + diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py new file mode 100644 index 000000000..3992b5a00 --- /dev/null +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -0,0 +1,17 @@ +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, IdentityTargetAndAux + + +class EMATeacher(TargetAndAuxModuleBase): + def __init__(self): + pass + + +# should be moved to its own file so as to prevent cyclical imports +def get_target_and_aux_calculator(config, model, rng, **kwargs): + target_and_aux_calc = config.get("target_and_aux_calc", None) + if target_and_aux_calc is None or target_and_aux_calc == "identity": + return IdentityTargetAndAux(model, rng, config) + elif target_and_aux_calc == "EMATeacher": + return EMATeacher(model, rng, kwargs["ema_model"]) + else: + raise NotImplemented(f"{target_and_aux_calc} is not implemented") diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..dd61dac93 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -46,7 +46,7 @@ from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler -from weathergen.train.trainer_base import TrainerBase +from weathergen.train.trainer_base import TrainerBase, get_target_and_aux_calculator from weathergen.utils.distributed import all_gather_vlen, ddp_average, is_root from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger from weathergen.utils.utils import get_dtype @@ -323,6 +323,8 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): is_model_sharded=(cf.with_ddp and cf.with_fsdp), ) + self.target_and_aux_calculator = get_target_and_aux_calculator(cf, self.model, None, ema_model = self.ema_model) + # if with_fsdp then parameter count is unreliable if (is_root() and not cf.with_fsdp) or not cf.with_ddp: self.model.print_num_parameters() @@ -591,19 +593,24 @@ def train(self, epoch): preds, posteriors = self.model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + + targets, aux_outputs = self.target_and_aux_calculator.compute(bidx, batch, self.model) loss_values = self.loss_calculator.compute_loss( preds=preds, - streams_data=batch[0], + streams_data=batch[0], # should additionally take targets? ) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + self.target_and_aux_calculator.update_state_pre_backward(bidx, batch, self.model) + # backward pass self.optimizer.zero_grad() self.grad_scaler.scale(loss_values.loss).backward() # loss_values.loss.backward() + # gradient clipping self.grad_scaler.unscale_(self.optimizer) total_norm = torch.nn.utils.clip_grad_norm_( @@ -622,6 +629,8 @@ def train(self, epoch): self.grad_scaler.update() # self.optimizer.step() + self.target_and_aux_calculator.update_state_post_opt_step(bidx, batch, self.model) + # update learning rate self.lr_scheduler.step() diff --git a/src/weathergen/train/trainer_base.py b/src/weathergen/train/trainer_base.py index 684b3b54b..db6b4aee5 100644 --- a/src/weathergen/train/trainer_base.py +++ b/src/weathergen/train/trainer_base.py @@ -20,6 +20,9 @@ from weathergen.train.utils import str_to_tensor, tensor_to_str from weathergen.utils.distributed import is_root +from weathergen.train.target_and_aux_module_base import IdentityTargetAndAux +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + PORT = 1345 @@ -167,3 +170,14 @@ def get_perf(self): perf_mem /= len(self.device_handles) return perf_gpu, perf_mem + + +# should be moved to its own file so as to prevent cyclical imports +def get_target_and_aux_calculator(config, model, rng, **kwargs): + target_and_aux_calc = config.get("target_and_aux_calc", None) + if target_and_aux_calc is None or target_and_aux_calc == "identity": + return IdentityTargetAndAux(model, rng, config) + elif target_and_aux_calc == "EMATeacher": + return EMATeacher(model, rng, kwargs["ema_model"]) + else: + raise NotImplemented(f"{target_and_aux_calc} is not implemented") From 03ed1483c72c2b290dc96ccbeeeeda1dd4719c62 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Fri, 31 Oct 2025 16:49:30 +0000 Subject: [PATCH 002/344] Start implementing the EMA Teacher The big question on the EMA teacher side to me is how to allow for a fleixble teacher and student architecture that can differ We updated some APIs of the abstract base class to allow the ema_model forward, subject to change given the loss calculator, which is imho the second big question mark --- src/weathergen/model/ema.py | 4 +- .../train/target_and_aux_module_base.py | 8 +++- .../train/target_and_aux_ssl_teacher.py | 41 ++++++++++++------- src/weathergen/train/trainer.py | 11 +++-- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 7acbbf9f0..207362b4f 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -44,7 +44,7 @@ def reset(self): self.ema_model.to_empty(device="cuda") maybe_sharded_sd = self.original_model.state_dict() # this copies correctly tested in pdb - mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=True, assign=False) + mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=False, assign=False) @torch.no_grad() def update(self, cur_step, batch_size): @@ -53,7 +53,7 @@ def update(self, cur_step, batch_size): halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) for p_net, p_ema in zip( - self.original_model.parameters(), self.ema_model.parameters(), strict=True + self.original_model.parameters(), self.ema_model.parameters(), strict=False ): p_ema.lerp_(p_net, 1 - beta) diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index 37812e8b6..c04d66ef7 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -4,6 +4,9 @@ class TargetAndAuxModuleBase: def __init__(self, model, rng, **kwargs): pass + def reset(self): + pass + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: pass @@ -18,13 +21,16 @@ class IdentityTargetAndAux(TargetAndAuxModuleBase): def __init__(self, model, rng, config): return + def reset(self): + return + def update_state_pre_backward(self, istep, batch, model, **kwargs): return def update_state_post_opt_step(self, istep, batch, model, **kwargs): return - def compute(self, istep, batch, model): + def compute(self, istep, batch, *args, **kwargs): return batch[0], None diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 3992b5a00..979c3bd8a 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -1,17 +1,30 @@ -from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, IdentityTargetAndAux +from typing import Any + +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase class EMATeacher(TargetAndAuxModuleBase): - def __init__(self): - pass - - -# should be moved to its own file so as to prevent cyclical imports -def get_target_and_aux_calculator(config, model, rng, **kwargs): - target_and_aux_calc = config.get("target_and_aux_calc", None) - if target_and_aux_calc is None or target_and_aux_calc == "identity": - return IdentityTargetAndAux(model, rng, config) - elif target_and_aux_calc == "EMATeacher": - return EMATeacher(model, rng, kwargs["ema_model"]) - else: - raise NotImplemented(f"{target_and_aux_calc} is not implemented") + def __init__(self, model, rng, ema_model, batch_size, **kwargs): + # One of the issues is that the teacher model may have a different architecture + # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the + # the teacher. Because of the device sharding etc that requires quite a bit of + # massaging we assume that the teacher creates the EMA model correctly. However, + # note that you cannot assume that model.state_dict equals ema_model.state_dict + self.ema_model = ema_model + self.batch_size = batch_size + + def reset(self, batch_size = None): + self.ema_model.reset() + if batch_size is not None: + self.batch_size = batch_size + + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + return + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + self.ema_model.update(istep, self.batch_size) + + def compute(self, bidx, batch, model_params, model, forecast_offset, forecast_steps) -> tuple[Any, Any]: + return self.ema_model.forward_eval(model_params, batch, forecast_offset, forecast_steps), None + + diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index dd61dac93..fb564b99e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -323,7 +323,9 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): is_model_sharded=(cf.with_ddp and cf.with_fsdp), ) - self.target_and_aux_calculator = get_target_and_aux_calculator(cf, self.model, None, ema_model = self.ema_model) + self.target_and_aux_calculator = get_target_and_aux_calculator( + cf, self.model, None, ema_model=self.ema_model + ) # if with_fsdp then parameter count is unreliable if (is_root() and not cf.with_fsdp) or not cf.with_ddp: @@ -594,10 +596,12 @@ def train(self, epoch): self.model_params, batch, cf.forecast_offset, forecast_steps ) - targets, aux_outputs = self.target_and_aux_calculator.compute(bidx, batch, self.model) + targets, aux_outputs = self.target_and_aux_calculator.compute( + bidx, batch, self.model_params, self.model, cf.forecast_offset, forecast_steps + ) loss_values = self.loss_calculator.compute_loss( preds=preds, - streams_data=batch[0], # should additionally take targets? + streams_data=batch[0], # should additionally take targets? ) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) @@ -610,7 +614,6 @@ def train(self, epoch): self.grad_scaler.scale(loss_values.loss).backward() # loss_values.loss.backward() - # gradient clipping self.grad_scaler.unscale_(self.optimizer) total_norm = torch.nn.utils.clip_grad_norm_( From 28d9b2264c7c5f51b6c642b32ded43e39f0f16e3 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 4 Nov 2025 09:38:28 +0100 Subject: [PATCH 003/344] adding loss calculator base class --- src/weathergen/train/loss_calculator.py | 230 +++++++++++++++++++ src/weathergen/train/loss_calculator_base.py | 98 ++++++++ 2 files changed, 328 insertions(+) create mode 100644 src/weathergen/train/loss_calculator_base.py diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index f457d6454..1d54d49c9 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -17,6 +17,7 @@ import weathergen.train.loss as losses from weathergen.train.loss import stat_loss_fcts +from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) @@ -318,3 +319,232 @@ def compute_loss( # Return all computed loss components encapsulated in a ModelLoss dataclass return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossCalculatorPhysical(LossCalculatorBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + stage: Stage, + device: str, + ): + LossCalculatorBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + + # Dynamically load loss functions based on configuration and stage + loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + def compute_loss( + self, + preds: list[list[Tensor]], + streams_data: list[list[any]], + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( + loss_fct, + stream_info, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_calculator_base.py new file mode 100644 index 000000000..43091978c --- /dev/null +++ b/src/weathergen/train/loss_calculator_base.py @@ -0,0 +1,98 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import dataclasses + +import torch +from torch import Tensor + +from weathergen.common.config import Config +from weathergen.utils.train_logger import Stage + + +@dataclasses.dataclass +class LossValues: + """ + A dataclass to encapsulate the various loss components computed by the LossCalculator. + + This provides a structured way to return the primary loss used for optimization, + along with detailed per-stream/per-channel/per-loss-function losses for logging, + and standard deviations for ensemble scenarios. + """ + + # The primary scalar loss value for optimization. + loss: Tensor + # Dictionaries containing detailed loss values for each stream, channel, and loss function, as + # well as standard deviations when operating with ensembles (e.g., when training with CRPS). + losses_all: dict[str, Tensor] + stddev_all: dict[str, Tensor] + + +class LossCalculatorBase: + def __init__(self): + """ + Initializes the LossCalculator. + + This sets up the configuration, the operational stage (training or validation), + the device for tensor operations, and initializes the list of loss functions + based on the provided configuration. + + Args: + cf: The OmegaConf DictConfig object containing model and training configurations. + It should specify 'loss_fcts' for training and 'loss_fcts_val' for validation. + stage: The current operational stage, either TRAIN or VAL. + This dictates which set of loss functions (training or validation) will be used. + device: The computation device, such as 'cpu' or 'cuda:0', where tensors will reside. + """ + self.cf: Config | None = None + self.stage: Stage + self.loss_fcts = [] + + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + + # def _get_weights(self, stream_info): + + # def _update_weights(self, stream_info): From 192beb67d423343e90fa8ee8822632a810cb0a0c Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:30:12 +0000 Subject: [PATCH 004/344] Option for constructing teacher model flexibly --- src/weathergen/model/model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 13c462a6f..e6e096a70 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -14,6 +14,8 @@ import warnings from pathlib import Path +import copy + import astropy_healpix as hp import astropy_healpix.healpy import numpy as np @@ -907,3 +909,15 @@ def predict( preds_tokens += [checkpoint(self.pred_heads[ii], tc_tokens, use_reentrant=False)] return preds_tokens + +def get_model(cf: Config, sources_size, targets_num_channels, targets_coords_size, **kwargs): + if cf["training_mode"] == "student-teacher-pretrain": + student = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + teacher_cf = copy.deepcopy(cf) + for key, val in teacher_cf["teacher_model"].items(): + teacher_cf[key] = val + teacher = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + return student, teacher + elif cf["training_mode"] == "forecasting": + model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + return model, None From aac7e29a1e1b376c1bca568d70c4f7c3b0526f7c Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:19:39 +0000 Subject: [PATCH 005/344] Extract get batch size util function Easier to read and as batchsize gets more complicated in SSL this will be a useful abstraction --- src/weathergen/model/model.py | 24 +++++++------ .../train/target_and_aux_ssl_teacher.py | 3 ++ src/weathergen/train/trainer.py | 36 ++++++++++++++----- src/weathergen/train/trainer_base.py | 6 ++-- src/weathergen/utils/utils.py | 5 +++ 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index e6e096a70..13cbc8c6b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -910,14 +910,16 @@ def predict( return preds_tokens -def get_model(cf: Config, sources_size, targets_num_channels, targets_coords_size, **kwargs): - if cf["training_mode"] == "student-teacher-pretrain": - student = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() - teacher_cf = copy.deepcopy(cf) - for key, val in teacher_cf["teacher_model"].items(): - teacher_cf[key] = val - teacher = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() - return student, teacher - elif cf["training_mode"] == "forecasting": - model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() - return model, None + +def get_model(student_or_teacher, cf: Config, sources_size, targets_num_channels, targets_coords_size, **kwargs): + if student_or_teacher == "student" or student_or_teacher == "teacher": + return Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + else: + if cf["training_mode"] == "masking": # TODO implement mode "student-teacher-pretrain": + teacher_cf = copy.deepcopy(cf) + for key, val in teacher_cf["teacher_model"].items(): + teacher_cf[key] = val + teacher = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + return teacher + else: + raise NotImplementedError(f"The training mode {cf['training_mode']} is not implemented.") diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 979c3bd8a..dbaead8d7 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -13,6 +13,9 @@ def __init__(self, model, rng, ema_model, batch_size, **kwargs): self.ema_model = ema_model self.batch_size = batch_size + self.reset() + import pdb; pdb.set_trace() + def reset(self, batch_size = None): self.ema_model.reset() if batch_size is not None: diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index fb564b99e..2de6a5f17 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -42,14 +42,14 @@ ) from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP -from weathergen.model.model import Model, ModelParams +from weathergen.model.model import get_model, Model, ModelParams from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler from weathergen.train.trainer_base import TrainerBase, get_target_and_aux_calculator from weathergen.utils.distributed import all_gather_vlen, ddp_average, is_root from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger -from weathergen.utils.utils import get_dtype +from weathergen.utils.utils import get_dtype, get_batch_size from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) @@ -157,13 +157,15 @@ def inference(self, cf, devices, run_id_trained, epoch): self.validate(epoch=0) logger.info(f"Finished inference run with id: {cf.run_id}") - def init_model_and_shard(self, cf, devices): + def init_model_and_shard(self, cf, student_or_teacher, devices): sources_size = self.dataset.get_sources_size() targets_num_channels = self.dataset.get_targets_num_channels() targets_coords_size = self.dataset.get_targets_coords_size() with torch.device("meta"): - model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + model = get_model( + student_or_teacher, cf, sources_size, targets_num_channels, targets_coords_size + ) for name, module in model.named_modules(): name = module.name if hasattr(module, "name") else name @@ -294,7 +296,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): self.dataset_val, **loader_params, sampler=None ) - self.model, self.model_params = self.init_model_and_shard(cf, devices) + self.model, self.model_params = self.init_model_and_shard(cf, "student", devices) if run_id_contd is None: self.model.to_empty(device="cuda") @@ -313,8 +315,20 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): self.validate_with_ema = cf.get("validate_with_ema", False) self.ema_model = None + # validate_with_ema is incompatible with student-teacher + self.validate_with_ema = False # TODO remove for testing only if self.validate_with_ema: - meta_ema_model = self.init_model_and_shard(cf, devices)[0] + meta_ema_model = self.init_model_and_shard(cf, "student", devices)[0] + self.ema_model = EMAModel( + self.model, + meta_ema_model, + halflife_steps=cf.get("ema_halflife_in_thousands", 1e-3), + rampup_ratio=cf.get("ema_ramp_up_ratio", 0.09), + is_model_sharded=(cf.with_ddp and cf.with_fsdp), + ) + elif cf["training_mode"] == "masking": # "student-teacher-pretrain": + meta_ema_model = self.init_model_and_shard(cf, "teacher", devices)[0] + cf["target_and_aux_calc"] = "EMATeacher" self.ema_model = EMAModel( self.model, meta_ema_model, @@ -324,7 +338,11 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): ) self.target_and_aux_calculator = get_target_and_aux_calculator( - cf, self.model, None, ema_model=self.ema_model + cf, + self.model, + None, + ema_model=self.ema_model, + batch_size=get_batch_size(cf, self.world_size_original), ) # if with_fsdp then parameter count is unreliable @@ -640,8 +658,8 @@ def train(self, epoch): # EMA update if self.validate_with_ema: self.ema_model.update( - self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, - self.world_size_original * self.cf.batch_size_per_gpu, + self.cf.istep * get_batch_size(self.cf, self.world_size_original), + get_batch_size(self.cf, self.world_size_original), ) self.loss_unweighted_hist += [loss_values.losses_all] diff --git a/src/weathergen/train/trainer_base.py b/src/weathergen/train/trainer_base.py index db6b4aee5..c75e3ce7e 100644 --- a/src/weathergen/train/trainer_base.py +++ b/src/weathergen/train/trainer_base.py @@ -173,11 +173,11 @@ def get_perf(self): # should be moved to its own file so as to prevent cyclical imports -def get_target_and_aux_calculator(config, model, rng, **kwargs): +def get_target_and_aux_calculator(config, model, rng, batch_size, **kwargs): target_and_aux_calc = config.get("target_and_aux_calc", None) if target_and_aux_calc is None or target_and_aux_calc == "identity": return IdentityTargetAndAux(model, rng, config) elif target_and_aux_calc == "EMATeacher": - return EMATeacher(model, rng, kwargs["ema_model"]) + return EMATeacher(model, rng, kwargs["ema_model"], batch_size) else: - raise NotImplemented(f"{target_and_aux_calc} is not implemented") + raise NotImplementedError(f"{target_and_aux_calc} is not implemented") diff --git a/src/weathergen/utils/utils.py b/src/weathergen/utils/utils.py index 5deba9287..1e0fed42c 100644 --- a/src/weathergen/utils/utils.py +++ b/src/weathergen/utils/utils.py @@ -9,6 +9,7 @@ import torch +from weathergen.common.config import Config def get_dtype(value: str) -> torch.dtype: """ @@ -24,3 +25,7 @@ def get_dtype(value: str) -> torch.dtype: raise NotImplementedError( f"Dtype {value} is not recognized, choose either, bf16, fp16, or fp32" ) + + +def get_batch_size(cf: Config, world_size: int) -> int: + return world_size * cf.batch_size_per_gpu From 145d18a3adfc9ffc30bda89c318ea80069129cd1 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:50:27 +0000 Subject: [PATCH 006/344] Fix mismatched dtypes in the target computation It runs so far. Next steps: - Route all the config options - Start writing the loss functions to understand the state requirements --- src/weathergen/train/target_and_aux_ssl_teacher.py | 1 - src/weathergen/train/trainer.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index dbaead8d7..ba008ae50 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -14,7 +14,6 @@ def __init__(self, model, rng, ema_model, batch_size, **kwargs): self.batch_size = batch_size self.reset() - import pdb; pdb.set_trace() def reset(self, batch_size = None): self.ema_model.reset() diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 2de6a5f17..4453cf33b 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -614,9 +614,9 @@ def train(self, epoch): self.model_params, batch, cf.forecast_offset, forecast_steps ) - targets, aux_outputs = self.target_and_aux_calculator.compute( - bidx, batch, self.model_params, self.model, cf.forecast_offset, forecast_steps - ) + targets, aux_outputs = self.target_and_aux_calculator.compute( + bidx, batch, self.model_params, self.model, cf.forecast_offset, forecast_steps + ) loss_values = self.loss_calculator.compute_loss( preds=preds, streams_data=batch[0], # should additionally take targets? From f1e71321ffedc5c4a272a36b8d76231affd41204 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 5 Nov 2025 18:12:01 +0100 Subject: [PATCH 007/344] abstract loss calc structure --- config/default_config.yml | 5 +- src/weathergen/train/loss_calculator.py | 524 ++---------------- src/weathergen/train/loss_calculator_base.py | 5 + .../train/loss_calculator_classes.py | 285 ++++++++++ src/weathergen/train/trainer.py | 6 +- 5 files changed, 329 insertions(+), 496 deletions(-) create mode 100644 src/weathergen/train/loss_calculator_classes.py diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..26f9382c2 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -70,9 +70,12 @@ latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True loss_fcts: - - + - - "mse" - 1.0 + # - + # - "latent:mse" + # - 0.3 loss_fcts_val: - - "mse" diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 1d54d49c9..ffd018f22 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -1,3 +1,5 @@ +# ruff: noqa: T201 + # (C) Copyright 2025 WeatherGenerator contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 @@ -7,49 +9,21 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import dataclasses import logging -import numpy as np -import torch from omegaconf import DictConfig -from torch import Tensor -import weathergen.train.loss as losses -from weathergen.train.loss import stat_loss_fcts -from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues -from weathergen.utils.train_logger import TRAIN, VAL, Stage +from weathergen.train.loss_calculator_base import LossValues +from weathergen.train.loss_calculator_classes import LossCalculatorLatent, LossCalculatorPhysical +from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) -@dataclasses.dataclass -class LossValues: - """ - A dataclass to encapsulate the various loss components computed by the LossCalculator. - - This provides a structured way to return the primary loss used for optimization, - along with detailed per-stream/per-channel/per-loss-function losses for logging, - and standard deviations for ensemble scenarios. - """ - - # The primary scalar loss value for optimization. - loss: Tensor - # Dictionaries containing detailed loss values for each stream, channel, and loss function, as - # well as standard deviations when operating with ensembles (e.g., when training with CRPS). - losses_all: dict[str, Tensor] - stddev_all: dict[str, Tensor] - - class LossCalculator: """ Manages and computes the overall loss for a WeatherGenerator model during training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. """ def __init__( @@ -76,475 +50,39 @@ def __init__( self.stage = stage self.device = device - # Dynamically load loss functions based on configuration and stage loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts - ] - - def _get_weights(self, stream_info): - """ - Get weights for current stream - """ - - device = self.device - - # Determine stream and channel loss weights based on the current stage - if self.stage == TRAIN: - # set loss_weights to 1. when not specified - stream_info_loss_weight = stream_info.get("loss_weight", 1.0) - weights_channels = ( - torch.tensor(stream_info["target_channel_weights"]).to( - device=device, non_blocking=True - ) - if "target_channel_weights" in stream_info - else None - ) - elif self.stage == VAL: - # in validation mode, always unweighted loss - stream_info_loss_weight = 1.0 - weights_channels = None - - return stream_info_loss_weight, weights_channels - - def _get_fstep_weights(self, forecast_steps): - timestep_weight_config = self.cf.get("timestep_weight") - if timestep_weight_config is None: - return [1.0 for _ in range(forecast_steps)] - weights_timestep_fct = getattr(losses, timestep_weight_config[0]) - return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) - - def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): - location_weight_type = stream_info.get("location_weight", None) - if location_weight_type is None: - return None - weights_locations_fct = getattr(losses, location_weight_type) - weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) - weights_locations = weights_locations.to(device=self.device, non_blocking=True) - - return weights_locations - - def _get_substep_masks(self, stream_info, fstep, stream_data): - """ - Find substeps and create corresponding masks (reused across loss functions) - """ - - tok_spacetime = stream_info.get("tokenize_spacetime", None) - target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] - target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] - substep_masks = [] - for t in target_times_unique: - # find substep - mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) - substep_masks.append(mask_t) - - return substep_masks - - @staticmethod - def _loss_per_loss_function( - loss_fct, - stream_info, - target: torch.Tensor, - pred: torch.Tensor, - substep_masks: list[torch.Tensor], - weights_channels: torch.Tensor, - weights_locations: torch.Tensor, - ): - """ - Compute loss for given loss function - """ - - loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) - losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) - - ctr_substeps = 0 - for mask_t in substep_masks: - assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True - - loss, loss_chs = loss_fct( - target[mask_t], pred[:, mask_t], weights_channels, weights_locations - ) - - # accumulate loss - loss_lfct = loss_lfct + loss - losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs - ctr_substeps += 1 if loss > 0.0 else 0 - - # normalize over forecast steps in window - losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 - - # TODO: substep weight - loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) - - return loss_lfct, losses_chs - - def compute_loss( - self, - preds: list[list[Tensor]], - streams_data: list[list[any]], - ) -> LossValues: - """ - Computes the total loss for a given batch of predictions and corresponding - stream data. - - The computed loss is: - - Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) - - This method orchestrates the calculation of the overall loss by iterating through - different data streams, forecast steps, channels, and configured loss functions. - It applies weighting, handles NaN values through masking, and accumulates - detailed loss metrics for logging. - - Args: - preds: A nested list of prediction tensors. The outer list represents forecast steps, - the inner list represents streams. Each tensor contains predictions for that - step and stream. - streams_data: A nested list representing the input batch data. The outer list is for - batch items, the inner list for streams. Each element provides an object - (e.g., dataclass instance) containing target data and metadata. - - Returns: - A ModelLoss dataclass instance containing: - - loss: The loss for back-propagation. - - losses_all: A dictionary mapping stream names to a tensor of per-channel and - per-loss-function losses, normalized by non-empty targets/forecast steps. - - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations - of predictions for channels with statistical loss functions, normalized. - """ - - # gradient loss - loss = torch.tensor(0.0, device=self.device, requires_grad=True) - # counter for non-empty targets - ctr_streams = 0 - - # initialize dictionaries for detailed loss tracking and standard deviation statistics - # create tensor for each stream - losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), - device=self.device, - ) - for st in self.cf.streams - } - stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams - } - - # TODO: iterate over batch dimension - i_batch = 0 - for i_stream_info, stream_info in enumerate(self.cf.streams): - # extract target tokens for current stream from the specified forecast offset onwards - targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] - - stream_data = streams_data[i_batch][i_stream_info] - - fstep_loss_weights = self._get_fstep_weights(len(targets)) - - loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps = 0 - - stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() - if stream_is_spoof: - spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) - else: - spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) - - for fstep, (target, fstep_weight) in enumerate( - zip(targets, fstep_loss_weights, strict=False) - ): - # skip if either target or prediction has no data points - pred = preds[fstep][i_stream_info] - if not (target.shape[0] > 0 and pred.shape[0] > 0): - continue - - # reshape prediction tensor to match target's dimensions: extract data/coords and - # remove token dimension if it exists. - # expected final shape of pred is [ensemble_size, num_samples, num_channels]. - pred = pred.reshape([pred.shape[0], *target.shape]) - assert pred.shape[1] > 0 - - # get weigths for current streams - stream_loss_weight, weights_channels = self._get_weights(stream_info) - - # get weights for locations - weights_locations = self._get_location_weights( - stream_info, stream_data, self.cf.forecast_offset, fstep - ) - - # get masks for sub-time steps - substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) - - # accumulate loss from different loss functions - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): - # loss for current loss function - loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( - loss_fct, - stream_info, - target, - pred, - substep_masks, - weights_channels, - weights_locations, - ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + ( - loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight - ) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) - ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) - ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 - - # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - - # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan - # normalize by all targets and forecast steps that were non-empty - # (with each having an expected loss of 1 for an uninitalized neural net) - loss = loss / ctr_streams - - # Return all computed loss components encapsulated in a ModelLoss dataclass - return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) - - -class LossCalculatorPhysical(LossCalculatorBase): - """ - Manages and computes the overall loss for a WeatherGenerator model during - training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. - """ - - def __init__( - self, - cf: DictConfig, - stage: Stage, - device: str, - ): - LossCalculatorBase.__init__(self) - self.cf = cf - self.stage = stage - self.device = device - - # Dynamically load loss functions based on configuration and stage - loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts + loss_fcts_physical = [[name, w] for name, w in loss_fcts if name.split(":")[0] != "latent"] + loss_fcts_latent = [ + [name.split(":")[1], w] for name, w in loss_fcts if name.split(":")[0] == "latent" ] - def _get_weights(self, stream_info): - """ - Get weights for current stream - """ - - device = self.device - - # Determine stream and channel loss weights based on the current stage - if self.stage == TRAIN: - # set loss_weights to 1. when not specified - stream_info_loss_weight = stream_info.get("loss_weight", 1.0) - weights_channels = ( - torch.tensor(stream_info["target_channel_weights"]).to( - device=device, non_blocking=True - ) - if "target_channel_weights" in stream_info - else None - ) - elif self.stage == VAL: - # in validation mode, always unweighted loss - stream_info_loss_weight = 1.0 - weights_channels = None - - return stream_info_loss_weight, weights_channels - - def _get_fstep_weights(self, forecast_steps): - timestep_weight_config = self.cf.get("timestep_weight") - if timestep_weight_config is None: - return [1.0 for _ in range(forecast_steps)] - weights_timestep_fct = getattr(losses, timestep_weight_config[0]) - return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) - - def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): - location_weight_type = stream_info.get("location_weight", None) - if location_weight_type is None: - return None - weights_locations_fct = getattr(losses, location_weight_type) - weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) - weights_locations = weights_locations.to(device=self.device, non_blocking=True) - - return weights_locations - - def _get_substep_masks(self, stream_info, fstep, stream_data): - """ - Find substeps and create corresponding masks (reused across loss functions) - """ + calculator_configs = [] - tok_spacetime = stream_info.get("tokenize_spacetime", None) - target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] - target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] - substep_masks = [] - for t in target_times_unique: - # find substep - mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) - substep_masks.append(mask_t) + if loss_fcts_physical: + calculator_configs.append((LossCalculatorPhysical, loss_fcts_physical, "physical")) + if loss_fcts_latent: + calculator_configs.append((LossCalculatorLatent, loss_fcts_latent, "latent")) - return substep_masks + self.loss_calculators = [ + (Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device), type) + for (Cls, losses, type) in calculator_configs + ] def compute_loss( self, - preds: list[list[Tensor]], - streams_data: list[list[any]], - ) -> LossValues: - """ - Computes the total loss for a given batch of predictions and corresponding - stream data. - - The computed loss is: - - Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) - - This method orchestrates the calculation of the overall loss by iterating through - different data streams, forecast steps, channels, and configured loss functions. - It applies weighting, handles NaN values through masking, and accumulates - detailed loss metrics for logging. - - Args: - preds: A nested list of prediction tensors. The outer list represents forecast steps, - the inner list represents streams. Each tensor contains predictions for that - step and stream. - streams_data: A nested list representing the input batch data. The outer list is for - batch items, the inner list for streams. Each element provides an object - (e.g., dataclass instance) containing target data and metadata. - - Returns: - A ModelLoss dataclass instance containing: - - loss: The loss for back-propagation. - - losses_all: A dictionary mapping stream names to a tensor of per-channel and - per-loss-function losses, normalized by non-empty targets/forecast steps. - - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations - of predictions for channels with statistical loss functions, normalized. - """ - - # gradient loss - loss = torch.tensor(0.0, device=self.device, requires_grad=True) - # counter for non-empty targets - ctr_streams = 0 - - # initialize dictionaries for detailed loss tracking and standard deviation statistics - # create tensor for each stream - losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), - device=self.device, - ) - for st in self.cf.streams - } - stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams - } - - # TODO: iterate over batch dimension - i_batch = 0 - for i_stream_info, stream_info in enumerate(self.cf.streams): - # extract target tokens for current stream from the specified forecast offset onwards - targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] - - stream_data = streams_data[i_batch][i_stream_info] - - fstep_loss_weights = self._get_fstep_weights(len(targets)) - - loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps = 0 - - stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() - if stream_is_spoof: - spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) - else: - spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) - - for fstep, (target, fstep_weight) in enumerate( - zip(targets, fstep_loss_weights, strict=False) - ): - # skip if either target or prediction has no data points - pred = preds[fstep][i_stream_info] - if not (target.shape[0] > 0 and pred.shape[0] > 0): - continue - - # reshape prediction tensor to match target's dimensions: extract data/coords and - # remove token dimension if it exists. - # expected final shape of pred is [ensemble_size, num_samples, num_channels]. - pred = pred.reshape([pred.shape[0], *target.shape]) - assert pred.shape[1] > 0 - - # get weigths for current streams - stream_loss_weight, weights_channels = self._get_weights(stream_info) - - # get weights for locations - weights_locations = self._get_location_weights( - stream_info, stream_data, self.cf.forecast_offset, fstep - ) - - # get masks for sub-time steps - substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) - - # accumulate loss from different loss functions - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): - # loss for current loss function - loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( - loss_fct, - stream_info, - target, - pred, - substep_masks, - weights_channels, - weights_locations, - ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + ( - loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight - ) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) - ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) - ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 - - # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - - # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan - - # normalize by all targets and forecast steps that were non-empty - # (with each having an expected loss of 1 for an uninitalized neural net) - loss = loss / ctr_streams - - # Return all computed loss components encapsulated in a ModelLoss dataclass + preds: dict, + targets: dict, + ): + loss_values = {} + loss = 0 + for calculator, type in self.loss_calculators: + loss_values[type] = calculator.compute_loss(preds=preds[type], targets=targets[type]) + loss += loss_values[type].loss + + losses_all = {} + stddev_all = {} + for _, v in loss_values.items(): + losses_all.update(v.losses_all) + stddev_all.update(v.stddev_all) return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_calculator_base.py index 43091978c..720116baa 100644 --- a/src/weathergen/train/loss_calculator_base.py +++ b/src/weathergen/train/loss_calculator_base.py @@ -17,6 +17,11 @@ from weathergen.common.config import Config from weathergen.utils.train_logger import Stage +# @dataclasses.dataclass +# class InputOutputStructure: + +# targets.latent + @dataclasses.dataclass class LossValues: diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_calculator_classes.py new file mode 100644 index 000000000..62ed45a7d --- /dev/null +++ b/src/weathergen/train/loss_calculator_classes.py @@ -0,0 +1,285 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss as losses +from weathergen.train.loss import stat_loss_fcts +from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues +from weathergen.utils.train_logger import TRAIN, VAL, Stage + +_logger = logging.getLogger(__name__) + + +class LossCalculatorPhysical(LossCalculatorBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossCalculatorBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + + # Dynamically load loss functions based on configuration and stage + + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + streams_data = targets + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = self._loss_per_loss_function( + loss_fct, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossCalculatorLatent(LossCalculatorBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossCalculatorBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..3c31daed6 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -588,12 +588,14 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, posteriors = self.model( + predictions, posteriors = self.model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + targets = {"physical": batch[0]} + preds = {"physical": predictions, "latent": posteriors} loss_values = self.loss_calculator.compute_loss( preds=preds, - streams_data=batch[0], + targets=targets, ) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) From e822e12928605c11482ede836b409decca1d658b Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 6 Nov 2025 16:45:38 +0100 Subject: [PATCH 008/344] add abstract method to loss calculator base class --- src/weathergen/train/loss_calculator_base.py | 53 ++++--------------- .../train/loss_calculator_classes.py | 37 +++++++++++++ 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_calculator_base.py index 720116baa..13ad0394d 100644 --- a/src/weathergen/train/loss_calculator_base.py +++ b/src/weathergen/train/loss_calculator_base.py @@ -10,8 +10,8 @@ # nor does it submit to any jurisdiction. import dataclasses +from abc import abstractmethod -import torch from torch import Tensor from weathergen.common.config import Config @@ -44,11 +44,7 @@ class LossValues: class LossCalculatorBase: def __init__(self): """ - Initializes the LossCalculator. - - This sets up the configuration, the operational stage (training or validation), - the device for tensor operations, and initializes the list of loss functions - based on the provided configuration. + Base class for loss calculators. Args: cf: The OmegaConf DictConfig object containing model and training configurations. @@ -61,43 +57,14 @@ def __init__(self): self.stage: Stage self.loss_fcts = [] - @staticmethod - def _loss_per_loss_function( - loss_fct, - target: torch.Tensor, - pred: torch.Tensor, - substep_masks: list[torch.Tensor], - weights_channels: torch.Tensor, - weights_locations: torch.Tensor, - ): + @abstractmethod + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: """ - Compute loss for given loss function + Computes loss given predictions and targets and returns values of LossValues dataclass. """ - loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) - losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) - - ctr_substeps = 0 - for mask_t in substep_masks: - assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True - - loss, loss_chs = loss_fct( - target[mask_t], pred[:, mask_t], weights_channels, weights_locations - ) - - # accumulate loss - loss_lfct = loss_lfct + loss - losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs - ctr_substeps += 1 if loss > 0.0 else 0 - - # normalize over forecast steps in window - losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 - - # TODO: substep weight - loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) - - return loss_lfct, losses_chs - - # def _get_weights(self, stream_info): - - # def _update_weights(self, stream_info): + raise NotImplementedError() diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_calculator_classes.py index 62ed45a7d..6edab4bf4 100644 --- a/src/weathergen/train/loss_calculator_classes.py +++ b/src/weathergen/train/loss_calculator_classes.py @@ -112,6 +112,43 @@ def _get_substep_masks(self, stream_info, fstep, stream_data): return substep_masks + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + def compute_loss( self, preds: list[list[Tensor]], From d24ef486279fa784a494d3e94b407c0ba2604a09 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 6 Nov 2025 17:21:11 +0100 Subject: [PATCH 009/344] add latent loss class --- src/weathergen/train/loss_calculator.py | 2 + .../train/loss_calculator_classes.py | 69 +++++++++++++++++-- 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index ffd018f22..b2bdcc88e 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -80,6 +80,8 @@ def compute_loss( loss_values[type] = calculator.compute_loss(preds=preds[type], targets=targets[type]) loss += loss_values[type].loss + # Bring all loss values together + # TODO: keys should tell what type of loss was used, e.g loss_mse.latent.loss_2t losses_all = {} stddev_all = {} for _, v in loss_values.items(): diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_calculator_classes.py index 6edab4bf4..13b9ae76a 100644 --- a/src/weathergen/train/loss_calculator_classes.py +++ b/src/weathergen/train/loss_calculator_classes.py @@ -294,13 +294,7 @@ def compute_loss( class LossCalculatorLatent(LossCalculatorBase): """ - Manages and computes the overall loss for a WeatherGenerator model during - training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. + Calculates loss in latent space. """ def __init__( @@ -320,3 +314,64 @@ def __init__( [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] for name, w in loss_fcts ] + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, ens=None, mu=pred) + + return loss_val + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + losses_all: Tensor = torch.zeros( + len(self.loss_fcts), + device=self.device, + ) + + loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps_lat = 0 + # TODO: KCT, do we need the below per fstep? + for fstep in range( + 1, len(preds) + ): # the first entry in tokens_all is the source itself, so skip it + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + # if forecast_offset==0, then the timepoints correspond. Otherwise targets don't encode the source timestep, so we don't need to skip + fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): + loss_lfct = self._loss_per_loss_function( + loss_fct, + stream_info=None, + target=targets[fstep_targs], + pred=preds[fstep], + ) + + losses_all[i_lfct] += loss_lfct # TODO: break into fsteps + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps_lat = loss_fsteps_lat + ( + loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + ) + ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) + + losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 + losses_all[losses_all == 0.0] = torch.nan + + return LossValues(loss=loss, losses_all=losses_all) From c259c20421ce559d0bc1a4530ada37a196198b2d Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 7 Nov 2025 16:15:33 +0100 Subject: [PATCH 010/344] update loss calc config and rename files --- config/default_config.yml | 11 ++++-- .../weathergen/evaluate/export_inference.py | 6 ++- src/weathergen/train/loss_calculator.py | 32 +++++++-------- ...s_calculator_classes.py => loss_module.py} | 39 ++++++++++++++----- ...calculator_base.py => loss_module_base.py} | 7 +--- 5 files changed, 56 insertions(+), 39 deletions(-) rename src/weathergen/train/{loss_calculator_classes.py => loss_module.py} (94%) rename src/weathergen/train/{loss_calculator_base.py => loss_module_base.py} (95%) diff --git a/config/default_config.yml b/config/default_config.yml index 26f9382c2..e2de3ff21 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -73,9 +73,6 @@ loss_fcts: - - "mse" - 1.0 - # - - # - "latent:mse" - # - 0.3 loss_fcts_val: - - "mse" @@ -97,6 +94,14 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" +training_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} + } +# training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], +# LossLatent: [['mse', 0.3]], +# LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} +# } +validation_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} + } # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index 2c0cb4243..4e0bd7d6e 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -61,6 +61,7 @@ def detect_grid_type(input_data_array: xr.DataArray) -> str: # Otherwise it's Gaussian (irregular spacing or reduced grid) return "gaussian" + def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: """ Find all the pressure levels for each variable using regex and returns a dictionary @@ -90,6 +91,7 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: pl = list(set(pl)) return var_dict, pl + def reshape_dataset_adaptive(input_data_array: xr.DataArray) -> xr.Dataset: """ Reshape dataset while preserving grid structure (regular or Gaussian). @@ -176,8 +178,6 @@ def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) -> return ds - - def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: """ Add CF conventions to the dataset attributes. @@ -201,6 +201,7 @@ def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: ds.attrs["Conventions"] = "CF-1.12" return ds + def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: """ Modified CF parser that handles both regular and Gaussian grids. @@ -323,6 +324,7 @@ def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: return dataset + def output_filename( prefix: str, run_id: str, diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index b2bdcc88e..dfd582ec8 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -13,8 +13,8 @@ from omegaconf import DictConfig -from weathergen.train.loss_calculator_base import LossValues -from weathergen.train.loss_calculator_classes import LossCalculatorLatent, LossCalculatorPhysical +import weathergen.train.loss_module as LossModule +from weathergen.train.loss_module_base import LossValues from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) @@ -50,23 +50,17 @@ def __init__( self.stage = stage self.device = device - loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val + calculator_configs = ( + cf.training_mode_config.losses if stage == TRAIN else cf.validation_mode_config.losses + ) - loss_fcts_physical = [[name, w] for name, w in loss_fcts if name.split(":")[0] != "latent"] - loss_fcts_latent = [ - [name.split(":")[1], w] for name, w in loss_fcts if name.split(":")[0] == "latent" + calculator_configs = [ + (getattr(LossModule, Cls), losses) for (Cls, losses) in calculator_configs.items() ] - calculator_configs = [] - - if loss_fcts_physical: - calculator_configs.append((LossCalculatorPhysical, loss_fcts_physical, "physical")) - if loss_fcts_latent: - calculator_configs.append((LossCalculatorLatent, loss_fcts_latent, "latent")) - self.loss_calculators = [ - (Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device), type) - for (Cls, losses, type) in calculator_configs + Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device) + for (Cls, losses) in calculator_configs ] def compute_loss( @@ -76,12 +70,12 @@ def compute_loss( ): loss_values = {} loss = 0 - for calculator, type in self.loss_calculators: - loss_values[type] = calculator.compute_loss(preds=preds[type], targets=targets[type]) - loss += loss_values[type].loss + for calculator in self.loss_calculators: + loss_values[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) + loss += loss_values[calculator.name].loss # Bring all loss values together - # TODO: keys should tell what type of loss was used, e.g loss_mse.latent.loss_2t + # TODO: make sure keys are explicit, e.g loss_mse.latent.loss_2t losses_all = {} stddev_all = {} for _, v in loss_values.items(): diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_module.py similarity index 94% rename from src/weathergen/train/loss_calculator_classes.py rename to src/weathergen/train/loss_module.py index 13b9ae76a..2b345c3fe 100644 --- a/src/weathergen/train/loss_calculator_classes.py +++ b/src/weathergen/train/loss_module.py @@ -18,13 +18,13 @@ import weathergen.train.loss as losses from weathergen.train.loss import stat_loss_fcts -from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues +from weathergen.train.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) -class LossCalculatorPhysical(LossCalculatorBase): +class LossPhysical(LossModuleBase): """ Manages and computes the overall loss for a WeatherGenerator model during training and validation stages. @@ -42,13 +42,13 @@ def __init__( stage: Stage, device: str, ): - LossCalculatorBase.__init__(self) + LossModuleBase.__init__(self) self.cf = cf self.stage = stage self.device = device + self.name = "LossPhysical" # Dynamically load loss functions based on configuration and stage - self.loss_fcts = [ [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] for name, w in loss_fcts @@ -151,8 +151,8 @@ def _loss_per_loss_function( def compute_loss( self, - preds: list[list[Tensor]], - targets: list[list[any]], + preds: dict, + targets: dict, ) -> LossValues: """ Computes the total loss for a given batch of predictions and corresponding @@ -184,7 +184,8 @@ def compute_loss( of predictions for channels with statistical loss functions, normalized. """ - streams_data = targets + preds = preds["physical"] + streams_data = targets["physical"] # gradient loss loss = torch.tensor(0.0, device=self.device, requires_grad=True) @@ -292,7 +293,7 @@ def compute_loss( return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) -class LossCalculatorLatent(LossCalculatorBase): +class LossLatent(LossModuleBase): """ Calculates loss in latent space. """ @@ -304,10 +305,11 @@ def __init__( stage: Stage, device: str, ): - LossCalculatorBase.__init__(self) + LossModuleBase.__init__(self) self.cf = cf self.stage = stage self.device = device + self.name = "LossLatent" # Dynamically load loss functions based on configuration and stage self.loss_fcts = [ @@ -375,3 +377,22 @@ def compute_loss( losses_all[losses_all == 0.0] = torch.nan return LossValues(loss=loss, losses_all=losses_all) + + +class LossStudentTeacher(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + self.name = "LossStudentTeacher" + raise NotImplementedError() + + def compute_loss(self, preds, targets): + return super().compute_loss(preds, targets) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_module_base.py similarity index 95% rename from src/weathergen/train/loss_calculator_base.py rename to src/weathergen/train/loss_module_base.py index 13ad0394d..de66bda28 100644 --- a/src/weathergen/train/loss_calculator_base.py +++ b/src/weathergen/train/loss_module_base.py @@ -17,11 +17,6 @@ from weathergen.common.config import Config from weathergen.utils.train_logger import Stage -# @dataclasses.dataclass -# class InputOutputStructure: - -# targets.latent - @dataclasses.dataclass class LossValues: @@ -41,7 +36,7 @@ class LossValues: stddev_all: dict[str, Tensor] -class LossCalculatorBase: +class LossModuleBase: def __init__(self): """ Base class for loss calculators. From a19ee1658f65d1e0074ddc137ac58e70f66e6622 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 11 Nov 2025 15:41:29 +0100 Subject: [PATCH 011/344] restructure loss modules --- src/weathergen/train/loss_calculator.py | 6 +- src/weathergen/train/loss_modules/__init__.py | 5 + .../train/{ => loss_modules}/loss.py | 0 .../{ => loss_modules}/loss_module_base.py | 0 .../train/loss_modules/loss_module_latent.py | 112 ++++++++++++++++++ .../loss_module_physical.py} | 111 +---------------- .../train/loss_modules/loss_module_ssl.py | 38 ++++++ 7 files changed, 161 insertions(+), 111 deletions(-) create mode 100644 src/weathergen/train/loss_modules/__init__.py rename src/weathergen/train/{ => loss_modules}/loss.py (100%) rename src/weathergen/train/{ => loss_modules}/loss_module_base.py (100%) create mode 100644 src/weathergen/train/loss_modules/loss_module_latent.py rename src/weathergen/train/{loss_module.py => loss_modules/loss_module_physical.py} (77%) create mode 100644 src/weathergen/train/loss_modules/loss_module_ssl.py diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index dfd582ec8..2eda80fce 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -13,8 +13,8 @@ from omegaconf import DictConfig -import weathergen.train.loss_module as LossModule -from weathergen.train.loss_module_base import LossValues +import weathergen.train.loss_modules as LossModules +from weathergen.train.loss_modules.loss_module_base import LossValues from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ def __init__( ) calculator_configs = [ - (getattr(LossModule, Cls), losses) for (Cls, losses) in calculator_configs.items() + (getattr(LossModules, Cls), losses) for (Cls, losses) in calculator_configs.items() ] self.loss_calculators = [ diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py new file mode 100644 index 000000000..7f5fc906d --- /dev/null +++ b/src/weathergen/train/loss_modules/__init__.py @@ -0,0 +1,5 @@ +from .loss_module_latent import LossLatent +from .loss_module_physical import LossPhysical +from .loss_module_ssl import LossStudentTeacher + +__all__ = [LossLatent, LossPhysical, LossStudentTeacher] diff --git a/src/weathergen/train/loss.py b/src/weathergen/train/loss_modules/loss.py similarity index 100% rename from src/weathergen/train/loss.py rename to src/weathergen/train/loss_modules/loss.py diff --git a/src/weathergen/train/loss_module_base.py b/src/weathergen/train/loss_modules/loss_module_base.py similarity index 100% rename from src/weathergen/train/loss_module_base.py rename to src/weathergen/train/loss_modules/loss_module_base.py diff --git a/src/weathergen/train/loss_modules/loss_module_latent.py b/src/weathergen/train/loss_modules/loss_module_latent.py new file mode 100644 index 000000000..6daf472bb --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_latent.py @@ -0,0 +1,112 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss_modules.loss as losses +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossLatent(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatent" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, ens=None, mu=pred) + + return loss_val + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + return super().compute_loss(preds, targets) + + ### FROM KEREM's PR + # losses_all: Tensor = torch.zeros( + # len(self.loss_fcts), + # device=self.device, + # ) + + # loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) + # ctr_fsteps_lat = 0 + # # TODO: KCT, do we need the below per fstep? + # for fstep in range( + # 1, len(preds) + # ): # the first entry in tokens_all is the source itself, so skip it + # loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + # ctr_loss_fcts = 0 + # # if forecast_offset==0, then the timepoints correspond. + # # Otherwise targets don't encode the source timestep, so we don't need to skip + # fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 + # for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): + # loss_lfct = self._loss_per_loss_function( + # loss_fct, + # stream_info=None, + # target=targets[fstep_targs], + # pred=preds[fstep], + # ) + + # losses_all[i_lfct] += loss_lfct # TODO: break into fsteps + + # # Add the weighted and normalized loss from this loss function to the total + # # batch loss + # loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + # ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + # loss_fsteps_lat = loss_fsteps_lat + ( + # loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + # ) + # ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 + + # loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) + + # losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 + # losses_all[losses_all == 0.0] = torch.nan + + # return LossValues(loss=loss, losses_all=losses_all) diff --git a/src/weathergen/train/loss_module.py b/src/weathergen/train/loss_modules/loss_module_physical.py similarity index 77% rename from src/weathergen/train/loss_module.py rename to src/weathergen/train/loss_modules/loss_module_physical.py index 2b345c3fe..db4917550 100644 --- a/src/weathergen/train/loss_module.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -16,9 +16,9 @@ from omegaconf import DictConfig from torch import Tensor -import weathergen.train.loss as losses -from weathergen.train.loss import stat_loss_fcts -from weathergen.train.loss_module_base import LossModuleBase, LossValues +import weathergen.train.loss_modules.loss as losses +from weathergen.train.loss_modules.loss import stat_loss_fcts +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) @@ -291,108 +291,3 @@ def compute_loss( # Return all computed loss components encapsulated in a ModelLoss dataclass return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) - - -class LossLatent(LossModuleBase): - """ - Calculates loss in latent space. - """ - - def __init__( - self, - cf: DictConfig, - loss_fcts: list, - stage: Stage, - device: str, - ): - LossModuleBase.__init__(self) - self.cf = cf - self.stage = stage - self.device = device - self.name = "LossLatent" - - # Dynamically load loss functions based on configuration and stage - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts - ] - - def _loss_per_loss_function( - self, - loss_fct, - target: torch.Tensor, - pred: torch.Tensor, - ): - """ - Compute loss for given loss function - """ - - loss_val = loss_fct(target=target, ens=None, mu=pred) - - return loss_val - - def compute_loss( - self, - preds: list[list[Tensor]], - targets: list[list[any]], - ) -> LossValues: - losses_all: Tensor = torch.zeros( - len(self.loss_fcts), - device=self.device, - ) - - loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps_lat = 0 - # TODO: KCT, do we need the below per fstep? - for fstep in range( - 1, len(preds) - ): # the first entry in tokens_all is the source itself, so skip it - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - # if forecast_offset==0, then the timepoints correspond. Otherwise targets don't encode the source timestep, so we don't need to skip - fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): - loss_lfct = self._loss_per_loss_function( - loss_fct, - stream_info=None, - target=targets[fstep_targs], - pred=preds[fstep], - ) - - losses_all[i_lfct] += loss_lfct # TODO: break into fsteps - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps_lat = loss_fsteps_lat + ( - loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 - ) - ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) - - losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 - losses_all[losses_all == 0.0] = torch.nan - - return LossValues(loss=loss, losses_all=losses_all) - - -class LossStudentTeacher(LossModuleBase): - """ - Calculates loss in latent space. - """ - - def __init__( - self, - cf: DictConfig, - loss_fcts: list, - stage: Stage, - device: str, - ): - self.name = "LossStudentTeacher" - raise NotImplementedError() - - def compute_loss(self, preds, targets): - return super().compute_loss(preds, targets) diff --git a/src/weathergen/train/loss_modules/loss_module_ssl.py b/src/weathergen/train/loss_modules/loss_module_ssl.py new file mode 100644 index 000000000..240a2e27d --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_ssl.py @@ -0,0 +1,38 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +from omegaconf import DictConfig + +from weathergen.train.loss_modules.loss_module_base import LossModuleBase +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossStudentTeacher(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + self.name = "LossStudentTeacher" + raise NotImplementedError() + + def compute_loss(self, preds, targets): + return super().compute_loss(preds, targets) From bf3e128b28c6042feb53a45e4f1a481cd72fa1a1 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 11 Nov 2025 16:09:20 +0100 Subject: [PATCH 012/344] add ModelOutput dataclass --- src/weathergen/model/model.py | 16 ++++++++- .../loss_modules/loss_module_physical.py | 2 +- src/weathergen/train/trainer.py | 33 +++++++------------ 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 000f36735..ba5e2bb89 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -9,6 +9,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import dataclasses import logging import math import warnings @@ -42,6 +43,16 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class ModelOutput: + """ + A dataclass to encapsulate the model output and give a clear API. + """ + + physical: dict[str, torch.Tensor] + latent: dict[str, torch.Tensor] + + class ModelParams(torch.nn.Module): """Creation of query and embedding parameters of the model.""" @@ -653,7 +664,10 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - return preds_all, posteriors + latents = {} + latents["posteriors"] = posteriors + + return ModelOutput(physical=preds_all, latent=latents) ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index db4917550..54d30acc1 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -184,7 +184,7 @@ def compute_loss( of predictions for channels with statistical loss functions, normalized. """ - preds = preds["physical"] + preds = preds.physical streams_data = targets["physical"] # gradient loss diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3c31daed6..6070e9263 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -588,17 +588,14 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - predictions, posteriors = self.model( - self.model_params, batch, cf.forecast_offset, forecast_steps - ) + output = self.model(self.model_params, batch, cf.forecast_offset, forecast_steps) targets = {"physical": batch[0]} - preds = {"physical": predictions, "latent": posteriors} loss_values = self.loss_calculator.compute_loss( - preds=preds, + preds=output, targets=targets, ) if cf.latent_noise_kl_weight > 0.0: - kl = torch.cat([posterior.kl() for posterior in posteriors]) + kl = torch.cat([posterior.kl() for posterior in output.latent]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() # backward pass @@ -681,17 +678,17 @@ def validate(self, epoch): if self.ema_model is None else self.ema_model.forward_eval ) - preds, _ = model_forward( + output = model_forward( self.model_params, batch, cf.forecast_offset, forecast_steps ) - - # compute loss and log output + targets = {"physical": batch[0]} + # compute loss + loss_values = self.loss_calculator_val.compute_loss( + preds=output, + targets=targets, + ) + # log output if bidx < cf.log_validation: - loss_values = self.loss_calculator_val.compute_loss( - preds=preds, - streams_data=batch[0], - ) - # TODO: Move _prepare_logging into write_validation by passing streams_data ( preds_all, @@ -700,7 +697,7 @@ def validate(self, epoch): targets_times_all, targets_lens, ) = self._prepare_logging( - preds=preds, + preds=output, forecast_offset=cf.forecast_offset, forecast_steps=cf.forecast_steps, streams_data=batch[0], @@ -718,12 +715,6 @@ def validate(self, epoch): targets_lens, ) - else: - loss_values = self.loss_calculator_val.compute_loss( - preds=preds, - streams_data=batch[0], - ) - self.loss_unweighted_hist += [loss_values.losses_all] self.loss_model_hist += [loss_values.loss.item()] self.stdev_unweighted_hist += [loss_values.stddev_all] From 711f29bedb7bc5e26dfabe9d1428298cd9fd2208 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 11 Nov 2025 17:29:01 +0000 Subject: [PATCH 013/344] First draft of diffusion model --- src/weathergen/model/diffusion.py | 133 ++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 src/weathergen/model/diffusion.py diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py new file mode 100644 index 000000000..c51233f58 --- /dev/null +++ b/src/weathergen/model/diffusion.py @@ -0,0 +1,133 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch + +from weathergen.model.engines import ForecastingEngine + + +class DiffusionForecastEngine(torch.nn.Module): + # Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72 + + def __init__( + self, + cf, + stage: str, + forecast_engine: ForecastingEngine, + sigma_min: float = 0.002, + sigma_max: float = 80, + sigma_data: float = 0.5, + rho: float = 7, + p_mean: float = -1.2, + p_std: float = 1.2, + ): + super().__init__() + self.stage = stage + self.net = forecast_engine + self.preconditioner = Preconditioner() + + # Parameters + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.rho = rho + self.p_mean = p_mean + self.p_std = p_std + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stage == "train": + return self.forward_train(y=x) + else: + return self.inference(x=x) + + def forward_train(self, y) -> torch.Tensor: + # Determine noise level -- potentially move to "preprocessing" + noise = torch.randn(y.shape, device=y.device) + sigma = (noise * self.p_std + self.p_mean).exp() + n = torch.randn_like(y) * sigma + + # Compute conditionings + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt + c_noise = sigma.log() / 4 + + # Add noise, precondition input, and feed through network + x = y + n + x = self.preconditioner.precondition(x) + net_out = self.net(c_in * x, c_noise) + y_hat = c_skip * y + c_out * net_out # Eq. (7) + # return y_hat + + """ + F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + """ + + # Compute loss -- move this to a separate loss calculator + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 + loss = weight * ((y_hat - y) ** 2) + + def inference( + self, + x: torch.Tensor, + num_steps: int = 30, + ) -> torch.Tensor: + # Forward pass of the diffusion model during inference + # https://github.com/NVlabs/edm/blob/main/generate.py + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=x.device) + t_steps = ( + self.sigma_max ** (1 / self.rho) + + step_indices + / (num_steps - 1) + * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) + ) ** self.rho + t_steps = torch.cat( + [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + ) # t_N = 0 + + # Main sampling loop. + x_next = x * t_steps[0] + for i, (t_cur, t_next) in enumerate( + zip(t_steps[:-1], t_steps[1:], strict=False) + ): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. (Stochastic sampling?) + # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + # t_hat = self.net.round_sigma(t_cur + gamma * t_cur) + # x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * s_noise * torch.randn_like(x_cur) + x_hat = x_cur + t_hat = t_cur + + # Euler step. + denoised = self.net(x_hat, t_hat) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = self.net(x_next, t_next) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next + + +class Preconditioner: + # Preconditioner, e.g., to concatenate previous frames to the input + def __init_(self): + pass + + def precondition(self, x): + return x From 81bd6eb2977006cb4e4851b6e5b95e1b45718646 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 12 Nov 2025 09:38:53 +0100 Subject: [PATCH 014/344] NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things. --- src/weathergen/datasets/masking.py | 44 ++++++ src/weathergen/datasets/tokenizer_masking.py | 58 ++++---- src/weathergen/datasets/tokenizer_utils.py | 137 ++++++++++++++++++- 3 files changed, 210 insertions(+), 29 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 5334ca428..cb6eb90e3 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -139,6 +139,50 @@ def _select_strategy(self): # Non-combination strategy, return as is return self.masking_strategy + def mask_source_idxs( + self, + idxs_cells, + idxs_cells_lens, + rdata, + ) -> torch.Tensor: + """ + + Return: + torch.Tensor[bool] of length num_tokens that determines masking for each token + """ + + mask_tokens, mask_channels = None, None + + num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() + + # If there are no tokens, return empty lists. + if num_tokens == 0: + return (mask_tokens, mask_channels) + + # Clean strategy selection + self.current_strategy = self._select_strategy() + + # Set the masking rate. + rate = self._get_sampling_rate() + + if self.current_strategy == "random": + mask_tokens = self.rng.uniform(0, 1, num_tokens) < rate + elif self.current_strategy == "healpix": + # TODO: currently only for fixed level + num_cells = len(idxs_cells_lens) + mask_cells = self.rng.uniform(0, 1, num_cells) < rate + # translate cell mask to token mask, replicating using number of tokens per cell + mask_tokens = [ + (torch.ones(2, dtype=torch.bool) * (1 if m else 0)).to(torch.bool) + for idxs_cell, m in zip(idxs_cells_lens, mask_cells, strict=False) + ] + else: + assert False, f"Unsupported masking strategy: {self.current_strategy}" + + self.perm_sel = mask_tokens + + return (mask_tokens, mask_channels) + def mask_source( self, tokenized_data: list[torch.Tensor], diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 548b52124..860dbc07a 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -19,6 +19,9 @@ arc_alpha, encode_times_source, encode_times_target, + tokenize_apply_mask, + tokenize_space, + tokenize_spacetime, tokenize_window_space, tokenize_window_spacetime, ) @@ -47,20 +50,9 @@ def batchify_source( normalize_coords, # dataset ): token_size = stream_info["token_size"] + stream_id = stream_info["stream_id"] + assert token_size is not None, "stream did not specify token_size" is_diagnostic = stream_info.get("diagnostic", False) - tokenize_spacetime = stream_info.get("tokenize_spacetime", False) - - tokenize_window = partial( - tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, - time_win=time_win, - token_size=token_size, - hl=self.hl_source, - hpy_verts_rots=self.hpy_verts_rots_source[-1], - n_coords=normalize_coords, - enc_time=encode_times_source, - ) - - self.token_size = token_size # return empty if there is no data or we are in diagnostic mode if is_diagnostic or rdata.data.shape[1] == 0 or len(rdata.data) < 2: @@ -69,24 +61,36 @@ def batchify_source( source_centroids = [torch.tensor([])] return (source_tokens_cells, source_tokens_lens, source_centroids) - # tokenize all data first - tokenized_data = tokenize_window( - 0, - rdata.coords, - rdata.geoinfos, - rdata.data, - rdata.datetimes, - ) + # create tokenization index + tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space + idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=True) - tokenized_data = [ - torch.stack(c) if len(c) > 0 else torch.tensor([]) for c in tokenized_data - ] + (mask_tokens, mask_channels) = self.masker.mask_source_idxs( + idxs_cells, idxs_cells_lens, rdata + ) - # Use the masker to get source tokens and the selection mask for the target - source_tokens_cells = self.masker.mask_source( - tokenized_data, rdata.coords, rdata.geoinfos, rdata.data + source_tokens_cells = tokenize_apply_mask( + idxs_cells, + idxs_cells_lens, + mask_tokens, + mask_channels, + stream_id, + rdata, + time_win, + self.hpy_verts_rots_source[-1], + normalize_coords, + encode_times_source, ) + # tokenized_data = [ + # torch.stack(c) if len(c) > 0 else torch.tensor([]) for c in tokenized_data + # ] + + # # Use the masker to get source tokens and the selection mask for the target + # source_tokens_cells = self.masker.mask_source( + # tokenized_data, rdata.coords, rdata.geoinfos, rdata.data + # ) + source_tokens_lens = torch.tensor([len(s) for s in source_tokens_cells], dtype=torch.int32) if source_tokens_lens.sum() > 0: source_centroids = self.compute_source_centroids(source_tokens_cells) diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index c15ece48f..354323875 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -6,6 +6,7 @@ from astropy_healpix.healpy import ang2pix from torch import Tensor +from weathergen.common.io import IOReaderData from weathergen.datasets.utils import ( r3tos2, s2tor3, @@ -163,11 +164,143 @@ def hpy_splits( # extract length and flatten nested list idxs_ord_lens = [[len(a) for a in aa] for aa in idxs_ord] - idxs_ord = [torch.cat([idxs for idxs in iidxs]) for iidxs in idxs_ord] + # idxs_ord = [torch.cat([idxs for idxs in iidxs]) for iidxs in idxs_ord] return idxs_ord, idxs_ord_lens, posr3 +def tokenize_space( + rdata, + token_size, + hl, + pad_tokens=True, +): + """Process one window into tokens""" + + # len(source)==1 would require special case handling that is not worth the effort + if len(rdata.data) < 2: + return + + # idx_ord_lens is length is number of tokens per healpix cell + idxs_ord, idxs_ord_lens, _ = hpy_splits(rdata.coords, hl, token_size, pad_tokens) + + return idxs_ord, idxs_ord_lens + + +def tokenize_spacetime( + rdata, + token_size, + hl, + pad_tokens=True, +): + """Tokenize respecting an intrinsic time step in the data, i.e. each time step is tokenized + separately + """ + + num_healpix_cells = 12 * 4**hl + idxs_cells = [[] for _ in range(num_healpix_cells)] + idxs_cells_lens = [[] for _ in range(num_healpix_cells)] + + t_unique = np.unique(rdata.datetimes) + for _, t in enumerate(t_unique): + mask = t == rdata.datetimes + rdata_cur = IOReaderData( + rdata.coords[mask], rdata.geoinfos[mask], rdata.data[mask], rdata.datetimes[mask] + ) + idxs_cur, idxs_cur_lens = tokenize_space(rdata_cur, token_size, hl, pad_tokens) + + idxs_cells = [t + list(tc) for t, tc in zip(idxs_cells, idxs_cur, strict=True)] + idxs_cells_lens = [t + tc for t, tc in zip(idxs_cells_lens, idxs_cur_lens, strict=True)] + + return idxs_cells, idxs_cells_lens + + +def tokenize_apply_mask( + idxs_cells, + idxs_cells_lens, + mask_tokens, + mask_channels, + stream_id, + rdata, + time_win, + hpy_verts_rots, + n_coords: CoordNormalizer, + enc_time, +): + # convert to token level, forgetting about cells + idxs_tokens = [i for t in idxs_cells for i in t] + idxs_lens = [i for t in idxs_cells_lens for i in t] + + # filter tokens using mask to obtain flat per data point index list + idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) + # filter list of token lens using mask and obtain flat list for splitting + idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) + + # pad with zero at the begining; idxs_cells -> idxs_tokens -> idxs_data has been prepared so + # that the zero-index is used to add the padding to the tokens to ensure fixed size + times_enc = enc_time(rdata.datetimes, time_win) + datetimes_enc_padded = torch.cat([torch.zeros_like(times_enc[0]).unsqueeze(0), times_enc]) + geoinfos_padded = torch.cat([torch.zeros_like(rdata.geoinfos[0]).unsqueeze(0), rdata.geoinfos]) + coords_padded = torch.cat([torch.zeros_like(rdata.coords[0]).unsqueeze(0), rdata.coords]) + data_padded = torch.cat([torch.zeros_like(rdata.data[0]).unsqueeze(0), rdata.data]) + + # apply mask + datetimes = datetimes_enc_padded[idxs_data] + geoinfos = geoinfos_padded[idxs_data] + coords = coords_padded[idxs_data] + data = data_padded[idxs_data] + + # TODO, TODO, TODO: fix _coords_local + # _coords_local + coords_local = torch.cat((coords, torch.zeros_like(coords[:, 0]).unsqueeze(1)), 1) + + # create tensor that contains all info + tokens = torch.cat((datetimes, coords_local, geoinfos, data), 1) + + # split up tensor into tokens + idxs_data_lens = idxs_data_lens.tolist() + tokens_cells = torch.split(tokens, idxs_data_lens) + + # # R^3 coords + # thetas = ((90.0 - coords[:, 0]) / 180.0) * np.pi + # phis = ((coords[:, 1] + 180.0) / 360.0) * 2.0 * np.pi + # posr3 = s2tor3(thetas, phis) + + # # convert to local coordinates + # # TODO: avoid that padded lists are rotated, which means potentially a lot of zeros + # coords_local = _coords_local(posr3, hpy_verts_rots, idxs_cells, n_coords) + + # # reorder based on cells (except for coords_local) and then cat along + # # (time,coords,geoinfos,source) dimension and then split based on cells + # tokens_cells = [ + # ( + # list( + # torch.split( + # torch.cat( + # ( + # torch.full([len(idxs), 1], stream_id, dtype=torch.float32), + # times_enc_padded[idxs], + # coords_local[i], + # geoinfos_padded[idxs], + # source_padded[idxs], + # ), + # 1, + # ), + # idxs_lens, + # ) + # ) + # if idxs_lens[0] > 0 + # else [] + # ) + # for i, (idxs, idxs_lens) in enumerate(zip(idxs_cells, idxs_cells_lens, strict=True)) + # ] + + return tokens_cells + + +#################################################################################################### + + def tokenize_window_space( stream_id: float, coords: torch.tensor, @@ -291,7 +424,7 @@ def _coords_local( # int32 should be enough idxs_ords_lens = torch.tensor(idxs_ords_lens_l, dtype=torch.int32) # concat all indices - idxs_ords_c = torch.cat(idxs_ord) + idxs_ords_c = torch.cat([torch.tensor(i) for i in idxs_ord]) # Copy the rotation matrices for each healpix cell # num_points x 3 x 3 rots = torch.repeat_interleave(hpy_verts_rots, idxs_ords_lens, dim=0) From f367bb421e11200972505ad34a659e9c27a95ef7 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 12 Nov 2025 13:29:30 +0000 Subject: [PATCH 015/344] Minor modifications --- src/weathergen/model/diffusion.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index c51233f58..bb26ecced 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -17,10 +17,9 @@ class DiffusionForecastEngine(torch.nn.Module): def __init__( self, - cf, stage: str, forecast_engine: ForecastingEngine, - sigma_min: float = 0.002, + sigma_min: float = 0.002, # Adapt to GenCast? sigma_max: float = 80, sigma_data: float = 0.5, rho: float = 7, @@ -47,7 +46,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.inference(x=x) def forward_train(self, y) -> torch.Tensor: - # Determine noise level -- potentially move to "preprocessing" + # Determine noise level -- move to "preprocessing" noise = torch.randn(y.shape, device=y.device) sigma = (noise * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma @@ -63,14 +62,8 @@ def forward_train(self, y) -> torch.Tensor: x = self.preconditioner.precondition(x) net_out = self.net(c_in * x, c_noise) y_hat = c_skip * y + c_out * net_out # Eq. (7) - # return y_hat - """ - F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) - assert F_x.dtype == dtype - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - """ + return y_hat # Compute loss -- move this to a separate loss calculator weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 @@ -126,7 +119,7 @@ def inference( class Preconditioner: # Preconditioner, e.g., to concatenate previous frames to the input - def __init_(self): + def __init__(self): pass def precondition(self, x): From 1cc168c87ae060f611dcb1001b975fd64e7cb321 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 12 Nov 2025 13:34:22 +0000 Subject: [PATCH 016/344] Linter --- src/weathergen/model/diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index bb26ecced..8379c2836 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -66,8 +66,8 @@ def forward_train(self, y) -> torch.Tensor: return y_hat # Compute loss -- move this to a separate loss calculator - weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 - loss = weight * ((y_hat - y) ** 2) + # weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 + # loss = weight * ((y_hat - y) ** 2) def inference( self, From 48934c25dd47edc889976890eb6ec050366d0cff Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 12 Nov 2025 16:03:01 +0000 Subject: [PATCH 017/344] Copyright attribution to EDM --- NOTICE | 14 ++++++++++++++ src/weathergen/model/diffusion.py | 7 +++++++ 2 files changed, 21 insertions(+) create mode 100644 NOTICE diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..ddd243b23 --- /dev/null +++ b/NOTICE @@ -0,0 +1,14 @@ +======================================================================= +NVLABS/EDM (Elucidating the Design of Diffusion Models) + +This software incorporates code from the 'edm' repository. + +Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +The source code is available at: +https://github.com/NVlabs/edm + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 8379c2836..d7ef68ca7 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -7,6 +7,13 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +# ---------------------------------------------------------------------------- +# Third-Party Attribution: NVLABS/EDM (Elucidating the Design of Diffusion Models) +# This file incorporates code originally from the 'NVlabs/edm' repository. +# +# Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# ---------------------------------------------------------------------------- + import torch from weathergen.model.engines import ForecastingEngine From 51f437f593112712d92fcd7683255dd8281d75f7 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 13 Nov 2025 07:04:23 +0100 Subject: [PATCH 018/344] NOT WORKING: Finished src, target still to be done. --- src/weathergen/datasets/masking.py | 10 +- src/weathergen/datasets/tokenizer_masking.py | 66 ++-- src/weathergen/datasets/tokenizer_utils.py | 382 ++++++++++++++----- src/weathergen/datasets/utils.py | 245 +----------- 4 files changed, 353 insertions(+), 350 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index cb6eb90e3..b8f56b023 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -54,10 +54,6 @@ def __init__(self, cf: Config): # number of healpix cells self.healpix_num_cells = 12 * (4**self.healpix_level_data) - # Initialize the mask, set to None initially, - # until it is generated in mask_source. - self.perm_sel: list[np.typing.NDArray] = None - # Per-batch strategy tracking self.same_strategy_per_batch = self.masking_strategy_config.get( "same_strategy_per_batch", False @@ -167,6 +163,10 @@ def mask_source_idxs( if self.current_strategy == "random": mask_tokens = self.rng.uniform(0, 1, num_tokens) < rate + elif self.current_strategy == "forecast": + mask_tokens = np.zeros( + num_tokens, + ) elif self.current_strategy == "healpix": # TODO: currently only for fixed level num_cells = len(idxs_cells_lens) @@ -179,8 +179,6 @@ def mask_source_idxs( else: assert False, f"Unsupported masking strategy: {self.current_strategy}" - self.perm_sel = mask_tokens - return (mask_tokens, mask_channels) def mask_source( diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 860dbc07a..b9426a9db 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -19,15 +19,14 @@ arc_alpha, encode_times_source, encode_times_target, + get_target_coords_local_ffast, tokenize_apply_mask, + tokenize_apply_mask_target, tokenize_space, tokenize_spacetime, tokenize_window_space, tokenize_window_spacetime, ) -from weathergen.datasets.utils import ( - get_target_coords_local_ffast, -) class TokenizerMasking(Tokenizer): @@ -42,6 +41,9 @@ def reset_rng(self, rng) -> None: self.masker.reset_rng(rng) self.rng = rng + self.mask_tokens = None + self.mask_channels = None + def batchify_source( self, stream_info: dict, @@ -78,24 +80,17 @@ def batchify_source( rdata, time_win, self.hpy_verts_rots_source[-1], - normalize_coords, encode_times_source, ) - # tokenized_data = [ - # torch.stack(c) if len(c) > 0 else torch.tensor([]) for c in tokenized_data - # ] - - # # Use the masker to get source tokens and the selection mask for the target - # source_tokens_cells = self.masker.mask_source( - # tokenized_data, rdata.coords, rdata.geoinfos, rdata.data - # ) - source_tokens_lens = torch.tensor([len(s) for s in source_tokens_cells], dtype=torch.int32) - if source_tokens_lens.sum() > 0: - source_centroids = self.compute_source_centroids(source_tokens_cells) - else: - source_centroids = torch.tensor([]) + # if source_tokens_lens.sum() > 0: + # source_centroids = self.compute_source_centroids(source_tokens_cells) + # else: + # TODO: remove completely? + source_centroids = [torch.tensor([])] + + self.mask_tokens, self.mask_channels = mask_tokens, mask_channels return (source_tokens_cells, source_tokens_lens, source_centroids) @@ -107,16 +102,43 @@ def batchify_target( time_win: tuple, ): token_size = stream_info["token_size"] - tokenize_spacetime = stream_info.get("tokenize_spacetime", False) max_num_targets = stream_info.get("max_num_targets", -1) - - target_tokens, target_coords = torch.tensor([]), torch.tensor([]) - target_tokens_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32) + stream_id = stream_info["stream_id"] # target is empty - if len(self.masker.perm_sel) == 0: + if len(self.mask_tokens) == 0: + target_tokens, target_coords = torch.tensor([]), torch.tensor([]) + target_tokens_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32) return (target_tokens, target_coords, torch.tensor([]), torch.tensor([])) + # create tokenization index + tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space + idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) + + mask_tokens = ~self.mask_tokens + # mask_channels = ~self.mask_channels if self.mask_channels is not None + # else self.mask_channels + mask_channels = self.mask_channels + + coords_token_cells = tokenize_apply_mask_target( + self.hl_target, + idxs_cells, + idxs_cells_lens, + mask_tokens, + mask_channels, + stream_id, + rdata, + time_win, + self.hpy_verts_rots_target, + self.hpy_verts_local_target, + self.hpy_nctrs_target, + encode_times_target, + ) + + import code + + code.interact(local=locals()) + # identity function def id(arg): return arg diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 354323875..9bb1ec3e9 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - import numpy as np import pandas as pd import torch @@ -8,12 +6,12 @@ from weathergen.common.io import IOReaderData from weathergen.datasets.utils import ( + locs_to_cell_coords_ctrs, + locs_to_ctr_coords, r3tos2, s2tor3, ) -CoordNormalizer = Callable[[torch.Tensor], torch.Tensor] - # on some clusters our numpy version is pinned to be 1.x.x where the np.argsort does not # the stable=True argument numpy_argsort_args = {"stable": True} if int(np.__version__.split(".")[0]) >= 2 else {} @@ -29,6 +27,13 @@ def arc_alpha(sin_alpha, cos_alpha): return t +def theta_phi_to_standard_coords(coords): + thetas = ((90.0 - coords[:, 0]) / 180.0) * np.pi + phis = ((coords[:, 1] + 180.0) / 360.0) * 2.0 * np.pi + + return thetas, phis + + def encode_times_source(times, time_win) -> torch.tensor: """Encode times in the format used for source @@ -103,8 +108,7 @@ def hpy_cell_splits(coords: torch.tensor, hl: int): phis : phis in rad posr3 : (thetas,phis) as position in R3 """ - thetas = ((90.0 - coords[:, 0]) / 180.0) * np.pi - phis = ((coords[:, 1] + 180.0) / 360.0) * 2.0 * np.pi + thetas, phis = theta_phi_to_standard_coords(coords) # healpix cells for all points hpy_idxs = ang2pix(2**hl, thetas, phis, nest=True) posr3 = s2tor3(thetas, phis) @@ -154,9 +158,11 @@ def hpy_splits( # helper variables to split according to cells # pad to token size *and* offset by +1 to account for the index 0 that is added for the padding + offset = 1 if pad_tokens else 0 + int32 = torch.int32 idxs_ord = [ torch.split( - torch.cat((torch.from_numpy(np.take(idxs, ts) + 1), torch.zeros(r, dtype=torch.int32))), + torch.cat((torch.from_numpy(np.take(idxs, ts) + offset), torch.zeros(r, dtype=int32))), token_size, ) for idxs, ts, r in zip(hpy_idxs_ord_split, thetas_sorted, rem, strict=True) @@ -224,80 +230,274 @@ def tokenize_apply_mask( rdata, time_win, hpy_verts_rots, - n_coords: CoordNormalizer, enc_time, ): + """ + Apply masking to the data. + + Conceptually, the data is a matrix with the rows corresponding to data points / tokens and + the cols the channels. Thereby mask_tokens acts on the rows, grouped according to the tokens as + specified in idxs_cells and mask_channels acts on the columns. + + """ + # convert to token level, forgetting about cells idxs_tokens = [i for t in idxs_cells for i in t] idxs_lens = [i for t in idxs_cells_lens for i in t] - # filter tokens using mask to obtain flat per data point index list - idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) - # filter list of token lens using mask and obtain flat list for splitting - idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) - - # pad with zero at the begining; idxs_cells -> idxs_tokens -> idxs_data has been prepared so - # that the zero-index is used to add the padding to the tokens to ensure fixed size - times_enc = enc_time(rdata.datetimes, time_win) - datetimes_enc_padded = torch.cat([torch.zeros_like(times_enc[0]).unsqueeze(0), times_enc]) - geoinfos_padded = torch.cat([torch.zeros_like(rdata.geoinfos[0]).unsqueeze(0), rdata.geoinfos]) - coords_padded = torch.cat([torch.zeros_like(rdata.coords[0]).unsqueeze(0), rdata.coords]) - data_padded = torch.cat([torch.zeros_like(rdata.data[0]).unsqueeze(0), rdata.data]) - - # apply mask - datetimes = datetimes_enc_padded[idxs_data] - geoinfos = geoinfos_padded[idxs_data] - coords = coords_padded[idxs_data] - data = data_padded[idxs_data] - - # TODO, TODO, TODO: fix _coords_local - # _coords_local - coords_local = torch.cat((coords, torch.zeros_like(coords[:, 0]).unsqueeze(1)), 1) + # apply spatial masking on a per token level + if mask_tokens is not None: + # filter tokens using mask to obtain flat per data point index list + idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) + # filter list of token lens using mask and obtain flat list for splitting + idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) + + # pad with zero at the begining of the conceptual 2D data tensor: + # idxs_cells -> idxs_tokens -> idxs_data has been prepared so + # that the zero-index is used to add the padding to the tokens to ensure fixed size + times_enc = enc_time(rdata.datetimes, time_win) + zeros_like = torch.zeros_like + datetimes_enc_padded = torch.cat([zeros_like(times_enc[0]).unsqueeze(0), times_enc]) + geoinfos_padded = torch.cat([zeros_like(rdata.geoinfos[0]).unsqueeze(0), rdata.geoinfos]) + coords_padded = torch.cat([zeros_like(rdata.coords[0]).unsqueeze(0), rdata.coords]) + data_padded = torch.cat([zeros_like(rdata.data[0]).unsqueeze(0), rdata.data]) + + # apply mask + datetimes = datetimes_enc_padded[idxs_data] + geoinfos = geoinfos_padded[idxs_data] + coords = coords_padded[idxs_data] + data = data_padded[idxs_data] + + if mask_channels is not None: + assert False, "to be implemented" + # data = data_padded[ : channel_mask] + + # local coords + num_tokens_per_cell = [len(idxs) for idxs in idxs_cells_lens] + mask_tokens_per_cell = torch.split(torch.from_numpy(mask_tokens), num_tokens_per_cell) + masked_points_per_cell = torch.tensor( + [ + torch.tensor([len(t) for t, m in zip(tt, mm, strict=False) if m]).sum() + for tt, mm in zip(idxs_cells, mask_tokens_per_cell, strict=False) + ] + ).to(dtype=torch.int32) + coords_local = get_source_coords_local(coords, hpy_verts_rots, masked_points_per_cell) # create tensor that contains all info tokens = torch.cat((datetimes, coords_local, geoinfos, data), 1) # split up tensor into tokens + # TODO: idxs_data_lens is currently only defined when mask_tokens is not None idxs_data_lens = idxs_data_lens.tolist() tokens_cells = torch.split(tokens, idxs_data_lens) - # # R^3 coords - # thetas = ((90.0 - coords[:, 0]) / 180.0) * np.pi - # phis = ((coords[:, 1] + 180.0) / 360.0) * 2.0 * np.pi - # posr3 = s2tor3(thetas, phis) - - # # convert to local coordinates - # # TODO: avoid that padded lists are rotated, which means potentially a lot of zeros - # coords_local = _coords_local(posr3, hpy_verts_rots, idxs_cells, n_coords) - - # # reorder based on cells (except for coords_local) and then cat along - # # (time,coords,geoinfos,source) dimension and then split based on cells - # tokens_cells = [ - # ( - # list( - # torch.split( - # torch.cat( - # ( - # torch.full([len(idxs), 1], stream_id, dtype=torch.float32), - # times_enc_padded[idxs], - # coords_local[i], - # geoinfos_padded[idxs], - # source_padded[idxs], - # ), - # 1, - # ), - # idxs_lens, - # ) - # ) - # if idxs_lens[0] > 0 - # else [] - # ) - # for i, (idxs, idxs_lens) in enumerate(zip(idxs_cells, idxs_cells_lens, strict=True)) - # ] - return tokens_cells +def tokenize_apply_mask_target( + hl, + idxs_cells, + idxs_cells_lens, + mask_tokens, + mask_channels, + stream_id, + rdata, + time_win, + hpy_verts_rots, + hpy_verts_local, + hpy_nctrs, + enc_time, +): + """ + Apply masking to the data. + + Conceptually, the data is a matrix with the rows corresponding to data points / tokens and + the cols the channels. Thereby mask_tokens acts on the rows, grouped according to the tokens as + specified in idxs_cells and mask_channels acts on the columns. + + """ + + # convert to token level, forgetting about cells + idxs_tokens = [i for t in idxs_cells for i in t] + idxs_lens = [i for t in idxs_cells_lens for i in t] + + # apply spatial masking on a per token level + if mask_tokens is not None: + # filter tokens using mask to obtain flat per data point index list + idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) + # filter list of token lens using mask and obtain flat list for splitting + idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) + + # apply mask + datetimes = enc_time(rdata.datetimes[idxs_data], time_win) + geoinfos = rdata.geoinfos[idxs_data] + coords = rdata.coords[idxs_data] + data = rdata.data[idxs_data] + + if mask_channels is not None: + assert False, "to be implemented" + # data = data_padded[ : channel_mask] + + num_tokens_per_cell = [len(idxs) for idxs in idxs_cells_lens] + mask_tokens_per_cell = torch.split(torch.from_numpy(mask_tokens), num_tokens_per_cell) + masked_points_per_cell = torch.tensor( + [ + torch.tensor([len(t) for t, m in zip(tt, mm, strict=False) if m]).sum() + for tt, mm in zip(idxs_cells, mask_tokens_per_cell, strict=False) + ] + ).to(dtype=torch.int32) + + # compute encoding of target coordinates used in prediction network + if torch.tensor(idxs_lens).sum() > 0: + coords_local = get_target_coords_local( + hl, + masked_points_per_cell, + coords, + geoinfos, + datetimes, + hpy_verts_rots, + hpy_verts_local, + hpy_nctrs, + ) + coords_local.requires_grad = False + coords_local = list(coords_local.split(idxs_data_lens.tolist())) + else: + coords_local = torch.tensor([]) + + return coords_local + + +def get_source_coords_local( + coords: Tensor, + hpy_verts_rots: Tensor, + masked_points_per_cell, +) -> list[Tensor]: + """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" + + # remove padding from coords + posr3 = s2tor3(*theta_phi_to_standard_coords(coords)) + posr3[0, 0] = 0.0 + posr3[0, 1] = 0.0 + posr3[0, 2] = 0.0 + + rots = torch.repeat_interleave(hpy_verts_rots, masked_points_per_cell, dim=0) + # BMM only works for b x n x m and b x m x 1 + # adding a dummy dimension to posr3 + vec_rot = torch.bmm(rots, posr3.unsqueeze(-1)).squeeze(-1) + vec_scaled = r3tos2(vec_rot).to(torch.float32) + + # TODO: vec_scaled are small -> should they be normalized/rescaled? + + return vec_scaled + + +def get_target_coords_local( + hlc, + masked_points_per_cell, + coords, + target_geoinfos, + target_times, + verts_rots, + verts_local, + nctrs, +): + """Generate local coordinates for target coords w.r.t healpix cell vertices and + and for healpix cell vertices themselves + """ + + # target_coords_lens = [len(t) for t in target_coords] + # tcs, target_coords = tcs_optimized(target_coords) + target_coords = s2tor3(*theta_phi_to_standard_coords(coords)) + tcs = torch.split(target_coords, masked_points_per_cell.tolist()) + + if target_coords.shape[0] == 0: + return torch.tensor([]) + # target_geoinfos = torch.cat(target_geoinfos) + # target_times = torch.cat(target_times) + + verts00_rots, verts10_rots, verts11_rots, verts01_rots, vertsmm_rots = verts_rots + + a = torch.zeros( + [ + *target_coords.shape[:-1], + 1 + target_geoinfos.shape[1] + target_times.shape[1] + 5 * (3 * 5) + 3 * 8, + ] + ) + # TODO: properly set stream_id, implicitly zero at the moment + geoinfo_offset = 1 + a[..., geoinfo_offset : geoinfo_offset + target_times.shape[1]] = target_times + geoinfo_offset += target_times.shape[1] + a[..., geoinfo_offset : geoinfo_offset + target_geoinfos.shape[1]] = target_geoinfos + geoinfo_offset += target_geoinfos.shape[1] + + ref = torch.tensor([1.0, 0.0, 0.0]) + + tcs_lens = torch.tensor([tt.shape[0] for tt in tcs], dtype=torch.int32) + tcs_lens_mask = tcs_lens > 0 + tcs_lens = tcs_lens[tcs_lens_mask] + + vls = torch.cat( + [ + vl.repeat([tt, 1, 1]) + for tt, vl in zip(tcs_lens, verts_local[tcs_lens_mask], strict=False) + ], + 0, + ) + vls = vls.transpose(0, 1) + + zi = 0 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts00_rots, tcs + ) + + zi = 3 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[0] + + zi = 15 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts10_rots, tcs + ) + + zi = 18 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[1] + + zi = 30 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts11_rots, tcs + ) + + zi = 33 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[2] + + zi = 45 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts01_rots, tcs + ) + + zi = 48 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[3] + + zi = 60 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + vertsmm_rots, tcs + ) + + zi = 63 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[4] + + tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in nctrs], -1) + zi = 75 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs + # a = add_local_vert_coords_ctrs2( vertsmm_rots, nctrs, tcs, a, 99, geoinfo_offset) + + # remaining geoinfos (zenith angle etc) + # zi=99+3*8; + zi = 99 + a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] + + return a + + #################################################################################################### @@ -311,7 +511,7 @@ def tokenize_window_space( token_size, hl, hpy_verts_rots, - n_coords: CoordNormalizer, + n_coords, enc_time, pad_tokens=True, local_coords=True, @@ -413,29 +613,29 @@ def tokenize_window_spacetime( return tokens_cells -def _coords_local( - posr3: Tensor, hpy_verts_rots: Tensor, idxs_ord: list[Tensor], n_coords: CoordNormalizer -) -> list[Tensor]: - """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" - fp32 = torch.float32 - posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3]) # prepend zero - - idxs_ords_lens_l = [len(idxs) for idxs in idxs_ord] - # int32 should be enough - idxs_ords_lens = torch.tensor(idxs_ords_lens_l, dtype=torch.int32) - # concat all indices - idxs_ords_c = torch.cat([torch.tensor(i) for i in idxs_ord]) - # Copy the rotation matrices for each healpix cell - # num_points x 3 x 3 - rots = torch.repeat_interleave(hpy_verts_rots, idxs_ords_lens, dim=0) - # BMM only works for b x n x m and b x m x 1 - # adding a dummy dimension to posr3 - # numpoints x 3 x 1 - posr3_sel = posr3[idxs_ords_c].unsqueeze(-1) - vec_rot = torch.bmm(rots, posr3_sel) - vec_rot = vec_rot.squeeze(-1) - vec_scaled = n_coords(r3tos2(vec_rot).to(fp32)) - # split back to ragged list - # num_points x 2 - coords_local = torch.split(vec_scaled, idxs_ords_lens_l, dim=0) - return list(coords_local) +# def _coords_local( +# posr3: Tensor, hpy_verts_rots: Tensor, idxs_ord: list[Tensor], n_coords: CoordNormalizer +# ) -> list[Tensor]: +# """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" +# fp32 = torch.float32 +# posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3]) # prepend zero + +# idxs_ords_lens_l = [len(idxs) for idxs in idxs_ord] +# # int32 should be enough +# idxs_ords_lens = torch.tensor(idxs_ords_lens_l, dtype=torch.int32) +# # concat all indices +# idxs_ords_c = torch.cat([torch.tensor(i) for i in idxs_ord]) +# # Copy the rotation matrices for each healpix cell +# # num_points x 3 x 3 +# rots = torch.repeat_interleave(hpy_verts_rots, idxs_ords_lens, dim=0) +# # BMM only works for b x n x m and b x m x 1 +# # adding a dummy dimension to posr3 +# # numpoints x 3 x 1 +# posr3_sel = posr3[idxs_ords_c].unsqueeze(-1) +# vec_rot = torch.bmm(rots, posr3_sel) +# vec_rot = vec_rot.squeeze(-1) +# vec_scaled = n_coords(r3tos2(vec_rot).to(fp32)) +# # split back to ragged list +# # num_points x 2 +# coords_local = torch.split(vec_scaled, idxs_ords_lens_l, dim=0) +# return list(coords_local) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index b5d2279b8..194249e28 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import warnings import astropy_healpix as hp import numpy as np @@ -268,230 +267,7 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): #################################################################################################### -# def add_local_vert_coords_ctrs3( ctrs, verts, tcs, a, zi, geoinfo_offset) : - -# ref = torch.tensor( [1., 0., 0.]) - -# local_locs = [ -# torch.matmul(R, s.transpose( -1, -2)).transpose( -2, -1) -# for i,(R,s) in enumerate(zip(healpix_centers_rots,locs)) if len(s)>0 -# ] -# aa = locs_to_cell_coords_ctrs( ctrs, verts.transpose(0,1)) -# aa = ref - torch.cat( [aaa.unsqueeze(0).repeat( [*tt.shape[:-1],1,1]) -# if len(tt)>0 else torch.tensor([]) -# for tt,aaa in zip(tcs,aa)] -# if tt>, 0 ) -# aa = aa.flatten(1,2) -# a[...,(geoinfo_offset+zi):(geoinfo_offset+zi+aa.shape[-1])] = aa -# return a - - -#################################################################################################### -def get_target_coords_local(hlc, target_coords, geoinfo_offset): - """Generate local coordinates for target coords w.r.t healpix cell vertices and - and for healpix cell vertices themselves - """ - - # target_coords_lens = [len(t) for t in target_coords] - tcs = [ - ( - s2tor3( - torch.deg2rad(90.0 - t[..., geoinfo_offset].to(torch.float64)), - torch.deg2rad(180.0 + t[..., geoinfo_offset + 1].to(torch.float64)), - ) - if len(t) > 0 - else torch.tensor([]) - ) - for t in target_coords - ] - target_coords = torch.cat(target_coords) - if target_coords.shape[0] == 0: - return torch.tensor([]) - - verts00 = healpix_verts(hlc, 0.0, 0.0) - verts10 = healpix_verts(hlc, 1.0, 0.0) - verts11 = healpix_verts(hlc, 1.0, 1.0) - verts01 = healpix_verts(hlc, 0.0, 1.0) - vertsmm = healpix_verts(hlc, 0.5, 0.5) - - a = torch.zeros( - [*target_coords.shape[:-1], (target_coords.shape[-1] - 2) + 5 * (3 * 5) + 3 * 8] - ) - a[..., :geoinfo_offset] = target_coords[..., :geoinfo_offset] - ref = torch.tensor([1.0, 0.0, 0.0]) - - zi = 0 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords(hlc, tcs, 0.0, 0.0) - ) - a = add_local_vert_coords(hlc, a, verts10, tcs, 3, 0.0, 0.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts11, tcs, 6, 0.0, 0.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts01, tcs, 9, 0.0, 0.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, vertsmm, tcs, 12, 0.0, 0.0, geoinfo_offset) - - zi = 15 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords(hlc, tcs, 1.0, 0.0) - ) - a = add_local_vert_coords(hlc, a, verts00, tcs, 18, 1.0, 0.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts11, tcs, 21, 1.0, 0.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts01, tcs, 24, 1.0, 0.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, vertsmm, tcs, 27, 1.0, 0.0, geoinfo_offset) - - zi = 30 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords(hlc, tcs, 1.0, 1.0) - ) - a = add_local_vert_coords(hlc, a, verts00, tcs, 33, 1.0, 1.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts10, tcs, 36, 1.0, 1.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts01, tcs, 39, 1.0, 1.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, vertsmm, tcs, 42, 1.0, 1.0, geoinfo_offset) - - zi = 45 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords(hlc, tcs, 0.0, 1.0) - ) - a = add_local_vert_coords(hlc, a, verts00, tcs, 48, 0.0, 1.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts11, tcs, 51, 0.0, 1.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts10, tcs, 54, 0.0, 1.0, geoinfo_offset) - # a = add_local_vert_coords( hlc, a, verts10, tcs, 51, 0.0, 1.0, geoinfo_offset) - # a = add_local_vert_coords( hlc, a, verts01, tcs, 54, 0.0, 1.0, geoinfo_offset) - a = add_local_vert_coords(hlc, a, vertsmm, tcs, 57, 0.0, 1.0, geoinfo_offset) - - zi = 60 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords(hlc, tcs, 0.5, 0.5) - ) - a = add_local_vert_coords(hlc, a, verts00, tcs, 63, 0.5, 0.5, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts10, tcs, 66, 0.5, 0.5, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts11, tcs, 69, 0.5, 0.5, geoinfo_offset) - a = add_local_vert_coords(hlc, a, verts01, tcs, 72, 0.5, 0.5, geoinfo_offset) - - # add centroids to neighboring cells wrt to cell center - num_healpix_cells = 12 * 4**hlc - with warnings.catch_warnings(action="ignore"): - temp = hp.neighbours(np.arange(num_healpix_cells), 2**hlc, order="nested").transpose() - # fix missing nbors with references to self - for i, row in enumerate(temp): - temp[i][row == -1] = i - # coords of centers of all centers - lons, lats = hp.healpix_to_lonlat( - np.arange(0, num_healpix_cells), 2**hlc, dx=0.5, dy=0.5, order="nested" - ) - ctrs = s2tor3(torch.from_numpy(np.pi / 2.0 - lats.value), torch.from_numpy(lons.value)) - ctrs = ctrs[temp.flatten()].reshape((num_healpix_cells, 8, 3)).transpose(1, 0) - # local coords with respect to all neighboring centers - tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in ctrs], -1) - zi = 75 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs - - # remaining geoinfos (zenith angle etc) - zi = 99 - a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] - - return a - - -#################################################################################################### -# TODO: remove this function, it is dead code that will fail immediately -def get_target_coords_local_fast(hlc, target_coords, geoinfo_offset): - """Generate local coordinates for target coords w.r.t healpix cell vertices and - and for healpix cell vertices themselves - """ - - # target_coords_lens = [len(t) for t in target_coords] - tcs = [ - ( - s2tor3( - torch.deg2rad(90.0 - t[..., geoinfo_offset].to(torch.float64)), - torch.deg2rad(180.0 + t[..., geoinfo_offset + 1].to(torch.float64)), - ) - if len(t) > 0 - else torch.tensor([]) - ) - for t in target_coords - ] - target_coords = torch.cat(target_coords) - if target_coords.shape[0] == 0: - return torch.tensor([]) - - verts00, verts00_rots = healpix_verts_rots(hlc, 0.0, 0.0) - verts10, verts10_rots = healpix_verts_rots(hlc, 1.0, 0.0) - verts11, verts11_rots = healpix_verts_rots(hlc, 1.0, 1.0) - verts01, verts01_rots = healpix_verts_rots(hlc, 0.0, 1.0) - vertsmm, vertsmm_rots = healpix_verts_rots(hlc, 0.5, 0.5) - - a = torch.zeros( - [*target_coords.shape[:-1], (target_coords.shape[-1] - 2) + 5 * (3 * 5) + 3 * 8] - ) - # a = torch.zeros( [*target_coords.shape[:-1], - # (target_coords.shape[-1]-2) + 5*(3*5) + 3*8]) - # a = torch.zeros( [*target_coords.shape[:-1], 148]) - # #(target_coords.shape[-1]-2) + 5*(3*5) + 3*8]) - a[..., :geoinfo_offset] = target_coords[..., :geoinfo_offset] - ref = torch.tensor([1.0, 0.0, 0.0]) - - zi = 0 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords_ctrs(verts00_rots, tcs) - ) - verts = torch.stack([verts10, verts11, verts01, vertsmm]) - a = add_local_vert_coords_ctrs2(verts00_rots, verts, tcs, a, 3, geoinfo_offset) - - zi = 15 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords_ctrs(verts10_rots, tcs) - ) - verts = torch.stack([verts00, verts11, verts01, vertsmm]) - a = add_local_vert_coords_ctrs2(verts10_rots, verts, tcs, a, 18, geoinfo_offset) - - zi = 30 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords_ctrs(verts11_rots, tcs) - ) - verts = torch.stack([verts00, verts10, verts01, vertsmm]) - a = add_local_vert_coords_ctrs2(verts11_rots, verts, tcs, a, 33, geoinfo_offset) - - zi = 45 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords_ctrs(verts01_rots, tcs) - ) - verts = torch.stack([verts00, verts11, verts10, vertsmm]) - a = add_local_vert_coords_ctrs2(verts01_rots, verts, tcs, a, 48, geoinfo_offset) - - zi = 60 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - torch.cat( - locs_to_cell_coords_ctrs(vertsmm_rots, tcs) - ) - verts = torch.stack([verts00, verts10, verts11, verts01]) - a = add_local_vert_coords_ctrs2(vertsmm_rots, verts, tcs, a, 63, geoinfo_offset) - - # add local coords wrt to center of neighboring cells - # (since the neighbors are used in the prediction) - num_healpix_cells = 12 * 4**hlc - with warnings.catch_warnings(action="ignore"): - temp = hp.neighbours(np.arange(num_healpix_cells), 2**hlc, order="nested").transpose() - # fix missing nbors with references to self - for i, row in enumerate(temp): - temp[i][row == -1] = i - nctrs = vertsmm[temp.flatten()].reshape((num_healpix_cells, 8, 3)).transpose(1, 0) - # local coords with respect to all neighboring centers - tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in nctrs], -1) - zi = 75 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs - # a = add_local_vert_coords_ctrs2( vertsmm_rots, nctrs, tcs, a, 99, geoinfo_offset) - - # remaining geoinfos (zenith angle etc) - # zi=99+3*8; - zi = 99 - # assert target_coords.shape[-1] + zi < a.shape[-1] - a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] - - return a - - -#################################################################################################### -def tcs_optimized(target_coords: list[torch.Tensor]) -> tuple[list[torch.Tensor], torch.Tensor]: +def tcs_optimized(stacked_coords: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]: """ Args: target_coords: List of 2D coordinate tensors, each with shape [N, 2] @@ -501,9 +277,6 @@ def tcs_optimized(target_coords: list[torch.Tensor]) -> tuple[list[torch.Tensor] concatenated_coords: All original coords concatenated """ - # Concatenate all tensors - stacked_coords = torch.cat(target_coords, dim=0) # [total_points, 2] - # Single vectorized coordinate transformation theta_all = torch.deg2rad(90.0 - stacked_coords[..., 0]) phi_all = torch.deg2rad(180.0 + stacked_coords[..., 1]) @@ -514,19 +287,29 @@ def tcs_optimized(target_coords: list[torch.Tensor]) -> tuple[list[torch.Tensor] # Split back to original structure sizes = [t.shape[0] for t in target_coords] # Get original tensor sizes tcs = list(torch.split(transformed_all, sizes, dim=0)) # Split back to list - return tcs, stacked_coords + + return stacked_coords #################################################################################################### def get_target_coords_local_ffast( - hlc, target_coords, target_geoinfos, target_times, verts_rots, verts_local, nctrs + hlc, + masked_points_per_cell, + coords, + target_geoinfos, + target_times, + verts_rots, + verts_local, + nctrs, ): """Generate local coordinates for target coords w.r.t healpix cell vertices and and for healpix cell vertices themselves """ # target_coords_lens = [len(t) for t in target_coords] - tcs, target_coords = tcs_optimized(target_coords) + # tcs, target_coords = tcs_optimized(target_coords) + target_coords = s2tor3(*theta_phi_to_standard_coords(coords)) + tcs = torch.split(masked_points_per_cell) if target_coords.shape[0] == 0: return torch.tensor([]) From 6046694e0776ff344a69b4ebccf300e859934a5a Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Thu, 13 Nov 2025 17:03:42 +0000 Subject: [PATCH 019/344] Adapt diffusion model to expected data structure --- src/weathergen/model/diffusion.py | 46 ++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index d7ef68ca7..503327afe 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -15,16 +15,38 @@ # ---------------------------------------------------------------------------- import torch +from dataclass import dataclass from weathergen.model.engines import ForecastingEngine +@dataclass +class BatchData: + """ + Mock function for the data that will be provided to the diffusion model. Will change. + """ + + model_samples: dict + target_samples: dict + + def _get_sample_len(self): + return len(list(self.model_samples.keys())) + + def _get_input_data(self, t: int): + return self.model_samples[t]["data"] + + def _get_target_data(self, t: int): + return self.target_samples[t]["data"] + + def _get_target_metadata(self, t: int): + return self.target_samples[t]["metadata"] + + class DiffusionForecastEngine(torch.nn.Module): # Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72 def __init__( self, - stage: str, forecast_engine: ForecastingEngine, sigma_min: float = 0.002, # Adapt to GenCast? sigma_max: float = 80, @@ -34,7 +56,6 @@ def __init__( p_std: float = 1.2, ): super().__init__() - self.stage = stage self.net = forecast_engine self.preconditioner = Preconditioner() @@ -46,16 +67,15 @@ def __init__( self.p_mean = p_mean self.p_std = p_std - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.stage == "train": - return self.forward_train(y=x) - else: - return self.inference(x=x) + def forward(self, data: BatchData) -> torch.Tensor: + # Retrieve conditionings [0:-1], target [-1], and noise from data object + cond = [data._get_input_data[t] for t in range(data._get_sample_len() - 1)] + y = data._get_target_data[-1] + eta = data._get_target_metadata[-1] - def forward_train(self, y) -> torch.Tensor: - # Determine noise level -- move to "preprocessing" - noise = torch.randn(y.shape, device=y.device) - sigma = (noise * self.p_std + self.p_mean).exp() + # Compute sigma (noise level) from eta + #noise = torch.randn(y.shape, device=y.device) + sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma # Compute conditionings @@ -66,7 +86,7 @@ def forward_train(self, y) -> torch.Tensor: # Add noise, precondition input, and feed through network x = y + n - x = self.preconditioner.precondition(x) + x = self.preconditioner.precondition(x, cond) net_out = self.net(c_in * x, c_noise) y_hat = c_skip * y + c_out * net_out # Eq. (7) @@ -129,5 +149,5 @@ class Preconditioner: def __init__(self): pass - def precondition(self, x): + def precondition(self, x, c): return x From f66c9fa1eccd7063487e92f96e8cf21b259dbcff Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Thu, 13 Nov 2025 17:13:45 +0000 Subject: [PATCH 020/344] Corrected data retrieval to only access model_samples and not target_samples --- src/weathergen/model/diffusion.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 503327afe..21f5c9b9c 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -34,6 +34,9 @@ def _get_sample_len(self): def _get_input_data(self, t: int): return self.model_samples[t]["data"] + + def _get_input_metadata(self, t: int): + return self.model_samples[t]["metadata"] def _get_target_data(self, t: int): return self.target_samples[t]["data"] @@ -68,10 +71,11 @@ def __init__( self.p_std = p_std def forward(self, data: BatchData) -> torch.Tensor: - # Retrieve conditionings [0:-1], target [-1], and noise from data object - cond = [data._get_input_data[t] for t in range(data._get_sample_len() - 1)] - y = data._get_target_data[-1] - eta = data._get_target_metadata[-1] + # Retrieve conditionings [0:-1], target [-1], and noise from data object. + # The data retrieval ignores batch and stream dimension for now (has to be adapted). + cond = [data._get_input_data(t) for t in range(data._get_sample_len() - 1)] + y = data._get_input_data(-1) + eta = data._get_input_metadata(-1) # Compute sigma (noise level) from eta #noise = torch.randn(y.shape, device=y.device) From 7e48c39e009e9015b4759d2c1263b4894b8ace60 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Thu, 13 Nov 2025 17:14:57 +0000 Subject: [PATCH 021/344] Minor correction --- src/weathergen/model/diffusion.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 21f5c9b9c..3082f15aa 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -29,19 +29,19 @@ class BatchData: model_samples: dict target_samples: dict - def _get_sample_len(self): + def get_sample_len(self): return len(list(self.model_samples.keys())) - def _get_input_data(self, t: int): + def get_input_data(self, t: int): return self.model_samples[t]["data"] - def _get_input_metadata(self, t: int): + def get_input_metadata(self, t: int): return self.model_samples[t]["metadata"] - def _get_target_data(self, t: int): + def get_target_data(self, t: int): return self.target_samples[t]["data"] - def _get_target_metadata(self, t: int): + def get_target_metadata(self, t: int): return self.target_samples[t]["metadata"] @@ -73,9 +73,9 @@ def __init__( def forward(self, data: BatchData) -> torch.Tensor: # Retrieve conditionings [0:-1], target [-1], and noise from data object. # The data retrieval ignores batch and stream dimension for now (has to be adapted). - cond = [data._get_input_data(t) for t in range(data._get_sample_len() - 1)] - y = data._get_input_data(-1) - eta = data._get_input_metadata(-1) + cond = [data.get_input_data(t) for t in range(data.get_sample_len() - 1)] + y = data.get_input_data(-1) + eta = data.get_input_metadata(-1) # Compute sigma (noise level) from eta #noise = torch.randn(y.shape, device=y.device) From e4a9cc0017506da320cf1a50801bfc8b09b3d0a2 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 13 Nov 2025 18:58:28 +0100 Subject: [PATCH 022/344] Masking target is working in principle but errors when feeding data to the model. --- .../datasets/multi_stream_data_sampler.py | 2 +- src/weathergen/datasets/stream_data.py | 8 +- src/weathergen/datasets/tokenizer_masking.py | 228 +++++++++--------- src/weathergen/datasets/tokenizer_utils.py | 14 +- 4 files changed, 127 insertions(+), 125 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index d1f67ce3e..16deddcfc 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -419,7 +419,7 @@ def __iter__(self): stream_data.target_is_spoof = True # preprocess data for model input - (tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target( + (tt_cells, tt_t, tt_c, tc) = self.tokenizer.batchify_target( stream_info, self.sampling_rate_target, readerdata_to_torch(rdata), diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 450d5e96d..33c2cbf10 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -161,7 +161,7 @@ def add_source( self.source_raw = ss_raw self.source_tokens_lens = ss_lens - self.source_tokens_cells = torch.cat(ss_cells) + self.source_tokens_cells = torch.stack(ss_cells) self.source_centroids = torch.cat(ss_centroids) idx = torch.isnan(self.source_tokens_cells) @@ -199,10 +199,10 @@ def add_target( None """ - self.target_tokens[fstep] = torch.cat(targets) + self.target_tokens[fstep] = targets self.target_coords[fstep] = torch.cat(target_coords) - self.target_times_raw[fstep] = np.concatenate(times_raw) - self.target_coords_raw[fstep] = torch.cat(target_coords_raw) + self.target_times_raw[fstep] = times_raw + self.target_coords_raw[fstep] = target_coords_raw tc = target_coords self.target_coords_lens[fstep] = torch.tensor( diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index b9426a9db..a600e4590 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -7,25 +7,20 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from functools import partial -import numpy as np import torch from weathergen.common.io import IOReaderData from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer from weathergen.datasets.tokenizer_utils import ( - arc_alpha, encode_times_source, encode_times_target, - get_target_coords_local_ffast, + # get_target_coords_local_ffast, tokenize_apply_mask, tokenize_apply_mask_target, tokenize_space, tokenize_spacetime, - tokenize_window_space, - tokenize_window_spacetime, ) @@ -116,11 +111,11 @@ def batchify_target( idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) mask_tokens = ~self.mask_tokens - # mask_channels = ~self.mask_channels if self.mask_channels is not None + # mask_channels = ~self.mask_channels if self.mask_channels is not None # else self.mask_channels mask_channels = self.mask_channels - coords_token_cells = tokenize_apply_mask_target( + data, datetimes, coords, tokens_coords_local = tokenize_apply_mask_target( self.hl_target, idxs_cells, idxs_cells_lens, @@ -135,112 +130,117 @@ def batchify_target( encode_times_target, ) - import code - - code.interact(local=locals()) - - # identity function - def id(arg): - return arg - - # set tokenization function, no normalization of coords - tokenize_window = partial( - tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, - time_win=time_win, - token_size=token_size, - hl=self.hl_source, - hpy_verts_rots=self.hpy_verts_rots_source[-1], - n_coords=id, - enc_time=encode_times_target, - pad_tokens=False, - local_coords=False, - ) - - # tokenize - target_tokens_cells = tokenize_window( - 0, - rdata.coords, - rdata.geoinfos, - rdata.data, - rdata.datetimes, - ) - - target_tokens = self.masker.mask_target( - target_tokens_cells, rdata.coords, rdata.geoinfos, rdata.data - ) - - target_tokens_lens = [len(t) for t in target_tokens] - total_target = sum(target_tokens_lens) - - # sampling the number of targets according to per-stream sampling_rate_target - # otherwise take global sampling_rate_target from config - sampling_rate_target = stream_info.get("sampling_rate_target", sampling_rate_target) - - samples = (torch.empty(total_target).uniform_() < sampling_rate_target).split( - target_tokens_lens - ) - target_tokens = [ - (tokens[samples]) for tokens, samples in zip(target_tokens, samples, strict=False) - ] - target_tokens_lens = [len(t) for t in target_tokens] - - if torch.tensor(target_tokens_lens).sum() == 0: - return (torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([])) - - tt_lin = torch.cat(target_tokens) - tt_lens = target_tokens_lens - - if max_num_targets > 0: - target_tokens = self.sample_tensors_uniform_vectorized( - target_tokens, torch.tensor(tt_lens), max_num_targets - ) - - tt_lin = torch.cat(target_tokens) - target_tokens_lens = [len(t) for t in target_tokens] - tt_lens = target_tokens_lens - - # TODO: can we avoid setting the offsets here manually? - # TODO: ideally we would not have recover it; but using tokenize_window seems necessary for - # consistency -> split tokenize_window in two parts with the cat only happening in the - # second - offset = 6 - # offset of 1 : stream_id - target_times = torch.split(tt_lin[..., 1:offset], tt_lens) - target_coords = torch.split(tt_lin[..., offset : offset + rdata.coords.shape[-1]], tt_lens) - offset += rdata.coords.shape[-1] - target_geoinfos = torch.split( - tt_lin[..., offset : offset + rdata.geoinfos.shape[-1]], tt_lens - ) - offset += rdata.geoinfos.shape[-1] - target_tokens = torch.split(tt_lin[..., offset:], tt_lens) - - offset = 6 - target_coords_raw = torch.split( - tt_lin[:, offset : offset + rdata.coords.shape[-1]], tt_lens - ) - # recover absolute time from relatives in encoded ones - # TODO: avoid recover; see TODO above - deltas_sec = ( - arc_alpha(tt_lin[..., 1] - 0.5, tt_lin[..., 2] - 0.5) / (2.0 * np.pi) * (12 * 3600) - ) - deltas_sec = deltas_sec.numpy().astype("timedelta64[s]") - target_times_raw = np.split(time_win[0] + deltas_sec, np.cumsum(tt_lens)[:-1]) - - # compute encoding of target coordinates used in prediction network - if torch.tensor(tt_lens).sum() > 0: - target_coords = get_target_coords_local_ffast( - self.hl_target, - target_coords, - target_geoinfos, - target_times, - self.hpy_verts_rots_target, - self.hpy_verts_local_target, - self.hpy_nctrs_target, - ) - target_coords.requires_grad = False - target_coords = list(target_coords.split(tt_lens)) - - return (target_tokens, target_coords, target_coords_raw, target_times_raw) + # # target_tokens, target_coords, target_coords_raw, target_times_raw) + + # import code + + # code.interact(local=locals()) + + # # identity function + # def id(arg): + # return arg + + # # set tokenization function, no normalization of coords + # tokenize_window = partial( + # tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, + # time_win=time_win, + # token_size=token_size, + # hl=self.hl_source, + # hpy_verts_rots=self.hpy_verts_rots_source[-1], + # n_coords=id, + # enc_time=encode_times_target, + # pad_tokens=False, + # local_coords=False, + # ) + + # # tokenize + # target_tokens_cells = tokenize_window( + # 0, + # rdata.coords, + # rdata.geoinfos, + # rdata.data, + # rdata.datetimes, + # ) + + # target_tokens = self.masker.mask_target( + # target_tokens_cells, rdata.coords, rdata.geoinfos, rdata.data + # ) + + # target_tokens_lens = [len(t) for t in target_tokens] + # total_target = sum(target_tokens_lens) + + # # sampling the number of targets according to per-stream sampling_rate_target + # # otherwise take global sampling_rate_target from config + # sampling_rate_target = stream_info.get("sampling_rate_target", sampling_rate_target) + + # samples = (torch.empty(total_target).uniform_() < sampling_rate_target).split( + # target_tokens_lens + # ) + # target_tokens = [ + # (tokens[samples]) for tokens, samples in zip(target_tokens, samples, strict=False) + # ] + # target_tokens_lens = [len(t) for t in target_tokens] + + # if torch.tensor(target_tokens_lens).sum() == 0: + # return (torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([])) + + # tt_lin = torch.cat(target_tokens) + # tt_lens = target_tokens_lens + + # if max_num_targets > 0: + # target_tokens = self.sample_tensors_uniform_vectorized( + # target_tokens, torch.tensor(tt_lens), max_num_targets + # ) + + # tt_lin = torch.cat(target_tokens) + # target_tokens_lens = [len(t) for t in target_tokens] + # tt_lens = target_tokens_lens + + # # TODO: can we avoid setting the offsets here manually? + # # TODO: ideally we would not have recover it; but using tokenize_window seems necessary for + # # consistency -> split tokenize_window in two parts with the cat only happening in the + # # second + # offset = 6 + # # offset of 1 : stream_id + # target_times = torch.split(tt_lin[..., 1:offset], tt_lens) + # target_coords = torch.split(tt_lin[..., offset : offset + rdata.coords.shape[-1]], tt_lens) + # offset += rdata.coords.shape[-1] + # target_geoinfos = torch.split( + # tt_lin[..., offset : offset + rdata.geoinfos.shape[-1]], tt_lens + # ) + # offset += rdata.geoinfos.shape[-1] + # target_tokens = torch.split(tt_lin[..., offset:], tt_lens) + + # offset = 6 + # target_coords_raw = torch.split( + # tt_lin[:, offset : offset + rdata.coords.shape[-1]], tt_lens + # ) + # # recover absolute time from relatives in encoded ones + # # TODO: avoid recover; see TODO above + # deltas_sec = ( + # arc_alpha(tt_lin[..., 1] - 0.5, tt_lin[..., 2] - 0.5) / (2.0 * np.pi) * (12 * 3600) + # ) + # deltas_sec = deltas_sec.numpy().astype("timedelta64[s]") + # target_times_raw = np.split(time_win[0] + deltas_sec, np.cumsum(tt_lens)[:-1]) + + # # compute encoding of target coordinates used in prediction network + # if torch.tensor(tt_lens).sum() > 0: + # target_coords = get_target_coords_local_ffast( + # self.hl_target, + # target_coords, + # target_geoinfos, + # target_times, + # self.hpy_verts_rots_target, + # self.hpy_verts_local_target, + # self.hpy_nctrs_target, + # ) + # target_coords.requires_grad = False + # target_coords = list(target_coords.split(tt_lens)) + + # return (target_tokens, target_coords, target_coords_raw, ) + # data, tokens_coords_local, datetimes + # # target_tokens, target_coords, target_coords_raw, target_times_raw) + return (data, datetimes, coords, tokens_coords_local) def sample_tensors_uniform_vectorized( self, tensor_list: list, lengths: list, max_total_points: int diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 9bb1ec3e9..cca53bf65 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -284,7 +284,8 @@ def tokenize_apply_mask( coords_local = get_source_coords_local(coords, hpy_verts_rots, masked_points_per_cell) # create tensor that contains all info - tokens = torch.cat((datetimes, coords_local, geoinfos, data), 1) + stream_ids = torch.full([len(datetimes), 1], stream_id, dtype=torch.float32) + tokens = torch.cat((stream_ids, datetimes, coords_local, geoinfos, data), 1) # split up tensor into tokens # TODO: idxs_data_lens is currently only defined when mask_tokens is not None @@ -329,7 +330,8 @@ def tokenize_apply_mask_target( idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) # apply mask - datetimes = enc_time(rdata.datetimes[idxs_data], time_win) + datetimes = rdata.datetimes[idxs_data] + datetimes_enc = enc_time(datetimes, time_win) geoinfos = rdata.geoinfos[idxs_data] coords = rdata.coords[idxs_data] data = rdata.data[idxs_data] @@ -354,17 +356,17 @@ def tokenize_apply_mask_target( masked_points_per_cell, coords, geoinfos, - datetimes, + datetimes_enc, hpy_verts_rots, hpy_verts_local, hpy_nctrs, ) coords_local.requires_grad = False - coords_local = list(coords_local.split(idxs_data_lens.tolist())) + tokens_coords_local = list(coords_local.split(idxs_data_lens.tolist())) else: - coords_local = torch.tensor([]) + tokens_coords_local = torch.tensor([]) - return coords_local + return data, datetimes, coords, tokens_coords_local def get_source_coords_local( From a5814052a0d4e286080bd4b41b2766ef4a41aed5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 13 Nov 2025 23:17:29 +0100 Subject: [PATCH 023/344] Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling --- .../datasets/multi_stream_data_sampler.py | 4 +- src/weathergen/datasets/stream_data.py | 18 +- src/weathergen/datasets/tokenizer_masking.py | 125 +------------ src/weathergen/datasets/tokenizer_utils.py | 165 ++---------------- src/weathergen/datasets/utils.py | 133 -------------- 5 files changed, 26 insertions(+), 419 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 16deddcfc..63a0e5d63 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -419,14 +419,14 @@ def __iter__(self): stream_data.target_is_spoof = True # preprocess data for model input - (tt_cells, tt_t, tt_c, tc) = self.tokenizer.batchify_target( + (tt_cells, tt_t, tt_c, tc, tc_l) = self.tokenizer.batchify_target( stream_info, self.sampling_rate_target, readerdata_to_torch(rdata), (time_win_target.start, time_win_target.end), ) - stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t) + stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t) # merge inputs for sources and targets for current stream streams_data += [stream_data] diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 33c2cbf10..55d79c34c 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -90,7 +90,6 @@ def to_device(self, device: str) -> None: self.target_coords = [t.to(device, non_blocking=True) for t in self.target_coords] self.target_tokens = [t.to(device, non_blocking=True) for t in self.target_tokens] - self.target_tokens_lens = [t.to(device, non_blocking=True) for t in self.target_tokens_lens] self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) @@ -131,7 +130,6 @@ def add_empty_target(self, fstep: int) -> None: """ self.target_tokens[fstep] += [torch.tensor([], dtype=torch.int32)] - self.target_tokens_lens[fstep] += [torch.zeros([self.healpix_cells], dtype=torch.int32)] self.target_coords[fstep] += [torch.zeros((0, 105)) for _ in range(self.healpix_cells)] self.target_coords_lens[fstep] += [torch.zeros([self.healpix_cells], dtype=torch.int32)] self.target_coords_raw[fstep] += [torch.tensor([]) for _ in range(self.healpix_cells)] @@ -172,6 +170,7 @@ def add_target( fstep: int, targets: list, target_coords: torch.tensor, + target_coords_per_cell: torch.tensor, target_coords_raw: torch.tensor, times_raw: torch.tensor, ) -> None: @@ -200,20 +199,11 @@ def add_target( """ self.target_tokens[fstep] = targets - self.target_coords[fstep] = torch.cat(target_coords) + self.target_coords[fstep] = target_coords + self.target_coords_lens[fstep] = target_coords_per_cell self.target_times_raw[fstep] = times_raw self.target_coords_raw[fstep] = target_coords_raw - tc = target_coords - self.target_coords_lens[fstep] = torch.tensor( - [len(f) for f in tc] if len(tc) > 1 else self.target_coords_lens[fstep], - dtype=torch.int, - ) - self.target_tokens_lens[fstep] = torch.tensor( - [len(f) for f in targets] if len(targets) > 1 else self.target_tokens_lens[fstep], - dtype=torch.int, - ) - def target_empty(self) -> bool: """ Test if target for stream is empty @@ -229,7 +219,7 @@ def target_empty(self) -> bool: """ # cat over forecast steps - return torch.cat(self.target_tokens_lens).sum() == 0 + return torch.cat(self.target_coords_lens).sum() == 0 def source_empty(self) -> bool: """ diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index a600e4590..4d6618477 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -66,7 +66,7 @@ def batchify_source( idxs_cells, idxs_cells_lens, rdata ) - source_tokens_cells = tokenize_apply_mask( + source_tokens_cells, source_tokens_lens = tokenize_apply_mask( idxs_cells, idxs_cells_lens, mask_tokens, @@ -78,7 +78,6 @@ def batchify_source( encode_times_source, ) - source_tokens_lens = torch.tensor([len(s) for s in source_tokens_cells], dtype=torch.int32) # if source_tokens_lens.sum() > 0: # source_centroids = self.compute_source_centroids(source_tokens_cells) # else: @@ -97,25 +96,24 @@ def batchify_target( time_win: tuple, ): token_size = stream_info["token_size"] - max_num_targets = stream_info.get("max_num_targets", -1) stream_id = stream_info["stream_id"] # target is empty if len(self.mask_tokens) == 0: - target_tokens, target_coords = torch.tensor([]), torch.tensor([]) - target_tokens_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32) - return (target_tokens, target_coords, torch.tensor([]), torch.tensor([])) + out = torch.tensor([]) + return (out, out, out, out, out) # create tokenization index tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) mask_tokens = ~self.mask_tokens + # TODO # mask_channels = ~self.mask_channels if self.mask_channels is not None # else self.mask_channels mask_channels = self.mask_channels - data, datetimes, coords, tokens_coords_local = tokenize_apply_mask_target( + data, datetimes, coords, coords_local, coords_per_cell = tokenize_apply_mask_target( self.hl_target, idxs_cells, idxs_cells_lens, @@ -130,117 +128,10 @@ def batchify_target( encode_times_target, ) - # # target_tokens, target_coords, target_coords_raw, target_times_raw) + # TODO, TODO, TODO: max_num_targets + # max_num_targets = stream_info.get("max_num_targets", -1) - # import code - - # code.interact(local=locals()) - - # # identity function - # def id(arg): - # return arg - - # # set tokenization function, no normalization of coords - # tokenize_window = partial( - # tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, - # time_win=time_win, - # token_size=token_size, - # hl=self.hl_source, - # hpy_verts_rots=self.hpy_verts_rots_source[-1], - # n_coords=id, - # enc_time=encode_times_target, - # pad_tokens=False, - # local_coords=False, - # ) - - # # tokenize - # target_tokens_cells = tokenize_window( - # 0, - # rdata.coords, - # rdata.geoinfos, - # rdata.data, - # rdata.datetimes, - # ) - - # target_tokens = self.masker.mask_target( - # target_tokens_cells, rdata.coords, rdata.geoinfos, rdata.data - # ) - - # target_tokens_lens = [len(t) for t in target_tokens] - # total_target = sum(target_tokens_lens) - - # # sampling the number of targets according to per-stream sampling_rate_target - # # otherwise take global sampling_rate_target from config - # sampling_rate_target = stream_info.get("sampling_rate_target", sampling_rate_target) - - # samples = (torch.empty(total_target).uniform_() < sampling_rate_target).split( - # target_tokens_lens - # ) - # target_tokens = [ - # (tokens[samples]) for tokens, samples in zip(target_tokens, samples, strict=False) - # ] - # target_tokens_lens = [len(t) for t in target_tokens] - - # if torch.tensor(target_tokens_lens).sum() == 0: - # return (torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([])) - - # tt_lin = torch.cat(target_tokens) - # tt_lens = target_tokens_lens - - # if max_num_targets > 0: - # target_tokens = self.sample_tensors_uniform_vectorized( - # target_tokens, torch.tensor(tt_lens), max_num_targets - # ) - - # tt_lin = torch.cat(target_tokens) - # target_tokens_lens = [len(t) for t in target_tokens] - # tt_lens = target_tokens_lens - - # # TODO: can we avoid setting the offsets here manually? - # # TODO: ideally we would not have recover it; but using tokenize_window seems necessary for - # # consistency -> split tokenize_window in two parts with the cat only happening in the - # # second - # offset = 6 - # # offset of 1 : stream_id - # target_times = torch.split(tt_lin[..., 1:offset], tt_lens) - # target_coords = torch.split(tt_lin[..., offset : offset + rdata.coords.shape[-1]], tt_lens) - # offset += rdata.coords.shape[-1] - # target_geoinfos = torch.split( - # tt_lin[..., offset : offset + rdata.geoinfos.shape[-1]], tt_lens - # ) - # offset += rdata.geoinfos.shape[-1] - # target_tokens = torch.split(tt_lin[..., offset:], tt_lens) - - # offset = 6 - # target_coords_raw = torch.split( - # tt_lin[:, offset : offset + rdata.coords.shape[-1]], tt_lens - # ) - # # recover absolute time from relatives in encoded ones - # # TODO: avoid recover; see TODO above - # deltas_sec = ( - # arc_alpha(tt_lin[..., 1] - 0.5, tt_lin[..., 2] - 0.5) / (2.0 * np.pi) * (12 * 3600) - # ) - # deltas_sec = deltas_sec.numpy().astype("timedelta64[s]") - # target_times_raw = np.split(time_win[0] + deltas_sec, np.cumsum(tt_lens)[:-1]) - - # # compute encoding of target coordinates used in prediction network - # if torch.tensor(tt_lens).sum() > 0: - # target_coords = get_target_coords_local_ffast( - # self.hl_target, - # target_coords, - # target_geoinfos, - # target_times, - # self.hpy_verts_rots_target, - # self.hpy_verts_local_target, - # self.hpy_nctrs_target, - # ) - # target_coords.requires_grad = False - # target_coords = list(target_coords.split(tt_lens)) - - # return (target_tokens, target_coords, target_coords_raw, ) - # data, tokens_coords_local, datetimes - # # target_tokens, target_coords, target_coords_raw, target_times_raw) - return (data, datetimes, coords, tokens_coords_local) + return (data, datetimes, coords, coords_local, coords_per_cell) def sample_tensors_uniform_vectorized( self, tensor_list: list, lengths: list, max_total_points: int diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index cca53bf65..24e5ec4f6 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -170,7 +170,6 @@ def hpy_splits( # extract length and flatten nested list idxs_ord_lens = [[len(a) for a in aa] for aa in idxs_ord] - # idxs_ord = [torch.cat([idxs for idxs in iidxs]) for iidxs in idxs_ord] return idxs_ord, idxs_ord_lens, posr3 @@ -215,8 +214,13 @@ def tokenize_spacetime( ) idxs_cur, idxs_cur_lens = tokenize_space(rdata_cur, token_size, hl, pad_tokens) - idxs_cells = [t + list(tc) for t, tc in zip(idxs_cells, idxs_cur, strict=True)] - idxs_cells_lens = [t + tc for t, tc in zip(idxs_cells_lens, idxs_cur_lens, strict=True)] + idxs_cells = [ + t + list(tc) if tc_l[0] > 0 else t + for t, tc, tc_l in zip(idxs_cells, idxs_cur, idxs_cur_lens, strict=True) + ] + idxs_cells_lens = [ + t + tc for t, tc in zip(idxs_cells_lens, idxs_cur_lens, strict=True) if len(tc) > 0 + ] return idxs_cells, idxs_cells_lens @@ -275,6 +279,7 @@ def tokenize_apply_mask( # local coords num_tokens_per_cell = [len(idxs) for idxs in idxs_cells_lens] mask_tokens_per_cell = torch.split(torch.from_numpy(mask_tokens), num_tokens_per_cell) + tokens_per_cell = torch.tensor([t.sum() for t in mask_tokens_per_cell]) masked_points_per_cell = torch.tensor( [ torch.tensor([len(t) for t, m in zip(tt, mm, strict=False) if m]).sum() @@ -283,7 +288,7 @@ def tokenize_apply_mask( ).to(dtype=torch.int32) coords_local = get_source_coords_local(coords, hpy_verts_rots, masked_points_per_cell) - # create tensor that contains all info + # create tensor that contains all data stream_ids = torch.full([len(datetimes), 1], stream_id, dtype=torch.float32) tokens = torch.cat((stream_ids, datetimes, coords_local, geoinfos, data), 1) @@ -292,7 +297,7 @@ def tokenize_apply_mask( idxs_data_lens = idxs_data_lens.tolist() tokens_cells = torch.split(tokens, idxs_data_lens) - return tokens_cells + return tokens_cells, tokens_per_cell def tokenize_apply_mask_target( @@ -326,8 +331,6 @@ def tokenize_apply_mask_target( if mask_tokens is not None: # filter tokens using mask to obtain flat per data point index list idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) - # filter list of token lens using mask and obtain flat list for splitting - idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) # apply mask datetimes = rdata.datetimes[idxs_data] @@ -362,11 +365,10 @@ def tokenize_apply_mask_target( hpy_nctrs, ) coords_local.requires_grad = False - tokens_coords_local = list(coords_local.split(idxs_data_lens.tolist())) else: - tokens_coords_local = torch.tensor([]) + coords_local = torch.tensor([]) - return data, datetimes, coords, tokens_coords_local + return data, datetimes, coords, coords_local, masked_points_per_cell def get_source_coords_local( @@ -498,146 +500,3 @@ def get_target_coords_local( a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] return a - - -#################################################################################################### - - -def tokenize_window_space( - stream_id: float, - coords: torch.tensor, - geoinfos, - source, - times, - time_win, - token_size, - hl, - hpy_verts_rots, - n_coords, - enc_time, - pad_tokens=True, - local_coords=True, -): - """Process one window into tokens""" - - # len(source)==1 would require special case handling that is not worth the effort - if len(source) < 2: - return - - # idx_ord_lens is length is number of tokens per healpix cell - idxs_ord, idxs_ord_lens, posr3 = hpy_splits(coords, hl, token_size, pad_tokens) - - # pad with zero at the beggining for token size padding - times_enc = enc_time(times, time_win) - times_enc_padded = torch.cat([torch.zeros_like(times_enc[0]).unsqueeze(0), times_enc]) - geoinfos_padded = torch.cat([torch.zeros_like(geoinfos[0]).unsqueeze(0), geoinfos]) - source_padded = torch.cat([torch.zeros_like(source[0]).unsqueeze(0), source]) - - # convert to local coordinates - # TODO: avoid that padded lists are rotated, which means potentially a lot of zeros - if local_coords: - coords_local = _coords_local(posr3, hpy_verts_rots, idxs_ord, n_coords) - else: - coords_local = torch.cat([torch.zeros_like(coords[0]).unsqueeze(0), coords]) - coords_local = [coords_local[idxs] for idxs in idxs_ord] - - # reorder based on cells (except for coords_local) and then cat along - # (time,coords,geoinfos,source) dimension and then split based on cells - tokens_cells = [ - ( - list( - torch.split( - torch.cat( - ( - torch.full([len(idxs), 1], stream_id, dtype=torch.float32), - times_enc_padded[idxs], - coords_local[i], - geoinfos_padded[idxs], - source_padded[idxs], - ), - 1, - ), - idxs_lens, - ) - ) - if idxs_lens[0] > 0 - else [] - ) - for i, (idxs, idxs_lens) in enumerate(zip(idxs_ord, idxs_ord_lens, strict=True)) - ] - - return tokens_cells - - -def tokenize_window_spacetime( - stream_id, - coords, - geoinfos, - source, - times, - time_win, - token_size, - hl, - hpy_verts_rots, - n_coords, - enc_time, - pad_tokens=True, - local_coords=True, -): - """Tokenize respecting an intrinsic time step in the data, i.e. each time step is tokenized - separately - """ - - num_healpix_cells = 12 * 4**hl - tokens_cells = [[] for _ in range(num_healpix_cells)] - - t_unique = np.unique(times) - for _, t in enumerate(t_unique): - mask = t == times - tokens_cells_cur = tokenize_window_space( - stream_id, - coords[mask], - geoinfos[mask], - source[mask], - times[mask], - time_win, - token_size, - hl, - hpy_verts_rots, - n_coords, - enc_time, - pad_tokens, - local_coords, - ) - - tokens_cells = [t + tc for t, tc in zip(tokens_cells, tokens_cells_cur, strict=True)] - - return tokens_cells - - -# def _coords_local( -# posr3: Tensor, hpy_verts_rots: Tensor, idxs_ord: list[Tensor], n_coords: CoordNormalizer -# ) -> list[Tensor]: -# """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" -# fp32 = torch.float32 -# posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3]) # prepend zero - -# idxs_ords_lens_l = [len(idxs) for idxs in idxs_ord] -# # int32 should be enough -# idxs_ords_lens = torch.tensor(idxs_ords_lens_l, dtype=torch.int32) -# # concat all indices -# idxs_ords_c = torch.cat([torch.tensor(i) for i in idxs_ord]) -# # Copy the rotation matrices for each healpix cell -# # num_points x 3 x 3 -# rots = torch.repeat_interleave(hpy_verts_rots, idxs_ords_lens, dim=0) -# # BMM only works for b x n x m and b x m x 1 -# # adding a dummy dimension to posr3 -# # numpoints x 3 x 1 -# posr3_sel = posr3[idxs_ords_c].unsqueeze(-1) -# vec_rot = torch.bmm(rots, posr3_sel) -# vec_rot = vec_rot.squeeze(-1) -# vec_scaled = n_coords(r3tos2(vec_rot).to(fp32)) -# # split back to ragged list -# # num_points x 2 -# coords_local = torch.split(vec_scaled, idxs_ords_lens_l, dim=0) -# return list(coords_local) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 194249e28..98d5a044e 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -266,139 +266,6 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): return a -#################################################################################################### -def tcs_optimized(stacked_coords: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]: - """ - Args: - target_coords: List of 2D coordinate tensors, each with shape [N, 2] - - Returns: - tcs: List of transformed coordinates - concatenated_coords: All original coords concatenated - """ - - # Single vectorized coordinate transformation - theta_all = torch.deg2rad(90.0 - stacked_coords[..., 0]) - phi_all = torch.deg2rad(180.0 + stacked_coords[..., 1]) - - # Transform all coordinates - transformed_all = s2tor3(theta_all, phi_all) # [total_points, 3] - - # Split back to original structure - sizes = [t.shape[0] for t in target_coords] # Get original tensor sizes - tcs = list(torch.split(transformed_all, sizes, dim=0)) # Split back to list - - return stacked_coords - - -#################################################################################################### -def get_target_coords_local_ffast( - hlc, - masked_points_per_cell, - coords, - target_geoinfos, - target_times, - verts_rots, - verts_local, - nctrs, -): - """Generate local coordinates for target coords w.r.t healpix cell vertices and - and for healpix cell vertices themselves - """ - - # target_coords_lens = [len(t) for t in target_coords] - # tcs, target_coords = tcs_optimized(target_coords) - target_coords = s2tor3(*theta_phi_to_standard_coords(coords)) - tcs = torch.split(masked_points_per_cell) - - if target_coords.shape[0] == 0: - return torch.tensor([]) - target_geoinfos = torch.cat(target_geoinfos) - target_times = torch.cat(target_times) - - verts00_rots, verts10_rots, verts11_rots, verts01_rots, vertsmm_rots = verts_rots - - a = torch.zeros( - [ - *target_coords.shape[:-1], - 1 + target_geoinfos.shape[1] + target_times.shape[1] + 5 * (3 * 5) + 3 * 8, - ] - ) - # TODO: properly set stream_id, implicitly zero at the moment - geoinfo_offset = 1 - a[..., geoinfo_offset : geoinfo_offset + target_times.shape[1]] = target_times - geoinfo_offset += target_times.shape[1] - a[..., geoinfo_offset : geoinfo_offset + target_geoinfos.shape[1]] = target_geoinfos - geoinfo_offset += target_geoinfos.shape[1] - - ref = torch.tensor([1.0, 0.0, 0.0]) - - tcs_lens = torch.tensor([tt.shape[0] for tt in tcs], dtype=torch.int32) - tcs_lens_mask = tcs_lens > 0 - tcs_lens = tcs_lens[tcs_lens_mask] - - vls = torch.cat( - [ - vl.repeat([tt, 1, 1]) - for tt, vl in zip(tcs_lens, verts_local[tcs_lens_mask], strict=False) - ], - 0, - ) - vls = vls.transpose(0, 1) - - zi = 0 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts00_rots, tcs - ) - - zi = 3 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[0] - - zi = 15 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts10_rots, tcs - ) - - zi = 18 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[1] - - zi = 30 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts11_rots, tcs - ) - - zi = 33 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[2] - - zi = 45 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts01_rots, tcs - ) - - zi = 48 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[3] - - zi = 60 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - vertsmm_rots, tcs - ) - - zi = 63 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[4] - - tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in nctrs], -1) - zi = 75 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs - # a = add_local_vert_coords_ctrs2( vertsmm_rots, nctrs, tcs, a, 99, geoinfo_offset) - - # remaining geoinfos (zenith angle etc) - # zi=99+3*8; - zi = 99 - a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] - - return a - - def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to From 9229e48d5a3488268f0eee703af6363a843e485b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 13 Nov 2025 23:19:21 +0100 Subject: [PATCH 024/344] Minor cleanup --- src/weathergen/datasets/tokenizer_masking.py | 4 +--- src/weathergen/datasets/tokenizer_utils.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 4d6618477..bf1cfcefc 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -66,7 +66,7 @@ def batchify_source( idxs_cells, idxs_cells_lens, rdata ) - source_tokens_cells, source_tokens_lens = tokenize_apply_mask( + source_tokens_cells, source_tokens_lens = tokenize_apply_mask_source( idxs_cells, idxs_cells_lens, mask_tokens, @@ -96,7 +96,6 @@ def batchify_target( time_win: tuple, ): token_size = stream_info["token_size"] - stream_id = stream_info["stream_id"] # target is empty if len(self.mask_tokens) == 0: @@ -119,7 +118,6 @@ def batchify_target( idxs_cells_lens, mask_tokens, mask_channels, - stream_id, rdata, time_win, self.hpy_verts_rots_target, diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 24e5ec4f6..191638bb9 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -225,7 +225,7 @@ def tokenize_spacetime( return idxs_cells, idxs_cells_lens -def tokenize_apply_mask( +def tokenize_apply_mask_source( idxs_cells, idxs_cells_lens, mask_tokens, @@ -306,7 +306,6 @@ def tokenize_apply_mask_target( idxs_cells_lens, mask_tokens, mask_channels, - stream_id, rdata, time_win, hpy_verts_rots, From db6f2858d4884f8859a6dd9b433d7824a4cb4e7f Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 13 Nov 2025 23:26:31 +0100 Subject: [PATCH 025/344] Fixed linting --- src/weathergen/datasets/tokenizer_masking.py | 3 +- src/weathergen/datasets/tokenizer_utils.py | 143 +++++++++++++++++++ src/weathergen/datasets/utils.py | 106 ++++++++++++++ 3 files changed, 250 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index bf1cfcefc..8db9bca18 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -16,8 +16,7 @@ from weathergen.datasets.tokenizer_utils import ( encode_times_source, encode_times_target, - # get_target_coords_local_ffast, - tokenize_apply_mask, + tokenize_apply_mask_source, tokenize_apply_mask_target, tokenize_space, tokenize_spacetime, diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 191638bb9..286ef260e 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -499,3 +499,146 @@ def get_target_coords_local( a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] return a + + +#################################################################################################### + + +def tokenize_window_space( + stream_id: float, + coords: torch.tensor, + geoinfos, + source, + times, + time_win, + token_size, + hl, + hpy_verts_rots, + n_coords, + enc_time, + pad_tokens=True, + local_coords=True, +): + """Process one window into tokens""" + + # len(source)==1 would require special case handling that is not worth the effort + if len(source) < 2: + return + + # idx_ord_lens is length is number of tokens per healpix cell + idxs_ord, idxs_ord_lens, posr3 = hpy_splits(coords, hl, token_size, pad_tokens) + + # pad with zero at the beggining for token size padding + times_enc = enc_time(times, time_win) + times_enc_padded = torch.cat([torch.zeros_like(times_enc[0]).unsqueeze(0), times_enc]) + geoinfos_padded = torch.cat([torch.zeros_like(geoinfos[0]).unsqueeze(0), geoinfos]) + source_padded = torch.cat([torch.zeros_like(source[0]).unsqueeze(0), source]) + + # convert to local coordinates + # TODO: avoid that padded lists are rotated, which means potentially a lot of zeros + if local_coords: + coords_local = _coords_local(posr3, hpy_verts_rots, idxs_ord, n_coords) + else: + coords_local = torch.cat([torch.zeros_like(coords[0]).unsqueeze(0), coords]) + coords_local = [coords_local[idxs] for idxs in idxs_ord] + + # reorder based on cells (except for coords_local) and then cat along + # (time,coords,geoinfos,source) dimension and then split based on cells + tokens_cells = [ + ( + list( + torch.split( + torch.cat( + ( + torch.full([len(idxs), 1], stream_id, dtype=torch.float32), + times_enc_padded[idxs], + coords_local[i], + geoinfos_padded[idxs], + source_padded[idxs], + ), + 1, + ), + idxs_lens, + ) + ) + if idxs_lens[0] > 0 + else [] + ) + for i, (idxs, idxs_lens) in enumerate(zip(idxs_ord, idxs_ord_lens, strict=True)) + ] + + return tokens_cells + + +def tokenize_window_spacetime( + stream_id, + coords, + geoinfos, + source, + times, + time_win, + token_size, + hl, + hpy_verts_rots, + n_coords, + enc_time, + pad_tokens=True, + local_coords=True, +): + """Tokenize respecting an intrinsic time step in the data, i.e. each time step is tokenized + separately + """ + + num_healpix_cells = 12 * 4**hl + tokens_cells = [[] for _ in range(num_healpix_cells)] + + t_unique = np.unique(times) + for _, t in enumerate(t_unique): + mask = t == times + tokens_cells_cur = tokenize_window_space( + stream_id, + coords[mask], + geoinfos[mask], + source[mask], + times[mask], + time_win, + token_size, + hl, + hpy_verts_rots, + n_coords, + enc_time, + pad_tokens, + local_coords, + ) + + tokens_cells = [t + tc for t, tc in zip(tokens_cells, tokens_cells_cur, strict=True)] + + return tokens_cells + + +# def _coords_local( +# posr3: Tensor, hpy_verts_rots: Tensor, idxs_ord: list[Tensor], n_coords: CoordNormalizer +# ) -> list[Tensor]: +# """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" +# fp32 = torch.float32 +# posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3]) # prepend zero + +# idxs_ords_lens_l = [len(idxs) for idxs in idxs_ord] +# # int32 should be enough +# idxs_ords_lens = torch.tensor(idxs_ords_lens_l, dtype=torch.int32) +# # concat all indices +# idxs_ords_c = torch.cat([torch.tensor(i) for i in idxs_ord]) +# # Copy the rotation matrices for each healpix cell +# # num_points x 3 x 3 +# rots = torch.repeat_interleave(hpy_verts_rots, idxs_ords_lens, dim=0) +# # BMM only works for b x n x m and b x m x 1 +# # adding a dummy dimension to posr3 +# # numpoints x 3 x 1 +# posr3_sel = posr3[idxs_ords_c].unsqueeze(-1) +# vec_rot = torch.bmm(rots, posr3_sel) +# vec_rot = vec_rot.squeeze(-1) +# vec_scaled = n_coords(r3tos2(vec_rot).to(fp32)) +# # split back to ragged list +# # num_points x 2 +# coords_local = torch.split(vec_scaled, idxs_ords_lens_l, dim=0) +# return list(coords_local) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 98d5a044e..6637cfa19 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -266,6 +266,112 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): return a +#################################################################################################### +def get_target_coords_local_ffast( + hlc, + masked_points_per_cell, + coords, + target_geoinfos, + target_times, + verts_rots, + verts_local, + nctrs, +): + """Generate local coordinates for target coords w.r.t healpix cell vertices and + and for healpix cell vertices themselves + """ + + target_coords = s2tor3(*theta_phi_to_standard_coords(coords)) + tcs = torch.split(masked_points_per_cell) + + if target_coords.shape[0] == 0: + return torch.tensor([]) + target_geoinfos = torch.cat(target_geoinfos) + target_times = torch.cat(target_times) + + verts00_rots, verts10_rots, verts11_rots, verts01_rots, vertsmm_rots = verts_rots + + a = torch.zeros( + [ + *target_coords.shape[:-1], + 1 + target_geoinfos.shape[1] + target_times.shape[1] + 5 * (3 * 5) + 3 * 8, + ] + ) + # TODO: properly set stream_id, implicitly zero at the moment + geoinfo_offset = 1 + a[..., geoinfo_offset : geoinfo_offset + target_times.shape[1]] = target_times + geoinfo_offset += target_times.shape[1] + a[..., geoinfo_offset : geoinfo_offset + target_geoinfos.shape[1]] = target_geoinfos + geoinfo_offset += target_geoinfos.shape[1] + + ref = torch.tensor([1.0, 0.0, 0.0]) + + tcs_lens = torch.tensor([tt.shape[0] for tt in tcs], dtype=torch.int32) + tcs_lens_mask = tcs_lens > 0 + tcs_lens = tcs_lens[tcs_lens_mask] + + vls = torch.cat( + [ + vl.repeat([tt, 1, 1]) + for tt, vl in zip(tcs_lens, verts_local[tcs_lens_mask], strict=False) + ], + 0, + ) + vls = vls.transpose(0, 1) + + zi = 0 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts00_rots, tcs + ) + + zi = 3 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[0] + + zi = 15 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts10_rots, tcs + ) + + zi = 18 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[1] + + zi = 30 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts11_rots, tcs + ) + + zi = 33 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[2] + + zi = 45 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + verts01_rots, tcs + ) + + zi = 48 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[3] + + zi = 60 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( + vertsmm_rots, tcs + ) + + zi = 63 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[4] + + tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in nctrs], -1) + zi = 75 + a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs + # a = add_local_vert_coords_ctrs2( vertsmm_rots, nctrs, tcs, a, 99, geoinfo_offset) + + # remaining geoinfos (zenith angle etc) + # zi=99+3*8; + zi = 99 + a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] + + return a + + def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to From 7866ff7a18a900cb5bd5c653af11ea926221c2ad Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 14 Nov 2025 07:13:11 +0000 Subject: [PATCH 026/344] Restructuring and correcting forward pass during inference --- src/weathergen/model/diffusion.py | 47 ++++++++++++++++++------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 3082f15aa..a7102a0d9 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -34,13 +34,13 @@ def get_sample_len(self): def get_input_data(self, t: int): return self.model_samples[t]["data"] - + def get_input_metadata(self, t: int): return self.model_samples[t]["metadata"] def get_target_data(self, t: int): return self.target_samples[t]["data"] - + def get_target_metadata(self, t: int): return self.target_samples[t]["metadata"] @@ -71,34 +71,43 @@ def __init__( self.p_std = p_std def forward(self, data: BatchData) -> torch.Tensor: + """ + Model forward call during training. Unpacks the conditioning c = [x_{t-k}, ..., x_{t}], the + target y = x_{t+1}, and the random noise eta from the data, computes the diffusion noise + level sigma, and feeds the noisy target along with the conditioning and sigma through the + model to return a denoised prediction. + """ # Retrieve conditionings [0:-1], target [-1], and noise from data object. - # The data retrieval ignores batch and stream dimension for now (has to be adapted). - cond = [data.get_input_data(t) for t in range(data.get_sample_len() - 1)] + # TOOD: The data retrieval ignores batch and stream dimension for now (has to be adapted). + c = [data.get_input_data(t) for t in range(data.get_sample_len() - 1)] y = data.get_input_data(-1) eta = data.get_input_metadata(-1) # Compute sigma (noise level) from eta - #noise = torch.randn(y.shape, device=y.device) + # noise = torch.randn(y.shape, device=y.device) # now eta from MultiStreamDataSampler sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma + return self.denoise(x=y + n, c=c, sigma=sigma) + + # Compute loss -- move this to a separate loss calculator + # weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 + # loss = weight * ((y_hat - y) ** 2) - # Compute conditionings + def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float) -> torch.Tensor: + """ + The actual diffusion step, where the model removes noise from the input x under + consideration of a conditioning c (e.g., previous time steps) and the current diffusion + noise level sigma. + """ + # Compute scaling conditionings c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt c_noise = sigma.log() / 4 - # Add noise, precondition input, and feed through network - x = y + n - x = self.preconditioner.precondition(x, cond) - net_out = self.net(c_in * x, c_noise) - y_hat = c_skip * y + c_out * net_out # Eq. (7) - - return y_hat - - # Compute loss -- move this to a separate loss calculator - # weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 - # loss = weight * ((y_hat - y) ** 2) + # Precondition input and feed through network + x = self.preconditioner.precondition(x, c) + return c_skip * x + c_out * self.net(c_in * x, c_noise) # Eq. (7) in EDM paper def inference( self, @@ -127,7 +136,7 @@ def inference( ): # 0, ..., N-1 x_cur = x_next - # Increase noise temporarily. (Stochastic sampling?) + # Increase noise temporarily. (Stochastic sampling; not used for now) # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 # t_hat = self.net.round_sigma(t_cur + gamma * t_cur) # x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * s_noise * torch.randn_like(x_cur) @@ -135,7 +144,7 @@ def inference( t_hat = t_cur # Euler step. - denoised = self.net(x_hat, t_hat) + denoised = self.denoise(x=x_hat, c=None, sigma=t_hat) # c to be discussed d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur From ec3812356e5f87de1fb631e68624757c15de929e Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 14 Nov 2025 08:27:21 +0100 Subject: [PATCH 027/344] Fixed remaining problems that occured for NPP-ATMS and SYNOP. TODO: - Forecast still needs to be adapted - Some more cleanup of variable naming, return values etc --- src/weathergen/datasets/tokenizer_masking.py | 2 + src/weathergen/datasets/tokenizer_utils.py | 44 +++++++++----------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 8db9bca18..a4770b54d 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -77,6 +77,8 @@ def batchify_source( encode_times_source, ) + # import code; code.interact( local=locals()) + # if source_tokens_lens.sum() > 0: # source_centroids = self.compute_source_centroids(source_tokens_cells) # else: diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 286ef260e..97017eecb 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -94,8 +94,8 @@ def encode_times_target(times, time_win) -> torch.tensor: time_tensor[..., 3] = np.cos(time_tensor[..., 3] / (12.0 * 3600.0) * 2.0 * np.pi) time_tensor[..., 4] = np.sin(time_tensor[..., 4] / (12.0 * 3600.0) * 2.0 * np.pi) - # We add + 0.5 as in ERA5 very often we otherwise get 0 as the first time and to prevent too - # many zeros in the input, where we cannot learn anything we add an offset + # We add + 0.5 as for datasets with regular time steps we otherwise very often get 0 as the + # first time and to prevent too many zeros in the input return time_tensor + 0.5 @@ -106,12 +106,10 @@ def hpy_cell_splits(coords: torch.tensor, hl: int): hpy_idxs_ord_split : list of per cell indices into thetas,phis,posr3 thetas : thetas in rad phis : phis in rad - posr3 : (thetas,phis) as position in R3 """ thetas, phis = theta_phi_to_standard_coords(coords) # healpix cells for all points hpy_idxs = ang2pix(2**hl, thetas, phis, nest=True) - posr3 = s2tor3(thetas, phis) # extract information to split according to cells by first sorting and then finding split idxs hpy_idxs_ord = np.argsort(hpy_idxs, **numpy_argsort_args) @@ -124,7 +122,7 @@ def hpy_cell_splits(coords: torch.tensor, hl: int): for b, x in zip(np.unique(np.unique(hpy_idxs[hpy_idxs_ord])), hpy_idxs_ord_temp, strict=True): hpy_idxs_ord_split[b] = x - return (hpy_idxs_ord_split, thetas, phis, posr3) + return (hpy_idxs_ord_split, thetas, phis) def hpy_splits( @@ -138,11 +136,10 @@ def hpy_splits( idxs_ord : flat list of indices (to data points) per healpix cell idxs_ord_lens : lens of lists per cell (so that data[idxs_ord].split( idxs_ord_lens) provides per cell data) - posr3 : R^3 positions of coords """ # list of data points per healpix cell - (hpy_idxs_ord_split, thetas, phis, posr3) = hpy_cell_splits(coords, hl) + (hpy_idxs_ord_split, thetas, phis) = hpy_cell_splits(coords, hl) # if token_size is exceeed split based on latitude # TODO: split by hierarchically traversing healpix scheme @@ -161,17 +158,23 @@ def hpy_splits( offset = 1 if pad_tokens else 0 int32 = torch.int32 idxs_ord = [ - torch.split( - torch.cat((torch.from_numpy(np.take(idxs, ts) + offset), torch.zeros(r, dtype=int32))), - token_size, + list( + torch.split( + torch.cat( + (torch.from_numpy(np.take(idxs, ts) + offset), torch.zeros(r, dtype=int32)) + ), + token_size, + ) ) + if len(idxs) > 0 + else [] for idxs, ts, r in zip(hpy_idxs_ord_split, thetas_sorted, rem, strict=True) ] # extract length and flatten nested list idxs_ord_lens = [[len(a) for a in aa] for aa in idxs_ord] - return idxs_ord, idxs_ord_lens, posr3 + return idxs_ord, idxs_ord_lens def tokenize_space( @@ -182,12 +185,8 @@ def tokenize_space( ): """Process one window into tokens""" - # len(source)==1 would require special case handling that is not worth the effort - if len(rdata.data) < 2: - return - # idx_ord_lens is length is number of tokens per healpix cell - idxs_ord, idxs_ord_lens, _ = hpy_splits(rdata.coords, hl, token_size, pad_tokens) + idxs_ord, idxs_ord_lens = hpy_splits(rdata.coords, hl, token_size, pad_tokens) return idxs_ord, idxs_ord_lens @@ -208,19 +207,16 @@ def tokenize_spacetime( t_unique = np.unique(rdata.datetimes) for _, t in enumerate(t_unique): + # data for current time step mask = t == rdata.datetimes rdata_cur = IOReaderData( rdata.coords[mask], rdata.geoinfos[mask], rdata.data[mask], rdata.datetimes[mask] ) idxs_cur, idxs_cur_lens = tokenize_space(rdata_cur, token_size, hl, pad_tokens) - idxs_cells = [ - t + list(tc) if tc_l[0] > 0 else t - for t, tc, tc_l in zip(idxs_cells, idxs_cur, idxs_cur_lens, strict=True) - ] - idxs_cells_lens = [ - t + tc for t, tc in zip(idxs_cells_lens, idxs_cur_lens, strict=True) if len(tc) > 0 - ] + # collect data for all time steps + idxs_cells = [t + tc for t, tc in zip(idxs_cells, idxs_cur, strict=True)] + idxs_cells_lens = [t + tc_l for t, tc_l in zip(idxs_cells_lens, idxs_cur_lens, strict=True)] return idxs_cells, idxs_cells_lens @@ -491,10 +487,8 @@ def get_target_coords_local( tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in nctrs], -1) zi = 75 a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs - # a = add_local_vert_coords_ctrs2( vertsmm_rots, nctrs, tcs, a, 99, geoinfo_offset) # remaining geoinfos (zenith angle etc) - # zi=99+3*8; zi = 99 a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] From 0634105d34361d084f83399816c89d6f5b097c4b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 14 Nov 2025 09:59:13 +0100 Subject: [PATCH 028/344] Enabled support for forecast. Cleaned up some bits and pieces. --- src/weathergen/datasets/masking.py | 41 ++++- .../datasets/multi_stream_data_sampler.py | 17 +- src/weathergen/datasets/tokenizer_forecast.py | 149 ---------------- src/weathergen/datasets/tokenizer_masking.py | 26 +-- src/weathergen/datasets/tokenizer_utils.py | 164 ++---------------- src/weathergen/datasets/utils.py | 106 ----------- 6 files changed, 62 insertions(+), 441 deletions(-) delete mode 100644 src/weathergen/datasets/tokenizer_forecast.py diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index b8f56b023..8ce2e039b 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -140,20 +140,20 @@ def mask_source_idxs( idxs_cells, idxs_cells_lens, rdata, - ) -> torch.Tensor: + ) -> (torch.Tensor, torch.Tensor): """ Return: torch.Tensor[bool] of length num_tokens that determines masking for each token """ - mask_tokens, mask_channels = None, None + self.mask_tokens, self.mask_channels = None, None num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() # If there are no tokens, return empty lists. if num_tokens == 0: - return (mask_tokens, mask_channels) + return (self.mask_tokens, self.mask_channels) # Clean strategy selection self.current_strategy = self._select_strategy() @@ -162,24 +162,47 @@ def mask_source_idxs( rate = self._get_sampling_rate() if self.current_strategy == "random": - mask_tokens = self.rng.uniform(0, 1, num_tokens) < rate + self.mask_tokens = self.rng.uniform(0, 1, num_tokens) < rate elif self.current_strategy == "forecast": - mask_tokens = np.zeros( - num_tokens, - ) + self.mask_tokens = np.ones(num_tokens, dtype=np.bool) elif self.current_strategy == "healpix": # TODO: currently only for fixed level num_cells = len(idxs_cells_lens) mask_cells = self.rng.uniform(0, 1, num_cells) < rate # translate cell mask to token mask, replicating using number of tokens per cell - mask_tokens = [ + self.mask_tokens = [ (torch.ones(2, dtype=torch.bool) * (1 if m else 0)).to(torch.bool) for idxs_cell, m in zip(idxs_cells_lens, mask_cells, strict=False) ] else: assert False, f"Unsupported masking strategy: {self.current_strategy}" - return (mask_tokens, mask_channels) + return (self.mask_tokens, self.mask_channels) + + def mask_targets_idxs( + self, + idxs_cells, + idxs_cells_lens, + rdata, + ) -> (torch.Tensor, torch.Tensor): + # mask_source_idxs is + assert (self.mask_tokens is not None) or (self.mask_tokens is not None) + + # TODO: better handling of if statement + if self.current_strategy == "forecast": + num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() + self.mask_tokens = np.ones(num_tokens, dtype=np.bool) + else: + # masking strategies: target is complement of source + # TODO: ensure/enforce that forecast_offset==0 + if self.mask_tokens is not None: + self.mask_tokens = ~self.mask_tokens + if self.mask_channels is not None: + self.mask_channels = ~self.mask_channels + + # TODO: self.mask_tokens seems brittle in terms of naming + + return (self.mask_tokens, self.mask_channels) def mask_source( self, diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 63a0e5d63..c3fda1ac1 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -26,7 +26,6 @@ from weathergen.datasets.icon_dataset import IconDataset from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData, spoof -from weathergen.datasets.tokenizer_forecast import TokenizerForecast from weathergen.datasets.tokenizer_masking import TokenizerMasking from weathergen.datasets.utils import ( compute_idxs_predict, @@ -224,17 +223,8 @@ def __init__( self.healpix_level: int = cf.healpix_level self.num_healpix_cells: int = 12 * 4**self.healpix_level - if cf.training_mode == "forecast": - self.tokenizer = TokenizerForecast(cf.healpix_level) - elif cf.training_mode == "masking": - masker = Masker(cf) - self.tokenizer = TokenizerMasking(cf.healpix_level, masker) - assert self.forecast_offset == 0, "masked token modeling requires auto-encoder training" - msg = "masked token modeling does not support self.input_window_steps > 1; " - msg += "increase window length" - assert self.input_window_steps == 1, msg - else: - assert False, f"Unsupported training mode: {cf.training_mode}" + masker = Masker(cf) + self.tokenizer = TokenizerMasking(cf.healpix_level, masker) self.epoch = 0 @@ -387,10 +377,9 @@ def __iter__(self): stream_info, readerdata_to_torch(rdata), (time_win_source.start, time_win_source.end), - stream_ds[0].normalize_coords, ) - # TODO: rdata only be collected in validation mode + # collect data for stream stream_data.add_source(rdata, ss_lens, ss_cells, ss_centroids) # target diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py deleted file mode 100644 index d54831265..000000000 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ /dev/null @@ -1,149 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -from functools import partial - -import numpy as np -import torch - -from weathergen.common.io import IOReaderData -from weathergen.datasets.tokenizer import Tokenizer -from weathergen.datasets.tokenizer_utils import ( - encode_times_source, - encode_times_target, - hpy_cell_splits, - tokenize_window_space, - tokenize_window_spacetime, -) -from weathergen.datasets.utils import ( - get_target_coords_local_ffast, -) - - -class TokenizerForecast(Tokenizer): - def reset_rng(self, rng) -> None: - """ - Reset rng after epoch to ensure proper randomization - """ - self.rng = rng - - def batchify_source( - self, - stream_info: dict, - rdata: IOReaderData, - time_win: tuple, - normalize_coords, - ): - token_size = stream_info["token_size"] - is_diagnostic = stream_info.get("diagnostic", False) - tokenize_spacetime = stream_info.get("tokenize_spacetime", False) - - tokenize_window = partial( - tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, - time_win=time_win, - token_size=token_size, - hl=self.hl_source, - hpy_verts_rots=self.hpy_verts_rots_source[-1], - n_coords=normalize_coords, - enc_time=encode_times_source, - ) - - source_tokens_cells = [torch.tensor([])] - source_centroids = [torch.tensor([])] - source_tokens_lens = torch.zeros([self.num_healpix_cells_source], dtype=torch.int32) - - if is_diagnostic or rdata.data.shape[1] == 0 or len(rdata.data) < 2: - return (source_tokens_cells, source_tokens_lens, source_centroids) - - # TODO: properly set stream_id; don't forget to normalize - source_tokens_cells = tokenize_window( - 0, - rdata.coords, - rdata.geoinfos, - rdata.data, - rdata.datetimes, - ) - - source_tokens_cells = [ - torch.stack(c) if len(c) > 0 else torch.tensor([]) for c in source_tokens_cells - ] - - source_tokens_lens = torch.tensor([len(s) for s in source_tokens_cells], dtype=torch.int32) - if source_tokens_lens.sum() > 0: - source_centroids = self.compute_source_centroids(source_tokens_cells) - - return (source_tokens_cells, source_tokens_lens, source_centroids) - - def batchify_target( - self, - stream_info: dict, - sampling_rate_target: float, - rdata: IOReaderData, - time_win: tuple, - ): - target_tokens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32) - target_coords = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32) - target_tokens_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32) - - sampling_rate_target = stream_info.get("sampling_rate_target", sampling_rate_target) - if sampling_rate_target < 1.0: - mask = self.rng.uniform(0.0, 1.0, rdata.data.shape[0]) < sampling_rate_target - rdata.coords = rdata.coords[mask] - rdata.geoinfos = rdata.geoinfos[mask] - rdata.data = rdata.data[mask] - rdata.datetimes = rdata.datetimes[mask] - - # TODO: currently treated as empty to avoid special case handling - if len(rdata.data) < 2: - return (target_tokens, target_coords, torch.tensor([]), torch.tensor([])) - - # compute indices for each cell - hpy_idxs_ord_split, _, _, _ = hpy_cell_splits(rdata.coords, self.hl_target) - - # TODO: expose parameter - with_perm_target = True - if with_perm_target: - hpy_idxs_ord_split = [ - idx[self.rng.permutation(len(idx))[: int(len(idx))]] for idx in hpy_idxs_ord_split - ] - - # helper variables to split according to cells - idxs_ord = np.concatenate(hpy_idxs_ord_split) - ll = np.cumsum(np.array([len(a) for a in hpy_idxs_ord_split]))[:-1] - - # compute encoding of time - times_reordered = rdata.datetimes[idxs_ord] - times_reordered_enc = encode_times_target(times_reordered, time_win) - - # reorder and split all relevant information based on cells - target_tokens = np.split(rdata.data[idxs_ord], ll) - coords_reordered = rdata.coords[idxs_ord] - target_coords = np.split(coords_reordered, ll) - target_coords_raw = np.split(coords_reordered, ll) - target_geoinfos = np.split(rdata.geoinfos[idxs_ord], ll) - target_times_raw = np.split(times_reordered, ll) - target_times = np.split(times_reordered_enc, ll) - - target_tokens_lens = torch.tensor([len(s) for s in target_tokens], dtype=torch.int32) - - # compute encoding of target coordinates used in prediction network - if target_tokens_lens.sum() > 0: - target_coords = get_target_coords_local_ffast( - self.hl_target, - target_coords, - target_geoinfos, - target_times, - self.hpy_verts_rots_target, - self.hpy_verts_local_target, - self.hpy_nctrs_target, - ) - target_coords.requires_grad = False - target_coords = list(target_coords.split(target_tokens_lens.tolist())) - - return (target_tokens, target_coords, target_coords_raw, target_times_raw) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index a4770b54d..6d95cbb86 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -35,15 +35,11 @@ def reset_rng(self, rng) -> None: self.masker.reset_rng(rng) self.rng = rng - self.mask_tokens = None - self.mask_channels = None - def batchify_source( self, stream_info: dict, rdata: IOReaderData, time_win: tuple, - normalize_coords, # dataset ): token_size = stream_info["token_size"] stream_id = stream_info["stream_id"] @@ -77,16 +73,12 @@ def batchify_source( encode_times_source, ) - # import code; code.interact( local=locals()) - # if source_tokens_lens.sum() > 0: # source_centroids = self.compute_source_centroids(source_tokens_cells) # else: # TODO: remove completely? source_centroids = [torch.tensor([])] - self.mask_tokens, self.mask_channels = mask_tokens, mask_channels - return (source_tokens_cells, source_tokens_lens, source_centroids) def batchify_target( @@ -98,20 +90,18 @@ def batchify_target( ): token_size = stream_info["token_size"] - # target is empty - if len(self.mask_tokens) == 0: - out = torch.tensor([]) - return (out, out, out, out, out) - # create tokenization index tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) - mask_tokens = ~self.mask_tokens - # TODO - # mask_channels = ~self.mask_channels if self.mask_channels is not None - # else self.mask_channels - mask_channels = self.mask_channels + (mask_tokens, mask_channels) = self.masker.mask_targets_idxs( + idxs_cells, idxs_cells_lens, rdata + ) + # mask_tokens = ~self.mask_tokens + # # TODO + # # mask_channels = ~self.mask_channels if self.mask_channels is not None + # # else self.mask_channels + # mask_channels = self.mask_channels data, datetimes, coords, coords_local, coords_per_cell = tokenize_apply_mask_target( self.hl_target, diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 97017eecb..7c5d056ac 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -248,7 +248,14 @@ def tokenize_apply_mask_source( # apply spatial masking on a per token level if mask_tokens is not None: # filter tokens using mask to obtain flat per data point index list - idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) + idxs_data = [t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m] + + if len(idxs_data) == 0: + tokens_cells = [] + tokens_per_cell = torch.zeros(len(idxs_cells_lens), dtype=torch.int32) + return tokens_cells, tokens_per_cell + + idxs_data = torch.cat(idxs_data) # filter list of token lens using mask and obtain flat list for splitting idxs_data_lens = torch.tensor([t for t, m in zip(idxs_lens, mask_tokens, strict=True) if m]) @@ -325,7 +332,17 @@ def tokenize_apply_mask_target( # apply spatial masking on a per token level if mask_tokens is not None: # filter tokens using mask to obtain flat per data point index list - idxs_data = torch.cat([t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m]) + idxs_data = [t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m] + + if len(idxs_data) == 0: + do = torch.zeros([0, rdata.data.shape[-1]]) + coords = torch.zeros([0, rdata.coords.shape[-1]]) + dt = np.array([], dtype=np.datetime64) + masked_points_per_cell = torch.zeros(len(idxs_cells_lens), dtype=torch.int32) + # data, datetimes, coords, coords_local, masked_points_per_cell + return do, dt, coords, coords, masked_points_per_cell + + idxs_data = torch.cat(idxs_data) # apply mask datetimes = rdata.datetimes[idxs_data] @@ -493,146 +510,3 @@ def get_target_coords_local( a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] return a - - -#################################################################################################### - - -def tokenize_window_space( - stream_id: float, - coords: torch.tensor, - geoinfos, - source, - times, - time_win, - token_size, - hl, - hpy_verts_rots, - n_coords, - enc_time, - pad_tokens=True, - local_coords=True, -): - """Process one window into tokens""" - - # len(source)==1 would require special case handling that is not worth the effort - if len(source) < 2: - return - - # idx_ord_lens is length is number of tokens per healpix cell - idxs_ord, idxs_ord_lens, posr3 = hpy_splits(coords, hl, token_size, pad_tokens) - - # pad with zero at the beggining for token size padding - times_enc = enc_time(times, time_win) - times_enc_padded = torch.cat([torch.zeros_like(times_enc[0]).unsqueeze(0), times_enc]) - geoinfos_padded = torch.cat([torch.zeros_like(geoinfos[0]).unsqueeze(0), geoinfos]) - source_padded = torch.cat([torch.zeros_like(source[0]).unsqueeze(0), source]) - - # convert to local coordinates - # TODO: avoid that padded lists are rotated, which means potentially a lot of zeros - if local_coords: - coords_local = _coords_local(posr3, hpy_verts_rots, idxs_ord, n_coords) - else: - coords_local = torch.cat([torch.zeros_like(coords[0]).unsqueeze(0), coords]) - coords_local = [coords_local[idxs] for idxs in idxs_ord] - - # reorder based on cells (except for coords_local) and then cat along - # (time,coords,geoinfos,source) dimension and then split based on cells - tokens_cells = [ - ( - list( - torch.split( - torch.cat( - ( - torch.full([len(idxs), 1], stream_id, dtype=torch.float32), - times_enc_padded[idxs], - coords_local[i], - geoinfos_padded[idxs], - source_padded[idxs], - ), - 1, - ), - idxs_lens, - ) - ) - if idxs_lens[0] > 0 - else [] - ) - for i, (idxs, idxs_lens) in enumerate(zip(idxs_ord, idxs_ord_lens, strict=True)) - ] - - return tokens_cells - - -def tokenize_window_spacetime( - stream_id, - coords, - geoinfos, - source, - times, - time_win, - token_size, - hl, - hpy_verts_rots, - n_coords, - enc_time, - pad_tokens=True, - local_coords=True, -): - """Tokenize respecting an intrinsic time step in the data, i.e. each time step is tokenized - separately - """ - - num_healpix_cells = 12 * 4**hl - tokens_cells = [[] for _ in range(num_healpix_cells)] - - t_unique = np.unique(times) - for _, t in enumerate(t_unique): - mask = t == times - tokens_cells_cur = tokenize_window_space( - stream_id, - coords[mask], - geoinfos[mask], - source[mask], - times[mask], - time_win, - token_size, - hl, - hpy_verts_rots, - n_coords, - enc_time, - pad_tokens, - local_coords, - ) - - tokens_cells = [t + tc for t, tc in zip(tokens_cells, tokens_cells_cur, strict=True)] - - return tokens_cells - - -# def _coords_local( -# posr3: Tensor, hpy_verts_rots: Tensor, idxs_ord: list[Tensor], n_coords: CoordNormalizer -# ) -> list[Tensor]: -# """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" -# fp32 = torch.float32 -# posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3]) # prepend zero - -# idxs_ords_lens_l = [len(idxs) for idxs in idxs_ord] -# # int32 should be enough -# idxs_ords_lens = torch.tensor(idxs_ords_lens_l, dtype=torch.int32) -# # concat all indices -# idxs_ords_c = torch.cat([torch.tensor(i) for i in idxs_ord]) -# # Copy the rotation matrices for each healpix cell -# # num_points x 3 x 3 -# rots = torch.repeat_interleave(hpy_verts_rots, idxs_ords_lens, dim=0) -# # BMM only works for b x n x m and b x m x 1 -# # adding a dummy dimension to posr3 -# # numpoints x 3 x 1 -# posr3_sel = posr3[idxs_ords_c].unsqueeze(-1) -# vec_rot = torch.bmm(rots, posr3_sel) -# vec_rot = vec_rot.squeeze(-1) -# vec_scaled = n_coords(r3tos2(vec_rot).to(fp32)) -# # split back to ragged list -# # num_points x 2 -# coords_local = torch.split(vec_scaled, idxs_ords_lens_l, dim=0) -# return list(coords_local) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 6637cfa19..98d5a044e 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -266,112 +266,6 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): return a -#################################################################################################### -def get_target_coords_local_ffast( - hlc, - masked_points_per_cell, - coords, - target_geoinfos, - target_times, - verts_rots, - verts_local, - nctrs, -): - """Generate local coordinates for target coords w.r.t healpix cell vertices and - and for healpix cell vertices themselves - """ - - target_coords = s2tor3(*theta_phi_to_standard_coords(coords)) - tcs = torch.split(masked_points_per_cell) - - if target_coords.shape[0] == 0: - return torch.tensor([]) - target_geoinfos = torch.cat(target_geoinfos) - target_times = torch.cat(target_times) - - verts00_rots, verts10_rots, verts11_rots, verts01_rots, vertsmm_rots = verts_rots - - a = torch.zeros( - [ - *target_coords.shape[:-1], - 1 + target_geoinfos.shape[1] + target_times.shape[1] + 5 * (3 * 5) + 3 * 8, - ] - ) - # TODO: properly set stream_id, implicitly zero at the moment - geoinfo_offset = 1 - a[..., geoinfo_offset : geoinfo_offset + target_times.shape[1]] = target_times - geoinfo_offset += target_times.shape[1] - a[..., geoinfo_offset : geoinfo_offset + target_geoinfos.shape[1]] = target_geoinfos - geoinfo_offset += target_geoinfos.shape[1] - - ref = torch.tensor([1.0, 0.0, 0.0]) - - tcs_lens = torch.tensor([tt.shape[0] for tt in tcs], dtype=torch.int32) - tcs_lens_mask = tcs_lens > 0 - tcs_lens = tcs_lens[tcs_lens_mask] - - vls = torch.cat( - [ - vl.repeat([tt, 1, 1]) - for tt, vl in zip(tcs_lens, verts_local[tcs_lens_mask], strict=False) - ], - 0, - ) - vls = vls.transpose(0, 1) - - zi = 0 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts00_rots, tcs - ) - - zi = 3 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[0] - - zi = 15 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts10_rots, tcs - ) - - zi = 18 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[1] - - zi = 30 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts11_rots, tcs - ) - - zi = 33 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[2] - - zi = 45 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - verts01_rots, tcs - ) - - zi = 48 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[3] - - zi = 60 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + 3)] = ref - locs_to_cell_coords_ctrs( - vertsmm_rots, tcs - ) - - zi = 63 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + vls.shape[-1])] = vls[4] - - tcs_ctrs = torch.cat([ref - torch.cat(locs_to_ctr_coords(c, tcs)) for c in nctrs], -1) - zi = 75 - a[..., (geoinfo_offset + zi) : (geoinfo_offset + zi + (3 * 8))] = tcs_ctrs - # a = add_local_vert_coords_ctrs2( vertsmm_rots, nctrs, tcs, a, 99, geoinfo_offset) - - # remaining geoinfos (zenith angle etc) - # zi=99+3*8; - zi = 99 - a[..., (geoinfo_offset + zi) :] = target_coords[..., (geoinfo_offset + 2) :] - - return a - - def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to From cab9fbe9a6745fa6b9cc4b9d7288e69bd0e2bf77 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 14 Nov 2025 10:41:18 +0100 Subject: [PATCH 029/344] mv streams_data declaration under if condition --- src/weathergen/train/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 695150e3b..8cf2c067a 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -692,7 +692,7 @@ def validate(self, epoch): ) targets = {"physical": batch[0]} - streams_data: list[list[StreamData]] = batch[0] + # compute loss loss_values = self.loss_calculator_val.compute_loss( preds=output, @@ -701,6 +701,7 @@ def validate(self, epoch): # log output if bidx < cf.log_validation: # TODO: Move _prepare_logging into write_validation by passing streams_data + streams_data: list[list[StreamData]] = batch[0] ( preds_all, targets_all, From 20da55574f91eef716abcb55882131e29880e3d3 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 14 Nov 2025 12:07:10 +0100 Subject: [PATCH 030/344] add weight to loss config, add toy loss class LossPhysicalTwo --- config/default_config.yml | 6 +- src/weathergen/train/loss_calculator.py | 40 +-- src/weathergen/train/loss_modules/__init__.py | 4 +- .../train/loss_modules/loss_module_base.py | 2 +- .../loss_modules/loss_module_physical.py | 269 ++++++++++++++++++ 5 files changed, 299 insertions(+), 22 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 3da835e5a..e99d9f423 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -94,13 +94,15 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" -training_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 1.0]]}, + LossPhysicalTwo: {weight: 0.3, loss_fcts: [['mse', 1.0]]}, + } } # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], # LossLatent: [['mse', 0.3]], # LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} # } -validation_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},} } # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2eda80fce..fbfaebdb0 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -9,9 +9,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import dataclasses import logging from omegaconf import DictConfig +from torch import Tensor import weathergen.train.loss_modules as LossModules from weathergen.train.loss_modules.loss_module_base import LossValues @@ -20,6 +22,18 @@ _logger = logging.getLogger(__name__) +@dataclasses.dataclass +class LossTerms: + """ + A dataclass which combines the LossValues of all loss modules + """ + + # The primary scalar loss value for optimization. + loss: Tensor + # Dictionary containing the LossValues of each loss module. + loss_terms: dict[str, LossValues] + + class LossCalculator: """ Manages and computes the overall loss for a WeatherGenerator model during @@ -53,14 +67,13 @@ def __init__( calculator_configs = ( cf.training_mode_config.losses if stage == TRAIN else cf.validation_mode_config.losses ) - calculator_configs = [ - (getattr(LossModules, Cls), losses) for (Cls, losses) in calculator_configs.items() + (getattr(LossModules, Cls), config) for (Cls, config) in calculator_configs.items() ] self.loss_calculators = [ - Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device) - for (Cls, losses) in calculator_configs + (config.weight, Cls(cf=cf, loss_fcts=config.loss_fcts, stage=stage, device=self.device)) + for (Cls, config) in calculator_configs ] def compute_loss( @@ -68,17 +81,10 @@ def compute_loss( preds: dict, targets: dict, ): - loss_values = {} + loss_terms = {} loss = 0 - for calculator in self.loss_calculators: - loss_values[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) - loss += loss_values[calculator.name].loss - - # Bring all loss values together - # TODO: make sure keys are explicit, e.g loss_mse.latent.loss_2t - losses_all = {} - stddev_all = {} - for _, v in loss_values.items(): - losses_all.update(v.losses_all) - stddev_all.update(v.stddev_all) - return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + for weight, calculator in self.loss_calculators: + loss_terms[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) + loss += weight * loss_terms[calculator.name].loss + + return LossTerms(loss=loss, loss_terms=loss_terms) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 7f5fc906d..43be4dfe1 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -1,5 +1,5 @@ from .loss_module_latent import LossLatent -from .loss_module_physical import LossPhysical +from .loss_module_physical import LossPhysical, LossPhysicalTwo from .loss_module_ssl import LossStudentTeacher -__all__ = [LossLatent, LossPhysical, LossStudentTeacher] +__all__ = [LossLatent, LossPhysical, LossPhysicalTwo, LossStudentTeacher] diff --git a/src/weathergen/train/loss_modules/loss_module_base.py b/src/weathergen/train/loss_modules/loss_module_base.py index de66bda28..8e6ad3b5d 100644 --- a/src/weathergen/train/loss_modules/loss_module_base.py +++ b/src/weathergen/train/loss_modules/loss_module_base.py @@ -21,7 +21,7 @@ @dataclasses.dataclass class LossValues: """ - A dataclass to encapsulate the various loss components computed by the LossCalculator. + A dataclass to encapsulate the loss components returned by each loss module. This provides a structured way to return the primary loss used for optimization, along with detailed per-stream/per-channel/per-loss-function losses for logging, diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 54d30acc1..1e900f25f 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -291,3 +291,272 @@ def compute_loss( # Return all computed loss components encapsulated in a ModelLoss dataclass return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossPhysicalTwo(LossModuleBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossPhysicalTwo" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + preds = preds.physical + streams_data = targets["physical"] + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = self._loss_per_loss_function( + loss_fct, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) From ce6c735bd9c420078d0f2dbfbc23c403185a11b5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 14 Nov 2025 16:56:51 +0100 Subject: [PATCH 031/344] Removing centroids options for embedding that was unused and should not be used. --- src/weathergen/datasets/masking.py | 12 ++- .../datasets/multi_stream_data_sampler.py | 8 +- src/weathergen/datasets/stream_data.py | 10 +-- src/weathergen/datasets/tokenizer_masking.py | 19 ++--- src/weathergen/model/embeddings.py | 28 ++----- src/weathergen/model/engines.py | 81 ++++++++++++------- 6 files changed, 78 insertions(+), 80 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 8ce2e039b..e3b0b8095 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -137,6 +137,7 @@ def _select_strategy(self): def mask_source_idxs( self, + stream_info, idxs_cells, idxs_cells_lens, rdata, @@ -155,7 +156,7 @@ def mask_source_idxs( if num_tokens == 0: return (self.mask_tokens, self.mask_channels) - # Clean strategy selection + # clean strategy selection self.current_strategy = self._select_strategy() # Set the masking rate. @@ -163,8 +164,10 @@ def mask_source_idxs( if self.current_strategy == "random": self.mask_tokens = self.rng.uniform(0, 1, num_tokens) < rate + elif self.current_strategy == "forecast": self.mask_tokens = np.ones(num_tokens, dtype=np.bool) + elif self.current_strategy == "healpix": # TODO: currently only for fixed level num_cells = len(idxs_cells_lens) @@ -174,13 +177,17 @@ def mask_source_idxs( (torch.ones(2, dtype=torch.bool) * (1 if m else 0)).to(torch.bool) for idxs_cell, m in zip(idxs_cells_lens, mask_cells, strict=False) ] + elif self.current_strategy == "cropping" or self.current_strategy == "causal": + pass + else: - assert False, f"Unsupported masking strategy: {self.current_strategy}" + assert False, f"Unsupported masking strategy: {self.current_strategy}." return (self.mask_tokens, self.mask_channels) def mask_targets_idxs( self, + stream_info, idxs_cells, idxs_cells_lens, rdata, @@ -192,6 +199,7 @@ def mask_targets_idxs( if self.current_strategy == "forecast": num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() self.mask_tokens = np.ones(num_tokens, dtype=np.bool) + else: # masking strategies: target is complement of source # TODO: ensure/enforce that forecast_offset==0 diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c3fda1ac1..1331abc0c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -206,8 +206,6 @@ def __init__( self.shuffle = shuffle # TODO: remove options that are no longer supported self.input_window_steps = cf.input_window_steps - self.embed_local_coords = cf.embed_local_coords - self.embed_centroids_local_coords = cf.embed_centroids_local_coords self.sampling_rate_target = cf.sampling_rate_target self.batch_size = batch_size @@ -352,6 +350,8 @@ def __iter__(self): streams_data: list[StreamData] = [] + # tokenizer.generate_masks_for_sample() + # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): stream_data = StreamData( @@ -373,14 +373,14 @@ def __iter__(self): stream_data.source_is_spoof = True # preprocess data for model input - (ss_cells, ss_lens, ss_centroids) = self.tokenizer.batchify_source( + (ss_cells, ss_lens) = self.tokenizer.batchify_source( stream_info, readerdata_to_torch(rdata), (time_win_source.start, time_win_source.end), ) # collect data for stream - stream_data.add_source(rdata, ss_lens, ss_cells, ss_centroids) + stream_data.add_source(rdata, ss_lens, ss_cells) # target diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 55d79c34c..b8051d81b 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -62,7 +62,6 @@ def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None: self.source_tokens_cells = [] # length of source tokens per cell (without padding) self.source_tokens_lens = [] - self.source_centroids = [] # unprocessed source (for logging) self.source_raw = [] # auxiliary data for scatter operation that changes from stream-centric to cell-centric @@ -85,7 +84,6 @@ def to_device(self, device: str) -> None: """ self.source_tokens_cells = self.source_tokens_cells.to(device, non_blocking=True) - self.source_centroids = self.source_centroids.to(device, non_blocking=True) self.source_tokens_lens = self.source_tokens_lens.to(device, non_blocking=True) self.target_coords = [t.to(device, non_blocking=True) for t in self.target_coords] @@ -113,7 +111,6 @@ def add_empty_source(self, source: IOReaderData) -> None: self.source_raw += [source] self.source_tokens_lens += [torch.ones([self.healpix_cells], dtype=torch.int32)] self.source_tokens_cells += [torch.tensor([])] - self.source_centroids += [torch.tensor([])] def add_empty_target(self, fstep: int) -> None: """ @@ -137,9 +134,7 @@ def add_empty_target(self, fstep: int) -> None: np.array([], dtype="datetime64[ns]") for _ in range(self.healpix_cells) ] - def add_source( - self, ss_raw: IOReaderData, ss_lens: torch.tensor, ss_cells: list, ss_centroids: list - ) -> None: + def add_source(self, ss_raw: IOReaderData, ss_lens: torch.tensor, ss_cells: list) -> None: """ Add data for source for one input. @@ -149,8 +144,6 @@ def add_source( ss_lens : torch.tensor( number of healpix cells ) ss_cells : list( number of healpix cells ) [ torch.tensor( tokens per cell, token size, number of channels) ] - ss_centroids : list(number of healpix cells ) - [ torch.tensor( for source , 5) ] Returns ------- @@ -160,7 +153,6 @@ def add_source( self.source_raw = ss_raw self.source_tokens_lens = ss_lens self.source_tokens_cells = torch.stack(ss_cells) - self.source_centroids = torch.cat(ss_centroids) idx = torch.isnan(self.source_tokens_cells) self.source_tokens_cells[idx] = self.mask_value diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 6d95cbb86..2f959a2c7 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -57,8 +57,10 @@ def batchify_source( tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=True) + # select strategy from XXX depending on stream and if student or teacher + (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - idxs_cells, idxs_cells_lens, rdata + stream_info, idxs_cells, idxs_cells_lens, rdata ) source_tokens_cells, source_tokens_lens = tokenize_apply_mask_source( @@ -73,13 +75,7 @@ def batchify_source( encode_times_source, ) - # if source_tokens_lens.sum() > 0: - # source_centroids = self.compute_source_centroids(source_tokens_cells) - # else: - # TODO: remove completely? - source_centroids = [torch.tensor([])] - - return (source_tokens_cells, source_tokens_lens, source_centroids) + return (source_tokens_cells, source_tokens_lens) def batchify_target( self, @@ -95,13 +91,8 @@ def batchify_target( idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) (mask_tokens, mask_channels) = self.masker.mask_targets_idxs( - idxs_cells, idxs_cells_lens, rdata + stream_info, idxs_cells, idxs_cells_lens, rdata ) - # mask_tokens = ~self.mask_tokens - # # TODO - # # mask_channels = ~self.mask_channels if self.mask_channels is not None - # # else self.mask_channels - # mask_channels = self.mask_channels data, datetimes, coords, coords_local, coords_per_cell = tokenize_apply_mask_target( self.hl_target, diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index c9a7b456c..0925c0c50 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -32,7 +32,6 @@ def __init__( num_heads, dropout_rate=0.0, norm_type="LayerNorm", - embed_size_centroids=64, unembed_mode="full", stream_name="stream_embed", ): @@ -57,7 +56,6 @@ def __init__( self.dim_out = dim_out self.num_blocks = num_blocks self.num_heads = num_heads - self.embed_size_centroids = embed_size_centroids self.unembed_mode = unembed_mode norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -90,14 +88,11 @@ def __init__( self.ln_final = norm(num_channels * self.dim_embed, eps=1e-03) self.unembed = torch.nn.Linear( num_channels * self.dim_embed, - self.num_tokens * self.dim_out - embed_size_centroids, + self.num_tokens * self.dim_out, ) elif self.unembed_mode == "block": - # modify embed_size_centroids to ensure no additional padding is needed - rem = (self.num_tokens * self.dim_out - embed_size_centroids) % num_channels - embed_size_centroids += rem - dim_out = (self.num_tokens * self.dim_out - embed_size_centroids) // num_channels + dim_out = (self.num_tokens * self.dim_out) // num_channels self.unembed = torch.nn.ModuleList( [torch.nn.Linear(dim_embed, dim_out) for _ in range(num_channels)] # [ @@ -116,7 +111,6 @@ def __init__( raise ValueError(f"Unknown unembed mode: {unembed_mode}") elif mode == "columns": - assert embed_size_centroids == 0 self.embed = torch.nn.Linear(self.dim_in, self.dim_embed) assert self.unembed_mode == "block" # only supported mode at the moment @@ -125,7 +119,7 @@ def __init__( self.out_pad = torch.nn.Parameter(torch.zeros(self.pad), requires_grad=False) self.unembed = torch.nn.Linear( self.dim_embed, - self.num_tokens * ((self.dim_out - embed_size_centroids) // token_size), + self.num_tokens * (self.dim_out // token_size), ) self.ln_final = norm(dim_out, eps=1e-6) @@ -140,9 +134,8 @@ def __init__( raise ValueError(f"Unknown mode: {mode}") self.dropout_final = torch.nn.Dropout(0.1) - self.embed_centroids = torch.nn.Linear(5, embed_size_centroids) - def forward_channels(self, x_in, centroids): + def forward_channels(self, x_in): peh = positional_encoding_harmonic # embed provided input data @@ -163,11 +156,6 @@ def forward_channels(self, x_in, centroids): else: raise ValueError(f"Unknown unembed mode: {self.unembed_mode}") - # append centroids - if self.embed_size_centroids > 0: - out = torch.cat([out, self.embed_centroids(centroids)], -1) - # if self.embed_size_centroids==0 and self.dim_out is not divisible by #channels with - # unembed_mode block then we need to pad to have the expected output shape if out.shape[-1] < self.dim_out: out = torch.nn.functional.pad(out, [0, self.dim_out - out.shape[-1]], value=0.0) # final reshape @@ -175,7 +163,7 @@ def forward_channels(self, x_in, centroids): return out - def forward_columns(self, x_in, centroids): + def forward_columns(self, x_in): # embed provided input data x = positional_encoding_harmonic(checkpoint(self.embed, x_in, use_reentrant=False)) @@ -192,11 +180,11 @@ def forward_columns(self, x_in, centroids): return out.to(torch.float16) - def forward(self, x_in, centroids): + def forward(self, x_in): if self.mode == "channels": - return self.forward_channels(x_in, centroids) + return self.forward_channels(x_in) elif self.mode == "columns": - return self.forward_columns(x_in, centroids) + return self.forward_columns(x_in) else: raise ValueError(f"Unknown mode {self.mode}") diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7359d1403..f9fa1598c 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -47,7 +47,7 @@ def __init__(self, cf: Config, sources_size) -> None: for i, si in enumerate(self.cf.streams): stream_name = si.get("name", i) - if si.get("diagnostic", False) or self.sources_size[i] == 0: + if "diagnostic" in si and si["diagnostic"]: self.embeds.append(torch.nn.Identity()) continue @@ -64,7 +64,6 @@ def __init__(self, cf: Config, sources_size) -> None: num_heads=si["embed"]["num_heads"], dropout_rate=self.cf.embed_dropout_rate, norm_type=self.cf.norm_type, - embed_size_centroids=self.cf.embed_size_centroids, unembed_mode=self.cf.embed_unembed_mode, stream_name=stream_name, ) @@ -107,7 +106,7 @@ def forward(self, streams_data, pe_embed, dtype, device): # create full scatter index # (there's no broadcasting which is likely highly inefficient) idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - x_embed = embed(s.source_tokens_cells, s.source_centroids).flatten(0, 1) + x_embed = embed(s.source_tokens_cells).flatten(0, 1) # there's undocumented limitation in flash_attn that will make embed fail if # #tokens is too large; code below is a work around # x_embed = torch.cat( @@ -197,32 +196,35 @@ def __init__(self, cf: Config) -> None: attention_dtype=get_dtype(self.cf.attention_dtype), ) ) - self.ae_adapter.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.ae_adapter_dropout_rate, - norm_type=self.cf.norm_type, - norm_eps=self.cf.mlp_norm_eps, + + ae_adapter_num_blocks = cf.get("ae_adapter_num_blocks", 2) + for _ in range(ae_adapter_num_blocks - 1): + self.ae_adapter.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.ae_adapter_dropout_rate, + norm_type=self.cf.norm_type, + norm_eps=self.cf.mlp_norm_eps, + ) ) - ) - self.ae_adapter.append( - MultiCrossAttentionHeadVarlenSlicedQ( - self.cf.ae_global_dim_embed, - self.cf.ae_local_dim_embed, - num_slices_q=self.cf.ae_local_num_queries, - dim_head_proj=self.cf.ae_adapter_embed, - num_heads=self.cf.ae_adapter_num_heads, - with_residual=self.cf.ae_adapter_with_residual, - with_qk_lnorm=self.cf.ae_adapter_with_qk_lnorm, - dropout_rate=self.cf.ae_adapter_dropout_rate, - with_flash=self.cf.with_flash_attention, - norm_type=self.cf.norm_type, - norm_eps=self.cf.norm_eps, - attention_dtype=get_dtype(self.cf.attention_dtype), + self.ae_adapter.append( + MultiCrossAttentionHeadVarlenSlicedQ( + self.cf.ae_global_dim_embed, + self.cf.ae_local_dim_embed, + num_slices_q=self.cf.ae_local_num_queries, + dim_head_proj=self.cf.ae_adapter_embed, + num_heads=self.cf.ae_adapter_num_heads, + with_residual=self.cf.ae_adapter_with_residual, + with_qk_lnorm=self.cf.ae_adapter_with_qk_lnorm, + dropout_rate=self.cf.ae_adapter_dropout_rate, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + ) ) - ) def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant): for block in self.ae_adapter: @@ -299,6 +301,10 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) + self.ae_global_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) + def forward(self, tokens, use_reentrant): for block in self.ae_global_blocks: tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) @@ -333,7 +339,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=(1 if cf.forecast_with_step_conditioning else 0), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), ) @@ -349,7 +355,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=(1 if cf.forecast_with_step_conditioning else 0), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), ) @@ -367,6 +373,10 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) + self.fe_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) + def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) @@ -377,11 +387,20 @@ def init_weights_final(m): block.apply(init_weights_final) def forward(self, tokens, fstep): + # predict residual to last time step if requested + forecast_residual = self.cf.get("forecast_residual", False) + if forecast_residual: + tokens_in = tokens + + # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") for block in self.fe_blocks: - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + if type(block) is torch.nn.LayerNorm: + tokens = block(tokens) + else: + tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) - return tokens + return tokens if not forecast_residual else (tokens_in + tokens) class EnsPredictionHead(torch.nn.Module): From 8fa544dc9db1d7216ca56558ec378a3bf5dc8e9d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 14 Nov 2025 20:43:57 +0100 Subject: [PATCH 032/344] Removed unused parameters --- config/default_config.yml | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index fb01e5aba..b5e2eef4d 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -1,9 +1,6 @@ streams_directory: "./config/streams/era5_1deg/" embed_orientation: "channels" -embed_local_coords: True -embed_centroids_local_coords: False -embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 @@ -42,7 +39,7 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 forecast_steps: 0 forecast_policy: null @@ -91,18 +88,16 @@ validate_with_ema: True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 -# training mode: "forecast" or "masking" (masked token modeling) -# for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +masking_strategy: "forecast" + +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +masking_strategy: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) # note that a sampled masking rate leads to varying requirements masking_rate_sampling: True -# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) -sampling_rate_target: 1.0 -# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" -masking_strategy: "random" # masking_strategy_config is a dictionary of additional parameters for the masking strategy # required for "healpix" and "channel" masking strategies # "healpix": requires healpix mask level to be specified with `hl_mask` @@ -135,12 +130,13 @@ norm_type: "LayerNorm" nn_module: "te" log_grad_norms: False -start_date: 197901010000 +# start_date: 197901010000 +start_date: 201401010000 end_date: 202012310000 start_date_val: 202101010000 end_date_val: 202201010000 -len_hrs: 6 -step_hrs: 6 +len_hrs: 3 +step_hrs: 3 input_window_steps: 1 val_initial: False From d7b326ba72a31d52939b5434861a1bf2b29cd8f9 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 14 Nov 2025 23:52:13 +0100 Subject: [PATCH 033/344] fixed trainer for multiple terms in losses_all, still need to fix logging --- src/weathergen/train/trainer.py | 56 +++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8cf2c067a..63b1d07d5 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -640,17 +640,23 @@ def train(self, epoch): self.world_size_original * self.cf.batch_size_per_gpu, ) - self.loss_unweighted_hist += [loss_values.losses_all] + if bidx == 0: + self.loss_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.stdev_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.loss_model_hist = [] + for name, loss_terms in loss_values.loss_terms.items(): + self.loss_unweighted_hist[name].append(loss_terms.losses_all) + self.stdev_unweighted_hist[name].append(loss_terms.stddev_all) self.loss_model_hist += [loss_values.loss.item()] - self.stdev_unweighted_hist += [loss_values.stddev_all] perf_gpu, perf_mem = self.get_perf() self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() - self._log_terminal(bidx, epoch, TRAIN) - if bidx % self.train_log_freq.metrics == 0: - self._log(TRAIN) + # NEED TO FIX LOGGING + # self._log_terminal(bidx, epoch, TRAIN) + # if bidx % self.train_log_freq.metrics == 0: + # self._log(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: @@ -665,7 +671,6 @@ def validate(self, epoch): self.model.eval() dataset_val_iter = iter(self.data_loader_validation) - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp @@ -730,14 +735,22 @@ def validate(self, epoch): sample_idxs, ) - self.loss_unweighted_hist += [loss_values.losses_all] + self.loss_unweighted_hist += [loss_values.loss_terms] + self.loss_model_hist += [loss_values.loss.item()] + if bidx == 0: + self.loss_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.stdev_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.loss_model_hist = [] + for name, loss_terms in loss_values.loss_terms.items(): + self.loss_unweighted_hist[name].append(loss_terms.losses_all) + self.stdev_unweighted_hist[name].append(loss_terms.stddev_all) self.loss_model_hist += [loss_values.loss.item()] - self.stdev_unweighted_hist += [loss_values.stddev_all] pbar.update(self.cf.batch_size_validation_per_gpu) - self._log_terminal(bidx, epoch, VAL) - self._log(VAL) + # NEED TO FIX LOGGING + # self._log_terminal(bidx, epoch, VAL) + # self._log(VAL) # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() @@ -961,21 +974,24 @@ def _prepare_losses_for_logging( stddev_all (dict[str, torch.Tensor]): Dictionary mapping each stream name to its per-channel standard deviation tensor. """ - losses_all: dict[str, Tensor] = {} - stddev_all: dict[str, Tensor] = {} + losses_all: dict[dict[str, Tensor]] = {} + stddev_all: dict[dict[str, Tensor]] = {} # Make list of losses into a tensor. This is individual tensor per rank real_loss = torch.tensor(self.loss_model_hist, device=self.device) # Gather all tensors from all ranks into a list and stack them into one tensor again real_loss = torch.cat(all_gather_vlen(real_loss)) - for stream in self.cf.streams: # Loop over all streams - stream_hist = [losses_all[stream.name] for losses_all in self.loss_unweighted_hist] - stream_all = torch.stack(stream_hist).to(torch.float64) - losses_all[stream.name] = torch.cat(all_gather_vlen(stream_all)) - stream_hist = [stddev_all[stream.name] for stddev_all in self.stdev_unweighted_hist] - stream_all = torch.stack(stream_hist).to(torch.float64) - stddev_all[stream.name] = torch.cat(all_gather_vlen(stream_all)) + for name in self.loss_unweighted_hist.keys(): + losses_all[name] = {} + stddev_all[name] = {} + for stream in self.cf.streams: # Loop over all streams + stream_hist = [losses[stream.name] for losses in self.loss_unweighted_hist[name]] + stream_all = torch.stack(stream_hist).to(torch.float64) + losses_all[name][stream.name] = torch.cat(all_gather_vlen(stream_all)) + stream_hist = [stddevs[stream.name] for stddevs in self.stdev_unweighted_hist[name]] + stream_all = torch.stack(stream_hist).to(torch.float64) + stddev_all[name][stream.name] = torch.cat(all_gather_vlen(stream_all)) return real_loss, losses_all, stddev_all @@ -1010,7 +1026,7 @@ def _log(self, stage: Stage): self.perf_mem, ) - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + self.loss_unweighted_hist, self.loss_model_hist = [], [] def _get_tensor_item(self, tensor): """ From 5d127bfeded12e1a55fbc4edfe614e5c0ba445e6 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sun, 16 Nov 2025 17:01:08 +0100 Subject: [PATCH 034/344] Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM --- src/weathergen/datasets/masking.py | 7 ++++++- src/weathergen/datasets/multi_stream_data_sampler.py | 4 ++-- src/weathergen/datasets/stream_data.py | 5 +++++ src/weathergen/datasets/tokenizer_masking.py | 4 ++-- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index e3b0b8095..6d91b569d 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -194,12 +194,17 @@ def mask_targets_idxs( ) -> (torch.Tensor, torch.Tensor): # mask_source_idxs is assert (self.mask_tokens is not None) or (self.mask_tokens is not None) + idxs_ord_inv = torch.tensor([], dtype=torch.int64) # TODO: better handling of if statement if self.current_strategy == "forecast": num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() self.mask_tokens = np.ones(num_tokens, dtype=np.bool) + # inverse map for reordering to output data points in same order as input + idxs_ord = torch.cat([t for tt in idxs_cells for t in tt]) + idxs_ord_inv = torch.argsort(idxs_ord) + else: # masking strategies: target is complement of source # TODO: ensure/enforce that forecast_offset==0 @@ -210,7 +215,7 @@ def mask_targets_idxs( # TODO: self.mask_tokens seems brittle in terms of naming - return (self.mask_tokens, self.mask_channels) + return (self.mask_tokens, self.mask_channels, idxs_ord_inv) def mask_source( self, diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 1331abc0c..55e4124c3 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -408,14 +408,14 @@ def __iter__(self): stream_data.target_is_spoof = True # preprocess data for model input - (tt_cells, tt_t, tt_c, tc, tc_l) = self.tokenizer.batchify_target( + (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( stream_info, self.sampling_rate_target, readerdata_to_torch(rdata), (time_win_target.start, time_win_target.end), ) - stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t) + stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) # merge inputs for sources and targets for current stream streams_data += [stream_data] diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index b8051d81b..18f2ac046 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -57,6 +57,7 @@ def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None: self.target_tokens_lens = [ torch.tensor([0 for _ in range(self.healpix_cells)]) for _ in range(forecast_steps + 1) ] + self.idxs_inv = [torch.tensor([], dtype=torch.int64) for _ in range(forecast_steps + 1)] # source tokens per cell self.source_tokens_cells = [] @@ -165,6 +166,7 @@ def add_target( target_coords_per_cell: torch.tensor, target_coords_raw: torch.tensor, times_raw: torch.tensor, + idxs_inv: torch.tensor, ) -> None: """ Add data for target for one input. @@ -184,6 +186,8 @@ def add_target( target_times : list( number of healpix cells) [ torch.tensor( points per cell) ] absolute target times + idxs_inv: + Indices to reorder targets back to order in input Returns ------- @@ -195,6 +199,7 @@ def add_target( self.target_coords_lens[fstep] = target_coords_per_cell self.target_times_raw[fstep] = times_raw self.target_coords_raw[fstep] = target_coords_raw + self.idxs_inv[fstep] = idxs_inv def target_empty(self) -> bool: """ diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 2f959a2c7..dcf81d394 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -90,7 +90,7 @@ def batchify_target( tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) - (mask_tokens, mask_channels) = self.masker.mask_targets_idxs( + (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( stream_info, idxs_cells, idxs_cells_lens, rdata ) @@ -111,7 +111,7 @@ def batchify_target( # TODO, TODO, TODO: max_num_targets # max_num_targets = stream_info.get("max_num_targets", -1) - return (data, datetimes, coords, coords_local, coords_per_cell) + return (data, datetimes, coords, coords_local, coords_per_cell, idxs_ord_inv) def sample_tensors_uniform_vectorized( self, tensor_list: list, lengths: list, max_total_points: int From 3ffdc6093aacebfd073b71ee23e15ac92fb2ac09 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Mon, 17 Nov 2025 12:02:04 +0100 Subject: [PATCH 035/344] fix _log_terminal --- src/weathergen/train/trainer.py | 44 +++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 63b1d07d5..b1e5ad240 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -654,7 +654,7 @@ def train(self, epoch): self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() # NEED TO FIX LOGGING - # self._log_terminal(bidx, epoch, TRAIN) + self._log_terminal(bidx, epoch, TRAIN) # if bidx % self.train_log_freq.metrics == 0: # self._log(TRAIN) @@ -749,7 +749,7 @@ def validate(self, epoch): pbar.update(self.cf.batch_size_validation_per_gpu) # NEED TO FIX LOGGING - # self._log_terminal(bidx, epoch, VAL) + self._log_terminal(bidx, epoch, VAL) # self._log(VAL) # avoid that there is a systematic bias in the validation subset @@ -982,16 +982,20 @@ def _prepare_losses_for_logging( # Gather all tensors from all ranks into a list and stack them into one tensor again real_loss = torch.cat(all_gather_vlen(real_loss)) - for name in self.loss_unweighted_hist.keys(): - losses_all[name] = {} - stddev_all[name] = {} + for calc_name in self.loss_unweighted_hist.keys(): + losses_all[calc_name] = {} + stddev_all[calc_name] = {} for stream in self.cf.streams: # Loop over all streams - stream_hist = [losses[stream.name] for losses in self.loss_unweighted_hist[name]] + stream_hist = [ + losses[stream.name] for losses in self.loss_unweighted_hist[calc_name] + ] stream_all = torch.stack(stream_hist).to(torch.float64) - losses_all[name][stream.name] = torch.cat(all_gather_vlen(stream_all)) - stream_hist = [stddevs[stream.name] for stddevs in self.stdev_unweighted_hist[name]] + losses_all[calc_name][stream.name] = torch.cat(all_gather_vlen(stream_all)) + stream_hist = [ + stddevs[stream.name] for stddevs in self.stdev_unweighted_hist[calc_name] + ] stream_all = torch.stack(stream_hist).to(torch.float64) - stddev_all[name][stream.name] = torch.cat(all_gather_vlen(stream_all)) + stddev_all[calc_name][stream.name] = torch.cat(all_gather_vlen(stream_all)) return real_loss, losses_all, stddev_all @@ -1059,11 +1063,12 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): logger.info( f"validation ({self.cf.run_id}) : {epoch:03d} : {avg_loss.nanmean().item()}" ) - for _, st in enumerate(self.cf.streams): - logger.info( - "{}".format(st["name"]) - + f" : {losses_all[st['name']].nanmean():0.4E} \t", - ) + for calc_name, losses in losses_all.items(): + for _, st in enumerate(self.cf.streams): + logger.info( + f"{calc_name}.{st['name']}" + + f" : {losses[st['name']].nanmean():0.4E} \t", + ) logger.info("\n") elif stage == TRAIN: @@ -1080,11 +1085,12 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): pstr += f"s/sec={(print_freq * self.cf.batch_size_per_gpu) / dt:.3f})" logger.info(pstr) logger.info("\t") - for _, st in enumerate(self.cf.streams): - logger.info( - "{}".format(st["name"]) - + f" : {losses_all[st['name']].nanmean():0.4E} \t", - ) + for calc_name, losses in losses_all.items(): + for _, st in enumerate(self.cf.streams): + logger.info( + f"{calc_name}.{st['name']}" + + f" : {losses[st['name']].nanmean():0.4E} \t", + ) logger.info("\n") self.t_start = time.time() From debbb8fdf8d162b9a771e36a10459e47e96d6247 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 17 Nov 2025 12:28:07 +0100 Subject: [PATCH 036/344] Changes to prepare_logging to apply index inversion --- src/weathergen/train/trainer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..b24f944c9 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -518,6 +518,14 @@ def _prepare_logging( ] for fstep in range(forecast_offset, forecast_offset + forecast_steps + 1) ] + # inverse indices + idxs_inv_rt = [ + [ + torch.cat([t[i].idxs_inv[fstep] for t in streams_data]) + for i in range(len(self.cf.streams)) + ] + for fstep in range(forecast_offset, forecast_offset + forecast_steps + 1) + ] # assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams) fsteps = len(targets_rt) @@ -533,6 +541,7 @@ def _prepare_logging( for fstep in range(len(targets_rt)): for i_strm, target in enumerate(targets_rt[fstep]): pred = preds[fstep][i_strm] + idxs_inv = idxs_inv_rt[fstep][i_strm] if not (target.shape[0] > 0 and pred.shape[0] > 0): continue @@ -548,6 +557,15 @@ def _prepare_logging( targets_lens[fstep][i_strm] += [target.shape[0]] dn_data = self.dataset_val.denormalize_target_channels + # reorder so that output order of target points matches input when reading + # (tokenization and masking changes this order) + # TODO: does this work with batch_size > 1 + if len(idxs_inv) > 0: + pred = pred[:, idxs_inv] + target = target[idxs_inv] + targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] + targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] + f32 = torch.float32 preds_all[fstep][i_strm] += [ np.asarray(dn_data(i_strm, pred.to(f32)).detach().cpu()) From ae5a2e6d574ef3391d878573bcab781a4fb3c3b4 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 17 Nov 2025 11:54:18 +0000 Subject: [PATCH 037/344] added file with ModelBatch and SampleMetadata dataclasses --- src/weathergen/datasets/inputs_metadata.py | 124 +++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 src/weathergen/datasets/inputs_metadata.py diff --git a/src/weathergen/datasets/inputs_metadata.py b/src/weathergen/datasets/inputs_metadata.py new file mode 100644 index 000000000..5a30b9462 --- /dev/null +++ b/src/weathergen/datasets/inputs_metadata.py @@ -0,0 +1,124 @@ +""" +Data structures for student-teacher multi-view training. + +Provides clean separation between: + - Model data (StreamData objects containing tensors) + - View metadata (spatial masks, strategies, relationships) +""" + +from dataclasses import dataclass, field +import numpy as np +from typing import Optional +from weathergen.datasets.stream_data import StreamData +import torch + + +# TODO: Add a store for a random number for diffusion +# TODO: GetTimestep to get the timestep +# TODO: GetData: get the streamdata +# TODO: GetMetaData: then this gets the right rn for the timestep! + +@dataclass +class SampleMetadata: + """ + Metadata describing how a view was generated. + + This captures the spatial selection (which cells/tokens were kept), + the strategy used (random, healpix, etc.), and hierarchical parameters. + + Attributes: + view_id: Unique identifier (e.g., "teacher_global", "student_local_0") + keep_mask: Boolean array [num_healpix_cells] at data level indicating kept cells + strategy: Name of selection strategy ("random", "healpix_level_2", etc.) + healpix_level: HEALPix level for hierarchical selection (None if not applicable) + rate: Fraction of data kept (e.g., 0.5 = 50% kept); None if fixed count + parent_view_id: ID of the parent view this is a subset of (None for teacher) + """ + view_id: str + keep_mask: np.ndarray # [num_cells] bool at data level + strategy: str # e.g., "random", "healpix_level_2" + healpix_level: Optional[int] + rate: Optional[float] + parent_view_id: Optional[str] = None # For students: which teacher they belong to + + +# TODO: This doesn't handle the masking case, and we probably want it to, +# where the model_inputs are the correct data for the masked source (and target?). Or target becomes the target? +# Also should this model batch contain the source_cell_lens and target_coords_idx? +# Every sample is n different [streams]...each view is a different dictionary corresponding to one model input +# to get epsilon in there... +# batches is for parallelism, but needs to all be in a tensor... [b, n, dim_embedding]? [b x n, dim_embedding] + + +# NOTE: this only stores the student source_cell_lens and target_coords_idx, +# because the teacher ones are already provided separately in (model_batches, source_cell_lens, target_coords_idx, forecast_dt) + # ^^^^^^ teacher ones ^^^^^^ +# However, we should probably store them all here for consistency. This needs changes to the model, so not done now. +# The forecast_dt is provided separately? + +@dataclass +class ModelBatch: + """ + Container for all data and metadata for one training batch. + + - In forecast/masking: model_inputs=[streams_data], targets=[] + - In student_teacher: model_inputs=[student_views], targets=[teacher_streams] + + Attributes: + model_inputs: List of student views, each containing StreamData for all streams + targets: List containing teacher view with StreamData for all streams + view_metadata: List of ViewMetadata describing each view (teacher + students) + batch_info: Optional dict with batch-level info (sample indices, forecast steps, etc.) + student_source_cell_lens: List of source cell lengths for each student view + student_target_coords_idx: List of target coordinate indices for each student view + """ + # TODO: for DINO we want two global views per-dataset sample + # TODO: we want the global' view in student, perhaps as the first, + # with some metadata saying it is a second global view + + model_inputs: list[list[any]] # [n_students][n_streams] + targets: list[list[any]] # [1][n_streams] (teacher) + view_metadata: list[ViewMetadata] + batch_info: Optional[dict] = field(default_factory=dict) + + # Offsets for student views (populated when needed for future student-teacher training) + # TODO: rename to model_input...source_cell/target_coords... NOTE: then there is a problem for target + student_source_cell_lens: Optional[list] = None # [n_students] each is a tensor + student_target_coords_idx: Optional[list] = None # [n_students] each is a list of lists + + # TODO fix this ridiculous naming + # Placeholders for having ModelBatch giving the full (StreamData, source_cell_lens, target_coords_idx) + teacher_source_cell_lens: torch.Tensor | None = None + teacher_target_coords_idx: list | None = None + + # TODO: add the timestep as an optional int for the model_inputs when we have multiple timesteps for the diffusion model... + # TODO add the forecast_dt as an optional int ? + + def to_device(self, device): + """Move all StreamData objects to the specified device.""" + for student_view in self.model_inputs: + for stream_data in student_view: + stream_data.to_device(device) + + for teacher_batch in self.targets: + for stream_data in teacher_batch: + stream_data.to_device(device) + + # Move student offsets if they exist + if self.student_source_cell_lens is not None: + self.student_source_cell_lens = [ + lens.to(device) if isinstance(lens, torch.Tensor) else lens + for lens in self.student_source_cell_lens + ] + + if self.student_target_coords_idx is not None: + # This is list[list[list[tensor]]], need to move all tensors + self.student_target_coords_idx = [ + [ + [t.to(device) if isinstance(t, torch.Tensor) else t for t in stream] + for stream in student_idx + ] + for student_idx in self.student_target_coords_idx + ] + + return self \ No newline at end of file From 7f3c71891344508fc017b39ce9d3ca9d6a8e9735 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 17 Nov 2025 14:51:01 +0100 Subject: [PATCH 038/344] Updating config to working version --- config/default_config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index b5e2eef4d..680a6b7ab 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -44,6 +44,7 @@ forecast_delta_hrs: 0 forecast_steps: 0 forecast_policy: null forecast_att_dense_rate: 1.0 +forecast_with_step_conditioning: False fe_num_blocks: 0 fe_num_heads: 16 fe_dropout_rate: 0.1 @@ -88,13 +89,12 @@ validate_with_ema: True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 -# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" -masking_strategy: "forecast" - # include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" masking_strategy: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 +# +sampling_rate_target: 1.0 # sample the masking rate (with normal distribution centered at masking_rate) # note that a sampled masking rate leads to varying requirements masking_rate_sampling: True @@ -135,8 +135,8 @@ start_date: 201401010000 end_date: 202012310000 start_date_val: 202101010000 end_date_val: 202201010000 -len_hrs: 3 -step_hrs: 3 +len_hrs: 6 +step_hrs: 6 input_window_steps: 1 val_initial: False From 694d948381a8fc678387215128134bebb249d9a4 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Mon, 17 Nov 2025 14:53:27 +0000 Subject: [PATCH 039/344] Encapsulated encoder and target encoding for latent diffusion model loss --- config/default_config.yml | 23 +++++++++++++++++-- src/weathergen/model/model.py | 2 +- .../train/target_and_aux_diffusion.py | 18 +++++++++++++++ .../train/target_and_aux_module_base.py | 6 +++++ src/weathergen/train/trainer.py | 7 ++++-- src/weathergen/train/trainer_base.py | 10 ++++---- 6 files changed, 57 insertions(+), 9 deletions(-) create mode 100644 src/weathergen/train/target_and_aux_diffusion.py diff --git a/config/default_config.yml b/config/default_config.yml index fb01e5aba..b46f457d1 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -91,9 +91,28 @@ validate_with_ema: True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 -# training mode: "forecast" or "masking" (masked token modeling) +# training mode: "forecast" or "masking" (masked token modeling) or "student-teacher" # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" +training_mode_config: { + "losses" : { + # LossLatentSSLStudentTeacher: { + # "iBOT": {'weight': 0.5, "out_dim": 65536, "n_register_tokens": 4, "student_temp": 0.1,"teacher_temp": 0.1, + # "teacher_style": "softmax_center", "center_momentum": 0.9}, + # "DINO": {'weight': 0.5, "out_dim": 65536, "n_register_tokens": 4, "student_temp": 0.1,"teacher_temp": 0.1, + # "teacher_style": "softmax_center", "center_momentum": 0.9}, + # "JEPA": {'weight': 0.5, "out_dim": 2048, "n_register_tokens": 4} } + LossLatentDiffusionForecastEngine: { + "MSE": {'weight': 1.0} + } + }, + "shared_heads": False, + "target_and_aux_calc": "DiffusionLatentTargetEncoder", + "teacher_model": {} +} +# training_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +# } +validation_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],}} # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ceead1b8d..f6ea79d8f 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -613,7 +613,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca (streams_data, _, target_coords_idxs) = batch - tokens, posteriors = self.encode(self, model_params=model_params, batch=batch) + tokens, posteriors = self.encode(model_params=model_params, batch=batch) # roll-out in latent space preds_all = [] diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py new file mode 100644 index 000000000..620697474 --- /dev/null +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -0,0 +1,18 @@ +from typing import Any + +import torch + +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase + + +class DiffusionLatentTargetEncoder(TargetAndAuxModuleBase): + def __init__(self, model): + # Todo: make sure this is a frozen clone or forward without gradients in compute() + self.model = model + + def compute( + self, bidx, batch, model_params, model, forecast_offset, forecast_steps + ) -> tuple[Any, Any]: + with torch.no_grad(): + tokens, posteriors = self.model.encode(model_params=model_params, batch=batch) + return tokens, posteriors diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index 7cca5f7bc..224facd75 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -17,6 +17,9 @@ def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: def compute(self, *args, **kwargs) -> tuple[Any, Any]: pass + def to_device(self, device): + pass + class IdentityTargetAndAux(TargetAndAuxModuleBase): def __init__(self, model, rng, config): @@ -33,3 +36,6 @@ def update_state_post_opt_step(self, istep, batch, model, **kwargs): def compute(self, istep, batch, *args, **kwargs): return batch[0], None + + def to_device(self, device): + return diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ba6d442b3..bee6c288c 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -8,6 +8,7 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import itertools import logging import re @@ -43,14 +44,14 @@ ) from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP -from weathergen.model.model import get_model, Model, ModelParams +from weathergen.model.model import Model, ModelParams from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler from weathergen.train.trainer_base import TrainerBase, get_target_and_aux_calculator from weathergen.utils.distributed import all_gather_vlen, ddp_average, is_root from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger -from weathergen.utils.utils import get_dtype, get_batch_size +from weathergen.utils.utils import get_batch_size, get_dtype from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) @@ -352,6 +353,8 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): batch_size=get_batch_size(cf, self.world_size_original), ) + self.target_and_aux_calculator.to_device(self.device) + # if with_fsdp then parameter count is unreliable if is_root() and not cf.with_fsdp and not cf.with_ddp: self.model.print_num_parameters() diff --git a/src/weathergen/train/trainer_base.py b/src/weathergen/train/trainer_base.py index c75e3ce7e..d9d3d5b33 100644 --- a/src/weathergen/train/trainer_base.py +++ b/src/weathergen/train/trainer_base.py @@ -17,11 +17,11 @@ import torch.multiprocessing from weathergen.common.config import Config -from weathergen.train.utils import str_to_tensor, tensor_to_str -from weathergen.utils.distributed import is_root - +from weathergen.train.target_and_aux_diffusion import DiffusionLatentTargetEncoder from weathergen.train.target_and_aux_module_base import IdentityTargetAndAux from weathergen.train.target_and_aux_ssl_teacher import EMATeacher +from weathergen.train.utils import str_to_tensor, tensor_to_str +from weathergen.utils.distributed import is_root PORT = 1345 @@ -174,10 +174,12 @@ def get_perf(self): # should be moved to its own file so as to prevent cyclical imports def get_target_and_aux_calculator(config, model, rng, batch_size, **kwargs): - target_and_aux_calc = config.get("target_and_aux_calc", None) + target_and_aux_calc = config.get("training_mode_config", None).get("target_and_aux_calc", None) if target_and_aux_calc is None or target_and_aux_calc == "identity": return IdentityTargetAndAux(model, rng, config) elif target_and_aux_calc == "EMATeacher": return EMATeacher(model, rng, kwargs["ema_model"], batch_size) + elif target_and_aux_calc == "DiffusionLatentTargetEncoder": + return DiffusionLatentTargetEncoder(model) else: raise NotImplementedError(f"{target_and_aux_calc} is not implemented") From beb4d6f07eeb0fa84e07116485a8d043695fd6cf Mon Sep 17 00:00:00 2001 From: Jubeku Date: Mon, 17 Nov 2025 17:26:25 +0100 Subject: [PATCH 040/344] fix logging --- config/default_config.yml | 2 +- .../loss_modules/loss_module_physical.py | 62 +++++++++++----- src/weathergen/train/trainer.py | 73 +++++++++++-------- src/weathergen/utils/train_logger.py | 49 ++++++------- 4 files changed, 107 insertions(+), 79 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index e99d9f423..c101521da 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -94,7 +94,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" -training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 1.0]]}, +training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mse', 0.2]]}, LossPhysicalTwo: {weight: 0.3, loss_fcts: [['mse', 1.0]]}, } } diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 1e900f25f..817360706 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -50,7 +50,7 @@ def __init__( # Dynamically load loss functions based on configuration and stage self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w, name] for name, w in loss_fcts ] @@ -195,14 +195,19 @@ def compute_loss( # initialize dictionaries for detailed loss tracking and standard deviation statistics # create tensor for each stream losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + f"{self.name}.{st.name}.{loss_fct_name}": torch.zeros( + (len(st[str(self.stage) + "_target_channels"])), device=self.device, ) for st in self.cf.streams + for _, _, loss_fct_name in self.loss_fcts } stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + f"{self.name}.{st.name}.{loss_fct_name}": torch.zeros( + len(stat_loss_fcts), device=self.device + ) + for st in self.cf.streams + for _, _, loss_fct_name in self.loss_fcts } # TODO: iterate over batch dimension @@ -252,7 +257,7 @@ def compute_loss( # accumulate loss from different loss functions loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: # loss for current loss function loss_lfct, loss_lfct_chs = self._loss_per_loss_function( loss_fct, @@ -262,7 +267,9 @@ def compute_loss( weights_channels, weights_locations, ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + losses_all[f"{self.name}.{stream_info.name}.{loss_fct_name}"] += ( + spoof_weight * loss_lfct_chs + ) # Add the weighted and normalized loss from this loss function to the total # batch loss @@ -278,12 +285,16 @@ def compute_loss( ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + if ctr_fsteps > 0: + for _, _, loss_fct_name in self.loss_fcts: + losses_all[f"{self.name}.{stream_info.name}.{loss_fct_name}"] /= ctr_fsteps + stddev_all[f"{self.name}.{stream_info.name}.{loss_fct_name}"] /= ctr_fsteps # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + for _, _, loss_fct_name in self.loss_fcts: + key = f"{self.name}.{stream_info.name}.{loss_fct_name}" + losses_all[key][losses_all[key] == 0.0] = torch.nan + stddev_all[key][stddev_all[key] == 0.0] = torch.nan # normalize by all targets and forecast steps that were non-empty # (with each having an expected loss of 1 for an uninitalized neural net) @@ -319,7 +330,7 @@ def __init__( # Dynamically load loss functions based on configuration and stage self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w, name] for name, w in loss_fcts ] @@ -464,14 +475,19 @@ def compute_loss( # initialize dictionaries for detailed loss tracking and standard deviation statistics # create tensor for each stream losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + f"{self.name}.{st.name}.{loss_fct_name}": torch.zeros( + (len(st[str(self.stage) + "_target_channels"])), device=self.device, ) for st in self.cf.streams + for _, _, loss_fct_name in self.loss_fcts } stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + f"{self.name}.{st.name}.{loss_fct_name}": torch.zeros( + len(stat_loss_fcts), device=self.device + ) + for st in self.cf.streams + for _, _, loss_fct_name in self.loss_fcts } # TODO: iterate over batch dimension @@ -521,7 +537,7 @@ def compute_loss( # accumulate loss from different loss functions loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: # loss for current loss function loss_lfct, loss_lfct_chs = self._loss_per_loss_function( loss_fct, @@ -531,7 +547,9 @@ def compute_loss( weights_channels, weights_locations, ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + losses_all[f"{self.name}.{stream_info.name}.{loss_fct_name}"] += ( + spoof_weight * loss_lfct_chs + ) # Add the weighted and normalized loss from this loss function to the total # batch loss @@ -547,12 +565,16 @@ def compute_loss( ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + if ctr_fsteps > 0: + for _, _, loss_fct_name in self.loss_fcts: + losses_all[f"{self.name}.{stream_info.name}.{loss_fct_name}"] /= ctr_fsteps + stddev_all[f"{self.name}.{stream_info.name}.{loss_fct_name}"] /= ctr_fsteps # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + for _, _, loss_fct_name in self.loss_fcts: + key = f"{self.name}.{stream_info.name}.{loss_fct_name}" + losses_all[key][losses_all[key] == 0.0] = torch.nan + stddev_all[key][stddev_all[key] == 0.0] = torch.nan # normalize by all targets and forecast steps that were non-empty # (with each having an expected loss of 1 for an uninitalized neural net) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b1e5ad240..d580816ef 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -641,12 +641,22 @@ def train(self, epoch): ) if bidx == 0: - self.loss_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} - self.stdev_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.loss_unweighted_hist = { + calc_name: {loss_name: []} + for calc_name, calc_terms in loss_values.loss_terms.items() + for loss_name in calc_terms.losses_all.keys() + } + self.stdev_unweighted_hist = { + calc_name: {loss_name: []} + for calc_name, calc_terms in loss_values.loss_terms.items() + for loss_name in calc_terms.stddev_all.keys() + } self.loss_model_hist = [] - for name, loss_terms in loss_values.loss_terms.items(): - self.loss_unweighted_hist[name].append(loss_terms.losses_all) - self.stdev_unweighted_hist[name].append(loss_terms.stddev_all) + for calc_name, loss_terms in loss_values.loss_terms.items(): + for loss_name, losses_all in loss_terms.losses_all.items(): + self.loss_unweighted_hist[calc_name][loss_name].append(losses_all) + for loss_name, stddev_all in loss_terms.stddev_all.items(): + self.stdev_unweighted_hist[calc_name][loss_name].append(stddev_all) self.loss_model_hist += [loss_values.loss.item()] perf_gpu, perf_mem = self.get_perf() @@ -655,12 +665,23 @@ def train(self, epoch): # NEED TO FIX LOGGING self._log_terminal(bidx, epoch, TRAIN) - # if bidx % self.train_log_freq.metrics == 0: - # self._log(TRAIN) + if bidx % self.train_log_freq.metrics == 0: + self._log(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: self.save_model(-1) + self.loss_unweighted_hist = { + calc_name: {loss_name: []} + for calc_name, calc_terms in loss_values.loss_terms.items() + for loss_name in calc_terms.losses_all.keys() + } + self.stdev_unweighted_hist = { + calc_name: {loss_name: []} + for calc_name, calc_terms in loss_values.loss_terms.items() + for loss_name in calc_terms.stddev_all.keys() + } + self.loss_model_hist = [] self.cf.istep += 1 @@ -750,7 +771,7 @@ def validate(self, epoch): # NEED TO FIX LOGGING self._log_terminal(bidx, epoch, VAL) - # self._log(VAL) + self._log(VAL) # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() @@ -982,20 +1003,16 @@ def _prepare_losses_for_logging( # Gather all tensors from all ranks into a list and stack them into one tensor again real_loss = torch.cat(all_gather_vlen(real_loss)) - for calc_name in self.loss_unweighted_hist.keys(): + for calc_name, loss_terms in self.loss_unweighted_hist.items(): losses_all[calc_name] = {} + for loss_name, losses in loss_terms.items(): + losses = torch.stack(losses).to(torch.float64) + losses_all[calc_name][loss_name] = torch.cat(all_gather_vlen(losses)) + for calc_name, stddev_terms in self.stdev_unweighted_hist.items(): stddev_all[calc_name] = {} - for stream in self.cf.streams: # Loop over all streams - stream_hist = [ - losses[stream.name] for losses in self.loss_unweighted_hist[calc_name] - ] - stream_all = torch.stack(stream_hist).to(torch.float64) - losses_all[calc_name][stream.name] = torch.cat(all_gather_vlen(stream_all)) - stream_hist = [ - stddevs[stream.name] for stddevs in self.stdev_unweighted_hist[calc_name] - ] - stream_all = torch.stack(stream_hist).to(torch.float64) - stddev_all[calc_name][stream.name] = torch.cat(all_gather_vlen(stream_all)) + for stddev_name, stddevs in stddev_terms.items(): + stddevs = torch.stack(stddevs).to(torch.float64) + stddev_all[calc_name][stddev_name] = torch.cat(all_gather_vlen(stddevs)) return real_loss, losses_all, stddev_all @@ -1030,8 +1047,6 @@ def _log(self, stage: Stage): self.perf_mem, ) - self.loss_unweighted_hist, self.loss_model_hist = [], [] - def _get_tensor_item(self, tensor): """ When using FSDP2, tensor is a DTensor and we need full_tensor().item() instead of .item(), @@ -1063,11 +1078,10 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): logger.info( f"validation ({self.cf.run_id}) : {epoch:03d} : {avg_loss.nanmean().item()}" ) - for calc_name, losses in losses_all.items(): - for _, st in enumerate(self.cf.streams): + for _, losses in losses_all.items(): + for loss_name, loss in losses.items(): logger.info( - f"{calc_name}.{st['name']}" - + f" : {losses[st['name']].nanmean():0.4E} \t", + f"{loss_name}" + f" : {loss.nanmean():0.4E} \t", ) logger.info("\n") @@ -1085,11 +1099,10 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): pstr += f"s/sec={(print_freq * self.cf.batch_size_per_gpu) / dt:.3f})" logger.info(pstr) logger.info("\t") - for calc_name, losses in losses_all.items(): - for _, st in enumerate(self.cf.streams): + for _, losses in losses_all.items(): + for loss_name, loss in losses.items(): logger.info( - f"{calc_name}.{st['name']}" - + f" : {losses[st['name']].nanmean():0.4E} \t", + f"{loss_name}" + f" : {loss.nanmean():0.4E} \t", ) logger.info("\n") diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 4e9229a72..1df6930d4 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -119,22 +119,17 @@ def add_train( log_vals += [avg_loss.nanmean().item()] log_vals += [lr] - for st in self.cf.streams: - loss = losses_all[st["name"]] - stddev = stddev_all[st["name"]] - - for j, (lf_name, _) in enumerate(self.cf.loss_fcts): - metrics[_key_loss(st["name"], lf_name)] = loss[:, :, j].nanmean().item() - + st = self.cf.streams[0] + for _, loss_terms in losses_all.items(): + for loss_name, losses in loss_terms.items(): + metrics[f"{loss_name}.loss_avg"] = losses[:, :].nanmean().item() for k, ch_n in enumerate(st.train_target_channels): - metrics[_key_loss_chn(st["name"], lf_name, ch_n)] = ( - loss[:, k, j].nanmean().item() - ) - log_vals += [loss[:, :, j].nanmean().item()] - - metrics[_key_stddev(st["name"])] = stddev.nanmean().item() - - log_vals += [stddev.nanmean().item()] + metrics[f"{loss_name}.loss_{ch_n}"] = losses[:, k].nanmean().item() + log_vals += [losses[:, :].nanmean().item()] + for _, stddev_terms in stddev_all.items(): + for loss_name, stddev in stddev_terms.items(): + metrics[f"{loss_name}.stddev_avg"] = stddev.nanmean().item() + log_vals += [stddev.nanmean().item()] with open(self.path_run / f"{self.cf.run_id}_train_log.txt", "ab") as f: np.savetxt(f, log_vals) @@ -161,19 +156,17 @@ def add_val( log_vals: list[float] = [int(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))] log_vals += [samples] - for st in self.cf.streams: - loss = losses_all[st["name"]] - stddev = stddev_all[st["name"]] - for j, (lf_name, _) in enumerate(self.cf.loss_fcts_val): - metrics[_key_loss(st["name"], lf_name)] = loss[:, :, j].nanmean().item() - for k, ch_n in enumerate(st.val_target_channels): - metrics[_key_loss_chn(st["name"], lf_name, ch_n)] = ( - loss[:, k, j].nanmean().item() - ) - log_vals += [loss[:, :, j].nanmean().item()] - - metrics[_key_stddev(st["name"])] = stddev.nanmean().item() - log_vals += [stddev.nanmean().item()] + st = self.cf.streams[0] + for _, loss_terms in losses_all.items(): + for loss_name, losses in loss_terms.items(): + metrics[f"{loss_name}.loss_avg"] = losses[:, :].nanmean().item() + for k, ch_n in enumerate(st.train_target_channels): + metrics[f"{loss_name}.loss_{ch_n}"] = losses[:, k].nanmean().item() + log_vals += [losses[:, :].nanmean().item()] + for _, stddev_terms in stddev_all.items(): + for loss_name, stddev in stddev_terms.items(): + metrics[f"{loss_name}.stddev_avg"] = stddev.nanmean().item() + log_vals += [stddev.nanmean().item()] self.log_metrics("val", metrics) with open(self.path_run / (self.cf.run_id + "_val_log.txt"), "ab") as f: From 761e26302b83a7814b4c66533ee510ff58443e5d Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 17 Nov 2025 18:13:57 +0000 Subject: [PATCH 041/344] update ViewMetadata spec --- src/weathergen/datasets/inputs_metadata.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/inputs_metadata.py b/src/weathergen/datasets/inputs_metadata.py index 5a30b9462..33d88a3eb 100644 --- a/src/weathergen/datasets/inputs_metadata.py +++ b/src/weathergen/datasets/inputs_metadata.py @@ -19,7 +19,7 @@ # TODO: GetMetaData: then this gets the right rn for the timestep! @dataclass -class SampleMetadata: +class ViewMetadata: """ Metadata describing how a view was generated. @@ -34,12 +34,14 @@ class SampleMetadata: rate: Fraction of data kept (e.g., 0.5 = 50% kept); None if fixed count parent_view_id: ID of the parent view this is a subset of (None for teacher) """ + + loss_type: str # DINO, JEPA... ? + strategy: str # "cropping", "masking", "forecasting", "forecasting_diffusion" + strategy_config: dict # rate: 0.5 etc., healpix_level: int etc., overlap: "disjoint" etc., view_id: str - keep_mask: np.ndarray # [num_cells] bool at data level - strategy: str # e.g., "random", "healpix_level_2" - healpix_level: Optional[int] - rate: Optional[float] parent_view_id: Optional[str] = None # For students: which teacher they belong to + keep_mask: np.ndarray # [num_cells] bool at data level + # TODO: This doesn't handle the masking case, and we probably want it to, From 047b29947673dc7749e15870483e71940b2118fb Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 17 Nov 2025 18:19:56 +0000 Subject: [PATCH 042/344] draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice --- config/default_config.yml | 2 +- src/weathergen/datasets/masking.py | 105 +++++++++++++++++ .../datasets/multi_stream_data_sampler.py | 2 +- src/weathergen/datasets/tokenizer_masking.py | 106 +++++++++++++++++- src/weathergen/datasets/view_builder.py | 97 ++++++++++++++++ 5 files changed, 304 insertions(+), 8 deletions(-) create mode 100644 src/weathergen/datasets/view_builder.py diff --git a/config/default_config.yml b/config/default_config.yml index 680a6b7ab..b12cc4e84 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -156,4 +156,4 @@ run_id: ??? train_log_freq: terminal: 10 metrics: 20 - checkpoint: 250 + checkpoint: 250 \ No newline at end of file diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 6d91b569d..dfb4002be 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -141,6 +141,7 @@ def mask_source_idxs( idxs_cells, idxs_cells_lens, rdata, + keep_mask: np.ndarray | None = None, ) -> (torch.Tensor, torch.Tensor): """ @@ -156,6 +157,28 @@ def mask_source_idxs( if num_tokens == 0: return (self.mask_tokens, self.mask_channels) + # If an explicit keep_mask is provided we bypass strategy selection and directly + # construct the token-level mask from it. keep_mask expresses cells to KEEP (True=keep). + # Otherwise fall back to the configured strategy logic. + if keep_mask is not None: + assert len(keep_mask) == len(idxs_cells_lens), ( + f"keep_mask length {len(keep_mask)} does not match number of cells {len(idxs_cells_lens)}" + ) + # build token level mask: for each cell replicate the keep flag across its tokens + token_level_flags: list[np.ndarray] = [] + for km, lens_cell in zip(keep_mask, idxs_cells_lens, strict=True): + num_tokens_cell = len(lens_cell) + if num_tokens_cell == 0: + continue + token_level_flags.append( + np.ones(num_tokens_cell, dtype=bool) if km else np.zeros(num_tokens_cell, dtype=bool) + ) + if token_level_flags: + self.mask_tokens = np.concatenate(token_level_flags) + else: + self.mask_tokens = np.array([], dtype=bool) + return (self.mask_tokens, self.mask_channels) + # clean strategy selection self.current_strategy = self._select_strategy() @@ -600,3 +623,85 @@ def _generate_causal_mask( ] return full_mask + + # --------------------------------------------------------------------- + # Cell-level keep mask generation (teacher/student view selection) + # --------------------------------------------------------------------- + def generate_cell_keep_mask( + self, + num_cells: int, + strategy: str | None = None, + rate: float | None = None, + masking_strategy_config: dict | None = None, + constraint_keep_mask: np.ndarray | None = None, + ) -> np.ndarray: + """Generate a boolean keep mask at data healpix level (True = keep cell). + + Parameters + ---------- + num_cells : int + Number of cells at data level (should equal 12 * 4**healpix_level). + strategy : str | None + Cell selection strategy: currently supports 'random' and 'healpix'. Uses + instance default if None. + rate : float | None + Fraction of parent cells (healpix) or data cells (random) to keep. Falls back + to instance masking_rate if None. + masking_strategy_config : dict | None + Optional override of strategy config (e.g., {'hl_mask': 3}). + constraint_keep_mask : np.ndarray | None + Optional boolean mask of allowed cells (True = allowed). Selection will be + limited to these cells. For subset/disjoint relationships. + + Returns + ------- + np.ndarray + Boolean array of shape [num_cells] where True indicates the cell is kept. + """ + strat = strategy or self.masking_strategy + cfg = masking_strategy_config or self.masking_strategy_config + keep_rate = rate if rate is not None else self.masking_rate + + # sample rate if requested (only if explicit rate not provided) + if rate is None and self.masking_rate_sampling: + keep_rate = self._get_sampling_rate() + + assert 0.0 <= keep_rate <= 1.0, f"keep_rate out of bounds: {keep_rate}" + assert num_cells == self.healpix_num_cells, ( + f"num_cells={num_cells} inconsistent with configured healpix level ({self.healpix_num_cells})." + ) + + if strat not in {"random", "healpix"}: + raise NotImplementedError( + f"Cell selection strategy '{strat}' not supported for keep mask generation." + ) + + if strat == "random": + base_mask = self.rng.uniform(0, 1, num_cells) < keep_rate + else: # healpix hierarchical selection + hl_data = self.healpix_level_data + hl_mask = cfg.get("hl_mask") + assert hl_mask is not None and hl_mask < hl_data, ( + "For healpix keep mask generation, cfg['hl_mask'] must be set and < data level.") + num_parent_cells = 12 * (4**hl_mask) + level_diff = hl_data - hl_mask + num_children_per_parent = 4**level_diff + # number of parents to KEEP + num_parents_to_keep = int(np.round(keep_rate * num_parent_cells)) + if num_parents_to_keep == 0: + base_mask = np.zeros(num_cells, dtype=bool) + else: + parent_ids = self.rng.choice(num_parent_cells, num_parents_to_keep, replace=False) + child_offsets = np.arange(num_children_per_parent) + child_indices = ( + parent_ids[:, None] * num_children_per_parent + child_offsets + ).reshape(-1) + base_mask = np.zeros(num_cells, dtype=bool) + base_mask[child_indices] = True + + # apply constraint if provided (only keep those cells within allowed) + if constraint_keep_mask is not None: + assert constraint_keep_mask.shape[0] == num_cells, "constraint_keep_mask wrong shape" + base_mask = base_mask & constraint_keep_mask + + return base_mask diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 55e4124c3..4b68e8f4d 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -373,7 +373,7 @@ def __iter__(self): stream_data.source_is_spoof = True # preprocess data for model input - (ss_cells, ss_lens) = self.tokenizer.batchify_source( + ss_cells, ss_lens, _mask_state = self.tokenizer.batchify_source( stream_info, readerdata_to_torch(rdata), (time_win_source.start, time_win_source.end), diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index dcf81d394..9c103126b 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -13,6 +13,8 @@ from weathergen.common.io import IOReaderData from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer +from weathergen.datasets.view_builder import build_views_for_stream +from weathergen.datasets.inputs_metadata import ViewMetadata from weathergen.datasets.tokenizer_utils import ( encode_times_source, encode_times_target, @@ -27,6 +29,8 @@ class TokenizerMasking(Tokenizer): def __init__(self, healpix_level: int, masker: Masker): super().__init__(healpix_level) self.masker = masker + # cache last built view metadata per stream invocation (optional downstream use) + self._last_view_metadata: list[ViewMetadata] | None = None def reset_rng(self, rng) -> None: """ @@ -40,6 +44,7 @@ def batchify_source( stream_info: dict, rdata: IOReaderData, time_win: tuple, + keep_mask: torch.Tensor | None = None, ): token_size = stream_info["token_size"] stream_id = stream_info["stream_id"] @@ -50,8 +55,8 @@ def batchify_source( if is_diagnostic or rdata.data.shape[1] == 0 or len(rdata.data) < 2: source_tokens_cells = [torch.tensor([])] source_tokens_lens = torch.zeros([self.num_healpix_cells_source], dtype=torch.int32) - source_centroids = [torch.tensor([])] - return (source_tokens_cells, source_tokens_lens, source_centroids) + mask_state = {"strategy": self.masker.current_strategy, "mask_tokens": None, "mask_channels": None} + return (source_tokens_cells, source_tokens_lens, mask_state) # create tokenization index tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space @@ -59,9 +64,16 @@ def batchify_source( # select strategy from XXX depending on stream and if student or teacher - (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - stream_info, idxs_cells, idxs_cells_lens, rdata - ) + # Optional per-cell keep_mask (boolean) converts to numpy for Masker override. + if keep_mask is not None: + keep_np = keep_mask.cpu().numpy().astype(bool) + (mask_tokens, mask_channels) = self.masker.mask_source_idxs( + stream_info, idxs_cells, idxs_cells_lens, rdata, keep_mask=keep_np + ) + else: + (mask_tokens, mask_channels) = self.masker.mask_source_idxs( + stream_info, idxs_cells, idxs_cells_lens, rdata + ) source_tokens_cells, source_tokens_lens = tokenize_apply_mask_source( idxs_cells, @@ -75,7 +87,15 @@ def batchify_source( encode_times_source, ) - return (source_tokens_cells, source_tokens_lens) + # capture per-view mask state to later produce consistent targets + mask_state = { + "strategy": self.masker.current_strategy, + "mask_tokens": mask_tokens, + "mask_channels": mask_channels, + } + return (source_tokens_cells, source_tokens_lens, mask_state) + + # batchify_target_for_view now unified into batchify_target via optional mask_state def batchify_target( self, @@ -83,6 +103,7 @@ def batchify_target( sampling_rate_target: float, rdata: IOReaderData, time_win: tuple, + mask_state: dict | None = None, ): token_size = stream_info["token_size"] @@ -90,6 +111,12 @@ def batchify_target( tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) + # Apply per-view mask state if provided + if mask_state is not None: + self.masker.current_strategy = mask_state.get("strategy", self.masker.masking_strategy) + self.masker.mask_tokens = mask_state.get("mask_tokens") + self.masker.mask_channels = mask_state.get("mask_channels") + (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( stream_info, idxs_cells, idxs_cells_lens, rdata ) @@ -113,6 +140,73 @@ def batchify_target( return (data, datetimes, coords, coords_local, coords_per_cell, idxs_ord_inv) + + # ------------------------------------------------------------------ + # Per-stream view construction (teacher + students) for student-teacher + # ------------------------------------------------------------------ + def build_stream_views( + self, + stream_info: dict, + rdata: IOReaderData, + time_win: tuple, + training_cfg: dict | None = None, + ): + """Construct teacher and student views for a single stream. + + Parameters + ---------- + stream_info : dict + Stream configuration dictionary. + rdata : IOReaderData + Combined reader data for this stream. + time_win : tuple + (start, end) datetime window. + training_cfg : dict | None + cf.training_config section; if absent or mode != 'student_teacher', fallback to single view. + + Returns + ------- + teacher : tuple | None + (tokens_cells, tokens_lens) for teacher or None when not student_teacher. + students : list + List of (tokens_cells, tokens_lens) for each student view (or single masking view). + view_metadata : list[ViewMetadata] | None + Metadata for teacher + students when in student_teacher mode. + """ + if training_cfg is None or training_cfg.get("training_mode") != "student_teacher": + # Standard masking path: single view only (treated as 'student' for uniformity) + scells, slens, _mask_state = self.batchify_source(stream_info, rdata, time_win) + self._last_view_metadata = None + return None, [(scells, slens, _mask_state)], None + + teacher_cfg = training_cfg.get("teacher_model_input", {}) + student_cfg = training_cfg.get("model_input", {}) + relationship = student_cfg.get("relationship", "subset") + + num_cells = self.num_healpix_cells_source + teacher_keep_mask, student_keep_masks, view_meta = build_views_for_stream( + self.masker, num_cells, teacher_cfg, student_cfg, relationship + ) + + # Convert keep masks to torch tensors for downstream masking override + teacher_keep_mask_t = torch.from_numpy(teacher_keep_mask) + student_keep_masks_t = [torch.from_numpy(m) for m in student_keep_masks] + + # Teacher tokens + t_cells, t_lens, t_mask_state = self.batchify_source( + stream_info, rdata, time_win, keep_mask=teacher_keep_mask_t + ) + # Student tokens + student_tokens = [ + self.batchify_source(stream_info, rdata, time_win, keep_mask=km) + for km in student_keep_masks_t + ] + # add mask_state inside each tuple + student_tokens = [(cells, lens, mstate) for (cells, lens, mstate) in student_tokens] + + self._last_view_metadata = view_meta + return (t_cells, t_lens, t_mask_state), student_tokens, view_meta + def sample_tensors_uniform_vectorized( self, tensor_list: list, lengths: list, max_total_points: int ): diff --git a/src/weathergen/datasets/view_builder.py b/src/weathergen/datasets/view_builder.py new file mode 100644 index 000000000..cd7228b49 --- /dev/null +++ b/src/weathergen/datasets/view_builder.py @@ -0,0 +1,97 @@ +import numpy as np +from typing import Tuple, List +from weathergen.datasets.masking import Masker +from weathergen.datasets.inputs_metadata import ViewMetadata + + +def build_views_for_stream( + masker: Masker, + num_cells: int, + teacher_cfg: dict, + student_cfg: dict, + relationship: str = "subset", +) -> Tuple[np.ndarray, List[np.ndarray], List[ViewMetadata]]: + """ + + Per-stream view construction: teacher + N student keep masks. + + Parameters + ---------- + masker : Masker + Instance providing RNG and healpix-level info. + num_cells : int + Number of healpix cells at data level. + teacher_cfg : dict + Config: {strategy, rate|keep_m, hl_mask, masking_strategy_config, rate_sampling}. + student_cfg : dict + Config: {masking_strategy, rate, num_views, hl_mask, masking_strategy_config, rate_sampling}. + relationship : str + One of {'subset','disjoint','independent'}. Determines derivation of student masks. + + Returns + ------- + teacher_keep_mask : np.ndarray + Boolean keep mask for teacher view. + student_keep_masks : list[np.ndarray] + Boolean keep masks for each student view. + metadata : list[ViewMetadata] + Metadata objects (teacher first, then students). + + """ + strat_teacher = teacher_cfg.get("strategy", "random") + rate_teacher = teacher_cfg.get("rate") + t_cfg_extra = teacher_cfg.get("masking_strategy_config") + + teacher_keep_mask = masker.generate_cell_keep_mask( + num_cells=num_cells, + strategy=strat_teacher, + rate=rate_teacher, + masking_strategy_config=t_cfg_extra, + ) + + # Student base masks + num_views = student_cfg.get("num_views", 1) + strat_student = student_cfg.get("masking_strategy", student_cfg.get("strategy", "random")) + rate_student = student_cfg.get("rate") + s_cfg_extra = student_cfg.get("masking_strategy_config") + + student_keep_masks: List[np.ndarray] = [] + for v in range(num_views): + base = masker.generate_cell_keep_mask( + num_cells=num_cells, + strategy=strat_student, + rate=rate_student, + masking_strategy_config=s_cfg_extra, + ) + if relationship == "subset": + keep = base & teacher_keep_mask + elif relationship == "disjoint": + keep = base & (~teacher_keep_mask) + else: # independent + keep = base + student_keep_masks.append(keep) + + metadata: List[ViewMetadata] = [] + metadata.append( + ViewMetadata( + view_id="teacher_global", + keep_mask=teacher_keep_mask, + strategy=strat_teacher, + healpix_level=masker.healpix_level_data, + rate=rate_teacher, + parent_view_id=None, + ) + ) + for i, m in enumerate(student_keep_masks): + metadata.append( + ViewMetadata( + view_id=f"student_local_{i}", + keep_mask=m, + strategy=strat_student, + healpix_level=masker.healpix_level_data, + rate=rate_student, + parent_view_id="teacher_global", + ) + ) + + return teacher_keep_mask, student_keep_masks, metadata From 7d5c3005c9f7455b5c8559c6a883074ee199cc27 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 17 Nov 2025 18:22:33 +0000 Subject: [PATCH 043/344] draft of training_config in default_config --- config/default_config.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/config/default_config.yml b/config/default_config.yml index b12cc4e84..1f0a25810 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -108,6 +108,35 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } +# Student-teacher configuration (only used when training_mode == "student_teacher") +training_config: + # when this is "masking", we are basically only using the model_input subconfig + training_mode: "student_teacher" # "masking", "student_teacher", "forecast" + + model_input: + masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher + rate: 0.5 # Masking rate to use for model input + num_views: 4 # if student-teacher, the number of local (student) views to generate + masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + "probabilities": [0.34, 0.33, 0.33], + "hl_mask": 0, "mode": "per_cell", + "same_strategy_per_batch": false + } + + teacher_model_input: + strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" + rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + # keep_m: 100 # Alternative to rate: keep exactly this many parent cells + rate_sampling: true # randomly sample the rate per batch + masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + "probabilities": [0.34, 0.33, 0.33], + "hl_mask": 4, "mode": "per_cell", + "same_strategy_per_batch": false + } + + + + num_epochs: 32 samples_per_epoch: 4096 samples_per_validation: 512 From c7332802b4b7864dbc7ae491794c5b45cdd073ec Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 17 Nov 2025 18:32:40 +0000 Subject: [PATCH 044/344] change view_metadata to dict in ModelInput --- src/weathergen/datasets/inputs_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/inputs_metadata.py b/src/weathergen/datasets/inputs_metadata.py index 33d88a3eb..2143163a1 100644 --- a/src/weathergen/datasets/inputs_metadata.py +++ b/src/weathergen/datasets/inputs_metadata.py @@ -80,7 +80,7 @@ class ModelBatch: model_inputs: list[list[any]] # [n_students][n_streams] targets: list[list[any]] # [1][n_streams] (teacher) - view_metadata: list[ViewMetadata] + view_metadata: dict[str, ViewMetadata] # perhaps dict, teacher_metadata : ViewMetadata, student_metadata: list[ViewMetadata] batch_info: Optional[dict] = field(default_factory=dict) # Offsets for student views (populated when needed for future student-teacher training) From a934f9740a53053768d8e59013896faab80a7773 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 18 Nov 2025 09:58:19 +0100 Subject: [PATCH 045/344] NOT WORKING: updating class to handle multiple input steps and improving overall structure --- .../datasets/multi_stream_data_sampler.py | 201 ++++++++++++------ 1 file changed, 131 insertions(+), 70 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 55e4124c3..cb986a33c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -34,6 +34,7 @@ ) from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage +from weathergen.datasets.inputs_metadata import ModelBatch type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs @@ -99,6 +100,8 @@ def __init__( self.mask_value = 0.0 self._stage = stage + self.num_input_steps = cf.get( "num_input_steps", 1) + self.len_hrs: int = cf.len_hrs self.step_hrs: int = cf.step_hrs self.time_window_handler = TimeWindowHandler(start_date, end_date, cf.len_hrs, cf.step_hrs) @@ -302,16 +305,131 @@ def reset(self): self.tokenizer.reset_rng(self.rng) - ################################################### def denormalize_source_channels(self, stream_id, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here return self.streams_datasets[stream_id][0].denormalize_source_channels(data) - ################################################### def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here return self.streams_datasets[stream_id][0].denormalize_target_channels(data) + def _build_stream_data_source( + self, + stream_data : StreamData, + base_idx: TIndex, + forecast_dt: int, + # view_meta: ViewMetadata, + stream_info: dict, + stream_ds: list, + ) -> StreamData: + """ + Return one batch of data + Build a StreamData object for a single view (teacher or student). + + Args: + stream_data : + base_idx: Time index for this sample + forecast_dt: Number of forecast steps + view_meta: ViewMetadata describing spatial mask + stream_info: Stream configuration dict + stream_ds: List of dataset readers for this stream + + Returns: + StreamData with source and targets masked according to view_meta + """ + + # iterate overall input steps + for step, idx in enumerate( range( base_idx, base_idx-self.num_input_steps, -1)) : + + # TODO: check that we are not out of bounds when we go back in time + + time_win_source = self.time_window_handler.window(idx) + + # collect all targets for current stream + rdata: IOReaderData = collect_datasources(stream_ds, idx, "source") + + if rdata.is_empty(): + # work around for https://github.com/pytorch/pytorch/issues/158719 + # create non-empty mean data instead of empty tensor + rdata = spoof( + self.healpix_level, + time_win_source.start, + stream_ds[0].get_geoinfo_size(), + stream_ds[0].mean[stream_ds[0].source_idx], + ) + stream_data.source_is_spoof = True + + # preprocess data for model input + (ss_cells, ss_lens) = self.tokenizer.batchify_source( + stream_info, + readerdata_to_torch(rdata), + (time_win_source.start, time_win_source.end), + ) + + # collect data for stream + stream_data.add_source( step, rdata, ss_lens, ss_cells) + + return stream_data + + def _build_stream_data_target( + self, + stream_data : StreamData, + idx: TIndex, + forecast_dt: int, + # view_meta: ViewMetadata, + stream_info: dict, + stream_ds: list, + ) -> StreamData : + + # collect for all forecast steps + for fstep in range( + self.forecast_offset, self.forecast_offset + forecast_dt + 1 + ): + step_forecast_dt = idx + (self.forecast_delta_hrs * fstep) // self.step_hrs + time_win_target = self.time_window_handler.window(step_forecast_dt) + + # collect all targets for current stream + rdata: IOReaderData = collect_datasources( + stream_ds, step_forecast_dt, "target" + ) + + if rdata.is_empty(): + # work around for https://github.com/pytorch/pytorch/issues/158719 + # create non-empty mean data instead of empty tensor + rdata = spoof( + self.healpix_level, + time_win_target.start, + stream_ds[0].get_geoinfo_size(), + stream_ds[0].mean[stream_ds[0].target_idx], + ) + stream_data.target_is_spoof = True + + # preprocess data for model input + (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( + stream_info, + self.sampling_rate_target, + readerdata_to_torch(rdata), + (time_win_target.start, time_win_target.end), + ) + + stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) + + return stream_data + + def _preprocess_model_data( self, batch) : + + # aggregated lens of tokens per cell across input batch samples + source_cell_lens = compute_source_cell_lens(batch, self.num_input_steps) + + # compute offsets for scatter computation after embedding + batch = compute_offsets_scatter_embed(batch, self.num_input_steps) + + # compute offsets and auxiliary data needed for prediction computation + # (info is not per stream so separate data structure) + target_coords_idx = compute_idxs_predict(self.forecast_offset + forecast_dt, batch) + + return batch, source_cell_lens, target_coords_idx + ################################################### def __iter__(self): """ @@ -354,70 +472,22 @@ def __iter__(self): # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + stream_data = StreamData( idx, forecast_dt + self.forecast_offset, self.num_healpix_cells ) - - # collect all targets for current stream - rdata: IOReaderData = collect_datasources(stream_ds, idx, "source") - - if rdata.is_empty(): - # work around for https://github.com/pytorch/pytorch/issues/158719 - # create non-empty mean data instead of empty tensor - rdata = spoof( - self.healpix_level, - time_win_source.start, - stream_ds[0].get_geoinfo_size(), - stream_ds[0].mean[stream_ds[0].source_idx], - ) - stream_data.source_is_spoof = True - - # preprocess data for model input - (ss_cells, ss_lens) = self.tokenizer.batchify_source( - stream_info, - readerdata_to_torch(rdata), - (time_win_source.start, time_win_source.end), + + # collect source data for current stream + stream_data = self._build_stream_data_source( + stream_data, idx, forecast_dt, stream_info, stream_ds ) - # collect data for stream - stream_data.add_source(rdata, ss_lens, ss_cells) - - # target - - # collect for all forecast steps - for fstep in range( - self.forecast_offset, self.forecast_offset + forecast_dt + 1 - ): - step_forecast_dt = idx + (self.forecast_delta_hrs * fstep) // self.step_hrs - time_win_target = self.time_window_handler.window(step_forecast_dt) - - # collect all targets for current stream - rdata: IOReaderData = collect_datasources( - stream_ds, step_forecast_dt, "target" - ) - - if rdata.is_empty(): - # work around for https://github.com/pytorch/pytorch/issues/158719 - # create non-empty mean data instead of empty tensor - rdata = spoof( - self.healpix_level, - time_win_target.start, - stream_ds[0].get_geoinfo_size(), - stream_ds[0].mean[stream_ds[0].target_idx], - ) - stream_data.target_is_spoof = True - - # preprocess data for model input - (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( - stream_info, - self.sampling_rate_target, - readerdata_to_torch(rdata), - (time_win_target.start, time_win_target.end), - ) - - stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) + # collect target data for current stream + stream_data = self._build_stream_data_target( + stream_data, idx, forecast_dt, stream_info, stream_ds + ) - # merge inputs for sources and targets for current stream + # add data for current stream streams_data += [stream_data] # Reset masking strategy for next batch item @@ -428,17 +498,8 @@ def __iter__(self): if not (all(s.empty() or s.target_empty() for s in streams_data)): batch += [streams_data] - # aggregated lens of tokens per cell - source_cell_lens = compute_source_cell_lens(batch) - - # compute offsets for scatter computation after embedding - batch = compute_offsets_scatter_embed(batch) - - # compute offsets and auxiliary data needed for prediction computation - # (info is not per stream so separate data structure) - target_coords_idx = compute_idxs_predict(self.forecast_offset + forecast_dt, batch) + batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch) - assert len(batch) == self.batch_size yield (batch, source_cell_lens, target_coords_idx, forecast_dt) ################################################### From 086aacbc00fc0658db5b0a934004182e45be2cdd Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 18 Nov 2025 09:27:32 +0000 Subject: [PATCH 046/344] Linter --- src/weathergen/train/target_and_aux_ssl_teacher.py | 14 ++++++++------ src/weathergen/utils/utils.py | 3 ++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index ba008ae50..d0d6a443c 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -6,7 +6,7 @@ class EMATeacher(TargetAndAuxModuleBase): def __init__(self, model, rng, ema_model, batch_size, **kwargs): # One of the issues is that the teacher model may have a different architecture - # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the + # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the # the teacher. Because of the device sharding etc that requires quite a bit of # massaging we assume that the teacher creates the EMA model correctly. However, # note that you cannot assume that model.state_dict equals ema_model.state_dict @@ -15,7 +15,7 @@ def __init__(self, model, rng, ema_model, batch_size, **kwargs): self.reset() - def reset(self, batch_size = None): + def reset(self, batch_size=None): self.ema_model.reset() if batch_size is not None: self.batch_size = batch_size @@ -26,7 +26,9 @@ def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: self.ema_model.update(istep, self.batch_size) - def compute(self, bidx, batch, model_params, model, forecast_offset, forecast_steps) -> tuple[Any, Any]: - return self.ema_model.forward_eval(model_params, batch, forecast_offset, forecast_steps), None - - + def compute( + self, bidx, batch, model_params, model, forecast_offset, forecast_steps + ) -> tuple[Any, Any]: + return self.ema_model.forward_eval( + model_params, batch, forecast_offset, forecast_steps + ), None diff --git a/src/weathergen/utils/utils.py b/src/weathergen/utils/utils.py index 1e0fed42c..c84f2d298 100644 --- a/src/weathergen/utils/utils.py +++ b/src/weathergen/utils/utils.py @@ -11,6 +11,7 @@ from weathergen.common.config import Config + def get_dtype(value: str) -> torch.dtype: """ changes the conf value to a torch dtype @@ -28,4 +29,4 @@ def get_dtype(value: str) -> torch.dtype: def get_batch_size(cf: Config, world_size: int) -> int: - return world_size * cf.batch_size_per_gpu + return world_size * cf.batch_size_per_gpu From c3b5c3bc27178de4f04d6fb603ce7054bb68e40a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 18 Nov 2025 12:02:17 +0100 Subject: [PATCH 047/344] Added basic support for multi-step sources. --- .../datasets/multi_stream_data_sampler.py | 49 ++++---- src/weathergen/datasets/stream_data.py | 34 ++--- src/weathergen/datasets/utils.py | 118 ++++++++++-------- src/weathergen/model/engines.py | 11 +- src/weathergen/model/model.py | 4 +- src/weathergen/train/trainer.py | 18 +-- 6 files changed, 125 insertions(+), 109 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index cb986a33c..266c9e0c7 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -34,7 +34,6 @@ ) from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage -from weathergen.datasets.inputs_metadata import ModelBatch type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs @@ -100,7 +99,7 @@ def __init__( self.mask_value = 0.0 self._stage = stage - self.num_input_steps = cf.get( "num_input_steps", 1) + self.num_input_steps = cf.get("num_input_steps", 1) self.len_hrs: int = cf.len_hrs self.step_hrs: int = cf.step_hrs @@ -315,7 +314,7 @@ def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: def _build_stream_data_source( self, - stream_data : StreamData, + stream_data: StreamData, base_idx: TIndex, forecast_dt: int, # view_meta: ViewMetadata, @@ -325,22 +324,21 @@ def _build_stream_data_source( """ Return one batch of data Build a StreamData object for a single view (teacher or student). - + Args: - stream_data : + stream_data : base_idx: Time index for this sample forecast_dt: Number of forecast steps view_meta: ViewMetadata describing spatial mask stream_info: Stream configuration dict stream_ds: List of dataset readers for this stream - + Returns: StreamData with source and targets masked according to view_meta """ # iterate overall input steps - for step, idx in enumerate( range( base_idx, base_idx-self.num_input_steps, -1)) : - + for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): # TODO: check that we are not out of bounds when we go back in time time_win_source = self.time_window_handler.window(idx) @@ -360,38 +358,33 @@ def _build_stream_data_source( stream_data.source_is_spoof = True # preprocess data for model input - (ss_cells, ss_lens) = self.tokenizer.batchify_source( + (ss_cells, ss_lens, mask_state) = self.tokenizer.batchify_source( stream_info, readerdata_to_torch(rdata), (time_win_source.start, time_win_source.end), ) # collect data for stream - stream_data.add_source( step, rdata, ss_lens, ss_cells) + stream_data.add_source(step, rdata, ss_lens, ss_cells) return stream_data - def _build_stream_data_target( + def _build_stream_data_target( self, - stream_data : StreamData, + stream_data: StreamData, idx: TIndex, forecast_dt: int, # view_meta: ViewMetadata, stream_info: dict, stream_ds: list, - ) -> StreamData : - + ) -> StreamData: # collect for all forecast steps - for fstep in range( - self.forecast_offset, self.forecast_offset + forecast_dt + 1 - ): + for fstep in range(self.forecast_offset, self.forecast_offset + forecast_dt + 1): step_forecast_dt = idx + (self.forecast_delta_hrs * fstep) // self.step_hrs time_win_target = self.time_window_handler.window(step_forecast_dt) # collect all targets for current stream - rdata: IOReaderData = collect_datasources( - stream_ds, step_forecast_dt, "target" - ) + rdata: IOReaderData = collect_datasources(stream_ds, step_forecast_dt, "target") if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 @@ -416,8 +409,7 @@ def _build_stream_data_target( return stream_data - def _preprocess_model_data( self, batch) : - + def _preprocess_model_data(self, batch, forecast_dt): # aggregated lens of tokens per cell across input batch samples source_cell_lens = compute_source_cell_lens(batch, self.num_input_steps) @@ -428,7 +420,7 @@ def _preprocess_model_data( self, batch) : # (info is not per stream so separate data structure) target_coords_idx = compute_idxs_predict(self.forecast_offset + forecast_dt, batch) - return batch, source_cell_lens, target_coords_idx + return batch, source_cell_lens, target_coords_idx ################################################### def __iter__(self): @@ -472,18 +464,17 @@ def __iter__(self): # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - stream_data = StreamData( idx, forecast_dt + self.forecast_offset, self.num_healpix_cells ) - + # collect source data for current stream - stream_data = self._build_stream_data_source( + stream_data = self._build_stream_data_source( stream_data, idx, forecast_dt, stream_info, stream_ds ) # collect target data for current stream - stream_data = self._build_stream_data_target( + stream_data = self._build_stream_data_target( stream_data, idx, forecast_dt, stream_info, stream_ds ) @@ -498,7 +489,9 @@ def __iter__(self): if not (all(s.empty() or s.target_empty() for s in streams_data)): batch += [streams_data] - batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch) + batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( + batch, forecast_dt + ) yield (batch, source_cell_lens, target_coords_idx, forecast_dt) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 18f2ac046..f5dcefee4 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -67,8 +67,8 @@ def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None: self.source_raw = [] # auxiliary data for scatter operation that changes from stream-centric to cell-centric # processing after embedding - self.source_idxs_embed = torch.tensor([]) - self.source_idxs_embed_pe = torch.tensor([]) + self.source_idxs_embed = [torch.tensor([])] + self.source_idxs_embed_pe = [torch.tensor([])] def to_device(self, device: str) -> None: """ @@ -84,14 +84,15 @@ def to_device(self, device: str) -> None: None """ - self.source_tokens_cells = self.source_tokens_cells.to(device, non_blocking=True) - self.source_tokens_lens = self.source_tokens_lens.to(device, non_blocking=True) + dv = device + self.source_tokens_cells = [s.to(dv, non_blocking=True) for s in self.source_tokens_cells] + self.source_tokens_lens = [s.to(dv, non_blocking=True) for s in self.source_tokens_lens] - self.target_coords = [t.to(device, non_blocking=True) for t in self.target_coords] - self.target_tokens = [t.to(device, non_blocking=True) for t in self.target_tokens] + self.target_coords = [t.to(dv, non_blocking=True) for t in self.target_coords] + self.target_tokens = [t.to(dv, non_blocking=True) for t in self.target_tokens] - self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) - self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) + self.source_idxs_embed = [s.to(dv, non_blocking=True) for s in self.source_idxs_embed] + self.source_idxs_embed_pe = [s.to(dv, non_blocking=True) for s in self.source_idxs_embed_pe] return self @@ -135,7 +136,9 @@ def add_empty_target(self, fstep: int) -> None: np.array([], dtype="datetime64[ns]") for _ in range(self.healpix_cells) ] - def add_source(self, ss_raw: IOReaderData, ss_lens: torch.tensor, ss_cells: list) -> None: + def add_source( + self, step: int, ss_raw: IOReaderData, ss_lens: torch.tensor, ss_cells: list + ) -> None: """ Add data for source for one input. @@ -151,12 +154,13 @@ def add_source(self, ss_raw: IOReaderData, ss_lens: torch.tensor, ss_cells: list None """ - self.source_raw = ss_raw - self.source_tokens_lens = ss_lens - self.source_tokens_cells = torch.stack(ss_cells) + # TODO: use step + self.source_raw += [ss_raw] + self.source_tokens_lens += [ss_lens] + self.source_tokens_cells += [torch.stack(ss_cells)] - idx = torch.isnan(self.source_tokens_cells) - self.source_tokens_cells[idx] = self.mask_value + idx = torch.isnan(self.source_tokens_cells[-1]) + self.source_tokens_cells[-1][idx] = self.mask_value def add_target( self, @@ -232,7 +236,7 @@ def source_empty(self) -> bool: True if target is empty for stream, else False """ - return self.source_tokens_lens.sum() == 0 + return torch.tensor([s.sum() for s in self.source_tokens_lens]).sum() == 0 def empty(self): """ diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 98d5a044e..3d92e5a66 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -266,7 +266,7 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): return a -def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: +def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to cell-centric computations @@ -283,46 +283,52 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: """ # collect source_tokens_lens for all stream datas - source_tokens_lens = torch.stack( - [ - torch.stack( - [ - s.source_tokens_lens if len(s.source_tokens_lens) > 0 else torch.tensor([]) - for s in stl_b - ] - ) - for stl_b in batch - ] - ) - - # precompute index sets for scatter operation after embed - offsets_base = source_tokens_lens.sum(1).sum(0).cumsum(0) - offsets = torch.cat([torch.zeros(1, dtype=torch.int32), offsets_base[:-1]]) - offsets_pe = torch.zeros_like(offsets) - - for ib, sb in enumerate(batch): - for itype, s in enumerate(sb): - if not s.source_empty(): - s.source_idxs_embed = torch.cat( - [ - torch.arange(offset, offset + token_len, dtype=torch.int64) - for offset, token_len in zip( - offsets, source_tokens_lens[ib, itype], strict=False - ) - ] - ) - s.source_idxs_embed_pe = torch.cat( + source_tokens_lens = [ + torch.stack( + [ + torch.stack( [ - torch.arange(offset, offset + token_len, dtype=torch.int32) - for offset, token_len in zip( - offsets_pe, source_tokens_lens[ib][itype], strict=False - ) + s.source_tokens_lens[i] + if len(s.source_tokens_lens[i]) > 0 + else torch.tensor([]) + for s in stl_b ] ) + for stl_b in batch + ] + ) + for i in range(num_input_steps) + ] - # advance offsets - offsets += source_tokens_lens[ib][itype] - offsets_pe += source_tokens_lens[ib][itype] + # precompute index sets for scatter operation after embed + offsets_base = [s.sum(1).sum(0).cumsum(0) for s in source_tokens_lens] + offsets = [torch.cat([torch.zeros(1, dtype=torch.int32), o[:-1]]) for o in offsets_base] + offsets_pe = [torch.zeros_like(o) for o in offsets] + + for i_s in range(num_input_steps): + for ib, sb in enumerate(batch): # batch items + for itype, s in enumerate(sb): # streams, i.e. here we have StreamData object + if not s.source_empty(): + s.source_idxs_embed[i_s] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int64) + for offset, token_len in zip( + offsets[i_s], source_tokens_lens[i_s][ib, itype], strict=False + ) + ] + ) + s.source_idxs_embed_pe[i_s] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int32) + for offset, token_len in zip( + offsets_pe[i_s], source_tokens_lens[i_s][ib][itype], strict=False + ) + ] + ) + + # advance offsets + offsets[i_s] += source_tokens_lens[i_s][ib][itype] + offsets_pe[i_s] += source_tokens_lens[i_s][ib][itype] return batch @@ -372,14 +378,16 @@ def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: return tcs_lens_merged -def compute_source_cell_lens(batch: StreamData) -> torch.tensor: +def compute_source_cell_lens( + batch: list[list[StreamData]], num_input_steps: int +) -> list[torch.tensor]: """ Compute auxiliary information for varlen attention for local assimilation Parameters ---------- batch : - StreamData information for current batch + StreamData information for current batch for each batch item and each stream Returns ------- @@ -388,18 +396,24 @@ def compute_source_cell_lens(batch: StreamData) -> torch.tensor: """ # precompute for processing in the model (with varlen flash attention) - source_cell_lens_raw = torch.stack( - [ - torch.stack( - [ - s.source_tokens_lens if len(s.source_tokens_lens) > 0 else torch.tensor([]) - for s in stl_b - ] - ) - for stl_b in batch - ] - ) - source_cell_lens = torch.sum(source_cell_lens_raw, 1).flatten().to(torch.int32) - source_cell_lens = torch.cat([torch.zeros(1, dtype=torch.int32), source_cell_lens]) + source_cell_lens_raw = [ + torch.stack( + [ + torch.stack( + [ + s.source_tokens_lens[i] + if len(s.source_tokens_lens[i]) > 0 + else torch.tensor([]) + for s in stl_b + ] + ) + for stl_b in batch + ] + ) + for i in range(num_input_steps) + ] + + source_cell_lens = [torch.sum(c, 1).flatten().to(torch.int32) for c in source_cell_lens_raw] + source_cell_lens = [torch.cat([torch.zeros(1, dtype=torch.int32), c]) for c in source_cell_lens] return source_cell_lens diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index f9fa1598c..1f799e223 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -80,11 +80,14 @@ def __init__(self, cf: Config, sources_size) -> None: raise ValueError("Unsupported embedding network type") def forward(self, streams_data, pe_embed, dtype, device): + istep = 0 source_tokens_lens = torch.stack( [ torch.stack( [ - s.source_tokens_lens if len(s.source_tokens_lens) > 0 else torch.tensor([]) + s.source_tokens_lens[istep] + if len(s.source_tokens_lens[istep]) > 0 + else torch.tensor([]) for s in stl_b ] ) @@ -100,13 +103,13 @@ def forward(self, streams_data, pe_embed, dtype, device): for _, sb in enumerate(streams_data): for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): if not s.source_empty(): - idxs = s.source_idxs_embed.to(device) - idxs_pe = s.source_idxs_embed_pe.to(device) + idxs = s.source_idxs_embed[istep].to(device) + idxs_pe = s.source_idxs_embed_pe[istep].to(device) # create full scatter index # (there's no broadcasting which is likely highly inefficient) idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - x_embed = embed(s.source_tokens_cells).flatten(0, 1) + x_embed = embed(s.source_tokens_cells[istep]).flatten(0, 1) # there's undocumented limitation in flash_attn that will make embed fail if # #tokens is too large; code below is a work around # x_embed = torch.cat( diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 000f36735..10bc207d0 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -726,7 +726,9 @@ def assimilate_local( # work around to bug in flash attention for hl>=5 - cell_lens = cell_lens[1:] + istep = 0 + + cell_lens = cell_lens[istep][1:] clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) tokens_global_all = [] posteriors = [] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b24f944c9..09dc05e34 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -557,14 +557,14 @@ def _prepare_logging( targets_lens[fstep][i_strm] += [target.shape[0]] dn_data = self.dataset_val.denormalize_target_channels - # reorder so that output order of target points matches input when reading - # (tokenization and masking changes this order) - # TODO: does this work with batch_size > 1 - if len(idxs_inv) > 0: - pred = pred[:, idxs_inv] - target = target[idxs_inv] - targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] - targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] + # # reorder so that output order of target points matches input when reading + # # (tokenization and masking changes this order) + # # TODO: does this work with batch_size > 1 + # if len(idxs_inv) > 0: + # pred = pred[:, idxs_inv] + # target = target[idxs_inv] + # targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] + # targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] f32 = torch.float32 preds_all[fstep][i_strm] += [ @@ -759,7 +759,7 @@ def batch_to_device(self, batch): # forecast_steps is dropped here from the batch return ( [[d.to_device(self.device) for d in db] for db in batch[0]], - batch[1].to(self.device), + [b.to(self.device) for b in batch[1]], [[b.to(self.device) for b in bf] for bf in batch[2]], ) From 668912d4addea4d2a98f553f34a3df75473da709 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 18 Nov 2025 13:47:40 +0100 Subject: [PATCH 048/344] Partially enabled correct handling of multiple input steps. --- src/weathergen/model/engines.py | 83 +++++++++++++++------------------ src/weathergen/model/model.py | 10 ++-- 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 1f799e223..613d45540 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -79,52 +79,43 @@ def __init__(self, cf: Config, sources_size) -> None: else: raise ValueError("Unsupported embedding network type") - def forward(self, streams_data, pe_embed, dtype, device): - istep = 0 - source_tokens_lens = torch.stack( - [ - torch.stack( - [ - s.source_tokens_lens[istep] - if len(s.source_tokens_lens[istep]) > 0 - else torch.tensor([]) - for s in stl_b - ] - ) - for stl_b in streams_data - ] - ) - offsets_base = source_tokens_lens.sum(1).sum(0).cumsum(0) - - tokens_all = torch.empty( - (int(offsets_base[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device - ) - - for _, sb in enumerate(streams_data): - for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): - if not s.source_empty(): - idxs = s.source_idxs_embed[istep].to(device) - idxs_pe = s.source_idxs_embed_pe[istep].to(device) - - # create full scatter index - # (there's no broadcasting which is likely highly inefficient) - idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - x_embed = embed(s.source_tokens_cells[istep]).flatten(0, 1) - # there's undocumented limitation in flash_attn that will make embed fail if - # #tokens is too large; code below is a work around - # x_embed = torch.cat( - # [ - # embed(s_c, c_c).flatten(0, 1) - # for s_c, c_c in zip( - # torch.split(s.source_tokens_cells, 49152), - # torch.split(s.source_centroids, 49152), - # ) - # ] - # ) - - # scatter write to reorder from per stream to per cell ordering - tokens_all.scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) - return tokens_all + def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): + num_step_input = len(source_cell_lens) + + offsets_base = [torch.cumsum(s[1:], 0) for s in source_cell_lens] + + tokens_all = [ + torch.empty((int(ob[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device) + for ob in offsets_base + ] + + for istep in range(num_step_input): + for _, sb in enumerate(streams_data): + for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): + if not s.source_empty(): + idxs = s.source_idxs_embed[istep].to(device) + idxs_pe = s.source_idxs_embed_pe[istep].to(device) + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + x_embed = embed(s.source_tokens_cells[istep]).flatten(0, 1) + # there's undocumented limitation in flash_attn that will make embed fail if + # #tokens is too large; code below is a work around + # x_embed = torch.cat( + # [ + # embed(s_c, c_c).flatten(0, 1) + # for s_c, c_c in zip( + # torch.split(s.source_tokens_cells, 49152), + # torch.split(s.source_centroids, 49152), + # ) + # ] + # ) + + # scatter write to reorder from per stream to per cell ordering + tokens_all[istep].scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) + + return tokens_all[0] class LocalAssimilationEngine(torch.nn.Module): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 10bc207d0..672989cf9 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -613,7 +613,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca (streams_data, source_cell_lens, target_coords_idxs) = batch # embed - tokens = self.embed_cells(model_params, streams_data) + tokens = self.embed_cells(model_params, streams_data, source_cell_lens) # local assimilation engine and adapter tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) @@ -656,7 +656,9 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca return preds_all, posteriors ######################################### - def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: + def embed_cells( + self, model_params: ModelParams, streams_data, source_cell_lens + ) -> torch.Tensor: """Embeds input data for each stream separately and rearranges it to cell-wise order Args: model_params : Query and embedding parameters @@ -666,7 +668,9 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: """ device = next(self.parameters()).device - tokens_all = self.embed_engine(streams_data, model_params.pe_embed, self.dtype, device) + tokens_all = self.embed_engine( + streams_data, source_cell_lens, model_params.pe_embed, self.dtype, device + ) return tokens_all From 33394ffb3818902a9e830994846a7d24c0c44541 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 14:53:25 +0100 Subject: [PATCH 049/344] initialize loss as torch tensor with grad --- src/weathergen/train/loss_calculator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index fbfaebdb0..f2deaacc8 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -12,6 +12,7 @@ import dataclasses import logging +import torch from omegaconf import DictConfig from torch import Tensor @@ -82,9 +83,9 @@ def compute_loss( targets: dict, ): loss_terms = {} - loss = 0 + loss = torch.tensor(0.0, requires_grad=True) for weight, calculator in self.loss_calculators: loss_terms[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) - loss += weight * loss_terms[calculator.name].loss + loss = loss + weight * loss_terms[calculator.name].loss return LossTerms(loss=loss, loss_terms=loss_terms) From bda52d8f8d54b288b7076f4f9384c054ac535e80 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 14:54:50 +0100 Subject: [PATCH 050/344] remove level in hist losses dict --- config/default_config.yml | 6 +- src/weathergen/train/loss_modules/loss.py | 60 ++++++++++++++++ src/weathergen/train/trainer.py | 86 +++++++++++------------ src/weathergen/utils/train_logger.py | 36 +++++----- 4 files changed, 121 insertions(+), 67 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index c101521da..e13bb9a11 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -94,7 +94,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" -training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mse', 0.2]]}, +training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]}, LossPhysicalTwo: {weight: 0.3, loss_fcts: [['mse', 1.0]]}, } } @@ -124,8 +124,8 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], } num_epochs: 32 -samples_per_epoch: 4096 -samples_per_validation: 512 +samples_per_epoch: 32 +samples_per_validation: 8 shuffle: True lr_scaling_policy: "sqrt" diff --git a/src/weathergen/train/loss_modules/loss.py b/src/weathergen/train/loss_modules/loss.py index 406cd051c..e09928dd1 100644 --- a/src/weathergen/train/loss_modules/loss.py +++ b/src/weathergen/train/loss_modules/loss.py @@ -186,6 +186,66 @@ def mse_channel_location_weighted( return loss, loss_chs +def mae( + target: torch.Tensor, + pred: torch.Tensor, + weights_channels: torch.Tensor | None, + weights_points: torch.Tensor | None, +): + """ + Compute weighted MAE loss for one window or step + + The function implements: + + loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points )) + + Geometrically, + + ------------------------ - + | | | | + | | | | + | | | | + | target - pred | x |wp| + | | | | + | | | | + | | | | + ------------------------ - + x + ------------------------ + | wc | + ------------------------ + + where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication. + + The computations are: + 1. weight the rows of (target - pred) by wp = weights_points + 2. take the mean over the row + 3. weight the collapsed cols by wc = weights_channels + 4. take the mean over the channel-weighted cols + + Params: + target : shape ( num_data_points , num_channels ) + target : shape ( ens_dim , num_data_points , num_channels) + weights_channels : shape = (num_channels,) + weights_points : shape = (num_data_points) + + Return: + loss : weight loss for gradient computation + loss_chs : losses per channel with location weighting but no channel weighting + """ + + mask_nan = ~torch.isnan(target) + pred = pred[0] if pred.shape[0] == 0 else pred.mean(0) + + diff2 = torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0) + if weights_points is not None: + diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0) + loss_chs = diff2.mean(0) + loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) + + return loss, loss_chs + + def cosine_latitude(stream_data, forecast_offset, fstep, min_value=1e-3, max_value=1.0): latitudes_radian = stream_data.target_coords_raw[forecast_offset + fstep][:, 0] * np.pi / 180 return (max_value - min_value) * np.cos(latitudes_radian) + min_value diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index d580816ef..0eab2c289 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -582,9 +582,6 @@ def train(self, epoch): self.optimizer.zero_grad() - # Unweighted loss, real weighted loss, std for losses that need it - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] - # training loop self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): @@ -642,28 +639,27 @@ def train(self, epoch): if bidx == 0: self.loss_unweighted_hist = { - calc_name: {loss_name: []} - for calc_name, calc_terms in loss_values.loss_terms.items() + loss_name: [] + for _, calc_terms in loss_values.loss_terms.items() for loss_name in calc_terms.losses_all.keys() } self.stdev_unweighted_hist = { - calc_name: {loss_name: []} - for calc_name, calc_terms in loss_values.loss_terms.items() + loss_name: [] + for _, calc_terms in loss_values.loss_terms.items() for loss_name in calc_terms.stddev_all.keys() } self.loss_model_hist = [] - for calc_name, loss_terms in loss_values.loss_terms.items(): + for _, loss_terms in loss_values.loss_terms.items(): for loss_name, losses_all in loss_terms.losses_all.items(): - self.loss_unweighted_hist[calc_name][loss_name].append(losses_all) + self.loss_unweighted_hist[loss_name].append(losses_all) for loss_name, stddev_all in loss_terms.stddev_all.items(): - self.stdev_unweighted_hist[calc_name][loss_name].append(stddev_all) + self.stdev_unweighted_hist[loss_name].append(stddev_all) self.loss_model_hist += [loss_values.loss.item()] perf_gpu, perf_mem = self.get_perf() self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() - # NEED TO FIX LOGGING self._log_terminal(bidx, epoch, TRAIN) if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) @@ -672,13 +668,13 @@ def train(self, epoch): if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: self.save_model(-1) self.loss_unweighted_hist = { - calc_name: {loss_name: []} - for calc_name, calc_terms in loss_values.loss_terms.items() + loss_name: [] + for _, calc_terms in loss_values.loss_terms.items() for loss_name in calc_terms.losses_all.keys() } self.stdev_unweighted_hist = { - calc_name: {loss_name: []} - for calc_name, calc_terms in loss_values.loss_terms.items() + loss_name: [] + for _, calc_terms in loss_values.loss_terms.items() for loss_name in calc_terms.stddev_all.keys() } self.loss_model_hist = [] @@ -756,15 +752,23 @@ def validate(self, epoch): sample_idxs, ) - self.loss_unweighted_hist += [loss_values.loss_terms] - self.loss_model_hist += [loss_values.loss.item()] if bidx == 0: - self.loss_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} - self.stdev_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.loss_unweighted_hist = { + loss_name: [] + for _, calc_terms in loss_values.loss_terms.items() + for loss_name in calc_terms.losses_all.keys() + } + self.stdev_unweighted_hist = { + loss_name: [] + for _, calc_terms in loss_values.loss_terms.items() + for loss_name in calc_terms.stddev_all.keys() + } self.loss_model_hist = [] - for name, loss_terms in loss_values.loss_terms.items(): - self.loss_unweighted_hist[name].append(loss_terms.losses_all) - self.stdev_unweighted_hist[name].append(loss_terms.stddev_all) + for _, loss_terms in loss_values.loss_terms.items(): + for loss_name, losses_all in loss_terms.losses_all.items(): + self.loss_unweighted_hist[loss_name].append(losses_all) + for loss_name, stddev_all in loss_terms.stddev_all.items(): + self.stdev_unweighted_hist[loss_name].append(stddev_all) self.loss_model_hist += [loss_values.loss.item()] pbar.update(self.cf.batch_size_validation_per_gpu) @@ -995,24 +999,20 @@ def _prepare_losses_for_logging( stddev_all (dict[str, torch.Tensor]): Dictionary mapping each stream name to its per-channel standard deviation tensor. """ - losses_all: dict[dict[str, Tensor]] = {} - stddev_all: dict[dict[str, Tensor]] = {} + losses_all: dict[str, Tensor] = {} + stddev_all: dict[str, Tensor] = {} # Make list of losses into a tensor. This is individual tensor per rank real_loss = torch.tensor(self.loss_model_hist, device=self.device) # Gather all tensors from all ranks into a list and stack them into one tensor again real_loss = torch.cat(all_gather_vlen(real_loss)) - for calc_name, loss_terms in self.loss_unweighted_hist.items(): - losses_all[calc_name] = {} - for loss_name, losses in loss_terms.items(): - losses = torch.stack(losses).to(torch.float64) - losses_all[calc_name][loss_name] = torch.cat(all_gather_vlen(losses)) - for calc_name, stddev_terms in self.stdev_unweighted_hist.items(): - stddev_all[calc_name] = {} - for stddev_name, stddevs in stddev_terms.items(): - stddevs = torch.stack(stddevs).to(torch.float64) - stddev_all[calc_name][stddev_name] = torch.cat(all_gather_vlen(stddevs)) + for loss_name, loss_values in self.loss_unweighted_hist.items(): + loss_values = torch.stack(loss_values).to(torch.float64) + losses_all[loss_name] = torch.cat(all_gather_vlen(loss_values)) + for stddev_name, stddev_values in self.stdev_unweighted_hist.items(): + stddev_values = torch.stack(stddev_values).to(torch.float64) + stddev_all[stddev_name] = torch.cat(all_gather_vlen(stddev_values)) return real_loss, losses_all, stddev_all @@ -1078,11 +1078,10 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): logger.info( f"validation ({self.cf.run_id}) : {epoch:03d} : {avg_loss.nanmean().item()}" ) - for _, losses in losses_all.items(): - for loss_name, loss in losses.items(): - logger.info( - f"{loss_name}" + f" : {loss.nanmean():0.4E} \t", - ) + for loss_name, loss_values in losses_all.items(): + logger.info( + f"{loss_name}" + f" : {loss_values.nanmean():0.4E} \t", + ) logger.info("\n") elif stage == TRAIN: @@ -1099,11 +1098,10 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): pstr += f"s/sec={(print_freq * self.cf.batch_size_per_gpu) / dt:.3f})" logger.info(pstr) logger.info("\t") - for _, losses in losses_all.items(): - for loss_name, loss in losses.items(): - logger.info( - f"{loss_name}" + f" : {loss.nanmean():0.4E} \t", - ) + for loss_name, loss_values in losses_all.items(): + logger.info( + f"{loss_name}" + f" : {loss_values.nanmean():0.4E} \t", + ) logger.info("\n") self.t_start = time.time() diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 1df6930d4..2df60651a 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -120,16 +120,14 @@ def add_train( log_vals += [lr] st = self.cf.streams[0] - for _, loss_terms in losses_all.items(): - for loss_name, losses in loss_terms.items(): - metrics[f"{loss_name}.loss_avg"] = losses[:, :].nanmean().item() - for k, ch_n in enumerate(st.train_target_channels): - metrics[f"{loss_name}.loss_{ch_n}"] = losses[:, k].nanmean().item() - log_vals += [losses[:, :].nanmean().item()] - for _, stddev_terms in stddev_all.items(): - for loss_name, stddev in stddev_terms.items(): - metrics[f"{loss_name}.stddev_avg"] = stddev.nanmean().item() - log_vals += [stddev.nanmean().item()] + for loss_name, loss_values in losses_all.items(): + metrics[f"{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() + for k, ch_n in enumerate(st.train_target_channels): + metrics[f"{loss_name}.loss_{ch_n}"] = loss_values[:, k].nanmean().item() + log_vals += [loss_values[:, :].nanmean().item()] + for loss_name, stddev_values in stddev_all.items(): + metrics[f"{loss_name}.stddev_avg"] = stddev_values.nanmean().item() + log_vals += [stddev_values.nanmean().item()] with open(self.path_run / f"{self.cf.run_id}_train_log.txt", "ab") as f: np.savetxt(f, log_vals) @@ -157,16 +155,14 @@ def add_val( log_vals += [samples] st = self.cf.streams[0] - for _, loss_terms in losses_all.items(): - for loss_name, losses in loss_terms.items(): - metrics[f"{loss_name}.loss_avg"] = losses[:, :].nanmean().item() - for k, ch_n in enumerate(st.train_target_channels): - metrics[f"{loss_name}.loss_{ch_n}"] = losses[:, k].nanmean().item() - log_vals += [losses[:, :].nanmean().item()] - for _, stddev_terms in stddev_all.items(): - for loss_name, stddev in stddev_terms.items(): - metrics[f"{loss_name}.stddev_avg"] = stddev.nanmean().item() - log_vals += [stddev.nanmean().item()] + for loss_name, loss_values in losses_all.items(): + metrics[f"{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() + for k, ch_n in enumerate(st.train_target_channels): + metrics[f"{loss_name}.loss_{ch_n}"] = loss_values[:, k].nanmean().item() + log_vals += [loss_values[:, :].nanmean().item()] + for loss_name, stddev_values in stddev_all.items(): + metrics[f"{loss_name}.stddev_avg"] = stddev_values.nanmean().item() + log_vals += [stddev_values.nanmean().item()] self.log_metrics("val", metrics) with open(self.path_run / (self.cf.run_id + "_val_log.txt"), "ab") as f: From 053dddd72c06a48b0eea6fb226c3b3b33d64876f Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 14:57:05 +0100 Subject: [PATCH 051/344] rename loss.py to loss_functions.py --- src/weathergen/train/loss_modules/loss.py | 257 ------------------ .../train/loss_modules/loss_module_latent.py | 2 +- .../loss_modules/loss_module_physical.py | 2 +- 3 files changed, 2 insertions(+), 259 deletions(-) delete mode 100644 src/weathergen/train/loss_modules/loss.py diff --git a/src/weathergen/train/loss_modules/loss.py b/src/weathergen/train/loss_modules/loss.py deleted file mode 100644 index e09928dd1..000000000 --- a/src/weathergen/train/loss_modules/loss.py +++ /dev/null @@ -1,257 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import numpy as np -import torch - -stat_loss_fcts = ["stats", "kernel_crps"] # Names of loss functions that need std computed - - -def gaussian(x, mu=0.0, std_dev=1.0): - # unnormalized Gaussian where maximum is one - return torch.exp(-0.5 * (x - mu) * (x - mu) / (std_dev * std_dev)) - - -def normalized_gaussian(x, mu=0.0, std_dev=1.0): - return (1 / (std_dev * np.sqrt(2.0 * np.pi))) * torch.exp( - -0.5 * (x - mu) * (x - mu) / (std_dev * std_dev) - ) - - -def erf(x, mu=0.0, std_dev=1.0): - c1 = torch.sqrt(torch.tensor(0.5 * np.pi)) - c2 = torch.sqrt(1.0 / torch.tensor(std_dev * std_dev)) - c3 = torch.sqrt(torch.tensor(2.0)) - val = c1 * (1.0 / c2 - std_dev * torch.special.erf((mu - x) / (c3 * std_dev))) - return val - - -def gaussian_crps(target, ens, mu, stddev): - # see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather - # forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018. - c1 = np.sqrt(1.0 / np.pi) - t1 = 2.0 * erf((target - mu) / stddev) - 1.0 - t2 = 2.0 * normalized_gaussian((target - mu) / stddev) - val = stddev * ((target - mu) / stddev * t1 + t2 - c1) - return torch.mean(val) # + torch.mean( torch.sqrt( stddev) ) - - -def stats(target, ens, mu, stddev): - diff = gaussian(target, mu, stddev) - 1.0 - return torch.mean(diff * diff) + torch.mean(torch.sqrt(stddev)) - - -def stats_normalized(target, ens, mu, stddev): - a = normalized_gaussian(target, mu, stddev) - max = 1 / (np.sqrt(2 * np.pi) * stddev) - d = a - max - return torch.mean(d * d) + torch.mean(torch.sqrt(stddev)) - - -def stats_normalized_erf(target, ens, mu, stddev): - delta = -torch.abs(target - mu) - d = 0.5 + torch.special.erf(delta / (np.sqrt(2.0) * stddev)) - return torch.mean(d * d) # + torch.mean( torch.sqrt( stddev) ) - - -def mse(target, ens, mu, *kwargs): - return torch.nn.functional.mse_loss(target, mu) - - -def mse_ens(target, ens, mu, stddev): - mse_loss = torch.nn.functional.mse_loss - return torch.stack([mse_loss(target, mem) for mem in ens], 0).mean() - - -def kernel_crps( - targets, - preds, - weights_channels: torch.Tensor | None, - weights_points: torch.Tensor | None, - fair=True, -): - """ - Compute kernel CRPS - - Params: - target : shape ( num_data_points , num_channels ) - pred : shape ( ens_dim , num_data_points , num_channels) - weights_channels : shape = (num_channels,) - weights_points : shape = (num_data_points) - - Returns: - loss: scalar - overall weighted CRPS - loss_chs: [C] - per-channel CRPS (location-weighted, not channel-weighted) - """ - - ens_size = preds.shape[0] - assert ens_size > 1, "Ensemble size has to be greater than 1 for kernel CRPS." - assert len(preds.shape) == 3, "if data has batch dimension, remove unsqueeze() below" - - # replace NaN by 0 - mask_nan = ~torch.isnan(targets) - targets = torch.where(mask_nan, targets, 0) - preds = torch.where(mask_nan, preds, 0) - - # permute to enable/simply broadcasting and contractions below - preds = preds.permute([2, 1, 0]).unsqueeze(0).to(torch.float32) - targets = targets.permute([1, 0]).unsqueeze(0).to(torch.float32) - - mae = torch.mean(torch.abs(targets[..., None] - preds), dim=-1) - - ens_n = -1.0 / (ens_size * (ens_size - 1)) if fair else -1.0 / (ens_size**2) - abs = torch.abs - ens_var = torch.zeros(size=preds.shape[:-1], device=preds.device) - # loop to reduce memory usage - for i in range(ens_size): - ens_var += torch.sum(ens_n * abs(preds[..., i].unsqueeze(-1) - preds[..., i + 1 :]), dim=-1) - - kcrps_locs_chs = mae + ens_var - - # apply point weighting - if weights_points is not None: - kcrps_locs_chs = kcrps_locs_chs * weights_points - # apply channel weighting - kcrps_chs = torch.mean(torch.mean(kcrps_locs_chs, 0), -1) - if weights_channels is not None: - kcrps_chs = kcrps_chs * weights_channels - - return torch.mean(kcrps_chs), kcrps_chs - - -def mse_channel_location_weighted( - target: torch.Tensor, - pred: torch.Tensor, - weights_channels: torch.Tensor | None, - weights_points: torch.Tensor | None, -): - """ - Compute weighted MSE loss for one window or step - - The function implements: - - loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points )) - - Geometrically, - - ------------------------ - - | | | | - | | | | - | | | | - | target - pred | x |wp| - | | | | - | | | | - | | | | - ------------------------ - - x - ------------------------ - | wc | - ------------------------ - - where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication. - - The computations are: - 1. weight the rows of (target - pred) by wp = weights_points - 2. take the mean over the row - 3. weight the collapsed cols by wc = weights_channels - 4. take the mean over the channel-weighted cols - - Params: - target : shape ( num_data_points , num_channels ) - target : shape ( ens_dim , num_data_points , num_channels) - weights_channels : shape = (num_channels,) - weights_points : shape = (num_data_points) - - Return: - loss : weight loss for gradient computation - loss_chs : losses per channel with location weighting but no channel weighting - """ - - mask_nan = ~torch.isnan(target) - pred = pred[0] if pred.shape[0] == 0 else pred.mean(0) - - diff2 = torch.square(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)) - if weights_points is not None: - diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0) - loss_chs = diff2.mean(0) - loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) - - return loss, loss_chs - - -def mae( - target: torch.Tensor, - pred: torch.Tensor, - weights_channels: torch.Tensor | None, - weights_points: torch.Tensor | None, -): - """ - Compute weighted MAE loss for one window or step - - The function implements: - - loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points )) - - Geometrically, - - ------------------------ - - | | | | - | | | | - | | | | - | target - pred | x |wp| - | | | | - | | | | - | | | | - ------------------------ - - x - ------------------------ - | wc | - ------------------------ - - where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication. - - The computations are: - 1. weight the rows of (target - pred) by wp = weights_points - 2. take the mean over the row - 3. weight the collapsed cols by wc = weights_channels - 4. take the mean over the channel-weighted cols - - Params: - target : shape ( num_data_points , num_channels ) - target : shape ( ens_dim , num_data_points , num_channels) - weights_channels : shape = (num_channels,) - weights_points : shape = (num_data_points) - - Return: - loss : weight loss for gradient computation - loss_chs : losses per channel with location weighting but no channel weighting - """ - - mask_nan = ~torch.isnan(target) - pred = pred[0] if pred.shape[0] == 0 else pred.mean(0) - - diff2 = torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0) - if weights_points is not None: - diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0) - loss_chs = diff2.mean(0) - loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) - - return loss, loss_chs - - -def cosine_latitude(stream_data, forecast_offset, fstep, min_value=1e-3, max_value=1.0): - latitudes_radian = stream_data.target_coords_raw[forecast_offset + fstep][:, 0] * np.pi / 180 - return (max_value - min_value) * np.cos(latitudes_radian) + min_value - - -def gamma_decay(forecast_steps, gamma): - fsteps = np.arange(forecast_steps) - weights = gamma**fsteps - return weights * (len(fsteps) / np.sum(weights)) diff --git a/src/weathergen/train/loss_modules/loss_module_latent.py b/src/weathergen/train/loss_modules/loss_module_latent.py index 6daf472bb..c680382ee 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent.py +++ b/src/weathergen/train/loss_modules/loss_module_latent.py @@ -15,7 +15,7 @@ from omegaconf import DictConfig from torch import Tensor -import weathergen.train.loss_modules.loss as losses +import weathergen.train.loss_modules.loss_functions as losses from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import Stage diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 817360706..7bef87064 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -16,7 +16,7 @@ from omegaconf import DictConfig from torch import Tensor -import weathergen.train.loss_modules.loss as losses +import weathergen.train.loss_modules.loss_functions as losses from weathergen.train.loss_modules.loss import stat_loss_fcts from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage From d094ad0e28cef105a76fab709f8b9a2e00002946 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 15:09:11 +0100 Subject: [PATCH 052/344] rename loss.py to loss_functions.py --- src/weathergen/train/loss_modules/loss_module_physical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 7bef87064..1e9cceb1b 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -17,7 +17,7 @@ from torch import Tensor import weathergen.train.loss_modules.loss_functions as losses -from weathergen.train.loss_modules.loss import stat_loss_fcts +from weathergen.train.loss_modules.loss_functions import stat_loss_fcts from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage From 8b4cbef35f26c723b491ee492ed39a38a43e19ff Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 15:10:00 +0100 Subject: [PATCH 053/344] return loss with grads seperately to trainer --- src/weathergen/train/loss_calculator.py | 5 +---- src/weathergen/train/trainer.py | 12 ++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index f2deaacc8..d462b3c1b 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -14,7 +14,6 @@ import torch from omegaconf import DictConfig -from torch import Tensor import weathergen.train.loss_modules as LossModules from weathergen.train.loss_modules.loss_module_base import LossValues @@ -29,8 +28,6 @@ class LossTerms: A dataclass which combines the LossValues of all loss modules """ - # The primary scalar loss value for optimization. - loss: Tensor # Dictionary containing the LossValues of each loss module. loss_terms: dict[str, LossValues] @@ -88,4 +85,4 @@ def compute_loss( loss_terms[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) loss = loss + weight * loss_terms[calculator.name].loss - return LossTerms(loss=loss, loss_terms=loss_terms) + return loss, LossTerms(loss_terms=loss_terms) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0eab2c289..295fac732 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -596,17 +596,17 @@ def train(self, epoch): ): output = self.model(self.model_params, batch, cf.forecast_offset, forecast_steps) targets = {"physical": batch[0]} - loss_values = self.loss_calculator.compute_loss( + loss, loss_values = self.loss_calculator.compute_loss( preds=output, targets=targets, ) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in output.latent]) - loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + loss += cf.latent_noise_kl_weight * kl.mean() # backward pass self.optimizer.zero_grad() - self.grad_scaler.scale(loss_values.loss).backward() + self.grad_scaler.scale(loss).backward() # loss_values.loss.backward() # gradient clipping @@ -654,7 +654,7 @@ def train(self, epoch): self.loss_unweighted_hist[loss_name].append(losses_all) for loss_name, stddev_all in loss_terms.stddev_all.items(): self.stdev_unweighted_hist[loss_name].append(stddev_all) - self.loss_model_hist += [loss_values.loss.item()] + self.loss_model_hist += [loss.item()] perf_gpu, perf_mem = self.get_perf() self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() @@ -716,7 +716,7 @@ def validate(self, epoch): targets = {"physical": batch[0]} # compute loss - loss_values = self.loss_calculator_val.compute_loss( + loss, loss_values = self.loss_calculator_val.compute_loss( preds=output, targets=targets, ) @@ -769,7 +769,7 @@ def validate(self, epoch): self.loss_unweighted_hist[loss_name].append(losses_all) for loss_name, stddev_all in loss_terms.stddev_all.items(): self.stdev_unweighted_hist[loss_name].append(stddev_all) - self.loss_model_hist += [loss_values.loss.item()] + self.loss_model_hist += [loss.item()] pbar.update(self.cf.batch_size_validation_per_gpu) From dd6f85aeca3a44866af1c3202a424747a231fadf Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 18 Nov 2025 15:30:22 +0100 Subject: [PATCH 054/344] Added mode and refactored get_sample_data into separate function. --- .../datasets/multi_stream_data_sampler.py | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 266c9e0c7..d3529781d 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -314,6 +314,7 @@ def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: def _build_stream_data_source( self, + mode: str, stream_data: StreamData, base_idx: TIndex, forecast_dt: int, @@ -371,6 +372,7 @@ def _build_stream_data_source( def _build_stream_data_target( self, + mode: str, stream_data: StreamData, idx: TIndex, forecast_dt: int, @@ -409,6 +411,35 @@ def _build_stream_data_target( return stream_data + def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): + """ + + mode : {student, teacher, mtm} + """ + + streams_data: list[StreamData] = [] + + # for all streams + for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + stream_data = StreamData( + idx, forecast_dt + self.forecast_offset, self.num_healpix_cells + ) + + # collect source data for current stream + stream_data = self._build_stream_data_source( + mode, stream_data, idx, forecast_dt, stream_info, stream_ds + ) + + # collect target data for current stream + stream_data = self._build_stream_data_target( + mode, stream_data, idx, forecast_dt, stream_info, stream_ds + ) + + # add data for current stream + streams_data += [stream_data] + + return streams_data + def _preprocess_model_data(self, batch, forecast_dt): # aggregated lens of tokens per cell across input batch samples source_cell_lens = compute_source_cell_lens(batch, self.num_input_steps) @@ -458,28 +489,10 @@ def __iter__(self): if hasattr(self.tokenizer, "masker"): self.tokenizer.masker.set_batch_strategy() - streams_data: list[StreamData] = [] + mode = "teacher" # tokenizer.generate_masks_for_sample() - - # for all streams - for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - stream_data = StreamData( - idx, forecast_dt + self.forecast_offset, self.num_healpix_cells - ) - - # collect source data for current stream - stream_data = self._build_stream_data_source( - stream_data, idx, forecast_dt, stream_info, stream_ds - ) - - # collect target data for current stream - stream_data = self._build_stream_data_target( - stream_data, idx, forecast_dt, stream_info, stream_ds - ) - - # add data for current stream - streams_data += [stream_data] + streams_data = self._get_sample_data(mode, idx, forecast_dt) # Reset masking strategy for next batch item if hasattr(self.tokenizer, "masker"): From d0ef572f67cff47832ed60e4953a35bcaa4c1059 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 16:05:15 +0100 Subject: [PATCH 055/344] modify log names --- src/weathergen/utils/train_logger.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 2df60651a..e0cbe49de 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -121,12 +121,12 @@ def add_train( st = self.cf.streams[0] for loss_name, loss_values in losses_all.items(): - metrics[f"{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() + metrics[f"loss.{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() for k, ch_n in enumerate(st.train_target_channels): - metrics[f"{loss_name}.loss_{ch_n}"] = loss_values[:, k].nanmean().item() + metrics[f"loss.{loss_name}.{ch_n}"] = loss_values[:, k].nanmean().item() log_vals += [loss_values[:, :].nanmean().item()] for loss_name, stddev_values in stddev_all.items(): - metrics[f"{loss_name}.stddev_avg"] = stddev_values.nanmean().item() + metrics[f"loss.{loss_name}.stddev_avg"] = stddev_values.nanmean().item() log_vals += [stddev_values.nanmean().item()] with open(self.path_run / f"{self.cf.run_id}_train_log.txt", "ab") as f: @@ -156,12 +156,12 @@ def add_val( st = self.cf.streams[0] for loss_name, loss_values in losses_all.items(): - metrics[f"{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() + metrics[f"loss.{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() for k, ch_n in enumerate(st.train_target_channels): - metrics[f"{loss_name}.loss_{ch_n}"] = loss_values[:, k].nanmean().item() + metrics[f"loss.{loss_name}.{ch_n}"] = loss_values[:, k].nanmean().item() log_vals += [loss_values[:, :].nanmean().item()] for loss_name, stddev_values in stddev_all.items(): - metrics[f"{loss_name}.stddev_avg"] = stddev_values.nanmean().item() + metrics[f"loss.{loss_name}.stddev_avg"] = stddev_values.nanmean().item() log_vals += [stddev_values.nanmean().item()] self.log_metrics("val", metrics) From c6805c4fc9fc6b47021daa43ffea26d9138c10f4 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 16:09:49 +0100 Subject: [PATCH 056/344] add loss_functions.py --- .../train/loss_modules/loss_functions.py | 257 ++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 src/weathergen/train/loss_modules/loss_functions.py diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py new file mode 100644 index 000000000..e09928dd1 --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -0,0 +1,257 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import numpy as np +import torch + +stat_loss_fcts = ["stats", "kernel_crps"] # Names of loss functions that need std computed + + +def gaussian(x, mu=0.0, std_dev=1.0): + # unnormalized Gaussian where maximum is one + return torch.exp(-0.5 * (x - mu) * (x - mu) / (std_dev * std_dev)) + + +def normalized_gaussian(x, mu=0.0, std_dev=1.0): + return (1 / (std_dev * np.sqrt(2.0 * np.pi))) * torch.exp( + -0.5 * (x - mu) * (x - mu) / (std_dev * std_dev) + ) + + +def erf(x, mu=0.0, std_dev=1.0): + c1 = torch.sqrt(torch.tensor(0.5 * np.pi)) + c2 = torch.sqrt(1.0 / torch.tensor(std_dev * std_dev)) + c3 = torch.sqrt(torch.tensor(2.0)) + val = c1 * (1.0 / c2 - std_dev * torch.special.erf((mu - x) / (c3 * std_dev))) + return val + + +def gaussian_crps(target, ens, mu, stddev): + # see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather + # forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018. + c1 = np.sqrt(1.0 / np.pi) + t1 = 2.0 * erf((target - mu) / stddev) - 1.0 + t2 = 2.0 * normalized_gaussian((target - mu) / stddev) + val = stddev * ((target - mu) / stddev * t1 + t2 - c1) + return torch.mean(val) # + torch.mean( torch.sqrt( stddev) ) + + +def stats(target, ens, mu, stddev): + diff = gaussian(target, mu, stddev) - 1.0 + return torch.mean(diff * diff) + torch.mean(torch.sqrt(stddev)) + + +def stats_normalized(target, ens, mu, stddev): + a = normalized_gaussian(target, mu, stddev) + max = 1 / (np.sqrt(2 * np.pi) * stddev) + d = a - max + return torch.mean(d * d) + torch.mean(torch.sqrt(stddev)) + + +def stats_normalized_erf(target, ens, mu, stddev): + delta = -torch.abs(target - mu) + d = 0.5 + torch.special.erf(delta / (np.sqrt(2.0) * stddev)) + return torch.mean(d * d) # + torch.mean( torch.sqrt( stddev) ) + + +def mse(target, ens, mu, *kwargs): + return torch.nn.functional.mse_loss(target, mu) + + +def mse_ens(target, ens, mu, stddev): + mse_loss = torch.nn.functional.mse_loss + return torch.stack([mse_loss(target, mem) for mem in ens], 0).mean() + + +def kernel_crps( + targets, + preds, + weights_channels: torch.Tensor | None, + weights_points: torch.Tensor | None, + fair=True, +): + """ + Compute kernel CRPS + + Params: + target : shape ( num_data_points , num_channels ) + pred : shape ( ens_dim , num_data_points , num_channels) + weights_channels : shape = (num_channels,) + weights_points : shape = (num_data_points) + + Returns: + loss: scalar - overall weighted CRPS + loss_chs: [C] - per-channel CRPS (location-weighted, not channel-weighted) + """ + + ens_size = preds.shape[0] + assert ens_size > 1, "Ensemble size has to be greater than 1 for kernel CRPS." + assert len(preds.shape) == 3, "if data has batch dimension, remove unsqueeze() below" + + # replace NaN by 0 + mask_nan = ~torch.isnan(targets) + targets = torch.where(mask_nan, targets, 0) + preds = torch.where(mask_nan, preds, 0) + + # permute to enable/simply broadcasting and contractions below + preds = preds.permute([2, 1, 0]).unsqueeze(0).to(torch.float32) + targets = targets.permute([1, 0]).unsqueeze(0).to(torch.float32) + + mae = torch.mean(torch.abs(targets[..., None] - preds), dim=-1) + + ens_n = -1.0 / (ens_size * (ens_size - 1)) if fair else -1.0 / (ens_size**2) + abs = torch.abs + ens_var = torch.zeros(size=preds.shape[:-1], device=preds.device) + # loop to reduce memory usage + for i in range(ens_size): + ens_var += torch.sum(ens_n * abs(preds[..., i].unsqueeze(-1) - preds[..., i + 1 :]), dim=-1) + + kcrps_locs_chs = mae + ens_var + + # apply point weighting + if weights_points is not None: + kcrps_locs_chs = kcrps_locs_chs * weights_points + # apply channel weighting + kcrps_chs = torch.mean(torch.mean(kcrps_locs_chs, 0), -1) + if weights_channels is not None: + kcrps_chs = kcrps_chs * weights_channels + + return torch.mean(kcrps_chs), kcrps_chs + + +def mse_channel_location_weighted( + target: torch.Tensor, + pred: torch.Tensor, + weights_channels: torch.Tensor | None, + weights_points: torch.Tensor | None, +): + """ + Compute weighted MSE loss for one window or step + + The function implements: + + loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points )) + + Geometrically, + + ------------------------ - + | | | | + | | | | + | | | | + | target - pred | x |wp| + | | | | + | | | | + | | | | + ------------------------ - + x + ------------------------ + | wc | + ------------------------ + + where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication. + + The computations are: + 1. weight the rows of (target - pred) by wp = weights_points + 2. take the mean over the row + 3. weight the collapsed cols by wc = weights_channels + 4. take the mean over the channel-weighted cols + + Params: + target : shape ( num_data_points , num_channels ) + target : shape ( ens_dim , num_data_points , num_channels) + weights_channels : shape = (num_channels,) + weights_points : shape = (num_data_points) + + Return: + loss : weight loss for gradient computation + loss_chs : losses per channel with location weighting but no channel weighting + """ + + mask_nan = ~torch.isnan(target) + pred = pred[0] if pred.shape[0] == 0 else pred.mean(0) + + diff2 = torch.square(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)) + if weights_points is not None: + diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0) + loss_chs = diff2.mean(0) + loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) + + return loss, loss_chs + + +def mae( + target: torch.Tensor, + pred: torch.Tensor, + weights_channels: torch.Tensor | None, + weights_points: torch.Tensor | None, +): + """ + Compute weighted MAE loss for one window or step + + The function implements: + + loss = Mean_{channels}( weight_channels * Mean_{data_pts}( (target - pred) * weights_points )) + + Geometrically, + + ------------------------ - + | | | | + | | | | + | | | | + | target - pred | x |wp| + | | | | + | | | | + | | | | + ------------------------ - + x + ------------------------ + | wc | + ------------------------ + + where wp = weights_points and wc = weights_channels and "x" denotes row/col-wise multiplication. + + The computations are: + 1. weight the rows of (target - pred) by wp = weights_points + 2. take the mean over the row + 3. weight the collapsed cols by wc = weights_channels + 4. take the mean over the channel-weighted cols + + Params: + target : shape ( num_data_points , num_channels ) + target : shape ( ens_dim , num_data_points , num_channels) + weights_channels : shape = (num_channels,) + weights_points : shape = (num_data_points) + + Return: + loss : weight loss for gradient computation + loss_chs : losses per channel with location weighting but no channel weighting + """ + + mask_nan = ~torch.isnan(target) + pred = pred[0] if pred.shape[0] == 0 else pred.mean(0) + + diff2 = torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0) + if weights_points is not None: + diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0) + loss_chs = diff2.mean(0) + loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) + + return loss, loss_chs + + +def cosine_latitude(stream_data, forecast_offset, fstep, min_value=1e-3, max_value=1.0): + latitudes_radian = stream_data.target_coords_raw[forecast_offset + fstep][:, 0] * np.pi / 180 + return (max_value - min_value) * np.cos(latitudes_radian) + min_value + + +def gamma_decay(forecast_steps, gamma): + fsteps = np.arange(forecast_steps) + weights = gamma**fsteps + return weights * (len(fsteps) / np.sum(weights)) From 7ac9e6b1b4a2f76476cd1d3c2695d31c99283044 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 18 Nov 2025 16:25:52 +0100 Subject: [PATCH 057/344] rm loss_fcts in default config --- config/default_config.yml | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 53c9ac560..ea2777a86 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -69,15 +69,6 @@ latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True -loss_fcts: - - - - "mse" - - 1.0 -loss_fcts_val: - - - - "mse" - - 1.0 - batch_size_per_gpu: 1 batch_size_validation_per_gpu: 1 @@ -124,8 +115,8 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], } num_mini_epochs: 32 -samples_per_mini_epoch: 4096 -samples_per_validation: 512 +samples_per_mini_epoch: 32 +samples_per_validation: 8 shuffle: True From 85fa139192fa930b2be7425fe879a349caac7ee9 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 18 Nov 2025 16:28:46 +0100 Subject: [PATCH 058/344] Comments --- src/weathergen/datasets/multi_stream_data_sampler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index d3529781d..2de3780d7 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -415,6 +415,8 @@ def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): """ mode : {student, teacher, mtm} + idx : + forecast_dt : """ streams_data: list[StreamData] = [] @@ -441,6 +443,8 @@ def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): return streams_data def _preprocess_model_data(self, batch, forecast_dt): + """ """ + # aggregated lens of tokens per cell across input batch samples source_cell_lens = compute_source_cell_lens(batch, self.num_input_steps) @@ -483,8 +487,6 @@ def __iter__(self): idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] idx_raw += 1 - time_win_source = self.time_window_handler.window(idx) - # Sample masking strategy once per batch item if hasattr(self.tokenizer, "masker"): self.tokenizer.masker.set_batch_strategy() @@ -502,6 +504,7 @@ def __iter__(self): if not (all(s.empty() or s.target_empty() for s in streams_data)): batch += [streams_data] + # compute batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch, forecast_dt ) From c1580c4b7a314638be352fd0dbf3a6c6aee4dc44 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 18 Nov 2025 16:30:44 +0100 Subject: [PATCH 059/344] Renaming --- src/weathergen/datasets/{inputs_metadata.py => batch.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/weathergen/datasets/{inputs_metadata.py => batch.py} (100%) diff --git a/src/weathergen/datasets/inputs_metadata.py b/src/weathergen/datasets/batch.py similarity index 100% rename from src/weathergen/datasets/inputs_metadata.py rename to src/weathergen/datasets/batch.py From 3c26ddc1c1d8e0525d815a76cf56f5e97fb56997 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 18 Nov 2025 17:32:00 +0000 Subject: [PATCH 060/344] updated default config training_config to allow student-teacher --- config/default_config.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 1f0a25810..c0007da6e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -109,23 +109,26 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], } # Student-teacher configuration (only used when training_mode == "student_teacher") +# TODO: adapt so that the masking or forecast config entry also sits here training_config: # when this is "masking", we are basically only using the model_input subconfig training_mode: "student_teacher" # "masking", "student_teacher", "forecast" + model_input: masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher rate: 0.5 # Masking rate to use for model input num_views: 4 # if student-teacher, the number of local (student) views to generate - masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + masking_strategy_config: {"strategies": ["random", "healpix", "channel"], # will be used with masking is moved under here "probabilities": [0.34, 0.33, 0.33], "hl_mask": 0, "mode": "per_cell", "same_strategy_per_batch": false } + relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. teacher_model_input: strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + rate: 0.5 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) # keep_m: 100 # Alternative to rate: keep exactly this many parent cells rate_sampling: true # randomly sample the rate per batch masking_strategy_config: {"strategies": ["random", "healpix", "channel"], From 66cf9cdc8a918e2fc56aff86ba800e1aaa1d1597 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 18 Nov 2025 17:33:08 +0000 Subject: [PATCH 061/344] added stream id to era5 config --- config/streams/era5_1deg/era5.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index bb2234c4e..eace84bfe 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -10,6 +10,7 @@ ERA5 : type : anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. From 36ea28765dd1c0a9a3d88d3d120d5f9500f317c7 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 18 Nov 2025 17:33:53 +0000 Subject: [PATCH 062/344] slight restructure of ViewMetadata --- src/weathergen/datasets/inputs_metadata.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/weathergen/datasets/inputs_metadata.py b/src/weathergen/datasets/inputs_metadata.py index 2143163a1..c42a2f5f0 100644 --- a/src/weathergen/datasets/inputs_metadata.py +++ b/src/weathergen/datasets/inputs_metadata.py @@ -34,15 +34,20 @@ class ViewMetadata: rate: Fraction of data kept (e.g., 0.5 = 50% kept); None if fixed count parent_view_id: ID of the parent view this is a subset of (None for teacher) """ - - loss_type: str # DINO, JEPA... ? - strategy: str # "cropping", "masking", "forecasting", "forecasting_diffusion" - strategy_config: dict # rate: 0.5 etc., healpix_level: int etc., overlap: "disjoint" etc., + # Core identifiers and selection description view_id: str - parent_view_id: Optional[str] = None # For students: which teacher they belong to keep_mask: np.ndarray # [num_cells] bool at data level + strategy: str # e.g. "random", "healpix", "channel" + + # Hierarchical/quantitative description of selection + healpix_level: Optional[int] = None + rate: Optional[float] = None + parent_view_id: Optional[str] = None # For students: which teacher they belong to + + # Optional extras for future/other training paradigms + loss_type: Optional[str] = None # e.g. DINO, JEPA + strategy_config: Optional[dict] = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} - # TODO: This doesn't handle the masking case, and we probably want it to, # where the model_inputs are the correct data for the masked source (and target?). Or target becomes the target? From 11ad4e659ca0c33b0b3e533171562ea50986715a Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 18 Nov 2025 17:34:19 +0000 Subject: [PATCH 063/344] basic if statement to yield the student and teacher views --- .../datasets/multi_stream_data_sampler.py | 74 ++++++++++++++++--- 1 file changed, 63 insertions(+), 11 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index d3529781d..6492452c0 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -27,6 +27,7 @@ from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData, spoof from weathergen.datasets.tokenizer_masking import TokenizerMasking +from weathergen.datasets.view_builder import build_views_for_stream from weathergen.datasets.utils import ( compute_idxs_predict, compute_offsets_scatter_embed, @@ -223,6 +224,8 @@ def __init__( self.healpix_level: int = cf.healpix_level self.num_healpix_cells: int = 12 * 4**self.healpix_level + self.training_cfg = cf.get("training_config", None) + masker = Masker(cf) self.tokenizer = TokenizerMasking(cf.healpix_level, masker) @@ -321,7 +324,8 @@ def _build_stream_data_source( # view_meta: ViewMetadata, stream_info: dict, stream_ds: list, - ) -> StreamData: + keep_mask: torch.Tensor | None = None, + ) -> tuple[StreamData, dict | None]: """ Return one batch of data Build a StreamData object for a single view (teacher or student). @@ -359,16 +363,19 @@ def _build_stream_data_source( stream_data.source_is_spoof = True # preprocess data for model input + # keep_mask used in student-teacher training to + # link the source mask between student and teacher (ss_cells, ss_lens, mask_state) = self.tokenizer.batchify_source( stream_info, readerdata_to_torch(rdata), (time_win_source.start, time_win_source.end), + keep_mask=keep_mask, ) # collect data for stream stream_data.add_source(step, rdata, ss_lens, ss_cells) - return stream_data + return stream_data, mask_state def _build_stream_data_target( self, @@ -379,6 +386,7 @@ def _build_stream_data_target( # view_meta: ViewMetadata, stream_info: dict, stream_ds: list, + mask_state: dict | None = None, ) -> StreamData: # collect for all forecast steps for fstep in range(self.forecast_offset, self.forecast_offset + forecast_dt + 1): @@ -400,21 +408,24 @@ def _build_stream_data_target( stream_data.target_is_spoof = True # preprocess data for model input + # carry around mask_state from source tokenizer call (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( stream_info, self.sampling_rate_target, readerdata_to_torch(rdata), (time_win_target.start, time_win_target.end), + mask_state=mask_state, ) stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) return stream_data - def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): + def _get_sample_data(self, mode: str, idx: int, forecast_dt: int, keep_mask: torch.Tensor | None = None): """ - + mode : {student, teacher, mtm} + TODO: these modes are not being used now. """ streams_data: list[StreamData] = [] @@ -426,13 +437,13 @@ def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): ) # collect source data for current stream - stream_data = self._build_stream_data_source( - mode, stream_data, idx, forecast_dt, stream_info, stream_ds + stream_data, mask_state = self._build_stream_data_source( + mode, stream_data, idx, forecast_dt, stream_info, stream_ds, keep_mask=keep_mask ) - # collect target data for current stream + # collect target data for current stream (aligned with source mask_state) stream_data = self._build_stream_data_target( - mode, stream_data, idx, forecast_dt, stream_info, stream_ds + mode, stream_data, idx, forecast_dt, stream_info, stream_ds, mask_state=mask_state ) # add data for current stream @@ -489,10 +500,49 @@ def __iter__(self): if hasattr(self.tokenizer, "masker"): self.tokenizer.masker.set_batch_strategy() - mode = "teacher" + # TODO: ideally update this student-teacher if-else to a more general + # view-based data sampling + if self.training_cfg.get("training_mode") == "student_teacher": + # Build teacher + student masks once per batch item + teacher_cfg = self.training_cfg.get("teacher_model_input", {}) + student_cfg = self.training_cfg.get("model_input", {}) + relationship = student_cfg.get("relationship") + + # use build_views_for_stream utility to create student and teacher masks + t_keep_np, s_keeps_np, _meta = build_views_for_stream( + self.tokenizer.masker, + self.num_healpix_cells, + teacher_cfg=teacher_cfg, + student_cfg=student_cfg, + relationship=relationship, + ) - # tokenizer.generate_masks_for_sample() - streams_data = self._get_sample_data(mode, idx, forecast_dt) + # Convert to torch.bool + def to_bool_tensor(arr): + if arr is None: + return None + return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) + + t_keep_t = to_bool_tensor(t_keep_np) + s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] + + # Teacher view + streams_data = self._get_sample_data("teacher", idx, forecast_dt, keep_mask=t_keep_t) + + # Students (build but do not change yielded batch shape yet) + # For each student view (set in the config) build separate StreamData + student_streams_data = [] + for s_keep_t in s_keep_t_list: + # do not do anything with this, just it is here. + student_stream_data = self._get_sample_data("student", idx, forecast_dt, keep_mask=s_keep_t) + student_streams_data.append(student_stream_data) + + # TODO: to pass around the correct source_cell_lens and target coords + # Somehow coming from _preprocess_model_data + + else: + # Standard masking/forecast path + streams_data = self._get_sample_data("masking", idx, forecast_dt) # Reset masking strategy for next batch item if hasattr(self.tokenizer, "masker"): @@ -502,6 +552,8 @@ def __iter__(self): if not (all(s.empty() or s.target_empty() for s in streams_data)): batch += [streams_data] + # TODO: link into ModelBatch + batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch, forecast_dt ) From 2536cecb314c58a3245bfd2367b5426517841c78 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 18 Nov 2025 17:40:26 +0000 Subject: [PATCH 064/344] correct imports with new batch.py --- src/weathergen/datasets/tokenizer_masking.py | 2 +- src/weathergen/datasets/view_builder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 9c103126b..e364c5582 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -14,7 +14,7 @@ from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer from weathergen.datasets.view_builder import build_views_for_stream -from weathergen.datasets.inputs_metadata import ViewMetadata +from weathergen.datasets.batch import ViewMetadata from weathergen.datasets.tokenizer_utils import ( encode_times_source, encode_times_target, diff --git a/src/weathergen/datasets/view_builder.py b/src/weathergen/datasets/view_builder.py index cd7228b49..21306826c 100644 --- a/src/weathergen/datasets/view_builder.py +++ b/src/weathergen/datasets/view_builder.py @@ -1,7 +1,7 @@ import numpy as np from typing import Tuple, List from weathergen.datasets.masking import Masker -from weathergen.datasets.inputs_metadata import ViewMetadata +from weathergen.datasets.batch import ViewMetadata def build_views_for_stream( From 31dc658a445eff27d833d856b971b68dc66d1e2a Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Wed, 19 Nov 2025 11:04:29 +0000 Subject: [PATCH 065/344] created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views. --- .../datasets/multi_stream_data_sampler.py | 79 +++++++++++-------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 53a46dc8f..aab0c336d 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -453,6 +453,52 @@ def _get_sample_data(self, mode: str, idx: int, forecast_dt: int, keep_mask: tor return streams_data + def _get_student_teacher_sample_data(self, idx: int, forecast_dt: int): + """ + Return one batch of data + Build a StreamData object for a single view (teacher or student). + + Args: + idx: Time index for this sample + forecast_dt: Number of forecast steps + """ + + teacher_cfg = self.training_cfg.get("teacher_model_input", {}) + student_cfg = self.training_cfg.get("model_input", {}) + relationship = student_cfg.get("relationship") + + # use build_views_for_stream utility to create student and teacher masks + t_keep_np, s_keeps_np, _meta = build_views_for_stream( + self.tokenizer.masker, + self.num_healpix_cells, + teacher_cfg=teacher_cfg, + student_cfg=student_cfg, + relationship=relationship, + ) + + # Convert to torch.bool + def to_bool_tensor(arr): + if arr is None: + return None + return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) + + t_keep_t = to_bool_tensor(t_keep_np) + s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] + + # Teacher view + streams_data = self._get_sample_data("teacher", idx, forecast_dt, keep_mask=t_keep_t) + + # Students (build but do not change yielded batch shape yet) + # For each student view (set in the config) build separate StreamData + student_streams_data = [] + for s_keep_t in s_keep_t_list: + # do not do anything with this, just it is here. + student_stream_data = self._get_sample_data("student", idx, forecast_dt, keep_mask=s_keep_t) + student_streams_data.append(student_stream_data) + + return streams_data, student_streams_data + + def _preprocess_model_data(self, batch, forecast_dt): """ """ @@ -505,39 +551,8 @@ def __iter__(self): # TODO: ideally update this student-teacher if-else to a more general # view-based data sampling if self.training_cfg.get("training_mode") == "student_teacher": - # Build teacher + student masks once per batch item - teacher_cfg = self.training_cfg.get("teacher_model_input", {}) - student_cfg = self.training_cfg.get("model_input", {}) - relationship = student_cfg.get("relationship") - - # use build_views_for_stream utility to create student and teacher masks - t_keep_np, s_keeps_np, _meta = build_views_for_stream( - self.tokenizer.masker, - self.num_healpix_cells, - teacher_cfg=teacher_cfg, - student_cfg=student_cfg, - relationship=relationship, - ) - # Convert to torch.bool - def to_bool_tensor(arr): - if arr is None: - return None - return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) - - t_keep_t = to_bool_tensor(t_keep_np) - s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] - - # Teacher view - streams_data = self._get_sample_data("teacher", idx, forecast_dt, keep_mask=t_keep_t) - - # Students (build but do not change yielded batch shape yet) - # For each student view (set in the config) build separate StreamData - student_streams_data = [] - for s_keep_t in s_keep_t_list: - # do not do anything with this, just it is here. - student_stream_data = self._get_sample_data("student", idx, forecast_dt, keep_mask=s_keep_t) - student_streams_data.append(student_stream_data) + streams_data, student_streams_data = self._get_student_teacher_sample_data(idx, forecast_dt) # TODO: to pass around the correct source_cell_lens and target coords # Somehow coming from _preprocess_model_data From a824bfccf0086112ef4f646da90ce3ac4580a284 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 12:23:47 +0100 Subject: [PATCH 066/344] Not working draft for restructuring --- .../datasets/multi_stream_data_sampler.py | 175 +++++++++++++++--- 1 file changed, 148 insertions(+), 27 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 2de3780d7..c007bc8dd 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -35,6 +35,8 @@ from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage +from weathergen.datasets.batch import ModelBatch + type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs logger = logging.getLogger(__name__) @@ -312,15 +314,19 @@ def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here return self.streams_datasets[stream_id][0].denormalize_target_channels(data) - def _build_stream_data_source( + + + + + + def _build_stream_data_input( self, mode: str, stream_data: StreamData, base_idx: TIndex, - forecast_dt: int, # view_meta: ViewMetadata, stream_info: dict, - stream_ds: list, + source_data: list ) -> StreamData: """ Return one batch of data @@ -345,7 +351,7 @@ def _build_stream_data_source( time_win_source = self.time_window_handler.window(idx) # collect all targets for current stream - rdata: IOReaderData = collect_datasources(stream_ds, idx, "source") + rdata = source_data[step] if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 @@ -359,14 +365,15 @@ def _build_stream_data_source( stream_data.source_is_spoof = True # preprocess data for model input - (ss_cells, ss_lens, mask_state) = self.tokenizer.batchify_source( + (source_cells, source_cells_lens, mask_state) = self.tokenizer.get_source( + mode, stream_info, readerdata_to_torch(rdata), (time_win_source.start, time_win_source.end), ) # collect data for stream - stream_data.add_source(step, rdata, ss_lens, ss_cells) + stream_data.add_source(step, rdata, source_cells, source_cells_lens) return stream_data @@ -375,18 +382,20 @@ def _build_stream_data_target( mode: str, stream_data: StreamData, idx: TIndex, - forecast_dt: int, # view_meta: ViewMetadata, stream_info: dict, - stream_ds: list, + forecast_dt: int, + target_data: list, ) -> StreamData: + # collect for all forecast steps - for fstep in range(self.forecast_offset, self.forecast_offset + forecast_dt + 1): + dt = self.forecast_offset + forecast_dt + for step, fstep in enumerate(range(self.forecast_offset, dt + 1)): step_forecast_dt = idx + (self.forecast_delta_hrs * fstep) // self.step_hrs time_win_target = self.time_window_handler.window(step_forecast_dt) # collect all targets for current stream - rdata: IOReaderData = collect_datasources(stream_ds, step_forecast_dt, "target") + rdata = target_data[step] if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 @@ -400,41 +409,152 @@ def _build_stream_data_target( stream_data.target_is_spoof = True # preprocess data for model input - (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( - stream_info, - self.sampling_rate_target, - readerdata_to_torch(rdata), - (time_win_target.start, time_win_target.end), + if "input" in mode : + (tc, tc_l) = self.tokenizer.get_target_coords( + stream_info, + self.sampling_rate_target, + readerdata_to_torch(rdata), + (time_win_target.start, time_win_target.end), + ) + stream_data.add_target_coords(fstep, tc, tc_l) + else : + (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( + stream_info, + self.sampling_rate_target, + readerdata_to_torch(rdata), + (time_win_target.start, time_win_target.end), + ) + stream_data.add_target_values(fstep, tt_cells tt_c, tt_t, idxs_inv) + + return stream_data + + + def _build_stream_data( + self, + mode: str, + stream_data: StreamData, + base_idx: TIndex, + forecast_dt: int, + # view_meta: ViewMetadata, + stream_info: dict, + source_data: list + target_data: list + ) -> StreamData: + """ + Return one batch of data + Build a StreamData object for a single view (teacher or student). + + Args: + mode : {student, teacher, physical} + stream_data : + base_idx: Time index for this sample + forecast_dt: Number of forecast steps + view_meta: ViewMetadata describing spatial mask + stream_info: Stream configuration dict + stream_ds: List of dataset readers for this stream + + Returns: + StreamData with source and targets masked according to view_meta + """ + + stream_data = self._build_stream_data_input( + mode, stream_data, base_idx, stream_info, input_data + ) + + # physical space + if "physical" in mode: + stream_data = self._build_stream_data_target( + mode, stream_data, base_idx, stream_info, stream_ds, forecast_dt ) - stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) + def _get_data( self, base_idx, forecast_dt, stream_ds) : + """ - return stream_data + """ + + # source data: iterate overall input steps + input_data = [] + for step, idx in enumerate(range( base_idx - self.num_input_steps, base_idx + 1)): + # TODO: check that we are not out of bounds when we go back in time + source_data += [ collect_datasources(stream_ds, idx, "source") ] + + # target data: collect for all forecast steps + output_data = [] + for fstep in range(self.forecast_offset, self.forecast_offset + forecast_dt + 1): + step_forecast_dt = base_idx + (self.forecast_delta_hrs * fstep) // self.step_hrs + # collect all targets for current stream + target_data += [ collect_datasources(stream_ds, step_forecast_dt, "target") ] + + return (input_data, target_data) + + + def _tokenize_data( self, stream_info, input_data, output_data) : + """ + Tokenize data (to amortize over the different views that are generated) + + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", False) + tok = tokenize_spacetime if tok_spacetime else tokenize_space + hl = self.healpix_level + token_size = stream_info["token_size"] + + input_tokens = [] + for rdata in input_data: + idxs_cells, idxs_cells_lens = tok(rdata, token_size, hl, pad_tokens=True) + input_tokens += [(idxs_cells, idxs_cells_lens)] + + output_tokens = [] + for rdata in output_data: + idxs_cells, idxs_cells_lens = tok(rdata, token_size, hl, pad_tokens=False) + output_tokens += [(idxs_cells, idxs_cells_lens)] - def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): + # TODO: target_coords are expensive + + return (input_tokens, output_tokens) + + def _get_sample(self, mode: str, idx: int, forecast_dt: int): """ - mode : {student, teacher, mtm} + modes : + ('student', 'teacher') + ('physical_input', 'physical_target') idx : forecast_dt : """ + dt = forecast_dt + self.forecast_offset streams_data: list[StreamData] = [] # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - stream_data = StreamData( - idx, forecast_dt + self.forecast_offset, self.num_healpix_cells - ) + + # input_data and output_data is conceptually consecutive but differs + # in source and target channels; overlap in one window when self.forecast_offset=0 + (input_data, output_data) = self._get_data( idx, forecast_dt, stream_ds) + + # tokenize windows + (input_tokens, output_tokens) = self._get_tokens( stream_info, input_data, output_data) + + # source_input + # target_input + # source_output + # target_output + + # get + masks = build_views_for_stream( modes, input_tokens, output_tokens) # collect source data for current stream - stream_data = self._build_stream_data_source( - mode, stream_data, idx, forecast_dt, stream_info, stream_ds + stream_data_source = StreamData( idx, dt, self.num_healpix_cells) + stream_data_source = self._build_stream_data( + "teacher", stream_data, idx, forecast_dt, stream_info, input_tokens, output_tokens ) # collect target data for current stream - stream_data = self._build_stream_data_target( - mode, stream_data, idx, forecast_dt, stream_info, stream_ds + # stream_data_target can contain network input + stream_data_target = StreamData( idx, dt, self.num_healpix_cells) + stream_data_target = self._build_stream_data( + "student", stream_data, idx, forecast_dt, stream_info, input_tokens, output_tokens ) # add data for current stream @@ -442,6 +562,7 @@ def _get_sample_data(self, mode: str, idx: int, forecast_dt: int): return streams_data + def _preprocess_model_data(self, batch, forecast_dt): """ """ @@ -494,7 +615,7 @@ def __iter__(self): mode = "teacher" # tokenizer.generate_masks_for_sample() - streams_data = self._get_sample_data(mode, idx, forecast_dt) + streams_data = self._get_sample( mode, idx, forecast_dt) # Reset masking strategy for next batch item if hasattr(self.tokenizer, "masker"): From 81cf929464a61a4e2806514a328f7f66ffae7dbd Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 15:58:57 +0100 Subject: [PATCH 067/344] Changes for better student teacher structure --- src/weathergen/datasets/tokenizer_masking.py | 31 +++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index e364c5582..525cb1e61 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -11,18 +11,16 @@ import torch from weathergen.common.io import IOReaderData +from weathergen.datasets.batch import ViewMetadata from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer -from weathergen.datasets.view_builder import build_views_for_stream -from weathergen.datasets.batch import ViewMetadata from weathergen.datasets.tokenizer_utils import ( encode_times_source, encode_times_target, tokenize_apply_mask_source, tokenize_apply_mask_target, - tokenize_space, - tokenize_spacetime, ) +from weathergen.datasets.view_builder import build_views_for_stream class TokenizerMasking(Tokenizer): @@ -43,6 +41,7 @@ def batchify_source( self, stream_info: dict, rdata: IOReaderData, + idxs_cells_data, time_win: tuple, keep_mask: torch.Tensor | None = None, ): @@ -55,12 +54,17 @@ def batchify_source( if is_diagnostic or rdata.data.shape[1] == 0 or len(rdata.data) < 2: source_tokens_cells = [torch.tensor([])] source_tokens_lens = torch.zeros([self.num_healpix_cells_source], dtype=torch.int32) - mask_state = {"strategy": self.masker.current_strategy, "mask_tokens": None, "mask_channels": None} + mask_state = { + "strategy": self.masker.current_strategy, + "mask_tokens": None, + "mask_channels": None, + } return (source_tokens_cells, source_tokens_lens, mask_state) - # create tokenization index - tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space - idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=True) + # # create tokenization index + # tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space + # idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=True) + (idxs_cells, idxs_cells_lens) = idxs_cells_data # select strategy from XXX depending on stream and if student or teacher @@ -93,6 +97,8 @@ def batchify_source( "mask_tokens": mask_tokens, "mask_channels": mask_channels, } + self.mask_state = mask_state + return (source_tokens_cells, source_tokens_lens, mask_state) # batchify_target_for_view now unified into batchify_target via optional mask_state @@ -102,14 +108,18 @@ def batchify_target( stream_info: dict, sampling_rate_target: float, rdata: IOReaderData, + token_data, time_win: tuple, mask_state: dict | None = None, ): token_size = stream_info["token_size"] # create tokenization index - tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space - idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) + # tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space + # idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) + (idxs_cells, idxs_cells_lens) = token_data + + mask_state = self.mask_state # Apply per-view mask state if provided if mask_state is not None: @@ -140,7 +150,6 @@ def batchify_target( return (data, datetimes, coords, coords_local, coords_per_cell, idxs_ord_inv) - # ------------------------------------------------------------------ # Per-stream view construction (teacher + students) for student-teacher # ------------------------------------------------------------------ From 46147d4125a638b097c75ec635290dfbe0f69ec1 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 17:01:29 +0100 Subject: [PATCH 068/344] More refactoring --- .../datasets/multi_stream_data_sampler.py | 101 ++++-------------- src/weathergen/datasets/tokenizer.py | 1 + src/weathergen/datasets/tokenizer_masking.py | 52 +++++++-- 3 files changed, 69 insertions(+), 85 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 0fb7f607d..fcd4f3692 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -27,10 +27,6 @@ from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData, spoof from weathergen.datasets.tokenizer_masking import TokenizerMasking -from weathergen.datasets.tokenizer_utils import ( - tokenize_space, - tokenize_spacetime, -) from weathergen.datasets.utils import ( compute_idxs_predict, compute_offsets_scatter_embed, @@ -327,7 +323,6 @@ def _build_stream_data_input( mode: str, stream_data: StreamData, base_idx: TIndex, - # view_meta: ViewMetadata, stream_info: dict, input_data: list, input_tokens: list, @@ -365,7 +360,7 @@ def _build_stream_data_input( # preprocess data for model input (source_cells, source_cells_lens, mask_state) = self.tokenizer.batchify_source( stream_info, - readerdata_to_torch(rdata), + rdata, token_data, (time_win_source.start, time_win_source.end), keep_mask=mask, @@ -376,17 +371,16 @@ def _build_stream_data_input( return stream_data, mask_state - def _build_stream_data_target( + def _build_stream_data_output( self, mode: str, stream_data: StreamData, idx: TIndex, - # view_meta: ViewMetadata, stream_info: dict, forecast_dt: int, output_data: list, output_tokens: list, - mask: dict | None = None, + mask_state: dict | None = None, ) -> StreamData: # collect for all forecast steps dt = self.forecast_offset + forecast_dt @@ -398,48 +392,20 @@ def _build_stream_data_target( rdata = output_data[step] token_data = output_tokens[step] - # if rdata.is_empty(): - # # work around for https://github.com/pytorch/pytorch/issues/158719 - # # create non-empty mean data instead of empty tensor - # rdata = spoof( - # self.healpix_level, - # time_win_target.start, - # stream_ds[0].get_geoinfo_size(), - # stream_ds[0].mean[stream_ds[0].target_idx], - # ) - # TODO: # stream_data.target_is_spoof = True (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( stream_info, self.sampling_rate_target, - readerdata_to_torch(rdata), + rdata, token_data, (time_win_target.start, time_win_target.end), - mask, + mask_state, ) stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) - # # preprocess data for model input - # if "input" in mode : - # (tc, tc_l) = self.tokenizer.get_target_coords( - # stream_info, - # self.sampling_rate_target, - # readerdata_to_torch(rdata), - # token_data, - # (time_win_target.start, time_win_target.end), - # ) - # stream_data.add_target_coords(fstep, tc, tc_l) - # else : - # (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( - # stream_info, - # self.sampling_rate_target, - # readerdata_to_torch(rdata), - # token_data, - # (time_win_target.start, time_win_target.end), - # ) - # stream_data.add_target_values(fstep, tt_cells, tt_c, tt_t, idxs_inv) + # TODO: separate target_coords and target_value computation? return stream_data @@ -476,18 +442,19 @@ def _build_stream_data( dt = self.forecast_offset + forecast_dt stream_data = StreamData(base_idx, dt, self.num_healpix_cells) - stream_data, _ = self._build_stream_data_input( + stream_data, mask_state = self._build_stream_data_input( mode, stream_data, base_idx, stream_info, input_data, input_tokens, + mask, ) # physical space if "physical" in mode: - stream_data = self._build_stream_data_target( + stream_data = self._build_stream_data_output( mode, stream_data, base_idx, @@ -495,7 +462,7 @@ def _build_stream_data( forecast_dt, output_data, output_tokens, - mask, + mask_state, ) return stream_data @@ -513,9 +480,10 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor + time_win = self.time_window_handler.window(idx) rdata = spoof( self.healpix_level, - time_win_source.start, + time_win.start, stream_ds[0].get_geoinfo_size(), stream_ds[0].mean[stream_ds[0].source_idx], ) @@ -533,9 +501,10 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): if rdata.is_empty(): # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor + time_win = self.time_window_handler.window(idx) rdata = spoof( self.healpix_level, - time_win_source.start, + time_win.start, stream_ds[0].get_geoinfo_size(), stream_ds[0].mean[stream_ds[0].source_idx], ) @@ -545,36 +514,6 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): return (input_data, output_data) - def _get_tokens_windows(self, stream_info, input_data, output_data): - """ - Tokenize data (to amortize over the different views that are generated) - - """ - - # TODO: move to tokenizer - - tok_spacetime = stream_info.get("tokenize_spacetime", False) - tok = tokenize_spacetime if tok_spacetime else tokenize_space - hl = self.healpix_level - token_size = stream_info["token_size"] - - input_tokens = [] - for rdata in input_data: - idxs_cells, idxs_cells_lens = tok( - readerdata_to_torch(rdata), token_size, hl, pad_tokens=True - ) - input_tokens += [(idxs_cells, idxs_cells_lens)] - - output_tokens = [] - for rdata in output_data: - idxs_cells, idxs_cells_lens = tok( - readerdata_to_torch(rdata), token_size, hl, pad_tokens=False - ) - output_tokens += [(idxs_cells, idxs_cells_lens)] - - # TODO: target_coords are expensive - - return (input_tokens, output_tokens) def _get_sample(self, mode: str, idx: int, forecast_dt: int): """ @@ -592,12 +531,14 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) # tokenize windows - (input_tokens, output_tokens) = self._get_tokens_windows( + # input_tokens = [ (cells_idx, cells_idx_lens), ... ] of time steps + (input_tokens, output_tokens) = self.tokenizer.get_tokens_windows( stream_info, input_data, output_data ) @@ -611,6 +552,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): t_keep_t, s_keep_t_list = self._get_student_teacher_masks(idx, forecast_dt) # collect source data for current stream + # TODO: list over teacher views stream_data_source = self._build_stream_data( "physical", idx, @@ -620,11 +562,12 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - t_keep_t, + t_keep_t[0], ) # collect target data for current stream # stream_data_target can contain network input + # TODO: list over student views stream_data_target = self._build_stream_data( "student", idx, @@ -637,6 +580,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): s_keep_t_list[0], ) + # TODO: build batch + # add data for current stream # streams_data += [( stream_data_source , stream_data_target)] streams_data += [stream_data_source] @@ -672,7 +617,7 @@ def to_bool_tensor(arr): return None return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) - t_keep_t = to_bool_tensor(t_keep_np) + t_keep_t = [to_bool_tensor(t_keep_np)] s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] # # Teacher view diff --git a/src/weathergen/datasets/tokenizer.py b/src/weathergen/datasets/tokenizer.py index a059d6b77..722bb5454 100644 --- a/src/weathergen/datasets/tokenizer.py +++ b/src/weathergen/datasets/tokenizer.py @@ -27,6 +27,7 @@ class Tokenizer: def __init__(self, healpix_level: int): ref = torch.tensor([1.0, 0.0, 0.0]) + self.healpix_level = healpix_level self.hl_source = healpix_level self.hl_target = healpix_level diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 525cb1e61..2ded389ba 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -19,9 +19,25 @@ encode_times_target, tokenize_apply_mask_source, tokenize_apply_mask_target, + tokenize_space, + tokenize_spacetime, ) from weathergen.datasets.view_builder import build_views_for_stream +def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData: + """ + Convert data, coords, and geoinfos to torch tensor + """ + if type(rdata.coords) is not torch.Tensor: + rdata.coords = torch.tensor(rdata.coords) + if type(rdata.geoinfos) is not torch.Tensor: + rdata.geoinfos = torch.tensor(rdata.geoinfos) + if type(rdata.data) is not torch.Tensor: + rdata.data = torch.tensor(rdata.data) + + return rdata + + class TokenizerMasking(Tokenizer): def __init__(self, healpix_level: int, masker: Masker): @@ -37,6 +53,35 @@ def reset_rng(self, rng) -> None: self.masker.reset_rng(rng) self.rng = rng + def get_tokens_windows(self, stream_info, input_data, output_data): + """ + Tokenize data (to amortize over the different views that are generated) + + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", False) + tok = tokenize_spacetime if tok_spacetime else tokenize_space + hl = self.healpix_level + token_size = stream_info["token_size"] + + input_tokens = [] + for rdata in input_data: + idxs_cells, idxs_cells_lens = tok( + readerdata_to_torch(rdata), token_size, hl, pad_tokens=True + ) + input_tokens += [(idxs_cells, idxs_cells_lens)] + + output_tokens = [] + for rdata in output_data: + idxs_cells, idxs_cells_lens = tok( + readerdata_to_torch(rdata), token_size, hl, pad_tokens=False + ) + output_tokens += [(idxs_cells, idxs_cells_lens)] + + # TODO: precompute target_coords -> expensive + + return (input_tokens, output_tokens) + def batchify_source( self, stream_info: dict, @@ -62,8 +107,6 @@ def batchify_source( return (source_tokens_cells, source_tokens_lens, mask_state) # # create tokenization index - # tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space - # idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=True) (idxs_cells, idxs_cells_lens) = idxs_cells_data # select strategy from XXX depending on stream and if student or teacher @@ -97,7 +140,6 @@ def batchify_source( "mask_tokens": mask_tokens, "mask_channels": mask_channels, } - self.mask_state = mask_state return (source_tokens_cells, source_tokens_lens, mask_state) @@ -115,12 +157,8 @@ def batchify_target( token_size = stream_info["token_size"] # create tokenization index - # tok = tokenize_spacetime if stream_info.get("tokenize_spacetime", False) else tokenize_space - # idxs_cells, idxs_cells_lens = tok(rdata, token_size, self.hl_source, pad_tokens=False) (idxs_cells, idxs_cells_lens) = token_data - mask_state = self.mask_state - # Apply per-view mask state if provided if mask_state is not None: self.masker.current_strategy = mask_state.get("strategy", self.masker.masking_strategy) From 1e70f5c3d2731cc57c9b535c058275abab08e035 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 17:09:20 +0100 Subject: [PATCH 069/344] More refactoring and cleanup --- .../datasets/multi_stream_data_sampler.py | 101 +++++++++--------- src/weathergen/datasets/tokenizer_masking.py | 21 ++-- 2 files changed, 55 insertions(+), 67 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index fcd4f3692..c5208d8a9 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -41,20 +41,6 @@ logger = logging.getLogger(__name__) -def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData: - """ - Convert data, coords, and geoinfos to torch tensor - """ - if type(rdata.coords) is not torch.Tensor: - rdata.coords = torch.tensor(rdata.coords) - if type(rdata.geoinfos) is not torch.Tensor: - rdata.geoinfos = torch.tensor(rdata.geoinfos) - if type(rdata.data) is not torch.Tensor: - rdata.data = torch.tensor(rdata.data) - - return rdata - - def collect_datasources(stream_datasets: list, idx: int, type: str) -> IOReaderData: """ Utility function to collect all sources / targets from streams list @@ -382,6 +368,10 @@ def _build_stream_data_output( output_tokens: list, mask_state: dict | None = None, ) -> StreamData: + """ + + """ + # collect for all forecast steps dt = self.forecast_offset + forecast_dt for step, fstep in enumerate(range(self.forecast_offset, dt + 1)): @@ -468,7 +458,10 @@ def _build_stream_data( return stream_data def _get_data_windows(self, base_idx, forecast_dt, stream_ds): - """ """ + """ + + + """ # source data: iterate overall input steps input_data = [] @@ -538,53 +531,57 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # tokenize windows # input_tokens = [ (cells_idx, cells_idx_lens), ... ] of time steps - (input_tokens, output_tokens) = self.tokenizer.get_tokens_windows( - stream_info, input_data, output_data - ) + input_tokens = self.tokenizer.get_tokens_windows( stream_info, input_data, True) + output_tokens = self.tokenizer.get_tokens_windows( stream_info, output_data, False) - # source_input - # target_input - # source_output - # target_output - - # get - # masks = build_views_for_stream( modes, input_tokens, output_tokens) + # get/coordinate masks t_keep_t, s_keep_t_list = self._get_student_teacher_masks(idx, forecast_dt) + # stream_data_target can contain network input + # loop over student views + stream_data_source = [] + for mask in s_keep_t_list : + stream_data_source += [ self._build_stream_data( + "physical", + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + mask, + ) + ] + # collect source data for current stream - # TODO: list over teacher views - stream_data_source = self._build_stream_data( - "physical", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - t_keep_t[0], - ) + # loop over teacher views + stream_data_target = [] + for mask in t_keep_t : + stream_data_target += [ self._build_stream_data( + "physical", + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + t_keep_t[0], + ) + ] - # collect target data for current stream - # stream_data_target can contain network input - # TODO: list over student views - stream_data_target = self._build_stream_data( - "student", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - s_keep_t_list[0], - ) # TODO: build batch + # source_input + # target_input + # source_output + # target_output + # add data for current stream # streams_data += [( stream_data_source , stream_data_target)] - streams_data += [stream_data_source] + streams_data += [stream_data_source[0]] return streams_data diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 2ded389ba..dea736da2 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -53,7 +53,7 @@ def reset_rng(self, rng) -> None: self.masker.reset_rng(rng) self.rng = rng - def get_tokens_windows(self, stream_info, input_data, output_data): + def get_tokens_windows(self, stream_info, data, pad_tokens): """ Tokenize data (to amortize over the different views that are generated) @@ -64,23 +64,14 @@ def get_tokens_windows(self, stream_info, input_data, output_data): hl = self.healpix_level token_size = stream_info["token_size"] - input_tokens = [] - for rdata in input_data: + tokens = [] + for rdata in data: idxs_cells, idxs_cells_lens = tok( - readerdata_to_torch(rdata), token_size, hl, pad_tokens=True + readerdata_to_torch(rdata), token_size, hl, pad_tokens ) - input_tokens += [(idxs_cells, idxs_cells_lens)] + tokens += [(idxs_cells, idxs_cells_lens)] - output_tokens = [] - for rdata in output_data: - idxs_cells, idxs_cells_lens = tok( - readerdata_to_torch(rdata), token_size, hl, pad_tokens=False - ) - output_tokens += [(idxs_cells, idxs_cells_lens)] - - # TODO: precompute target_coords -> expensive - - return (input_tokens, output_tokens) + return tokens def batchify_source( self, From 1235aab60562910471d69cc63e6c425210fd3dc0 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 17:47:40 +0100 Subject: [PATCH 070/344] More refactoring. Code working again. --- .../datasets/multi_stream_data_sampler.py | 71 ++++++++------ src/weathergen/datasets/stream_data.py | 74 ++++++++++++++ src/weathergen/datasets/tokenizer_masking.py | 97 ++++++++++++++++++- 3 files changed, 212 insertions(+), 30 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c5208d8a9..9a1e0ea0b 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -304,7 +304,7 @@ def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here return self.streams_datasets[stream_id][0].denormalize_target_channels(data) - def _build_stream_data_input( + def _build_stream_input_data( self, mode: str, stream_data: StreamData, @@ -330,6 +330,8 @@ def _build_stream_data_input( StreamData with source and targets masked according to view_meta """ + # source input data + # iterate overall input steps for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): # TODO: check that we are not out of bounds when we go back in time @@ -341,10 +343,10 @@ def _build_stream_data_input( token_data = input_tokens[step] # TODO: - # stream_data.source_is_spoof = True + # stream_data.source_is_spoof = rdata.is_spoof # preprocess data for model input - (source_cells, source_cells_lens, mask_state) = self.tokenizer.batchify_source( + (source_cells, source_cells_lens, mask_state) = self.tokenizer.get_source( stream_info, rdata, token_data, @@ -357,7 +359,7 @@ def _build_stream_data_input( return stream_data, mask_state - def _build_stream_data_output( + def _build_stream_output_data( self, mode: str, stream_data: StreamData, @@ -382,20 +384,31 @@ def _build_stream_data_output( rdata = output_data[step] token_data = output_tokens[step] - # TODO: - # stream_data.target_is_spoof = True + # stream_data.target_is_spoof = rdata.spoof - (tt_cells, tt_t, tt_c, tc, tc_l, idxs_inv) = self.tokenizer.batchify_target( - stream_info, - self.sampling_rate_target, - rdata, - token_data, - (time_win_target.start, time_win_target.end), - mask_state, - ) - stream_data.add_target(fstep, tt_cells, tc, tc_l, tt_c, tt_t, idxs_inv) + if "target_coords" in mode : + + (tc, tc_l) = self.tokenizer.get_target_coords( + stream_info, + self.sampling_rate_target, + rdata, + token_data, + (time_win_target.start, time_win_target.end), + mask_state, + ) + stream_data.add_target_coords(fstep, tc, tc_l) + + if "target_values" in mode : - # TODO: separate target_coords and target_value computation? + (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( + stream_info, + self.sampling_rate_target, + rdata, + token_data, + (time_win_target.start, time_win_target.end), + mask_state, + ) + stream_data.add_target_values(fstep, tt_cells, tt_c, tt_t, idxs_inv) return stream_data @@ -432,7 +445,7 @@ def _build_stream_data( dt = self.forecast_offset + forecast_dt stream_data = StreamData(base_idx, dt, self.num_healpix_cells) - stream_data, mask_state = self._build_stream_data_input( + stream_data, mask_state = self._build_stream_input_data( mode, stream_data, base_idx, @@ -443,17 +456,20 @@ def _build_stream_data( ) # physical space + mode_target = "" if "physical" in mode: - stream_data = self._build_stream_data_output( - mode, - stream_data, - base_idx, - stream_info, - forecast_dt, - output_data, - output_tokens, - mask_state, - ) + mode_target = "target_coords target_values" + + stream_data = self._build_stream_output_data( + mode_target, + stream_data, + base_idx, + stream_info, + forecast_dt, + output_data, + output_tokens, + mask_state, + ) return stream_data @@ -578,7 +594,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # source_output # target_output - # add data for current stream # streams_data += [( stream_data_source , stream_data_target)] streams_data += [stream_data_source[0]] diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index f5dcefee4..0e9166370 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -205,6 +205,80 @@ def add_target( self.target_coords_raw[fstep] = target_coords_raw self.idxs_inv[fstep] = idxs_inv + def add_target_values( + self, + fstep: int, + targets: list, + target_coords_raw: torch.tensor, + times_raw: torch.tensor, + idxs_inv: torch.tensor, + ) -> None: + """ + Add data for target for one input. + + Parameters + ---------- + fstep : int + forecast step + targets : torch.tensor( number of healpix cells ) + [ torch.tensor( num tokens, channels) ] + Target data for loss computation + targets_lens : torch.tensor( number of healpix cells) + length of targets per cell + target_coords : list( number of healpix cells) + [ torch.tensor( points per cell, 105) ] + target coordinates + target_times : list( number of healpix cells) + [ torch.tensor( points per cell) ] + absolute target times + idxs_inv: + Indices to reorder targets back to order in input + + Returns + ------- + None + """ + + self.target_tokens[fstep] = targets + self.target_times_raw[fstep] = times_raw + self.target_coords_raw[fstep] = target_coords_raw + self.idxs_inv[fstep] = idxs_inv + + def add_target_coords( + self, + fstep: int, + target_coords: torch.tensor, + target_coords_per_cell: torch.tensor, + ) -> None: + """ + Add data for target for one input. + + Parameters + ---------- + fstep : int + forecast step + targets : torch.tensor( number of healpix cells ) + [ torch.tensor( num tokens, channels) ] + Target data for loss computation + targets_lens : torch.tensor( number of healpix cells) + length of targets per cell + target_coords : list( number of healpix cells) + [ torch.tensor( points per cell, 105) ] + target coordinates + target_times : list( number of healpix cells) + [ torch.tensor( points per cell) ] + absolute target times + idxs_inv: + Indices to reorder targets back to order in input + + Returns + ------- + None + """ + + self.target_coords[fstep] = target_coords + self.target_coords_lens[fstep] = target_coords_per_cell + def target_empty(self) -> bool: """ Test if target for stream is empty diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index dea736da2..85870154c 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -73,7 +73,7 @@ def get_tokens_windows(self, stream_info, data, pad_tokens): return tokens - def batchify_source( + def get_source( self, stream_info: dict, rdata: IOReaderData, @@ -136,7 +136,7 @@ def batchify_source( # batchify_target_for_view now unified into batchify_target via optional mask_state - def batchify_target( + def get_target( self, stream_info: dict, sampling_rate_target: float, @@ -179,6 +179,99 @@ def batchify_target( return (data, datetimes, coords, coords_local, coords_per_cell, idxs_ord_inv) + + def get_target_coords( + self, + stream_info: dict, + sampling_rate_target: float, + rdata: IOReaderData, + token_data, + time_win: tuple, + mask_state: dict | None = None, + ): + token_size = stream_info["token_size"] + + # create tokenization index + (idxs_cells, idxs_cells_lens) = token_data + + # Apply per-view mask state if provided + if mask_state is not None: + self.masker.current_strategy = mask_state.get("strategy", self.masker.masking_strategy) + self.masker.mask_tokens = mask_state.get("mask_tokens") + self.masker.mask_channels = mask_state.get("mask_channels") + + (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( + stream_info, idxs_cells, idxs_cells_lens, rdata + ) + + # TODO: split up + _, _, _, coords_local, coords_per_cell = tokenize_apply_mask_target( + self.hl_target, + idxs_cells, + idxs_cells_lens, + mask_tokens, + mask_channels, + rdata, + time_win, + self.hpy_verts_rots_target, + self.hpy_verts_local_target, + self.hpy_nctrs_target, + encode_times_target, + ) + + # TODO, TODO, TODO: max_num_targets + # max_num_targets = stream_info.get("max_num_targets", -1) + + return (coords_local, coords_per_cell) + + + def get_target_values( + self, + stream_info: dict, + sampling_rate_target: float, + rdata: IOReaderData, + token_data, + time_win: tuple, + mask_state: dict | None = None, + ): + token_size = stream_info["token_size"] + + # create tokenization index + (idxs_cells, idxs_cells_lens) = token_data + + # Apply per-view mask state if provided + if mask_state is not None: + self.masker.current_strategy = mask_state.get("strategy", self.masker.masking_strategy) + self.masker.mask_tokens = mask_state.get("mask_tokens") + self.masker.mask_channels = mask_state.get("mask_channels") + + (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( + stream_info, idxs_cells, idxs_cells_lens, rdata + ) + + data, datetimes, coords, _, _ = tokenize_apply_mask_target( + self.hl_target, + idxs_cells, + idxs_cells_lens, + mask_tokens, + mask_channels, + rdata, + time_win, + self.hpy_verts_rots_target, + self.hpy_verts_local_target, + self.hpy_nctrs_target, + encode_times_target, + ) + + # TODO, TODO, TODO: max_num_targets + # max_num_targets = stream_info.get("max_num_targets", -1) + + # TODO: shuffeling + + return (data, datetimes, coords, idxs_ord_inv) + + + # ------------------------------------------------------------------ # Per-stream view construction (teacher + students) for student-teacher # ------------------------------------------------------------------ From 4613f7af3823b4aa34e2cc6472efffe5068a35cd Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 17:58:10 +0100 Subject: [PATCH 071/344] Cleaned up parametrization --- .../datasets/multi_stream_data_sampler.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 9a1e0ea0b..8ec3ab4df 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -455,13 +455,8 @@ def _build_stream_data( mask, ) - # physical space - mode_target = "" - if "physical" in mode: - mode_target = "target_coords target_values" - stream_data = self._build_stream_output_data( - mode_target, + mode, stream_data, base_idx, stream_info, @@ -496,7 +491,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): stream_ds[0].get_geoinfo_size(), stream_ds[0].mean[stream_ds[0].source_idx], ) - # stream_data.source_is_spoof = True + # rdata.is_spoof = True input_data += [rdata] @@ -517,7 +512,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): stream_ds[0].get_geoinfo_size(), stream_ds[0].mean[stream_ds[0].source_idx], ) - # stream_data.target_is_spoof = True + # rdata.is_spoof = True output_data += [rdata] @@ -558,7 +553,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): stream_data_source = [] for mask in s_keep_t_list : stream_data_source += [ self._build_stream_data( - "physical", + "target_coords target_values", idx, forecast_dt, stream_info, @@ -575,7 +570,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): stream_data_target = [] for mask in t_keep_t : stream_data_target += [ self._build_stream_data( - "physical", + "target_values", idx, forecast_dt, stream_info, @@ -587,7 +582,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): ) ] - # TODO: build batch # source_input # target_input From 9fe94f591b17b7a606f2ee6ff82faac524674610 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 19:30:48 +0100 Subject: [PATCH 072/344] Changes necessary for spoofing flag per IOReaderData --- packages/common/src/weathergen/common/io.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2dba8b727..42c44d653 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -49,6 +49,7 @@ class IOReaderData: geoinfos: NDArray[DType] data: NDArray[DType] datetimes: NDArray[NPDT64] + is_spoof: bool=False def is_empty(self): """ @@ -90,6 +91,7 @@ def combine(cls, others: list["IOReaderData"]) -> "IOReaderData": geoinfos = np.zeros((0, other.geoinfos.shape[1]), dtype=other.geoinfos.dtype) data = np.zeros((0, other.data.shape[1]), dtype=other.data.dtype) datetimes = np.array([], dtype=other.datetimes.dtype) + is_spoof = True for other in others: n_datapoints = len(other.data) @@ -101,8 +103,9 @@ def combine(cls, others: list["IOReaderData"]) -> "IOReaderData": geoinfos = np.concatenate([geoinfos, other.geoinfos]) data = np.concatenate([data, other.data]) datetimes = np.concatenate([datetimes, other.datetimes]) + is_spoof = is_spoof and other.is_spoof - return cls(coords, geoinfos, data, datetimes) + return cls(coords, geoinfos, data, datetimes, is_spoof) @dataclasses.dataclass From ed26c02fdf6c7b777016c05da1783875ad892684 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 19:57:23 +0100 Subject: [PATCH 073/344] Changes to have spoofing on a per data reader sample --- src/weathergen/datasets/data_reader_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 2b1bc9509..2850c9d20 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -199,6 +199,7 @@ class ReaderData: geoinfos: NDArray[DType] data: NDArray[DType] datetimes: NDArray[NPDT64] + is_spoof: bool=False @staticmethod def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData": @@ -215,6 +216,7 @@ def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData": geoinfos=np.zeros((0, num_geo_fields), dtype=np.float32), data=np.zeros((0, num_data_fields), dtype=np.float32), datetimes=np.zeros((0,), dtype=np.datetime64), + is_spoof=False ) def is_empty(self): From 6d685c0c51749544ac43f324a269284406917d2b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 19:57:46 +0100 Subject: [PATCH 074/344] Moved _get_student_teacher_masks() so that masks are generated for all streams first. --- .../datasets/multi_stream_data_sampler.py | 91 +++++++++---------- 1 file changed, 43 insertions(+), 48 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 8ec3ab4df..71392211b 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -342,8 +342,7 @@ def _build_stream_input_data( rdata = input_data[step] token_data = input_tokens[step] - # TODO: - # stream_data.source_is_spoof = rdata.is_spoof + stream_data.source_is_spoof = rdata.is_spoof # preprocess data for model input (source_cells, source_cells_lens, mask_state) = self.tokenizer.get_source( @@ -384,7 +383,7 @@ def _build_stream_output_data( rdata = output_data[step] token_data = output_tokens[step] - # stream_data.target_is_spoof = rdata.spoof + stream_data.target_is_spoof = rdata.is_spoof if "target_coords" in mode : @@ -491,7 +490,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): stream_ds[0].get_geoinfo_size(), stream_ds[0].mean[stream_ds[0].source_idx], ) - # rdata.is_spoof = True + rdata.is_spoof = True input_data += [rdata] @@ -512,7 +511,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): stream_ds[0].get_geoinfo_size(), stream_ds[0].mean[stream_ds[0].source_idx], ) - # rdata.is_spoof = True + rdata.is_spoof = True output_data += [rdata] @@ -533,9 +532,15 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): dt = forecast_dt + self.forecast_offset streams_data: list[StreamData] = [] + # get/coordinate masks + masks_streams = self._get_student_teacher_masks( idx, forecast_dt) + # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + name = stream_info["name"] + (target_masks, source_masks) = masks_streams[name] + # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) @@ -545,14 +550,12 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): input_tokens = self.tokenizer.get_tokens_windows( stream_info, input_data, True) output_tokens = self.tokenizer.get_tokens_windows( stream_info, output_data, False) - # get/coordinate masks - t_keep_t, s_keep_t_list = self._get_student_teacher_masks(idx, forecast_dt) - # stream_data_target can contain network input # loop over student views - stream_data_source = [] - for mask in s_keep_t_list : - stream_data_source += [ self._build_stream_data( + stream_data_source = { } + for mask in source_masks : + stream_data_source[name] = self._build_stream_data( + # stream_data_source += [ self._build_stream_data( "target_coords target_values", idx, forecast_dt, @@ -563,13 +566,12 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_tokens, mask, ) - ] # collect source data for current stream # loop over teacher views - stream_data_target = [] - for mask in t_keep_t : - stream_data_target += [ self._build_stream_data( + stream_data_target = {} + for mask in target_masks : + stream_data_target[name] = self._build_stream_data( "target_values", idx, forecast_dt, @@ -578,9 +580,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - t_keep_t[0], + mask, ) - ] # TODO: build batch # source_input @@ -590,7 +591,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # add data for current stream # streams_data += [( stream_data_source , stream_data_target)] - streams_data += [stream_data_source[0]] + # streams_data += [stream_data_source[0]] + streams_data += [v for k,v in stream_data_source.items()] return streams_data @@ -604,41 +606,34 @@ def _get_student_teacher_masks(self, idx: int, forecast_dt: int): forecast_dt: Number of forecast steps """ - teacher_cfg = self.training_cfg.get("teacher_model_input", {}) - student_cfg = self.training_cfg.get("model_input", {}) - relationship = student_cfg.get("relationship") - - # use build_views_for_stream utility to create student and teacher masks - t_keep_np, s_keeps_np, _meta = build_views_for_stream( - self.tokenizer.masker, - self.num_healpix_cells, - teacher_cfg=teacher_cfg, - student_cfg=student_cfg, - relationship=relationship, - ) + masks = {} + for stream_info in self.streams : + + teacher_cfg = self.training_cfg.get("teacher_model_input", {}) + student_cfg = self.training_cfg.get("model_input", {}) + relationship = student_cfg.get("relationship") - # Convert to torch.bool - def to_bool_tensor(arr): - if arr is None: - return None - return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) + # use build_views_for_stream utility to create student and teacher masks + t_keep_np, s_keeps_np, _meta = build_views_for_stream( + self.tokenizer.masker, + self.num_healpix_cells, + teacher_cfg=teacher_cfg, + student_cfg=student_cfg, + relationship=relationship, + ) - t_keep_t = [to_bool_tensor(t_keep_np)] - s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] + # Convert to torch.bool + def to_bool_tensor(arr): + if arr is None: + return None + return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) - # # Teacher view - # streams_data = self._get_sample_data("teacher", idx, forecast_dt, keep_mask=t_keep_t) + t_keep_t = [to_bool_tensor(t_keep_np)] + s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] - # # Students (build but do not change yielded batch shape yet) - # # For each student view (set in the config) build separate StreamData - # student_streams_data = [] - # for s_keep_t in s_keep_t_list: - # # do not do anything with this, just it is here. - # student_stream_data = self._get_sample_data("student", idx, forecast_dt, keep_mask=s_keep_t) - # student_streams_data.append(student_stream_data) + masks[ stream_info["name"] ] = (t_keep_t, s_keep_t_list) - # streams_data, student_streams_data - return t_keep_t, s_keep_t_list + return masks def _preprocess_model_data(self, batch, forecast_dt): """ """ From 848880b52f86d0337c0f24625a197a4fa448cc4b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 20:06:41 +0100 Subject: [PATCH 075/344] Renaming and minor clean up. --- .../datasets/multi_stream_data_sampler.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 71392211b..85b13be40 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -304,7 +304,7 @@ def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here return self.streams_datasets[stream_id][0].denormalize_target_channels(data) - def _build_stream_input_data( + def _build_stream_data_input( self, mode: str, stream_data: StreamData, @@ -358,7 +358,7 @@ def _build_stream_input_data( return stream_data, mask_state - def _build_stream_output_data( + def _build_stream_data_output( self, mode: str, stream_data: StreamData, @@ -444,7 +444,7 @@ def _build_stream_data( dt = self.forecast_offset + forecast_dt stream_data = StreamData(base_idx, dt, self.num_healpix_cells) - stream_data, mask_state = self._build_stream_input_data( + stream_data, mask_state = self._build_stream_data_input( mode, stream_data, base_idx, @@ -454,7 +454,7 @@ def _build_stream_data( mask, ) - stream_data = self._build_stream_output_data( + stream_data = self._build_stream_data_output( mode, stream_data, base_idx, @@ -469,7 +469,8 @@ def _build_stream_data( def _get_data_windows(self, base_idx, forecast_dt, stream_ds): """ - + Collect all data needed for current stream to potentially amortize costs by + generating multiple samples """ @@ -533,7 +534,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): streams_data: list[StreamData] = [] # get/coordinate masks - masks_streams = self._get_student_teacher_masks( idx, forecast_dt) + masks_streams = self._get_source_target_masks( idx, forecast_dt) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): @@ -546,16 +547,15 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) # tokenize windows - # input_tokens = [ (cells_idx, cells_idx_lens), ... ] of time steps + # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps input_tokens = self.tokenizer.get_tokens_windows( stream_info, input_data, True) output_tokens = self.tokenizer.get_tokens_windows( stream_info, output_data, False) - # stream_data_target can contain network input + # collect source data for current stream # loop over student views stream_data_source = { } for mask in source_masks : stream_data_source[name] = self._build_stream_data( - # stream_data_source += [ self._build_stream_data( "target_coords target_values", idx, forecast_dt, @@ -567,8 +567,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) - # collect source data for current stream - # loop over teacher views + # stream_data_target can contain network input stream_data_target = {} for mask in target_masks : stream_data_target[name] = self._build_stream_data( @@ -590,13 +589,11 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # target_output # add data for current stream - # streams_data += [( stream_data_source , stream_data_target)] - # streams_data += [stream_data_source[0]] streams_data += [v for k,v in stream_data_source.items()] return streams_data - def _get_student_teacher_masks(self, idx: int, forecast_dt: int): + def _get_source_target_masks(self, idx: int, forecast_dt: int): """ Return one batch of data Build a StreamData object for a single view (teacher or student). @@ -690,7 +687,6 @@ def __iter__(self): mode = "student_teacher" - # tokenizer.generate_masks_for_sample() streams_data = self._get_sample(mode, idx, forecast_dt) # Reset masking strategy for next batch item From 1b1654c0efbd43fe0b5d9c553ce59723cf1805e4 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 19 Nov 2025 22:32:05 +0100 Subject: [PATCH 076/344] Added basic support for use of ModelBatch class to define rough structure and interface. --- src/weathergen/datasets/batch.py | 203 ++++++++++++------ .../datasets/multi_stream_data_sampler.py | 20 +- 2 files changed, 148 insertions(+), 75 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index c42a2f5f0..3e518c095 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -7,25 +7,26 @@ """ from dataclasses import dataclass, field -import numpy as np -from typing import Optional + from weathergen.datasets.stream_data import StreamData -import torch +import numpy as np +import torch # TODO: Add a store for a random number for diffusion # TODO: GetTimestep to get the timestep # TODO: GetData: get the streamdata # TODO: GetMetaData: then this gets the right rn for the timestep! + @dataclass class ViewMetadata: """ Metadata describing how a view was generated. - + This captures the spatial selection (which cells/tokens were kept), the strategy used (random, healpix, etc.), and hierarchical parameters. - + Attributes: view_id: Unique identifier (e.g., "teacher_global", "student_local_0") keep_mask: Boolean array [num_healpix_cells] at data level indicating kept cells @@ -34,43 +35,61 @@ class ViewMetadata: rate: Fraction of data kept (e.g., 0.5 = 50% kept); None if fixed count parent_view_id: ID of the parent view this is a subset of (None for teacher) """ + # Core identifiers and selection description view_id: str - keep_mask: np.ndarray # [num_cells] bool at data level - strategy: str # e.g. "random", "healpix", "channel" - + keep_mask: np.ndarray # [num_cells] bool at data level + strategy: str # e.g. "random", "healpix", "channel" + # Hierarchical/quantitative description of selection - healpix_level: Optional[int] = None - rate: Optional[float] = None - parent_view_id: Optional[str] = None # For students: which teacher they belong to - + healpix_level: int | None = None + rate: float | None = None + parent_view_id: str | None = None # For students: which teacher they belong to + # Optional extras for future/other training paradigms - loss_type: Optional[str] = None # e.g. DINO, JEPA - strategy_config: Optional[dict] = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} - + loss_type: str | None = None # e.g. DINO, JEPA + strategy_config: dict | None = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} + # TODO: This doesn't handle the masking case, and we probably want it to, # where the model_inputs are the correct data for the masked source (and target?). Or target becomes the target? # Also should this model batch contain the source_cell_lens and target_coords_idx? -# Every sample is n different [streams]...each view is a different dictionary corresponding to one model input +# Every sample is n different [streams]...each view is a different dictionary corresponding to one model input # to get epsilon in there... # batches is for parallelism, but needs to all be in a tensor... [b, n, dim_embedding]? [b x n, dim_embedding] -# NOTE: this only stores the student source_cell_lens and target_coords_idx, +# NOTE: this only stores the student source_cell_lens and target_coords_idx, # because the teacher ones are already provided separately in (model_batches, source_cell_lens, target_coords_idx, forecast_dt) - # ^^^^^^ teacher ones ^^^^^^ +# ^^^^^^ teacher ones ^^^^^^ # However, we should probably store them all here for consistency. This needs changes to the model, so not done now. -# The forecast_dt is provided separately? +# The forecast_dt is provided separately? + + +class Sample : + + meta_info : dict + streams_data : dict + + def __init__( self, streams) : + + # TODO: can we pass this right away? + self.meta_info = {} + + self.streams_data = {} + for stream_info in streams : + self.streams_data[stream_info["name"]] = None + + def add_stream_data( self, stream_name, stream_data) : + assert self.streams_data.get( stream_name, -1) != -1, "stream name does not exist" + self.streams_data[stream_name] = stream_data + -@dataclass class ModelBatch: """ Container for all data and metadata for one training batch. - - In forecast/masking: model_inputs=[streams_data], targets=[] - - In student_teacher: model_inputs=[student_views], targets=[teacher_streams] - + Attributes: model_inputs: List of student views, each containing StreamData for all streams targets: List containing teacher view with StreamData for all streams @@ -79,53 +98,95 @@ class ModelBatch: student_source_cell_lens: List of source cell lengths for each student view student_target_coords_idx: List of target coordinate indices for each student view """ + # TODO: for DINO we want two global views per-dataset sample - # TODO: we want the global' view in student, perhaps as the first, + # TODO: we want the global' view in student, perhaps as the first, # with some metadata saying it is a second global view - - model_inputs: list[list[any]] # [n_students][n_streams] - targets: list[list[any]] # [1][n_streams] (teacher) - view_metadata: dict[str, ViewMetadata] # perhaps dict, teacher_metadata : ViewMetadata, student_metadata: list[ViewMetadata] - batch_info: Optional[dict] = field(default_factory=dict) - - # Offsets for student views (populated when needed for future student-teacher training) - # TODO: rename to model_input...source_cell/target_coords... NOTE: then there is a problem for target - student_source_cell_lens: Optional[list] = None # [n_students] each is a tensor - student_target_coords_idx: Optional[list] = None # [n_students] each is a list of lists - - # TODO fix this ridiculous naming - # Placeholders for having ModelBatch giving the full (StreamData, source_cell_lens, target_coords_idx) - teacher_source_cell_lens: torch.Tensor | None = None - teacher_target_coords_idx: list | None = None - - # TODO: add the timestep as an optional int for the model_inputs when we have multiple timesteps for the diffusion model... - # TODO add the forecast_dt as an optional int ? - - def to_device(self, device): - """Move all StreamData objects to the specified device.""" - for student_view in self.model_inputs: - for stream_data in student_view: - stream_data.to_device(device) - - for teacher_batch in self.targets: - for stream_data in teacher_batch: - stream_data.to_device(device) - - # Move student offsets if they exist - if self.student_source_cell_lens is not None: - self.student_source_cell_lens = [ - lens.to(device) if isinstance(lens, torch.Tensor) else lens - for lens in self.student_source_cell_lens - ] - - if self.student_target_coords_idx is not None: - # This is list[list[list[tensor]]], need to move all tensors - self.student_target_coords_idx = [ - [ - [t.to(device) if isinstance(t, torch.Tensor) else t for t in stream] - for stream in student_idx - ] - for student_idx in self.student_target_coords_idx - ] - - return self \ No newline at end of file + + source_samples : list[Sample] + + targt_samples : list[Sample] + + + def __init__( self, streams, num_source_samples, num_target_samples) : + """ + + """ + + self.source_samples = [Sample(streams) for _ in range(num_source_samples)] + self.target_samples = [Sample(streams) for _ in range(num_target_samples)] + + def add_source_stream( self, sample_idx : int, stream_name : str, stream_data : StreamData) : + """ + + """ + self.source_samples[sample_idx].add_stream_data( stream_name, stream_data) + + def add_target_stream( self, sample_idx : int, stream_name : str, stream_data : StreamData) : + """ + + """ + self.target_samples[sample_idx].add_stream_data( stream_name, stream_data) + + def len_source( self) : + """ + + """ + return len(self.source_samples) + + def len_target( self) : + """ + + """ + return len(self.target_samples) + + + # model_inputs: list[list[any]] # [n_students][n_streams] + # targets: list[list[any]] # [1][n_streams] (teacher) + # view_metadata: dict[ + # str, ViewMetadata + # ] # perhaps dict, teacher_metadata : ViewMetadata, student_metadata: list[ViewMetadata] + # batch_info: dict | None = field(default_factory=dict) + + + # # Offsets for student views (populated when needed for future student-teacher training) + # # TODO: rename to model_input...source_cell/target_coords... NOTE: then there is a problem for target + # student_source_cell_lens: list | None = None # [n_students] each is a tensor + # student_target_coords_idx: list | None = None # [n_students] each is a list of lists + + # # TODO fix this ridiculous naming + # # Placeholders for having ModelBatch giving the full (StreamData, source_cell_lens, target_coords_idx) + # teacher_source_cell_lens: torch.Tensor | None = None + # teacher_target_coords_idx: list | None = None + + # # TODO: add the timestep as an optional int for the model_inputs when we have multiple timesteps for the diffusion model... + # # TODO add the forecast_dt as an optional int ? + + # def to_device(self, device): + # """Move all StreamData objects to the specified device.""" + # for student_view in self.model_inputs: + # for stream_data in student_view: + # stream_data.to_device(device) + + # for teacher_batch in self.targets: + # for stream_data in teacher_batch: + # stream_data.to_device(device) + + # # Move student offsets if they exist + # if self.student_source_cell_lens is not None: + # self.student_source_cell_lens = [ + # lens.to(device) if isinstance(lens, torch.Tensor) else lens + # for lens in self.student_source_cell_lens + # ] + + # if self.student_target_coords_idx is not None: + # # This is list[list[list[tensor]]], need to move all tensors + # self.student_target_coords_idx = [ + # [ + # [t.to(device) if isinstance(t, torch.Tensor) else t for t in stream] + # for stream in student_idx + # ] + # for student_idx in self.student_target_coords_idx + # ] + + # return self diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 85b13be40..a010f7d77 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -33,6 +33,7 @@ compute_source_cell_lens, ) from weathergen.datasets.view_builder import build_views_for_stream +from weathergen.datasets.batch import ModelBatch from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage @@ -536,6 +537,11 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # get/coordinate masks masks_streams = self._get_source_target_masks( idx, forecast_dt) + # TODO: these params come from config? + num_source_samples = 8 + num_target_samples = 2 + batch = ModelBatch( self.streams, num_source_samples, num_target_samples) + # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): @@ -554,8 +560,9 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # collect source data for current stream # loop over student views stream_data_source = { } - for mask in source_masks : - stream_data_source[name] = self._build_stream_data( + for sidx, mask in enumerate(source_masks) : + # stream_data_source[name] = self._build_stream_data( + sdata = self._build_stream_data( "target_coords target_values", idx, forecast_dt, @@ -566,11 +573,14 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_tokens, mask, ) + stream_data_source[name] = sdata + batch.add_source_stream( sidx, name, sdata) # stream_data_target can contain network input stream_data_target = {} - for mask in target_masks : - stream_data_target[name] = self._build_stream_data( + for sidx, mask in enumerate(target_masks) : + # stream_data_target[name] = self._build_stream_data( + sdata = self._build_stream_data( "target_values", idx, forecast_dt, @@ -581,6 +591,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_tokens, mask, ) + stream_data_target[name] = sdata + batch.add_target_stream( sidx, name, sdata) # TODO: build batch # source_input From c1d32fba7eee50bb1172a7b87e0e939c60157785 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 08:20:21 +0100 Subject: [PATCH 077/344] linting --- packages/common/src/weathergen/common/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 42c44d653..0e85bb391 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -49,7 +49,7 @@ class IOReaderData: geoinfos: NDArray[DType] data: NDArray[DType] datetimes: NDArray[NPDT64] - is_spoof: bool=False + is_spoof: bool = False def is_empty(self): """ From 6a96065fddee4e6aebb26db5dbccc3b7a3cc19b3 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 08:20:42 +0100 Subject: [PATCH 078/344] Linting --- src/weathergen/datasets/data_reader_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 2850c9d20..6d1971b9a 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -199,7 +199,7 @@ class ReaderData: geoinfos: NDArray[DType] data: NDArray[DType] datetimes: NDArray[NPDT64] - is_spoof: bool=False + is_spoof: bool = False @staticmethod def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData": @@ -216,7 +216,7 @@ def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData": geoinfos=np.zeros((0, num_geo_fields), dtype=np.float32), data=np.zeros((0, num_data_fields), dtype=np.float32), datetimes=np.zeros((0,), dtype=np.datetime64), - is_spoof=False + is_spoof=False, ) def is_empty(self): From 3bca490ebebb0ff93b647474d05e77dc154cfae6 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 08:21:13 +0100 Subject: [PATCH 079/344] linting --- src/weathergen/datasets/masking.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index dfb4002be..f6c4f60fc 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -141,7 +141,7 @@ def mask_source_idxs( idxs_cells, idxs_cells_lens, rdata, - keep_mask: np.ndarray | None = None, + keep_mask: np.typing.NDArray | None = None, ) -> (torch.Tensor, torch.Tensor): """ @@ -162,16 +162,18 @@ def mask_source_idxs( # Otherwise fall back to the configured strategy logic. if keep_mask is not None: assert len(keep_mask) == len(idxs_cells_lens), ( - f"keep_mask length {len(keep_mask)} does not match number of cells {len(idxs_cells_lens)}" + "keep_mask length does not match number of cells." ) # build token level mask: for each cell replicate the keep flag across its tokens - token_level_flags: list[np.ndarray] = [] + token_level_flags: list[np.typing.NDArray] = [] for km, lens_cell in zip(keep_mask, idxs_cells_lens, strict=True): num_tokens_cell = len(lens_cell) if num_tokens_cell == 0: continue token_level_flags.append( - np.ones(num_tokens_cell, dtype=bool) if km else np.zeros(num_tokens_cell, dtype=bool) + np.ones(num_tokens_cell, dtype=bool) + if km + else np.zeros(num_tokens_cell, dtype=bool) ) if token_level_flags: self.mask_tokens = np.concatenate(token_level_flags) @@ -633,8 +635,8 @@ def generate_cell_keep_mask( strategy: str | None = None, rate: float | None = None, masking_strategy_config: dict | None = None, - constraint_keep_mask: np.ndarray | None = None, - ) -> np.ndarray: + constraint_keep_mask: np.typing.NDArray | None = None, + ) -> np.typing.NDArray: """Generate a boolean keep mask at data healpix level (True = keep cell). Parameters @@ -668,7 +670,7 @@ def generate_cell_keep_mask( assert 0.0 <= keep_rate <= 1.0, f"keep_rate out of bounds: {keep_rate}" assert num_cells == self.healpix_num_cells, ( - f"num_cells={num_cells} inconsistent with configured healpix level ({self.healpix_num_cells})." + "num_cells inconsistent with configured healpix level." ) if strat not in {"random", "healpix"}: @@ -682,7 +684,8 @@ def generate_cell_keep_mask( hl_data = self.healpix_level_data hl_mask = cfg.get("hl_mask") assert hl_mask is not None and hl_mask < hl_data, ( - "For healpix keep mask generation, cfg['hl_mask'] must be set and < data level.") + "For healpix keep mask generation, cfg['hl_mask'] must be set and < data level." + ) num_parent_cells = 12 * (4**hl_mask) level_diff = hl_data - hl_mask num_children_per_parent = 4**level_diff From 5d5e999ed820c353a9200805e100b4d2105e5655 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 08:21:31 +0100 Subject: [PATCH 080/344] Linting problems but removed unused ViewMetaData dependence --- src/weathergen/datasets/tokenizer_masking.py | 22 ++++---------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 85870154c..d27178381 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -11,7 +11,6 @@ import torch from weathergen.common.io import IOReaderData -from weathergen.datasets.batch import ViewMetadata from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer from weathergen.datasets.tokenizer_utils import ( @@ -24,6 +23,7 @@ ) from weathergen.datasets.view_builder import build_views_for_stream + def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData: """ Convert data, coords, and geoinfos to torch tensor @@ -38,13 +38,10 @@ def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData: return rdata - class TokenizerMasking(Tokenizer): def __init__(self, healpix_level: int, masker: Masker): super().__init__(healpix_level) self.masker = masker - # cache last built view metadata per stream invocation (optional downstream use) - self._last_view_metadata: list[ViewMetadata] | None = None def reset_rng(self, rng) -> None: """ @@ -81,9 +78,7 @@ def get_source( time_win: tuple, keep_mask: torch.Tensor | None = None, ): - token_size = stream_info["token_size"] stream_id = stream_info["stream_id"] - assert token_size is not None, "stream did not specify token_size" is_diagnostic = stream_info.get("diagnostic", False) # return empty if there is no data or we are in diagnostic mode @@ -145,7 +140,7 @@ def get_target( time_win: tuple, mask_state: dict | None = None, ): - token_size = stream_info["token_size"] + # TODO: remove # create tokenization index (idxs_cells, idxs_cells_lens) = token_data @@ -179,7 +174,6 @@ def get_target( return (data, datetimes, coords, coords_local, coords_per_cell, idxs_ord_inv) - def get_target_coords( self, stream_info: dict, @@ -189,8 +183,6 @@ def get_target_coords( time_win: tuple, mask_state: dict | None = None, ): - token_size = stream_info["token_size"] - # create tokenization index (idxs_cells, idxs_cells_lens) = token_data @@ -224,7 +216,6 @@ def get_target_coords( return (coords_local, coords_per_cell) - def get_target_values( self, stream_info: dict, @@ -234,8 +225,6 @@ def get_target_values( time_win: tuple, mask_state: dict | None = None, ): - token_size = stream_info["token_size"] - # create tokenization index (idxs_cells, idxs_cells_lens) = token_data @@ -270,8 +259,6 @@ def get_target_values( return (data, datetimes, coords, idxs_ord_inv) - - # ------------------------------------------------------------------ # Per-stream view construction (teacher + students) for student-teacher # ------------------------------------------------------------------ @@ -293,7 +280,8 @@ def build_stream_views( time_win : tuple (start, end) datetime window. training_cfg : dict | None - cf.training_config section; if absent or mode != 'student_teacher', fallback to single view. + cf.training_config section; if absent or mode != 'student_teacher', fallback to + single view. Returns ------- @@ -301,8 +289,6 @@ def build_stream_views( (tokens_cells, tokens_lens) for teacher or None when not student_teacher. students : list List of (tokens_cells, tokens_lens) for each student view (or single masking view). - view_metadata : list[ViewMetadata] | None - Metadata for teacher + students when in student_teacher mode. """ if training_cfg is None or training_cfg.get("training_mode") != "student_teacher": # Standard masking path: single view only (treated as 'student' for uniformity) From e8ccb8d55f094cda620b3af80fae5c9d3526025b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 08:22:26 +0100 Subject: [PATCH 081/344] Added required reflexivity between source and target samples to Batch --- src/weathergen/datasets/batch.py | 189 ++++++++---------- .../datasets/multi_stream_data_sampler.py | 50 ++--- 2 files changed, 108 insertions(+), 131 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 3e518c095..80b735b4d 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -6,12 +6,11 @@ - View metadata (spatial masks, strategies, relationships) """ -from dataclasses import dataclass, field - -from weathergen.datasets.stream_data import StreamData +from dataclasses import dataclass import numpy as np -import torch + +from weathergen.datasets.stream_data import StreamData # TODO: Add a store for a random number for diffusion # TODO: GetTimestep to get the timestep @@ -38,7 +37,7 @@ class ViewMetadata: # Core identifiers and selection description view_id: str - keep_mask: np.ndarray # [num_cells] bool at data level + keep_mask: np.typing.NDArray # [num_cells] bool at data level strategy: str # e.g. "random", "healpix", "channel" # Hierarchical/quantitative description of selection @@ -51,142 +50,126 @@ class ViewMetadata: strategy_config: dict | None = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} -# TODO: This doesn't handle the masking case, and we probably want it to, -# where the model_inputs are the correct data for the masked source (and target?). Or target becomes the target? -# Also should this model batch contain the source_cell_lens and target_coords_idx? -# Every sample is n different [streams]...each view is a different dictionary corresponding to one model input -# to get epsilon in there... -# batches is for parallelism, but needs to all be in a tensor... [b, n, dim_embedding]? [b x n, dim_embedding] +class SampleMetaData: + # masking strategy + masking_strategy: str + # parameters for masking strategy + masking_params: dict -# NOTE: this only stores the student source_cell_lens and target_coords_idx, -# because the teacher ones are already provided separately in (model_batches, source_cell_lens, target_coords_idx, forecast_dt) -# ^^^^^^ teacher ones ^^^^^^ -# However, we should probably store them all here for consistency. This needs changes to the model, so not done now. -# The forecast_dt is provided separately? +class Sample: + # keys: stream name, values: SampleMetaData + meta_info: dict -class Sample : - - meta_info : dict - streams_data : dict - - def __init__( self, streams) : + # data for all streams + # keys: stream_name, values: StreamData + streams_data: dict + def __init__(self, streams: dict) -> None: # TODO: can we pass this right away? self.meta_info = {} self.streams_data = {} - for stream_info in streams : + for stream_info in streams: self.streams_data[stream_info["name"]] = None - def add_stream_data( self, stream_name, stream_data) : - assert self.streams_data.get( stream_name, -1) != -1, "stream name does not exist" + def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: + """ + Add data for stream @stream_name to sample + """ + assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" self.streams_data[stream_name] = stream_data class ModelBatch: """ Container for all data and metadata for one training batch. + """ + # source samples (for model) + source_samples: list[Sample] - Attributes: - model_inputs: List of student views, each containing StreamData for all streams - targets: List containing teacher view with StreamData for all streams - view_metadata: List of ViewMetadata describing each view (teacher + students) - batch_info: Optional dict with batch-level info (sample indices, forecast steps, etc.) - student_source_cell_lens: List of source cell lengths for each student view - student_target_coords_idx: List of target coordinate indices for each student view - """ + # target samples (for TargetAuxCalculator) + targt_samples: list[Sample] - # TODO: for DINO we want two global views per-dataset sample - # TODO: we want the global' view in student, perhaps as the first, - # with some metadata saying it is a second global view + # index of corresponding target (for source samples) or source (for target samples) + # these are in 1-to-1 corresponding for classical training modes (MTM, forecasting) but + # can be more complex for strategies like student-teacher training + source_matching_idx: np.typing.NDArray[np.int32] + target_matching_idx: np.typing.NDArray[np.int32] - source_samples : list[Sample] + def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> None: + """ """ - targt_samples : list[Sample] + self.source_samples = [Sample(streams) for _ in range(num_source_samples)] + self.target_samples = [Sample(streams) for _ in range(num_target_samples)] + self.source_target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) + self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) - def __init__( self, streams, num_source_samples, num_target_samples) : + def add_source_stream( + self, + source_sample_idx: int, + target_sample_idx: int, + stream_name: str, + stream_data: StreamData, + ) -> None: """ - + Add data for one stream to sample @source_sample_idx + """ + self.source_samples[source_sample_idx].add_stream_data(stream_name, stream_data) + + assert target_sample_idx < len(self.target_samples), "invalid value for target_sample_idx" + self.source_target_matching_idxs[source_sample_idx] = target_sample_idx + + def add_target_stream( + self, + target_sample_idx: int, + source_sample_idx: int, + stream_name: str, + stream_data: StreamData, + ) -> None: """ + Add data for one stream to sample @target_sample_idx + """ + self.target_samples[target_sample_idx].add_stream_data(stream_name, stream_data) - self.source_samples = [Sample(streams) for _ in range(num_source_samples)] - self.target_samples = [Sample(streams) for _ in range(num_target_samples)] + assert source_sample_idx < len(self.source_samples), "invalid value for source_sample_idx" + self.target_source_matching_idxs[target_sample_idx] = source_sample_idx - def add_source_stream( self, sample_idx : int, stream_name : str, stream_data : StreamData) : + def len_sources(self) -> int: """ - + Number of source samples """ - self.source_samples[sample_idx].add_stream_data( stream_name, stream_data) + return len(self.source_samples) - def add_target_stream( self, sample_idx : int, stream_name : str, stream_data : StreamData) : + def len_targets(self) -> int: """ - + Number of target samples """ - self.target_samples[sample_idx].add_stream_data( stream_name, stream_data) + return len(self.target_samples) - def len_source( self) : + def get_source_sample(self, idx: int) -> Sample: """ - + Get a source sample """ - return len(self.source_samples) + return self.source_samples[idx] - def len_target( self) : + def get_target_sample(self, idx: int) -> Sample: """ - + Get a target sample """ - return len(self.target_samples) + return self.target_samples[idx] + def get_source_idx_for_target(self, target_idx: int) -> int: + """ + Get index of source sample for a given target sample index + """ + return int(self.source_target_matching_idxs[target_idx]) - # model_inputs: list[list[any]] # [n_students][n_streams] - # targets: list[list[any]] # [1][n_streams] (teacher) - # view_metadata: dict[ - # str, ViewMetadata - # ] # perhaps dict, teacher_metadata : ViewMetadata, student_metadata: list[ViewMetadata] - # batch_info: dict | None = field(default_factory=dict) - - - # # Offsets for student views (populated when needed for future student-teacher training) - # # TODO: rename to model_input...source_cell/target_coords... NOTE: then there is a problem for target - # student_source_cell_lens: list | None = None # [n_students] each is a tensor - # student_target_coords_idx: list | None = None # [n_students] each is a list of lists - - # # TODO fix this ridiculous naming - # # Placeholders for having ModelBatch giving the full (StreamData, source_cell_lens, target_coords_idx) - # teacher_source_cell_lens: torch.Tensor | None = None - # teacher_target_coords_idx: list | None = None - - # # TODO: add the timestep as an optional int for the model_inputs when we have multiple timesteps for the diffusion model... - # # TODO add the forecast_dt as an optional int ? - - # def to_device(self, device): - # """Move all StreamData objects to the specified device.""" - # for student_view in self.model_inputs: - # for stream_data in student_view: - # stream_data.to_device(device) - - # for teacher_batch in self.targets: - # for stream_data in teacher_batch: - # stream_data.to_device(device) - - # # Move student offsets if they exist - # if self.student_source_cell_lens is not None: - # self.student_source_cell_lens = [ - # lens.to(device) if isinstance(lens, torch.Tensor) else lens - # for lens in self.student_source_cell_lens - # ] - - # if self.student_target_coords_idx is not None: - # # This is list[list[list[tensor]]], need to move all tensors - # self.student_target_coords_idx = [ - # [ - # [t.to(device) if isinstance(t, torch.Tensor) else t for t in stream] - # for stream in student_idx - # ] - # for student_idx in self.student_target_coords_idx - # ] - - # return self + def get_target_idx_for_source(self, source_idx: int) -> int: + """ + Get index of target sample for a given source sample index + """ + return int(self.source_target_matching_idxs[source_idx]) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index a010f7d77..080e056f6 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -14,6 +14,7 @@ import torch from weathergen.common.io import IOReaderData +from weathergen.datasets.batch import ModelBatch from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi from weathergen.datasets.data_reader_base import ( DataReaderBase, @@ -33,7 +34,6 @@ compute_source_cell_lens, ) from weathergen.datasets.view_builder import build_views_for_stream -from weathergen.datasets.batch import ModelBatch from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage @@ -370,9 +370,7 @@ def _build_stream_data_output( output_tokens: list, mask_state: dict | None = None, ) -> StreamData: - """ - - """ + """ """ # collect for all forecast steps dt = self.forecast_offset + forecast_dt @@ -386,8 +384,7 @@ def _build_stream_data_output( stream_data.target_is_spoof = rdata.is_spoof - if "target_coords" in mode : - + if "target_coords" in mode: (tc, tc_l) = self.tokenizer.get_target_coords( stream_info, self.sampling_rate_target, @@ -398,8 +395,7 @@ def _build_stream_data_output( ) stream_data.add_target_coords(fstep, tc, tc_l) - if "target_values" in mode : - + if "target_values" in mode: (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( stream_info, self.sampling_rate_target, @@ -469,15 +465,15 @@ def _build_stream_data( return stream_data def _get_data_windows(self, base_idx, forecast_dt, stream_ds): - """ + """ Collect all data needed for current stream to potentially amortize costs by generating multiple samples - + """ # source data: iterate overall input steps input_data = [] - for step, idx in enumerate(range(base_idx - self.num_input_steps, base_idx + 1)): + for idx in range(base_idx - self.num_input_steps, base_idx + 1): # TODO: check that we are not out of bounds when we go back in time rdata = collect_datasources(stream_ds, idx, "source") @@ -519,7 +515,6 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): return (input_data, output_data) - def _get_sample(self, mode: str, idx: int, forecast_dt: int): """ @@ -531,20 +526,18 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): TODO: these modes are not being used now. """ - dt = forecast_dt + self.forecast_offset streams_data: list[StreamData] = [] # get/coordinate masks - masks_streams = self._get_source_target_masks( idx, forecast_dt) + masks_streams = self._get_source_target_masks(idx, forecast_dt) # TODO: these params come from config? num_source_samples = 8 - num_target_samples = 2 - batch = ModelBatch( self.streams, num_source_samples, num_target_samples) + num_target_samples = 8 + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - name = stream_info["name"] (target_masks, source_masks) = masks_streams[name] @@ -554,13 +547,13 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # tokenize windows # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps - input_tokens = self.tokenizer.get_tokens_windows( stream_info, input_data, True) - output_tokens = self.tokenizer.get_tokens_windows( stream_info, output_data, False) + input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) + output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) # collect source data for current stream # loop over student views - stream_data_source = { } - for sidx, mask in enumerate(source_masks) : + stream_data_source = {} + for sidx, mask in enumerate(source_masks): # stream_data_source[name] = self._build_stream_data( sdata = self._build_stream_data( "target_coords target_values", @@ -574,11 +567,12 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_source[name] = sdata - batch.add_source_stream( sidx, name, sdata) + # TODO: set target sample correctly + batch.add_source_stream(sidx, sidx, name, sdata) # stream_data_target can contain network input stream_data_target = {} - for sidx, mask in enumerate(target_masks) : + for sidx, mask in enumerate(target_masks): # stream_data_target[name] = self._build_stream_data( sdata = self._build_stream_data( "target_values", @@ -592,7 +586,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_target[name] = sdata - batch.add_target_stream( sidx, name, sdata) + # TODO: set target sample correctly + batch.add_target_stream(sidx, sidx, name, sdata) # TODO: build batch # source_input @@ -601,7 +596,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # target_output # add data for current stream - streams_data += [v for k,v in stream_data_source.items()] + streams_data += [v for k, v in stream_data_source.items()] return streams_data @@ -616,8 +611,7 @@ def _get_source_target_masks(self, idx: int, forecast_dt: int): """ masks = {} - for stream_info in self.streams : - + for stream_info in self.streams: teacher_cfg = self.training_cfg.get("teacher_model_input", {}) student_cfg = self.training_cfg.get("model_input", {}) relationship = student_cfg.get("relationship") @@ -640,7 +634,7 @@ def to_bool_tensor(arr): t_keep_t = [to_bool_tensor(t_keep_np)] s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] - masks[ stream_info["name"] ] = (t_keep_t, s_keep_t_list) + masks[stream_info["name"]] = (t_keep_t, s_keep_t_list) return masks From d18cf868f29f8cae9b51362ea7b9d338b52d4de7 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 08:26:40 +0100 Subject: [PATCH 082/344] Added todo --- src/weathergen/datasets/batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 80b735b4d..1478237bd 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -81,6 +81,7 @@ def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" self.streams_data[stream_name] = stream_data + # TODO: complete interface, e.g get_stream class ModelBatch: """ From b2be982bd67487323e8a887fb340cabf5b17207e Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 20 Nov 2025 13:07:47 +0000 Subject: [PATCH 083/344] fix typo in ModelBatch --- src/weathergen/datasets/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 1478237bd..550a19598 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -92,7 +92,7 @@ class ModelBatch: source_samples: list[Sample] # target samples (for TargetAuxCalculator) - targt_samples: list[Sample] + target_samples: list[Sample] # index of corresponding target (for source samples) or source (for target samples) # these are in 1-to-1 corresponding for classical training modes (MTM, forecasting) but From b34b6da569c6e176647999a57f144fbbb91b129c Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 20 Nov 2025 13:09:19 +0000 Subject: [PATCH 084/344] collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher --- .../datasets/multi_stream_data_sampler.py | 72 +++++++++++++------ 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 080e056f6..8d6d818d5 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -531,15 +531,19 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # get/coordinate masks masks_streams = self._get_source_target_masks(idx, forecast_dt) - # TODO: these params come from config? - num_source_samples = 8 - num_target_samples = 8 + # Determine number of views direct from config (teacher & student views) + teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} + num_target_samples = int(teacher_cfg.get("num_views", 1)) + num_source_samples = int(teacher_cfg.get("num_views", 1)) * int(student_cfg.get("num_views", 1)) # per teacher + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): name = stream_info["name"] - (target_masks, source_masks) = masks_streams[name] + + (target_masks, source_masks, student_to_teacher) = masks_streams[name] # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 @@ -567,12 +571,21 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_source[name] = sdata - # TODO: set target sample correctly - batch.add_source_stream(sidx, sidx, name, sdata) + # TODO: check this is correct + # Map each student (source) to its teacher (target) + t_idx = int(student_to_teacher[sidx]) + batch.add_source_stream(sidx, t_idx, name, sdata) # stream_data_target can contain network input stream_data_target = {} - for sidx, mask in enumerate(target_masks): + # Not sure if this is the neatest approach... + # choose a student per teacher for reverse mapping + # pick the first student index that maps to this teacher + rep_source_for_teacher = {} + for s_idx, t_idx in enumerate(student_to_teacher): + rep_source_for_teacher.setdefault(int(t_idx), s_idx) + + for t_idx, mask in enumerate(target_masks): # stream_data_target[name] = self._build_stream_data( sdata = self._build_stream_data( "target_values", @@ -586,8 +599,10 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_target[name] = sdata - # TODO: set target sample correctly - batch.add_target_stream(sidx, sidx, name, sdata) + # TODO: check target sample here + # Map target to one representative source (not 1:N here) + rep_source_idx = int(rep_source_for_teacher.get(t_idx, 0)) + batch.add_target_stream(t_idx, rep_source_idx, name, sdata) # TODO: build batch # source_input @@ -616,14 +631,8 @@ def _get_source_target_masks(self, idx: int, forecast_dt: int): student_cfg = self.training_cfg.get("model_input", {}) relationship = student_cfg.get("relationship") - # use build_views_for_stream utility to create student and teacher masks - t_keep_np, s_keeps_np, _meta = build_views_for_stream( - self.tokenizer.masker, - self.num_healpix_cells, - teacher_cfg=teacher_cfg, - student_cfg=student_cfg, - relationship=relationship, - ) + # number of teacher views + num_teacher_views = int(teacher_cfg.get("num_views", 1)) # Convert to torch.bool def to_bool_tensor(arr): @@ -631,11 +640,34 @@ def to_bool_tensor(arr): return None return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) - t_keep_t = [to_bool_tensor(t_keep_np)] - s_keep_t_list = [to_bool_tensor(m) for m in (s_keeps_np or [])] + # renaming here + target_masks: list[torch.Tensor] = [] + source_masks: list[torch.Tensor] = [] + student_to_teacher: list[int] = [] + + # add a loop over num_teacher_views, generate students for each teacher + for t_idx in range(num_teacher_views): + # Build one teacher and its student views + t_keep_np, s_keeps_np, _meta = build_views_for_stream( + self.tokenizer.masker, + self.num_healpix_cells, + teacher_cfg=teacher_cfg, + student_cfg=student_cfg, + relationship=relationship, + ) + + # append teacher mask + t_tensor = to_bool_tensor(t_keep_np) + target_masks.append(t_tensor) - masks[stream_info["name"]] = (t_keep_t, s_keep_t_list) + # this teacher's students and mapping + for s_np in (s_keeps_np or []): + source_masks.append(to_bool_tensor(s_np)) + # append 0, 1, ... depending on which teacher we did + student_to_teacher.append(len(target_masks) - 1) + masks[stream_info["name"]] = (target_masks, source_masks, student_to_teacher) + return masks def _preprocess_model_data(self, batch, forecast_dt): From 87ad45f308b66765a1e017208f8b00a98520ac13 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 20 Nov 2025 13:10:34 +0000 Subject: [PATCH 085/344] add teacher num_views parameter to config --- config/default_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/default_config.yml b/config/default_config.yml index c0007da6e..347908668 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -129,6 +129,7 @@ training_config: teacher_model_input: strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" rate: 0.5 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + num_views: 2 # number of teacher views to generate # keep_m: 100 # Alternative to rate: keep exactly this many parent cells rate_sampling: true # randomly sample the rate per batch masking_strategy_config: {"strategies": ["random", "healpix", "channel"], From 9b702c599d0ac427d342c709ec058007c3ee99b1 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 20 Nov 2025 14:34:34 +0100 Subject: [PATCH 086/344] Re-enabling inversion of targert ordering. --- src/weathergen/train/trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 09dc05e34..98d3a86c3 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -557,14 +557,14 @@ def _prepare_logging( targets_lens[fstep][i_strm] += [target.shape[0]] dn_data = self.dataset_val.denormalize_target_channels - # # reorder so that output order of target points matches input when reading - # # (tokenization and masking changes this order) - # # TODO: does this work with batch_size > 1 - # if len(idxs_inv) > 0: - # pred = pred[:, idxs_inv] - # target = target[idxs_inv] - # targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] - # targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] + # reorder so that output order of target points matches input when reading + # (tokenization and masking changes this order) + # TODO: does this work with batch_size > 1 + if len(idxs_inv) > 0: + pred = pred[:, idxs_inv] + target = target[idxs_inv] + targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] + targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] f32 = torch.float32 preds_all[fstep][i_strm] += [ From 1806ae5feafa1663838610c4d5d8a134dc4319d4 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 20 Nov 2025 16:28:30 +0000 Subject: [PATCH 087/344] tidy up, remove unused build_stream_views in tokenizer_masking --- src/weathergen/datasets/tokenizer_masking.py | 65 -------------------- 1 file changed, 65 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index d27178381..43d1785a1 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -259,71 +259,6 @@ def get_target_values( return (data, datetimes, coords, idxs_ord_inv) - # ------------------------------------------------------------------ - # Per-stream view construction (teacher + students) for student-teacher - # ------------------------------------------------------------------ - def build_stream_views( - self, - stream_info: dict, - rdata: IOReaderData, - time_win: tuple, - training_cfg: dict | None = None, - ): - """Construct teacher and student views for a single stream. - - Parameters - ---------- - stream_info : dict - Stream configuration dictionary. - rdata : IOReaderData - Combined reader data for this stream. - time_win : tuple - (start, end) datetime window. - training_cfg : dict | None - cf.training_config section; if absent or mode != 'student_teacher', fallback to - single view. - - Returns - ------- - teacher : tuple | None - (tokens_cells, tokens_lens) for teacher or None when not student_teacher. - students : list - List of (tokens_cells, tokens_lens) for each student view (or single masking view). - """ - if training_cfg is None or training_cfg.get("training_mode") != "student_teacher": - # Standard masking path: single view only (treated as 'student' for uniformity) - scells, slens, _mask_state = self.batchify_source(stream_info, rdata, time_win) - self._last_view_metadata = None - return None, [(scells, slens, _mask_state)], None - - teacher_cfg = training_cfg.get("teacher_model_input", {}) - student_cfg = training_cfg.get("model_input", {}) - relationship = student_cfg.get("relationship", "subset") - - num_cells = self.num_healpix_cells_source - teacher_keep_mask, student_keep_masks, view_meta = build_views_for_stream( - self.masker, num_cells, teacher_cfg, student_cfg, relationship - ) - - # Convert keep masks to torch tensors for downstream masking override - teacher_keep_mask_t = torch.from_numpy(teacher_keep_mask) - student_keep_masks_t = [torch.from_numpy(m) for m in student_keep_masks] - - # Teacher tokens - t_cells, t_lens, t_mask_state = self.batchify_source( - stream_info, rdata, time_win, keep_mask=teacher_keep_mask_t - ) - # Student tokens - student_tokens = [ - self.batchify_source(stream_info, rdata, time_win, keep_mask=km) - for km in student_keep_masks_t - ] - # add mask_state inside each tuple - student_tokens = [(cells, lens, mstate) for (cells, lens, mstate) in student_tokens] - - self._last_view_metadata = view_meta - return (t_cells, t_lens, t_mask_state), student_tokens, view_meta - def sample_tensors_uniform_vectorized( self, tensor_list: list, lengths: list, max_total_points: int ): From 647e4b23a38123a046e9537db2aecd2e12d74ca5 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 20 Nov 2025 18:31:45 +0000 Subject: [PATCH 088/344] multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this --- src/weathergen/datasets/batch.py | 10 +++++++--- .../datasets/multi_stream_data_sampler.py | 18 ++++++------------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 550a19598..ae69dcd5f 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -107,7 +107,8 @@ def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> self.target_samples = [Sample(streams) for _ in range(num_target_samples)] self.source_target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) - self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) + # self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) + self.target_source_matching_idxs = [[] for _ in range(num_target_samples)] def add_source_stream( self, @@ -127,7 +128,7 @@ def add_source_stream( def add_target_stream( self, target_sample_idx: int, - source_sample_idx: int, + source_sample_idx: int | list[int], stream_name: str, stream_data: StreamData, ) -> None: @@ -136,7 +137,10 @@ def add_target_stream( """ self.target_samples[target_sample_idx].add_stream_data(stream_name, stream_data) - assert source_sample_idx < len(self.source_samples), "invalid value for source_sample_idx" + if isinstance(source_sample_idx, int): + assert source_sample_idx < len(self.source_samples), "invalid value for source_sample_idx" + else: + assert all(idx < len(self.source_samples) for idx in source_sample_idx), "invalid value for source_sample_idx" self.target_source_matching_idxs[target_sample_idx] = source_sample_idx def len_sources(self) -> int: diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 8d6d818d5..f002697ea 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -571,19 +571,13 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_source[name] = sdata - # TODO: check this is correct + # TODO: seb check this # Map each student (source) to its teacher (target) - t_idx = int(student_to_teacher[sidx]) + t_idx = student_to_teacher[sidx] batch.add_source_stream(sidx, t_idx, name, sdata) # stream_data_target can contain network input stream_data_target = {} - # Not sure if this is the neatest approach... - # choose a student per teacher for reverse mapping - # pick the first student index that maps to this teacher - rep_source_for_teacher = {} - for s_idx, t_idx in enumerate(student_to_teacher): - rep_source_for_teacher.setdefault(int(t_idx), s_idx) for t_idx, mask in enumerate(target_masks): # stream_data_target[name] = self._build_stream_data( @@ -599,10 +593,10 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_target[name] = sdata - # TODO: check target sample here - # Map target to one representative source (not 1:N here) - rep_source_idx = int(rep_source_for_teacher.get(t_idx, 0)) - batch.add_target_stream(t_idx, rep_source_idx, name, sdata) + # TODO: seb to check + # Map target to all source students + student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] + batch.add_target_stream(t_idx, student_indices, name, sdata) # TODO: build batch # source_input From 91c3d7a8e1d91be477701d13823e3811d6fda956 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 21 Nov 2025 12:53:31 +0000 Subject: [PATCH 089/344] add max_num_targets to era5 --- config/streams/era5_1deg/era5.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index eace84bfe..effc76111 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -18,7 +18,7 @@ ERA5 : masking_rate_none : 0.05 token_size : 8 tokenize_spacetime : True - max_num_targets: -1 + max_num_targets: -1 embed : net : transformer num_tokens : 1 From 1a418bfb01192c5fa279f2560eb9928efc0050f0 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 21 Nov 2025 12:54:33 +0000 Subject: [PATCH 090/344] add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty --- .../datasets/multi_stream_data_sampler.py | 6 +- src/weathergen/datasets/tokenizer_masking.py | 70 +++++++++++++++++-- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index f002697ea..68d8346c2 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -383,9 +383,11 @@ def _build_stream_data_output( token_data = output_tokens[step] stream_data.target_is_spoof = rdata.is_spoof + # None, or returned by get_target_coords + target_selection = None if "target_coords" in mode: - (tc, tc_l) = self.tokenizer.get_target_coords( + (tc, tc_l, target_selection) = self.tokenizer.get_target_coords( stream_info, self.sampling_rate_target, rdata, @@ -403,6 +405,7 @@ def _build_stream_data_output( token_data, (time_win_target.start, time_win_target.end), mask_state, + target_selection, ) stream_data.add_target_values(fstep, tt_cells, tt_c, tt_t, idxs_inv) @@ -596,6 +599,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # TODO: seb to check # Map target to all source students student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] + # print("Student indices", student_indices) batch.add_target_stream(t_idx, student_indices, name, sdata) # TODO: build batch diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 43d1785a1..297604451 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import numpy as np import torch from weathergen.common.io import IOReaderData @@ -211,10 +212,33 @@ def get_target_coords( encode_times_target, ) - # TODO, TODO, TODO: max_num_targets - # max_num_targets = stream_info.get("max_num_targets", -1) - - return (coords_local, coords_per_cell) + selection = self._select_target_subset(stream_info, coords_local.shape[0]) + + if selection is not None and coords_local.numel() > 0: + # use nice index_select method + coords_local = coords_local.index_select(0, selection.to(coords_local.device)) + + # coords_per_cell is trickier + if selection is not None and coords_per_cell.numel() > 0: + total_points = int(coords_per_cell.sum().item()) + if total_points == 0: + coords_per_cell = torch.zeros_like(coords_per_cell) + else: + cell_ids = torch.repeat_interleave( + torch.arange(coords_per_cell.shape[0], dtype=torch.long), + coords_per_cell.to(torch.long), + ) + if cell_ids.numel() == 0: + coords_per_cell = torch.zeros_like(coords_per_cell) + else: + new_counts = torch.bincount( + cell_ids[selection.to(cell_ids.device)], + minlength=coords_per_cell.shape[0], + ) + coords_per_cell = new_counts.to(dtype=coords_per_cell.dtype) + + # pass the selection back for use in get_target_values + return (coords_local, coords_per_cell, selection) def get_target_values( self, @@ -224,6 +248,7 @@ def get_target_values( token_data, time_win: tuple, mask_state: dict | None = None, + selection: torch.Tensor | None = None, ): # create tokenization index (idxs_cells, idxs_cells_lens) = token_data @@ -252,13 +277,44 @@ def get_target_values( encode_times_target, ) - # TODO, TODO, TODO: max_num_targets - # max_num_targets = stream_info.get("max_num_targets", -1) + if selection is None: + selection = self._select_target_subset(stream_info, data.shape[0]) + + if selection is not None and data.numel() > 0: + device_sel = selection.to(data.device) + data = data.index_select(0, device_sel) + coords = coords.index_select(0, device_sel) + if idxs_ord_inv.numel() > 0: + idxs_ord_inv = idxs_ord_inv.index_select(0, device_sel) + + # datetimes is numpy here + np_sel = selection.cpu().numpy() + datetimes = datetimes[np_sel] - # TODO: shuffeling + # TODO: shuffling + # selection not passed on, we call get_target_coords first return (data, datetimes, coords, idxs_ord_inv) + def _select_target_subset( + self, + stream_info: dict, + num_points: int, + ) -> torch.Tensor | None: + max_num_targets = stream_info.get("max_num_targets", -1) + + if max_num_targets is None or max_num_targets <= 0 or num_points <= max_num_targets: + return None + + rng = getattr(self, "rng", None) + if rng is None: + rng = np.random.default_rng() + self.rng = rng + + selected = np.sort(rng.choice(num_points, max_num_targets, replace=False)) + + return torch.from_numpy(selected).to(torch.long) + def sample_tensors_uniform_vectorized( self, tensor_list: list, lengths: list, max_total_points: int ): From 4df1788218565232fb8e9cd7ff12555907aa87bc Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Fri, 21 Nov 2025 14:01:36 +0100 Subject: [PATCH 091/344] Latent diffusion loss (#1322) * first commit for latent diffusion loss * first run with latent diffusion loss * rename latent diffusion module * rename latent diffusion module * rename latent diff module in init --- config/default_config.yml | 8 +- src/weathergen/model/model.py | 4 +- src/weathergen/train/loss_modules/__init__.py | 4 +- .../train/loss_modules/loss_functions.py | 2 +- .../train/loss_modules/loss_module_latent.py | 112 ---------------- .../loss_module_latent_diffusion.py | 121 ++++++++++++++++++ src/weathergen/train/trainer.py | 7 +- src/weathergen/utils/train_logger.py | 10 +- 8 files changed, 139 insertions(+), 129 deletions(-) delete mode 100644 src/weathergen/train/loss_modules/loss_module_latent.py create mode 100644 src/weathergen/train/loss_modules/loss_module_latent_diffusion.py diff --git a/config/default_config.yml b/config/default_config.yml index fb115dde5..532ad2075 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -42,12 +42,12 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 1 +forecast_policy: "fixed" forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 6fcf3ff88..7d19aa8b3 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -584,6 +584,8 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca # roll-out in latent space preds_all = [] + latents = {} + latents["preds"] = [] for fstep in range(forecast_offset, forecast_offset + forecast_steps): # prediction preds_all += [ @@ -603,6 +605,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std tokens = self.forecast(model_params, tokens, fstep) + latents["preds"] += [tokens] # prediction for final step preds_all += [ @@ -615,7 +618,6 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - latents = {} latents["posteriors"] = posteriors return ModelOutput(physical=preds_all, latent=latents) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 43be4dfe1..9df30d23a 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -1,5 +1,5 @@ -from .loss_module_latent import LossLatent +from .loss_module_latent_diffusion import LossLatentDiffusion from .loss_module_physical import LossPhysical, LossPhysicalTwo from .loss_module_ssl import LossStudentTeacher -__all__ = [LossLatent, LossPhysical, LossPhysicalTwo, LossStudentTeacher] +__all__ = [LossLatentDiffusion, LossPhysical, LossPhysicalTwo, LossStudentTeacher] diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index e09928dd1..4cbcaf40f 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -61,7 +61,7 @@ def stats_normalized_erf(target, ens, mu, stddev): return torch.mean(d * d) # + torch.mean( torch.sqrt( stddev) ) -def mse(target, ens, mu, *kwargs): +def mse(target, mu, *kwargs): return torch.nn.functional.mse_loss(target, mu) diff --git a/src/weathergen/train/loss_modules/loss_module_latent.py b/src/weathergen/train/loss_modules/loss_module_latent.py deleted file mode 100644 index c680382ee..000000000 --- a/src/weathergen/train/loss_modules/loss_module_latent.py +++ /dev/null @@ -1,112 +0,0 @@ -# ruff: noqa: T201 - -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging - -import torch -from omegaconf import DictConfig -from torch import Tensor - -import weathergen.train.loss_modules.loss_functions as losses -from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues -from weathergen.utils.train_logger import Stage - -_logger = logging.getLogger(__name__) - - -class LossLatent(LossModuleBase): - """ - Calculates loss in latent space. - """ - - def __init__( - self, - cf: DictConfig, - loss_fcts: list, - stage: Stage, - device: str, - ): - LossModuleBase.__init__(self) - self.cf = cf - self.stage = stage - self.device = device - self.name = "LossLatent" - - # Dynamically load loss functions based on configuration and stage - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts - ] - - def _loss_per_loss_function( - self, - loss_fct, - target: torch.Tensor, - pred: torch.Tensor, - ): - """ - Compute loss for given loss function - """ - - loss_val = loss_fct(target=target, ens=None, mu=pred) - - return loss_val - - def compute_loss( - self, - preds: list[list[Tensor]], - targets: list[list[any]], - ) -> LossValues: - return super().compute_loss(preds, targets) - - ### FROM KEREM's PR - # losses_all: Tensor = torch.zeros( - # len(self.loss_fcts), - # device=self.device, - # ) - - # loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) - # ctr_fsteps_lat = 0 - # # TODO: KCT, do we need the below per fstep? - # for fstep in range( - # 1, len(preds) - # ): # the first entry in tokens_all is the source itself, so skip it - # loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - # ctr_loss_fcts = 0 - # # if forecast_offset==0, then the timepoints correspond. - # # Otherwise targets don't encode the source timestep, so we don't need to skip - # fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 - # for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): - # loss_lfct = self._loss_per_loss_function( - # loss_fct, - # stream_info=None, - # target=targets[fstep_targs], - # pred=preds[fstep], - # ) - - # losses_all[i_lfct] += loss_lfct # TODO: break into fsteps - - # # Add the weighted and normalized loss from this loss function to the total - # # batch loss - # loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) - # ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - # loss_fsteps_lat = loss_fsteps_lat + ( - # loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 - # ) - # ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 - - # loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) - - # losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 - # losses_all[losses_all == 0.0] = torch.nan - - # return LossValues(loss=loss, losses_all=losses_all) diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py new file mode 100644 index 000000000..498594968 --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -0,0 +1,121 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss_modules.loss_functions as losses +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossLatentDiffusion(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatentDiff" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [[getattr(losses, name), w, name] for name, w in loss_fcts] + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, mu=pred) + + return loss_val + + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: + losses_all: dict[str, Tensor] = { + f"{self.name}.{loss_fct_name}": torch.zeros( + 1, + device=self.device, + ) + for _, _, loss_fct_name in self.loss_fcts + } + + preds = preds.latent["preds"] + targets = targets["targets"] + fsteps = len(targets) + + fstep_loss_weights = self._get_fstep_weights(fsteps) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + for target, pred, fstep_loss_weight in zip(targets, preds, fstep_loss_weights, strict=True): + # the first entry in tokens_all is the source itself, so skip it + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + # if forecast_offset==0, then the timepoints correspond. + # Otherwise targets don't encode the source timestep, so we don't need to skip + for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: + loss_lfct = self._loss_per_loss_function( + loss_fct, + target=target, + pred=pred, + ) + + losses_all[f"{self.name}.{loss_fct_name}"] += loss_lfct # TODO: break into fsteps + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + ( + loss_fstep * fstep_loss_weight / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + ) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss_fsteps / (ctr_fsteps if ctr_fsteps > 0 else 1.0) + + for _, loss_values in losses_all.items(): + loss_values /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + loss_values[loss_values == 0.0] = torch.nan + + return LossValues( + loss=loss, losses_all=losses_all, stddev_all={"latent": torch.tensor(torch.nan)} + ) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 94fc53fef..c2eb52bfe 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -621,11 +621,8 @@ def train(self, mini_epoch): targets, aux_outputs = self.target_and_aux_calculator.compute( bidx, batch, self.model_params, self.model, cf.forecast_offset, forecast_steps ) - - loss, loss_values = self.loss_calculator.compute_loss( - preds=output, - targets=targets, - ) + targets = {"targets": [targets], "aux_outputs": aux_outputs} + loss, loss_values = self.loss_calculator.compute_loss(preds=output, targets=targets) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in output.latent]) loss += cf.latent_noise_kl_weight * kl.mean() diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 8e004e316..5f34e4e11 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -122,8 +122,9 @@ def add_train( st = self.cf.streams[0] for loss_name, loss_values in losses_all.items(): metrics[f"loss.{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() - for k, ch_n in enumerate(st.train_target_channels): - metrics[f"loss.{loss_name}.{ch_n}"] = loss_values[:, k].nanmean().item() + if "Physical" in loss_name: + for k, ch_n in enumerate(st.train_target_channels): + metrics[f"loss.{loss_name}.{ch_n}"] = loss_values[:, k].nanmean().item() log_vals += [loss_values[:, :].nanmean().item()] for loss_name, stddev_values in stddev_all.items(): metrics[f"loss.{loss_name}.stddev_avg"] = stddev_values.nanmean().item() @@ -157,8 +158,9 @@ def add_val( st = self.cf.streams[0] for loss_name, loss_values in losses_all.items(): metrics[f"loss.{loss_name}.loss_avg"] = loss_values[:, :].nanmean().item() - for k, ch_n in enumerate(st.train_target_channels): - metrics[f"loss.{loss_name}.{ch_n}"] = loss_values[:, k].nanmean().item() + if "Physical" in loss_name: + for k, ch_n in enumerate(st.train_target_channels): + metrics[f"loss.{loss_name}.{ch_n}"] = loss_values[:, k].nanmean().item() log_vals += [loss_values[:, :].nanmean().item()] for loss_name, stddev_values in stddev_all.items(): metrics[f"loss.{loss_name}.stddev_avg"] = stddev_values.nanmean().item() From 63b2b632bb77c0c02cb38df66035d0e3cafe43a3 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 21 Nov 2025 13:06:48 +0000 Subject: [PATCH 092/344] Build latent diffusion forecast engine --- src/weathergen/model/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7d19aa8b3..df80147e4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -24,6 +24,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.model.engines import ( EmbeddingEngine, EnsPredictionHead, @@ -333,6 +334,7 @@ def create(self) -> "Model": ) self.forecast_engine = ForecastingEngine(cf, self.num_healpix_cells) + self.forecast_engine = DiffusionForecastEngine(forecast_engine=self.forecast_engine) ############### # embed coordinates yielding one query token for each target token From dbffbea3a67c7883ab27abff2d2fae02c9ae4caa Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 21 Nov 2025 14:34:06 +0100 Subject: [PATCH 093/344] fix training dataflow with diffusion FE --- config/default_config.yml | 1 + src/weathergen/model/diffusion.py | 24 +++++++++++++++++------- src/weathergen/model/model.py | 3 +++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 532ad2075..b30a80bea 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -51,6 +51,7 @@ fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True +fe_diffusion_model: True impute_latent_noise_std: 0.0 # 1e-4 healpix_level: 5 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index a7102a0d9..eb7ae6076 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -14,13 +14,15 @@ # Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # ---------------------------------------------------------------------------- + +import dataclasses + import torch -from dataclass import dataclass from weathergen.model.engines import ForecastingEngine -@dataclass +@dataclasses.dataclass class BatchData: """ Mock function for the data that will be provided to the diffusion model. Will change. @@ -70,7 +72,7 @@ def __init__( self.p_mean = p_mean self.p_std = p_std - def forward(self, data: BatchData) -> torch.Tensor: + def forward(self, tokens: torch.Tensor, fstep: int) -> torch.Tensor: """ Model forward call during training. Unpacks the conditioning c = [x_{t-k}, ..., x_{t}], the target y = x_{t+1}, and the random noise eta from the data, computes the diffusion noise @@ -79,9 +81,13 @@ def forward(self, data: BatchData) -> torch.Tensor: """ # Retrieve conditionings [0:-1], target [-1], and noise from data object. # TOOD: The data retrieval ignores batch and stream dimension for now (has to be adapted). - c = [data.get_input_data(t) for t in range(data.get_sample_len() - 1)] - y = data.get_input_data(-1) - eta = data.get_input_metadata(-1) + # c = [data.get_input_data(t) for t in range(data.get_sample_len() - 1)] + # y = data.get_input_data(-1) + # eta = data.get_input_metadata(-1) + + c = 1 + y = tokens + eta = torch.randn(1).to(device=tokens.device) # Compute sigma (noise level) from eta # noise = torch.randn(y.shape, device=y.device) # now eta from MultiStreamDataSampler @@ -102,11 +108,15 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float) -> torch.Tenso # Compute scaling conditionings c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt + c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() c_noise = sigma.log() / 4 # Precondition input and feed through network x = self.preconditioner.precondition(x, c) + # return c_skip * x + c_out * self.net(c_in * x, c_noise) # Eq. (7) in EDM paper + + fstep = 0 + aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") return c_skip * x + c_out * self.net(c_in * x, c_noise) # Eq. (7) in EDM paper def inference( diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7d19aa8b3..93981afb4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -24,6 +24,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.model.engines import ( EmbeddingEngine, EnsPredictionHead, @@ -333,6 +334,8 @@ def create(self) -> "Model": ) self.forecast_engine = ForecastingEngine(cf, self.num_healpix_cells) + if cf.fe_diffusion_model: + self.forecast_engine = DiffusionForecastEngine(forecast_engine=self.forecast_engine) ############### # embed coordinates yielding one query token for each target token From f8c9369368d6447772fabb3f071db3ff0878e17e Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 21 Nov 2025 14:41:12 +0100 Subject: [PATCH 094/344] update validation loop --- src/weathergen/model/diffusion.py | 4 ---- src/weathergen/train/trainer.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index eb7ae6076..9a24aa07a 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -113,10 +113,6 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float) -> torch.Tenso # Precondition input and feed through network x = self.preconditioner.precondition(x, c) - # return c_skip * x + c_out * self.net(c_in * x, c_noise) # Eq. (7) in EDM paper - - fstep = 0 - aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") return c_skip * x + c_out * self.net(c_in * x, c_noise) # Eq. (7) in EDM paper def inference( diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index c2eb52bfe..26b3eac5f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -739,8 +739,15 @@ def validate(self, mini_epoch): output = model_forward( self.model_params, batch, cf.forecast_offset, forecast_steps ) - - targets = {"physical": batch[0]} + targets, aux_outputs = self.target_and_aux_calculator.compute( + bidx, + batch, + self.model_params, + self.model, + cf.forecast_offset, + forecast_steps, + ) + targets = {"targets": [targets], "aux_outputs": aux_outputs} # compute loss loss, loss_values = self.loss_calculator_val.compute_loss( From ece1dd0781b8c4ed34222ab4140033b0c9925a40 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 21 Nov 2025 16:22:27 +0000 Subject: [PATCH 095/344] move build_views_for_stream into masker --- src/weathergen/datasets/masking.py | 71 +++++++++++++- .../datasets/multi_stream_data_sampler.py | 4 +- src/weathergen/datasets/tokenizer_masking.py | 1 - src/weathergen/datasets/view_builder.py | 97 ------------------- 4 files changed, 70 insertions(+), 103 deletions(-) delete mode 100644 src/weathergen/datasets/view_builder.py diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f6c4f60fc..98b99a177 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -1,9 +1,11 @@ import logging +from typing import List, Tuple import numpy as np import torch from weathergen.common.config import Config +from weathergen.datasets.batch import ViewMetadata _logger = logging.getLogger(__name__) @@ -142,7 +144,7 @@ def mask_source_idxs( idxs_cells_lens, rdata, keep_mask: np.typing.NDArray | None = None, - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """ Return: @@ -216,7 +218,7 @@ def mask_targets_idxs( idxs_cells, idxs_cells_lens, rdata, - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: # mask_source_idxs is assert (self.mask_tokens is not None) or (self.mask_tokens is not None) idxs_ord_inv = torch.tensor([], dtype=torch.int64) @@ -626,6 +628,71 @@ def _generate_causal_mask( return full_mask + def build_views_for_stream( + self, + num_cells: int, + teacher_cfg: dict, + student_cfg: dict, + relationship: str = "subset", + ) -> Tuple[np.ndarray, List[np.ndarray], List[ViewMetadata]]: + """Construct teacher/student keep masks for a stream.""" + + strat_teacher = teacher_cfg.get("strategy", "random") + rate_teacher = teacher_cfg.get("rate") + t_cfg_extra = teacher_cfg.get("masking_strategy_config") + + teacher_keep_mask = self.generate_cell_keep_mask( + num_cells=num_cells, + strategy=strat_teacher, + rate=rate_teacher, + masking_strategy_config=t_cfg_extra, + ) + + num_views = student_cfg.get("num_views", 1) + strat_student = student_cfg.get("masking_strategy", student_cfg.get("strategy", "random")) + rate_student = student_cfg.get("rate") + s_cfg_extra = student_cfg.get("masking_strategy_config") + + student_keep_masks: List[np.ndarray] = [] + for _ in range(num_views): + base = self.generate_cell_keep_mask( + num_cells=num_cells, + strategy=strat_student, + rate=rate_student, + masking_strategy_config=s_cfg_extra, + ) + if relationship == "subset": + keep = base & teacher_keep_mask + elif relationship == "disjoint": + keep = base & (~teacher_keep_mask) + else: + keep = base + student_keep_masks.append(keep) + + metadata: List[ViewMetadata] = [ + ViewMetadata( + view_id="teacher_global", + keep_mask=teacher_keep_mask, + strategy=strat_teacher, + healpix_level=self.healpix_level_data, + rate=rate_teacher, + parent_view_id=None, + ) + ] + for idx, mask in enumerate(student_keep_masks): + metadata.append( + ViewMetadata( + view_id=f"student_local_{idx}", + keep_mask=mask, + strategy=strat_student, + healpix_level=self.healpix_level_data, + rate=rate_student, + parent_view_id="teacher_global", + ) + ) + + return teacher_keep_mask, student_keep_masks, metadata + # --------------------------------------------------------------------- # Cell-level keep mask generation (teacher/student view selection) # --------------------------------------------------------------------- diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 68d8346c2..3e74c3049 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -33,7 +33,6 @@ compute_offsets_scatter_embed, compute_source_cell_lens, ) -from weathergen.datasets.view_builder import build_views_for_stream from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage @@ -646,8 +645,7 @@ def to_bool_tensor(arr): # add a loop over num_teacher_views, generate students for each teacher for t_idx in range(num_teacher_views): # Build one teacher and its student views - t_keep_np, s_keeps_np, _meta = build_views_for_stream( - self.tokenizer.masker, + t_keep_np, s_keeps_np, _meta = self.tokenizer.masker.build_views_for_stream( self.num_healpix_cells, teacher_cfg=teacher_cfg, student_cfg=student_cfg, diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 297604451..fc177bdf8 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -22,7 +22,6 @@ tokenize_space, tokenize_spacetime, ) -from weathergen.datasets.view_builder import build_views_for_stream def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData: diff --git a/src/weathergen/datasets/view_builder.py b/src/weathergen/datasets/view_builder.py deleted file mode 100644 index 21306826c..000000000 --- a/src/weathergen/datasets/view_builder.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np -from typing import Tuple, List -from weathergen.datasets.masking import Masker -from weathergen.datasets.batch import ViewMetadata - - -def build_views_for_stream( - masker: Masker, - num_cells: int, - teacher_cfg: dict, - student_cfg: dict, - relationship: str = "subset", -) -> Tuple[np.ndarray, List[np.ndarray], List[ViewMetadata]]: - """ - - Per-stream view construction: teacher + N student keep masks. - - Parameters - ---------- - masker : Masker - Instance providing RNG and healpix-level info. - num_cells : int - Number of healpix cells at data level. - teacher_cfg : dict - Config: {strategy, rate|keep_m, hl_mask, masking_strategy_config, rate_sampling}. - student_cfg : dict - Config: {masking_strategy, rate, num_views, hl_mask, masking_strategy_config, rate_sampling}. - relationship : str - One of {'subset','disjoint','independent'}. Determines derivation of student masks. - - Returns - ------- - teacher_keep_mask : np.ndarray - Boolean keep mask for teacher view. - student_keep_masks : list[np.ndarray] - Boolean keep masks for each student view. - metadata : list[ViewMetadata] - Metadata objects (teacher first, then students). - - """ - strat_teacher = teacher_cfg.get("strategy", "random") - rate_teacher = teacher_cfg.get("rate") - t_cfg_extra = teacher_cfg.get("masking_strategy_config") - - teacher_keep_mask = masker.generate_cell_keep_mask( - num_cells=num_cells, - strategy=strat_teacher, - rate=rate_teacher, - masking_strategy_config=t_cfg_extra, - ) - - # Student base masks - num_views = student_cfg.get("num_views", 1) - strat_student = student_cfg.get("masking_strategy", student_cfg.get("strategy", "random")) - rate_student = student_cfg.get("rate") - s_cfg_extra = student_cfg.get("masking_strategy_config") - - student_keep_masks: List[np.ndarray] = [] - for v in range(num_views): - base = masker.generate_cell_keep_mask( - num_cells=num_cells, - strategy=strat_student, - rate=rate_student, - masking_strategy_config=s_cfg_extra, - ) - if relationship == "subset": - keep = base & teacher_keep_mask - elif relationship == "disjoint": - keep = base & (~teacher_keep_mask) - else: # independent - keep = base - student_keep_masks.append(keep) - - metadata: List[ViewMetadata] = [] - metadata.append( - ViewMetadata( - view_id="teacher_global", - keep_mask=teacher_keep_mask, - strategy=strat_teacher, - healpix_level=masker.healpix_level_data, - rate=rate_teacher, - parent_view_id=None, - ) - ) - for i, m in enumerate(student_keep_masks): - metadata.append( - ViewMetadata( - view_id=f"student_local_{i}", - keep_mask=m, - strategy=strat_student, - healpix_level=masker.healpix_level_data, - rate=rate_student, - parent_view_id="teacher_global", - ) - ) - - return teacher_keep_mask, student_keep_masks, metadata From b9a60f3b8c36c6d578a94d2d12179cabce84f4f5 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 21 Nov 2025 18:38:40 +0000 Subject: [PATCH 096/344] tidy up, remove unused arguments, types --- src/weathergen/datasets/batch.py | 14 ++++----- src/weathergen/datasets/masking.py | 16 +++++----- .../datasets/multi_stream_data_sampler.py | 30 +++++-------------- src/weathergen/datasets/stream_data.py | 26 ++++++++-------- src/weathergen/datasets/tokenizer_masking.py | 10 +++---- 5 files changed, 41 insertions(+), 55 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index ae69dcd5f..c66b7874b 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -10,14 +10,15 @@ import numpy as np +from weathergen.common.config import Config from weathergen.datasets.stream_data import StreamData # TODO: Add a store for a random number for diffusion # TODO: GetTimestep to get the timestep -# TODO: GetData: get the streamdata # TODO: GetMetaData: then this gets the right rn for the timestep! +# NOTE: TO BE DECPRECATED @dataclass class ViewMetadata: """ @@ -29,7 +30,6 @@ class ViewMetadata: Attributes: view_id: Unique identifier (e.g., "teacher_global", "student_local_0") keep_mask: Boolean array [num_healpix_cells] at data level indicating kept cells - strategy: Name of selection strategy ("random", "healpix_level_2", etc.) healpix_level: HEALPix level for hierarchical selection (None if not applicable) rate: Fraction of data kept (e.g., 0.5 = 50% kept); None if fixed count parent_view_id: ID of the parent view this is a subset of (None for teacher) @@ -47,7 +47,7 @@ class ViewMetadata: # Optional extras for future/other training paradigms loss_type: str | None = None # e.g. DINO, JEPA - strategy_config: dict | None = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} + strategy_config: Config | None = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} class SampleMetaData: @@ -55,7 +55,7 @@ class SampleMetaData: masking_strategy: str # parameters for masking strategy - masking_params: dict + masking_params: Config class Sample: @@ -64,7 +64,7 @@ class Sample: # data for all streams # keys: stream_name, values: StreamData - streams_data: dict + streams_data: dict[str, StreamData | None] def __init__(self, streams: dict) -> None: # TODO: can we pass this right away? @@ -97,8 +97,8 @@ class ModelBatch: # index of corresponding target (for source samples) or source (for target samples) # these are in 1-to-1 corresponding for classical training modes (MTM, forecasting) but # can be more complex for strategies like student-teacher training - source_matching_idx: np.typing.NDArray[np.int32] - target_matching_idx: np.typing.NDArray[np.int32] + source_target_matching_idxs: np.typing.NDArray[np.int32] + target_source_matching_idxs: np.typing.NDArray[np.int32] def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> None: """ """ diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 98b99a177..0548ce388 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -139,12 +139,10 @@ def _select_strategy(self): def mask_source_idxs( self, - stream_info, idxs_cells, idxs_cells_lens, - rdata, keep_mask: np.typing.NDArray | None = None, - ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Return: @@ -214,11 +212,9 @@ def mask_source_idxs( def mask_targets_idxs( self, - stream_info, idxs_cells, idxs_cells_lens, - rdata, - ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # mask_source_idxs is assert (self.mask_tokens is not None) or (self.mask_tokens is not None) idxs_ord_inv = torch.tensor([], dtype=torch.int64) @@ -634,8 +630,12 @@ def build_views_for_stream( teacher_cfg: dict, student_cfg: dict, relationship: str = "subset", - ) -> Tuple[np.ndarray, List[np.ndarray], List[ViewMetadata]]: - """Construct teacher/student keep masks for a stream.""" + ) -> Tuple[np.typing.NDArray, List[np.typing.NDArray], List[ViewMetadata]]: + """ + Construct teacher/student keep masks for a stream. + ViewMetadata likely to be deprecated, + but information can be piped here for now. + """ strat_teacher = teacher_cfg.get("strategy", "random") rate_teacher = teacher_cfg.get("rate") diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 3e74c3049..ba89b9401 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -68,7 +68,7 @@ def collect_datasources(stream_datasets: list, idx: int, type: str) -> IOReaderD class MultiStreamDataSampler(torch.utils.data.IterableDataset): - ################################################### + def __init__( self, cf, @@ -220,14 +220,12 @@ def __init__( self.epoch = 0 - ################################################### def advance(self): """ Advance epoch (this is applied to the template for the worker processes) """ self.epoch += 1 - ################################################### def get_sources_size(self): return [ 0 @@ -239,15 +237,12 @@ def get_sources_size(self): for ds in self.streams_datasets ] - ################################################### def get_sources_num_channels(self): return [ds[0].get_source_num_channels() for ds in self.streams_datasets] - ################################################### def get_targets_num_channels(self): return [ds[0].get_target_num_channels() for ds in self.streams_datasets] - ################################################### def get_targets_coords_size(self): # TODO: avoid hard coding magic values # +6 at the end for stram_id and time encoding @@ -255,7 +250,6 @@ def get_targets_coords_size(self): (ds[0].get_geoinfo_size() + (5 * (3 * 5)) + 3 * 8) + 6 for ds in self.streams_datasets ] - ################################################### def reset(self): # initialize the random number generator: self.data_loader_rng_seed is set to a DDP-unique # value in worker_workset() @@ -369,7 +363,8 @@ def _build_stream_data_output( output_tokens: list, mask_state: dict | None = None, ) -> StreamData: - """ """ + """ + """ # collect for all forecast steps dt = self.forecast_offset + forecast_dt @@ -415,7 +410,6 @@ def _build_stream_data( mode: str, base_idx: TIndex, forecast_dt: int, - # view_meta: ViewMetadata, stream_info: dict, input_data: list, output_data: list, @@ -428,11 +422,10 @@ def _build_stream_data( Build a StreamData object for a single view (teacher or student). Args: - mode : {student, teacher, physical} + mode : stream_data : base_idx: Time index for this sample forecast_dt: Number of forecast steps - view_meta: ViewMetadata describing spatial mask stream_info: Stream configuration dict stream_ds: List of dataset readers for this stream @@ -531,7 +524,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): streams_data: list[StreamData] = [] # get/coordinate masks - masks_streams = self._get_source_target_masks(idx, forecast_dt) + masks_streams = self._get_source_target_masks() # Determine number of views direct from config (teacher & student views) teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} @@ -612,14 +605,10 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): return streams_data - def _get_source_target_masks(self, idx: int, forecast_dt: int): + def _get_source_target_masks(self): """ - Return one batch of data - Build a StreamData object for a single view (teacher or student). - - Args: - idx: Time index for this sample - forecast_dt: Number of forecast steps + Generate source and target masks for all streams + according to the student-teacher configuration """ masks = {} @@ -681,7 +670,6 @@ def _preprocess_model_data(self, batch, forecast_dt): return batch, source_cell_lens, target_coords_idx - ################################################### def __iter__(self): """ Return one batch of data @@ -740,11 +728,9 @@ def __iter__(self): yield (batch, source_cell_lens, target_coords_idx, forecast_dt) - ################################################### def __len__(self): return self.len - ################################################### def worker_workset(self): local_start, local_end = self.rank * self.len, (self.rank + 1) * self.len diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 0e9166370..19cf94b18 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -137,7 +137,7 @@ def add_empty_target(self, fstep: int) -> None: ] def add_source( - self, step: int, ss_raw: IOReaderData, ss_lens: torch.tensor, ss_cells: list + self, step: int, ss_raw: IOReaderData, ss_lens: torch.Tensor, ss_cells: list ) -> None: """ Add data for source for one input. @@ -145,9 +145,9 @@ def add_source( Parameters ---------- ss_raw : IOReaderData( dataclass containing coords, geoinfos, data, and datetimes ) - ss_lens : torch.tensor( number of healpix cells ) + ss_lens : torch.Tensor( number of healpix cells ) ss_cells : list( number of healpix cells ) - [ torch.tensor( tokens per cell, token size, number of channels) ] + [ torch.Tensor( tokens per cell, token size, number of channels) ] Returns ------- @@ -166,11 +166,11 @@ def add_target( self, fstep: int, targets: list, - target_coords: torch.tensor, - target_coords_per_cell: torch.tensor, - target_coords_raw: torch.tensor, - times_raw: torch.tensor, - idxs_inv: torch.tensor, + target_coords: torch.Tensor, + target_coords_per_cell: torch.Tensor, + target_coords_raw: torch.Tensor, + times_raw: torch.Tensor, + idxs_inv: torch.Tensor, ) -> None: """ Add data for target for one input. @@ -209,9 +209,9 @@ def add_target_values( self, fstep: int, targets: list, - target_coords_raw: torch.tensor, - times_raw: torch.tensor, - idxs_inv: torch.tensor, + target_coords_raw: torch.Tensor, + times_raw: torch.Tensor, + idxs_inv: torch.Tensor, ) -> None: """ Add data for target for one input. @@ -247,8 +247,8 @@ def add_target_values( def add_target_coords( self, fstep: int, - target_coords: torch.tensor, - target_coords_per_cell: torch.tensor, + target_coords: torch.Tensor, + target_coords_per_cell: torch.Tensor, ) -> None: """ Add data for target for one input. diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index fc177bdf8..72b314570 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -101,11 +101,11 @@ def get_source( if keep_mask is not None: keep_np = keep_mask.cpu().numpy().astype(bool) (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - stream_info, idxs_cells, idxs_cells_lens, rdata, keep_mask=keep_np + idxs_cells, idxs_cells_lens, keep_mask=keep_np ) else: (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - stream_info, idxs_cells, idxs_cells_lens, rdata + idxs_cells, idxs_cells_lens, ) source_tokens_cells, source_tokens_lens = tokenize_apply_mask_source( @@ -152,7 +152,7 @@ def get_target( self.masker.mask_channels = mask_state.get("mask_channels") (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - stream_info, idxs_cells, idxs_cells_lens, rdata + idxs_cells, idxs_cells_lens, ) data, datetimes, coords, coords_local, coords_per_cell = tokenize_apply_mask_target( @@ -193,7 +193,7 @@ def get_target_coords( self.masker.mask_channels = mask_state.get("mask_channels") (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - stream_info, idxs_cells, idxs_cells_lens, rdata + idxs_cells, idxs_cells_lens, ) # TODO: split up @@ -259,7 +259,7 @@ def get_target_values( self.masker.mask_channels = mask_state.get("mask_channels") (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - stream_info, idxs_cells, idxs_cells_lens, rdata + idxs_cells, idxs_cells_lens, ) data, datetimes, coords, _, _ = tokenize_apply_mask_target( From 2905cb0b77a40c501716989b381ce606642b7965 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Sat, 22 Nov 2025 13:59:37 +0000 Subject: [PATCH 097/344] fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working --- .../datasets/multi_stream_data_sampler.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index ba89b9401..21e323f15 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -326,6 +326,11 @@ def _build_stream_data_input( # source input data + # Fornow, keep only mask state of the final timestep + # (correspondsing to base_idx, first of the loop below) + # to ensure alignment with the target data for MTM/S-T. + final_mask_state = None + # iterate overall input steps for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): # TODO: check that we are not out of bounds when we go back in time @@ -333,8 +338,9 @@ def _build_stream_data_input( time_win_source = self.time_window_handler.window(idx) # collect all targets for current stream - rdata = input_data[step] - token_data = input_tokens[step] + # do we want this to be ascending or descending in time? + rdata = input_data[-(step+1)] + token_data = input_tokens[-(step+1)] stream_data.source_is_spoof = rdata.is_spoof @@ -347,10 +353,13 @@ def _build_stream_data_input( keep_mask=mask, ) + if step == 0: + final_mask_state = mask_state + # collect data for stream stream_data.add_source(step, rdata, source_cells_lens, source_cells) - return stream_data, mask_state + return stream_data, final_mask_state def _build_stream_data_output( self, @@ -591,7 +600,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # TODO: seb to check # Map target to all source students student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] - # print("Student indices", student_indices) batch.add_target_stream(t_idx, student_indices, name, sdata) # TODO: build batch From b193a50f5aff2d22585d31f38ab2170f71787ba6 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 24 Nov 2025 17:13:37 +0100 Subject: [PATCH 098/344] updated configs so code runs. Note default config to be overhauled still --- config/default_config.yml | 5 ++++- config/streams/era5_nppatms_synop/era5.yml | 1 + config/streams/era5_nppatms_synop/npp_atms.yml | 1 + config/streams/era5_nppatms_synop/synop.yml | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index ce1ff3f5e..e467f3aa7 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -39,7 +39,7 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 1 +forecast_offset : 0 forecast_delta_hrs: 0 forecast_steps: 0 forecast_policy: null @@ -88,6 +88,9 @@ training_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse' } validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},} } + +# masking +masking_strategy: "random" # obviously TODO # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # diff --git a/config/streams/era5_nppatms_synop/era5.yml b/config/streams/era5_nppatms_synop/era5.yml index c51eb6e33..90d0b9790 100644 --- a/config/streams/era5_nppatms_synop/era5.yml +++ b/config/streams/era5_nppatms_synop/era5.yml @@ -9,6 +9,7 @@ ERA5 : type : anemoi + stream_id : 0 filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] loss_weight : 1. source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] diff --git a/config/streams/era5_nppatms_synop/npp_atms.yml b/config/streams/era5_nppatms_synop/npp_atms.yml index 583c1b4b2..75302f443 100644 --- a/config/streams/era5_nppatms_synop/npp_atms.yml +++ b/config/streams/era5_nppatms_synop/npp_atms.yml @@ -9,6 +9,7 @@ NPPATMS : type : obs + stream_id : 1 filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr'] loss_weight : 1.0 token_size : 32 diff --git a/config/streams/era5_nppatms_synop/synop.yml b/config/streams/era5_nppatms_synop/synop.yml index 97a575019..ce9adfa44 100644 --- a/config/streams/era5_nppatms_synop/synop.yml +++ b/config/streams/era5_nppatms_synop/synop.yml @@ -5,6 +5,7 @@ SurfaceCombined : type : obs + stream_id : 2 filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] loss_weight : 1.0 masking_rate : 0.6 From fa24fc1d9adad699234162723b3a5aa57c6ce4f7 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 25 Nov 2025 16:36:52 +0100 Subject: [PATCH 099/344] very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up --- config/default_config.yml | 20 +++++----------- src/weathergen/datasets/batch.py | 24 ++++++++++++++++++- .../datasets/multi_stream_data_sampler.py | 24 +++++++++++++++---- src/weathergen/train/trainer.py | 20 ++++++++++++++++ 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index e467f3aa7..6c54d1796 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -90,7 +90,7 @@ validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['ms } # masking -masking_strategy: "random" # obviously TODO +masking_strategy: "dog" # obviously TODO # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # @@ -117,26 +117,18 @@ training_config: model_input: masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher - rate: 0.5 # Masking rate to use for model input + rate: 0.4 # Masking rate to use for model input num_views: 4 # if student-teacher, the number of local (student) views to generate - masking_strategy_config: {"strategies": ["random", "healpix", "channel"], # will be used with masking is moved under here - "probabilities": [0.34, 0.33, 0.33], - "hl_mask": 0, "mode": "per_cell", - "same_strategy_per_batch": false - } + hl_mask : 4 # healpix level to use for healpix masking strategy relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. teacher_model_input: strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - rate: 0.5 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + rate: 0.8 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) num_views: 2 # number of teacher views to generate + hl_mask : 0 # healpix level to use for healpix masking strategy # keep_m: 100 # Alternative to rate: keep exactly this many parent cells rate_sampling: true # randomly sample the rate per batch - masking_strategy_config: {"strategies": ["random", "healpix", "channel"], - "probabilities": [0.34, 0.33, 0.33], - "hl_mask": 4, "mode": "per_cell", - "same_strategy_per_batch": false - } @@ -174,7 +166,7 @@ input_window_steps: 1 val_initial: False -loader_num_workers: 8 +loader_num_workers: 0 log_validation: 0 streams_output: ["ERA5"] diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index c66b7874b..e7047d55a 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -52,7 +52,7 @@ class ViewMetadata: class SampleMetaData: # masking strategy - masking_strategy: str + # masking_strategy: str # parameters for masking strategy masking_params: Config @@ -81,8 +81,21 @@ def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" self.streams_data[stream_name] = stream_data + def add_meta_info(self, stream_name: str, meta_info: SampleMetaData) -> None: + """ + Add metadata for stream @stream_name to sample + """ + self.meta_info[stream_name] = meta_info + # TODO: complete interface, e.g get_stream + def get_stream_data(self, stream_name: str) -> StreamData: + """ + Get data for stream @stream_name from sample + """ + assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" + return self.streams_data[stream_name] + class ModelBatch: """ Container for all data and metadata for one training batch. @@ -116,12 +129,17 @@ def add_source_stream( target_sample_idx: int, stream_name: str, stream_data: StreamData, + source_meta_info: SampleMetaData, ) -> None: """ Add data for one stream to sample @source_sample_idx """ self.source_samples[source_sample_idx].add_stream_data(stream_name, stream_data) + # add the meta_info + self.source_samples[source_sample_idx].add_meta_info(stream_name, source_meta_info) + + assert target_sample_idx < len(self.target_samples), "invalid value for target_sample_idx" self.source_target_matching_idxs[source_sample_idx] = target_sample_idx @@ -131,12 +149,16 @@ def add_target_stream( source_sample_idx: int | list[int], stream_name: str, stream_data: StreamData, + target_meta_info: SampleMetaData, ) -> None: """ Add data for one stream to sample @target_sample_idx """ self.target_samples[target_sample_idx].add_stream_data(stream_name, stream_data) + # add the meta_info -- for target we have different + self.target_samples[target_sample_idx].add_meta_info(stream_name, target_meta_info) + if isinstance(source_sample_idx, int): assert source_sample_idx < len(self.source_samples), "invalid value for source_sample_idx" else: diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 809d76a86..da2dc84ca 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -578,10 +578,16 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_source[name] = sdata + + # source meta info... + # source_meta_info = SampleMetaData(... + + source_meta_info = student_cfg + # TODO: seb check this # Map each student (source) to its teacher (target) t_idx = student_to_teacher[sidx] - batch.add_source_stream(sidx, t_idx, name, sdata) + batch.add_source_stream(sidx, t_idx, name, sdata, source_meta_info) # stream_data_target can contain network input stream_data_target = {} @@ -600,10 +606,16 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_target[name] = sdata + + # get teacher config info + teacher_meta_info = teacher_cfg + # TODO: seb to check # Map target to all source students student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] - batch.add_target_stream(t_idx, student_indices, name, sdata) + batch.add_target_stream(t_idx, student_indices, name, sdata, teacher_meta_info) + + # import pdb; pdb.set_trace() # TODO: build batch # source_input @@ -614,7 +626,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # add data for current stream streams_data += [v for k, v in stream_data_source.items()] - return streams_data + return streams_data, batch def _get_source_target_masks(self): """ @@ -720,7 +732,7 @@ def __iter__(self): mode = "student_teacher" - streams_data = self._get_sample(mode, idx, forecast_dt) + streams_data, student_teacher_batch = self._get_sample(mode, idx, forecast_dt) # Reset masking strategy for next batch item if hasattr(self.tokenizer, "masker"): @@ -737,7 +749,9 @@ def __iter__(self): batch, forecast_dt ) - yield (batch, source_cell_lens, target_coords_idx, forecast_dt) + import pdb; pdb.set_trace() + + yield (batch, source_cell_lens, target_coords_idx, forecast_dt), student_teacher_batch def __len__(self): return self.len diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 21d0b91fb..a8f02a247 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -612,6 +612,26 @@ def train(self, mini_epoch): # training loop self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): + + # make existing pipeline work: + batch = batch[0] + + ################################################################ + # SOPH: student teacher access path here: + # student_teacher_data = batch[1] + # access student views: + #all_student_views = student_teacher_data.source_samples + #student_sample_1 = student_teacher_data.source_samples[0] + #student_sample_1_stream_data = student_teacher_data.source_samples[0].streams_data # dict, {stream: stream data} of first student view + # e.g. target tokens of ERA5 stream of first student view: + # target_tokens_of_student_sample_1_ERA5_stream_data = student_teacher_batch.source_samples[0].streams_data["ERA5"].target_tokens + + # access metadata of the student views, this is currently shared, very hacky, to fix. + #metadata_student_view = student_teacher_batch.source_samples[0].meta_info + + # You will also need the source_cell_lens, target_coords_idx, these are not being passed through for the views yet. + ################################################################ + forecast_steps = batch[-1] batch = self.batch_to_device(batch) From 4f8f62b3df3338bbe63fc3a69f4cc39be19062c1 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 25 Nov 2025 18:56:56 +0100 Subject: [PATCH 100/344] instructions for sophie --- src/weathergen/train/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index a8f02a247..8a1e8d5d9 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -613,9 +613,6 @@ def train(self, mini_epoch): self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): - # make existing pipeline work: - batch = batch[0] - ################################################################ # SOPH: student teacher access path here: # student_teacher_data = batch[1] @@ -631,6 +628,11 @@ def train(self, mini_epoch): # You will also need the source_cell_lens, target_coords_idx, these are not being passed through for the views yet. ################################################################ + + # make existing pipeline work: + batch = batch[0] + + forecast_steps = batch[-1] batch = self.batch_to_device(batch) From c0df0bf4ee947e90887866b2895dd266a43309f3 Mon Sep 17 00:00:00 2001 From: Moritz Hauschulz <60788263+moritzhauschulz@users.noreply.github.com> Date: Wed, 26 Nov 2025 10:46:26 +0000 Subject: [PATCH 101/344] Issue1279 noise conditioning (#1337) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit [draft] * adapt noise conditioner to make it closer to DiT * adapt dimensionalities – code runs with default config * lint * Updated Copyright * Updated Copyright * fixes round 1 --- NOTICE | 26 +++++++ .../common/src/weathergen/common/config.py | 2 +- packages/dashboard/atmo_eval.py | 4 +- src/weathergen/model/attention.py | 51 ++++++++++++-- src/weathergen/model/diffusion.py | 70 +++++++++++++++++-- src/weathergen/model/engines.py | 16 ++++- src/weathergen/model/layers.py | 69 +++++++++++++++++- src/weathergen/utils/validation_io.py | 2 +- 8 files changed, 221 insertions(+), 19 deletions(-) diff --git a/NOTICE b/NOTICE index ddd243b23..d657eede5 100644 --- a/NOTICE +++ b/NOTICE @@ -12,3 +12,29 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +======================================================================= +google-deepmind/graphcast (several associated papers) + +This software incorporates code from the 'google-deepmind/graphcast' repository, with adaptations. + +Original Copyright 2024 DeepMind Technologies Limited. + +The source code is available at: +https://github.com/google-deepmind/graphcast + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 + +======================================================================= +facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) + +This software incorporates code from the 'facebookresearch/DiT' repository, with adaptations. + +The source code is available at: +https://github.com/facebookresearch/DiT + +The code and model weights are licensed under CC-BY-NC. +See https://raw.githubusercontent.com/facebookresearch/DiT/refs/heads/main/LICENSE.txt for details. diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index c32732ba7..1b9e1928d 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -225,7 +225,7 @@ def load_config( # use OmegaConf.unsafe_merge if too slow c = OmegaConf.merge(base_config, private_config, *overwrite_configs) assert isinstance(c, Config) - + # Ensure the config has mini-epoch notation if hasattr(c, "samples_per_epoch"): c.samples_per_mini_epoch = c.samples_per_epoch diff --git a/packages/dashboard/atmo_eval.py b/packages/dashboard/atmo_eval.py index a98b32268..3dc077f6e 100644 --- a/packages/dashboard/atmo_eval.py +++ b/packages/dashboard/atmo_eval.py @@ -77,7 +77,9 @@ def get_score_step_48h(score_col: str) -> pl.DataFrame: .sort("start_time") .filter(pl.col(score_col).is_not_null()) ) - _logger.info(f"Getting score data for {score_col} at 48h (step={step_48h}): len={len(score_data)}") + _logger.info( + f"Getting score data for {score_col} at 48h (step={step_48h}): len={len(score_data)}" + ) # Iterate over the runs to get the metric at step 48h scores_dt: list[float | None] = [] diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 39ed1c041..f63554367 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -13,6 +13,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from weathergen.model.layers import LinearNormConditioning from weathergen.model.norms import AdaLayerNorm, RMSNorm @@ -197,6 +198,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_noise_conditioning=False, # should only be True for diffusion model ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -242,11 +244,29 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - def forward(self, x, ada_ln_aux=None): + self.noise_conditioning = None + if with_noise_conditioning: + self.noise_conditioning = LinearNormConditioning(dim_embed, dtype=self.dtype) + + def forward(self, *args): + # NOTE: Hotfix to accomodate TargetPredictionEngineClassic forward pass for attn. block, MLP... + x = args[0] + if len(args) == 2: + ada_ln_aux = args[1] + elif len(args) > 2: + ada_ln_aux = args[-1] + emb = args[1] if self.noise_conditioning else None + else: + ada_ln_aux = None + emb = None + if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + if self.noise_conditioning: + x, gate = self.noise_conditioning(x, emb) + # project onto heads s = [x.shape[0], x.shape[1], self.num_heads, -1] qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) @@ -257,7 +277,7 @@ def forward(self, x, ada_ln_aux=None): out = self.proj_out(self.dropout(outs.flatten(-2, -1))) if self.with_residual: - out = x_in + out + out = x_in + out * gate if self.noise_conditioning else x_in + out return out @@ -487,6 +507,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_noise_conditioning=False, # should only be True for diffusion model ): super(MultiSelfAttentionHead, self).__init__() @@ -527,11 +548,33 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) - def forward(self, x, ada_ln_aux=None): + self.noise_conditioning = None + if with_noise_conditioning: + # NOTE: noise_emb_dim currently hard-coded + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed, noise_emb_dim=512, dtype=self.dtype + ) + + def forward(self, *args): + # NOTE: Hotfix to accomodate TargetPredictionEngineClassic forward pass for attn. block, MLP... + x = args[0] + if len(args) == 2: + ada_ln_aux = args[1] + elif len(args) > 2: + ada_ln_aux = args[-1] + emb = args[1] if self.noise_conditioning else None + else: + ada_ln_aux = None + emb = None + if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + if self.noise_conditioning: + assert emb is not None, "Need noise embedding if using noise conditioning" + x, gate = self.noise_conditioning(x, emb) + # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1] @@ -547,7 +590,7 @@ def forward(self, x, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: - out = out + x_in + out = out + x_in * gate if self.noise_conditioning else out + x_in return out diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 9a24aa07a..02f247739 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -14,11 +14,17 @@ # Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# Third-Party Attribution: facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) +# This file incorporates code originally from the 'facebookresearch/DiT' repository, with adaptations. +# +# The original code is licensed under CC-BY-NC. +# ---------------------------------------------------------------------------- -import dataclasses +import dataclasses +import math import torch - from weathergen.model.engines import ForecastingEngine @@ -53,6 +59,8 @@ class DiffusionForecastEngine(torch.nn.Module): def __init__( self, forecast_engine: ForecastingEngine, + frequency_embedding_dim: int = 256, # TODO: determine suitable dimension + embedding_dim: int = 512, # TODO: determine suitable dimension sigma_min: float = 0.002, # Adapt to GenCast? sigma_max: float = 80, sigma_data: float = 0.5, @@ -63,6 +71,9 @@ def __init__( super().__init__() self.net = forecast_engine self.preconditioner = Preconditioner() + self.noise_embedder = NoiseEmbedder( + embedding_dim=embedding_dim, frequency_embedding_dim=frequency_embedding_dim + ) # Parameters self.sigma_min = sigma_min @@ -93,13 +104,13 @@ def forward(self, tokens: torch.Tensor, fstep: int) -> torch.Tensor: # noise = torch.randn(y.shape, device=y.device) # now eta from MultiStreamDataSampler sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma - return self.denoise(x=y + n, c=c, sigma=sigma) + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) # Compute loss -- move this to a separate loss calculator # weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 # loss = weight * ((y_hat - y) ** 2) - def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float) -> torch.Tensor: + def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> torch.Tensor: """ The actual diffusion step, where the model removes noise from the input x under consideration of a conditioning c (e.g., previous time steps) and the current diffusion @@ -111,13 +122,17 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float) -> torch.Tenso c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() c_noise = sigma.log() / 4 + # Embed noise level + noise_emb = self.noise_embedder(c_noise) + # Precondition input and feed through network x = self.preconditioner.precondition(x, c) - return c_skip * x + c_out * self.net(c_in * x, c_noise) # Eq. (7) in EDM paper + return c_skip * x + c_out * self.net(c_in * x, fstep=fstep, noise_emb=noise_emb) # Eq. (7) in EDM paper def inference( self, x: torch.Tensor, + fstep: int, num_steps: int = 30, ) -> torch.Tensor: # Forward pass of the diffusion model during inference @@ -150,13 +165,13 @@ def inference( t_hat = t_cur # Euler step. - denoised = self.denoise(x=x_hat, c=None, sigma=t_hat) # c to be discussed + denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep) # c to be discussed d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - denoised = self.net(x_next, t_next) + denoised = self.denoise(x=x_next, c=None, sigma=t_next, fstep=fstep) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) @@ -170,3 +185,44 @@ def __init__(self): def precondition(self, x, c): return x + + +# NOTE: Adapted from DiT codebase: +class NoiseEmbedder(torch.nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, embedding_dim: int, frequency_embedding_dim: int, dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + self.mlp = torch.nn.Sequential( + torch.nn.Linear(frequency_embedding_dim, embedding_dim, bias=True), + torch.nn.SiLU(), + torch.nn.Linear(embedding_dim, embedding_dim, bias=True), + ) + self.frequency_embedding_dim = frequency_embedding_dim + + def timestep_embedding(self, t: float, max_period: int=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = self.frequency_embedding_dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=self.dtype) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.frequency_embedding_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t: float): + t_freq = self.timestep_embedding(t) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7359d1403..c1acf74f3 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -336,6 +336,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dim_aux=1, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, ) ) else: @@ -352,6 +353,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dim_aux=1, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, ) ) # Add MLP block @@ -364,6 +366,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, dim_aux=1, norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=self.cf.fe_diffusion_model, ) ) @@ -376,10 +379,17 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep): + def forward(self, tokens, fstep, noise_emb=None): aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") - for block in self.fe_blocks: - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + if self.cf.fe_diffusion_model: + assert noise_emb is not None, ( + "Noise embedding must be provided for diffusion forecast engine" + ) + for block in self.fe_blocks: + tokens = checkpoint(block, tokens, noise_emb, aux_info, use_reentrant=False) + else: + for block in self.fe_blocks: + tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) return tokens diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..8ab5156b5 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -7,6 +7,21 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +# ---------------------------------------------------------------------------- +# Third-Party Attribution: facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) +# This file incorporates code originally from the 'facebookresearch/DiT' repository, with adaptations. +# +# The original code is licensed under CC-BY-NC. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Third-Party Attribution: google-deepmind/graphcast (several associated papers) +# This file incorporates code originally from the 'google-deepmind/graphcast' repository, with adaptations. +# +# The original code is licensed under Apache 2.0. Original Copyright 2024 DeepMind Technologies Limited. +# ---------------------------------------------------------------------------- + + import torch import torch.nn as nn @@ -43,6 +58,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, name: str | None = None, + with_noise_conditioning=False, ): """Constructor""" @@ -55,6 +71,7 @@ def __init__( self.with_residual = with_residual self.with_aux = dim_aux is not None + self.with_noise_conditioning = with_noise_conditioning dim_hidden = int(dim_in * hidden_factor) self.layers = torch.nn.ModuleList() @@ -68,6 +85,11 @@ def __init__( else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) ) + if with_noise_conditioning: + self.noise_conditioning = LinearNormConditioning( + dim_in + ) # TODO: chech if should pass some dtype? + self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) self.layers.append(nonlin()) self.layers.append(torch.nn.Dropout(p=dropout_rate)) @@ -79,11 +101,20 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): - x, x_in, aux = args[0], args[0], args[-1] + x, x_in = args[0], args[0] + if len(args) == 2: + aux = args[1] + elif len(args) > 2: + aux = args[-1] + noise_emb = args[1] if self.with_noise_conditioning else None for i, layer in enumerate(self.layers): - x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + if isinstance(layer, LinearNormConditioning): + x = layer(x, noise_emb) # noise embedding + else: + x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) if self.with_residual: if x.shape[-1] == x_in.shape[-1]: @@ -93,3 +124,37 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + + +# NOTE: Inspired by GenCast/DiT. +class LinearNormConditioning(torch.nn.Module): + """Module for norm conditioning, adapted from GenCast with additional gate parameter from DiT. + + Conditions the normalization of `inputs` by applying a linear layer to the + `norm_conditioning` which produces the scale and offset for each channel. + """ + + def __init__(self, latent_space_dim: int, noise_emb_dim: int = 512, dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + + self.conditional_linear_layer = torch.nn.Linear( + in_features=noise_emb_dim, + out_features=3 * latent_space_dim, + ) + # Optional: initialize weights similar to TruncatedNormal(stddev=1e-8) + torch.nn.init.normal_(self.conditional_linear_layer.weight, std=1e-8) + torch.nn.init.zeros_(self.conditional_linear_layer.bias) + + def forward(self, inputs, noise_emb): + conditional_scale_offset = self.conditional_linear_layer(noise_emb.to(self.dtype)) + scale_minus_one, offset, gate = torch.chunk(conditional_scale_offset, 3, dim=-1) + scale = scale_minus_one + 1.0 + + # Reshape scale and offset for broadcasting if needed + while scale.dim() < inputs.dim(): + scale = scale.unsqueeze(1) + offset = offset.unsqueeze(1) + return (inputs * scale + offset).to( + self.dtype + ), gate # TODO: check if to(self.dtype) needed here diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 355be0e51..f8a5a1cc5 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -30,7 +30,7 @@ def write_output( sample_idxs, ): stream_names = [stream.name for stream in cf.streams] - analysis_streams_output = cf.get( 'analysis_streams_output', None) + analysis_streams_output = cf.get("analysis_streams_output", None) if cf.streams_output is not None: output_stream_names = cf.streams_output elif analysis_streams_output is not None: # --- to be removed at some point --- From c27156cab8a8aafb563ce3bd3a827944548cd072 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Wed, 26 Nov 2025 12:35:03 +0100 Subject: [PATCH 102/344] add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy --- src/weathergen/datasets/batch.py | 19 ++++++- src/weathergen/datasets/masking.py | 27 +++------- .../datasets/multi_stream_data_sampler.py | 50 +++++++++++++++---- 3 files changed, 65 insertions(+), 31 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index e7047d55a..583ebfcdf 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -9,6 +9,7 @@ from dataclasses import dataclass import numpy as np +import torch from weathergen.common.config import Config from weathergen.datasets.stream_data import StreamData @@ -50,13 +51,13 @@ class ViewMetadata: strategy_config: Config | None = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} +@dataclass class SampleMetaData: # masking strategy # masking_strategy: str # parameters for masking strategy - masking_params: Config - + masking_params: Config | dict class Sample: # keys: stream name, values: SampleMetaData @@ -66,6 +67,10 @@ class Sample: # keys: stream_name, values: StreamData streams_data: dict[str, StreamData | None] + # perhaps this should be a dict too? + source_cell_lens: list[torch.Tensor] | None + target_coords_idx: list[torch.Tensor] | None + def __init__(self, streams: dict) -> None: # TODO: can we pass this right away? self.meta_info = {} @@ -74,6 +79,9 @@ def __init__(self, streams: dict) -> None: for stream_info in streams: self.streams_data[stream_info["name"]] = None + self.source_cell_lens: list[torch.Tensor] | None = None + self.target_coords_idx: list[torch.Tensor] | None = None + def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: """ Add data for stream @stream_name to sample @@ -87,6 +95,13 @@ def add_meta_info(self, stream_name: str, meta_info: SampleMetaData) -> None: """ self.meta_info[stream_name] = meta_info + def set_preprocessed(self, source_cell_lens, target_coords_idx): + """ + Set preprocessed data for sample + """ + self.source_cell_lens = source_cell_lens + self.target_coords_idx = target_coords_idx + # TODO: complete interface, e.g get_stream def get_stream_data(self, stream_name: str) -> StreamData: diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 3e96bb562..332bba688 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -5,7 +5,7 @@ import torch from weathergen.common.config import Config -from weathergen.datasets.batch import ViewMetadata +from weathergen.datasets.batch import SampleMetaData _logger = logging.getLogger(__name__) @@ -630,11 +630,10 @@ def build_views_for_stream( teacher_cfg: dict, student_cfg: dict, relationship: str = "subset", - ) -> Tuple[np.typing.NDArray, List[np.typing.NDArray], List[ViewMetadata]]: + ) -> Tuple[np.typing.NDArray, List[np.typing.NDArray], List[SampleMetaData]]: """ Construct teacher/student keep masks for a stream. - ViewMetadata likely to be deprecated, - but information can be piped here for now. + SampleMetaData is currently just a dict with the masking params used. """ strat_teacher = teacher_cfg.get("strategy", "random") @@ -669,25 +668,15 @@ def build_views_for_stream( keep = base student_keep_masks.append(keep) - metadata: List[ViewMetadata] = [ - ViewMetadata( - view_id="teacher_global", - keep_mask=teacher_keep_mask, - strategy=strat_teacher, - healpix_level=self.healpix_level_data, - rate=rate_teacher, - parent_view_id=None, + metadata: List[SampleMetaData] = [ + SampleMetaData( + masking_params=teacher_cfg, ) ] for idx, mask in enumerate(student_keep_masks): metadata.append( - ViewMetadata( - view_id=f"student_local_{idx}", - keep_mask=mask, - strategy=strat_student, - healpix_level=self.healpix_level_data, - rate=rate_student, - parent_view_id="teacher_global", + SampleMetaData( + masking_params=student_cfg, ) ) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index da2dc84ca..2d0e69dff 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -14,7 +14,7 @@ import torch from weathergen.common.io import IOReaderData -from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.batch import ModelBatch, Sample, SampleMetaData from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi from weathergen.datasets.data_reader_base import ( DataReaderBase, @@ -550,7 +550,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): name = stream_info["name"] - (target_masks, source_masks, student_to_teacher) = masks_streams[name] + (target_masks, source_masks, student_to_teacher, target_metadata_list, source_metadata_list) = masks_streams[name] # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 @@ -582,12 +582,16 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # source meta info... # source_meta_info = SampleMetaData(... - source_meta_info = student_cfg + #print("metadata:", metadata) + #print("How many elements in metadata?", len(metadata)) + #print("current sidx:", sidx) + + source_metadata = source_metadata_list[sidx] # first is teacher # TODO: seb check this # Map each student (source) to its teacher (target) t_idx = student_to_teacher[sidx] - batch.add_source_stream(sidx, t_idx, name, sdata, source_meta_info) + batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) # stream_data_target can contain network input stream_data_target = {} @@ -608,12 +612,12 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): stream_data_target[name] = sdata # get teacher config info - teacher_meta_info = teacher_cfg + target_metadata = target_metadata_list[t_idx] # TODO: seb to check # Map target to all source students student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] - batch.add_target_stream(t_idx, student_indices, name, sdata, teacher_meta_info) + batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) # import pdb; pdb.set_trace() @@ -653,11 +657,13 @@ def to_bool_tensor(arr): target_masks: list[torch.Tensor] = [] source_masks: list[torch.Tensor] = [] student_to_teacher: list[int] = [] + target_metadata: list[SampleMetaData] = [] + source_metadata: list[SampleMetaData] = [] # add a loop over num_teacher_views, generate students for each teacher for t_idx in range(num_teacher_views): # Build one teacher and its student views - t_keep_np, s_keeps_np, _meta = self.tokenizer.masker.build_views_for_stream( + t_keep_np, s_keeps_np, metadata = self.tokenizer.masker.build_views_for_stream( self.num_healpix_cells, teacher_cfg=teacher_cfg, student_cfg=student_cfg, @@ -667,14 +673,16 @@ def to_bool_tensor(arr): # append teacher mask t_tensor = to_bool_tensor(t_keep_np) target_masks.append(t_tensor) + target_metadata.append(metadata[0]) # TODO: first is teacher # this teacher's students and mapping - for s_np in (s_keeps_np or []): + for s_np, metadata in zip(s_keeps_np or [], metadata[1:], strict=True): source_masks.append(to_bool_tensor(s_np)) # append 0, 1, ... depending on which teacher we did + source_metadata.append(metadata) student_to_teacher.append(len(target_masks) - 1) - masks[stream_info["name"]] = (target_masks, source_masks, student_to_teacher) + masks[stream_info["name"]] = (target_masks, source_masks, student_to_teacher, target_metadata, source_metadata) return masks @@ -693,6 +701,21 @@ def _preprocess_model_data(self, batch, forecast_dt): return batch, source_cell_lens, target_coords_idx + def _preprocess_single_view(self, sample: Sample, forecast_dt: int): + """ """ + streams = [sd for sd in sample.streams_data.values() if sd is not None] + if not streams: + sample.set_preprocessed([], []) + return + _, scl, tci = self._preprocess_model_data([streams], forecast_dt) + sample.set_preprocessed(scl, tci) + + def _preprocess_model_batch_views(self, model_batch: ModelBatch, forecast_dt: int): + for sample in model_batch.source_samples: + self._preprocess_single_view(sample, forecast_dt) + for sample in model_batch.target_samples: + self._preprocess_single_view(sample, forecast_dt) + def __iter__(self): """ Return one batch of data @@ -744,12 +767,19 @@ def __iter__(self): # TODO: link into ModelBatch + print("Batch size:", len(batch)) + print("What is batch at this point?", batch) + + # import pdb; pdb.set_trace() + # compute batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch, forecast_dt ) - import pdb; pdb.set_trace() + self._preprocess_model_batch_views(student_teacher_batch, forecast_dt) + + # import pdb; pdb.set_trace() yield (batch, source_cell_lens, target_coords_idx, forecast_dt), student_teacher_batch From e0d73461902b1b6639a9ebfa0349e7d5f035dc81 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Wed, 26 Nov 2025 14:31:52 +0100 Subject: [PATCH 103/344] remove prints, pdb --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 -- src/weathergen/train/trainer.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 2d0e69dff..6682d7c68 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -767,8 +767,6 @@ def __iter__(self): # TODO: link into ModelBatch - print("Batch size:", len(batch)) - print("What is batch at this point?", batch) # import pdb; pdb.set_trace() diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8a1e8d5d9..e0b1f09de 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -613,6 +613,8 @@ def train(self, mini_epoch): self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): + # import pdb; pdb.set_trace() + ################################################################ # SOPH: student teacher access path here: # student_teacher_data = batch[1] @@ -632,8 +634,6 @@ def train(self, mini_epoch): # make existing pipeline work: batch = batch[0] - - forecast_steps = batch[-1] batch = self.batch_to_device(batch) From a09a7373ba6594fc5b9972f5896080b96c8b0988 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 26 Nov 2025 16:16:16 +0100 Subject: [PATCH 104/344] linting --- src/weathergen/model/attention.py | 4 ++-- src/weathergen/model/diffusion.py | 11 ++++++++--- src/weathergen/model/layers.py | 10 ++++++---- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index f63554367..5ff5e5bf7 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -249,7 +249,7 @@ def mask_block_local(batch, head, idx_q, idx_kv): self.noise_conditioning = LinearNormConditioning(dim_embed, dtype=self.dtype) def forward(self, *args): - # NOTE: Hotfix to accomodate TargetPredictionEngineClassic forward pass for attn. block, MLP... + # NOTE: Hotfix for TargetPredictionEngineClassic forward pass for attn. block, MLP... x = args[0] if len(args) == 2: ada_ln_aux = args[1] @@ -556,7 +556,7 @@ def __init__( ) def forward(self, *args): - # NOTE: Hotfix to accomodate TargetPredictionEngineClassic forward pass for attn. block, MLP... + # NOTE: Hotfix for TargetPredictionEngineClassic forward pass for attn. block, MLP... x = args[0] if len(args) == 2: ada_ln_aux = args[1] diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 02f247739..9b34f8d48 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -16,7 +16,8 @@ # ---------------------------------------------------------------------------- # Third-Party Attribution: facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) -# This file incorporates code originally from the 'facebookresearch/DiT' repository, with adaptations. +# This file incorporates code originally from the 'facebookresearch/DiT' repository, +# with adaptations. # # The original code is licensed under CC-BY-NC. # ---------------------------------------------------------------------------- @@ -24,7 +25,9 @@ import dataclasses import math + import torch + from weathergen.model.engines import ForecastingEngine @@ -127,7 +130,9 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> # Precondition input and feed through network x = self.preconditioner.precondition(x, c) - return c_skip * x + c_out * self.net(c_in * x, fstep=fstep, noise_emb=noise_emb) # Eq. (7) in EDM paper + return c_skip * x + c_out * self.net( + c_in * x, fstep=fstep, noise_emb=noise_emb + ) # Eq. (7) in EDM paper def inference( self, @@ -203,7 +208,7 @@ def __init__(self, embedding_dim: int, frequency_embedding_dim: int, dtype=torch ) self.frequency_embedding_dim = frequency_embedding_dim - def timestep_embedding(self, t: float, max_period: int=10000): + def timestep_embedding(self, t: float, max_period: int = 10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 8ab5156b5..e85acc6c7 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -9,20 +9,22 @@ # ---------------------------------------------------------------------------- # Third-Party Attribution: facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) -# This file incorporates code originally from the 'facebookresearch/DiT' repository, with adaptations. +# This file incorporates code originally from the 'facebookresearch/DiT' repository, +# with adaptations. # # The original code is licensed under CC-BY-NC. # ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- # Third-Party Attribution: google-deepmind/graphcast (several associated papers) -# This file incorporates code originally from the 'google-deepmind/graphcast' repository, with adaptations. +# This file incorporates code originally from the 'google-deepmind/graphcast' repository, +# with adaptations. # -# The original code is licensed under Apache 2.0. Original Copyright 2024 DeepMind Technologies Limited. +# The original code is licensed under Apache 2.0. +# Original Copyright 2024 DeepMind Technologies Limited. # ---------------------------------------------------------------------------- - import torch import torch.nn as nn From 705cb0a119d5db93789f354ee5f4fa7a504b407f Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 26 Nov 2025 17:42:32 +0100 Subject: [PATCH 105/344] fix ddp --- config/default_config.yml | 3 +- src/weathergen/model/model.py | 47 +++++++++++-------- src/weathergen/model/model_interface.py | 6 ++- .../loss_module_latent_diffusion.py | 4 +- .../train/target_and_aux_diffusion.py | 8 +++- 5 files changed, 45 insertions(+), 23 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 03339ac55..84730d5f7 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -59,7 +59,7 @@ healpix_level: 5 with_mixed_precision: True with_flash_attention: True compile_model: False -with_fsdp: True +with_fsdp: False attention_dtype: bf16 mlp_norm_eps: 1e-5 norm_eps: 1e-4 @@ -129,6 +129,7 @@ samples_per_mini_epoch: 4096 samples_per_validation: 512 shuffle: True +mixed_precision_dtype: bf16 lr_scaling_policy: "sqrt" lr_start: 1e-6 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 03e3e3ba0..237beb8be 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -562,7 +562,14 @@ def rename_old_state_dict(self, params: dict) -> dict: return new_params ######################################### - def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int): + def forward( + self, + model_params: ModelParams, + batch, + forecast_offset: int, + forecast_steps: int, + encode_only: bool = False, + ): """Performs the forward pass of the model to generate forecasts Tokens are processed through the model components, which were defined in the create method. @@ -583,6 +590,8 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca (streams_data, _, target_coords_idxs) = batch tokens, posteriors = self.encode(model_params=model_params, batch=batch) + if encode_only: + return tokens, posteriors # roll-out in latent space preds_all = [] @@ -590,15 +599,15 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca latents["preds"] = [] for fstep in range(forecast_offset, forecast_offset + forecast_steps): # prediction - preds_all += [ - self.predict( - model_params, - fstep, - tokens, - streams_data, - target_coords_idxs, - ) - ] + # preds_all += [ + # self.predict( + # model_params, + # fstep, + # tokens, + # streams_data, + # target_coords_idxs, + # ) + # ] if self.training: # Impute noise to the latent state @@ -610,15 +619,15 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca latents["preds"] += [tokens] # prediction for final step - preds_all += [ - self.predict( - model_params, - forecast_offset + forecast_steps, - tokens, - streams_data, - target_coords_idxs, - ) - ] + # preds_all += [ + # self.predict( + # model_params, + # forecast_offset + forecast_steps, + # tokens, + # streams_data, + # target_coords_idxs, + # ) + # ] latents["posteriors"] = posteriors diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index ad636de45..1acf1f5bd 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -106,7 +106,11 @@ def init_model_and_shard( if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.forecast_engine.fe_blocks.modules(): + if cf.fe_diffusion_model: + model_fe_blocks = model.forecast_engine.net.fe_blocks + else: + model_fe_blocks = model.forecast_engine.fe_blocks + for module in model_fe_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 793f06302..e72ddf346 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -117,5 +117,7 @@ def compute_loss( loss_values[loss_values == 0.0] = torch.nan return LossValues( - loss=loss, losses_all=losses_all, stddev_all={"latent": torch.tensor(torch.nan)} + loss=loss, + losses_all=losses_all, + stddev_all={"latent": torch.tensor(torch.nan).to(self.device)}, ) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 18a2e86e3..569dd6920 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -14,5 +14,11 @@ def compute( self, bidx, batch, model_params, model, forecast_offset, forecast_steps ) -> tuple[Any, Any]: with torch.no_grad(): - tokens, posteriors = self.model.encode(model_params=model_params, batch=batch) + tokens, posteriors = self.model( + model_params=model_params, + batch=batch, + forecast_offset=None, + forecast_steps=None, + encode_only=True, + ) return {"latent": [tokens]}, posteriors From 3e989c40ca448cfc0003e660170b763d8fbe7141 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 26 Nov 2025 18:17:03 +0100 Subject: [PATCH 106/344] load encoder weights, fixed for multi-gpu --- config/default_config.yml | 4 ++- src/weathergen/model/model_interface.py | 44 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index 84730d5f7..7450844ec 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -10,7 +10,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 +ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -53,6 +53,8 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True impute_latent_noise_std: 0.0 # 1e-4 +chkpt_encoder_weights: "./models/whkujigw/whkujigw_epoch00063.chkpt" + healpix_level: 5 diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 1acf1f5bd..18d86f44f 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -157,6 +157,50 @@ def init_model_and_shard( logger.info(f"Continuing run with id={run_id_contd} at mini_epoch {mini_epoch_contd}.") model = load_model(cf, model, device, run_id_contd, mini_epoch_contd) + # ------------------------------------------------------------------------------------------ + # LOAD AND FREEZE ENCODER WEIGHTS + # ONLY FOR EXPERIMENTATION, TO BE REMOVED + params = torch.load( + cf.chkpt_encoder_weights, + map_location=torch.device("cpu"), + mmap=True, + weights_only=True, + ) + encoder_modules = [ + "embed_engine", + "ae_local_engine", + "ae_local_global_engine", + "ae_global_engine", + ] + + # Load encoder weights + params_temp = {} + for name in params.keys(): + if any(e_module in name for e_module in encoder_modules): + if cf.with_ddp: + params_temp[f"module.{name}"] = params[name] + else: + params_temp[name] = params[name] + params = params_temp + mkeys, ukeys = model.load_state_dict(params, strict=False) + + # Freeze encoder weights + for name, module in model.named_modules(): + if any(e_module in name for e_module in encoder_modules): + for p in module.parameters(): + p.requires_grad = False + + model = model.to(f"cuda:{cf.local_rank}") + + # warn about difference in checkpoint and model + if len(mkeys) == 0 and len(ukeys) == 0: + logger.info(f"Checkpoint {cf.chkpt_encoder_weights} loaded successfully with all weights.") + if len(mkeys) > 0: + logger.warning(f"Missing keys when loading model: {mkeys}") + if len(ukeys) > 0: + logger.warning(f"Unused keys when loading model: {ukeys}") + # ------------------------------------------------------------------------------------------ + # model params model_params = ModelParams(cf).create(cf) model_params.reset_parameters(cf) From 5cb5f056a397a5b55d422d7bb5d74d8ae44c89e3 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 26 Nov 2025 22:27:29 +0100 Subject: [PATCH 107/344] fix parameter counting in case of diff FE --- src/weathergen/model/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 237beb8be..5a5cc7622 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -487,7 +487,11 @@ def print_num_parameters(self) -> None: num_params_q_cells = np.prod(self.q_cells.shape) if self.q_cells.requires_grad else 0 num_params_ae_adapater = get_num_parameters(self.ae_local_global_engine.ae_adapter) - num_params_fe = get_num_parameters(self.forecast_engine.fe_blocks) + num_params_fe = get_num_parameters( + self.forecast_engine.net.fe_blocks + if cf.fe_diffusion_model + else self.forecast_engine.fe_blocks + ) num_params_pred_adapter = [get_num_parameters(kv) for kv in self.pred_adapter_kv] num_params_embed_tcs = [get_num_parameters(etc) for etc in self.embed_target_coords] From 6d909d604e86597895960aada4fb7cacbb427a31 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 27 Nov 2025 11:32:32 +0100 Subject: [PATCH 108/344] add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views --- src/weathergen/datasets/batch.py | 32 +++++++++++++------ .../datasets/multi_stream_data_sampler.py | 14 ++++++-- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 583ebfcdf..fe65cb0d5 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -59,6 +59,8 @@ class SampleMetaData: # parameters for masking strategy masking_params: Config | dict + mask: torch.Tensor | None = None + class Sample: # keys: stream name, values: SampleMetaData meta_info: dict @@ -66,9 +68,11 @@ class Sample: # data for all streams # keys: stream_name, values: StreamData streams_data: dict[str, StreamData | None] - - # perhaps this should be a dict too? + forecast_dt: int | None + + # these two live in ModelBatch as they are flattened! source_cell_lens: list[torch.Tensor] | None + # this should be a dict also lives in ModelBatch target_coords_idx: list[torch.Tensor] | None def __init__(self, streams: dict) -> None: @@ -82,6 +86,8 @@ def __init__(self, streams: dict) -> None: self.source_cell_lens: list[torch.Tensor] | None = None self.target_coords_idx: list[torch.Tensor] | None = None + self.forecast_dt: int | None = None + def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: """ Add data for stream @stream_name to sample @@ -102,6 +108,12 @@ def set_preprocessed(self, source_cell_lens, target_coords_idx): self.source_cell_lens = source_cell_lens self.target_coords_idx = target_coords_idx + def set_forecast_dt(self, forecast_dt: int) -> None: + """ + Set forecast_dt for sample + """ + self.forecast_dt = forecast_dt + # TODO: complete interface, e.g get_stream def get_stream_data(self, stream_name: str) -> StreamData: @@ -125,8 +137,8 @@ class ModelBatch: # index of corresponding target (for source samples) or source (for target samples) # these are in 1-to-1 corresponding for classical training modes (MTM, forecasting) but # can be more complex for strategies like student-teacher training - source_target_matching_idxs: np.typing.NDArray[np.int32] - target_source_matching_idxs: np.typing.NDArray[np.int32] + source2target_matching_idxs: np.typing.NDArray[np.int32] + target2source_matching_idxs: np.typing.NDArray[np.int32] def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> None: """ """ @@ -134,9 +146,9 @@ def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> self.source_samples = [Sample(streams) for _ in range(num_source_samples)] self.target_samples = [Sample(streams) for _ in range(num_target_samples)] - self.source_target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) + self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) # self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) - self.target_source_matching_idxs = [[] for _ in range(num_target_samples)] + self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] def add_source_stream( self, @@ -156,7 +168,7 @@ def add_source_stream( assert target_sample_idx < len(self.target_samples), "invalid value for target_sample_idx" - self.source_target_matching_idxs[source_sample_idx] = target_sample_idx + self.source2target_matching_idxs[source_sample_idx] = target_sample_idx def add_target_stream( self, @@ -178,7 +190,7 @@ def add_target_stream( assert source_sample_idx < len(self.source_samples), "invalid value for source_sample_idx" else: assert all(idx < len(self.source_samples) for idx in source_sample_idx), "invalid value for source_sample_idx" - self.target_source_matching_idxs[target_sample_idx] = source_sample_idx + self.target2source_matching_idxs[target_sample_idx] = source_sample_idx def len_sources(self) -> int: """ @@ -208,10 +220,10 @@ def get_source_idx_for_target(self, target_idx: int) -> int: """ Get index of source sample for a given target sample index """ - return int(self.source_target_matching_idxs[target_idx]) + return int(self.target2source_matching_idxs[target_idx]) def get_target_idx_for_source(self, source_idx: int) -> int: """ Get index of target sample for a given source sample index """ - return int(self.source_target_matching_idxs[source_idx]) + return int(self.source2target_matching_idxs[source_idx]) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 6682d7c68..36837f36c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -577,6 +577,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_tokens, mask, ) + stream_data_source[name] = sdata # source meta info... @@ -588,10 +589,15 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): source_metadata = source_metadata_list[sidx] # first is teacher + # also want to add the mask to the metadata + source_metadata.mask = mask + # TODO: seb check this # Map each student (source) to its teacher (target) t_idx = student_to_teacher[sidx] batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) + batch.source_samples[sidx].set_forecast_dt(forecast_dt) + # stream_data_target can contain network input stream_data_target = {} @@ -614,13 +620,15 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # get teacher config info target_metadata = target_metadata_list[t_idx] + # also want to add the mask to the metadata + target_metadata.mask = mask + # TODO: seb to check # Map target to all source students student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) + batch.target_samples[t_idx].set_forecast_dt(forecast_dt) - # import pdb; pdb.set_trace() - # TODO: build batch # source_input # target_input @@ -697,6 +705,8 @@ def _preprocess_model_data(self, batch, forecast_dt): # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) + + ##### target_coords_idx we probably don't need for the targets ##### target_coords_idx = compute_idxs_predict(self.forecast_offset + forecast_dt, batch) return batch, source_cell_lens, target_coords_idx From 26f7b5bfb42db38f1552edd0fb155deebbcf7519 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Thu, 27 Nov 2025 15:33:22 +0100 Subject: [PATCH 109/344] add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far --- src/weathergen/datasets/batch.py | 42 +--- .../datasets/multi_stream_data_sampler.py | 227 ++++++++++++------ src/weathergen/train/trainer.py | 110 ++++++--- 3 files changed, 245 insertions(+), 134 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index fe65cb0d5..55674f81c 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -18,39 +18,6 @@ # TODO: GetTimestep to get the timestep # TODO: GetMetaData: then this gets the right rn for the timestep! - -# NOTE: TO BE DECPRECATED -@dataclass -class ViewMetadata: - """ - Metadata describing how a view was generated. - - This captures the spatial selection (which cells/tokens were kept), - the strategy used (random, healpix, etc.), and hierarchical parameters. - - Attributes: - view_id: Unique identifier (e.g., "teacher_global", "student_local_0") - keep_mask: Boolean array [num_healpix_cells] at data level indicating kept cells - healpix_level: HEALPix level for hierarchical selection (None if not applicable) - rate: Fraction of data kept (e.g., 0.5 = 50% kept); None if fixed count - parent_view_id: ID of the parent view this is a subset of (None for teacher) - """ - - # Core identifiers and selection description - view_id: str - keep_mask: np.typing.NDArray # [num_cells] bool at data level - strategy: str # e.g. "random", "healpix", "channel" - - # Hierarchical/quantitative description of selection - healpix_level: int | None = None - rate: float | None = None - parent_view_id: str | None = None # For students: which teacher they belong to - - # Optional extras for future/other training paradigms - loss_type: str | None = None # e.g. DINO, JEPA - strategy_config: Config | None = None # e.g. {rate: 0.5, hl_mask: 3, overlap: "disjoint"} - - @dataclass class SampleMetaData: # masking strategy @@ -61,6 +28,8 @@ class SampleMetaData: mask: torch.Tensor | None = None + noise_level_rn: float | None = None + class Sample: # keys: stream name, values: SampleMetaData meta_info: dict @@ -69,10 +38,11 @@ class Sample: # keys: stream_name, values: StreamData streams_data: dict[str, StreamData | None] forecast_dt: int | None - - # these two live in ModelBatch as they are flattened! - source_cell_lens: list[torch.Tensor] | None + + # TODO: + # these two need to live in ModelBatch as they are flattened! # this should be a dict also lives in ModelBatch + source_cell_lens: list[torch.Tensor] | None target_coords_idx: list[torch.Tensor] | None def __init__(self, streams: dict) -> None: diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 36837f36c..4f9fdc661 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -329,7 +329,7 @@ def _build_stream_data_input( # source input data - # Fornow, keep only mask state of the final timestep + # For now, keep only mask state of the final timestep # (correspondsing to base_idx, first of the loop below) # to ensure alignment with the target data for MTM/S-T. final_mask_state = None @@ -356,6 +356,7 @@ def _build_stream_data_input( keep_mask=mask, ) + # for masked autoencoding, we want the mask state that overlaps with the target if step == 0: final_mask_state = mask_state @@ -533,39 +534,147 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): TODO: these modes are not being used now. """ - streams_data: list[StreamData] = [] + if mode == "student_teacher": + + streams_data: list[StreamData] = [] - # get/coordinate masks - masks_streams = self._get_source_target_masks() + # get/coordinate masks + masks_streams = self._get_source_target_masks() - # Determine number of views direct from config (teacher & student views) - teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} - student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} - num_target_samples = int(teacher_cfg.get("num_views", 1)) - num_source_samples = int(teacher_cfg.get("num_views", 1)) * int(student_cfg.get("num_views", 1)) # per teacher - - batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + # Determine number of views direct from config (teacher & student views) + teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} + num_target_samples = int(teacher_cfg.get("num_views", 1)) + num_source_samples = int(teacher_cfg.get("num_views", 1)) * int(student_cfg.get("num_views", 1)) # per teacher + + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + + # for all streams + for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + name = stream_info["name"] + + (target_masks, source_masks, student_to_teacher, target_metadata_list, source_metadata_list) = masks_streams[name] + + # input_data and output_data is conceptually consecutive but differs + # in source and target channels; overlap in one window when self.forecast_offset=0 + (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) + + # tokenize windows + # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps + input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) + output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) + + # collect source data for current stream + # loop over student views + stream_data_source = {} + for sidx, mask in enumerate(source_masks): + # stream_data_source[name] = self._build_stream_data( + sdata = self._build_stream_data( + "target_coords target_values", + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + mask, + ) + + stream_data_source[name] = sdata + + + # source meta info... + # source_meta_info = SampleMetaData(... + + source_metadata = source_metadata_list[sidx] # first is teacher + + # also want to add the mask to the metadata + source_metadata.mask = mask + + # TODO: seb check this + # Map each student (source) to its teacher (target) + t_idx = student_to_teacher[sidx] + batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) + # num_input_steps? + batch.source_samples[sidx].set_forecast_dt(forecast_dt) + + + # stream_data_target can contain network input + stream_data_target = {} + + for t_idx, mask in enumerate(target_masks): + # stream_data_target[name] = self._build_stream_data( + sdata = self._build_stream_data( + "target_values", + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + mask, + ) + stream_data_target[name] = sdata + + # get teacher config info + target_metadata = target_metadata_list[t_idx] + + # also want to add the mask to the metadata + target_metadata.mask = mask + + # TODO: seb to check + # Map target to all source students + student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] + batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) + batch.target_samples[t_idx].set_forecast_dt(forecast_dt) + + # TODO: build batch + # source_input + # target_input + # source_output + # target_output - # for all streams - for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - name = stream_info["name"] + # add data for current stream + streams_data += [v for k, v in stream_data_source.items()] - (target_masks, source_masks, student_to_teacher, target_metadata_list, source_metadata_list) = masks_streams[name] + elif mode == "diffusion_forecast": - # input_data and output_data is conceptually consecutive but differs - # in source and target channels; overlap in one window when self.forecast_offset=0 - (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) + streams_data: list[StreamData] = [] - # tokenize windows - # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps - input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) - output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) + # get/coordinate masks + masks_streams = self._get_source_target_masks() - # collect source data for current stream - # loop over student views - stream_data_source = {} - for sidx, mask in enumerate(source_masks): - # stream_data_source[name] = self._build_stream_data( + # Determine number of views direct from config (teacher & student views) + teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} + num_target_samples = int(teacher_cfg.get("num_views", 1)) + num_source_samples = int(teacher_cfg.get("num_views", 1)) * int(student_cfg.get("num_views", 1)) # per teacher + + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + + + # for all streams + for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + name = stream_info["name"] + + source_metadata = SampleMetaData(masking_params=student_cfg) + target_metadata = SampleMetaData(masking_params=teacher_cfg) + + # input_data and output_data is conceptually consecutive but differs + # in source and target channels; overlap in one window when self.forecast_offset=0 + (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) + + # tokenize windows + # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps + input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) + output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) + + # collect source data for current stream + # loop over student views + stream_data_source = {} + # stream_data_source[name] = self._build_stream_data( sdata = self._build_stream_data( "target_coords target_values", idx, @@ -575,34 +684,24 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - mask, + mask=None, ) stream_data_source[name] = sdata - # source meta info... - # source_meta_info = SampleMetaData(... + source_metadata = source_metadata - #print("metadata:", metadata) - #print("How many elements in metadata?", len(metadata)) - #print("current sidx:", sidx) + # add a ramdom number for diffusion timestep + source_metadata.noise_level_rn = self.rng.normal(0.0, 1.0) - source_metadata = source_metadata_list[sidx] # first is teacher - - # also want to add the mask to the metadata - source_metadata.mask = mask - - # TODO: seb check this # Map each student (source) to its teacher (target) - t_idx = student_to_teacher[sidx] - batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) - batch.source_samples[sidx].set_forecast_dt(forecast_dt) + batch.add_source_stream(0, 0, name, sdata, source_metadata) + # num_input_steps? + batch.source_samples[0].set_forecast_dt(forecast_dt) + # stream_data_target can contain network input + stream_data_target = {} - # stream_data_target can contain network input - stream_data_target = {} - - for t_idx, mask in enumerate(target_masks): # stream_data_target[name] = self._build_stream_data( sdata = self._build_stream_data( "target_values", @@ -613,30 +712,28 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - mask, + mask=None, ) stream_data_target[name] = sdata # get teacher config info - target_metadata = target_metadata_list[t_idx] + target_metadata = target_metadata - # also want to add the mask to the metadata - target_metadata.mask = mask + # TODO: handle this for different number of source timesteps + target_metadata.noise_level_rn = source_metadata.noise_level_rn - # TODO: seb to check # Map target to all source students - student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] - batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) - batch.target_samples[t_idx].set_forecast_dt(forecast_dt) - - # TODO: build batch - # source_input - # target_input - # source_output - # target_output + batch.add_target_stream(0, 0, name, sdata, target_metadata) + batch.target_samples[0].set_forecast_dt(forecast_dt) + + # TODO: build batch + # source_input + # target_input + # source_output + # target_output - # add data for current stream - streams_data += [v for k, v in stream_data_source.items()] + # add data for current stream + streams_data += [v for k, v in stream_data_source.items()] return streams_data, batch @@ -763,7 +860,7 @@ def __iter__(self): # # view-based data sampling # if self.training_cfg.get("training_mode") == "student_teacher": - mode = "student_teacher" + mode = self.training_cfg.get("training_mode") streams_data, student_teacher_batch = self._get_sample(mode, idx, forecast_dt) @@ -778,8 +875,6 @@ def __iter__(self): # TODO: link into ModelBatch - # import pdb; pdb.set_trace() - # compute batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch, forecast_dt @@ -787,8 +882,6 @@ def __iter__(self): self._preprocess_model_batch_views(student_teacher_batch, forecast_dt) - # import pdb; pdb.set_trace() - yield (batch, source_cell_lens, target_coords_idx, forecast_dt), student_teacher_batch def __len__(self): diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e0b1f09de..cb96ef9f7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -613,45 +613,93 @@ def train(self, mini_epoch): self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): - # import pdb; pdb.set_trace() - - ################################################################ - # SOPH: student teacher access path here: - # student_teacher_data = batch[1] - # access student views: - #all_student_views = student_teacher_data.source_samples - #student_sample_1 = student_teacher_data.source_samples[0] - #student_sample_1_stream_data = student_teacher_data.source_samples[0].streams_data # dict, {stream: stream data} of first student view - # e.g. target tokens of ERA5 stream of first student view: - # target_tokens_of_student_sample_1_ERA5_stream_data = student_teacher_batch.source_samples[0].streams_data["ERA5"].target_tokens - - # access metadata of the student views, this is currently shared, very hacky, to fix. - #metadata_student_view = student_teacher_batch.source_samples[0].meta_info - - # You will also need the source_cell_lens, target_coords_idx, these are not being passed through for the views yet. - ################################################################ - - # make existing pipeline work: - batch = batch[0] - - forecast_steps = batch[-1] - batch = self.batch_to_device(batch) - - # evaluate model + # NOTE: we are still returning legacy batch structure and the new batch together. + + # Julian and Matthias: + # here we can access data as follows: + # batch[-1] is the new ModelBatch object, see the structure in batch.py + # batch[-1].source_samples is a list of Sample objects for the source data, timesteps + # batch[-1].target_samples is a list of Sample objects for the target data, timesteps + # batch[-1].meta_info is a dictionary with metadata info per sample + # batch[-1].meta_info["ERA5"] etc. + # here we have the noise_level_rn + # batch[-1].source_samples[0].meta_info["ERA5"].noise_level_rn == batch[-1].target_samples[0].meta_info["ERA5"].noise_level_rn + # for the same timestep, this needs to be fixed for when we have more source timesteps, and perhaps with bigger batch sizes? + # Each Sample object has: + # .streams_data: a dictionary of StreamData objects per stream name + # .source_cell_lens: list of tensors with lengths of source cells per stream # to be changed to be in ModelBatch + # .target_coords_idx: list of tensors with target coordinate indices per stream # to be changed to be in ModelBatch + + ###### Legacy batch after batch.to_device: + # (Pdb++) batch[0] + # [[]] + # (Pdb++) batch[1] + # [tensor([0, 1, 1, ..., 0, 0, 0], device='cuda:0', dtype=torch.int32)] + # (Pdb++) batch[2] + # [[tensor([0, 0, 0, ..., 4, 4, 4], device='cuda:0', dtype=torch.int32)]] + + # TODO: access from new ModelBatch + forecast_steps = batch[0][-1] + #batch = self.batch_to_device(batch) + + ### After to_device, then the original is: + with torch.autocast( device_type=f"cuda:{cf.local_rank}", dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - output = self.model(self.model_params, batch, cf.forecast_offset, forecast_steps) - targets = {"physical": batch[0]} + outputs = [] + for view in batch[-1].source_samples: + # TODO remove when ModelBatch and Sample get a to_device() + streams_data = [[view.streams_data['ERA5']]] + streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] + source_cell_lens = view.source_cell_lens + source_cell_lens = [b.to(self.device) for b in source_cell_lens] + target_coords_idxs = view.target_coords_idx + target_coords_idxs = [[b.to(self.device) for b in bf] for bf in target_coords_idxs] + outputs.append(self.model( + self.model_params, (streams_data, source_cell_lens, target_coords_idxs), cf.forecast_offset, forecast_steps + )) + + targets_and_auxs = [] + for view in batch[-1].target_samples: + # TODO remove when ModelBatch and Sample get a to_device() + streams_data = [[view.streams_data['ERA5']]] + streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] + source_cell_lens = view.source_cell_lens + source_cell_lens = [b.to(self.device) for b in source_cell_lens] + target_coords_idxs = view.target_coords_idx + target_coords_idxs = [[b.to(self.device) for b in bf] for bf in target_coords_idxs] + targets_and_auxs.append(self.target_and_aux_calculator.compute( + self.cf.istep, + (streams_data, source_cell_lens, target_coords_idxs), + self.model_params, + self.model, + cf.forecast_offset, + forecast_steps, + )) + targets, aux = zip(*targets_and_auxs) loss, loss_values = self.loss_calculator.compute_loss( - preds=output, + preds=outputs, targets=targets, + view_metadata=(batch[-1].source2target_matching_idxs, + [sample.meta_info for sample in batch[-1].source_samples], + batch[-1].target2source_matching_idxs, + [sample.meta_info for sample in batch[-1].target_samples] + ), + ) + # TODO re-enable this, need to think on how to make it compatible with + # student-teacher training + # if cf.latent_noise_kl_weight > 0.0: + # kl = torch.cat([posterior.kl() for posterior in output.latent["posteriors"]]) + # loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + + self.target_and_aux_calculator.update_state_pre_backward( + self.cf.istep, batch, self.model ) - if cf.latent_noise_kl_weight > 0.0: - kl = torch.cat([posterior.kl() for posterior in output.latent]) - loss += cf.latent_noise_kl_weight * kl.mean() + + self.target_and_aux_calculator.update_state_pre_backward(bidx, batch, self.model) # backward pass self.optimizer.zero_grad() From b7cfb21022261fa60128345e6fa6b90c103b4023 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 27 Nov 2025 15:47:41 +0100 Subject: [PATCH 110/344] move diff parameters to config, add noise weight calc in latent loss --- config/default_config.yml | 11 ++++++++ src/weathergen/model/diffusion.py | 27 +++++++++---------- src/weathergen/model/model.py | 4 ++- .../loss_module_latent_diffusion.py | 15 ++++++++++- 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 7450844ec..b25f83189 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -53,9 +53,20 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True impute_latent_noise_std: 0.0 # 1e-4 +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 50000 +sigma_data: 0.5 +rho: 7 +p_mean: -1.2 +p_std: 1.2 +# Encoder weights chkpt_encoder_weights: "./models/whkujigw/whkujigw_epoch00063.chkpt" + healpix_level: 5 with_mixed_precision: True diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 9b34f8d48..0315b93ca 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -28,6 +28,7 @@ import torch +from weathergen.common.config import Config from weathergen.model.engines import ForecastingEngine @@ -62,29 +63,25 @@ class DiffusionForecastEngine(torch.nn.Module): def __init__( self, forecast_engine: ForecastingEngine, - frequency_embedding_dim: int = 256, # TODO: determine suitable dimension - embedding_dim: int = 512, # TODO: determine suitable dimension - sigma_min: float = 0.002, # Adapt to GenCast? - sigma_max: float = 80, - sigma_data: float = 0.5, - rho: float = 7, - p_mean: float = -1.2, - p_std: float = 1.2, + cf: Config, ): super().__init__() + self.cf = cf self.net = forecast_engine self.preconditioner = Preconditioner() + self.frequency_embedding_dim = self.cf.frequency_embedding_dim + self.embedding_dim = self.cf.embedding_dim self.noise_embedder = NoiseEmbedder( - embedding_dim=embedding_dim, frequency_embedding_dim=frequency_embedding_dim + embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim ) # Parameters - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.rho = rho - self.p_mean = p_mean - self.p_std = p_std + self.sigma_min = self.cf.sigma_min + self.sigma_max = self.cf.sigma_max + self.sigma_data = self.cf.sigma_data + self.rho = self.cf.rho + self.p_mean = self.cf.p_mean + self.p_std = self.cf.p_std def forward(self, tokens: torch.Tensor, fstep: int) -> torch.Tensor: """ diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 5a5cc7622..f541876bf 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -334,7 +334,9 @@ def create(self) -> "Model": self.forecast_engine = ForecastingEngine(cf, self.num_healpix_cells) if cf.fe_diffusion_model: - self.forecast_engine = DiffusionForecastEngine(forecast_engine=self.forecast_engine) + self.forecast_engine = DiffusionForecastEngine( + forecast_engine=self.forecast_engine, cf=cf + ) ############### # embed coordinates yielding one query token for each target token diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index e72ddf346..b68147ca6 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -40,9 +40,18 @@ def __init__( self.device = device self.name = "LossLatentDiff" + self.sigma_data = self.cf.sigma_data + self.rho = self.cf.rho + self.p_mean = self.cf.p_mean + self.p_std = self.cf.p_std + # Dynamically load loss functions based on configuration and stage self.loss_fcts = [[getattr(losses, name), w, name] for name, w in loss_fcts] + def _get_noise_weight(self, eta): + sigma = (eta * self.p_std + self.p_mean).exp() + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + def _get_fstep_weights(self, forecast_steps): timestep_weight_config = self.cf.get("timestep_weight") if timestep_weight_config is None: @@ -55,12 +64,13 @@ def _loss_per_loss_function( loss_fct, target: torch.Tensor, pred: torch.Tensor, + noise_weight: torch.Tensor = 1.0, ): """ Compute loss for given loss function """ - loss_val = loss_fct(target=target, mu=pred) + loss_val = noise_weight * loss_fct(target=target, mu=pred) return loss_val @@ -81,6 +91,8 @@ def compute_loss( targets = targets["latent"] fsteps = len(targets) + eta = torch.randn(1) + noise_weight = self._get_noise_weight(eta).to(device=preds[0].device) fstep_loss_weights = self._get_fstep_weights(fsteps) loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) @@ -96,6 +108,7 @@ def compute_loss( loss_fct, target=target, pred=pred, + noise_weight=noise_weight, ) losses_all[f"{self.name}.{loss_fct_name}"] += loss_lfct # TODO: break into fsteps From 3e4de7aa589595156ce8bb6c085724854ae57329 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:07:51 +0100 Subject: [PATCH 111/344] Linting --- src/weathergen/train/trainer.py | 59 ++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 935aac962..e21e84cf0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -494,7 +494,6 @@ def train(self, mini_epoch): # training loop self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): - # NOTE: we are still returning legacy batch structure and the new batch together. # Julian and Matthias: @@ -505,13 +504,17 @@ def train(self, mini_epoch): # batch[-1].meta_info is a dictionary with metadata info per sample # batch[-1].meta_info["ERA5"] etc. # here we have the noise_level_rn - # batch[-1].source_samples[0].meta_info["ERA5"].noise_level_rn == batch[-1].target_samples[0].meta_info["ERA5"].noise_level_rn - # for the same timestep, this needs to be fixed for when we have more source timesteps, and perhaps with bigger batch sizes? + # batch[-1].source_samples[0].meta_info["ERA5"].noise_level_rn + # == batch[-1].target_samples[0].meta_info["ERA5"].noise_level_rn + # for the same timestep, this needs to be fixed for when we have more source timesteps, + # and perhaps with bigger batch sizes? # Each Sample object has: # .streams_data: a dictionary of StreamData objects per stream name - # .source_cell_lens: list of tensors with lengths of source cells per stream # to be changed to be in ModelBatch - # .target_coords_idx: list of tensors with target coordinate indices per stream # to be changed to be in ModelBatch - + # .source_cell_lens: list of tensors with lengths of source cells per stream # to be + # changed to be in ModelBatch + # .target_coords_idx: list of tensors with target coordinate indices per stream # to + # be changed to be in ModelBatch + ###### Legacy batch after batch.to_device: # (Pdb++) batch[0] # [[]] @@ -522,7 +525,7 @@ def train(self, mini_epoch): # TODO: access from new ModelBatch forecast_steps = batch[0][-1] - #batch = self.batch_to_device(batch) + # batch = self.batch_to_device(batch) ### After to_device, then the original is: @@ -534,33 +537,44 @@ def train(self, mini_epoch): outputs = [] for view in batch[-1].source_samples: # TODO remove when ModelBatch and Sample get a to_device() - streams_data = [[view.streams_data['ERA5']]] + streams_data = [[view.streams_data["ERA5"]]] streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] source_cell_lens = view.source_cell_lens source_cell_lens = [b.to(self.device) for b in source_cell_lens] target_coords_idxs = view.target_coords_idx - target_coords_idxs = [[b.to(self.device) for b in bf] for bf in target_coords_idxs] - outputs.append(self.model( - self.model_params, (streams_data, source_cell_lens, target_coords_idxs), cf.forecast_offset, forecast_steps - )) + target_coords_idxs = [ + [b.to(self.device) for b in bf] for bf in target_coords_idxs + ] + outputs.append( + self.model( + self.model_params, + (streams_data, source_cell_lens, target_coords_idxs), + cf.forecast_offset, + forecast_steps, + ) + ) targets_and_auxs = [] for view in batch[-1].target_samples: # TODO remove when ModelBatch and Sample get a to_device() - streams_data = [[view.streams_data['ERA5']]] + streams_data = [[view.streams_data["ERA5"]]] streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] source_cell_lens = view.source_cell_lens source_cell_lens = [b.to(self.device) for b in source_cell_lens] target_coords_idxs = view.target_coords_idx - target_coords_idxs = [[b.to(self.device) for b in bf] for bf in target_coords_idxs] - targets_and_auxs.append(self.target_and_aux_calculator.compute( - self.cf.istep, - (streams_data, source_cell_lens, target_coords_idxs), - self.model_params, - self.model, - cf.forecast_offset, - forecast_steps, - )) + target_coords_idxs = [ + [b.to(self.device) for b in bf] for bf in target_coords_idxs + ] + targets_and_auxs.append( + self.target_and_aux_calculator.compute( + self.cf.istep, + (streams_data, source_cell_lens, target_coords_idxs), + self.model_params, + self.model, + cf.forecast_offset, + forecast_steps, + ) + ) # targets, aux = zip(*targets_and_auxs) loss, loss_values = self.loss_calculator.compute_loss( preds=outputs[0], @@ -583,7 +597,6 @@ def train(self, mini_epoch): # targets=target_aux_output, # ) - # TODO re-enable this, need to think on how to make it compatible with # TODO: CL, this should become a regular loss term # student-teacher training From 8ef3a4c7607c546a7094caf1285b4dc6503b7341 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:08:04 +0100 Subject: [PATCH 112/344] Simplified and clarified handling of default target_aux_calcualtor --- src/weathergen/model/model_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index fb269c71e..6260439e8 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -264,8 +264,8 @@ def get_target_aux_calculator(cf: Config, dataset, model, device, **kwargs): target_aux = None - target_and_aux_calc = cf.get("target_and_aux_calc", None) - if target_and_aux_calc is None or target_and_aux_calc == "identity": + target_and_aux_calc = cf.get("target_and_aux_calc", "physical") + if target_and_aux_calc == "physical": target_aux = PhysicalTargetAndAux(cf, model) elif target_and_aux_calc == "EMATeacher": From d8998a98e072507cb081ea7ee5facbf635d84c4a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:08:38 +0100 Subject: [PATCH 113/344] Linting --- src/weathergen/model/engines.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 8d42c109a..45b98c135 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -94,13 +94,12 @@ def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): # TODO: what is this list dimension??? Where should the istep index be??? for _, sb in enumerate(streams_data): for stream_name, s_data in zip(self.stream_names, sb, strict=True): - # embedding network embed = self.embeds[stream_name] - + # skip empty stream if not s_data.source_empty(): - continue + continue idxs = s_data.source_idxs_embed.to(device) idxs_pe = s_data.source_idxs_embed_pe.to(device) From 652500afefb020d256f290e31a9377869ccf56a2 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:08:53 +0100 Subject: [PATCH 114/344] Linting --- src/weathergen/datasets/tokenizer_masking.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 72a308a49..7c6bb9071 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -105,7 +105,8 @@ def get_source( ) else: (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - idxs_cells, idxs_cells_lens, + idxs_cells, + idxs_cells_lens, ) source_tokens_cells, source_tokens_lens = tokenize_apply_mask_source( @@ -152,7 +153,8 @@ def get_target( self.masker.mask_channels = mask_state.get("mask_channels") (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - idxs_cells, idxs_cells_lens, + idxs_cells, + idxs_cells_lens, ) data, datetimes, coords, coords_local, coords_per_cell = tokenize_apply_mask_target( @@ -193,7 +195,8 @@ def get_target_coords( self.masker.mask_channels = mask_state.get("mask_channels") (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - idxs_cells, idxs_cells_lens, + idxs_cells, + idxs_cells_lens, ) # TODO: split up @@ -259,7 +262,8 @@ def get_target_values( self.masker.mask_channels = mask_state.get("mask_channels") (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - idxs_cells, idxs_cells_lens, + idxs_cells, + idxs_cells_lens, ) data, datetimes, coords, _, _ = tokenize_apply_mask_target( From 03166a202e9af571a93160b5d5fc1f65e6aa9b5d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:09:10 +0100 Subject: [PATCH 115/344] Linting --- .../datasets/multi_stream_data_sampler.py | 85 ++++++++++++------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 4f9fdc661..086fb9c08 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -68,7 +68,6 @@ def collect_datasources(stream_datasets: list, idx: int, type: str) -> IOReaderD class MultiStreamDataSampler(torch.utils.data.IterableDataset): - def __init__( self, cf, @@ -342,8 +341,8 @@ def _build_stream_data_input( # collect all targets for current stream # do we want this to be ascending or descending in time? - rdata = input_data[-(step+1)] - token_data = input_tokens[-(step+1)] + rdata = input_data[-(step + 1)] + token_data = input_tokens[-(step + 1)] stream_data.source_is_spoof = rdata.is_spoof @@ -376,8 +375,7 @@ def _build_stream_data_output( output_tokens: list, mask_state: dict | None = None, ) -> StreamData: - """ - """ + """ """ # collect for all forecast steps dt = self.forecast_offset + forecast_dt @@ -435,7 +433,7 @@ def _build_stream_data( Build a StreamData object for a single view (teacher or student). Args: - mode : + mode : stream_data : base_idx: Time index for this sample forecast_dt: Number of forecast steps @@ -534,26 +532,35 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): TODO: these modes are not being used now. """ - if mode == "student_teacher": - + if mode == "masking" or mode == "student_teacher": streams_data: list[StreamData] = [] # get/coordinate masks masks_streams = self._get_source_target_masks() # Determine number of views direct from config (teacher & student views) - teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + teacher_cfg = ( + self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + ) student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} num_target_samples = int(teacher_cfg.get("num_views", 1)) - num_source_samples = int(teacher_cfg.get("num_views", 1)) * int(student_cfg.get("num_views", 1)) # per teacher - + num_source_samples = int(teacher_cfg.get("num_views", 1)) * int( + student_cfg.get("num_views", 1) + ) # per teacher + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): name = stream_info["name"] - (target_masks, source_masks, student_to_teacher, target_metadata_list, source_metadata_list) = masks_streams[name] + ( + target_masks, + source_masks, + student_to_teacher, + target_metadata_list, + source_metadata_list, + ) = masks_streams[name] # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 @@ -583,7 +590,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): stream_data_source[name] = sdata - # source meta info... # source_meta_info = SampleMetaData(... @@ -592,14 +598,13 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # also want to add the mask to the metadata source_metadata.mask = mask - # TODO: seb check this + # TODO: seb check this # Map each student (source) to its teacher (target) t_idx = student_to_teacher[sidx] batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) # num_input_steps? batch.source_samples[sidx].set_forecast_dt(forecast_dt) - # stream_data_target can contain network input stream_data_target = {} @@ -617,7 +622,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask, ) stream_data_target[name] = sdata - + # get teacher config info target_metadata = target_metadata_list[t_idx] @@ -626,34 +631,39 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # TODO: seb to check # Map target to all source students - student_indices = [s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx] + student_indices = [ + s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx + ] batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) batch.target_samples[t_idx].set_forecast_dt(forecast_dt) - + # TODO: build batch # source_input # target_input # source_output # target_output + # TOOD: remove # add data for current stream streams_data += [v for k, v in stream_data_source.items()] elif mode == "diffusion_forecast": - streams_data: list[StreamData] = [] # get/coordinate masks masks_streams = self._get_source_target_masks() # Determine number of views direct from config (teacher & student views) - teacher_cfg = self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + teacher_cfg = ( + self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + ) student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} num_target_samples = int(teacher_cfg.get("num_views", 1)) - num_source_samples = int(teacher_cfg.get("num_views", 1)) * int(student_cfg.get("num_views", 1)) # per teacher - - batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + num_source_samples = int(teacher_cfg.get("num_views", 1)) * int( + student_cfg.get("num_views", 1) + ) # per teacher + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): @@ -674,7 +684,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # collect source data for current stream # loop over student views stream_data_source = {} - # stream_data_source[name] = self._build_stream_data( + # stream_data_source[name] = self._build_stream_data( sdata = self._build_stream_data( "target_coords target_values", idx, @@ -715,7 +725,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): mask=None, ) stream_data_target[name] = sdata - + # get teacher config info target_metadata = target_metadata @@ -725,16 +735,20 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # Map target to all source students batch.add_target_stream(0, 0, name, sdata, target_metadata) batch.target_samples[0].set_forecast_dt(forecast_dt) - + # TODO: build batch # source_input # target_input # source_output # target_output + # TOOD: remove # add data for current stream streams_data += [v for k, v in stream_data_source.items()] + else: + assert False, "Mode not implemented" + return streams_data, batch def _get_source_target_masks(self): @@ -766,7 +780,7 @@ def to_bool_tensor(arr): source_metadata: list[SampleMetaData] = [] # add a loop over num_teacher_views, generate students for each teacher - for t_idx in range(num_teacher_views): + for _ in range(num_teacher_views): # Build one teacher and its student views t_keep_np, s_keeps_np, metadata = self.tokenizer.masker.build_views_for_stream( self.num_healpix_cells, @@ -781,14 +795,20 @@ def to_bool_tensor(arr): target_metadata.append(metadata[0]) # TODO: first is teacher # this teacher's students and mapping - for s_np, metadata in zip(s_keeps_np or [], metadata[1:], strict=True): + for s_np, md in zip(s_keeps_np or [], metadata[1:], strict=True): source_masks.append(to_bool_tensor(s_np)) # append 0, 1, ... depending on which teacher we did - source_metadata.append(metadata) + source_metadata.append(md) student_to_teacher.append(len(target_masks) - 1) - masks[stream_info["name"]] = (target_masks, source_masks, student_to_teacher, target_metadata, source_metadata) - + masks[stream_info["name"]] = ( + target_masks, + source_masks, + student_to_teacher, + target_metadata, + source_metadata, + ) + return masks def _preprocess_model_data(self, batch, forecast_dt): @@ -802,7 +822,7 @@ def _preprocess_model_data(self, batch, forecast_dt): # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) - + ##### target_coords_idx we probably don't need for the targets ##### target_coords_idx = compute_idxs_predict(self.forecast_offset + forecast_dt, batch) @@ -874,7 +894,6 @@ def __iter__(self): # TODO: link into ModelBatch - # compute batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( batch, forecast_dt From e41a5751d8ff6fde35d548aa80bba996e9133e2d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:09:28 +0100 Subject: [PATCH 116/344] Linting --- src/weathergen/datasets/masking.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 332bba688..885e98527 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -1,5 +1,4 @@ import logging -from typing import List, Tuple import numpy as np import torch @@ -142,7 +141,7 @@ def mask_source_idxs( idxs_cells, idxs_cells_lens, keep_mask: np.typing.NDArray | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Return: @@ -214,7 +213,7 @@ def mask_targets_idxs( self, idxs_cells, idxs_cells_lens, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # mask_source_idxs is assert (self.mask_tokens is not None) or (self.mask_tokens is not None) idxs_ord_inv = torch.tensor([], dtype=torch.int64) @@ -630,7 +629,7 @@ def build_views_for_stream( teacher_cfg: dict, student_cfg: dict, relationship: str = "subset", - ) -> Tuple[np.typing.NDArray, List[np.typing.NDArray], List[SampleMetaData]]: + ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: """ Construct teacher/student keep masks for a stream. SampleMetaData is currently just a dict with the masking params used. @@ -652,7 +651,7 @@ def build_views_for_stream( rate_student = student_cfg.get("rate") s_cfg_extra = student_cfg.get("masking_strategy_config") - student_keep_masks: List[np.ndarray] = [] + student_keep_masks: list[np.typing.NDArray] = [] for _ in range(num_views): base = self.generate_cell_keep_mask( num_cells=num_cells, @@ -668,12 +667,12 @@ def build_views_for_stream( keep = base student_keep_masks.append(keep) - metadata: List[SampleMetaData] = [ + metadata: list[SampleMetaData] = [ SampleMetaData( masking_params=teacher_cfg, ) ] - for idx, mask in enumerate(student_keep_masks): + for _, _ in enumerate(student_keep_masks): metadata.append( SampleMetaData( masking_params=student_cfg, From 0db8b6242f82c191b9055a2dfa406deeac5cd4da Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:09:41 +0100 Subject: [PATCH 117/344] Linting --- src/weathergen/datasets/batch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 55674f81c..25f8fc3aa 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -18,6 +18,7 @@ # TODO: GetTimestep to get the timestep # TODO: GetMetaData: then this gets the right rn for the timestep! + @dataclass class SampleMetaData: # masking strategy @@ -30,6 +31,7 @@ class SampleMetaData: noise_level_rn: float | None = None + class Sample: # keys: stream name, values: SampleMetaData meta_info: dict @@ -93,6 +95,7 @@ def get_stream_data(self, stream_name: str) -> StreamData: assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" return self.streams_data[stream_name] + class ModelBatch: """ Container for all data and metadata for one training batch. @@ -135,7 +138,6 @@ def add_source_stream( # add the meta_info self.source_samples[source_sample_idx].add_meta_info(stream_name, source_meta_info) - assert target_sample_idx < len(self.target_samples), "invalid value for target_sample_idx" self.source2target_matching_idxs[source_sample_idx] = target_sample_idx @@ -157,9 +159,13 @@ def add_target_stream( self.target_samples[target_sample_idx].add_meta_info(stream_name, target_meta_info) if isinstance(source_sample_idx, int): - assert source_sample_idx < len(self.source_samples), "invalid value for source_sample_idx" + assert source_sample_idx < len(self.source_samples), ( + "invalid value for source_sample_idx" + ) else: - assert all(idx < len(self.source_samples) for idx in source_sample_idx), "invalid value for source_sample_idx" + assert all(idx < len(self.source_samples) for idx in source_sample_idx), ( + "invalid value for source_sample_idx" + ) self.target2source_matching_idxs[target_sample_idx] = source_sample_idx def len_sources(self) -> int: From 47750a5b2503d2706e36f7d2db593d44ac7f7c75 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:10:09 +0100 Subject: [PATCH 118/344] Restoring masking as training_mode in default_config --- config/default_config.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index a27ddcdce..fbc68224f 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -116,20 +116,19 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], # TODO: adapt so that the masking or forecast config entry also sits here training_config: # when this is "masking", we are basically only using the model_input subconfig - training_mode: "student_teacher" # "masking", "student_teacher", "forecast" - + training_mode: "masking" # "masking", "student_teacher", "forecast" model_input: masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher rate: 0.4 # Masking rate to use for model input - num_views: 4 # if student-teacher, the number of local (student) views to generate + num_views: 1 # if student-teacher, the number of local (student) views to generate hl_mask : 4 # healpix level to use for healpix masking strategy relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. teacher_model_input: strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - rate: 0.8 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) - num_views: 2 # number of teacher views to generate + rate: 0.4 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + num_views: 1 # number of teacher views to generate hl_mask : 0 # healpix level to use for healpix masking strategy # keep_m: 100 # Alternative to rate: keep exactly this many parent cells rate_sampling: true # randomly sample the rate per batch From bc8d23e2f332590fa215b3e3d81fa81333f50161 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:18:01 +0100 Subject: [PATCH 119/344] More linting --- src/weathergen/train/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e21e84cf0..53e66e21f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -504,13 +504,13 @@ def train(self, mini_epoch): # batch[-1].meta_info is a dictionary with metadata info per sample # batch[-1].meta_info["ERA5"] etc. # here we have the noise_level_rn - # batch[-1].source_samples[0].meta_info["ERA5"].noise_level_rn + # batch[-1].source_samples[0].meta_info["ERA5"].noise_level_rn # == batch[-1].target_samples[0].meta_info["ERA5"].noise_level_rn - # for the same timestep, this needs to be fixed for when we have more source timesteps, + # for the same timestep, this needs to be fixed for when we have more source timesteps, # and perhaps with bigger batch sizes? # Each Sample object has: # .streams_data: a dictionary of StreamData objects per stream name - # .source_cell_lens: list of tensors with lengths of source cells per stream # to be + # .source_cell_lens: list of tensors with lengths of source cells per stream # to be # changed to be in ModelBatch # .target_coords_idx: list of tensors with target coordinate indices per stream # to # be changed to be in ModelBatch From 62899591a31e405a8502dd68547a3dd51693d1d8 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:36:38 +0100 Subject: [PATCH 120/344] Removed duplicate lines due to mergeing --- src/weathergen/train/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 53e66e21f..b8078bff2 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -608,10 +608,6 @@ def train(self, mini_epoch): self.cf.istep, batch, self.model ) - self.target_and_aux_calculator.update_state_pre_backward(bidx, batch, self.model) - - self.target_and_aux_calculator.update_state_pre_backward(bidx, batch, self.model) - # backward pass self.optimizer.zero_grad() self.grad_scaler.scale(loss).backward() From d526dfca0c6845ffb80a22da49ab85281e1b86ff Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:37:02 +0100 Subject: [PATCH 121/344] Restored masking as training mode. Not working due to NaN in prediction --- .../datasets/multi_stream_data_sampler.py | 118 +++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 086fb9c08..4f89b6727 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -532,7 +532,123 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): TODO: these modes are not being used now. """ - if mode == "masking" or mode == "student_teacher": + if mode == "masking": + streams_data: list[StreamData] = [] + + # get/coordinate masks + masks_streams = self._get_source_target_masks() + + # Determine number of views direct from config (teacher & student views) + teacher_cfg = ( + self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} + ) + student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} + num_target_samples = int(teacher_cfg.get("num_views", 1)) + num_source_samples = int(teacher_cfg.get("num_views", 1)) * int( + student_cfg.get("num_views", 1) + ) # per teacher + + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + + # for all streams + for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + name = stream_info["name"] + + # TODO: data class for this or something similar + ( + target_masks, + source_masks, + student_to_teacher, + target_metadata_list, + source_metadata_list, + ) = masks_streams[name] + + # input_data and output_data is conceptually consecutive but differs + # in source and target channels; overlap in one window when self.forecast_offset=0 + (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) + + # tokenize windows + # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps + input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) + output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) + + # collect source data for current stream + # loop over student views + stream_data_source = {} + for sidx, mask in enumerate(source_masks): + # stream_data_source[name] = self._build_stream_data( + sdata = self._build_stream_data( + "target_coords target_values", + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + mask, + ) + + stream_data_source[name] = sdata + + # source meta info... + # source_meta_info = SampleMetaData(... + + source_metadata = source_metadata_list[sidx] # first is teacher + + # also want to add the mask to the metadata + source_metadata.mask = mask + + # TODO: seb check this + # Map each student (source) to its teacher (target) + t_idx = student_to_teacher[sidx] + batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) + # num_input_steps? + batch.source_samples[sidx].set_forecast_dt(forecast_dt) + + # stream_data_target can contain network input + stream_data_target = {} + + for t_idx, mask in enumerate(source_masks): + # stream_data_target[name] = self._build_stream_data( + sdata = self._build_stream_data( + "target_values", + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + mask, + ) + stream_data_target[name] = sdata + + # get teacher config info + target_metadata = target_metadata_list[t_idx] + + # also want to add the mask to the metadata + target_metadata.mask = mask + + # TODO: seb to check + # Map target to all source students + student_indices = [ + s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx + ] + batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) + batch.target_samples[t_idx].set_forecast_dt(forecast_dt) + + # TODO: build batch + # source_input + # target_input + # source_output + # target_output + + # TOOD: remove + # add data for current stream + streams_data += [v for k, v in stream_data_source.items()] + + elif mode == "student_teacher": streams_data: list[StreamData] = [] # get/coordinate masks From 657094a2a5ec0466ab4c9b7c24487574df9d18fa Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 28 Nov 2025 08:59:39 +0100 Subject: [PATCH 122/344] Fixed problem in engines introduced in recent commits merging develop. This fixes masking training --- src/weathergen/model/engines.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 45b98c135..8a26c492b 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -89,7 +89,8 @@ def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): for ob in offsets_base ] - # iterate over all forecast steps and + # TODO: handling of input steps should be done using encoder + # iterate over all input steps and streams for istep in range(num_step_input): # TODO: what is this list dimension??? Where should the istep index be??? for _, sb in enumerate(streams_data): @@ -98,11 +99,11 @@ def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): embed = self.embeds[stream_name] # skip empty stream - if not s_data.source_empty(): + if s_data.source_empty(): continue - idxs = s_data.source_idxs_embed.to(device) - idxs_pe = s_data.source_idxs_embed_pe.to(device) + idxs = s_data.source_idxs_embed[istep].to(device) + idxs_pe = s_data.source_idxs_embed_pe[istep].to(device) # create full scatter index # (there's no broadcasting which is likely highly inefficient) From 1a37dd1ff76aaae1b262618b5d773f82023be82b Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 28 Nov 2025 10:31:43 +0100 Subject: [PATCH 123/344] remove unused mask generation in diffusion_forecast --- src/weathergen/datasets/multi_stream_data_sampler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 4f89b6727..84729a585 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -766,9 +766,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): elif mode == "diffusion_forecast": streams_data: list[StreamData] = [] - # get/coordinate masks - masks_streams = self._get_source_target_masks() - # Determine number of views direct from config (teacher & student views) teacher_cfg = ( self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} From 0d44f40a1aeedb998848bb0730749473599eb441 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 28 Nov 2025 10:50:59 +0100 Subject: [PATCH 124/344] remove duplicate key in config --- config/default_config.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index f2ade2dd1..4def74c0e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -162,7 +162,6 @@ samples_per_mini_epoch: 4096 samples_per_validation: 512 shuffle: True -mixed_precision_dtype: bf16 lr_scaling_policy: "sqrt" lr_start: 1e-6 From caadb379540390e77d6d77190792e65f48616b58 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 28 Nov 2025 10:54:47 +0100 Subject: [PATCH 125/344] add back masking_rate dog --- config/default_config.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/config/default_config.yml b/config/default_config.yml index 4def74c0e..2fe50d3bd 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -117,6 +117,8 @@ validation_mode_config: { "shared_heads": False, "teacher_model": {} } +# masking +masking_strategy: "dog" # obviously TODO # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # From 680f577e5c1f1636e240a99a1168f2f3746e7709 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 28 Nov 2025 11:22:29 +0100 Subject: [PATCH 126/344] update config with new training mode --- config/default_config.yml | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 2fe50d3bd..edc176c3b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -140,22 +140,16 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], # TODO: adapt so that the masking or forecast config entry also sits here training_config: # when this is "masking", we are basically only using the model_input subconfig - training_mode: "masking" # "masking", "student_teacher", "forecast" + training_mode: "diffusion_forecast" # "masking", "student_teacher", "forecast" model_input: - masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher - rate: 0.4 # Masking rate to use for model input - num_views: 1 # if student-teacher, the number of local (student) views to generate - hl_mask : 4 # healpix level to use for healpix masking strategy - relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - + num_input_steps: 1 # fake, read from standard cd.num_input_steps, default 1 # NOTE: only 1 works for now + loss: "LatentDiffusionLoss" # place holder + teacher_model_input: - strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - rate: 0.4 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) - num_views: 1 # number of teacher views to generate - hl_mask : 0 # healpix level to use for healpix masking strategy - # keep_m: 100 # Alternative to rate: keep exactly this many parent cells - rate_sampling: true # randomly sample the rate per batch + forecast_offset: 0 # fake, still read from usual place + num_forecast_steps: 1 # fake, still read from usual place + loss: "LatentDiffusionLoss" # placeholder From bb717313dc3c99d93fec8df77006d28d701f72f1 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 28 Nov 2025 12:25:35 +0100 Subject: [PATCH 127/344] forecast_diffusion running with new data batch --- config/default_config.yml | 2 +- .../datasets/multi_stream_data_sampler.py | 4 +-- src/weathergen/model/diffusion.py | 33 +++---------------- src/weathergen/model/engines.py | 13 +++----- .../loss_module_latent_diffusion.py | 18 +++++----- .../train/target_and_aux_diffusion.py | 7 ++-- 6 files changed, 28 insertions(+), 49 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index edc176c3b..bfc97c34b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -118,7 +118,7 @@ validation_mode_config: { "teacher_model": {} } # masking -masking_strategy: "dog" # obviously TODO +masking_strategy: "random" # obviously TODO # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 84729a585..b1a4049f3 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -859,8 +859,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # add data for current stream streams_data += [v for k, v in stream_data_source.items()] - else: - assert False, "Mode not implemented" + else: + assert False, "Mode not implemented" return streams_data, batch diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index c9a969538..b64ea6fd5 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -23,7 +23,6 @@ # ---------------------------------------------------------------------------- -import dataclasses import math import torch @@ -32,31 +31,6 @@ from weathergen.model.engines import ForecastingEngine -@dataclasses.dataclass -class BatchData: - """ - Mock function for the data that will be provided to the diffusion model. Will change. - """ - - model_samples: dict - target_samples: dict - - def get_sample_len(self): - return len(list(self.model_samples.keys())) - - def get_input_data(self, t: int): - return self.model_samples[t]["data"] - - def get_input_metadata(self, t: int): - return self.model_samples[t]["metadata"] - - def get_target_data(self, t: int): - return self.target_samples[t]["data"] - - def get_target_metadata(self, t: int): - return self.target_samples[t]["metadata"] - - class DiffusionForecastEngine(torch.nn.Module): # Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72 @@ -96,14 +70,17 @@ def forward(self, tokens: torch.Tensor, fstep: int, metadata: dict) -> torch.Ten # y = data.get_input_data(-1) # eta = data.get_input_metadata(-1) - c = 1 + c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens - eta = metadata.noise_level_rn.to(device=tokens.device) + eta = torch.tensor([metadata.noise_level_rn], device=tokens.device) + # eta = torch.randn(1).to(device=tokens.device) + # eta = torch.tensor([metadata.noise_level_rn]).to(device=tokens.device) # Compute sigma (noise level) from eta # noise = torch.randn(y.shape, device=y.device) # now eta from MultiStreamDataSampler sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) # Compute loss -- move this to a separate loss calculator diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7a3110a7f..d51fc65cf 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -342,7 +342,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=(1 if cf.forecast_with_step_conditioning else 0), + dim_aux=(1 if cf.forecast_with_step_conditioning else None), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_noise_conditioning=self.cf.fe_diffusion_model, @@ -359,7 +359,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=(1 if cf.forecast_with_step_conditioning else 0), + dim_aux=(1 if cf.forecast_with_step_conditioning else None), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_noise_conditioning=self.cf.fe_diffusion_model, @@ -373,16 +373,12 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.fe_dropout_rate, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=(1 if cf.forecast_with_step_conditioning else None), norm_eps=self.cf.mlp_norm_eps, with_noise_conditioning=self.cf.fe_diffusion_model, ) ) - self.fe_blocks.append( - torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) - ) - def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) @@ -399,7 +395,8 @@ def forward(self, tokens, fstep, noise_emb=None): tokens_in = tokens # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning - aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") + # aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") + aux_info = None if self.cf.fe_diffusion_model: assert noise_emb is not None, ( "Noise embedding must be provided for diffusion forecast engine" diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index b68147ca6..a233835c2 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -87,17 +87,19 @@ def compute_loss( for _, _, loss_fct_name in self.loss_fcts } - preds = preds.latent["preds"] - targets = targets["latent"] - fsteps = len(targets) + pred_tokens_all = preds.latent["preds"] + target_tokens_all = targets.latent + eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device) + fsteps = len(target_tokens_all) - eta = torch.randn(1) - noise_weight = self._get_noise_weight(eta).to(device=preds[0].device) + noise_weight = self._get_noise_weight(eta) fstep_loss_weights = self._get_fstep_weights(fsteps) loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) ctr_fsteps = 0 - for target, pred, fstep_loss_weight in zip(targets, preds, fstep_loss_weights, strict=True): + for target_tokens, pred_tokens, fstep_loss_weight in zip( + target_tokens_all, pred_tokens_all, fstep_loss_weights, strict=True + ): # the first entry in tokens_all is the source itself, so skip it loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) ctr_loss_fcts = 0 @@ -106,8 +108,8 @@ def compute_loss( for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: loss_lfct = self._loss_per_loss_function( loss_fct, - target=target, - pred=pred, + target=target_tokens, + pred=pred_tokens, noise_weight=noise_weight, ) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 569dd6920..f86a2c71d 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -2,7 +2,7 @@ import torch -from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput class DiffusionLatentTargetEncoder(TargetAndAuxModuleBase): @@ -13,6 +13,7 @@ def __init__(self, model): def compute( self, bidx, batch, model_params, model, forecast_offset, forecast_steps ) -> tuple[Any, Any]: + (_, _, _, metadata) = batch with torch.no_grad(): tokens, posteriors = self.model( model_params=model_params, @@ -21,4 +22,6 @@ def compute( forecast_steps=None, encode_only=True, ) - return {"latent": [tokens]}, posteriors + return TargetAuxOutput( + physical=None, latent=[tokens], aux_outputs={"noise_level_rn": metadata.noise_level_rn} + ) From 6ea07e726aa2b39ee2e001f4b22d92e549ed8a52 Mon Sep 17 00:00:00 2001 From: Seb Hickman <56727418+shmh40@users.noreply.github.com> Date: Fri, 28 Nov 2025 11:34:41 +0000 Subject: [PATCH 128/344] restore masking_strategy to random Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index fbc68224f..a10e52de2 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -94,7 +94,7 @@ validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['ms } # masking -masking_strategy: "dog" # obviously TODO +masking_strategy: "random" # TODO # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # From 4281aff1f20e0cd1f14b654c9a8bb136eb1e559a Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 28 Nov 2025 12:40:24 +0100 Subject: [PATCH 129/344] restore loader_num_workers to 8 --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index a10e52de2..8d6ae4026 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -169,7 +169,7 @@ input_window_steps: 1 val_initial: False -loader_num_workers: 0 +loader_num_workers: 8 log_validation: 0 streams_output: ["ERA5"] From 950e5b461898c06e7612129a4b3734877c7e81f4 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 28 Nov 2025 13:27:36 +0100 Subject: [PATCH 130/344] set loader_num_workers to 8 --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index bfc97c34b..c44e5e91a 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -187,7 +187,7 @@ input_window_steps: 1 val_initial: False -loader_num_workers: 0 +loader_num_workers: 8 log_validation: 0 streams_output: ["ERA5"] From 15b46e9a5c78928bca8310187deab26133507783 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Fri, 28 Nov 2025 13:30:54 +0100 Subject: [PATCH 131/344] fix indentation of else: assert False in _get_sample msds --- src/weathergen/datasets/multi_stream_data_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 84729a585..b1a4049f3 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -859,8 +859,8 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # add data for current stream streams_data += [v for k, v in stream_data_source.items()] - else: - assert False, "Mode not implemented" + else: + assert False, "Mode not implemented" return streams_data, batch From 76270aad7665cae6fbddce48bf6f696a005c8ab7 Mon Sep 17 00:00:00 2001 From: Moritz Hauschulz <60788263+moritzhauschulz@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:12:06 +0000 Subject: [PATCH 132/344] [1269] Noise generation in diffusion inference (#1374) * noise generation in diffusion inference * lint --------- Co-authored-by: Matthias Karlbauer --- src/weathergen/model/diffusion.py | 9 +++++++-- src/weathergen/model/model.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index b64ea6fd5..26f667d24 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -36,11 +36,14 @@ class DiffusionForecastEngine(torch.nn.Module): def __init__( self, + cf: Config, + num_healpix_cells: int, forecast_engine: ForecastingEngine, cf: Config, ): super().__init__() self.cf = cf + self.num_healpix_cells = num_healpix_cells self.net = forecast_engine self.preconditioner = Preconditioner() self.frequency_embedding_dim = self.cf.frequency_embedding_dim @@ -110,15 +113,17 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> def inference( self, - x: torch.Tensor, fstep: int, num_steps: int = 30, ) -> torch.Tensor: # Forward pass of the diffusion model during inference # https://github.com/NVlabs/edm/blob/main/generate.py + # Sample noise (assuming single batch element for now) + x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") + # Time step discretization. - step_indices = torch.arange(num_steps, dtype=torch.float64, device=x.device) + step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") t_steps = ( self.sigma_max ** (1 / self.rho) + step_indices diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a8174bfb4..1f91656a6 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -337,7 +337,7 @@ def create(self) -> "Model": self.forecast_engine = ForecastingEngine(cf, self.num_healpix_cells) if cf.fe_diffusion_model: self.forecast_engine = DiffusionForecastEngine( - forecast_engine=self.forecast_engine, cf=cf + cf, self.num_healpix_cells, forecast_engine=self.forecast_engine ) ############### From b662bf26806b8647bcb41645962bd3c36754fe91 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 28 Nov 2025 14:54:44 +0000 Subject: [PATCH 133/344] Made pre-trained encoder weights optional --- config/default_config.yml | 9 +-- src/weathergen/model/diffusion.py | 8 +-- src/weathergen/model/model_interface.py | 77 +++++++++++++------------ 3 files changed, 43 insertions(+), 51 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index c44e5e91a..332d7be82 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -60,10 +60,8 @@ sigma_data: 0.5 rho: 7 p_mean: -1.2 p_std: 1.2 -# Encoder weights -chkpt_encoder_weights: "./models/whkujigw/whkujigw_epoch00063.chkpt" - - +# Encoder weights (set to null to not load a pretrained encoder) +chkpt_encoder_weights: "./models/whkujigw/whkujigw_latest.chkpt" healpix_level: 5 @@ -151,8 +149,6 @@ training_config: num_forecast_steps: 1 # fake, still read from usual place loss: "LatentDiffusionLoss" # placeholder - - num_mini_epochs: 32 samples_per_mini_epoch: 4096 samples_per_validation: 512 @@ -204,7 +200,6 @@ train_log_freq: metrics: 20 checkpoint: 250 - # Tags for experiment tracking # These tags will be logged in MLFlow along with completed runs for train, eval, val # The tags are free-form, with the following rules: diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 26f667d24..7689e1440 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -34,13 +34,7 @@ class DiffusionForecastEngine(torch.nn.Module): # Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72 - def __init__( - self, - cf: Config, - num_healpix_cells: int, - forecast_engine: ForecastingEngine, - cf: Config, - ): + def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: ForecastingEngine): super().__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index e93713473..13b9ec2e3 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -160,45 +160,48 @@ def init_model_and_shard( # ------------------------------------------------------------------------------------------ # LOAD AND FREEZE ENCODER WEIGHTS # ONLY FOR EXPERIMENTATION, TO BE REMOVED - params = torch.load( - cf.chkpt_encoder_weights, - map_location=torch.device("cpu"), - mmap=True, - weights_only=True, - ) - encoder_modules = [ - "embed_engine", - "ae_local_engine", - "ae_local_global_engine", - "ae_global_engine", - ] - - # Load encoder weights - params_temp = {} - for name in params.keys(): - if any(e_module in name for e_module in encoder_modules): - if cf.with_ddp: - params_temp[f"module.{name}"] = params[name] - else: - params_temp[name] = params[name] - params = params_temp - mkeys, ukeys = model.load_state_dict(params, strict=False) - - # Freeze encoder weights - for name, module in model.named_modules(): - if any(e_module in name for e_module in encoder_modules): - for p in module.parameters(): - p.requires_grad = False + if cf.chkpt_encoder_weights: + params = torch.load( + cf.chkpt_encoder_weights, + map_location=torch.device("cpu"), + mmap=True, + weights_only=True, + ) + encoder_modules = [ + "embed_engine", + "ae_local_engine", + "ae_local_global_engine", + "ae_global_engine", + ] + + # Load encoder weights + params_temp = {} + for name in params.keys(): + if any(e_module in name for e_module in encoder_modules): + if cf.with_ddp: + params_temp[f"module.{name}"] = params[name] + else: + params_temp[name] = params[name] + params = params_temp + mkeys, ukeys = model.load_state_dict(params, strict=False) - model = model.to(f"cuda:{cf.local_rank}") + # Freeze encoder weights + for name, module in model.named_modules(): + if any(e_module in name for e_module in encoder_modules): + for p in module.parameters(): + p.requires_grad = False - # warn about difference in checkpoint and model - if len(mkeys) == 0 and len(ukeys) == 0: - logger.info(f"Checkpoint {cf.chkpt_encoder_weights} loaded successfully with all weights.") - if len(mkeys) > 0: - logger.warning(f"Missing keys when loading model: {mkeys}") - if len(ukeys) > 0: - logger.warning(f"Unused keys when loading model: {ukeys}") + model = model.to(f"cuda:{cf.local_rank}") + + # warn about difference in checkpoint and model + if len(mkeys) == 0 and len(ukeys) == 0: + logger.info( + f"Checkpoint {cf.chkpt_encoder_weights} loaded successfully with all weights." + ) + if len(mkeys) > 0: + logger.warning(f"Missing keys when loading model: {mkeys}") + if len(ukeys) > 0: + logger.warning(f"Unused keys when loading model: {ukeys}") # ------------------------------------------------------------------------------------------ # model params From 3b55ef56432d90e0a65fddbbb66e39db14dc3ac0 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 2 Dec 2025 09:42:35 +0000 Subject: [PATCH 134/344] Update validation to new data structure --- src/weathergen/train/trainer.py | 101 ++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 23 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 69b86a7bf..1044a4209 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -725,34 +725,87 @@ def validate(self, mini_epoch): total=len(self.data_loader_validation), disable=self.cf.with_ddp ) as pbar: for bidx, batch in enumerate(dataset_val_iter): - forecast_steps = batch[-1] - batch = self.batch_to_device(batch) - # evaluate model + forecast_steps = batch[0][-1] with torch.autocast( device_type=f"cuda:{cf.local_rank}", dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, - ): - model_forward = ( - self.model.forward - if self.ema_model is None - else self.ema_model.forward_eval - ) - output = model_forward( - self.model_params, batch, cf.forecast_offset, forecast_steps - ) - target_aux_output = self.target_and_aux_calculator.compute( - bidx, - batch, - self.model_params, - self.model, - cf.forecast_offset, - forecast_steps, - ) + ): + outputs = [] + for view in batch[-1].source_samples: + # TODO remove when ModelBatch and Sample get a to_device() + streams_data = [[view.streams_data["ERA5"]]] + streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] + source_cell_lens = view.source_cell_lens + source_cell_lens = [b.to(self.device) for b in source_cell_lens] + target_coords_idxs = view.target_coords_idx + target_coords_idxs = [ + [b.to(self.device) for b in bf] for bf in target_coords_idxs + ] + outputs.append( + self.model( + self.model_params, + ( + streams_data, + source_cell_lens, + target_coords_idxs, + view.meta_info["ERA5"], + ), + cf.forecast_offset, + forecast_steps, + ) + ) + + targets_and_auxs = [] + for view in batch[-1].target_samples: + # TODO remove when ModelBatch and Sample get a to_device() + streams_data = [[view.streams_data["ERA5"]]] + streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] + source_cell_lens = view.source_cell_lens + source_cell_lens = [b.to(self.device) for b in source_cell_lens] + target_coords_idxs = view.target_coords_idx + target_coords_idxs = [ + [b.to(self.device) for b in bf] for bf in target_coords_idxs + ] + targets_and_auxs.append( + self.target_and_aux_calculator.compute( + self.cf.istep, + ( + streams_data, + source_cell_lens, + target_coords_idxs, + view.meta_info["ERA5"], + ), + self.model_params, + self.model, + cf.forecast_offset, + forecast_steps, + ) + ) + # OLD + # forecast_steps = batch[-1] + # batch = self.batch_to_device(batch) + # + # model_forward = ( + # self.model.forward + # if self.ema_model is None + # else self.ema_model.forward_eval + # ) + # output = model_forward( + # self.model_params, batch, cf.forecast_offset, forecast_steps + # ) + # target_aux_output = self.target_and_aux_calculator.compute( + # bidx, + # batch, + # self.model_params, + # self.model, + # cf.forecast_offset, + # forecast_steps, + # ) loss, loss_values = self.loss_calculator_val.compute_loss( - preds=output, - targets=target_aux_output, + preds=outputs[0], + targets=targets_and_auxs[0], ) # log output @@ -766,7 +819,7 @@ def validate(self, mini_epoch): targets_times_all, targets_lens, ) = self._prepare_logging( - preds=output, + preds=outputs, forecast_offset=cf.forecast_offset, forecast_steps=cf.forecast_steps, streams_data=streams_data, @@ -820,6 +873,8 @@ def batch_to_device(self, batch): self.device_type = torch.accelerator.current_accelerator() self.device = torch.device(f"{self.device_type}:{self.cf.local_rank}") # forecast_steps is dropped here from the batch + for i, b in enumerate(batch[0]): + print(f"{i}th b before to_device: {b}") return ( [[d.to_device(self.device) for d in db] for db in batch[0]], [b.to(self.device) for b in batch[1]], From 2b2c9778551bfd6f4236cbca42b7b102ca7b0d13 Mon Sep 17 00:00:00 2001 From: Tim Hunter Date: Tue, 2 Dec 2025 17:03:41 +0100 Subject: [PATCH 135/344] linter warnings --- src/weathergen/datasets/masking.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 9dde6af78..f27104c29 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -49,6 +49,9 @@ def __init__(self, cf: Config): # masking_strategy_config is a dictionary that can hold any additional parameters self.healpix_level_data = cf.healpix_level self.masking_strategy_config = cf.get("masking_strategy_config", {}) + self.perm_sel = None + self.mask_tokens = None + self.mask_channels = None self.mask_value = 0.0 self.dim_time_enc = 6 From c8a2aadc26a1799a2af3505c4649316c8f701e45 Mon Sep 17 00:00:00 2001 From: Tim Hunter Date: Tue, 2 Dec 2025 17:06:56 +0100 Subject: [PATCH 136/344] commenting tests --- .../datasets/tokenizer_utils_test.py | 64 ------------------- src/weathergen/datasets/utils_test.py | 57 ++++++++--------- 2 files changed, 28 insertions(+), 93 deletions(-) delete mode 100644 src/weathergen/datasets/tokenizer_utils_test.py diff --git a/src/weathergen/datasets/tokenizer_utils_test.py b/src/weathergen/datasets/tokenizer_utils_test.py deleted file mode 100644 index 322ca87eb..000000000 --- a/src/weathergen/datasets/tokenizer_utils_test.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -from torch import Tensor, tensor - -from weathergen.datasets.tokenizer_utils import CoordNormalizer, _coords_local, r3tos2 - -_pos3r = tensor( - [ - [-1.2492e-02, -1.0921e-09, 9.9992e-01], - [-1.1881e-02, 9.9992e-01, -3.8603e-03], - [-1.0106e-02, -7.3428e-03, 9.9992e-01], - [-7.3428e-03, -1.0106e-02, 9.9992e-01], - [-3.8603e-03, -1.1881e-02, 9.9992e-01], - [1.4897e-10, -1.2492e-02, 9.9992e-01], - [3.8603e-03, -1.1881e-02, 9.9992e-01], - [7.3428e-03, -1.0106e-02, 9.9992e-01], - [1.0106e-02, -7.3428e-03, 9.9992e-01], - [1.1881e-02, -3.8603e-03, 9.9992e-01], - [1.2492e-02, 0.0000e00, 9.9992e-01], - [1.1881e-02, 3.8603e-03, 9.9992e-01], - [1.0106e-02, 7.3428e-03, 9.9992e-01], - [7.3428e-03, 1.0106e-02, 9.9992e-01], - [3.8603e-03, 1.1881e-02, 9.9992e-01], - [-5.4606e-10, 1.2492e-02, 9.9992e-01], - [-3.8603e-03, 1.1881e-02, 9.9992e-01], - [-7.3428e-03, 1.0106e-02, 9.9992e-01], - [-1.0106e-02, 7.3428e-03, 9.9992e-01], - ] -) - -_idxs_ord = [ - tensor([6, 4, 5, 7, 0, 0, 0, 0]), - tensor([1, 2, 3, 8, 0, 0, 0, 0]), - tensor([9, 10, 11, 0, 0, 0, 0, 0]), -] - -_hpy_verts_rots = tensor( - [ - [[0.7070, 0.7070, 0.0208], [-0.7070, 0.7072, -0.0086], [-0.0208, -0.0086, 0.9997]], - [[0.6889, 0.7236, 0.0417], [-0.7236, 0.6900, -0.0179], [-0.0417, -0.0179, 0.9990]], - [[0.7236, 0.6889, 0.0417], [-0.6889, 0.7246, -0.0167], [-0.0417, -0.0167, 0.9990]], - ] -) - - -def simple_coords_local( - posr3: Tensor, hpy_verts_rots: Tensor, idxs_ord: list[Tensor], n_coords: CoordNormalizer -) -> list[Tensor]: - fp32 = torch.float32 - posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3]) # prepend zero - """Compute simple local coordinates for a set of 3D positions on the unit sphere.""" - return [ - n_coords(r3tos2(torch.matmul(R, posr3[idxs].transpose(1, 0)).transpose(1, 0)).to(fp32)) - for R, idxs in zip(hpy_verts_rots, idxs_ord, strict=True) - ] - - -def test_coords_local(): - n_coords = lambda x: x - coords_local = simple_coords_local(_pos3r, _hpy_verts_rots, _idxs_ord, n_coords) - coords_local_ref = _coords_local(_pos3r, _hpy_verts_rots, _idxs_ord, n_coords) - torch.testing.assert_close(coords_local, coords_local_ref, atol=1e-6, rtol=0) - - -test_coords_local() diff --git a/src/weathergen/datasets/utils_test.py b/src/weathergen/datasets/utils_test.py index 56f59cd07..2e614937e 100644 --- a/src/weathergen/datasets/utils_test.py +++ b/src/weathergen/datasets/utils_test.py @@ -5,7 +5,6 @@ locs_to_cell_coords_ctrs, locs_to_ctr_coords, s2tor3, - tcs_optimized, vecs_to_rots, ) @@ -73,36 +72,36 @@ def test_locs_to_cell_coords_ctrs(): ) -def _tcs_simpled(target_coords: list[Tensor]) -> tuple[list[Tensor], Tensor]: - tcs = [ - ( - s2tor3( - torch.deg2rad(90.0 - t[..., 0]), - torch.deg2rad(180.0 + t[..., 1]), - ) - if len(t) > 0 - else torch.tensor([]) - ) - for t in target_coords - ] - cat_target_coords = torch.cat(target_coords) - return tcs, cat_target_coords +# def _tcs_simpled(target_coords: list[Tensor]) -> tuple[list[Tensor], Tensor]: +# tcs = [ +# ( +# s2tor3( +# torch.deg2rad(90.0 - t[..., 0]), +# torch.deg2rad(180.0 + t[..., 1]), +# ) +# if len(t) > 0 +# else torch.tensor([]) +# ) +# for t in target_coords +# ] +# cat_target_coords = torch.cat(target_coords) +# return tcs, cat_target_coords -def test_tcs(): - target_coords = [ - tensor( - [[2.3377, -135.0000], [1.4026, -135.4545], [1.4026, -134.5455], [0.4675, -135.0000]] - ), - tensor( - [[3.2727, -133.6082], [2.3377, -134.0816], [2.3377, -133.1633], [1.4026, -133.6364]] - ), - ] - tcs_ref, cat_tcs_ref = _tcs_simpled(target_coords) - tcs_opt, cat_tcs_opt = tcs_optimized(target_coords) - assert len(tcs_ref) == len(tcs_opt) - torch.testing.assert_close(cat_tcs_ref, cat_tcs_opt) - torch.testing.assert_close(tcs_ref, tcs_opt, atol=1e-8, rtol=1e-5) +# def test_tcs(): +# target_coords = [ +# tensor( +# [[2.3377, -135.0000], [1.4026, -135.4545], [1.4026, -134.5455], [0.4675, -135.0000]] +# ), +# tensor( +# [[3.2727, -133.6082], [2.3377, -134.0816], [2.3377, -133.1633], [1.4026, -133.6364]] +# ), +# ] +# tcs_ref, cat_tcs_ref = _tcs_simpled(target_coords) +# tcs_opt, cat_tcs_opt = tcs_optimized(target_coords) +# assert len(tcs_ref) == len(tcs_opt) +# torch.testing.assert_close(cat_tcs_ref, cat_tcs_opt) +# torch.testing.assert_close(tcs_ref, tcs_opt, atol=1e-8, rtol=1e-5) def _locs_to_ctr_coords(ctrs_r3, locs: list[torch.Tensor]) -> list[torch.Tensor]: From 2599ec28eba9c573425202f239377d1857d686c0 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 00:10:13 +0100 Subject: [PATCH 137/344] Restructured code so that mask generation and application is cleanly separated --- src/weathergen/datasets/batch.py | 9 +- src/weathergen/datasets/masking.py | 494 +++++------------- .../datasets/multi_stream_data_sampler.py | 288 +++------- src/weathergen/datasets/tokenizer_masking.py | 153 +++--- 4 files changed, 293 insertions(+), 651 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 25f8fc3aa..190de38ad 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -21,16 +21,11 @@ @dataclass class SampleMetaData: - # masking strategy - # masking_strategy: str - - # parameters for masking strategy - masking_params: Config | dict + # sample parameters (masking) + params: Config | dict mask: torch.Tensor | None = None - noise_level_rn: float | None = None - class Sample: # keys: stream name, values: SampleMetaData diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 885e98527..b7c22ec6d 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass import numpy as np import torch @@ -9,6 +10,18 @@ _logger = logging.getLogger(__name__) +# Convert to torch.bool +def to_bool_tensor(arr): + return torch.from_numpy(np.asarray(arr)).to(torch.bool) + + +@dataclass +class MaskingStrategy: + strategy: str + config: dict + num_samples: int + + class Masker: """Class to generate masks for token sequences and apply them. This class supports different masking strategies and combinations. @@ -136,300 +149,6 @@ def _select_strategy(self): # Non-combination strategy, return as is return self.masking_strategy - def mask_source_idxs( - self, - idxs_cells, - idxs_cells_lens, - keep_mask: np.typing.NDArray | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - - Return: - torch.Tensor[bool] of length num_tokens that determines masking for each token - """ - - self.mask_tokens, self.mask_channels = None, None - - num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() - - # If there are no tokens, return empty lists. - if num_tokens == 0: - return (self.mask_tokens, self.mask_channels) - - # If an explicit keep_mask is provided we bypass strategy selection and directly - # construct the token-level mask from it. keep_mask expresses cells to KEEP (True=keep). - # Otherwise fall back to the configured strategy logic. - if keep_mask is not None: - assert len(keep_mask) == len(idxs_cells_lens), ( - "keep_mask length does not match number of cells." - ) - # build token level mask: for each cell replicate the keep flag across its tokens - token_level_flags: list[np.typing.NDArray] = [] - for km, lens_cell in zip(keep_mask, idxs_cells_lens, strict=True): - num_tokens_cell = len(lens_cell) - if num_tokens_cell == 0: - continue - token_level_flags.append( - np.ones(num_tokens_cell, dtype=bool) - if km - else np.zeros(num_tokens_cell, dtype=bool) - ) - if token_level_flags: - self.mask_tokens = np.concatenate(token_level_flags) - else: - self.mask_tokens = np.array([], dtype=bool) - return (self.mask_tokens, self.mask_channels) - - # clean strategy selection - self.current_strategy = self._select_strategy() - - # Set the masking rate. - rate = self._get_sampling_rate() - - if self.current_strategy == "random": - self.mask_tokens = self.rng.uniform(0, 1, num_tokens) < rate - - elif self.current_strategy == "forecast": - self.mask_tokens = np.ones(num_tokens, dtype=np.bool) - - elif self.current_strategy == "healpix": - # TODO: currently only for fixed level - num_cells = len(idxs_cells_lens) - mask_cells = self.rng.uniform(0, 1, num_cells) < rate - # translate cell mask to token mask, replicating using number of tokens per cell - self.mask_tokens = [ - (torch.ones(2, dtype=torch.bool) * (1 if m else 0)).to(torch.bool) - for idxs_cell, m in zip(idxs_cells_lens, mask_cells, strict=False) - ] - elif self.current_strategy == "cropping" or self.current_strategy == "causal": - pass - - else: - assert False, f"Unsupported masking strategy: {self.current_strategy}." - - return (self.mask_tokens, self.mask_channels) - - def mask_targets_idxs( - self, - idxs_cells, - idxs_cells_lens, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # mask_source_idxs is - assert (self.mask_tokens is not None) or (self.mask_tokens is not None) - idxs_ord_inv = torch.tensor([], dtype=torch.int64) - - # TODO: better handling of if statement - if self.current_strategy == "forecast": - num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() - self.mask_tokens = np.ones(num_tokens, dtype=np.bool) - - # inverse map for reordering to output data points in same order as input - idxs_ord = torch.cat([t for tt in idxs_cells for t in tt]) - idxs_ord_inv = torch.argsort(idxs_ord) - - else: - # masking strategies: target is complement of source - # TODO: ensure/enforce that forecast_offset==0 - if self.mask_tokens is not None: - self.mask_tokens = ~self.mask_tokens - if self.mask_channels is not None: - self.mask_channels = ~self.mask_channels - - # TODO: self.mask_tokens seems brittle in terms of naming - - return (self.mask_tokens, self.mask_channels, idxs_ord_inv) - - def mask_source( - self, - tokenized_data: list[torch.Tensor], - coords: torch.Tensor, - geoinfos: torch.Tensor, - source: torch.Tensor, - ) -> list[torch.Tensor]: - """ - Receives tokenized data, generates a mask, and returns the source data (unmasked) - and the permutation selection mask (perm_sel) to be used for the target. - - Args: - tokenized_data (list[torch.Tensor]): A list of tensors, where each tensor - represents the tokens for a cell. - - Returns: - list[torch.Tensor]: The unmasked tokens (model input). - """ - - token_lens = [len(t) for t in tokenized_data] - num_tokens = sum(token_lens) - - # If there are no tokens, return empty lists. - if num_tokens == 0: - return tokenized_data - - # Clean strategy selection - self.current_strategy = self._select_strategy() - - # Set the masking rate. - rate = self._get_sampling_rate() - - if rate == 0.0: - _logger.warning( - "masking_rate is 0. This will result in empty target. The sample will be skipped. " - + "If this occurs repeatedtly the masking settings likely need to be revised." - ) - - # Handle the special case where all tokens are masked - if rate == 1.0: - token_lens = [len(t) for t in tokenized_data] - self.perm_sel = [np.ones(tl, dtype=bool) for tl in token_lens] - source_data = [data[~p] for data, p in zip(tokenized_data, self.perm_sel, strict=True)] - return source_data - - # Implementation of different masking strategies. - # Generate a flat boolean mask for random, block, or healpix masking at cell level. - # Generate a 3D mask to apply to each cell for channel masking. - - if self.current_strategy == "random": - flat_mask = self.rng.uniform(0, 1, num_tokens) < rate - - elif self.current_strategy == "block": - flat_mask = np.zeros(num_tokens, dtype=bool) - block_size = int(np.round(rate * num_tokens)) - if block_size > 0 and num_tokens > 0: - start_index = self.rng.integers(0, max(1, num_tokens - block_size + 1)) - flat_mask[start_index : start_index + block_size] = True - - elif self.current_strategy == "healpix": - flat_mask = self._generate_healpix_mask(token_lens, rate) - - elif self.current_strategy == "channel": - mask = self._generate_channel_mask(tokenized_data, rate, coords, geoinfos, source) - - elif self.current_strategy == "causal": - mask = self._generate_causal_mask(tokenized_data, rate, coords, geoinfos, source) - - else: - assert False, f"Unknown masking strategy: {self.current_strategy}" - - # apply mask - - # if masking_strategy is channel, we need to handle the masking differently, - # since p is not 1D Boolean for the list of cells, but 3D to mask the channels in each cell. - if self.current_strategy == "channel": - self.perm_sel = mask - # In the source_data we will set the channels that are masked to 0.0. - source_data = [] - for data, p in zip(tokenized_data, self.perm_sel, strict=True): - if len(data) > 0: - data[p] = self.mask_value - source_data.append(data) - else: - source_data.append(data) - - elif self.current_strategy == "causal": - # Only select unmasked timesteps - self.perm_sel = mask - source_data = [] - for data, p in zip(tokenized_data, self.perm_sel, strict=True): - source_data.append(data[~p] if len(data) > 0 else data) - - else: - # Split the flat mask to match the structure of the tokenized data (list of lists) - # This will be perm_sel, as a class attribute, used to mask the target data. - split_indices = np.cumsum(token_lens)[:-1] - self.perm_sel = np.split(flat_mask, split_indices) - - # Apply the mask to get the source data (where mask is False) - source_data = [data[~p] for data, p in zip(tokenized_data, self.perm_sel, strict=True)] - - return source_data - - def mask_target( - self, - target_tokenized_data: list[list[torch.Tensor]], - coords: torch.Tensor, - geoinfos: torch.Tensor, - source: torch.Tensor, - ) -> list[torch.Tensor]: - """ - Applies the permutation selection mask to - the tokenized data to create the target data. - Handles cases where a cell has no target - tokens by returning an empty tensor of the correct shape. - - Args: - target_tokens_cells (list[list[torch.Tensor]]): List of lists of tensors for each cell. - coords (torch.Tensor): Coordinates tensor, used to determine feature dimension. - geoinfos (torch.Tensor): Geoinfos tensor, used to determine feature dimension. - source (torch.Tensor): Source tensor, used to determine feature dimension. - - Returns: - list[torch.Tensor]: The target data with masked tokens, one tensor per cell. - """ - - # check that self.perm_sel is set, and not None with an assert statement - assert self.perm_sel is not None, "Masker.perm_sel must be set before calling mask_target." - - # Pre-calculate the total feature dimension of a token to create - # correctly shaped empty tensors. - - feature_dim = self.dim_time_enc + coords.shape[-1] + geoinfos.shape[-1] + source.shape[-1] - - processed_target_tokens = [] - - # process all tokens used for embedding - for cc, pp in zip(target_tokenized_data, self.perm_sel, strict=True): - if len(cc) == 0: # Skip if there's no target data - pass - - if self.current_strategy == "channel": - # If masking strategy is channel, handle target tokens differently. - # We don't have Booleans per cell, instead per channel per cell, - # we set the unmasked channels to NaN so not in loss. - selected_tensors = [] - for c, p in zip(cc, pp, strict=True): - # slightly complicated as the first dimension of c varies with data in the cell. - # do not mask the first 8 channels, - # and set unmasked channels to nan - c[:, (self.dim_time_enc + coords.shape[-1] + geoinfos.shape[-1]) :][ - :, ~p[0, (self.dim_time_enc + coords.shape[-1] + geoinfos.shape[-1]) :] - ] = torch.nan - selected_tensors.append(c) - - elif self.current_strategy == "causal": - # select only the target times where mask is True - if len(cc) == len(pp): - selected_tensors = [c for i, c in enumerate(cc) if pp[i]] - elif len(pp) == 0: - selected_tensors = cc - else: # If length of target and mask doesn't match, create new mask - ratio = np.sum(pp) / len(pp) # Ratio of masked tokens in source - indx = max(1, int(ratio * len(cc))) # Get the same for target - selected_tensors = cc[-indx:] - - elif self.current_strategy == "healpix": - selected_tensors = ( - cc if len(pp) > 0 and pp[0] else [] - ) # All tokens inside healpix cell have the same mask - - elif self.current_strategy == "random": - # For random masking, we simply select the tensors where the mask is True. - # When there's no mask it's assumed to be False. This is done via strict=False - selected_tensors = [c for c, p in zip(cc, pp, strict=False) if p] - else: - raise NotImplementedError( - f"Masking strategy {self.current_strategy} is not supported." - ) - - # Append the selected tensors to the processed_target_tokens list. - if selected_tensors: - processed_target_tokens.append(torch.cat(selected_tensors)) - else: - processed_target_tokens.append( - torch.empty(0, feature_dim, dtype=coords.dtype, device=coords.device) - ) - - return processed_target_tokens - def _get_sampling_rate(self): """ Get the sampling, if requested by sampling it itself @@ -623,74 +342,136 @@ def _generate_causal_mask( return full_mask - def build_views_for_stream( + def build_samples_for_stream( self, + training_mode: str, num_cells: int, - teacher_cfg: dict, - student_cfg: dict, - relationship: str = "subset", + target_cfg: dict, + source_cfg: dict, ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: """ Construct teacher/student keep masks for a stream. SampleMetaData is currently just a dict with the masking params used. """ - strat_teacher = teacher_cfg.get("strategy", "random") - rate_teacher = teacher_cfg.get("rate") - t_cfg_extra = teacher_cfg.get("masking_strategy_config") + # get source and target configs; target defaults to source config + + source_num_samples = source_cfg.get("num_samples", 1) + source_strategy = source_cfg.get("masking_strategy", source_cfg.get("strategy", "random")) + source_masking_params = source_cfg.get("masking_strategy_config") + relationship = source_cfg.get("relationship", "complement") - teacher_keep_mask = self.generate_cell_keep_mask( - num_cells=num_cells, - strategy=strat_teacher, - rate=rate_teacher, - masking_strategy_config=t_cfg_extra, + if target_cfg is not None: + target_num_samples = target_cfg.get("num_samples", 1) + target_strategy = target_cfg.get("strategy", "random") + target_masking_params = target_cfg.get("masking_strategy_config") + else: + target_strategy = source_strategy + target_num_samples = source_num_samples + target_masking_params = source_masking_params + # # do other relationships make sense + # assert relationship == "complement" + + assert source_num_samples % target_num_samples == 0, ( + "number of source samples has to be multiple of target samples" ) - num_views = student_cfg.get("num_views", 1) - strat_student = student_cfg.get("masking_strategy", student_cfg.get("strategy", "random")) - rate_student = student_cfg.get("rate") - s_cfg_extra = student_cfg.get("masking_strategy_config") - - student_keep_masks: list[np.typing.NDArray] = [] - for _ in range(num_views): - base = self.generate_cell_keep_mask( - num_cells=num_cells, - strategy=strat_student, - rate=rate_student, - masking_strategy_config=s_cfg_extra, - ) + # translate settings into sampling masks + + # iterate over all target samples + target_masks: list[np.typing.NDArray] = [] + target_metadata: list[SampleMetaData] = [] + for _ in range(target_num_samples): + target_masks += [ + self._get_mask( + num_cells=num_cells, + strategy=target_strategy, + target_mask=None, + masking_strategy_config=target_masking_params, + ) + ] + target_metadata += [SampleMetaData(params=target_cfg)] + + # iterate over all source samples + source_masks: list[np.typing.NDArray] = [] + source_metadata: list[SampleMetaData] = [] + source_target_mapping = np.zeros(source_num_samples, dtype=np.int32) + for it in range(source_num_samples): + source_masks += [ + self._get_mask( + num_cells=num_cells, + strategy=source_strategy, + masking_strategy_config=source_masking_params, + target_mask=target_masks[it % target_num_samples], + relationship=relationship, + ) + ] + source_metadata += [SampleMetaData(params=target_cfg)] + source_target_mapping[it] = it % target_num_samples + + return ( + (target_masks, target_metadata), + (source_masks, source_metadata), + source_target_mapping, + ) + + def _get_mask( + self, + num_cells: int, + strategy: str | None = None, + rate: float | None = None, + masking_strategy_config: dict | None = None, + target_mask: np.typing.NDArray | None = None, + relationship: str = "subset", + ) -> np.typing.NDArray: + """Get effective mask, combining with target mask if specified. + + Parameters + ---------- + num_cells : int + Number of cells at data level (should equal 12 * 4**healpix_level). + strategy : str | None + Cell selection strategy: currently supports 'random' and 'healpix'. Uses + instance default if None. + rate : float | None + Fraction of parent cells (healpix) or data cells (random) to keep. Falls back + to instance masking_rate if None. + masking_strategy_config : dict | None + Optional override of strategy config (e.g., {'hl_mask': 3}). + constraint_keep_mask : np.ndarray | None + Optional boolean mask of allowed cells (True = allowed). Selection will be + limited to these cells. For subset/disjoint relationships. + + Returns + ------- + np.ndarray + Boolean array of shape [num_cells] where True indicates the cell is kept. + """ + + # handle cases where mask is directly derived from target_mask + if target_mask is not None: + if relationship == "complement": + mask = ~target_mask + return mask + + # get mask + mask = self._generate_cell_mask(num_cells, strategy, rate, masking_strategy_config) + + # handle cases where mask needs to be combined with target_mask + if target_mask is not None: if relationship == "subset": - keep = base & teacher_keep_mask + mask = mask & target_mask elif relationship == "disjoint": - keep = base & (~teacher_keep_mask) - else: - keep = base - student_keep_masks.append(keep) + mask = mask & (~target_mask) - metadata: list[SampleMetaData] = [ - SampleMetaData( - masking_params=teacher_cfg, - ) - ] - for _, _ in enumerate(student_keep_masks): - metadata.append( - SampleMetaData( - masking_params=student_cfg, - ) - ) + return mask - return teacher_keep_mask, student_keep_masks, metadata - - # --------------------------------------------------------------------- - # Cell-level keep mask generation (teacher/student view selection) - # --------------------------------------------------------------------- - def generate_cell_keep_mask( + def _generate_cell_mask( self, num_cells: int, strategy: str | None = None, rate: float | None = None, masking_strategy_config: dict | None = None, - constraint_keep_mask: np.typing.NDArray | None = None, ) -> np.typing.NDArray: """Generate a boolean keep mask at data healpix level (True = keep cell). @@ -715,6 +496,9 @@ def generate_cell_keep_mask( np.ndarray Boolean array of shape [num_cells] where True indicates the cell is kept. """ + + # get config for mask + strat = strategy or self.masking_strategy cfg = masking_strategy_config or self.masking_strategy_config keep_rate = rate if rate is not None else self.masking_rate @@ -733,9 +517,15 @@ def generate_cell_keep_mask( f"Cell selection strategy '{strat}' not supported for keep mask generation." ) + # generate cell mask + if strat == "random": - base_mask = self.rng.uniform(0, 1, num_cells) < keep_rate - else: # healpix hierarchical selection + mask = self.rng.uniform(0, 1, num_cells) < keep_rate + + elif strat == "forecast" or strat == "causal": + mask = np.ones(num_cells, dtype=np.bool) + + elif strat == "healpix": hl_data = self.healpix_level_data hl_mask = cfg.get("hl_mask") assert hl_mask is not None and hl_mask < hl_data, ( @@ -747,19 +537,19 @@ def generate_cell_keep_mask( # number of parents to KEEP num_parents_to_keep = int(np.round(keep_rate * num_parent_cells)) if num_parents_to_keep == 0: - base_mask = np.zeros(num_cells, dtype=bool) + mask = np.zeros(num_cells, dtype=bool) else: parent_ids = self.rng.choice(num_parent_cells, num_parents_to_keep, replace=False) child_offsets = np.arange(num_children_per_parent) child_indices = ( parent_ids[:, None] * num_children_per_parent + child_offsets ).reshape(-1) - base_mask = np.zeros(num_cells, dtype=bool) - base_mask[child_indices] = True + mask = np.zeros(num_cells, dtype=bool) + mask[child_indices] = True + + else: + assert False, "Unknown strategy." - # apply constraint if provided (only keep those cells within allowed) - if constraint_keep_mask is not None: - assert constraint_keep_mask.shape[0] == num_cells, "constraint_keep_mask wrong shape" - base_mask = base_mask & constraint_keep_mask + mask = to_bool_tensor(mask) - return base_mask + return mask diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index b1a4049f3..ddc8d7a16 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -311,8 +311,7 @@ def _build_stream_data_input( mask: torch.Tensor | None = None, ) -> tuple[StreamData, dict | None]: """ - Return one batch of data - Build a StreamData object for a single view (teacher or student). + Build model network input Args: stream_data : @@ -326,13 +325,6 @@ def _build_stream_data_input( StreamData with source and targets masked according to view_meta """ - # source input data - - # For now, keep only mask state of the final timestep - # (correspondsing to base_idx, first of the loop below) - # to ensure alignment with the target data for MTM/S-T. - final_mask_state = None - # iterate overall input steps for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): # TODO: check that we are not out of bounds when we go back in time @@ -352,17 +344,13 @@ def _build_stream_data_input( rdata, token_data, (time_win_source.start, time_win_source.end), - keep_mask=mask, + mask, ) - # for masked autoencoding, we want the mask state that overlaps with the target - if step == 0: - final_mask_state = mask_state - # collect data for stream stream_data.add_source(step, rdata, source_cells_lens, source_cells) - return stream_data, final_mask_state + return stream_data def _build_stream_data_output( self, @@ -373,9 +361,12 @@ def _build_stream_data_output( forecast_dt: int, output_data: list, output_tokens: list, - mask_state: dict | None = None, + target_mask, ) -> StreamData: - """ """ + """ + Generate stream data for output + + """ # collect for all forecast steps dt = self.forecast_offset + forecast_dt @@ -388,17 +379,15 @@ def _build_stream_data_output( token_data = output_tokens[step] stream_data.target_is_spoof = rdata.is_spoof - # None, or returned by get_target_coords - target_selection = None if "target_coords" in mode: - (tc, tc_l, target_selection) = self.tokenizer.get_target_coords( + (tc, tc_l) = self.tokenizer.get_target_coords( stream_info, self.sampling_rate_target, rdata, token_data, (time_win_target.start, time_win_target.end), - mask_state, + target_mask, ) stream_data.add_target_coords(fstep, tc, tc_l) @@ -409,8 +398,7 @@ def _build_stream_data_output( rdata, token_data, (time_win_target.start, time_win_target.end), - mask_state, - target_selection, + target_mask, ) stream_data.add_target_values(fstep, tt_cells, tt_c, tt_t, idxs_inv) @@ -426,7 +414,8 @@ def _build_stream_data( output_data: list, input_tokens: list, output_tokens: list, - mask, + target_mask, + source_mask, ) -> StreamData: """ Return one batch of data @@ -447,14 +436,14 @@ def _build_stream_data( dt = self.forecast_offset + forecast_dt stream_data = StreamData(base_idx, dt, self.num_healpix_cells) - stream_data, mask_state = self._build_stream_data_input( + stream_data = self._build_stream_data_input( mode, stream_data, base_idx, stream_info, input_data, input_tokens, - mask, + source_mask, ) stream_data = self._build_stream_data_output( @@ -465,7 +454,7 @@ def _build_stream_data( forecast_dt, output_data, output_tokens, - mask_state, + target_mask, ) return stream_data @@ -532,20 +521,20 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): TODO: these modes are not being used now. """ - if mode == "masking": - streams_data: list[StreamData] = [] + # get/coordinate masks + masks_streams = self._get_source_target_masks(mode) - # get/coordinate masks - masks_streams = self._get_source_target_masks() + if mode == "masking" or mode == "student_teacher": + streams_data: list[StreamData] = [] # Determine number of views direct from config (teacher & student views) - teacher_cfg = ( - self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} - ) - student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} - num_target_samples = int(teacher_cfg.get("num_views", 1)) - num_source_samples = int(teacher_cfg.get("num_views", 1)) * int( - student_cfg.get("num_views", 1) + target_cfg = self.training_cfg.get("target_input", {}) if self.training_cfg else {} + target_cfg = target_cfg if target_cfg is not None else {} + source_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} + # TODO: handle this cleaner (maybe enforce earlier that teacher_cfg is dict) + num_target_samples = int(target_cfg.get("num_samples", 1)) + num_source_samples = int(target_cfg.get("num_samples", 1)) * int( + source_cfg.get("num_samples", 1) ) # per teacher batch = ModelBatch(self.streams, num_source_samples, num_target_samples) @@ -575,7 +564,9 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # collect source data for current stream # loop over student views stream_data_source = {} - for sidx, mask in enumerate(source_masks): + for sidx, (target_mask, source_mask) in enumerate( + zip(target_masks, source_masks, strict=False) + ): # stream_data_source[name] = self._build_stream_data( sdata = self._build_stream_data( "target_coords target_values", @@ -586,30 +577,25 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - mask, + target_mask, + source_mask, ) stream_data_source[name] = sdata - # source meta info... - # source_meta_info = SampleMetaData(... - - source_metadata = source_metadata_list[sidx] # first is teacher - - # also want to add the mask to the metadata - source_metadata.mask = mask - - # TODO: seb check this # Map each student (source) to its teacher (target) t_idx = student_to_teacher[sidx] - batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) + batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata_list[sidx]) # num_input_steps? batch.source_samples[sidx].set_forecast_dt(forecast_dt) # stream_data_target can contain network input stream_data_target = {} - for t_idx, mask in enumerate(source_masks): + # for t_idx, mask in enumerate(source_masks): + for sidx, (target_mask, source_mask) in enumerate( + zip(target_masks, source_masks, strict=False) + ): # stream_data_target[name] = self._build_stream_data( sdata = self._build_stream_data( "target_values", @@ -620,20 +606,23 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - mask, + target_mask, + source_mask, ) stream_data_target[name] = sdata # get teacher config info + # TODO, TODO, TODO: is this correct? + t_idx = sidx target_metadata = target_metadata_list[t_idx] # also want to add the mask to the metadata - target_metadata.mask = mask + target_metadata.mask = None # target_mask # TODO: seb to check # Map target to all source students student_indices = [ - s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx + s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == sidx ] batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) batch.target_samples[t_idx].set_forecast_dt(forecast_dt) @@ -648,132 +637,16 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # add data for current stream streams_data += [v for k, v in stream_data_source.items()] - elif mode == "student_teacher": - streams_data: list[StreamData] = [] - - # get/coordinate masks - masks_streams = self._get_source_target_masks() - - # Determine number of views direct from config (teacher & student views) - teacher_cfg = ( - self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} - ) - student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} - num_target_samples = int(teacher_cfg.get("num_views", 1)) - num_source_samples = int(teacher_cfg.get("num_views", 1)) * int( - student_cfg.get("num_views", 1) - ) # per teacher - - batch = ModelBatch(self.streams, num_source_samples, num_target_samples) - - # for all streams - for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - name = stream_info["name"] - - ( - target_masks, - source_masks, - student_to_teacher, - target_metadata_list, - source_metadata_list, - ) = masks_streams[name] - - # input_data and output_data is conceptually consecutive but differs - # in source and target channels; overlap in one window when self.forecast_offset=0 - (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) - - # tokenize windows - # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps - input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) - output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) - - # collect source data for current stream - # loop over student views - stream_data_source = {} - for sidx, mask in enumerate(source_masks): - # stream_data_source[name] = self._build_stream_data( - sdata = self._build_stream_data( - "target_coords target_values", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - mask, - ) - - stream_data_source[name] = sdata - - # source meta info... - # source_meta_info = SampleMetaData(... - - source_metadata = source_metadata_list[sidx] # first is teacher - - # also want to add the mask to the metadata - source_metadata.mask = mask - - # TODO: seb check this - # Map each student (source) to its teacher (target) - t_idx = student_to_teacher[sidx] - batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata) - # num_input_steps? - batch.source_samples[sidx].set_forecast_dt(forecast_dt) - - # stream_data_target can contain network input - stream_data_target = {} - - for t_idx, mask in enumerate(target_masks): - # stream_data_target[name] = self._build_stream_data( - sdata = self._build_stream_data( - "target_values", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - mask, - ) - stream_data_target[name] = sdata - - # get teacher config info - target_metadata = target_metadata_list[t_idx] - - # also want to add the mask to the metadata - target_metadata.mask = mask - - # TODO: seb to check - # Map target to all source students - student_indices = [ - s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == t_idx - ] - batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) - batch.target_samples[t_idx].set_forecast_dt(forecast_dt) - - # TODO: build batch - # source_input - # target_input - # source_output - # target_output - - # TOOD: remove - # add data for current stream - streams_data += [v for k, v in stream_data_source.items()] elif mode == "diffusion_forecast": streams_data: list[StreamData] = [] # Determine number of views direct from config (teacher & student views) - teacher_cfg = ( - self.training_cfg.get("teacher_model_input", {}) if self.training_cfg else {} - ) + teacher_cfg = self.training_cfg.get("target_input", {}) if self.training_cfg else {} student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} - num_target_samples = int(teacher_cfg.get("num_views", 1)) - num_source_samples = int(teacher_cfg.get("num_views", 1)) * int( - student_cfg.get("num_views", 1) + num_target_samples = int(teacher_cfg.get("num_samples", 1)) + num_source_samples = int(teacher_cfg.get("num_samples", 1)) * int( + student_cfg.get("num_samples", 1) ) # per teacher batch = ModelBatch(self.streams, num_source_samples, num_target_samples) @@ -842,6 +715,9 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # get teacher config info target_metadata = target_metadata + # TODO: + # target.mask = + # TODO: handle this for different number of source timesteps target_metadata.noise_level_rn = source_metadata.noise_level_rn @@ -864,62 +740,30 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): return streams_data, batch - def _get_source_target_masks(self): + def _get_source_target_masks(self, training_mode): """ Generate source and target masks for all streams - according to the student-teacher configuration """ masks = {} for stream_info in self.streams: - teacher_cfg = self.training_cfg.get("teacher_model_input", {}) - student_cfg = self.training_cfg.get("model_input", {}) - relationship = student_cfg.get("relationship") - - # number of teacher views - num_teacher_views = int(teacher_cfg.get("num_views", 1)) - - # Convert to torch.bool - def to_bool_tensor(arr): - if arr is None: - return None - return torch.from_numpy(np.asarray(arr, dtype=bool)).to(torch.bool) - - # renaming here - target_masks: list[torch.Tensor] = [] - source_masks: list[torch.Tensor] = [] - student_to_teacher: list[int] = [] - target_metadata: list[SampleMetaData] = [] - source_metadata: list[SampleMetaData] = [] - - # add a loop over num_teacher_views, generate students for each teacher - for _ in range(num_teacher_views): - # Build one teacher and its student views - t_keep_np, s_keeps_np, metadata = self.tokenizer.masker.build_views_for_stream( - self.num_healpix_cells, - teacher_cfg=teacher_cfg, - student_cfg=student_cfg, - relationship=relationship, - ) - - # append teacher mask - t_tensor = to_bool_tensor(t_keep_np) - target_masks.append(t_tensor) - target_metadata.append(metadata[0]) # TODO: first is teacher - - # this teacher's students and mapping - for s_np, md in zip(s_keeps_np or [], metadata[1:], strict=True): - source_masks.append(to_bool_tensor(s_np)) - # append 0, 1, ... depending on which teacher we did - source_metadata.append(md) - student_to_teacher.append(len(target_masks) - 1) + target_cfg = self.training_cfg.get("target_input", {}) + source_cfg = self.training_cfg.get("model_input", {}) + + # Build one teacher and its student views + target_data, source_data, mapping = self.tokenizer.masker.build_samples_for_stream( + training_mode, + self.num_healpix_cells, + target_cfg=target_cfg, + source_cfg=source_cfg, + ) masks[stream_info["name"]] = ( - target_masks, - source_masks, - student_to_teacher, - target_metadata, - source_metadata, + target_data[0], + source_data[0], + mapping, + target_data[1], + source_data[1], ) return masks diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 7c6bb9071..edcff081b 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -70,13 +70,43 @@ def get_tokens_windows(self, stream_info, data, pad_tokens): return tokens + def cell_to_token_mask(self, idxs_cells, idxs_cells_lens, mask): + """ """ + + mask_tokens, mask_channels = None, None + num_tokens = torch.tensor([len(t) for t in idxs_cells_lens]).sum().item() + + # If there are no tokens, return empty lists. + if num_tokens == 0: + return (mask_tokens, mask_channels) + + # TODO, TODO, TODO: use np.repeat + # https://stackoverflow.com/questions/26038778/repeat-each-values-of-an-array-different-times + # build token level mask: for each cell replicate the keep flag across its tokens + token_level_flags: list[np.typing.NDArray] = [] + for km, lens_cell in zip(mask, idxs_cells_lens, strict=True): + num_tokens_cell = len(lens_cell) + if num_tokens_cell == 0: + continue + token_level_flags.append( + np.ones(num_tokens_cell, dtype=bool) + if km + else np.zeros(num_tokens_cell, dtype=bool) + ) + if token_level_flags: + mask_tokens = np.concatenate(token_level_flags) + else: + mask_tokens = np.array([], dtype=bool) + + return (mask_tokens, mask_channels) + def get_source( self, stream_info: dict, rdata: IOReaderData, idxs_cells_data, time_win: tuple, - keep_mask: torch.Tensor | None = None, + cell_mask: torch.Tensor, ): stream_id = stream_info["stream_id"] is_diagnostic = stream_info.get("diagnostic", False) @@ -92,22 +122,14 @@ def get_source( } return (source_tokens_cells, source_tokens_lens, mask_state) - # # create tokenization index + # create tokenization index (idxs_cells, idxs_cells_lens) = idxs_cells_data # select strategy from XXX depending on stream and if student or teacher - # Optional per-cell keep_mask (boolean) converts to numpy for Masker override. - if keep_mask is not None: - keep_np = keep_mask.cpu().numpy().astype(bool) - (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - idxs_cells, idxs_cells_lens, keep_mask=keep_np - ) - else: - (mask_tokens, mask_channels) = self.masker.mask_source_idxs( - idxs_cells, - idxs_cells_lens, - ) + (mask_tokens, mask_channels) = self.cell_to_token_mask( + idxs_cells, idxs_cells_lens, cell_mask + ) source_tokens_cells, source_tokens_lens = tokenize_apply_mask_source( idxs_cells, @@ -183,20 +205,14 @@ def get_target_coords( rdata: IOReaderData, token_data, time_win: tuple, - mask_state: dict | None = None, + cell_mask, + # mask_state: dict | None = None, ): # create tokenization index (idxs_cells, idxs_cells_lens) = token_data - # Apply per-view mask state if provided - if mask_state is not None: - self.masker.current_strategy = mask_state.get("strategy", self.masker.masking_strategy) - self.masker.mask_tokens = mask_state.get("mask_tokens") - self.masker.mask_channels = mask_state.get("mask_channels") - - (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - idxs_cells, - idxs_cells_lens, + (mask_tokens, mask_channels) = self.cell_to_token_mask( + idxs_cells, idxs_cells_lens, cell_mask ) # TODO: split up @@ -214,33 +230,33 @@ def get_target_coords( encode_times_target, ) - selection = self._select_target_subset(stream_info, coords_local.shape[0]) - - if selection is not None and coords_local.numel() > 0: - # use nice index_select method - coords_local = coords_local.index_select(0, selection.to(coords_local.device)) - - # coords_per_cell is trickier - if selection is not None and coords_per_cell.numel() > 0: - total_points = int(coords_per_cell.sum().item()) - if total_points == 0: - coords_per_cell = torch.zeros_like(coords_per_cell) - else: - cell_ids = torch.repeat_interleave( - torch.arange(coords_per_cell.shape[0], dtype=torch.long), - coords_per_cell.to(torch.long), - ) - if cell_ids.numel() == 0: - coords_per_cell = torch.zeros_like(coords_per_cell) - else: - new_counts = torch.bincount( - cell_ids[selection.to(cell_ids.device)], - minlength=coords_per_cell.shape[0], - ) - coords_per_cell = new_counts.to(dtype=coords_per_cell.dtype) + # selection = self._select_target_subset(stream_info, coords_local.shape[0]) + + # if selection is not None and coords_local.numel() > 0: + # # use nice index_select method + # coords_local = coords_local.index_select(0, selection.to(coords_local.device)) + + # # coords_per_cell is trickier + # if selection is not None and coords_per_cell.numel() > 0: + # total_points = int(coords_per_cell.sum().item()) + # if total_points == 0: + # coords_per_cell = torch.zeros_like(coords_per_cell) + # else: + # cell_ids = torch.repeat_interleave( + # torch.arange(coords_per_cell.shape[0], dtype=torch.long), + # coords_per_cell.to(torch.long), + # ) + # if cell_ids.numel() == 0: + # coords_per_cell = torch.zeros_like(coords_per_cell) + # else: + # new_counts = torch.bincount( + # cell_ids[selection.to(cell_ids.device)], + # minlength=coords_per_cell.shape[0], + # ) + # coords_per_cell = new_counts.to(dtype=coords_per_cell.dtype) # pass the selection back for use in get_target_values - return (coords_local, coords_per_cell, selection) + return (coords_local, coords_per_cell) def get_target_values( self, @@ -249,21 +265,15 @@ def get_target_values( rdata: IOReaderData, token_data, time_win: tuple, - mask_state: dict | None = None, - selection: torch.Tensor | None = None, + cell_mask, + # mask_state: dict | None = None, + # selection: torch.Tensor | None = None, ): # create tokenization index (idxs_cells, idxs_cells_lens) = token_data - # Apply per-view mask state if provided - if mask_state is not None: - self.masker.current_strategy = mask_state.get("strategy", self.masker.masking_strategy) - self.masker.mask_tokens = mask_state.get("mask_tokens") - self.masker.mask_channels = mask_state.get("mask_channels") - - (mask_tokens, mask_channels, idxs_ord_inv) = self.masker.mask_targets_idxs( - idxs_cells, - idxs_cells_lens, + (mask_tokens, mask_channels) = self.cell_to_token_mask( + idxs_cells, idxs_cells_lens, cell_mask ) data, datetimes, coords, _, _ = tokenize_apply_mask_target( @@ -280,22 +290,25 @@ def get_target_values( encode_times_target, ) - if selection is None: - selection = self._select_target_subset(stream_info, data.shape[0]) + # if selection is None: + # selection = self._select_target_subset(stream_info, data.shape[0]) - if selection is not None and data.numel() > 0: - device_sel = selection.to(data.device) - data = data.index_select(0, device_sel) - coords = coords.index_select(0, device_sel) - if idxs_ord_inv.numel() > 0: - idxs_ord_inv = idxs_ord_inv.index_select(0, device_sel) + # if selection is not None and data.numel() > 0: + # device_sel = selection.to(data.device) + # data = data.index_select(0, device_sel) + # coords = coords.index_select(0, device_sel) + # if idxs_ord_inv.numel() > 0: + # idxs_ord_inv = idxs_ord_inv.index_select(0, device_sel) - # datetimes is numpy here - np_sel = selection.cpu().numpy() - datetimes = datetimes[np_sel] + # # datetimes is numpy here + # np_sel = selection.cpu().numpy() + # datetimes = datetimes[np_sel] # TODO: shuffling + # TODO: idxs_ord_inv + idxs_ord_inv = None + # selection not passed on, we call get_target_coords first return (data, datetimes, coords, idxs_ord_inv) From c8a26d7183edccf697287fb9b521605bc3825425 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 00:11:37 +0100 Subject: [PATCH 138/344] Commit --- config/default_config.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 8d6ae4026..91d374100 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -121,17 +121,17 @@ training_config: model_input: masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher rate: 0.4 # Masking rate to use for model input - num_views: 1 # if student-teacher, the number of local (student) views to generate + num_samples: 1 # if student-teacher, the number of local (student) views to generate hl_mask : 4 # healpix level to use for healpix masking strategy - relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - teacher_model_input: - strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - rate: 0.4 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) - num_views: 1 # number of teacher views to generate - hl_mask : 0 # healpix level to use for healpix masking strategy - # keep_m: 100 # Alternative to rate: keep exactly this many parent cells - rate_sampling: true # randomly sample the rate per batch + # target_input: + # strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" + # rate: 0.4 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + # num_samples: 1 # number of teacher views to generate + # hl_mask : 0 # healpix level to use for healpix masking strategy + # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells + # rate_sampling: true # randomly sample the rate per batch From 23e02679c9cd1110639aa74ff38799d99ff6a770 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 00:11:48 +0100 Subject: [PATCH 139/344] Update --- uv.lock | 76 +++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/uv.lock b/uv.lock index 4cdcbdcc5..8d5e11878 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.12.*" resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -513,7 +513,7 @@ wheels = [ [[package]] name = "earthkit-data" -version = "0.14.4" +version = "0.18.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cfgrib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -536,9 +536,9 @@ dependencies = [ { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "xarray", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/43/99/d41bf77e4769ed4146ad1d66b94b05d8c17ec168bc93f145b82bcfd40c40/earthkit_data-0.14.4.tar.gz", hash = "sha256:d3d5d7b920b57a4abdbfc3add56bf167bb2d1eec151b6f6d36abea766b06929a", size = 4851784, upload-time = "2025-06-06T17:00:24.357Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/a4/a78a78258093ea85f11bf2b5b90274403f0c88fe82c2b53070f4ab0d4bdb/earthkit_data-0.18.2.tar.gz", hash = "sha256:fbbb9ade7898b872456913af70dea2f680734cd414747dd368739804794670df", size = 5554363, upload-time = "2025-11-18T19:35:09.109Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/a3/50f356541db8b74565359b604412b4392bc4e053fbc8a45e3f14327517de/earthkit_data-0.14.4-py3-none-any.whl", hash = "sha256:e56a0fe22f13648ac0becb2e075d8ee41da351c072afced5f16154a0da1d4083", size = 369012, upload-time = "2025-06-06T17:00:22.195Z" }, + { url = "https://files.pythonhosted.org/packages/57/cb/d6d435c7ce7782fa3c7aaf260f779cab80f6944c13a1546a0a3aed797b69/earthkit_data-0.18.2-py3-none-any.whl", hash = "sha256:0c61b5f61c7decb921ff3543f9c73b4988b6f2c88d6e8b68ee1ee34bee9d3573", size = 389574, upload-time = "2025-11-18T19:35:07.334Z" }, ] [[package]] @@ -570,33 +570,57 @@ wheels = [ [[package]] name = "earthkit-utils" -version = "0.0.1" +version = "0.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "array-api-compat", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/51/d7/91ce33376f48a0dc5993008eebbb12415853fb87361ac849a834db505a35/earthkit_utils-0.0.1.tar.gz", hash = "sha256:8bb41d9b6c8cfc1e0d330cf9801183301e7febd03b6c87082ce3d52d129939e3", size = 19791, upload-time = "2025-04-04T16:00:49.79Z" } +sdist = { url = "https://files.pythonhosted.org/packages/06/50/8e4d6a75a11db4d6c86d5d921dfb3b89c71bc19231323f01c26fde449a43/earthkit_utils-0.1.2.tar.gz", hash = "sha256:f0e6059c6fc40cc0c7f76ac52e3725ed1b2d837e43f5218f946b2a7dc012ff82", size = 23716, upload-time = "2025-09-10T12:41:42.529Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6e/9b/568d8e53ea6804084284015d79a757c11fa41f907b56c523c9e11a7ee679/earthkit_utils-0.0.1-py3-none-any.whl", hash = "sha256:1732ac37d9c4c97f56b733526052c047c49854438a7fc35ed775e2c204a7d825", size = 15431, upload-time = "2025-04-04T16:00:48.607Z" }, + { url = "https://files.pythonhosted.org/packages/54/86/1ffe2e8af8dbfc2c691fe5653b109e2a587e5ab797a837cd5f7c86dd0dac/earthkit_utils-0.1.2-py3-none-any.whl", hash = "sha256:150cf68ce5228dadec1b50bc6f8ff0b68d69b702a4019056b3bd0149b1ed8236", size = 21458, upload-time = "2025-09-10T12:41:41.531Z" }, ] [[package]] name = "eccodes" -version = "2.41.0" +version = "2.44.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "cffi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodeslib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "findlibs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8a/c2/e2f8d98dea0b2d8d77c99edb72b5831286abfaf80d94edf13ad127f6979a/eccodes-2.41.0.tar.gz", hash = "sha256:f3e209f5da5a7fcee4942295db4ee7888e077bd2e0342e6170ec5fedb9b29840", size = 2268345, upload-time = "2025-04-10T10:18:00.637Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/2a/9242d0a83de707ed401906a34bfe1d9a3af616abf498580ef73a6e8cebd5/eccodes-2.44.0.tar.gz", hash = "sha256:8aba9316749349e64db7d075100bff8e24a892814e3529132ec97b6d787eb8f4", size = 2310714, upload-time = "2025-10-03T14:02:37.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/a8/4d3b00f09440b269da208831b450a77e150ecfd1ac3981ca83d984ede4bd/eccodes-2.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:20864247343bf88df88eafbf811fa90c290c45ed32d24f046238bd0f1684e16e", size = 7247248, upload-time = "2025-10-03T14:02:05.837Z" }, + { url = "https://files.pythonhosted.org/packages/dd/b8/9d15cea1f63fb2e1e14fda4160c355e6187e69b71b848c05faaae08b2e6c/eccodes-2.44.0-py3-none-any.whl", hash = "sha256:c3f11041bde7c3f53767c5bbed608c43695f257c09c58bb4de24bcd9cdae4e3a", size = 83465, upload-time = "2025-10-03T14:02:36.181Z" }, +] + +[[package]] +name = "eccodeslib" +version = "2.44.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "fckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] wheels = [ - { url = "https://files.pythonhosted.org/packages/35/d5/7803aa1bbff4161b147c11cd6531d421a2ad38a0bb2fd29a7265fb369c3d/eccodes-2.41.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:56803ac38e5b50377054cb2b944c8982d6fcfca0c25b4d04fe1ef94ba81b202c", size = 6557422, upload-time = "2025-04-10T10:14:50.202Z" }, - { url = "https://files.pythonhosted.org/packages/d1/17/bf8f714f5dd483d0da11515dbcb1b4f0992e900abef540c318a93b55edb9/eccodes-2.41.0-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:ae3a0f62a4b1107fe9f8362d58e4f452891ccbcc6aaecca5437724223f71a974", size = 6659630, upload-time = "2025-04-10T10:17:05.746Z" }, - { url = "https://files.pythonhosted.org/packages/45/4a/7a45f8fc7d8f2047b023befd17155fa7d2d1274feda9796b1e69b68b7033/eccodes-2.41.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:bb0182f7ce3c907671860ab554430be053f37f227789a7a27b2d06118fe48c97", size = 7431350, upload-time = "2025-04-10T10:13:40.244Z" }, - { url = "https://files.pythonhosted.org/packages/31/f7/f48a3ab347941243064060e3b0139aeb8d9414f1775f54239bed7fe66a29/eccodes-2.41.0-cp312-cp312-win_amd64.whl", hash = "sha256:81ca80d251be3fa66c42c020538cd67b12ed6e7c79e1e10299dc36dc07d28678", size = 6239631, upload-time = "2025-04-10T10:15:36.13Z" }, - { url = "https://files.pythonhosted.org/packages/bd/42/ac29e37149f36807e8f979707f5ae0d466d4a2c4b340597e2177809a016b/eccodes-2.41.0-py3-none-any.whl", hash = "sha256:f3f4444757aac6a249cc47947dee5660309d48854ebfc5e6ca8515374398e1bf", size = 44012, upload-time = "2025-04-10T10:17:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/61/21/555e76b8dfa2ac050df8e638e9b91c6e671c3e2ba0abc2213e8df84d1e5c/eccodeslib-2.44.0.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:f28cebbfae6594ed393214f59b828b55238bdd2c61e4f533e96098c2e19bb47f", size = 8926805, upload-time = "2025-11-25T11:59:59.543Z" }, + { url = "https://files.pythonhosted.org/packages/1c/b8/e50cfc8588a85f31568ef02f6913b42d44e36c476cd1aaf61f2489e6749b/eccodeslib-2.44.0.7-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:e191ce8d33fce4c796fe6ffb57652e0faa19e61b3ae8e9d0adacb50fc824d77b", size = 8723732, upload-time = "2025-11-25T11:59:18.999Z" }, + { url = "https://files.pythonhosted.org/packages/55/7f/a81915d7693e8d46df61b44d5bbc1717c8b41deaf3084831b369191ee24c/eccodeslib-2.44.0.7-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:c1730012671b8c6a70001fc9f6fa4a557ca8d0888c2f76eae81ab6f978190cad", size = 20983542, upload-time = "2025-11-25T12:01:17.72Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ae/a8f0fc3468e7d0e3cbcf7d2d51d55c53a785f7e3440f9b4546a0994b29b9/eccodeslib-2.44.0.7-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:61774756ea652e3bcea436e8bd6dcbbd3e044f0a41b39d748e110fada48ffdbe", size = 20853439, upload-time = "2025-11-25T12:04:15.242Z" }, +] + +[[package]] +name = "eckitlib" +version = "1.32.3.7" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/02/55468294aa6fb836d1a4d3d18459fad467e2f622df980e59181da2ed80a4/eckitlib-1.32.3.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:2344c7250b28f3cad2110ee6703c2a58714ed2d012c66fc6b38edb39eb567cd7", size = 2925833, upload-time = "2025-11-25T12:00:05.094Z" }, + { url = "https://files.pythonhosted.org/packages/91/19/33ba5777745f1f237ee6a549fb585afc6dde6f51672ea269d0285237214e/eckitlib-1.32.3.7-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:3964fee0a5cf828886957d87d053d42f29081f4d8b0f3b9f78fcc4a2401f6335", size = 3028987, upload-time = "2025-11-25T11:59:25.777Z" }, + { url = "https://files.pythonhosted.org/packages/7e/42/51dbb879c0e4b3a70dfa3463c24c41aca5097a2cddc68accacd1f7b572e8/eckitlib-1.32.3.7-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:de323af51b6560b22de2fb0ec4ad98c5f318975688526d04351b13428ad72de6", size = 43683895, upload-time = "2025-11-25T12:01:26.934Z" }, + { url = "https://files.pythonhosted.org/packages/40/88/2e751d24663b15a50e8aec49332020cb5e3c1305e6dc229e8cf396f92809/eckitlib-1.32.3.7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3825984742c687db0feed540c8ba4bb26472859711e78c560711eba3fe6d12cf", size = 44585482, upload-time = "2025-11-25T12:04:28.06Z" }, ] [[package]] @@ -670,6 +694,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/2b/0817a2b257fe88725c25589d89aec060581aabf668707a8d03b2e9e0cb2a/fastjsonschema-2.21.1-py3-none-any.whl", hash = "sha256:c9e5b7e908310918cf494a434eeb31384dd84a98b57a30bcb1f535015b554667", size = 23924, upload-time = "2024-12-02T10:55:07.599Z" }, ] +[[package]] +name = "fckitlib" +version = "0.14.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/07/921d9adf99b4cb0983f4327f32e76718e88e1fbc78eb253e6a33ce1004e4/fckitlib-0.14.1.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:cd3213c33656e6bb7652cbbf4dac7b466d294f42188ef0f4c7f69aab124c006e", size = 411476, upload-time = "2025-11-25T12:00:07.712Z" }, + { url = "https://files.pythonhosted.org/packages/e8/e8/3339b155d2486a3710bf59274259bee846325b7bad5aaa269565e2b76838/fckitlib-0.14.1.7-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:073c97897e032e51ff64028fe4602e50a8092c40563d147f06641f0f7b4f8f23", size = 417158, upload-time = "2025-11-25T11:59:29.004Z" }, + { url = "https://files.pythonhosted.org/packages/bd/46/06c9fd28b580a8fc59f7a889a7710fdd4afe6f029325ac908d687bdbc3eb/fckitlib-0.14.1.7-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6af3c23fa3aecf8dc45d7dc610ed542b807ae81e73f2e164c74e90ebbcc0252e", size = 1342966, upload-time = "2025-11-25T12:01:31.618Z" }, + { url = "https://files.pythonhosted.org/packages/f7/8f/d85d55b3582e168a0221a71b2f54c28b02b1d7ce78b37926cc6019da7945/fckitlib-0.14.1.7-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d70bea78e3248780b4b49fce4683ebbf9419e2b5cfd9a9fdd9512ead8627aa3d", size = 12761273, upload-time = "2025-11-25T12:04:35.044Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -962,7 +1000,7 @@ name = "jinja2" version = "3.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "markupsafe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } wheels = [ @@ -2850,6 +2888,10 @@ version = "0.1.0" source = { editable = "packages/evaluate" } dependencies = [ { name = "cartopy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "earthkit-data", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodeslib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "omegaconf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "panel", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "plotly", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2870,6 +2912,10 @@ dev = [ [package.metadata] requires-dist = [ { name = "cartopy", specifier = ">=0.24.1" }, + { name = "earthkit-data", specifier = "==0.18.2" }, + { name = "eccodes", specifier = "==2.44.0" }, + { name = "eccodeslib", specifier = "==2.44.0.7" }, + { name = "eckitlib", specifier = "==1.32.3.7" }, { name = "omegaconf" }, { name = "panel" }, { name = "plotly", specifier = ">=6.2.0" }, From 9f5e49ce11a7eea8031a4ec963f3d595eb88e4d5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 00:20:25 +0100 Subject: [PATCH 140/344] Fixed uv.lock --- uv.lock | 81 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/uv.lock b/uv.lock index d64abc499..e7c15991d 100644 --- a/uv.lock +++ b/uv.lock @@ -553,12 +553,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a", size = 11178, upload-time = "2020-04-20T14:23:36.581Z" }, ] +[[package]] +name = "dill" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/80/630b4b88364e9a8c8c5797f4602d0f76ef820909ee32f0bacb9f90654042/dill-0.4.0.tar.gz", hash = "sha256:0633f1d2df477324f53a895b02c901fb961bdbf65a17122586ea7019292cbcf0", size = 186976, upload-time = "2025-04-16T00:41:48.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/3d/9373ad9c56321fdab5b41197068e1d8c25883b3fea29dd361f9b55116869/dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049", size = 119668, upload-time = "2025-04-16T00:41:47.671Z" }, +] + +[[package]] +name = "donfig" +version = "0.8.1.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, +] + [[package]] name = "earthkit-data" version = "0.18.2" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cfgrib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "dask", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "deprecation", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "earthkit-meteo", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "earthkit-utils", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "entrypoints", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "jinja2", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "jsonschema", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "lru-dict", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "markdown", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "multiurl", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "netcdf4", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pandas", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pdbufr", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "pyyaml", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "tqdm", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "xarray", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c5/a4/a78a78258093ea85f11bf2b5b90274403f0c88fe82c2b53070f4ab0d4bdb/earthkit_data-0.18.2.tar.gz", hash = "sha256:fbbb9ade7898b872456913af70dea2f680734cd414747dd368739804794670df", size = 5554363, upload-time = "2025-11-18T19:35:09.109Z" } wheels = [ @@ -609,11 +648,11 @@ name = "eccodes" version = "2.44.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "attrs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "cffi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "eccodeslib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "findlibs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "attrs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "cffi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodeslib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "findlibs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1a/2a/9242d0a83de707ed401906a34bfe1d9a3af616abf498580ef73a6e8cebd5/eccodes-2.44.0.tar.gz", hash = "sha256:8aba9316749349e64db7d075100bff8e24a892814e3529132ec97b6d787eb8f4", size = 2310714, upload-time = "2025-10-03T14:02:37.462Z" } wheels = [ @@ -626,8 +665,8 @@ name = "eccodeslib" version = "2.44.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "fckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eckitlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "fckitlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/61/21/555e76b8dfa2ac050df8e638e9b91c6e671c3e2ba0abc2213e8df84d1e5c/eccodeslib-2.44.0.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:f28cebbfae6594ed393214f59b828b55238bdd2c61e4f533e96098c2e19bb47f", size = 8926805, upload-time = "2025-11-25T11:59:59.543Z" }, @@ -726,7 +765,7 @@ name = "fckitlib" version = "0.14.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eckitlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/6a/07/921d9adf99b4cb0983f4327f32e76718e88e1fbc78eb253e6a33ce1004e4/fckitlib-0.14.1.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:cd3213c33656e6bb7652cbbf4dac7b466d294f42188ef0f4c7f69aab124c006e", size = 411476, upload-time = "2025-11-25T12:00:07.712Z" }, @@ -1064,7 +1103,7 @@ name = "jinja2" version = "3.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "markupsafe", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } wheels = [ @@ -3054,18 +3093,18 @@ name = "weathergen-evaluate" version = "0.1.0" source = { editable = "packages/evaluate" } dependencies = [ - { name = "cartopy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "earthkit-data", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "eccodes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "eccodeslib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "omegaconf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "panel", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "plotly", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "weathergen-metrics", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "xhistogram", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "xskillscore", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "cartopy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "earthkit-data", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodeslib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eckitlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "omegaconf", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "panel", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "plotly", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "weathergen-common", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "weathergen-metrics", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "xhistogram", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "xskillscore", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] [package.dev-dependencies] From 3641e1f81cf0b5384a33591fb7e157ed3e7b667a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 00:20:42 +0100 Subject: [PATCH 141/344] Fix for integration test --- integration_tests/streams/era5_small.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integration_tests/streams/era5_small.yml b/integration_tests/streams/era5_small.yml index 2a06eb7df..6910a427a 100644 --- a/integration_tests/streams/era5_small.yml +++ b/integration_tests/streams/era5_small.yml @@ -9,6 +9,7 @@ ERA5 : type : anemoi + stream_id: 0 filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] loss_weight : 1. source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] @@ -36,4 +37,4 @@ ERA5 : num_heads : 2 pred_head : ens_size : 1 - num_layers : 1 \ No newline at end of file + num_layers : 1 From 9a1a6a94de20a36a3808066f5b78e16cba47d4b5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 13:12:52 +0100 Subject: [PATCH 142/344] Re-enabled multi-source training --- src/weathergen/datasets/multi_stream_data_sampler.py | 1 - src/weathergen/train/trainer.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 3f7d08935..e72985950 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -641,7 +641,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # add data for current stream streams_data += [v for k, v in stream_data_source.items()] - elif mode == "diffusion_forecast": streams_data: list[StreamData] = [] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index cad092787..58995055f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -560,8 +560,9 @@ def train(self, mini_epoch): outputs = [] for view in batch[-1].source_samples: # TODO remove when ModelBatch and Sample get a to_device() - streams_data = [[view.streams_data["ERA5"]]] - streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] + streams_data = [ + [v.to_device(self.device) for k, v in view.streams_data.items()] + ] source_cell_lens = view.source_cell_lens source_cell_lens = [b.to(self.device) for b in source_cell_lens] target_coords_idxs = view.target_coords_idx @@ -580,8 +581,9 @@ def train(self, mini_epoch): targets_and_auxs = [] for view in batch[-1].target_samples: # TODO remove when ModelBatch and Sample get a to_device() - streams_data = [[view.streams_data["ERA5"]]] - streams_data = [[d.to_device(self.device) for d in db] for db in streams_data] + streams_data = [ + [v.to_device(self.device) for k, v in view.streams_data.items()] + ] source_cell_lens = view.source_cell_lens source_cell_lens = [b.to(self.device) for b in source_cell_lens] target_coords_idxs = view.target_coords_idx From 402b8de1ebd06b3ec32b5e84388d0ad8e00ae8a8 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Wed, 3 Dec 2025 17:11:15 +0100 Subject: [PATCH 143/344] 1390 - Adapt forward pass of new batch object (#1391) * Add to device to ModelBatch, etc & adapt model TODO adapt validate and inference TODO test forecasting and multiple stream because predict changed substantially * Rename view to sample and fix validate * Revert predict function and fix inference * Fix invalid access with mask * Linting * Fixed handling of target_idxs and other minor issues --------- Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Co-authored-by: Christian Lessig --- packages/common/src/weathergen/common/io.py | 2 +- src/weathergen/datasets/batch.py | 32 ++++++++- .../datasets/multi_stream_data_sampler.py | 6 +- src/weathergen/datasets/utils.py | 32 ++++----- src/weathergen/model/engines.py | 60 ++++++++-------- src/weathergen/model/model.py | 25 ++++--- .../loss_modules/loss_module_physical.py | 8 ++- src/weathergen/train/trainer.py | 71 +++++++++++-------- 8 files changed, 139 insertions(+), 97 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2d10f157c..4978bc777 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -651,7 +651,7 @@ def _extract_sources( source: IOReaderData = self.sources[sample][stream_idx] assert source.data.shape[1] == len(channels), ( - "Number of source channel names does not align with source data" + f"Number of source channel names {len(channels)} does not align with source data." ) source_dataset = OutputDataset( diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 190de38ad..c4b68624b 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -29,7 +29,7 @@ class SampleMetaData: class Sample: # keys: stream name, values: SampleMetaData - meta_info: dict + meta_info: dict[str | SampleMetaData] # data for all streams # keys: stream_name, values: StreamData @@ -40,6 +40,7 @@ class Sample: # these two need to live in ModelBatch as they are flattened! # this should be a dict also lives in ModelBatch source_cell_lens: list[torch.Tensor] | None + # TODO why is this a list of lists in practice, but the type says list of tensors? target_coords_idx: list[torch.Tensor] | None def __init__(self, streams: dict) -> None: @@ -55,6 +56,27 @@ def __init__(self, streams: dict) -> None: self.forecast_dt: int | None = None + def to_device(self, device) -> None: + if self.source_cell_lens is not None: + # iterate over forecast steps + self.source_cell_lens = [t.to(device) for t in self.source_cell_lens] + + if self.target_coords_idx is not None: + target_coords_idx_new = {} + for k, v in self.target_coords_idx.items(): + # iterate over forecast steps + target_coords_idx_new[k] = [vv.to(device) for vv in v] + self.target_coords_idx = target_coords_idx_new + + for key in self.meta_info.keys(): + self.meta_info[key].mask = ( + self.meta_info[key].mask.to(device) if self.meta_info[key].mask else None + ) + + for key, val in self.streams_data.items(): + if val is not None: + self.streams_data[key] = val.to_device(device) + def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: """ Add data for stream @stream_name to sample @@ -105,6 +127,7 @@ class ModelBatch: # index of corresponding target (for source samples) or source (for target samples) # these are in 1-to-1 corresponding for classical training modes (MTM, forecasting) but # can be more complex for strategies like student-teacher training + # TODO @CL and @SHickman can we make these tensors? source2target_matching_idxs: np.typing.NDArray[np.int32] target2source_matching_idxs: np.typing.NDArray[np.int32] @@ -118,6 +141,13 @@ def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> # self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] + def to_device(self, device): + for sample in self.source_samples: + sample.to_device(device) + + for sample in self.target_samples: + sample.to_device(device) + def add_source_stream( self, source_sample_idx: int, diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index e72985950..4b906dfe0 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -783,8 +783,10 @@ def _preprocess_model_data(self, batch, forecast_dt): # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) - ##### target_coords_idx we probably don't need for the targets ##### - target_coords_idx = compute_idxs_predict(self.forecast_offset + forecast_dt, batch) + # TODO: only use when targets are predicted with decoders + target_coords_idx = compute_idxs_predict( + self.forecast_offset + forecast_dt, batch, self.streams + ) return batch, source_cell_lens, target_coords_idx diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 3d92e5a66..726936bb5 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -333,7 +333,7 @@ def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> St return batch -def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: +def compute_idxs_predict(forecast_dt: int, batch: StreamData, streams: list[dict]) -> list: """ Compute auxiliary information for prediction @@ -353,26 +353,24 @@ def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: target_coords_lens = [[s.target_coords_lens for s in sb] for sb in batch] # target coords idxs - tcs_lens_merged = [] + tcs_lens_merged = {} pad = torch.zeros(1, dtype=torch.int32) for ii in range(len(batch[0])): # generate len lists for varlen attention (per batch list for local, per-cell attention and # global - tcs_lens_merged += [ - [ - torch.cat( - [ - pad, - torch.cat( - [ - target_coords_lens[i_b][ii][fstep] - for i_b in range(len(target_coords_lens)) - ] - ), - ] - ).to(torch.int32) - for fstep in range(forecast_dt + 1) - ] + tcs_lens_merged[streams[ii]["name"]] = [ + torch.cat( + [ + pad, + torch.cat( + [ + target_coords_lens[i_b][ii][fstep] + for i_b in range(len(target_coords_lens)) + ] + ), + ] + ).to(torch.int32) + for fstep in range(forecast_dt + 1) ] return tcs_lens_merged diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index e347219ac..248f61fe3 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -92,37 +92,35 @@ def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): # TODO: handling of input steps should be done using encoder # iterate over all input steps and streams for istep in range(num_step_input): - # TODO: what is this list dimension??? Where should the istep index be??? - for _, sb in enumerate(streams_data): - for stream_name, s_data in zip(self.stream_names, sb, strict=True): - # embedding network - embed = self.embeds[stream_name] - - # skip empty stream - if s_data.source_empty(): - continue - - idxs = s_data.source_idxs_embed[istep].to(device) - idxs_pe = s_data.source_idxs_embed_pe[istep].to(device) - - # create full scatter index - # (there's no broadcasting which is likely highly inefficient) - idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - x_embed = embed(s_data.source_tokens_cells[istep]).flatten(0, 1) - # there's undocumented limitation in flash_attn that will make embed fail if - # #tokens is too large; code below is a work around - # x_embed = torch.cat( - # [ - # embed(s_c, c_c).flatten(0, 1) - # for s_c, c_c in zip( - # torch.split(s.source_tokens_cells, 49152), - # torch.split(s.source_centroids, 49152), - # ) - # ] - # ) - - # scatter write to reorder from per stream to per cell ordering - tokens_all[istep].scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) + for stream_name, s_data in streams_data.items(): + # embedding network + embed = self.embeds[stream_name] + + # skip empty stream + if s_data.source_empty(): + continue + + idxs = s_data.source_idxs_embed[istep].to(device) + idxs_pe = s_data.source_idxs_embed_pe[istep].to(device) + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + x_embed = embed(s_data.source_tokens_cells[istep]).flatten(0, 1) + # there's undocumented limitation in flash_attn that will make embed fail if + # #tokens is too large; code below is a work around + # x_embed = torch.cat( + # [ + # embed(s_c, c_c).flatten(0, 1) + # for s_c, c_c in zip( + # torch.split(s.source_tokens_cells, 49152), + # torch.split(s.source_centroids, 49152), + # ) + # ] + # ) + + # scatter write to reorder from per stream to per cell ordering + tokens_all[istep].scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) return tokens_all[0] diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index c34c2f1bc..e97c99d25 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -605,7 +605,6 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca Returns: A list containing all prediction results """ - (streams_data, source_cell_lens, target_coords_idxs) = batch # embed @@ -644,7 +643,8 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca model_params, forecast_offset + forecast_steps, tokens, - streams_data, + # TODO We add the batch dimension back and thus wrap stream_data in a list + [streams_data], target_coords_idxs, ) ] @@ -857,6 +857,7 @@ def predict( fstep : Number of forecast steps tokens : Tokens from global assimilation engine streams_data : Used to initialize target coordinates tokens and index information + List of StreamData len(streams_data) == batch_size_per_gpu target_coords_idxs : Indices of target coordinates Returns: Prediction output tokens in physical representation for each target_coords. @@ -872,8 +873,7 @@ def predict( # pair with tokens from assimilation engine to obtain target tokens preds_tokens = [] - for idx, stream_name in enumerate(self.stream_names): - si = self.cf.streams[idx] + for stream_name in self.stream_names: tte = self.target_token_engines[stream_name] tte_kv = self.pred_adapter_kv[stream_name] tc_embed = self.embed_target_coords[stream_name] @@ -887,18 +887,18 @@ def predict( [ checkpoint( tc_embed, - streams_data[i_b][idx].target_coords[fstep], + streams_data[i_b][stream_name].target_coords[fstep], use_reentrant=False, ) - if len(streams_data[i_b][idx].target_coords[fstep].shape) > 1 - else streams_data[i_b][idx].target_coords[fstep] - for i_b in range(len(streams_data)) + if len(streams_data[i_b][stream_name].target_coords[fstep].shape) > 1 + else streams_data[i_b][stream_name].target_coords[fstep] + for i_b in range(len(streams_data)) # i_b is the index over the batch dimension ] ) # skip when coordinate embeddings yields nan (i.e. the coord embedding network diverged) if torch.isnan(tc_tokens).any(): - nn = si["name"] + nn = stream_name if is_root(): logger.warning( ( @@ -919,10 +919,13 @@ def predict( assert isinstance(tte_kv, torch.nn.Identity) # lens for varlen attention - tcs_lens = target_coords_idxs[idx][fstep] + tcs_lens = target_coords_idxs[stream_name][fstep] # coord information for learnable layer norm tcs_aux = torch.cat( - [streams_data[i_b][idx].target_coords[fstep] for i_b in range(len(streams_data))] + [ + streams_data[i_b][stream_name].target_coords[fstep] + for i_b in range(len(streams_data)) + ] ) tc_tokens = tte( diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 156f0c97a..0f523b409 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -212,18 +212,20 @@ def compute_loss( # TODO: iterate over batch dimension i_batch = 0 + streams_data = [streams_data] for i_stream_info, stream_info in enumerate(self.cf.streams): + stream_name = stream_info["name"] # extract target tokens for current stream from the specified forecast offset onwards - targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + targets = streams_data[i_batch][stream_name].target_tokens[self.cf.forecast_offset :] - stream_data = streams_data[i_batch][i_stream_info] + stream_data = streams_data[i_batch][stream_name] fstep_loss_weights = self._get_fstep_weights(len(targets)) loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) ctr_fsteps = 0 - stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + stream_is_spoof = streams_data[i_batch][stream_name].is_spoof() if stream_is_spoof: spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) else: diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 58995055f..ca32efe06 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -463,7 +463,7 @@ def _prepare_logging( continue for i_strm, target in enumerate(targets_rt[fstep]): - pred = preds[fstep][i_strm] + pred = preds.physical[fstep][i_strm] idxs_inv = idxs_inv_rt[fstep][i_strm] if not (target.shape[0] > 0 and pred.shape[0] > 0): @@ -558,42 +558,31 @@ def train(self, mini_epoch): enabled=cf.with_mixed_precision, ): outputs = [] - for view in batch[-1].source_samples: - # TODO remove when ModelBatch and Sample get a to_device() - streams_data = [ - [v.to_device(self.device) for k, v in view.streams_data.items()] - ] - source_cell_lens = view.source_cell_lens - source_cell_lens = [b.to(self.device) for b in source_cell_lens] - target_coords_idxs = view.target_coords_idx - target_coords_idxs = [ - [b.to(self.device) for b in bf] for bf in target_coords_idxs - ] + batch[-1].to_device(self.device) + for sample in batch[-1].source_samples: outputs.append( self.model( self.model_params, - (streams_data, source_cell_lens, target_coords_idxs), + ( + sample.streams_data, + sample.source_cell_lens, + sample.target_coords_idx, + ), cf.forecast_offset, forecast_steps, ) ) targets_and_auxs = [] - for view in batch[-1].target_samples: - # TODO remove when ModelBatch and Sample get a to_device() - streams_data = [ - [v.to_device(self.device) for k, v in view.streams_data.items()] - ] - source_cell_lens = view.source_cell_lens - source_cell_lens = [b.to(self.device) for b in source_cell_lens] - target_coords_idxs = view.target_coords_idx - target_coords_idxs = [ - [b.to(self.device) for b in bf] for bf in target_coords_idxs - ] + for sample in batch[-1].target_samples: targets_and_auxs.append( self.target_and_aux_calculator.compute( self.cf.istep, - (streams_data, source_cell_lens, target_coords_idxs), + ( + sample.streams_data, + sample.source_cell_lens, + sample.target_coords_idx, + ), self.model_params, self.model, cf.forecast_offset, @@ -738,8 +727,10 @@ def validate(self, mini_epoch): total=len(self.data_loader_validation), disable=self.cf.with_ddp ) as pbar: for bidx, batch in enumerate(dataset_val_iter): - forecast_steps = batch[-1] - batch = self.batch_to_device(batch) + forecast_steps = batch[0][-1] + old_batch = batch[0] + batch = batch[-1] + batch.to_device(self.device) # evaluate model with torch.autocast( @@ -752,12 +743,25 @@ def validate(self, mini_epoch): if self.ema_model is None else self.ema_model.forward_eval ) + sample = batch.source_samples[0] output = model_forward( - self.model_params, batch, cf.forecast_offset, forecast_steps + self.model_params, + ( + sample.streams_data, + sample.source_cell_lens, + sample.target_coords_idx, + ), + cf.forecast_offset, + forecast_steps, ) + sample = batch.target_samples[0] target_aux_output = self.target_and_aux_calculator.compute( bidx, - batch, + ( + sample.streams_data, + sample.source_cell_lens, + sample.target_coords_idx, + ), self.model_params, self.model, cf.forecast_offset, @@ -771,7 +775,12 @@ def validate(self, mini_epoch): # log output if bidx < cf.log_validation: # TODO: Move _prepare_logging into write_validation by passing streams_data - streams_data: list[list[StreamData]] = batch[0] + # TODO right now we hardcode ERA5 which obviously is bad, but not sure + # how this logging function is supposed to change + streams_data: list[list[StreamData]] = old_batch[0] + import pdb + + pdb.set_trace() ( preds_all, targets_all, @@ -791,7 +800,7 @@ def validate(self, mini_epoch): self.cf, mini_epoch, bidx, - sources, + sources[0], preds_all, targets_all, targets_coords_all, From 2cd397170bbf89b77da06bb92db8e12548828386 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Dec 2025 23:56:41 +0100 Subject: [PATCH 144/344] Completed migration to new batch class by removing reference to old list of lists --- src/weathergen/datasets/batch.py | 45 +++++++--- .../datasets/multi_stream_data_sampler.py | 63 ++++---------- src/weathergen/train/trainer.py | 83 +++++++++---------- 3 files changed, 91 insertions(+), 100 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index c4b68624b..d2292f12b 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -34,7 +34,6 @@ class Sample: # data for all streams # keys: stream_name, values: StreamData streams_data: dict[str, StreamData | None] - forecast_dt: int | None # TODO: # these two need to live in ModelBatch as they are flattened! @@ -54,8 +53,6 @@ def __init__(self, streams: dict) -> None: self.source_cell_lens: list[torch.Tensor] | None = None self.target_coords_idx: list[torch.Tensor] | None = None - self.forecast_dt: int | None = None - def to_device(self, device) -> None: if self.source_cell_lens is not None: # iterate over forecast steps @@ -77,6 +74,12 @@ def to_device(self, device) -> None: if val is not None: self.streams_data[key] = val.to_device(device) + def is_empty(self) -> bool: + """ + Check if sample is empty + """ + return np.all(np.array([s.empty() for _, s in self.streams_data.items()])) + def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: """ Add data for stream @stream_name to sample @@ -97,14 +100,6 @@ def set_preprocessed(self, source_cell_lens, target_coords_idx): self.source_cell_lens = source_cell_lens self.target_coords_idx = target_coords_idx - def set_forecast_dt(self, forecast_dt: int) -> None: - """ - Set forecast_dt for sample - """ - self.forecast_dt = forecast_dt - - # TODO: complete interface, e.g get_stream - def get_stream_data(self, stream_name: str) -> StreamData: """ Get data for stream @stream_name from sample @@ -131,7 +126,11 @@ class ModelBatch: source2target_matching_idxs: np.typing.NDArray[np.int32] target2source_matching_idxs: np.typing.NDArray[np.int32] - def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> None: + forecast_dt: int | None + + def __init__( + self, streams, num_source_samples: int, num_target_samples: int, forecast_dt: int + ) -> None: """ """ self.source_samples = [Sample(streams) for _ in range(num_source_samples)] @@ -141,6 +140,8 @@ def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> # self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] + self.forecast_dt = forecast_dt + def to_device(self, device): for sample in self.source_samples: sample.to_device(device) @@ -193,6 +194,26 @@ def add_target_stream( ) self.target2source_matching_idxs[target_sample_idx] = source_sample_idx + def is_empty(self): + """ + Check if batch is empty + """ + source_empty = np.all(np.array([s.is_empty() for s in self.source_samples])) + target_empty = np.all(np.array([s.is_empty() for s in self.target_samples])) + return source_empty or target_empty + + def set_forecast_dt(self, forecast_dt: int) -> None: + """ + Set forecast_dt for sample + """ + self.forecast_dt = forecast_dt + + def get_forecast_dt(self) -> int: + """ + Get forecast_dt + """ + return self.forecast_dt + def len_sources(self) -> int: """ Number of source samples diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 4b906dfe0..02eaae797 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -514,7 +514,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): return (input_data, output_data) - def _get_sample(self, mode: str, idx: int, forecast_dt: int): + def _get_sample(self, idx: int, forecast_dt: int): """ modes : @@ -525,12 +525,12 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): TODO: these modes are not being used now. """ + mode = self.training_cfg.get("training_mode") + # get/coordinate masks masks_streams = self._get_source_target_masks(mode) if mode == "masking" or mode == "student_teacher": - streams_data: list[StreamData] = [] - # Determine number of views direct from config (teacher & student views) target_cfg = self.training_cfg.get("target_input", {}) if self.training_cfg else {} target_cfg = target_cfg if target_cfg is not None else {} @@ -541,7 +541,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): source_cfg.get("num_samples", 1) ) # per teacher - batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): @@ -591,7 +591,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): t_idx = student_to_teacher[sidx] batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata_list[sidx]) # num_input_steps? - batch.source_samples[sidx].set_forecast_dt(forecast_dt) # stream_data_target can contain network input stream_data_target = {} @@ -629,7 +628,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == sidx ] batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) - batch.target_samples[t_idx].set_forecast_dt(forecast_dt) # TODO: build batch # source_input @@ -637,13 +635,7 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # source_output # target_output - # TOOD: remove - # add data for current stream - streams_data += [v for k, v in stream_data_source.items()] - elif mode == "diffusion_forecast": - streams_data: list[StreamData] = [] - # Determine number of views direct from config (teacher & student views) teacher_cfg = self.training_cfg.get("target_input", {}) if self.training_cfg else {} student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} @@ -695,8 +687,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # Map each student (source) to its teacher (target) batch.add_source_stream(0, 0, name, sdata, source_metadata) - # num_input_steps? - batch.source_samples[0].set_forecast_dt(forecast_dt) # stream_data_target can contain network input stream_data_target = {} @@ -726,7 +716,6 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # Map target to all source students batch.add_target_stream(0, 0, name, sdata, target_metadata) - batch.target_samples[0].set_forecast_dt(forecast_dt) # TODO: build batch # source_input @@ -734,14 +723,10 @@ def _get_sample(self, mode: str, idx: int, forecast_dt: int): # source_output # target_output - # TOOD: remove - # add data for current stream - streams_data += [v for k, v in stream_data_source.items()] - else: assert False, "Mode not implemented" - return streams_data, batch + return batch def _get_source_target_masks(self, training_mode): """ @@ -790,7 +775,7 @@ def _preprocess_model_data(self, batch, forecast_dt): return batch, source_cell_lens, target_coords_idx - def _preprocess_single_view(self, sample: Sample, forecast_dt: int): + def _preprocess_model_batch_sample(self, sample: Sample, forecast_dt: int): """ """ streams = [sd for sd in sample.streams_data.values() if sd is not None] if not streams: @@ -799,11 +784,11 @@ def _preprocess_single_view(self, sample: Sample, forecast_dt: int): _, scl, tci = self._preprocess_model_data([streams], forecast_dt) sample.set_preprocessed(scl, tci) - def _preprocess_model_batch_views(self, model_batch: ModelBatch, forecast_dt: int): + def _preprocess_model_batch(self, model_batch: ModelBatch, forecast_dt: int): for sample in model_batch.source_samples: - self._preprocess_single_view(sample, forecast_dt) + self._preprocess_model_batch_sample(sample, forecast_dt) for sample in model_batch.target_samples: - self._preprocess_single_view(sample, forecast_dt) + self._preprocess_model_batch_sample(sample, forecast_dt) def __iter__(self): """ @@ -829,8 +814,7 @@ def __iter__(self): # use while loop due to the scattered nature of the data in time and to # ensure batches are not empty - batch = [] - while len(batch) < self.batch_size: + while True: idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] idx_raw += 1 @@ -838,32 +822,21 @@ def __iter__(self): if hasattr(self.tokenizer, "masker"): self.tokenizer.masker.set_batch_strategy() - # # TODO: ideally update this student-teacher if-else to a more general - # # view-based data sampling - # if self.training_cfg.get("training_mode") == "student_teacher": - - mode = self.training_cfg.get("training_mode") - - streams_data, student_teacher_batch = self._get_sample(mode, idx, forecast_dt) + batch = self._get_sample(idx, forecast_dt) # Reset masking strategy for next batch item if hasattr(self.tokenizer, "masker"): self.tokenizer.masker.reset_batch_strategy() - # skip completely empty batch item or when all targets are empty -> no grad - if not (all(s.empty() or s.target_empty() for s in streams_data)): - batch += [streams_data] - - # TODO: link into ModelBatch - - # compute - batch, source_cell_lens, target_coords_idx = self._preprocess_model_data( - batch, forecast_dt - ) + # # skip completely empty batch item or when all targets are empty -> no grad + if not batch.is_empty(): + break + else: + logger.warning("Skipping empty batch.") - self._preprocess_model_batch_views(student_teacher_batch, forecast_dt) + self._preprocess_model_batch(batch, forecast_dt) - yield (batch, source_cell_lens, target_coords_idx, forecast_dt), student_teacher_batch + yield batch def __len__(self): return self.len diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ca32efe06..ff6679399 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -26,7 +26,6 @@ import weathergen.common.config as config from weathergen.common.config import Config from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler -from weathergen.datasets.stream_data import StreamData from weathergen.model.ema import EMAModel from weathergen.model.model_interface import ( get_target_aux_calculator, @@ -38,7 +37,6 @@ from weathergen.utils.distributed import all_gather_vlen, ddp_average, is_root from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger from weathergen.utils.utils import get_batch_size, get_dtype -from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) @@ -547,10 +545,11 @@ def train(self, mini_epoch): # [[tensor([0, 0, 0, ..., 4, 4, 4], device='cuda:0', dtype=torch.int32)]] # TODO: access from new ModelBatch - forecast_steps = batch[0][-1] + forecast_steps = batch.get_forecast_dt() # batch = self.batch_to_device(batch) ### After to_device, then the original is: + batch.to_device(self.device) with torch.autocast( device_type=f"cuda:{cf.local_rank}", @@ -558,8 +557,7 @@ def train(self, mini_epoch): enabled=cf.with_mixed_precision, ): outputs = [] - batch[-1].to_device(self.device) - for sample in batch[-1].source_samples: + for sample in batch.source_samples: outputs.append( self.model( self.model_params, @@ -574,7 +572,7 @@ def train(self, mini_epoch): ) targets_and_auxs = [] - for sample in batch[-1].target_samples: + for sample in batch.target_samples: targets_and_auxs.append( self.target_and_aux_calculator.compute( self.cf.istep, @@ -727,9 +725,7 @@ def validate(self, mini_epoch): total=len(self.data_loader_validation), disable=self.cf.with_ddp ) as pbar: for bidx, batch in enumerate(dataset_val_iter): - forecast_steps = batch[0][-1] - old_batch = batch[0] - batch = batch[-1] + forecast_steps = batch.get_forecast_dt() batch.to_device(self.device) # evaluate model @@ -774,40 +770,41 @@ def validate(self, mini_epoch): # log output if bidx < cf.log_validation: - # TODO: Move _prepare_logging into write_validation by passing streams_data - # TODO right now we hardcode ERA5 which obviously is bad, but not sure - # how this logging function is supposed to change - streams_data: list[list[StreamData]] = old_batch[0] - import pdb - - pdb.set_trace() - ( - preds_all, - targets_all, - targets_coords_all, - targets_times_all, - targets_lens, - ) = self._prepare_logging( - preds=output, - forecast_offset=cf.forecast_offset, - forecast_steps=cf.forecast_steps, - streams_data=streams_data, - ) - sources = [[item.source_raw for item in stream] for stream in streams_data] - # sample idx should be the same across streams => select first - sample_idxs = [item.sample_idx for item in streams_data[0]] - write_output( - self.cf, - mini_epoch, - bidx, - sources[0], - preds_all, - targets_all, - targets_coords_all, - targets_times_all, - targets_lens, - sample_idxs, - ) + logger.warning("logging of data currently not implemented") + # # TODO: Move _prepare_logging into write_validation by passing streams_data + # # TODO right now we hardcode ERA5 which obviously is bad, but not sure + # # how this logging function is supposed to change + # streams_data: list[list[StreamData]] = old_batch[0] + # import pdb + + # pdb.set_trace() + # ( + # preds_all, + # targets_all, + # targets_coords_all, + # targets_times_all, + # targets_lens, + # ) = self._prepare_logging( + # preds=output, + # forecast_offset=cf.forecast_offset, + # forecast_steps=cf.forecast_steps, + # streams_data=streams_data, + # ) + # sources = [[item.source_raw for item in stream] for stream in streams_data] + # # sample idx should be the same across streams => select first + # sample_idxs = [item.sample_idx for item in streams_data[0]] + # write_output( + # self.cf, + # mini_epoch, + # bidx, + # sources[0], + # preds_all, + # targets_all, + # targets_coords_all, + # targets_times_all, + # targets_lens, + # sample_idxs, + # ) # Collecting loss statistics for later inspection if bidx == 0: From 51754fa5fe9e9057aaa193cc86b02d93b2f4ce7b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 00:00:20 +0100 Subject: [PATCH 145/344] Fixed missing non_blocking=True in to_device() --- src/weathergen/datasets/batch.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index d2292f12b..ffcf600fc 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -56,23 +56,25 @@ def __init__(self, streams: dict) -> None: def to_device(self, device) -> None: if self.source_cell_lens is not None: # iterate over forecast steps - self.source_cell_lens = [t.to(device) for t in self.source_cell_lens] + self.source_cell_lens = [t.to(device, non_blocking=True) for t in self.source_cell_lens] if self.target_coords_idx is not None: target_coords_idx_new = {} for k, v in self.target_coords_idx.items(): # iterate over forecast steps - target_coords_idx_new[k] = [vv.to(device) for vv in v] + target_coords_idx_new[k] = [vv.to(device, non_blocking=True) for vv in v] self.target_coords_idx = target_coords_idx_new for key in self.meta_info.keys(): self.meta_info[key].mask = ( - self.meta_info[key].mask.to(device) if self.meta_info[key].mask else None + self.meta_info[key].mask.to(device, non_blocking=True) + if self.meta_info[key].mask + else None ) for key, val in self.streams_data.items(): if val is not None: - self.streams_data[key] = val.to_device(device) + self.streams_data[key] = val.to_device(device, non_blocking=True) def is_empty(self) -> bool: """ From 69b53a6d2a4a533ed99ec89eeffa64e90744c701 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 00:00:42 +0100 Subject: [PATCH 146/344] Removed old comments --- src/weathergen/train/trainer.py | 48 ++------------------------------- 1 file changed, 2 insertions(+), 46 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ff6679399..d54ebadf2 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -515,40 +515,6 @@ def train(self, mini_epoch): # training loop self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): - # NOTE: we are still returning legacy batch structure and the new batch together. - - # Julian and Matthias: - # here we can access data as follows: - # batch[-1] is the new ModelBatch object, see the structure in batch.py - # batch[-1].source_samples is a list of Sample objects for the source data, timesteps - # batch[-1].target_samples is a list of Sample objects for the target data, timesteps - # batch[-1].meta_info is a dictionary with metadata info per sample - # batch[-1].meta_info["ERA5"] etc. - # here we have the noise_level_rn - # batch[-1].source_samples[0].meta_info["ERA5"].noise_level_rn - # == batch[-1].target_samples[0].meta_info["ERA5"].noise_level_rn - # for the same timestep, this needs to be fixed for when we have more source timesteps, - # and perhaps with bigger batch sizes? - # Each Sample object has: - # .streams_data: a dictionary of StreamData objects per stream name - # .source_cell_lens: list of tensors with lengths of source cells per stream # to be - # changed to be in ModelBatch - # .target_coords_idx: list of tensors with target coordinate indices per stream # to - # be changed to be in ModelBatch - - ###### Legacy batch after batch.to_device: - # (Pdb++) batch[0] - # [[]] - # (Pdb++) batch[1] - # [tensor([0, 1, 1, ..., 0, 0, 0], device='cuda:0', dtype=torch.int32)] - # (Pdb++) batch[2] - # [[tensor([0, 0, 0, ..., 4, 4, 4], device='cuda:0', dtype=torch.int32)]] - - # TODO: access from new ModelBatch - forecast_steps = batch.get_forecast_dt() - # batch = self.batch_to_device(batch) - - ### After to_device, then the original is: batch.to_device(self.device) with torch.autocast( @@ -567,7 +533,7 @@ def train(self, mini_epoch): sample.target_coords_idx, ), cf.forecast_offset, - forecast_steps, + batch.get_forecast_dt(), ) ) @@ -584,7 +550,7 @@ def train(self, mini_epoch): self.model_params, self.model, cf.forecast_offset, - forecast_steps, + batch.get_forecast_dt(), ) ) # targets, aux = zip(*targets_and_auxs) @@ -599,16 +565,6 @@ def train(self, mini_epoch): # ), ) - # OLD - # output = self.model(self.model_params, batch, cf.forecast_offset, forecast_steps) - # target_aux_output = self.target_and_aux_calculator.compute( - # bidx, batch, self.model_params, self.model, cf.forecast_offset, forecast_steps - # ) - # loss, loss_values = self.loss_calculator.compute_loss( - # preds=output, - # targets=target_aux_output, - # ) - # TODO re-enable this, need to think on how to make it compatible with # TODO: CL, this should become a regular loss term # student-teacher training From 59510dda2b80d14cab1de7a95d11bcd79a44581a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 00:01:50 +0100 Subject: [PATCH 147/344] Fixed problem with non_blocking=True --- src/weathergen/datasets/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index ffcf600fc..b9852c030 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -74,7 +74,7 @@ def to_device(self, device) -> None: for key, val in self.streams_data.items(): if val is not None: - self.streams_data[key] = val.to_device(device, non_blocking=True) + self.streams_data[key] = val.to_device(device) def is_empty(self) -> bool: """ From b69b743986fc7b49170a3768b17107ac5d962a56 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 13:30:41 +0100 Subject: [PATCH 148/344] Cleaned up comments and return values a bit --- .../datasets/multi_stream_data_sampler.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 02eaae797..36d6d4a7b 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -567,11 +567,9 @@ def _get_sample(self, idx: int, forecast_dt: int): # collect source data for current stream # loop over student views - stream_data_source = {} for sidx, (target_mask, source_mask) in enumerate( zip(target_masks, source_masks, strict=False) ): - # stream_data_source[name] = self._build_stream_data( sdata = self._build_stream_data( "target_coords target_values", idx, @@ -585,12 +583,9 @@ def _get_sample(self, idx: int, forecast_dt: int): source_mask, ) - stream_data_source[name] = sdata - # Map each student (source) to its teacher (target) t_idx = student_to_teacher[sidx] batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata_list[sidx]) - # num_input_steps? # stream_data_target can contain network input stream_data_target = {} @@ -785,18 +780,20 @@ def _preprocess_model_batch_sample(self, sample: Sample, forecast_dt: int): sample.set_preprocessed(scl, tci) def _preprocess_model_batch(self, model_batch: ModelBatch, forecast_dt: int): + """ + Perform necessary pre-processing of model batch + """ for sample in model_batch.source_samples: self._preprocess_model_batch_sample(sample, forecast_dt) for sample in model_batch.target_samples: self._preprocess_model_batch_sample(sample, forecast_dt) - def __iter__(self): + def __iter__(self) -> ModelBatch: """ Return one batch of data - Return : list[list[StreamData]] - len : number of batch items - len[*] : number of streams + Return : + batch of data """ iter_start, iter_end = self.worker_workset() logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}") @@ -819,14 +816,14 @@ def __iter__(self): idx_raw += 1 # Sample masking strategy once per batch item - if hasattr(self.tokenizer, "masker"): - self.tokenizer.masker.set_batch_strategy() + # TODO: still needed? + self.tokenizer.masker.set_batch_strategy() batch = self._get_sample(idx, forecast_dt) # Reset masking strategy for next batch item - if hasattr(self.tokenizer, "masker"): - self.tokenizer.masker.reset_batch_strategy() + # TODO: still needed? + self.tokenizer.masker.reset_batch_strategy() # # skip completely empty batch item or when all targets are empty -> no grad if not batch.is_empty(): From d36367a93bb5305a27fc7b9b581b5ed91904c144 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 13:31:55 +0100 Subject: [PATCH 149/344] Changed args to embedding --- src/weathergen/model/engines.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 248f61fe3..927f95119 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -79,10 +79,10 @@ def __init__(self, cf: Config, sources_size, stream_names: list[str]) -> None: else: raise ValueError("Unsupported embedding network type") - def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): - num_step_input = len(source_cell_lens) - - offsets_base = [torch.cumsum(s[1:], 0) for s in source_cell_lens] + # TODO: remove device from arg list + def forward(self, sample, pe_embed, dtype, device): + num_step_input = len(sample.source_cell_lens) + offsets_base = [torch.cumsum(s[1:], 0) for s in sample.source_cell_lens] tokens_all = [ torch.empty((int(ob[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device) @@ -92,7 +92,7 @@ def forward(self, streams_data, source_cell_lens, pe_embed, dtype, device): # TODO: handling of input steps should be done using encoder # iterate over all input steps and streams for istep in range(num_step_input): - for stream_name, s_data in streams_data.items(): + for stream_name, s_data in sample.streams_data.items(): # embedding network embed = self.embeds[stream_name] From 3f52a8d858cc2395e0c1355986938ba2c6c96512 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 13:32:14 +0100 Subject: [PATCH 150/344] Changed core functions to take sample as arg --- src/weathergen/model/model.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index e97c99d25..c66c5fa78 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -588,7 +588,7 @@ def rename_old_state_dict(self, params: dict) -> dict: return new_params ######################################### - def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int): + def forward(self, model_params: ModelParams, sample, forecast_offset: int, forecast_steps: int): """Performs the forward pass of the model to generate forecasts Tokens are processed through the model components, which were defined in the create method. @@ -605,13 +605,12 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca Returns: A list containing all prediction results """ - (streams_data, source_cell_lens, target_coords_idxs) = batch # embed - tokens = self.embed_cells(model_params, streams_data, source_cell_lens) + tokens = self.embed_cells(model_params, sample) # local assimilation engine and adapter - tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) + tokens, posteriors = self.assimilate_local(model_params, tokens, sample) tokens = self.assimilate_global(model_params, tokens) @@ -624,8 +623,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca model_params, fstep, tokens, - streams_data, - target_coords_idxs, + sample, ) ] @@ -644,8 +642,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca forecast_offset + forecast_steps, tokens, # TODO We add the batch dimension back and thus wrap stream_data in a list - [streams_data], - target_coords_idxs, + sample, ) ] @@ -655,27 +652,23 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca return ModelOutput(physical=preds_all, latent=latents) ######################################### - def embed_cells( - self, model_params: ModelParams, streams_data, source_cell_lens - ) -> torch.Tensor: + def embed_cells(self, model_params: ModelParams, sample) -> torch.Tensor: """Embeds input data for each stream separately and rearranges it to cell-wise order Args: model_params : Query and embedding parameters streams_data : Used to initialize first tokens for pre-processing - Returns: + Returns:uv Tokens for local assimilation """ device = next(self.parameters()).device - tokens_all = self.embed_engine( - streams_data, source_cell_lens, model_params.pe_embed, self.dtype, device - ) + tokens_all = self.embed_engine(sample, model_params.pe_embed, self.dtype, device) return tokens_all ######################################### def assimilate_local( - self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor + self, model_params: ModelParams, tokens: torch.Tensor, sample: torch.Tensor ) -> torch.Tensor: """Processes embedded tokens locally and prepares them for the global assimilation Args: @@ -687,6 +680,7 @@ def assimilate_local( Tokens for global assimilation """ + cell_lens = sample.source_cell_lens batch_size = ( self.cf.batch_size_per_gpu if self.training else self.cf.batch_size_validation_per_gpu ) @@ -846,8 +840,7 @@ def predict( model_params: ModelParams, fstep: int, tokens: torch.Tensor, - streams_data, - target_coords_idxs, + sample, ) -> list[torch.Tensor]: """Predict outputs at the specific target coordinates based on the input weather state and pre-training task and projects the latent space representation back to physical space. @@ -863,6 +856,10 @@ def predict( Prediction output tokens in physical representation for each target_coords. """ + # add list which represents batch samples + streams_data = [sample.streams_data] + target_coords_idxs = sample.target_coords_idx + batch_size = ( self.cf.batch_size_per_gpu if self.training else self.cf.batch_size_validation_per_gpu ) From 90652194904963d9e04e5e3dfb0f59afd5d38b84 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 13:33:42 +0100 Subject: [PATCH 151/344] Changed that model takes sample as input --- src/weathergen/train/trainer.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index d54ebadf2..2ee71a547 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -527,11 +527,7 @@ def train(self, mini_epoch): outputs.append( self.model( self.model_params, - ( - sample.streams_data, - sample.source_cell_lens, - sample.target_coords_idx, - ), + sample, cf.forecast_offset, batch.get_forecast_dt(), ) @@ -698,11 +694,7 @@ def validate(self, mini_epoch): sample = batch.source_samples[0] output = model_forward( self.model_params, - ( - sample.streams_data, - sample.source_cell_lens, - sample.target_coords_idx, - ), + sample, cf.forecast_offset, forecast_steps, ) From 12bae1517615f48a9663bb6e9ed2d9f995b11b21 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 15:01:07 +0100 Subject: [PATCH 152/344] Fixes for diffusion --- src/weathergen/datasets/masking.py | 73 ++++++++++--------- .../datasets/multi_stream_data_sampler.py | 9 +-- 2 files changed, 38 insertions(+), 44 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index c555adc3d..9333520ac 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -362,19 +362,17 @@ def build_samples_for_stream( source_num_samples = source_cfg.get("num_samples", 1) source_strategy = source_cfg.get("masking_strategy", source_cfg.get("strategy", "random")) - source_masking_params = source_cfg.get("masking_strategy_config") + source_masking_params = source_cfg.get("masking_strategy_config", {}) relationship = source_cfg.get("relationship", "complement") if target_cfg is not None: target_num_samples = target_cfg.get("num_samples", 1) target_strategy = target_cfg.get("strategy", "random") - target_masking_params = target_cfg.get("masking_strategy_config") + target_masking_params = target_cfg.get("masking_strategy_config", {}) else: target_strategy = source_strategy target_num_samples = source_num_samples target_masking_params = source_masking_params - # # do other relationships make sense - # assert relationship == "complement" assert source_num_samples % target_num_samples == 0, ( "number of source samples has to be multiple of target samples" @@ -386,31 +384,29 @@ def build_samples_for_stream( target_masks: list[np.typing.NDArray] = [] target_metadata: list[SampleMetaData] = [] for _ in range(target_num_samples): - target_masks += [ - self._get_mask( - num_cells=num_cells, - strategy=target_strategy, - target_mask=None, - masking_strategy_config=target_masking_params, - ) - ] - target_metadata += [SampleMetaData(params=target_cfg)] + target_mask, mask_params = self._get_mask( + num_cells=num_cells, + strategy=target_strategy, + target_mask=None, + masking_strategy_config=target_masking_params, + ) + target_masks += [target_mask] + target_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] # iterate over all source samples source_masks: list[np.typing.NDArray] = [] source_metadata: list[SampleMetaData] = [] source_target_mapping = np.zeros(source_num_samples, dtype=np.int32) for it in range(source_num_samples): - source_masks += [ - self._get_mask( - num_cells=num_cells, - strategy=source_strategy, - masking_strategy_config=source_masking_params, - target_mask=target_masks[it % target_num_samples], - relationship=relationship, - ) - ] - source_metadata += [SampleMetaData(params=target_cfg)] + source_mask, mask_params = self._get_mask( + num_cells=num_cells, + strategy=source_strategy, + masking_strategy_config=source_masking_params, + target_mask=target_masks[it % target_num_samples], + relationship=relationship, + ) + source_masks += [source_mask] + source_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] source_target_mapping[it] = it % target_num_samples return ( @@ -427,7 +423,7 @@ def _get_mask( masking_strategy_config: dict | None = None, target_mask: np.typing.NDArray | None = None, relationship: str = "subset", - ) -> np.typing.NDArray: + ) -> (np.typing.NDArray, dict): """Get effective mask, combining with target mask if specified. Parameters @@ -450,16 +446,18 @@ def _get_mask( ------- np.ndarray Boolean array of shape [num_cells] where True indicates the cell is kept. + dict + Parameters describing the masking that was applied """ # handle cases where mask is directly derived from target_mask if target_mask is not None: if relationship == "complement": mask = ~target_mask - return mask + return mask, {} # get mask - mask = self._generate_cell_mask(num_cells, strategy, rate, masking_strategy_config) + mask, params = self._generate_cell_mask(num_cells, strategy, rate, masking_strategy_config) # handle cases where mask needs to be combined with target_mask if target_mask is not None: @@ -468,7 +466,7 @@ def _get_mask( elif relationship == "disjoint": mask = mask & (~target_mask) - return mask + return (mask, params) def _generate_cell_mask( self, @@ -476,7 +474,7 @@ def _generate_cell_mask( strategy: str | None = None, rate: float | None = None, masking_strategy_config: dict | None = None, - ) -> np.typing.NDArray: + ) -> (np.typing.NDArray, dict): """Generate a boolean keep mask at data healpix level (True = keep cell). Parameters @@ -501,6 +499,9 @@ def _generate_cell_mask( Boolean array of shape [num_cells] where True indicates the cell is kept. """ + # params describing the masking + masking_params = {} + # get config for mask strat = strategy or self.masking_strategy @@ -516,19 +517,17 @@ def _generate_cell_mask( "num_cells inconsistent with configured healpix level." ) - if strat not in {"random", "healpix"}: - raise NotImplementedError( - f"Cell selection strategy '{strat}' not supported for keep mask generation." - ) - # generate cell mask if strat == "random": mask = self.rng.uniform(0, 1, num_cells) < keep_rate - elif strat == "forecast" or strat == "causal": + elif "forecast" in strat or strat == "causal": mask = np.ones(num_cells, dtype=np.bool) + if "diffusion" in masking_strategy_config: + masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) + elif strat == "healpix": hl_data = self.healpix_level_data hl_mask = cfg.get("hl_mask") @@ -552,8 +551,10 @@ def _generate_cell_mask( mask[child_indices] = True else: - assert False, "Unknown strategy." + raise NotImplementedError( + f"Cell selection strategy '{strat}' not supported for keep mask generation." + ) mask = to_bool_tensor(mask) - return mask + return (mask, masking_params) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 36d6d4a7b..90f7c9e37 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -639,7 +639,7 @@ def _get_sample(self, idx: int, forecast_dt: int): student_cfg.get("num_samples", 1) ) # per teacher - batch = ModelBatch(self.streams, num_source_samples, num_target_samples) + batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): @@ -659,8 +659,6 @@ def _get_sample(self, idx: int, forecast_dt: int): # collect source data for current stream # loop over student views - stream_data_source = {} - # stream_data_source[name] = self._build_stream_data( sdata = self._build_stream_data( "target_coords target_values", idx, @@ -673,8 +671,6 @@ def _get_sample(self, idx: int, forecast_dt: int): mask=None, ) - stream_data_source[name] = sdata - source_metadata = source_metadata # add a ramdom number for diffusion timestep @@ -683,9 +679,6 @@ def _get_sample(self, idx: int, forecast_dt: int): # Map each student (source) to its teacher (target) batch.add_source_stream(0, 0, name, sdata, source_metadata) - # stream_data_target can contain network input - stream_data_target = {} - # stream_data_target[name] = self._build_stream_data( sdata = self._build_stream_data( "target_values", From 7745e47377893b073442965f694abce58ed8894b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 16:35:19 +0100 Subject: [PATCH 153/344] Switched to lists of model / target stratgies --- src/weathergen/datasets/masking.py | 91 ++++++------- .../datasets/multi_stream_data_sampler.py | 122 ++---------------- 2 files changed, 50 insertions(+), 163 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 9333520ac..38bd25b69 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -1,5 +1,4 @@ import logging -from dataclasses import dataclass import numpy as np import torch @@ -15,13 +14,6 @@ def to_bool_tensor(arr): return torch.from_numpy(np.asarray(arr)).to(torch.bool) -@dataclass -class MaskingStrategy: - strategy: str - config: dict - num_samples: int - - class Masker: """Class to generate masks for token sequences and apply them. This class supports different masking strategies and combinations. @@ -347,67 +339,56 @@ def _generate_causal_mask( return full_mask def build_samples_for_stream( - self, - training_mode: str, - num_cells: int, - target_cfg: dict, - source_cfg: dict, + self, training_mode: str, num_cells: int, training_cfg: dict ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: """ Construct teacher/student keep masks for a stream. SampleMetaData is currently just a dict with the masking params used. """ - # get source and target configs; target defaults to source config - - source_num_samples = source_cfg.get("num_samples", 1) - source_strategy = source_cfg.get("masking_strategy", source_cfg.get("strategy", "random")) - source_masking_params = source_cfg.get("masking_strategy_config", {}) - relationship = source_cfg.get("relationship", "complement") - - if target_cfg is not None: - target_num_samples = target_cfg.get("num_samples", 1) - target_strategy = target_cfg.get("strategy", "random") - target_masking_params = target_cfg.get("masking_strategy_config", {}) - else: - target_strategy = source_strategy - target_num_samples = source_num_samples - target_masking_params = source_masking_params - - assert source_num_samples % target_num_samples == 0, ( - "number of source samples has to be multiple of target samples" - ) + target_cfgs = training_cfg.get("target_input", []) + source_cfgs = training_cfg.get("model_input", []) - # translate settings into sampling masks + # target and source are assumed identical when target is not specified + if len(target_cfgs) == 0: + target_cfgs = source_cfgs # iterate over all target samples target_masks: list[np.typing.NDArray] = [] target_metadata: list[SampleMetaData] = [] - for _ in range(target_num_samples): - target_mask, mask_params = self._get_mask( - num_cells=num_cells, - strategy=target_strategy, - target_mask=None, - masking_strategy_config=target_masking_params, - ) - target_masks += [target_mask] - target_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] + # different strategies + for target_cfg in target_cfgs: + # different samples/view per strategy + for _ in range(target_cfg.get("num_samples", 1)): + target_mask, mask_params = self._get_mask( + num_cells=num_cells, + strategy=target_cfg.get("strategy"), + target_mask=None, + masking_strategy_config=target_cfg.get("masking_strategy_config", {}), + ) + target_masks += [target_mask] + target_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] # iterate over all source samples source_masks: list[np.typing.NDArray] = [] source_metadata: list[SampleMetaData] = [] - source_target_mapping = np.zeros(source_num_samples, dtype=np.int32) - for it in range(source_num_samples): - source_mask, mask_params = self._get_mask( - num_cells=num_cells, - strategy=source_strategy, - masking_strategy_config=source_masking_params, - target_mask=target_masks[it % target_num_samples], - relationship=relationship, - ) - source_masks += [source_mask] - source_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] - source_target_mapping[it] = it % target_num_samples + source_target_mapping = [] + # different strategies + for i_source, source_cfg in enumerate(source_cfgs): + # samples per strategy + for _ in range(source_cfg.get("num_samples", 1)): + source_mask, mask_params = self._get_mask( + num_cells=num_cells, + strategy=source_cfg.get("strategy"), + masking_strategy_config=source_cfg.get("masking_strategy_config", {}), + target_mask=target_masks[i_source], + relationship=source_cfg.get("relationship", "independent"), + ) + source_masks += [source_mask] + source_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] + source_target_mapping += [i_source] + + source_target_mapping = np.array(source_target_mapping, dtype=np.int32) return ( (target_masks, target_metadata), @@ -525,7 +506,7 @@ def _generate_cell_mask( elif "forecast" in strat or strat == "causal": mask = np.ones(num_cells, dtype=np.bool) - if "diffusion" in masking_strategy_config: + if "diffusion_rn" in masking_strategy_config: masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) elif strat == "healpix": diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 90f7c9e37..2bec1f472 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -14,7 +14,7 @@ import torch from weathergen.common.io import IOReaderData -from weathergen.datasets.batch import ModelBatch, Sample, SampleMetaData +from weathergen.datasets.batch import ModelBatch, Sample from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi from weathergen.datasets.data_reader_base import ( DataReaderBase, @@ -528,19 +528,10 @@ def _get_sample(self, idx: int, forecast_dt: int): mode = self.training_cfg.get("training_mode") # get/coordinate masks - masks_streams = self._get_source_target_masks(mode) + # TODO: should also return number of views + masks_streams, num_source_samples, num_target_samples = self._get_source_target_masks(mode) if mode == "masking" or mode == "student_teacher": - # Determine number of views direct from config (teacher & student views) - target_cfg = self.training_cfg.get("target_input", {}) if self.training_cfg else {} - target_cfg = target_cfg if target_cfg is not None else {} - source_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} - # TODO: handle this cleaner (maybe enforce earlier that teacher_cfg is dict) - num_target_samples = int(target_cfg.get("num_samples", 1)) - num_source_samples = int(target_cfg.get("num_samples", 1)) * int( - source_cfg.get("num_samples", 1) - ) # per teacher - batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) # for all streams @@ -624,93 +615,6 @@ def _get_sample(self, idx: int, forecast_dt: int): ] batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) - # TODO: build batch - # source_input - # target_input - # source_output - # target_output - - elif mode == "diffusion_forecast": - # Determine number of views direct from config (teacher & student views) - teacher_cfg = self.training_cfg.get("target_input", {}) if self.training_cfg else {} - student_cfg = self.training_cfg.get("model_input", {}) if self.training_cfg else {} - num_target_samples = int(teacher_cfg.get("num_samples", 1)) - num_source_samples = int(teacher_cfg.get("num_samples", 1)) * int( - student_cfg.get("num_samples", 1) - ) # per teacher - - batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) - - # for all streams - for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - name = stream_info["name"] - - source_metadata = SampleMetaData(masking_params=student_cfg) - target_metadata = SampleMetaData(masking_params=teacher_cfg) - - # input_data and output_data is conceptually consecutive but differs - # in source and target channels; overlap in one window when self.forecast_offset=0 - (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) - - # tokenize windows - # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps - input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) - output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) - - # collect source data for current stream - # loop over student views - sdata = self._build_stream_data( - "target_coords target_values", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - mask=None, - ) - - source_metadata = source_metadata - - # add a ramdom number for diffusion timestep - source_metadata.noise_level_rn = self.rng.normal(0.0, 1.0) - - # Map each student (source) to its teacher (target) - batch.add_source_stream(0, 0, name, sdata, source_metadata) - - # stream_data_target[name] = self._build_stream_data( - sdata = self._build_stream_data( - "target_values", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - mask=None, - ) - stream_data_target[name] = sdata - - # get teacher config info - target_metadata = target_metadata - - # TODO: - # target.mask = - - # TODO: handle this for different number of source timesteps - target_metadata.noise_level_rn = source_metadata.noise_level_rn - - # Map target to all source students - batch.add_target_stream(0, 0, name, sdata, target_metadata) - - # TODO: build batch - # source_input - # target_input - # source_output - # target_output - else: assert False, "Mode not implemented" @@ -723,17 +627,12 @@ def _get_source_target_masks(self, training_mode): masks = {} for stream_info in self.streams: - target_cfg = self.training_cfg.get("target_input", {}) - source_cfg = self.training_cfg.get("model_input", {}) - - # Build one teacher and its student views + # Build source and target sample masks target_data, source_data, mapping = self.tokenizer.masker.build_samples_for_stream( - training_mode, - self.num_healpix_cells, - target_cfg=target_cfg, - source_cfg=source_cfg, + training_mode, self.num_healpix_cells, self.training_cfg ) + # TODO: avoid the unpacking here masks[stream_info["name"]] = ( target_data[0], source_data[0], @@ -742,7 +641,14 @@ def _get_source_target_masks(self, training_mode): source_data[1], ) - return masks + # Determine number of views direct from config (teacher & student views) + source_cfgs = self.training_cfg.get("model_input") + target_cfgs = self.training_cfg.get("target_input", source_cfgs) + target_cfgs = target_cfgs if target_cfgs is not None else source_cfgs + num_target_samples = np.array([sc.get("num_samples", 1) for sc in source_cfgs]).sum().item() + num_source_samples = np.array([sc.get("num_samples", 1) for sc in target_cfgs]).sum().item() + + return masks, num_source_samples, num_target_samples def _preprocess_model_data(self, batch, forecast_dt): """ """ From bf17bfe0deeb1df59ce860d4d4da26dfabc7dfb4 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Dec 2025 16:53:51 +0100 Subject: [PATCH 154/344] Updated config --- config/default_config.yml | 58 ++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 5c7dc25d8..f53499f0c 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -1,4 +1,5 @@ streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_nppatms_synop/" embed_orientation: "channels" embed_unembed_mode: "block" @@ -47,13 +48,13 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: "fixed" forecast_att_dense_rate: 1.0 -forecast_with_step_conditioning: False -fe_num_blocks: 0 +forecast_with_step_conditioning: True # False +fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -114,11 +115,7 @@ masking_rate_sampling: True # required for "healpix" and "channel" masking strategies # "healpix": requires healpix mask level to be specified with `hl_mask` # "channel": requires "mode" to be specified, "per_cell" or "global", -masking_strategy_config: {"strategies": ["random", "healpix", "channel"], - "probabilities": [0.34, 0.33, 0.33], - "hl_mask": 3, "mode": "per_cell", - "same_strategy_per_batch": false - } + # Student-teacher configuration (only used when training_mode == "student_teacher") # TODO: adapt so that the masking or forecast config entry also sits here @@ -127,25 +124,36 @@ training_config: training_mode: "masking" # "masking", "student_teacher", "forecast" model_input: - masking_strategy: "healpix" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher - rate: 0.4 # Masking rate to use for model input - num_samples: 1 # if student-teacher, the number of local (student) views to generate - hl_mask : 4 # healpix level to use for healpix masking strategy - relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - - # target_input: - # strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - # rate: 0.4 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) - # num_samples: 1 # number of teacher views to generate - # hl_mask : 0 # healpix level to use for healpix masking strategy - # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells - # rate_sampling: true # randomly sample the rate per batch + - masking_strategy: "forecast" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher + num_samples: 1 # if student-teacher, the number of local (student) views to generate + masking_strategy_config : { diffusion_rn : True, rate : 0.4 } + # masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + # "probabilities": [0.34, 0.33, 0.33], + # "hl_mask": 3, "mode": "per_cell", + # "same_strategy_per_batch": false + # } + # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. + relationship: "indepenendent" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + loss : LossPhysical + - masking_strategy: "masking" + num_samples: 1 # if student-teacher, the number of local (student) views to generate + masking_strategy_config : { diffusion_rn : True, rate : 0.4 } + relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + loss : LossPhysical + # target_input: + # masking_strategy: "cropping" # Strategy for teacher (global) view: "random", "healpix" + # rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + # num_samples: 2 # number of teacher views to generate + # hl_mask : 0 # healpix level to use for healpix masking strategy + # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells + # rate_sampling: true # randomly sample the rate per batch + losses: num_mini_epochs: 32 -samples_per_mini_epoch: 4096 -samples_per_validation: 512 +samples_per_mini_epoch: 64 #4096 +samples_per_validation: 32 #512 shuffle: True From 89f770ec098d44be55274252eac9722300d89dfd Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 09:22:43 +0100 Subject: [PATCH 155/344] Changed to per masking strategy loss terms --- src/weathergen/train/loss_calculator.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 6a752bfa6..8de59c25d 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -64,11 +64,20 @@ def __init__( self.stage = stage self.device = device - calculator_configs = ( - cf.training_mode_config.losses if stage == TRAIN else cf.validation_mode_config.losses - ) + training_config = cf.get("training_config") + loss_configs = [(t.num_samples, t.loss) for t in training_config.model_input] + + calculator_configs = [] + for num_samples, lc in loss_configs: + for _ in range(num_samples): + calculator_configs += ( + lc.training if stage == TRAIN else lc.get("validation", lc.training) + ) + calculator_configs = [ - (getattr(LossModules, Cls), config) for (Cls, config) in calculator_configs.items() + (getattr(LossModules, Cls), config) + for t in calculator_configs + for (Cls, config) in t.items() ] self.loss_calculators = [ From a93fdb3bc53dbad9ef700dd0a38d0e033796f0a0 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 09:23:26 +0100 Subject: [PATCH 156/344] Removed old masking options. Still needs to be fully cleaned up --- config/default_config.yml | 81 +++++---- src/weathergen/datasets/masking.py | 171 ++++++++---------- .../datasets/multi_stream_data_sampler.py | 13 +- src/weathergen/datasets/tokenizer_masking.py | 4 +- 4 files changed, 128 insertions(+), 141 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index f53499f0c..23b519e24 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -90,30 +90,30 @@ validate_with_ema: True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 -# training mode: "forecast" or "masking" (masked token modeling) -# for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" -training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},} - } -# training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], -# LossLatent: [['mse', 0.3]], -# LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} +# # training mode: "forecast" or "masking" (masked token modeling) +# # for "masking" to train with auto-encoder mode, forecast_offset should be 0 +# # training_mode: "masking" +# training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},} # } -validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},} - } - -# masking -masking_strategy: "random" # TODO -# masking rate when training mode is "masking"; ignored in foreacast mode -masking_rate: 0.6 -# -sampling_rate_target: 1.0 -# sample the masking rate (with normal distribution centered at masking_rate) -# note that a sampled masking rate leads to varying requirements -masking_rate_sampling: True -# masking_strategy_config is a dictionary of additional parameters for the masking strategy -# required for "healpix" and "channel" masking strategies -# "healpix": requires healpix mask level to be specified with `hl_mask` +# # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], +# # LossLatent: [['mse', 0.3]], +# # LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} +# # } +# validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},} +# } + +# # masking +# masking_strategy: "random" # TODO +# # masking rate when training mode is "masking"; ignored in foreacast mode +# masking_rate: 0.6 +# # +# sampling_rate_target: 1.0 +# # sample the masking rate (with normal distribution centered at masking_rate) +# # note that a sampled masking rate leads to varying requirements +# masking_rate_sampling: True +# # masking_strategy_config is a dictionary of additional parameters for the masking strategy +# # required for "healpix" and "channel" masking strategies +# # "healpix": requires healpix mask level to be specified with `hl_mask` # "channel": requires "mode" to be specified, "per_cell" or "global", @@ -134,22 +134,33 @@ training_config: # } # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. relationship: "indepenendent" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - loss : LossPhysical - - masking_strategy: "masking" + loss : + training : + - LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]} + validation : + - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 0.8]]} + - masking_strategy: "random" num_samples: 1 # if student-teacher, the number of local (student) views to generate masking_strategy_config : { diffusion_rn : True, rate : 0.4 } relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - loss : LossPhysical + loss : + training : + - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]} - # target_input: - # masking_strategy: "cropping" # Strategy for teacher (global) view: "random", "healpix" - # rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) - # num_samples: 2 # number of teacher views to generate - # hl_mask : 0 # healpix level to use for healpix masking strategy - # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells - # rate_sampling: true # randomly sample the rate per batch - - losses: + target_input: + - masking_strategy: "random" + num_samples: 1 # if student-teacher, the number of local (student) views to generate + masking_strategy_config : { rate : 0.4 } + - masking_strategy: "random" + num_samples: 1 # if student-teacher, the number of local (student) views to generate + masking_strategy_config : { rate : 0.4 } + + # masking_strategy: "cropping" # Strategy for teacher (global) view: "random", "healpix" + # rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + # num_samples: 2 # number of teacher views to generate + # hl_mask : 0 # healpix level to use for healpix masking strategy + # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells + # rate_sampling: true # randomly sample the rate per batch num_mini_epochs: 32 samples_per_mini_epoch: 64 #4096 diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 38bd25b69..203ba36c6 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -47,59 +47,59 @@ class Masker: def __init__(self, cf: Config): self.rng = None - self.masking_rate = cf.masking_rate - self.masking_strategy = cf.masking_strategy - self.current_strategy = cf.masking_strategy # Current strategy in use - self.masking_rate_sampling = cf.masking_rate_sampling - # masking_strategy_config is a dictionary that can hold any additional parameters - self.healpix_level_data = cf.healpix_level - self.masking_strategy_config = cf.get("masking_strategy_config", {}) - self.perm_sel = None - self.mask_tokens = None - self.mask_channels = None + # self.masking_rate = cf.masking_rate + # self.masking_strategy = cf.masking_strategy + # self.current_strategy = cf.masking_strategy # Current strategy in use + # self.masking_rate_sampling = cf.masking_rate_sampling + # # masking_strategy_config is a dictionary that can hold any additional parameters + # self.healpix_level_data = cf.healpix_level + # self.masking_strategy_config = cf.get("masking_strategy_config", {}) + # self.perm_sel = None + # self.mask_tokens = None + # self.mask_channels = None self.mask_value = 0.0 self.dim_time_enc = 6 # number of healpix cells - self.healpix_num_cells = 12 * (4**self.healpix_level_data) - - # Per-batch strategy tracking - self.same_strategy_per_batch = self.masking_strategy_config.get( - "same_strategy_per_batch", False - ) - self.batch_strategy_set = False - - # Check for required masking_strategy_config at construction time - if self.current_strategy == "healpix": - hl_data = self.healpix_level_data - hl_mask = self.masking_strategy_config.get("hl_mask") - assert hl_data is not None and hl_mask is not None, ( - "If HEALPix masking, hl_mask must be given in masking_strategy_config." - ) - assert hl_mask < hl_data, "hl_mask must be less than hl_data for HEALPix masking." - - if self.current_strategy == "channel": - # Ensure that masking_strategy_config contains either 'global' or 'per_cell' - assert self.masking_strategy_config.get("mode") in [ - "global", - "per_cell", - ], "masking_strategy_config must contain 'mode' key with value 'global' or 'per_cell'." - - # check all streams that source and target channels are identical - for stream in cf.streams: - # check explicit includes - source_include = stream.get("source_include", []) - target_include = stream.get("target_include", []) - assert set(source_include) == set(target_include), ( - "Source and target channels not identical. Required for masking_mode=channel" - ) - # check excludes - source_exclude = stream.get("source_exclude", []) - target_exclude = stream.get("target_exclude", []) - assert set(source_exclude) == set(target_exclude), ( - "Source and target channels not identical. Required for masking_mode=channel" - ) + self.healpix_num_cells = 12 * (4**cf.healpix_level) + + # # Per-batch strategy tracking + # self.same_strategy_per_batch = self.masking_strategy_config.get( + # "same_strategy_per_batch", False + # ) + # self.batch_strategy_set = False + + # # Check for required masking_strategy_config at construction time + # if self.current_strategy == "healpix": + # hl_data = self.healpix_level_data + # hl_mask = self.masking_strategy_config.get("hl_mask") + # assert hl_data is not None and hl_mask is not None, ( + # "If HEALPix masking, hl_mask must be given in masking_strategy_config." + # ) + # assert hl_mask < hl_data, "hl_mask must be less than hl_data for HEALPix masking." + + # if self.current_strategy == "channel": + # # Ensure that masking_strategy_config contains either 'global' or 'per_cell' + # assert self.masking_strategy_config.get("mode") in [ + # "global", + # "per_cell", + # ], "masking_strategy_config must contain 'mode' key with value 'global' or 'per_cell'." + + # # check all streams that source and target channels are identical + # for stream in cf.streams: + # # check explicit includes + # source_include = stream.get("source_include", []) + # target_include = stream.get("target_include", []) + # assert set(source_include) == set(target_include), ( + # "Source and target channels not identical. Required for masking_mode=channel" + # ) + # # check excludes + # source_exclude = stream.get("source_exclude", []) + # target_exclude = stream.get("target_exclude", []) + # assert set(source_exclude) == set(target_exclude), ( + # "Source and target channels not identical. Required for masking_mode=channel" + # ) def reset_rng(self, rng) -> None: """ @@ -107,25 +107,25 @@ def reset_rng(self, rng) -> None: """ self.rng = rng - def set_batch_strategy(self): - """ - Set strategy for this batch. - Only relevant with combination and same_strategy_per_batch. - """ - if self.masking_strategy == "combination" and self.same_strategy_per_batch: - self.current_strategy = self.rng.choice( - self.masking_strategy_config["strategies"], - p=self.masking_strategy_config["probabilities"], - ) - self.batch_strategy_set = True - - def reset_batch_strategy(self): - """ - Reset for next batch. - """ - if self.masking_strategy == "combination" and self.same_strategy_per_batch: - self.current_strategy = None - self.batch_strategy_set = False + # def set_batch_strategy(self): + # """ + # Set strategy for this batch. + # Only relevant with combination and same_strategy_per_batch. + # """ + # if self.masking_strategy == "combination" and self.same_strategy_per_batch: + # self.current_strategy = self.rng.choice( + # self.masking_strategy_config["strategies"], + # p=self.masking_strategy_config["probabilities"], + # ) + # self.batch_strategy_set = True + + # def reset_batch_strategy(self): + # """ + # Reset for next batch. + # """ + # if self.masking_strategy == "combination" and self.same_strategy_per_batch: + # self.current_strategy = None + # self.batch_strategy_set = False def _select_strategy(self): """ @@ -362,7 +362,7 @@ def build_samples_for_stream( for _ in range(target_cfg.get("num_samples", 1)): target_mask, mask_params = self._get_mask( num_cells=num_cells, - strategy=target_cfg.get("strategy"), + strategy=target_cfg.get("masking_strategy"), target_mask=None, masking_strategy_config=target_cfg.get("masking_strategy_config", {}), ) @@ -379,7 +379,7 @@ def build_samples_for_stream( for _ in range(source_cfg.get("num_samples", 1)): source_mask, mask_params = self._get_mask( num_cells=num_cells, - strategy=source_cfg.get("strategy"), + strategy=source_cfg.get("masking_strategy"), masking_strategy_config=source_cfg.get("masking_strategy_config", {}), target_mask=target_masks[i_source], relationship=source_cfg.get("relationship", "independent"), @@ -400,7 +400,6 @@ def _get_mask( self, num_cells: int, strategy: str | None = None, - rate: float | None = None, masking_strategy_config: dict | None = None, target_mask: np.typing.NDArray | None = None, relationship: str = "subset", @@ -414,9 +413,6 @@ def _get_mask( strategy : str | None Cell selection strategy: currently supports 'random' and 'healpix'. Uses instance default if None. - rate : float | None - Fraction of parent cells (healpix) or data cells (random) to keep. Falls back - to instance masking_rate if None. masking_strategy_config : dict | None Optional override of strategy config (e.g., {'hl_mask': 3}). constraint_keep_mask : np.ndarray | None @@ -438,7 +434,7 @@ def _get_mask( return mask, {} # get mask - mask, params = self._generate_cell_mask(num_cells, strategy, rate, masking_strategy_config) + mask, params = self._generate_cell_mask(num_cells, strategy, masking_strategy_config) # handle cases where mask needs to be combined with target_mask if target_mask is not None: @@ -450,11 +446,7 @@ def _get_mask( return (mask, params) def _generate_cell_mask( - self, - num_cells: int, - strategy: str | None = None, - rate: float | None = None, - masking_strategy_config: dict | None = None, + self, num_cells: int, strategy: str, masking_strategy_config: dict ) -> (np.typing.NDArray, dict): """Generate a boolean keep mask at data healpix level (True = keep cell). @@ -465,9 +457,6 @@ def _generate_cell_mask( strategy : str | None Cell selection strategy: currently supports 'random' and 'healpix'. Uses instance default if None. - rate : float | None - Fraction of parent cells (healpix) or data cells (random) to keep. Falls back - to instance masking_rate if None. masking_strategy_config : dict | None Optional override of strategy config (e.g., {'hl_mask': 3}). constraint_keep_mask : np.ndarray | None @@ -485,13 +474,13 @@ def _generate_cell_mask( # get config for mask - strat = strategy or self.masking_strategy - cfg = masking_strategy_config or self.masking_strategy_config - keep_rate = rate if rate is not None else self.masking_rate + cfg = masking_strategy_config + keep_rate = cfg.get("rate", None) + assert keep_rate is not None, 'No sampling rate "rate" specified.' # sample rate if requested (only if explicit rate not provided) - if rate is None and self.masking_rate_sampling: - keep_rate = self._get_sampling_rate() + # if rate is None and self.masking_rate_sampling: + # keep_rate = self._get_sampling_rate() assert 0.0 <= keep_rate <= 1.0, f"keep_rate out of bounds: {keep_rate}" assert num_cells == self.healpix_num_cells, ( @@ -500,16 +489,16 @@ def _generate_cell_mask( # generate cell mask - if strat == "random": + if strategy == "random": mask = self.rng.uniform(0, 1, num_cells) < keep_rate - elif "forecast" in strat or strat == "causal": + elif "forecast" in strategy or strategy == "causal": mask = np.ones(num_cells, dtype=np.bool) if "diffusion_rn" in masking_strategy_config: masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) - elif strat == "healpix": + elif strategy == "healpix": hl_data = self.healpix_level_data hl_mask = cfg.get("hl_mask") assert hl_mask is not None and hl_mask < hl_data, ( @@ -533,7 +522,7 @@ def _generate_cell_mask( else: raise NotImplementedError( - f"Cell selection strategy '{strat}' not supported for keep mask generation." + f"Cell selection strategy '{strategy}' not supported for keep mask generation." ) mask = to_bool_tensor(mask) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 2bec1f472..48f541640 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -199,7 +199,6 @@ def __init__( self.shuffle = shuffle # TODO: remove options that are no longer supported self.input_window_steps = cf.input_window_steps - self.sampling_rate_target = cf.sampling_rate_target self.batch_size = batch_size @@ -387,7 +386,6 @@ def _build_stream_data_output( if "target_coords" in mode: (tc, tc_l) = self.tokenizer.get_target_coords( stream_info, - self.sampling_rate_target, rdata, token_data, (time_win_target.start, time_win_target.end), @@ -398,7 +396,6 @@ def _build_stream_data_output( if "target_values" in mode: (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( stream_info, - self.sampling_rate_target, rdata, token_data, (time_win_target.start, time_win_target.end), @@ -714,17 +711,9 @@ def __iter__(self) -> ModelBatch: idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] idx_raw += 1 - # Sample masking strategy once per batch item - # TODO: still needed? - self.tokenizer.masker.set_batch_strategy() - batch = self._get_sample(idx, forecast_dt) - # Reset masking strategy for next batch item - # TODO: still needed? - self.tokenizer.masker.reset_batch_strategy() - - # # skip completely empty batch item or when all targets are empty -> no grad + # skip completely empty batch item or when all targets are empty -> no grad if not batch.is_empty(): break else: diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index d42ec6324..bed7c16dd 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -147,7 +147,7 @@ def get_source( # capture per-view mask state to later produce consistent targets mask_state = { - "strategy": self.masker.current_strategy, + "strategy": None, # self.masker.current_strategy, "mask_tokens": mask_tokens, "mask_channels": mask_channels, } @@ -203,7 +203,6 @@ def get_target( def get_target_coords( self, stream_info: dict, - sampling_rate_target: float, rdata: IOReaderData, token_data, time_win: tuple, @@ -263,7 +262,6 @@ def get_target_coords( def get_target_values( self, stream_info: dict, - sampling_rate_target: float, rdata: IOReaderData, token_data, time_win: tuple, From 454dffb9c2eb0ece49a47c1e8019eac099948f7a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 17:39:21 +0100 Subject: [PATCH 157/344] More robust handling of empty streams --- src/weathergen/datasets/batch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index b9852c030..71990f53f 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -80,7 +80,9 @@ def is_empty(self) -> bool: """ Check if sample is empty """ - return np.all(np.array([s.empty() for _, s in self.streams_data.items()])) + return np.all( + np.array([s.empty() if s is not None else True for _, s in self.streams_data.items()]) + ) def add_stream_data(self, stream_name: str, stream_data: StreamData) -> None: """ @@ -200,8 +202,12 @@ def is_empty(self): """ Check if batch is empty """ - source_empty = np.all(np.array([s.is_empty() for s in self.source_samples])) - target_empty = np.all(np.array([s.is_empty() for s in self.target_samples])) + source_empty = np.all( + np.array([s.is_empty() if s is not None else True for s in self.source_samples]) + ) + target_empty = np.all( + np.array([s.is_empty() if s is not None else True for s in self.target_samples]) + ) return source_empty or target_empty def set_forecast_dt(self, forecast_dt: int) -> None: From 5cbbaa36cf18164060b5f0a925ef94b9c110230d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 17:40:24 +0100 Subject: [PATCH 158/344] Fixed incorrect handling of empty target_coords_idx --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 48f541640..be6ce47a6 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -670,7 +670,7 @@ def _preprocess_model_batch_sample(self, sample: Sample, forecast_dt: int): """ """ streams = [sd for sd in sample.streams_data.values() if sd is not None] if not streams: - sample.set_preprocessed([], []) + sample.set_preprocessed([], {}) return _, scl, tci = self._preprocess_model_data([streams], forecast_dt) sample.set_preprocessed(scl, tci) From 9c7474171cb4d93cda38a5e6baac7ce5dd4afcbf Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 17:40:52 +0100 Subject: [PATCH 159/344] Fixed problem when number of model and target samples is different --- src/weathergen/datasets/masking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 203ba36c6..1e6d96320 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -381,7 +381,7 @@ def build_samples_for_stream( num_cells=num_cells, strategy=source_cfg.get("masking_strategy"), masking_strategy_config=source_cfg.get("masking_strategy_config", {}), - target_mask=target_masks[i_source], + target_mask=target_masks[i_source % len(target_masks)], relationship=source_cfg.get("relationship", "independent"), ) source_masks += [source_mask] From 085b55fd1bb7d4b32858317fe522f57c64a46e45 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 17:41:20 +0100 Subject: [PATCH 160/344] Example for config with non-trivial model and target inputs --- config/default_config.yml | 63 +++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 23b519e24..65262fe01 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -124,7 +124,7 @@ training_config: training_mode: "masking" # "masking", "student_teacher", "forecast" model_input: - - masking_strategy: "forecast" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher + - masking_strategy: "random" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher num_samples: 1 # if student-teacher, the number of local (student) views to generate masking_strategy_config : { diffusion_rn : True, rate : 0.4 } # masking_strategy_config: {"strategies": ["random", "healpix", "channel"], @@ -133,34 +133,65 @@ training_config: # "same_strategy_per_batch": false # } # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. - relationship: "indepenendent" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + # loss : ibot loss : training : - LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]} validation : - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 0.8]]} - masking_strategy: "random" - num_samples: 1 # if student-teacher, the number of local (student) views to generate + num_samples: 2 # if student-teacher, the number of local (student) views to generate masking_strategy_config : { diffusion_rn : True, rate : 0.4 } - relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + # loss : dino loss : training : - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]} target_input: - - masking_strategy: "random" - num_samples: 1 # if student-teacher, the number of local (student) views to generate - masking_strategy_config : { rate : 0.4 } - - masking_strategy: "random" - num_samples: 1 # if student-teacher, the number of local (student) views to generate - masking_strategy_config : { rate : 0.4 } + - masking_strategy: "random" # Strategy for teacher (global) view: "random", "healpix" + num_samples: 1 # number of teacher views to generate + masking_strategy_config : { diffusion_rn : True, rate : 0.4, rate_sampling: true } + + # model_input: + # - masking_strategy: "forecast" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher + # num_samples: 1 # if student-teacher, the number of local (student) views to generate + # masking_strategy_config : { diffusion_rn : True, rate : 0.4 } + # # masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + # # "probabilities": [0.34, 0.33, 0.33], + # # "hl_mask": 3, "mode": "per_cell", + # # "same_strategy_per_batch": false + # # } + # # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. + # relationship: "indepenendent" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + # loss : + # training : + # - LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]} + # validation : + # - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 0.8]]} + # - masking_strategy: "random" + # num_samples: 1 # if student-teacher, the number of local (student) views to generate + # masking_strategy_config : { diffusion_rn : True, rate : 0.4 } + # relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + # loss : + # training : + # - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]} + + # target_input: + # - masking_strategy: "random" + # num_samples: 1 # if student-teacher, the number of local (student) views to generate + # masking_strategy_config : { rate : 0.4 } + # - masking_strategy: "random" + # num_samples: 1 # if student-teacher, the number of local (student) views to generate + # masking_strategy_config : { rate : 0.4 } - # masking_strategy: "cropping" # Strategy for teacher (global) view: "random", "healpix" - # rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) - # num_samples: 2 # number of teacher views to generate - # hl_mask : 0 # healpix level to use for healpix masking strategy - # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells - # rate_sampling: true # randomly sample the rate per batch + # # masking_strategy: "cropping" # Strategy for teacher (global) view: "random", "healpix" + # # rate: 0.1 # Fraction of data to keep in global view (alternative: use "keep_m" for absolute count) + # # num_samples: 2 # number of teacher views to generate + # # hl_mask : 0 # healpix level to use for healpix masking strategy + # # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells + # # rate_sampling: true # randomly sample the rate per batch num_mini_epochs: 32 samples_per_mini_epoch: 64 #4096 From 4dac76d429cc91c806c7ba7b8a856a61e6065e1a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 20:16:59 +0100 Subject: [PATCH 161/344] Fixed bug in total sample counting --- src/weathergen/datasets/multi_stream_data_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index be6ce47a6..319544e53 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -642,8 +642,8 @@ def _get_source_target_masks(self, training_mode): source_cfgs = self.training_cfg.get("model_input") target_cfgs = self.training_cfg.get("target_input", source_cfgs) target_cfgs = target_cfgs if target_cfgs is not None else source_cfgs - num_target_samples = np.array([sc.get("num_samples", 1) for sc in source_cfgs]).sum().item() - num_source_samples = np.array([sc.get("num_samples", 1) for sc in target_cfgs]).sum().item() + num_source_samples = np.array([sc.get("num_samples", 1) for sc in source_cfgs]).sum().item() + num_target_samples = np.array([sc.get("num_samples", 1) for sc in target_cfgs]).sum().item() return masks, num_source_samples, num_target_samples From fe2f63a8fd4010751e8362df07ce8842b463be14 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 5 Dec 2025 20:17:16 +0100 Subject: [PATCH 162/344] Re-enabled missing healpix level --- src/weathergen/datasets/masking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 1e6d96320..fe8687cf7 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -52,7 +52,6 @@ def __init__(self, cf: Config): # self.current_strategy = cf.masking_strategy # Current strategy in use # self.masking_rate_sampling = cf.masking_rate_sampling # # masking_strategy_config is a dictionary that can hold any additional parameters - # self.healpix_level_data = cf.healpix_level # self.masking_strategy_config = cf.get("masking_strategy_config", {}) # self.perm_sel = None # self.mask_tokens = None @@ -62,6 +61,7 @@ def __init__(self, cf: Config): self.dim_time_enc = 6 # number of healpix cells + self.healpix_level_data = cf.healpix_level self.healpix_num_cells = 12 * (4**cf.healpix_level) # # Per-batch strategy tracking From b9195bbbcb391b631b513b4dc5046fc980dd19f7 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 6 Dec 2025 16:49:40 +0100 Subject: [PATCH 163/344] Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg). --- src/weathergen/datasets/batch.py | 2 +- .../datasets/multi_stream_data_sampler.py | 225 +++++++++--------- src/weathergen/datasets/stream_data.py | 4 +- src/weathergen/datasets/utils.py | 7 +- 4 files changed, 121 insertions(+), 117 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 71990f53f..661dd255f 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -68,7 +68,7 @@ def to_device(self, device) -> None: for key in self.meta_info.keys(): self.meta_info[key].mask = ( self.meta_info[key].mask.to(device, non_blocking=True) - if self.meta_info[key].mask + if self.meta_info[key].mask is not None else None ) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 319544e53..dee65a82c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -328,30 +328,31 @@ def _build_stream_data_input( StreamData with source and targets masked according to view_meta """ - # iterate overall input steps - for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): - # TODO: check that we are not out of bounds when we go back in time + if "network_input" in mode: + # iterate overall input steps + for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): + # TODO: check that we are not out of bounds when we go back in time - time_win_source = self.time_window_handler.window(idx) + time_win_source = self.time_window_handler.window(idx) - # collect all targets for current stream - # do we want this to be ascending or descending in time? - rdata = input_data[-(step + 1)] - token_data = input_tokens[-(step + 1)] - - stream_data.source_is_spoof = rdata.is_spoof - - # preprocess data for model input - (source_cells, source_cells_lens, mask_state) = self.tokenizer.get_source( - stream_info, - rdata, - token_data, - (time_win_source.start, time_win_source.end), - mask, - ) + # collect all targets for current stream + # do we want this to be ascending or descending in time? + rdata = input_data[-(step + 1)] + token_data = input_tokens[-(step + 1)] + + stream_data.source_is_spoof = rdata.is_spoof + + # preprocess data for model input + (source_cells, source_cells_lens, mask_state) = self.tokenizer.get_source( + stream_info, + rdata, + token_data, + (time_win_source.start, time_win_source.end), + mask, + ) - # collect data for stream - stream_data.add_source(step, rdata, source_cells_lens, source_cells) + # collect data for stream + stream_data.add_source(step, rdata, source_cells_lens, source_cells) return stream_data @@ -407,7 +408,7 @@ def _build_stream_data_output( def _build_stream_data( self, - mode: str, + modes: str, base_idx: TIndex, forecast_dt: int, stream_info: dict, @@ -415,21 +416,25 @@ def _build_stream_data( output_data: list, input_tokens: list, output_tokens: list, - target_mask, - source_mask, + output_mask, + input_mask, ) -> StreamData: """ Return one batch of data Build a StreamData object for a single view (teacher or student). Args: - mode : + modes : stream_data : base_idx: Time index for this sample forecast_dt: Number of forecast steps stream_info: Stream configuration dict stream_ds: List of dataset readers for this stream + output_mask : mask for output/prediction/target + input_mask : mask for network input (can be source or target) + + Returns: StreamData with source and targets masked according to view_meta """ @@ -438,24 +443,24 @@ def _build_stream_data( stream_data = StreamData(base_idx, dt, self.num_healpix_cells) stream_data = self._build_stream_data_input( - mode, + modes, stream_data, base_idx, stream_info, input_data, input_tokens, - source_mask, + input_mask, ) stream_data = self._build_stream_data_output( - mode, + modes, stream_data, base_idx, stream_info, forecast_dt, output_data, output_tokens, - target_mask, + output_mask, ) return stream_data @@ -511,7 +516,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): return (input_data, output_data) - def _get_sample(self, idx: int, forecast_dt: int): + def _get_batch(self, idx: int, forecast_dt: int): """ modes : @@ -528,92 +533,86 @@ def _get_sample(self, idx: int, forecast_dt: int): # TODO: should also return number of views masks_streams, num_source_samples, num_target_samples = self._get_source_target_masks(mode) - if mode == "masking" or mode == "student_teacher": - batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) - - # for all streams - for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - name = stream_info["name"] - - # TODO: data class for this or something similar - ( - target_masks, - source_masks, - student_to_teacher, - target_metadata_list, - source_metadata_list, - ) = masks_streams[name] - - # input_data and output_data is conceptually consecutive but differs - # in source and target channels; overlap in one window when self.forecast_offset=0 - (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) - - # tokenize windows - # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps - input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) - output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) - - # collect source data for current stream - # loop over student views - for sidx, (target_mask, source_mask) in enumerate( - zip(target_masks, source_masks, strict=False) - ): - sdata = self._build_stream_data( - "target_coords target_values", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - target_mask, - source_mask, - ) + if mode == "masking": + source_select = ["network_input", "target_coords"] + target_select = ["target_values"] + elif mode == "student_teacher": + source_select = ["network_input"] + target_select = ["network_input"] + else: + raise NotImplementedError(f"Unsupported training mode {mode}.") + + batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) + + # for all streams + for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): + stream_name = stream_info["name"] + + # TODO: data class for this or something similar + ( + target_masks, + source_masks, + student_to_teacher, + target_metadata_list, + source_metadata_list, + ) = masks_streams[stream_name] + + # input_data and output_data is conceptually consecutive but differs + # in source and target channels; overlap in one window when self.forecast_offset=0 + (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) + + # tokenize windows + # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps + input_tokens = self.tokenizer.get_tokens_windows(stream_info, input_data, True) + output_tokens = self.tokenizer.get_tokens_windows(stream_info, output_data, False) + + # collect source data for current stream + # loop over student views + for sidx, source_mask in enumerate(source_masks): + sdata = self._build_stream_data( + source_select, + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + target_masks[student_to_teacher[sidx]], + source_mask, + ) - # Map each student (source) to its teacher (target) - t_idx = student_to_teacher[sidx] - batch.add_source_stream(sidx, t_idx, name, sdata, source_metadata_list[sidx]) - - # stream_data_target can contain network input - stream_data_target = {} - - # for t_idx, mask in enumerate(source_masks): - for sidx, (target_mask, source_mask) in enumerate( - zip(target_masks, source_masks, strict=False) - ): - # stream_data_target[name] = self._build_stream_data( - sdata = self._build_stream_data( - "target_values", - idx, - forecast_dt, - stream_info, - input_data, - output_data, - input_tokens, - output_tokens, - target_mask, - source_mask, - ) - stream_data_target[name] = sdata + # also want to add the mask to the metadata + source_metadata = source_metadata_list[sidx] + source_metadata.mask = source_mask - # get teacher config info - # TODO, TODO, TODO: is this correct? - t_idx = sidx - target_metadata = target_metadata_list[t_idx] + # map each source to its target + t_idx = student_to_teacher[sidx] + batch.add_source_stream(sidx, t_idx, stream_name, sdata, source_metadata) - # also want to add the mask to the metadata - target_metadata.mask = None # target_mask + for sidx, target_mask in enumerate(target_masks): + sdata = self._build_stream_data( + target_select, + idx, + forecast_dt, + stream_info, + input_data, + output_data, + input_tokens, + output_tokens, + target_mask, + target_mask, + ) - # TODO: seb to check - # Map target to all source students - student_indices = [ - s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == sidx - ] - batch.add_target_stream(t_idx, student_indices, name, sdata, target_metadata) + # get target config info + target_metadata = target_metadata_list[sidx] + target_metadata.mask = target_mask - else: - assert False, "Mode not implemented" + # find indices of all sources for current target + student_indices = [ + s_idx for s_idx, tid in enumerate(student_to_teacher) if tid == sidx + ] + batch.add_target_stream(sidx, student_indices, stream_name, sdata, target_metadata) return batch @@ -638,7 +637,7 @@ def _get_source_target_masks(self, training_mode): source_data[1], ) - # Determine number of views direct from config (teacher & student views) + # Determine number of samples directly from config (teacher and student views) source_cfgs = self.training_cfg.get("model_input") target_cfgs = self.training_cfg.get("target_input", source_cfgs) target_cfgs = target_cfgs if target_cfgs is not None else source_cfgs @@ -711,7 +710,7 @@ def __iter__(self) -> ModelBatch: idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] idx_raw += 1 - batch = self._get_sample(idx, forecast_dt) + batch = self._get_batch(idx, forecast_dt) # skip completely empty batch item or when all targets are empty -> no grad if not batch.is_empty(): diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 19cf94b18..c8727fe9b 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -294,7 +294,9 @@ def target_empty(self) -> bool: """ # cat over forecast steps - return torch.cat(self.target_coords_lens).sum() == 0 + target_coords_empty = torch.cat(self.target_coords_lens).sum() == 0 + target_tokens_empty = torch.cat(self.target_tokens).sum() == 0 + return target_coords_empty and target_tokens_empty def source_empty(self) -> bool: """ diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 726936bb5..4e5abaddb 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -289,7 +289,7 @@ def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> St torch.stack( [ s.source_tokens_lens[i] - if len(s.source_tokens_lens[i]) > 0 + if (len(s.source_tokens_lens) > 0) and (len(s.source_tokens_lens[i]) > 0) else torch.tensor([]) for s in stl_b ] @@ -305,6 +305,9 @@ def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> St offsets = [torch.cat([torch.zeros(1, dtype=torch.int32), o[:-1]]) for o in offsets_base] offsets_pe = [torch.zeros_like(o) for o in offsets] + if torch.cat(offsets_base).shape[0] == 0: + return batch + for i_s in range(num_input_steps): for ib, sb in enumerate(batch): # batch items for itype, s in enumerate(sb): # streams, i.e. here we have StreamData object @@ -400,7 +403,7 @@ def compute_source_cell_lens( torch.stack( [ s.source_tokens_lens[i] - if len(s.source_tokens_lens[i]) > 0 + if (len(s.source_tokens_lens) > 0) and (len(s.source_tokens_lens[i]) > 0) else torch.tensor([]) for s in stl_b ] From 43f9b01d9b6eb435e6faa4a86d7b1cc870a51b1a Mon Sep 17 00:00:00 2001 From: kctezcan Date: Sat, 6 Dec 2025 17:20:05 +0100 Subject: [PATCH 164/344] An encoder formed by embedding + local assimilation + global assimilation (#1397) * initial changes * more changes * removed extra print parameters statement * changed names for backward checkpoint loading * added encoder. to module names in sharding * adding encoder. to embed_engine * added back the conditions for param printong * lint * forecast config * switch back to MTM config * lint --- src/weathergen/datasets/utils_test.py | 3 +- src/weathergen/model/encoder.py | 290 ++++++++++++++++++++++++ src/weathergen/model/model.py | 279 +++-------------------- src/weathergen/model/model_interface.py | 8 +- 4 files changed, 322 insertions(+), 258 deletions(-) create mode 100644 src/weathergen/model/encoder.py diff --git a/src/weathergen/datasets/utils_test.py b/src/weathergen/datasets/utils_test.py index 2e614937e..9794b6b9c 100644 --- a/src/weathergen/datasets/utils_test.py +++ b/src/weathergen/datasets/utils_test.py @@ -1,10 +1,9 @@ import torch -from torch import Tensor, tensor +from torch import tensor from weathergen.datasets.utils import ( locs_to_cell_coords_ctrs, locs_to_ctr_coords, - s2tor3, vecs_to_rots, ) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py new file mode 100644 index 000000000..aad1e7e3a --- /dev/null +++ b/src/weathergen/model/encoder.py @@ -0,0 +1,290 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch +from astropy_healpix import healpy + +from weathergen.common.config import Config +from weathergen.model.engines import ( + EmbeddingEngine, + GlobalAssimilationEngine, + Local2GlobalAssimilationEngine, + LocalAssimilationEngine, + QueryAggregationEngine, +) + +# from weathergen.model.model import ModelParams +from weathergen.model.parametrised_prob_dist import LatentInterpolator +from weathergen.utils.utils import get_dtype + + +class EncoderModule(torch.nn.Module): + name: "EncoderModule" + + def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coords_size) -> None: + """ + Initialize the EmbeddingEngine with the configuration. + + :param cf: Configuration object containing parameters for the engine. + :param sources_size: List of source sizes for each stream. + :param stream_names: Ordered list of stream identifiers aligned with cf.streams. + """ + super(EncoderModule, self).__init__() + self.cf = cf + + self.healpix_level = cf.healpix_level + self.num_healpix_cells = 12 * 4**self.healpix_level + + self.cf = cf + self.dtype = get_dtype(self.cf.attention_dtype) + self.sources_size = sources_size + self.targets_num_channels = targets_num_channels + self.targets_coords_size = targets_coords_size + + self.ae_aggregation_engine: QueryAggregationEngine | None = None + self.ae_global_engine: GlobalAssimilationEngine | None = None + self.ae_local_engine: LocalAssimilationEngine | None = None + self.ae_local_global_engine: Local2GlobalAssimilationEngine | None = None + self.embed_engine: EmbeddingEngine | None = None + self.interpolate_latents: LatentInterpolator | None = None + + ############## + # embedding engine + # determine stream names once so downstream components use consistent keys + self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] + # separate embedding networks for differnt observation types + self.embed_engine = EmbeddingEngine(cf, self.sources_size, self.stream_names) + + ############## + # local assimilation engine + self.ae_local_engine = LocalAssimilationEngine(cf) + + if cf.latent_noise_kl_weight > 0.0: + self.interpolate_latents = LatentInterpolator( + gamma=cf.latent_noise_gamma, + dim=cf.ae_local_dim_embed, + use_additive_noise=cf.latent_noise_use_additive_noise, + deterministic=cf.latent_noise_deterministic_latents, + ) + + ############## + # local -> global assimilation engine adapter + self.ae_local_global_engine = Local2GlobalAssimilationEngine(cf) + + ############## + # learnable queries + if cf.ae_local_queries_per_cell: + s = (self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed) + q_cells = torch.rand(s, requires_grad=True) / cf.ae_global_dim_embed + # add meta data + q_cells[:, :, -8:-6] = ( + (torch.arange(self.num_healpix_cells) / self.num_healpix_cells) + .unsqueeze(1) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 2)) + ) + theta, phi = healpy.pix2ang( + nside=2**self.healpix_level, ipix=torch.arange(self.num_healpix_cells) + ) + q_cells[:, :, -6:-3] = ( + torch.cos(theta).unsqueeze(1).unsqueeze(1).repeat((1, cf.ae_local_num_queries, 3)) + ) + q_cells[:, :, -3:] = ( + torch.sin(phi).unsqueeze(1).unsqueeze(1).repeat((1, cf.ae_local_num_queries, 3)) + ) + q_cells[:, :, -9] = torch.arange(cf.ae_local_num_queries) + q_cells[:, :, -10] = torch.arange(cf.ae_local_num_queries) + else: + s = (1, cf.ae_local_num_queries, cf.ae_global_dim_embed) + q_cells = torch.rand(s, requires_grad=True) / cf.ae_global_dim_embed + self.q_cells = torch.nn.Parameter(q_cells, requires_grad=True) + + ############## + # query aggregation engine + self.ae_aggregation_engine = QueryAggregationEngine(cf, self.num_healpix_cells) + + ############## + # global assimilation engine + self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) + + def forward(self, model_params, sample): + # embed + tokens = self.embed_cells(model_params, sample) + + # local assimilation engine and adapter + tokens, posteriors = self.assimilate_local(model_params, tokens, sample) + + tokens = self.assimilate_global(tokens) + + return tokens, posteriors + + ######################################### + def embed_cells(self, model_params, sample) -> torch.Tensor: + """Embeds input data for each stream separately and rearranges it to cell-wise order + Args: + model_params : Query and embedding parameters + streams_data : Used to initialize first tokens for pre-processing + Returns:uv + Tokens for local assimilation + """ + + device = next(self.parameters()).device + tokens_all = self.embed_engine(sample, model_params.pe_embed, self.dtype, device) + + return tokens_all + + ######################################### + def assimilate_local( + self, model_params, tokens: torch.Tensor, sample: torch.Tensor + ) -> torch.Tensor: + """Processes embedded tokens locally and prepares them for the global assimilation + Args: + model_params : Query and embedding parameters + tokens : Input tokens to be processed by local assimilation + cell_lens : Used to identify range of tokens to use from generated tokens in cell + embedding + Returns: + Tokens for global assimilation + """ + + cell_lens = sample.source_cell_lens + batch_size = ( + self.cf.batch_size_per_gpu if self.training else self.cf.batch_size_validation_per_gpu + ) + + s = self.q_cells.shape + # print( f'{np.prod(np.array(tokens.shape))} :: {np.prod(np.array(s))}' + # + ':: {np.prod(np.array(tokens.shape))/np.prod(np.array(s))}') + # TODO: test if positional encoding is needed here + if self.cf.ae_local_queries_per_cell: + tokens_global = (self.q_cells + model_params.pe_global).repeat(batch_size, 1, 1) + else: + tokens_global = ( + self.q_cells.repeat(self.num_healpix_cells, 1, 1) + model_params.pe_global + ) + q_cells_lens = torch.cat( + [model_params.q_cells_lens[0].unsqueeze(0)] + + [model_params.q_cells_lens[1:] for _ in range(batch_size)] + ) + + # local assimilation model + # for block in self.ae_local_blocks: + # tokens = checkpoint(block, tokens, cell_lens, use_reentrant=False) + + # if self.cf.latent_noise_kl_weight > 0.0: + # tokens, posteriors = self.interpolate_latents.interpolate_with_noise( + # tokens, sampling=self.training + # ) + # else: + # tokens, posteriors = tokens, 0.0 + + # for block in self.ae_adapter: + # tokens_global = checkpoint( + # block, + # tokens_global, + # tokens, + # q_cells_lens, + # cell_lens, + # use_reentrant=False, + # ) + + # work around to bug in flash attention for hl>=5 + + istep = 0 + + cell_lens = cell_lens[istep][1:] + clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) + tokens_global_unmasked_all = [] + posteriors = [] + zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) + for i in range((cell_lens.shape[0]) // clen): + # make sure we properly catch all elements in last chunk + i_end = (i + 1) * clen if i < (cell_lens.shape[0] // clen) - 1 else cell_lens.shape[0] + l0, l1 = ( + (0 if i == 0 else cell_lens[: i * clen].cumsum(0)[-1]), + cell_lens[:i_end].cumsum(0)[-1], + ) + + tokens_c = tokens[l0:l1] + tokens_global_c = tokens_global[i * clen : i_end] + cell_lens_c = torch.cat([zero_pad, cell_lens[i * clen : i_end]]) + q_cells_lens_c = q_cells_lens[: cell_lens_c.shape[0]] + + # local assimilation model + tokens_c = self.ae_local_engine(tokens_c, cell_lens_c, use_reentrant=False) + + if self.cf.latent_noise_kl_weight > 0.0: + tokens_c, posteriors_c = self.interpolate_latents.interpolate_with_noise( + tokens_c, sampling=self.training + ) + posteriors += [posteriors_c] + else: + tokens_c, posteriors = tokens_c, 0.0 + + # create mask for global tokens, without first element (used for padding) + mask_c = cell_lens_c[1:].to(torch.bool) + tokens_global_unmasked_c = tokens_global_c[mask_c] + q_cells_lens_unmasked_c = torch.cat([zero_pad, q_cells_lens_c[1:][mask_c]]) + cell_lens_unmasked_c = torch.cat([zero_pad, cell_lens_c[1:][mask_c]]) + + if l0 == l1 or tokens_c.shape[0] == 0: + tokens_global_unmasked_all += [tokens_global_unmasked_c] + continue + + # local to global adapter engine + tokens_global_unmasked_c = self.ae_local_global_engine( + tokens_c, + tokens_global_unmasked_c, + q_cells_lens_unmasked_c, + cell_lens_unmasked_c, + use_reentrant=False, + ) + + tokens_global_unmasked_all += [tokens_global_unmasked_c] + + tokens_global_unmasked = torch.cat(tokens_global_unmasked_all) + + # query aggregation engine on the query tokens in unmasked cells + # (applying this here assumes batch_size=1) + # permute to use ae_local_num_queries as the batchsize and no_of_tokens + # as seq len for flash attention + tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) + tokens_global_unmasked = self.ae_aggregation_engine( + tokens_global_unmasked, use_reentrant=False + ) + tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) + + # create mask from cell lens + mask = cell_lens.to(torch.bool) + + # fill empty tensor using mask for positions of unmasked tokens + tokens_global[mask] = tokens_global_unmasked.to(tokens_global.dtype) + + # recover batch dimension and build global token list + tokens_global = ( + tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]]) + + model_params.pe_global + ).flatten(1, 2) + + return tokens_global, posteriors + + ######################################### + def assimilate_global(self, tokens: torch.Tensor) -> torch.Tensor: + """Performs transformer based global assimilation in latent space + Args: + model_params : Query and embedding parameters (never used) + tokens : Input tokens to be pre-processed by global assimilation + Returns: + Latent representation of the model + """ + + # global assimilation engine and adapter + tokens = self.ae_global_engine(tokens, use_reentrant=False) + + return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index c66c5fa78..00efbe6d8 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -19,23 +19,17 @@ import numpy as np import torch import torch.nn as nn -from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( - EmbeddingEngine, EnsPredictionHead, ForecastingEngine, - GlobalAssimilationEngine, - Local2GlobalAssimilationEngine, - LocalAssimilationEngine, - QueryAggregationEngine, TargetPredictionEngine, TargetPredictionEngineClassic, ) from weathergen.model.layers import MLP, NamedLinear -from weathergen.model.parametrised_prob_dist import LatentInterpolator from weathergen.model.utils import get_num_parameters from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype @@ -269,14 +263,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size - self.ae_aggregation_engine: QueryAggregationEngine | None = None - self.ae_global_engine: GlobalAssimilationEngine | None = None - self.ae_local_engine: LocalAssimilationEngine | None = None - self.ae_local_global_engine: Local2GlobalAssimilationEngine | None = None - self.embed_engine: EmbeddingEngine | None = None self.embed_target_coords = None self.forecast_engine: ForecastingEngine | None = None - self.interpolate_latents: LatentInterpolator | None = None self.pred_adapter_kv = None self.pred_heads = None self.q_cells: torch.Tensor | None = None @@ -288,62 +276,9 @@ def create(self) -> "Model": """Create each individual module of the model""" cf = self.cf - # determine stream names once so downstream components use consistent keys - self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] - # separate embedding networks for differnt observation types - self.embed_engine = EmbeddingEngine(cf, self.sources_size, self.stream_names) - - ############## - # local assimilation engine - self.ae_local_engine = LocalAssimilationEngine(cf) - - if cf.latent_noise_kl_weight > 0.0: - self.interpolate_latents = LatentInterpolator( - gamma=cf.latent_noise_gamma, - dim=cf.ae_local_dim_embed, - use_additive_noise=cf.latent_noise_use_additive_noise, - deterministic=cf.latent_noise_deterministic_latents, - ) - - ############## - # local -> global assimilation engine adapter - self.ae_local_global_engine = Local2GlobalAssimilationEngine(cf) - - ############## - # learnable queries - if cf.ae_local_queries_per_cell: - s = (self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed) - q_cells = torch.rand(s, requires_grad=True) / cf.ae_global_dim_embed - # add meta data - q_cells[:, :, -8:-6] = ( - (torch.arange(self.num_healpix_cells) / self.num_healpix_cells) - .unsqueeze(1) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 2)) - ) - theta, phi = healpy.pix2ang( - nside=2**self.healpix_level, ipix=torch.arange(self.num_healpix_cells) - ) - q_cells[:, :, -6:-3] = ( - torch.cos(theta).unsqueeze(1).unsqueeze(1).repeat((1, cf.ae_local_num_queries, 3)) - ) - q_cells[:, :, -3:] = ( - torch.sin(phi).unsqueeze(1).unsqueeze(1).repeat((1, cf.ae_local_num_queries, 3)) - ) - q_cells[:, :, -9] = torch.arange(cf.ae_local_num_queries) - q_cells[:, :, -10] = torch.arange(cf.ae_local_num_queries) - else: - s = (1, cf.ae_local_num_queries, cf.ae_global_dim_embed) - q_cells = torch.rand(s, requires_grad=True) / cf.ae_global_dim_embed - self.q_cells = torch.nn.Parameter(q_cells, requires_grad=True) - - ############## - # query aggregation engine - self.ae_aggregation_engine = QueryAggregationEngine(cf, self.num_healpix_cells) - - ############## - # global assimilation engine - self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) + self.encoder = EncoderModule( + cf, self.sources_size, self.targets_num_channels, self.targets_coords_size + ) ############### # forecasting engine @@ -366,6 +301,9 @@ def create(self) -> "Model": self.pred_adapter_kv = torch.nn.ModuleDict() self.pred_heads = torch.nn.ModuleDict() + # determine stream names once so downstream components use consistent keys + self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] + for i_obs, si in enumerate(cf.streams): stream_name = self.stream_names[i_obs] @@ -494,16 +432,19 @@ def print_num_parameters(self) -> None: cf = self.cf num_params_embed = [ - get_num_parameters(self.embed_engine.embeds[name]) for name in self.stream_names + get_num_parameters(self.encoder.embed_engine.embeds[name]) for name in self.stream_names ] num_params_total = get_num_parameters(self) - num_params_ae_local = get_num_parameters(self.ae_local_engine.ae_local_blocks) - num_params_ae_global = get_num_parameters(self.ae_global_engine.ae_global_blocks) + num_params_ae_local = get_num_parameters(self.encoder.ae_local_engine.ae_local_blocks) + num_params_ae_global = get_num_parameters(self.encoder.ae_global_engine.ae_global_blocks) + + num_params_q_cells = ( + np.prod(self.encoder.q_cells.shape) if self.encoder.q_cells.requires_grad else 0 + ) + num_params_ae_adapater = get_num_parameters(self.encoder.ae_local_global_engine.ae_adapter) - num_params_q_cells = np.prod(self.q_cells.shape) if self.q_cells.requires_grad else 0 - num_params_ae_adapater = get_num_parameters(self.ae_local_global_engine.ae_adapter) num_params_ae_aggregation = get_num_parameters( - self.ae_aggregation_engine.ae_aggregation_blocks + self.encoder.ae_aggregation_engine.ae_aggregation_blocks ) num_params_fe = get_num_parameters(self.forecast_engine.fe_blocks) @@ -559,11 +500,16 @@ def rename_old_state_dict(self, params: dict) -> dict: new_params : Dictionary with (renamed) model parameters """ params_cleanup = { - "embeds": "embed_engine.embeds", # EmbeddingEngine - "ae_local_blocks": "ae_local_engine.ae_local_blocks", # LocalAssimilationEngine - "ae_adapter": "ae_local_global_engine.ae_adapter", # Local2GlobalAssimilationEngine - "ae_global_blocks": "ae_global_engine.ae_global_blocks", # GlobalAssimilationEngine - "fe_blocks": "forecast_engine.fe_blocks", # ForecastingEngine + # EmbeddingEngine + "embeds": "encoder.embed_engine.embeds", + # LocalAssimilationEngine + "ae_local_blocks": "encoder.ae_local_engine.ae_local_blocks", + # Local2GlobalAssimilationEngine + "ae_adapter": "encoder.ae_local_global_engine.ae_adapter", + # GlobalAssimilationEngine + "ae_global_blocks": "encoder.ae_global_engine.ae_global_blocks", + # ForecastingEngine + "fe_blocks": "forecast_engine.fe_blocks", } new_params = {} @@ -606,13 +552,7 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec A list containing all prediction results """ - # embed - tokens = self.embed_cells(model_params, sample) - - # local assimilation engine and adapter - tokens, posteriors = self.assimilate_local(model_params, tokens, sample) - - tokens = self.assimilate_global(model_params, tokens) + tokens, posteriors = self.encoder(model_params, sample) # roll-out in latent space preds_all = [] @@ -651,171 +591,6 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec return ModelOutput(physical=preds_all, latent=latents) - ######################################### - def embed_cells(self, model_params: ModelParams, sample) -> torch.Tensor: - """Embeds input data for each stream separately and rearranges it to cell-wise order - Args: - model_params : Query and embedding parameters - streams_data : Used to initialize first tokens for pre-processing - Returns:uv - Tokens for local assimilation - """ - - device = next(self.parameters()).device - tokens_all = self.embed_engine(sample, model_params.pe_embed, self.dtype, device) - - return tokens_all - - ######################################### - def assimilate_local( - self, model_params: ModelParams, tokens: torch.Tensor, sample: torch.Tensor - ) -> torch.Tensor: - """Processes embedded tokens locally and prepares them for the global assimilation - Args: - model_params : Query and embedding parameters - tokens : Input tokens to be processed by local assimilation - cell_lens : Used to identify range of tokens to use from generated tokens in cell - embedding - Returns: - Tokens for global assimilation - """ - - cell_lens = sample.source_cell_lens - batch_size = ( - self.cf.batch_size_per_gpu if self.training else self.cf.batch_size_validation_per_gpu - ) - - s = self.q_cells.shape - # print( f'{np.prod(np.array(tokens.shape))} :: {np.prod(np.array(s))}' - # + ':: {np.prod(np.array(tokens.shape))/np.prod(np.array(s))}') - # TODO: test if positional encoding is needed here - if self.cf.ae_local_queries_per_cell: - tokens_global = (self.q_cells + model_params.pe_global).repeat(batch_size, 1, 1) - else: - tokens_global = ( - self.q_cells.repeat(self.num_healpix_cells, 1, 1) + model_params.pe_global - ) - q_cells_lens = torch.cat( - [model_params.q_cells_lens[0].unsqueeze(0)] - + [model_params.q_cells_lens[1:] for _ in range(batch_size)] - ) - - # local assimilation model - # for block in self.ae_local_blocks: - # tokens = checkpoint(block, tokens, cell_lens, use_reentrant=False) - - # if self.cf.latent_noise_kl_weight > 0.0: - # tokens, posteriors = self.interpolate_latents.interpolate_with_noise( - # tokens, sampling=self.training - # ) - # else: - # tokens, posteriors = tokens, 0.0 - - # for block in self.ae_adapter: - # tokens_global = checkpoint( - # block, - # tokens_global, - # tokens, - # q_cells_lens, - # cell_lens, - # use_reentrant=False, - # ) - - # work around to bug in flash attention for hl>=5 - - istep = 0 - - cell_lens = cell_lens[istep][1:] - clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) - tokens_global_unmasked_all = [] - posteriors = [] - zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) - for i in range((cell_lens.shape[0]) // clen): - # make sure we properly catch all elements in last chunk - i_end = (i + 1) * clen if i < (cell_lens.shape[0] // clen) - 1 else cell_lens.shape[0] - l0, l1 = ( - (0 if i == 0 else cell_lens[: i * clen].cumsum(0)[-1]), - cell_lens[:i_end].cumsum(0)[-1], - ) - - tokens_c = tokens[l0:l1] - tokens_global_c = tokens_global[i * clen : i_end] - cell_lens_c = torch.cat([zero_pad, cell_lens[i * clen : i_end]]) - q_cells_lens_c = q_cells_lens[: cell_lens_c.shape[0]] - - # local assimilation model - tokens_c = self.ae_local_engine(tokens_c, cell_lens_c, use_reentrant=False) - - if self.cf.latent_noise_kl_weight > 0.0: - tokens_c, posteriors_c = self.interpolate_latents.interpolate_with_noise( - tokens_c, sampling=self.training - ) - posteriors += [posteriors_c] - else: - tokens_c, posteriors = tokens_c, 0.0 - - # create mask for global tokens, without first element (used for padding) - mask_c = cell_lens_c[1:].to(torch.bool) - tokens_global_unmasked_c = tokens_global_c[mask_c] - q_cells_lens_unmasked_c = torch.cat([zero_pad, q_cells_lens_c[1:][mask_c]]) - cell_lens_unmasked_c = torch.cat([zero_pad, cell_lens_c[1:][mask_c]]) - - if l0 == l1 or tokens_c.shape[0] == 0: - tokens_global_unmasked_all += [tokens_global_unmasked_c] - continue - - # local to global adapter engine - tokens_global_unmasked_c = self.ae_local_global_engine( - tokens_c, - tokens_global_unmasked_c, - q_cells_lens_unmasked_c, - cell_lens_unmasked_c, - use_reentrant=False, - ) - - tokens_global_unmasked_all += [tokens_global_unmasked_c] - - tokens_global_unmasked = torch.cat(tokens_global_unmasked_all) - - # query aggregation engine on the query tokens in unmasked cells - # (applying this here assumes batch_size=1) - # permute to use ae_local_num_queries as the batchsize and no_of_tokens - # as seq len for flash attention - tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) - tokens_global_unmasked = self.ae_aggregation_engine( - tokens_global_unmasked, use_reentrant=False - ) - tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) - - # create mask from cell lens - mask = cell_lens.to(torch.bool) - - # fill empty tensor using mask for positions of unmasked tokens - tokens_global[mask] = tokens_global_unmasked.to(tokens_global.dtype) - - # recover batch dimension and build global token list - tokens_global = ( - tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]]) - + model_params.pe_global - ).flatten(1, 2) - - return tokens_global, posteriors - - ######################################### - def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> torch.Tensor: - """Performs transformer based global assimilation in latent space - Args: - model_params : Query and embedding parameters (never used) - tokens : Input tokens to be pre-processed by global assimilation - Returns: - Latent representation of the model - """ - - # global assimilation engine and adapter - tokens = self.ae_global_engine(tokens, use_reentrant=False) - - return tokens - ######################################### def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) -> torch.Tensor: """Advances latent space representation in time diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 6260439e8..18731d15b 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -93,15 +93,15 @@ def init_model_and_shard( MultiSelfAttentionHeadVarlen, ) - for module in model.ae_local_engine.ae_local_blocks.modules(): + for module in model.encoder.ae_local_engine.ae_local_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.ae_local_global_engine.ae_adapter.modules(): + for module in model.encoder.ae_local_global_engine.ae_adapter.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.ae_global_engine.ae_global_blocks.modules(): + for module in model.encoder.ae_global_engine.ae_global_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) @@ -137,7 +137,7 @@ def init_model_and_shard( # functions in the embedding engine as forward functions. Thus, yielding a crash # because the input tensors are not converted to DTensors. This seems to primarily # occur during validation. - for embed in model.embed_engine.embeds.values(): + for embed in model.encoder.embed_engine.embeds.values(): torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_channels") torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_columns") From 4d27a952166246401ddf75213ae04dff3f835c1c Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 6 Dec 2025 18:52:47 +0100 Subject: [PATCH 165/344] Formatting --- src/weathergen/model/encoder.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index aad1e7e3a..732a77857 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -54,14 +54,12 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.embed_engine: EmbeddingEngine | None = None self.interpolate_latents: LatentInterpolator | None = None - ############## # embedding engine # determine stream names once so downstream components use consistent keys self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] # separate embedding networks for differnt observation types self.embed_engine = EmbeddingEngine(cf, self.sources_size, self.stream_names) - ############## # local assimilation engine self.ae_local_engine = LocalAssimilationEngine(cf) @@ -73,11 +71,9 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord deterministic=cf.latent_noise_deterministic_latents, ) - ############## # local -> global assimilation engine adapter self.ae_local_global_engine = Local2GlobalAssimilationEngine(cf) - ############## # learnable queries if cf.ae_local_queries_per_cell: s = (self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed) @@ -105,11 +101,9 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord q_cells = torch.rand(s, requires_grad=True) / cf.ae_global_dim_embed self.q_cells = torch.nn.Parameter(q_cells, requires_grad=True) - ############## # query aggregation engine self.ae_aggregation_engine = QueryAggregationEngine(cf, self.num_healpix_cells) - ############## # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) @@ -124,7 +118,6 @@ def forward(self, model_params, sample): return tokens, posteriors - ######################################### def embed_cells(self, model_params, sample) -> torch.Tensor: """Embeds input data for each stream separately and rearranges it to cell-wise order Args: @@ -139,7 +132,6 @@ def embed_cells(self, model_params, sample) -> torch.Tensor: return tokens_all - ######################################### def assimilate_local( self, model_params, tokens: torch.Tensor, sample: torch.Tensor ) -> torch.Tensor: @@ -274,7 +266,6 @@ def assimilate_local( return tokens_global, posteriors - ######################################### def assimilate_global(self, tokens: torch.Tensor) -> torch.Tensor: """Performs transformer based global assimilation in latent space Args: From 9cf040e2a4c045d83b86ee30073308f64df5b978 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 6 Dec 2025 18:53:05 +0100 Subject: [PATCH 166/344] Fix source-target matching problem. --- src/weathergen/datasets/masking.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index fe8687cf7..3f8226092 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -386,7 +386,8 @@ def build_samples_for_stream( ) source_masks += [source_mask] source_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] - source_target_mapping += [i_source] + # TODO: proper correspondence between source and target + source_target_mapping += [i_source % len(target_masks)] source_target_mapping = np.array(source_target_mapping, dtype=np.int32) From 5fca790861bd7ab8f1cc9e57b3492a8eb0f29354 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sun, 7 Dec 2025 11:10:49 +0100 Subject: [PATCH 167/344] Enabled multiple input steps. Fixed various robustness that arose through this. This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now. --- config/default_config.yml | 31 ++++--- src/weathergen/datasets/batch.py | 26 ++---- .../datasets/multi_stream_data_sampler.py | 33 ++++--- src/weathergen/datasets/stream_data.py | 92 +++++++------------ src/weathergen/datasets/utils.py | 10 +- src/weathergen/model/encoder.py | 35 ++++--- src/weathergen/model/engines.py | 82 ++++++++--------- src/weathergen/model/model.py | 10 +- src/weathergen/train/trainer.py | 4 +- 9 files changed, 152 insertions(+), 171 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 65262fe01..6ccbf6f97 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -1,5 +1,5 @@ -streams_directory: "./config/streams/era5_1deg/" -# streams_directory: "./config/streams/era5_nppatms_synop/" +# streams_directory: "./config/streams/era5_1deg/" +streams_directory: "./config/streams/era5_nppatms_synop/" embed_orientation: "channels" embed_unembed_mode: "block" @@ -133,26 +133,27 @@ training_config: # "same_strategy_per_batch": false # } # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. - relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + num_steps_input: 2 # loss : ibot loss : training : - LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]} validation : - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 0.8]]} - - masking_strategy: "random" - num_samples: 2 # if student-teacher, the number of local (student) views to generate - masking_strategy_config : { diffusion_rn : True, rate : 0.4 } - relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - # loss : dino - loss : - training : - - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]} + # - masking_strategy: "random" + # num_samples: 2 # if student-teacher, the number of local (student) views to generate + # masking_strategy_config : { diffusion_rn : True, rate : 0.4 } + # relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. + # # loss : dino + # loss : + # training : + # - LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]} - target_input: - - masking_strategy: "random" # Strategy for teacher (global) view: "random", "healpix" - num_samples: 1 # number of teacher views to generate - masking_strategy_config : { diffusion_rn : True, rate : 0.4, rate_sampling: true } + # target_input: + # - masking_strategy: "random" # Strategy for teacher (global) view: "random", "healpix" + # num_samples: 1 # number of teacher views to generate + # masking_strategy_config : { diffusion_rn : True, rate : 0.4, rate_sampling: true } # model_input: # - masking_strategy: "forecast" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 661dd255f..608b2d11a 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -111,6 +111,11 @@ def get_stream_data(self, stream_name: str) -> StreamData: assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" return self.streams_data[stream_name] + def get_forecast_steps(self) -> int: + for _, sdata in self.streams_data.items(): + forecast_dt = sdata.get_forecast_steps() + return forecast_dt + class ModelBatch: """ @@ -130,22 +135,15 @@ class ModelBatch: source2target_matching_idxs: np.typing.NDArray[np.int32] target2source_matching_idxs: np.typing.NDArray[np.int32] - forecast_dt: int | None - - def __init__( - self, streams, num_source_samples: int, num_target_samples: int, forecast_dt: int - ) -> None: + def __init__(self, streams, num_source_samples: int, num_target_samples: int) -> None: """ """ self.source_samples = [Sample(streams) for _ in range(num_source_samples)] self.target_samples = [Sample(streams) for _ in range(num_target_samples)] self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) - # self.target_source_matching_idxs = np.full(num_target_samples, -1, dtype=np.int32) self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] - self.forecast_dt = forecast_dt - def to_device(self, device): for sample in self.source_samples: sample.to_device(device) @@ -210,18 +208,6 @@ def is_empty(self): ) return source_empty or target_empty - def set_forecast_dt(self, forecast_dt: int) -> None: - """ - Set forecast_dt for sample - """ - self.forecast_dt = forecast_dt - - def get_forecast_dt(self) -> int: - """ - Get forecast_dt - """ - return self.forecast_dt - def len_sources(self) -> int: """ Number of source samples diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index dee65a82c..25982dd44 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -88,8 +88,6 @@ def __init__( self.mask_value = 0.0 self._stage = stage - self.num_input_steps = cf.get("num_input_steps", 1) - self.len_hrs: int = cf.len_hrs self.step_hrs: int = cf.step_hrs self.time_window_handler = TimeWindowHandler(start_date, end_date, cf.len_hrs, cf.step_hrs) @@ -309,6 +307,7 @@ def _build_stream_data_input( stream_data: StreamData, base_idx: TIndex, stream_info: dict, + num_steps_input: int, input_data: list, input_tokens: list, mask: torch.Tensor | None = None, @@ -330,7 +329,7 @@ def _build_stream_data_input( if "network_input" in mode: # iterate overall input steps - for step, idx in enumerate(range(base_idx, base_idx - self.num_input_steps, -1)): + for step, idx in enumerate(range(base_idx, base_idx - num_steps_input, -1)): # TODO: check that we are not out of bounds when we go back in time time_win_source = self.time_window_handler.window(idx) @@ -412,6 +411,7 @@ def _build_stream_data( base_idx: TIndex, forecast_dt: int, stream_info: dict, + num_steps_input: int, input_data: list, output_data: list, input_tokens: list, @@ -440,13 +440,14 @@ def _build_stream_data( """ dt = self.forecast_offset + forecast_dt - stream_data = StreamData(base_idx, dt, self.num_healpix_cells) + stream_data = StreamData(base_idx, num_steps_input, dt, self.num_healpix_cells) stream_data = self._build_stream_data_input( modes, stream_data, base_idx, stream_info, + num_steps_input, input_data, input_tokens, input_mask, @@ -465,7 +466,7 @@ def _build_stream_data( return stream_data - def _get_data_windows(self, base_idx, forecast_dt, stream_ds): + def _get_data_windows(self, base_idx, forecast_dt, num_steps_input_max, stream_ds): """ Collect all data needed for current stream to potentially amortize costs by generating multiple samples @@ -474,7 +475,7 @@ def _get_data_windows(self, base_idx, forecast_dt, stream_ds): # source data: iterate overall input steps input_data = [] - for idx in range(base_idx - self.num_input_steps, base_idx + 1): + for idx in range(base_idx - num_steps_input_max, base_idx + 1): # TODO: check that we are not out of bounds when we go back in time rdata = collect_datasources(stream_ds, idx, "source") @@ -528,6 +529,8 @@ def _get_batch(self, idx: int, forecast_dt: int): """ mode = self.training_cfg.get("training_mode") + source_cfgs = self.training_cfg.get("model_input") + target_cfgs = self.training_cfg.get("target_input", source_cfgs) # get/coordinate masks # TODO: should also return number of views @@ -542,7 +545,7 @@ def _get_batch(self, idx: int, forecast_dt: int): else: raise NotImplementedError(f"Unsupported training mode {mode}.") - batch = ModelBatch(self.streams, num_source_samples, num_target_samples, forecast_dt) + batch = ModelBatch(self.streams, num_source_samples, num_target_samples) # for all streams for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): @@ -559,7 +562,10 @@ def _get_batch(self, idx: int, forecast_dt: int): # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 - (input_data, output_data) = self._get_data_windows(idx, forecast_dt, stream_ds) + # max number of input steps + i_max = np.array([sc.get("num_steps_input", 1) for sc in source_cfgs]).max().item() + self.num_steps_input = i_max + (input_data, output_data) = self._get_data_windows(idx, forecast_dt, i_max, stream_ds) # tokenize windows # *_tokens = [ (cells_idx, cells_idx_lens), ... ] with length = #time_steps @@ -574,6 +580,7 @@ def _get_batch(self, idx: int, forecast_dt: int): idx, forecast_dt, stream_info, + source_cfgs[sidx].get("num_steps_input", 1), input_data, output_data, input_tokens, @@ -596,6 +603,7 @@ def _get_batch(self, idx: int, forecast_dt: int): idx, forecast_dt, stream_info, + target_cfgs[sidx].get("num_steps_input", 1), input_data, output_data, input_tokens, @@ -642,18 +650,21 @@ def _get_source_target_masks(self, training_mode): target_cfgs = self.training_cfg.get("target_input", source_cfgs) target_cfgs = target_cfgs if target_cfgs is not None else source_cfgs num_source_samples = np.array([sc.get("num_samples", 1) for sc in source_cfgs]).sum().item() - num_target_samples = np.array([sc.get("num_samples", 1) for sc in target_cfgs]).sum().item() + num_target_samples = np.array([tc.get("num_samples", 1) for tc in target_cfgs]).sum().item() return masks, num_source_samples, num_target_samples def _preprocess_model_data(self, batch, forecast_dt): """ """ + # TODO, TODO, TODO: cleanup + num_steps_input = self.num_steps_input + # aggregated lens of tokens per cell across input batch samples - source_cell_lens = compute_source_cell_lens(batch, self.num_input_steps) + source_cell_lens = compute_source_cell_lens(batch, num_steps_input) # compute offsets for scatter computation after embedding - batch = compute_offsets_scatter_embed(batch, self.num_input_steps) + batch = compute_offsets_scatter_embed(batch, num_steps_input) # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index c8727fe9b..8c5827817 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -20,7 +20,7 @@ class StreamData: for one stream. """ - def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None: + def __init__(self, idx: int, input_steps: int, forecast_steps: int, healpix_cells: int) -> None: """ StreamData object @@ -38,6 +38,7 @@ def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None: self.mask_value = 0.0 + self.input_steps = input_steps self.forecast_steps = forecast_steps self.healpix_cells = healpix_cells @@ -60,15 +61,15 @@ def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None: self.idxs_inv = [torch.tensor([], dtype=torch.int64) for _ in range(forecast_steps + 1)] # source tokens per cell - self.source_tokens_cells = [] + self.source_tokens_cells = [None for _ in range(self.input_steps)] # length of source tokens per cell (without padding) - self.source_tokens_lens = [] + self.source_tokens_lens = [[] for _ in range(self.input_steps)] # unprocessed source (for logging) - self.source_raw = [] + self.source_raw = [None for _ in range(self.input_steps)] # auxiliary data for scatter operation that changes from stream-centric to cell-centric # processing after embedding - self.source_idxs_embed = [torch.tensor([])] - self.source_idxs_embed_pe = [torch.tensor([])] + self.source_idxs_embed = [None for _ in range(self.input_steps)] + self.source_idxs_embed_pe = [None for _ in range(self.input_steps)] def to_device(self, device: str) -> None: """ @@ -85,56 +86,22 @@ def to_device(self, device: str) -> None: """ dv = device - self.source_tokens_cells = [s.to(dv, non_blocking=True) for s in self.source_tokens_cells] - self.source_tokens_lens = [s.to(dv, non_blocking=True) for s in self.source_tokens_lens] - self.target_coords = [t.to(dv, non_blocking=True) for t in self.target_coords] self.target_tokens = [t.to(dv, non_blocking=True) for t in self.target_tokens] - self.source_idxs_embed = [s.to(dv, non_blocking=True) for s in self.source_idxs_embed] - self.source_idxs_embed_pe = [s.to(dv, non_blocking=True) for s in self.source_idxs_embed_pe] - - return self - - def add_empty_source(self, source: IOReaderData) -> None: - """ - Add an empty source for an input. - - Parameters - ---------- - None - - Returns - ------- - None - """ - - source = spoof(source) - self.source_raw += [source] - self.source_tokens_lens += [torch.ones([self.healpix_cells], dtype=torch.int32)] - self.source_tokens_cells += [torch.tensor([])] - - def add_empty_target(self, fstep: int) -> None: - """ - Add an empty target for an input. - - Parameters - ---------- - fstep : int - forecast step + # move to device if source data is present + if not np.array([s is None for s in self.source_tokens_cells]).all(): + self.source_tokens_cells = [ + s.to(dv, non_blocking=True) for s in self.source_tokens_cells + ] + self.source_tokens_lens = [s.to(dv, non_blocking=True) for s in self.source_tokens_lens] - Returns - ------- - None - """ + self.source_idxs_embed = [s.to(dv, non_blocking=True) for s in self.source_idxs_embed] + self.source_idxs_embed_pe = [ + s.to(dv, non_blocking=True) for s in self.source_idxs_embed_pe + ] - self.target_tokens[fstep] += [torch.tensor([], dtype=torch.int32)] - self.target_coords[fstep] += [torch.zeros((0, 105)) for _ in range(self.healpix_cells)] - self.target_coords_lens[fstep] += [torch.zeros([self.healpix_cells], dtype=torch.int32)] - self.target_coords_raw[fstep] += [torch.tensor([]) for _ in range(self.healpix_cells)] - self.target_times_raw[fstep] += [ - np.array([], dtype="datetime64[ns]") for _ in range(self.healpix_cells) - ] + return self def add_source( self, step: int, ss_raw: IOReaderData, ss_lens: torch.Tensor, ss_cells: list @@ -154,13 +121,14 @@ def add_source( None """ - # TODO: use step - self.source_raw += [ss_raw] - self.source_tokens_lens += [ss_lens] - self.source_tokens_cells += [torch.stack(ss_cells)] + assert step < self.input_steps - idx = torch.isnan(self.source_tokens_cells[-1]) - self.source_tokens_cells[-1][idx] = self.mask_value + self.source_raw[step] = ss_raw + self.source_tokens_lens[step] = ss_lens + self.source_tokens_cells[step] = torch.stack(ss_cells) + + idx = torch.isnan(self.source_tokens_cells[step]) + self.source_tokens_cells[step][idx] = self.mask_value def add_target( self, @@ -312,7 +280,9 @@ def source_empty(self) -> bool: True if target is empty for stream, else False """ - return torch.tensor([s.sum() for s in self.source_tokens_lens]).sum() == 0 + return ( + torch.tensor([s.sum() if len(s) > 0 else 0 for s in self.source_tokens_lens]).sum() == 0 + ) def empty(self): """ @@ -336,6 +306,12 @@ def is_spoof(self) -> bool: """ return self.source_is_spoof or self.target_is_spoof + def get_forecast_steps(self) -> int: + """ + Get number of forecast steps + """ + return self.forecast_steps + def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderData: """ diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 4e5abaddb..1740aac2c 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -266,7 +266,7 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): return a -def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> StreamData: +def compute_offsets_scatter_embed(batch: StreamData, num_steps_input: int) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to cell-centric computations @@ -297,7 +297,7 @@ def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> St for stl_b in batch ] ) - for i in range(num_input_steps) + for i in range(num_steps_input) ] # precompute index sets for scatter operation after embed @@ -308,7 +308,7 @@ def compute_offsets_scatter_embed(batch: StreamData, num_input_steps: int) -> St if torch.cat(offsets_base).shape[0] == 0: return batch - for i_s in range(num_input_steps): + for i_s in range(num_steps_input): for ib, sb in enumerate(batch): # batch items for itype, s in enumerate(sb): # streams, i.e. here we have StreamData object if not s.source_empty(): @@ -380,7 +380,7 @@ def compute_idxs_predict(forecast_dt: int, batch: StreamData, streams: list[dict def compute_source_cell_lens( - batch: list[list[StreamData]], num_input_steps: int + batch: list[list[StreamData]], num_steps_input: int ) -> list[torch.tensor]: """ Compute auxiliary information for varlen attention for local assimilation @@ -411,7 +411,7 @@ def compute_source_cell_lens( for stl_b in batch ] ) - for i in range(num_input_steps) + for i in range(num_steps_input) ] source_cell_lens = [torch.sum(c, 1).flatten().to(torch.int32) for c in source_cell_lens_raw] diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 732a77857..833491bfa 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -108,17 +108,28 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) def forward(self, model_params, sample): - # embed - tokens = self.embed_cells(model_params, sample) + """ + Encoder forward + """ + + num_steps_input = len(sample.source_cell_lens) - # local assimilation engine and adapter - tokens, posteriors = self.assimilate_local(model_params, tokens, sample) + tokens, posteriors = [], [] + for input_step in range(num_steps_input): + # embed from physical space + toks = self.embed_cells(input_step, model_params, sample) - tokens = self.assimilate_global(tokens) + # local assimilation engine and adapter + toks, posts = self.assimilate_local(input_step, model_params, toks, sample) + + toks = self.assimilate_global(toks) + + tokens += [toks] + posteriors += [posts] return tokens, posteriors - def embed_cells(self, model_params, sample) -> torch.Tensor: + def embed_cells(self, input_step, model_params, sample) -> torch.Tensor: """Embeds input data for each stream separately and rearranges it to cell-wise order Args: model_params : Query and embedding parameters @@ -127,13 +138,13 @@ def embed_cells(self, model_params, sample) -> torch.Tensor: Tokens for local assimilation """ - device = next(self.parameters()).device - tokens_all = self.embed_engine(sample, model_params.pe_embed, self.dtype, device) + dev = next(self.parameters()).device + tokens_all = self.embed_engine(input_step, sample, model_params.pe_embed, self.dtype, dev) return tokens_all def assimilate_local( - self, model_params, tokens: torch.Tensor, sample: torch.Tensor + self, input_step: int, model_params, tokens: torch.Tensor, sample: torch.Tensor ) -> torch.Tensor: """Processes embedded tokens locally and prepares them for the global assimilation Args: @@ -145,7 +156,7 @@ def assimilate_local( Tokens for global assimilation """ - cell_lens = sample.source_cell_lens + cell_lens = sample.source_cell_lens[input_step] batch_size = ( self.cf.batch_size_per_gpu if self.training else self.cf.batch_size_validation_per_gpu ) @@ -188,9 +199,7 @@ def assimilate_local( # work around to bug in flash attention for hl>=5 - istep = 0 - - cell_lens = cell_lens[istep][1:] + cell_lens = cell_lens[1:] clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) tokens_global_unmasked_all = [] posteriors = [] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 927f95119..bb19d249d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -80,49 +80,45 @@ def __init__(self, cf: Config, sources_size, stream_names: list[str]) -> None: raise ValueError("Unsupported embedding network type") # TODO: remove device from arg list - def forward(self, sample, pe_embed, dtype, device): - num_step_input = len(sample.source_cell_lens) - offsets_base = [torch.cumsum(s[1:], 0) for s in sample.source_cell_lens] - - tokens_all = [ - torch.empty((int(ob[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device) - for ob in offsets_base - ] - - # TODO: handling of input steps should be done using encoder - # iterate over all input steps and streams - for istep in range(num_step_input): - for stream_name, s_data in sample.streams_data.items(): - # embedding network - embed = self.embeds[stream_name] - - # skip empty stream - if s_data.source_empty(): - continue - - idxs = s_data.source_idxs_embed[istep].to(device) - idxs_pe = s_data.source_idxs_embed_pe[istep].to(device) - - # create full scatter index - # (there's no broadcasting which is likely highly inefficient) - idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - x_embed = embed(s_data.source_tokens_cells[istep]).flatten(0, 1) - # there's undocumented limitation in flash_attn that will make embed fail if - # #tokens is too large; code below is a work around - # x_embed = torch.cat( - # [ - # embed(s_c, c_c).flatten(0, 1) - # for s_c, c_c in zip( - # torch.split(s.source_tokens_cells, 49152), - # torch.split(s.source_centroids, 49152), - # ) - # ] - # ) - - # scatter write to reorder from per stream to per cell ordering - tokens_all[istep].scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) - - return tokens_all[0] + def forward(self, input_step, sample, pe_embed, dtype, device): + offsets_base = torch.cumsum(sample.source_cell_lens[input_step][1:], 0) + + tokens_all = torch.empty( + (int(offsets_base[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device + ) + + # iterate over all streams + for stream_name, s_data in sample.streams_data.items(): + # embedding network + embed = self.embeds[stream_name] + + # skip empty stream + if s_data.source_empty(): + continue + + idxs = s_data.source_idxs_embed[input_step].to(device) + idxs_pe = s_data.source_idxs_embed_pe[input_step].to(device) + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + x_embed = embed(s_data.source_tokens_cells[input_step]).flatten(0, 1) + # there's undocumented limitation in flash_attn that will make embed fail if + # #tokens is too large; code below is a work around + # x_embed = torch.cat( + # [ + # embed(s_c, c_c).flatten(0, 1) + # for s_c, c_c in zip( + # torch.split(s.source_tokens_cells, 49152), + # torch.split(s.source_centroids, 49152), + # ) + # ] + # ) + + # scatter write to reorder from per stream to per cell ordering + tokens_all.scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) + + return tokens_all class LocalAssimilationEngine(torch.nn.Module): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 00efbe6d8..1394d4e98 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -535,7 +535,7 @@ def rename_old_state_dict(self, params: dict) -> dict: ######################################### def forward(self, model_params: ModelParams, sample, forecast_offset: int, forecast_steps: int): - """Performs the forward pass of the model to generate forecasts + """Forward pass of the model Tokens are processed through the model components, which were defined in the create method. Args: @@ -554,9 +554,12 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec tokens, posteriors = self.encoder(model_params, sample) + # collapse along input step dimension + tokens = torch.stack(tokens, 0).sum(0) + # roll-out in latent space preds_all = [] - for fstep in range(forecast_offset, forecast_offset + forecast_steps): + for fstep in range(forecast_offset, forecast_steps): # prediction preds_all += [ self.predict( @@ -579,9 +582,8 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec preds_all += [ self.predict( model_params, - forecast_offset + forecast_steps, + forecast_steps, tokens, - # TODO We add the batch dimension back and thus wrap stream_data in a list sample, ) ] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 2ee71a547..39efd2b3f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -529,7 +529,7 @@ def train(self, mini_epoch): self.model_params, sample, cf.forecast_offset, - batch.get_forecast_dt(), + sample.get_forecast_steps(), ) ) @@ -546,7 +546,7 @@ def train(self, mini_epoch): self.model_params, self.model, cf.forecast_offset, - batch.get_forecast_dt(), + sample.get_forecast_steps(), ) ) # targets, aux = zip(*targets_and_auxs) From 47e81fac62e9add97a478076020c38895822dabc Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sun, 7 Dec 2025 11:24:30 +0100 Subject: [PATCH 168/344] Linting --- src/weathergen/datasets/multi_stream_data_sampler.py | 3 +++ src/weathergen/model/model.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 25982dd44..2de60e740 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -197,6 +197,9 @@ def __init__( self.shuffle = shuffle # TODO: remove options that are no longer supported self.input_window_steps = cf.input_window_steps + # TODO, TODO, TODO: this needs to be stream specific and should not be an attribute + # current implementation needs to be cleaned up when batch_size > 1 is enabled + self.num_steps_input = -1 self.batch_size = batch_size diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 1394d4e98..d76d19d26 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -264,6 +264,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.targets_coords_size = targets_coords_size self.embed_target_coords = None + self.encoder: EncoderModule | None = None self.forecast_engine: ForecastingEngine | None = None self.pred_adapter_kv = None self.pred_heads = None From e0f6cc40b403a606f053ca031418b379f7a7eb78 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Dec 2025 13:54:25 +0100 Subject: [PATCH 169/344] Missing update to validation() --- src/weathergen/train/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 39efd2b3f..8c0ebe9bf 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -677,7 +677,6 @@ def validate(self, mini_epoch): total=len(self.data_loader_validation), disable=self.cf.with_ddp ) as pbar: for bidx, batch in enumerate(dataset_val_iter): - forecast_steps = batch.get_forecast_dt() batch.to_device(self.device) # evaluate model @@ -696,7 +695,7 @@ def validate(self, mini_epoch): self.model_params, sample, cf.forecast_offset, - forecast_steps, + sample.get_forecast_dt(), ) sample = batch.target_samples[0] target_aux_output = self.target_and_aux_calculator.compute( @@ -709,7 +708,7 @@ def validate(self, mini_epoch): self.model_params, self.model, cf.forecast_offset, - forecast_steps, + sample.get_forecast_dt(), ) loss, loss_values = self.loss_calculator_val.compute_loss( preds=output, From 8f097ec14c0980edfffdcc284022e647aefbfff9 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Dec 2025 13:54:43 +0100 Subject: [PATCH 170/344] Improved robustness through sanity checking of arguments --- src/weathergen/datasets/masking.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 3f8226092..e68eb369d 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -403,7 +403,7 @@ def _get_mask( strategy: str | None = None, masking_strategy_config: dict | None = None, target_mask: np.typing.NDArray | None = None, - relationship: str = "subset", + relationship: str | None = None, ) -> (np.typing.NDArray, dict): """Get effective mask, combining with target mask if specified. @@ -428,6 +428,12 @@ def _get_mask( Parameters describing the masking that was applied """ + if strategy == "forecast": + if relationship is not None: + assert relationship == "independent", ( + "strategy forecast requires relationship independent " + ) + # handle cases where mask is directly derived from target_mask if target_mask is not None: if relationship == "complement": From 6b6451132434516753ad6c358d3956eeaa457fa4 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Dec 2025 13:55:10 +0100 Subject: [PATCH 171/344] Improved handling of corner cases --- src/weathergen/datasets/stream_data.py | 4 ++-- src/weathergen/datasets/tokenizer_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 8c5827817..f13ddc595 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -68,8 +68,8 @@ def __init__(self, idx: int, input_steps: int, forecast_steps: int, healpix_cell self.source_raw = [None for _ in range(self.input_steps)] # auxiliary data for scatter operation that changes from stream-centric to cell-centric # processing after embedding - self.source_idxs_embed = [None for _ in range(self.input_steps)] - self.source_idxs_embed_pe = [None for _ in range(self.input_steps)] + self.source_idxs_embed = [torch.tensor([]) for _ in range(self.input_steps)] + self.source_idxs_embed_pe = [torch.tensor([]) for _ in range(self.input_steps)] def to_device(self, device: str) -> None: """ diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 7c5d056ac..fc61b1b02 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -251,7 +251,7 @@ def tokenize_apply_mask_source( idxs_data = [t for t, m in zip(idxs_tokens, mask_tokens, strict=True) if m] if len(idxs_data) == 0: - tokens_cells = [] + tokens_cells = [torch.tensor([])] tokens_per_cell = torch.zeros(len(idxs_cells_lens), dtype=torch.int32) return tokens_cells, tokens_per_cell From 303f48a1bb0fcd52bcfe85565213482031d34e99 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Dec 2025 14:42:32 +0100 Subject: [PATCH 172/344] - Fixed incorrect call to get_forecast_steps() in validation - Fixed interface of target_aux_calculator --- src/weathergen/train/target_and_aux_module_base.py | 6 +++--- src/weathergen/train/trainer.py | 11 +++-------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index 022898e55..45ede53a8 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -27,7 +27,7 @@ def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: pass - def compute(self, *args, **kwargs) -> TargetAuxOutput: + def compute(self, sample, *args, **kwargs) -> TargetAuxOutput: pass def to_device(self, device): @@ -47,8 +47,8 @@ def update_state_pre_backward(self, istep, batch, model, **kwargs): def update_state_post_opt_step(self, istep, batch, model, **kwargs): return - def compute(self, istep, batch, *args, **kwargs) -> TargetAuxOutput: - return TargetAuxOutput(physical=batch[0], latent=None, aux_outputs=None) + def compute(self, sample, *args, **kwargs) -> TargetAuxOutput: + return TargetAuxOutput(physical=sample.streams_data, latent=None, aux_outputs=None) def to_device(self, device): return diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8c0ebe9bf..89de6bbbd 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -537,12 +537,7 @@ def train(self, mini_epoch): for sample in batch.target_samples: targets_and_auxs.append( self.target_and_aux_calculator.compute( - self.cf.istep, - ( - sample.streams_data, - sample.source_cell_lens, - sample.target_coords_idx, - ), + sample, self.model_params, self.model, cf.forecast_offset, @@ -695,7 +690,7 @@ def validate(self, mini_epoch): self.model_params, sample, cf.forecast_offset, - sample.get_forecast_dt(), + sample.get_forecast_steps(), ) sample = batch.target_samples[0] target_aux_output = self.target_and_aux_calculator.compute( @@ -708,7 +703,7 @@ def validate(self, mini_epoch): self.model_params, self.model, cf.forecast_offset, - sample.get_forecast_dt(), + sample.get_forecast_steps(), ) loss, loss_values = self.loss_calculator_val.compute_loss( preds=output, From 729910659f6f8939075cb6b8aeb26dc6c74cc872 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Dec 2025 14:59:30 +0100 Subject: [PATCH 173/344] More fixed to validation --- src/weathergen/train/trainer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 89de6bbbd..b767ecdf8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -694,12 +694,7 @@ def validate(self, mini_epoch): ) sample = batch.target_samples[0] target_aux_output = self.target_and_aux_calculator.compute( - bidx, - ( - sample.streams_data, - sample.source_cell_lens, - sample.target_coords_idx, - ), + sample, self.model_params, self.model, cf.forecast_offset, From 45189a476f4004fddc993abc9e3b1b33d40b5409 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Dec 2025 21:25:45 +0100 Subject: [PATCH 174/344] Adding stream_id --- integration_tests/streams_multi/era5_small.yml | 3 ++- integration_tests/streams_multi/npp_atms.yml | 1 + integration_tests/streams_multi/synop.yml | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/integration_tests/streams_multi/era5_small.yml b/integration_tests/streams_multi/era5_small.yml index b0463596a..04d47ca99 100644 --- a/integration_tests/streams_multi/era5_small.yml +++ b/integration_tests/streams_multi/era5_small.yml @@ -1,5 +1,6 @@ ERA5: type: anemoi + stream_id: 0 filenames: ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] loss_weight: 1.0 source_exclude: ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] @@ -27,4 +28,4 @@ ERA5: num_heads: 2 pred_head: ens_size: 1 - num_layers: 1 \ No newline at end of file + num_layers: 1 diff --git a/integration_tests/streams_multi/npp_atms.yml b/integration_tests/streams_multi/npp_atms.yml index f7d852d0f..6affb1da1 100644 --- a/integration_tests/streams_multi/npp_atms.yml +++ b/integration_tests/streams_multi/npp_atms.yml @@ -1,5 +1,6 @@ NPPATMS : type : obs + stream_id: 1 filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr'] loss_weight : 1.0 # masking_rate : 0.6 diff --git a/integration_tests/streams_multi/synop.yml b/integration_tests/streams_multi/synop.yml index 1fe0b8d56..461bde9ab 100644 --- a/integration_tests/streams_multi/synop.yml +++ b/integration_tests/streams_multi/synop.yml @@ -1,5 +1,6 @@ SurfaceCombined : type : obs + stream_id: 2 filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] loss_weight : 1.0 masking_rate : 0.6 From 5bed79223c142e2645d05a62158b106cf87685d3 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 10 Dec 2025 19:32:29 +0100 Subject: [PATCH 175/344] Cleaned up ModelOutput class to have proper access functions and a better structure --- src/weathergen/model/model.py | 120 +++++++----------- .../loss_modules/loss_module_physical.py | 4 +- 2 files changed, 47 insertions(+), 77 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d76d19d26..2c5a4ab25 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -9,7 +9,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import dataclasses import logging import math import warnings @@ -22,6 +21,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.batch import Sample from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( EnsPredictionHead, @@ -36,15 +36,42 @@ logger = logging.getLogger(__name__) +type StreamName = str + -@dataclasses.dataclass class ModelOutput: """ - A dataclass to encapsulate the model output and give a clear API. + Representation of model output """ - physical: dict[str, torch.Tensor] - latent: dict[str, torch.Tensor] + physical: list[dict[StreamName, torch.Tensor]] + latent: list[dict[StreamName, torch.Tensor]] + + def __init__(self, forecast_steps: int) -> None: + self.physical = [{} for _ in range(forecast_steps)] + self.latent = [{} for _ in range(forecast_steps)] + + def add_physical_prediction( + self, fstep: int, stream_name: StreamName, pred: torch.Tensor + ) -> None: + self.physical[fstep][stream_name] = pred + + def add_latent_prediction( + self, fstep: int, stream_name: StreamName, pred: torch.Tensor + ) -> None: + self.latent[fstep][stream_name] = pred + + def get_physical_prediction(self, fstep: int, stream_name: StreamName | None = None): + pred = self.physical[fstep] + if stream_name is not None: + pred = pred[stream_name] + return pred + + def get_latent_prediction(self, fstep: int, stream_name: StreamName | None = None): + pred = self.latent[fstep] + if stream_name is not None: + pred = pred[stream_name] + return pred class ModelParams(torch.nn.Module): @@ -490,50 +517,6 @@ def print_num_parameters(self) -> None: ] print("-----------------") - ######################################### - def rename_old_state_dict(self, params: dict) -> dict: - """Checks if model from checkpoint is from the old model version and if so renames - the parameters accordingly to the new model version. - - Args: - params : Dictionary with (old) model parameters from checkpoint - Returns: - new_params : Dictionary with (renamed) model parameters - """ - params_cleanup = { - # EmbeddingEngine - "embeds": "encoder.embed_engine.embeds", - # LocalAssimilationEngine - "ae_local_blocks": "encoder.ae_local_engine.ae_local_blocks", - # Local2GlobalAssimilationEngine - "ae_adapter": "encoder.ae_local_global_engine.ae_adapter", - # GlobalAssimilationEngine - "ae_global_blocks": "encoder.ae_global_engine.ae_global_blocks", - # ForecastingEngine - "fe_blocks": "forecast_engine.fe_blocks", - } - - new_params = {} - - for k, v in params.items(): - new_k = k - prefix = "" - - # Strip "module." (prefix for DataParallel or DistributedDataParallel) - if new_k.startswith("module."): - prefix = "module." - new_k = new_k[len(prefix) :] - - first_w, rest = new_k.split(".", 1) if "." in new_k else (new_k, "") - # Only check first word (root level modules) to avoid false matches. - if first_w in params_cleanup: - new_k = params_cleanup[first_w] + "." + rest - - new_k = prefix + new_k - new_params[new_k] = v - - return new_params - ######################################### def forward(self, model_params: ModelParams, sample, forecast_offset: int, forecast_steps: int): """Forward pass of the model @@ -553,23 +536,17 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec A list containing all prediction results """ + output = ModelOutput(forecast_steps + 1) + tokens, posteriors = self.encoder(model_params, sample) # collapse along input step dimension tokens = torch.stack(tokens, 0).sum(0) # roll-out in latent space - preds_all = [] for fstep in range(forecast_offset, forecast_steps): # prediction - preds_all += [ - self.predict( - model_params, - fstep, - tokens, - sample, - ) - ] + output = self.predict(model_params, fstep, tokens, sample, output) if self.training: # Impute noise to the latent state @@ -580,19 +557,12 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec tokens = self.forecast(model_params, tokens, fstep) # prediction for final step - preds_all += [ - self.predict( - model_params, - forecast_steps, - tokens, - sample, - ) - ] + output = self.predict(model_params, forecast_steps, tokens, sample, output) - latents = {} - latents["posteriors"] = posteriors + # TODO: set properly + output.latents = posteriors - return ModelOutput(physical=preds_all, latent=latents) + return output ######################################### def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) -> torch.Tensor: @@ -618,7 +588,8 @@ def predict( model_params: ModelParams, fstep: int, tokens: torch.Tensor, - sample, + sample: Sample, + output: ModelOutput, ) -> list[torch.Tensor]: """Predict outputs at the specific target coordinates based on the input weather state and pre-training task and projects the latent space representation back to physical space. @@ -647,7 +618,7 @@ def predict( tokens_stream = tokens_stream[model_params.hp_nbours.flatten()].flatten(0, 1) # pair with tokens from assimilation engine to obtain target tokens - preds_tokens = [] + preds_tokens = {} for stream_name in self.stream_names: tte = self.target_token_engines[stream_name] tte_kv = self.pred_adapter_kv[stream_name] @@ -712,8 +683,7 @@ def predict( ) # final prediction head to map back to physical space - preds_tokens += [ - checkpoint(self.pred_heads[stream_name], tc_tokens, use_reentrant=False) - ] + pred = checkpoint(self.pred_heads[stream_name], tc_tokens, use_reentrant=False) + output.add_physical_prediction(fstep, stream_name, pred) - return preds_tokens + return output diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 0f523b409..a62ae8735 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -213,7 +213,7 @@ def compute_loss( # TODO: iterate over batch dimension i_batch = 0 streams_data = [streams_data] - for i_stream_info, stream_info in enumerate(self.cf.streams): + for stream_info in self.cf.streams: stream_name = stream_info["name"] # extract target tokens for current stream from the specified forecast offset onwards targets = streams_data[i_batch][stream_name].target_tokens[self.cf.forecast_offset :] @@ -235,7 +235,7 @@ def compute_loss( zip(targets, fstep_loss_weights, strict=False) ): # skip if either target or prediction has no data points - pred = preds[fstep][i_stream_info] + pred = preds[fstep + self.cf.forecast_offset].get(stream_name, torch.tensor([])) if not (target.shape[0] > 0 and pred.shape[0] > 0): continue From 06f2e063c451cbac5fd8e9a02103bb81144709f3 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 10 Dec 2025 19:33:06 +0100 Subject: [PATCH 176/344] Switched to use dict to internally represent streams_datasets --- .../datasets/multi_stream_data_sampler.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 2de60e740..5bc57c448 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -37,6 +37,7 @@ from weathergen.utils.train_logger import Stage type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs +type StreamName = str logger = logging.getLogger(__name__) @@ -112,9 +113,10 @@ def __init__( self.len = 100000000 - self.streams_datasets: list[list[AnyDataReader]] = [] + self.streams_datasets: dict[StreamName, list[AnyDataReader]] = {} for _, stream_info in enumerate(cf.streams): - self.streams_datasets.append([]) + # list of sources for current stream + self.streams_datasets[stream_info["name"]] = [] for fname in stream_info["filenames"]: kwargs = { @@ -171,7 +173,6 @@ def __init__( if len(ds) > 0: self.len = min(self.len, len(ds) - (self.len_hrs * (fsm + 1)) // self.step_hrs) - # MODIFIES config !!! stream_info[str(self._stage) + "_source_channels"] = ds.source_channels stream_info[str(self._stage) + "_target_channels"] = ds.target_channels stream_info["target_channel_weights"] = ( @@ -180,7 +181,7 @@ def __init__( else [1.0 for _ in ds.target_channels] ) - self.streams_datasets[-1] += [ds] + self.streams_datasets[stream_info["name"]] += [ds] index_range = self.time_window_handler.get_index_range() self.len = int(index_range.end - index_range.start) @@ -239,20 +240,21 @@ def get_sources_size(self): + ds[0].get_geoinfo_size() + ds[0].get_coords_size() + self.tokenizer.get_size_time_embedding() - for ds in self.streams_datasets + for _, ds in self.streams_datasets.items() ] def get_sources_num_channels(self): - return [ds[0].get_source_num_channels() for ds in self.streams_datasets] + return [ds[0].get_source_num_channels() for _, ds in self.streams_datasets.items()] def get_targets_num_channels(self): - return [ds[0].get_target_num_channels() for ds in self.streams_datasets] + return [ds[0].get_target_num_channels() for _, ds in self.streams_datasets.items()] def get_targets_coords_size(self): # TODO: avoid hard coding magic values # +6 at the end for stram_id and time encoding return [ - (ds[0].get_geoinfo_size() + (5 * (3 * 5)) + 3 * 8) + 6 for ds in self.streams_datasets + (ds[0].get_geoinfo_size() + (5 * (3 * 5)) + 3 * 8) + 6 + for _, ds in self.streams_datasets.items() ] def reset(self): @@ -296,13 +298,13 @@ def reset(self): self.tokenizer.reset_rng(self.rng) - def denormalize_source_channels(self, stream_id, data) -> torch.Tensor: + def denormalize_source_channels(self, stream_name, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here - return self.streams_datasets[stream_id][0].denormalize_source_channels(data) + return self.streams_datasets[stream_name][0].denormalize_source_channels(data) - def denormalize_target_channels(self, stream_id, data) -> torch.Tensor: + def denormalize_target_channels(self, stream_name, data) -> torch.Tensor: # TODO: with multiple ds per stream we need to distinguish these here - return self.streams_datasets[stream_id][0].denormalize_target_channels(data) + return self.streams_datasets[stream_name][0].denormalize_target_channels(data) def _build_stream_data_input( self, @@ -542,7 +544,7 @@ def _get_batch(self, idx: int, forecast_dt: int): if mode == "masking": source_select = ["network_input", "target_coords"] target_select = ["target_values"] - elif mode == "student_teacher": + elif mode == "student_teacher" or mode == "latent_loss": source_select = ["network_input"] target_select = ["network_input"] else: @@ -551,9 +553,9 @@ def _get_batch(self, idx: int, forecast_dt: int): batch = ModelBatch(self.streams, num_source_samples, num_target_samples) # for all streams - for stream_info, stream_ds in zip(self.streams, self.streams_datasets, strict=True): - stream_name = stream_info["name"] - + for stream_info, (stream_name, stream_ds) in zip( + self.streams, self.streams_datasets.items(), strict=True + ): # TODO: data class for this or something similar ( target_masks, From ad5a19cdfb7091defc4bb95f7def91f6f9d31e8d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 10 Dec 2025 23:19:50 +0100 Subject: [PATCH 177/344] Improving robustness of interface of ModelOutput class --- src/weathergen/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 2c5a4ab25..261656fcd 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -64,13 +64,13 @@ def add_latent_prediction( def get_physical_prediction(self, fstep: int, stream_name: StreamName | None = None): pred = self.physical[fstep] if stream_name is not None: - pred = pred[stream_name] + pred = pred.get(stream_name, None) return pred def get_latent_prediction(self, fstep: int, stream_name: StreamName | None = None): pred = self.latent[fstep] if stream_name is not None: - pred = pred[stream_name] + pred = pred.get(stream_name, None) return pred From 4f8abbb3c7b2d1ca6ffb3adc4f8a2f57861f2156 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 10 Dec 2025 23:20:25 +0100 Subject: [PATCH 178/344] Re-enabling model output --- src/weathergen/train/trainer.py | 196 +------------------------- src/weathergen/utils/validation_io.py | 75 ++++++++-- 2 files changed, 65 insertions(+), 206 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b767ecdf8..53adf5410 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -10,13 +10,10 @@ # nor does it submit to any jurisdiction. import logging import time -from typing import Any import numpy as np -import omegaconf import torch import tqdm -from numpy.typing import NDArray from omegaconf import OmegaConf from torch import Tensor @@ -37,6 +34,7 @@ from weathergen.utils.distributed import all_gather_vlen, ddp_average, is_root from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger from weathergen.utils.utils import get_batch_size, get_dtype +from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) @@ -350,159 +348,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # log final model self.save_model(cf.num_mini_epochs) - ########################################### - def _prepare_logging( - self, - preds: list[list[Tensor]], - forecast_offset: int, - forecast_steps: int, - streams_data: list[list[Any]], - ): - """Collects and denormalizes prediction and target data for logging. - - This function processes target and prediction tensors, extracts relevant - coordinates and timestamps, denormalizes the data, and organizes it - into a structured format suitable for logging or further analysis. It - handles potential empty tensors and NaN values. - - Args: - preds: A list of lists, where the outer list - corresponds to forecast steps, and the inner list contains prediction - tensors for each observation stream. Each prediction tensor is - expected to be in the normalized latent or observation space, - depending on the model's output. - targets: A list of lists, where the outer list - corresponds to forecast steps, and the inner list contains target - tensors for each observation stream. Each target tensor is expected - to be in the normalized observation space. - forecast_offset: The starting offset for the forecast steps - relative to the original data. - forecast_steps: The number of forecast steps to consider. - streams_data: A list of lists, where each inner list - contains data objects (e.g., `BatchItem` instances) for each stream - at a specific time step. These objects are expected to have - `target_coords_raw` and `target_times_raw` attributes. - - Returns: - tuple: A tuple containing: - - preds_all: Denormalized - predictions, organized by forecast step and observation stream. - - targets_all: Denormalized - targets, organized by forecast step and observation stream. - - targets_coords_raw: Raw target coordinates, - extracted and concatenated for each forecast step and stream. - - targets_times_raw: Raw target timestamps, - extracted and concatenated for each forecast step and stream. - - targets_lens: A list of lists, where each - inner list contains the original lengths (shape[0]) of the target - tensors before any filtering. - """ - - # handle case when forecast_steps is a list - if type(forecast_steps) is omegaconf.listconfig.ListConfig: - forecast_range = np.array(forecast_steps) - else: - forecast_range = np.arange(forecast_offset, forecast_offset + forecast_steps + 1) - - #''' - # TODO: Remove this function and port functionality to write_validation(), which then - # extracts preds_all, targets_all,... itself directly from stream_data. - # TODO: Undo list resorting - # The following list operations realize a reshaping of the original tensors in streams_data - # from shape [batch_sample][stream][fstep] into shape [fstep][stream][batch_sample]. When - # removing the reshaping, make sure to index the tensors starting at forecast_offset, e.g., - # target_times_raw = streams_data[i_batch][i_strm].target_times_raw[forecast_offset+fstep], - # when iterating over batch, stream, and fsteps. - targets_rt = [ - [ - torch.cat([t[i].target_tokens[fstep] for t in streams_data]) - for i in range(len(self.cf.streams)) - ] - for fstep in forecast_range - ] - # TODO: Undo list resorting - targets_coords_raw = [ - [ - torch.cat([t[i].target_coords_raw[fstep] for t in streams_data]) - for i in range(len(self.cf.streams)) - ] - for fstep in forecast_range - ] - # TODO: Undo list resorting - targets_times_raw = [ - [ - np.concatenate([t[i].target_times_raw[fstep] for t in streams_data]) - for i in range(len(self.cf.streams)) - ] - for fstep in forecast_range - ] - # inverse indices - idxs_inv_rt = [ - [ - torch.cat([t[i].idxs_inv[fstep] for t in streams_data]) - for i in range(len(self.cf.streams)) - ] - for fstep in range(forecast_offset, forecast_offset + forecast_steps + 1) - ] - - # assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams) - fsteps = len(targets_rt) - preds_all: list[list[list[NDArray]]] = [ - [[] for _ in self.cf.streams] for _ in range(fsteps) - ] - targets_all: list[list[list[NDArray]]] = [ - [[] for _ in self.cf.streams] for _ in range(fsteps) - ] - targets_lens: list[list[list[int]]] = [[[] for _ in self.cf.streams] for _ in range(fsteps)] - - # TODO: iterate over batches here in future, and change loop order to batch, stream, fstep - for fstep in range(len(targets_rt)): - if len(preds.physical[fstep]) == 0: - continue - - for i_strm, target in enumerate(targets_rt[fstep]): - pred = preds.physical[fstep][i_strm] - idxs_inv = idxs_inv_rt[fstep][i_strm] - - if not (target.shape[0] > 0 and pred.shape[0] > 0): - continue - - # extract data/coords and remove token dimension if it exists - pred = pred.reshape([pred.shape[0], *target.shape]) - assert pred.shape[1] > 0 - - mask_nan = ~torch.isnan(target) - if pred[:, mask_nan].shape[1] == 0: - continue - - targets_lens[fstep][i_strm] += [target.shape[0]] - dn_data = self.dataset_val.denormalize_target_channels - - # reorder so that output order of target points matches input when reading - # (tokenization and masking changes this order) - # TODO: does this work with batch_size > 1 - if len(idxs_inv) > 0: - pred = pred[:, idxs_inv] - target = target[idxs_inv] - targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] - targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] - - f32 = torch.float32 - preds_all[fstep][i_strm] += [ - np.asarray(dn_data(i_strm, pred.to(f32)).detach().cpu()) - ] - targets_all[fstep][i_strm] += [ - np.asarray(dn_data(i_strm, target.to(f32)).detach().cpu()) - ] - - return ( - preds_all, - targets_all, - targets_coords_raw, - targets_times_raw, - targets_lens, - ) - def train(self, mini_epoch): cf = self.cf self.model.train() @@ -707,41 +552,10 @@ def validate(self, mini_epoch): # log output if bidx < cf.log_validation: - logger.warning("logging of data currently not implemented") - # # TODO: Move _prepare_logging into write_validation by passing streams_data - # # TODO right now we hardcode ERA5 which obviously is bad, but not sure - # # how this logging function is supposed to change - # streams_data: list[list[StreamData]] = old_batch[0] - # import pdb - - # pdb.set_trace() - # ( - # preds_all, - # targets_all, - # targets_coords_all, - # targets_times_all, - # targets_lens, - # ) = self._prepare_logging( - # preds=output, - # forecast_offset=cf.forecast_offset, - # forecast_steps=cf.forecast_steps, - # streams_data=streams_data, - # ) - # sources = [[item.source_raw for item in stream] for stream in streams_data] - # # sample idx should be the same across streams => select first - # sample_idxs = [item.sample_idx for item in streams_data[0]] - # write_output( - # self.cf, - # mini_epoch, - # bidx, - # sources[0], - # preds_all, - # targets_all, - # targets_coords_all, - # targets_times_all, - # targets_lens, - # sample_idxs, - # ) + dn_data = self.dataset_val.denormalize_target_channels + write_output( + self.cf, mini_epoch, bidx, dn_data, batch, output, target_aux_output + ) # Collecting loss statistics for later inspection if bidx == 0: diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index f8a5a1cc5..27f8888a8 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -9,6 +9,8 @@ import logging +import torch + import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange @@ -17,18 +19,63 @@ _logger = logging.getLogger(__name__) -def write_output( - cf, - mini_epoch, - batch_idx, - sources, - preds_all, - targets_all, - targets_coords_all, - targets_times_all, - targets_lens, - sample_idxs, -): +def write_output(cf, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_output): + """ + Interface for writing model output + """ + + fp32 = torch.float32 + preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], [] + for fstep in range(cf.forecast_offset, cf.forecast_steps + 2): + preds_all += [[]] + targets_all += [[]] + targets_coords_all += [[]] + targets_times_all += [[]] + for stream_info in cf.streams: + # predictions + pred = model_output.get_physical_prediction(fstep, stream_info["name"]).to(fp32) + target = target_aux_output.physical[stream_info["name"]].target_tokens[fstep].to(fp32) + + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # extract data/coords and remove token dimension if it exists + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # TODO: the inner lists here should not be needed + preds_all[-1] += [[dn_data(stream_info["name"], pred).detach().cpu().numpy()]] + targets_all[-1] += [[dn_data(stream_info["name"], target).detach().cpu().numpy()]] + + sname = stream_info["name"] + targets_coords_all[-1] += [target_aux_output.physical[sname].target_coords_raw[fstep]] + targets_times_all[-1] += [target_aux_output.physical[sname].target_times_raw[fstep]] + + # # TODO: re-enable + # if len(idxs_inv) > 0: + # pred = pred[:, idxs_inv] + # target = target[idxs_inv] + # targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv] + # targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv] + + # TODO: remove + targets_lens = [[[t[0].shape[0]] for t in tt] for tt in targets_all] + + sources = [] + for sample in batch.source_samples: + sources += [[]] + for _, stream_data in sample.streams_data.items(): + # TODO: support multiple input steps + sources[-1] += [stream_data.source_raw[0]] + + sample_idxs = [ + [sdata.sample_idx for _, sdata in sample.streams_data.items()] + for sample in batch.source_samples + ] + sample_idxs = [s[0].item() for s in sample_idxs] + + # more prep work + stream_names = [stream.name for stream in cf.streams] analysis_streams_output = cf.get("analysis_streams_output", None) if cf.streams_output is not None: @@ -54,9 +101,7 @@ def write_output( # => calculate global sample indices for this batch by offsetting by sample_start sample_start = batch_idx * cf.batch_size_validation_per_gpu - assert len(stream_names) == len(targets_all[0]), "data does not match number of streams" - assert len(stream_names) == len(preds_all[0]), "data does not match number of streams" - assert len(stream_names) == len(sources[0]), "data does not match number of streams" + # write output start_date = str_to_datetime64(cf.start_date_val) end_date = str_to_datetime64(cf.end_date_val) From d36716c351eb2b740a67b0a9e477ac1759059c92 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 08:06:24 +0100 Subject: [PATCH 179/344] Ruff --- src/weathergen/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 261656fcd..ce6dd8206 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -560,7 +560,7 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec output = self.predict(model_params, forecast_steps, tokens, sample, output) # TODO: set properly - output.latents = posteriors + output.latent = posteriors return output From b8d95b2181fb6016b95abafce750608f890935e1 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 08:06:34 +0100 Subject: [PATCH 180/344] Minor clean-ups and additional comments --- src/weathergen/utils/validation_io.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 27f8888a8..c9b16df36 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -24,6 +24,7 @@ def write_output(cf, mini_epoch, batch_idx, dn_data, batch, model_output, target Interface for writing model output """ + # collect all target / prediction-related information fp32 = torch.float32 preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], [] for fstep in range(cf.forecast_offset, cf.forecast_steps + 2): @@ -61,6 +62,7 @@ def write_output(cf, mini_epoch, batch_idx, dn_data, batch, model_output, target # TODO: remove targets_lens = [[[t[0].shape[0]] for t in tt] for tt in targets_all] + # collect source information sources = [] for sample in batch.source_samples: sources += [[]] @@ -77,11 +79,8 @@ def write_output(cf, mini_epoch, batch_idx, dn_data, batch, model_output, target # more prep work stream_names = [stream.name for stream in cf.streams] - analysis_streams_output = cf.get("analysis_streams_output", None) if cf.streams_output is not None: output_stream_names = cf.streams_output - elif analysis_streams_output is not None: # --- to be removed at some point --- - output_stream_names = analysis_streams_output # --- to be removed at some point --- else: output_stream_names = None @@ -97,8 +96,7 @@ def write_output(cf, mini_epoch, batch_idx, dn_data, batch, model_output, target geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels - # assume: is batch size guarnteed and constant: - # => calculate global sample indices for this batch by offsetting by sample_start + # calculate global sample indices for this batch by offsetting by sample_start sample_start = batch_idx * cf.batch_size_validation_per_gpu # write output From 081d90ab8c682bf2da8c63db8e81ae14b6197ccd Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 09:25:53 +0100 Subject: [PATCH 181/344] Minor cleanups --- src/weathergen/train/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 53adf5410..1d240cf42 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -351,7 +351,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): def train(self, mini_epoch): cf = self.cf self.model.train() - # torch.autograd.set_detect_anomaly(True) dataset_iter = iter(self.data_loader) @@ -433,7 +432,6 @@ def train(self, mini_epoch): # optimizer step self.grad_scaler.step(self.optimizer) self.grad_scaler.update() - # self.optimizer.step() self.target_and_aux_calculator.update_state_post_opt_step(bidx, batch, self.model) From 6b8fe83f8f93d2eee7e5119a13a00e74b7e1a540 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 09:26:06 +0100 Subject: [PATCH 182/344] Cleaned up handling of masks and masking metadata --- src/weathergen/datasets/masking.py | 36 +++++++++++-------- .../datasets/multi_stream_data_sampler.py | 30 ++++------------ 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index e68eb369d..24f118341 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -9,6 +9,22 @@ _logger = logging.getLogger(__name__) +class MaskData: + masks: list[np.typing.NDArray] = [] + metadata: list[SampleMetaData] = [] + + def __init__(self): + self.masks = [] + self.metadata = [] + + def __len__(self): + return len(self.masks) + + def add_mask(self, mask, params, cfg): + self.masks += [mask] + self.metadata += [SampleMetaData(params={**cfg, **params})] + + # Convert to torch.bool def to_bool_tensor(arr): return torch.from_numpy(np.asarray(arr)).to(torch.bool) @@ -354,8 +370,7 @@ def build_samples_for_stream( target_cfgs = source_cfgs # iterate over all target samples - target_masks: list[np.typing.NDArray] = [] - target_metadata: list[SampleMetaData] = [] + target_masks = MaskData() # different strategies for target_cfg in target_cfgs: # different samples/view per strategy @@ -366,12 +381,10 @@ def build_samples_for_stream( target_mask=None, masking_strategy_config=target_cfg.get("masking_strategy_config", {}), ) - target_masks += [target_mask] - target_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] + target_masks.add_mask(target_mask, mask_params, target_cfg) # iterate over all source samples - source_masks: list[np.typing.NDArray] = [] - source_metadata: list[SampleMetaData] = [] + source_masks = MaskData() source_target_mapping = [] # different strategies for i_source, source_cfg in enumerate(source_cfgs): @@ -381,21 +394,16 @@ def build_samples_for_stream( num_cells=num_cells, strategy=source_cfg.get("masking_strategy"), masking_strategy_config=source_cfg.get("masking_strategy_config", {}), - target_mask=target_masks[i_source % len(target_masks)], + target_mask=target_masks.masks[i_source % len(target_masks)], relationship=source_cfg.get("relationship", "independent"), ) - source_masks += [source_mask] - source_metadata += [SampleMetaData(params={**target_cfg, **mask_params})] + source_masks.add_mask(source_mask, mask_params, source_cfg) # TODO: proper correspondence between source and target source_target_mapping += [i_source % len(target_masks)] source_target_mapping = np.array(source_target_mapping, dtype=np.int32) - return ( - (target_masks, target_metadata), - (source_masks, source_metadata), - source_target_mapping, - ) + return (target_masks, source_masks, source_target_mapping) def _get_mask( self, diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 5bc57c448..b8006b538 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -556,14 +556,7 @@ def _get_batch(self, idx: int, forecast_dt: int): for stream_info, (stream_name, stream_ds) in zip( self.streams, self.streams_datasets.items(), strict=True ): - # TODO: data class for this or something similar - ( - target_masks, - source_masks, - student_to_teacher, - target_metadata_list, - source_metadata_list, - ) = masks_streams[stream_name] + (target_masks, source_masks, student_to_teacher) = masks_streams[stream_name] # input_data and output_data is conceptually consecutive but differs # in source and target channels; overlap in one window when self.forecast_offset=0 @@ -579,7 +572,7 @@ def _get_batch(self, idx: int, forecast_dt: int): # collect source data for current stream # loop over student views - for sidx, source_mask in enumerate(source_masks): + for sidx, source_mask in enumerate(source_masks.masks): sdata = self._build_stream_data( source_select, idx, @@ -590,19 +583,19 @@ def _get_batch(self, idx: int, forecast_dt: int): output_data, input_tokens, output_tokens, - target_masks[student_to_teacher[sidx]], + target_masks.masks[student_to_teacher[sidx]], source_mask, ) # also want to add the mask to the metadata - source_metadata = source_metadata_list[sidx] + source_metadata = source_masks.metadata[sidx] source_metadata.mask = source_mask # map each source to its target t_idx = student_to_teacher[sidx] batch.add_source_stream(sidx, t_idx, stream_name, sdata, source_metadata) - for sidx, target_mask in enumerate(target_masks): + for sidx, target_mask in enumerate(target_masks.masks): sdata = self._build_stream_data( target_select, idx, @@ -618,7 +611,7 @@ def _get_batch(self, idx: int, forecast_dt: int): ) # get target config info - target_metadata = target_metadata_list[sidx] + target_metadata = target_masks.metadata[sidx] target_metadata.mask = target_mask # find indices of all sources for current target @@ -637,19 +630,10 @@ def _get_source_target_masks(self, training_mode): masks = {} for stream_info in self.streams: # Build source and target sample masks - target_data, source_data, mapping = self.tokenizer.masker.build_samples_for_stream( + masks[stream_info["name"]] = self.tokenizer.masker.build_samples_for_stream( training_mode, self.num_healpix_cells, self.training_cfg ) - # TODO: avoid the unpacking here - masks[stream_info["name"]] = ( - target_data[0], - source_data[0], - mapping, - target_data[1], - source_data[1], - ) - # Determine number of samples directly from config (teacher and student views) source_cfgs = self.training_cfg.get("model_input") target_cfgs = self.training_cfg.get("target_input", source_cfgs) From 5a8ad49a5cc699b6b2fabff2d84ee22c81c11a41 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Thu, 11 Dec 2025 10:27:26 +0000 Subject: [PATCH 183/344] Resolved bugs when updating data structure --- config/default_config.yml | 8 ++++---- src/weathergen/train/target_and_aux_diffusion.py | 2 +- src/weathergen/train/trainer.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 98f953ee2..2af2fd218 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -1,5 +1,5 @@ -streams_directory: "./config/streams/era5_1deg/" -# streams_directory: "./config/streams/era5_nppatms_synop/" +# streams_directory: "./config/streams/era5_1deg/" +streams_directory: "./config/streams/era5_nppatms_synop/" embed_orientation: "channels" embed_unembed_mode: "block" @@ -156,7 +156,7 @@ training_config: # } # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. relationship: "independent" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - num_steps_input: 2 + num_steps_input: 1 # loss : ibot loss : training : @@ -251,7 +251,7 @@ input_window_steps: 1 val_initial: False -loader_num_workers: 0 +loader_num_workers: 8 log_validation: 0 streams_output: ["ERA5"] diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index ac2436980..accc5017c 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -17,5 +17,5 @@ def compute(self, sample: Sample, model_params: ModelParams, **kwargs) -> tuple[ with torch.no_grad(): tokens, posteriors = self.encoder(model_params=model_params, sample=sample) return TargetAuxOutput( - physical=None, latent=[tokens], aux_outputs={"noise_level_rn": noise_level_rn} + physical=None, latent=tokens, aux_outputs={"noise_level_rn": noise_level_rn} ) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index efa5b5448..30308fbd0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -546,7 +546,7 @@ def train(self, mini_epoch): ) ) # targets, aux = zip(*targets_and_auxs) - breakpoint() + loss, loss_values = self.loss_calculator.compute_loss( preds=outputs[0], targets=targets_and_auxs[0], From f768046e76966a2b0a3e9604ee06e883066add19 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Thu, 11 Dec 2025 11:17:34 +0000 Subject: [PATCH 184/344] Linter --- src/weathergen/model/diffusion.py | 5 +---- src/weathergen/model/engines.py | 2 +- src/weathergen/model/model.py | 9 ++++----- src/weathergen/train/target_and_aux_diffusion.py | 4 +++- src/weathergen/train/trainer.py | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 90874c3fa..13172d3f7 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -56,10 +56,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.p_std = self.cf.p_std def forward( - self, - tokens: torch.Tensor, - fstep: int, - meta_info: dict[str, SampleMetaData] + self, tokens: torch.Tensor, fstep: int, meta_info: dict[str, SampleMetaData] ) -> torch.Tensor: """ Model forward call during training. Unpacks the conditioning c = [x_{t-k}, ..., x_{t}], the diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 23ea26822..313e23dd1 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -459,7 +459,7 @@ def forward( tokens: torch.Tensor, fstep: int, meta_info: SampleMetaData = None, - noise_emb: torch.Tensor = None + noise_emb: torch.Tensor = None, ) -> torch.Tensor: # predict residual to last time step if requested forecast_residual = self.cf.get("forecast_residual", False) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b49ea542a..464f2cf4e 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -556,7 +556,6 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec preds_all = [] latents = {"preds": []} for fstep in range(forecast_offset, forecast_steps): - if self.training: # Impute noise to the latent state noise_std = self.cf.get("impute_latent_noise_std", 0.0) @@ -568,7 +567,7 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec # Decode tokens into physical space output = self.predict(model_params, fstep, tokens, sample, output) - + if len(preds_all) == 0: # Decode tokens when no forecasting is involved output = self.predict(model_params, forecast_steps, tokens, sample, output) @@ -582,9 +581,9 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec def forecast( self, model_params: ModelParams, - tokens: torch.Tensor, - fstep: int, - meta_info: dict[str, SampleMetaData] = None + tokens: torch.Tensor, + fstep: int, + meta_info: dict[str, SampleMetaData] = None, ) -> torch.Tensor: """Advances latent space representation in time diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index accc5017c..21a1e4f5c 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -13,7 +13,9 @@ def __init__(self, model): self.encoder = model.encoder def compute(self, sample: Sample, model_params: ModelParams, **kwargs) -> tuple[Any, Any]: - noise_level_rn = sample.meta_info["ERA5"].params["noise_level_rn"] # TODO: adjust for multiple streams + noise_level_rn = sample.meta_info["ERA5"].params[ + "noise_level_rn" + ] # TODO: adjust for multiple streams with torch.no_grad(): tokens, posteriors = self.encoder(model_params=model_params, sample=sample) return TargetAuxOutput( diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 6bd1928fa..998d1be6e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -390,7 +390,7 @@ def train(self, mini_epoch): ) ) # targets, aux = zip(*target_aux_outputs) - + loss, loss_values = self.loss_calculator.compute_loss( preds=outputs[0], targets=target_aux_outputs[0], From ca9e605991d0c6e7a0ca6d9c6255e664fa32d406 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 16:49:08 +0100 Subject: [PATCH 185/344] Current working version of default_config --- config/default_config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 6ccbf6f97..e2e3c017c 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -1,5 +1,5 @@ -# streams_directory: "./config/streams/era5_1deg/" -streams_directory: "./config/streams/era5_nppatms_synop/" +streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_nppatms_synop/" embed_orientation: "channels" embed_unembed_mode: "block" @@ -195,7 +195,7 @@ training_config: # # rate_sampling: true # randomly sample the rate per batch num_mini_epochs: 32 -samples_per_mini_epoch: 64 #4096 +samples_per_mini_epoch: 512 #4096 samples_per_validation: 32 #512 shuffle: True @@ -226,9 +226,9 @@ len_hrs: 6 step_hrs: 6 input_window_steps: 1 -val_initial: False +val_initial: False #True -loader_num_workers: 8 +loader_num_workers: 12 log_validation: 0 streams_output: ["ERA5"] From f8b1ca60a131d8a9d8509e6c0b041e3867b9db20 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 18:01:40 +0100 Subject: [PATCH 186/344] Fixed problem with branches with old code and incomplete cleanup --- src/weathergen/model/model.py | 69 +++++++++++++++++------------------ 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ce6dd8206..f40ab4a8d 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -618,7 +618,6 @@ def predict( tokens_stream = tokens_stream[model_params.hp_nbours.flatten()].flatten(0, 1) # pair with tokens from assimilation engine to obtain target tokens - preds_tokens = {} for stream_name in self.stream_names: tte = self.target_token_engines[stream_name] tte_kv = self.pred_adapter_kv[stream_name] @@ -644,46 +643,44 @@ def predict( # skip when coordinate embeddings yields nan (i.e. the coord embedding network diverged) if torch.isnan(tc_tokens).any(): - nn = stream_name - if is_root(): - logger.warning( - ( - f"Skipping prediction for {nn} because", - f" of {torch.isnan(tc_tokens).sum()} NaN in tc_tokens.", - ) + logger.warning( + ( + f"Skipping prediction for {stream_name} because", + f" of {torch.isnan(tc_tokens).sum()} NaN in tc_tokens.", ) - preds_tokens += [torch.tensor([], device=tc_tokens.device)] - continue + ) + pred = torch.tensor([], device=tc_tokens.device) # skip empty lengths - if tc_tokens.shape[0] == 0: - preds_tokens += [torch.tensor([], device=tc_tokens.device)] - continue - - # TODO: how to support tte_kv efficiently, - # generate 1-ring neighborhoods here or on a per stream basis - assert isinstance(tte_kv, torch.nn.Identity) - - # lens for varlen attention - tcs_lens = target_coords_idxs[stream_name][fstep] - # coord information for learnable layer norm - tcs_aux = torch.cat( - [ - streams_data[i_b][stream_name].target_coords[fstep] - for i_b in range(len(streams_data)) - ] - ) + elif tc_tokens.shape[0] == 0: + pred = torch.tensor([], device=tc_tokens.device) - tc_tokens = tte( - latent=tokens_stream, - output=tc_tokens, - latent_lens=model_params.tokens_lens, - output_lens=tcs_lens, - coordinates=tcs_aux, - ) + else: + # TODO: how to support tte_kv efficiently, + # generate 1-ring neighborhoods here or on a per stream basis + assert isinstance(tte_kv, torch.nn.Identity) + + # lens for varlen attention + tcs_lens = target_coords_idxs[stream_name][fstep] + # coord information for learnable layer norm + tcs_aux = torch.cat( + [ + streams_data[i_b][stream_name].target_coords[fstep] + for i_b in range(len(streams_data)) + ] + ) + + tc_tokens = tte( + latent=tokens_stream, + output=tc_tokens, + latent_lens=model_params.tokens_lens, + output_lens=tcs_lens, + coordinates=tcs_aux, + ) + + # final prediction head to map back to physical space + pred = checkpoint(self.pred_heads[stream_name], tc_tokens, use_reentrant=False) - # final prediction head to map back to physical space - pred = checkpoint(self.pred_heads[stream_name], tc_tokens, use_reentrant=False) output.add_physical_prediction(fstep, stream_name, pred) return output From 003b0cfe193b5d663eb4d6a479742e7b198d27d1 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 18:02:10 +0100 Subject: [PATCH 187/344] Updated to test convergence of integration test. --- integration_tests/small_multi_stream.yaml | 9 ++++----- integration_tests/small_multi_stream_test.py | 12 ++++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/integration_tests/small_multi_stream.yaml b/integration_tests/small_multi_stream.yaml index 6257ec00b..06ab73480 100644 --- a/integration_tests/small_multi_stream.yaml +++ b/integration_tests/small_multi_stream.yaml @@ -4,12 +4,11 @@ model_path: "./models" loss_fcts: [["mse", 1.0]] loss_fcts_val: [["mse", 1.0]] num_mini_epochs: 1 -samples_per_mini_epoch: 16 +samples_per_mini_epoch: 128 samples_per_validation: 4 -lr_steps: 4 -lr_steps_warmup: 2 +lr_steps_warmup: 4 lr_steps_cooldown: 2 -loader_num_workers: 1 +loader_num_workers: 8 # forecast_offset: 0 forecast_offset : 1 @@ -32,4 +31,4 @@ train_log: start_date: 201210010000 # need to customize the starting date for NPPATMS -# otherwise Integration Test can fails (nan values in the losses) \ No newline at end of file +# otherwise Integration Test can fails (nan values in the losses) diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index f388c13ff..3bd07ff56 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -183,14 +183,14 @@ def assert_stream_losses_below_threshold(run_id, stage="train"): # Thresholds for train and val thresholds = { "train": { - "ERA5": 2.0, - "NPPATMS": 2.0, - "SurfaceCombined": 2.0, + "ERA5": 0.2, + "NPPATMS": 0.4, + "SurfaceCombined": 0.6, }, "val": { - "ERA5": 1.5, - "NPPATMS": 1.5, - "SurfaceCombined": 1.5, + "ERA5": 0.2, + "NPPATMS": 0.3, + "SurfaceCombined": 0.5, }, } From f38e6d2b7470f128db6608b3e56da28cb9ef5b0d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 11 Dec 2025 18:13:18 +0100 Subject: [PATCH 188/344] Updated settings --- integration_tests/small_multi_stream_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index 3bd07ff56..728f9775e 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -184,13 +184,13 @@ def assert_stream_losses_below_threshold(run_id, stage="train"): thresholds = { "train": { "ERA5": 0.2, - "NPPATMS": 0.4, - "SurfaceCombined": 0.6, + "NPPATMS": 0.5, + "SurfaceCombined": 0.7, }, "val": { "ERA5": 0.2, - "NPPATMS": 0.3, - "SurfaceCombined": 0.5, + "NPPATMS": 0.4, + "SurfaceCombined": 0.6, }, } From 7e7ff8ebcd9cb50612b2e604d0e6e6dc2686c930 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 12 Dec 2025 10:41:36 +0100 Subject: [PATCH 189/344] Clessig/ypd/dev/1353 add tokens latent state finalization (#1452) * Add LatentState * Add class and register tokens for LatentState, adjust everything accordingly * Add option in config file + minor changes * Add pos.emb. for register tokens + remove class tokens + minor fixes * Minor fix * Changed empty to zeros pe_register * Ruffed * Clean-up and fixed positional encoding * Fixing things that got lost during last merge --------- Co-authored-by: Yura Perugachi Diaz Co-authored-by: Yura Perugachi Diaz --- config/default_config.yml | 2 ++ src/weathergen/model/encoder.py | 10 +++++++++ src/weathergen/model/engines.py | 11 ++++++++++ src/weathergen/model/model.py | 38 +++++++++++++++++++-------------- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index e2e3c017c..d78a4de54 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -194,6 +194,8 @@ training_config: # # # keep_m: 100 # Alternative to rate: keep exactly this many parent cells # # rate_sampling: true # randomly sample the rate per batch +num_register_tokens: 16 + num_mini_epochs: 32 samples_per_mini_epoch: 512 #4096 samples_per_validation: 32 #512 diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 833491bfa..e4a1866d9 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -21,6 +21,7 @@ # from weathergen.model.model import ModelParams from weathergen.model.parametrised_prob_dist import LatentInterpolator +from weathergen.model.positional_encoding import positional_encoding_harmonic from weathergen.utils.utils import get_dtype @@ -60,6 +61,9 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # separate embedding networks for differnt observation types self.embed_engine = EmbeddingEngine(cf, self.sources_size, self.stream_names) + assert cf.ae_global_att_dense_rate == 1.0, "Local attention not adapted for register tokens" + self.num_register_tokens = cf.num_register_tokens + # local assimilation engine self.ae_local_engine = LocalAssimilationEngine(cf) @@ -273,6 +277,12 @@ def assimilate_local( + model_params.pe_global ).flatten(1, 2) + # create register tokens and prepend to latent spatial tokens + tokens_global_register = positional_encoding_harmonic( + self.q_cells.repeat(batch_size, self.num_register_tokens, 1) + ) + tokens_global = torch.cat([tokens_global_register, tokens_global], dim=1) + return tokens_global, posteriors def assimilate_global(self, tokens: torch.Tensor) -> torch.Tensor: diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index bb19d249d..c8617c899 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -6,6 +6,7 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import dataclasses import torch import torch.nn as nn @@ -819,3 +820,13 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates): else output ) return output + + +@dataclasses.dataclass +class LatentState: + """ + A dataclass to encapsulate the latent state + """ + + register_tokens: torch.Tensor + latent_tokens: torch.Tensor diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f40ab4a8d..6863c6a60 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -26,6 +26,7 @@ from weathergen.model.engines import ( EnsPredictionHead, ForecastingEngine, + LatentState, TargetPredictionEngine, TargetPredictionEngineClassic, ) @@ -45,21 +46,19 @@ class ModelOutput: """ physical: list[dict[StreamName, torch.Tensor]] - latent: list[dict[StreamName, torch.Tensor]] + latent: list[torch.Tensor | None] def __init__(self, forecast_steps: int) -> None: self.physical = [{} for _ in range(forecast_steps)] - self.latent = [{} for _ in range(forecast_steps)] + self.latent = [None for _ in range(forecast_steps)] def add_physical_prediction( self, fstep: int, stream_name: StreamName, pred: torch.Tensor ) -> None: self.physical[fstep][stream_name] = pred - def add_latent_prediction( - self, fstep: int, stream_name: StreamName, pred: torch.Tensor - ) -> None: - self.latent[fstep][stream_name] = pred + def add_latent_prediction(self, fstep: int, pred: torch.Tensor) -> None: + self.latent[fstep] = pred def get_physical_prediction(self, fstep: int, stream_name: StreamName | None = None): pred = self.physical[fstep] @@ -67,11 +66,8 @@ def get_physical_prediction(self, fstep: int, stream_name: StreamName | None = N pred = pred.get(stream_name, None) return pred - def get_latent_prediction(self, fstep: int, stream_name: StreamName | None = None): - pred = self.latent[fstep] - if stream_name is not None: - pred = pred.get(stream_name, None) - return pred + def get_latent_prediction(self, fstep: int): + return self.latent[fstep] class ModelParams(torch.nn.Module): @@ -443,6 +439,8 @@ def create(self) -> "Model": stream_name=stream_name, ) + self.num_register_tokens = cf.num_register_tokens + return self def reset_parameters(self): @@ -543,10 +541,20 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec # collapse along input step dimension tokens = torch.stack(tokens, 0).sum(0) + # latents for output + latent_state = LatentState( + register_tokens=tokens[:, : self.num_register_tokens].clone(), + latent_tokens=tokens[:, self.num_register_tokens :].clone(), + ) + output.add_latent_prediction(0, {"posteriors": posteriors, "latent_state": latent_state}) + + # forecasting + # roll-out in latent space for fstep in range(forecast_offset, forecast_steps): # prediction - output = self.predict(model_params, fstep, tokens, sample, output) + tokens_latent = tokens[:, self.num_register_tokens :] + output = self.predict(model_params, fstep, tokens_latent, sample, output) if self.training: # Impute noise to the latent state @@ -557,10 +565,8 @@ def forward(self, model_params: ModelParams, sample, forecast_offset: int, forec tokens = self.forecast(model_params, tokens, fstep) # prediction for final step - output = self.predict(model_params, forecast_steps, tokens, sample, output) - - # TODO: set properly - output.latent = posteriors + tokens_latent = tokens[:, self.num_register_tokens :] + output = self.predict(model_params, forecast_steps, tokens_latent, sample, output) return output From 31a0b969f58a4bd69326ef2bf286c9536311f840 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 12 Dec 2025 10:50:47 +0100 Subject: [PATCH 190/344] Ruffed --- src/weathergen/model/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 6863c6a60..7e88579c4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -46,7 +46,7 @@ class ModelOutput: """ physical: list[dict[StreamName, torch.Tensor]] - latent: list[torch.Tensor | None] + latent: list[torch.Tensor] def __init__(self, forecast_steps: int) -> None: self.physical = [{} for _ in range(forecast_steps)] @@ -295,6 +295,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.stream_names: list[str] = None self.target_token_engines = None + self.num_register_tokens = cf.num_register_tokens + ######################################### def create(self) -> "Model": """Create each individual module of the model""" @@ -439,8 +441,6 @@ def create(self) -> "Model": stream_name=stream_name, ) - self.num_register_tokens = cf.num_register_tokens - return self def reset_parameters(self): From 4fe90d7622969cea2a4901c1158bcadaf9d4cef3 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 12 Dec 2025 10:56:10 +0100 Subject: [PATCH 191/344] Adding sanity check for register tokens --- src/weathergen/model/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7e88579c4..226820df6 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -295,6 +295,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.stream_names: list[str] = None self.target_token_engines = None + assert cf.forecast_att_dense_rate == 1.0, "Local attention not adapted for register tokens" self.num_register_tokens = cf.num_register_tokens ######################################### From 458e652da10e208a4308066836179102b996c0b1 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Wed, 14 Jan 2026 16:09:57 +0000 Subject: [PATCH 192/344] debug target_aux, loss_module, engines, etc --- config/default_config.yml | 5 ++-- src/weathergen/model/engines.py | 2 -- src/weathergen/model/model.py | 8 +++--- src/weathergen/model/model_interface.py | 2 +- .../loss_module_latent_diffusion.py | 25 +++++++++++-------- .../train/target_and_aux_diffusion.py | 21 ++++++++++------ 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index e610b42c0..a0a651cd2 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -179,8 +179,9 @@ training_config: eps : 2e-08 losses : { - "physical": { - type: LossPhysical, + "latent_diff": { + type: LossLatentDiffusion, + target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, }, } diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 358152e65..c4d76eba7 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -8,8 +8,6 @@ # nor does it submit to any jurisdiction. import dataclasses -import dataclasses - import torch import torch.nn as nn from omegaconf import OmegaConf diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 813743671..23ec5420b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -21,9 +21,8 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config -from weathergen.datasets.batch import Sample, SampleMetaData -from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.datasets.batch import ModelBatch +from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, @@ -270,7 +269,6 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.embed_target_coords = None self.encoder: EncoderModule | None = None self.forecast_engine: ForecastingEngine | None = None - self.forecast_offset = cf.forecast_offset self.pred_heads = None self.q_cells: torch.Tensor | None = None @@ -610,7 +608,9 @@ def forward( return output ######################################### - def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int, meta_info = None) -> torch.Tensor: + def forecast( + self, model_params: ModelParams, tokens: torch.Tensor, fstep: int, meta_info=None + ) -> torch.Tensor: """Advances latent space representation in time Args: diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index fd4520fcd..02dabb143 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -334,7 +334,7 @@ def get_target_aux_calculator( if target_and_aux_calc == "Physical": target_aux = PhysicalTargetAndAux(loss_cfg, model) elif target_and_aux_calc == "DiffusionLatentTargetEncoder": - target_aux = DiffusionLatentTargetEncoder(model) + target_aux = DiffusionLatentTargetEncoder(model) elif target_and_aux_calc == "EMATeacher": meta_ema_model, _ = init_model_and_shard( cf, diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 50dec277f..5eb04ecd9 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -15,7 +15,7 @@ from omegaconf import DictConfig from torch import Tensor -import weathergen.train.loss_modules.loss_functions as losses +import weathergen.train.loss_modules.loss_functions as loss_fns from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import Stage @@ -30,9 +30,10 @@ class LossLatentDiffusion(LossModuleBase): def __init__( self, cf: DictConfig, - loss_fcts: list, + mode_cfg: DictConfig, stage: Stage, device: str, + **loss_fcts: dict, ): LossModuleBase.__init__(self) self.cf = cf @@ -46,7 +47,14 @@ def __init__( self.p_std = self.cf.p_std # Dynamically load loss functions based on configuration and stage - self.loss_fcts = [[getattr(losses, name), w, name] for name, w in loss_fcts] + self.loss_fcts = [ + [ + getattr(loss_fns, name), + params.get("weight", 1.0), + name, + ] + for name, params in loss_fcts.items() + ] def _get_noise_weight(self, eta): sigma = (eta * self.p_std + self.p_mean).exp() @@ -56,7 +64,7 @@ def _get_fstep_weights(self, forecast_steps): timestep_weight_config = self.cf.get("timestep_weight") if timestep_weight_config is None: return [1.0 for _ in range(forecast_steps)] - weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + weights_timestep_fct = getattr(loss_fns, timestep_weight_config[0]) return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) def _loss_per_loss_function( @@ -70,16 +78,11 @@ def _loss_per_loss_function( Compute loss for given loss function """ - loss_val = noise_weight * loss_fct(target=target, mu=pred) + loss_val = noise_weight * loss_fct(target=target, pred=pred) return loss_val - def compute_loss( - self, - preds: dict, - targets: dict, - **kwargs - ) -> LossValues: + def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: losses_all: dict[str, Tensor] = { f"{self.name}.{loss_fct_name}": torch.zeros( 1, diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 198c3cd22..1daffd624 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -12,18 +12,25 @@ def __init__(self, model): # Todo: make sure this is a frozen clone or forward without gradients in compute() self.encoder = model.encoder - def compute(self, batch: ModelBatch, model_params: ModelParams, model: torch.nn.Module, **kwargs) -> tuple[Any, Any]: + def compute( + self, + istep: int, + batch: ModelBatch, + model_params: ModelParams, + model: torch.nn.Module, + *args, + **kwargs, + ) -> tuple[Any, Any]: + noise_level_rn = ( + batch.target_samples[0].meta_info["ERA5"].params["noise_level_rn"] + ) # TODO: adjust for multiple streams - noise_level_rn = batch.target_samples[0].meta_info["ERA5"].params[ - "noise_level_rn" - ] # TODO: adjust for multiple streams - with torch.no_grad(): tokens, posteriors = self.encoder(model_params=model_params, batch=batch) - + return TargetAuxOutput( num_forecast_steps=batch.get_forecast_steps(), physical=None, latent=tokens, - aux_outputs={"noise_level_rn": noise_level_rn} + aux_outputs={"noise_level_rn": noise_level_rn}, ) From 61dce3936f8c32c67606f3b19cd8306a4a96976c Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Wed, 14 Jan 2026 16:23:45 +0000 Subject: [PATCH 193/344] debug, diffusion_rn and batch.sample --- config/default_config.yml | 1 + src/weathergen/model/model.py | 5 +++-- src/weathergen/train/target_and_aux_diffusion.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index a0a651cd2..2a3067e69 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -190,6 +190,7 @@ training_config: "forecasting" : { # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True} }, } diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 23ec5420b..61d88da32 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -21,8 +21,9 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config -from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.batch import Sample, SampleMetaData from weathergen.model.diffusion import DiffusionForecastEngine +from weathergen.datasets.batch import ModelBatch from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, @@ -591,7 +592,7 @@ def forward( if noise_std > 0.0: tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std - tokens = self.forecast(model_params, tokens, fstep, batch.source_samples[0].meta_info) + tokens = self.forecast(model_params, tokens, fstep, batch.samples[0].meta_info) # safe latent prediction latent_state = LatentState( diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 1daffd624..129944df6 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -22,7 +22,7 @@ def compute( **kwargs, ) -> tuple[Any, Any]: noise_level_rn = ( - batch.target_samples[0].meta_info["ERA5"].params["noise_level_rn"] + batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] ) # TODO: adjust for multiple streams with torch.no_grad(): From ea4d76c7048abf208f85a86a4a285ebc2494e6d9 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Thu, 15 Jan 2026 09:14:45 +0000 Subject: [PATCH 194/344] Corrected latent token retrieval in loss calculation --- config/default_config.yml | 5 +++-- .../train/loss_modules/loss_module_latent_diffusion.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 2a3067e69..dc97f125f 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -142,7 +142,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking"] + training_mode: ["student_teacher"] num_mini_epochs: 32 samples_per_mini_epoch: 4096 @@ -181,6 +181,7 @@ training_config: losses : { "latent_diff": { type: LossLatentDiffusion, + weight: 1.0, target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, }, @@ -196,7 +197,7 @@ training_config: forecast : time_step: 06:00:00 - num_steps: 2 + num_steps: 1 policy: "fixed" diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 5eb04ecd9..3ee876e48 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -91,8 +91,9 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].latent_tokens for pl in preds.latent if pl] - target_tokens_all = [targets.latent] # TODO: remove extra list + + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + target_tokens_all = targets.latent eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device) fsteps = len(target_tokens_all) From b8757349c90bc3ee9695e6319fc914efbea4f136 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Thu, 15 Jan 2026 11:00:50 +0000 Subject: [PATCH 195/344] working training loop on single sample --- config/default_config.yml | 14 +++++++------- src/weathergen/model/model.py | 3 +-- .../train/loss_modules/loss_functions.py | 4 ++-- .../loss_modules/loss_module_latent_diffusion.py | 8 ++++---- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index dc97f125f..9842f4c41 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -50,8 +50,8 @@ pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True -num_class_tokens: 1 -num_register_tokens: 7 +num_class_tokens: 0 +num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder @@ -59,7 +59,7 @@ fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_diffusion_model: True +fe_diffusion_model: False fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -135,7 +135,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : False + repeat_data_in_mini_epoch : True # config for training @@ -148,13 +148,13 @@ training_config: samples_per_mini_epoch: 4096 shuffle: True - start_date: 1979-01-01T00:00 - end_date: 2022-12-31T00:00 + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 - window_offset_prediction : 1 + window_offset_prediction : 0 learning_rate_scheduling : lr_start: 1e-6 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 61d88da32..390c6b44d 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -21,9 +21,8 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config -from weathergen.datasets.batch import Sample, SampleMetaData -from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.datasets.batch import ModelBatch +from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index ccf626529..3a14c4ffa 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -206,8 +206,8 @@ def lp_loss( def mse( target: torch.Tensor, pred: torch.Tensor, - weights_channels: torch.Tensor | None, - weights_points: torch.Tensor | None, + weights_channels: torch.Tensor | None = None, + weights_points: torch.Tensor | None = None, ): """ Computes the mean squared error (mse). diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 3ee876e48..71e03d365 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -78,9 +78,10 @@ def _loss_per_loss_function( Compute loss for given loss function """ - loss_val = noise_weight * loss_fct(target=target, pred=pred) + loss, loss_chs = loss_fct(target=target, pred=pred) + loss = noise_weight * loss - return loss_val + return loss def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: losses_all: dict[str, Tensor] = { @@ -91,8 +92,7 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + pred_tokens_all = [pl["latent_state"].patch_tokens for pl in preds.latent if pl] target_tokens_all = targets.latent eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device) From c91d5c932e6c0609c2258a93f3c51078e70b742e Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 15 Jan 2026 14:53:43 +0100 Subject: [PATCH 196/344] update config to fit forecast checkpoint --- config/default_config.yml | 12 ++++++------ config/streams/era5_1deg/era5.yml | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 9842f4c41..9ed1ac100 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 1024 +ae_local_dim_embed: 2048 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -26,7 +26,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 2 +ae_aggregation_num_blocks: 0 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -55,12 +55,12 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 6 +fe_num_blocks: 16 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: False -fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 @@ -75,7 +75,7 @@ rho: 7 p_mean: -1.2 p_std: 1.2 # Encoder weights (set to null to not load a pretrained encoder) -chkpt_encoder_weights: "./models/whkujigw/whkujigw_latest.chkpt" +chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" healpix_level: 5 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 429682d85..ee1e2f19e 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -24,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 From 91d633bcc461d071494129702696daa69a712257 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Mon, 19 Jan 2026 18:12:11 +0100 Subject: [PATCH 197/344] reset default config --- config/default_config.yml | 50 +++++++++++++-------------------------- 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index f8545057a..1b403cd26 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,8 +11,8 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 2048 -ae_local_num_blocks: 0 +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -26,7 +26,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 4 +ae_global_num_blocks: 8 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 0 +ae_aggregation_num_blocks: 2 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -50,39 +50,26 @@ pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True -num_class_tokens: 0 -num_register_tokens: 0 +num_class_tokens: 1 +num_register_tokens: 7 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 16 +fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_diffusion_model: False -fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 -with_step_conditioning: True # False -# Diffusion related parameters -frequency_embedding_dim: 256 -embedding_dim: 512 -sigma_min: 0.002 -sigma_max: 50000 -sigma_data: 0.5 -rho: 7 -p_mean: -1.2 -p_std: 1.2 -# Encoder weights (set to null to not load a pretrained encoder) -chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" healpix_level: 5 with_mixed_precision: True with_flash_attention: True compile_model: False -with_fsdp: False +with_fsdp: True attention_dtype: bf16 mixed_precision_dtype: bf16 mlp_norm_eps: 1e-5 @@ -137,7 +124,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : True + repeat_data_in_mini_epoch : False # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. @@ -149,19 +136,19 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["student_teacher"] + training_mode: ["masking"] num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True - start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 - window_offset_prediction : 0 + window_offset_prediction : 1 learning_rate_scheduling : lr_start: 1e-6 @@ -186,10 +173,8 @@ training_config: eps : 2e-08 losses : { - "latent_diff": { - type: LossLatentDiffusion, - weight: 1.0, - target_and_aux_calc: DiffusionLatentTargetEncoder, + "physical": { + type: LossPhysical, loss_fcts: { "mse": { }, }, }, } @@ -198,13 +183,12 @@ training_config: "forecasting" : { # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", - masking_strategy_config: {diffusion_rn: True} }, } forecast : time_step: 06:00:00 - num_steps: 1 + num_steps: 2 policy: "fixed" From bbdb3a1da90d0cd99a7c16e90e8f7da1a9edfa7b Mon Sep 17 00:00:00 2001 From: Jubeku Date: Mon, 19 Jan 2026 18:24:12 +0100 Subject: [PATCH 198/344] modify default config for diffusion --- config/default_config.yml | 51 ++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 1b403cd26..f9f12b631 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,8 +11,8 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -26,7 +26,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 2 +ae_aggregation_num_blocks: 0 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -50,19 +50,32 @@ pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True -num_class_tokens: 1 -num_register_tokens: 7 +num_class_tokens: 0 +num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 6 +fe_num_blocks: 16 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_diffusion_model: False +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 50000 +sigma_data: 0.5 +rho: 7 +p_mean: -1.2 +p_std: 1.2 +# Encoder weights (set to null to not load a pretrained encoder) +# chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" healpix_level: 5 @@ -81,7 +94,8 @@ latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True -freeze_modules: "" + +freeze_modules: ".*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*|" norm_type: "LayerNorm" @@ -124,7 +138,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : False + repeat_data_in_mini_epoch : True # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. @@ -136,19 +150,19 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking"] + training_mode: ["student_teacher"] num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True - start_date: 1979-01-01T00:00 - end_date: 2022-12-31T00:00 + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 - window_offset_prediction : 1 + window_offset_prediction : 0 learning_rate_scheduling : lr_start: 1e-6 @@ -173,8 +187,10 @@ training_config: eps : 2e-08 losses : { - "physical": { - type: LossPhysical, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, }, } @@ -183,12 +199,13 @@ training_config: "forecasting" : { # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True} }, } forecast : time_step: 06:00:00 - num_steps: 2 + num_steps: 1 policy: "fixed" From 43b21c4e21081bdc7162e83d3e93e719279e8ac4 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Mon, 19 Jan 2026 17:56:03 +0000 Subject: [PATCH 199/344] adding encoder loading to model interface --- config/default_config.yml | 4 +- config/streams/era5_1deg/era5.yml | 4 +- src/weathergen/model/model_interface.py | 133 ++++++++++++++++++------ 3 files changed, 105 insertions(+), 36 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index f9f12b631..fbf117709 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -59,7 +59,7 @@ fe_num_blocks: 16 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_diffusion_model: False +fe_diffusion_model: True fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -76,6 +76,8 @@ p_mean: -1.2 p_std: 1.2 # Encoder weights (set to null to not load a pretrained encoder) # chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" +chkpt_encoder_weights: y8wuhr2t #"dhb9q2yo" +chkpt_encoder_mini_epoch: -1 #126 healpix_level: 5 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index ee1e2f19e..429682d85 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -24,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 512 + dim_embed : 256 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 512 + dim_embed : 256 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 664a8b824..f425c1df8 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -163,12 +163,12 @@ def init_model_and_shard( # LOAD AND FREEZE ENCODER WEIGHTS # ONLY FOR EXPERIMENTATION, TO BE REMOVED if cf.chkpt_encoder_weights: - params = torch.load( - cf.chkpt_encoder_weights, - map_location=torch.device("cpu"), - mmap=True, - weights_only=True, - ) + if is_root(): + logger.info( + f"Loading chkpt from run_id={cf.chkpt_encoder_weights}"\ + f" at mini_epoch {cf.chkpt_encoder_mini_epoch}." + ) + encoder_modules = [ "embed_engine", "ae_local_engine", @@ -176,34 +176,15 @@ def init_model_and_shard( "ae_global_engine", ] - # Load encoder weights - params_temp = {} - for name in params.keys(): - if any(e_module in name for e_module in encoder_modules): - if cf.with_ddp: - params_temp[f"module.{name}"] = params[name] - else: - params_temp[name] = params[name] - params = params_temp - mkeys, ukeys = model.load_state_dict(params, strict=False) - - # Freeze encoder weights - for name, module in model.named_modules(): - if any(e_module in name for e_module in encoder_modules): - for p in module.parameters(): - p.requires_grad = False - - model = model.to(f"cuda:{cf.local_rank}") + model = load_encoder( + cf, + model, + encoder_modules, + device, + cf.chkpt_encoder_weights, + cf.chkpt_encoder_mini_epoch, + ) - # warn about difference in checkpoint and model - if len(mkeys) == 0 and len(ukeys) == 0: - logger.info( - f"Checkpoint {cf.chkpt_encoder_weights} loaded successfully with all weights." - ) - if len(mkeys) > 0: - logger.warning(f"Missing keys when loading model: {mkeys}") - if len(ukeys) > 0: - logger.warning(f"Unused keys when loading model: {ukeys}") # ------------------------------------------------------------------------------------------ # model params @@ -213,7 +194,93 @@ def init_model_and_shard( return model, model_params +def load_encoder(cf, model, encoder_modules, device, run_id: str, mini_epoch=-1): + """Loads model state from checkpoint and checks for missing and unused keys. + Args: + run_id : model_id of the trained model + mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch + """ + + path_run = Path(cf.model_path) / run_id + mini_epoch_id = ( + f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + ) + filename = f"{run_id}_{mini_epoch_id}.chkpt" + + if not (path_run / filename).exists(): + mini_epoch_id = f"epoch{mini_epoch:05d}" + filename = f"{run_id}_{mini_epoch_id}.chkpt" + if is_root(): + logger.info(path_run / filename) + + params = torch.load( + path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True + ) + + is_model_sharded = cf.with_ddp and cf.with_fsdp + if is_model_sharded: + meta_sharded_sd = model.state_dict() + maybe_sharded_sd = {} + for param_name, full_tensor in params.items(): + if any(e_module in param_name for e_module in encoder_modules): + sharded_meta_param = meta_sharded_sd.get(param_name) + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) + maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor + mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True) + + if is_root(): + if len(mkeys) > 0: + logger.warning(f"Missing keys when loading model: {mkeys}") + if len(ukeys) > 0: + logger.warning(f"Unused keys when loading model: {mkeys}") + + # # new network parts (e.g. for fine-tuning) + # if mkeys: + # # Get the unique parent modules for the missing parameters + # new_modules_to_init = {key.rsplit(".", 1)[0] for key in mkeys} + + # # Find the highest-level "root" new modules to avoid redundant initializations + # root_new_modules = set() + # for path in sorted(list(new_modules_to_init)): + # if not any(path.startswith(root + ".") for root in root_new_modules): + # root_new_modules.add(path) + + # # Get all modules for quick lookup and initialize the new ones + # all_modules = dict(model.named_modules()) + # for path in root_new_modules: + # if is_root(): + # logger.info(f"Initializing new module not found in checkpoint: {path}") + # module_to_init = all_modules[path] + # module_to_init.to_empty(device="cuda") + # module_to_init.reset_parameters() + + else: + if not cf.with_ddp: + params_temp = {} + for k in params.keys(): + if any(e_module in k for e_module in encoder_modules): + params_temp[k.replace("module.", "")] = params[k] + params = params_temp + + mkeys, ukeys = model.load_state_dict(params, strict=False) + model = model.to(device) + + # warn about difference in checkpoint and model + if len(mkeys) == 0 and len(ukeys) == 0: + logger.info(f"Checkpoint {filename} loaded successfully with all weights matching.") + if len(mkeys) > 0: + logger.warning(f"Missing keys when loading model: {mkeys}") + if len(ukeys) > 0: + logger.warning(f"Unused keys when loading model: {mkeys}") + return model + def load_model(cf, model, device, run_id: str, mini_epoch=-1): """Loads model state from checkpoint and checks for missing and unused keys. Args: From 52b6bb15c60f0daf222467eac439016aa10cc2a3 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 20 Jan 2026 11:09:17 +0100 Subject: [PATCH 200/344] setting checkpoint to null temporarily --- config/default_config.yml | 4 ++-- config/streams/era5_1deg/era5.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index fbf117709..bdef79149 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -76,8 +76,8 @@ p_mean: -1.2 p_std: 1.2 # Encoder weights (set to null to not load a pretrained encoder) # chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" -chkpt_encoder_weights: y8wuhr2t #"dhb9q2yo" -chkpt_encoder_mini_epoch: -1 #126 +chkpt_encoder_weights: null #ß"dhb9q2yo" +chkpt_encoder_mini_epoch: 126 healpix_level: 5 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 429682d85..ee1e2f19e 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -24,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 From 0f7d4e5b6c87490eec515b9e0a3efba0191604f4 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 20 Jan 2026 12:54:34 +0100 Subject: [PATCH 201/344] rm activation checkpoint around diff forecast engine --- src/weathergen/model/engines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index fac227f71..f39f70ac1 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -488,9 +488,9 @@ def forward( ) for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): - tokens = checkpoint(block, tokens, use_reentrant=False) + tokens = block(tokens) else: - tokens = checkpoint(block, tokens, noise_emb, aux_info, use_reentrant=False) + tokens = block(tokens, noise_emb, aux_info) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): From a51f70632810147d4bc51aebb1c3fc268db62d2c Mon Sep 17 00:00:00 2001 From: Belkis Asma SEMCHEDDINE Date: Fri, 23 Jan 2026 13:24:28 +0100 Subject: [PATCH 202/344] [Diff] sbAsma/issue1279 noise conditioning (#1358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit [draft] * adapt noise conditioner to make it closer to DiT * adapt dimensionalities – code runs with default config * lint * fix: add conditional prediction mode handling This commit resolves architectural incompatibilities when integrating diffusion-based forecast engines: 1. FSDP Sharding: DiffusionForecastEngine wraps ForecastingEngine as `self.net`, but trainer code assumed direct `fe_blocks` access. Fixed by: - Adding fe_diffusion_model conditional check in init_model_and_shard() - Routing to model.forecast_engine.net.fe_blocks for diffusion mode 2. Model Initialization: Reordered ForecastingEngine creation to handle both standard and diffusion-wrapped variants with proper fallback. 3. Target Format Handling: Autoencoder mode uses different target structure than diffusion mode. Added conditional formatting: - Diffusion: targets = {"targets": [targets], "aux_outputs": aux} - Autoencoder: targets = {"physical": batch[0]} 4. Config Updates: added file config/diffusion_config.yml for diffusion model config * added forecast engine argument * removed unecessary logging * reverting back to the previous config * replaced getattr by get * modification of forecasting engine initialization --------- Co-authored-by: moritzhauschulz Co-authored-by: Matthias Karlbauer --- src/weathergen/model/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 36c04e2ce..d4ff2cbb8 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -292,11 +292,14 @@ def create(self) -> "Model": cf, self.sources_size, self.targets_num_channels, self.targets_coords_size ) + # Initialize forecasting engine: standard or diffusion-wrapped mode_cfg = cf.training_config self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) - if cf.fe_diffusion_model: + if cf.get("fe_diffusion_model", False): self.forecast_engine = DiffusionForecastEngine( - cf, self.num_healpix_cells, forecast_engine=self.forecast_engine + forecast_engine=ForecastingEngine( + cf, self.num_healpix_cells, forecast_engine=self.forecast_engine + ) ) # embed coordinates yielding one query token for each target token From 47566beba7078db2e8344f38a38da4faa362924f Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 23 Jan 2026 13:06:48 +0000 Subject: [PATCH 203/344] Correct forecast engine initialization --- src/weathergen/model/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d4ff2cbb8..0bb139c07 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -297,9 +297,7 @@ def create(self) -> "Model": self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) if cf.get("fe_diffusion_model", False): self.forecast_engine = DiffusionForecastEngine( - forecast_engine=ForecastingEngine( - cf, self.num_healpix_cells, forecast_engine=self.forecast_engine - ) + cf, self.num_healpix_cells, forecast_engine=self.forecast_engine ) # embed coordinates yielding one query token for each target token From 3ce80f0208dd6698f4fa470f410a8815ed74b824 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Sun, 8 Feb 2026 21:16:11 +0100 Subject: [PATCH 204/344] code runs... --- src/weathergen/model/model.py | 3 ++- src/weathergen/train/loss_calculator.py | 2 ++ .../loss_modules/loss_module_latent_diffusion.py | 2 +- src/weathergen/train/target_and_aux_diffusion.py | 16 +++++++++++----- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 86f6d0074..7a45953ed 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -596,7 +596,8 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: for step in batch.get_output_idxs(): # apply forecasting engine (if present) if self.forecast_engine: - tokens = self.forecast_engine(tokens, step) + # print(batch.samples[0].meta_info) + tokens = self.forecast_engine(tokens, step, batch.samples[0].meta_info) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2f81940a3..02fec9292 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -88,6 +88,8 @@ def compute_loss( loss = torch.tensor(0.0, requires_grad=True) for loss_term_name, calc_term in self.loss_calculators.items(): target = targets_and_aux[loss_term_name] + print(f'available targets for targets_and_aux.keys(): {targets_and_aux.keys()}') + print(f'Computing loss for {loss_term_name}') for weight, calculator in calc_term: if weight > 0.0: loss_values = calculator.compute_loss( diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 71e03d365..41b679807 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -93,7 +93,7 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: } pred_tokens_all = [pl["latent_state"].patch_tokens for pl in preds.latent if pl] - target_tokens_all = targets.latent + target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device) fsteps = len(target_tokens_all) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 129944df6..f2195dac4 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -28,9 +28,15 @@ def compute( with torch.no_grad(): tokens, posteriors = self.encoder(model_params=model_params, batch=batch) - return TargetAuxOutput( - num_forecast_steps=batch.get_forecast_steps(), - physical=None, - latent=tokens, - aux_outputs={"noise_level_rn": noise_level_rn}, + target_aux_output = TargetAuxOutput( + batch.get_output_len(), + output_idxs=batch.get_output_idxs(), ) + + #TODO: currently hard-coding 0 + target_aux_output.add_latent_target(0, 'diffusion_latent', tokens) + + #TODO: write function in TargetAuxOutput class + target_aux_output.aux_outputs={"noise_level_rn": noise_level_rn} + + return target_aux_output From a144867a681f5d0658997a0d6d9ed7134bfe936d Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 18 Feb 2026 09:21:35 +0100 Subject: [PATCH 205/344] remove some debugging code --- config/config_diffusion.yml | 279 ++++++++++++++++++ config/default_config.yml | 58 ++-- src/weathergen/model/diffusion.py | 21 ++ src/weathergen/model/model.py | 14 +- src/weathergen/train/loss_calculator.py | 2 - .../train/target_and_aux_diffusion.py | 19 +- src/weathergen/train/trainer_base.py | 3 - 7 files changed, 341 insertions(+), 55 deletions(-) create mode 100644 config/config_diffusion.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml new file mode 100644 index 000000000..51bef729a --- /dev/null +++ b/config/config_diffusion.yml @@ -0,0 +1,279 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 50000 +sigma_data: 0.5 +rho: 7 +p_mean: -1.2 +p_std: 1.2 +# Encoder weights (set to null to not load a pretrained encoder) +# chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" +chkpt_encoder_weights: "dhb9q2yo" +chkpt_encoder_mini_epoch: 126 + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 1 + samples_per_mini_epoch: 64 + shuffle: True + + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-5 + lr_max: 5e-5 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 512 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + }, + # "physical": { + # type: LossPhysical, + # target_and_aux_calc: Physical, + # weight: 0, + # loss_fcts: { "mse": { }, }, + # }, + # } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True} + }, + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 64 + shuffle: False + + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/default_config.yml b/config/default_config.yml index 730396d71..d951680e2 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,8 +11,8 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 2048 -ae_local_num_blocks: 0 +ae_local_dim_embed: 512 #1024 +ae_local_num_blocks: 2 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -25,8 +25,8 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 -ae_global_num_blocks: 4 +ae_global_dim_embed: 512 #1024 #2048 +ae_global_num_blocks: 2 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 0 +ae_aggregation_num_blocks: 2 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -50,34 +50,20 @@ pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True -num_class_tokens: 0 -num_register_tokens: 0 +num_class_tokens: 1 +num_register_tokens: 7 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 16 +fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_diffusion_model: True -fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_diffusion_model: False +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 -with_step_conditioning: True # False -# Diffusion related parameters -frequency_embedding_dim: 256 -embedding_dim: 512 -sigma_min: 0.002 -sigma_max: 50000 -sigma_data: 0.5 -rho: 7 -p_mean: -1.2 -p_std: 1.2 -# Encoder weights (set to null to not load a pretrained encoder) -# chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" -chkpt_encoder_weights: null #ß"dhb9q2yo" -chkpt_encoder_mini_epoch: 126 healpix_level: 5 @@ -96,8 +82,7 @@ latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True - -freeze_modules: ".*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*|" +freeze_modules: "" load_chkpt: {} norm_type: "LayerNorm" @@ -141,7 +126,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : True + repeat_data_in_mini_epoch : False # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. @@ -153,14 +138,14 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["student_teacher"] + training_mode: ["masking"] num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True - start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -188,10 +173,8 @@ training_config: eps : 2e-08 losses : { - "latent_diff": { - type: LossLatentDiffusion, - weight: 1.0, - target_and_aux_calc: DiffusionLatentTargetEncoder, + "physical": { + type: LossPhysical, loss_fcts: { "mse": { }, }, }, } @@ -200,14 +183,13 @@ training_config: "forecasting" : { # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", - masking_strategy_config: {diffusion_rn: True} }, } forecast : time_step: 06:00:00 - num_steps: 1 - offset: 0 + num_steps: 2 + offset: 1 policy: "fixed" @@ -269,4 +251,4 @@ wgtags: # *** Experiment-specific tags *** # All extra tags (including lists, dictionaries, etc.) are treated # as strings by mlflow, so treat all extra tags as simple string key: value pairs. - grid: null + grid: null \ No newline at end of file diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 13172d3f7..02002807e 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -23,6 +23,7 @@ # ---------------------------------------------------------------------------- +import logging import math import torch @@ -31,6 +32,8 @@ from weathergen.datasets.batch import SampleMetaData from weathergen.model.engines import ForecastingEngine +logger = logging.getLogger(__name__) + class DiffusionForecastEngine(torch.nn.Module): # Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72 @@ -54,6 +57,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.rho = self.cf.rho self.p_mean = self.cf.p_mean self.p_std = self.cf.p_std + self.cur_token = None # TODO: re move after single sample experiments def forward( self, tokens: torch.Tensor, fstep: int, meta_info: dict[str, SampleMetaData] @@ -70,6 +74,23 @@ def forward( # y = data.get_input_data(-1) # eta = data.get_input_metadata(-1) + # TODO: remove after single sample experiments + if self.cur_token is not None: + logger.info("checking single sampling") + assert self.cur_token[0].shape == tokens[0].shape, ( + "first token shape was different between iterations " + "– violates single sample overfitting with difference" + ) + assert torch.equal(self.cur_token[0], tokens[0]), ( + f"first token was different between iterations " + f"– violates single sample overfitting {self.cur_token[0] - tokens[0]}" + ) + assert torch.equal(self.cur_token, tokens), ( + f"tokens were different between iterations " + f"– violates single sample overfitting {self.cur_token - tokens}" + ) + self.cur_token = tokens + c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens # TODO: add correct eta from meta_info diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7a45953ed..f61f6d435 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -510,8 +510,13 @@ def print_num_parameters(self) -> None: num_params_latent_heads += get_num_parameters(self.latent_pre_norm) num_params_fe = ( - get_num_parameters(self.forecast_engine.net.fe_blocks - if cf.fe_diffusion_model else self.forecast_engine.fe_blocks) if self.forecast_engine else 0 + get_num_parameters( + self.forecast_engine.net.fe_blocks + if cf.fe_diffusion_model + else self.forecast_engine.fe_blocks + ) + if self.forecast_engine + else 0 ) mdict = self.embed_target_coords @@ -654,6 +659,7 @@ def predict_decoders( Prediction output tokens in physical representation for each target_coords. """ # Empty dicts evaluate to False in python + # breakpoint() if not self.pred_heads: return output @@ -671,6 +677,8 @@ def predict_decoders( ) tokens_nbors_lens[0] = 0 + # breakpoint() + # pair with tokens from assimilation engine to obtain target tokens for stream_name in self.stream_names: # extract target coords for current stream and fstep and convert to one tensor @@ -681,6 +689,7 @@ def predict_decoders( t_coords_lens = [len(t) for t in t_coords] t_coords = torch.cat(t_coords) + # breakpoint() if len(t_coords) == 0: continue @@ -732,6 +741,7 @@ def predict_decoders( # recover batch dimension (ragged, so as list) pred = torch.split(pred, t_coords_lens, dim=1) + # breakpoint() output.add_physical_prediction(step, stream_name, pred) return output diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 02fec9292..2f81940a3 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -88,8 +88,6 @@ def compute_loss( loss = torch.tensor(0.0, requires_grad=True) for loss_term_name, calc_term in self.loss_calculators.items(): target = targets_and_aux[loss_term_name] - print(f'available targets for targets_and_aux.keys(): {targets_and_aux.keys()}') - print(f'Computing loss for {loss_term_name}') for weight, calculator in calc_term: if weight > 0.0: loss_values = calculator.compute_loss( diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index f2195dac4..96246e20a 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -4,10 +4,12 @@ from weathergen.datasets.batch import ModelBatch from weathergen.model.model import ModelParams -from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput +from weathergen.train.target_and_aux_module_base import ( + PhysicalTargetAndAux, +) -class DiffusionLatentTargetEncoder(TargetAndAuxModuleBase): +class DiffusionLatentTargetEncoder(PhysicalTargetAndAux): def __init__(self, model): # Todo: make sure this is a frozen clone or forward without gradients in compute() self.encoder = model.encoder @@ -28,15 +30,12 @@ def compute( with torch.no_grad(): tokens, posteriors = self.encoder(model_params=model_params, batch=batch) - target_aux_output = TargetAuxOutput( - batch.get_output_len(), - output_idxs=batch.get_output_idxs(), - ) + target_aux_output = super().compute(istep, batch, model_params, model) - #TODO: currently hard-coding 0 - target_aux_output.add_latent_target(0, 'diffusion_latent', tokens) + # TODO: currently hard-coding 0 + target_aux_output.add_latent_target(0, "diffusion_latent", tokens) - #TODO: write function in TargetAuxOutput class - target_aux_output.aux_outputs={"noise_level_rn": noise_level_rn} + # TODO: write function in TargetAuxOutput class + target_aux_output.aux_outputs = {"noise_level_rn": noise_level_rn} return target_aux_output diff --git a/src/weathergen/train/trainer_base.py b/src/weathergen/train/trainer_base.py index 67be78270..d0359892a 100644 --- a/src/weathergen/train/trainer_base.py +++ b/src/weathergen/train/trainer_base.py @@ -16,9 +16,6 @@ import torch.multiprocessing from weathergen.common.config import Config -from weathergen.train.target_and_aux_diffusion import DiffusionLatentTargetEncoder -from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux -from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.utils import str_to_tensor, tensor_to_str from weathergen.utils.distributed import is_root From 63b3f78f3556f9e6f53ca2348424876173e0c590 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 18 Feb 2026 12:42:26 +0100 Subject: [PATCH 206/344] adjusted diffusion config --- config/config_diffusion.yml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 51bef729a..a7c1daea4 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -193,14 +193,7 @@ training_config: weight: 1.0, target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, - }, - # "physical": { - # type: LossPhysical, - # target_and_aux_calc: Physical, - # weight: 0, - # loss_fcts: { "mse": { }, }, - # }, - # } + } } model_input: { From 83bb4c9b6eb0482f1ba6a3e83a42963b40d0113c Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 18 Feb 2026 22:56:18 +0100 Subject: [PATCH 207/344] fixed inference --- src/weathergen/model/diffusion.py | 1 + src/weathergen/run_train.py | 2 -- src/weathergen/train/target_and_aux_diffusion.py | 12 +++++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 02002807e..8dac95888 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -130,6 +130,7 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> c_in * x, fstep=fstep, noise_emb=noise_emb ) # Eq. (7) in EDM paper + def inference( self, fstep: int, diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index e91501274..a6bab58b5 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -173,12 +173,10 @@ def run_train(args): """ cli_overwrite = config.from_cli_arglist(args.options) - cf = config.load_merge_configs( args.private_config, None, None, args.base_config, *args.config, cli_overwrite ) cf = config.set_run_id(cf, args.run_id, False) - cf.data_loading.rng_seed = int(time.time()) mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 96246e20a..e4f71f28f 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -5,11 +5,12 @@ from weathergen.datasets.batch import ModelBatch from weathergen.model.model import ModelParams from weathergen.train.target_and_aux_module_base import ( - PhysicalTargetAndAux, + TargetAndAuxModuleBase, + TargetAuxOutput, ) -class DiffusionLatentTargetEncoder(PhysicalTargetAndAux): +class DiffusionLatentTargetEncoder(TargetAndAuxModuleBase): def __init__(self, model): # Todo: make sure this is a frozen clone or forward without gradients in compute() self.encoder = model.encoder @@ -27,11 +28,16 @@ def compute( batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] ) # TODO: adjust for multiple streams + #TODO: check if there are scenarios where the encoder needs to be set to eval with torch.no_grad(): tokens, posteriors = self.encoder(model_params=model_params, batch=batch) + #NOTE: must not set to train afterwards unless it was already in train - target_aux_output = super().compute(istep, batch, model_params, model) + output_idxs = batch.get_output_idxs() + assert len(output_idxs) > 0 + target_aux_output = TargetAuxOutput(batch.get_output_len(), output_idxs) + # TODO: currently hard-coding 0 target_aux_output.add_latent_target(0, "diffusion_latent", tokens) From bb3bbe51f4fac271b5880854ab1e90ec9cf1d125 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 18 Feb 2026 22:56:39 +0100 Subject: [PATCH 208/344] actually fiex inference (via config) --- config/config_diffusion.yml | 12 ++++- config/evaluate/eval_config_diffusion.yml | 62 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 config/evaluate/eval_config_diffusion.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index a7c1daea4..78b6a5cbf 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -153,7 +153,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["student_teacher"] # ["student_teacher", "physical_loss"] + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] num_mini_epochs: 1 samples_per_mini_epoch: 64 @@ -188,6 +188,14 @@ training_config: eps : 2e-08 losses : { + "physical": { + type: LossPhysical, + weight: 1.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, "latent_diff": { type: LossLatentDiffusion, weight: 1.0, @@ -204,6 +212,8 @@ training_config: }, } + + forecast : time_step: 06:00:00 num_steps: 1 diff --git a/config/evaluate/eval_config_diffusion.yml b/config/evaluate/eval_config_diffusion.yml new file mode 100644 index 000000000..15a567af2 --- /dev/null +++ b/config/evaluate/eval_config_diffusion.yml @@ -0,0 +1,62 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +#global_plotting_options: +# region: ["belgium", "global"] +# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +# dpi_val : 300 +# fps: 2 +# ERA5: +# marker_size: 2 +# scale_marker_size: 1 +# marker: "o" +# # alpha: 0.5 +# 2t: +# vmin: 250 +# vmax: 300 +# 10u: +# vmin: -40 +# vmax: 40 + +evaluation: + metrics : ["rmse", "mae"] + regions: ["global", "nhem"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + num_processes: 0 #options: int, "auto", 0 means no parallelism (default) + # baseline: "ar40mckx" + + +default_streams: + ERA5: + channels: ["2t", "10u"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [1] + forecast_step: [1] #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + ensemble: [0] #supported: "all", "mean", [0,1,2] + plot_maps: true + plot_target: true + plot_histograms: true + plot_animations: true + + +run_ids : + kuia5xr0: + label: "debugging model g0vdqua7" + results_base_dir : "../results/" + #NEW: if "streams" is not specified, the default streams are used + + + \ No newline at end of file From b5ee07135ede17581f9e9ba9e3b8e35bfbd672d1 Mon Sep 17 00:00:00 2001 From: Matthias Date: Thu, 19 Feb 2026 18:07:19 +0100 Subject: [PATCH 209/344] Plot maps during training at validation time --- config/config_diffusion.yml | 16 ++-- src/weathergen/datasets/masking.py | 3 +- src/weathergen/model/diffusion.py | 6 +- src/weathergen/utils/validation_io.py | 111 ++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 15 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 78b6a5cbf..f02916370 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -155,7 +155,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 1 + num_mini_epochs: 25 samples_per_mini_epoch: 64 shuffle: True @@ -167,10 +167,10 @@ training_config: learning_rate_scheduling : lr_start: 1e-5 - lr_max: 5e-5 + lr_max: 1e-4 lr_final_decay: 1e-6 lr_final: 0.0 - num_steps_warmup: 512 + num_steps_warmup: 128 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "constant" @@ -190,7 +190,7 @@ training_config: losses : { "physical": { type: LossPhysical, - weight: 1.0, + weight: 0.0, loss_fcts: { "mse": {}, }, @@ -212,8 +212,6 @@ training_config: }, } - - forecast : time_step: 06:00:00 num_steps: 1 @@ -224,7 +222,7 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: - samples_per_mini_epoch: 64 + samples_per_mini_epoch: 1 shuffle: False start_date: 2012-06-01T00:00 @@ -239,7 +237,7 @@ validation_config: # parameters for validation samples that are written to disk output : { # number of samples that are written - num_samples: 0, + num_samples: 1, # write samples in normalized model space normalized_samples: False, # output streams to write; default all @@ -247,7 +245,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: True # test config; full test config is merge of validation and test config diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f84111541..844e1fe31 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -464,7 +464,8 @@ def _generate_cell_mask( mask = np.ones(num_cells, dtype=np.bool) if "diffusion_rn" in masking_strategy_config: - masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) + # masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) + masking_params["noise_level_rn"] = 0.1 elif strategy == "healpix": # prepare healpix-based masking diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 8dac95888..ce00ca262 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -76,7 +76,7 @@ def forward( # TODO: remove after single sample experiments if self.cur_token is not None: - logger.info("checking single sampling") + # logger.info("checking single sampling") assert self.cur_token[0].shape == tokens[0].shape, ( "first token shape was different between iterations " "– violates single sample overfitting with difference" @@ -105,10 +105,6 @@ def forward( return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) - # Compute loss -- move this to a separate loss calculator - # weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 # Table 1 - # loss = weight * ((y_hat - y) ** 2) - def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> torch.Tensor: """ The actual diffusion step, where the model removes noise from the input x under diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 08a866c45..5d96a094a 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,18 +8,58 @@ # nor does it submit to any jurisdiction. import logging +from pathlib import Path import numpy as np import torch +import xarray as xr import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange, zarrio_writer from weathergen.datasets.data_reader_base import TimeWindowHandler +from weathergen.evaluate.plotting.plotter import Plotter _logger = logging.getLogger(__name__) +def _normalize_channel_name(name: str) -> str: + return str(name).lower().replace("_", "").replace(" ", "") + + +def _resolve_channel_names(stream_info, raw_channels): + if not raw_channels: + return raw_channels + if isinstance(raw_channels[0], str): + return list(raw_channels) + + channel_names = None + if hasattr(stream_info, "val_target_channels") and stream_info.val_target_channels: + if isinstance(stream_info.val_target_channels[0], str): + channel_names = list(stream_info.val_target_channels) + + if channel_names is None: + target_weights = getattr(stream_info, "target_channel_weights", None) + if isinstance(target_weights, dict): + channel_names = list(target_weights.keys()) + + if channel_names is None: + channel_weights = getattr(stream_info, "channel_weights", None) + if isinstance(channel_weights, dict): + channel_names = list(channel_weights.keys()) + + if channel_names is None: + return [f"ch{idx}" for idx in raw_channels] + + resolved = [] + for idx in raw_channels: + if 0 <= int(idx) < len(channel_names): + resolved.append(channel_names[int(idx)]) + else: + resolved.append(f"ch{idx}") + return resolved + + def write_output( cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out ): @@ -159,3 +199,74 @@ def write_output( with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: for subset in data.items(): zio.write_zarr(subset) + + # Prepare prediction data for Plotter (scatter plot expects lat/lon coords on ipoint). + base_plot_dir = config.get_path_run(cf) / "plots" / "validation" + base_plot_dir.mkdir(parents=True, exist_ok=True) + plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) + headline_channels = {"2t", "z500", "q850", "10u", "10v"} + + t_idx = 0 + for stream_idx, stream_info in enumerate(cf.streams): + stream_name = stream_info["name"] + preds_stream = preds_all[t_idx][stream_idx] + coords_stream = targets_coords_all[t_idx][stream_idx] + + if preds_stream.size == 0 or coords_stream.size == 0: + _logger.warning(f"No prediction data to plot for stream {stream_name}.") + continue + + # Expected shape is (ens, ipoint, channel). Select first ensemble if present. + if preds_stream.ndim == 3: + preds_stream = preds_stream[0] + elif preds_stream.ndim != 2: + _logger.warning( + f"Unsupported prediction shape {preds_stream.shape} for stream {stream_name}." + ) + continue + + lat = coords_stream[:, 0] + lon = coords_stream[:, 1] + channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) + + da = xr.DataArray( + preds_stream, + dims=("ipoint", "channel"), + coords={ + "ipoint": np.arange(preds_stream.shape[0]), + "channel": channels, + "lat": ("ipoint", lat), + "lon": ("ipoint", lon), + }, + ) + + plotter.stream = stream_name + plotter.run_id = config.get_run_id_from_config(cf) + plotter.fstep = forecast_offset + + selected_channels = [ + ch for ch in channels if _normalize_channel_name(ch) in headline_channels + ] + if not selected_channels: + _logger.warning( + f"No headline channels available for plotting stream {stream_name}." + ) + continue + + for varname in selected_channels: + data = da.sel(channel=varname).dropna(dim="ipoint") + channel_dir = base_plot_dir / varname + channel_dir.mkdir(parents=True, exist_ok=True) + epoch_tag = f"epoch_{mini_epoch:03d}" + plot_name = plotter.scatter_plot( + data, + channel_dir, + varname=varname, + regionname="global", + tag=epoch_tag, + title=f"{stream_name} - {varname} (fstep {forecast_offset})", + ) + src = channel_dir / f"{plot_name}.{plotter.image_format}" + dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" + if src != dst and src.exists(): + src.replace(dst) \ No newline at end of file From 55b69c2c19745d4abc7422dfa455bf2b9c6a2e93 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 20 Feb 2026 12:47:45 +0100 Subject: [PATCH 210/344] Intermediate state. Single sample overfitting works --- config/config_diffusion.yml | 17 ++++++++------- src/weathergen/datasets/masking.py | 4 ++-- src/weathergen/model/diffusion.py | 9 ++------ src/weathergen/model/engines.py | 28 ++++++++++++++++++++----- src/weathergen/model/model.py | 1 - src/weathergen/model/model_interface.py | 8 ++++++- src/weathergen/train/trainer.py | 2 -- 7 files changed, 43 insertions(+), 26 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index f02916370..a02219e03 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -70,10 +70,10 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 50000 -sigma_data: 0.5 +sigma_data: 1.7 # 0.5 rho: 7 -p_mean: -1.2 -p_std: 1.2 +p_mean: -1.2 # -1.2 +p_std: 1.2 # 1.2 # Encoder weights (set to null to not load a pretrained encoder) # chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" chkpt_encoder_weights: "dhb9q2yo" @@ -166,11 +166,11 @@ training_config: time_window_len: 06:00:00 learning_rate_scheduling : - lr_start: 1e-5 - lr_max: 1e-4 + lr_start: 5e-5 # 1e-6? + lr_max: 1e-4 # 5e-5? lr_final_decay: 1e-6 lr_final: 0.0 - num_steps_warmup: 128 + num_steps_warmup: 64 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "constant" @@ -208,8 +208,9 @@ training_config: "forecasting" : { # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", - masking_strategy_config: {diffusion_rn: True} - }, + masking_strategy_config: {diffusion_rn: True}, + # num_samples: 2 + } } forecast : diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 844e1fe31..dbf866f32 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -464,8 +464,8 @@ def _generate_cell_mask( mask = np.ones(num_cells, dtype=np.bool) if "diffusion_rn" in masking_strategy_config: - # masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) - masking_params["noise_level_rn"] = 0.1 + masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) + # masking_params["noise_level_rn"] = 1.0 elif strategy == "healpix": # prepare healpix-based masking diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index ce00ca262..a0d8670ba 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -122,10 +122,7 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> # Precondition input and feed through network x = self.preconditioner.precondition(x, c) - return c_skip * x + c_out * self.net( - c_in * x, fstep=fstep, noise_emb=noise_emb - ) # Eq. (7) in EDM paper - + return c_skip * x + c_out * self.net(c_in * x, fstep=fstep, noise_emb=noise_emb) # Eq. (7) in EDM paper def inference( self, @@ -146,9 +143,7 @@ def inference( / (num_steps - 1) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) ) ** self.rho - t_steps = torch.cat( - [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] - ) # t_N = 0 + t_steps = torch.cat([self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 # Main sampling loop. x_next = x * t_steps[0] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 9ff74d96a..b15988e01 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -454,12 +454,30 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.fe_blocks.append( torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) - + + # def init_weights_final(m): + # if isinstance(m, torch.nn.Linear): + # torch.nn.init.normal_(m.weight, mean=0, std=0.001) + # if m.bias is not None: + # torch.nn.init.normal_(m.bias, mean=0, std=0.001) + + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # with_residual=True, + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.mlp_norm_eps, + # with_noise_conditioning=self.cf.fe_diffusion_model, + # ) + # ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_(m.weight, mean=0, std=0.001) + torch.nn.init.normal_(m.weight, mean=0, std=0.1) if m.bias is not None: - torch.nn.init.normal_(m.bias, mean=0, std=0.001) + torch.nn.init.normal_(m.bias, mean=0, std=0.1) for block in self.fe_blocks: block.apply(init_weights_final) @@ -479,8 +497,8 @@ def forward( noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) if noise_std > 0.0: tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std - - # predict residual to last time step if requested + + # predict residual to last time step if requested forecast_residual = self.cf.get("forecast_residual", False) if forecast_residual: tokens_in = tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f61f6d435..500daa64f 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -601,7 +601,6 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: for step in batch.get_output_idxs(): # apply forecasting engine (if present) if self.forecast_engine: - # print(batch.samples[0].meta_info) tokens = self.forecast_engine(tokens, step, batch.samples[0].meta_info) # decoder predictions diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 3110d2411..c773ddc3c 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -203,13 +203,19 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): meta_sharded_sd = model.state_dict() maybe_sharded_sd = {} for param_name, full_tensor in params.items(): + # Skip loading forecast engine weights from checkpoint + if param_name.startswith("forecast_engine."): + logger.info(f"Skipping loading of forecast engine parameter: {param_name}. It will be initialized from scratch.") + continue sharded_meta_param = meta_sharded_sd.get(param_name) + if sharded_meta_param is None: + logger.warning(f"Parameter '{param_name}' not found in model state_dict. Available keys: {list(meta_sharded_sd.keys())}") + raise RuntimeError(f"sharded_meta_param is None for '{param_name}'. Checkpoint/model mismatch or missing weights. If you intend to skip loading this parameter, add logic to skip it.") sharded_tensor = distribute_tensor( full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements, ) - # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 99f8faad5..c7884b7f8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -266,8 +266,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): else: self.validate_with_ema = False self.ema_model = None - # validate_with_ema is incompatible with student-teacher - self.validate_with_ema = False # TODO remove for testing only if self.validate_with_ema: meta_ema_model, _ = init_model_and_shard( cf, From be6cb24ded506248452169e81f37fd415ac9c14e Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 20 Feb 2026 16:05:32 +0100 Subject: [PATCH 211/344] Successful single-sample overfitting on one GPU --- config/config_diffusion.yml | 4 +- src/weathergen/model/diffusion.py | 6 +- src/weathergen/model/engines.py | 138 ++++++++++++------------ src/weathergen/model/model_interface.py | 8 +- 4 files changed, 73 insertions(+), 83 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 4f1208479..f0f9b7bb0 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -55,7 +55,7 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 16 +fe_num_blocks: 1 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -223,7 +223,7 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: - samples_per_mini_epoch: 16 + samples_per_mini_epoch: 15 shuffle: False start_date: 2012-06-01T00:00 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index c2de0de23..319c5509b 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -93,13 +93,9 @@ def forward( c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens - # TODO: add correct eta from meta_info eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) - # eta = torch.randn(1).to(device=tokens.device) - # eta = torch.tensor([metadata.noise_level_rn]).to(device=tokens.device) - # Compute sigma (noise level) from eta - # noise = torch.randn(y.shape, device=y.device) # now eta from MultiStreamDataSampler + # Compute sigma (noise level) from eta and create noise tensor sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 76236fb35..e5eedd2e3 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -404,75 +404,75 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() - # global_rate = int(1 / self.cf.forecast_att_dense_rate) - # if mode_cfg.get("forecast", {}).get("policy") is not None: - # for i in range(self.cf.fe_num_blocks): - # # Alternate between global and local attention - # if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: - # self.fe_blocks.append( - # MultiSelfAttentionHead( - # self.cf.ae_global_dim_embed, - # num_heads=self.cf.fe_num_heads, - # dropout_rate=self.cf.fe_dropout_rate, - # with_qk_lnorm=self.cf.fe_with_qk_lnorm, - # with_flash=self.cf.with_flash_attention, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.norm_eps, - # attention_dtype=get_dtype(self.cf.attention_dtype), - # with_noise_conditioning=self.cf.fe_diffusion_model, - # with_2d_rope=self.cf.get("rope_2D", False), - # ) - # ) - # else: - # self.fe_blocks.append( - # MultiSelfAttentionHeadLocal( - # self.cf.ae_global_dim_embed, - # num_heads=self.cf.fe_num_heads, - # qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, - # block_factor=self.cf.ae_global_block_factor, - # dropout_rate=self.cf.fe_dropout_rate, - # with_qk_lnorm=self.cf.fe_with_qk_lnorm, - # with_flash=self.cf.with_flash_attention, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.norm_eps, - # attention_dtype=get_dtype(self.cf.attention_dtype), - # with_noise_conditioning=self.cf.fe_diffusion_model, - # with_2d_rope=self.cf.get("rope_2D", False), - # ) - # ) - # # Add MLP block - # self.fe_blocks.append( - # MLP( - # self.cf.ae_global_dim_embed, - # self.cf.ae_global_dim_embed, - # with_residual=True, - # dropout_rate=self.cf.fe_dropout_rate, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.mlp_norm_eps, - # with_noise_conditioning=self.cf.fe_diffusion_model, - # ) - # ) - # # Optionally, add LayerNorm after i-th layer - # if i in self.cf.get("fe_layer_norm_after_blocks", []): - # self.fe_blocks.append( - # torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) - # ) - - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=self.cf.fe_diffusion_model, - ) - ) + global_rate = int(1 / self.cf.forecast_att_dense_rate) + if mode_cfg.get("forecast", {}).get("policy") is not None: + for i in range(self.cf.fe_num_blocks): + # Alternate between global and local attention + if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: + self.fe_blocks.append( + MultiSelfAttentionHead( + self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + dropout_rate=self.cf.fe_dropout_rate, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, + with_2d_rope=self.cf.get("rope_2D", False), + ) + ) + else: + self.fe_blocks.append( + MultiSelfAttentionHeadLocal( + self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, + block_factor=self.cf.ae_global_block_factor, + dropout_rate=self.cf.fe_dropout_rate, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, + with_2d_rope=self.cf.get("rope_2D", False), + ) + ) + # Add MLP block + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=self.cf.fe_diffusion_model, + ) + ) + # Optionally, add LayerNorm after i-th layer + if i in self.cf.get("fe_layer_norm_after_blocks", []): + self.fe_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) + + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # with_residual=True, + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.mlp_norm_eps, + # with_noise_conditioning=self.cf.fe_diffusion_model, + # ) + # ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.1) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index c773ddc3c..3110d2411 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -203,19 +203,13 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): meta_sharded_sd = model.state_dict() maybe_sharded_sd = {} for param_name, full_tensor in params.items(): - # Skip loading forecast engine weights from checkpoint - if param_name.startswith("forecast_engine."): - logger.info(f"Skipping loading of forecast engine parameter: {param_name}. It will be initialized from scratch.") - continue sharded_meta_param = meta_sharded_sd.get(param_name) - if sharded_meta_param is None: - logger.warning(f"Parameter '{param_name}' not found in model state_dict. Available keys: {list(meta_sharded_sd.keys())}") - raise RuntimeError(f"sharded_meta_param is None for '{param_name}'. Checkpoint/model mismatch or missing weights. If you intend to skip loading this parameter, add logic to skip it.") sharded_tensor = distribute_tensor( full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements, ) + # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True) From 2c63c7e1907a934d680ded69d3c89b04ac2c0a98 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 20 Feb 2026 16:07:23 +0100 Subject: [PATCH 212/344] Minor config change --- config/config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index f0f9b7bb0..010713dbc 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -223,7 +223,7 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: - samples_per_mini_epoch: 15 + samples_per_mini_epoch: 16 shuffle: False start_date: 2012-06-01T00:00 From 4414fe62513c20c15e1e78c38504cd2b85a19745 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:38:17 +0100 Subject: [PATCH 213/344] Adding missing reset() function for FSDP --- src/weathergen/model/encoder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index b15f3ce86..49d860ce1 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -326,3 +326,6 @@ def assimilate_local( ).flatten(1, 2) return tokens_global, posteriors + + def reset_parameters(self): + return From c9177777d6ab926b085810154455e6d7ee841ed0 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:39:04 +0100 Subject: [PATCH 214/344] Linting --- src/weathergen/model/attention.py | 1 - src/weathergen/model/diffusion.py | 14 +++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 8b07f7496..7fedc61d2 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -561,7 +561,6 @@ def __init__( latent_space_dim=dim_embed, noise_emb_dim=512, dtype=self.dtype ) - def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 319c5509b..e7bda33c7 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -60,7 +60,11 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.cur_token = None # TODO: re move after single sample experiments def forward( - self, tokens: torch.Tensor, fstep: int, meta_info: dict[str, SampleMetaData], coords: torch.Tensor = None + self, + tokens: torch.Tensor, + fstep: int, + meta_info: dict[str, SampleMetaData], + coords: torch.Tensor = None, ) -> torch.Tensor: """ Model forward call during training. Unpacks the conditioning c = [x_{t-k}, ..., x_{t}], the @@ -118,7 +122,9 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> # Precondition input and feed through network x = self.preconditioner.precondition(x, c) - return c_skip * x + c_out * self.net(c_in * x, fstep=fstep, noise_emb=noise_emb) # Eq. (7) in EDM paper + return c_skip * x + c_out * self.net( + c_in * x, fstep=fstep, noise_emb=noise_emb + ) # Eq. (7) in EDM paper def inference( self, @@ -139,7 +145,9 @@ def inference( / (num_steps - 1) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) ) ** self.rho - t_steps = torch.cat([self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + t_steps = torch.cat( + [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + ) # t_N = 0 # Main sampling loop. x_next = x * t_steps[0] From 4ae7c13c53d1eb5714bc066e6f4ebd3ddff7c8cf Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:39:18 +0100 Subject: [PATCH 215/344] Linting --- src/weathergen/model/engines.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index e5eedd2e3..ff2ddd2a4 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -490,7 +490,6 @@ def forward( noise_emb: torch.Tensor = None, coords: torch.Tensor = None, ) -> torch.Tensor: - # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning # aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") if self.training: @@ -498,7 +497,7 @@ def forward( noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) if noise_std > 0.0: tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std - + # predict residual to last time step if requested forecast_residual = self.cf.get("forecast_residual", False) if forecast_residual: From 268d34fdbd33a6e0c600c83cdb87749a9b707284 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:39:33 +0100 Subject: [PATCH 216/344] Linting --- src/weathergen/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7c67f6f1b..9d4dc4d49 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -656,7 +656,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: tokens, step, meta_info=batch.samples[0].meta_info, - coords=model_params.rope_coords + coords=model_params.rope_coords, ) # decoder predictions From 6a487d989cf9e9ac2076fd03c1c5471d9e03b5ae Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:41:20 +0100 Subject: [PATCH 217/344] Workding on FSDP --- src/weathergen/model/model_interface.py | 38 +++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 3110d2411..a5a3d3c6a 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -180,7 +180,7 @@ def init_model_and_shard( return model, model_params - + def load_model(cf, model, device, run_id: str, mini_epoch=-1): """Loads model state from checkpoint and checks for missing and unused keys. Args: @@ -200,10 +200,30 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): is_model_sharded = cf.with_ddp and cf.with_fsdp if is_model_sharded: + # model_has_prefix_module = list(model.state_dict().keys())[0].split(".")[0] == "module" + # params_has_prefix_module = list(params.keys())[0].split(".")[0] == "module" + # if model_has_prefix_module and not params_has_prefix_module: + # # add "module." prefix + # params_temp = {} + # for k in params.keys(): + # params_temp["module." + k] = params[k] + # params = params_temp + # elif not model_has_prefix_module and params_has_prefix_module: + # # remove "module." prefix + # params_temp = {} + # for k in params.keys(): + # params_temp[k.replace("module.", "")] = params[k] + # params = params_temp + meta_sharded_sd = model.state_dict() maybe_sharded_sd = {} for param_name, full_tensor in params.items(): sharded_meta_param = meta_sharded_sd.get(param_name) + if ( + sharded_meta_param is None + or type(sharded_meta_param) is not torch.distributed.tensor.DTensor + ): + continue sharded_tensor = distribute_tensor( full_tensor, sharded_meta_param.device_mesh, @@ -309,8 +329,22 @@ def get_target_aux_calculator( # create target_and_aux_calc if target_and_aux_calc == "Physical": target_aux = PhysicalTargetAndAux(loss_cfg, model) + elif target_and_aux_calc == "DiffusionLatentTargetEncoder": - target_aux = DiffusionLatentTargetEncoder(model) + model, _ = init_model_and_shard( + cf, + dataset, + cf.get("load_chkpt", {}).get("run_id", None), + cf.get("load_chkpt", {}).get("epoch", -1), + "student", + device, + with_ddp=False, + with_fsdp=False, + overrides=target_and_aux_calc_params.get("model_param_overrides", {}), + ) + target_aux = DiffusionLatentTargetEncoder( + model, is_model_sharded=(cf.with_ddp and cf.with_fsdp) + ) elif target_and_aux_calc == "EMATeacher": # work around for problems with FSDP2 assert not cf.with_fsdp, "EMATeacher not supported with FSDP(2) at the moment" From 351e8f91d9671efac05d81d2d8c1364fe844f223 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:41:38 +0100 Subject: [PATCH 218/344] Working on FSDP --- .../train/target_and_aux_diffusion.py | 37 ++++++++++++++++--- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index e4f71f28f..ecae6b42e 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -11,9 +11,34 @@ class DiffusionLatentTargetEncoder(TargetAndAuxModuleBase): - def __init__(self, model): + def __init__(self, encoder, is_model_sharded=True): # Todo: make sure this is a frozen clone or forward without gradients in compute() - self.encoder = model.encoder + self.encoder = encoder + + self.is_model_sharded = is_model_sharded + # Build a name → param map once + self.src_params = dict(self.encoder.named_parameters()) + + self.reset() + + @torch.no_grad() + def reset(self): + """ + This function resets the EMAModel to be the same as the Model. + + It operates via the state_dict to be able to deal with sharded tensors in case + FSDP2 is used. + """ + self.encoder.to_empty(device="cuda") + for p in self.encoder.parameters(): + p.requires_grad = False + maybe_sharded_sd = self.encoder.state_dict() + mkeys, ukeys = self.encoder.load_state_dict(maybe_sharded_sd, strict=False, assign=False) + self.encoder.eval() + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + if self.is_model_sharded: + self.encoder.reshard() def compute( self, @@ -28,16 +53,16 @@ def compute( batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] ) # TODO: adjust for multiple streams - #TODO: check if there are scenarios where the encoder needs to be set to eval + # TODO: check if there are scenarios where the encoder needs to be set to eval with torch.no_grad(): - tokens, posteriors = self.encoder(model_params=model_params, batch=batch) - #NOTE: must not set to train afterwards unless it was already in train + tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) + # NOTE: must not set to train afterwards unless it was already in train output_idxs = batch.get_output_idxs() assert len(output_idxs) > 0 target_aux_output = TargetAuxOutput(batch.get_output_len(), output_idxs) - + # TODO: currently hard-coding 0 target_aux_output.add_latent_target(0, "diffusion_latent", tokens) From fbc7cd1dcdc7b018162092db0333b4baa1c84c78 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 21 Feb 2026 16:41:52 +0100 Subject: [PATCH 219/344] Linting --- src/weathergen/utils/validation_io.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 5d96a094a..f5ff0d096 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. import logging -from pathlib import Path import numpy as np import torch @@ -248,9 +247,7 @@ def write_output( ch for ch in channels if _normalize_channel_name(ch) in headline_channels ] if not selected_channels: - _logger.warning( - f"No headline channels available for plotting stream {stream_name}." - ) + _logger.warning(f"No headline channels available for plotting stream {stream_name}.") continue for varname in selected_channels: @@ -269,4 +266,4 @@ def write_output( src = channel_dir / f"{plot_name}.{plotter.image_format}" dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" if src != dst and src.exists(): - src.replace(dst) \ No newline at end of file + src.replace(dst) From 873f7b35059428cfd0d32bee13123ba868dd8785 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 23 Feb 2026 17:02:03 +0100 Subject: [PATCH 220/344] minor config changes --- config/config_diffusion.yml | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 010713dbc..057b2595b 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -59,7 +59,7 @@ fe_num_blocks: 1 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_diffusion_model: False +fe_diffusion_model: True fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -72,12 +72,9 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 1.7 # 0.5, but std of tokens is 1.7 rho: 7 -p_mean: 3.0 # -1.2 +p_mean: -1.2 #3.0 # -1.2 p_std: 1.2 # 1.2 -# Encoder weights (set to null to not load a pretrained encoder) -# chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" -chkpt_encoder_weights: "dhb9q2yo" -chkpt_encoder_mini_epoch: 126 + healpix_level: 5 @@ -98,7 +95,7 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -load_chkpt: {} +load_chkpt: {'run_id': 'aev85iny', 'epoch': -1} norm_type: "LayerNorm" @@ -166,8 +163,8 @@ training_config: time_window_len: 06:00:00 learning_rate_scheduling : - lr_start: 5e-5 # 1e-6? - lr_max: 1e-4 # 5e-5? + lr_start: 1e-6 #5e-5 + lr_max: 5e-5 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 From 7149866e41b1b96a5d79553fac93998492d44de2 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 24 Feb 2026 18:00:32 +0100 Subject: [PATCH 221/344] Activating diffusion model --- config/config_diffusion.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 010713dbc..d42909abf 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -59,7 +59,7 @@ fe_num_blocks: 1 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_diffusion_model: False +fe_diffusion_model: True fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 1.7 # 0.5, but std of tokens is 1.7 rho: 7 -p_mean: 3.0 # -1.2 +p_mean: 1.2 # -1.2 p_std: 1.2 # 1.2 # Encoder weights (set to null to not load a pretrained encoder) # chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt" From 610334cf6d92b07849c7f02f80a90c486911ce76 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 25 Feb 2026 12:53:02 +0100 Subject: [PATCH 222/344] temp set MLP --- config/config_diffusion.yml | 2 +- src/weathergen/model/engines.py | 136 ++++++++++++++++---------------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 057b2595b..105ba8710 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -206,7 +206,7 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, - num_samples: 3 + num_samples: 1 } } diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index ff2ddd2a4..e0312c75a 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -405,74 +405,74 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.fe_blocks = torch.nn.ModuleList() global_rate = int(1 / self.cf.forecast_att_dense_rate) - if mode_cfg.get("forecast", {}).get("policy") is not None: - for i in range(self.cf.fe_num_blocks): - # Alternate between global and local attention - if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: - self.fe_blocks.append( - MultiSelfAttentionHead( - self.cf.ae_global_dim_embed, - num_heads=self.cf.fe_num_heads, - dropout_rate=self.cf.fe_dropout_rate, - with_qk_lnorm=self.cf.fe_with_qk_lnorm, - with_flash=self.cf.with_flash_attention, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.norm_eps, - attention_dtype=get_dtype(self.cf.attention_dtype), - with_noise_conditioning=self.cf.fe_diffusion_model, - with_2d_rope=self.cf.get("rope_2D", False), - ) - ) - else: - self.fe_blocks.append( - MultiSelfAttentionHeadLocal( - self.cf.ae_global_dim_embed, - num_heads=self.cf.fe_num_heads, - qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, - block_factor=self.cf.ae_global_block_factor, - dropout_rate=self.cf.fe_dropout_rate, - with_qk_lnorm=self.cf.fe_with_qk_lnorm, - with_flash=self.cf.with_flash_attention, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.norm_eps, - attention_dtype=get_dtype(self.cf.attention_dtype), - with_noise_conditioning=self.cf.fe_diffusion_model, - with_2d_rope=self.cf.get("rope_2D", False), - ) - ) - # Add MLP block - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=self.cf.fe_diffusion_model, - ) - ) - # Optionally, add LayerNorm after i-th layer - if i in self.cf.get("fe_layer_norm_after_blocks", []): - self.fe_blocks.append( - torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) - ) - - # self.fe_blocks.append( - # MLP( - # self.cf.ae_global_dim_embed, - # self.cf.ae_global_dim_embed, - # with_residual=True, - # dropout_rate=self.cf.fe_dropout_rate, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.mlp_norm_eps, - # with_noise_conditioning=self.cf.fe_diffusion_model, - # ) - # ) + # if mode_cfg.get("forecast", {}).get("policy") is not None: + # for i in range(self.cf.fe_num_blocks): + # # Alternate between global and local attention + # if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: + # self.fe_blocks.append( + # MultiSelfAttentionHead( + # self.cf.ae_global_dim_embed, + # num_heads=self.cf.fe_num_heads, + # dropout_rate=self.cf.fe_dropout_rate, + # with_qk_lnorm=self.cf.fe_with_qk_lnorm, + # with_flash=self.cf.with_flash_attention, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.norm_eps, + # attention_dtype=get_dtype(self.cf.attention_dtype), + # with_noise_conditioning=self.cf.fe_diffusion_model, + # with_2d_rope=self.cf.get("rope_2D", False), + # ) + # ) + # else: + # self.fe_blocks.append( + # MultiSelfAttentionHeadLocal( + # self.cf.ae_global_dim_embed, + # num_heads=self.cf.fe_num_heads, + # qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, + # block_factor=self.cf.ae_global_block_factor, + # dropout_rate=self.cf.fe_dropout_rate, + # with_qk_lnorm=self.cf.fe_with_qk_lnorm, + # with_flash=self.cf.with_flash_attention, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.norm_eps, + # attention_dtype=get_dtype(self.cf.attention_dtype), + # with_noise_conditioning=self.cf.fe_diffusion_model, + # with_2d_rope=self.cf.get("rope_2D", False), + # ) + # ) + # # Add MLP block + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # with_residual=True, + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.mlp_norm_eps, + # with_noise_conditioning=self.cf.fe_diffusion_model, + # ) + # ) + # # Optionally, add LayerNorm after i-th layer + # if i in self.cf.get("fe_layer_norm_after_blocks", []): + # self.fe_blocks.append( + # torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + # ) + + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=self.cf.fe_diffusion_model, + ) + ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.1) From c1cf8f5914481976bc408709e3ad0678bdd9127a Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 27 Feb 2026 12:46:51 +0100 Subject: [PATCH 223/344] Combined physical and latent loss experiments --- config/config_diffusion.yml | 12 +-- src/weathergen/datasets/masking.py | 17 ++-- src/weathergen/model/engines.py | 128 ++++++++++++-------------- src/weathergen/model/model.py | 11 +++ src/weathergen/utils/validation_io.py | 41 ++++++++- 5 files changed, 122 insertions(+), 87 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 105ba8710..031b9693e 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -55,7 +55,7 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 1 +fe_num_blocks: 2 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -70,9 +70,9 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 50000 -sigma_data: 1.7 # 0.5, but std of tokens is 1.7 +sigma_data: 0.5 rho: 7 -p_mean: -1.2 #3.0 # -1.2 +p_mean: 0.0 # -1.2 p_std: 1.2 # 1.2 @@ -152,7 +152,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 25 + num_mini_epochs: 150 samples_per_mini_epoch: 66 shuffle: True @@ -187,7 +187,7 @@ training_config: losses : { "physical": { type: LossPhysical, - weight: 0.0, + weight: 0.1, loss_fcts: { "mse": {}, }, @@ -195,7 +195,7 @@ training_config: }, "latent_diff": { type: LossLatentDiffusion, - weight: 1.0, + weight: 0.9, target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, } diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f84111541..f2b8d6621 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -27,17 +27,22 @@ def __len__(self): return len(self.masks) def add_mask(self, mask, params, cfg, losses, idx, correspondence, relationship): + # TODO: REVERT TO ORIGINAL CODE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + # If noise_level_rn exists in params, also add it to global_params for easier downstream access + global_params = { + "idx": idx, + "correspondence": correspondence, + "loss": losses, + "relationship": relationship, + } + if "noise_level_rn" in params: + global_params["noise_level_rn"] = params["noise_level_rn"] self.masks += [mask] self.metadata += [ SampleMetaData( params={**cfg, **params}, mask=mask, - global_params={ - "idx": idx, - "correspondence": correspondence, - "loss": losses, - "relationship": relationship, - }, + global_params=global_params, ) ] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index e0312c75a..f9ec73a98 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -405,79 +405,67 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.fe_blocks = torch.nn.ModuleList() global_rate = int(1 / self.cf.forecast_att_dense_rate) - # if mode_cfg.get("forecast", {}).get("policy") is not None: - # for i in range(self.cf.fe_num_blocks): - # # Alternate between global and local attention - # if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: - # self.fe_blocks.append( - # MultiSelfAttentionHead( - # self.cf.ae_global_dim_embed, - # num_heads=self.cf.fe_num_heads, - # dropout_rate=self.cf.fe_dropout_rate, - # with_qk_lnorm=self.cf.fe_with_qk_lnorm, - # with_flash=self.cf.with_flash_attention, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.norm_eps, - # attention_dtype=get_dtype(self.cf.attention_dtype), - # with_noise_conditioning=self.cf.fe_diffusion_model, - # with_2d_rope=self.cf.get("rope_2D", False), - # ) - # ) - # else: - # self.fe_blocks.append( - # MultiSelfAttentionHeadLocal( - # self.cf.ae_global_dim_embed, - # num_heads=self.cf.fe_num_heads, - # qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, - # block_factor=self.cf.ae_global_block_factor, - # dropout_rate=self.cf.fe_dropout_rate, - # with_qk_lnorm=self.cf.fe_with_qk_lnorm, - # with_flash=self.cf.with_flash_attention, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.norm_eps, - # attention_dtype=get_dtype(self.cf.attention_dtype), - # with_noise_conditioning=self.cf.fe_diffusion_model, - # with_2d_rope=self.cf.get("rope_2D", False), - # ) - # ) - # # Add MLP block - # self.fe_blocks.append( - # MLP( - # self.cf.ae_global_dim_embed, - # self.cf.ae_global_dim_embed, - # with_residual=True, - # dropout_rate=self.cf.fe_dropout_rate, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.mlp_norm_eps, - # with_noise_conditioning=self.cf.fe_diffusion_model, - # ) - # ) - # # Optionally, add LayerNorm after i-th layer - # if i in self.cf.get("fe_layer_norm_after_blocks", []): - # self.fe_blocks.append( - # torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) - # ) - - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=self.cf.fe_diffusion_model, - ) - ) + if mode_cfg.get("forecast", {}).get("policy") is not None: + for i in range(self.cf.fe_num_blocks): + # Alternate between global and local attention + if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: + self.fe_blocks.append( + MultiSelfAttentionHead( + self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + dropout_rate=self.cf.fe_dropout_rate, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, + with_2d_rope=self.cf.get("rope_2D", False), + ) + ) + else: + self.fe_blocks.append( + MultiSelfAttentionHeadLocal( + self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, + block_factor=self.cf.ae_global_block_factor, + dropout_rate=self.cf.fe_dropout_rate, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, + with_2d_rope=self.cf.get("rope_2D", False), + ) + ) + # Add MLP block + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=self.cf.fe_diffusion_model, + ) + ) + # Optionally, add LayerNorm after i-th layer + if i in self.cf.get("fe_layer_norm_after_blocks", []): + self.fe_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) + def init_weights_final(m): if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_(m.weight, mean=0, std=0.1) + torch.nn.init.normal_(m.weight, mean=0, std=0.001) if m.bias is not None: - torch.nn.init.normal_(m.bias, mean=0, std=0.1) + torch.nn.init.normal_(m.bias, mean=0, std=0.001) for block in self.fe_blocks: block.apply(init_weights_final) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 9d4dc4d49..625987f34 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -648,6 +648,13 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) + # Normalize tokens + # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + t_mean = tokens.mean() + t_std = tokens.std() + tokens = (tokens - t_mean) / (t_std + 1e-6) + tokens = torch.clamp(tokens, -5.0, 5.0) + # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): # apply forecasting engine (if present) @@ -659,6 +666,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: coords=model_params.rope_coords, ) + # Un-normalize tokens + # TODO: REMOVE THIS AS ABOVE. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + tokens = tokens * (t_std + 1e-6) + t_mean + # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index f5ff0d096..de08cf4d6 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -21,11 +21,16 @@ _logger = logging.getLogger(__name__) +# TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. +i = 0 + +# TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. def _normalize_channel_name(name: str) -> str: return str(name).lower().replace("_", "").replace(" ", "") +# TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. def _resolve_channel_names(stream_info, raw_channels): if not raw_channels: return raw_channels @@ -65,7 +70,9 @@ def write_output( """ Interface for writing model output """ - + # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + global i + # TODO: how to handle multiple physical loss terms outputs_physical = [ loss_name @@ -90,7 +97,8 @@ def write_output( targets_coords_all += [[]] targets_times_all += [[]] targets_lens += [[]] - for stream_info in cf.streams: + noise_levels = [] # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + for stream_idx, stream_info in enumerate(cf.streams): sname = stream_info["name"] # predictions preds = model_output.get_physical_prediction(t_idx, sname) @@ -99,6 +107,18 @@ def write_output( preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] targets_lens[-1] += [[]] + # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + # Try to extract noise_level_rn from batch metadata if present + noise_level = None + if hasattr(batch, "target_samples"): + # Try to get from first sample for this stream + samples = batch.target_samples.get_samples() + if samples and hasattr(samples[0], "meta_info") and sname in samples[0].meta_info: + meta = samples[0].meta_info[sname] + if meta and hasattr(meta, "global_params") and meta.global_params: + noise_level = meta.global_params.get("noise_level_rn", None) + noise_levels.append(noise_level) + # handle forcing streams or if sample is empty if preds is None: # preds are empty so create copy of target and add ensemble dimension @@ -199,11 +219,14 @@ def write_output( for subset in data.items(): zio.write_zarr(subset) + # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + # Prepare prediction data for Plotter (scatter plot expects lat/lon coords on ipoint). base_plot_dir = config.get_path_run(cf) / "plots" / "validation" base_plot_dir.mkdir(parents=True, exist_ok=True) plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) - headline_channels = {"2t", "z500", "q850", "10u", "10v"} + # headline_channels = {"2t", "z500", "q850", "10u", "10v"} + headline_channels = {"2t"} t_idx = 0 for stream_idx, stream_info in enumerate(cf.streams): @@ -254,16 +277,24 @@ def write_output( data = da.sel(channel=varname).dropna(dim="ipoint") channel_dir = base_plot_dir / varname channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}" + epoch_tag = f"epoch_{mini_epoch:03d}_{i%3}" + # Add noise_level_rn to title if present for this stream + noise_level = noise_levels[stream_idx] + if noise_level is not None: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noise_level_rn={noise_level:.4f}" + else: + title = f"{stream_name} - {varname} (fstep {forecast_offset})" + plot_name = plotter.scatter_plot( data, channel_dir, varname=varname, regionname="global", tag=epoch_tag, - title=f"{stream_name} - {varname} (fstep {forecast_offset})", + title=title, ) src = channel_dir / f"{plot_name}.{plotter.image_format}" dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" if src != dst and src.exists(): src.replace(dst) + i += 1 From 506089cd96dc93b0097cb5388fadbb5edb0d8d05 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 27 Feb 2026 14:36:59 +0100 Subject: [PATCH 224/344] Adjust diffusion config to 3 samples per GPU --- config/config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 031b9693e..29474bfa2 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -206,7 +206,7 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, - num_samples: 1 + num_samples: 3 } } From ffe89c27eea1a4c9df6f17d70de2d31b7b7da9f4 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 27 Feb 2026 15:28:39 +0100 Subject: [PATCH 225/344] Pull plot_train from develop --- src/weathergen/utils/plot_training.py | 232 +++++++++++++++----------- src/weathergen/utils/train_logger.py | 16 +- 2 files changed, 146 insertions(+), 102 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index ce3f2f51a..a44f1623e 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -257,14 +257,15 @@ def plot_loss_avg(plot_dir: Path, runs_ids, runs_data, runs_active, stage=TRAIN, y_vals = np.array(run_data.train["loss_avg_mean"]) elif stage == VAL: x_vals = np.array(run_data.val["num_samples"]) - # y_vals = np.array(run_data.val["LossLatentSSLStudentTeacher.loss_avg"]) y_vals = np.array(run_data.val["loss_avg_mean"]) else: assert False + mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals)) + plt.plot( - x_vals, - y_vals, + x_vals[mask], + y_vals[mask], color=colors[i_run % len(colors)], ) # legend_str += [ run_id + " : " + runs_ids[run_id][1]] @@ -300,6 +301,7 @@ def plot_loss_per_stream( stream_names: list[str], plot_dir: Path, errs: list[str], + channels: list[str], x_axis: str = "samples", x_type: str = "step", x_scale_log: bool = False, @@ -336,92 +338,105 @@ def plot_loss_per_stream( prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] + ["r", "g", "b", "k", "m", "y"] - for stream_name in stream_names: - _fig = plt.figure(figsize=(10, 7), dpi=300) - - legend_strs = [] - min_val = np.finfo(np.float32).max - max_val = 0.0 - for mode in modes: - legend_strs += [[]] - for err in errs: - linestyle = "-" if mode == "train" else ("--x" if len(modes) > 1 else "-x") - linestyle = ":" if "stddev" in err else linestyle - alpha = 1.0 - if "train" in modes and "val" in modes: - alpha = 0.35 if "train" in mode else alpha - - for j, run_data in enumerate(runs_data): - run_data_mode = run_data.by_mode(mode) - if run_data_mode.is_empty(): - continue - # find the col of the request x-axis (e.g. samples) - x_col = next(filter(lambda c: x_axis in c, run_data_mode.columns)) - # find the cols of the requested metric (e.g. mse) for all streams - data_cols = filter( - lambda c: err in c and stream_name.lower() in c.lower(), - run_data_mode.columns, - ) - - for col in data_cols: - x_vals = np.array(run_data_mode[x_col]) - y_data = np.array(run_data_mode[col]) - - plt.plot( - x_vals, - y_data, - linestyle, - color=colors[j % len(colors)], - alpha=alpha, - ) - legend_strs[-1] += [ - ("R" if runs_active[j] else "X") - + " : " - + run_data.run_id - + " : " - + runs_ids[run_data.run_id][1] - + ": " - + col - ] - - # skip all-nan slices - if (~np.isnan(y_data)).sum() > 0: - min_val = np.min([min_val, np.nanmin(y_data)]) - max_val = np.max([max_val, np.nanmax(y_data)]) - - # TODO: ensure that legend is plotted with full opacity - legend_str = legend_strs[0] - if len(legend_str) < 1: + for channel in channels: + for stream_name in stream_names: + _fig = plt.figure(figsize=(10, 7), dpi=300) + + legend_strs = [] + min_val = np.finfo(np.float32).max + max_val = 0.0 + for mode in modes: + legend_strs += [[]] + for err in errs: + linestyle = "-" if mode == "train" else ("--x" if len(modes) > 1 else "-x") + linestyle = ":" if "stddev" in err else linestyle + alpha = 1.0 + if "train" in modes and "val" in modes: + alpha = 0.35 if "train" in mode else alpha + + for j, run_data in enumerate(runs_data): + run_data_mode = run_data.by_mode(mode) + if run_data_mode.is_empty(): + continue + # find the col of the request x-axis (e.g. samples) + x_col = next(filter(lambda c: x_axis in c, run_data_mode.columns)) + # find the cols of the requested metric (e.g. mse) and channel + # for all streams + data_cols = [] + for col in run_data_mode.columns: + col_split = col.split(".") + if len(col_split) < 4: + if col == stream_name.lower(): + data_cols += [col] + elif ( + col_split[1].lower() == stream_name.lower() + and col_split[2].lower() == err.lower() + and col_split[3] == channel + ): + data_cols += [col] + + for col in data_cols: + x_vals = np.array(run_data_mode[x_col]) + y_data = np.array(run_data_mode[col]) + mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_data)) + + plt.plot( + x_vals[mask], + y_data[mask], + linestyle, + color=colors[j % len(colors)], + alpha=alpha, + ) + legend_strs[-1] += [ + ("R" if runs_active[j] else "X") + + " : " + + run_data.run_id + + " : " + + runs_ids[run_data.run_id][1] + + ": " + + col + ] + + # skip all-nan slices + if (~np.isnan(y_data)).sum() > 0: + min_val = np.min([min_val, np.nanmin(y_data)]) + max_val = np.max([max_val, np.nanmax(y_data)]) + + # TODO: ensure that legend is plotted with full opacity + legend_str = legend_strs[0] + if len(legend_str) < 1: + plt.close() + _logger.warning(f"Could not find any data for stream: {stream_name}") + continue + + # no valid data found + if (min_val >= max_val) or np.isnan(min_val) or np.isnan(max_val): + continue + + legend = plt.legend(legend_str, loc="upper right" if not x_scale_log else "lower left") + for line in legend.get_lines(): + line.set(alpha=1.0) + plt.grid(True, which="both", ls="-") + # cap at 1.0 in case of divergence of run (through normalziation, max should be + # around 1.0) + # plt.ylim([0.95 * min_val, (None if max_val < 2.0 else min(1.1, 1.025 * max_val))]) + plt.ylim([0.95 * min_val, 1.025 * max_val]) + plt.yscale("log") + if x_scale_log: + plt.xscale("log") + plt.title(stream_name + ": " + channel + " (" + ", ".join(modes) + ")") + plt.ylabel("loss") + plt.xlabel(x_axis if x_type == "step" else "rel. time [h]") + plt.tight_layout() + rstr = "".join([f"{r}_" for r in runs_ids]) + + # save the plot + plt_fname = plot_dir / "{}{}{}_{}.png".format( + rstr, "".join([f"{m}_" for m in modes]), stream_name, channel + ) + _logger.info(f"Saving loss per stream plot to '{plt_fname}'") + plt.savefig(plt_fname) plt.close() - _logger.warning(f"Could not find any data for stream: {stream_name}") - continue - - # no valid data found - if (min_val >= max_val) or np.isnan(min_val) or np.isnan(max_val): - continue - - legend = plt.legend(legend_str) # , loc="upper right" if not x_scale_log else "lower left") - for line in legend.get_lines(): - line.set(alpha=1.0) - plt.grid(True, which="both", ls="-") - plt.yscale("log") - # cap at 1.0 in case of divergence of run (through normalziation, max should be around 1.0) - plt.ylim([0.95 * min_val, (None if max_val < 2.0 else min(1.1, 1.025 * max_val))]) - if x_scale_log: - plt.xscale("log") - plt.title(stream_name) - plt.ylabel("loss") - plt.xlabel(x_axis if x_type == "step" else "rel. time [h]") - plt.tight_layout() - rstr = "".join([f"{r}_" for r in runs_ids]) - - # save the plot - plt_fname = plot_dir / "{}{}{}.png".format( - rstr, "".join([f"{m}_" for m in modes]), stream_name - ) - _logger.info(f"Saving loss per stream plot to '{plt_fname}'") - plt.savefig(plt_fname) - plt.close() #################################################################################################### @@ -431,6 +446,7 @@ def plot_loss_per_run( run_desc: str, run_data: Metrics, stream_names: list[str], + channels: list[str] | None, plot_dir: Path, errs: list[str] | None = None, x_axis: str = "samples", @@ -486,6 +502,13 @@ def plot_loss_per_run( x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0] # find the cols of the requested metric (e.g. mse) for all streams data_cols = [c for _, c in enumerate(run_data_mode.columns) if err in c] + data_cols = [] + for col in run_data_mode.columns: + col_split = col.split(".") + if len(col_split) < 4: + continue + if col_split[2].lower() == err.lower() and col_split[3] == channels: + data_cols += [col] data_cols = list(data_cols) @@ -591,13 +614,20 @@ def plot_train(args=None): help="List of streams to plot", ) parser.add_argument( - "--errors", - "-e", - dest="errors", - default=["loss_avg"], + "--channels", + dest="channels", + default=["avg"], + type=str, + nargs="+", + help="List of channels to plot", + ) + parser.add_argument( + "--metrics", + dest="metrics", + default=["mse"], type=str, nargs="+", - help="List of errors to plot", + help="List of metrics (e.g. mse) to plot", ) parser.add_argument( "--x_type", @@ -654,7 +684,10 @@ def plot_train(args=None): # read logged data - runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids] + runs_data = [ + TrainLogger.read(run_id, model_path=model_base_dir, cols_patterns=streams) + for run_id in runs_ids + ] # determine which runs are still alive (as a process, though they might hang internally) ret = subprocess.run(["squeue"], capture_output=True) @@ -678,7 +711,8 @@ def plot_train(args=None): runs_data, runs_active, streams, - errs=args.errors, + errs=args.metrics, + channels=args.channels, x_type=args.x_type, x_scale_log=x_scale_log, plot_dir=out_dir, @@ -689,7 +723,8 @@ def plot_train(args=None): runs_data, runs_active, streams, - errs=args.errors, + errs=args.metrics, + channels=args.channels, x_type=args.x_type, x_scale_log=x_scale_log, plot_dir=out_dir, @@ -700,7 +735,8 @@ def plot_train(args=None): runs_data, runs_active, streams, - errs=args.errors, + errs=args.metrics, + channels=args.channels, x_type=args.x_type, x_scale_log=x_scale_log, plot_dir=out_dir, @@ -714,6 +750,7 @@ def plot_train(args=None): runs_ids[run_id], run_data, get_stream_names(run_id, model_path=model_base_dir), # limit to available streams + channels=args.channels, plot_dir=out_dir, ) plot_loss_per_run( @@ -722,6 +759,7 @@ def plot_train(args=None): runs_ids[run_id], run_data, get_stream_names(run_id, model_path=model_base_dir), # limit to available streams + channels=args.channels, plot_dir=out_dir, ) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index ba91c53b1..c2d85f82f 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -128,7 +128,12 @@ def add_logs( ####################################### @staticmethod - def read(run_id: str, model_path: str = None, mini_epoch: int | None = None) -> Metrics: + def read( + run_id: str, + model_path: str = None, + mini_epoch: int | None = None, + cols_patterns: list[str] | None = None, + ) -> Metrics: """ Read data for run_id """ @@ -154,7 +159,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int | None = None) -> cols1, cols_train = get_loss_terms_per_stream(cf.streams, training_cfg) cols_train += ["dtime", "samples", "mse", "lr"] cols1 += [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] - cols1_patterns = ["loss_avg"] + cols1_patterns = ["loss_avg"] + cols_patterns # read training log data try: @@ -202,8 +207,8 @@ def read(run_id: str, model_path: str = None, mini_epoch: int | None = None) -> ) cols2, cols_val = get_loss_terms_per_stream(cf.streams, validation_cfg) cols_val = ["dtime", "samples"] - cols2 = [_weathergen_timestamp, "num_samples"] - cols2_patterns = ["loss_avg"] + cols2 += [_weathergen_timestamp, "num_samples"] + cols2_patterns = ["loss_avg"] + cols_patterns # read validation log data try: @@ -272,7 +277,8 @@ def read_metrics( df = read_metrics_file(metrics_path) if cols_patterns is not None: - cols += [col for col in df.columns if "loss_avg" in col] + for col_pattern in cols_patterns: + cols += [col for col in df.columns if col_pattern in col] if stage is not None: df = df.filter(pl.col("stage") == stage) From f7a42f69f9e2978b47b6ca203a76da2757b2d807 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 2 Mar 2026 08:22:12 +0100 Subject: [PATCH 226/344] Latent size downscaling mlps for diffusion --- src/weathergen/model/engines.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index f9ec73a98..fd033f25b 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -404,6 +404,22 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() + downscale_factor = 8 + + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed // downscale_factor, + with_residual=False, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=False, + ) + ) + self.cf.ae_global_dim_embed = self.cf.ae_global_dim_embed // downscale_factor + global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: for i in range(self.cf.fe_num_blocks): @@ -461,6 +477,20 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed * downscale_factor, + with_residual=False, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=False, + ) + ) + self.cf.ae_global_dim_embed = self.cf.ae_global_dim_embed * downscale_factor + def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) From 1db5fe6d02edd3098f4b46901d2c12ce0ce7f6c0 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 2 Mar 2026 09:23:17 +0100 Subject: [PATCH 227/344] Improve support for latent losses --- src/weathergen/utils/plot_training.py | 2 +- src/weathergen/utils/train_logger.py | 32 ++++----------------------- 2 files changed, 5 insertions(+), 29 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index a44f1623e..01f860b7a 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -366,7 +366,7 @@ def plot_loss_per_stream( for col in run_data_mode.columns: col_split = col.split(".") if len(col_split) < 4: - if col == stream_name.lower(): + if stream_name in col: data_cols += [col] elif ( col_split[1].lower() == stream_name.lower() diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index c2d85f82f..0e34a1363 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -24,10 +24,9 @@ import weathergen.common.config as config # from weathergen.train.trainer import cfg_keys_to_filter -from weathergen.train.utils import Stage, cfg_keys_to_filter, flatten_dict, get_active_stage_config +from weathergen.train.utils import Stage, flatten_dict from weathergen.utils.distributed import ddp_average from weathergen.utils.metrics import get_train_metrics_path, read_metrics_file -from weathergen.utils.utils import is_stream_forcing _weathergen_timestamp = "weathergen.timestamp" _weathergen_reltime = "weathergen.reltime" @@ -155,10 +154,8 @@ def read( # training # define cols for training - training_cfg = get_active_stage_config(cf.training_config, {}, cfg_keys_to_filter) - cols1, cols_train = get_loss_terms_per_stream(cf.streams, training_cfg) - cols_train += ["dtime", "samples", "mse", "lr"] - cols1 += [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] + cols_train = ["dtime", "samples", "mse", "lr"] + cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] cols1_patterns = ["loss_avg"] + cols_patterns # read training log data @@ -200,14 +197,9 @@ def read( log_train_df = read_metrics(cf, run_id, "train", cols1, cols1_patterns, result_dir_base) - # validation # define cols for validation - validation_cfg = get_active_stage_config( - training_cfg, cf.get("validation_config", {}), cfg_keys_to_filter - ) - cols2, cols_val = get_loss_terms_per_stream(cf.streams, validation_cfg) cols_val = ["dtime", "samples"] - cols2 += [_weathergen_timestamp, "num_samples"] + cols2 = [_weathergen_timestamp, "num_samples"] cols2_patterns = ["loss_avg"] + cols_patterns # read validation log data @@ -332,22 +324,6 @@ def clean_name(s: str) -> str: return "".join(c for c in s if c.isalnum() or c == "-" or c == "_") -def get_loss_terms_per_stream(streams, stage_config): - """ - Extract per stream loss terms - """ - cols, cols_stage = [], [] - for si in streams: - if is_stream_forcing(si): - continue - for _, loss_config in stage_config.get("losses", {}).items(): - if loss_config.get("type", "LossPhysical") == "LossPhysical": - for lname, _ in loss_config.loss_fcts.items(): - cols += [_key_loss(si["name"], lname)] - cols_stage += [_clean_stream_name(si["name"]) + lname] - return cols, cols_stage - - def _clean_stream_name(stream_name: str) -> str: return stream_name.replace(",", "").replace("/", "_").replace(" ", "_") + ", " From b9feb92b4ad5c27e5fb61b9e4d5d2cfd73efde4b Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 2 Mar 2026 11:50:57 +0100 Subject: [PATCH 228/344] Fix to support models trained on older code versions --- packages/common/src/weathergen/common/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index ee9c0f851..f5f5ec53f 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -329,10 +329,10 @@ def _check_logging(config: Config) -> Config: Apply fixes to log frequency config. """ config = config.copy() - if config.get("train_logging") is None: # TODO remove this for next version - config.train_logging = OmegaConf.create( - {"checkpoint": 250, "terminal": 10, "metrics": config.train_logging.log_interval} - ) + # if config.get("train_logging") is None: # TODO remove this for next version + # config.train_logging = OmegaConf.create( + # {"checkpoint": 250, "terminal": 10, "metrics": config.train_logging.log_interval} + # ) return config From 9e205ab0308031aec59507cbcbbda0fea8453144 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 3 Mar 2026 18:42:13 +0100 Subject: [PATCH 229/344] Repair code and update to load more recent pre-trained model --- config/config_diffusion.yml | 7 ++-- src/weathergen/model/engines.py | 42 ++++++------------- src/weathergen/model/model_interface.py | 1 + src/weathergen/train/loss_calculator.py | 2 +- .../loss_module_latent_diffusion.py | 2 +- 5 files changed, 19 insertions(+), 35 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 29474bfa2..c699bf3f8 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -81,7 +81,7 @@ healpix_level: 5 with_mixed_precision: True with_flash_attention: True compile_model: False -with_fsdp: True +with_fsdp: False attention_dtype: bf16 mixed_precision_dtype: bf16 mlp_norm_eps: 1e-5 @@ -95,7 +95,7 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -load_chkpt: {'run_id': 'aev85iny', 'epoch': -1} +load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} norm_type: "LayerNorm" @@ -128,10 +128,11 @@ general: run_history: [] # logging frequency in the training loop (in number of batches) -train_log_freq: +train_logging: terminal: 10 metrics: 20 checkpoint: 250 + log_grad_norms: False # parameters for data loading data_loading : diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 74cac7354..871fd366b 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -405,22 +405,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() - downscale_factor = 8 - - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed // downscale_factor, - with_residual=False, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=False, - ) - ) - self.cf.ae_global_dim_embed = self.cf.ae_global_dim_embed // downscale_factor - global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: for i in range(self.cf.fe_num_blocks): @@ -478,20 +462,18 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed * downscale_factor, - with_residual=False, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=False, - ) - ) - self.cf.ae_global_dim_embed = self.cf.ae_global_dim_embed * downscale_factor - + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # with_residual=False, + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.mlp_norm_eps, + # with_noise_conditioning=False, + # ) + # ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index a5a3d3c6a..bc51585ba 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -76,6 +76,7 @@ def init_model_and_shard( find_unused_parameters=cf.get("ddp_find_unused_parameters", True), gradient_as_bucket_view=True, bucket_cap_mb=512, + static_graph=cf.get("ddp_static_graph", True) ) elif with_ddp and with_fsdp: diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2f81940a3..0baf707ba 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -85,7 +85,7 @@ def compute_loss( ): losses_all = defaultdict(dict) stddev_all = defaultdict(dict) - loss = torch.tensor(0.0, requires_grad=True) + loss = torch.tensor(0.0, device=self.device, requires_grad=True) for loss_term_name, calc_term in self.loss_calculators.items(): target = targets_and_aux[loss_term_name] for weight, calculator in calc_term: diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 41b679807..d1b05c1e8 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -95,7 +95,7 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: pred_tokens_all = [pl["latent_state"].patch_tokens for pl in preds.latent if pl] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] - eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device) + eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32) fsteps = len(target_tokens_all) noise_weight = self._get_noise_weight(eta) From 31b93a057238d866b37518cb785ada12dc192110 Mon Sep 17 00:00:00 2001 From: Moritz Hauschulz <60788263+moritzhauschulz@users.noreply.github.com> Date: Mon, 9 Mar 2026 12:08:03 +0000 Subject: [PATCH 230/344] Fixes that enable basic single sample overfitting (#2003) * init commit * basic single sampling works * some linting --- config/config_diffusion.yml | 8 ++-- config/runs_plot_train.yml | 37 +++++++++++++++++++ src/weathergen/model/diffusion.py | 2 + src/weathergen/model/model.py | 2 +- src/weathergen/model/model_interface.py | 2 +- .../loss_module_latent_diffusion.py | 6 ++- .../train/target_and_aux_diffusion.py | 9 ++++- src/weathergen/utils/validation_io.py | 13 ++++--- 8 files changed, 65 insertions(+), 14 deletions(-) create mode 100644 config/runs_plot_train.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index c699bf3f8..9c83e208e 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -117,7 +117,7 @@ general: # local_rank, # with_ddp, # data_path_*, - # model_path, + # model_path, # run_path, # path_shared_ @@ -156,7 +156,7 @@ training_config: num_mini_epochs: 150 samples_per_mini_epoch: 66 shuffle: True - + start_date: 2012-06-01T00:00 end_date: 2012-06-01T18:00 @@ -164,8 +164,8 @@ training_config: time_window_len: 06:00:00 learning_rate_scheduling : - lr_start: 1e-6 #5e-5 - lr_max: 5e-5 #1e-4 + lr_start: 1e-5 #5e-5 + lr_max: 1e-4 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..8613cc92f --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,37 @@ +train : + plot : + # crn7ov5y: + # slurm_id: 0 + # description: "crn7ov5y: no noise, lat 1, phys 0, start_lr=1e-5, max_lr=1e-4" + + # h3of5mec: + # slurm_id: 0 + # description: "h3of5mec: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-4, fe=2 blocks" + + # p0q1oz52: + # slurm_id: 0 + # description: "p0q1oz52: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=5e-5, fe=2 blocks" + + # aabi87jc: + # slurm_id: 0 + # description: "aabi87jc: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=2.5e-5, fe=2 blocks" + + # s9fsjudp: + # slurm_id: 0 + # description: "s9fsjudp: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-5, fe=2 blocks" + + # gbq1pxc9: + # slurm_id: 0 + # description: "gbq1pxc9: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-4, fe=MLP" + + # yyv2m7ir: + # slurm_id: 0 + # description: "yyv2m7ir: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=5e-5, fe=MLP" + + # b5c60g4a: + # slurm_id: 0 + # description: "b5c60g4a: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=2.5e-5, fe=MLP" + + # dkzr6lfq: + # slurm_id: 0 + # description: "dkzr6lfq: lat 0.9, phys 0.1, start_lr=1e-6, max_lr=1e-5, fe=MLP" \ No newline at end of file diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index e7bda33c7..66083738d 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -100,6 +100,7 @@ def forward( eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) # Compute sigma (noise level) from eta and create noise tensor + sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma @@ -122,6 +123,7 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> # Precondition input and feed through network x = self.preconditioner.precondition(x, c) + return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, noise_emb=noise_emb ) # Eq. (7) in EDM paper diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f8c4ed2c1..dc1a8d9fb 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -651,7 +651,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: t_mean = tokens.mean() t_std = tokens.std() tokens = (tokens - t_mean) / (t_std + 1e-6) - tokens = torch.clamp(tokens, -5.0, 5.0) + tokens = torch.clamp(tokens, -100.0, 100.0) # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index bc51585ba..ca2a2d725 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -76,7 +76,7 @@ def init_model_and_shard( find_unused_parameters=cf.get("ddp_find_unused_parameters", True), gradient_as_bucket_view=True, bucket_cap_mb=512, - static_graph=cf.get("ddp_static_graph", True) + static_graph=cf.get("ddp_static_graph", True), ) elif with_ddp and with_fsdp: diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index d1b05c1e8..ed076aee2 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -92,10 +92,12 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].patch_tokens for pl in preds.latent if pl] + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] - eta = torch.tensor([targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32) + eta = torch.tensor( + [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 + ) fsteps = len(target_tokens_all) noise_weight = self._get_noise_weight(eta) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index ecae6b42e..7cd2f9ea3 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -4,6 +4,7 @@ from weathergen.datasets.batch import ModelBatch from weathergen.model.model import ModelParams +from weathergen.model.utils import apply_fct_to_blocks, freeze_weights, set_to_eval from weathergen.train.target_and_aux_module_base import ( TargetAndAuxModuleBase, TargetAuxOutput, @@ -15,11 +16,14 @@ def __init__(self, encoder, is_model_sharded=True): # Todo: make sure this is a frozen clone or forward without gradients in compute() self.encoder = encoder + apply_fct_to_blocks(self.encoder, ".*", freeze_weights) + apply_fct_to_blocks(self.encoder, ".*", set_to_eval) + self.is_model_sharded = is_model_sharded # Build a name → param map once self.src_params = dict(self.encoder.named_parameters()) - self.reset() + # self.reset() @torch.no_grad() def reset(self): @@ -29,6 +33,8 @@ def reset(self): It operates via the state_dict to be able to deal with sharded tensors in case FSDP2 is used. """ + # TODO: This needs fixing, might need to use apply_fct_to_blocks as in init() + self.encoder.to_empty(device="cuda") for p in self.encoder.parameters(): p.requires_grad = False @@ -55,6 +61,7 @@ def compute( # TODO: check if there are scenarios where the encoder needs to be set to eval with torch.no_grad(): + self.encoder.encoder.eval() # NOTE: might be redundant tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) # NOTE: must not set to train afterwards unless it was already in train diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 93db38871..1290a8660 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -72,7 +72,7 @@ def write_output( """ # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. global i - + # TODO: how to handle multiple physical loss terms outputs_physical = [ loss_name @@ -97,7 +97,7 @@ def write_output( targets_coords_all += [[]] targets_times_all += [[]] targets_lens += [[]] - noise_levels = [] # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + # noise_levels = [] # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. for stream_idx, stream_info in enumerate(cf.streams): sname = stream_info["name"] @@ -277,14 +277,17 @@ def write_output( data = da.sel(channel=varname).dropna(dim="ipoint") channel_dir = base_plot_dir / varname channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i%3}" + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}" # Add noise_level_rn to title if present for this stream - noise_level = noise_levels[stream_idx] + # noise_level = noise_levels[stream_idx] + noise_level = ( + None # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. + ) if noise_level is not None: title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noise_level_rn={noise_level:.4f}" else: title = f"{stream_name} - {varname} (fstep {forecast_offset})" - + plot_name = plotter.scatter_plot( data, channel_dir, From 44c8816094bd5daa72f8d02ef640c03364ee33af Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Tue, 10 Mar 2026 10:23:35 +0100 Subject: [PATCH 231/344] Update diffusion to develop (#2022) * Improve support for latent losses (#1963) * Revert 2D rope to false by default (#1967) Set to True by accident * Implementation of DataReaderMesh (#1840) * First implementation of DataReaderMesh * Move to datareaders extra * ruff * ruff2 * Undo ruff * undo auto-linting * correct typo in eval config (#1971) * Added all-physical-streams option and x/y axis limits (#1972) * Added all-physical-streams option and x/y axis limits * Fix * Changed flag for all streams * Removed old code * moved metric parsing to eval_from_config (#1977) Co-authored-by: buschow1 * Fixed integration test (#1980) * [1974][model] Add fallback to config loading (#1985) * Add fallback to config loading * Adjust error message to be not misleading * Homegenize naming convention * Introduce bias/diff maps and animations (#1912) * Introduce bias/diff maps and animations * minor correction * Changes based on review * Introduce "plot_bias" in evaluation configuration (#1986) * Fixed index ordering to not have shuffled output (#1982) * Fixed idxs_inv to revert data point shuffeling * Fixed output handling * Handling of empty data case, addressing reviewer comment * [1893][eval] csvreader cleanup (#1906) * refactor csvreader * check if dataarray size is 0 * fix and use original logic for empty data * linting fixes * revert assertions back * [1890][eval] Move MergeReader to own module (#1892) * move mergereader * use assertions only * implement scoring for the sub-steps within the forecast window (#1896) * work in progress * working for forecast_step * working version * restore no valid times option * lint * Rename scale_z_channels to _scale_z_channels * fix 1 sample bug * Remove points_per_sample from ReaderOutput Remove points_per_sample from ReaderOutput return. * remove n_point_per_sample * fix lead time coord in compute_scores * lint * fix integration test * Fix integration test single stream (#1996) * fix test single * change yml extension and minor fixes --------- Co-authored-by: cosi1 Co-authored-by: cosi1 * [1907][eval] clean up wegen_reader.py (#1911) * clean up wegen_reader.py * remove exception * consistent reader naming * add blank line * use assertions only * make names consistent * Merge branch 'develop' into 1907-wegenreader-cleanup * revert is_regular --------- Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: Ilaria Luise * [1888][eval] Refactor Reader class (#1889) * refactor Reader * use assertion only * fix npp atms --------- Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: Ilaria Luise * [1975][model] Load model path from private repo instead of json (#1998) * Load model path from private repo instead of json * Lint * Script to compute spatial autocorrelation of structured/unstructured datasets (#1955) * standalone script to compute spatial autocorrelation of variables in a structured or unstructured dataset * remove commits that should be in pr 1951 * lint * addressed comments * removed last failure returning 500km default, and moved to packages science * updated a note * rename autocorrelation script * update example usage * Correct EMA halflife_steps calculation with rampup_ratio (#2001) Corrected rampup calculation: https://github.com/NVlabs/edm2/blob/4bf8162f601bcc09472ce8a32dd0cbe8889dc8fc/training/phema.py#L145 Co-authored-by: Wael * Reduce verbosity of output during inference and evaluation (#2006) * Fix incorrect length in validation progress bar * Removing too verbose output * [1766][1743][1332] lint and unit-test fix (#1802) * [1766][1742] fix lint and unit-test * [1766] fix linter * [1766] lint local and global consistent * [1332] add script to detect bad functions (getattr) * code quality: lint and bad functions * [1766] disable some checks * [1877] Script to populate PR labels from linked issues (#1878) * script * branch * more dirs * typo * enable * Fixed bug in linear embedding (#2012) * Adding forecast_steps feature to plot_train (#2010) * Adding forecast_steps feature to plot_train * Renamed arguement to conform to hyphen convention * Added forecast step to filename --------- Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> --------- Co-authored-by: Christian Lessig Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> Co-authored-by: Kacper Nowak Co-authored-by: Till Hauer Co-authored-by: s6sebusc <49226935+s6sebusc@users.noreply.github.com> Co-authored-by: buschow1 Co-authored-by: Matthias Karlbauer Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Co-authored-by: Michael Tarnawa <18899420+mtar@users.noreply.github.com> Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: pierluigicosi <91318382+pierluigicosi@users.noreply.github.com> Co-authored-by: cosi1 Co-authored-by: cosi1 Co-authored-by: Ilaria Luise Co-authored-by: Wael Co-authored-by: Simone Norberti <63310821+simone99n@users.noreply.github.com> Co-authored-by: Timothy Hunter --- .github/workflows/pr_assign_labels_cron.yml | 128 ++ config/config_jepa.yml | 2 +- config/evaluate/eval_config.yml | 7 +- integration_tests/small1.yaml | 30 - integration_tests/small1.yml | 219 ++++ integration_tests/small1_test.py | 37 +- integration_tests/small_multi_stream_test.py | 6 +- .../common/src/weathergen/common/config.py | 25 +- packages/common/src/weathergen/common/io.py | 7 +- .../weathergen/evaluate/export/cf_utils.py | 2 + .../evaluate/export/parsers/netcdf_parser.py | 2 + .../evaluate/export/parsers/quaver_parser.py | 2 + .../src/weathergen/evaluate/io/csv_reader.py | 178 +-- .../src/weathergen/evaluate/io/io_reader.py | 138 +- .../weathergen/evaluate/io/merge_reader.py | 360 ++++++ .../weathergen/evaluate/io/wegen_reader.py | 927 ++++++------- .../weathergen/evaluate/plotting/plotter.py | 23 +- .../src/weathergen/evaluate/run_evaluation.py | 12 +- .../src/weathergen/evaluate/utils/regions.py | 1 - .../src/weathergen/evaluate/utils/utils.py | 166 ++- .../readers_extra/data_reader_iconart.py | 2 + .../readers_extra/data_reader_mesh.py | 542 ++++++++ .../src/weathergen/readers_extra/registry.py | 4 + .../compute_spatial_autocorrelation.py | 1142 +++++++++++++++++ pyproject.toml | 8 +- scripts/actions.sh | 39 +- src/weathergen/datasets/tokenizer_masking.py | 18 +- src/weathergen/model/ema.py | 2 +- src/weathergen/model/embeddings.py | 3 +- src/weathergen/model/model_interface.py | 5 +- src/weathergen/train/loss_calculator.py | 1 + .../loss_modules/loss_module_physical.py | 1 + .../train/target_and_aux_module_base.py | 4 +- src/weathergen/train/trainer.py | 4 +- src/weathergen/utils/better_abc.py | 2 + src/weathergen/utils/plot_training.py | 111 +- src/weathergen/utils/validation_io.py | 23 +- 37 files changed, 3302 insertions(+), 881 deletions(-) create mode 100644 .github/workflows/pr_assign_labels_cron.yml delete mode 100644 integration_tests/small1.yaml create mode 100644 integration_tests/small1.yml create mode 100644 packages/evaluate/src/weathergen/evaluate/io/merge_reader.py create mode 100644 packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py create mode 100644 packages/science/compute_spatial_autocorrelation.py diff --git a/.github/workflows/pr_assign_labels_cron.yml b/.github/workflows/pr_assign_labels_cron.yml new file mode 100644 index 000000000..6af133490 --- /dev/null +++ b/.github/workflows/pr_assign_labels_cron.yml @@ -0,0 +1,128 @@ +name: Sync PR labels to issues + +on: + schedule: + - cron: '0 * * * *' # Every hour + workflow_dispatch: # Allow manual triggering + push: + branches: + # Trying on the dev branch. + - "tjh/dev/1877-pr-labels" +permissions: + pull-requests: write +# contents: write +# issues: write + +jobs: + sync-labels: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Sync PR labels + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + // The labels that will be transferred. + // Operating on a whitelist to keep the list clean. + const LABEL_WHITELIST = ["bug", "data", + "documentation", "eval", + "model", "model:inference", "model:pretrain", "model:rollout", + "performance", "science"]; + const LABEL_PATHS = { + "packages/common": ["infra"], + "packages/dashboard": ["infra"], + "packages/evaluate": ["eval"], + "packages/metrics": ["infra"], + "packages/readers_extra": ["data"], + "src/weathergen/model": ["model"], + }; + + const { data: pullRequests } = await github.rest.pulls.list({ + owner: context.repo.owner, + repo: context.repo.repo, + state: "open", + per_page: 100, + }); + + for (const pr of pullRequests) { + console.log(`Processing PR #${pr.number}: ${pr.title}`); + + const labelsToAdd = new Set(); + + // --- Rule 1: Skip issue-label sync if PR already has labels --- + if (pr.labels.length > 0) { + console.log(` PR #${pr.number} already has labels, skipping issue-label sync.`); + } else { + // --- Rule 2: Collect labels from linked issues via GraphQL --- + const { repository } = await github.graphql(` + query($owner: String!, $repo: String!, $pr: Int!) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $pr) { + closingIssuesReferences(first: 50) { + nodes { + number + labels(first: 20) { + nodes { + name + } + } + } + } + } + } + } + `, { + owner: context.repo.owner, + repo: context.repo.repo, + pr: pr.number, + }); + + const linkedIssues = repository.pullRequest.closingIssuesReferences.nodes; + const issueNumbers = linkedIssues.map(i => i.number); + console.log(` Found linked issues: ${issueNumbers.join(", ") || "none"}`); + + for (const issue of linkedIssues) { + for (const label of issue.labels.nodes) { + if (LABEL_WHITELIST.includes(label.name)) { + labelsToAdd.add(label.name); + } + } + } + } + + // --- Rule 3: Check if PR touches any path in LABEL_PATHS --- + const files = await github.paginate(github.rest.pulls.listFiles, { + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number, + per_page: 100, + }); + + for (const [path, labels] of Object.entries(LABEL_PATHS)) { + const normalizedPath = path.replace(/\/?$/, "/"); // ensure trailing slash + const touches = files.some(f => + f.filename.startsWith(normalizedPath) || f.filename === path + ); + + if (touches) { + console.log(` PR #${pr.number} touches "${path}", adding labels: ${labels.join(", ")}`); + for (const label of labels) labelsToAdd.add(label); + } + } + + // --- Apply labels --- + if (labelsToAdd.size > 0) { + console.log(` Adding labels to PR #${pr.number}: ${[...labelsToAdd].join(", ")}`); + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + labels: [...labelsToAdd], + }); + } else { + console.log(` No labels to add for PR #${pr.number}.`); + } + } \ No newline at end of file diff --git a/config/config_jepa.yml b/config/config_jepa.yml index 4f2a83b56..e0cadfe86 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -69,7 +69,7 @@ healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding -rope_2D: True +rope_2D: False with_mixed_precision: True with_flash_attention: True diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index b498103ba..bfa309a45 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -1,7 +1,7 @@ #optional: if commented out all is taken care of by the default settings # NB. global options apply to all run_ids #global_plotting_options: -# region: ["belgium", "global"] +# regions: ["europe", "global"] # image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. # dpi_val : 300 # fps: 2 @@ -47,6 +47,7 @@ default_streams: forecast_step: [1,3, 2] #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) ensemble: [0,2,5] #supported: "all", "mean", [0,1,2] plot_maps: true + plot_bias: false plot_target: false plot_histograms: true plot_animations: true @@ -59,6 +60,7 @@ default_streams: sample: [2, 3, 0] forecast_step: [1,3, 4, 5] plot_maps: true + plot_bias: false plot_target: false plot_histograms: true plot_animations: true @@ -96,6 +98,7 @@ run_ids : forecast_step: [1,3, 2] ensemble: "mean" plot_maps: true + plot_bias: false plot_target: false plot_histograms: true plot_animations: true @@ -192,4 +195,4 @@ run_ids : metrics: - fbi: thresh: 280 - - rmse \ No newline at end of file + - rmse diff --git a/integration_tests/small1.yaml b/integration_tests/small1.yaml deleted file mode 100644 index deee14328..000000000 --- a/integration_tests/small1.yaml +++ /dev/null @@ -1,30 +0,0 @@ -streams_directory: "./integration_tests/streams/" -run_path: "./results" -model_path: "./models" -loss_fcts: [["mse", 1.0]] -loss_fcts_val: [["mse", 1.0]] -num_mini_epochs: 1 -samples_per_mini_epoch: 128 -samples_per_validation: 32 -lr_steps_warmup: 4 -lr_steps_cooldown: 2 -loader_num_workers: 8 - -# forecast_offset: 0 -forecast_offset : 1 -forecast_steps: 2 -forecast_policy: "fixed" -forecast_freeze_model: False -forecast_att_dense_rate: 1.0 -fe_num_blocks: 2 -fe_num_heads: 16 -fe_dropout_rate: 0.1 -fe_with_qk_lnorm: True -fe_layer_norm_after_blocks: [] -impute_latent_noise_std: 0.0 - -healpix_level: 4 -training_mode: "forecast" - -train_logging: - log_interval: 1 diff --git a/integration_tests/small1.yml b/integration_tests/small1.yml new file mode 100644 index 000000000..415aa8548 --- /dev/null +++ b/integration_tests/small1.yml @@ -0,0 +1,219 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 512 #1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 512 #1024 #2048 +ae_global_num_blocks: 2 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 1 +num_register_tokens: 7 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 2 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 4 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" + +norm_type: "LayerNorm" + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + + +################ + +streams_directory: "./integration_tests/streams/" +model_path: "./models" +results_path: "./results" + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 2 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + +# config for training +training_config: + + training_mode: ["masking"] + + num_mini_epochs: 1 + samples_per_mini_epoch: 48 + shuffle: True + + start_date: 2014-01-01T00:00 + end_date: 2020-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + window_offset_prediction : 1 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 0.00005 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 4 + num_steps_cooldown: 4 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + masking_strategy: "forecast", + } + } + + forecast : + time_step: 06:00:00 + num_steps: 2 + policy: "fixed" + offset: 1 + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 32 + shuffle: False + + start_date: 2021-10-10T00:00 + end_date: 2022-10-11T00:00 + + output: + streams: ["ERA5"] + + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + +test_config: + output: + num_samples: 2 + + +# TODO: read latent from here +inference_config: + output: + streams: ["ERA5"] \ No newline at end of file diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index b4845d157..65ffec429 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -51,14 +51,14 @@ def test_train(setup, test_run_id): main( [ - "inference", - f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml", + "train", + f"--base-config={WEATHERGEN_HOME}/integration_tests/small1.yml", "--run-id", test_run_id, ] ) - infer_with_missing(test_run_id) + infer(test_run_id) evaluate_results(test_run_id) assert_missing_metrics_file(test_run_id) assert_train_loss_below_threshold(test_run_id) @@ -70,28 +70,7 @@ def infer(run_id): logger.info("run inference") main( [ - "-start", - "2022-10-10", - "-end", - "2022-10-11", - "--samples", - "10", - "--mini-epoch", - "0", - "--from-run-id", - run_id, - "--run-id", - run_id, - "--config", - f"{WEATHERGEN_HOME}/integration_tests/small1.yaml", - ] - ) - - -def infer_with_missing(run_id): - logger.info("run inference") - main( - [ + "inference", "-start", "2021-10-10", "-end", @@ -105,7 +84,7 @@ def infer_with_missing(run_id): "--run-id", run_id, "--config", - f"{WEATHERGEN_HOME}/integration_tests/small1.yaml", + f"{WEATHERGEN_HOME}/integration_tests/small1.yml", ] ) @@ -155,7 +134,7 @@ def evaluate_results(run_id): def load_metrics(run_id): """Helper function to load metrics""" - file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id) + file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results" / run_id, run_id=run_id) if not os.path.exists(file_path): raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}") with open(file_path) as f: @@ -165,7 +144,7 @@ def load_metrics(run_id): def assert_missing_metrics_file(run_id): """Test that a missing metrics file raises FileNotFoundError.""" - file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id) + file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results"/ run_id, run_id=run_id) assert os.path.exists(file_path), f"Metrics file does not exist for run_id: {run_id}" metrics = load_metrics(run_id) logger.info(f"Loaded metrics for run_id: {run_id}: {metrics}") @@ -208,4 +187,4 @@ def assert_val_loss_below_threshold(run_id): assert loss_metric is not None, f"'{loss_avg_name}' metric is missing in metrics file" # Check that the loss does not explode in a single mini_epoch # This is meant to be a quick test, not a convergence test - assert loss_metric < 0.25, f"'{loss_avg_name}' is {loss_metric}, expected to be below 0.25" + assert loss_metric < 0.2, f"'{loss_avg_name}' is {loss_metric}, expected to be below 0.2" \ No newline at end of file diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index 18feb1a93..92141f100 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -114,7 +114,7 @@ def evaluate_multi_stream_results(run_id): }, "evaluation": { "regions": ["global"], - "metrics": ["rmse", "l1", "mse"], + "metrics": ["rmse", "mae"], "verbose": True, "summary_plots": True, "summary_dir": "./plots/", @@ -169,7 +169,7 @@ def evaluate_multi_stream_results(run_id): def load_metrics(run_id): """Helper function to load metrics""" - file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id) + file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results" / run_id, run_id=run_id) if not file_path.is_file(): raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}") with open(file_path) as f: @@ -179,7 +179,7 @@ def load_metrics(run_id): def assert_metrics_file_exists(run_id): """Test that the metrics file exists and can be loaded.""" - file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id) + file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results" / run_id, run_id=run_id) assert file_path.is_file(), f"Metrics file does not exist for run_id: {run_id}" metrics = load_metrics(run_id) logger.info(f"Loaded metrics for run_id: {run_id}: {metrics}") diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index f5f5ec53f..4be7e92a5 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -234,12 +234,25 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) else: path = Path(model_path) / run_id - fname = path / _get_model_config_file_read_name(run_id, mini_epoch) - assert fname.exists(), ( - "The fallback path to the model does not exist. Please provide a `model_path`.", - fname, - ) - _logger.info(f"Loading config from specified run_id and mini_epoch: {fname}") + config_path_with_epoch = path / _get_model_config_file_read_name(run_id, mini_epoch) + config_path_without_epoch = path / _get_model_config_file_read_name(run_id, None) + + if config_path_with_epoch.exists(): + fname = config_path_with_epoch + _logger.info(f"Loading config from specified run_id and mini_epoch: {fname}") + elif config_path_without_epoch.exists(): + fname = config_path_without_epoch + _logger.info( + f"Config for mini_epoch {mini_epoch} not found. " + f"Falling back to config without mini_epoch: {fname}" + ) + else: + raise FileNotFoundError( + f"Could not find model config for run_id '{run_id}' " + f"(mini_epoch={mini_epoch}) in '{path}'. " + f"Tried: '{config_path_with_epoch.name}' and '{config_path_without_epoch.name}'. " + f"Please check run_id and mini_epoch." + ) with fname.open() as f: json_str = f.read() diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2243cdee8..5d493370c 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -23,7 +23,6 @@ import zarr from numpy import datetime64 from numpy.typing import NDArray -from tqdm import tqdm from zarr.errors import ZarrUserWarning from zarr.storage import LocalStore, ZipStore @@ -409,9 +408,7 @@ def __exit__(self, exc_type, exc_value, exc_tb): def write_zarr(self, item: OutputItem): """Write one output item to the zarr store.""" group = self._get_group(item.key, create=True) - for dataset in tqdm(item.datasets): # pyrefly: ignore[not-iterable] - # pyrefly doesn't recognize that tqdm makes item.datasets iterable - # until fixed, ignore the warning here + for dataset in item.datasets: if dataset is not None: self._write_dataset(group, dataset) @@ -530,7 +527,7 @@ def forecast_steps(self) -> list[int]: class ZipZarrIO(ZarrIO): def __enter__(self) -> typing.Self: - _logger.info(f"Opening zipstore, read-only: {self.read_only}") + _logger.debug(f"Opening zipstore, read-only: {self.read_only}") self._store = ZipStore(self._store_path, mode=self._mode, read_only=self.read_only) if self.read_only: self.data_root = zarr.open_group(store=self._store, mode=self._mode) diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 201ffa168..1367e4b3c 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -1,3 +1,5 @@ +# pylint: disable=bad-builtin + import logging from pathlib import Path diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index fe7655fbe..49c77a481 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -1,3 +1,5 @@ +# pylint: disable=bad-builtin + import logging from pathlib import Path from typing import Any diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py index 95e58f87b..cdf9ed640 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py @@ -1,3 +1,5 @@ +# pylint: disable=bad-builtin + import logging from pathlib import Path diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 1c819f78b..c04e214b1 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -9,7 +9,6 @@ # Standard library import logging -import re from pathlib import Path # Third-party @@ -35,25 +34,28 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non Parameters ---------- - eval_cfg : - config with plotting and evaluation options for that run id + eval_cfg : dict + Configuration containing plotting and evaluation options for the given run ID. run_id : str - run id of the model - private_paths: - list of private paths for the supported HPC + Run identifier of the model. + private_paths : dict or None, optional + Dictionary of private paths for the target HPC system. Defaults to None. """ super().__init__(eval_cfg, run_id, private_paths) self.metrics_dir = Path(self.eval_cfg.get("metrics_dir")) self.metrics_base_dir = self.metrics_dir - # for backward compatibility allow metric_dir to be specified in the run config + # for backward compatibility allow metric_dir to be specified + # in the run config assert self.metrics_dir is not None, "metrics_dir folder must be provided in the config." self.stream = list(eval_cfg.streams.keys()) + assert self.stream is not None, "stream must be provided in the config." assert len(self.stream) == 1, "CsvReader only supports one stream." + self.stream = self.stream[0] self.channels = eval_cfg.streams.get(self.stream).get("channels") @@ -62,11 +64,10 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.data = pd.DataFrame() # parameter,level,number,score,step,date,domain_name,value - for channel_file in (self.metrics_dir / self.run_id).iterdir(): + metrics_run_dir = self.metrics_dir / self.run_id + for channel_file in metrics_run_dir.iterdir(): data = pd.read_csv(channel_file) - if data.empty: - continue - else: + if not data.empty: self.data = pd.concat([self.data, data], ignore_index=True) self.data = self.data.dropna(subset=["step", "level"]) @@ -77,55 +78,78 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if "level" in self.data.columns else self.data["parameter"].astype(str) ) - self.data["step"] = pd.to_timedelta(self.data["step"]) / np.timedelta64(1, "h") - self.data["step"] = self.data["step"].astype(int) - + self.data["step"] = (pd.to_timedelta(self.data["step"]) / np.timedelta64(1, "h")).astype( + int + ) self.samples = [0] - - self.forecast_steps = sorted(self.data.step.dropna().unique().tolist()) - self.npoints_per_sample = [0] + self.forecast_steps = sorted(self.data["step"].dropna().unique().tolist()) self.epoch = [0] def get_samples(self) -> set[int]: - """get set of samples for the retrieved scores (initialisation times)""" + """ + Get set of samples for the retrieved scores (initialisation times). + + Returns + ------- + samples: set[int] + A set containing the sample indices. + """ return set(self.samples) # Placeholder implementation def get_forecast_steps(self) -> set[int]: - """get set of forecast steps""" + """ + Get set of forecast steps. + + Returns + ------- + fsteps: set[int] + A set containing the forecast step values. + """ return set(self.forecast_steps) # Placeholder implementation # TODO: get this from config def get_channels(self, stream: str | None = None) -> list[str]: - """get set of channels + """Get the list of available channels for a given stream. + Parameters ---------- - stream : - Stream name. + stream : str + The name of the stream for which to retrieve channels. + Returns ------- - List of channels. + list[str] + A list of channels available in the stream. """ assert stream == self.stream, "streams do not match in CSVReader." return list(self.channels) # Placeholder implementation def get_values( self, region: str, metric: str, forecast_steps: list[int], channels: list[str] - ) -> xr.DataArray: + ) -> xr.DataArray | None: """ - Get score values in the right format + Retrieve metric values for the specified region, metric, forecast steps and channels. + + Parameters ---------- - region : - Region name. - metric : - Metric name. - forecast_steps : - List of forecast steps. - channels : - List of channels. + region : str + The name of the region to filter by. + metric : str + The name of the metric to filter by. + forecast_steps : list[int] + A list of forecast step values to include. + channels : list[str] + A list of channel names to include. + Returns ------- - The metric DataArray. + da: xr.DataArray or None + An xarray DataArray containing the metric values with dimensions for sample, + forecast_step, lead_time, channel, and metric. The DataArray includes attributes + ``npoints_per_sample`` and the metric name as a coordinate. + If no data was found for the specified region, metric, forecast steps, and channels, + None is returned instead. """ metric_name = _metric_quaver_convention(metric) region_name = _region_quaver_convention(region) @@ -137,14 +161,20 @@ def get_values( & (self.data["channel"].isin(channels)) ] + if data.empty: + _logger.warning( + f"No values were found for region '{region}', metric '{metric}', " + f"forecast steps '{forecast_steps}', and channels '{channels}'" + ) + return None + + # convert to DataArray data = data.copy() data["sample"] = data["date"].astype("category").cat.codes data["forecast_step"] = data["step"].astype("category").cat.codes - data = data.rename(columns={"step": "lead_time"}) - data = data.rename(columns={"score": "metric"}) - - data = data[["sample", "forecast_step", "lead_time", "channel", "metric", "value"]] - df = data.set_index(["sample", "forecast_step", "channel", "metric"]) + data = data.rename(columns={"step": "lead_time", "score": "metric"}) + cols = ["sample", "forecast_step", "lead_time", "channel", "metric", "value"] + df = data[cols].set_index(["sample", "forecast_step", "channel", "metric"]) da = df["value"].to_xarray() lead_time_map = ( @@ -157,31 +187,28 @@ def get_values( lead_time=("forecast_step", lead_time_map.loc[da.forecast_step.values].values) ) - da.attrs["npoints_per_sample"] = self.npoints_per_sample da["metric"] = [metric] return da - def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray: + def load_scores(self, stream: str, regions: list[str], metrics: list[str]) -> tuple[dict, None]: """ Load the existing scores for a given run, stream and metric. Parameters ---------- - reader : - Reader object containing all info for a specific run_id - stream : + stream : str Stream name. - regions : - Regions name. - metrics : - Metrics name. + regions : list[str] + List of region names. + metrics : list[str] + List of metric names. Returns ------- - The metric DataArray. + scores: tuple[dict,None] + Dictionary of local scores keyed by metric/region/stream/run_id. """ - available_data = self.check_availability(stream, mode="evaluation") channels = available_data.channels fsteps = available_data.fsteps @@ -190,13 +217,12 @@ def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray: local_scores = {} for metric in metrics: + local_scores[metric] = {} for region in regions: - # fill it only for matching metric data = self.get_values( region=region, metric=metric, forecast_steps=fsteps, channels=channels ) - - if data.size == 0: + if data is None: data = xr.DataArray( np.full( (len(samples), len(fsteps), len(channels), 1), @@ -206,18 +232,16 @@ def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray: dims=("sample", "forecast_step", "channel", "metric"), coords={ "sample": samples, - "lead_time": fsteps, + "lead_time": ("forecast_step", fsteps), "forecast_step": range(len(fsteps)), "channel": channels, "metric": [metric], }, - attrs={"npoints_per_sample": self.npoints_per_sample}, ) - local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ - self.run_id - ] = data - + local_scores[metric].setdefault(region, {})[stream] = { + self.run_id: data, + } return local_scores, None @@ -227,10 +251,12 @@ def _metric_quaver_convention(metric: str) -> str: Parameters ---------- - metric : + metric : str Original metric name. + Returns ------- + metric: str Metric name in Quaver convention. """ metric_mapping = { @@ -247,12 +273,15 @@ def _metric_quaver_convention(metric: str) -> str: def _region_quaver_convention(region: str) -> str: """ Convert region name to Quaver convention if needed. + Parameters ---------- - region : + region : str Original region name. + Returns ------- + region: str Region name in Quaver convention. """ region_mapping = { @@ -261,30 +290,3 @@ def _region_quaver_convention(region: str) -> str: # Add more mappings as needed } return region_mapping.get(region, region) - - -##### Helper function for CSVReader #### -def _rename_channels(data) -> pd.DataFrame: - """ - The scores downloaded from Quaver have a different convention. Need renaming. - Rename channel names to include underscore between letters and digits. - E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' - - Parameters - ---------- - name : - Original channel name. - - Returns - ------- - Dataset with renamed channel names. - """ - for name in list(data.index): - # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged - if re.match(r"^\d", name): - continue - - # Otherwise, insert underscore between letters and digits - data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}) - - return data diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index f53bd0a31..405cfa949 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -10,6 +10,7 @@ # Standard library import logging import re +from abc import ABC, abstractmethod from dataclasses import dataclass # Third-party @@ -29,13 +30,10 @@ class ReaderOutput: Dictionary of xarray Datasets for targets, indexed by forecast step. prediction : dict[str, xr.Dataset] Dictionary of xarray Datasets for predictions, indexed by forecast step. - points_per_sample : xr.DataArray | None - xarray DataArray containing the number of points per sample, if `return_counts` is True """ target: dict[str, xr.Dataset] prediction: dict[str, xr.Dataset] - points_per_sample: xr.DataArray | None @dataclass @@ -52,6 +50,8 @@ class DataAvailability: List of forecast steps requested samples: List of samples requested + ensemle: + List of ensemble member identifiers """ score_availability: bool @@ -61,33 +61,30 @@ class DataAvailability: ensemble: list[str] | None = None -class Reader: +class Reader(ABC): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None): """ Generic data reader class. Parameters ---------- - eval_cfg : - config with plotting and evaluation options for that run id - run_id : - run id of the model + eval_cfg : dict + Config with plotting and evaluation options for that run id. + run_id : str + Run identifier of the model private_paths: - dictionary of private paths for the supported HPC + Dictionary of private paths for the supported HPC """ self.eval_cfg = eval_cfg self.run_id = run_id self.private_paths = private_paths - self.streams = eval_cfg.streams.keys() + self.streams = list(eval_cfg.streams.keys()) # TODO: propagate it to the other functions using global plotting opts self.global_plotting_options = eval_cfg.get("global_plotting_options", {}) - # If results_base_dir and model_base_dir are not provided, default paths are used - self.model_base_dir = self.eval_cfg.get("model_base_dir", None) - - self.results_base_dir = self.eval_cfg.get( - "results_base_dir", None - ) # base directory where results will be stored + # Default paths if not provided + self.model_base_dir = eval_cfg.get("model_base_dir") + self.results_base_dir = eval_cfg.get("results_base_dir") def get_stream(self, stream: str): """ @@ -105,39 +102,44 @@ def get_stream(self, stream: str): """ return self.eval_cfg.streams.get(stream, {}) + @abstractmethod def get_samples(self) -> set[int]: """Placeholder implementation of sample getter. Override in subclass.""" - return set() + pass + @abstractmethod def get_forecast_steps(self) -> set[int]: """Placeholder implementation forecast step getter. Override in subclass.""" - return set() + pass # TODO: get this from config + @abstractmethod def get_channels(self, stream: str | None = None) -> list[str]: """Placeholder implementation channel names getter. Override in subclass.""" - return list() + pass + @abstractmethod def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" - return list() + pass - def is_regular(self, stream: str) -> bool: + def is_gridded_data(self, stream: str) -> bool: """ Placeholder implementation to check if lat/lon are regularly spaced. Override in subclass. """ return True + @abstractmethod def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: """Placeholder to load pre-computed scores for a given run, stream, metric""" - return None + pass def check_availability( self, stream: str, available_data: dict | None = None, - mode: str = "", + mode: str = "evaluation", ) -> DataAvailability: """ Check if requested channels, forecast steps and samples are @@ -145,15 +147,18 @@ def check_availability( ii) available in the source file (e.g. the Zarr file, return error otherwise) Additionally, if channels, forecast steps or samples is None/'all', it will i) set the variable to all available vars in source file - ii) return True only if the respective variable contains the same indeces in metric file - and source file (return False otherwise) + ii) return True only if the respective variable contains the same indices in + metric file and source file (return False otherwise) Parameters ---------- - stream : + stream : str The stream considered. - available_data : - The available data loaded from metric file. + available_data : dict or None + Available data loaded from metric file. + mode : str + Mode string. Can be 'evaluation' or 'plotting'. + Returns ------- DataAvailability @@ -161,9 +166,10 @@ def check_availability( - channels: list of channels or None if 'all' - fsteps: list of forecast steps or None if 'all' - samples: list of samples or None if 'all' + - ensemble: list of ensembleor None if 'all' """ - # fill info for requested channels, fsteps, samples + # Fill requested info for channels, fsteps, samples, ensemble requested_data = self._get_channels_fsteps_samples(stream, mode) channels = requested_data.channels @@ -177,7 +183,7 @@ def check_availability( "ensemble": set(ensemble) if ensemble is not None else None, } - # fill info from available metric file (if provided) + # Extract available info from metric file (if provided) available = { "channel": ( set(available_data["channel"].values.ravel()) @@ -196,12 +202,12 @@ def check_availability( ), "ensemble": ( set(available_data["ens"].values.ravel()) - if available_data is not None and "ens" in available_data.coords + if (available_data is not None and "ens" in available_data.coords) else set() ), } - # fill info from reader + # Extract actual reader data (from source) reader_data = { "fstep": set(int(f) for f in self.get_forecast_steps()), "sample": set(int(s) for s in self.get_samples()), @@ -211,6 +217,7 @@ def check_availability( check_score = True corrected = False + for name in ["channel", "fstep", "sample", "ensemble"]: if requested[name] is None: # Default to all in Zarr @@ -218,8 +225,8 @@ def check_availability( # If file with metrics exists, must exactly match if available_data is not None and reader_data[name] != available[name]: _logger.info( - f"Requested all {name}s for {mode}, but previous config was a " - "strict subset. Recomputation required." + f"Requested all {name}s for {mode}, but previous config " + "was a strict subset. Recomputation required." ) check_score = False @@ -227,8 +234,10 @@ def check_availability( if not requested[name] <= reader_data[name]: missing = requested[name] - reader_data[name] + # Special handling for ensemble mean if name == "ensemble" and "mean" in missing: missing.remove("mean") + if missing: _logger.info( f"Requested {name}(s) {missing} is unavailable. " @@ -241,8 +250,8 @@ def check_availability( if available_data is not None and not requested[name] <= available[name]: missing = requested[name] - available[name] _logger.info( - f"{name.capitalize()}(s) {missing} missing in previous evaluation." - "Recomputation required." + f"{name.capitalize()}(s) {missing} missing in previous " + "evaluation. Recomputation required." ) check_score = False @@ -281,40 +290,53 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili - fsteps: list of forecast steps or None if 'all' - samples: list of samples or None if 'all' """ - assert mode == "plotting" or mode == "evaluation", ( - "get_channels_fsteps_samples:: Mode should be either 'plotting' or 'evaluation'" + + # Helper function to process range strings like '1-3' into lists [1,2,3] + def _parse_range_list(value, name): + if isinstance(value, str) and value != "all": + assert re.match(r"^\d+-\d+$", value), ( + f"String format for {name} in config must be " + f"'digit-digit' or 'all'. " + f"Got '{value}'." + ) + start, end = map(int, value.split("-")) + return list(range(start, end + 1)) + return value + + # Normalize None vs "all" + def normalize(val): + return ( + None + if (val == "all" or val is None) + else list(val) + if isinstance(val, list) + else val + ) + + assert mode in ("plotting", "evaluation"), ( + f"Mode must be either 'plotting' or 'evaluation'. Got '{mode}' instead." ) stream_cfg = self.get_stream(stream) - assert stream_cfg.get(mode, False), "Mode does not exist in stream config. Please add it." + assert stream_cfg.get(mode, False), ( + f"Mode '{mode}' does not exist in stream config for '{stream}'. Please add it." + ) samples = stream_cfg[mode].get("sample", None) fsteps = stream_cfg[mode].get("forecast_step", None) channels = stream_cfg.get("channels", None) ensemble = stream_cfg[mode].get("ensemble", None) + if ensemble == "mean": ensemble = ["mean"] - if isinstance(fsteps, str) and fsteps != "all": - assert re.match(r"^\d+-\d+$", fsteps), ( - "String format for forecast_step in config must be 'digit-digit' or 'all'" - ) - fsteps = list(range(int(fsteps.split("-")[0]), int(fsteps.split("-")[1]) + 1)) - if isinstance(samples, str) and samples != "all": - assert re.match(r"^\d+-\d+$", samples), ( - "String format for sample in config must be 'digit-digit' or 'all'" - ) - samples = list(range(int(samples.split("-")[0]), int(samples.split("-")[1]) + 1)) - if isinstance(ensemble, str) and ensemble not in {"all", "mean"}: - assert re.match(r"^\d+-\d+$", ensemble), ( - "String format for sample in config must be 'digit-digit' or 'all'" - ) - ensemble = list(range(int(ensemble.split("-")[0]), int(ensemble.split("-")[1]) + 1)) + fsteps = _parse_range_list(fsteps, "forecast_step") + samples = _parse_range_list(samples, "sample") return DataAvailability( score_availability=True, - channels=None if (channels == "all" or channels is None) else list(channels), - fsteps=None if (fsteps == "all" or fsteps is None) else list(fsteps), - samples=None if (samples == "all" or samples is None) else list(samples), - ensemble=None if (ensemble == "all" or ensemble is None) else list(ensemble), + channels=normalize(channels), + fsteps=normalize(fsteps), + samples=normalize(samples), + ensemble=normalize(ensemble), ) diff --git a/packages/evaluate/src/weathergen/evaluate/io/merge_reader.py b/packages/evaluate/src/weathergen/evaluate/io/merge_reader.py new file mode 100644 index 000000000..f3defd890 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/io/merge_reader.py @@ -0,0 +1,360 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# Standard library +import logging +from pathlib import Path + +# Third-party +import xarray as xr + +# Local application / package +from weathergen.evaluate.io.io_reader import Reader, ReaderOutput +from weathergen.evaluate.io.wegen_reader import WeatherGenJsonReader, WeatherGenZarrReader +from weathergen.evaluate.utils.utils import merge + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +class WeatherGenMergeReader(Reader): + def __init__( + self, + eval_cfg: dict, + run_id: str, + private_paths: dict | None = None, + regions: list[str] | None = None, + metrics: list[str] | None = None, + reader_type: str = "zarr", + ): + """ + Data reader class for merging WeatherGenerator model outputs stored in Zarr or JSON format. + + Parameters + ---------- + eval_cfg: dict + config with plotting and evaluation options for that run id + run_id: str + run id of the model + private_paths: dict + dictionary of private paths for the supported HPC + regions: list[str] + names of predefined bounding box for a region (only used for WeatherGenJsonReader) + metrics: list[str] + names of the metric scores to compute (only used for WeatherGenJsonReader) + reader_type: str + The type of the internal reader. If zarr, WeatherGenZarrReader is used, + WeatherGenJsonReader otherwise. Default: zarr + """ + super().__init__(eval_cfg, run_id, private_paths) + self.run_ids = eval_cfg.get("merge_run_ids", []) + self.metrics_dir = Path(eval_cfg.get("merge_metrics_dir")) + self.mini_epoch = eval_cfg.get("mini_epoch", 0) + + assert self.run_ids, ( + f"'merge_run_ids' must be non-empty in eval_cfg, but got: {self.run_ids}" + ) + + _logger.info(f"Initialising merge reader with {len(self.run_ids)} run(s): {self.run_ids}") + + self.readers: list[Reader] = [] + + for i, run_id in enumerate(self.run_ids): + _logger.debug( + f"Creating internal reader {i + 1}/{len(self.run_ids)} for run_id '{run_id}' ..." + ) + if reader_type == "zarr": + reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) + else: + reader = WeatherGenJsonReader( + self.eval_cfg, run_id, self.private_paths, regions, metrics + ) + self.readers.append(reader) + _logger.debug(f"Instantiated reader for run_id '{run_id}' with {reader_type}") + + _logger.info(f"Instantiated {len(self.readers)} internal readers of type {reader_type}.") + + def get_data( + self, + stream: str, + samples: list[int] | None = None, + fsteps: list[str] | None = None, + channels: list[str] | None = None, + ensemble: list[str] | None = None, + return_counts: bool = False, + ) -> ReaderOutput: + """ + Retrieve prediction and target data for a given run from the Zarr store. + + Parameters + ---------- + cfg : + Configuration dictionary containing all information for the evaluation. + + results_dir : Path + Directory where the inference results are stored. + Expected scheme `/`. + stream : + Stream name to retrieve data for. + samples : + List of sample indices to retrieve. If None, all samples are retrieved. + fsteps : + List of forecast steps to retrieve. If None, all forecast steps are retrieved. + channels : + List of channel names to retrieve. If None, all channels are retrieved. + return_counts : + If True, also return the number of points per sample. + Returns + ------- + ReaderOutput + A dataclass containing: + - target: Dictionary of xarray DataArrays for targets, indexed by forecast step. + - prediction: Dictionary of xarray DataArrays for predictions, indexed by forecast + step. + - points_per_sample: xarray DataArray containing the number of points per sample, + if `return_counts` is True. + """ + + da_tars_merge, da_preds_merge, fsteps_merge = [], [], [] + + points_per_sample = None + + for reader in self.readers: + da_tars, da_preds, da_fsteps = [], [], [] + _logger.info(f"MERGE READERS: Processing run_id {reader.run_id}...") + + out = reader.get_data( + stream, + samples, + fsteps, + channels, + ensemble="mean", + ) + + for fstep in out.target.keys(): + _logger.debug(f"MERGE READERS: Processing fstep {fstep}...") + + da_tars.append(out.target[fstep]) + da_preds.append(out.prediction[fstep]) + da_fsteps.append(fstep) + + if return_counts: + if points_per_sample is None: + points_per_sample = out.points_per_sample + else: + points_per_sample += out.points_per_sample + + da_tars_merge.append(da_tars) + da_preds_merge.append(da_preds) + fsteps_merge.append(da_fsteps) + + da_tars_merge = self._concat_over_ens(da_tars_merge, fsteps_merge) + da_preds_merge = self._concat_over_ens(da_preds_merge, fsteps_merge) + + return ReaderOutput( + target=da_tars_merge, prediction=da_preds_merge, points_per_sample=points_per_sample + ) + + def _concat_over_ens(self, da_merge, fsteps_merge): + """ + Parameters + ---------- + da_merge : list[list[xr.DataArray]] + Outer list over readers, inner list over forecast steps. + fsteps_merge : list[list[int]] + Forecast steps per reader (must be identical across readers). + + Returns + ------- + dict[int, xr.DataArray] + DataArrays concatenated over new 'ens' dimension, keyed by fstep. + """ + n_readers = len(da_merge) + + # use fsteps from first reader as reference + fsteps = fsteps_merge[0] + + da_ens = {} + for k, fstep in enumerate(fsteps): + da_list = [da_merge[i][k] for i in range(n_readers)] + da_ens[fstep] = xr.concat(da_list, dim="ens").assign_coords(ens=range(n_readers)) + + return da_ens + + def load_scores( + self, stream: str, regions: list[str], metrics: list[str] + ) -> xr.DataArray | None: + """ + Load the pre-computed scores for a given run, stream and metric and epoch. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + regions : + Region names. + metrics : + Metric names. + Returns + ------- + xr.DataArray + The metric DataArray. + missing_metrics: + dictionary of missing regions and metrics that need to be recomputed. + """ + local_scores = {} + missing_metrics = {} + + if isinstance(self.readers[0], WeatherGenZarrReader): + # TODO: implement this properly. Not it is skipping loading scores + for region in regions: + for metric, parameters in metrics.items(): + # all other cases: recompute scores + missing_metrics.setdefault(region, {}).update({metric: parameters}) + else: + local_scores, missing_metrics = self._load_scores_json(stream, regions, metrics) + return local_scores, missing_metrics + + def get_climatology_filename(self, stream: str) -> str | None: + """ + Get the climatology filename for a given stream from the inference configuration. + Parameters + ---------- + stream : + Name of the data stream. + Returns + ------- + Climatology filename if specified, otherwise None. + """ + for reader in self.readers: + clim_data_path = reader.get_climatology_filename(stream) + if clim_data_path: + return clim_data_path + return None + + def get_stream(self, stream: str): + """ + returns the dictionary associated to a particular stream. + Returns an empty dictionary if the stream does not exist in the Zarr file. + + Parameters + ---------- + stream: + the stream name + + Returns + ------- + The config dictionary associated to that stream + """ + stream_dict = self.eval_cfg.streams.get(stream, {}) + return stream_dict + + def get_samples(self) -> set[int]: + """Get the set of sample indices from the Zarr file.""" + samples = [] + for reader in self.readers: + samples.append(reader.get_samples()) + return set.intersection(*map(set, samples)) + + def get_forecast_steps(self) -> set[int]: + """Get the set of forecast steps from the Zarr file.""" + forecast_steps = [] + for reader in self.readers: + forecast_steps.append(reader.get_forecast_steps()) + return set.intersection(*map(set, forecast_steps)) + + def get_channels(self, stream: str) -> list[str]: + """ + Get the list of channels for a given stream from the config. + + Parameters + ---------- + stream : + The name of the stream to get channels for. + + Returns + ------- + A list of channel names. + """ + all_channels = [] + + for reader in self.readers: + all_channels.append(reader.get_channels(stream)) + + return set.intersection(*map(set, all_channels)) + + def get_ensemble(self, stream: str | None = None) -> list[str]: + """Get the list of ensemble member names for a given stream from the config. + Parameters + ---------- + stream : + The name of the stream to get channels for. + + Returns + ------- + A range of ensemble members equal to the number of merged readers. + """ + _logger.debug(f"Getting ensembles for stream {stream}...") + all_ensembles = [] + for reader in self.readers: + all_ensembles.append(reader.get_ensemble(stream)) + + assert all(e == ["0"] or e == [0] or e == {0} for e in all_ensembles), ( + "Merging readers with multiple ensemble members is not supported yet." + ) + return set(range(len(self.readers))) + + # TODO: improve this + def is_gridded_data(self, stream: str) -> bool: + """Check if the latitude and longitude coordinates are regularly spaced for a given stream. + Parameters + ---------- + stream : + The name of the stream to get channels for. + + Returns + ------- + True if the stream is regularly spaced. False otherwise. + """ + _logger.debug(f"Checking regular spacing for stream {stream}...") + return all(reader.is_gridded_data(stream) for reader in self.readers) + + def _load_scores_json(self, stream, regions, metrics): + "Concatenate the scores of all JSON readers" + + local_scores = {} + missing_metrics = {} + + # deep merge dicts + for reader in self.readers: + scores, missing = reader.load_scores(stream, regions, metrics) + merge(local_scores, scores) + merge(missing_metrics, missing) + + # merge runs into one with all scores concatenated + for metric in local_scores.keys(): + for region in local_scores[metric].keys(): + for stream in local_scores[metric][region].keys(): + assert len(local_scores[metric][region][stream].keys()) == len(self.run_ids), ( + f"Not all run ids are distinct or have the requested precomputed " + f"scores for metric: {metric}, region: {region}, stream: {stream}" + ) + + scores = ( + local_scores[metric][region][stream].pop(run_id) for run_id in self.run_ids + ) + + local_scores[metric][region][stream].setdefault( + self.run_id, + xr.concat(scores, dim="ens").assign_coords(ens=range(len(self.readers))), + ) + + return local_scores, missing_metrics diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index e9010667e..4c75fef32 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -29,7 +29,6 @@ from weathergen.evaluate.io.io_reader import Reader, ReaderOutput from weathergen.evaluate.scores.score_utils import to_list from weathergen.evaluate.utils.derived_channels import DeriveChannels -from weathergen.evaluate.utils.utils import merge _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -42,6 +41,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non # TODO: remove backwards compatibility to "epoch" in Feb. 2026 self.mini_epoch = eval_cfg.get("mini_epoch", 0) self.rank = eval_cfg.get("rank", 0) + # Load model configuration and set (run-id specific) directories self.inference_cfg = self.get_inference_config() @@ -61,25 +61,25 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.step_hrs = self.inference_cfg.get("step_hrs", 1) - self.results_dir, self.runplot_dir = ( - Path(self.results_base_dir), - Path(self.runplot_base_dir), - ) # for backward compatibility allow metric_dir to be specified in the run config + self.results_dir = Path(self.results_base_dir) + self.runplot_dir = Path(self.runplot_base_dir) self.metrics_dir = Path( self.eval_cfg.get("metrics_dir", self.metrics_base_dir / "evaluation") ) def get_inference_config(self): """ - load the config associated to the inference run (different from the eval_cfg which - contains plot and evaluaiton options.) + Load the config associated to the inference run (different from the + eval_cfg which contains plot and evaluation options.) Returns ------- - dict - configuration file from the inference run + config: dict + Configuration file from the inference run """ + config = {} + if self.private_paths: _logger.info( f"Loading config for run {self.run_id} from private paths: {self.private_paths}" @@ -91,47 +91,49 @@ def get_inference_config(self): ) config = load_run_config(self.run_id, self.mini_epoch, self.model_base_dir) - if type(config) not in [dict, oc.DictConfig]: + if not isinstance(config, dict | oc.DictConfig): _logger.warning("Model config not found. inference config will be empty.") config = {} + return config def get_climatology_filename(self, stream: str) -> str | None: """ - Get the climatology filename for a given stream from the inference configuration. + Get the climatology filename for a given stream from the inference + configuration. + Parameters ---------- - stream : + stream : str Name of the data stream. + Returns ------- - Climatology filename if specified, otherwise None. + path: str | None + Full climatology path if available, otherwise None. """ - stream_dict = self.get_stream(stream) clim_data_path = stream_dict.get("climatology_path", None) if not clim_data_path: clim_base_dir = self.inference_cfg.get("data_path_aux", None) - clim_fn = next( ( item.get("climatology_filename") - for item in self.inference_cfg["streams"] + for item in self.inference_cfg.get("streams", []) if item.get("name") == stream ), None, ) - if clim_base_dir and clim_fn: - clim_data_path = Path(clim_base_dir).join(clim_fn) + clim_data_path = Path(clim_base_dir) / clim_fn else: _logger.warning( f"No climatology path specified for stream {stream}. Setting climatology to " "NaN. Add 'climatology_path' to evaluation config to use metrics like ACC." ) - return clim_data_path + return str(clim_data_path) if clim_data_path else None def get_channels(self, stream: str) -> list[str]: """ @@ -139,11 +141,12 @@ def get_channels(self, stream: str) -> list[str]: Parameters ---------- - stream : + stream : str The name of the stream to get channels for. Returns ------- + all_channels: list[str] A list of channel names. """ _logger.debug(f"Getting channels for stream {stream}...") @@ -153,36 +156,36 @@ def get_channels(self, stream: str) -> list[str]: def load_scores( self, stream: str, regions: list[str], metrics: dict[str, object] - ) -> xr.DataArray | None: + ) -> tuple[dict, dict]: """ - Load multiple pre-computed scores for a given run, stream and metric and epoch. + Load multiple pre-computed scores for a given run, stream and metric + and epoch. Parameters ---------- - reader : - Reader object containing all info for a specific run_id - stream : + stream : str Stream name. - regions : + regions : list[str] Region names. - metrics : + metrics : list[str] Metric names. Returns ------- - xr.DataArray - The metric DataArray. - computable_metrics: - dictionary of regions and metrics that can be recomputed - (empty for JSONreader). + tuple[dict, dict] + - local_scores: dictionary of available scores. + - recomputable_missing_metrics: dictionary of regions and metrics + that must be recomputed (empty for JSON reader). """ - local_scores = {} missing_metrics = {} for region in regions: for metric, parameters in metrics.items(): score = self.load_single_score(stream, region, metric, parameters) - if score is not None: + if score is None: + # all other cases: recompute scores + missing_metrics.setdefault(region, {}).update({metric: parameters}) + else: available_data = self.check_availability(stream, score, mode="evaluation") if available_data.score_availability: score = score.sel( @@ -193,11 +196,7 @@ def load_scores( local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault( stream, {} )[self.run_id] = score - continue - # all other cases: recompute scores - missing_metrics.setdefault(region, {}).update({metric: parameters}) - continue recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics) return local_scores, recomputable_missing_metrics @@ -205,7 +204,12 @@ def load_single_score( self, stream: str, region: str, metric: str, parameters: dict | None = None ) -> xr.DataArray | None: """ - Load a single pre-computed score for a given run, stream and metric + Load a single pre-computed score for a given run, stream and metric. + + Returns + ------- + score: xr.DataArray or None + DataArray of the score if found, else None. """ if parameters is None: parameters = {} @@ -214,6 +218,7 @@ def load_single_score( / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") + score = None if score_path.exists(): with open(score_path) as f: @@ -226,8 +231,20 @@ def load_single_score( break return score - def get_recomputable_metrics(self, metrics): - """determine whether given metrics can be re-computed.""" + def get_recomputable_metrics(self, metrics: dict) -> dict: + """ + Determine which metrics can be recomputed. + + Parameters + ---------- + metrics : dict + Dictionary mapping regions to missing metrics. + + Returns + ------- + metrics: dict + Same as input + """ return metrics def get_inference_stream_attr(self, stream_name: str, key: str, default=None): @@ -236,16 +253,15 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): Parameters: ------------ - config: - The full configuration dictionary. - stream_name: + stream_name: str The name of the stream (e.g. 'ERA5'). - key: + key: str The key to look up (e.g. 'tokenize_spacetime'). default: Optional Value to return if not found (default: None). Returns: + ------------ The parameter value if found, otherwise the default. """ for stream in self.inference_cfg.get("streams", []): @@ -254,7 +270,7 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): return default -class WeatherGenJSONReader(WeatherGenReader): +class WeatherGenJsonReader(WeatherGenReader): def __init__( self, eval_cfg: dict, @@ -264,13 +280,15 @@ def __init__( metrics: dict[str, object] | None = None, ): super().__init__(eval_cfg, run_id, private_paths) - # goes looking for the coordinates available for all streams, regions, metrics - streams = list(self.eval_cfg.streams.keys()) + self.common_coords: dict = self._compute_common_coords(regions, metrics) + + def _compute_common_coords(self, regions: list[str], metrics: list[str]) -> dict: + # Find common coordinates across streams, regions, metrics. + streams = list(self.streams) coord_names = ["sample", "forecast_step", "ens"] - all_coords = {name: [] for name in coord_names} # collect all available coordinates - provenance = { - name: defaultdict(list) for name in coord_names - } # remember who had which coords, so we can warn about it later. + all_coords = {name: [] for name in coord_names} + provenance = {name: defaultdict(list) for name in coord_names} + for stream in streams: for region in regions: for metric, parameters in metrics.items(): @@ -281,15 +299,21 @@ def __init__( all_coords[name].append(vals) for val in vals: provenance[name][val].append((stream, region, metric)) - self.common_coords = {name: set.intersection(*all_coords[name]) for name in coord_names} - # issue warnings for skipped coords + + common_coords = {name: set.intersection(*all_coords[name]) for name in coord_names} + + # Warn about any skipped coordinates for name in coord_names: - skipped = set.union(*all_coords[name]) - self.common_coords[name] + skipped = set.union(*all_coords[name]) - common_coords[name] if skipped: - message = [f"Some {name}(s) were not common among streams, regions and metrics:"] + msg_lines = [ + f"Some {name}(s) were not common across streams, regions, and metrics:" + ] for val in skipped: - message.append(f" {val} only in {provenance[name][val]}") - _logger.warning("\n".join(message)) + msg_lines.append(f" {val} only present in {provenance[name][val]}") + _logger.warning("\n".join(msg_lines)) + + return common_coords def get_samples(self) -> set[int]: return self.common_coords["sample"] @@ -303,7 +327,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: def get_data(self, *args, **kwargs): # TODO this should not be needed, the reader should not even be created if this is the case # it can still happen when a particular score was available for a different channel - raise ValueError(f"Missing JSON data for run {self.run_id}.") + assert False, f"Missing JSON data for run {self.run_id}." def get_recomputable_metrics(self, metrics): _logger.info( @@ -318,19 +342,21 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non super().__init__(eval_cfg, run_id, private_paths) zarr_ext = self.inference_cfg.get("zarr_store", "zarr") - # for backwards compatibility assume zarr store is local i.e. .zarr format + # For backwards compatibility, assume zarr store is local (.zarr format). fname_zarr = self.results_dir.joinpath( f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.{zarr_ext}" ) - if fname_zarr.exists(): - if (zarr_ext == "zarr" and fname_zarr.is_dir()) or ( - zarr_ext == "zip" and fname_zarr.is_file() - ): - self.fname_zarr = fname_zarr - else: - _logger.error(f"Zarr file {fname_zarr} does not exist.") - raise FileNotFoundError(f"Zarr file {fname_zarr} does not exist") + + assert fname_zarr.exists(), f"Zarr file {fname_zarr} does not exist." + + assert (zarr_ext == "zarr" and fname_zarr.is_dir()) or ( + zarr_ext == "zip" and fname_zarr.is_file() + ), ( + f"Zarr file {fname_zarr} has unexpected format. ({zarr_ext}). " + f"Expected directory for 'zarr' or file for 'zip'." + ) + self.fname_zarr = fname_zarr def get_data( self, @@ -339,7 +365,6 @@ def get_data( fsteps: list[str] | None = None, channels: list[str] | None = None, ensemble: list[str] | None = None, - return_counts: bool = False, ) -> ReaderOutput: """ Retrieve prediction and target data for a given run from the Zarr store. @@ -359,59 +384,37 @@ def get_data( List of forecast steps to retrieve. If None, all forecast steps are retrieved. channels : List of channel names to retrieve. If None, all channels are retrieved. - return_counts : - If True, also return the number of points per sample. Returns ------- - ReaderOutput + out: ReaderOutput A dataclass containing: - target: Dictionary of xarray DataArrays for targets, indexed by forecast step. - prediction: Dictionary of xarray DataArrays for predictions, indexed by forecast step. - - points_per_sample: xarray DataArray containing the number of points per sample, - if `return_counts` is True. """ - # get type of zarr store - - with zarrio_reader(self.fname_zarr) as zio: - stream_cfg = self.get_stream(stream) - all_channels = self.get_channels(stream) - _logger.info(f"RUN {self.run_id}: Processing stream {stream}...") + stream_cfg = self.get_stream(stream) + all_channels = self.get_channels(stream) + _logger.info(f"RUN {self.run_id}: Processing stream {stream}...") - fsteps = self.get_forecast_steps() if fsteps is None else fsteps + fsteps = self.get_forecast_steps() if fsteps is None else fsteps - # TODO: Avoid conversion of fsteps and sample to integers (as obtained from the ZarrIO) - fsteps = sorted([int(fstep) for fstep in fsteps]) - samples = samples or sorted([int(sample) for sample in self.get_samples()]) - channels = channels or stream_cfg.get("channels", all_channels) - channels = to_list(channels) + # TODO: Avoid conversion of fsteps and sample to integers (as obtained from the ZarrIO) + fsteps = sorted([int(fstep) for fstep in fsteps]) + samples = samples or sorted([int(sample) for sample in self.get_samples()]) + channels = channels or stream_cfg.get("channels", all_channels) + channels = to_list(channels) - ensemble = ensemble or self.get_ensemble(stream) - ensemble = to_list(ensemble) + ensemble = ensemble or self.get_ensemble(stream) + ensemble = to_list(ensemble) - dc = DeriveChannels( - all_channels, - channels, - stream_cfg, - ) - - da_tars, da_preds = [], [] - - if return_counts: - points_per_sample = xr.DataArray( - np.full((len(fsteps), len(samples)), np.nan), - coords={"forecast_step": fsteps, "sample": samples}, - dims=("forecast_step", "sample"), - name=f"points_per_sample_{stream}", - ) - else: - points_per_sample = None + da_tars, da_preds = [], [] - fsteps_final = [] + fsteps_final = [] + with zarrio_reader(self.fname_zarr) as zio: for fstep in fsteps: _logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...") - da_tars_fs, da_preds_fs, pps = [], [], [] + da_tars_fs, da_preds_fs, valid_times_fs = [], [], [] for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"): out = zio.get_data(sample, stream, fstep) @@ -426,7 +429,6 @@ def get_data( target, pred = out.target.as_xarray(), out.prediction.as_xarray() npoints = len(target.ipoint) - pps.append(npoints) if npoints == 0: _logger.info( @@ -442,8 +444,15 @@ def get_data( _logger.debug(f"Selecting ensemble members {ensemble}.") pred = pred.sel(ens=ensemble) - da_tars_fs.append(target.squeeze()) - da_preds_fs.append(pred.squeeze()) + pred = pred.squeeze() + target = target.squeeze() + + if self.is_gridded_data(stream): + vt_list = np.unique(target.valid_time.values).tolist() + valid_times_fs.append(vt_list) + + da_tars_fs.append(target.persist()) + da_preds_fs.append(pred.persist()) if not da_tars_fs: _logger.info( @@ -451,7 +460,7 @@ def get_data( ) continue - fsteps_final.append(fstep) + fsteps_final.append(valid_times_fs if valid_times_fs else fstep) _logger.debug( f"Concatenating targets and predictions for stream {stream}, " @@ -459,114 +468,59 @@ def get_data( ) # faster processing - if self.is_regular(stream): + if self.is_gridded_data(stream): # Efficient concatenation for regular grid - da_preds_fs = _force_consistent_grids(da_preds_fs) - da_tars_fs = _force_consistent_grids(da_tars_fs) + da_preds_fs = _split_by_valid_time(da_preds_fs) + da_tars_fs = _split_by_valid_time(da_tars_fs) - # add lead time coordinate - da_tars_fs = self.add_lead_time_coord(da_tars_fs) - da_preds_fs = self.add_lead_time_coord(da_preds_fs) + da_tars_fs = _force_consistent_grids(da_tars_fs) + da_preds_fs = _force_consistent_grids(da_preds_fs) else: # Irregular (scatter) case. concatenate over ipoint - da_tars_fs = xr.concat(da_tars_fs, dim="ipoint") - da_preds_fs = xr.concat(da_preds_fs, dim="ipoint") - - if len(samples) == 1: - _logger.debug("Repeating sample coordinate for single-sample case.") - for da in (da_tars_fs, da_preds_fs): - da.assign_coords( - sample=( - "ipoint", - np.repeat(da.sample.values, da.sizes["ipoint"]), - ) - ) - - if set(channels) != set(all_channels): - _logger.debug( - f"Restricting targets and predictions to channels {channels} " - f"for stream {stream}..." + da_tars_fs = xr.concat( + da_tars_fs, dim="ipoint", coords="different", compat="equals" ) - - da_tars_fs, da_preds_fs, channels = dc.get_derived_channels( - da_tars_fs, da_preds_fs + da_preds_fs = xr.concat( + da_preds_fs, dim="ipoint", coords="different", compat="equals" ) - da_tars_fs = da_tars_fs.sel(channel=channels) - da_preds_fs = da_preds_fs.sel(channel=channels) - - # apply z scaling if needed - da_tars_fs = self.scale_z_channels(da_tars_fs, stream) - da_preds_fs = self.scale_z_channels(da_preds_fs, stream) - da_tars.append(da_tars_fs) da_preds.append(da_preds_fs) - if return_counts: - points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps) # Safer than a list - da_tars = {fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True)} - da_preds = {fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=True)} - - return ReaderOutput( - target=da_tars, prediction=da_preds, points_per_sample=points_per_sample - ) - - ######## reader utils ######## - - def add_lead_time_coord(self, da: xr.DataArray, sample_dim="sample") -> xr.DataArray: - """ - Add lead_time coordinate computed as: - valid_time - source_interval_end - - lead_time has dims (sample, ipoint) and dtype timedelta64[ns]. + da_tars_dict, da_preds_dict = {}, {} + i = 1 - Parameters - ---------- - da : - Input DataArray - sample_dim : - The name of the sample dimension (default is "sample") which should be kept. - Collapse over the others. - Returns - ------- - Returns a Dataset with an added lead_time coordinate. - """ - - vt = da["valid_time"] - sis = da["source_interval_start"] - - vt_reduced = vt.min(dim=[d for d in vt.dims if d != sample_dim]) + for _, (fstep, da_t, da_p) in enumerate( + zip(fsteps_final, da_tars, da_preds, strict=True) + ): + if isinstance(fstep, list): # regular grid with lead times (1 or multiple) + for t, p in zip(da_t, da_p, strict=True): + t, p = _select_channels(t, p, stream, channels, stream_cfg) - lead_time = vt_reduced - sis + # But we also want to have a common forecast_step coordinate for all + # substeps to be able to apply the same metrics. + t = t.assign_coords(forecast_step=i) + p = p.assign_coords(forecast_step=i) - return da.assign_coords(lead_time=lead_time) + # TODO: move somewhere else into another loop maybe. but 2 loops is slow? + t = _add_lead_time_coord(t) + p = _add_lead_time_coord(p) - def scale_z_channels(self, data: xr.DataArray, stream: str) -> xr.DataArray: - """ - Check scale all channels. + p = _scale_z_channels(p, stream) + t = _scale_z_channels(t, stream) - Parameters - ---------- - data : - Input dataset - stream : - Stream name. - Returns - ------- - Returns a Dataset where channels have been scaled if needed - """ - if stream is None or not str(stream).startswith("ERA5"): - return data + da_tars_dict[i] = t + da_preds_dict[i] = p + i += 1 + else: + da_t, da_p = _select_channels(da_t, da_p, stream, channels, stream_cfg) + da_tars_dict[int(fstep)] = da_t + da_preds_dict[int(fstep)] = da_p - channels_z = [ch for ch in np.atleast_1d(data.channel.values) if str(ch).startswith("z_")] - factor = 9.80665 + return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict) - if channels_z: - channels = data.channel.astype(str) - mask = channels.str.startswith("z_") - data = data.where(~mask, data / factor) - return data + ######## reader utils ######## def get_stream(self, stream: str): """ @@ -599,6 +553,17 @@ def get_forecast_steps(self) -> set[int]: with zarrio_reader(self.fname_zarr) as zio: return set(int(f) for f in zio.forecast_steps) + def get_forecast_substep_valid_times(self, stream: str) -> set[str]: + """Get the set of forecast times from the Zarr file.""" + if not self.is_gridded_data(stream): + _logger.warning(f"Stream {stream} is not gridded. Forecast times cannot be retrieved.") + return set() + + with zarrio_reader(self.fname_zarr) as zio: + dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + unique_lead = np.unique(dummy.valid_time.data) + return set(str(lt) for lt in unique_lead) + def get_ensemble(self, stream: str | None = None) -> list[str]: """Get the list of ensemble member names for a given stream from the config. Parameters @@ -617,8 +582,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) return list(dummy.prediction.as_xarray().coords["ens"].values) - # TODO: improve this - def is_regular(self, stream: str) -> bool: + def is_gridded_data(self, stream: str) -> bool: """Check if the latitude and longitude coordinates are regularly spaced for a given stream. Parameters ---------- @@ -653,366 +617,301 @@ def is_regular(self, stream: str) -> bool: ): _logger.debug("Latitude and/or longitude coordinates are not regularly spaced.") return False - - _logger.debug("Latitude and longitude coordinates are regularly spaced.") - return True + else: + _logger.debug("Latitude and longitude coordinates are regularly spaced.") + return True ################### Helper functions ######################## -def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: +def _select_channels( + da_tar: xr.DataArray, da_pred: xr.DataArray, stream, channels, stream_cfg +) -> tuple[xr.DataArray, xr.DataArray]: """ - Force all samples to share the same ipoint order. + Preprocess the data by scaling z channels if needed and adding lead_time coordinate. Parameters ---------- - ref: - Input dataset + da_tar : + Input DataArray to preprocess. + da_pred : + Input DataArray to preprocess. + stream: + Stream name, used to determine if z channels need to be scaled. + channels: + List of channels to select. + stream_cfg: + Stream configuration dictionary, used to determine if derived channels need to be computed. Returns ------- - Returns a Dataset where all samples have the same lat lon and ipoint ordering + Data arrays with selected channels and added derived channels if applicable. """ + # Ensure channel is a dimension, not a scalar coordinate (can happen after squeeze) + if "channel" not in da_tar.dims: + da_tar = da_tar.expand_dims("channel") + if "channel" not in da_pred.dims: + da_pred = da_pred.expand_dims("channel") - # Pick first sample as reference - ref_lat = ref[0].lat - ref_lon = ref[0].lon - - sort_idx = np.lexsort((ref_lon.values, ref_lat.values)) - npoints = sort_idx.size - aligned = [] - samples = [] - for i, a in enumerate(ref): - a_sorted = a.isel(ipoint=sort_idx) - samples.append(a_sorted.sample.values) - a_sorted = a_sorted.assign_coords( - ipoint=np.arange(npoints), - lat=("ipoint", ref_lat.values[sort_idx]), - lon=("ipoint", ref_lon.values[sort_idx]), - ) - - if "sample" not in a_sorted.dims: - a_sorted = a_sorted.expand_dims(sample=[i]) - - aligned.append(a_sorted) + assert da_pred.channel.values.tolist() == da_tar.channel.values.tolist(), ( + "Channels in prediction and target do not match." + ) - return xr.concat(aligned, dim="sample").assign_coords({"sample": samples}) + all_channels = da_tar.channel.values.tolist() + if set(channels) != set(all_channels): + _logger.debug( + f"Restricting targets and predictions to channels {channels} for stream {stream}..." + ) -class WeatherGenMergeReader(Reader): - def __init__( - self, - eval_cfg: dict, - run_id: str, - private_paths: dict | None = None, - regions: list[str] | None = None, - metrics: list[str] | None = None, - reader_type: str = "zarr", - ): - """ - Data reader class for merging WeatherGenerator model outputs stored in Zarr or JSON format. - - Parameters - ---------- - eval_cfg: dict - config with plotting and evaluation options for that run id - run_id: str - run id of the model - private_paths: dict - dictionary of private paths for the supported HPC - regions: list[str] - names of predefined bounding box for a region - metrics: list[str] - names of the metric scores to compute - reader_type: str - The type of the internal reader. If zarr, WeatherGenZarrReader is used, - WeatherGenJSONReader otherwise. Default: zarr - """ - super().__init__(eval_cfg, run_id, private_paths) - self.run_ids = eval_cfg.get("merge_run_ids", []) - self.metrics_dir = Path(eval_cfg.get("merge_metrics_dir")) - self.mini_epoch = eval_cfg.get("mini_epoch", 0) - - self.readers = [] - - _logger.info(f"MERGE READERS: {self.run_ids} ...") + dc = DeriveChannels( + all_channels, + channels, + stream_cfg, + ) - for run_id in self.run_ids: - if reader_type == "zarr": - reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) - else: - reader = WeatherGenJSONReader( - self.eval_cfg, run_id, self.private_paths, regions, metrics - ) - self.readers.append(reader) + da_tar, da_pred, channels = dc.get_derived_channels(da_tar, da_pred) - def get_data( - self, - stream: str, - samples: list[int] | None = None, - fsteps: list[str] | None = None, - channels: list[str] | None = None, - ensemble: list[str] | None = None, - return_counts: bool = False, - ) -> ReaderOutput: - """ - Retrieve prediction and target data for a given run from the Zarr store. + da_tar = da_tar.sel(channel=channels) + da_pred = da_pred.sel(channel=channels) - Parameters - ---------- - cfg : - Configuration dictionary containing all information for the evaluation. + return da_pred, da_tar - results_dir : Path - Directory where the inference results are stored. - Expected scheme `/`. - stream : - Stream name to retrieve data for. - samples : - List of sample indices to retrieve. If None, all samples are retrieved. - fsteps : - List of forecast steps to retrieve. If None, all forecast steps are retrieved. - channels : - List of channel names to retrieve. If None, all channels are retrieved. - return_counts : - If True, also return the number of points per sample. - Returns - ------- - ReaderOutput - A dataclass containing: - - target: Dictionary of xarray DataArrays for targets, indexed by forecast step. - - prediction: Dictionary of xarray DataArrays for predictions, indexed by forecast step. - - points_per_sample: xarray DataArray containing the number of points per sample, - if `return_counts` is True. - """ - da_tars_merge, da_preds_merge, fsteps_merge = [], [], [] +def _scale_z_channels(data: xr.DataArray, stream: str) -> xr.DataArray: + """ + Check scale all channels. - points_per_sample = None + Parameters + ---------- + data : + Input dataset + stream : + Stream name. + Returns + ------- + Returns a Dataset where channels have been scaled if needed + """ + if stream is None or not str(stream).startswith("ERA5"): + return data - for reader in self.readers: - da_tars, da_preds, da_fsteps = [], [], [] - _logger.info(f"MERGE READERS: Processing run_id {reader.run_id}...") + channels_z = [ch for ch in np.atleast_1d(data.channel.values) if str(ch).startswith("z_")] + factor = 9.80665 - out = reader.get_data( - stream, - samples, - fsteps, - channels, - ensemble="mean", - ) + if channels_z: + channels = data.channel.astype(str) + mask = channels.str.startswith("z_") + data = data.where(~mask, data / factor) + return data - for fstep in out.target.keys(): - _logger.debug(f"MERGE READERS: Processing fstep {fstep}...") - da_tars.append(out.target[fstep]) - da_preds.append(out.prediction[fstep]) - da_fsteps.append(fstep) +def _split_by_valid_time(arrays: list[xr.DataArray]) -> list[xr.DataArray]: + """ + Split arrays by valid_time and stack by sample, creating separate + arrays for each unique lead_time. - if return_counts: - if points_per_sample is None: - points_per_sample = out.points_per_sample - else: - points_per_sample += out.points_per_sample + Lead_time is calculated as: valid_time - source_interval_start - da_tars_merge.append(da_tars) - da_preds_merge.append(da_preds) - fsteps_merge.append(da_fsteps) + Parameters + ---------- + arrays : list[xr.DataArray] + List of DataArrays, each containing multiple valid_times per sample - da_tars_merge = self._concat_over_ens(da_tars_merge, fsteps_merge) - da_preds_merge = self._concat_over_ens(da_preds_merge, fsteps_merge) + Returns + ------- + list[xr.DataArray] + List of DataArrays, one per unique lead_time, with samples + stacked along 'sample' dimension + """ + # Pre-compute all lead times and build index in single pass + lead_time_groups = {} # lead_time -> list of (arr_idx, ipoint_indices) - return ReaderOutput( - target=da_tars_merge, - prediction=da_preds_merge, - points_per_sample=points_per_sample, + unique_valid_times = [np.unique(da.valid_time.values) for da in arrays] + if len(unique_valid_times) == len(arrays) and all(len(uvt) == 1 for uvt in unique_valid_times): + _logger.debug( + "All arrays have a single unique valid_time. Skipping splitting by valid_time." ) + return arrays + + for arr_idx, da in tqdm(enumerate(arrays), total=len(arrays), desc="Splitting by valid time"): + vt = da.valid_time.values + sis = da.source_interval_start.values + + # Calculate lead_time once + if vt.ndim > 1: + lead_times = vt - (sis[:, np.newaxis] if sis.ndim == 1 else sis) + # Flatten and get unique lead times with their ipoint indices + valid_mask = ~np.isnat(lead_times) + for i in range(lead_times.shape[0]): + row_leads = lead_times[i][valid_mask[i]] + row_ipoints = np.where(valid_mask[i])[0] + for lead, ipoint in zip(row_leads, row_ipoints, strict=False): + lead_time_groups.setdefault(lead, []).append((arr_idx, i, ipoint)) + else: + lead_times = vt - sis + valid_mask = ~np.isnat(lead_times) + valid_leads = lead_times[valid_mask] + valid_ipoints = np.where(valid_mask)[0] + for lead, ipoint in zip(valid_leads, valid_ipoints, strict=False): + lead_time_groups.setdefault(lead, []).append((arr_idx, 0, ipoint)) + + # Get reference grid from first array for alignment + ref_lat = arrays[0].lat.values + ref_lon = arrays[0].lon.values + ref_sort_idx = np.lexsort((ref_lon, ref_lat)) + ref_lat_sorted = ref_lat[ref_sort_idx] + ref_lon_sorted = ref_lon[ref_sort_idx] + + # Process each lead time + sorted_leads = sorted(lead_time_groups.keys()) + out = [] + + for forecast_step, lead in enumerate(sorted_leads, start=1): + # Group by array index to minimize selections + array_groups = {} + for arr_idx, sample_idx, ipoint in lead_time_groups[lead]: + array_groups.setdefault(arr_idx, {}).setdefault(sample_idx, []).append(ipoint) + + per_sample = [] + for arr_idx, sample_dict in array_groups.items(): + da = arrays[arr_idx] + + for sample_idx, ipoint_list in sample_dict.items(): + # Single selection operation + ipoint_arr = np.array(ipoint_list) + da_subset = da.isel(ipoint=ipoint_arr) + + # Align to reference grid + sort_idx = np.lexsort((da_subset.lon.values, da_subset.lat.values)) + da_subset = da_subset.isel(ipoint=sort_idx).assign_coords( + ipoint=np.arange(len(ipoint_arr)), + lat=("ipoint", ref_lat_sorted[: len(ipoint_arr)]), + lon=("ipoint", ref_lon_sorted[: len(ipoint_arr)]), + ) - def _concat_over_ens(self, da_merge, fsteps_merge): - """ - Parameters - ---------- - da_merge : list[list[xr.DataArray]] - Outer list over readers, inner list over forecast steps. - fsteps_merge : list[list[int]] - Forecast steps per reader (must be identical across readers). - - Returns - ------- - dict[int, xr.DataArray] - DataArrays concatenated over new 'ens' dimension, keyed by fstep. - """ - n_readers = len(da_merge) - - # use fsteps from first reader as reference - fsteps = fsteps_merge[0] - - da_ens = {} - for k, fstep in enumerate(fsteps): - da_list = [da_merge[i][k] for i in range(n_readers)] - da_ens[fstep] = xr.concat(da_list, dim="ens").assign_coords(ens=range(n_readers)) - - return da_ens + # Ensure sample dimension + if "sample" not in da_subset.dims: + sample_val = da.sample.values.item() if da.sample.ndim == 0 else sample_idx + da_subset = da_subset.expand_dims(sample=[sample_val]) - def load_scores( - self, stream: str, regions: list[str], metrics: list[str] - ) -> xr.DataArray | None: - """ - Load the pre-computed scores for a given run, stream and metric and epoch. + per_sample.append(da_subset) - Parameters - ---------- - reader : - Reader object containing all info for a specific run_id - stream : - Stream name. - regions : - Region names. - metrics : - Metric names. - Returns - ------- - xr.DataArray - The metric DataArray. - missing_metrics: - dictionary of missing regions and metrics that need to be recomputed. - """ - local_scores = {} - missing_metrics = {} + if per_sample: + # Single concat operation + combined = xr.concat(per_sample, dim="sample", coords="different", compat="equals") + combined = combined.assign_coords( + ipoint=np.arange(combined.sizes["ipoint"]), forecast_step=forecast_step + ) + out.append(combined) - if isinstance(self.readers[0], WeatherGenZarrReader): - # TODO: implement this properly. Not it is skipping loading scores - for region in regions: - for metric, parameters in metrics.items(): - # all other cases: recompute scores - missing_metrics.setdefault(region, {}).update({metric: parameters}) - else: # JsonReader - # deep merge dicts - for reader in self.readers: - scores, missing = reader.load_scores(stream, regions, metrics) - merge(local_scores, scores) - merge(missing_metrics, missing) - - # merge runs into one with all scores concatenated - for metric in local_scores.keys(): - for region in local_scores[metric].keys(): - for stream in local_scores[metric][region].keys(): - scores = ( - local_scores[metric][region][stream].pop(run_id) - for run_id in self.run_ids - ) - local_scores[metric][region][stream].setdefault( - self.run_id, - xr.concat(scores, dim="ens").assign_coords( - ens=range(len(self.readers)) - ), - ) + return out - return local_scores, missing_metrics - def get_climatology_filename(self, stream: str) -> str | None: - """ - Get the climatology filename for a given stream from the inference configuration. - Parameters - ---------- - stream : - Name of the data stream. - Returns - ------- - Climatology filename if specified, otherwise None. - """ - for reader in self.readers: - clim_data_path = reader.get_climatology_filename(stream) - if clim_data_path: - return clim_data_path - return None +def _add_lead_time_coord(da: xr.DataArray, sample_dim="sample") -> xr.DataArray: + """ + Add lead_time coordinate computed as: + valid_time - source_interval_start - def get_stream(self, stream: str): - """ - returns the dictionary associated to a particular stream. - Returns an empty dictionary if the stream does not exist in the Zarr file. + lead_time has dims (sample, ipoint) and dtype timedelta64[ns]. - Parameters - ---------- - stream: - the stream name + Parameters + ---------- + da : + Input DataArray + sample_dim : + The name of the sample dimension (default is "sample") which should be kept. + Collapse over the others. + Returns + ------- + Returns a DataArray with the lead_time coordinate added. - Returns - ------- - The config dictionary associated to that stream - """ - stream_dict = self.eval_cfg.streams.get(stream, {}) - return stream_dict + NB. Need to be used AFTER splitting by valid_time and stacking by sample, + so that all valid_times within a sample are the same and we can assign a + single lead_time per sample. - def get_samples(self) -> set[int]: - """Get the set of sample indices from the Zarr file.""" - samples = [] - for reader in self.readers: - samples.append(reader.get_samples()) - return set.intersection(*map(set, samples)) + """ + vt = da["valid_time"].values + sis = da["source_interval_start"].values + # Compute lead_time: valid_time - source_interval_start + if vt.ndim > 1: + sis_expanded = sis[:, np.newaxis] if sis.ndim == 1 else sis + lead_time_values = vt - sis_expanded + # Get unique lead_time per sample, verify consistency + lead_times = [ + np.unique(lead_time_values[i][~np.isnat(lead_time_values[i])]) + for i in range(lead_time_values.shape[0]) + ] + if any(len(lt) != 1 for lt in lead_times): + raise ValueError( + "Inconsistent lead_time values within samples for " + f"forecast_step {da.forecast_step.values}" + ) + lead_time_per_sample = np.array([lt[0] for lt in lead_times]) + else: + lead_time_values = vt - sis + lead_time_per_sample = np.unique(lead_time_values[~np.isnat(lead_time_values)]) + + # Verify all samples have same lead_time for this forecast_step + unique_lead = np.unique(lead_time_per_sample) + if len(unique_lead) != 1: + raise ValueError( + "Multiple lead_time values across samples for " + f"forecast_step {da.forecast_step.values}: {unique_lead}" + ) - def get_forecast_steps(self) -> set[int]: - """Get the set of forecast steps from the Zarr file.""" - forecast_steps = [] - for reader in self.readers: - forecast_steps.append(reader.get_forecast_steps()) - return set.intersection(*map(set, forecast_steps)) + da = da.assign_coords(lead_time=unique_lead[0]) + return da - def get_channels(self, stream: str) -> list[str]: - """ - Get the list of channels for a given stream from the config. - Parameters - ---------- - stream : - The name of the stream to get channels for. +def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: + """ + Force all samples to share the same ipoint order. - Returns - ------- - A list of channel names. - """ - all_channels = [] + This function aligns the spatial ordering (lat/lon/ipoint) of all samples + to that of the first sample, ensuring consistent spatial coordinates for + subsequent concatenation. It is essential for regular-grid (gridded) data + where spatial order matters but may differ across samples. - for reader in self.readers: - all_channels.append(reader.get_channels(stream)) + Parameters + ---------- + ref: list[xr.DataArray] + List of xarray DataArrays, each representing one sample. Must have at least one element. - return set.intersection(*map(set, all_channels)) + Returns + ------- + xr.DataArray + A concatenated DataArray across the 'sample' dimension, where each sample's ipoint indices + have been reordered to match the sorted lat/lon order of the first sample. + + Notes + ----- + - All input DataArrays must share identical lat/lon values + (though possibly in different orders). + - Enforces consistent ipoint indexing after alignment (0..N-1). + - Preserves and aligns all other coordinates and data variables. + """ + assert len(ref) > 0, "_force_consistent_grids requires at least one input DataArray in 'ref'." - def get_ensemble(self, stream: str | None = None) -> list[str]: - """Get the list of ensemble member names for a given stream from the config. - Parameters - ---------- - stream : - The name of the stream to get channels for. + # Pick first sample as reference + ref_lat = ref[0].lat + ref_lon = ref[0].lon - Returns - ------- - A range of ensemble members equal to the number of merged readers. - """ - _logger.debug(f"Getting ensembles for stream {stream}...") - all_ensembles = [] - for reader in self.readers: - all_ensembles.append(reader.get_ensemble(stream)) + sort_idx = np.lexsort((ref_lon.values, ref_lat.values)) + npoints = sort_idx.size + aligned = [] + samples = [] + for i, a in enumerate(ref): + a_sorted = a.isel(ipoint=sort_idx) + samples.append(a_sorted.sample.values) + a_sorted = a_sorted.assign_coords( + ipoint=np.arange(npoints), + lat=("ipoint", ref_lat.values[sort_idx]), + lon=("ipoint", ref_lon.values[sort_idx]), + ) - if all(e == ["0"] or e == [0] or e == {0} for e in all_ensembles): - return set(range(len(self.readers))) - else: - raise NotImplementedError( - "Merging readers with multiple ensemble members is not supported yet." - ) - return + if "sample" not in a_sorted.dims: + a_sorted = a_sorted.expand_dims(sample=[i]) - # TODO: improve this - def is_regular(self, stream: str) -> bool: - """Check if the latitude and longitude coordinates are regularly spaced for a given stream. - Parameters - ---------- - stream : - The name of the stream to get channels for. + aligned.append(a_sorted) - Returns - ------- - True if the stream is regularly spaced. False otherwise. - """ - _logger.debug(f"Checking regular spacing for stream {stream}...") - return all(reader.is_regular(stream) for reader in self.readers) + return aligned # xr.concat(aligned, dim="sample") diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index f13148f09..2adba17bf 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -841,7 +841,6 @@ def plot( x_dim: str = "lead_time", y_dim: str = "value", print_summary: bool = False, - plot_ensemble: str | bool = False, ) -> None: """ Plot a line graph comparing multiple datasets. @@ -893,7 +892,11 @@ def plot( parts = ["compare", tag] name = "_".join(filter(None, parts)) - self._plot_base(fig, name, x_dim, y_dim, print_summary) + + # TODO: generalise this for other x_dims by instroducing a "units" + # entry in the function if needed + xunits = "hr" if x_dim == "lead_time" else None + self._plot_base(fig, name, x_dim, y_dim, print_summary, xunits=xunits) def _plot_base( self, @@ -905,6 +908,8 @@ def _plot_base( line: float | None = None, vlines: bool = False, title: str | None = None, + xunits: str | None = None, + yunits: str | None = None, ) -> None: """ Apply labels, title, legend, save and optionally print summary. @@ -926,12 +931,22 @@ def _plot_base( If True, draw vertical lines to separate each group of variables. title: Title for the plot. + xunits: + Units for the x-axis. + yunits: + Units for the y-axis. Returns ------- None """ - plt.xlabel("".join(c if c.isalnum() else " " for c in x_dim)) - plt.ylabel("".join(c if c.isalnum() else " " for c in y_dim)) + + plt.xlabel( + "".join(c if c.isalnum() else " " for c in x_dim) + (f" [{xunits}]" if xunits else "") + ) + plt.ylabel( + "".join(c if c.isalnum() else " " for c in y_dim) + (f" [{yunits}]" if yunits else "") + ) + plt.title(title if title is not None else " ".join(c if c.isalnum() else " " for c in name)) plt.legend(frameon=False) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 35d02cf96..2b898d302 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -28,9 +28,9 @@ from weathergen.common.logger import init_loggers from weathergen.common.platform_env import get_platform_env from weathergen.evaluate.io.csv_reader import CsvReader +from weathergen.evaluate.io.merge_reader import WeatherGenMergeReader from weathergen.evaluate.io.wegen_reader import ( - WeatherGenJSONReader, - WeatherGenMergeReader, + WeatherGenJsonReader, WeatherGenReader, WeatherGenZarrReader, ) @@ -154,8 +154,6 @@ def evaluate_from_args(argl: list[str], log_queue: mp.Queue) -> None: _logger.info(f"MLFlow client set up: {mlflow_client}") cf = OmegaConf.load(config) - with open_dict(cf): - cf.evaluation.metrics = parse_metric_params(cf.evaluation.metrics) assert isinstance(cf, DictConfig) evaluate_from_config(cf, mlflow_client, log_queue) @@ -173,7 +171,7 @@ def get_reader( elif reader_type == "csv": reader = CsvReader(run, run_id, private_paths) elif reader_type == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) + reader = WeatherGenJsonReader(run, run_id, private_paths, region, metric) elif reader_type == "merge": reader = WeatherGenMergeReader(run, run_id, private_paths) elif reader_type == "jsonmerge": @@ -224,12 +222,12 @@ def _process_stream( plot_score_maps: Bool to define if the score maps need to be plotted or not. """ - type_ = run.get("type", "zarr") reader = get_reader(type_, run, run_id, private_paths, regions, metrics) stream_dict = reader.get_stream(stream) if not stream_dict: + _logger.info(f"No evaluation config for {run_id} - {stream}. Skipping.") return run_id, stream, {} # Parallel plotting @@ -274,6 +272,8 @@ def evaluate_from_config( cfg: Configuration input stored as dictionary. """ + with open_dict(cfg): + cfg.evaluation.metrics = parse_metric_params(cfg.evaluation.metrics) runs = cfg.run_ids _logger.info(f"Detected {len(runs)} runs") private_paths = cfg.get("private_paths") diff --git a/packages/evaluate/src/weathergen/evaluate/utils/regions.py b/packages/evaluate/src/weathergen/evaluate/utils/regions.py index 1be9c5ea3..b2893c314 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/regions.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/regions.py @@ -100,7 +100,6 @@ def apply_mask( & (lon >= self.lon_min) & (lon <= self.lon_max) ) - return data.sel({data_dim: mask}) @classmethod diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index fd685f316..179e98fbd 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -107,8 +107,8 @@ def calc_scores_per_stream( samples = available_data.samples channels = available_data.channels ensemble = available_data.ensemble - is_regular = reader.is_regular(stream) - group_by_coord = None if is_regular else "sample" + is_gridded_data = reader.is_gridded_data(stream) + group_by_coord = None if is_gridded_data else "sample" output_data = reader.get_data( stream, @@ -116,11 +116,11 @@ def calc_scores_per_stream( samples=samples, channels=channels, ensemble=ensemble, - return_counts=True, ) + da_preds = output_data.prediction da_tars = output_data.target - + fsteps = sorted(list(da_preds.keys())) aligned_clim_data = get_climatology(reader, da_tars, stream) for region in regions: @@ -129,7 +129,7 @@ def calc_scores_per_stream( _logger.info( f"RUN {reader.run_id} - {stream}: Calculating scores for region {region}" - f" and metrics {metrics}..." + f" and metrics {list(metrics.keys())}..." ) metric_stream = xr.DataArray( np.full( @@ -145,13 +145,21 @@ def calc_scores_per_stream( }, ) - lead_time_map = {} + if "lead_time" in da_preds[fsteps[0]].coords: + metric_stream = metric_stream.assign_coords( + lead_time=("forecast_step", np.full(len(fsteps), -1, dtype=int)) + ) + # Store metric-specific attributes that get lost during concat # Key: (fstep, metric) -> attrs dict all_metric_attrs = {} - for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): - if preds.ipoint.size == 0: + for (fstep, tars), (_, preds) in tqdm( + zip(da_tars.items(), da_preds.items(), strict=False), + total=len(da_tars), + desc=f"Computing scores for {reader.run_id} - stream {stream} and region {region}", + ): + if preds.sizes.get("ipoint") == 0: _logger.warning( f"No data for stream {stream} at fstep {fstep} in region {region}. Skipping." ) @@ -216,7 +224,6 @@ def calc_scores_per_stream( combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) combined_metrics = combined_metrics.compute() - for coord in ["channel", "sample", "ens"]: combined_metrics = scalar_coord_to_dim(combined_metrics, coord) @@ -237,56 +244,48 @@ def calc_scores_per_stream( # Skip coordinates that are already dimensions (no need to restore) if coord_name in combined_metrics.dims or coord_name in metric_stream.dims: continue - - # Only restore coordinates whose dimensions exist in metric_stream - # (e.g., skip coords with 'quantile' dim if metric_stream doesn't have it) - coord_dims = combined_metrics.coords[coord_name].dims - if not all(dim in metric_stream.dims for dim in coord_dims): - _logger.debug( - f"Skipping coordinate '{coord_name}' with incompatible " - f"dimensions {coord_dims} (metric_stream has {metric_stream.dims})" + if coord_name == "lead_time": + metric_stream.coords["lead_time"].loc[{"forecast_step": int(fstep)}] = ( + combined_metrics.coords["lead_time"] + .values.astype("timedelta64[h]") + .astype(int) ) - continue - - # Initialize coordinate in metric_stream if it doesn't exist yet - if coord_name not in metric_stream.coords: - coord_shape = tuple(len(metric_stream.coords[dim]) for dim in coord_dims) - metric_stream = metric_stream.assign_coords( - { - coord_name: xr.DataArray( - np.full(coord_shape, "", dtype=object), - dims=coord_dims, - coords={dim: metric_stream.coords[dim] for dim in coord_dims}, - ) - } - ) - - # Build indexers to select the right location in metric_stream - indexers = {dim: criteria[dim] for dim in coord_dims if dim in criteria} - metric_stream.coords[coord_name].loc[indexers] = combined_metrics.coords[coord_name] + else: + # Only restore coordinates whose dimensions exist in metric_stream + # (e.g., skip coords with 'quantile' dim if metric_stream doesn't have it) + coord_dims = combined_metrics.coords[coord_name].dims + if not all(dim in metric_stream.dims for dim in coord_dims): + _logger.debug( + f"Skipping coordinate '{coord_name}' with incompatible " + f"dimensions {coord_dims} (metric_stream has {metric_stream.dims})" + ) + continue + + # Initialize coordinate in metric_stream if it doesn't exist yet + if coord_name not in metric_stream.coords: + coord_shape = tuple(len(metric_stream.coords[dim]) for dim in coord_dims) + metric_stream = metric_stream.assign_coords( + { + coord_name: xr.DataArray( + np.full(coord_shape, "", dtype=object), + dims=coord_dims, + coords={dim: metric_stream.coords[dim] for dim in coord_dims}, + ) + } + ) - lead_time_map[fstep] = ( - np.unique(combined_metrics.lead_time.values.astype("timedelta64[h]")) - if "lead_time" in combined_metrics.coords - else None - ) + # Build indexers to select the right location in metric_stream + indexers = {dim: criteria[dim] for dim in coord_dims if dim in criteria} + metric_stream.coords[coord_name].loc[indexers] = combined_metrics.coords[ + coord_name + ] - if is_regular and plot_score_maps: + if is_gridded_data and plot_score_maps: _logger.info(f"Plotting scores on a map {stream} - forecast step: {fstep}...") _plot_score_maps_per_stream( reader, map_dir, stream, region, score_data, metrics, fstep ) - if all(lead_time_map[f] is not None for f in lead_time_map): - lead_time_values = np.array( - [lead_time_map[f].astype(int) for f in metric_stream.forecast_step.values] - ).squeeze() - - if lead_time_values.shape == metric_stream.forecast_step.shape: - metric_stream = metric_stream.assign_coords( - lead_time=("forecast_step", lead_time_values) - ) - _logger.info(f"Scores for run {reader.run_id} - {stream} calculated successfully.") _logger.debug(f"all_metric_attrs keys: {list(all_metric_attrs.keys())}") @@ -442,6 +441,10 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: if not isinstance(plot_maps, bool): raise TypeError("plot_maps must be a boolean.") + plot_bias = plot_settings.get("plot_bias", True) + if not isinstance(plot_bias, bool): + raise TypeError("plot_bias must be a boolean.") + plot_target = plot_settings.get("plot_target", True) if not isinstance(plot_target, bool): raise TypeError("plot_target must be a boolean.") @@ -476,6 +479,9 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: maps_config = common_ranges( da_tars, da_preds, available_data.channels, global_plotting_opts[stream] ) + bias_config = bias_ranges( + da_tars, da_preds, available_data.channels, global_plotting_opts[stream] + ) for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): plot_chs = list(np.atleast_1d(tars.channel.values)) @@ -493,6 +499,12 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: plotter.create_maps_per_sample( tars, plot_chs, data_selection, "targets", maps_config ) + + if plot_bias: + plotter.create_maps_per_sample( + preds - tars, plot_chs, data_selection, "bias", bias_config + ) + for ens in available_data.ensemble: preds_ens = ( preds.sel(ens=ens) if "ens" in preds.dims and ens != "mean" else preds @@ -520,7 +532,8 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: plotter.animation(plot_samples, plot_fsteps, plot_chs, data_selection, preds_name) if plot_target: plotter.animation(plot_samples, plot_fsteps, plot_chs, data_selection, "targets") - + if plot_bias: + plotter.animation(plot_samples, plot_fsteps, plot_chs, data_selection, "bias") return @@ -653,7 +666,7 @@ def common_ranges( data_tars: list[dict], data_preds: list[dict], plot_chs: list[str], - maps_config: oc.dictconfig.DictConfig, + global_plotting_opts_stream: oc.dictconfig.DictConfig, ) -> oc.dictconfig.DictConfig: """ Calculate common ranges per stream and variables. @@ -666,14 +679,15 @@ def common_ranges( the (prediction) list of dictionaries with the forecasteps and respective xarray plot_chs: the variables to be plotted as given by the configuration file - maps_config: - the global plotting configuration + global_plotting_opts_stream: + the global plotting configuration for the stream as given by the configuration file, which + may or may not include predefined ranges for some variables. Returns ------- maps_config : - the global plotting configuration with the ranges added and included for each variable (and - for each stream). + the global plotting configuration with the ranges added and included for each variable. """ + maps_config = global_plotting_opts_stream.copy() for var in plot_chs: if var in maps_config: if not isinstance(maps_config[var].get("vmax"), (int | float)): @@ -697,6 +711,44 @@ def common_ranges( return maps_config +def bias_ranges( + data_tars: dict, + data_preds: dict, + plot_chs: list[str], + global_plotting_opts_stream: oc.dictconfig.DictConfig, +) -> oc.dictconfig.DictConfig: + """ + Calculate symmetric bias ranges (preds - tars) per variable. + + Parameters + ---------- + data_tars : + Dictionary mapping forecast steps to target xarray DataArrays. + data_preds : + Dictionary mapping forecast steps to prediction xarray DataArrays. + plot_chs : + List of variable (channel) names to compute bias ranges for. + global_plotting_opts_stream : + The global plotting configuration for the stream, used as the base config. + + Returns + ------- + oc.dictconfig.DictConfig + Per-variable symmetric ranges (vmin = -abs_max, vmax = abs_max) for bias. + """ + bias_config = global_plotting_opts_stream.copy() + for var in plot_chs: + bias_vals = [ + (p - t).sel(channel=var).values + for t, p in zip(data_tars.values(), data_preds.values(), strict=False) + ] + abs_max = float( + max(abs(np.concatenate(bias_vals).max()), abs(np.concatenate(bias_vals).min())) + ) + bias_config.update({var: {"vmax": abs_max, "vmin": -abs_max}}) + return bias_config + + def calc_val(x: xr.DataArray, bound: str) -> list[float]: """ Calculate the maximum or minimum value per variable for all forecasteps. diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py index 748a3499d..39d37ec92 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py @@ -1,3 +1,5 @@ +# pylint: disable=bad-builtin + # (C) Copyright 2025 WeatherGenerator contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py new file mode 100644 index 000000000..f0b541bfa --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py @@ -0,0 +1,542 @@ +# pylint: disable=bad-builtin + +import json +import logging +from pathlib import Path +from typing import override + +import dask +import dask.array as da +import fsspec +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + DTRange, + ReaderData, + TimeWindowHandler, + TIndex, +) + +_logger = logging.getLogger(__name__) + +# Small epsilon to handle time boundary exclusivity +t_epsilon = np.timedelta64(1, "ms") +MIN_PATCH_POINTS = 1024 + + +class DataReaderMesh(DataReaderTimestep): + """ + A data reader for unstructured mesh data accessed via Virtual Zarr. + Features: + - Separate Source and Target files. + - Persistence of State time indexing (forward fill). + - Robust Multi-Node/Worker support (Fork-safe, Dask-safe). + - Dynamic Patching (local) OR Global Sparse Sampling. + """ + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + self.filename_source = Path(filename) + # Check for separate target file + if "target_file" in stream_info: + self.filename_target = Path(stream_info["target_file"]) + else: + self.filename_target = self.filename_source + + self._stream_info = stream_info + self.roi = stream_info.get("roi") + self.patch_size_deg = stream_info.get("patch_size_deg") + self.sample_points = stream_info.get("sample_points") + + self._dask_arrays = {} + + # 'patch' = contiguous geographic square + # 'global_sparse' = random points scattered over the whole available area + self.sampling_mode = stream_info.get("sampling_mode", "patch") + + # --- WARNING FOR MIXED FILES + GLOBAL SAMPLING --- + if self.sampling_mode == "global_sparse" and self.filename_source != self.filename_target: + _logger.warning( + f"[Stream {stream_info.get('name')}] GLOBAL SPARSE SAMPLING" + "enabled with DIFFERENT Source and Target files!" + ) + _logger.warning( + " -> This assumes perfect row-by-row index alignment between the two meshes." + ) + _logger.warning( + " -> If the meshes have different node orderings," + " this will produce SILENT DATA CORRUPTION." + ) + + self._initialized = False + self.ds_source = None + self.ds_target = None + self.mapper_src = None + self.mapper_trg = None + + if not self.filename_source.exists(): + _logger.warning(f"Source file {self.filename_source} not found. Stream skipped.") + self.init_empty() + super().__init__(tw_handler, stream_info, None, None, None) + return + + # --- PROBE METADATA (Source & Target) --- + self.col_map = {} + self.stats_means = {} + self.stats_vars = {} + + # 1. Probe Source + meta_src = self._probe_file(self.filename_source, is_source=True) + if not meta_src: + return + + # 2. Probe Target (if different) + if self.filename_target != self.filename_source: + meta_trg = self._probe_file(self.filename_target, is_source=False) + if not meta_trg: + return + self.col_map.update(meta_trg["col_map"]) + self.stats_means.update(meta_trg["means"]) + self.stats_vars.update(meta_trg["vars"]) + + # Unpack Source Metadata + ds_time_values = meta_src["time"] + self._len_cached = len(ds_time_values) + self._time_values_cached = ds_time_values + + data_start_time = np.datetime64(ds_time_values[0], "ns") + # Calc period from first two steps if possible, else default to something safe + if len(ds_time_values) > 1: + period = np.datetime64(ds_time_values[1], "ns") - data_start_time + else: + # Fallback for single-step datasets + period = np.timedelta64(24, "h") + + data_end_time = np.datetime64(ds_time_values[-1], "ns") + + self.lats = meta_src["lats"] + self.lons = meta_src["lons"] + self.spatial_indices = meta_src["indices"] + self.coords = meta_src["coords"] + + # Parse ROI from config for consistency + if self.roi: + self.roi_min_lon, self.roi_min_lat, self.roi_max_lon, self.roi_max_lat = self.roi + else: + self.roi_min_lon, self.roi_min_lat, self.roi_max_lon, self.roi_max_lat = ( + -180.0, + -90.0, + 180.0, + 90.0, + ) + + self.available_channels = list(self.col_map.keys()) + + super().__init__(tw_handler, stream_info, data_start_time, data_end_time, period) + + self.source_idx = self._select_channels("source") + self.target_idx = self._select_channels("target") + self.geoinfo_idx = [] + self.geoinfo_channels = [] + + self.source_channels = [self.available_channels[i] for i in self.source_idx] + self.target_channels = [self.available_channels[i] for i in self.target_idx] + + self._init_stats_arrays() + + def _probe_file(self, filepath, is_source=True): + """Helper to open a file, extract meta, and close it immediately.""" + mapper = fsspec.get_mapper("reference://", fo=str(filepath), remote_protocol="file") + try: + with xr.open_dataset(mapper, engine="zarr", chunks={}, consolidated=False) as ds: + if "time" not in ds.coords: + all_vars = list(ds.coords) + list(ds.data_vars) + time_candidates = [v for v in all_vars if "time" in v.lower()] + if time_candidates: + target = time_candidates[0] + if target in ds.data_vars: + ds = ds.set_coords(target) + if target != "time": + ds = ds.rename({target: "time"}) + if "time" in ds.dims and "time" not in ds.indexes: + ds = ds.assign_coords(time=ds["time"].values) + + if "time" not in ds.coords: + _logger.error(f"No time coordinate in {filepath}.") + if is_source: + self.init_empty() + super().__init__( + self._stream_info.get("tw_handler"), self._stream_info, None, None, None + ) + return None + + meta = { + "time": ds.time.values, + "col_map": self._parse_attr(ds.attrs, "weathergen_col_map"), + "means": self._parse_attr(ds.attrs, "weathergen_means"), + "vars": self._parse_attr(ds.attrs, "weathergen_vars"), + } + + if is_source: + self.col_map.update(meta["col_map"]) + self.stats_means.update(meta["means"]) + self.stats_vars.update(meta["vars"]) + + lats = ( + ds["lat"].values.astype(np.float32) + if "lat" in ds + else ds["lat_c"].values.astype(np.float32) + ) + lons = ( + ds["lon"].values.astype(np.float32) + if "lon" in ds + else ds["lon_c"].values.astype(np.float32) + ) + + lats = np.nan_to_num(lats, nan=0.0) + lons = np.nan_to_num(lons, nan=0.0) + if np.any(lats > 90.0): + lats = lats - 90.0 + lats = np.clip(lats, -90.0, 90.0) + lons = ((lons + 180.0) % 360.0) - 180.0 + + if self.roi: + min_lon, min_lat, max_lon, max_lat = self.roi + if min_lon > max_lon: + mask = (lons >= min_lon) | (lons <= max_lon) + else: + mask = (lons >= min_lon) & (lons <= max_lon) + mask &= (lats >= min_lat) & (lats <= max_lat) + spatial_indices = np.where(mask)[0] + lats = lats[spatial_indices] + lons = lons[spatial_indices] + else: + spatial_indices = np.arange(len(lats)) + + meta["lats"] = lats + meta["lons"] = lons + meta["indices"] = spatial_indices + meta["coords"] = np.stack([lats, lons], axis=1) + + return meta + except Exception as e: + _logger.error(f"Failed to probe {filepath}: {e}") + return None + + def _lazy_init(self): + if self._initialized: + return + + self.mapper_src = fsspec.get_mapper( + "reference://", fo=str(self.filename_source), remote_protocol="file" + ) + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*separate the stored chunks.*") + self.ds_source = xr.open_dataset( + self.mapper_src, engine="zarr", chunks={}, decode_times=True, consolidated=False + ) + + if self.filename_target != self.filename_source: + self.mapper_trg = fsspec.get_mapper( + "reference://", fo=str(self.filename_target), remote_protocol="file" + ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*separate the stored chunks.*") + self.ds_target = xr.open_dataset( + self.mapper_trg, engine="zarr", chunks={}, decode_times=True, consolidated=False + ) + else: + self.ds_target = self.ds_source + + self._dask_arrays = {} + for ch in self.source_channels: + var = self.col_map[ch]["var"] + if var in self.ds_source: + self._dask_arrays[ch] = self.ds_source[var].data + + for ch in self.target_channels: + var = self.col_map[ch]["var"] + if var in self.ds_target: + self._dask_arrays[ch] = self.ds_target[var].data + + self._initialized = True + + def _get_persistent_time_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: + dtr = self.time_window_handler.window(idx) + if dtr.end < self.data_start_time or dtr.start > self.data_end_time: + return (np.array([], dtype=np.int64), dtr) + + delta_start = dtr.start - self.data_start_time + start_idx = int(delta_start / self.period) + + delta_end = dtr.end - self.data_start_time - t_epsilon + end_idx = int(delta_end / self.period) + + start_idx = max(0, start_idx) + end_idx = min(len(self._time_values_cached) - 1, end_idx) + + return (np.arange(start_idx, end_idx + 1, dtype=np.int64), dtr) + + @override + def get_source(self, idx: TIndex) -> ReaderData: + return self._fetch_data(idx, self.source_channels, is_source=True) + + @override + def get_target(self, idx: TIndex) -> ReaderData: + return self._fetch_data(idx, self.target_channels, is_source=False) + + def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> ReaderData: + self._lazy_init() + (t_idxs, dtr) = self._get_persistent_time_idxs(idx) + + if len(t_idxs) == 0 or not channels: + return ReaderData.empty(len(channels), 0) + + channel_indices = [self.available_channels.index(c) for c in channels] + start_t, end_t = t_idxs[0], t_idxs[-1] + 1 + n_steps = len(t_idxs) + + # Setup RNG + local_seed = int(idx) + 12345 + patch_rng = np.random.default_rng(local_seed) + + # --- STRATEGY SELECTION --- + if self.sampling_mode == "global_sparse": + # --- GLOBAL SPARSE SAMPLING --- + total_points = len(self.spatial_indices) + target_n = self.sample_points if self.sample_points else 4096 + + # Simple random choice from the full available set (defined by ROI in init) + # This is deterministic because patch_rng is seeded with idx + indices_local = patch_rng.choice( + total_points, size=min(target_n, total_points), replace=False + ) + + patch_coords_base = self.coords[indices_local] + final_disk_indices = self.spatial_indices[indices_local] + + # Note: For scattered points, we force fancy indexing by passing rel_indices=None + # to _load_block to avoid reading the whole file array. + use_contiguous_read = False + + elif self.patch_size_deg: + # --- PATCH SAMPLING --- + lat_range = max(0.0, (self.roi_max_lat - self.roi_min_lat) - self.patch_size_deg) + lon_range = max(0.0, (self.roi_max_lon - self.roi_min_lon) - self.patch_size_deg) + + patch_indices_local = np.array([]) + attempts = 0 + + while len(patch_indices_local) < MIN_PATCH_POINTS and attempts < 100: + lat_0 = self.roi_min_lat + patch_rng.random() * lat_range + lon_0 = self.roi_min_lon + patch_rng.random() * lon_range + + mask = ( + (self.lats >= lat_0) + & (self.lats < lat_0 + self.patch_size_deg) + & (self.lons >= lon_0) + & (self.lons < lon_0 + self.patch_size_deg) + ) + patch_indices_local = np.where(mask)[0] + attempts += 1 + + if len(patch_indices_local) < MIN_PATCH_POINTS: + # Fallback to random points if patch is too sparse + req_points = min(MIN_PATCH_POINTS, len(self.lats)) + patch_indices_local = patch_rng.choice( + len(self.lats), size=req_points, replace=False + ) + + patch_coords_base = self.coords[patch_indices_local] + final_disk_indices = self.spatial_indices[patch_indices_local] + use_contiguous_read = True + + else: + # --- FULL ROI / FULL GRID --- + final_disk_indices = self.spatial_indices + patch_coords_base = self.coords + use_contiguous_read = True + + # Load Data + ds_ref = self.ds_source if is_source else self.ds_target + + if use_contiguous_read: + # Optimized Contiguous Read + disk_start, disk_stop = np.min(final_disk_indices), np.max(final_disk_indices) + 1 + rel_indices = final_disk_indices - disk_start + data_block = self._load_block_from_ds( + ds_ref, + channel_indices, + start_t, + end_t, + n_steps, + slice(disk_start, disk_stop), + rel_indices, + ) + else: + # Scattered Read (Global Sparse) -> Pass raw indices, rel_indices=None + data_block = self._load_block_from_ds( + ds_ref, channel_indices, start_t, end_t, n_steps, final_disk_indices, None + ) + + if data_block.size > 0: + d_max = np.nanmax(np.abs(data_block)) + if d_max > 1e10: + data_block[np.abs(data_block) > 1e10] = np.nan + + coords_flat = np.tile(patch_coords_base, (n_steps, 1)) + dt_values = self._time_values_cached[start_t:end_t] + dt_flat = np.repeat(dt_values, patch_coords_base.shape[0]) + + rdata = ReaderData( + coords=coords_flat, + geoinfos=np.zeros((len(data_block), 0), dtype=np.float32), + data=data_block, + datetimes=dt_flat, + ) + return rdata + + def _load_block_from_ds(self, ds, indices, start_t, end_t, n_steps, disk_indices, rel_indices): + """ + Loads data using either contiguous slicing (fastest for patches) + or fancy indexing (memory efficient for sparse global). + """ + # Calculate output size + if rel_indices is not None: + num_points = len(rel_indices) + else: + num_points = len(disk_indices) # disk_indices is the list of points + + if not indices: + return np.zeros((n_steps * num_points, 0), dtype=np.float32) + + output_block = np.zeros((n_steps * num_points, len(indices)), dtype=np.float32) + + with dask.config.set(scheduler="single-threaded"): + for i, idx in enumerate(indices): + ch_name = self.available_channels[idx] + if ch_name not in self._dask_arrays: + info = self.col_map[ch_name] + if info["var"] in ds: + self._dask_arrays[ch_name] = ds[info["var"]].data + else: + continue + + info = self.col_map[ch_name] + base_arr = self._dask_arrays[ch_name] + dims = ds[info["var"]].dims + + sliced = base_arr + if info["sel"]: + sls = [slice(None)] * sliced.ndim + for d, val in info["sel"].items(): + if d in dims: + sls[dims.index(d)] = val + sliced = sliced[tuple(sls)] + + # --- STRATEGY SELECTION --- + if rel_indices is not None: + # STRATEGY A: Contiguous Read + Memory Filter (Best for Patches) + if "time" in dims: + sliced = sliced[start_t:end_t, disk_indices] + else: + sliced = da.repeat(da.expand_dims(sliced[disk_indices], 0), n_steps, axis=0) + + # Dask computes the slice, then we filter in RAM + chunk = sliced.compute().astype(np.float32) + chunk = chunk[:, rel_indices] + else: + # STRATEGY B: Fancy Indexing (Best for Global Sparse) + # disk_indices is a list of integers here + if "time" in dims: + # Slice time range, then select specific points + sliced = sliced[start_t:end_t] + sliced = sliced[:, disk_indices] + else: + sliced = sliced[disk_indices] + sliced = da.repeat(da.expand_dims(sliced, 0), n_steps, axis=0) + + chunk = sliced.compute().astype(np.float32) + + chunk[~np.isfinite(chunk)] = np.nan + output_block[:, i] = chunk.reshape(-1) + + return output_block + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + raise NotImplementedError( + "DataReaderMesh._get should not be called directly. Use get_source or get_target." + ) + + @override + def init_empty(self) -> None: + super().init_empty() + self._len_cached = 0 + + @override + def length(self) -> int: + return getattr(self, "_len_cached", 0) + + def _parse_attr(self, attrs, key): + val = attrs.get(key, {}) + return json.loads(val) if isinstance(val, str) else val + + def _select_channels(self, type_key: str) -> list[int]: + select = self._stream_info.get(type_key) + exclude = self._stream_info.get(f"{type_key}_exclude", []) + return [ + i + for i, ch in enumerate(self.available_channels) + if (not select or any(s in ch for s in select)) and not any(e in ch for e in exclude) + ] + + def _init_stats_arrays(self): + self.mean = np.zeros(len(self.available_channels), dtype=np.float32) + self.stdev = np.ones(len(self.available_channels), dtype=np.float32) + for i, ch in enumerate(self.available_channels): + mu = self.stats_means.get(ch, 0.0) + var = self.stats_vars.get(ch, 1.0) + if mu is None or np.isnan(mu) or np.isinf(mu): + mu = 0.0 + if var is None or np.isnan(var) or np.isinf(var) or var < 1e-7: + var = 1.0 + self.mean[i] = mu + self.stdev[i] = np.sqrt(var) + self.mean_geoinfo = np.zeros(0, dtype=np.float32) + self.stdev_geoinfo = np.ones(0, dtype=np.float32) + + @override + def normalize_source_channels(self, source: np.typing.NDArray) -> np.typing.NDArray: + norm = (source.astype(np.float64) - self.mean[self.source_idx]) / self.stdev[ + self.source_idx + ] + return np.nan_to_num(norm, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) + + @override + def normalize_target_channels(self, target: np.typing.NDArray) -> np.typing.NDArray: + norm = (target.astype(np.float64) - self.mean[self.target_idx]) / self.stdev[ + self.target_idx + ] + return np.nan_to_num(norm, nan=np.nan, posinf=np.nan, neginf=np.nan).astype(np.float32) + + @override + def denormalize_source_channels(self, source: np.typing.NDArray) -> np.typing.NDArray: + return (source * self.stdev[self.source_idx]) + self.mean[self.source_idx] + + @override + def denormalize_target_channels(self, data: np.typing.NDArray) -> np.typing.NDArray: + return (data * self.stdev[self.target_idx]) + self.mean[self.target_idx] + + @override + def normalize_geoinfos(self, geoinfos: np.typing.NDArray) -> np.typing.NDArray: + return geoinfos diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 39953a25e..7e0b6f3d8 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -20,5 +20,9 @@ def get_extra_reader(stream_type: str) -> object | None: from weathergen.readers_extra.data_reader_cams import DataReaderCams return DataReaderCams + case "mesh": + from weathergen.readers_extra.data_reader_mesh import DataReaderMesh + + return DataReaderMesh case _: return None diff --git a/packages/science/compute_spatial_autocorrelation.py b/packages/science/compute_spatial_autocorrelation.py new file mode 100644 index 000000000..5f7ea11e1 --- /dev/null +++ b/packages/science/compute_spatial_autocorrelation.py @@ -0,0 +1,1142 @@ +#!/usr/bin/env python3 +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Compute spatial autocorrelation per variable and suggest per-stream masking configs. + +This script analyses a dataset to determine the spatial correlation length of +each variable, maps that to an appropriate HEALPix masking level (``hl_mask``), +and groups variables by similar correlation scale. The output is a summary +table plus YAML snippets ready to be used as ``masking_override`` blocks in +stream config files. + +Example usage: + + uv run python packages/science/compute_spatial_autocorrelation.py \\ + --dataset /path/to/data.zarr \\ + --type anemoi \\ or obs, or less supported options below + --channels z_500 z_850 t_500 t_850 q_700 tp \\ defaults to all vars + --n-time-samples 100 \\ + --n-sample-pairs 100000 \\ + --correlation-multiplier 0.5 \\ + + then see further optional args below for controlling the output. +""" + +from __future__ import annotations + +import argparse +import logging +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +from numpy.typing import NDArray + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +EARTH_RADIUS_KM = 6371.0 + + +@dataclass +class DatasetInfo: + """Minimal container for data loaded from a dataset.""" + + lats: NDArray[np.float64] | None # [n_points] radians + lons: NDArray[np.float64] | None # [n_points] radians + data: dict[str, NDArray] # var_name -> [n_times, n_points] + period_hours: float | None = None + lats_ragged: list[NDArray[np.float64]] | None = None # per-time [n_points_t] + lons_ragged: list[NDArray[np.float64]] | None = None # per-time [n_points_t] + data_ragged: dict[str, list[NDArray]] | None = None # var_name -> list[n_points_t] + + +@dataclass +class VarResult: + """Autocorrelation analysis result for a single variable.""" + + name: str + l_corr_km: float + hl_mask: int + bin_centers_km: NDArray = field(repr=False) + bin_correlations: NDArray = field(repr=False) + + +# --------------------------------------------------------------------------- +# Data loaders +# --------------------------------------------------------------------------- + + +def load_anemoi( + path: str | Path, + n_time_samples: int, + channels: list[str] | None, + seed: int, +) -> DatasetInfo: + """Load data from an anemoi-format zarr dataset.""" + import anemoi.datasets as anemoi_datasets + + ds = anemoi_datasets.open_dataset(path) + rng = np.random.default_rng(seed) + + all_vars = list(ds.variables) + if channels is None: + channels = all_vars + + # Map channel names to indices + var_indices = {} + for ch in channels: + if ch not in all_vars: + logger.warning(f"Channel '{ch}' not found in dataset, skipping. Available: {all_vars}") + continue + var_indices[ch] = all_vars.index(ch) + + if not var_indices: + raise ValueError(f"No valid channels found. Available: {all_vars}") + + n_times_total = len(ds) + n_samples = min(n_time_samples, n_times_total) + time_indices = rng.choice(n_times_total, size=n_samples, replace=False) + time_indices.sort() + + lats = np.deg2rad(ds.latitudes) + lons = np.deg2rad(ds.longitudes) + + data = {} + for ch, idx in var_indices.items(): + values = [] + for ti in time_indices: + row = ds[int(ti)] # common shapes: [n_vars, n_points], [n_vars, 1, n_points] + row = np.asarray(row) + if row.ndim == 1: + values.append(row) + continue + + var_axis = None + for axis, dim in enumerate(row.shape): + if dim == len(all_vars): + var_axis = axis + break + + if var_axis is None: + raise ValueError( + "Could not locate variable axis in anemoi sample with shape " + f"{row.shape} (n_vars={len(all_vars)})" + ) + + var_slice = np.take(row, idx, axis=var_axis) + var_slice = np.squeeze(var_slice) + if var_slice.ndim > 1: + var_slice = var_slice.reshape(-1) + values.append(var_slice) + + data[ch] = np.stack(values, axis=0) + + period = None + if hasattr(ds, "frequency"): + freq = np.timedelta64(ds.frequency) + period = freq / np.timedelta64(1, "h") + + return DatasetInfo(lats=lats, lons=lons, data=data, period_hours=period) + + +def load_zarr_columnar( + path: str | Path, + lat_col: str, + lon_col: str, + data_cols: list[str] | None, + n_time_samples: int, + seed: int, + max_points_per_time: int | None = 50_000, +) -> DatasetInfo: + """Load from a zarr store with named lat/lon columns.""" + import zarr + + store = zarr.open(path, mode="r") + rng = np.random.default_rng(seed) + + if lat_col in store and lon_col in store: + lats = np.deg2rad(np.asarray(store[lat_col])) + lons = np.deg2rad(np.asarray(store[lon_col])) + + if data_cols is None: + skip = {lat_col, lon_col, "time", "datetime", "date"} + data_cols = [k for k in store.keys() if k not in skip] + + # Determine time dimension + first_arr = np.asarray(store[data_cols[0]]) + if first_arr.ndim == 1: + # No time dimension + data = {col: np.asarray(store[col])[np.newaxis, :] for col in data_cols} + else: + n_times_total = first_arr.shape[0] + n_samples = min(n_time_samples, n_times_total) + time_indices = rng.choice(n_times_total, size=n_samples, replace=False) + time_indices.sort() + data = {col: np.asarray(store[col])[time_indices] for col in data_cols} + + return DatasetInfo(lats=lats, lons=lons, data=data) + + # Observation-style zarr with a single data table and column metadata + if "data" not in store: + raise ValueError("Zarr store does not contain lat/lon arrays or a 'data' table.") + + data_arr = store["data"] + colnames = list(data_arr.attrs.get("colnames", [])) + if not colnames: + raise ValueError("Zarr 'data' array missing 'colnames' metadata.") + + if lat_col in colnames: + lat_idx = colnames.index(lat_col) + else: + lat_idx = int(data_arr.attrs.get("lat_idx", [None])[0]) + if lon_col in colnames: + lon_idx = colnames.index(lon_col) + else: + lon_idx = int(data_arr.attrs.get("lon_idx", [None])[0]) + + if lat_idx is None or lon_idx is None: + raise ValueError("Could not determine lat/lon column indices for observation table.") + + if data_cols is None: + data_idxs = data_arr.attrs.get("data_idxs") + if data_idxs is None: + skip = {lat_col, lon_col, "time", "datetime", "date"} + data_idxs = [i for i, name in enumerate(colnames) if name not in skip] + data_cols = [colnames[i] for i in data_idxs] + data_indices = list(data_idxs) + else: + missing = [c for c in data_cols if c not in colnames] + if missing: + raise ValueError(f"Requested columns not found in data table: {missing}") + data_indices = [colnames.index(c) for c in data_cols] + + idx_key = next((k for k in store.keys() if k.startswith("idx_")), None) + if idx_key is None: + raise ValueError("Observation zarr missing time index array (idx_*).") + + idx = np.asarray(store[idx_key]) + n_rows = data_arr.shape[0] + n_times_total = len(idx) + end_idx = np.concatenate([idx[1:], np.array([n_rows], dtype=idx.dtype)]) + counts = end_idx - idx + valid_times = np.where(counts >= 2)[0] + if len(valid_times) == 0: + raise ValueError("No valid time slices found in observation zarr table.") + + n_samples = min(n_time_samples, len(valid_times)) + time_indices = rng.choice(valid_times, size=n_samples, replace=False) + time_indices.sort() + + lats_list: list[NDArray[np.float64]] = [] + lons_list: list[NDArray[np.float64]] = [] + data_list: dict[str, list[NDArray]] = {col: [] for col in data_cols} + + for ti in time_indices: + start = int(idx[ti]) + end = int(end_idx[ti]) + if end <= start: + continue + rows = np.asarray(data_arr[start:end]) + if rows.ndim != 2: + raise ValueError(f"Observation rows expected 2D, got {rows.shape}") + + lats = np.deg2rad(rows[:, lat_idx]) + lons = np.deg2rad(rows[:, lon_idx]) + + if max_points_per_time is not None and len(lats) > max_points_per_time: + sample = rng.choice(len(lats), size=max_points_per_time, replace=False) + lats = lats[sample] + lons = lons[sample] + rows = rows[sample] + + lats_list.append(lats) + lons_list.append(lons) + for col, cidx in zip(data_cols, data_indices, strict=False): + data_list[col].append(rows[:, cidx]) + + if not lats_list: + raise ValueError("No valid time slices found in observation zarr table.") + + return DatasetInfo( + lats=None, + lons=None, + data={}, + lats_ragged=lats_list, + lons_ragged=lons_list, + data_ragged=data_list, + ) + + +def load_xarray( + path: str | Path, + lat_var: str, + lon_var: str, + data_vars: list[str] | None, + n_time_samples: int, + seed: int, +) -> DatasetInfo: + """Load from a netCDF/xarray-compatible file.""" + import xarray as xr + + ds = xr.open_dataset(path) + rng = np.random.default_rng(seed) + + lats_raw = ds[lat_var].values + lons_raw = ds[lon_var].values + + if data_vars is None: + skip = {lat_var, lon_var} + data_vars = [v for v in ds.data_vars if v not in skip] + + # Handle gridded data: flatten spatial dims + sample_var = ds[data_vars[0]] + dims = sample_var.dims + + # Find time dimension + time_dim = None + for d in dims: + if "time" in d.lower(): + time_dim = d + break + + if time_dim is not None: + n_times_total = ds.sizes[time_dim] + n_samples = min(n_time_samples, n_times_total) + time_indices = rng.choice(n_times_total, size=n_samples, replace=False) + time_indices.sort() + ds_sub = ds.isel({time_dim: time_indices}) + else: + ds_sub = ds + + # If lat/lon are 1D coordinate arrays, create a meshgrid + if lats_raw.ndim == 1 and lons_raw.ndim == 1: + lon_grid, lat_grid = np.meshgrid(lons_raw, lats_raw) + lats_flat = np.deg2rad(lat_grid.ravel()) + lons_flat = np.deg2rad(lon_grid.ravel()) + else: + lats_flat = np.deg2rad(lats_raw.ravel()) + lons_flat = np.deg2rad(lons_raw.ravel()) + + data = {} + for var in data_vars: + arr = ds_sub[var].values + # Flatten spatial dims, keep time + if time_dim is not None: + spatial_size = np.prod(arr.shape[1:]) + data[var] = arr.reshape(arr.shape[0], spatial_size) + else: + data[var] = arr.ravel()[np.newaxis, :] + + return DatasetInfo(lats=lats_flat, lons=lons_flat, data=data) + + +# --------------------------------------------------------------------------- +# Anomaly / detrending helpers +# --------------------------------------------------------------------------- + + +def _standardize_structured(data: NDArray) -> NDArray: + """Per-point temporal standardization: remove time-mean, divide by time-std. + + This removes the climatological spatial pattern (latitude gradients, land-sea + contrast, orographic effects) so that the autocorrelation reflects the + correlation structure of *weather anomalies* rather than the smooth background + climate. Without this, fields with strong gradients (tp, q, 2t) show + artificially long correlation lengths. + """ + time_mean = np.nanmean(data, axis=0) # [n_points] + anomalies = data - time_mean[None, :] + time_std = np.nanstd(data, axis=0) # [n_points] + time_std = np.where(time_std < 1e-12, 1.0, time_std) + return anomalies / time_std[None, :] + + +def _standardize_ragged(data_list: list[NDArray]) -> list[NDArray]: + """Per-snapshot spatial standardization for ragged unstructured data. + + Since each time slice may have different observation locations, we cannot + compute a per-point temporal mean. Instead, we remove the spatial mean and + normalise by the spatial std within each snapshot. This removes the gross + large-scale gradient for each time step. + """ + result: list[NDArray] = [] + for values in data_list: + values = np.asarray(values, dtype=np.float64) + mean_val = np.nanmean(values) + std_val = np.nanstd(values) + if std_val < 1e-12: + result.append(values - mean_val) + else: + result.append((values - mean_val) / std_val) + return result + + +# --------------------------------------------------------------------------- +# Spatial autocorrelation +# --------------------------------------------------------------------------- + + +def haversine_km(lat1: NDArray, lon1: NDArray, lat2: NDArray, lon2: NDArray) -> NDArray: + """Vectorized haversine distance in km. Inputs in radians.""" + dlat = lat2 - lat1 + dlon = lon2 - lon1 + a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2 + return 2 * EARTH_RADIUS_KM * np.arcsin(np.sqrt(np.clip(a, 0, 1))) + + +def compute_spatial_autocorr( + data: NDArray, + lats: NDArray, + lons: NDArray, + max_lag_km: float = 3000.0, + n_bins: int = 50, + n_sample_pairs: int = 100_000, + seed: int = 42, +) -> tuple[float, NDArray, NDArray]: + """Estimate the spatial correlation length of a variable on a fixed grid. + + The algorithm randomly samples pairs of grid points at different + distances and computes the Pearson correlation of the variable's + anomaly values as a function of great-circle (haversine) distance. + Specifically: + + 1. Draw ``n_sample_pairs`` random (time, point_i, point_j) triples. + 2. Compute the haversine distance for each pair and discard pairs + beyond ``max_lag_km``. + 3. Bin remaining pairs by distance into ``n_bins`` equal-width bins. + 4. For each bin, compute the Pearson correlation: + ρ(d) = Cov(X_i, X_j) / Var(X), where the variance is global. + 5. Fit an exponential decay ρ(d) ≈ exp(-d / L_corr) to the binned + correlations to estimate the correlation length L_corr in km. + + Parameters + ---------- + data : array [n_times, n_points] + lats, lons : arrays [n_points], in radians + max_lag_km : maximum lag distance for binning + n_bins : number of distance bins + n_sample_pairs : number of random point pairs to sample + seed : RNG seed + + Returns + ------- + l_corr_km : estimated correlation length in km + bin_centers : distance bin centers in km + bin_corr : binned correlation values + """ + rng = np.random.default_rng(seed) + data = np.asarray(data) + if data.ndim == 1: + data = data[np.newaxis, :] + elif data.ndim != 2: + raise ValueError( + f"Expected data with shape [n_times, n_points], got array with shape {data.shape}" + ) + + n_points_expected = len(lats) + if data.shape[1] != n_points_expected and data.shape[0] == n_points_expected: + data = data.T + + if data.shape[1] != n_points_expected: + raise ValueError( + "Data spatial dimension does not match lat/lon length: " + f"data.shape={data.shape}, n_points={n_points_expected}" + ) + + n_times, n_points = data.shape + + # Sample random pairs of (time, point_i, point_j) + time_indices = rng.integers(0, n_times, size=n_sample_pairs) + idx_i = rng.integers(0, n_points, size=n_sample_pairs) + idx_j = rng.integers(0, n_points, size=n_sample_pairs) + + # Remove self-pairs + valid = idx_i != idx_j + time_indices = time_indices[valid] + idx_i = idx_i[valid] + idx_j = idx_j[valid] + + # Compute distances + distances = haversine_km(lats[idx_i], lons[idx_i], lats[idx_j], lons[idx_j]) + + # Filter by max distance + in_range = distances <= max_lag_km + distances = distances[in_range] + time_indices = time_indices[in_range] + idx_i = idx_i[in_range] + idx_j = idx_j[in_range] + + if len(distances) < 100: + raise ValueError( + "Too few valid point pairs for autocorrelation estimation " + f"({len(distances)} pairs). The dataset may be too small or too sparse." + ) + + # Get values for all pairs + vals_i = data[time_indices, idx_i] + vals_j = data[time_indices, idx_j] + + # Remove pairs with NaN + nan_mask = np.isnan(vals_i) | np.isnan(vals_j) + if nan_mask.any(): + keep = ~nan_mask + distances = distances[keep] + vals_i = vals_i[keep] + vals_j = vals_j[keep] + + if len(distances) < 100: + raise ValueError( + "Too few non-NaN point pairs for autocorrelation estimation " + f"({len(distances)} pairs). The variable may contain too many NaNs." + ) + + # Bin by distance and compute correlation per bin + bin_edges = np.linspace(0, max_lag_km, n_bins + 1) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + bin_indices = np.digitize(distances, bin_edges) - 1 + bin_indices = np.clip(bin_indices, 0, n_bins - 1) + + # Compute per-bin Pearson correlation via E[XY] - E[X]E[Y] + global_mean = np.nanmean(np.concatenate([vals_i, vals_j])) + global_var = np.nanvar(np.concatenate([vals_i, vals_j])) + + if global_var < 1e-12: + raise ValueError( + "Near-zero variance in sampled data — the variable is effectively constant " + "and its correlation length is undefined." + ) + + bin_corr = np.full(n_bins, np.nan) + bin_counts = np.zeros(n_bins, dtype=int) + + for b in range(n_bins): + mask = bin_indices == b + count = mask.sum() + bin_counts[b] = count + if count < 30: + continue + vi = vals_i[mask] + vj = vals_j[mask] + cov = np.mean((vi - global_mean) * (vj - global_mean)) + bin_corr[b] = cov / global_var + + # Fit exponential decay: corr(d) = exp(-d/L) + l_corr_km = _fit_correlation_length(bin_centers, bin_corr, bin_counts) + + return l_corr_km, bin_centers, bin_corr + + +def compute_spatial_autocorr_unstructured( + data_list: list[NDArray], + lats_list: list[NDArray], + lons_list: list[NDArray], + max_lag_km: float = 3000.0, + n_bins: int = 50, + n_sample_pairs: int = 100_000, + seed: int = 42, +) -> tuple[float, NDArray, NDArray]: + """Compute spatial autocorrelation for ragged (unstructured) per-time observations. + + Unlike ``compute_spatial_autocorr``, this function handles datasets where + each time step can have a different number of observation points at + different locations (e.g. SYNOP, radiosonde). The algorithm: + + 1. Weight-sample time steps proportionally to the number of possible + pairs, so denser time steps contribute more. + 2. For each sampled time step, draw random point pairs from that + snapshot and compute haversine distances. + 3. Bin all (distance, value_i, value_j) tuples into distance bins and + compute the Pearson correlation per bin, identical to the structured + version. + 4. Fit an exponential decay to estimate the correlation length L_corr. + """ + rng = np.random.default_rng(seed) + if not (len(data_list) == len(lats_list) == len(lons_list)): + raise ValueError("Ragged data, lat, and lon lists must have the same length.") + + n_times = len(data_list) + if n_times == 0: + raise ValueError("No time samples available for autocorrelation estimation.") + + sizes = np.array([len(lats) for lats in lats_list], dtype=int) + valid_times = np.where(sizes >= 2)[0] + if len(valid_times) == 0: + raise ValueError( + "All time slices have fewer than 2 observation points — " + "cannot compute pairwise autocorrelation." + ) + + weights = sizes[valid_times] * (sizes[valid_times] - 1) + if weights.sum() == 0: + raise ValueError("Too few valid observation pairs for autocorrelation estimation.") + + time_samples = rng.choice( + valid_times, size=n_sample_pairs, replace=True, p=weights / weights.sum() + ) + + distances_list: list[NDArray] = [] + vals_i_list: list[NDArray] = [] + vals_j_list: list[NDArray] = [] + + for t in np.unique(time_samples): + count = int(np.sum(time_samples == t)) + lats = lats_list[t] + lons = lons_list[t] + vals = np.asarray(data_list[t]) + if vals.ndim != 1: + vals = vals.reshape(-1) + n_points = len(vals) + if n_points < 2: + continue + + idx_i = rng.integers(0, n_points, size=count) + idx_j = rng.integers(0, n_points, size=count) + same = idx_i == idx_j + while same.any(): + idx_j[same] = rng.integers(0, n_points, size=int(same.sum())) + same = idx_i == idx_j + + distances = haversine_km(lats[idx_i], lons[idx_i], lats[idx_j], lons[idx_j]) + distances_list.append(distances) + vals_i_list.append(vals[idx_i]) + vals_j_list.append(vals[idx_j]) + + if not distances_list: + raise ValueError( + "No valid observation pairs were generated — " + "the dataset may be too sparse for autocorrelation estimation." + ) + + distances = np.concatenate(distances_list) + vals_i = np.concatenate(vals_i_list) + vals_j = np.concatenate(vals_j_list) + + in_range = distances <= max_lag_km + distances = distances[in_range] + vals_i = vals_i[in_range] + vals_j = vals_j[in_range] + + if len(distances) < 100: + raise ValueError( + "Too few valid point pairs for autocorrelation estimation " + f"({len(distances)} pairs). The observation dataset may be too sparse." + ) + + nan_mask = np.isnan(vals_i) | np.isnan(vals_j) + if nan_mask.any(): + keep = ~nan_mask + distances = distances[keep] + vals_i = vals_i[keep] + vals_j = vals_j[keep] + + if len(distances) < 100: + raise ValueError( + "Too few non-NaN observation pairs for autocorrelation estimation " + f"({len(distances)} pairs). The variable may contain too many NaNs." + ) + + bin_edges = np.linspace(0, max_lag_km, n_bins + 1) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + bin_indices = np.digitize(distances, bin_edges) - 1 + bin_indices = np.clip(bin_indices, 0, n_bins - 1) + + global_mean = np.nanmean(np.concatenate([vals_i, vals_j])) + global_var = np.nanvar(np.concatenate([vals_i, vals_j])) + + if global_var < 1e-12: + raise ValueError( + "Near-zero variance in sampled observation data — the variable is effectively " + "constant and its correlation length is undefined." + ) + + bin_corr = np.full(n_bins, np.nan) + bin_counts = np.zeros(n_bins, dtype=int) + + for b in range(n_bins): + mask = bin_indices == b + count = int(mask.sum()) + bin_counts[b] = count + if count < 30: + continue + vi = vals_i[mask] + vj = vals_j[mask] + cov = np.mean((vi - global_mean) * (vj - global_mean)) + bin_corr[b] = cov / global_var + + l_corr_km = _fit_correlation_length(bin_centers, bin_corr, bin_counts) + + return l_corr_km, bin_centers, bin_corr + + +def _fit_correlation_length(bin_centers: NDArray, bin_corr: NDArray, bin_counts: NDArray) -> float: + """Fit correlation length from binned correlation data. + + The correlation length L_corr characterises the spatial scale over which + a variable's anomalies are significantly correlated. Because different + atmospheric variables have very different correlation structures (e.g. + smooth geopotential vs. noisy precipitation), no single estimator is + universally reliable. We therefore try three methods in order of + decreasing robustness: + + Strategy (in priority order): + + 1. **1/e threshold crossing** – the most model-free estimator. We find + the distance at which the binned correlation drops below 1/e ≈ 0.37. + This works well for any monotonically decaying correlation function, + including non-exponential shapes common in precipitation and wind + fields. It is preferred because it makes no parametric assumptions. + + 2. **Integrated correlation scale** – L_eff = ∫ max(ρ(r), 0) dr. When + the correlation function never cleanly crosses the 1/e level (e.g. + it starts below 1/e due to noise, or has a plateau), this integral + measure provides a robust single-number summary. It is less + sensitive to bin noise than a threshold crossing but can under- + estimate L_corr if the function has a long positive tail. + + 3. **Log-linear (exponential) fit** – weighted least-squares regression + of log(ρ) vs. distance, i.e. fitting ρ(d) = exp(-d/L). This is + used as a last resort because the exponential model can overestimate + L_corr when the true correlation function has a steep initial drop + followed by a slow tail (common for moisture variables). + + If all three methods fail (e.g. too few valid bins), a ``ValueError`` + is raised rather than returning a potentially misleading default. + """ + min_bin_count = 30 + + valid = (~np.isnan(bin_corr)) & (bin_counts >= min_bin_count) & (bin_corr > 0.01) + valid_centers = bin_centers[valid] + valid_corr = bin_corr[valid] + + if len(valid_centers) < 3: + raise ValueError( + "Too few valid distance bins for correlation length estimation " + f"({len(valid_centers)} valid bins). The data may be too noisy or too sparse." + ) + + # --- Method 1: 1/e threshold crossing (most robust) --- + threshold = 1.0 / np.e + for i in range(len(valid_corr) - 1): + if valid_corr[i] >= threshold > valid_corr[i + 1]: + frac = (valid_corr[i] - threshold) / (valid_corr[i] - valid_corr[i + 1]) + l_corr = valid_centers[i] + frac * (valid_centers[i + 1] - valid_centers[i]) + if 10.0 < l_corr < 20000.0: + return l_corr + + # --- Method 2: Integrated correlation scale --- + # L_eff = ∫ max(ρ(r), 0) dr (trapezoidal integration over valid bins) + positive_corr = np.maximum(valid_corr, 0.0) + if len(valid_centers) >= 2: + l_eff = float(np.trapezoid(positive_corr, valid_centers)) + if 10.0 < l_eff < 20000.0: + return l_eff + + # --- Method 3: Log-linear (exponential) fit --- + log_corr = np.log(valid_corr) + try: + weights = np.sqrt(bin_counts[valid].astype(float)) + coeffs = np.polyfit(valid_centers, log_corr, 1, w=weights) + slope = coeffs[0] + if slope < -1e-8: + l_corr = -1.0 / slope + if 10.0 < l_corr < 20000.0: + return l_corr + except (np.linalg.LinAlgError, ValueError): + pass + + raise ValueError( + "All three correlation length estimation methods failed " + "(1/e crossing, integrated scale, log-linear fit). " + "The correlation function may be too noisy or non-monotonic." + ) + + +# --------------------------------------------------------------------------- +# hl_mask mapping and grouping +# --------------------------------------------------------------------------- + + +def correlation_length_to_hl_mask( + l_corr_km: float, + healpix_level: int, + multiplier: float = 1.5, +) -> int: + """Map a correlation length to the appropriate HEALPix masking level. + + Finds the finest HEALPix level where the cell size exceeds + ``l_corr_km * multiplier``. + + Parameters + ---------- + l_corr_km : spatial correlation length in km + healpix_level : the training grid HEALPix level + multiplier : how much larger mask blocks should be vs. correlation length + + Returns + ------- + hl_mask : integer HEALPix level for masking (0 to healpix_level) + """ + target_km = l_corr_km * multiplier + # HEALPix cell size at level l: approx 4000 / 2^l km + # (12 base pixels, area = 4*pi*R^2/Npix, side ~ sqrt(area)) + for hl in range(healpix_level, -1, -1): + n_pix = 12 * 4**hl + cell_area_km2 = (4 * np.pi * EARTH_RADIUS_KM**2) / n_pix + cell_size_km = np.sqrt(cell_area_km2) + if cell_size_km >= target_km: + return hl + return 0 + + +def group_by_hl_mask(var_results: dict[str, VarResult]) -> dict[int, list[str]]: + """Group variables by their recommended hl_mask level. + + Returns + ------- + dict mapping hl_mask -> list of variable names + """ + groups: dict[int, list[str]] = {} + for name, result in var_results.items(): + groups.setdefault(result.hl_mask, []).append(name) + return dict(sorted(groups.items())) + + +def group_by_hl_mask_for_multiplier( + var_results: dict[str, VarResult], + healpix_level: int, + multiplier: float, +) -> dict[int, list[str]]: + """Group variables by hl_mask for a given correlation multiplier.""" + groups: dict[int, list[str]] = {} + for name, result in var_results.items(): + hl = correlation_length_to_hl_mask(result.l_corr_km, healpix_level, multiplier) + groups.setdefault(hl, []).append(name) + return dict(sorted(groups.items())) + + +# --------------------------------------------------------------------------- +# Output formatting +# --------------------------------------------------------------------------- + +# Approximate cell sizes for display +_HL_CELL_SIZES = { + 0: "~4000", + 1: "~2000", + 2: "~1000", + 3: "~500", + 4: "~250", + 5: "~125", +} + +_HL_CORR_RANGES = { + 0: "2000+ km", + 1: "1000-2000 km", + 2: "500-1000 km", + 3: "250-500 km", + 4: "100-250 km", + 5: "<100 km", +} + + +def format_results_table(var_results: dict[str, VarResult]) -> str: + """Format a human-readable results table.""" + lines = [] + lines.append(f"{'Variable':<20s} {'L_corr (km)':>12s} {'hl_mask':>7s}") + lines.append("-" * 20 + " " + "-" * 12 + " " + "-" * 7) + for name, r in sorted(var_results.items(), key=lambda x: -x[1].l_corr_km): + lines.append(f"{name:<20s} {r.l_corr_km:>12.0f} {r.hl_mask:>7d}") + return "\n".join(lines) + + +def format_groupings(groups: dict[int, list[str]]) -> str: + """Format stream grouping suggestions.""" + lines = [] + for hl, vars_list in sorted(groups.items()): + corr_range = _HL_CORR_RANGES.get(hl, "unknown") + vars_str = ", ".join(vars_list) + lines.append(f"Stream group hl_mask={hl} (L_corr {corr_range}): {vars_str}") + return "\n".join(lines) + + +def generate_yaml_snippets(groups: dict[int, list[str]]) -> str: + """Generate YAML masking_override snippets for each group.""" + lines = [] + for hl, vars_list in sorted(groups.items()): + vars_str = ", ".join(vars_list) + lines.append(f"# Stream group for: {vars_str}") + lines.append(f"# Recommended hl_mask: {hl}") + lines.append("masking_override:") + lines.append(" model_input:") + lines.append(" masking_strategy_config:") + lines.append(f" hl_mask: {hl}") + lines.append(" target_input:") + lines.append(" masking_strategy_config:") + lines.append(f" hl_mask: {max(0, hl)}") + lines.append("") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + + +def analyse_dataset( + dataset_path: str | Path, + dataset_type: str, + channels: list[str] | None = None, + n_time_samples: int = 100, + n_sample_pairs: int = 100_000, + correlation_multiplier: float = 1.5, + healpix_level: int = 5, + seed: int = 42, + lat_col: str = "lat", + lon_col: str = "lon", + lat_var: str = "latitude", + lon_var: str = "longitude", + max_points_per_time: int | None = 50_000, + detrend: bool = True, +) -> tuple[dict[str, VarResult], dict[int, list[str]]]: + """Run the full analysis pipeline. + + Parameters + ---------- + detrend : bool + If True (default), remove the climatological spatial pattern before + computing autocorrelation. This prevents large-scale gradients + (latitude, orography) from inflating correlation lengths. + + Returns + ------- + var_results : per-variable analysis results + groups : hl_mask -> list of variable names + """ + # Load data + # NOTE: We use lightweight standalone loaders instead of the training + # DataReaders (DataReaderAnemoi / DataReaderObs). The analysis needs + # per-variable [n_times, n_points] arrays, whereas the readers return + # all channels flattened into [n_times*n_points, n_channels] which would + # need to be reshaped back. Reusing them would also require constructing + # a TimeWindowHandler to samples times. + logger.info(f"Loading dataset from {dataset_path} (type={dataset_type})") + if dataset_type == "anemoi": + ds_info = load_anemoi(dataset_path, n_time_samples, channels, seed) + elif dataset_type in ("fesom", "obs"): + ds_info = load_zarr_columnar( + dataset_path, + lat_col, + lon_col, + channels, + n_time_samples, + seed, + max_points_per_time=max_points_per_time, + ) + elif dataset_type in ("xarray", "cams", "eobs", "iconart", "iconesm"): + ds_info = load_xarray(dataset_path, lat_var, lon_var, channels, n_time_samples, seed) + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + if ds_info.data_ragged is not None: + n_times = len(ds_info.lats_ragged or []) + avg_points = int(np.mean([len(x) for x in ds_info.lats_ragged or [0]])) + logger.info( + f"Loaded {len(ds_info.data_ragged)} variables, {n_times} time samples, " + f"avg {avg_points} points/time (ragged)" + ) + else: + logger.info( + f"Loaded {len(ds_info.data)} variables, " + f"{next(iter(ds_info.data.values())).shape[0]} time samples, " + f"{len(ds_info.lats)} spatial points" + ) + + # Standardize to anomalies if requested + if detrend: + logger.info("Detrending: computing anomaly autocorrelation (climatology removed)") + if ds_info.data_ragged is not None: + for var_name in ds_info.data_ragged: + ds_info.data_ragged[var_name] = _standardize_ragged(ds_info.data_ragged[var_name]) + else: + for var_name in ds_info.data: + ds_info.data[var_name] = _standardize_structured(ds_info.data[var_name]) + else: + logger.info("No detrending: computing raw-field autocorrelation") + + # Compute autocorrelation per variable + var_results: dict[str, VarResult] = {} + if ds_info.data_ragged is not None: + assert ds_info.lats_ragged is not None + assert ds_info.lons_ragged is not None + data_items = ds_info.data_ragged.items() + else: + data_items = ds_info.data.items() + + for var_name, var_data in data_items: + logger.info(f"Computing autocorrelation for '{var_name}'...") + if ds_info.data_ragged is not None: + l_corr, bin_centers, bin_corr = compute_spatial_autocorr_unstructured( + var_data, + ds_info.lats_ragged, + ds_info.lons_ragged, + n_sample_pairs=n_sample_pairs, + seed=seed, + ) + else: + l_corr, bin_centers, bin_corr = compute_spatial_autocorr( + var_data, + ds_info.lats, + ds_info.lons, + n_sample_pairs=n_sample_pairs, + seed=seed, + ) + hl = correlation_length_to_hl_mask(l_corr, healpix_level, correlation_multiplier) + var_results[var_name] = VarResult( + name=var_name, + l_corr_km=l_corr, + hl_mask=hl, + bin_centers_km=bin_centers, + bin_correlations=bin_corr, + ) + logger.info(f" {var_name}: L_corr={l_corr:.0f} km -> hl_mask={hl}") + + groups = group_by_hl_mask(var_results) + return var_results, groups + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Compute spatial autocorrelation per variable and suggest masking configs.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--dataset", required=True, help="Path to dataset") + parser.add_argument( + "--type", + required=True, + choices=["anemoi", "fesom", "obs", "xarray", "cams", "eobs", "iconart", "iconesm"], + help="Dataset type", + ) + parser.add_argument( + "--channels", + nargs="*", + default=None, + help="Variables to analyse. If omitted, all variables in the dataset are analysed.", + ) + parser.add_argument("--n-time-samples", type=int, default=100, help="Timesteps to sample") + parser.add_argument("--n-sample-pairs", type=int, default=100_000, help="Point pairs to sample") + parser.add_argument( + "--correlation-multiplier", + type=float, + default=1.5, + help="Multiplier for mapping L_corr -> mask block size", + ) + parser.add_argument( + "--correlation-multipliers", + type=float, + nargs="*", + default=None, + help="Optional list of multipliers to print separate suggestions", + ) + parser.add_argument("--healpix-level", type=int, default=5, help="Training grid HEALPix level") + parser.add_argument("--output", default=None, help="Output YAML file path") + parser.add_argument("--seed", type=int, default=42, help="RNG seed") + # Extra args for non-anemoi types + parser.add_argument("--lat-col", default="lat", help="Latitude column name (zarr)") + parser.add_argument("--lon-col", default="lon", help="Longitude column name (zarr)") + parser.add_argument("--lat-var", default="latitude", help="Latitude variable name (xarray)") + parser.add_argument("--lon-var", default="longitude", help="Longitude variable name (xarray)") + parser.add_argument( + "--max-points-per-time", + type=int, + default=50_000, + help="Max points per time slice for unstructured observations", + ) + parser.add_argument( + "--no-detrend", + action="store_true", + help="Disable anomaly standardization (use raw fields for autocorrelation)", + ) + + args = parser.parse_args(argv) + + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + var_results, groups = analyse_dataset( + dataset_path=args.dataset, + dataset_type=args.type, + channels=args.channels, + n_time_samples=args.n_time_samples, + n_sample_pairs=args.n_sample_pairs, + correlation_multiplier=args.correlation_multiplier, + healpix_level=args.healpix_level, + seed=args.seed, + lat_col=args.lat_col, + lon_col=args.lon_col, + lat_var=args.lat_var, + lon_var=args.lon_var, + max_points_per_time=args.max_points_per_time, + detrend=not args.no_detrend, + ) + + # Output results to stdout + def _write(msg: str = "") -> None: + import sys + + sys.stdout.write(msg + "\n") + + mode = "anomaly" if not args.no_detrend else "raw" + _write("\n" + "=" * 50) + _write(f"Per-variable autocorrelation analysis (mode={mode})") + _write("=" * 50) + _write(format_results_table(var_results)) + _write() + # When --correlation-multipliers is given (e.g. 1.0 1.5 2.0), we print + # separate grouping tables and YAML snippets for each multiplier value so + # the user can compare how aggressively variables are grouped. Otherwise + # we use the single --correlation-multiplier value. + multipliers = args.correlation_multipliers or [args.correlation_multiplier] + yaml_sections: list[str] = [] + + for multiplier in multipliers: + if multiplier == args.correlation_multiplier: + groups_for_multiplier = groups + else: + groups_for_multiplier = group_by_hl_mask_for_multiplier( + var_results, args.healpix_level, multiplier + ) + + _write("=" * 50) + _write(f"Suggested stream groupings (multiplier={multiplier:g})") + _write("=" * 50) + _write(format_groupings(groups_for_multiplier)) + _write() + + yaml_output = generate_yaml_snippets(groups_for_multiplier) + _write("=" * 50) + _write(f"YAML masking_override snippets (multiplier={multiplier:g})") + _write("=" * 50) + _write(yaml_output) + yaml_sections.append(f"# Multiplier: {multiplier:g}\n{yaml_output.strip()}") + + if args.output: + Path(args.output).write_text("\n\n".join(yaml_sections) + "\n") + _write(f"\nYAML written to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 1dfdef15d..049e1790f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,11 +147,15 @@ line-ending = "lf" # Some useful warnings are not yet in ruff [tool.pylint.main] -ignore = ["tests"] +ignore = ["tests", ".venv", ".git", ".tox", ".ruff_cache"] +load-plugins = ["pylint.extensions.bad_builtin"] + +[tool.pylint.bad_builtins] +bad-functions = ["getattr", "setattr"] [tool.pylint.messages_control] disable = ["all"] -enable = ["W0201"] +enable = ["W0201", "W0141"] [tool.uv] diff --git a/scripts/actions.sh b/scripts/actions.sh index f0f4852ec..a3dc5e1c0 100755 --- a/scripts/actions.sh +++ b/scripts/actions.sh @@ -19,28 +19,24 @@ case "$1" in lint) ( cd "$SCRIPT_DIR" || exit 1 - uv run --no-project --with "ruff==0.12.2" ruff format --target-version py312 \ - src/ scripts/ packages/ \ - && \ uv run --no-project --with "ruff==0.12.2" \ - ruff check --target-version py312 \ - --fix \ - src/ scripts/ packages/ + ruff format --target-version py312 src/ scripts/ packages/ \ + && \ + uv run --no-project --with "ruff==0.12.2" \ + ruff check --target-version py312 --fix src/ scripts/ packages/ ) ;; lint-check) ( cd "$SCRIPT_DIR" || exit 1 - uv run --no-project --with "ruff==0.12.2" ruff format --target-version py312 \ - -n \ - src/ scripts/ packages/ \ - && \ uv run --no-project --with "ruff==0.12.2" \ - ruff check --target-version py312 \ - src/ scripts/ packages/ \ - && \ + ruff format --target-version py312 -n src/ scripts/ packages/ \ + && \ + uv run --no-project --with "ruff==0.12.2" \ + ruff check --target-version py312 src/ scripts/ packages/ \ + && \ uv run --no-project --with "pylint==4.0.3" \ - pylint src/ packages/ + pylint src/ packages/ ) ;; type-check) @@ -50,7 +46,6 @@ case "$1" in # weathergen-common uv sync --project packages/common --no-install-workspace - uv pip list uv run --project packages/common --frozen pyrefly check packages/common # Fail for errors on weathergen-common: if [ $? -ne 0 ]; then @@ -60,7 +55,6 @@ case "$1" in # weathergen-metrics uv sync --project packages/metrics --no-install-workspace - uv pip list uv run --project packages/metrics --frozen pyrefly check packages/metrics # Fail for errors on weathergen-metrics: if [ $? -ne 0 ]; then @@ -70,13 +64,11 @@ case "$1" in # weathergen-evaluate uv sync --project packages/evaluate --no-install-workspace --package weathergen-evaluate - uv pip list uv run --project packages/evaluate --frozen pyrefly check packages/evaluate # weathergen (root) # Install the whole workspace. It also needs the extra cpu option for the right version of pytorch. uv sync --all-packages --extra cpu --no-install-workspace - uv pip list uv run --all-packages pyrefly check src echo "Type checking completed." ) @@ -84,7 +76,6 @@ case "$1" in unit-test) ( cd "$SCRIPT_DIR" || exit 1 - uv sync --extra cpu uv run --extra cpu pytest src/ ) ;; @@ -176,9 +167,11 @@ case "$1" in ) ;; *) - # Automatically extract all options from the case statement - options=$(grep -oP '^\s*\K[\w-]+(?=\))' "$0" | tr '\n' '|' | sed 's/|$//') - echo "Usage: $0 {$options}" - exit 1 + ( + # Automatically extract all options from the case statement + options=$(grep -oP '^\s*\K[\w-]+(?=\))' "$0" | tr '\n' '|' | sed 's/|$//') + echo "Usage: $0 {$options}" + exit 1 + ) ;; esac diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 6dfe71c89..65d1897cf 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -213,19 +213,11 @@ def get_target_values( encode_times_target, ) - # if selection is not None and data.numel() > 0: - # device_sel = selection.to(data.device) - # data = data.index_select(0, device_sel) - # coords = coords.index_select(0, device_sel) - # if idxs_ord_inv.numel() > 0: - # idxs_ord_inv = idxs_ord_inv.index_select(0, device_sel) - - # # datetimes is numpy here - # np_sel = selection.cpu().numpy() - # datetimes = datetimes[np_sel] - - # TODO: idxs_ord_inv idxs_ord_inv = None + if data.numel() > 0: + # flatten per-token indices into one flat list + idxs_flat = torch.cat([idxs for idxs_cell in idxs_cells for idxs in idxs_cell]) + # compute indices for inversion + _, idxs_ord_inv = torch.sort(idxs_flat) - # selection not passed on, we call get_target_coords first return (data, datetimes, coords, idxs_ord_inv) diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index b070feda1..e652d4391 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -77,7 +77,7 @@ def get_current_beta(self, cur_step: int) -> float: """ halflife_steps = self.halflife_steps if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + halflife_steps = min(halflife_steps, cur_step * self.rampup_ratio) beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) return beta diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index aaf43ef0c..e65590857 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -199,7 +199,6 @@ def __init__(self, dim_in, dim_out, stream_name="stream_embed"): self.layer = torch.nn.Linear(dim_in, dim_out) def forward(self, x): - # x = checkpoint( self.layer, x.flatten( -2, -1), use_reentrant=True) - x = self.layer(x.flatten(-2, -1)) + x = checkpoint(self.layer, x.flatten(-2, -1), use_reentrant=False).unsqueeze(0) return x diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index ca2a2d725..be933935a 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -11,7 +11,6 @@ import itertools import logging -from pathlib import Path import omegaconf import torch @@ -21,7 +20,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -189,7 +188,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ - path_run = Path(cf.model_path) / run_id + path_run = get_path_model(run_id=run_id) mini_epoch_id = ( f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" ) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 0baf707ba..c3627c97e 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -1,3 +1,4 @@ +# pylint: disable=bad-builtin # ruff: noqa: T201 # (C) Copyright 2025 WeatherGenerator contributors. diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 116dce926..d913b70e4 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -1,3 +1,4 @@ +# pylint: disable=bad-builtin # ruff: noqa: T201 # (C) Copyright 2025 WeatherGenerator contributors. diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index bb39d1b17..465065753 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -109,11 +109,12 @@ def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput: # collect targets for all forecast steps for step in output_idxs: targets_cur, target_times_cur, target_coords_cur, meta_data = [], [], [], [] - is_spoof = [] + is_spoof, idxs_inv = [], [] for sample in batch.samples: targets_cur += [sample.streams_data[stream_name].target_tokens[step]] target_times_cur += [sample.streams_data[stream_name].target_times_raw[step]] target_coords_cur += [sample.streams_data[stream_name].target_coords_raw[step]] + idxs_inv += [sample.streams_data[stream_name].idxs_inv[step]] meta_data += [sample.meta_info] is_spoof += [sample.streams_data[stream_name].is_spoof()] @@ -123,6 +124,7 @@ def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput: "target_coords": target_coords_cur, "target_metda_data": meta_data, "is_spoof": is_spoof, + "idxs_inv": idxs_inv, } targets.add_physical_target(step, stream_name, targets_step) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index eb34a673d..431b9d1e0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -555,7 +555,9 @@ def validate(self, mini_epoch, mode_cfg, batch_size): with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp - with tqdm.tqdm(total=mode_cfg.samples_per_mini_epoch, disable=self.cf.with_ddp) as pbar: + with tqdm.tqdm( + total=len(self.data_loader_validation), disable=self.cf.with_ddp + ) as pbar: for bidx, batch in enumerate(dataset_val_iter): if cf.data_loading.get("memory_pinning", False): # pin memory for faster CPU-GPU transfer diff --git a/src/weathergen/utils/better_abc.py b/src/weathergen/utils/better_abc.py index e322927d4..347400294 100644 --- a/src/weathergen/utils/better_abc.py +++ b/src/weathergen/utils/better_abc.py @@ -1,3 +1,5 @@ +# pylint: disable=bad-builtin + """ Coding recipe for abstract fields in Python. diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 01f860b7a..03a7b9646 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -20,7 +20,7 @@ import yaml import weathergen.common.config as config -from weathergen.train.utils import TRAIN, VAL +from weathergen.train.utils import TRAIN from weathergen.utils.train_logger import Metrics, TrainLogger _logger = logging.getLogger(__name__) @@ -252,14 +252,9 @@ def plot_loss_avg(plot_dir: Path, runs_ids, runs_data, runs_active, stage=TRAIN, legend_str = [] for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): - if stage == TRAIN: - x_vals = np.array(run_data.train["num_samples"]) - y_vals = np.array(run_data.train["loss_avg_mean"]) - elif stage == VAL: - x_vals = np.array(run_data.val["num_samples"]) - y_vals = np.array(run_data.val["loss_avg_mean"]) - else: - assert False + run_data_stage = run_data.train if stage == TRAIN else run_data.val + x_vals = np.array(run_data_stage["num_samples"]) + y_vals = np.array(run_data_stage["loss_avg_mean"]) mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals)) @@ -302,8 +297,11 @@ def plot_loss_per_stream( plot_dir: Path, errs: list[str], channels: list[str], + forecast_steps: list[int], x_axis: str = "samples", x_type: str = "step", + x_lim: list[float] | None = None, + y_lim: list[float] | None = None, x_scale_log: bool = False, ): """ @@ -368,12 +366,21 @@ def plot_loss_per_stream( if len(col_split) < 4: if stream_name in col: data_cols += [col] - elif ( - col_split[1].lower() == stream_name.lower() - and col_split[2].lower() == err.lower() - and col_split[3] == channel - ): - data_cols += [col] + elif len(col_split) == 4: + if ( + col_split[1].lower() == stream_name.lower() + and col_split[2].lower() == err.lower() + and col_split[3] == channel + ): + data_cols += [col] + elif len(col_split) == 5: + if ( + col_split[1].lower() == stream_name.lower() + and col_split[2].lower() == err.lower() + and col_split[3] == channel + and int(col_split[4]) in forecast_steps + ): + data_cols += [col] for col in data_cols: x_vals = np.array(run_data_mode[x_col]) @@ -417,10 +424,14 @@ def plot_loss_per_stream( for line in legend.get_lines(): line.set(alpha=1.0) plt.grid(True, which="both", ls="-") - # cap at 1.0 in case of divergence of run (through normalziation, max should be - # around 1.0) - # plt.ylim([0.95 * min_val, (None if max_val < 2.0 else min(1.1, 1.025 * max_val))]) - plt.ylim([0.95 * min_val, 1.025 * max_val]) + + if y_lim is not None: + plt.ylim(y_lim) + else: + plt.ylim([0.95 * min_val, 1.025 * max_val]) + if x_lim is not None: + plt.xlim(x_lim) + plt.yscale("log") if x_scale_log: plt.xscale("log") @@ -431,8 +442,12 @@ def plot_loss_per_stream( rstr = "".join([f"{r}_" for r in runs_ids]) # save the plot - plt_fname = plot_dir / "{}{}{}_{}.png".format( - rstr, "".join([f"{m}_" for m in modes]), stream_name, channel + plt_fname = plot_dir / "{}{}fs_{}{}_{}.png".format( + rstr, + "".join([f"{m}_" for m in modes]), + "".join([f"{fs}_" for fs in forecast_steps]), + stream_name, + channel, ) _logger.info(f"Saving loss per stream plot to '{plt_fname}'") plt.savefig(plt_fname) @@ -621,6 +636,14 @@ def plot_train(args=None): nargs="+", help="List of channels to plot", ) + parser.add_argument( + "--forecast-steps", + dest="forecast_steps", + default=[0, 1], + type=int, + nargs="+", + help="List of channels to plot", + ) parser.add_argument( "--metrics", dest="metrics", @@ -629,6 +652,22 @@ def plot_train(args=None): nargs="+", help="List of metrics (e.g. mse) to plot", ) + parser.add_argument( + "--per-stream-x-lim", + dest="per_stream_x_lim", + default=None, + type=float, + nargs="+", + help="x-lim for per-stream plots", + ) + parser.add_argument( + "--per-stream-y-lim", + dest="per_stream_y_lim", + default=None, + type=float, + nargs="+", + help="x-lim for per-stream plots", + ) parser.add_argument( "--x_type", "-x", @@ -682,6 +721,27 @@ def plot_train(args=None): if args.delete == "True": clean_plot_folder(out_dir) + # collect all physical streams from all run_ids if requested + if "all" in streams: + for run_id in runs_ids: + # Load config from given model_path if provided, otherwise use path from private config + if model_base_dir: + cf = config.load_run_config( + run_id=run_id, mini_epoch=None, model_path=model_base_dir + ) + else: + cf = config.load_merge_configs( + private_home=None, + from_run_id=run_id, + mini_epoch=None, + ) + for stream_info in cf.streams: + streams += [stream_info["name"]] + # ensure items are unique + streams = list(set(streams)) + # remove "all" key that is a special flag and not an actual stream name + streams.remove("all") + # read logged data runs_data = [ @@ -713,8 +773,11 @@ def plot_train(args=None): streams, errs=args.metrics, channels=args.channels, + forecast_steps=args.forecast_steps, x_type=args.x_type, x_scale_log=x_scale_log, + x_lim=args.per_stream_x_lim, + y_lim=args.per_stream_y_lim, plot_dir=out_dir, ) plot_loss_per_stream( @@ -725,8 +788,11 @@ def plot_train(args=None): streams, errs=args.metrics, channels=args.channels, + forecast_steps=args.forecast_steps, x_type=args.x_type, x_scale_log=x_scale_log, + x_lim=args.per_stream_x_lim, + y_lim=args.per_stream_y_lim, plot_dir=out_dir, ) plot_loss_per_stream( @@ -737,8 +803,11 @@ def plot_train(args=None): streams, errs=args.metrics, channels=args.channels, + forecast_steps=args.forecast_steps, x_type=args.x_type, x_scale_log=x_scale_log, + x_lim=args.per_stream_x_lim, + y_lim=args.per_stream_y_lim, plot_dir=out_dir, ) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 1290a8660..71f2dccf7 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -125,14 +125,24 @@ def write_output( preds = [target.clone().unsqueeze(0) for target in targets] for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): + target_data = target_aux_out.physical[t_idx][sname] + t_coords = target_data["target_coords"][i_batch] + t_times = target_data["target_times"][i_batch] + + idxs_inv = target_aux_out.physical[t_idx][sname]["idxs_inv"][i_batch] + if idxs_inv is not None: + pred = pred[:, idxs_inv] + target = target[idxs_inv] + t_coords = t_coords[idxs_inv] + t_times = t_times[idxs_inv] + # denormalize data if requested and map to storage format preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] # extract original target coords and times from target data - target_data = target_aux_out.physical[t_idx][sname] - t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] - t_times_s += [target_data["target_times"][i_batch].astype("datetime64[ns]")] + t_coords_s += [t_coords.cpu().numpy()] + t_times_s += [t_times.astype("datetime64[ns]")] targets_lens[-1] += [[]] targets_lens[-1][-1] += [t.shape[0] for t in targets_s] @@ -142,13 +152,6 @@ def write_output( targets_coords_all[-1] += [np.concatenate(t_coords_s)] targets_times_all[-1] += [np.concatenate(t_times_s)] - # # TODO: re-enable - # if len(idxs_inv) > 0: - # pred = pred[:, idxs_inv] - # target = target[idxs_inv] - # targets_coords_raw[t_idx][i_strm] = targets_coords_raw[t_idx][i_strm][idxs_inv] - # targets_times_raw[t_idx][i_strm] = targets_times_raw[t_idx][i_strm][idxs_inv] - if len(preds_all) == 0 or np.array([p.shape[1] for pp in preds_all for p in pp]).sum() == 0: _logger.warning("Writing no data since predictions are empty.") return From 6b0fbb02481c1e70e05a3e338afde2e340a93d23 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 10 Mar 2026 22:37:46 +0100 Subject: [PATCH 232/344] add noise distribution plotting --- config/config_diffusion.yml | 4 +- src/weathergen/model/model.py | 10 +- src/weathergen/train/trainer.py | 48 ++++++ src/weathergen/utils/plot_diffusion_noise.py | 170 +++++++++++++++++++ 4 files changed, 225 insertions(+), 7 deletions(-) create mode 100644 src/weathergen/utils/plot_diffusion_noise.py diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 9c83e208e..3a3a3bced 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,8 +72,8 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 0.5 rho: 7 -p_mean: 0.0 # -1.2 -p_std: 1.2 # 1.2 +p_mean: -1.2 +p_std: 1.2 healpix_level: 5 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index dc1a8d9fb..ea7970259 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -648,10 +648,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - t_mean = tokens.mean() - t_std = tokens.std() - tokens = (tokens - t_mean) / (t_std + 1e-6) - tokens = torch.clamp(tokens, -100.0, 100.0) + # t_mean = tokens.mean() + # t_std = tokens.std() + # tokens = (tokens - t_mean) / (t_std + 1e-6) + # tokens = torch.clamp(tokens, -100.0, 100.0) # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): @@ -666,7 +666,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Un-normalize tokens # TODO: REMOVE THIS AS ABOVE. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - tokens = tokens * (t_std + 1e-6) + t_mean + # tokens = tokens * (t_std + 1e-6) + t_mean # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 431b9d1e0..0b47d9910 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -45,6 +45,7 @@ get_target_idxs_from_cfg, ) from weathergen.utils.distributed import is_root +from weathergen.utils.plot_diffusion_noise import plot_noise_vs_tokens from weathergen.utils.train_logger import TrainLogger, prepare_losses_for_logging from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output @@ -553,6 +554,10 @@ def validate(self, mini_epoch, mode_cfg, batch_size): num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size + # Collect encoded tokens for diffusion noise diagnostic plot + is_diffusion = cf.get("fe_diffusion_model", False) + collected_tokens = [] + with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp with tqdm.tqdm( @@ -598,6 +603,14 @@ def validate(self, mini_epoch, mode_cfg, batch_size): metadata=extract_batch_metadata(batch), ) + # Collect encoded tokens (z_pre_norm) for diffusion diagnostics + if is_diffusion and preds.latent and preds.latent[0]: + latent_state = preds.latent[0].get("latent_state") + if latent_state is not None and latent_state.z_pre_norm is not None: + collected_tokens.append( + latent_state.z_pre_norm.detach().float().cpu().numpy().flatten() + ) + # log output if bidx < num_samples_write: # denormalization function for data @@ -627,9 +640,44 @@ def validate(self, mini_epoch, mode_cfg, batch_size): self._log_terminal(0, mini_epoch, VAL) self._log(VAL) + # Plot diffusion noise vs encoded token distribution (root rank only) + if is_diffusion and collected_tokens and is_root(): + self._plot_diffusion_noise_vs_tokens(mini_epoch, collected_tokens) + # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() + def _plot_diffusion_noise_vs_tokens( + self, mini_epoch: int, collected_tokens: list[np.ndarray] + ) -> None: + """Generate diffusion noise vs encoded token distribution diagnostic plot. + + Called at the end of validation when fe_diffusion_model is True. + + Args: + mini_epoch: Current mini epoch (used in filename). + collected_tokens: List of flattened token arrays from validation batches. + """ + token_values = np.concatenate(collected_tokens) + p_mean = self.cf.get("p_mean", -1.2) + p_std = self.cf.get("p_std", 1.2) + sigma_data = self.cf.get("sigma_data", 0.5) + + output_dir = config.get_path_run(self.cf) + output_path = output_dir / f"diffusion_noise_vs_tokens_epoch{mini_epoch:05d}.png" + + logger.info( + f"Plotting diffusion noise vs tokens: p_mean={p_mean}, p_std={p_std}, " + f"token_mean={token_values.mean():.4f}, token_std={token_values.std():.4f}" + ) + plot_noise_vs_tokens( + p_mean=p_mean, + p_std=p_std, + token_values=token_values, + sigma_data=sigma_data, + output_path=output_path, + ) + def _get_full_model_state_dict(self): maybe_sharded_sd = ( self.model.state_dict() if self.ema_model is None else self.ema_model.state_dict() diff --git a/src/weathergen/utils/plot_diffusion_noise.py b/src/weathergen/utils/plot_diffusion_noise.py new file mode 100644 index 000000000..cfa3c19bf --- /dev/null +++ b/src/weathergen/utils/plot_diffusion_noise.py @@ -0,0 +1,170 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Plotting utilities for comparing the diffusion noise distribution (sigma) +with the distribution of encoded tokens from the model encoder. + +The noise level sigma is derived from the EDM parameterization (Karras et al.): + eta ~ N(0, 1) + sigma = exp(eta * p_std + p_mean) + +So log(sigma) ~ N(p_mean, p_std^2), i.e. sigma follows a log-normal distribution. + +These functions are called from the Trainer during validation to produce diagnostic plots. +""" + +import logging +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") # non-interactive backend for HPC / headless environments +import matplotlib.pyplot as plt +import numpy as np + +logger = logging.getLogger(__name__) + + +def sample_sigma(p_mean: float, p_std: float, n_samples: int = 100_000) -> np.ndarray: + """Sample sigma values from the diffusion noise distribution. + + Args: + p_mean: Mean of log(sigma) distribution. + p_std: Std of log(sigma) distribution. + n_samples: Number of samples to draw. + + Returns: + Array of sigma values. + """ + eta = np.random.standard_normal(n_samples) + return np.exp(eta * p_std + p_mean) + + +def compute_loss_weight(sigma: np.ndarray, sigma_data: float = 0.5) -> np.ndarray: + """Compute the EDM loss weighting lambda(sigma). + + lambda(sigma) = (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2 + """ + return (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + + +def plot_noise_vs_tokens( + p_mean: float, + p_std: float, + token_values: np.ndarray, + sigma_data: float = 0.5, + n_samples: int = 200_000, + output_path: str | Path | None = None, +) -> plt.Figure: + """Plot noise distribution compared with the encoded token value distribution. + + Produces a 2x2 figure: + - Panel 1: Encoded token value distribution (histogram + mean/std lines) + - Panel 2: |token| distribution overlaid with the sigma distribution + - Panel 3: log(sigma) distribution vs log(|token|) distribution + - Panel 4: Noise-to-signal ratio: sigma / token_std + + Args: + p_mean: p_mean hyperparameter from config. + p_std: p_std hyperparameter from config. + token_values: Flattened numpy array of encoded token values (from encoder output). + sigma_data: sigma_data hyperparameter. + n_samples: Number of sigma samples for the noise distribution. + output_path: If set, save figure to this path. + + Returns: + The matplotlib Figure. + """ + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + + token_std = float(np.std(token_values)) + token_mean = float(np.mean(token_values)) + token_abs_mean = float(np.mean(np.abs(token_values))) + sigma = sample_sigma(p_mean, p_std, n_samples) + label_noise = f"sigma (p_mean={p_mean}, p_std={p_std})" + + # --- Panel 1: Token value distribution --- + ax = axes[0, 0] + ax.hist( + token_values, bins=300, density=True, alpha=0.7, color="steelblue", label="Token values" + ) + ax.axvline(token_mean, color="red", ls="--", lw=1.5, label=f"mean={token_mean:.3f}") + ax.axvline(token_mean + token_std, color="orange", ls="--", lw=1, label=f"std={token_std:.3f}") + ax.axvline(token_mean - token_std, color="orange", ls="--", lw=1) + ax.set_xlabel("Token value") + ax.set_ylabel("Density") + ax.set_title(f"Encoded Token Distribution (n={len(token_values):,})") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # --- Panel 2: |token| distribution vs sigma distribution --- + ax = axes[0, 1] + abs_tokens = np.abs(token_values) + ax.hist( + abs_tokens, + bins=300, + density=True, + alpha=0.5, + color="steelblue", + label=f"|tokens| (mean={token_abs_mean:.3f})", + ) + sigma_clipped = sigma[sigma < np.percentile(abs_tokens, 99.5)] + ax.hist(sigma_clipped, bins=200, density=True, alpha=0.4, color="coral", label=label_noise) + ax.set_xlabel("Magnitude") + ax.set_ylabel("Density") + ax.set_title("|Token values| vs sigma magnitude") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # --- Panel 3: log-scale comparison --- + ax = axes[1, 0] + ax.hist( + np.log10(np.abs(token_values) + 1e-12), + bins=200, + density=True, + alpha=0.5, + color="steelblue", + label="log10(|tokens|)", + ) + ax.hist(np.log10(sigma), bins=200, density=True, alpha=0.4, color="coral", label=label_noise) + ax.axvline( + np.log10(sigma_data), color="k", ls="--", lw=1.5, label=f"sigma_data={sigma_data}" + ) + ax.set_xlabel("log10 scale") + ax.set_ylabel("Density") + ax.set_title("log10(|tokens|) vs log10(sigma)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # --- Panel 4: Noise magnitude relative to token std --- + ax = axes[1, 1] + ratio = sigma / (token_std + 1e-12) + ax.hist(np.log10(ratio), bins=200, density=True, alpha=0.6, color="coral", label=label_noise) + ax.axvline(0, color="k", ls="--", lw=1.5, label="sigma = token_std") + ax.set_xlabel("log10(sigma / token_std)") + ax.set_ylabel("Density") + ax.set_title(f"Noise / Token scale ratio (token_std={token_std:.3f})") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + fig.suptitle( + f"Diffusion Noise vs Encoded Tokens | p_mean={p_mean}, p_std={p_std}," + f" sigma_data={sigma_data}", + fontsize=13, + ) + fig.tight_layout() + + if output_path: + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=150, bbox_inches="tight") + logger.info(f"Saved diffusion noise vs tokens plot to {output_path}") + + plt.close(fig) + return fig From 1a04f3354060005f3c931d7d991c70505972e84f Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 11 Mar 2026 13:42:34 +0100 Subject: [PATCH 233/344] plot noise distribution and decoded noised tokens --- config/config_diffusion.yml | 4 +- src/weathergen/model/diffusion.py | 3 + src/weathergen/model/model.py | 35 ++++++++- src/weathergen/utils/plot_diffusion_noise.py | 4 +- src/weathergen/utils/validation_io.py | 81 ++++++++++++++++++++ 5 files changed, 121 insertions(+), 6 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 3a3a3bced..550fb512d 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 0.5 rho: 7 -p_mean: -1.2 +p_mean: 1.2 p_std: 1.2 @@ -207,7 +207,7 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, - num_samples: 3 + num_samples: 1 } } diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 66083738d..4269e32fe 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -58,6 +58,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.p_mean = self.cf.p_mean self.p_std = self.cf.p_std self.cur_token = None # TODO: re move after single sample experiments + self._noised_tokens: torch.Tensor | None = None def forward( self, @@ -104,6 +105,8 @@ def forward( sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma + self._noised_tokens = y + n + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> torch.Tensor: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ea7970259..3e739f67f 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -53,16 +53,29 @@ class ModelOutput: physical: list[dict[StreamName, torch.Tensor]] latent: list[dict[str, torch.Tensor | LatentState]] + noised_physical: list[dict[StreamName, torch.Tensor]] def __init__(self, len_output: int) -> None: self.physical = [{} for _ in range(len_output)] self.latent = [{} for _ in range(len_output)] + self.noised_physical = [{} for _ in range(len_output)] def add_physical_prediction( self, fstep: int, stream_name: StreamName, pred: torch.Tensor ) -> None: self.physical[fstep][stream_name] = pred + def add_noised_physical_prediction( + self, fstep: int, stream_name: StreamName, pred: torch.Tensor + ) -> None: + self.noised_physical[fstep][stream_name] = pred + + def get_noised_physical_prediction(self, fstep: int, stream_name: StreamName | None = None): + pred = self.noised_physical[fstep] + if stream_name is not None: + pred = pred.get(stream_name, None) + return pred + def add_latent_prediction(self, fstep: int, latent_name: str, pred: torch.Tensor) -> None: self.latent[fstep][latent_name] = pred @@ -670,6 +683,22 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) + + # decode noised tokens for visualization (diffusion models only, eval mode) + if ( + not self.training + and isinstance(self.forecast_engine, DiffusionForecastEngine) + and self.forecast_engine._noised_tokens is not None + ): + output = self.predict_decoders( + model_params, + step, + self.forecast_engine._noised_tokens, + batch, + output, + noised=True, + ) + # latent predictions (raw and with SSL heads) output = self.predict_latent(model_params, step, tokens, batch, output) @@ -705,6 +734,7 @@ def predict_decoders( tokens: torch.Tensor, batch: ModelBatch, output: ModelOutput, + noised: bool = False, ) -> ModelOutput: """ Compute decoder-based predictions @@ -806,6 +836,9 @@ def predict_decoders( # recover batch dimension (ragged, so as list) pred = torch.split(pred, t_coords_lens, dim=1) # breakpoint() - output.add_physical_prediction(step, stream_name, pred) + if noised: + output.add_noised_physical_prediction(step, stream_name, pred) + else: + output.add_physical_prediction(step, stream_name, pred) return output diff --git a/src/weathergen/utils/plot_diffusion_noise.py b/src/weathergen/utils/plot_diffusion_noise.py index cfa3c19bf..a9fbe59be 100644 --- a/src/weathergen/utils/plot_diffusion_noise.py +++ b/src/weathergen/utils/plot_diffusion_noise.py @@ -134,9 +134,7 @@ def plot_noise_vs_tokens( label="log10(|tokens|)", ) ax.hist(np.log10(sigma), bins=200, density=True, alpha=0.4, color="coral", label=label_noise) - ax.axvline( - np.log10(sigma_data), color="k", ls="--", lw=1.5, label=f"sigma_data={sigma_data}" - ) + ax.axvline(np.log10(sigma_data), color="k", ls="--", lw=1.5, label=f"sigma_data={sigma_data}") ax.set_xlabel("log10 scale") ax.set_ylabel("Density") ax.set_title("log10(|tokens|) vs log10(sigma)") diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 71f2dccf7..bea091c60 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -85,6 +85,7 @@ def write_output( # collect all target / prediction-related information fp32 = torch.float32 preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], [] + noised_preds_all = [] # decoded noised tokens (diffusion models only) timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() forecast_offset = timestep_idxs[0] @@ -96,6 +97,7 @@ def write_output( targets_all += [[]] targets_coords_all += [[]] targets_times_all += [[]] + noised_preds_all += [[]] targets_lens += [[]] # noise_levels = [] # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. for stream_idx, stream_info in enumerate(cf.streams): @@ -152,6 +154,19 @@ def write_output( targets_coords_all[-1] += [np.concatenate(t_coords_s)] targets_times_all[-1] += [np.concatenate(t_times_s)] + # collect decoded noised tokens (diffusion models only) + noised_preds = model_output.get_noised_physical_prediction(t_idx, sname) + if noised_preds is not None: + noised_s = [] + for i_batch, npred in enumerate(noised_preds): + idxs_inv = target_aux_out.physical[t_idx][sname]["idxs_inv"][i_batch] + if idxs_inv is not None: + npred = npred[:, idxs_inv] + noised_s += [dn_data(sname, npred).detach().to(fp32).cpu().numpy()] + noised_preds_all[-1] += [np.concatenate(noised_s, axis=1)] + else: + noised_preds_all[-1] += [np.array([])] + if len(preds_all) == 0 or np.array([p.shape[1] for pp in preds_all for p in pp]).sum() == 0: _logger.warning("Writing no data since predictions are empty.") return @@ -303,4 +318,70 @@ def write_output( dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" if src != dst and src.exists(): src.replace(dst) + + # Plot decoded noised tokens (diffusion models only) + has_noised = any( + noised_preds_all[t_idx][s_idx].size > 0 + for s_idx in range(len(cf.streams)) + if noised_preds_all[t_idx][s_idx].ndim >= 2 + ) + if has_noised: + for stream_idx, stream_info in enumerate(cf.streams): + stream_name = stream_info["name"] + noised_stream = noised_preds_all[t_idx][stream_idx] + coords_stream = targets_coords_all[t_idx][stream_idx] + + if noised_stream.size == 0 or coords_stream.size == 0: + continue + + if noised_stream.ndim == 3: + noised_stream = noised_stream[0] + elif noised_stream.ndim != 2: + continue + + lat = coords_stream[:, 0] + lon = coords_stream[:, 1] + channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) + + da_noised = xr.DataArray( + noised_stream, + dims=("ipoint", "channel"), + coords={ + "ipoint": np.arange(noised_stream.shape[0]), + "channel": channels, + "lat": ("ipoint", lat), + "lon": ("ipoint", lon), + }, + ) + + plotter.stream = stream_name + plotter.run_id = config.get_run_id_from_config(cf) + plotter.fstep = forecast_offset + + selected_channels = [ + ch for ch in channels if _normalize_channel_name(ch) in headline_channels + ] + if not selected_channels: + continue + + for varname in selected_channels: + data = da_noised.sel(channel=varname).dropna(dim="ipoint") + channel_dir = base_plot_dir / varname / "noised" + channel_dir.mkdir(parents=True, exist_ok=True) + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_noised" + title = f"{stream_name} - {varname} (fstep {forecast_offset}) [noised input]" + + plot_name = plotter.scatter_plot( + data, + channel_dir, + varname=varname, + regionname="global", + tag=epoch_tag, + title=title, + ) + src = channel_dir / f"{plot_name}.{plotter.image_format}" + dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" + if src != dst and src.exists(): + src.replace(dst) + i += 1 From 3f731ad7b1b90918a43f9f7e337af9b4fe17a113 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 12 Mar 2026 11:34:10 +0100 Subject: [PATCH 234/344] fix noise level in validation to p_mean --- config/config_diffusion.yml | 2 +- src/weathergen/model/diffusion.py | 8 ++++++-- src/weathergen/model/model.py | 10 +++++----- src/weathergen/train/target_and_aux_diffusion.py | 11 ++++++++--- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 550fb512d..003645780 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 0.5 rho: 7 -p_mean: 1.2 +p_mean: -1.2 p_std: 1.2 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 4269e32fe..df2e21f4a 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -98,10 +98,14 @@ def forward( c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens - eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) - # Compute sigma (noise level) from eta and create noise tensor + if self.training: + eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) + else: + # During validation, fix sigma to exp(p_mean) by setting eta to the mean of N(0,1) + eta = torch.zeros(1, device=tokens.device) + # Compute sigma (noise level) from eta and create noise tensor sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 3e739f67f..024aad1d8 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -661,10 +661,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # t_mean = tokens.mean() - # t_std = tokens.std() - # tokens = (tokens - t_mean) / (t_std + 1e-6) - # tokens = torch.clamp(tokens, -100.0, 100.0) + t_mean = tokens.mean() + t_std = tokens.std() + tokens = (tokens - t_mean) / (t_std + 1e-6) * cf.p_mean + tokens = torch.clamp(tokens, -100.0, 100.0) # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): @@ -679,7 +679,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Un-normalize tokens # TODO: REMOVE THIS AS ABOVE. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # tokens = tokens * (t_std + 1e-6) + t_mean + tokens = tokens * (t_std + 1e-6) / cf.p_mean + t_mean # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 7cd2f9ea3..3ec99d140 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -55,9 +55,14 @@ def compute( *args, **kwargs, ) -> tuple[Any, Any]: - noise_level_rn = ( - batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] - ) # TODO: adjust for multiple streams + # During validation (model in eval mode), fix noise level to 0.0 + # so that sigma = exp(p_mean), consistent with DiffusionForecastEngine + if model.training: + noise_level_rn = ( + batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] + ) # TODO: adjust for multiple streams + else: + noise_level_rn = 0.0 # TODO: check if there are scenarios where the encoder needs to be set to eval with torch.no_grad(): From 100b5c20648d5391d18de14b43b3e2812cc6d18b Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 12 Mar 2026 12:04:34 +0100 Subject: [PATCH 235/344] rm noise and token distribution plotting --- src/weathergen/model/model.py | 4 +- src/weathergen/train/trainer.py | 48 ------ src/weathergen/utils/plot_diffusion_noise.py | 168 ------------------- src/weathergen/utils/validation_io.py | 2 +- 4 files changed, 3 insertions(+), 219 deletions(-) delete mode 100644 src/weathergen/utils/plot_diffusion_noise.py diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 024aad1d8..49cbf9bf6 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -663,7 +663,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. t_mean = tokens.mean() t_std = tokens.std() - tokens = (tokens - t_mean) / (t_std + 1e-6) * cf.p_mean + tokens = (tokens - t_mean) / (t_std + 1e-6) * self.cf.sigma_data tokens = torch.clamp(tokens, -100.0, 100.0) # roll-out in latent space, iterate and generate output over requested output steps @@ -679,7 +679,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Un-normalize tokens # TODO: REMOVE THIS AS ABOVE. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - tokens = tokens * (t_std + 1e-6) / cf.p_mean + t_mean + tokens = tokens * (t_std + 1e-6) / self.cf.sigma_data + t_mean # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0b47d9910..431b9d1e0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -45,7 +45,6 @@ get_target_idxs_from_cfg, ) from weathergen.utils.distributed import is_root -from weathergen.utils.plot_diffusion_noise import plot_noise_vs_tokens from weathergen.utils.train_logger import TrainLogger, prepare_losses_for_logging from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output @@ -554,10 +553,6 @@ def validate(self, mini_epoch, mode_cfg, batch_size): num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size - # Collect encoded tokens for diffusion noise diagnostic plot - is_diffusion = cf.get("fe_diffusion_model", False) - collected_tokens = [] - with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp with tqdm.tqdm( @@ -603,14 +598,6 @@ def validate(self, mini_epoch, mode_cfg, batch_size): metadata=extract_batch_metadata(batch), ) - # Collect encoded tokens (z_pre_norm) for diffusion diagnostics - if is_diffusion and preds.latent and preds.latent[0]: - latent_state = preds.latent[0].get("latent_state") - if latent_state is not None and latent_state.z_pre_norm is not None: - collected_tokens.append( - latent_state.z_pre_norm.detach().float().cpu().numpy().flatten() - ) - # log output if bidx < num_samples_write: # denormalization function for data @@ -640,44 +627,9 @@ def validate(self, mini_epoch, mode_cfg, batch_size): self._log_terminal(0, mini_epoch, VAL) self._log(VAL) - # Plot diffusion noise vs encoded token distribution (root rank only) - if is_diffusion and collected_tokens and is_root(): - self._plot_diffusion_noise_vs_tokens(mini_epoch, collected_tokens) - # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() - def _plot_diffusion_noise_vs_tokens( - self, mini_epoch: int, collected_tokens: list[np.ndarray] - ) -> None: - """Generate diffusion noise vs encoded token distribution diagnostic plot. - - Called at the end of validation when fe_diffusion_model is True. - - Args: - mini_epoch: Current mini epoch (used in filename). - collected_tokens: List of flattened token arrays from validation batches. - """ - token_values = np.concatenate(collected_tokens) - p_mean = self.cf.get("p_mean", -1.2) - p_std = self.cf.get("p_std", 1.2) - sigma_data = self.cf.get("sigma_data", 0.5) - - output_dir = config.get_path_run(self.cf) - output_path = output_dir / f"diffusion_noise_vs_tokens_epoch{mini_epoch:05d}.png" - - logger.info( - f"Plotting diffusion noise vs tokens: p_mean={p_mean}, p_std={p_std}, " - f"token_mean={token_values.mean():.4f}, token_std={token_values.std():.4f}" - ) - plot_noise_vs_tokens( - p_mean=p_mean, - p_std=p_std, - token_values=token_values, - sigma_data=sigma_data, - output_path=output_path, - ) - def _get_full_model_state_dict(self): maybe_sharded_sd = ( self.model.state_dict() if self.ema_model is None else self.ema_model.state_dict() diff --git a/src/weathergen/utils/plot_diffusion_noise.py b/src/weathergen/utils/plot_diffusion_noise.py deleted file mode 100644 index a9fbe59be..000000000 --- a/src/weathergen/utils/plot_diffusion_noise.py +++ /dev/null @@ -1,168 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -""" -Plotting utilities for comparing the diffusion noise distribution (sigma) -with the distribution of encoded tokens from the model encoder. - -The noise level sigma is derived from the EDM parameterization (Karras et al.): - eta ~ N(0, 1) - sigma = exp(eta * p_std + p_mean) - -So log(sigma) ~ N(p_mean, p_std^2), i.e. sigma follows a log-normal distribution. - -These functions are called from the Trainer during validation to produce diagnostic plots. -""" - -import logging -from pathlib import Path - -import matplotlib - -matplotlib.use("Agg") # non-interactive backend for HPC / headless environments -import matplotlib.pyplot as plt -import numpy as np - -logger = logging.getLogger(__name__) - - -def sample_sigma(p_mean: float, p_std: float, n_samples: int = 100_000) -> np.ndarray: - """Sample sigma values from the diffusion noise distribution. - - Args: - p_mean: Mean of log(sigma) distribution. - p_std: Std of log(sigma) distribution. - n_samples: Number of samples to draw. - - Returns: - Array of sigma values. - """ - eta = np.random.standard_normal(n_samples) - return np.exp(eta * p_std + p_mean) - - -def compute_loss_weight(sigma: np.ndarray, sigma_data: float = 0.5) -> np.ndarray: - """Compute the EDM loss weighting lambda(sigma). - - lambda(sigma) = (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2 - """ - return (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 - - -def plot_noise_vs_tokens( - p_mean: float, - p_std: float, - token_values: np.ndarray, - sigma_data: float = 0.5, - n_samples: int = 200_000, - output_path: str | Path | None = None, -) -> plt.Figure: - """Plot noise distribution compared with the encoded token value distribution. - - Produces a 2x2 figure: - - Panel 1: Encoded token value distribution (histogram + mean/std lines) - - Panel 2: |token| distribution overlaid with the sigma distribution - - Panel 3: log(sigma) distribution vs log(|token|) distribution - - Panel 4: Noise-to-signal ratio: sigma / token_std - - Args: - p_mean: p_mean hyperparameter from config. - p_std: p_std hyperparameter from config. - token_values: Flattened numpy array of encoded token values (from encoder output). - sigma_data: sigma_data hyperparameter. - n_samples: Number of sigma samples for the noise distribution. - output_path: If set, save figure to this path. - - Returns: - The matplotlib Figure. - """ - fig, axes = plt.subplots(2, 2, figsize=(14, 10)) - - token_std = float(np.std(token_values)) - token_mean = float(np.mean(token_values)) - token_abs_mean = float(np.mean(np.abs(token_values))) - sigma = sample_sigma(p_mean, p_std, n_samples) - label_noise = f"sigma (p_mean={p_mean}, p_std={p_std})" - - # --- Panel 1: Token value distribution --- - ax = axes[0, 0] - ax.hist( - token_values, bins=300, density=True, alpha=0.7, color="steelblue", label="Token values" - ) - ax.axvline(token_mean, color="red", ls="--", lw=1.5, label=f"mean={token_mean:.3f}") - ax.axvline(token_mean + token_std, color="orange", ls="--", lw=1, label=f"std={token_std:.3f}") - ax.axvline(token_mean - token_std, color="orange", ls="--", lw=1) - ax.set_xlabel("Token value") - ax.set_ylabel("Density") - ax.set_title(f"Encoded Token Distribution (n={len(token_values):,})") - ax.legend(fontsize=8) - ax.grid(True, alpha=0.3) - - # --- Panel 2: |token| distribution vs sigma distribution --- - ax = axes[0, 1] - abs_tokens = np.abs(token_values) - ax.hist( - abs_tokens, - bins=300, - density=True, - alpha=0.5, - color="steelblue", - label=f"|tokens| (mean={token_abs_mean:.3f})", - ) - sigma_clipped = sigma[sigma < np.percentile(abs_tokens, 99.5)] - ax.hist(sigma_clipped, bins=200, density=True, alpha=0.4, color="coral", label=label_noise) - ax.set_xlabel("Magnitude") - ax.set_ylabel("Density") - ax.set_title("|Token values| vs sigma magnitude") - ax.legend(fontsize=8) - ax.grid(True, alpha=0.3) - - # --- Panel 3: log-scale comparison --- - ax = axes[1, 0] - ax.hist( - np.log10(np.abs(token_values) + 1e-12), - bins=200, - density=True, - alpha=0.5, - color="steelblue", - label="log10(|tokens|)", - ) - ax.hist(np.log10(sigma), bins=200, density=True, alpha=0.4, color="coral", label=label_noise) - ax.axvline(np.log10(sigma_data), color="k", ls="--", lw=1.5, label=f"sigma_data={sigma_data}") - ax.set_xlabel("log10 scale") - ax.set_ylabel("Density") - ax.set_title("log10(|tokens|) vs log10(sigma)") - ax.legend(fontsize=8) - ax.grid(True, alpha=0.3) - - # --- Panel 4: Noise magnitude relative to token std --- - ax = axes[1, 1] - ratio = sigma / (token_std + 1e-12) - ax.hist(np.log10(ratio), bins=200, density=True, alpha=0.6, color="coral", label=label_noise) - ax.axvline(0, color="k", ls="--", lw=1.5, label="sigma = token_std") - ax.set_xlabel("log10(sigma / token_std)") - ax.set_ylabel("Density") - ax.set_title(f"Noise / Token scale ratio (token_std={token_std:.3f})") - ax.legend(fontsize=8) - ax.grid(True, alpha=0.3) - - fig.suptitle( - f"Diffusion Noise vs Encoded Tokens | p_mean={p_mean}, p_std={p_std}," - f" sigma_data={sigma_data}", - fontsize=13, - ) - fig.tight_layout() - - if output_path: - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - fig.savefig(str(output_path), dpi=150, bbox_inches="tight") - logger.info(f"Saved diffusion noise vs tokens plot to {output_path}") - - plt.close(fig) - return fig diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index bea091c60..67c766e23 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -244,7 +244,7 @@ def write_output( base_plot_dir.mkdir(parents=True, exist_ok=True) plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) # headline_channels = {"2t", "z500", "q850", "10u", "10v"} - headline_channels = {"2t"} + headline_channels = {"2t", "q850"} t_idx = 0 for stream_idx, stream_info in enumerate(cf.streams): From 89d5c855f5773bdea5d0e28bedb10e55a890a381 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 12 Mar 2026 13:55:57 +0100 Subject: [PATCH 236/344] add multple fixed val noise levels --- config/config_diffusion.yml | 5 + src/weathergen/model/diffusion.py | 6 +- .../loss_module_latent_diffusion.py | 3 +- .../train/target_and_aux_diffusion.py | 7 +- src/weathergen/train/trainer.py | 193 +++++++++++------- 5 files changed, 132 insertions(+), 82 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 003645780..170ca1992 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -221,6 +221,11 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [0.0, -1.0, 1.0] + samples_per_mini_epoch: 16 shuffle: False diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index df2e21f4a..9d0f6fba0 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,6 +59,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.p_std = self.cf.p_std self.cur_token = None # TODO: re move after single sample experiments self._noised_tokens: torch.Tensor | None = None + self._fixed_noise_level: float | None = None def forward( self, @@ -102,8 +103,9 @@ def forward( if self.training: eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) else: - # During validation, fix sigma to exp(p_mean) by setting eta to the mean of N(0,1) - eta = torch.zeros(1, device=tokens.device) + # During validation, use fixed noise level (default: 0.0 = mean of noise distribution) + noise_level = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 + eta = torch.tensor([noise_level], device=tokens.device) # Compute sigma (noise level) from eta and create noise tensor sigma = (eta * self.p_std + self.p_mean).exp() diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index ed076aee2..7f8ab0357 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -100,7 +100,8 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: ) fsteps = len(target_tokens_all) - noise_weight = self._get_noise_weight(eta) + # During validation, use unweighted loss (no noise-level scaling) + noise_weight = 1.0 if self.stage == "val" else self._get_noise_weight(eta) fstep_loss_weights = self._get_fstep_weights(fsteps) loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 3ec99d140..a857bce0b 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -20,6 +20,7 @@ def __init__(self, encoder, is_model_sharded=True): apply_fct_to_blocks(self.encoder, ".*", set_to_eval) self.is_model_sharded = is_model_sharded + self._fixed_noise_level: float | None = None # Build a name → param map once self.src_params = dict(self.encoder.named_parameters()) @@ -55,14 +56,14 @@ def compute( *args, **kwargs, ) -> tuple[Any, Any]: - # During validation (model in eval mode), fix noise level to 0.0 - # so that sigma = exp(p_mean), consistent with DiffusionForecastEngine + # During validation (model in eval mode), use fixed noise level + # so that sigma = exp(eta * p_std + p_mean) is deterministic if model.training: noise_level_rn = ( batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] ) # TODO: adjust for multiple streams else: - noise_level_rn = 0.0 + noise_level_rn = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 # TODO: check if there are scenarios where the encoder needs to be set to eval with torch.no_grad(): diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 431b9d1e0..45b3bbe54 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -543,93 +543,131 @@ def train(self, mini_epoch): def validate(self, mini_epoch, mode_cfg, batch_size): """ - Perform validation / test computation as specified by mode_cfg + Perform validation / test computation as specified by mode_cfg. + + For diffusion models, runs separate validation passes for each noise level + specified in ``validation_noise_levels`` (defaults to ``[0.0]``). + Losses are logged with a per-noise-level suffix so they can be compared. """ cf = self.cf self.model.eval() - dataset_val_iter = iter(self.data_loader_validation) - - num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size - - with torch.no_grad(): - # print progress bar but only in interactive mode, i.e. when without ddp - with tqdm.tqdm( - total=len(self.data_loader_validation), disable=self.cf.with_ddp - ) as pbar: - for bidx, batch in enumerate(dataset_val_iter): - if cf.data_loading.get("memory_pinning", False): - # pin memory for faster CPU-GPU transfer - batch = batch.pin_memory() - - batch.to_device(self.device) - - # evaluate model - with torch.autocast( - device_type=f"cuda:{cf.local_rank}", - dtype=self.mixed_precision_dtype, - enabled=cf.with_mixed_precision, - ): - if self.ema_model is None: - preds = self.model( - self.model_params, - batch.get_source_samples(), - ) - else: - preds = self.ema_model.forward_eval( - self.model_params, - batch.get_source_samples(), - ) + is_diffusion = cf.get("fe_diffusion_model", False) + noise_levels = list(mode_cfg.get("validation_noise_levels", [0.0])) + if not is_diffusion: + noise_levels = [0.0] - targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): - target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) - targets_and_auxs[loss_name] = target_aux.compute( - self.cf.general.istep, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) + for noise_idx, noise_level in enumerate(noise_levels): + if is_diffusion: + self._set_validation_noise_level(noise_level) - _ = self.loss_calculator_val.compute_loss( - preds=preds, - targets_and_aux=targets_and_auxs, - metadata=extract_batch_metadata(batch), - ) + stage_suffix = f"_eta{noise_level:.2f}" if len(noise_levels) > 1 else "" + write_samples = noise_idx == 0 - # log output - if bidx < num_samples_write: - # denormalization function for data - denormalize_data_fct = ( - (lambda x0, x1: x1) - if mode_cfg.get("output", {}).get("normalized_samples", False) - else self.dataset_val.denormalize_target_channels - ) - # write output - write_output( - self.cf, - mode_cfg, - batch_size, - mini_epoch, - bidx, - denormalize_data_fct, - batch, - preds, - targets_and_auxs, + dataset_val_iter = iter(self.data_loader_validation) + num_samples_write = ( + mode_cfg.get("output", {}).get("num_samples", 0) * batch_size + if write_samples + else 0 + ) + + with torch.no_grad(): + # print progress bar but only in interactive mode, i.e. when without ddp + with tqdm.tqdm( + total=len(self.data_loader_validation), disable=self.cf.with_ddp + ) as pbar: + for bidx, batch in enumerate(dataset_val_iter): + if cf.data_loading.get("memory_pinning", False): + # pin memory for faster CPU-GPU transfer + batch = batch.pin_memory() + + batch.to_device(self.device) + + # evaluate model + with torch.autocast( + device_type=f"cuda:{cf.local_rank}", + dtype=self.mixed_precision_dtype, + enabled=cf.with_mixed_precision, + ): + if self.ema_model is None: + preds = self.model( + self.model_params, + batch.get_source_samples(), + ) + else: + preds = self.ema_model.forward_eval( + self.model_params, + batch.get_source_samples(), + ) + + targets_and_auxs = {} + for loss_name, target_aux in self.target_and_aux_calculators_val.items(): + target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) + targets_and_auxs[loss_name] = target_aux.compute( + self.cf.general.istep, + batch.get_target_samples(target_idxs), + self.model_params, + self.model, + ) + + _ = self.loss_calculator_val.compute_loss( + preds=preds, + targets_and_aux=targets_and_auxs, + metadata=extract_batch_metadata(batch), ) - pbar.update(batch_size) + # log output + if bidx < num_samples_write: + # denormalization function for data + denormalize_data_fct = ( + (lambda x0, x1: x1) + if mode_cfg.get("output", {}).get("normalized_samples", False) + else self.dataset_val.denormalize_target_channels + ) + # write output + write_output( + self.cf, + mode_cfg, + batch_size, + mini_epoch, + bidx, + denormalize_data_fct, + batch, + preds, + targets_and_auxs, + ) + + pbar.update(batch_size) - if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: - break + if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: + break - self._log_terminal(0, mini_epoch, VAL) - self._log(VAL) + self._log_terminal(0, mini_epoch, VAL, stage_suffix=stage_suffix) + self._log(VAL, stage_suffix=stage_suffix) + + # reset fixed noise level + if is_diffusion: + self._set_validation_noise_level(None) # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() + def _set_validation_noise_level(self, noise_level: float | None): + """Set fixed noise level on diffusion components for validation. + + Args: + noise_level: The eta value (standard normal space) to fix for validation. + sigma = exp(eta * p_std + p_mean). None resets to default (0.0). + """ + if hasattr(self.model, "forecast_engine") and hasattr( + self.model.forecast_engine, "_fixed_noise_level" + ): + self.model.forecast_engine._fixed_noise_level = noise_level + for calc in self.target_and_aux_calculators_val.values(): + if hasattr(calc, "_fixed_noise_level"): + calc._fixed_noise_level = noise_level + def _get_full_model_state_dict(self): maybe_sharded_sd = ( self.model.state_dict() if self.ema_model is None else self.ema_model.state_dict() @@ -704,13 +742,15 @@ def save_model(self, mini_epoch: int, name=None): # save config config.save(self.cf, mini_epoch) - def _log(self, stage: Stage): + def _log(self, stage: Stage, stage_suffix: str = ""): """ Logs training or validation metrics. Args: stage: Stage Is it's VAL, logs are treated as validation logs. If TRAIN, logs are treated as training logs + stage_suffix: Optional suffix appended to the logged stage name + (e.g. "_eta0.00" for per-noise-level validation). Notes: - This method only executes logging on the main process (rank 0). @@ -724,15 +764,16 @@ def _log(self, stage: Stage): ) samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + log_stage = f"{stage}{stage_suffix}" if stage_suffix else stage if is_root(): # plain logger if stage == VAL: - self.train_logger.add_logs(stage, samples, losses_all, stddev_all) + self.train_logger.add_logs(log_stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: self.train_logger.add_logs( - stage, + log_stage, samples, losses_all, stddev_all, @@ -764,7 +805,7 @@ def _log_instant_grad_norms(self, stage: Stage): if is_root(): self.train_logger.log_metrics(stage, grad_norms) - def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): + def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage, stage_suffix: str = ""): print_freq = self.train_logging.terminal if bidx % print_freq == 0 and bidx > 0 or stage == VAL: # compute from last iteration @@ -778,7 +819,7 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): if is_root(): if stage == VAL: logger.info( - f"""validation ({self.cf.general.run_id}) : {mini_epoch:03d} : + f"""validation{stage_suffix} ({self.cf.general.run_id}) : {mini_epoch:03d} : {np.nanmean(avg_loss)}""" ) From 79dcc902cd657efd34128cf6f839548fc8393f65 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 13 Mar 2026 10:03:27 +0100 Subject: [PATCH 237/344] enable multiple fixed val noise levels --- config/config_diffusion.yml | 4 +-- src/weathergen/train/trainer.py | 51 ++++++++++++++++++++++----- src/weathergen/utils/validation_io.py | 47 +++++++++++++++++------- 3 files changed, 78 insertions(+), 24 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 170ca1992..eb5b9055e 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 0.5 rho: 7 -p_mean: -1.2 +p_mean: 1.2 p_std: 1.2 @@ -224,7 +224,7 @@ validation_config: # Noise levels (eta values in standard normal space) at which to evaluate the # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. - validation_noise_levels: [0.0, -1.0, 1.0] + validation_noise_levels: [0.03, 0.3, 3.0] samples_per_mini_epoch: 16 shuffle: False diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 45b3bbe54..a3f02b3fa 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -10,6 +10,7 @@ # nor does it submit to any jurisdiction. import copy import logging +import re import time import numpy as np @@ -558,19 +559,21 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if not is_diffusion: noise_levels = [0.0] + # Accumulate losses across noise levels with suffixed keys so they are + # logged as a single "val" entry (e.g. LossLatentDiff.LossLatentDiff.mse.eta0.03) + all_losses: dict[str, list] = {} + all_stddev: dict[str, list] = {} + for noise_idx, noise_level in enumerate(noise_levels): if is_diffusion: self._set_validation_noise_level(noise_level) - stage_suffix = f"_eta{noise_level:.2f}" if len(noise_levels) > 1 else "" - write_samples = noise_idx == 0 + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + loss_suffix = f".eta{eta_str}" if len(noise_levels) > 1 else "" + stage_suffix = f"_eta{eta_str}" if len(noise_levels) > 1 else "" dataset_val_iter = iter(self.data_loader_validation) - num_samples_write = ( - mode_cfg.get("output", {}).get("num_samples", 0) * batch_size - if write_samples - else 0 - ) + num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp @@ -625,7 +628,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if mode_cfg.get("output", {}).get("normalized_samples", False) else self.dataset_val.denormalize_target_channels ) - # write output + # write output (zarr only for first noise level, plots for all) write_output( self.cf, mode_cfg, @@ -636,6 +639,8 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch, preds, targets_and_auxs, + noise_level=noise_level if is_diffusion and len(noise_levels) > 1 else None, + write_zarr=(noise_idx == 0), ) pbar.update(batch_size) @@ -643,8 +648,28 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: break + # Terminal logging per noise level for progress visibility self._log_terminal(0, mini_epoch, VAL, stage_suffix=stage_suffix) - self._log(VAL, stage_suffix=stage_suffix) + + # Extract losses for this noise level, suffix keys, and accumulate + loss_calc = self.loss_calculator_val + _, losses_level, stddev_level = prepare_losses_for_logging( + loss_calc.loss_hist, + loss_calc.losses_unweighted_hist, + loss_calc.stddev_unweighted_hist, + ) + for key, value in losses_level.items(): + all_losses[f"{key}{loss_suffix}"] = value + for key, value in stddev_level.items(): + all_stddev[f"{key}{loss_suffix}"] = value + loss_calc.loss_hist = [] + loss_calc.losses_unweighted_hist = [] + loss_calc.stddev_unweighted_hist = [] + + # Log all noise levels as a single "val" entry with suffixed loss keys + samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + if is_root(): + self.train_logger.add_logs(VAL, samples, all_losses, all_stddev) # reset fixed noise level if is_diffusion: @@ -660,10 +685,18 @@ def _set_validation_noise_level(self, noise_level: float | None): noise_level: The eta value (standard normal space) to fix for validation. sigma = exp(eta * p_std + p_mean). None resets to default (0.0). """ + # Set on the base model if hasattr(self.model, "forecast_engine") and hasattr( self.model.forecast_engine, "_fixed_noise_level" ): self.model.forecast_engine._fixed_noise_level = noise_level + # Also set on the EMA model (separate model copy used during validation) + if self.ema_model is not None: + ema_net = self.ema_model.ema_model + if hasattr(ema_net, "forecast_engine") and hasattr( + ema_net.forecast_engine, "_fixed_noise_level" + ): + ema_net.forecast_engine._fixed_noise_level = noise_level for calc in self.target_and_aux_calculators_val.values(): if hasattr(calc, "_fixed_noise_level"): calc._fixed_noise_level = noise_level diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 67c766e23..bce38306b 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import logging +import re import numpy as np import torch @@ -65,10 +66,21 @@ def _resolve_channel_names(stream_info, raw_channels): def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out + cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out, + noise_level=None, + write_zarr=True, ): """ Interface for writing model output + + Parameters + ---------- + noise_level : float | None + Fixed diffusion noise level (eta) used for this validation pass. + When not None the value is embedded in plot filenames and titles. + write_zarr : bool + Whether to write zarr output. Default True. Set to False to only + generate plots without writing zarr data. """ # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. global i @@ -233,9 +245,10 @@ def write_output( sample_start, forecast_offset, ) - with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: - for subset in data.items(): - zio.write_zarr(subset) + if write_zarr: + with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: + for subset in data.items(): + zio.write_zarr(subset) # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. @@ -295,14 +308,14 @@ def write_output( data = da.sel(channel=varname).dropna(dim="ipoint") channel_dir = base_plot_dir / varname channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}" - # Add noise_level_rn to title if present for this stream - # noise_level = noise_levels[stream_idx] - noise_level = ( - None # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - ) if noise_level is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noise_level_rn={noise_level:.4f}" + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + else: + eta_str = None + eta_tag = f"_eta{eta_str}" if eta_str is not None else "" + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" + if eta_str is not None: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | eta={eta_str}" else: title = f"{stream_name} - {varname} (fstep {forecast_offset})" @@ -368,8 +381,16 @@ def write_output( data = da_noised.sel(channel=varname).dropna(dim="ipoint") channel_dir = base_plot_dir / varname / "noised" channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_noised" - title = f"{stream_name} - {varname} (fstep {forecast_offset}) [noised input]" + if noise_level is not None: + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + else: + eta_str = None + eta_tag = f"_eta{eta_str}" if eta_str is not None else "" + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}_noised" + if eta_str is not None: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) [noised input] | eta={eta_str}" + else: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) [noised input]" plot_name = plotter.scatter_plot( data, From 162a8606ecd5137d48096fbf32646f74df0767bf Mon Sep 17 00:00:00 2001 From: Moritz Hauschulz <60788263+moritzhauschulz@users.noreply.github.com> Date: Sun, 15 Mar 2026 19:41:52 +0000 Subject: [PATCH 238/344] Mh/single sample diffusion plotting update (#2049) * grad accumulation first try * better memory * revert grad accumulation * Jk/mk/mh/diffusion single sample fix val (#2045) * Improve support for latent losses (#1963) * Revert 2D rope to false by default (#1967) Set to True by accident * Implementation of DataReaderMesh (#1840) * First implementation of DataReaderMesh * Move to datareaders extra * ruff * ruff2 * Undo ruff * undo auto-linting * correct typo in eval config (#1971) * Added all-physical-streams option and x/y axis limits (#1972) * Added all-physical-streams option and x/y axis limits * Fix * Changed flag for all streams * Removed old code * moved metric parsing to eval_from_config (#1977) Co-authored-by: buschow1 * Fixed integration test (#1980) * [1974][model] Add fallback to config loading (#1985) * Add fallback to config loading * Adjust error message to be not misleading * Homegenize naming convention * Introduce bias/diff maps and animations (#1912) * Introduce bias/diff maps and animations * minor correction * Changes based on review * Introduce "plot_bias" in evaluation configuration (#1986) * Fixed index ordering to not have shuffled output (#1982) * Fixed idxs_inv to revert data point shuffeling * Fixed output handling * Handling of empty data case, addressing reviewer comment * [1893][eval] csvreader cleanup (#1906) * refactor csvreader * check if dataarray size is 0 * fix and use original logic for empty data * linting fixes * revert assertions back * [1890][eval] Move MergeReader to own module (#1892) * move mergereader * use assertions only * implement scoring for the sub-steps within the forecast window (#1896) * work in progress * working for forecast_step * working version * restore no valid times option * lint * Rename scale_z_channels to _scale_z_channels * fix 1 sample bug * Remove points_per_sample from ReaderOutput Remove points_per_sample from ReaderOutput return. * remove n_point_per_sample * fix lead time coord in compute_scores * lint * fix integration test * Fix integration test single stream (#1996) * fix test single * change yml extension and minor fixes --------- Co-authored-by: cosi1 Co-authored-by: cosi1 * [1907][eval] clean up wegen_reader.py (#1911) * clean up wegen_reader.py * remove exception * consistent reader naming * add blank line * use assertions only * make names consistent * Merge branch 'develop' into 1907-wegenreader-cleanup * revert is_regular --------- Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: Ilaria Luise * [1888][eval] Refactor Reader class (#1889) * refactor Reader * use assertion only * fix npp atms --------- Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: Ilaria Luise * [1975][model] Load model path from private repo instead of json (#1998) * Load model path from private repo instead of json * Lint * Script to compute spatial autocorrelation of structured/unstructured datasets (#1955) * standalone script to compute spatial autocorrelation of variables in a structured or unstructured dataset * remove commits that should be in pr 1951 * lint * addressed comments * removed last failure returning 500km default, and moved to packages science * updated a note * rename autocorrelation script * update example usage * Correct EMA halflife_steps calculation with rampup_ratio (#2001) Corrected rampup calculation: https://github.com/NVlabs/edm2/blob/4bf8162f601bcc09472ce8a32dd0cbe8889dc8fc/training/phema.py#L145 Co-authored-by: Wael * Reduce verbosity of output during inference and evaluation (#2006) * Fix incorrect length in validation progress bar * Removing too verbose output * [1766][1743][1332] lint and unit-test fix (#1802) * [1766][1742] fix lint and unit-test * [1766] fix linter * [1766] lint local and global consistent * [1332] add script to detect bad functions (getattr) * code quality: lint and bad functions * [1766] disable some checks * [1877] Script to populate PR labels from linked issues (#1878) * script * branch * more dirs * typo * enable * Fixed bug in linear embedding (#2012) * Adding forecast_steps feature to plot_train (#2010) * Adding forecast_steps feature to plot_train * Renamed arguement to conform to hyphen convention * Added forecast step to filename --------- Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> * add noise distribution plotting * plot noise distribution and decoded noised tokens * fix noise level in validation to p_mean * rm noise and token distribution plotting --------- Co-authored-by: Christian Lessig Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> Co-authored-by: Kacper Nowak Co-authored-by: Till Hauer Co-authored-by: s6sebusc <49226935+s6sebusc@users.noreply.github.com> Co-authored-by: buschow1 Co-authored-by: Matthias Karlbauer Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Co-authored-by: Michael Tarnawa <18899420+mtar@users.noreply.github.com> Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: pierluigicosi <91318382+pierluigicosi@users.noreply.github.com> Co-authored-by: cosi1 Co-authored-by: cosi1 Co-authored-by: Ilaria Luise Co-authored-by: Wael Co-authored-by: Simone Norberti <63310821+simone99n@users.noreply.github.com> Co-authored-by: Timothy Hunter * implement zero-weight physical loss * plot config file * deletion * merged with noised plotting * revert some changes * revert change for fixed eta * revert changes * revert more changes * fixed noised plotting * config changes --------- Co-authored-by: Julian Kuehnert Co-authored-by: Christian Lessig Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> Co-authored-by: Kacper Nowak Co-authored-by: Till Hauer Co-authored-by: s6sebusc <49226935+s6sebusc@users.noreply.github.com> Co-authored-by: buschow1 Co-authored-by: Matthias Karlbauer Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Co-authored-by: Michael Tarnawa <18899420+mtar@users.noreply.github.com> Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: pierluigicosi <91318382+pierluigicosi@users.noreply.github.com> Co-authored-by: cosi1 Co-authored-by: cosi1 Co-authored-by: Ilaria Luise Co-authored-by: Wael Co-authored-by: Simone Norberti <63310821+simone99n@users.noreply.github.com> Co-authored-by: Timothy Hunter --- config/config_diffusion.yml | 4 +- ...rain.yml => runs_plot_train_diffusion.yml} | 6 +- src/weathergen/train/loss_calculator.py | 12 +- src/weathergen/utils/validation_io.py | 244 +++++++++++------- 4 files changed, 163 insertions(+), 103 deletions(-) rename config/{runs_plot_train.yml => runs_plot_train_diffusion.yml} (88%) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index eb5b9055e..9d4074e88 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 50000 sigma_data: 0.5 rho: 7 -p_mean: 1.2 +p_mean: -1.2 p_std: 1.2 @@ -224,7 +224,7 @@ validation_config: # Noise levels (eta values in standard normal space) at which to evaluate the # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. - validation_noise_levels: [0.03, 0.3, 3.0] + validation_noise_levels: [0.03, 0.3, 1, 3.0] samples_per_mini_epoch: 16 shuffle: False diff --git a/config/runs_plot_train.yml b/config/runs_plot_train_diffusion.yml similarity index 88% rename from config/runs_plot_train.yml rename to config/runs_plot_train_diffusion.yml index 8613cc92f..9ac76fd22 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train_diffusion.yml @@ -1,8 +1,8 @@ train : plot : - # crn7ov5y: - # slurm_id: 0 - # description: "crn7ov5y: no noise, lat 1, phys 0, start_lr=1e-5, max_lr=1e-4" + gfaszkoh: + slurm_id: 0 + description: "gfaszkoh: EDM settings, single sample, lat 1, phys 0, start_lr=1e-5, max_lr=1e-4" # h3of5mec: # slurm_id: 0 diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index c3627c97e..bf84881a7 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -90,14 +90,14 @@ def compute_loss( for loss_term_name, calc_term in self.loss_calculators.items(): target = targets_and_aux[loss_term_name] for weight, calculator in calc_term: + loss_values = calculator.compute_loss( + preds=preds, targets=target, metadata=metadata + ) if weight > 0.0: - loss_values = calculator.compute_loss( - preds=preds, targets=target, metadata=metadata - ) loss = loss + weight * loss_values.loss - losses_all[calculator.name] = loss_values.losses_all - losses_all[calculator.name]["loss_avg"] = loss_values.loss - stddev_all[calculator.name] = loss_values.stddev_all + losses_all[calculator.name] = loss_values.losses_all + losses_all[calculator.name]["loss_avg"] = loss_values.loss + stddev_all[calculator.name] = loss_values.stddev_all # Keep histories for logging self.loss_hist += [loss.detach()] diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index bce38306b..deff62042 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import logging +from math import exp import re import numpy as np @@ -66,7 +67,7 @@ def _resolve_channel_names(stream_info, raw_channels): def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out, + cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out, noise_level=None, write_zarr=True, ): @@ -111,12 +112,12 @@ def write_output( targets_times_all += [[]] noised_preds_all += [[]] targets_lens += [[]] - # noise_levels = [] # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. for stream_idx, stream_info in enumerate(cf.streams): sname = stream_info["name"] # handle spoof data: do not write since it might corrupt validation (spoofing invisible # there) + if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: preds = model_output.get_physical_prediction(t_idx, sname) preds_shape = preds[0].shape @@ -129,7 +130,7 @@ def write_output( else: preds = model_output.get_physical_prediction(t_idx, sname) targets = target_aux_out.physical[t_idx][sname]["target"] - + preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] # handle forcing streams or if sample is empty @@ -157,6 +158,7 @@ def write_output( # extract original target coords and times from target data t_coords_s += [t_coords.cpu().numpy()] t_times_s += [t_times.astype("datetime64[ns]")] + targets_lens[-1] += [[]] targets_lens[-1][-1] += [t.shape[0] for t in targets_s] @@ -250,6 +252,10 @@ def write_output( for subset in data.items(): zio.write_zarr(subset) + + # Free arrays no longer needed after zarr writing + del targets_all, targets_times_all, targets_lens, sources, data + # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. # Prepare prediction data for Plotter (scatter plot expects lat/lon coords on ipoint). @@ -278,59 +284,87 @@ def write_output( ) continue - lat = coords_stream[:, 0] - lon = coords_stream[:, 1] channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) - - da = xr.DataArray( - preds_stream, - dims=("ipoint", "channel"), - coords={ - "ipoint": np.arange(preds_stream.shape[0]), - "channel": channels, - "lat": ("ipoint", lat), - "lon": ("ipoint", lon), - }, - ) - - plotter.stream = stream_name - plotter.run_id = config.get_run_id_from_config(cf) - plotter.fstep = forecast_offset - selected_channels = [ ch for ch in channels if _normalize_channel_name(ch) in headline_channels ] if not selected_channels: _logger.warning(f"No headline channels available for plotting stream {stream_name}.") + del preds_stream, coords_stream continue - for varname in selected_channels: - data = da.sel(channel=varname).dropna(dim="ipoint") - channel_dir = base_plot_dir / varname - channel_dir.mkdir(parents=True, exist_ok=True) - if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) - else: - eta_str = None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - if eta_str is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | eta={eta_str}" - else: - title = f"{stream_name} - {varname} (fstep {forecast_offset})" - - plot_name = plotter.scatter_plot( - data, - channel_dir, - varname=varname, - regionname="global", - tag=epoch_tag, - title=title, - ) - src = channel_dir / f"{plot_name}.{plotter.image_format}" - dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst and src.exists(): - src.replace(dst) + # Build a channel index map so we can slice numpy arrays directly + # instead of constructing a full xarray DataArray for all channels. + ch_to_col = {ch: idx for idx, ch in enumerate(channels)} + + lat = coords_stream[:, 0] + lon = coords_stream[:, 1] + + plotter.stream = stream_name + plotter.run_id = config.get_run_id_from_config(cf) + plotter.fstep = forecast_offset + + num_samples = len(preds) + len_per_sample = preds_stream.shape[0] // num_samples + + for sample in range(num_samples): + s_start = sample * len_per_sample + s_end = (sample + 1) * len_per_sample + + for varname in selected_channels: + col = ch_to_col[varname] + vals = preds_stream[s_start:s_end, col] + sample_lat = lat[s_start:s_end] + sample_lon = lon[s_start:s_end] + + # Drop NaN points + valid = ~np.isnan(vals) + vals = vals[valid] + sample_lat = sample_lat[valid] + sample_lon = sample_lon[valid] + + sample_da = xr.DataArray( + vals, + dims=("ipoint",), + coords={ + "ipoint": np.arange(len(vals)), + "lat": ("ipoint", sample_lat), + "lon": ("ipoint", sample_lon), + }, + ) + + channel_dir = base_plot_dir / varname + channel_dir.mkdir(parents=True, exist_ok=True) + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}" + # Add noise_level_rn to title if present for this stream + if noise_level is not None: + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + else: + eta_str = None + eta_tag = f"_eta{eta_str}" if eta_str is not None else "" + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" + + if noise_level is not None: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1} | noise_level={eta_str}" + else: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}" + + plot_name = plotter.scatter_plot( + sample_da, + channel_dir, + varname=varname, + regionname="global", + tag=epoch_tag, + title=title, + ) + src = channel_dir / f"{plot_name}.{plotter.image_format}" + dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" + if src != dst and src.exists(): + src.replace(dst) + + del sample_da, vals, sample_lat, sample_lon, valid + + del preds_stream, coords_stream # Plot decoded noised tokens (diffusion models only) has_noised = any( @@ -352,57 +386,83 @@ def write_output( elif noised_stream.ndim != 2: continue - lat = coords_stream[:, 0] - lon = coords_stream[:, 1] channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) - - da_noised = xr.DataArray( - noised_stream, - dims=("ipoint", "channel"), - coords={ - "ipoint": np.arange(noised_stream.shape[0]), - "channel": channels, - "lat": ("ipoint", lat), - "lon": ("ipoint", lon), - }, - ) - - plotter.stream = stream_name - plotter.run_id = config.get_run_id_from_config(cf) - plotter.fstep = forecast_offset - selected_channels = [ ch for ch in channels if _normalize_channel_name(ch) in headline_channels ] if not selected_channels: + del noised_stream, coords_stream continue - for varname in selected_channels: - data = da_noised.sel(channel=varname).dropna(dim="ipoint") - channel_dir = base_plot_dir / varname / "noised" - channel_dir.mkdir(parents=True, exist_ok=True) - if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) - else: - eta_str = None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}_noised" - if eta_str is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) [noised input] | eta={eta_str}" - else: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) [noised input]" + ch_to_col = {ch: idx for idx, ch in enumerate(channels)} - plot_name = plotter.scatter_plot( - data, - channel_dir, - varname=varname, - regionname="global", - tag=epoch_tag, - title=title, - ) - src = channel_dir / f"{plot_name}.{plotter.image_format}" - dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst and src.exists(): - src.replace(dst) + lat = coords_stream[:, 0] + lon = coords_stream[:, 1] + + plotter.stream = stream_name + plotter.run_id = config.get_run_id_from_config(cf) + plotter.fstep = forecast_offset + + num_samples = len(preds) + len_per_sample = noised_stream.shape[0] // num_samples + + for sample in range(num_samples): + s_start = sample * len_per_sample + s_end = (sample + 1) * len_per_sample + + for varname in selected_channels: + col = ch_to_col[varname] + vals = noised_stream[s_start:s_end, col] + sample_lat = lat[s_start:s_end] + sample_lon = lon[s_start:s_end] + + # Drop NaN points + valid = ~np.isnan(vals) + vals = vals[valid] + sample_lat = sample_lat[valid] + sample_lon = sample_lon[valid] + + sample_da = xr.DataArray( + vals, + dims=("ipoint",), + coords={ + "ipoint": np.arange(len(vals)), + "lat": ("ipoint", sample_lat), + "lon": ("ipoint", sample_lon), + }, + ) + + channel_dir = base_plot_dir / varname / "noised" + channel_dir.mkdir(parents=True, exist_ok=True) + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}_noised" + + if noise_level is not None: + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + else: + eta_str = None + eta_tag = f"_eta{eta_str}" if eta_str is not None else "" + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" + + if noise_level is not None: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noised sample {sample + 1} | noise_level={eta_str}" + else: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noised sample {sample + 1}" + + plot_name = plotter.scatter_plot( + sample_da, + channel_dir, + varname=varname, + regionname="global", + tag=epoch_tag, + title=title, + ) + src = channel_dir / f"{plot_name}.{plotter.image_format}" + dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" + if src != dst and src.exists(): + src.replace(dst) + + del sample_da, vals, sample_lat, sample_lon, valid + + del noised_stream, coords_stream i += 1 From 7ed3b8de69e12cf1adbe4406909448fa3903ed18 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Sun, 15 Mar 2026 20:50:05 +0100 Subject: [PATCH 239/344] not write zarr outputs, update plot_training --- src/weathergen/train/trainer.py | 2 +- src/weathergen/utils/plot_training.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index a3f02b3fa..9ec8330b7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -640,7 +640,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): preds, targets_and_auxs, noise_level=noise_level if is_diffusion and len(noise_levels) > 1 else None, - write_zarr=(noise_idx == 0), + write_zarr=False, #(noise_idx == 0), ) pbar.update(batch_size) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 03a7b9646..ea5d960b4 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -366,7 +366,7 @@ def plot_loss_per_stream( if len(col_split) < 4: if stream_name in col: data_cols += [col] - elif len(col_split) == 4: + elif col_split[3] == "avg": if ( col_split[1].lower() == stream_name.lower() and col_split[2].lower() == err.lower() From 3d5263362f46453445ccd3ef4153f540dd6fec81 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Mon, 16 Mar 2026 15:22:37 +0100 Subject: [PATCH 240/344] update val noise levels --- config/config_diffusion.yml | 8 ++++---- src/weathergen/utils/plot_training.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 9d4074e88..765d3a78a 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -55,12 +55,12 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 2 +fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 @@ -153,7 +153,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 150 + num_mini_epochs: 512 samples_per_mini_epoch: 66 shuffle: True @@ -224,7 +224,7 @@ validation_config: # Noise levels (eta values in standard normal space) at which to evaluate the # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. - validation_noise_levels: [0.03, 0.3, 1, 3.0] + validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] samples_per_mini_epoch: 16 shuffle: False diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index ea5d960b4..87a546541 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -363,7 +363,9 @@ def plot_loss_per_stream( data_cols = [] for col in run_data_mode.columns: col_split = col.split(".") - if len(col_split) < 4: + if col == stream_name: + data_cols += [col] + elif len(col_split) < 4: if stream_name in col: data_cols += [col] elif col_split[3] == "avg": From 594412bebf65a013061f8e5c7980176eda4fa14b Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 19 Mar 2026 14:29:04 +0100 Subject: [PATCH 241/344] avoid rounding of validation noise levels for logging --- src/weathergen/train/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 9ec8330b7..ce78885c4 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -10,8 +10,8 @@ # nor does it submit to any jurisdiction. import copy import logging -import re import time +from decimal import Decimal import numpy as np import torch @@ -568,7 +568,9 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if is_diffusion: self._set_validation_noise_level(noise_level) - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + _d = Decimal(str(noise_level)).normalize() + _sign, _digits, _exp = _d.as_tuple() + eta_str = f"{'-' if _sign else ''}{''.join(map(str, _digits))}e{_exp}" loss_suffix = f".eta{eta_str}" if len(noise_levels) > 1 else "" stage_suffix = f"_eta{eta_str}" if len(noise_levels) > 1 else "" From 366c6b50556e6f2b6aab349f995c493d05bed566 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 23 Mar 2026 09:23:56 +0100 Subject: [PATCH 242/344] ERA5 distribution setup --- config/config_diffusion.yml | 17 +- config/runs_plot_train_diffusion.yml | 37 ---- src/weathergen/model/diffusion.py | 15 -- src/weathergen/utils/validation_io.py | 275 ++++++++++++-------------- 4 files changed, 135 insertions(+), 209 deletions(-) delete mode 100644 config/runs_plot_train_diffusion.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 765d3a78a..3c977d11c 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -55,7 +55,7 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 6 +fe_num_blocks: 4 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -158,7 +158,7 @@ training_config: shuffle: True start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + end_date: 2012-09-01T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -188,7 +188,7 @@ training_config: losses : { "physical": { type: LossPhysical, - weight: 0.1, + weight: 0.0, loss_fcts: { "mse": {}, }, @@ -196,7 +196,7 @@ training_config: }, "latent_diff": { type: LossLatentDiffusion, - weight: 0.9, + weight: 1.0, target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, } @@ -224,13 +224,14 @@ validation_config: # Noise levels (eta values in standard normal space) at which to evaluate the # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. - validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] + # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] + validation_noise_levels: [0.3, 1.5] - samples_per_mini_epoch: 16 - shuffle: False + samples_per_mini_epoch: 1 + shuffle: True # TODO: Set back to False start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + end_date: 2012-07-01T18:00 # whether to track the exponential moving average of weights for validation validate_with_ema: diff --git a/config/runs_plot_train_diffusion.yml b/config/runs_plot_train_diffusion.yml deleted file mode 100644 index 9ac76fd22..000000000 --- a/config/runs_plot_train_diffusion.yml +++ /dev/null @@ -1,37 +0,0 @@ -train : - plot : - gfaszkoh: - slurm_id: 0 - description: "gfaszkoh: EDM settings, single sample, lat 1, phys 0, start_lr=1e-5, max_lr=1e-4" - - # h3of5mec: - # slurm_id: 0 - # description: "h3of5mec: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-4, fe=2 blocks" - - # p0q1oz52: - # slurm_id: 0 - # description: "p0q1oz52: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=5e-5, fe=2 blocks" - - # aabi87jc: - # slurm_id: 0 - # description: "aabi87jc: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=2.5e-5, fe=2 blocks" - - # s9fsjudp: - # slurm_id: 0 - # description: "s9fsjudp: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-5, fe=2 blocks" - - # gbq1pxc9: - # slurm_id: 0 - # description: "gbq1pxc9: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-4, fe=MLP" - - # yyv2m7ir: - # slurm_id: 0 - # description: "yyv2m7ir: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=5e-5, fe=MLP" - - # b5c60g4a: - # slurm_id: 0 - # description: "b5c60g4a: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=2.5e-5, fe=MLP" - - # dkzr6lfq: - # slurm_id: 0 - # description: "dkzr6lfq: lat 0.9, phys 0.1, start_lr=1e-6, max_lr=1e-5, fe=MLP" \ No newline at end of file diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 9d0f6fba0..0a1834be2 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -80,21 +80,6 @@ def forward( # y = data.get_input_data(-1) # eta = data.get_input_metadata(-1) - # TODO: remove after single sample experiments - if self.cur_token is not None: - # logger.info("checking single sampling") - assert self.cur_token[0].shape == tokens[0].shape, ( - "first token shape was different between iterations " - "– violates single sample overfitting with difference" - ) - assert torch.equal(self.cur_token[0], tokens[0]), ( - f"first token was different between iterations " - f"– violates single sample overfitting {self.cur_token[0] - tokens[0]}" - ) - assert torch.equal(self.cur_token, tokens), ( - f"tokens were different between iterations " - f"– violates single sample overfitting {self.cur_token - tokens}" - ) self.cur_token = tokens c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index deff62042..9d17d79c5 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -11,15 +11,16 @@ from math import exp import re +import cartopy.crs as ccrs +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import numpy as np import torch -import xarray as xr import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange, zarrio_writer from weathergen.datasets.data_reader_base import TimeWindowHandler -from weathergen.evaluate.plotting.plotter import Plotter _logger = logging.getLogger(__name__) @@ -252,25 +253,40 @@ def write_output( for subset in data.items(): zio.write_zarr(subset) + # Extract a representative date per batch sample from target times before + # they are freed. Use the first non-NaT timestamp found in t_idx=0. + sample_dates: list[str] = [] + if len(targets_times_all) > 0: + for stream_times in targets_times_all[0]: + if stream_times.size > 0: + valid_times = stream_times[~np.isnat(stream_times)] + if valid_times.size > 0: + sample_dates.append(str(valid_times[0].astype("datetime64[h]"))) + break # Free arrays no longer needed after zarr writing del targets_all, targets_times_all, targets_lens, sources, data # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # Prepare prediction data for Plotter (scatter plot expects lat/lon coords on ipoint). + # Prepare prediction data for plotting (scatter plot expects lat/lon coords on ipoint). base_plot_dir = config.get_path_run(cf) / "plots" / "validation" base_plot_dir.mkdir(parents=True, exist_ok=True) - plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) + dpi_val = 150 + image_format = "png" # headline_channels = {"2t", "z500", "q850", "10u", "10v"} headline_channels = {"2t", "q850"} - t_idx = 0 + t_idx = np.random.randint(0, len(preds_all)) # TODO: loop over all time steps once plotting is set up for stream_idx, stream_info in enumerate(cf.streams): stream_name = stream_info["name"] preds_stream = preds_all[t_idx][stream_idx] coords_stream = targets_coords_all[t_idx][stream_idx] + # Check for noised data for this stream + noised_stream = noised_preds_all[t_idx][stream_idx] + has_noised = noised_stream.size > 0 and noised_stream.ndim >= 2 + if preds_stream.size == 0 or coords_stream.size == 0: _logger.warning(f"No prediction data to plot for stream {stream_name}.") continue @@ -284,6 +300,12 @@ def write_output( ) continue + if has_noised: + if noised_stream.ndim == 3: + noised_stream = noised_stream[0] + elif noised_stream.ndim != 2: + has_noised = False + channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) selected_channels = [ ch for ch in channels if _normalize_channel_name(ch) in headline_channels @@ -300,169 +322,124 @@ def write_output( lat = coords_stream[:, 0] lon = coords_stream[:, 1] - plotter.stream = stream_name - plotter.run_id = config.get_run_id_from_config(cf) - plotter.fstep = forecast_offset + run_id = config.get_run_id_from_config(cf) num_samples = len(preds) len_per_sample = preds_stream.shape[0] // num_samples + noised_len_per_sample = noised_stream.shape[0] // num_samples if has_noised else 0 + + if noise_level is not None: + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + else: + eta_str = None + eta_tag = f"_eta{eta_str}" if eta_str is not None else "" for sample in range(num_samples): s_start = sample * len_per_sample s_end = (sample + 1) * len_per_sample + ns_start = sample * noised_len_per_sample if has_noised else 0 + ns_end = (sample + 1) * noised_len_per_sample if has_noised else 0 + + # Extract sample date from target times + sample_date_str = sample_dates[0] if len(sample_dates) > 0 else "" for varname in selected_channels: col = ch_to_col[varname] - vals = preds_stream[s_start:s_end, col] - sample_lat = lat[s_start:s_end] - sample_lon = lon[s_start:s_end] - - # Drop NaN points - valid = ~np.isnan(vals) - vals = vals[valid] - sample_lat = sample_lat[valid] - sample_lon = sample_lon[valid] - - sample_da = xr.DataArray( - vals, - dims=("ipoint",), - coords={ - "ipoint": np.arange(len(vals)), - "lat": ("ipoint", sample_lat), - "lon": ("ipoint", sample_lon), - }, - ) - channel_dir = base_plot_dir / varname - channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}" - # Add noise_level_rn to title if present for this stream - if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + # --- denoised data --- + den_vals = preds_stream[s_start:s_end, col] + den_lat = lat[s_start:s_end] + den_lon = lon[s_start:s_end] + den_valid = ~np.isnan(den_vals) + den_vals, den_lat, den_lon = den_vals[den_valid], den_lat[den_valid], den_lon[den_valid] + + # --- noised data --- + if has_noised: + noi_vals = noised_stream[ns_start:ns_end, col] + noi_lat = lat[ns_start:ns_end] + noi_lon = lon[ns_start:ns_end] + noi_valid = ~np.isnan(noi_vals) + noi_vals, noi_lat, noi_lon = noi_vals[noi_valid], noi_lat[noi_valid], noi_lon[noi_valid] else: - eta_str = None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - - if noise_level is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1} | noise_level={eta_str}" + noi_vals = noi_lat = noi_lon = None + + # Shared colour scale across both panels + all_vals = np.concatenate([den_vals] + ([noi_vals] if noi_vals is not None else [])) + vmin, vmax = float(np.nanmin(all_vals)), float(np.nanmax(all_vals)) + norm = mcolors.Normalize(vmin=vmin, vmax=vmax) + cmap = plt.get_cmap("coolwarm") + + ncols = 2 if has_noised else 1 + proj = ccrs.Robinson() + fig, axes = plt.subplots( + 1, ncols, figsize=(8 * ncols, 5), dpi=dpi_val, + subplot_kw={"projection": proj}, + ) + if ncols == 1: + axes = [axes] + + # Left panel: noised (or skip if not available) + if has_noised: + ax_noised = axes[0] + ax_noised.coastlines() + ax_noised.set_global() + sc_n = ax_noised.scatter( + noi_lon, noi_lat, c=noi_vals, norm=norm, cmap=cmap, + s=2.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, + ) + if eta_str is not None: + ax_noised.set_title(f"Noised | {varname} | eta={eta_str}", fontsize=9.5) + else: + ax_noised.set_title(f"Noised | {varname}", fontsize=9.5) + ax_noised.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1) + ax_denoised = axes[1] else: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}" - - plot_name = plotter.scatter_plot( - sample_da, - channel_dir, - varname=varname, - regionname="global", - tag=epoch_tag, - title=title, + ax_denoised = axes[0] + + # Right panel (or only panel): denoised + ax_denoised.coastlines() + ax_denoised.set_global() + sc_d = ax_denoised.scatter( + den_lon, den_lat, c=den_vals, norm=norm, cmap=cmap, + s=2.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, ) - src = channel_dir / f"{plot_name}.{plotter.image_format}" - dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst and src.exists(): - src.replace(dst) - - del sample_da, vals, sample_lat, sample_lon, valid - - del preds_stream, coords_stream - - # Plot decoded noised tokens (diffusion models only) - has_noised = any( - noised_preds_all[t_idx][s_idx].size > 0 - for s_idx in range(len(cf.streams)) - if noised_preds_all[t_idx][s_idx].ndim >= 2 - ) - if has_noised: - for stream_idx, stream_info in enumerate(cf.streams): - stream_name = stream_info["name"] - noised_stream = noised_preds_all[t_idx][stream_idx] - coords_stream = targets_coords_all[t_idx][stream_idx] + if eta_str is not None: + ax_denoised.set_title(f"Denoised | {varname} | eta={eta_str}", fontsize=9.5) + else: + ax_denoised.set_title(f"Denoised | {varname}", fontsize=9.5) + ax_denoised.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1) - if noised_stream.size == 0 or coords_stream.size == 0: - continue + # Shared colourbar + fig.colorbar( + sc_d, ax=axes, orientation="horizontal", + label=f"Variable: {varname}", fraction=0.05, pad=0.07, + ) - if noised_stream.ndim == 3: - noised_stream = noised_stream[0] - elif noised_stream.ndim != 2: - continue - - channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) - selected_channels = [ - ch for ch in channels if _normalize_channel_name(ch) in headline_channels - ] - if not selected_channels: - del noised_stream, coords_stream - continue - - ch_to_col = {ch: idx for idx, ch in enumerate(channels)} - - lat = coords_stream[:, 0] - lon = coords_stream[:, 1] - - plotter.stream = stream_name - plotter.run_id = config.get_run_id_from_config(cf) - plotter.fstep = forecast_offset - - num_samples = len(preds) - len_per_sample = noised_stream.shape[0] // num_samples - - for sample in range(num_samples): - s_start = sample * len_per_sample - s_end = (sample + 1) * len_per_sample - - for varname in selected_channels: - col = ch_to_col[varname] - vals = noised_stream[s_start:s_end, col] - sample_lat = lat[s_start:s_end] - sample_lon = lon[s_start:s_end] - - # Drop NaN points - valid = ~np.isnan(vals) - vals = vals[valid] - sample_lat = sample_lat[valid] - sample_lon = sample_lon[valid] - - sample_da = xr.DataArray( - vals, - dims=("ipoint",), - coords={ - "ipoint": np.arange(len(vals)), - "lat": ("ipoint", sample_lat), - "lon": ("ipoint", sample_lon), - }, + date_part = f" | {sample_date_str}" if sample_date_str else "" + if eta_str is not None: + fig.suptitle( + f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}{date_part} | eta={eta_str}", + fontsize=11, ) - - channel_dir = base_plot_dir / varname / "noised" - channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}_noised" - - if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) - else: - eta_str = None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - - if noise_level is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noised sample {sample + 1} | noise_level={eta_str}" - else: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noised sample {sample + 1}" - - plot_name = plotter.scatter_plot( - sample_da, - channel_dir, - varname=varname, - regionname="global", - tag=epoch_tag, - title=title, + else: + fig.suptitle( + f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}{date_part}", + fontsize=11, ) - src = channel_dir / f"{plot_name}.{plotter.image_format}" - dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst and src.exists(): - src.replace(dst) - del sample_da, vals, sample_lat, sample_lon, valid + channel_dir = base_plot_dir / varname + channel_dir.mkdir(parents=True, exist_ok=True) + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" + fname = channel_dir / f"{epoch_tag}.{image_format}" + fig.savefig(fname, bbox_inches="tight") + plt.close(fig) - del noised_stream, coords_stream + del den_vals, den_lat, den_lon, den_valid + if has_noised: + del noi_vals, noi_lat, noi_lon, noi_valid + + del preds_stream, coords_stream + if has_noised: + del noised_stream i += 1 From bad90739519a6a621ec3e9edde8097f9e10da363 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Mon, 23 Mar 2026 12:20:17 +0100 Subject: [PATCH 243/344] Update diffusion config to normal train/val split --- config/config_diffusion.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 3c977d11c..8cc190eaa 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -157,8 +157,8 @@ training_config: samples_per_mini_epoch: 66 shuffle: True - start_date: 2012-06-01T00:00 - end_date: 2012-09-01T18:00 + start_date: 1979-06-01T00:00 + end_date: 2022-12-31T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -230,8 +230,8 @@ validation_config: samples_per_mini_epoch: 1 shuffle: True # TODO: Set back to False - start_date: 2012-06-01T00:00 - end_date: 2012-07-01T18:00 + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 # whether to track the exponential moving average of weights for validation validate_with_ema: From 1d4f30b58e71849b99773f143a2e43719619d7af Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 25 Mar 2026 21:48:08 +0100 Subject: [PATCH 244/344] Testing inference --- src/weathergen/model/diffusion.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 9d0f6fba0..6dd63cfd7 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -97,6 +97,11 @@ def forward( ) self.cur_token = tokens + # print("input tokens statistics") + # print("mean", tokens.mean(), "std", tokens.std(), "max", tokens.max(), "min", tokens.min()) + + # return self.inference(fstep=fstep, num_steps=100) + c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -130,6 +135,8 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> # Embed noise level noise_emb = self.noise_embedder(c_noise) + # print("sigma", sigma) + # Precondition input and feed through network x = self.preconditioner.precondition(x, c) @@ -146,7 +153,10 @@ def inference( # https://github.com/NVlabs/edm/blob/main/generate.py # Sample noise (assuming single batch element for now) - x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") + x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") * 0.1 + # x = self.cur_token + x + print("initial noise statistics") + print("mean", x.mean(), "std", x.std(), "max", x.max(), "min", x.min()) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") @@ -156,15 +166,16 @@ def inference( / (num_steps - 1) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) ) ** self.rho - t_steps = torch.cat( - [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] - ) # t_N = 0 + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 # Main sampling loop. x_next = x * t_steps[0] for i, (t_cur, t_next) in enumerate( zip(t_steps[:-1], t_steps[1:], strict=False) ): # 0, ..., N-1 + t_cur = torch.tensor([t_cur], device="cuda").float() + t_next = torch.tensor([t_next], device="cuda").float() + x_cur = x_next # Increase noise temporarily. (Stochastic sampling; not used for now) From 57a300c6c0fa4c5a4a7930668c39987400bc80ab Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 25 Mar 2026 22:07:31 +0100 Subject: [PATCH 245/344] Train more samples --- config/config_diffusion.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 765d3a78a..136819aa5 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -69,7 +69,7 @@ with_step_conditioning: True # False frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 -sigma_max: 50000 +sigma_max: 80 sigma_data: 0.5 rho: 7 p_mean: -1.2 @@ -153,8 +153,8 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 512 - samples_per_mini_epoch: 66 + num_mini_epochs: 128 + samples_per_mini_epoch: 512 shuffle: True start_date: 2012-06-01T00:00 @@ -224,7 +224,8 @@ validation_config: # Noise levels (eta values in standard normal space) at which to evaluate the # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. - validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] + # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] + validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] samples_per_mini_epoch: 16 shuffle: False From 3857b2d863dbc81805cbe5ebb7f7b1bd2ec423b9 Mon Sep 17 00:00:00 2001 From: Matthias Date: Thu, 26 Mar 2026 07:44:47 +0100 Subject: [PATCH 246/344] Untrack runs_plot_train --- config/runs_plot_train_diffusion.yml | 37 ---------------------------- 1 file changed, 37 deletions(-) delete mode 100644 config/runs_plot_train_diffusion.yml diff --git a/config/runs_plot_train_diffusion.yml b/config/runs_plot_train_diffusion.yml deleted file mode 100644 index 9ac76fd22..000000000 --- a/config/runs_plot_train_diffusion.yml +++ /dev/null @@ -1,37 +0,0 @@ -train : - plot : - gfaszkoh: - slurm_id: 0 - description: "gfaszkoh: EDM settings, single sample, lat 1, phys 0, start_lr=1e-5, max_lr=1e-4" - - # h3of5mec: - # slurm_id: 0 - # description: "h3of5mec: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-4, fe=2 blocks" - - # p0q1oz52: - # slurm_id: 0 - # description: "p0q1oz52: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=5e-5, fe=2 blocks" - - # aabi87jc: - # slurm_id: 0 - # description: "aabi87jc: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=2.5e-5, fe=2 blocks" - - # s9fsjudp: - # slurm_id: 0 - # description: "s9fsjudp: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-5, fe=2 blocks" - - # gbq1pxc9: - # slurm_id: 0 - # description: "gbq1pxc9: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=1e-4, fe=MLP" - - # yyv2m7ir: - # slurm_id: 0 - # description: "yyv2m7ir: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=5e-5, fe=MLP" - - # b5c60g4a: - # slurm_id: 0 - # description: "b5c60g4a: lat 0.9, phys 0.1, start_lr=1e-5, max_lr=2.5e-5, fe=MLP" - - # dkzr6lfq: - # slurm_id: 0 - # description: "dkzr6lfq: lat 0.9, phys 0.1, start_lr=1e-6, max_lr=1e-5, fe=MLP" \ No newline at end of file From cfce7a2b596d8d5b10739be06957f1dfab249411 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 27 Mar 2026 10:28:58 +0100 Subject: [PATCH 247/344] Enable DDP training --- config/config_diffusion.yml | 9 +++++---- src/weathergen/model/diffusion.py | 27 ++++++++++++++++++------- src/weathergen/model/engines.py | 4 ++-- src/weathergen/model/model.py | 10 ++++----- src/weathergen/model/model_interface.py | 8 ++++++++ src/weathergen/train/trainer.py | 10 +++++---- src/weathergen/utils/validation_io.py | 18 +++++++++++------ 7 files changed, 58 insertions(+), 28 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 136819aa5..ee6c3fac7 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -154,7 +154,7 @@ training_config: training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] num_mini_epochs: 128 - samples_per_mini_epoch: 512 + samples_per_mini_epoch: 1024 shuffle: True start_date: 2012-06-01T00:00 @@ -188,7 +188,7 @@ training_config: losses : { "physical": { type: LossPhysical, - weight: 0.1, + weight: 0.0, loss_fcts: { "mse": {}, }, @@ -196,7 +196,7 @@ training_config: }, "latent_diff": { type: LossLatentDiffusion, - weight: 0.9, + weight: 1.0, target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, } @@ -225,7 +225,8 @@ validation_config: # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] - validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] + validation_noise_levels: [2.0, 3.0, 3.2, 3.5, 4.0] + # validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] samples_per_mini_epoch: 16 shuffle: False diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 6dd63cfd7..fde75eeda 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -95,11 +95,10 @@ def forward( f"tokens were different between iterations " f"– violates single sample overfitting {self.cur_token - tokens}" ) - self.cur_token = tokens + self.cur_token = tokens.detach() # print("input tokens statistics") # print("mean", tokens.mean(), "std", tokens.std(), "max", tokens.max(), "min", tokens.min()) - # return self.inference(fstep=fstep, num_steps=100) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) @@ -116,7 +115,7 @@ def forward( sigma = (eta * self.p_std + self.p_mean).exp() n = torch.randn_like(y) * sigma - self._noised_tokens = y + n + self._noised_tokens = (y + n).detach() return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) @@ -153,10 +152,22 @@ def inference( # https://github.com/NVlabs/edm/blob/main/generate.py # Sample noise (assuming single batch element for now) - x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") * 0.1 - # x = self.cur_token + x - print("initial noise statistics") - print("mean", x.mean(), "std", x.std(), "max", x.max(), "min", x.min()) + x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") * 1.0 + + # eta = torch.tensor([1.0], device="cuda").float() # 1.0 (good), 2.0 (okay), 2.2 (max), 2.5 (hard) + # sigma = (eta * self.p_std + self.p_mean).exp() + # print("sigma", sigma) + # n = torch.randn_like(x).to(device="cuda") * sigma + # x = self.cur_token + n + + x = self.cur_token * 0.05 + x + + # breakpoint() + + + # return self.denoise(x=x, c=None, sigma=sigma, fstep=fstep) + # print("initial noise statistics") + # print("mean", x.mean(), "std", x.std(), "max", x.max(), "min", x.min()) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") @@ -176,6 +187,8 @@ def inference( t_cur = torch.tensor([t_cur], device="cuda").float() t_next = torch.tensor([t_next], device="cuda").float() + print(i, t_cur.item()) + x_cur = x_next # Increase noise temporarily. (Stochastic sampling; not used for now) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 871fd366b..250bda96c 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -511,9 +511,9 @@ def forward( ) for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): - tokens = block(tokens) + tokens = checkpoint(block, tokens, use_reentrant=False) else: - tokens = block(tokens, coords, noise_emb, aux_info) + tokens = checkpoint(block, tokens, coords, noise_emb, aux_info, use_reentrant=False) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 49cbf9bf6..66a962c69 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -661,10 +661,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - t_mean = tokens.mean() - t_std = tokens.std() - tokens = (tokens - t_mean) / (t_std + 1e-6) * self.cf.sigma_data - tokens = torch.clamp(tokens, -100.0, 100.0) + # t_mean = tokens.mean() + # t_std = tokens.std() + # tokens = (tokens - t_mean) / (t_std + 1e-6) * self.cf.sigma_data + # tokens = torch.clamp(tokens, -100.0, 100.0) # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): @@ -679,7 +679,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Un-normalize tokens # TODO: REMOVE THIS AS ABOVE. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - tokens = tokens * (t_std + 1e-6) / self.cf.sigma_data + t_mean + # tokens = tokens * (t_std + 1e-6) / self.cf.sigma_data + t_mean # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index be933935a..b4b3c54e7 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -342,6 +342,14 @@ def get_target_aux_calculator( with_fsdp=False, overrides=target_and_aux_calc_params.get("model_param_overrides", {}), ) + # Free components not needed by DiffusionLatentTargetEncoder (only uses the encoder) + for attr in ("forecast_engine", "pred_heads", "target_token_engines", + "embed_target_coords", "latent_heads", "latent_pre_norm"): + if hasattr(model, attr) and getattr(model, attr) is not None: + delattr(model, attr) + setattr(model, attr, None) + torch.cuda.empty_cache() + target_aux = DiffusionLatentTargetEncoder( model, is_model_sharded=(cf.with_ddp and cf.with_fsdp) ) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ce78885c4..24c5f1990 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -687,14 +687,16 @@ def _set_validation_noise_level(self, noise_level: float | None): noise_level: The eta value (standard normal space) to fix for validation. sigma = exp(eta * p_std + p_mean). None resets to default (0.0). """ + # Unwrap DDP/FSDP to access the underlying model + base_model = getattr(self.model, "module", self.model) # Set on the base model - if hasattr(self.model, "forecast_engine") and hasattr( - self.model.forecast_engine, "_fixed_noise_level" + if hasattr(base_model, "forecast_engine") and hasattr( + base_model.forecast_engine, "_fixed_noise_level" ): - self.model.forecast_engine._fixed_noise_level = noise_level + base_model.forecast_engine._fixed_noise_level = noise_level # Also set on the EMA model (separate model copy used during validation) if self.ema_model is not None: - ema_net = self.ema_model.ema_model + ema_net = getattr(self.ema_model.ema_model, "module", self.ema_model.ema_model) if hasattr(ema_net, "forecast_engine") and hasattr( ema_net.forecast_engine, "_fixed_noise_level" ): diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index deff62042..30074e207 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -338,7 +338,7 @@ def write_output( epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}" # Add noise_level_rn to title if present for this stream if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + eta_str = str(noise_level) else: eta_str = None eta_tag = f"_eta{eta_str}" if eta_str is not None else "" @@ -359,8 +359,11 @@ def write_output( ) src = channel_dir / f"{plot_name}.{plotter.image_format}" dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst and src.exists(): - src.replace(dst) + if src != dst: + try: + src.replace(dst) + except (FileNotFoundError, OSError): + pass # another rank already renamed or removed the file del sample_da, vals, sample_lat, sample_lon, valid @@ -437,7 +440,7 @@ def write_output( epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}_noised" if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + eta_str = str(noise_level) else: eta_str = None eta_tag = f"_eta{eta_str}" if eta_str is not None else "" @@ -458,8 +461,11 @@ def write_output( ) src = channel_dir / f"{plot_name}.{plotter.image_format}" dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst and src.exists(): - src.replace(dst) + if src != dst: + try: + src.replace(dst) + except (FileNotFoundError, OSError): + pass # another rank already renamed or removed the file del sample_da, vals, sample_lat, sample_lon, valid From c2029c52c6511ae2869ab6ae7f5751124cfd6355 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Fri, 27 Mar 2026 15:44:56 +0100 Subject: [PATCH 248/344] inter commit --- config/config_diffusion.yml | 3 +- src/weathergen/datasets/batch.py | 7 ++ src/weathergen/datasets/masking.py | 4 + src/weathergen/model/attention.py | 7 +- src/weathergen/model/diffusion.py | 65 +++++++++++++- src/weathergen/model/engines.py | 3 +- src/weathergen/model/model.py | 13 ++- src/weathergen/model/norms.py | 86 +++++-------------- .../train/target_and_aux_diffusion.py | 1 + src/weathergen/train/trainer.py | 2 + src/weathergen/utils/validation_io.py | 1 + 11 files changed, 119 insertions(+), 73 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 8cc190eaa..b951153b9 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -66,6 +66,7 @@ fe_impute_latent_noise_std: 0.0 # 1e-4 forecast_att_dense_rate: 1.0 with_step_conditioning: True # False # Diffusion related parameters +diffusion_conditioning_embed_dim: 4 frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 @@ -137,7 +138,7 @@ train_logging: # parameters for data loading data_loading : - num_workers: 12 + num_workers: 0 #12 rng_seed: ??? repeat_data_in_mini_epoch : True diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index d106feb08..0685957f0 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -25,6 +25,11 @@ class SampleMetaData: global_params: dict | None = None + def add_global_params(self, params: dict) -> None: + if self.global_params is None: + self.global_params = {} + self.global_params.update(params) + class Sample: # keys: stream name, values: SampleMetaData @@ -91,6 +96,7 @@ def add_meta_info(self, stream_name: str, meta_info: SampleMetaData) -> None: """ Add metadata for stream @stream_name to sample """ + print(meta_info.__dict__) self.meta_info[stream_name] = meta_info def get_stream_data(self, stream_name: str) -> StreamData: @@ -274,6 +280,7 @@ def add_source_stream( """ Add data for one stream to sample @source_sample_idx """ + self.source_samples.samples[source_sample_idx].add_stream_data(stream_name, stream_data) # add the meta_info diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f2b8d6621..abe959cc1 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -37,6 +37,8 @@ def add_mask(self, mask, params, cfg, losses, idx, correspondence, relationship) } if "noise_level_rn" in params: global_params["noise_level_rn"] = params["noise_level_rn"] + print(mask) + print(params) self.masks += [mask] self.metadata += [ SampleMetaData( @@ -471,6 +473,8 @@ def _generate_cell_mask( if "diffusion_rn" in masking_strategy_config: masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) + + elif strategy == "healpix": # prepare healpix-based masking keep_rate = self._get_sampling_rate(masking_strategy_config) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 9ce6ccc75..004be6e84 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.layers import LinearNormConditioning -from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.norms import AdaLayerNorm, AdaLayerNormFinal, RMSNorm from weathergen.model.positional_encoding import rotary_pos_emb_2d """ @@ -539,7 +539,8 @@ def __init__( norm = RMSNorm if dim_aux is not None: - self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) + self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero + self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero else: self.lnorm = norm(dim_embed, eps=norm_eps) self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -597,7 +598,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: - out = out + x_in * gate if self.noise_conditioning else out + x_in + out = x_in + out * gate if self.noise_conditioning else out + x_in return out diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 0a1834be2..b14092815 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -25,6 +25,7 @@ import logging import math +import numpy as np import torch @@ -49,6 +50,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.noise_embedder = NoiseEmbedder( embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim ) + self.datetime_embedder = DateTimeEncoder() # Parameters self.sigma_min = self.cf.sigma_min @@ -82,7 +84,7 @@ def forward( self.cur_token = tokens - c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) + c = torch.tensor([meta_info["ERA5"].params["datetime"]], device=tokens.device) # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) y = tokens if self.training: @@ -116,10 +118,11 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> noise_emb = self.noise_embedder(c_noise) # Precondition input and feed through network - x = self.preconditioner.precondition(x, c) + x = self.preconditioner.precondition(x, c) #currently does nothing + c = self.datetime_embedder(c) return c_skip * x + c_out * self.net( - c_in * x, fstep=fstep, noise_emb=noise_emb + c_in * x, fstep=fstep, noise_emb=noise_emb, ada_ln_aux=c ) # Eq. (7) in EDM paper def inference( @@ -221,3 +224,59 @@ def forward(self, t: float): t_freq = self.timestep_embedding(t) t_emb = self.mlp(t_freq) return t_emb + +class DateTimeEncoder(torch.nn.Module): + """ + Encodes timestamp(s) in seconds since Unix epoch into a 4D vector: + [time_of_day_sin, time_of_day_cos, day_of_year_sin, day_of_year_cos] + + Input shape: scalar or any tensor shape (...) + Output shape: (..., 4) + """ + + def __init__(self): + super().__init__() + + def forward(self, timestamp: torch.Tensor | np.ndarray) -> torch.Tensor: + """ + Encode datetime64 timestamp into a 4D vector: + [time_of_day_sin, time_of_day_cos, day_of_year_sin, day_of_year_cos] + + Input: np.datetime64 or torch.Tensor containing datetime64 values + Output: (..., 4) shaped tensor + """ + # Convert to numpy if needed + if isinstance(timestamp, torch.Tensor): + timestamp = timestamp.detach().cpu().numpy() + + # Ensure datetime64[s] precision + timestamp = timestamp.astype('datetime64[s]') + orig_shape = timestamp.shape + timestamp_flat = timestamp.reshape(-1) + + two_pi = 2.0 * np.pi + + # --- Time of day from seconds since epoch --- + ts_int64 = timestamp_flat.astype('int64') # seconds since Unix epoch + seconds_in_day = 86400.0 + time_of_day = (ts_int64 % int(seconds_in_day)) / seconds_in_day + tod_sin = np.sin(two_pi * time_of_day).astype(np.float32) + tod_cos = np.cos(two_pi * time_of_day).astype(np.float32) + + # --- Day of year --- + day_np = timestamp_flat.astype('datetime64[D]') + year_start = day_np.astype('datetime64[Y]').astype('datetime64[D]') + next_year_start = (day_np.astype('datetime64[Y]') + np.timedelta64(1, 'Y')).astype('datetime64[D]') + + day_of_year_0 = (day_np - year_start).astype(np.int64) + days_in_year = (next_year_start - year_start).astype(np.int64) + doy_frac = day_of_year_0.astype(np.float32) / days_in_year.astype(np.float32) + + doy_sin = np.sin(two_pi * doy_frac).astype(np.float32) + doy_cos = np.cos(two_pi * doy_frac).astype(np.float32) + + # Stack and convert to tensor + out = np.stack([tod_sin, tod_cos, doy_sin, doy_cos], axis=-1) + out = torch.from_numpy(out).float() + + return out.reshape(*orig_shape, 4) \ No newline at end of file diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 871fd366b..dbc8c36fe 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -489,6 +489,7 @@ def forward( fstep: int, meta_info: SampleMetaData = None, noise_emb: torch.Tensor = None, + ada_ln_aux: torch.Tensor = None, coords: torch.Tensor = None, ) -> torch.Tensor: # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning @@ -513,7 +514,7 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = block(tokens) else: - tokens = block(tokens, coords, noise_emb, aux_info) + tokens = block(tokens, coords, noise_emb, ada_ln_aux) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 49cbf9bf6..95f7107c1 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -389,11 +389,13 @@ def create(self) -> "Model": mode_cfg = cf.training_config self.forecast_engine = None if cf.fe_num_blocks > 0: - self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) if cf.get("fe_diffusion_model", False): + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells, dim_aux=self.cf.diffusion_conditioning_embed_dim) self.forecast_engine = DiffusionForecastEngine( cf, self.num_healpix_cells, forecast_engine=self.forecast_engine ) + else: + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) # embed coordinates yielding one query token for each target token dropout_rate = cf.embed_dropout_rate @@ -666,10 +668,19 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: tokens = (tokens - t_mean) / (t_std + 1e-6) * self.cf.sigma_data tokens = torch.clamp(tokens, -100.0, 100.0) + breakpoint() + # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): # apply forecasting engine (if present) if self.forecast_engine: + + #TODO: move this to the appropriate place in batch consturction + batch.samples[0].meta_info['ERA5'].add_global_params({'datetime': batch.samples[0].streams_data['ERA5'].source_raw[0].datetimes[0]}) + print(f'added {batch.samples[0].streams_data['ERA5'].source_raw[0].datetimes[0]}') + # add_global_params({'datetime': batch.samples[0].streams_data[self.stream_names[0]].source_raw[0].datetimes[0]}) + + breakpoint() tokens = self.forecast_engine( tokens, step, diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 4ecbfa80a..cbb5f7ba2 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -87,76 +87,34 @@ def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Ten x = self.norm(x) * (1 + scale) + shift return x - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - - -class SwiGLU(nn.Module): - def __init__(self): - super(SwiGLU, self).__init__() - - def forward(self, x): - x1, x2 = x.chunk(2, dim=-1) - return x2 * F.silu(x1) - - -class AdaLayerNormLayer(torch.nn.Module): + +class AdaLayerNormFinal(torch.nn.Module): """ - AdaLayerNorm for embedding auxiliary information as done in DiT (Peebles & Xie) with zero - initialisation https://arxiv.org/pdf/2212.09748 - - This module thus wraps a layer (e.g. self-attention or feedforward nn) and applies LayerNorm - followed by scale and shift before the layer and a final scaling after the layer as well as the - final residual layer. - - layer is a function that takes 2 arguments the first the latent and the second is the - conditioning signal + AdaLayerNorm from DiT for the final output gate only, i.e. only scale """ - + def __init__( - self, - dim, - dim_aux, - layer, - norm_eps: float = 1e-6, - dropout_rate: float = 0.0, + self, dim_embed_x, dim_aux, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5 ): super().__init__() - self.dim = dim - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_aux, 3 * dim, bias=True)) - - self.ln = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps) - self.layer = layer - - # Initialize weights to zero for modulation and gating layers - self.initialise_weights() - - def initialise_weights(self): - nn.init.zeros_(self.adaLN_modulation[-1].weight) - nn.init.zeros_(self.adaLN_modulation[-1].bias) - - def forward(self, x: torch.Tensor, c: torch.Tensor, x_lens, **kwargs) -> torch.Tensor: - # the -1 in torch.repeat_interleave(..) is because x_lens is designed for use with flash - # attention and thus has a spurious 0 at the beginning to satisfy the flash attention api - shift, scale, gate = self.adaLN_modulation(c)[torch.repeat_interleave(x_lens) - 1].chunk( - 3, dim=1 - ) - kwargs["x_lens"] = x_lens - return ( - gate - * self.layer( - modulate( - self.ln(x), - shift, - scale, - ), - **kwargs, - ) - + x - ) + # simple 2-layer MLP for embedding auxiliary information + self.embed_aux = torch.nn.ModuleList() + self.embed_aux.append(torch.nn.Linear(dim_aux, 4 * dim_aux)) + self.embed_aux.append(torch.nn.SiLU()) + self.embed_aux.append(torch.nn.Linear(4 * dim_aux, dim_embed_x)) + + self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) + + def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: + for block in self.embed_aux: + aux = block(aux) + scale = aux + + x = self.norm(x) * (1 + scale) + + return x + class SaturateEncodings(nn.Module): diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index a857bce0b..eb0f905fb 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -58,6 +58,7 @@ def compute( ) -> tuple[Any, Any]: # During validation (model in eval mode), use fixed noise level # so that sigma = exp(eta * p_std + p_mean) is deterministic + if model.training: noise_level_rn = ( batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index ce78885c4..3c77bdbe8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -434,6 +434,8 @@ def train(self, mini_epoch): batch.to_device(self.device) + breakpoint() + with torch.autocast( device_type=f"cuda:{cf.local_rank}", dtype=self.mixed_precision_dtype, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 9d17d79c5..64df4905a 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -94,6 +94,7 @@ def write_output( if loss_term.type == "LossPhysical" ] assert len(outputs_physical) == 1 + breakpoint() target_aux_out = target_aux_out[outputs_physical[0]] # collect all target / prediction-related information From 30fed1b614e8490c9ce4ff72eb15c2579d750965 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 30 Mar 2026 13:42:27 +0200 Subject: [PATCH 249/344] Add z500 only configs --- config/config_diffusion.yml | 8 +- config/config_diffusion_tiny.yml | 278 ++++++++++++++++++ .../streams/era5_1deg_diffusion_tiny/era5.yml | 38 +++ src/weathergen/model/attention.py | 3 +- src/weathergen/model/diffusion.py | 2 +- 5 files changed, 323 insertions(+), 6 deletions(-) create mode 100644 config/config_diffusion_tiny.yml create mode 100644 config/streams/era5_1deg_diffusion_tiny/era5.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index ee6c3fac7..19c42f396 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -69,8 +69,8 @@ with_step_conditioning: True # False frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 -sigma_max: 80 -sigma_data: 0.5 +sigma_max: 80 # 170 +sigma_data: 0.5 # 1.7 rho: 7 p_mean: -1.2 p_std: 1.2 @@ -153,7 +153,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 128 + num_mini_epochs: 512 samples_per_mini_epoch: 1024 shuffle: True @@ -225,7 +225,7 @@ validation_config: # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] - validation_noise_levels: [2.0, 3.0, 3.2, 3.5, 4.0] + validation_noise_levels: [2.0, 3.0, 3.2, 3.5, 4.0, 5.0] # validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] samples_per_mini_epoch: 16 diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml new file mode 100644 index 000000000..4e34d6ce2 --- /dev/null +++ b/config/config_diffusion_tiny.yml @@ -0,0 +1,278 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 128 +ae_local_num_blocks: 0 +ae_local_num_heads: 8 +ae_local_dropout_rate: 0.0 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 8 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.0 + +ae_global_dim_embed: 128 +ae_global_num_blocks: 4 +ae_global_num_heads: 8 +ae_global_dropout_rate: 0.0 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 4 +ae_aggregation_dropout_rate: 0.0 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 4 +fe_num_heads: 4 +fe_dropout_rate: 0.0 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 50000 +sigma_data: 0.5 +rho: 7 +p_mean: 0.0 # -1.2 +p_std: 1.2 # 1.2 + +healpix_level: 3 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +load_chkpt: {'run_id': 'lgrasnq6', 'epoch': -1} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_diffusion_tiny/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 150 + samples_per_mini_epoch: 66 + shuffle: True + + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 5e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.1, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 0.9, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_samples: 3 + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 16 + shuffle: False + + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 1, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/streams/era5_1deg_diffusion_tiny/era5.yml b/config/streams/era5_1deg_diffusion_tiny/era5.yml new file mode 100644 index 000000000..96b3aa6a1 --- /dev/null +++ b/config/streams/era5_1deg_diffusion_tiny/era5.yml @@ -0,0 +1,38 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source : ["z_500"] + target : ["z_500"] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 32 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 32 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 9ce6ccc75..0462bd3b5 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -597,7 +597,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: - out = out + x_in * gate if self.noise_conditioning else out + x_in + # out = out + x_in * gate if self.noise_conditioning else out + x_in + out = x_in + out * gate if self.noise_conditioning else x_in + out return out diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index fde75eeda..84b687c7d 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -160,7 +160,7 @@ def inference( # n = torch.randn_like(x).to(device="cuda") * sigma # x = self.cur_token + n - x = self.cur_token * 0.05 + x + x = self.cur_token * 0.01 + x # breakpoint() From 96a074be81dfd2cdfa2563b6392f0427a374e71d Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 31 Mar 2026 10:48:38 +0200 Subject: [PATCH 250/344] Overfitting to constant random noise fails --- config/config_diffusion_tiny.yml | 35 +++++++++++-------- src/weathergen/datasets/masking.py | 1 + src/weathergen/model/diffusion.py | 32 ++++++++++------- src/weathergen/model/engines.py | 11 ++++-- src/weathergen/model/layers.py | 14 +++++--- .../loss_module_latent_diffusion.py | 17 ++++++++- src/weathergen/utils/validation_io.py | 3 +- 7 files changed, 77 insertions(+), 36 deletions(-) diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml index 4e34d6ce2..f4bef3bcd 100644 --- a/config/config_diffusion_tiny.yml +++ b/config/config_diffusion_tiny.yml @@ -60,7 +60,7 @@ fe_num_heads: 4 fe_dropout_rate: 0.0 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 @@ -69,11 +69,11 @@ with_step_conditioning: True # False frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 -sigma_max: 50000 -sigma_data: 0.5 +sigma_max: 80 # 170 +sigma_data: 0.7855 # 0.5 # 1.7 rho: 7 -p_mean: 0.0 # -1.2 -p_std: 1.2 # 1.2 +p_mean: -1.2 +p_std: 1.2 healpix_level: 3 @@ -152,8 +152,8 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 150 - samples_per_mini_epoch: 66 + num_mini_epochs: 128 + samples_per_mini_epoch: 1024 shuffle: True start_date: 2012-06-01T00:00 @@ -175,8 +175,8 @@ training_config: parallel_scaling_policy: "sqrt" optimizer: - grad_clip: 1.0 - weight_decay: 0.1 + grad_clip: 10.0 # 1.0 + weight_decay: 0.0 # 0.1 log_grad_norms: False adamw : # parameters are scaled by number of DDP workers @@ -187,7 +187,7 @@ training_config: losses : { "physical": { type: LossPhysical, - weight: 0.1, + weight: 0.0, loss_fcts: { "mse": {}, }, @@ -195,7 +195,7 @@ training_config: }, "latent_diff": { type: LossLatentDiffusion, - weight: 0.9, + weight: 1.0, target_and_aux_calc: DiffusionLatentTargetEncoder, loss_fcts: { "mse": { }, }, } @@ -206,7 +206,7 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, - num_samples: 3 + num_samples: 1 } } @@ -220,7 +220,14 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: - samples_per_mini_epoch: 16 + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + # validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] + + samples_per_mini_epoch: 8 shuffle: False start_date: 2012-06-01T00:00 @@ -243,7 +250,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: True + validate_before_training: False # test config; full test config is merge of validation and test config diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f2b8d6621..6be8db6c5 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -470,6 +470,7 @@ def _generate_cell_mask( if "diffusion_rn" in masking_strategy_config: masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) + # masking_params["noise_level_rn"] = self.rng.uniform(-1.0, 5.0) elif strategy == "healpix": # prepare healpix-based masking diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 84b687c7d..5e9e19772 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -61,6 +61,8 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self._noised_tokens: torch.Tensor | None = None self._fixed_noise_level: float | None = None + self._noise = None + def forward( self, tokens: torch.Tensor, @@ -97,6 +99,8 @@ def forward( ) self.cur_token = tokens.detach() + # return tokens + # print("input tokens statistics") # print("mean", tokens.mean(), "std", tokens.std(), "max", tokens.max(), "min", tokens.min()) # return self.inference(fstep=fstep, num_steps=100) @@ -117,7 +121,12 @@ def forward( self._noised_tokens = (y + n).detach() - return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) + # return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) + #n = torch.ones_like(y) + if self._noise is None: + self._noise = torch.randn_like(y) + n = self._noise + return self.denoise(x=n, c=c, sigma=sigma, fstep=fstep) def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> torch.Tensor: """ @@ -125,23 +134,23 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> consideration of a conditioning c (e.g., previous time steps) and the current diffusion noise level sigma. """ - # Compute scaling conditionings - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() + # # Compute scaling conditionings (EDM Eq. 7 — disabled for direct prediction) + # c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + # c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + # c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() c_noise = sigma.log() / 4 # Embed noise level noise_emb = self.noise_embedder(c_noise) - # print("sigma", sigma) - # Precondition input and feed through network x = self.preconditioner.precondition(x, c) - return c_skip * x + c_out * self.net( - c_in * x, fstep=fstep, noise_emb=noise_emb - ) # Eq. (7) in EDM paper + # Direct prediction: network outputs denoised estimate directly + return self.net(x, fstep=fstep, noise_emb=noise_emb) + # return c_skip * x + c_out * self.net( + # c_in * x, fstep=fstep, noise_emb=noise_emb + # ) # Eq. (7) in EDM paper def inference( self, @@ -160,8 +169,7 @@ def inference( # n = torch.randn_like(x).to(device="cuda") * sigma # x = self.cur_token + n - x = self.cur_token * 0.01 + x - + # x = self.cur_token * 0.01 + x # breakpoint() diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 250bda96c..a64906f9e 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -471,14 +471,19 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = # norm_type=self.cf.norm_type, # dim_aux=dim_aux, # norm_eps=self.cf.mlp_norm_eps, - # with_noise_conditioning=False, + # with_noise_conditioning=True, # ) # ) + # def init_weights_final(m): + # if isinstance(m, torch.nn.Linear): + # torch.nn.init.normal_(m.weight, mean=0, std=0.001) + # if m.bias is not None: + # torch.nn.init.normal_(m.bias, mean=0, std=0.001) def init_weights_final(m): if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_(m.weight, mean=0, std=0.001) + torch.nn.init.normal_(m.weight, mean=0, std=0.1) if m.bias is not None: - torch.nn.init.normal_(m.bias, mean=0, std=0.001) + torch.nn.init.normal_(m.bias, mean=0, std=0.1) for block in self.fe_blocks: block.apply(init_weights_final) diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index e85acc6c7..4b968c7de 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -110,15 +110,19 @@ def forward(self, *args): aux = args[1] elif len(args) > 2: aux = args[-1] - noise_emb = args[1] if self.with_noise_conditioning else None + noise_emb = args[2] if self.with_noise_conditioning else None + gate = None for i, layer in enumerate(self.layers): - if isinstance(layer, LinearNormConditioning): - x = layer(x, noise_emb) # noise embedding - else: - x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + # Apply noise conditioning after layer norm (first layer), mirroring + # the AdaLN-Zero pattern used in MultiSelfAttentionHead + if i == 0 and self.with_noise_conditioning: + x, gate = self.noise_conditioning(x, noise_emb) if self.with_residual: + if gate is not None: + x = x * gate if x.shape[-1] == x_in.shape[-1]: x = x_in + x else: diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 7f8ab0357..dcdee3059 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -56,6 +56,8 @@ def __init__( for name, params in loss_fcts.items() ] + self.random_target = None + def _get_noise_weight(self, eta): sigma = (eta * self.p_std + self.p_mean).exp() return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 @@ -101,7 +103,8 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: fsteps = len(target_tokens_all) # During validation, use unweighted loss (no noise-level scaling) - noise_weight = 1.0 if self.stage == "val" else self._get_noise_weight(eta) + # noise_weight = 1.0 if self.stage == "val" else self._get_noise_weight(eta) + noise_weight = 1.0 fstep_loss_weights = self._get_fstep_weights(fsteps) loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) @@ -115,6 +118,18 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: # if forecast_offset==0, then the timepoints correspond. # Otherwise targets don't encode the source timestep, so we don't need to skip for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: + + + + # Try random fixed target + if self.random_target is None: + # self.random_target = torch.randn_like(target_tokens) * 1.0 + 1 + self.random_target = torch.ones_like(target_tokens) + target_tokens = self.random_target + + print("pred std", pred_tokens.std().item(), "pred mean", pred_tokens.mean().item()) + print("trgt std", target_tokens.std().item(), "trgt mean", target_tokens.mean().item(), ) + loss_lfct = self._loss_per_loss_function( loss_fct, target=target_tokens, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 30074e207..f6dd1fd47 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -263,7 +263,8 @@ def write_output( base_plot_dir.mkdir(parents=True, exist_ok=True) plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) # headline_channels = {"2t", "z500", "q850", "10u", "10v"} - headline_channels = {"2t", "q850"} + # headline_channels = {"2t", "q850"} + headline_channels = {"z500"} t_idx = 0 for stream_idx, stream_info in enumerate(cf.streams): From 2490a31071cf57b2d2349ea853c34c51c347b570 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 31 Mar 2026 12:03:46 +0200 Subject: [PATCH 251/344] implement date-time conditioning data flow --- config/config_diffusion.yml | 13 +-- src/weathergen/datasets/batch.py | 2 +- .../datasets/multi_stream_data_sampler.py | 31 ++++++- src/weathergen/model/attention.py | 28 ++++-- src/weathergen/model/blocks.py | 12 +-- src/weathergen/model/diffusion.py | 88 ++++++++++++------- src/weathergen/model/engines.py | 7 +- src/weathergen/model/model.py | 10 +-- src/weathergen/model/norms.py | 33 +++++++ src/weathergen/train/trainer.py | 2 - src/weathergen/utils/validation_io.py | 2 +- 11 files changed, 162 insertions(+), 66 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index b951153b9..822ef0713 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,13 +60,14 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True +fe_diffusion_model_conditioning: "date_time" # options: "date_time", "forecast_step", "none" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 with_step_conditioning: True # False # Diffusion related parameters -diffusion_conditioning_embed_dim: 4 +diffusion_conditioning_embed_dim: 32 # Multi-frequency calendar embedding (8 frequencies × 4 components) frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 @@ -138,7 +139,7 @@ train_logging: # parameters for data loading data_loading : - num_workers: 0 #12 + num_workers: 12 rng_seed: ??? repeat_data_in_mini_epoch : True @@ -159,7 +160,7 @@ training_config: shuffle: True start_date: 1979-06-01T00:00 - end_date: 2022-12-31T18:00 + end_date: 1979-06-01T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -231,9 +232,9 @@ validation_config: samples_per_mini_epoch: 1 shuffle: True # TODO: Set back to False - start_date: 2023-10-01T00:00 - end_date: 2023-12-31T18:00 - + start_date: 1979-06-01T00:00 + end_date: 1979-06-01T18:00 + # whether to track the exponential moving average of weights for validation validate_with_ema: enabled : True diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 0685957f0..63092d614 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -96,7 +96,7 @@ def add_meta_info(self, stream_name: str, meta_info: SampleMetaData) -> None: """ Add metadata for stream @stream_name to sample """ - print(meta_info.__dict__) + self.meta_info[stream_name] = meta_info def get_stream_data(self, stream_name: str) -> StreamData: diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index cf545703d..11bd6ad61 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -92,6 +92,8 @@ def __init__( self.rank = cf.rank self.world_size = cf.world_size + self.diffusion_model_conditioning = cf.fe_diffusion_model_conditioning + self.healpix_level: int = cf.healpix_level self.num_healpix_cells: int = 12 * 4**self.healpix_level @@ -587,6 +589,7 @@ def _get_source_target_masks(self, training_mode): self.mode_cfg, stream_info, ) + # identical for all streams num_target_samples = len(masks[stream_info["name"]][0]) num_source_samples = len(masks[stream_info["name"]][1]) @@ -623,7 +626,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): # get/coordinate masks masks_streams, num_source_samples, num_target_samples = self._get_source_target_masks(mode) - + source_select, target_select = [], [] if "masking" in mode: source_select += ["network_input", "target_coords"] @@ -707,6 +710,15 @@ def _get_batch(self, idx: int, num_forecast_steps: int): input_mask=target_mask, ) target_metadata = target_masks.metadata[tidx] + # Add output timestamp to metadata - use actual target times from data + if target_metadata.params is None: + target_metadata.params = {} + # Get first target step's times (using self.output_offset as the first output step index) + if self.diffusion_model_conditioning == "date_time": + target_times_array = sdata.target_times_raw[self.output_offset] + target_metadata.params['timestamp'] = ( + target_times_array[0] if len(target_times_array) > 0 else None + ) # also want to add the mask to the metadata target_metadata.mask = target_mask # Map target to all source students @@ -720,6 +732,23 @@ def _get_batch(self, idx: int, num_forecast_steps: int): target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) + #add target times in source for diffusion model date/time conditioning + if self.diffusion_model_conditioning == "date_time": + #TODO: Might need upgrading fro num_samples > 1 + + # Assert singular source and target samples + assert len(batch.source_samples.samples) == 1, "Only single source sample supported for diffusion model conditioning." + assert len(batch.target_samples.samples) == 1, "Only single target sample supported for diffusion model conditioning." + + source_sample = batch.source_samples.samples[0] + target_sample = batch.target_samples.samples[0] + + # Copy target timestamps to source metadata for all streams + for stream_name in [s["name"] for s in self.streams]: + if stream_name in target_sample.meta_info and stream_name in source_sample.meta_info: + target_timestamp = target_sample.meta_info[stream_name].params.get('timestamp') + source_sample.meta_info[stream_name].params['timestamp'] = target_timestamp + return batch def __iter__(self) -> ModelBatch: diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 004be6e84..6e8a07e3b 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -234,6 +234,7 @@ def __init__( if dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) + self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -268,7 +269,12 @@ def mask_block_local(batch, head, idx_q, idx_kv): def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x - x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + + # Handle ada_ln_aux conditioning + if ada_ln_aux is None: + x = self.lnorm(x) + else: + x = self.lnorm(x, ada_ln_aux) if self.noise_conditioning: x, gate = self.noise_conditioning(x, emb) @@ -287,8 +293,12 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) + + if ada_ln_aux is not None: + out = self.lnorm_final(out, ada_ln_aux) + if self.with_residual: - out = x_in + out * gate if self.noise_conditioning else x_in + out + out = x_in + out * gate if self.noise_conditioning is not None else x_in + out return out @@ -566,16 +576,20 @@ def __init__( if with_noise_conditioning: # NOTE: noise_emb_dim currently hard-coded self.noise_conditioning = LinearNormConditioning( - latent_space_dim=dim_embed, noise_emb_dim=512, dtype=self.dtype + latent_space_dim=dim_embed, dtype=self.dtype ) def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x - x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + + # Handle ada_ln_aux conditioning + if ada_ln_aux is None: + x = self.lnorm(x) + else: + x = self.lnorm(x, ada_ln_aux) if self.noise_conditioning: - assert emb is not None, "Need noise embedding if using noise conditioning" x, gate = self.noise_conditioning(x, emb) # project onto heads and q,k,v and @@ -597,6 +611,10 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) out = self.proj_out(outs.flatten(-2, -1)) + + if ada_ln_aux is not None: + out = self.lnorm_final(out, ada_ln_aux) + if self.with_residual: out = x_in + out * gate if self.noise_conditioning else out + x_in diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index 061928f64..1f815ffdd 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -15,7 +15,7 @@ MultiSelfAttentionHeadVarlen, ) from weathergen.model.layers import MLP -from weathergen.model.norms import AdaLayerNormLayer +from weathergen.model.norms import AdaLayerNorm from weathergen.utils.utils import get_dtype @@ -37,7 +37,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs **kwargs["attention_kwargs"], ) if self.with_adanorm: - self.mhsa_block = AdaLayerNormLayer(dim, dim_aux, self.mhsa, dropout_rate) + self.mhsa_block = AdaLayerNorm(dim, dim_aux, self.mhsa, dropout_rate) else: self.ln_sa = nn.LayerNorm(dim, eps=kwargs["attention_kwargs"]["norm_eps"]) self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x @@ -53,7 +53,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) - self.mlp_block = AdaLayerNormLayer(dim, dim_aux, self.mlp_fn, dropout_rate) + self.mlp_block = AdaLayerNorm(dim, dim_aux, self.mlp_fn, dropout_rate) else: self.ln_mlp = nn.LayerNorm(norm_eps=kwargs["attention_kwargs"]["norm_eps"]) self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x), None, **kwargs) + x @@ -114,7 +114,7 @@ def __init__( **kwargs["attention_kwargs"], ) if self.with_adanorm: - self.mhsa_block = AdaLayerNormLayer(dim_q, dim_aux, self.mhsa, dropout_rate) + self.mhsa_block = AdaLayerNorm(dim_q, dim_aux, self.mhsa, dropout_rate) else: self.ln_sa = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x @@ -127,7 +127,7 @@ def __init__( **kwargs["attention_kwargs"], ) if self.with_adanorm: - self.cross_attn_block = AdaLayerNormLayer(dim_q, dim_aux, self.cross_attn, dropout_rate) + self.cross_attn_block = AdaLayerNorm(dim_q, dim_aux, self.cross_attn, dropout_rate) else: self.ln_ca = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) self.cross_attn_block = ( @@ -145,7 +145,7 @@ def __init__( ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) - self.mlp_block = AdaLayerNormLayer(dim_q, dim_aux, self.mlp_fn, dropout_rate) + self.mlp_block = AdaLayerNorm(dim_q, dim_aux, self.mlp_fn, dropout_rate) else: self.ln_mlp = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x)) + x diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index b14092815..18a2e7ef2 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -84,7 +84,7 @@ def forward( self.cur_token = tokens - c = torch.tensor([meta_info["ERA5"].params["datetime"]], device=tokens.device) # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) + c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) y = tokens if self.training: @@ -100,6 +100,8 @@ def forward( self._noised_tokens = y + n + print(f"date was: {c}") + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> torch.Tensor: @@ -119,8 +121,8 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> # Precondition input and feed through network x = self.preconditioner.precondition(x, c) #currently does nothing - c = self.datetime_embedder(c) - + c = self.datetime_embedder(c).to(x.device) + return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, noise_emb=noise_emb, ada_ln_aux=c ) # Eq. (7) in EDM paper @@ -227,56 +229,80 @@ def forward(self, t: float): class DateTimeEncoder(torch.nn.Module): """ - Encodes timestamp(s) in seconds since Unix epoch into a 4D vector: - [time_of_day_sin, time_of_day_cos, day_of_year_sin, day_of_year_cos] - + Encodes timestamp(s) into multi-frequency sinusoidal calendar embeddings. + + Inspired by cBottle (Climate in a Bottle) with k=1..8 frequency scales. + Captures seasonal (day-of-year) and diurnal (time-of-day) cycles at multiple timescales. + Input shape: scalar or any tensor shape (...) - Output shape: (..., 4) + Output shape: (..., 32) — 8 frequencies × 4 components (cos/sin per signal) + + Output structure for k=1..8: + [cos(2πk·doy/365.25), sin(2πk·doy/365.25), cos(k·t), sin(k·t)] + where: + - doy = day of year (0-365.25) + - t = 2π·seconds_of_day/86400 (time of day in radians, UTC) """ def __init__(self): super().__init__() + self.num_frequencies = 8 - def forward(self, timestamp: torch.Tensor | np.ndarray) -> torch.Tensor: + def forward(self, timestamp: np.ndarray) -> torch.Tensor: """ - Encode datetime64 timestamp into a 4D vector: - [time_of_day_sin, time_of_day_cos, day_of_year_sin, day_of_year_cos] + Encode numpy datetime64 timestamps into 32D multi-frequency calendar embeddings. - Input: np.datetime64 or torch.Tensor containing datetime64 values - Output: (..., 4) shaped tensor + Args: + timestamp: np.datetime64 scalar or array of timestamps + + Returns: + torch.Tensor of shape (..., 32) containing multi-frequency embeddings """ - # Convert to numpy if needed - if isinstance(timestamp, torch.Tensor): - timestamp = timestamp.detach().cpu().numpy() - - # Ensure datetime64[s] precision - timestamp = timestamp.astype('datetime64[s]') + + # TODO: Consider adding local time encoding (e.g., using longitude) + orig_shape = timestamp.shape timestamp_flat = timestamp.reshape(-1) two_pi = 2.0 * np.pi - # --- Time of day from seconds since epoch --- + # --- Extract time components --- ts_int64 = timestamp_flat.astype('int64') # seconds since Unix epoch seconds_in_day = 86400.0 - time_of_day = (ts_int64 % int(seconds_in_day)) / seconds_in_day - tod_sin = np.sin(two_pi * time_of_day).astype(np.float32) - tod_cos = np.cos(two_pi * time_of_day).astype(np.float32) + seconds_of_day = (ts_int64 % int(seconds_in_day)) / seconds_in_day # [0, 1) - # --- Day of year --- + # --- Extract day of year --- day_np = timestamp_flat.astype('datetime64[D]') year_start = day_np.astype('datetime64[Y]').astype('datetime64[D]') next_year_start = (day_np.astype('datetime64[Y]') + np.timedelta64(1, 'Y')).astype('datetime64[D]') - day_of_year_0 = (day_np - year_start).astype(np.int64) - days_in_year = (next_year_start - year_start).astype(np.int64) - doy_frac = day_of_year_0.astype(np.float32) / days_in_year.astype(np.float32) + day_of_year_0 = (day_np - year_start).astype(np.int64) # [0, 365] or [0, 366] + days_in_year = (next_year_start - year_start).astype(np.int64) # 365 or 366 + doy_frac = day_of_year_0.astype(np.float32) / days_in_year.astype(np.float32) # [0, 1) - doy_sin = np.sin(two_pi * doy_frac).astype(np.float32) - doy_cos = np.cos(two_pi * doy_frac).astype(np.float32) + # --- Multi-frequency sinusoidal embeddings --- + # Build output for all 8 frequency scales + embeddings = [] + for k in range(1, self.num_frequencies + 1): + k_float = float(k) + + # Day-of-year components: cos(2π·k·doy/365.25), sin(2π·k·doy/365.25) + doy_phase = two_pi * k_float * doy_frac + doy_cos = np.cos(doy_phase).astype(np.float32) + doy_sin = np.sin(doy_phase).astype(np.float32) + + # Time-of-day components: cos(k·t), sin(k·t) where t = 2π·seconds_of_day + tot_phase = k_float * two_pi * seconds_of_day + tot_cos = np.cos(tot_phase).astype(np.float32) + tot_sin = np.sin(tot_phase).astype(np.float32) + + embeddings.append(doy_cos) + embeddings.append(doy_sin) + embeddings.append(tot_cos) + embeddings.append(tot_sin) - # Stack and convert to tensor - out = np.stack([tod_sin, tod_cos, doy_sin, doy_cos], axis=-1) + # Stack all components: (N, 32) + out = np.stack(embeddings, axis=-1) out = torch.from_numpy(out).float() - return out.reshape(*orig_shape, 4) \ No newline at end of file + return out.reshape(*orig_shape, self.num_frequencies * 4) \ No newline at end of file diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index dbc8c36fe..ae89a7f26 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -505,10 +505,9 @@ def forward( if forecast_residual: tokens_in = tokens - aux_info = None if self.cf.fe_diffusion_model: - assert noise_emb is not None, ( - "Noise embedding must be provided for diffusion forecast engine" + assert ada_ln_aux is not None, ( + "Conditioning (noise and other) must be provided for diffusion forecast engine" ) for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): @@ -520,7 +519,7 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) + tokens = checkpoint(block, tokens, coords, ada_ln_aux, use_reentrant=False) return tokens if not forecast_residual else (tokens_in + tokens) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 95f7107c1..003a0edc6 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -668,19 +668,11 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: tokens = (tokens - t_mean) / (t_std + 1e-6) * self.cf.sigma_data tokens = torch.clamp(tokens, -100.0, 100.0) - breakpoint() - # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): # apply forecasting engine (if present) if self.forecast_engine: - - #TODO: move this to the appropriate place in batch consturction - batch.samples[0].meta_info['ERA5'].add_global_params({'datetime': batch.samples[0].streams_data['ERA5'].source_raw[0].datetimes[0]}) - print(f'added {batch.samples[0].streams_data['ERA5'].source_raw[0].datetimes[0]}') - # add_global_params({'datetime': batch.samples[0].streams_data[self.stream_names[0]].source_raw[0].datetimes[0]}) - - breakpoint() + tokens = self.forecast_engine( tokens, step, diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index cbb5f7ba2..78187087f 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -114,6 +114,39 @@ def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Ten x = self.norm(x) * (1 + scale) return x + +# NOTE: Inspired by GenCast/DiT. +class LinearNormConditioning(torch.nn.Module): + """Module for norm conditioning, adapted from GenCast with additional gate parameter from DiT. + + Conditions the normalization of `inputs` by applying a linear layer to the + `norm_conditioning` which produces the scale and offset for each channel. + """ + + def __init__(self, latent_space_dim: int, noise_emb_dim: int = 512, dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + + self.conditional_linear_layer = torch.nn.Linear( + in_features=noise_emb_dim, + out_features=3 * latent_space_dim, + ) + # Optional: initialize weights similar to TruncatedNormal(stddev=1e-8) + torch.nn.init.normal_(self.conditional_linear_layer.weight, std=1e-8) + torch.nn.init.zeros_(self.conditional_linear_layer.bias) + + def forward(self, inputs, noise_emb): + conditional_scale_offset = self.conditional_linear_layer(noise_emb.to(self.dtype)) + scale_minus_one, offset, gate = torch.chunk(conditional_scale_offset, 3, dim=-1) + scale = scale_minus_one + 1.0 + + # Reshape scale and offset for broadcasting if needed + while scale.dim() < inputs.dim(): + scale = scale.unsqueeze(1) + offset = offset.unsqueeze(1) + return (inputs * scale + offset).to( + self.dtype + ), gate # TODO: check if to(self.dtype) needed here diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3c77bdbe8..ce78885c4 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -434,8 +434,6 @@ def train(self, mini_epoch): batch.to_device(self.device) - breakpoint() - with torch.autocast( device_type=f"cuda:{cf.local_rank}", dtype=self.mixed_precision_dtype, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 64df4905a..712566683 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -94,7 +94,7 @@ def write_output( if loss_term.type == "LossPhysical" ] assert len(outputs_physical) == 1 - breakpoint() + target_aux_out = target_aux_out[outputs_physical[0]] # collect all target / prediction-related information From 77e248c5aab938f57d40abd8164d05e26f15a124 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 31 Mar 2026 12:37:12 +0200 Subject: [PATCH 252/344] change config --- config/config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 822ef0713..c7ba41557 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -55,7 +55,7 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 4 +fe_num_blocks: 2 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True From 2153713e5b4bd1669c3580964c12c6e2b31f6b43 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 31 Mar 2026 13:57:59 +0200 Subject: [PATCH 253/344] apply PR review changes --- src/weathergen/datasets/batch.py | 4 ++++ src/weathergen/datasets/masking.py | 2 -- .../datasets/multi_stream_data_sampler.py | 14 ++++++-------- src/weathergen/model/attention.py | 5 +++-- src/weathergen/model/diffusion.py | 2 +- src/weathergen/model/model.py | 9 +++++---- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 63092d614..2ce085d65 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -30,6 +30,10 @@ def add_global_params(self, params: dict) -> None: self.global_params = {} self.global_params.update(params) + def add_params(self, params: dict) -> None: + if self.params is None: + self.params = {} + self.params.update(params) class Sample: # keys: stream name, values: SampleMetaData diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index abe959cc1..0a989178e 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -37,8 +37,6 @@ def add_mask(self, mask, params, cfg, losses, idx, correspondence, relationship) } if "noise_level_rn" in params: global_params["noise_level_rn"] = params["noise_level_rn"] - print(mask) - print(params) self.masks += [mask] self.metadata += [ SampleMetaData( diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 11bd6ad61..4f8163701 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -626,7 +626,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): # get/coordinate masks masks_streams, num_source_samples, num_target_samples = self._get_source_target_masks(mode) - + source_select, target_select = [], [] if "masking" in mode: source_select += ["network_input", "target_coords"] @@ -710,15 +710,13 @@ def _get_batch(self, idx: int, num_forecast_steps: int): input_mask=target_mask, ) target_metadata = target_masks.metadata[tidx] - # Add output timestamp to metadata - use actual target times from data - if target_metadata.params is None: - target_metadata.params = {} + # Get first target step's times (using self.output_offset as the first output step index) if self.diffusion_model_conditioning == "date_time": target_times_array = sdata.target_times_raw[self.output_offset] - target_metadata.params['timestamp'] = ( + target_metadata.add_params({'timestamp': ( target_times_array[0] if len(target_times_array) > 0 else None - ) + )}) # also want to add the mask to the metadata target_metadata.mask = target_mask # Map target to all source students @@ -747,8 +745,8 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for stream_name in [s["name"] for s in self.streams]: if stream_name in target_sample.meta_info and stream_name in source_sample.meta_info: target_timestamp = target_sample.meta_info[stream_name].params.get('timestamp') - source_sample.meta_info[stream_name].params['timestamp'] = target_timestamp - + source_sample.meta_info[stream_name].add_params({'timestamp': target_timestamp}) + return batch def __iter__(self) -> ModelBatch: diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 6e8a07e3b..8c32d14c5 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -277,6 +277,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): x = self.lnorm(x, ada_ln_aux) if self.noise_conditioning: + assert emb is not None, "Need noise embedding if using noise conditioning" x, gate = self.noise_conditioning(x, emb) # project onto heads @@ -298,7 +299,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.lnorm_final(out, ada_ln_aux) if self.with_residual: - out = x_in + out * gate if self.noise_conditioning is not None else x_in + out + out = x_in + out * gate if self.noise_conditioning else x_in + out return out @@ -574,7 +575,6 @@ def __init__( self.noise_conditioning = None if with_noise_conditioning: - # NOTE: noise_emb_dim currently hard-coded self.noise_conditioning = LinearNormConditioning( latent_space_dim=dim_embed, dtype=self.dtype ) @@ -590,6 +590,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): x = self.lnorm(x, ada_ln_aux) if self.noise_conditioning: + assert emb is not None, "Need noise embedding if using noise conditioning" x, gate = self.noise_conditioning(x, emb) # project onto heads and q,k,v and diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 18a2e7ef2..e68726bc6 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -100,7 +100,7 @@ def forward( self._noised_tokens = y + n - print(f"date was: {c}") + logger.info(f"Conditioning on date: {c}") return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 003a0edc6..5328ce1d3 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -389,13 +389,14 @@ def create(self) -> "Model": mode_cfg = cf.training_config self.forecast_engine = None if cf.fe_num_blocks > 0: - if cf.get("fe_diffusion_model", False): + if cf.get("diffusion_conditioning_embed_dim", None) is None: self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells, dim_aux=self.cf.diffusion_conditioning_embed_dim) - self.forecast_engine = DiffusionForecastEngine( - cf, self.num_healpix_cells, forecast_engine=self.forecast_engine - ) else: self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) + if cf.get("fe_diffusion_model", False): + self.forecast_engine = DiffusionForecastEngine( + cf, self.num_healpix_cells, forecast_engine=self.forecast_engine + ) # embed coordinates yielding one query token for each target token dropout_rate = cf.embed_dropout_rate From e8664abec9d4db18cb79b2178fc52a31a446a358 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Tue, 31 Mar 2026 23:37:56 +0200 Subject: [PATCH 254/344] Finish SwiGLU implementation --- config/default_config.yml | 1 + src/weathergen/model/blocks.py | 8 +++++- src/weathergen/model/embeddings.py | 3 +++ src/weathergen/model/engines.py | 39 ++++++++++++++++++++++++--- src/weathergen/model/layers.py | 34 ++++++++++++++++------- src/weathergen/model/model.py | 4 +++ src/weathergen/train/teacher_utils.py | 9 +++++-- 7 files changed, 82 insertions(+), 16 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index d62b575ce..dd922d17c 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -10,6 +10,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 +mlp_type: mlp ae_local_dim_embed: 1024 ae_local_num_blocks: 2 diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index 061928f64..17fe729d7 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -25,7 +25,7 @@ class SelfAttentionBlock(nn.Module): layer norm with a FFN. """ - def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs): + def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, mlp_type="mlp", **kwargs): super().__init__() self.with_adanorm = with_adanorm @@ -48,6 +48,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs dim_out=dim, hidden_factor=4, dropout_rate=0.1, + mlp_type=mlp_type, nonlin=approx_gelu, with_residual=False, ) @@ -98,6 +99,7 @@ def __init__( with_mlp, num_heads, dropout_rate, + mlp_type="mlp", **kwargs, ): super().__init__() @@ -140,6 +142,7 @@ def __init__( dim_in=dim_q, dim_out=dim_q, hidden_factor=4, + mlp_type=mlp_type, nonlin=approx_gelu, with_residual=False, ) @@ -189,6 +192,7 @@ def __init__( attention_kwargs, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, tro_type, mlp_norm_eps=1e-6, ): @@ -198,6 +202,7 @@ def __init__( self.tro_type = tro_type self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor + self.tr_mlp_type = tr_mlp_type self.block = nn.ModuleList() @@ -244,6 +249,7 @@ def __init__( with_residual=True, hidden_factor=self.tr_mlp_hidden_factor, dropout_rate=0.1, # Assuming dropout_rate is 0.1 + mlp_type=self.tr_mlp_type, norm_type=self.cf.norm_type, dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), norm_eps=self.cf.mlp_norm_eps, diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index 90fbcf714..8e6a84cc1 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -32,6 +32,7 @@ def __init__( num_heads, dropout_rate=0.0, norm_type="LayerNorm", + mlp_type="mlp", unembed_mode="full", stream_name="stream_embed", ): @@ -56,6 +57,7 @@ def __init__( self.dim_out = dim_out self.num_blocks = num_blocks self.num_heads = num_heads + self.mlp_type = mlp_type self.unembed_mode = unembed_mode norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -77,6 +79,7 @@ def __init__( self.dim_embed, hidden_factor=2, dropout_rate=dropout_rate, + mlp_type=self.mlp_type, with_residual=True, ) ) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 4419c2955..b7c274d06 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -66,6 +66,7 @@ def __init__(self, cf: Config, sources_size) -> None: num_heads=si["embed"]["num_heads"], dropout_rate=self.cf.embed_dropout_rate, norm_type=self.cf.norm_type, + mlp_type=self.cf.get("mlp_type", "mlp"), unembed_mode=self.cf.embed_unembed_mode, stream_name=stream_name, ) @@ -157,6 +158,7 @@ def __init__(self, cf: Config) -> None: self.cf.ae_local_dim_embed, with_residual=True, dropout_rate=self.cf.ae_local_dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -206,6 +208,7 @@ def __init__(self, cf: Config) -> None: self.cf.ae_global_dim_embed, with_residual=True, dropout_rate=self.cf.ae_adapter_dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -300,6 +303,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.ae_aggregation_dropout_rate, hidden_factor=self.cf.ae_aggregation_mlp_hidden_factor, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -374,6 +378,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.ae_global_dropout_rate, hidden_factor=self.cf.ae_global_mlp_hidden_factor, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -448,6 +453,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.cf.ae_global_dim_embed, with_residual=True, dropout_rate=self.cf.fe_dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, @@ -547,6 +553,7 @@ def __init__( dim_coord_in, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, softcap, stream_config: dict, ): @@ -568,6 +575,7 @@ def __init__( self.dim_coord_in = dim_coord_in self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor + self.tr_mlp_type = tr_mlp_type self.softcap = softcap self.tte = torch.nn.ModuleList() @@ -615,6 +623,7 @@ def __init__( with_residual=True, hidden_factor=self.tr_mlp_hidden_factor, dropout_rate=0.1, # Assuming dropout_rate is 0.1 + mlp_type=self.tr_mlp_type, norm_type=self.cf.norm_type, dim_aux=(self.dim_coord_in if self.cf.pred_mlp_adaln else None), norm_eps=self.cf.mlp_norm_eps, @@ -652,8 +661,9 @@ def __init__( dim_coord_in, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, softcap, - stream_name: str, + stream_config: dict, ): """ Initialize the TargetPredictionEngine with the configuration. @@ -676,13 +686,14 @@ def __init__( LayerNorm that does not scale after the layer is applied """ super(TargetPredictionEngine, self).__init__() - self.name = f"TargetPredictionEngine_{stream_name}" + self.name = f"TargetPredictionEngine_{stream_config['name']}" self.cf = cf self.dims_embed = dims_embed self.dim_coord_in = dim_coord_in self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor + self.tr_mlp_type = tr_mlp_type self.softcap = softcap # For backwards compatibility @@ -721,6 +732,7 @@ def __init__( with_self_attn=False, with_adanorm=False, with_mlp=False, + mlp_type=self.tr_mlp_type, attention_kwargs=attention_kwargs, ) ) @@ -733,6 +745,7 @@ def __init__( attention_kwargs=attention_kwargs, with_adanorm=True, dropout_rate=0.1, + mlp_type=self.tr_mlp_type, ) ) elif self.cf.decoder_type == "CrossAttentionConditioning": @@ -746,6 +759,7 @@ def __init__( with_adanorm=False, with_mlp=True, dropout_rate=0.1, + mlp_type=self.tr_mlp_type, attention_kwargs=attention_kwargs, ) ) @@ -760,6 +774,7 @@ def __init__( with_adanorm=True, with_mlp=True, dropout_rate=0.1, + mlp_type=self.tr_mlp_type, attention_kwargs=attention_kwargs, ) ) @@ -775,6 +790,7 @@ def __init__( attention_kwargs=attention_kwargs, tr_dim_head_proj=tr_dim_head_proj, tr_mlp_hidden_factor=tr_mlp_hidden_factor, + tr_mlp_type=tr_mlp_type, mlp_norm_eps=self.cf.mlp_norm_eps, ) ) @@ -892,6 +908,7 @@ def __init__( hidden_factor=4, with_residual=True, dropout_rate=dropout_rate, + mlp_type=loss_conf.get("mlp_type", self.global_cf.get("mlp_type", "mlp")), norm_type=self.global_cf.norm_type, # dim_aux=dim_aux, norm_eps=self.global_cf.mlp_norm_eps, @@ -931,7 +948,15 @@ def forward(self, x: LatentState): class LatentPredictionHeadMLP(nn.Module): - def __init__(self, name, in_dim: int, loss_conf, use_class_token: bool, use_patch_token: bool): + def __init__( + self, + name, + in_dim: int, + loss_conf, + use_class_token: bool, + use_patch_token: bool, + default_mlp_type: str = "mlp", + ): super().__init__() self.name = name @@ -946,7 +971,13 @@ def __init__(self, name, in_dim: int, loss_conf, use_class_token: bool, use_patc self.use_patch_token = use_patch_token # Create an MLP block - self.blocks = MLP(in_dim, out_dim, num_layers, hidden_factor) + self.blocks = MLP( + in_dim, + out_dim, + num_layers, + hidden_factor, + mlp_type=loss_conf.get("mlp_type", default_mlp_type), + ) def forward(self, x: LatentState): outputs = [] diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..dee5a4757 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -6,12 +6,10 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - - import torch import torch.nn as nn -from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.norms import AdaLayerNorm, RMSNorm, SwiGLU class NamedLinear(torch.nn.Module): @@ -42,6 +40,7 @@ def __init__( norm_type="LayerNorm", dim_aux=None, norm_eps=1e-5, + mlp_type="mlp", name: str | None = None, ): """Constructor""" @@ -55,8 +54,16 @@ def __init__( self.with_residual = with_residual self.with_aux = dim_aux is not None + self.mlp_type = mlp_type.lower() dim_hidden = int(dim_in * hidden_factor) + if self.mlp_type not in {"mlp", "swiglu"}: + raise ValueError(f"Unsupported mlp_type: {mlp_type}") + + if self.mlp_type == "swiglu": + # Align with the standard LLaMA-style SwiGLU hidden-width rule. + dim_hidden = max(1, int(2 * dim_hidden / 3)) + self.layers = torch.nn.ModuleList() norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -68,15 +75,24 @@ def __init__( else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) ) - self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) - self.layers.append(nonlin()) - self.layers.append(torch.nn.Dropout(p=dropout_rate)) - - for _ in range(num_layers - 2): - self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + if self.mlp_type == "swiglu": + self.layers.append(torch.nn.Linear(dim_in, 2 * dim_hidden)) + self.layers.append(SwiGLU()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, 2 * dim_hidden)) + self.layers.append(SwiGLU()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + else: + self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) self.layers.append(nonlin()) self.layers.append(torch.nn.Dropout(p=dropout_rate)) + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) def forward(self, *args): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index e57a75658..fbbbe197d 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -346,6 +346,7 @@ def _create_latent_pred_head( loss_cfg, use_class_token=use_class_token, use_patch_token=use_patch_token, + default_mlp_type=global_cfg.get("mlp_type", "mlp"), ) elif loss_cfg["head"].lower() == "transformer": return LatentPredictionHeadTransformer( @@ -412,6 +413,7 @@ def create(self) -> "Model": tr_mlp_hidden_factor = ( tr["mlp_hidden_factor"] if "mlp_hidden_factor" in tr else 2 ) + tr_mlp_type = tr.get("mlp_type", cf.get("mlp_type", "mlp")) tr_dim_head_proj = tr["dim_head_proj"] if "dim_head_proj" in tr else None softcap = tr["softcap"] if "softcap" in tr else 0.0 @@ -439,6 +441,7 @@ def create(self) -> "Model": hidden_factor=8, with_residual=False, dropout_rate=dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_eps=self.cf.mlp_norm_eps, name=f"embed_target_coords_{stream_name}", ) @@ -465,6 +468,7 @@ def create(self) -> "Model": dim_coord_in, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, softcap, stream_config=si, ) diff --git a/src/weathergen/train/teacher_utils.py b/src/weathergen/train/teacher_utils.py index c026960b5..224505f5d 100644 --- a/src/weathergen/train/teacher_utils.py +++ b/src/weathergen/train/teacher_utils.py @@ -43,7 +43,12 @@ def _create_teacher_heads( if head_type == "mlp": return LatentPredictionHeadMLP( - f"{name}-head", dim_embed, loss_conf, use_class_token, use_patch_token + f"{name}-head", + dim_embed, + loss_conf, + use_class_token, + use_patch_token, + default_mlp_type=(cf.get("mlp_type", "mlp") if cf is not None else "mlp"), ) elif head_type == "transformer": if cf is None: @@ -88,7 +93,7 @@ def prepare_encoder_teacher(model: nn.Module, training_cfg, override_cfg) -> Non elif name in ("iBOT", "DINO"): head_type = conf.get("head", "mlp").lower() model.latent_heads[name] = _create_teacher_heads( - name, head_type, teacher_dim_embed, conf + name, head_type, teacher_dim_embed, conf, cf=override_cfg ) else: logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.") From c1ead6228be719a8df1deec78c6b7c11bb973cc0 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Tue, 31 Mar 2026 23:46:06 +0200 Subject: [PATCH 255/344] Implement XSA --- config/default_config.yml | 1 + src/weathergen/model/attention.py | 30 ++++++++++++++++++++++++++++++ src/weathergen/model/blocks.py | 16 +++++++++++++++- src/weathergen/model/embeddings.py | 3 +++ src/weathergen/model/engines.py | 13 +++++++++++++ 5 files changed, 62 insertions(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index dd922d17c..031ec51be 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,6 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 mlp_type: mlp +use_xsa: False ae_local_dim_embed: 1024 ae_local_num_blocks: 2 diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index f815aba2a..64f63a4d8 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -24,6 +24,16 @@ """ +def _apply_xsa(attn_out: torch.Tensor, self_values: torch.Tensor) -> torch.Tensor: + attn_out_float = attn_out.float() + self_values_float = self_values.float() + denom = self_values_float.pow(2).sum(dim=-1, keepdim=True).clamp_min( + torch.finfo(self_values_float.dtype).eps + ) + proj = (attn_out_float * self_values_float).sum(dim=-1, keepdim=True) / denom + return (attn_out_float - (proj * self_values_float)).to(attn_out.dtype) + + class MultiSelfAttentionHeadVarlen(torch.nn.Module): def __init__( self, @@ -40,6 +50,7 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, + use_xsa=False, ): super(MultiSelfAttentionHeadVarlen, self).__init__() @@ -49,6 +60,7 @@ def __init__( self.softcap = softcap self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -112,6 +124,9 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): dropout_p=dropout_rate, ) + if self.use_xsa: + outs = _apply_xsa(outs, vs) + out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: @@ -134,6 +149,7 @@ def __init__( softcap=0.0, norm_eps=1e-5, attention_dtype=torch.bfloat16, + use_xsa=False, ): super(MultiSelfAttentionHeadVarlenFlex, self).__init__() @@ -141,6 +157,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -188,6 +205,9 @@ def forward(self, x, x_lens=None): outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze() + if self.use_xsa: + outs = _apply_xsa(outs, vs.transpose(1, 2).squeeze()) + out = self.dropout(self.proj_out(outs.flatten(-2, -1))) if self.with_residual: out = out + x_in @@ -213,6 +233,7 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, + use_xsa=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -221,6 +242,7 @@ def __init__( self.softcap = softcap self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -277,6 +299,9 @@ def forward(self, x, coords=None, ada_ln_aux=None): outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) + if self.use_xsa: + outs = _apply_xsa(outs, vs.transpose(1, 2)) + out = self.proj_out(self.dropout(outs.flatten(-2, -1))) if self.with_residual: out = x_in + out @@ -510,6 +535,7 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, + use_xsa=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -519,6 +545,7 @@ def __init__( self.dropout_rate = dropout_rate self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -574,6 +601,9 @@ def forward(self, x, coords=None, ada_ln_aux=None): # ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt) outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) + if self.use_xsa: + outs = _apply_xsa(outs, vs) + out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: out = out + x_in diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index 17fe729d7..5e86deac8 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -25,7 +25,17 @@ class SelfAttentionBlock(nn.Module): layer norm with a FFN. """ - def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, mlp_type="mlp", **kwargs): + def __init__( + self, + dim, + dim_aux, + with_adanorm, + num_heads, + dropout_rate, + mlp_type="mlp", + use_xsa=False, + **kwargs, + ): super().__init__() self.with_adanorm = with_adanorm @@ -34,6 +44,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, mlp_type dim_embed=dim, num_heads=num_heads, with_residual=False, + use_xsa=use_xsa, **kwargs["attention_kwargs"], ) if self.with_adanorm: @@ -100,6 +111,7 @@ def __init__( num_heads, dropout_rate, mlp_type="mlp", + use_xsa=False, **kwargs, ): super().__init__() @@ -113,6 +125,7 @@ def __init__( dim_embed=dim_q, num_heads=num_heads, with_residual=False, + use_xsa=use_xsa, **kwargs["attention_kwargs"], ) if self.with_adanorm: @@ -238,6 +251,7 @@ def __init__( dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + use_xsa=self.cf.get("use_xsa", False), ) ) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index 8e6a84cc1..7b9137cff 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -33,6 +33,7 @@ def __init__( dropout_rate=0.0, norm_type="LayerNorm", mlp_type="mlp", + use_xsa=False, unembed_mode="full", stream_name="stream_embed", ): @@ -58,6 +59,7 @@ def __init__( self.num_blocks = num_blocks self.num_heads = num_heads self.mlp_type = mlp_type + self.use_xsa = use_xsa self.unembed_mode = unembed_mode norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -71,6 +73,7 @@ def __init__( dropout_rate=dropout_rate, with_qk_lnorm=True, with_flash=True, + use_xsa=self.use_xsa, ) ) self.layers.append( diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index b7c274d06..8ae07f5bb 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -67,6 +67,7 @@ def __init__(self, cf: Config, sources_size) -> None: dropout_rate=self.cf.embed_dropout_rate, norm_type=self.cf.norm_type, mlp_type=self.cf.get("mlp_type", "mlp"), + use_xsa=self.cf.get("use_xsa", False), unembed_mode=self.cf.embed_unembed_mode, stream_name=stream_name, ) @@ -147,6 +148,7 @@ def __init__(self, cf: Config) -> None: dropout_rate=self.cf.ae_local_dropout_rate, with_qk_lnorm=self.cf.ae_local_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), @@ -273,6 +275,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_aggregation_dropout_rate, with_qk_lnorm=self.cf.ae_aggregation_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), @@ -290,6 +293,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_aggregation_dropout_rate, with_qk_lnorm=self.cf.ae_aggregation_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), @@ -348,6 +352,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_global_dropout_rate, with_qk_lnorm=self.cf.ae_global_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), @@ -364,6 +369,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_global_dropout_rate, with_qk_lnorm=self.cf.ae_global_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), @@ -422,6 +428,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dropout_rate=self.cf.fe_dropout_rate, with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.norm_eps, @@ -439,6 +446,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dropout_rate=self.cf.fe_dropout_rate, with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.norm_eps, @@ -608,6 +616,7 @@ def __init__( dropout_rate=0.1, # Assuming dropout_rate is 0.1 with_qk_lnorm=True, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, dim_aux=self.dim_coord_in, norm_eps=self.cf.norm_eps, @@ -746,6 +755,7 @@ def __init__( with_adanorm=True, dropout_rate=0.1, mlp_type=self.tr_mlp_type, + use_xsa=self.cf.get("use_xsa", False), ) ) elif self.cf.decoder_type == "CrossAttentionConditioning": @@ -760,6 +770,7 @@ def __init__( with_mlp=True, dropout_rate=0.1, mlp_type=self.tr_mlp_type, + use_xsa=self.cf.get("use_xsa", False), attention_kwargs=attention_kwargs, ) ) @@ -775,6 +786,7 @@ def __init__( with_mlp=True, dropout_rate=0.1, mlp_type=self.tr_mlp_type, + use_xsa=self.cf.get("use_xsa", False), attention_kwargs=attention_kwargs, ) ) @@ -894,6 +906,7 @@ def __init__( dropout_rate=dropout_rate, with_qk_lnorm=with_qk_lnorm, with_flash=self.global_cf.with_flash_attention, + use_xsa=self.global_cf.get("use_xsa", False), norm_type=self.global_cf.norm_type, # dim_aux=dim_aux, norm_eps=self.global_cf.norm_eps, From 76aeadaa6512585d0330c0e8275be67fae594743 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 10:13:55 +0200 Subject: [PATCH 256/344] fixed MLP implementation --- src/weathergen/model/attention.py | 2 ++ src/weathergen/model/engines.py | 1 + src/weathergen/model/layers.py | 33 +++++++++++++++++++++++++------ src/weathergen/model/model.py | 5 ++++- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 8c32d14c5..39b5e3b73 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -280,6 +280,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): assert emb is not None, "Need noise embedding if using noise conditioning" x, gate = self.noise_conditioning(x, emb) + # project onto heads s = [x.shape[0], x.shape[1], self.num_heads, -1] qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) @@ -593,6 +594,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): assert emb is not None, "Need noise embedding if using noise conditioning" x, gate = self.noise_conditioning(x, emb) + # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index ae89a7f26..befeefb21 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -449,6 +449,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.cf.ae_global_dim_embed, self.cf.ae_global_dim_embed, with_residual=True, + post_layer_norm=cf.fe_diffusion_model_conditioning in ["date_time"], dropout_rate=self.cf.fe_dropout_rate, norm_type=self.cf.norm_type, dim_aux=dim_aux, diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index e85acc6c7..bcc6e4952 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.norms import AdaLayerNorm, AdaLayerNormFinal, RMSNorm class NamedLinear(torch.nn.Module): @@ -53,6 +53,7 @@ def __init__( num_layers=2, hidden_factor=2, pre_layer_norm=True, + post_layer_norm=False, dropout_rate=0.0, nonlin=torch.nn.GELU, with_residual=False, @@ -90,7 +91,7 @@ def __init__( if with_noise_conditioning: self.noise_conditioning = LinearNormConditioning( dim_in - ) # TODO: chech if should pass some dtype? + ) # TODO: check if should pass some dtype? self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) self.layers.append(nonlin()) @@ -103,22 +104,42 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + if post_layer_norm: + self.layers.append( + norm(dim_out, eps=norm_eps) + if dim_aux is None + else AdaLayerNormFinal(dim_out, dim_aux, norm_eps=norm_eps) + ) + # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): x, x_in = args[0], args[0] + if len(args) < 2 and self.with_aux: + raise ValueError("Auxiliary input required but not provided") if len(args) == 2: aux = args[1] elif len(args) > 2: aux = args[-1] - noise_emb = args[1] if self.with_noise_conditioning else None + noise_emb = args[2] if self.with_noise_conditioning else None + gate = None for i, layer in enumerate(self.layers): - if isinstance(layer, LinearNormConditioning): - x = layer(x, noise_emb) # noise embedding + if i == 0 and self.with_aux: + if isinstance(layer, (AdaLayerNorm)): + x = layer(x, aux) + if self.with_noise_conditioning: + x, gate = self.noise_conditioning(x, noise_emb) else: - x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + if i == 0 and self.with_noise_conditioning: + x, gate = self.noise_conditioning(x, noise_emb) + if isinstance(layer, (AdaLayerNormFinal)): + x = layer(x, aux) + else: + x = layer(x) if self.with_residual: + if gate is not None: + x = x * gate if x.shape[-1] == x_in.shape[-1]: x = x_in + x else: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 5328ce1d3..8219511a7 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -389,7 +389,10 @@ def create(self) -> "Model": mode_cfg = cf.training_config self.forecast_engine = None if cf.fe_num_blocks > 0: - if cf.get("diffusion_conditioning_embed_dim", None) is None: + if cf.get("fe_diffusion_model_conditioning", None) in ["date_time"]: + assert cf.diffusion_conditioning_embed_dim is not None, ( + "Diffusion conditioning embedding dimension must be specified when using diffusion model conditioning" + ) self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells, dim_aux=self.cf.diffusion_conditioning_embed_dim) else: self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) From 8658f6955ca4bbcdab914b9451e484c63cc083c3 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 10:25:02 +0200 Subject: [PATCH 257/344] re-added adalayernormlayer --- src/weathergen/model/norms.py | 79 ++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 78187087f..bebee67c0 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -60,10 +60,30 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class SwiGLU(nn.Module): + def __init__(self): + super(SwiGLU, self).__init__() + + def forward(self, x): + x1, x2 = x.chunk(2, dim=-1) + return x2 * F.silu(x1) + class AdaLayerNorm(torch.nn.Module): """ - AdaLayerNorm for embedding auxiliary information + AdaLayerNorm for embedding auxiliary information as done in DiT (Peebles & Xie) with zero + initialisation https://arxiv.org/pdf/2212.09748 + + This module thus wraps a layer (e.g. self-attention or feedforward nn) and applies LayerNorm + followed by scale and shift before the layer and a final scaling after the layer as well as the + final residual layer. + + layer is a function that takes 2 arguments the first the latent and the second is the + conditioning signal """ def __init__( @@ -115,6 +135,63 @@ def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Ten return x +class AdaLayerNormLayer(torch.nn.Module): + """ + AdaLayerNorm for embedding auxiliary information as done in DiT (Peebles & Xie) with zero + initialisation https://arxiv.org/pdf/2212.09748 + + This module thus wraps a layer (e.g. self-attention or feedforward nn) and applies LayerNorm + followed by scale and shift before the layer and a final scaling after the layer as well as the + final residual layer. + + layer is a function that takes 2 arguments the first the latent and the second is the + conditioning signal + """ + + def __init__( + self, + dim, + dim_aux, + layer, + norm_eps: float = 1e-6, + dropout_rate: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_aux, 3 * dim, bias=True)) + + self.ln = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps) + self.layer = layer + + # Initialize weights to zero for modulation and gating layers + self.initialise_weights() + + def initialise_weights(self): + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + + def forward(self, x: torch.Tensor, c: torch.Tensor, x_lens, **kwargs) -> torch.Tensor: + # the -1 in torch.repeat_interleave(..) is because x_lens is designed for use with flash + # attention and thus has a spurious 0 at the beginning to satisfy the flash attention api + shift, scale, gate = self.adaLN_modulation(c)[torch.repeat_interleave(x_lens) - 1].chunk( + 3, dim=1 + ) + kwargs["x_lens"] = x_lens + return ( + gate + * self.layer( + modulate( + self.ln(x), + shift, + scale, + ), + **kwargs, + ) + + x + ) + + # NOTE: Inspired by GenCast/DiT. class LinearNormConditioning(torch.nn.Module): """Module for norm conditioning, adapted from GenCast with additional gate parameter from DiT. From 352f1ab670e8eac1082be363e9620b9892419787 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 10:29:11 +0200 Subject: [PATCH 258/344] change norms --- src/weathergen/model/norms.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index bebee67c0..076efa0e7 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -60,30 +60,10 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight -def modulate(x, shift, scale): - return x * (1 + scale) + shift - - -class SwiGLU(nn.Module): - def __init__(self): - super(SwiGLU, self).__init__() - - def forward(self, x): - x1, x2 = x.chunk(2, dim=-1) - return x2 * F.silu(x1) - class AdaLayerNorm(torch.nn.Module): """ - AdaLayerNorm for embedding auxiliary information as done in DiT (Peebles & Xie) with zero - initialisation https://arxiv.org/pdf/2212.09748 - - This module thus wraps a layer (e.g. self-attention or feedforward nn) and applies LayerNorm - followed by scale and shift before the layer and a final scaling after the layer as well as the - final residual layer. - - layer is a function that takes 2 arguments the first the latent and the second is the - conditioning signal + AdaLayerNorm for embedding auxiliary information """ def __init__( @@ -135,6 +115,19 @@ def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Ten return x +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class SwiGLU(nn.Module): + def __init__(self): + super(SwiGLU, self).__init__() + + def forward(self, x): + x1, x2 = x.chunk(2, dim=-1) + return x2 * F.silu(x1) + + class AdaLayerNormLayer(torch.nn.Module): """ AdaLayerNorm for embedding auxiliary information as done in DiT (Peebles & Xie) with zero From 6c03b353f1bde33dc676bdd8eeabd02b9d39572e Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 11:02:31 +0200 Subject: [PATCH 259/344] updated adanorm, added new diffusion forward --- src/weathergen/model/blocks.py | 12 ++--- src/weathergen/model/diffusion.py | 90 +++++++++++++++++++++++++++++-- src/weathergen/model/norms.py | 52 +++++++++--------- 3 files changed, 116 insertions(+), 38 deletions(-) diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index 1f815ffdd..061928f64 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -15,7 +15,7 @@ MultiSelfAttentionHeadVarlen, ) from weathergen.model.layers import MLP -from weathergen.model.norms import AdaLayerNorm +from weathergen.model.norms import AdaLayerNormLayer from weathergen.utils.utils import get_dtype @@ -37,7 +37,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs **kwargs["attention_kwargs"], ) if self.with_adanorm: - self.mhsa_block = AdaLayerNorm(dim, dim_aux, self.mhsa, dropout_rate) + self.mhsa_block = AdaLayerNormLayer(dim, dim_aux, self.mhsa, dropout_rate) else: self.ln_sa = nn.LayerNorm(dim, eps=kwargs["attention_kwargs"]["norm_eps"]) self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x @@ -53,7 +53,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) - self.mlp_block = AdaLayerNorm(dim, dim_aux, self.mlp_fn, dropout_rate) + self.mlp_block = AdaLayerNormLayer(dim, dim_aux, self.mlp_fn, dropout_rate) else: self.ln_mlp = nn.LayerNorm(norm_eps=kwargs["attention_kwargs"]["norm_eps"]) self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x), None, **kwargs) + x @@ -114,7 +114,7 @@ def __init__( **kwargs["attention_kwargs"], ) if self.with_adanorm: - self.mhsa_block = AdaLayerNorm(dim_q, dim_aux, self.mhsa, dropout_rate) + self.mhsa_block = AdaLayerNormLayer(dim_q, dim_aux, self.mhsa, dropout_rate) else: self.ln_sa = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x @@ -127,7 +127,7 @@ def __init__( **kwargs["attention_kwargs"], ) if self.with_adanorm: - self.cross_attn_block = AdaLayerNorm(dim_q, dim_aux, self.cross_attn, dropout_rate) + self.cross_attn_block = AdaLayerNormLayer(dim_q, dim_aux, self.cross_attn, dropout_rate) else: self.ln_ca = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) self.cross_attn_block = ( @@ -145,7 +145,7 @@ def __init__( ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) - self.mlp_block = AdaLayerNorm(dim_q, dim_aux, self.mlp_fn, dropout_rate) + self.mlp_block = AdaLayerNormLayer(dim_q, dim_aux, self.mlp_fn, dropout_rate) else: self.ln_mlp = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x)) + x diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index e68726bc6..e6fe6a039 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -64,6 +64,68 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self._fixed_noise_level: float | None = None def forward( + self, + tokens: torch.Tensor = None, + fstep: int = None, + meta_info: dict[str, SampleMetaData] = None, + coords: torch.Tensor = None, + num_steps: int = 30, + ) -> torch.Tensor: + """ + Forward pass that routes to training_forward or inference_forward based on model status. + + During training: + - calls training_forward with tokens, fstep, meta_info, coords + - extracts datetime conditioning from meta_info and passes through datetime embedder + - adds noise to target and returns denoised prediction + + During inference: + - calls inference_forward with fstep, num_steps, and meta_info + - generates samples via iterative diffusion steps with conditional temporal modulation + + Args: + tokens: Training tensor of shape (B, H, D) - required during training + fstep: Forecast step index - required for both modes + meta_info: Sample metadata dict containing timestamps - required for both modes + coords: Optional coordinate tensor + num_steps: Number of diffusion steps for inference (default: 30) + + Returns: + torch.Tensor: Model output (denoised prediction during training, + or generated sample during inference) + + Raises: + ValueError: If required arguments are missing for current mode + """ + if self.training: + if tokens is None or fstep is None or meta_info is None: + raise ValueError( + f"During training, tokens, fstep, and meta_info are required. " + f"Got tokens={tokens is not None}, fstep={fstep}, meta_info={meta_info is not None}" + ) + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) + else: + #NOTE: temporary for analysing denoising + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) + # if fstep is None: + # raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") + # return self.inference_forward( + # fstep=fstep, + # num_steps=num_steps, + # meta_info=meta_info, + # ) + + def training_forward( self, tokens: torch.Tensor, fstep: int, @@ -127,13 +189,31 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> c_in * x, fstep=fstep, noise_emb=noise_emb, ada_ln_aux=c ) # Eq. (7) in EDM paper - def inference( + def inference_forward( self, fstep: int, num_steps: int = 30, + meta_info: dict[str, SampleMetaData] = None, ) -> torch.Tensor: - # Forward pass of the diffusion model during inference - # https://github.com/NVlabs/edm/blob/main/generate.py + """ + Forward pass of the diffusion model during inference. + + Iteratively denoises a random sample using the learned score function, + with optional temporal conditioning extracted from meta_info. + https://github.com/NVlabs/edm/blob/main/generate.py + + Args: + fstep: Forecast step index for the network + num_steps: Number of diffusion denoising steps (default: 30) + meta_info: Optional sample metadata dict containing timestamps for temporal conditioning + + Returns: + torch.Tensor: Generated sample of shape (1, num_healpix_cells, ae_global_dim_embed) + """ + # Extract conditioning from meta_info (same as training_forward) + c = None + if meta_info is not None: + c = meta_info["ERA5"].params["timestamp"] # Sample noise (assuming single batch element for now) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") @@ -165,13 +245,13 @@ def inference( t_hat = t_cur # Euler step. - denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep) # c to be discussed + denoised = self.denoise(x=x_hat, c=c, sigma=t_hat, fstep=fstep) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - denoised = self.denoise(x=x_next, c=None, sigma=t_next, fstep=fstep) + denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 076efa0e7..0c27aac53 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -63,7 +63,8 @@ def forward(self, x): class AdaLayerNorm(torch.nn.Module): """ - AdaLayerNorm for embedding auxiliary information + AdaLayerNorm for embedding auxiliary information. + Produces scale and shift for adaptive layer norm. """ def __init__( @@ -71,26 +72,26 @@ def __init__( ): super().__init__() - # simple 2-layer MLP for embedding auxiliary information - self.embed_aux = torch.nn.ModuleList() - self.embed_aux.append(torch.nn.Linear(dim_aux, 4 * dim_aux)) - self.embed_aux.append(torch.nn.SiLU()) - self.embed_aux.append(torch.nn.Linear(4 * dim_aux, 2 * dim_embed_x)) - + # MLP for embedding auxiliary information (matches DiT style) self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim_aux, 2 * dim_embed_x, bias=True) + ) + + # Initialize weights to zero for stable training (DiT style) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: - for block in self.embed_aux: - aux = block(aux) - scale, shift = aux.split(aux.shape[-1] // 2, dim=-1) - - x = self.norm(x) * (1 + scale) + shift + shift, scale = self.adaLN_modulation(aux).chunk(2, dim=-1) + return modulate(self.norm(x), shift, scale) - return x class AdaLayerNormFinal(torch.nn.Module): """ - AdaLayerNorm from DiT for the final output gate only, i.e. only scale + AdaLayerNorm for gating only (scale only, no shift). + Used for final output gating as in DiT. """ def __init__( @@ -98,22 +99,19 @@ def __init__( ): super().__init__() - # simple 2-layer MLP for embedding auxiliary information - self.embed_aux = torch.nn.ModuleList() - self.embed_aux.append(torch.nn.Linear(dim_aux, 4 * dim_aux)) - self.embed_aux.append(torch.nn.SiLU()) - self.embed_aux.append(torch.nn.Linear(4 * dim_aux, dim_embed_x)) - self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim_aux, dim_embed_x, bias=True) + ) + + # Initialize weights to zero for stable training (DiT style) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: - for block in self.embed_aux: - aux = block(aux) - scale = aux - - x = self.norm(x) * (1 + scale) - - return x + scale = self.adaLN_modulation(aux) + return modulate(self.norm(x), shift=0, scale=scale) def modulate(x, shift, scale): return x * (1 + scale) + shift From 546847f67a292f4b22eb1c9b1c1bb7533a7ad098 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 11:22:05 +0200 Subject: [PATCH 260/344] fix conditioning during inference --- src/weathergen/model/diffusion.py | 34 +++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index e6fe6a039..0b5da8d30 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -111,19 +111,19 @@ def forward( ) else: #NOTE: temporary for analysing denoising - return self.training_forward( - tokens=tokens, - fstep=fstep, - meta_info=meta_info, - coords=coords, - ) - # if fstep is None: - # raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") - # return self.inference_forward( + # return self.training_forward( + # tokens=tokens, # fstep=fstep, - # num_steps=num_steps, # meta_info=meta_info, + # coords=coords, # ) + if fstep is None: + raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") + return self.inference_forward( + fstep=fstep, + num_steps=num_steps, + meta_info=meta_info, + ) def training_forward( self, @@ -162,7 +162,7 @@ def training_forward( self._noised_tokens = y + n - logger.info(f"Conditioning on date: {c}") + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) @@ -172,6 +172,7 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> consideration of a conditioning c (e.g., previous time steps) and the current diffusion noise level sigma. """ + # Compute scaling conditionings c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() @@ -227,8 +228,11 @@ def inference_forward( * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) ) ** self.rho t_steps = torch.cat( - [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + [t_steps, torch.zeros_like(t_steps[:1])] ) # t_N = 0 + # t_steps = torch.cat( + # [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + # ) # t_N = 0 # Main sampling loop. x_next = x * t_steps[0] @@ -286,12 +290,16 @@ def __init__(self, embedding_dim: int, frequency_embedding_dim: int, dtype=torch def timestep_embedding(self, t: float, max_period: int = 10000): """ Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. + :param t: a scalar or 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ + # Ensure t is at least 1D + if t.dim() == 0: + t = t.unsqueeze(0) + half = self.frequency_embedding_dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=self.dtype) / half From 922dd95dd78678efdae66e4535a6e15238e38c85 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 12:14:05 +0200 Subject: [PATCH 261/344] disable inference --- src/weathergen/model/diffusion.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 0b5da8d30..85352354c 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -110,20 +110,20 @@ def forward( coords=coords, ) else: - #NOTE: temporary for analysing denoising - # return self.training_forward( - # tokens=tokens, - # fstep=fstep, - # meta_info=meta_info, - # coords=coords, - # ) - if fstep is None: - raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") - return self.inference_forward( + # NOTE: temporary for analysing denoising + return self.training_forward( + tokens=tokens, fstep=fstep, - num_steps=num_steps, meta_info=meta_info, + coords=coords, ) + # if fstep is None: + # raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") + # return self.inference_forward( + # fstep=fstep, + # num_steps=num_steps, + # meta_info=meta_info, + # ) def training_forward( self, From 11f968750132722bf8ce6f147e32b9236469c054 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 1 Apr 2026 12:14:22 +0200 Subject: [PATCH 262/344] adjust eta rendering --- src/weathergen/utils/validation_io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 712566683..6e2eb4427 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -330,7 +330,8 @@ def write_output( noised_len_per_sample = noised_stream.shape[0] // num_samples if has_noised else 0 if noise_level is not None: - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.0e}')) + # Format with .1e to preserve one decimal place in mantissa, then clean up exponent notation + eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.1e}')) else: eta_str = None eta_tag = f"_eta{eta_str}" if eta_str is not None else "" From f0e77d7009a1e1d293e1bad182bf672ec4a3fa52 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:22:25 +0200 Subject: [PATCH 263/344] New best --- config/config_forecasting.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 4f1ff1499..56ab2f4a0 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -67,6 +67,8 @@ forecast_att_dense_rate: 1.0 healpix_level: 5 rope_2D: False +mlp_type: swiglu +use_xsa: True with_mixed_precision: True with_flash_attention: True From 594a119869d05d4137db5bee5184d82ea663d2c1 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Thu, 2 Apr 2026 14:47:57 +0200 Subject: [PATCH 264/344] Fix plotting reset --- src/weathergen/utils/plot_training.py | 35 ++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 8e13cd242..2cdcb9591 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -153,6 +153,33 @@ def clean_plot_folder(plot_dir: Path): image.unlink() +#################################################################################################### +def _adjust_reset_x_axis(x_vals, x_col: str) -> np.ndarray: + """ + Keep sample-based x-axes monotonic when chained jobs append metrics with a reset counter. + """ + adjusted_x_vals = np.array(x_vals, dtype=np.float64, copy=True) + + if adjusted_x_vals.size < 2 or "sample" not in x_col.lower(): + return adjusted_x_vals + + offset = 0.0 + prev_raw = np.nan + prev_adjusted = np.nan + for idx, raw_val in enumerate(adjusted_x_vals): + if np.isnan(raw_val): + continue + if not np.isnan(prev_raw) and raw_val < prev_raw: + offset = prev_adjusted + + adjusted_val = raw_val + offset + adjusted_x_vals[idx] = adjusted_val + prev_raw = raw_val + prev_adjusted = adjusted_val + + return adjusted_x_vals + + #################################################################################################### def get_stream_names(run_id: str, model_path: Path | None = "./model"): """ @@ -213,7 +240,7 @@ def plot_lr( x_col = next(filter(lambda c: x_axis in c, run_data.train.columns)) data_cols = list(filter(lambda c: "learning_rate" in c, run_data.train.columns)) - x_vals = run_data.train[x_col] + x_vals = _adjust_reset_x_axis(run_data.train[x_col], x_col) y_vals = np.array(run_data.train[data_cols]) mask = y_vals > 1000.0 y_vals[mask] = 0.0 # np.nan @@ -257,7 +284,7 @@ def plot_loss_avg(plot_dir: Path, runs_ids, runs_data, runs_active, stage=TRAIN, legend_str = [] for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): run_data_stage = run_data.train if stage == TRAIN else run_data.val - x_vals = np.array(run_data_stage["num_samples"]) + x_vals = _adjust_reset_x_axis(run_data_stage["num_samples"], "num_samples") y_vals = np.array(run_data_stage["loss_avg_mean"]) mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals)) @@ -387,7 +414,7 @@ def plot_loss_per_stream( data_cols += [col] for col in data_cols: - x_vals = np.array(run_data_mode[x_col]) + x_vals = _adjust_reset_x_axis(run_data_mode[x_col], x_col) y_data = np.array(run_data_mode[col]) mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_data)) @@ -551,7 +578,7 @@ def plot_loss_per_run( if run_data_mode[col].shape[0] == 0: continue - x_vals = np.array(run_data_mode[x_col]) + x_vals = _adjust_reset_x_axis(run_data_mode[x_col], x_col) y_data = np.array(run_data_mode[col]) plt.plot( From 92af4bf4f7c3e026b7cb6b6b62dca87c1482e55c Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 3 Apr 2026 17:01:03 +0200 Subject: [PATCH 265/344] Inference success with z500 d128 model --- config/config_diffusion.yml | 10 +- config/config_diffusion_tiny.yml | 2 +- src/weathergen/datasets/masking.py | 1 - src/weathergen/model/diffusion.py | 34 ++-- src/weathergen/model/engines.py | 164 +++++++++--------- src/weathergen/model/layers.py | 15 +- src/weathergen/model/model.py | 5 - .../loss_module_latent_diffusion.py | 15 +- src/weathergen/utils/validation_io.py | 3 +- 9 files changed, 118 insertions(+), 131 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 19c42f396..d270e3281 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -75,9 +75,13 @@ rho: 7 p_mean: -1.2 p_std: 1.2 - healpix_level: 5 +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True + with_mixed_precision: True with_flash_attention: True compile_model: False @@ -225,8 +229,8 @@ validation_config: # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] - validation_noise_levels: [2.0, 3.0, 3.2, 3.5, 4.0, 5.0] - # validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] + # validation_noise_levels: [2.0, 3.0, 3.2, 3.5, 4.0, 5.0] + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] samples_per_mini_epoch: 16 shuffle: False diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml index b87587f0b..947a9f76c 100644 --- a/config/config_diffusion_tiny.yml +++ b/config/config_diffusion_tiny.yml @@ -59,7 +59,7 @@ fe_num_blocks: 4 fe_num_heads: 4 fe_dropout_rate: 0.0 fe_with_qk_lnorm: True -fe_diffusion_model: False +fe_diffusion_model: True fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # currently fixed to 1.0 (due to limitations with flex_attention and triton) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 5c1cf86a0..7705f7803 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -561,7 +561,6 @@ def _generate_cell_mask( if "diffusion_rn" in masking_strategy_config: masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) - # masking_params["noise_level_rn"] = self.rng.uniform(-1.0, 5.0) elif strategy == "healpix": # prepare healpix-based masking diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 8860d00bc..b47040068 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,11 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - # return tokens - - # print("input tokens statistics") - # print("mean", tokens.mean(), "std", tokens.std(), "max", tokens.max(), "min", tokens.min()) - # return self.inference(fstep=fstep, num_steps=100) + # return self.inference(fstep=fstep, num_steps=50, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -121,12 +117,7 @@ def forward( self._noised_tokens = (y + n).detach() - # return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep) - n = torch.ones_like(y) - # if self._noise is None: - # self._noise = torch.randn_like(y) - # n = self._noise - return self.denoise(x=n, c=c, sigma=sigma, fstep=fstep, coords=coords) + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep, coords=coords) def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, coords: torch.Tensor = None) -> torch.Tensor: """ @@ -134,10 +125,10 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co consideration of a conditioning c (e.g., previous time steps) and the current diffusion noise level sigma. """ - # # Compute scaling conditionings (EDM Eq. 7 — disabled for direct prediction) - # c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - # c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - # c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() + # Compute scaling conditionings (EDM Eq. 7 — disabled for direct prediction) + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() c_noise = sigma.log() / 4 # Embed noise level @@ -147,15 +138,16 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co x = self.preconditioner.precondition(x, c) # Direct prediction: network outputs denoised estimate directly - return self.net(x, fstep=fstep, coords=coords, noise_emb=noise_emb) - # return c_skip * x + c_out * self.net( - # c_in * x, fstep=fstep, noise_emb=noise_emb - # ) # Eq. (7) in EDM paper + # return self.net(x, fstep=fstep, coords=coords, noise_emb=noise_emb) + return c_skip * x + c_out * self.net( + c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb + ) # Eq. (7) in EDM paper def inference( self, fstep: int, num_steps: int = 30, + coords: torch.Tensor = None, ) -> torch.Tensor: # Forward pass of the diffusion model during inference # https://github.com/NVlabs/edm/blob/main/generate.py @@ -207,13 +199,13 @@ def inference( t_hat = t_cur # Euler step. - denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep) # c to be discussed + denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep, coords=coords) # c to be discussed d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - denoised = self.denoise(x=x_next, c=None, sigma=t_next, fstep=fstep) + denoised = self.denoise(x=x_next, c=None, sigma=t_next, fstep=fstep, coords=coords) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 51cee3d53..78694477f 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -405,87 +405,90 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() - # global_rate = int(1 / self.cf.forecast_att_dense_rate) - # if mode_cfg.get("forecast", {}).get("policy") is not None: - # for i in range(self.cf.fe_num_blocks): - # # Alternate between global and local attention - # if (i % global_rate == 0) or i + 1 == self.cf.fe_num_blocks: - # self.fe_blocks.append( - # MultiSelfAttentionHead( - # self.cf.ae_global_dim_embed, - # num_heads=self.cf.fe_num_heads, - # dropout_rate=self.cf.fe_dropout_rate, - # with_qk_lnorm=self.cf.fe_with_qk_lnorm, - # with_flash=self.cf.with_flash_attention, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.norm_eps, - # attention_dtype=get_dtype(self.cf.attention_dtype), - # with_noise_conditioning=self.cf.fe_diffusion_model, - # with_2d_rope=self.cf.get("rope_2D", False), - # ) - # ) - # else: - # self.fe_blocks.append( - # MultiSelfAttentionHeadLocal( - # self.cf.ae_global_dim_embed, - # num_heads=self.cf.fe_num_heads, - # qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, - # block_factor=self.cf.ae_global_block_factor, - # dropout_rate=self.cf.fe_dropout_rate, - # with_qk_lnorm=self.cf.fe_with_qk_lnorm, - # with_flash=self.cf.with_flash_attention, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.norm_eps, - # attention_dtype=get_dtype(self.cf.attention_dtype), - # with_noise_conditioning=self.cf.fe_diffusion_model, - # with_2d_rope=self.cf.get("rope_2D", False), - # ) - # ) - # # Add MLP block - # self.fe_blocks.append( - # MLP( - # self.cf.ae_global_dim_embed, - # self.cf.ae_global_dim_embed, - # with_residual=False, - # dropout_rate=self.cf.fe_dropout_rate, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.mlp_norm_eps, - # with_noise_conditioning=self.cf.fe_diffusion_model, - # ) - # ) - # # Optionally, add LayerNorm after i-th layer - # if i in self.cf.get("fe_layer_norm_after_blocks", []): - # self.fe_blocks.append( - # torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) - # ) - - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed + 2, - self.cf.ae_global_dim_embed, - num_layers=12, - with_residual=False, - pre_layer_norm=False, # TODO: REMOVE AGAIN - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=dim_aux, - norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=False, # TODO: SWITCH BACK TO TRUE - ) - ) - # def init_weights_final(m): - # if isinstance(m, torch.nn.Linear): - # torch.nn.init.normal_(m.weight, mean=0, std=0.001) - # if m.bias is not None: - # torch.nn.init.normal_(m.bias, mean=0, std=0.001) + # self.position_layer = torch.nn.Linear(2, self.cf.ae_global_dim_embed) + + global_rate = int(1 / self.cf.forecast_att_dense_rate) + if mode_cfg.get("forecast", {}).get("policy") is not None: + for i in range(self.cf.fe_num_blocks): + # Alternate between global and local attention + if (i % global_rate == 0) or i + 1 == self.cf.fe_num_blocks: + self.fe_blocks.append( + MultiSelfAttentionHead( + self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + dropout_rate=self.cf.fe_dropout_rate, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, + with_2d_rope=self.cf.get("rope_2D", False), + ) + ) + else: + self.fe_blocks.append( + MultiSelfAttentionHeadLocal( + self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, + block_factor=self.cf.ae_global_block_factor, + dropout_rate=self.cf.fe_dropout_rate, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=self.cf.fe_diffusion_model, + with_2d_rope=self.cf.get("rope_2D", False), + ) + ) + # Add MLP block + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + num_layers=2, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=dim_aux, + norm_eps=self.cf.mlp_norm_eps, + with_noise_conditioning=self.cf.fe_diffusion_model + ) + ) + # Optionally, add LayerNorm after i-th layer + if i in self.cf.get("fe_layer_norm_after_blocks", []): + self.fe_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) + + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # num_layers=12, + # with_residual=True, + # pre_layer_norm=True, # TODO: REMOVE AGAIN + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=dim_aux, + # norm_eps=self.cf.mlp_norm_eps, + # with_noise_conditioning=True, # TODO: SWITCH BACK TO TRUE + # ) + # ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_(m.weight, mean=0, std=0.1) + torch.nn.init.normal_(m.weight, mean=0, std=0.001) if m.bias is not None: - torch.nn.init.normal_(m.bias, mean=0, std=0.1) + torch.nn.init.normal_(m.bias, mean=0, std=0.001) + # def init_weights_final(m): + # if isinstance(m, torch.nn.Linear): + # torch.nn.init.normal_(m.weight, mean=0, std=0.1) + # if m.bias is not None: + # torch.nn.init.normal_(m.bias, mean=0, std=0.1) for block in self.fe_blocks: block.apply(init_weights_final) @@ -520,10 +523,13 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: + # if isinstance(block, MLP): + # # tokens = torch.concat([tokens, coords], dim=-1) if coords is not None else tokens + # # TODO: REMOVE + # tokens = tokens + self.position_layer(coords) # Assuming args[1] contains positional information tokens = checkpoint(block, tokens, coords, noise_emb, aux_info, use_reentrant=False) else: for block in self.fe_blocks: - tokens = torch.concat([tokens, coords], dim=-1) if coords is not None else tokens if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 4b968c7de..9497ffa36 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -60,7 +60,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, name: str | None = None, - with_noise_conditioning=False, + with_noise_conditioning=False ): """Constructor""" @@ -114,11 +114,14 @@ def forward(self, *args): gate = None for i, layer in enumerate(self.layers): - x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) - # Apply noise conditioning after layer norm (first layer), mirroring - # the AdaLN-Zero pattern used in MultiSelfAttentionHead - if i == 0 and self.with_noise_conditioning: - x, gate = self.noise_conditioning(x, noise_emb) + if i == 0 and self.with_aux: + x = layer(x, aux) + if i == 0 and self.with_noise_conditioning: + x, gate = self.noise_conditioning(x, noise_emb) + else: + if i == 0 and self.with_noise_conditioning: + x, gate = self.noise_conditioning(x, noise_emb) + x = layer(x) if self.with_residual: if gate is not None: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ffc2c77b4..754289d16 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -737,11 +737,6 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) - tokens = torch.ones_like(tokens) - # if self._noise is None: - # self._noise = torch.randn_like(tokens) - # tokens = self._noise - # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. # t_mean = tokens.mean() diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 0cc8e1344..a89c9d6b8 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -103,8 +103,7 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: fsteps = len(target_tokens_all) # During validation, use unweighted loss (no noise-level scaling) - # noise_weight = 1.0 if self.stage == "val" else self._get_noise_weight(eta) - noise_weight = 1.0 + noise_weight = 1.0 if self.stage == "val" else self._get_noise_weight(eta) fstep_loss_weights = self._get_fstep_weights(fsteps) loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) @@ -118,18 +117,6 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: # if forecast_offset==0, then the timepoints correspond. # Otherwise targets don't encode the source timestep, so we don't need to skip for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: - - - - # Try random fixed target - if self.random_target is None: - self.random_target = torch.randn_like(target_tokens) * 1.0 + 0.0 - # self.random_target = torch.ones_like(target_tokens) - target_tokens = self.random_target - - print("pred std", pred_tokens.std().item(), "pred mean", pred_tokens.mean().item()) - print("trgt std", target_tokens.std().item(), "trgt mean", target_tokens.mean().item(), ) - loss_lfct = self._loss_per_loss_function( loss_fct, target=target_tokens, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 76aa60d75..9284ed2be 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -264,7 +264,8 @@ def write_output( plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) # headline_channels = {"2t", "z500", "q850", "10u", "10v"} # headline_channels = {"2t", "q850"} - headline_channels = {"z500"} + # headline_channels = {"z500"} + headline_channels = {"2t", "z500"} t_idx = 0 for stream_idx, stream_info in enumerate(cf.streams): From 9eec75f022ce69d4c0ad8d026950d0cb70b6a5fd Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Mon, 6 Apr 2026 11:55:29 +0200 Subject: [PATCH 266/344] Multi-sample small working --- config/config_diffusion.yml | 8 ++- config/config_diffusion_tiny.yml | 68 ++++++++++++++++--- .../streams/era5_1deg_diffusion_tiny/era5.yml | 6 +- config/streams/era5_1deg_forecasting/era5.yml | 9 +-- 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index d270e3281..527803eb0 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -70,7 +70,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 # 170 -sigma_data: 0.5 # 1.7 +sigma_data: 0.5 # 1.7, 157.38 rho: 7 p_mean: -1.2 p_std: 1.2 @@ -99,13 +99,15 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} +load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} norm_type: "LayerNorm" ##################################### -streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_1deg/" +streams_directory: "./config/streams/era5_1deg_forecasting_z500/" streams: ??? # type of zarr_store diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml index 947a9f76c..262534721 100644 --- a/config/config_diffusion_tiny.yml +++ b/config/config_diffusion_tiny.yml @@ -7,19 +7,67 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +# # z500 small +# embed_orientation: "channels" +# embed_unembed_mode: "block" +# embed_dropout_rate: 0.1 + +# ae_local_dim_embed: 128 +# ae_local_num_blocks: 0 +# ae_local_num_heads: 8 +# ae_local_dropout_rate: 0.0 +# ae_local_with_qk_lnorm: True + +# ae_local_num_queries: 1 +# ae_local_queries_per_cell: False +# ae_adapter_num_heads: 8 +# ae_adapter_embed: 128 +# ae_adapter_with_qk_lnorm: True +# ae_adapter_with_residual: True +# ae_adapter_dropout_rate: 0.0 + +# ae_global_dim_embed: 128 +# ae_global_num_blocks: 4 +# ae_global_num_heads: 8 +# ae_global_dropout_rate: 0.0 +# ae_global_with_qk_lnorm: True +# # TODO: switching to < 1 triggers triton-related issues. +# # See https://github.com/ecmwf/WeatherGenerator/issues/1050 +# ae_global_att_dense_rate: 1.0 +# ae_global_block_factor: 64 +# ae_global_mlp_hidden_factor: 2 +# ae_global_trailing_layer_norm: False + +# ae_aggregation_num_blocks: 0 +# ae_aggregation_num_heads: 4 +# ae_aggregation_dropout_rate: 0.0 +# ae_aggregation_with_qk_lnorm: True +# ae_aggregation_att_dense_rate: 1.0 +# ae_aggregation_block_factor: 64 +# ae_aggregation_mlp_hidden_factor: 2 + +# decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +# pred_adapter_kv: False +# pred_self_attention: True +# pred_dyadic_dims: False +# pred_mlp_adaln: True +# num_class_tokens: 0 +# num_register_tokens: 0 + +# # multi-var small embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 ae_local_dim_embed: 128 ae_local_num_blocks: 0 -ae_local_num_heads: 8 +ae_local_num_heads: 4 ae_local_dropout_rate: 0.0 ae_local_with_qk_lnorm: True ae_local_num_queries: 1 ae_local_queries_per_cell: False -ae_adapter_num_heads: 8 +ae_adapter_num_heads: 16 ae_adapter_embed: 128 ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True @@ -27,13 +75,13 @@ ae_adapter_dropout_rate: 0.0 ae_global_dim_embed: 128 ae_global_num_blocks: 4 -ae_global_num_heads: 8 +ae_global_num_heads: 4 ae_global_dropout_rate: 0.0 ae_global_with_qk_lnorm: True # TODO: switching to < 1 triggers triton-related issues. # See https://github.com/ecmwf/WeatherGenerator/issues/1050 ae_global_att_dense_rate: 1.0 -ae_global_block_factor: 64 +ae_global_block_factor: 8 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False @@ -42,7 +90,7 @@ ae_aggregation_num_heads: 4 ae_aggregation_dropout_rate: 0.0 ae_aggregation_with_qk_lnorm: True ae_aggregation_att_dense_rate: 1.0 -ae_aggregation_block_factor: 64 +ae_aggregation_block_factor: 8 ae_aggregation_mlp_hidden_factor: 2 decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear @@ -99,13 +147,15 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -load_chkpt: {'run_id': 'lgrasnq6', 'epoch': -1} +# load_chkpt: {'run_id': 'lgrasnq6', 'epoch': -1} # z500 small +load_chkpt: {'run_id': 'hhz27wy0', 'epoch': -1} # multi-var small norm_type: "LayerNorm" ##################################### -streams_directory: "./config/streams/era5_1deg_diffusion_tiny/" +# streams_directory: "./config/streams/era5_1deg_diffusion_tiny/" # z500 small +streams_directory: "./config/streams/era5_1deg_forecasting/" # multi-var small streams: ??? # type of zarr_store @@ -157,7 +207,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 128 + num_mini_epochs: 256 samples_per_mini_epoch: 1024 shuffle: True @@ -255,7 +305,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: True # test config; full test config is merge of validation and test config diff --git a/config/streams/era5_1deg_diffusion_tiny/era5.yml b/config/streams/era5_1deg_diffusion_tiny/era5.yml index 96b3aa6a1..40da38c09 100644 --- a/config/streams/era5_1deg_diffusion_tiny/era5.yml +++ b/config/streams/era5_1deg_diffusion_tiny/era5.yml @@ -11,8 +11,10 @@ ERA5 : type : anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] stream_id : 0 - source : ["z_500"] - target : ["z_500"] + # source : ["z_500"] + # target : ["z_500"] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 0bd70ae01..39f67714f 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -19,16 +19,17 @@ ERA5 : masking_rate_none : 0.05 token_size : 8 tokenize_spacetime : True - max_num_targets: 20000 + # max_num_targets: 20000 + max_num_targets: -1 embed : net : transformer num_tokens : 1 - num_heads : 8 - dim_embed : 256 + num_heads : 4 + dim_embed : 32 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 32 target_readout : num_layers : 2 num_heads : 4 From b93b79231f9e48915e4338f40f0430c858f7e30e Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 7 Apr 2026 14:01:47 +0200 Subject: [PATCH 267/344] Successful inference with 5% noise or full noise with small model and rope --- config/config_diffusion.yml | 13 +- config/config_diffusion_tiny.yml | 2 +- config/config_forecasting_z500.yml | 251 ++++++++++++++++++ config/streams/era5_1deg_forecasting/era5.yml | 136 +++++----- .../era5_1deg_forecasting_z500/era5.yml | 38 +++ src/weathergen/model/diffusion.py | 4 +- src/weathergen/model/encoder.py | 4 + 7 files changed, 373 insertions(+), 75 deletions(-) create mode 100644 config/config_forecasting_z500.yml create mode 100644 config/streams/era5_1deg_forecasting_z500/era5.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 527803eb0..b481f9e0a 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -99,15 +99,20 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} -load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} +load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data= + norm_type: "LayerNorm" ##################################### -# streams_directory: "./config/streams/era5_1deg/" -streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams_directory: "./config/streams/era5_1deg_forecasting/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" streams: ??? # type of zarr_store diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml index 262534721..c20e442d0 100644 --- a/config/config_diffusion_tiny.yml +++ b/config/config_diffusion_tiny.yml @@ -148,7 +148,7 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" # load_chkpt: {'run_id': 'lgrasnq6', 'epoch': -1} # z500 small -load_chkpt: {'run_id': 'hhz27wy0', 'epoch': -1} # multi-var small +load_chkpt: {'run_id': 'hhz27wy0', 'epoch': -1} # multi-var small, sigma_data=0.7855 norm_type: "LayerNorm" diff --git a/config/config_forecasting_z500.yml b/config/config_forecasting_z500.yml new file mode 100644 index 000000000..5dc0dbe43 --- /dev/null +++ b/config/config_forecasting_z500.yml @@ -0,0 +1,251 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 3 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams_directory: "./config/streams/era5_1deg_forecasting/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 39f67714f..09f88b654 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -25,11 +25,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 4 - dim_embed : 32 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 32 + dim_embed : 512 target_readout : num_layers : 2 num_heads : 4 @@ -37,69 +37,69 @@ ERA5 : pred_head : ens_size : 1 num_layers : 1 - channel_weights : - q_50: 0.2 - q_100: 0.23 - q_150: 0.26 - q_200: 0.29 - q_250: 0.33 - q_300: 0.36 - q_400: 0.42 - q_500: 0.48 - q_600: 0.55 - q_700: 0.61 - q_850: 0.71 - q_925: 0.75 - q_1000: 0.8 - t_50: 0.2 - t_100: 0.23 - t_150: 0.26 - t_200: 0.29 - t_250: 0.33 - t_300: 0.36 - t_400: 0.42 - t_500: 0.48 - t_600: 0.55 - t_700: 0.61 - t_850: 0.71 - t_925: 0.75 - t_1000: 0.8 - u_50: 0.2 - u_100: 0.23 - u_150: 0.26 - u_200: 0.29 - u_250: 0.33 - u_300: 0.36 - u_400: 0.42 - u_500: 0.48 - u_600: 0.55 - u_700: 0.61 - u_850: 0.71 - u_925: 0.75 - u_1000: 0.8 - v_50: 0.2 - v_100: 0.23 - v_150: 0.26 - v_200: 0.29 - v_250: 0.33 - v_300: 0.36 - v_400: 0.42 - v_500: 0.48 - v_600: 0.55 - v_700: 0.61 - v_850: 0.71 - v_925: 0.75 - v_1000: 0.8 - z_50: 0.2 - z_100: 0.23 - z_150: 0.26 - z_200: 0.29 - z_250: 0.33 - z_300: 0.36 - z_400: 0.42 - z_500: 0.48 - z_600: 0.55 - z_700: 0.61 - z_850: 0.71 - z_925: 0.75 - z_1000: 0.8 \ No newline at end of file + # channel_weights : + # q_50: 0.2 + # q_100: 0.23 + # q_150: 0.26 + # q_200: 0.29 + # q_250: 0.33 + # q_300: 0.36 + # q_400: 0.42 + # q_500: 0.48 + # q_600: 0.55 + # q_700: 0.61 + # q_850: 0.71 + # q_925: 0.75 + # q_1000: 0.8 + # t_50: 0.2 + # t_100: 0.23 + # t_150: 0.26 + # t_200: 0.29 + # t_250: 0.33 + # t_300: 0.36 + # t_400: 0.42 + # t_500: 0.48 + # t_600: 0.55 + # t_700: 0.61 + # t_850: 0.71 + # t_925: 0.75 + # t_1000: 0.8 + # u_50: 0.2 + # u_100: 0.23 + # u_150: 0.26 + # u_200: 0.29 + # u_250: 0.33 + # u_300: 0.36 + # u_400: 0.42 + # u_500: 0.48 + # u_600: 0.55 + # u_700: 0.61 + # u_850: 0.71 + # u_925: 0.75 + # u_1000: 0.8 + # v_50: 0.2 + # v_100: 0.23 + # v_150: 0.26 + # v_200: 0.29 + # v_250: 0.33 + # v_300: 0.36 + # v_400: 0.42 + # v_500: 0.48 + # v_600: 0.55 + # v_700: 0.61 + # v_850: 0.71 + # v_925: 0.75 + # v_1000: 0.8 + # z_50: 0.2 + # z_100: 0.23 + # z_150: 0.26 + # z_200: 0.29 + # z_250: 0.33 + # z_300: 0.36 + # z_400: 0.42 + # z_500: 0.48 + # z_600: 0.55 + # z_700: 0.61 + # z_850: 0.71 + # z_925: 0.75 + # z_1000: 0.8 \ No newline at end of file diff --git a/config/streams/era5_1deg_forecasting_z500/era5.yml b/config/streams/era5_1deg_forecasting_z500/era5.yml new file mode 100644 index 000000000..5240bf8d1 --- /dev/null +++ b/config/streams/era5_1deg_forecasting_z500/era5.yml @@ -0,0 +1,38 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source : ['z_500'] + target : ['z_500'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index b47040068..7af893ca4 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - # return self.inference(fstep=fstep, num_steps=50, coords=coords) + # return self.inference(fstep=fstep, num_steps=10, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -161,7 +161,7 @@ def inference( # n = torch.randn_like(x).to(device="cuda") * sigma # x = self.cur_token + n - # x = self.cur_token * 0.01 + x + x = self.cur_token * 0.05 + x # breakpoint() diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index ebfac1ab8..77ba3f3da 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -112,6 +112,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) + # self.ln = torch.nn.LayerNorm(cf.ae_local_dim_embed, elementwise_affine=False) + def forward(self, model_params, batch): """ Encoder forward @@ -132,6 +134,8 @@ def forward(self, model_params, batch): use_reentrant=False, ) + # tokens_global = self.ln(tokens_global) + return tokens_global, posteriors def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor): From 720681f37dac30ee79fd72302a88f8b6a7f5029e Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 7 Apr 2026 18:09:34 +0200 Subject: [PATCH 268/344] Minor diffusion config update --- config/config_diffusion.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index b481f9e0a..9fd4b4f4d 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -70,7 +70,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 # 170 -sigma_data: 0.5 # 1.7, 157.38 +sigma_data: 1.7 # 0.5 # 1.7, 157.38 rho: 7 p_mean: -1.2 p_std: 1.2 @@ -176,7 +176,7 @@ training_config: learning_rate_scheduling : lr_start: 1e-5 #5e-5 - lr_max: 1e-4 #1e-4 + lr_max: 5e-5 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 From 2c48e3184e942bfa07cda194545b9a42df64585b Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 8 Apr 2026 10:17:04 +0200 Subject: [PATCH 269/344] Config for 128-dim hl5 z500 --- config/config_diffusion_tiny.yml | 6 +++--- config/streams/era5_1deg_diffusion_tiny/era5.yml | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml index c20e442d0..065e76d2b 100644 --- a/config/config_diffusion_tiny.yml +++ b/config/config_diffusion_tiny.yml @@ -123,7 +123,7 @@ rho: 7 p_mean: -1.2 p_std: 1.2 -healpix_level: 3 +healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) @@ -154,8 +154,8 @@ norm_type: "LayerNorm" ##################################### -# streams_directory: "./config/streams/era5_1deg_diffusion_tiny/" # z500 small -streams_directory: "./config/streams/era5_1deg_forecasting/" # multi-var small +streams_directory: "./config/streams/era5_1deg_diffusion_tiny/" # z500 small +# streams_directory: "./config/streams/era5_1deg_forecasting/" # multi-var small streams: ??? # type of zarr_store diff --git a/config/streams/era5_1deg_diffusion_tiny/era5.yml b/config/streams/era5_1deg_diffusion_tiny/era5.yml index 40da38c09..efa163d42 100644 --- a/config/streams/era5_1deg_diffusion_tiny/era5.yml +++ b/config/streams/era5_1deg_diffusion_tiny/era5.yml @@ -11,10 +11,10 @@ ERA5 : type : anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] stream_id : 0 - # source : ["z_500"] - # target : ["z_500"] - source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] - target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + source : ["z_500"] + target : ["z_500"] + # source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + # target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 From 900e220e9f443b5379cb999f57fafadf92d3051a Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 8 Apr 2026 15:55:18 +0200 Subject: [PATCH 270/344] config changes --- config/config_diffusion.yml | 2 +- config/runs_plot_train.yml | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 config/runs_plot_train.yml diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index c7ba41557..a4583ca31 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -227,7 +227,7 @@ validation_config: # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] - validation_noise_levels: [0.3, 1.5] + validation_noise_levels: [0.3, 1.5, 3.0] samples_per_mini_epoch: 1 shuffle: True # TODO: Set back to False diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..6bd2a91bc --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,5 @@ +train : + plot : + h8wnm1kt: + slurm_id: 0 + description: "first conditioning experiment" \ No newline at end of file From d1f2a08de63137f60cc29d40d393902b64c47ec9 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 8 Apr 2026 16:20:58 +0200 Subject: [PATCH 271/344] Inference diagnostic tools --- src/weathergen/model/diffusion.py | 229 ++++++++++++++++++++++++++---- 1 file changed, 203 insertions(+), 26 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 7af893ca4..e69759ab4 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -28,7 +28,7 @@ import torch -from weathergen.common.config import Config +from weathergen.common.config import Config, get_path_run from weathergen.datasets.batch import SampleMetaData from weathergen.model.engines import ForecastingEngine @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - # return self.inference(fstep=fstep, num_steps=10, coords=coords) + return self.inference(fstep=fstep, num_steps=15, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -146,39 +146,57 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co def inference( self, fstep: int, - num_steps: int = 30, + num_steps: int = 50, coords: torch.Tensor = None, ) -> torch.Tensor: - # Forward pass of the diffusion model during inference + # Forward pass of the diffusion model during inference (Heun sampler) # https://github.com/NVlabs/edm/blob/main/generate.py - # Sample noise (assuming single batch element for now) - x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") * 1.0 - - # eta = torch.tensor([1.0], device="cuda").float() # 1.0 (good), 2.0 (okay), 2.2 (max), 2.5 (hard) - # sigma = (eta * self.p_std + self.p_mean).exp() - # print("sigma", sigma) - # n = torch.randn_like(x).to(device="cuda") * sigma - # x = self.cur_token + n - - x = self.cur_token * 0.05 + x - # breakpoint() - - - # return self.denoise(x=x, c=None, sigma=sigma, fstep=fstep) - # print("initial noise statistics") - # print("mean", x.mean(), "std", x.std(), "max", x.max(), "min", x.min()) + # Sample pure noise (assuming single batch element for now) + torch.manual_seed(42) + x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") + x = self.cur_token * 1.0 + x * 0.1 + + # --- Training-aligned sigma bounds --- + # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). + # The network only learns to denoise reliably within the training distribution. + # - sigma_max_eff: cap at 99.7th percentile = exp(p_mean + 3*p_std) + # Beyond this, the denoiser is in untrained territory → garbage predictions + # that poison the entire ODE trajectory. + # - sigma_min_eff: floor at a level where the network still contributes. + # With EDM preconditioning, c_skip = sigma_data^2/(sigma^2+sigma_data^2). + # At sigma << sigma_data, c_skip → 1, meaning the output ≈ input (skip + # connection dominates) and the network can no longer correct errors. + # We stop at sigma_min = max(config value, sigma_data * 0.01), which gives + # c_skip ≈ 0.9999 — still some network contribution, and avoids the + # numerical instability of dividing by near-zero sigma in the ODE. + sigma_max_train = math.exp(self.p_mean + 3.0 * self.p_std) + sigma_max_eff = min(self.sigma_max, sigma_max_train) + sigma_min_eff = max(self.sigma_min, self.sigma_data * 0.01) + logger.info( + f"Inference sigma schedule: " + f"sigma_max_eff={sigma_max_eff:.4f} (config={self.sigma_max}, train 3σ={sigma_max_train:.4f}), " + f"sigma_min_eff={sigma_min_eff:.4f} (config={self.sigma_min}), " + f"sigma_data={self.sigma_data}, rho={self.rho}, num_steps={num_steps}" + ) - # Time step discretization. + # --- Time step discretization (EDM Eq. 5) with training-aligned bounds --- step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") t_steps = ( - self.sigma_max ** (1 / self.rho) + sigma_max_eff ** (1 / self.rho) + step_indices / (num_steps - 1) - * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) + * (sigma_min_eff ** (1 / self.rho) - sigma_max_eff ** (1 / self.rho)) ) ** self.rho t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + # --- Per-step tracking for diagnostics --- + track = { + "sigma": [], "x_std": [], "denoised_std": [], + "l2_to_target": [], "cosine_to_target": [], + "c_skip": [], "x": [x.cpu()] + } + # Main sampling loop. x_next = x * t_steps[0] for i, (t_cur, t_next) in enumerate( @@ -187,8 +205,6 @@ def inference( t_cur = torch.tensor([t_cur], device="cuda").float() t_next = torch.tensor([t_next], device="cuda").float() - print(i, t_cur.item()) - x_cur = x_next # Increase noise temporarily. (Stochastic sampling; not used for now) @@ -199,7 +215,7 @@ def inference( t_hat = t_cur # Euler step. - denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep, coords=coords) # c to be discussed + denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep, coords=coords) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur @@ -209,8 +225,169 @@ def inference( d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + # --- Record diagnostics --- + with torch.no_grad(): + s = t_cur.item() + track["sigma"].append(s) + track["c_skip"].append(self.sigma_data**2 / (s**2 + self.sigma_data**2)) + track["x_std"].append(x_next.std().item()) + track["denoised_std"].append(denoised.std().item()) + track["x"].append(x_next.cpu()) + if self.cur_token is not None: + flat_d = denoised.reshape(-1).float() + flat_t = self.cur_token.reshape(-1).float() + track["l2_to_target"].append((flat_d - flat_t).norm().item()) + track["cosine_to_target"].append( + torch.nn.functional.cosine_similarity(flat_d.unsqueeze(0), flat_t.unsqueeze(0)).item() + ) + track["x"].append(self.cur_token.cpu()) + + self._plot_sampling_diagnostics(track, num_steps) return x_next + def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: + """Save a diagnostic plot of the sampling trajectory.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + steps = list(range(len(track["sigma"]))) + has_target = len(track["l2_to_target"]) > 0 + n_plots = 5 if has_target else 3 + + fig, axes = plt.subplots(n_plots, 1, figsize=(10, 3 * n_plots), sharex=True) + + # 1) Sigma schedule + axes[0].semilogy(steps, track["sigma"], "o-", markersize=3) + axes[0].set_ylabel("sigma (noise level)") + axes[0].set_title( + f"Sampling diagnostics | sigma_max_eff={track['sigma'][0]:.2f}, " + f"sigma_data={self.sigma_data}, steps={num_steps}" + ) + axes[0].axhline(self.sigma_data, color="grey", ls="--", lw=0.8, label=f"sigma_data={self.sigma_data}") + axes[0].legend(fontsize=8) + axes[0].grid(True, alpha=0.3) + + # 2) c_skip — skip connection weight (EDM preconditioning) + axes[1].plot(steps, track["c_skip"], "o-", markersize=3, color="tab:orange") + axes[1].set_ylabel("c_skip") + axes[1].set_title("c_skip = σ_data² / (σ² + σ_data²) — 1.0 means output ≈ input (no correction)") + axes[1].axhline(0.5, color="grey", ls="--", lw=0.8, label="c_skip=0.5 (σ=σ_data)") + axes[1].set_ylim(-0.05, 1.05) + axes[1].legend(fontsize=8) + axes[1].grid(True, alpha=0.3) + + # 3) Std of x_next and denoised estimate + axes[2].plot(steps, track["x_std"], "o-", markersize=3, label="x (noisy state)") + axes[2].plot(steps, track["denoised_std"], "s-", markersize=3, label="denoised estimate") + if self.cur_token is not None: + target_std = self.cur_token.std().item() + axes[2].axhline(target_std, color="grey", ls="--", lw=0.8, label=f"target std={target_std:.3f}") + axes[2].set_ylabel("std") + axes[2].legend(fontsize=8) + axes[2].grid(True, alpha=0.3) + + if has_target: + # 4) L2 error to target + axes[3].plot(steps, track["l2_to_target"], "o-", markersize=3, color="tab:red") + axes[3].set_ylabel("L2 error to target") + axes[3].grid(True, alpha=0.3) + + # 5) Cosine similarity to target + axes[4].plot(steps, track["cosine_to_target"], "o-", markersize=3, color="tab:green") + axes[4].set_ylabel("cosine sim to target") + axes[4].set_ylim(-1.05, 1.05) + axes[4].axhline(1.0, color="grey", ls="--", lw=0.8) + axes[4].grid(True, alpha=0.3) + + axes[-1].set_xlabel("sampling step") + fig.tight_layout() + + out_dir = get_path_run(self.cf) + out_dir.mkdir(exist_ok=True, parents=True) + out_path_base = out_dir / "plots" / "validation" / "plots" + out_path_base.mkdir(exist_ok=True, parents=True) + fig.savefig(out_path_base / "sampling_diagnostics.png", dpi=150) + plt.close(fig) + logger.info(f"Saved sampling diagnostics to {out_path_base / 'sampling_diagnostics.png'}") + + vmin, vmax = track["x"][-1].min().item(), track["x"][-1].max().item() + for s_idx, x in enumerate(track["x"]): + fig, axes2 = plt.subplots(1, 2, figsize=(12, 5)) + + im0 = axes2[0].imshow(x[0].t().cpu(), aspect="auto", vmin=vmin, vmax=vmax, cmap="seismic") + plt.colorbar(im0, ax=axes2[0]) + axes2[0].set_title(f"Sample at step {s_idx}") + axes2[0].set_xlabel("embedding dim") + axes2[0].set_ylabel("healpix cell") + + diff = (x[0].cpu() - track["x"][-1][0].cpu()).t() + im1 = axes2[1].imshow(diff, aspect="auto", cmap="bwr") + plt.colorbar(im1, ax=axes2[1]) + axes2[1].set_title("Difference to target") + axes2[1].set_xlabel("embedding dim") + + fig.tight_layout() + plt.savefig(out_path_base / f"sample_{s_idx:05d}.png", dpi=100) + plt.close(fig) + logger.info(f"Saved sample visualization to {out_path_base / f'sample_{s_idx:05d}.png'}") + + + # # --- OLD inference (before training-aligned sigma & diagnostics) --- + # def inference( + # self, + # fstep: int, + # num_steps: int = 30, + # coords: torch.Tensor = None, + # ) -> torch.Tensor: + # # Forward pass of the diffusion model during inference + # # https://github.com/NVlabs/edm/blob/main/generate.py + # + # # Sample noise (assuming single batch element for now) + # torch.manual_seed(42) + # x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") * 1.0 + # + # x = self.cur_token * 0.0 + x + # + # # Time step discretization. + # step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") + # t_steps = ( + # self.sigma_max ** (1 / self.rho) + # + step_indices + # / (num_steps - 1) + # * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) + # ) ** self.rho + # t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + # + # # Main sampling loop. + # x_next = x * t_steps[0] + # for i, (t_cur, t_next) in enumerate( + # zip(t_steps[:-1], t_steps[1:], strict=False) + # ): # 0, ..., N-1 + # t_cur = torch.tensor([t_cur], device="cuda").float() + # t_next = torch.tensor([t_next], device="cuda").float() + # + # print(i, t_cur.item()) + # + # x_cur = x_next + # + # x_hat = x_cur + # t_hat = t_cur + # + # # Euler step. + # denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep, coords=coords) + # d_cur = (x_hat - denoised) / t_hat + # x_next = x_hat + (t_next - t_hat) * d_cur + # + # # Apply 2nd order correction. + # if i < num_steps - 1: + # denoised = self.denoise(x=x_next, c=None, sigma=t_next, fstep=fstep, coords=coords) + # d_prime = (x_next - denoised) / t_next + # x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + # + # return x_next + # # --- END OLD inference --- + class Preconditioner: # Preconditioner, e.g., to concatenate previous frames to the input From 4eaf333d4eb54debafdba0d53f9d53c0b2e183c1 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 8 Apr 2026 18:21:58 +0200 Subject: [PATCH 272/344] Refined inference diagnostics --- config/config_diffusion.yml | 9 ++- .../era5_1deg_forecasting_z500/era5.yml | 4 +- src/weathergen/model/diffusion.py | 79 ++++++++++--------- 3 files changed, 48 insertions(+), 44 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 9fd4b4f4d..91981c491 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -99,10 +99,13 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 # load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 # load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data= @@ -111,8 +114,8 @@ norm_type: "LayerNorm" ##################################### -streams_directory: "./config/streams/era5_1deg_forecasting/" -# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +# streams_directory: "./config/streams/era5_1deg_forecasting/" +streams_directory: "./config/streams/era5_1deg_forecasting_z500/" streams: ??? # type of zarr_store diff --git a/config/streams/era5_1deg_forecasting_z500/era5.yml b/config/streams/era5_1deg_forecasting_z500/era5.yml index 5240bf8d1..f1659fb21 100644 --- a/config/streams/era5_1deg_forecasting_z500/era5.yml +++ b/config/streams/era5_1deg_forecasting_z500/era5.yml @@ -24,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 4 - dim_embed : 256 + dim_embed : 32 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 32 target_readout : num_layers : 2 num_heads : 4 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index e69759ab4..8d9cc4514 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - return self.inference(fstep=fstep, num_steps=15, coords=coords) + # return self.inference(fstep=fstep, num_steps=10, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -155,7 +155,7 @@ def inference( # Sample pure noise (assuming single batch element for now) torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") - x = self.cur_token * 1.0 + x * 0.1 + # x = self.cur_token * 1.0 + x * 0.1 # --- Training-aligned sigma bounds --- # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). @@ -234,12 +234,13 @@ def inference( track["denoised_std"].append(denoised.std().item()) track["x"].append(x_next.cpu()) if self.cur_token is not None: - flat_d = denoised.reshape(-1).float() - flat_t = self.cur_token.reshape(-1).float() - track["l2_to_target"].append((flat_d - flat_t).norm().item()) - track["cosine_to_target"].append( - torch.nn.functional.cosine_similarity(flat_d.unsqueeze(0), flat_t.unsqueeze(0)).item() - ) + # flat_d = denoised.reshape(-1).float() + # flat_t = self.cur_token.reshape(-1).float() + # track["l2_to_target"].append((flat_d - flat_t).norm().item()) + # track["cosine_to_target"].append( + # torch.nn.functional.cosine_similarity(flat_d.unsqueeze(0), flat_t.unsqueeze(0)).item() + # ) + track["l2_to_target"].append((x_next - self.cur_token).norm().item()) track["x"].append(self.cur_token.cpu()) self._plot_sampling_diagnostics(track, num_steps) @@ -250,10 +251,11 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt + import matplotlib.colors as mcolors steps = list(range(len(track["sigma"]))) has_target = len(track["l2_to_target"]) > 0 - n_plots = 5 if has_target else 3 + n_plots = 3 fig, axes = plt.subplots(n_plots, 1, figsize=(10, 3 * n_plots), sharex=True) @@ -268,37 +270,28 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: axes[0].legend(fontsize=8) axes[0].grid(True, alpha=0.3) - # 2) c_skip — skip connection weight (EDM preconditioning) - axes[1].plot(steps, track["c_skip"], "o-", markersize=3, color="tab:orange") - axes[1].set_ylabel("c_skip") - axes[1].set_title("c_skip = σ_data² / (σ² + σ_data²) — 1.0 means output ≈ input (no correction)") - axes[1].axhline(0.5, color="grey", ls="--", lw=0.8, label="c_skip=0.5 (σ=σ_data)") - axes[1].set_ylim(-0.05, 1.05) - axes[1].legend(fontsize=8) - axes[1].grid(True, alpha=0.3) - # 3) Std of x_next and denoised estimate - axes[2].plot(steps, track["x_std"], "o-", markersize=3, label="x (noisy state)") - axes[2].plot(steps, track["denoised_std"], "s-", markersize=3, label="denoised estimate") + axes[1].plot(steps, track["x_std"], "o-", markersize=3, label="x (noisy state)") + axes[1].plot(steps, track["denoised_std"], "s-", markersize=3, label="denoised estimate") if self.cur_token is not None: target_std = self.cur_token.std().item() - axes[2].axhline(target_std, color="grey", ls="--", lw=0.8, label=f"target std={target_std:.3f}") - axes[2].set_ylabel("std") - axes[2].legend(fontsize=8) - axes[2].grid(True, alpha=0.3) + axes[1].axhline(target_std, color="grey", ls="--", lw=0.8, label=f"target std={target_std:.3f}") + axes[1].set_ylabel("std") + axes[1].legend(fontsize=8) + axes[1].grid(True, alpha=0.3) if has_target: # 4) L2 error to target - axes[3].plot(steps, track["l2_to_target"], "o-", markersize=3, color="tab:red") - axes[3].set_ylabel("L2 error to target") - axes[3].grid(True, alpha=0.3) + axes[2].plot(steps, track["l2_to_target"], "o-", markersize=3, color="tab:red") + axes[2].set_ylabel("L2 error to target") + axes[2].grid(True, alpha=0.3) - # 5) Cosine similarity to target - axes[4].plot(steps, track["cosine_to_target"], "o-", markersize=3, color="tab:green") - axes[4].set_ylabel("cosine sim to target") - axes[4].set_ylim(-1.05, 1.05) - axes[4].axhline(1.0, color="grey", ls="--", lw=0.8) - axes[4].grid(True, alpha=0.3) + # # 5) Cosine similarity to target + # axes[4].plot(steps, track["cosine_to_target"], "o-", markersize=3, color="tab:green") + # axes[4].set_ylabel("cosine sim to target") + # axes[4].set_ylim(-1.05, 1.05) + # axes[4].axhline(1.0, color="grey", ls="--", lw=0.8) + # axes[4].grid(True, alpha=0.3) axes[-1].set_xlabel("sampling step") fig.tight_layout() @@ -315,17 +308,25 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: for s_idx, x in enumerate(track["x"]): fig, axes2 = plt.subplots(1, 2, figsize=(12, 5)) - im0 = axes2[0].imshow(x[0].t().cpu(), aspect="auto", vmin=vmin, vmax=vmax, cmap="seismic") + abs_max = max(abs(vmin), abs(vmax)) * 0.1 + # im0 = axes2[0].imshow(x[0].t().cpu(), aspect="auto", cmap="seismic", + # norm=mcolors.SymLogNorm(linthresh=1e-2, vmin=-abs_max, vmax=abs_max)) + im0 = axes2[0].imshow(x[0].t().cpu(), aspect="auto", cmap="seismic", vmin=vmin, vmax=vmax) plt.colorbar(im0, ax=axes2[0]) - axes2[0].set_title(f"Sample at step {s_idx}") - axes2[0].set_xlabel("embedding dim") - axes2[0].set_ylabel("healpix cell") + if s_idx == len(track["x"]) - 1: + axes2[0].set_title(f"Target") + else: + axes2[0].set_title(f"Sample at step {s_idx}") + axes2[0].set_xlabel("healpix cell") + axes2[0].set_ylabel("embedding dim") diff = (x[0].cpu() - track["x"][-1][0].cpu()).t() - im1 = axes2[1].imshow(diff, aspect="auto", cmap="bwr") + # im1 = axes2[1].imshow(diff, aspect="auto", cmap="bwr", + # norm=mcolors.SymLogNorm(linthresh=1e-2, vmin=-0.2, vmax=0.2)) + im1 = axes2[1].imshow(diff, aspect="auto", cmap="bwr", vmin=-1, vmax=1) plt.colorbar(im1, ax=axes2[1]) axes2[1].set_title("Difference to target") - axes2[1].set_xlabel("embedding dim") + axes2[1].set_xlabel("healpix cell") fig.tight_layout() plt.savefig(out_path_base / f"sample_{s_idx:05d}.png", dpi=100) From f746eab4ca6568f3842b9357dfb8c69d5c52144d Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 8 Apr 2026 18:53:46 +0200 Subject: [PATCH 273/344] Minor adjustments --- src/weathergen/model/diffusion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 8d9cc4514..0e9c85366 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - # return self.inference(fstep=fstep, num_steps=10, coords=coords) + return self.inference(fstep=fstep, num_steps=10, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -155,7 +155,7 @@ def inference( # Sample pure noise (assuming single batch element for now) torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") - # x = self.cur_token * 1.0 + x * 0.1 + # x = self.cur_token * 0.025 + x # --- Training-aligned sigma bounds --- # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). @@ -244,6 +244,7 @@ def inference( track["x"].append(self.cur_token.cpu()) self._plot_sampling_diagnostics(track, num_steps) + # self._plot_sampling_process(track, num_steps) return x_next def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: @@ -304,6 +305,7 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: plt.close(fig) logger.info(f"Saved sampling diagnostics to {out_path_base / 'sampling_diagnostics.png'}") + def _plot_sampling_process(self, track: dict, num_steps: int) -> None: vmin, vmax = track["x"][-1].min().item(), track["x"][-1].max().item() for s_idx, x in enumerate(track["x"]): fig, axes2 = plt.subplots(1, 2, figsize=(12, 5)) From 45790cd7393a56e8734b557eb6e64ca702d5e931 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 8 Apr 2026 21:55:16 +0200 Subject: [PATCH 274/344] Minor adjustments --- src/weathergen/model/diffusion.py | 33 ++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 0e9c85366..3af1938c3 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - return self.inference(fstep=fstep, num_steps=10, coords=coords) + # return self.inference(fstep=fstep, num_steps=10, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -155,7 +155,7 @@ def inference( # Sample pure noise (assuming single batch element for now) torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") - # x = self.cur_token * 0.025 + x + # x = self.cur_token * 0.05 + x # --- Training-aligned sigma bounds --- # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). @@ -234,17 +234,11 @@ def inference( track["denoised_std"].append(denoised.std().item()) track["x"].append(x_next.cpu()) if self.cur_token is not None: - # flat_d = denoised.reshape(-1).float() - # flat_t = self.cur_token.reshape(-1).float() - # track["l2_to_target"].append((flat_d - flat_t).norm().item()) - # track["cosine_to_target"].append( - # torch.nn.functional.cosine_similarity(flat_d.unsqueeze(0), flat_t.unsqueeze(0)).item() - # ) track["l2_to_target"].append((x_next - self.cur_token).norm().item()) track["x"].append(self.cur_token.cpu()) self._plot_sampling_diagnostics(track, num_steps) - # self._plot_sampling_process(track, num_steps) + self._plot_sampling_process(track, num_steps) return x_next def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: @@ -271,7 +265,7 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: axes[0].legend(fontsize=8) axes[0].grid(True, alpha=0.3) - # 3) Std of x_next and denoised estimate + # 2) Std of x_next and denoised estimate axes[1].plot(steps, track["x_std"], "o-", markersize=3, label="x (noisy state)") axes[1].plot(steps, track["denoised_std"], "s-", markersize=3, label="denoised estimate") if self.cur_token is not None: @@ -282,18 +276,11 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: axes[1].grid(True, alpha=0.3) if has_target: - # 4) L2 error to target + # 3) L2 error to target axes[2].plot(steps, track["l2_to_target"], "o-", markersize=3, color="tab:red") axes[2].set_ylabel("L2 error to target") axes[2].grid(True, alpha=0.3) - # # 5) Cosine similarity to target - # axes[4].plot(steps, track["cosine_to_target"], "o-", markersize=3, color="tab:green") - # axes[4].set_ylabel("cosine sim to target") - # axes[4].set_ylim(-1.05, 1.05) - # axes[4].axhline(1.0, color="grey", ls="--", lw=0.8) - # axes[4].grid(True, alpha=0.3) - axes[-1].set_xlabel("sampling step") fig.tight_layout() @@ -306,6 +293,16 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: logger.info(f"Saved sampling diagnostics to {out_path_base / 'sampling_diagnostics.png'}") def _plot_sampling_process(self, track: dict, num_steps: int) -> None: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import matplotlib.colors as mcolors + + out_dir = get_path_run(self.cf) + out_dir.mkdir(exist_ok=True, parents=True) + out_path_base = out_dir / "plots" / "validation" / "plots" + out_path_base.mkdir(exist_ok=True, parents=True) + vmin, vmax = track["x"][-1].min().item(), track["x"][-1].max().item() for s_idx, x in enumerate(track["x"]): fig, axes2 = plt.subplots(1, 2, figsize=(12, 5)) From bd42849056868f5f74ed1a07b458c18b4ff9321a Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 9 Apr 2026 12:49:46 +0200 Subject: [PATCH 275/344] diffusion adjustment --- src/weathergen/model/diffusion.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 85352354c..2f5359578 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -111,19 +111,21 @@ def forward( ) else: # NOTE: temporary for analysing denoising - return self.training_forward( - tokens=tokens, - fstep=fstep, - meta_info=meta_info, - coords=coords, - ) - # if fstep is None: - # raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") - # return self.inference_forward( + # return self.training_forward( + # tokens=tokens, # fstep=fstep, - # num_steps=num_steps, # meta_info=meta_info, + # coords=coords, # ) + if fstep is None: + raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") + + return self.inference_forward( + tokens=tokens, # TODO: remove after single sample experiments + fstep=fstep, + num_steps=num_steps, + meta_info=meta_info, + ) def training_forward( self, @@ -192,6 +194,7 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int) -> def inference_forward( self, + tokens, fstep: int, num_steps: int = 30, meta_info: dict[str, SampleMetaData] = None, @@ -211,6 +214,7 @@ def inference_forward( Returns: torch.Tensor: Generated sample of shape (1, num_healpix_cells, ae_global_dim_embed) """ + # Extract conditioning from meta_info (same as training_forward) c = None if meta_info is not None: @@ -219,6 +223,8 @@ def inference_forward( # Sample noise (assuming single batch element for now) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") + x = tokens * 0.66 + x * 0.33 #NOTE: for debugging only! + # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") t_steps = ( @@ -240,6 +246,7 @@ def inference_forward( zip(t_steps[:-1], t_steps[1:], strict=False) ): # 0, ..., N-1 x_cur = x_next + print(f"Step {i+1}/{num_steps}: t_cur={t_cur.item():.4f}, t_next={t_next.item():.4f}") # Increase noise temporarily. (Stochastic sampling; not used for now) # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 From ba913de38a1a70011a94844d602777e489761074 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 10 Apr 2026 14:56:44 +0200 Subject: [PATCH 276/344] Config edits for 512 dim model --- config/config_diffusion.yml | 15 ++++++++++----- config/config_diffusion_tiny.yml | 7 +++++-- config/streams/era5_1deg_diffusion_tiny/era5.yml | 8 ++++---- src/weathergen/model/diffusion.py | 6 +++--- src/weathergen/model/model.py | 2 ++ 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 91981c491..ddda013fe 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -99,23 +99,28 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 # load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 # load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 # load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 -load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 # load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 -# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data= +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d2512 hl5 enc-lnorm, sigma_data=1.0 norm_type: "LayerNorm" ##################################### -# streams_directory: "./config/streams/era5_1deg_forecasting/" -streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams_directory: "./config/streams/era5_1deg_forecasting/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" streams: ??? # type of zarr_store diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml index 065e76d2b..ba306d483 100644 --- a/config/config_diffusion_tiny.yml +++ b/config/config_diffusion_tiny.yml @@ -128,7 +128,7 @@ healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding -rope_2D: True +rope_2D: False with_mixed_precision: True with_flash_attention: True @@ -148,7 +148,10 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" # load_chkpt: {'run_id': 'lgrasnq6', 'epoch': -1} # z500 small -load_chkpt: {'run_id': 'hhz27wy0', 'epoch': -1} # multi-var small, sigma_data=0.7855 +# load_chkpt: {'run_id': 'hhz27wy0', 'epoch': -1} # multi-var small, sigma_data=0.7855 +# load_chkpt: {'run_id': 'xpwjhaf4', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ml89b5r6', 'epoch': -1} # multi-var d128 hl5, sigma_data=0.2415 +load_chkpt: {'run_id': 'a3n1pdkl', 'epoch': -1} # multi-var d128 hl5, nopos, sigma_data=0.2507 norm_type: "LayerNorm" diff --git a/config/streams/era5_1deg_diffusion_tiny/era5.yml b/config/streams/era5_1deg_diffusion_tiny/era5.yml index efa163d42..40da38c09 100644 --- a/config/streams/era5_1deg_diffusion_tiny/era5.yml +++ b/config/streams/era5_1deg_diffusion_tiny/era5.yml @@ -11,10 +11,10 @@ ERA5 : type : anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] stream_id : 0 - source : ["z_500"] - target : ["z_500"] - # source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] - # target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + # source : ["z_500"] + # target : ["z_500"] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 3af1938c3..ed6d3cd61 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - # return self.inference(fstep=fstep, num_steps=10, coords=coords) + return self.inference(fstep=fstep, num_steps=4, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -155,7 +155,7 @@ def inference( # Sample pure noise (assuming single batch element for now) torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") - # x = self.cur_token * 0.05 + x + # x = self.cur_token * 0.02 + x # --- Training-aligned sigma bounds --- # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). @@ -238,7 +238,7 @@ def inference( track["x"].append(self.cur_token.cpu()) self._plot_sampling_diagnostics(track, num_steps) - self._plot_sampling_process(track, num_steps) + # self._plot_sampling_process(track, num_steps) return x_next def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 754289d16..0269c69f2 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -737,6 +737,8 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) + # breakpoint() + # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. # t_mean = tokens.mean() From f02471557202b856614048a783668fb55d60224d Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Mon, 13 Apr 2026 12:33:55 +0200 Subject: [PATCH 277/344] Clean up for reproduction and hand-over --- config/config_diffusion.yml | 26 ++-- config/streams/era5_1deg_forecasting/era5.yml | 4 +- src/weathergen/model/diffusion.py | 115 +++--------------- src/weathergen/model/engines.py | 25 ---- src/weathergen/model/model.py | 2 - 5 files changed, 29 insertions(+), 143 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index ddda013fe..ed6017969 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 2048 +ae_local_dim_embed: 512 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -25,7 +25,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 +ae_global_dim_embed: 512 ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 @@ -69,8 +69,8 @@ with_step_conditioning: True # False frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 -sigma_max: 80 # 170 -sigma_data: 1.7 # 0.5 # 1.7, 157.38 +sigma_max: 80 +sigma_data: 0.5789 rho: 7 p_mean: -1.2 p_std: 1.2 @@ -99,7 +99,7 @@ latent_noise_deterministic_latents: True freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 # load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 # load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 @@ -110,9 +110,11 @@ load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_da # load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 # load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 # load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 -# load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 # load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 -# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d2512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 norm_type: "LayerNorm" @@ -172,8 +174,8 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 512 - samples_per_mini_epoch: 1024 + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 shuffle: True start_date: 2012-06-01T00:00 @@ -183,8 +185,8 @@ training_config: time_window_len: 06:00:00 learning_rate_scheduling : - lr_start: 1e-5 #5e-5 - lr_max: 5e-5 #1e-4 + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 @@ -243,8 +245,6 @@ validation_config: # Noise levels (eta values in standard normal space) at which to evaluate the # diffusion model during validation. sigma = exp(eta * p_std + p_mean). # Each value produces a separate validation pass with independently logged metrics. - # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] - # validation_noise_levels: [2.0, 3.0, 3.2, 3.5, 4.0, 5.0] validation_noise_levels: [1.0, 2.0, 3.0, 4.0] samples_per_mini_epoch: 16 diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 09f88b654..569922fb9 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -25,11 +25,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 4 - dim_embed : 512 + dim_embed : 256 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 512 + dim_embed : 256 target_readout : num_layers : 2 num_heads : 4 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index ed6d3cd61..54fde2008 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,7 +99,7 @@ def forward( ) self.cur_token = tokens.detach() - return self.inference(fstep=fstep, num_steps=4, coords=coords) + # return self.inference(fstep=fstep, num_steps=10, coords=coords) c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step) y = tokens @@ -138,7 +138,6 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co x = self.preconditioner.precondition(x, c) # Direct prediction: network outputs denoised estimate directly - # return self.net(x, fstep=fstep, coords=coords, noise_emb=noise_emb) return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb ) # Eq. (7) in EDM paper @@ -155,8 +154,19 @@ def inference( # Sample pure noise (assuming single batch element for now) torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") - # x = self.cur_token * 0.02 + x + ### OLD WAY OF COMPUTING SIGMA SCHEDULE + # # Time step discretization. + # step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") + # t_steps = ( + # self.sigma_max ** (1 / self.rho) + # + step_indices + # / (num_steps - 1) + # * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) + # ) ** self.rho + # t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + ### NEW WAY OF COMPUTING SIGMA SCHEDULE WITH TRAINING-ALIGNED BOUNDS AND DIAGNOSTICS # --- Training-aligned sigma bounds --- # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). # The network only learns to denoise reliably within the training distribution. @@ -236,9 +246,8 @@ def inference( if self.cur_token is not None: track["l2_to_target"].append((x_next - self.cur_token).norm().item()) track["x"].append(self.cur_token.cpu()) - self._plot_sampling_diagnostics(track, num_steps) - # self._plot_sampling_process(track, num_steps) + return x_next def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: @@ -292,102 +301,6 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: plt.close(fig) logger.info(f"Saved sampling diagnostics to {out_path_base / 'sampling_diagnostics.png'}") - def _plot_sampling_process(self, track: dict, num_steps: int) -> None: - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - import matplotlib.colors as mcolors - - out_dir = get_path_run(self.cf) - out_dir.mkdir(exist_ok=True, parents=True) - out_path_base = out_dir / "plots" / "validation" / "plots" - out_path_base.mkdir(exist_ok=True, parents=True) - - vmin, vmax = track["x"][-1].min().item(), track["x"][-1].max().item() - for s_idx, x in enumerate(track["x"]): - fig, axes2 = plt.subplots(1, 2, figsize=(12, 5)) - - abs_max = max(abs(vmin), abs(vmax)) * 0.1 - # im0 = axes2[0].imshow(x[0].t().cpu(), aspect="auto", cmap="seismic", - # norm=mcolors.SymLogNorm(linthresh=1e-2, vmin=-abs_max, vmax=abs_max)) - im0 = axes2[0].imshow(x[0].t().cpu(), aspect="auto", cmap="seismic", vmin=vmin, vmax=vmax) - plt.colorbar(im0, ax=axes2[0]) - if s_idx == len(track["x"]) - 1: - axes2[0].set_title(f"Target") - else: - axes2[0].set_title(f"Sample at step {s_idx}") - axes2[0].set_xlabel("healpix cell") - axes2[0].set_ylabel("embedding dim") - - diff = (x[0].cpu() - track["x"][-1][0].cpu()).t() - # im1 = axes2[1].imshow(diff, aspect="auto", cmap="bwr", - # norm=mcolors.SymLogNorm(linthresh=1e-2, vmin=-0.2, vmax=0.2)) - im1 = axes2[1].imshow(diff, aspect="auto", cmap="bwr", vmin=-1, vmax=1) - plt.colorbar(im1, ax=axes2[1]) - axes2[1].set_title("Difference to target") - axes2[1].set_xlabel("healpix cell") - - fig.tight_layout() - plt.savefig(out_path_base / f"sample_{s_idx:05d}.png", dpi=100) - plt.close(fig) - logger.info(f"Saved sample visualization to {out_path_base / f'sample_{s_idx:05d}.png'}") - - - # # --- OLD inference (before training-aligned sigma & diagnostics) --- - # def inference( - # self, - # fstep: int, - # num_steps: int = 30, - # coords: torch.Tensor = None, - # ) -> torch.Tensor: - # # Forward pass of the diffusion model during inference - # # https://github.com/NVlabs/edm/blob/main/generate.py - # - # # Sample noise (assuming single batch element for now) - # torch.manual_seed(42) - # x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") * 1.0 - # - # x = self.cur_token * 0.0 + x - # - # # Time step discretization. - # step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") - # t_steps = ( - # self.sigma_max ** (1 / self.rho) - # + step_indices - # / (num_steps - 1) - # * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) - # ) ** self.rho - # t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 - # - # # Main sampling loop. - # x_next = x * t_steps[0] - # for i, (t_cur, t_next) in enumerate( - # zip(t_steps[:-1], t_steps[1:], strict=False) - # ): # 0, ..., N-1 - # t_cur = torch.tensor([t_cur], device="cuda").float() - # t_next = torch.tensor([t_next], device="cuda").float() - # - # print(i, t_cur.item()) - # - # x_cur = x_next - # - # x_hat = x_cur - # t_hat = t_cur - # - # # Euler step. - # denoised = self.denoise(x=x_hat, c=None, sigma=t_hat, fstep=fstep, coords=coords) - # d_cur = (x_hat - denoised) / t_hat - # x_next = x_hat + (t_next - t_hat) * d_cur - # - # # Apply 2nd order correction. - # if i < num_steps - 1: - # denoised = self.denoise(x=x_next, c=None, sigma=t_next, fstep=fstep, coords=coords) - # d_prime = (x_next - denoised) / t_next - # x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) - # - # return x_next - # # --- END OLD inference --- - class Preconditioner: # Preconditioner, e.g., to concatenate previous frames to the input diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78694477f..7c42f9a27 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -405,8 +405,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() - # self.position_layer = torch.nn.Linear(2, self.cf.ae_global_dim_embed) - global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: for i in range(self.cf.fe_num_blocks): @@ -465,30 +463,11 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) - # self.fe_blocks.append( - # MLP( - # self.cf.ae_global_dim_embed, - # self.cf.ae_global_dim_embed, - # num_layers=12, - # with_residual=True, - # pre_layer_norm=True, # TODO: REMOVE AGAIN - # dropout_rate=self.cf.fe_dropout_rate, - # norm_type=self.cf.norm_type, - # dim_aux=dim_aux, - # norm_eps=self.cf.mlp_norm_eps, - # with_noise_conditioning=True, # TODO: SWITCH BACK TO TRUE - # ) - # ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) if m.bias is not None: torch.nn.init.normal_(m.bias, mean=0, std=0.001) - # def init_weights_final(m): - # if isinstance(m, torch.nn.Linear): - # torch.nn.init.normal_(m.weight, mean=0, std=0.1) - # if m.bias is not None: - # torch.nn.init.normal_(m.bias, mean=0, std=0.1) for block in self.fe_blocks: block.apply(init_weights_final) @@ -523,10 +502,6 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - # if isinstance(block, MLP): - # # tokens = torch.concat([tokens, coords], dim=-1) if coords is not None else tokens - # # TODO: REMOVE - # tokens = tokens + self.position_layer(coords) # Assuming args[1] contains positional information tokens = checkpoint(block, tokens, coords, noise_emb, aux_info, use_reentrant=False) else: for block in self.fe_blocks: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 0269c69f2..754289d16 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -737,8 +737,6 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) - # breakpoint() - # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. # t_mean = tokens.mean() From 9c0904359d457337ee954a4c33a034d323513059 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 13 Apr 2026 16:01:24 +0200 Subject: [PATCH 278/344] merged conditioning with Matze's branch --- config/config_diffusion.yml | 10 +++++----- src/weathergen/model/attention.py | 1 - src/weathergen/model/diffusion.py | 23 ++++++++++++++--------- src/weathergen/model/engines.py | 17 +++++++++-------- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 24244fed0..a9f49f273 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,7 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: "date_time" # options: "date_time", "forecast_step", "none" +fe_diffusion_model_conditioning: None # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -180,8 +180,8 @@ training_config: samples_per_mini_epoch: 4096 shuffle: True - start_date: 1979-06-01T00:00 - end_date: 1979-06-01T18:00 + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -252,8 +252,8 @@ validation_config: samples_per_mini_epoch: 1 shuffle: True # TODO: Set back to False - start_date: 1979-06-01T00:00 - end_date: 1979-06-01T18:00 + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 # whether to track the exponential moving average of weights for validation validate_with_ema: diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 39b5e3b73..1f3df7242 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -591,7 +591,6 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): x = self.lnorm(x, ada_ln_aux) if self.noise_conditioning: - assert emb is not None, "Need noise embedding if using noise conditioning" x, gate = self.noise_conditioning(x, emb) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index ea935e1e3..10013b181 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -123,10 +123,10 @@ def forward( raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") return self.inference_forward( - tokens=tokens, # TODO: remove after single sample experiments fstep=fstep, num_steps=num_steps, meta_info=meta_info, + coords=coords, ) def training_forward( @@ -150,7 +150,11 @@ def training_forward( self.cur_token = tokens - c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) + if self.cf.fe_diffusion_model_conditioning == "date_time": + c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) + else: + c = None + y = tokens if self.training: @@ -188,7 +192,8 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co # Precondition input and feed through network x = self.preconditioner.precondition(x, c) #currently does nothing - c = self.datetime_embedder(c).to(x.device) + if self.cf.fe_diffusion_model_conditioning == "date_time": + c = self.datetime_embedder(c).to(x.device) return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=c @@ -196,11 +201,10 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co def inference_forward( self, - tokens, fstep: int, num_steps: int = 50, - coords: torch.Tensor = None, meta_info: dict[str, SampleMetaData] = None, + coords: torch.Tensor = None, ) -> torch.Tensor: """ Forward pass of the diffusion model during inference. @@ -213,18 +217,19 @@ def inference_forward( fstep: Forecast step index for the network num_steps: Number of diffusion denoising steps (default: 30) meta_info: Optional sample metadata dict containing timestamps for temporal conditioning - + coords: Optional coordinate tensor for spatial conditioning Returns: torch.Tensor: Generated sample of shape (1, num_healpix_cells, ae_global_dim_embed) """ # Extract conditioning from meta_info (same as training_forward) c = None - if meta_info is not None: + + if self.cf.fe_diffusion_model_conditioning == "date_time": c = meta_info["ERA5"].params["timestamp"] # Sample pure noise (assuming single batch element for now) - torch.manual_seed(42) + # torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") ### OLD WAY OF COMPUTING SIGMA SCHEDULE @@ -323,7 +328,7 @@ def inference_forward( track["x"].append(x_next.cpu()) if self.cur_token is not None: track["l2_to_target"].append((x_next - self.cur_token).norm().item()) - track["x"].append(self.cur_token.cpu()) + track["x"].append(self.cur_token.cpu()) self._plot_sampling_diagnostics(track, num_steps) return x_next diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index de9a62fd7..7a5b98e3a 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -496,18 +496,19 @@ def forward( tokens_in = tokens if self.cf.fe_diffusion_model: - assert ada_ln_aux is not None, ( - "Conditioning (noise and other) must be provided for diffusion forecast engine" - ) for block in self.fe_blocks: + breakpoint() if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - # if isinstance(block, MLP): - # # tokens = torch.concat([tokens, coords], dim=-1) if coords is not None else tokens - # # TODO: REMOVE - # tokens = tokens + self.position_layer(coords) # Assuming args[1] contains positional information - tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) + if self.cf.fe_diffusion_model_conditioning in ["date_time"]: + # Assuming ada_ln_aux contains the date_time embedding in this case + assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) + else: + assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" + assert noise_emb is not None, "noise_emb must be provided for diffusion model conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): From 1b06241f7820cf4baeb6b70f78b7195abcc325bd Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 13 Apr 2026 17:32:13 +0200 Subject: [PATCH 279/344] update num blocks --- config/config_diffusion.yml | 4 ++-- src/weathergen/model/diffusion.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index a9f49f273..f103d4a42 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -55,7 +55,7 @@ num_register_tokens: 0 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 2 +fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -249,7 +249,7 @@ validation_config: # Each value produces a separate validation pass with independently logged metrics. validation_noise_levels: [1.0, 2.0, 3.0, 4.0] - samples_per_mini_epoch: 1 + samples_per_mini_epoch: 16 shuffle: True # TODO: Set back to False start_date: 2012-06-01T00:00 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 10013b181..b7db378cc 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -154,7 +154,7 @@ def training_forward( c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) else: c = None - + y = tokens if self.training: @@ -298,7 +298,6 @@ def inference_forward( t_next = torch.tensor([t_next], device="cuda").float() x_cur = x_next - print(f"Step {i+1}/{num_steps}: t_cur={t_cur.item():.4f}, t_next={t_next.item():.4f}") # Increase noise temporarily. (Stochastic sampling; not used for now) # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 From 394fba322409d2c505ca3cd741824c11e8124120 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 13 Apr 2026 18:39:05 +0200 Subject: [PATCH 280/344] incorporate feedback from matze --- src/weathergen/model/diffusion.py | 13 +- src/weathergen/model/engines.py | 1 - src/weathergen/utils/validation_io.py | 185 ++++++++------------------ 3 files changed, 63 insertions(+), 136 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index b7db378cc..f5c6a31f5 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -113,12 +113,13 @@ def forward( ) else: # NOTE: temporary for analysing denoising - # return self.training_forward( - # tokens=tokens, - # fstep=fstep, - # meta_info=meta_info, - # coords=coords, - # ) + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) + if fstep is None: raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7a5b98e3a..191d3bdf4 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -497,7 +497,6 @@ def forward( if self.cf.fe_diffusion_model: for block in self.fe_blocks: - breakpoint() if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 380845c88..6a621044f 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -11,16 +11,15 @@ from math import exp import re -import cartopy.crs as ccrs -import matplotlib.colors as mcolors -import matplotlib.pyplot as plt import numpy as np import torch +import xarray as xr import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange, zarrio_writer from weathergen.datasets.data_reader_base import TimeWindowHandler +from weathergen.evaluate.plotting.plotter import Plotter _logger = logging.getLogger(__name__) @@ -94,7 +93,6 @@ def write_output( if loss_term.type == "LossPhysical" ] assert len(outputs_physical) == 1 - target_aux_out = target_aux_out[outputs_physical[0]] # collect all target / prediction-related information @@ -254,42 +252,27 @@ def write_output( for subset in data.items(): zio.write_zarr(subset) - # Extract a representative date per batch sample from target times before - # they are freed. Use the first non-NaT timestamp found in t_idx=0. - sample_dates: list[str] = [] - if len(targets_times_all) > 0: - for stream_times in targets_times_all[0]: - if stream_times.size > 0: - valid_times = stream_times[~np.isnat(stream_times)] - if valid_times.size > 0: - sample_dates.append(str(valid_times[0].astype("datetime64[h]"))) - break # Free arrays no longer needed after zarr writing del targets_all, targets_times_all, targets_lens, sources, data # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # Prepare prediction data for plotting (scatter plot expects lat/lon coords on ipoint). + # Prepare prediction data for Plotter (scatter plot expects lat/lon coords on ipoint). base_plot_dir = config.get_path_run(cf) / "plots" / "validation" base_plot_dir.mkdir(parents=True, exist_ok=True) - dpi_val = 150 - image_format = "png" + plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) # headline_channels = {"2t", "z500", "q850", "10u", "10v"} # headline_channels = {"2t", "q850"} # headline_channels = {"z500"} headline_channels = {"2t", "z500"} - t_idx = np.random.randint(0, len(preds_all)) # TODO: loop over all time steps once plotting is set up + t_idx = 0 for stream_idx, stream_info in enumerate(cf.streams): stream_name = stream_info["name"] preds_stream = preds_all[t_idx][stream_idx] coords_stream = targets_coords_all[t_idx][stream_idx] - # Check for noised data for this stream - noised_stream = noised_preds_all[t_idx][stream_idx] - has_noised = noised_stream.size > 0 and noised_stream.ndim >= 2 - if preds_stream.size == 0 or coords_stream.size == 0: _logger.warning(f"No prediction data to plot for stream {stream_name}.") continue @@ -303,12 +286,6 @@ def write_output( ) continue - if has_noised: - if noised_stream.ndim == 3: - noised_stream = noised_stream[0] - elif noised_stream.ndim != 2: - has_noised = False - channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) selected_channels = [ ch for ch in channels if _normalize_channel_name(ch) in headline_channels @@ -325,122 +302,72 @@ def write_output( lat = coords_stream[:, 0] lon = coords_stream[:, 1] - run_id = config.get_run_id_from_config(cf) + plotter.stream = stream_name + plotter.run_id = config.get_run_id_from_config(cf) + plotter.fstep = forecast_offset num_samples = len(preds) len_per_sample = preds_stream.shape[0] // num_samples - noised_len_per_sample = noised_stream.shape[0] // num_samples if has_noised else 0 - - if noise_level is not None: - # Format with .1e to preserve one decimal place in mantissa, then clean up exponent notation - eta_str = re.sub(r'e[+]?0*(?=\d)', 'e', re.sub(r'e-0*(?=\d)', 'e-', f'{noise_level:.1e}')) - else: - eta_str = None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" for sample in range(num_samples): s_start = sample * len_per_sample s_end = (sample + 1) * len_per_sample - ns_start = sample * noised_len_per_sample if has_noised else 0 - ns_end = (sample + 1) * noised_len_per_sample if has_noised else 0 - - # Extract sample date from target times - sample_date_str = sample_dates[0] if len(sample_dates) > 0 else "" for varname in selected_channels: col = ch_to_col[varname] - - # --- denoised data --- - den_vals = preds_stream[s_start:s_end, col] - den_lat = lat[s_start:s_end] - den_lon = lon[s_start:s_end] - den_valid = ~np.isnan(den_vals) - den_vals, den_lat, den_lon = den_vals[den_valid], den_lat[den_valid], den_lon[den_valid] - - # --- noised data --- - if has_noised: - noi_vals = noised_stream[ns_start:ns_end, col] - noi_lat = lat[ns_start:ns_end] - noi_lon = lon[ns_start:ns_end] - noi_valid = ~np.isnan(noi_vals) - noi_vals, noi_lat, noi_lon = noi_vals[noi_valid], noi_lat[noi_valid], noi_lon[noi_valid] - else: - noi_vals = noi_lat = noi_lon = None - - # Shared colour scale across both panels - all_vals = np.concatenate([den_vals] + ([noi_vals] if noi_vals is not None else [])) - vmin, vmax = float(np.nanmin(all_vals)), float(np.nanmax(all_vals)) - norm = mcolors.Normalize(vmin=vmin, vmax=vmax) - cmap = plt.get_cmap("coolwarm") - - ncols = 2 if has_noised else 1 - proj = ccrs.Robinson() - fig, axes = plt.subplots( - 1, ncols, figsize=(8 * ncols, 5), dpi=dpi_val, - subplot_kw={"projection": proj}, + vals = preds_stream[s_start:s_end, col] + sample_lat = lat[s_start:s_end] + sample_lon = lon[s_start:s_end] + + # Drop NaN points + valid = ~np.isnan(vals) + vals = vals[valid] + sample_lat = sample_lat[valid] + sample_lon = sample_lon[valid] + + sample_da = xr.DataArray( + vals, + dims=("ipoint",), + coords={ + "ipoint": np.arange(len(vals)), + "lat": ("ipoint", sample_lat), + "lon": ("ipoint", sample_lon), + }, ) - if ncols == 1: - axes = [axes] - - # Left panel: noised (or skip if not available) - if has_noised: - ax_noised = axes[0] - ax_noised.coastlines() - ax_noised.set_global() - sc_n = ax_noised.scatter( - noi_lon, noi_lat, c=noi_vals, norm=norm, cmap=cmap, - s=2.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, - ) - if eta_str is not None: - ax_noised.set_title(f"Noised | {varname} | eta={eta_str}", fontsize=9.5) - else: - ax_noised.set_title(f"Noised | {varname}", fontsize=9.5) - ax_noised.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1) - ax_denoised = axes[1] - else: - ax_denoised = axes[0] - - # Right panel (or only panel): denoised - ax_denoised.coastlines() - ax_denoised.set_global() - sc_d = ax_denoised.scatter( - den_lon, den_lat, c=den_vals, norm=norm, cmap=cmap, - s=2.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, - ) - if eta_str is not None: - ax_denoised.set_title(f"Denoised | {varname} | eta={eta_str}", fontsize=9.5) - else: - ax_denoised.set_title(f"Denoised | {varname}", fontsize=9.5) - ax_denoised.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1) - - # Shared colourbar - fig.colorbar( - sc_d, ax=axes, orientation="horizontal", - label=f"Variable: {varname}", fraction=0.05, pad=0.07, - ) - - date_part = f" | {sample_date_str}" if sample_date_str else "" - if eta_str is not None: - fig.suptitle( - f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}{date_part} | eta={eta_str}", - fontsize=11, - ) - else: - fig.suptitle( - f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}{date_part}", - fontsize=11, - ) channel_dir = base_plot_dir / varname channel_dir.mkdir(parents=True, exist_ok=True) + epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}" + # Add noise_level_rn to title if present for this stream + if noise_level is not None: + eta_str = str(noise_level) + else: + eta_str = None + eta_tag = f"_eta{eta_str}" if eta_str is not None else "" epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - fname = channel_dir / f"{epoch_tag}.{image_format}" - fig.savefig(fname, bbox_inches="tight") - plt.close(fig) + + if noise_level is not None: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1} | noise_level={eta_str}" + else: + title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}" + + plot_name = plotter.scatter_plot( + sample_da, + channel_dir, + varname=varname, + regionname="global", + tag=epoch_tag, + title=title, + ) + src = channel_dir / f"{plot_name}.{plotter.image_format}" + dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" + if src != dst: + try: + src.replace(dst) + except (FileNotFoundError, OSError): + pass # another rank already renamed or removed the file - del den_vals, den_lat, den_lon, den_valid - if has_noised: - del noi_vals, noi_lat, noi_lon, noi_valid + del sample_da, vals, sample_lat, sample_lon, valid del preds_stream, coords_stream @@ -546,4 +473,4 @@ def write_output( del noised_stream, coords_stream - i += 1 + i += 1 \ No newline at end of file From 1f8fc094fdf776f3b072ac968eeb9bfb90848a47 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Wed, 15 Apr 2026 17:51:56 +0200 Subject: [PATCH 281/344] LayerNorm config --- config/config_diffusion.yml | 6 +++--- src/weathergen/model/encoder.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index ed6017969..101c95a06 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -110,9 +110,9 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 # load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 # load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 -load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 # load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 -# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 @@ -270,7 +270,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: True + validate_before_training: False # test config; full test config is merge of validation and test config diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 77ba3f3da..c103b083a 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -112,7 +112,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) - # self.ln = torch.nn.LayerNorm(cf.ae_local_dim_embed, elementwise_affine=False) + self.ln = torch.nn.LayerNorm(cf.ae_local_dim_embed, elementwise_affine=False) def forward(self, model_params, batch): """ @@ -134,7 +134,7 @@ def forward(self, model_params, batch): use_reentrant=False, ) - # tokens_global = self.ln(tokens_global) + tokens_global = self.ln(tokens_global) return tokens_global, posteriors From c3b676d16034ac4b4f761eeb11c5ba12db816b37 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Fri, 17 Apr 2026 10:32:26 +0200 Subject: [PATCH 282/344] quickfixes --- config/config_diffusion.yml | 6 +-- src/weathergen/model/attention.py | 12 +++--- src/weathergen/model/diffusion.py | 17 ++++----- src/weathergen/model/layers.py | 16 ++++---- src/weathergen/model/norms.py | 61 +++++++++++++++++++++++-------- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index f103d4a42..131f97087 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 -sigma_data: 0.5789 +sigma_data: 1.0 rho: 7 p_mean: -1.2 p_std: 1.2 @@ -112,7 +112,7 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 # load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 # load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 -load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 # load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 # load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 @@ -250,7 +250,7 @@ validation_config: validation_noise_levels: [1.0, 2.0, 3.0, 4.0] samples_per_mini_epoch: 16 - shuffle: True # TODO: Set back to False + shuffle: False # TODO: Set back to False start_date: 2012-06-01T00:00 end_date: 2012-06-01T18:00 diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 1f3df7242..d5403db05 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -234,7 +234,7 @@ def __init__( if dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) - self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) + # self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -296,8 +296,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.proj_out(self.dropout(outs.flatten(-2, -1))) - if ada_ln_aux is not None: - out = self.lnorm_final(out, ada_ln_aux) + # if ada_ln_aux is not None: + # out = self.lnorm_final(out, ada_ln_aux) if self.with_residual: out = x_in + out * gate if self.noise_conditioning else x_in + out @@ -552,7 +552,7 @@ def __init__( if dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero - self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero + # self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero else: self.lnorm = norm(dim_embed, eps=norm_eps) self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -614,8 +614,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) - if ada_ln_aux is not None: - out = self.lnorm_final(out, ada_ln_aux) + # if ada_ln_aux is not None: + # out = self.lnorm_final(out, ada_ln_aux) if self.with_residual: out = x_in + out * gate if self.noise_conditioning else out + x_in diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index f5c6a31f5..d28e64834 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -71,7 +71,7 @@ def forward( fstep: int = None, meta_info: dict[str, SampleMetaData] = None, coords: torch.Tensor = None, - num_steps: int = 30, + num_steps: int = 10, ) -> torch.Tensor: """ Forward pass that routes to training_forward or inference_forward based on model status. @@ -113,12 +113,12 @@ def forward( ) else: # NOTE: temporary for analysing denoising - return self.training_forward( - tokens=tokens, - fstep=fstep, - meta_info=meta_info, - coords=coords, - ) + # return self.training_forward( + # tokens=tokens, + # fstep=fstep, + # meta_info=meta_info, + # coords=coords, + # ) if fstep is None: raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") @@ -149,7 +149,7 @@ def training_forward( # y = data.get_input_data(-1) # eta = data.get_input_metadata(-1) - self.cur_token = tokens + self.cur_token = tokens.detach() if self.cf.fe_diffusion_model_conditioning == "date_time": c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) @@ -170,7 +170,6 @@ def training_forward( n = torch.randn_like(y) * sigma self._noised_tokens = (y + n).detach() - self._noised_tokens = y + n diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 17a89c2fc..9e5a48790 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -104,12 +104,12 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) - if post_layer_norm: - self.layers.append( - norm(dim_out, eps=norm_eps) - if dim_aux is None - else AdaLayerNormFinal(dim_out, dim_aux, norm_eps=norm_eps) - ) + # if post_layer_norm: + # self.layers.append( + # norm(dim_out, eps=norm_eps) + # if dim_aux is None + # else AdaLayerNormFinal(dim_out, dim_aux, norm_eps=norm_eps) + # ) # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): @@ -134,14 +134,12 @@ def forward(self, *args): else: if i == 0 and self.with_noise_conditioning: x, gate = self.noise_conditioning(x, noise_emb) - if isinstance(layer, (AdaLayerNormFinal)): + if self.with_aux and isinstance(layer, (AdaLayerNormFinal)): x = layer(x, aux) else: x = layer(x) if self.with_residual: - if gate is not None: - x = x * gate if gate is not None: x = x * gate if x.shape[-1] == x_in.shape[-1]: diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 0c27aac53..b62000d09 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -59,12 +59,10 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) return output * self.weight - - + class AdaLayerNorm(torch.nn.Module): """ - AdaLayerNorm for embedding auxiliary information. - Produces scale and shift for adaptive layer norm. + AdaLayerNorm for embedding auxiliary information """ def __init__( @@ -72,20 +70,51 @@ def __init__( ): super().__init__() - # MLP for embedding auxiliary information (matches DiT style) + # simple 2-layer MLP for embedding auxiliary information + self.embed_aux = torch.nn.ModuleList() + self.embed_aux.append(torch.nn.Linear(dim_aux, 4 * dim_aux)) + self.embed_aux.append(torch.nn.SiLU()) + self.embed_aux.append(torch.nn.Linear(4 * dim_aux, 2 * dim_embed_x)) + self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(dim_aux, 2 * dim_embed_x, bias=True) - ) - - # Initialize weights to zero for stable training (DiT style) - nn.init.zeros_(self.adaLN_modulation[-1].weight) - nn.init.zeros_(self.adaLN_modulation[-1].bias) def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: - shift, scale = self.adaLN_modulation(aux).chunk(2, dim=-1) - return modulate(self.norm(x), shift, scale) + for block in self.embed_aux: + aux = block(aux) + scale, shift = aux.split(aux.shape[-1] // 2, dim=-1) + + x = self.norm(x) * (1 + scale) + shift + + return x + +# TODO: Check if want to overall AdaLayernorm implementation as below... +# class AdaLayerNorm(torch.nn.Module): +# """ +# AdaLayerNorm for embedding auxiliary information. +# Produces scale and shift for adaptive layer norm. +# """ + +# def __init__( +# self, dim_embed_x, dim_aux, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5 +# ): +# super().__init__() + +# breakpoint() + +# # MLP for embedding auxiliary information (matches DiT style) +# self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) +# self.adaLN_modulation = nn.Sequential( +# nn.SiLU(), +# nn.Linear(dim_aux, 2 * dim_embed_x, bias=True) +# ) + +# # Initialize weights to zero for stable training (DiT style) +# nn.init.zeros_(self.adaLN_modulation[-1].weight) +# nn.init.zeros_(self.adaLN_modulation[-1].bias) + +# def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: +# shift, scale = self.adaLN_modulation(aux).chunk(2, dim=-1) +# return modulate(self.norm(x), shift, scale) class AdaLayerNormFinal(torch.nn.Module): @@ -99,6 +128,8 @@ def __init__( ): super().__init__() + breakpoint() + self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) self.adaLN_modulation = nn.Sequential( nn.SiLU(), From 38ac54f2b89def6bfa2540358e6378a327c3613c Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Fri, 17 Apr 2026 11:50:19 +0200 Subject: [PATCH 283/344] add layernorm --- src/weathergen/model/diffusion.py | 6 +++--- src/weathergen/model/encoder.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index d28e64834..21e5b57b4 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -418,9 +418,9 @@ def timestep_embedding(self, t: float, max_period: int = 10000): :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ - # Ensure t is at least 1D - if t.dim() == 0: - t = t.unsqueeze(0) + # Ensure t is 1D + if t.ndim == 0: + t = t.view(1) half = self.frequency_embedding_dim // 2 freqs = torch.exp( diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 77ba3f3da..a1caefe1f 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -112,7 +112,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) - # self.ln = torch.nn.LayerNorm(cf.ae_local_dim_embed, elementwise_affine=False) + self.ln = torch.nn.LayerNorm(cf.ae_local_dim_embed, elementwise_affine=False) def forward(self, model_params, batch): """ From 6b7fe2bad7d06945a21b0a989c52370a341973cf Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Fri, 17 Apr 2026 15:43:48 +0200 Subject: [PATCH 284/344] now running with new layer structure --- src/weathergen/model/attention.py | 92 ++++++++++++++++------------- src/weathergen/model/diffusion.py | 12 ++-- src/weathergen/model/encoder.py | 2 +- src/weathergen/model/engines.py | 10 ++-- src/weathergen/model/layers.py | 86 ++++++++++++++------------- src/weathergen/model/norms.py | 97 +++++++++++++------------------ 6 files changed, 148 insertions(+), 151 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index d5403db05..a63e0e734 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.layers import LinearNormConditioning -from weathergen.model.norms import AdaLayerNorm, AdaLayerNormFinal, RMSNorm +from weathergen.model.norms import AdaLayerNorm, AdaLNZero, RMSNorm from weathergen.model.positional_encoding import rotary_pos_emb_2d """ @@ -213,8 +213,9 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - with_noise_conditioning=False, # should only be True for diffusion model with_2d_rope=False, + is_dit=False, + dit_is_cond=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -223,6 +224,7 @@ def __init__( self.softcap = softcap self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.dtype = attention_dtype assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -232,11 +234,22 @@ def __init__( else: norm = RMSNorm - if dim_aux is not None: + self.is_dit = is_dit + self.dit_is_cond = dit_is_cond + + if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert with_residual, "DIT attention should always have residual connection" + self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed, dtype=attention_dtype + ) + elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) - # self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) + self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -262,24 +275,20 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - self.noise_conditioning = None - if with_noise_conditioning: - self.noise_conditioning = LinearNormConditioning(dim_embed, dtype=self.dtype) - def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x # Handle ada_ln_aux conditioning - if ada_ln_aux is None: - x = self.lnorm(x) + if self.is_dit: + if self.dit_is_cond: + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + cond_gate = 1 + x, noise_gate = self.noise_conditioning(x, emb) + gate = cond_gate * noise_gate else: - x = self.lnorm(x, ada_ln_aux) - - if self.noise_conditioning: - assert emb is not None, "Need noise embedding if using noise conditioning" - x, gate = self.noise_conditioning(x, emb) - + x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) # project onto heads s = [x.shape[0], x.shape[1], self.num_heads, -1] @@ -296,11 +305,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.proj_out(self.dropout(outs.flatten(-2, -1))) - # if ada_ln_aux is not None: - # out = self.lnorm_final(out, ada_ln_aux) - if self.with_residual: - out = x_in + out * gate if self.noise_conditioning else x_in + out + out = x_in + out * gate if self.is_dit else x_in + out return out @@ -530,8 +536,9 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - with_noise_conditioning=False, # should only be True for diffusion model with_2d_rope=False, + is_dit = False, # should only be True for diffusion model + dit_is_cond = False, # whether the attention is used for conditioning in the diffusion model (as opposed to denoising). Should only be True for cross attention layers in the diffusion model, and will control whether ada_ln_aux is applied to the input or output of the attention layer ): super(MultiSelfAttentionHead, self).__init__() @@ -545,16 +552,28 @@ def __init__( assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj + if norm_type == "LayerNorm": norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps) else: norm = RMSNorm - if dim_aux is not None: - self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero - # self.lnorm_final = AdaLayerNormFinal(dim_embed, dim_aux, norm_eps=norm_eps) #should be initialised to zero + self.is_dit = is_dit + self.dit_is_cond = dit_is_cond + + if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert with_residual, "DIT attention should always have residual connection" + self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed + ) #TODO: Do I need to pass dtype? + elif dim_aux is not None: + self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) + self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -574,24 +593,20 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) - self.noise_conditioning = None - if with_noise_conditioning: - self.noise_conditioning = LinearNormConditioning( - latent_space_dim=dim_embed, dtype=self.dtype - ) - def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x # Handle ada_ln_aux conditioning - if ada_ln_aux is None: - x = self.lnorm(x) + if self.is_dit: + if self.dit_is_cond: + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + cond_gate = 1 + x, noise_gate = self.noise_conditioning(x, emb) + gate = cond_gate * noise_gate else: - x = self.lnorm(x, ada_ln_aux) - - if self.noise_conditioning: - x, gate = self.noise_conditioning(x, emb) + x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) # project onto heads and q,k,v and @@ -614,11 +629,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) - # if ada_ln_aux is not None: - # out = self.lnorm_final(out, ada_ln_aux) - if self.with_residual: - out = x_in + out * gate if self.noise_conditioning else out + x_in + out = x_in + out * gate if self.is_dit else out + x_in return out diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 21e5b57b4..216d142b1 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -113,12 +113,12 @@ def forward( ) else: # NOTE: temporary for analysing denoising - # return self.training_forward( - # tokens=tokens, - # fstep=fstep, - # meta_info=meta_info, - # coords=coords, - # ) + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) if fstep is None: raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index a1caefe1f..c103b083a 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -134,7 +134,7 @@ def forward(self, model_params, batch): use_reentrant=False, ) - # tokens_global = self.ln(tokens_global) + tokens_global = self.ln(tokens_global) return tokens_global, posteriors diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 191d3bdf4..c6fc101cd 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -421,8 +421,9 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_noise_conditioning=self.cf.fe_diffusion_model, with_2d_rope=self.cf.get("rope_2D", False), + is_dit=self.cf.fe_diffusion_model, + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) else: @@ -439,8 +440,9 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_noise_conditioning=self.cf.fe_diffusion_model, with_2d_rope=self.cf.get("rope_2D", False), + is_dit=self.cf.fe_diffusion_model, + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) # Add MLP block @@ -450,12 +452,12 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.cf.ae_global_dim_embed, num_layers=2, with_residual=True, - post_layer_norm=cf.fe_diffusion_model_conditioning in ["date_time"], dropout_rate=self.cf.fe_dropout_rate, norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, - with_noise_conditioning=self.cf.fe_diffusion_model + is_dit=self.cf.fe_diffusion_model, + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) # Optionally, add LayerNorm after i-th layer diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 9e5a48790..55913ee2c 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from weathergen.model.norms import AdaLayerNorm, AdaLayerNormFinal, RMSNorm +from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm class NamedLinear(torch.nn.Module): @@ -53,7 +53,6 @@ def __init__( num_layers=2, hidden_factor=2, pre_layer_norm=True, - post_layer_norm=False, dropout_rate=0.0, nonlin=torch.nn.GELU, with_residual=False, @@ -61,7 +60,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, name: str | None = None, - with_noise_conditioning=False + is_dit=False, + dit_is_cond=False, ): """Constructor""" @@ -74,25 +74,28 @@ def __init__( self.with_residual = with_residual self.with_aux = dim_aux is not None - self.with_noise_conditioning = with_noise_conditioning + self.is_dit = is_dit + self.dit_is_cond = dit_is_cond dim_hidden = int(dim_in * hidden_factor) self.layers = torch.nn.ModuleList() norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm - if pre_layer_norm: - self.layers.append( - norm(dim_in, eps=norm_eps) - if dim_aux is None - else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) - ) - - if with_noise_conditioning: - self.noise_conditioning = LinearNormConditioning( - dim_in - ) # TODO: check if should pass some dtype? - + if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert with_residual, "DIT attention should always have residual connection" + self.lnorm = AdaLNZero(dim_in, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_in, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning(dim_in) + elif dim_aux is not None: + self.lnorm = AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + else: + self.lnorm = norm(dim_in, eps=norm_eps) + + #TODO: The below should be consolidated – implementing in layer list for backward compatibility + if not is_dit: + self.layers.append(self.lnorm) self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) self.layers.append(nonlin()) self.layers.append(torch.nn.Dropout(p=dropout_rate)) @@ -104,12 +107,6 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) - # if post_layer_norm: - # self.layers.append( - # norm(dim_out, eps=norm_eps) - # if dim_aux is None - # else AdaLayerNormFinal(dim_out, dim_aux, norm_eps=norm_eps) - # ) # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): @@ -117,30 +114,34 @@ def forward(self, *args): if len(args) < 2 and self.with_aux: raise ValueError("Auxiliary input required but not provided") if len(args) == 2: - aux = args[1] + ada_ln_aux = args[1] elif len(args) > 2: - aux = args[-1] - noise_emb = args[2] if self.with_noise_conditioning else None - noise_emb = args[2] if self.with_noise_conditioning else None - - gate = None - gate = None - for i, layer in enumerate(self.layers): - if i == 0 and self.with_aux: - if isinstance(layer, (AdaLayerNorm)): - x = layer(x, aux) - if self.with_noise_conditioning: - x, gate = self.noise_conditioning(x, noise_emb) + ada_ln_aux = args[-1] + noise_emb = args[2] if self.is_dit else None + noise_emb = args[2] if self.is_dit else None + + if self.is_dit: + if self.dit_is_cond: + assert ada_ln_aux is not None, "Need auxiliary input for conditional DIT" + x, cond_gate = self.lnorm(x, ada_ln_aux) else: - if i == 0 and self.with_noise_conditioning: - x, gate = self.noise_conditioning(x, noise_emb) - if self.with_aux and isinstance(layer, (AdaLayerNormFinal)): - x = layer(x, aux) - else: - x = layer(x) + cond_gate = 1 + assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" + x, noise_gate = self.noise_conditioning(x, noise_emb) + gate = cond_gate * noise_gate + # elif self.dim_aux is not None: + # x = self.lnorm(x, ada_ln_aux) + # else: + # x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) + + for layer in self.layers: + if isinstance(layer, AdaLayerNorm): + x = layer(x, ada_ln_aux) + else: + x = layer(x) if self.with_residual: - if gate is not None: + if self.is_dit: x = x * gate if x.shape[-1] == x_in.shape[-1]: x = x_in + x @@ -183,3 +184,4 @@ def forward(self, inputs, noise_emb): return (inputs * scale + offset).to( self.dtype ), gate # TODO: check if to(self.dtype) needed here + diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index b62000d09..404575e27 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -59,6 +59,45 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) return output * self.weight + + +class AdaLNZero(torch.nn.Module): + """DiT-style adaptive layer norm with zero initialization for diffusion models.""" + + def __init__(self, dim_embed: int, dim_aux: int, norm_eps: float = 1e-5): + super().__init__() + self.norm = torch.nn.LayerNorm(dim_embed, elementwise_affine=False, eps=norm_eps) + self.gate_proj = torch.nn.Linear(dim_aux, dim_embed) + self.shift_proj = torch.nn.Linear(dim_aux, dim_embed) + + with torch.no_grad(): + self.gate_proj.weight.zero_() + self.gate_proj.bias.zero_() + self.shift_proj.weight.zero_() + self.shift_proj.bias.zero_() + + def forward( + self, x: torch.Tensor, aux: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Returns (x_normalized_and_scaled, gate_signal) for residual modulation.""" + x_norm = self.norm(x) + + if aux.dim() == 0: + aux = aux.unsqueeze(0) + + gate_params = self.gate_proj(aux) + shift_params = self.shift_proj(aux) + + while gate_params.dim() < x_norm.dim(): + gate_params = gate_params.unsqueeze(-2) + shift_params = shift_params.unsqueeze(-2) + + gate = 1 + gate_params + x_out = gate * x_norm + shift_params + gate_signal = gate.mean(dim=-1, keepdim=True) + + return x_out, gate_signal + class AdaLayerNorm(torch.nn.Module): """ @@ -87,67 +126,9 @@ def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Ten return x -# TODO: Check if want to overall AdaLayernorm implementation as below... -# class AdaLayerNorm(torch.nn.Module): -# """ -# AdaLayerNorm for embedding auxiliary information. -# Produces scale and shift for adaptive layer norm. -# """ - -# def __init__( -# self, dim_embed_x, dim_aux, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5 -# ): -# super().__init__() - -# breakpoint() - -# # MLP for embedding auxiliary information (matches DiT style) -# self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) -# self.adaLN_modulation = nn.Sequential( -# nn.SiLU(), -# nn.Linear(dim_aux, 2 * dim_embed_x, bias=True) -# ) - -# # Initialize weights to zero for stable training (DiT style) -# nn.init.zeros_(self.adaLN_modulation[-1].weight) -# nn.init.zeros_(self.adaLN_modulation[-1].bias) - -# def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: -# shift, scale = self.adaLN_modulation(aux).chunk(2, dim=-1) -# return modulate(self.norm(x), shift, scale) - - -class AdaLayerNormFinal(torch.nn.Module): - """ - AdaLayerNorm for gating only (scale only, no shift). - Used for final output gating as in DiT. - """ - - def __init__( - self, dim_embed_x, dim_aux, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5 - ): - super().__init__() - - breakpoint() - - self.norm = torch.nn.LayerNorm(dim_embed_x, norm_eps, norm_elementwise_affine) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(dim_aux, dim_embed_x, bias=True) - ) - - # Initialize weights to zero for stable training (DiT style) - nn.init.zeros_(self.adaLN_modulation[-1].weight) - nn.init.zeros_(self.adaLN_modulation[-1].bias) - - def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: - scale = self.adaLN_modulation(aux) - return modulate(self.norm(x), shift=0, scale=scale) - def modulate(x, shift, scale): return x * (1 + scale) + shift - class SwiGLU(nn.Module): def __init__(self): super(SwiGLU, self).__init__() From ddcd9166844d00c8ae92f9ee947fab3dbe1ceb41 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Fri, 17 Apr 2026 16:07:47 +0200 Subject: [PATCH 285/344] minor additions --- src/weathergen/model/attention.py | 2 +- src/weathergen/model/norms.py | 36 ------------------------------- 2 files changed, 1 insertion(+), 37 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index a63e0e734..d592a9461 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -567,7 +567,7 @@ def __init__( assert with_residual, "DIT attention should always have residual connection" self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) self.noise_conditioning = LinearNormConditioning( - latent_space_dim=dim_embed + latent_space_dim=dim_embed, dtype=attention_dtype ) #TODO: Do I need to pass dtype? elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 404575e27..4fb6484a2 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -194,42 +194,6 @@ def forward(self, x: torch.Tensor, c: torch.Tensor, x_lens, **kwargs) -> torch.T + x ) - -# NOTE: Inspired by GenCast/DiT. -class LinearNormConditioning(torch.nn.Module): - """Module for norm conditioning, adapted from GenCast with additional gate parameter from DiT. - - Conditions the normalization of `inputs` by applying a linear layer to the - `norm_conditioning` which produces the scale and offset for each channel. - """ - - def __init__(self, latent_space_dim: int, noise_emb_dim: int = 512, dtype=torch.bfloat16): - super().__init__() - self.dtype = dtype - - self.conditional_linear_layer = torch.nn.Linear( - in_features=noise_emb_dim, - out_features=3 * latent_space_dim, - ) - # Optional: initialize weights similar to TruncatedNormal(stddev=1e-8) - torch.nn.init.normal_(self.conditional_linear_layer.weight, std=1e-8) - torch.nn.init.zeros_(self.conditional_linear_layer.bias) - - def forward(self, inputs, noise_emb): - conditional_scale_offset = self.conditional_linear_layer(noise_emb.to(self.dtype)) - scale_minus_one, offset, gate = torch.chunk(conditional_scale_offset, 3, dim=-1) - scale = scale_minus_one + 1.0 - - # Reshape scale and offset for broadcasting if needed - while scale.dim() < inputs.dim(): - scale = scale.unsqueeze(1) - offset = offset.unsqueeze(1) - return (inputs * scale + offset).to( - self.dtype - ), gate # TODO: check if to(self.dtype) needed here - - - class SaturateEncodings(nn.Module): """A common alternative to a KL regularisation prevent outliers in the latent space when learning an auto-encoder for latent generative model, an example value for the scale factor is 5 From a3cf6b631f52ffbd1dde4dc3ff6c613fbd894636 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Sat, 18 Apr 2026 21:42:23 +0200 Subject: [PATCH 286/344] Config and plot noised/denoised side by side --- config/config_diffusion.yml | 17 +- src/weathergen/model/diffusion.py | 30 ++-- src/weathergen/model/model.py | 3 + src/weathergen/utils/validation_io.py | 236 ++++++++++---------------- 4 files changed, 112 insertions(+), 174 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 101c95a06..85a5d29d4 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -70,7 +70,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 -sigma_data: 0.5789 +sigma_data: 0.63 rho: 7 p_mean: -1.2 p_std: 1.2 @@ -111,6 +111,7 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 # load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 # load_chkpt: {'run_id': 'qf9yoimd', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'd43eu78a', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.63 # load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 @@ -160,7 +161,7 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : True + repeat_data_in_mini_epoch : False # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. @@ -178,8 +179,8 @@ training_config: samples_per_mini_epoch: 4096 shuffle: True - start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -248,10 +249,10 @@ validation_config: validation_noise_levels: [1.0, 2.0, 3.0, 4.0] samples_per_mini_epoch: 16 - shuffle: False + shuffle: True - start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 # whether to track the exponential moving average of weights for validation validate_with_ema: @@ -270,7 +271,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: True # test config; full test config is merge of validation and test config diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 54fde2008..26660b4d5 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -82,21 +82,21 @@ def forward( # y = data.get_input_data(-1) # eta = data.get_input_metadata(-1) - # TODO: remove after single sample experiments - if self.cur_token is not None: - # logger.info("checking single sampling") - assert self.cur_token[0].shape == tokens[0].shape, ( - "first token shape was different between iterations " - "– violates single sample overfitting with difference" - ) - assert torch.equal(self.cur_token[0], tokens[0]), ( - f"first token was different between iterations " - f"– violates single sample overfitting {self.cur_token[0] - tokens[0]}" - ) - assert torch.equal(self.cur_token, tokens), ( - f"tokens were different between iterations " - f"– violates single sample overfitting {self.cur_token - tokens}" - ) + # # TODO: remove after single sample experiments + # if self.cur_token is not None: + # # logger.info("checking single sampling") + # assert self.cur_token[0].shape == tokens[0].shape, ( + # "first token shape was different between iterations " + # "– violates single sample overfitting with difference" + # ) + # assert torch.equal(self.cur_token[0], tokens[0]), ( + # f"first token was different between iterations " + # f"– violates single sample overfitting {self.cur_token[0] - tokens[0]}" + # ) + # assert torch.equal(self.cur_token, tokens), ( + # f"tokens were different between iterations " + # f"– violates single sample overfitting {self.cur_token - tokens}" + # ) self.cur_token = tokens.detach() # return self.inference(fstep=fstep, num_steps=10, coords=coords) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 754289d16..8e4181428 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -737,6 +737,9 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) + # print(tokens.std().item()) + # breakpoint() + # Normalize tokens # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. # t_mean = tokens.mean() diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 9284ed2be..a730aac05 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -11,6 +11,8 @@ from math import exp import re +import matplotlib.pyplot as plt +import cartopy.crs as ccrs import numpy as np import torch import xarray as xr @@ -254,7 +256,7 @@ def write_output( # Free arrays no longer needed after zarr writing - del targets_all, targets_times_all, targets_lens, sources, data + del targets_all, targets_lens, sources, data # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. @@ -271,7 +273,9 @@ def write_output( for stream_idx, stream_info in enumerate(cf.streams): stream_name = stream_info["name"] preds_stream = preds_all[t_idx][stream_idx] + noised_stream = noised_preds_all[t_idx][stream_idx] coords_stream = targets_coords_all[t_idx][stream_idx] + times_stream = targets_times_all[t_idx][stream_idx] if preds_stream.size == 0 or coords_stream.size == 0: _logger.warning(f"No prediction data to plot for stream {stream_name}.") @@ -286,26 +290,26 @@ def write_output( ) continue + has_noised = ( + noised_stream.size > 0 and noised_stream.ndim >= 2 + ) + if has_noised and noised_stream.ndim == 3: + noised_stream = noised_stream[0] + channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) selected_channels = [ ch for ch in channels if _normalize_channel_name(ch) in headline_channels ] if not selected_channels: _logger.warning(f"No headline channels available for plotting stream {stream_name}.") - del preds_stream, coords_stream continue - # Build a channel index map so we can slice numpy arrays directly - # instead of constructing a full xarray DataArray for all channels. ch_to_col = {ch: idx for idx, ch in enumerate(channels)} lat = coords_stream[:, 0] lon = coords_stream[:, 1] - plotter.stream = stream_name - plotter.run_id = config.get_run_id_from_config(cf) - plotter.fstep = forecast_offset - + run_id = config.get_run_id_from_config(cf) num_samples = len(preds) len_per_sample = preds_stream.shape[0] // num_samples @@ -313,164 +317,94 @@ def write_output( s_start = sample * len_per_sample s_end = (sample + 1) * len_per_sample + # Extract sample date from target times + sample_times = times_stream[s_start:s_end] + sample_date = np.unique(sample_times) + if len(sample_date) > 0 and not np.isnat(sample_date[0]): + date_str = str(sample_date[0].astype("datetime64[h]")) + else: + date_str = "unknown date" + for varname in selected_channels: col = ch_to_col[varname] - vals = preds_stream[s_start:s_end, col] + pred_vals = preds_stream[s_start:s_end, col] sample_lat = lat[s_start:s_end] sample_lon = lon[s_start:s_end] - # Drop NaN points - valid = ~np.isnan(vals) - vals = vals[valid] - sample_lat = sample_lat[valid] - sample_lon = sample_lon[valid] - - sample_da = xr.DataArray( - vals, - dims=("ipoint",), - coords={ - "ipoint": np.arange(len(vals)), - "lat": ("ipoint", sample_lat), - "lon": ("ipoint", sample_lon), - }, - ) + # Drop NaN points (use pred mask for both panels) + valid = ~np.isnan(pred_vals) + pred_vals = pred_vals[valid] + plot_lat = sample_lat[valid] + plot_lon = sample_lon[valid] channel_dir = base_plot_dir / varname channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}" - # Add noise_level_rn to title if present for this stream - if noise_level is not None: - eta_str = str(noise_level) - else: - eta_str = None + + eta_str = str(noise_level) if noise_level is not None else None eta_tag = f"_eta{eta_str}" if eta_str is not None else "" epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - - if noise_level is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1} | noise_level={eta_str}" - else: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | sample {sample + 1}" - - plot_name = plotter.scatter_plot( - sample_da, - channel_dir, - varname=varname, - regionname="global", - tag=epoch_tag, - title=title, - ) - src = channel_dir / f"{plot_name}.{plotter.image_format}" - dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst: - try: - src.replace(dst) - except (FileNotFoundError, OSError): - pass # another rank already renamed or removed the file - - del sample_da, vals, sample_lat, sample_lon, valid - del preds_stream, coords_stream - - # Plot decoded noised tokens (diffusion models only) - has_noised = any( - noised_preds_all[t_idx][s_idx].size > 0 - for s_idx in range(len(cf.streams)) - if noised_preds_all[t_idx][s_idx].ndim >= 2 - ) - if has_noised: - for stream_idx, stream_info in enumerate(cf.streams): - stream_name = stream_info["name"] - noised_stream = noised_preds_all[t_idx][stream_idx] - coords_stream = targets_coords_all[t_idx][stream_idx] - - if noised_stream.size == 0 or coords_stream.size == 0: - continue - - if noised_stream.ndim == 3: - noised_stream = noised_stream[0] - elif noised_stream.ndim != 2: - continue - - channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) - selected_channels = [ - ch for ch in channels if _normalize_channel_name(ch) in headline_channels - ] - if not selected_channels: - del noised_stream, coords_stream - continue - - ch_to_col = {ch: idx for idx, ch in enumerate(channels)} - - lat = coords_stream[:, 0] - lon = coords_stream[:, 1] - - plotter.stream = stream_name - plotter.run_id = config.get_run_id_from_config(cf) - plotter.fstep = forecast_offset - - num_samples = len(preds) - len_per_sample = noised_stream.shape[0] // num_samples - - for sample in range(num_samples): - s_start = sample * len_per_sample - s_end = (sample + 1) * len_per_sample - - for varname in selected_channels: - col = ch_to_col[varname] - vals = noised_stream[s_start:s_end, col] - sample_lat = lat[s_start:s_end] - sample_lon = lon[s_start:s_end] - - # Drop NaN points - valid = ~np.isnan(vals) - vals = vals[valid] - sample_lat = sample_lat[valid] - sample_lon = sample_lon[valid] - - sample_da = xr.DataArray( - vals, - dims=("ipoint",), - coords={ - "ipoint": np.arange(len(vals)), - "lat": ("ipoint", sample_lat), - "lon": ("ipoint", sample_lon), - }, + # Determine number of panels + ncols = 2 if has_noised else 1 + proj = ccrs.Robinson() + fig, axes = plt.subplots( + 1, ncols, figsize=(8 * ncols, 5), + subplot_kw={"projection": proj}, dpi=150, + ) + if ncols == 1: + axes = [axes] + + # Shared color limits across panels + vmin, vmax = np.nanmin(pred_vals), np.nanmax(pred_vals) + + # Panel 1: noised (if available) + if has_noised: + noised_vals = noised_stream[s_start:s_end, col][valid] + vmin = min(vmin, np.nanmin(noised_vals)) + vmax = max(vmax, np.nanmax(noised_vals)) + ax_noised = axes[0] + ax_noised.coastlines() + ax_noised.set_global() + sc_n = ax_noised.scatter( + plot_lon, plot_lat, c=noised_vals, + vmin=vmin, vmax=vmax, cmap="coolwarm", + s=4.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, ) + ax_noised.set_title("Noised", fontsize=10) + ax_denoised = axes[1] + else: + ax_denoised = axes[0] + + # Panel 2 (or only panel): denoised prediction + ax_denoised.coastlines() + ax_denoised.set_global() + sc_d = ax_denoised.scatter( + plot_lon, plot_lat, c=pred_vals, + vmin=vmin, vmax=vmax, cmap="coolwarm", + s=4.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, + ) + ax_denoised.set_title("Denoised", fontsize=10) + + # Shared colorbar + fig.colorbar(sc_d, ax=axes, orientation="horizontal", + label=varname, shrink=0.6, pad=0.05) + + # Suptitle with date + eta_info = f" | noise_level={eta_str}" if eta_str else "" + fig.suptitle( + f"{stream_name} - {varname} (fstep {forecast_offset})" + f" | sample {sample + 1} | {date_str}{eta_info}", + fontsize=11, + ) - channel_dir = base_plot_dir / varname / "noised" - channel_dir.mkdir(parents=True, exist_ok=True) - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}_{sample}_noised" + fname = channel_dir / f"{epoch_tag}_{sample}.{plotter.image_format}" + fig.savefig(fname, bbox_inches="tight") + plt.close(fig) - if noise_level is not None: - eta_str = str(noise_level) - else: - eta_str = None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - - if noise_level is not None: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noised sample {sample + 1} | noise_level={eta_str}" - else: - title = f"{stream_name} - {varname} (fstep {forecast_offset}) | noised sample {sample + 1}" - - plot_name = plotter.scatter_plot( - sample_da, - channel_dir, - varname=varname, - regionname="global", - tag=epoch_tag, - title=title, - ) - src = channel_dir / f"{plot_name}.{plotter.image_format}" - dst = channel_dir / f"{epoch_tag}.{plotter.image_format}" - if src != dst: - try: - src.replace(dst) - except (FileNotFoundError, OSError): - pass # another rank already renamed or removed the file + del pred_vals, plot_lat, plot_lon, valid - del sample_da, vals, sample_lat, sample_lon, valid + del preds_stream, coords_stream - del noised_stream, coords_stream + del targets_times_all i += 1 From 1c8623c6b9c08c8b9274a03f453b0c32dd3501da Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Mon, 20 Apr 2026 08:51:26 +0200 Subject: [PATCH 287/344] Remove fixed seed from inference --- src/weathergen/model/diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 26660b4d5..a4614a9f0 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -152,7 +152,7 @@ def inference( # https://github.com/NVlabs/edm/blob/main/generate.py # Sample pure noise (assuming single batch element for now) - torch.manual_seed(42) + # torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") ### OLD WAY OF COMPUTING SIGMA SCHEDULE From ee301e2183dcf1dda6dd00281cf34e7d2202f9f1 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 20 Apr 2026 11:17:26 +0200 Subject: [PATCH 288/344] inter changes --- config/config_diffusion.yml | 14 +++++------ config/runs_plot_train.yml | 48 +++++++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 131f97087..b36dfb32c 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -176,19 +176,19 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 128 - samples_per_mini_epoch: 4096 + num_mini_epochs: 512 + samples_per_mini_epoch: 1028 shuffle: True - start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + start_date: 1992-06-01T00:00 + end_date: 2012-06-01T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 learning_rate_scheduling : lr_start: 1e-6 #5e-5 - lr_max: 1e-5 #1e-4 + lr_max: 5e-6 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 @@ -252,8 +252,8 @@ validation_config: samples_per_mini_epoch: 16 shuffle: False # TODO: Set back to False - start_date: 2012-06-01T00:00 - end_date: 2012-06-01T18:00 + start_date: 1992-06-01T00:00 + end_date: 2012-06-01T00:00 # whether to track the exponential moving average of weights for validation validate_with_ema: diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index 6bd2a91bc..ef3203d96 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,5 +1,45 @@ -train : - plot : - h8wnm1kt: +train: + plot: + it6wj130: + slurm_id: 376473 + description: "single samples, lr_start=1e-6, lr_max=1e-5" + p54uvzl2: + slurm_id: 376808 + description: "single samples, lr_start=1e-5, lr_max=1e-5" + lq5djr4m: + slurm_id: 376811 + description: "single samples, lr_start=1e-6, lr_max=1e-6" + wcruesg4: + slurm_id: 376816 + description: "single samples, lr_start=1e-7, lr_max=1e-6" + k3qh6elp: + slurm_id: 377059 + description: "single samples, lr_start=1e-6, lr_max=5e-6" + # w8hp1c2g: + # slurm_id: 376855 + # description: "20y distribution, lr_start=1e-5, lr_max=1e-5" + # ss9z2rqi: + # slurm_id: 376858 + # description: "20y distribution, lr_start=1e-6, lr_max=1e-6" + # g9iqgz0d: + # slurm_id: 376860 + # description: "20y distribution, lr_start=1e-7, lr_max=1e-6" + # q5u9p8xo: + # slurm_id: 377063 + # description: "20y distribution, lr_start=1e-6, lr_max=5e-6" + # f8e97mqx: + # slurm_id: 376862 + # description: "20y distribution, lr_start=1e-6, lr_max=1e-5" + xqgy519d: slurm_id: 0 - description: "first conditioning experiment" \ No newline at end of file + description: "Old Matze Baseline 1e-5 (single samples)" + rj2xksg0: + slurm_id: 0 + description: "Old Matze Baseline 1e-5 (single samples)" + bbosl5wy: + slurm_id: 0 + description: "New Matze Baseline (20y)" + y0l8egdr: + slurm_id: 0 + description: "New Matze Baseline (single samples)" + From 1002b40528110feca26a37845d0c73d7a91d9c91 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 20 Apr 2026 12:50:46 +0200 Subject: [PATCH 289/344] some fixes --- config/config_diffusion.yml | 12 +++++------ config/runs_plot_train.yml | 36 +++++++++++++++++++------------ src/weathergen/model/attention.py | 2 ++ src/weathergen/model/engines.py | 2 +- src/weathergen/model/layers.py | 1 + 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index b36dfb32c..e351fa290 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -180,15 +180,15 @@ training_config: samples_per_mini_epoch: 1028 shuffle: True - start_date: 1992-06-01T00:00 - end_date: 2012-06-01T00:00 + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 learning_rate_scheduling : lr_start: 1e-6 #5e-5 - lr_max: 5e-6 #1e-4 + lr_max: 5e-5 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 @@ -252,8 +252,8 @@ validation_config: samples_per_mini_epoch: 16 shuffle: False # TODO: Set back to False - start_date: 1992-06-01T00:00 - end_date: 2012-06-01T00:00 + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 # whether to track the exponential moving average of weights for validation validate_with_ema: @@ -272,7 +272,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: True + validate_before_training: False # test config; full test config is merge of validation and test config diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index ef3203d96..e91fb71be 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,20 +1,28 @@ + + train: plot: - it6wj130: - slurm_id: 376473 + cigywmh2: + slurm_id: 380678 description: "single samples, lr_start=1e-6, lr_max=1e-5" - p54uvzl2: - slurm_id: 376808 - description: "single samples, lr_start=1e-5, lr_max=1e-5" - lq5djr4m: - slurm_id: 376811 - description: "single samples, lr_start=1e-6, lr_max=1e-6" - wcruesg4: - slurm_id: 376816 - description: "single samples, lr_start=1e-7, lr_max=1e-6" - k3qh6elp: - slurm_id: 377059 - description: "single samples, lr_start=1e-6, lr_max=5e-6" + kxe5zfla: + slurm_id: 380680 + description: "single samples, lr_start=1e-6, lr_max=5e-5" + # it6wj130: + # slurm_id: 376473 + # description: "single samples, lr_start=1e-6, lr_max=1e-5" + # p54uvzl2: + # slurm_id: 376808 + # description: "single samples, lr_start=1e-5, lr_max=1e-5" + # lq5djr4m: + # slurm_id: 376811 + # description: "single samples, lr_start=1e-6, lr_max=1e-6" + # wcruesg4: + # slurm_id: 376816 + # description: "single samples, lr_start=1e-7, lr_max=1e-6" + # k3qh6elp: + # slurm_id: 377059 + # description: "single samples, lr_start=1e-6, lr_max=5e-6" # w8hp1c2g: # slurm_id: 376855 # description: "20y distribution, lr_start=1e-5, lr_max=1e-5" diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index d592a9461..837c0f003 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -284,6 +284,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.dit_is_cond: x, cond_gate = self.lnorm(x, ada_ln_aux) else: + x = self.lnorm(x) cond_gate = 1 x, noise_gate = self.noise_conditioning(x, emb) gate = cond_gate * noise_gate @@ -602,6 +603,7 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.dit_is_cond: x, cond_gate = self.lnorm(x, ada_ln_aux) else: + x = self.lnorm(x) cond_gate = 1 x, noise_gate = self.noise_conditioning(x, emb) gate = cond_gate * noise_gate diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index c6fc101cd..5a7b34fe6 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -498,6 +498,7 @@ def forward( tokens_in = tokens if self.cf.fe_diffusion_model: + assert noise_emb is not None, "noise_emb must be provided for diffusion model conditioning" for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) @@ -508,7 +509,6 @@ def forward( tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) else: assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" - assert noise_emb is not None, "noise_emb must be provided for diffusion model conditioning" tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: for block in self.fe_blocks: diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 55913ee2c..0ec0bdf38 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -125,6 +125,7 @@ def forward(self, *args): assert ada_ln_aux is not None, "Need auxiliary input for conditional DIT" x, cond_gate = self.lnorm(x, ada_ln_aux) else: + x = self.lnorm(x) cond_gate = 1 assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" x, noise_gate = self.noise_conditioning(x, noise_emb) From 0a454137b4be529291f85128eb44edff4dfc77f8 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 21 Apr 2026 00:13:43 +0200 Subject: [PATCH 290/344] config changes --- config/config_diffusion.yml | 2 +- config/runs_plot_train.yml | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index e351fa290..6ef023676 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,7 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: None # options: "date_time" +fe_diffusion_model_conditioning: "date_time" # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index e91fb71be..43150ae51 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -2,6 +2,9 @@ train: plot: + bpeh160r: + slurm_id: 381190 + description: "single samples, lr_start=1e-6, lr_max=1e-6" cigywmh2: slurm_id: 380678 description: "single samples, lr_start=1e-6, lr_max=1e-5" From 320bf078caa025765cb808e4f094eebe57a6e86e Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 21 Apr 2026 15:18:48 +0200 Subject: [PATCH 291/344] config changes --- config/config_diffusion.yml | 4 +-- config/runs_plot_train.yml | 53 +++++++++++++++++++++++-------------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 981cb0919..9425e38da 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,7 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: "date_time" # options: "date_time" +fe_diffusion_model_conditioning: None # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -114,7 +114,7 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 # load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 -load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index 43150ae51..dbed4c750 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,16 +1,29 @@ - +# ERA5 – no cond t8cm7bn9 382239 +# ERA5 – cond b9oyntjg 382235 train: plot: - bpeh160r: - slurm_id: 381190 - description: "single samples, lr_start=1e-6, lr_max=1e-6" - cigywmh2: - slurm_id: 380678 - description: "single samples, lr_start=1e-6, lr_max=1e-5" - kxe5zfla: - slurm_id: 380680 - description: "single samples, lr_start=1e-6, lr_max=5e-5" + # bpeh160r: + # slurm_id: 381190 + # description: "single samples, lr_start=1e-6, lr_max=1e-6" + # cigywmh2: + # slurm_id: 380678 + # description: "single samples, lr_start=1e-6, lr_max=1e-5" + # kxe5zfla: + # slurm_id: 380680 + # description: "single samples, lr_start=1e-6, lr_max=5e-5" + fuz6l32i: + slurm_id: 382174 + description: "conditioning w/ single samples, lr_start=1e-6, lr_max=5e-5" + # vujmw4g2: + # slurm_id: 382207 + # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + t8cm7bn9: + slurm_id: 382239 + description: "MERGE no conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + b9oyntjg: + slurm_id: 382235 + description: "MERGE conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" # it6wj130: # slurm_id: 376473 # description: "single samples, lr_start=1e-6, lr_max=1e-5" @@ -41,16 +54,16 @@ train: # f8e97mqx: # slurm_id: 376862 # description: "20y distribution, lr_start=1e-6, lr_max=1e-5" - xqgy519d: - slurm_id: 0 - description: "Old Matze Baseline 1e-5 (single samples)" - rj2xksg0: - slurm_id: 0 - description: "Old Matze Baseline 1e-5 (single samples)" + # xqgy519d: + # slurm_id: 0 + # description: "Old Matze Baseline 1e-5 (single samples)" + # rj2xksg0: + # slurm_id: 0 + # description: "Old Matze Baseline 1e-5 (single samples)" bbosl5wy: slurm_id: 0 - description: "New Matze Baseline (20y)" - y0l8egdr: - slurm_id: 0 - description: "New Matze Baseline (single samples)" + description: "Matze Baseline (ERA5)" + # y0l8egdr: + # slurm_id: 0 + # description: "New Matze Baseline (single samples)" From 7539a982aa244352a251476f52572a62120fe3cc Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 10:18:19 +0200 Subject: [PATCH 292/344] default config now converges --- config/config_diffusion.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 9425e38da..7e7b84b1d 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -176,8 +176,8 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] - num_mini_epochs: 512 - samples_per_mini_epoch: 1028 + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 shuffle: True start_date: 1979-01-01T00:00 @@ -188,7 +188,7 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 #5e-5 - lr_max: 5e-5 #1e-4 + lr_max: 7e-6 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 @@ -272,7 +272,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: True # test config; full test config is merge of validation and test config From 3569447c6828598fe748ce9ccf6901750817d1b6 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 15:29:36 +0200 Subject: [PATCH 293/344] config changes --- config/config_diffusion.yml | 2 +- config/runs_plot_train.yml | 32 +++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 7e7b84b1d..793a1c697 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 -sigma_data: 0.63 +sigma_data: 1 rho: 7 p_mean: -1.2 p_std: 1.2 diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index dbed4c750..c14727adf 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,8 +1,18 @@ -# ERA5 – no cond t8cm7bn9 382239 -# ERA5 – cond b9oyntjg 382235 train: plot: + u7etjsm0: + slurm_id: 385058 + description: "ERA5, lr_start=1e-6, lr_max=1e-5" + mot8sfay: + slurm_id: 385060 + description: "ERA5, lr_start=1e-6, lr_max=7e-6" + zhon45xy: + slurm_id: 385064 + description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=1e-5" + yimje7g3: + slurm_id: 385062 + description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=7e-6" # bpeh160r: # slurm_id: 381190 # description: "single samples, lr_start=1e-6, lr_max=1e-6" @@ -12,18 +22,18 @@ train: # kxe5zfla: # slurm_id: 380680 # description: "single samples, lr_start=1e-6, lr_max=5e-5" - fuz6l32i: - slurm_id: 382174 - description: "conditioning w/ single samples, lr_start=1e-6, lr_max=5e-5" + # fuz6l32i: + # slurm_id: 382174 + # description: "conditioning w/ single samples, lr_start=1e-6, lr_max=5e-5" # vujmw4g2: # slurm_id: 382207 # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" - t8cm7bn9: - slurm_id: 382239 - description: "MERGE no conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" - b9oyntjg: - slurm_id: 382235 - description: "MERGE conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + # t8cm7bn9: + # slurm_id: 382239 + # description: "MERGE no conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + # b9oyntjg: + # slurm_id: 382235 + # description: "MERGE conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" # it6wj130: # slurm_id: 376473 # description: "single samples, lr_start=1e-6, lr_max=1e-5" From b432e6894b787fc195d03956d42563e2f6b9768b Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 16:26:34 +0200 Subject: [PATCH 294/344] remove conditioning --- config/config_diffusion.yml | 2 -- src/weathergen/model/attention.py | 33 ++++++++----------------------- src/weathergen/model/engines.py | 3 --- src/weathergen/model/layers.py | 22 +++++---------------- 4 files changed, 13 insertions(+), 47 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 793a1c697..c9449334a 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,14 +60,12 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: None # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 with_step_conditioning: True # False # Diffusion related parameters -diffusion_conditioning_embed_dim: 32 # Multi-frequency calendar embedding (8 frequencies × 4 components) frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 837c0f003..4f3cd6ee3 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -215,7 +215,6 @@ def __init__( attention_dtype=torch.bfloat16, with_2d_rope=False, is_dit=False, - dit_is_cond=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -235,13 +234,10 @@ def __init__( norm = RMSNorm self.is_dit = is_dit - self.dit_is_cond = dit_is_cond - if is_dit: - if dit_is_cond: - assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert dim_aux is None, "conditioning not yet implemented for DIT attention" assert with_residual, "DIT attention should always have residual connection" - self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) + self.lnorm = norm(dim_embed, eps=norm_eps) self.noise_conditioning = LinearNormConditioning( latent_space_dim=dim_embed, dtype=attention_dtype ) @@ -281,13 +277,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): # Handle ada_ln_aux conditioning if self.is_dit: - if self.dit_is_cond: - x, cond_gate = self.lnorm(x, ada_ln_aux) - else: - x = self.lnorm(x) - cond_gate = 1 - x, noise_gate = self.noise_conditioning(x, emb) - gate = cond_gate * noise_gate + x = self.lnorm(x) + x, gate = self.noise_conditioning(x, emb) else: x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) @@ -539,7 +530,6 @@ def __init__( attention_dtype=torch.bfloat16, with_2d_rope=False, is_dit = False, # should only be True for diffusion model - dit_is_cond = False, # whether the attention is used for conditioning in the diffusion model (as opposed to denoising). Should only be True for cross attention layers in the diffusion model, and will control whether ada_ln_aux is applied to the input or output of the attention layer ): super(MultiSelfAttentionHead, self).__init__() @@ -560,13 +550,11 @@ def __init__( norm = RMSNorm self.is_dit = is_dit - self.dit_is_cond = dit_is_cond if is_dit: - if dit_is_cond: - assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert dim_aux is None, "conditioning not yet implemented for DIT attention" assert with_residual, "DIT attention should always have residual connection" - self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) + self.lnorm = norm(dim_embed, eps=norm_eps) self.noise_conditioning = LinearNormConditioning( latent_space_dim=dim_embed, dtype=attention_dtype ) #TODO: Do I need to pass dtype? @@ -600,13 +588,8 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): # Handle ada_ln_aux conditioning if self.is_dit: - if self.dit_is_cond: - x, cond_gate = self.lnorm(x, ada_ln_aux) - else: - x = self.lnorm(x) - cond_gate = 1 - x, noise_gate = self.noise_conditioning(x, emb) - gate = cond_gate * noise_gate + x = self.lnorm(x) + x, gate = self.noise_conditioning(x, emb) else: x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 5a7b34fe6..255bf0e23 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -423,7 +423,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) else: @@ -442,7 +441,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) # Add MLP block @@ -457,7 +455,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) # Optionally, add LayerNorm after i-th layer diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 0ec0bdf38..51f24548a 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -61,7 +61,6 @@ def __init__( norm_eps=1e-5, name: str | None = None, is_dit=False, - dit_is_cond=False, ): """Constructor""" @@ -75,7 +74,6 @@ def __init__( self.with_residual = with_residual self.with_aux = dim_aux is not None self.is_dit = is_dit - self.dit_is_cond = dit_is_cond dim_hidden = int(dim_in * hidden_factor) self.layers = torch.nn.ModuleList() @@ -83,10 +81,9 @@ def __init__( norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm if is_dit: - if dit_is_cond: - assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert dim_aux is None, "conditioning not yet implemented for DIT attention" assert with_residual, "DIT attention should always have residual connection" - self.lnorm = AdaLNZero(dim_in, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_in, eps=norm_eps) + self.lnorm = norm(dim_in, eps=norm_eps) self.noise_conditioning = LinearNormConditioning(dim_in) elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) @@ -121,19 +118,10 @@ def forward(self, *args): noise_emb = args[2] if self.is_dit else None if self.is_dit: - if self.dit_is_cond: - assert ada_ln_aux is not None, "Need auxiliary input for conditional DIT" - x, cond_gate = self.lnorm(x, ada_ln_aux) - else: - x = self.lnorm(x) - cond_gate = 1 + assert ada_ln_aux is None, "conditioning not yet implemented for DIT attention" + x = self.lnorm(x) assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" - x, noise_gate = self.noise_conditioning(x, noise_emb) - gate = cond_gate * noise_gate - # elif self.dim_aux is not None: - # x = self.lnorm(x, ada_ln_aux) - # else: - # x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) + x, gate = self.noise_conditioning(x, noise_emb) for layer in self.layers: if isinstance(layer, AdaLayerNorm): From bc243e545396a1040090c8bfe463550a6328563d Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 18:18:56 +0200 Subject: [PATCH 295/344] remove more conditioning --- .../datasets/multi_stream_data_sampler.py | 25 ------------------- src/weathergen/model/diffusion.py | 10 +------- src/weathergen/model/engines.py | 9 ++----- src/weathergen/model/layers.py | 21 ++++++++-------- src/weathergen/model/model.py | 8 +----- 5 files changed, 14 insertions(+), 59 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 0bf531c3b..a0fe7df90 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -92,8 +92,6 @@ def __init__( self.rank = cf.rank self.world_size = cf.world_size - self.diffusion_model_conditioning = cf.fe_diffusion_model_conditioning - self.healpix_level: int = cf.healpix_level self.num_healpix_cells: int = 12 * 4**self.healpix_level @@ -710,12 +708,6 @@ def _get_batch(self, idx: int, num_forecast_steps: int): ) target_metadata = target_masks.metadata[tidx] - # Get first target step's times (using self.output_offset as the first output step index) - if self.diffusion_model_conditioning == "date_time": - target_times_array = sdata.target_times_raw[self.output_offset] - target_metadata.add_params({'timestamp': ( - target_times_array[0] if len(target_times_array) > 0 else None - )}) # also want to add the mask to the metadata target_metadata.mask = target_mask # Map target to all source students @@ -729,23 +721,6 @@ def _get_batch(self, idx: int, num_forecast_steps: int): target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) - #add target times in source for diffusion model date/time conditioning - if self.diffusion_model_conditioning == "date_time": - #TODO: Might need upgrading fro num_samples > 1 - - # Assert singular source and target samples - assert len(batch.source_samples.samples) == 1, "Only single source sample supported for diffusion model conditioning." - assert len(batch.target_samples.samples) == 1, "Only single target sample supported for diffusion model conditioning." - - source_sample = batch.source_samples.samples[0] - target_sample = batch.target_samples.samples[0] - - # Copy target timestamps to source metadata for all streams - for stream_name in [s["name"] for s in self.streams]: - if stream_name in target_sample.meta_info and stream_name in source_sample.meta_info: - target_timestamp = target_sample.meta_info[stream_name].params.get('timestamp') - source_sample.meta_info[stream_name].add_params({'timestamp': target_timestamp}) - return batch def __iter__(self) -> ModelBatch: diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 216d142b1..f9239f6ed 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -151,10 +151,7 @@ def training_forward( self.cur_token = tokens.detach() - if self.cf.fe_diffusion_model_conditioning == "date_time": - c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) - else: - c = None + c = None y = tokens @@ -192,8 +189,6 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co # Precondition input and feed through network x = self.preconditioner.precondition(x, c) #currently does nothing - if self.cf.fe_diffusion_model_conditioning == "date_time": - c = self.datetime_embedder(c).to(x.device) return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=c @@ -225,9 +220,6 @@ def inference_forward( # Extract conditioning from meta_info (same as training_forward) c = None - if self.cf.fe_diffusion_model_conditioning == "date_time": - c = meta_info["ERA5"].params["timestamp"] - # Sample pure noise (assuming single batch element for now) # torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 255bf0e23..fe613bb11 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -500,13 +500,8 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - if self.cf.fe_diffusion_model_conditioning in ["date_time"]: - # Assuming ada_ln_aux contains the date_time embedding in this case - assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" - tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) - else: - assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" - tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) + assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" + tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 51f24548a..bfcb08679 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -108,17 +108,16 @@ def __init__( # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): x, x_in = args[0], args[0] - if len(args) < 2 and self.with_aux: - raise ValueError("Auxiliary input required but not provided") - if len(args) == 2: - ada_ln_aux = args[1] - elif len(args) > 2: - ada_ln_aux = args[-1] - noise_emb = args[2] if self.is_dit else None - noise_emb = args[2] if self.is_dit else None - - if self.is_dit: - assert ada_ln_aux is None, "conditioning not yet implemented for DIT attention" + if not self.is_dit: + if len(args) < 2 and self.with_aux: + raise ValueError("Auxiliary input required but not provided") + if len(args) == 2: + ada_ln_aux = args[1] + elif len(args) > 2: + ada_ln_aux = args[-1] + else: + assert len(args) == 3, "DIT gets 3 args (no conditioning implemented yet)" + noise_emb = args[-1] x = self.lnorm(x) assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" x, gate = self.noise_conditioning(x, noise_emb) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 5c847c751..133b7b04a 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -402,13 +402,7 @@ def create(self) -> "Model": mode_cfg = cf.training_config self.forecast_engine = None if cf.fe_num_blocks > 0: - if cf.get("fe_diffusion_model_conditioning", None) in ["date_time"]: - assert cf.diffusion_conditioning_embed_dim is not None, ( - "Diffusion conditioning embedding dimension must be specified when using diffusion model conditioning" - ) - self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells, dim_aux=self.cf.diffusion_conditioning_embed_dim) - else: - self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) if cf.get("fe_diffusion_model", False): self.forecast_engine = DiffusionForecastEngine( cf, self.num_healpix_cells, forecast_engine=self.forecast_engine From 092079d699d715b83004f5f5c0315db6ff8a1669 Mon Sep 17 00:00:00 2001 From: Moritz Hauschulz <60788263+moritzhauschulz@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:24:16 +0200 Subject: [PATCH 296/344] Mh/diffusion era5 uncond (#2257) * choosing btw train/inference * config update --------- Co-authored-by: Kerem Tezcan --- config/config_diffusion.yml | 2 +- src/weathergen/model/diffusion.py | 39 +++++++++++++++++-------------- src/weathergen/run_train.py | 3 +++ 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index c9449334a..651f619f2 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -70,7 +70,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 -sigma_data: 1 +sigma_data: 0.63 rho: 7 p_mean: -1.2 p_std: 1.2 diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index f9239f6ed..e1634763c 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -99,6 +99,7 @@ def forward( Raises: ValueError: If required arguments are missing for current mode """ + # called during training in training mode if self.training: if tokens is None or fstep is None or meta_info is None: raise ValueError( @@ -112,23 +113,27 @@ def forward( coords=coords, ) else: - # NOTE: temporary for analysing denoising - return self.training_forward( - tokens=tokens, - fstep=fstep, - meta_info=meta_info, - coords=coords, - ) - - if fstep is None: - raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") - - return self.inference_forward( - fstep=fstep, - num_steps=num_steps, - meta_info=meta_info, - coords=coords, - ) + # called in evaluation mode : + # decide btw pure noise generation (inference) vs denoising a sample for + # evaluation (train) using the stage variable + if self.cf.stage == 'train' or self.cf.stage == 'train_continue': + # NOTE: temporary for analysing denoising + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) + elif self.cf.stage == 'inference': + if fstep is None: + raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") + + return self.inference_forward( + fstep=fstep, + num_steps=num_steps, + meta_info=meta_info, + coords=coords, + ) def training_forward( self, diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 403be1235..e7703925f 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -96,6 +96,7 @@ def run_inference(args): cli_overwrite, ) cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + cf.stage = args.stage devices = Trainer.init_torch() cf = Trainer.init_ddp(cf) @@ -134,6 +135,7 @@ def run_continue(args): cli_overwrite, ) cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + cf.stage = args.stage mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) @@ -168,6 +170,7 @@ def run_train(args): ) cf = config.set_run_id(cf, args.run_id, False) cf.data_loading.rng_seed = int(time.time()) + cf.stage = args.stage mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) cf = Trainer.init_ddp(cf) From efaadabad4f711fbe57847a90c7079736278e569 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 24 Apr 2026 15:55:03 +0200 Subject: [PATCH 297/344] adding validation pass of random noise --- src/weathergen/train/trainer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8cf564090..3228106e5 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -556,6 +556,9 @@ def validate(self, mini_epoch, mode_cfg, batch_size): noise_levels = list(mode_cfg.get("validation_noise_levels", [0.0])) if not is_diffusion: noise_levels = [0.0] + else: + # Always include a pass without fixed noise level (random sampling) + noise_levels = [None] + noise_levels # Accumulate losses across noise levels with suffixed keys so they are # logged as a single "val" entry (e.g. LossLatentDiff.LossLatentDiff.mse.eta0.03) @@ -566,11 +569,15 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if is_diffusion: self._set_validation_noise_level(noise_level) - _d = Decimal(str(noise_level)).normalize() - _sign, _digits, _exp = _d.as_tuple() - eta_str = f"{'-' if _sign else ''}{''.join(map(str, _digits))}e{_exp}" - loss_suffix = f".eta{eta_str}" if len(noise_levels) > 1 else "" - stage_suffix = f"_eta{eta_str}" if len(noise_levels) > 1 else "" + if noise_level is None: + loss_suffix = "" + stage_suffix = "" + else: + _d = Decimal(str(noise_level)).normalize() + _sign, _digits, _exp = _d.as_tuple() + eta_str = f"{'-' if _sign else ''}{''.join(map(str, _digits))}e{_exp}" + loss_suffix = f".eta{eta_str}" if len(noise_levels) > 1 else "" + stage_suffix = f"_eta{eta_str}" if len(noise_levels) > 1 else "" dataset_val_iter = iter(self.data_loader_validation) num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size From 30461eaa23fa06fb06714cfa26cf749704a86bca Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 28 Apr 2026 15:55:49 +0200 Subject: [PATCH 298/344] rm plotting during validation, rm decoding of noised tokens, lint --- src/weathergen/datasets/batch.py | 1 + src/weathergen/datasets/masking.py | 6 - src/weathergen/model/attention.py | 14 +- src/weathergen/model/diffusion.py | 94 ++++---- src/weathergen/model/engines.py | 8 +- src/weathergen/model/layers.py | 6 +- src/weathergen/model/model.py | 65 +----- src/weathergen/model/norms.py | 23 +- src/weathergen/train/target_and_aux_utils.py | 14 +- src/weathergen/train/trainer.py | 11 +- src/weathergen/utils/validation_io.py | 230 +------------------ 11 files changed, 113 insertions(+), 359 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 244c0235f..22547ee92 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -35,6 +35,7 @@ def add_params(self, params: dict) -> None: self.params = {} self.params.update(params) + class Sample: # keys: stream name, values: SampleMetaData meta_info: dict[str | SampleMetaData] diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 68a5a58a7..9c20d010b 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -28,16 +28,12 @@ def __len__(self): return len(self.masks) def add_mask(self, mask, params, cfg, losses, idx, correspondence, relationship): - # TODO: REVERT TO ORIGINAL CODE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # If noise_level_rn exists in params, also add it to global_params for easier downstream access global_params = { "idx": idx, "correspondence": correspondence, "loss": losses, "relationship": relationship, } - if "noise_level_rn" in params: - global_params["noise_level_rn"] = params["noise_level_rn"] self.masks += [mask] self.metadata += [ SampleMetaData( @@ -562,8 +558,6 @@ def _generate_cell_mask( if "diffusion_rn" in masking_strategy_config: masking_params["noise_level_rn"] = self.rng.normal(0.0, 1.0) - - elif strategy == "healpix": # prepare healpix-based masking keep_rate = self._get_sampling_rate(masking_strategy_config) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 02379f039..990853d63 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.layers import LinearNormConditioning -from weathergen.model.norms import AdaLayerNorm, AdaLNZero, RMSNorm +from weathergen.model.norms import AdaLayerNorm, RMSNorm from weathergen.model.positional_encoding import rotary_pos_emb_2d """ @@ -258,7 +258,7 @@ def __init__( self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) - + self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -292,7 +292,7 @@ def mask_block_local(batch, head, idx_q, idx_kv): def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x - + # Handle ada_ln_aux conditioning if self.is_dit: x = self.lnorm(x) @@ -560,7 +560,7 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, - is_dit = False, # should only be True for diffusion model + is_dit=False, # should only be True for diffusion model ): super(MultiSelfAttentionHead, self).__init__() @@ -574,7 +574,6 @@ def __init__( assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj - if norm_type == "LayerNorm": norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps) else: @@ -588,7 +587,7 @@ def __init__( self.lnorm = norm(dim_embed, eps=norm_eps) self.noise_conditioning = LinearNormConditioning( latent_space_dim=dim_embed, dtype=attention_dtype - ) #TODO: Do I need to pass dtype? + ) # TODO: Do I need to pass dtype? elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) else: @@ -621,7 +620,7 @@ def __init__( def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x - + # Handle ada_ln_aux conditioning if self.is_dit: x = self.lnorm(x) @@ -629,7 +628,6 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): else: x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) - # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1] diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index e1634763c..38bc6210f 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -25,8 +25,8 @@ import logging import math -import numpy as np +import numpy as np import torch from weathergen.common.config import Config, get_path_run @@ -116,7 +116,7 @@ def forward( # called in evaluation mode : # decide btw pure noise generation (inference) vs denoising a sample for # evaluation (train) using the stage variable - if self.cf.stage == 'train' or self.cf.stage == 'train_continue': + if self.cf.stage == "train" or self.cf.stage == "train_continue": # NOTE: temporary for analysing denoising return self.training_forward( tokens=tokens, @@ -124,10 +124,10 @@ def forward( meta_info=meta_info, coords=coords, ) - elif self.cf.stage == 'inference': + elif self.cf.stage == "inference": if fstep is None: raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") - + return self.inference_forward( fstep=fstep, num_steps=num_steps, @@ -173,11 +173,16 @@ def training_forward( self._noised_tokens = (y + n).detach() - - return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep, coords=coords) - def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, coords: torch.Tensor = None) -> torch.Tensor: + def denoise( + self, + x: torch.Tensor, + c: torch.Tensor, + sigma: float, + fstep: int, + coords: torch.Tensor = None, + ) -> torch.Tensor: """ The actual diffusion step, where the model removes noise from the input x under consideration of a conditioning c (e.g., previous time steps) and the current diffusion @@ -193,8 +198,8 @@ def denoise(self, x: torch.Tensor, c: torch.Tensor, sigma: float, fstep: int, co noise_emb = self.noise_embedder(c_noise) # Precondition input and feed through network - x = self.preconditioner.precondition(x, c) #currently does nothing - + x = self.preconditioner.precondition(x, c) # currently does nothing + return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=c ) # Eq. (7) in EDM paper @@ -208,11 +213,11 @@ def inference_forward( ) -> torch.Tensor: """ Forward pass of the diffusion model during inference. - + Iteratively denoises a random sample using the learned score function, with optional temporal conditioning extracted from meta_info. https://github.com/NVlabs/edm/blob/main/generate.py - + Args: fstep: Forecast step index for the network num_steps: Number of diffusion denoising steps (default: 30) @@ -272,18 +277,20 @@ def inference_forward( / (num_steps - 1) * (sigma_min_eff ** (1 / self.rho) - sigma_max_eff ** (1 / self.rho)) ) ** self.rho - t_steps = torch.cat( - [t_steps, torch.zeros_like(t_steps[:1])] - ) # t_N = 0 + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 # t_steps = torch.cat( # [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] # ) # t_N = 0 # --- Per-step tracking for diagnostics --- track = { - "sigma": [], "x_std": [], "denoised_std": [], - "l2_to_target": [], "cosine_to_target": [], - "c_skip": [], "x": [x.cpu()] + "sigma": [], + "x_std": [], + "denoised_std": [], + "l2_to_target": [], + "cosine_to_target": [], + "c_skip": [], + "x": [x.cpu()], } # Main sampling loop. @@ -332,9 +339,9 @@ def inference_forward( def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: """Save a diagnostic plot of the sampling trajectory.""" import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt - import matplotlib.colors as mcolors steps = list(range(len(track["sigma"]))) has_target = len(track["l2_to_target"]) > 0 @@ -349,7 +356,9 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: f"Sampling diagnostics | sigma_max_eff={track['sigma'][0]:.2f}, " f"sigma_data={self.sigma_data}, steps={num_steps}" ) - axes[0].axhline(self.sigma_data, color="grey", ls="--", lw=0.8, label=f"sigma_data={self.sigma_data}") + axes[0].axhline( + self.sigma_data, color="grey", ls="--", lw=0.8, label=f"sigma_data={self.sigma_data}" + ) axes[0].legend(fontsize=8) axes[0].grid(True, alpha=0.3) @@ -358,7 +367,9 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: axes[1].plot(steps, track["denoised_std"], "s-", markersize=3, label="denoised estimate") if self.cur_token is not None: target_std = self.cur_token.std().item() - axes[1].axhline(target_std, color="grey", ls="--", lw=0.8, label=f"target std={target_std:.3f}") + axes[1].axhline( + target_std, color="grey", ls="--", lw=0.8, label=f"target std={target_std:.3f}" + ) axes[1].set_ylabel("std") axes[1].legend(fontsize=8) axes[1].grid(True, alpha=0.3) @@ -418,7 +429,7 @@ def timestep_embedding(self, t: float, max_period: int = 10000): # Ensure t is 1D if t.ndim == 0: t = t.view(1) - + half = self.frequency_embedding_dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=self.dtype) / half @@ -434,16 +445,17 @@ def forward(self, t: float): t_emb = self.mlp(t_freq) return t_emb + class DateTimeEncoder(torch.nn.Module): """ Encodes timestamp(s) into multi-frequency sinusoidal calendar embeddings. - + Inspired by cBottle (Climate in a Bottle) with k=1..8 frequency scales. Captures seasonal (day-of-year) and diurnal (time-of-day) cycles at multiple timescales. - + Input shape: scalar or any tensor shape (...) Output shape: (..., 32) — 8 frequencies × 4 components (cos/sin per signal) - + Output structure for k=1..8: [cos(2πk·doy/365.25), sin(2πk·doy/365.25), cos(k·t), sin(k·t)] where: @@ -461,7 +473,7 @@ def forward(self, timestamp: np.ndarray) -> torch.Tensor: Args: timestamp: np.datetime64 scalar or array of timestamps - + Returns: torch.Tensor of shape (..., 32) containing multi-frequency embeddings """ @@ -470,46 +482,48 @@ def forward(self, timestamp: np.ndarray) -> torch.Tensor: orig_shape = timestamp.shape timestamp_flat = timestamp.reshape(-1) - + two_pi = 2.0 * np.pi - + # --- Extract time components --- - ts_int64 = timestamp_flat.astype('int64') # seconds since Unix epoch + ts_int64 = timestamp_flat.astype("int64") # seconds since Unix epoch seconds_in_day = 86400.0 seconds_of_day = (ts_int64 % int(seconds_in_day)) / seconds_in_day # [0, 1) - + # --- Extract day of year --- - day_np = timestamp_flat.astype('datetime64[D]') - year_start = day_np.astype('datetime64[Y]').astype('datetime64[D]') - next_year_start = (day_np.astype('datetime64[Y]') + np.timedelta64(1, 'Y')).astype('datetime64[D]') - + day_np = timestamp_flat.astype("datetime64[D]") + year_start = day_np.astype("datetime64[Y]").astype("datetime64[D]") + next_year_start = (day_np.astype("datetime64[Y]") + np.timedelta64(1, "Y")).astype( + "datetime64[D]" + ) + day_of_year_0 = (day_np - year_start).astype(np.int64) # [0, 365] or [0, 366] days_in_year = (next_year_start - year_start).astype(np.int64) # 365 or 366 doy_frac = day_of_year_0.astype(np.float32) / days_in_year.astype(np.float32) # [0, 1) - + # --- Multi-frequency sinusoidal embeddings --- # Build output for all 8 frequency scales embeddings = [] for k in range(1, self.num_frequencies + 1): k_float = float(k) - + # Day-of-year components: cos(2π·k·doy/365.25), sin(2π·k·doy/365.25) doy_phase = two_pi * k_float * doy_frac doy_cos = np.cos(doy_phase).astype(np.float32) doy_sin = np.sin(doy_phase).astype(np.float32) - + # Time-of-day components: cos(k·t), sin(k·t) where t = 2π·seconds_of_day tot_phase = k_float * two_pi * seconds_of_day tot_cos = np.cos(tot_phase).astype(np.float32) tot_sin = np.sin(tot_phase).astype(np.float32) - + embeddings.append(doy_cos) embeddings.append(doy_sin) embeddings.append(tot_cos) embeddings.append(tot_sin) - + # Stack all components: (N, 32) out = np.stack(embeddings, axis=-1) out = torch.from_numpy(out).float() - - return out.reshape(*orig_shape, self.num_frequencies * 4) \ No newline at end of file + + return out.reshape(*orig_shape, self.num_frequencies * 4) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 195d3edf7..d6894fcc2 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -636,12 +636,16 @@ def forward( tokens_in = tokens if self.cf.fe_diffusion_model: - assert noise_emb is not None, "noise_emb must be provided for diffusion model conditioning" + assert noise_emb is not None, ( + "noise_emb must be provided for diffusion model conditioning" + ) for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" + assert ada_ln_aux is None, ( + "ada_ln_aux should not be provided when diffusion model conditioning is disabled" + ) tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: for block in self.fe_blocks: diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index bfcb08679..62bc0af8d 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm +from weathergen.model.norms import AdaLayerNorm, RMSNorm class NamedLinear(torch.nn.Module): @@ -90,7 +90,7 @@ def __init__( else: self.lnorm = norm(dim_in, eps=norm_eps) - #TODO: The below should be consolidated – implementing in layer list for backward compatibility + # TODO: The below should be consolidated – implementing in layer list for backward compatibility if not is_dit: self.layers.append(self.lnorm) self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) @@ -104,7 +104,6 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) - # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): x, x_in = args[0], args[0] @@ -172,4 +171,3 @@ def forward(self, inputs, noise_emb): return (inputs * scale + offset).to( self.dtype ), gate # TODO: check if to(self.dtype) needed here - diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 629b9a173..a03d5109a 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -54,29 +54,16 @@ class ModelOutput: physical: list[dict[StreamName, torch.Tensor]] latent: list[dict[str, torch.Tensor | LatentState]] - noised_physical: list[dict[StreamName, torch.Tensor]] def __init__(self, len_output: int) -> None: self.physical = [{} for _ in range(len_output)] self.latent = [{} for _ in range(len_output)] - self.noised_physical = [{} for _ in range(len_output)] def add_physical_prediction( self, fstep: int, stream_name: StreamName, pred: torch.Tensor ) -> None: self.physical[fstep][stream_name] = pred - def add_noised_physical_prediction( - self, fstep: int, stream_name: StreamName, pred: torch.Tensor - ) -> None: - self.noised_physical[fstep][stream_name] = pred - - def get_noised_physical_prediction(self, fstep: int, stream_name: StreamName | None = None): - pred = self.noised_physical[fstep] - if stream_name is not None: - pred = pred.get(stream_name, None) - return pred - def add_latent_prediction(self, fstep: int, latent_name: str, pred: torch.Tensor) -> None: self.latent[fstep][latent_name] = pred @@ -326,20 +313,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord """ super(Model, self).__init__() - - - - - - self._noise = None - - - - - - self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**self.healpix_level @@ -411,7 +386,7 @@ def create(self) -> "Model": if cf.fe_num_blocks > 0: self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) if cf.get("fe_diffusion_model", False): - self.forecast_engine = DiffusionForecastEngine( + self.forecast_engine = DiffusionForecastEngine( cf, self.num_healpix_cells, forecast_engine=self.forecast_engine ) @@ -744,21 +719,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) - # print(tokens.std().item()) - # breakpoint() - - # Normalize tokens - # TODO: REMOVE THIS LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # t_mean = tokens.mean() - # t_std = tokens.std() - # tokens = (tokens - t_mean) / (t_std + 1e-6) * self.cf.sigma_data - # tokens = torch.clamp(tokens, -100.0, 100.0) - # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): # apply forecasting engine (if present) if self.forecast_engine: - tokens = self.forecast_engine( tokens, step, @@ -766,28 +730,9 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: coords=model_params.rope_coords, ) - # Un-normalize tokens - # TODO: REMOVE THIS AS ABOVE. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - # tokens = tokens * (t_std + 1e-6) / self.cf.sigma_data + t_mean - # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) - # decode noised tokens for visualization (diffusion models only, eval mode) - if ( - not self.training - and isinstance(self.forecast_engine, DiffusionForecastEngine) - and self.forecast_engine._noised_tokens is not None - ): - output = self.predict_decoders( - model_params, - step, - self.forecast_engine._noised_tokens, - batch, - output, - noised=True, - ) - # latent predictions (raw and with SSL heads) output = self.predict_latent(model_params, step, tokens, batch, output) @@ -823,7 +768,6 @@ def predict_decoders( tokens: torch.Tensor, batch: ModelBatch, output: ModelOutput, - noised: bool = False, ) -> ModelOutput: """ Compute decoder-based predictions @@ -924,10 +868,7 @@ def predict_decoders( # recover batch dimension (ragged, so as list) pred = torch.split(pred, t_coords_lens, dim=1) - # breakpoint() - if noised: - output.add_noised_physical_prediction(step, stream_name, pred) - else: - output.add_physical_prediction(step, stream_name, pred) + + output.add_physical_prediction(step, stream_name, pred) return output diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 4fb6484a2..1d394057c 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -69,36 +69,34 @@ def __init__(self, dim_embed: int, dim_aux: int, norm_eps: float = 1e-5): self.norm = torch.nn.LayerNorm(dim_embed, elementwise_affine=False, eps=norm_eps) self.gate_proj = torch.nn.Linear(dim_aux, dim_embed) self.shift_proj = torch.nn.Linear(dim_aux, dim_embed) - + with torch.no_grad(): self.gate_proj.weight.zero_() self.gate_proj.bias.zero_() self.shift_proj.weight.zero_() self.shift_proj.bias.zero_() - def forward( - self, x: torch.Tensor, aux: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor, aux: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Returns (x_normalized_and_scaled, gate_signal) for residual modulation.""" x_norm = self.norm(x) - + if aux.dim() == 0: aux = aux.unsqueeze(0) - + gate_params = self.gate_proj(aux) shift_params = self.shift_proj(aux) - + while gate_params.dim() < x_norm.dim(): gate_params = gate_params.unsqueeze(-2) shift_params = shift_params.unsqueeze(-2) - + gate = 1 + gate_params x_out = gate * x_norm + shift_params gate_signal = gate.mean(dim=-1, keepdim=True) - + return x_out, gate_signal - + class AdaLayerNorm(torch.nn.Module): """ AdaLayerNorm for embedding auxiliary information @@ -126,9 +124,11 @@ def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Ten return x + def modulate(x, shift, scale): return x * (1 + scale) + shift + class SwiGLU(nn.Module): def __init__(self): super(SwiGLU, self).__init__() @@ -137,7 +137,7 @@ def forward(self, x): x1, x2 = x.chunk(2, dim=-1) return x2 * F.silu(x1) - + class AdaLayerNormLayer(torch.nn.Module): """ AdaLayerNorm for embedding auxiliary information as done in DiT (Peebles & Xie) with zero @@ -194,6 +194,7 @@ def forward(self, x: torch.Tensor, c: torch.Tensor, x_lens, **kwargs) -> torch.T + x ) + class SaturateEncodings(nn.Module): """A common alternative to a KL regularisation prevent outliers in the latent space when learning an auto-encoder for latent generative model, an example value for the scale factor is 5 diff --git a/src/weathergen/train/target_and_aux_utils.py b/src/weathergen/train/target_and_aux_utils.py index 3e26c4ab9..47f9551db 100644 --- a/src/weathergen/train/target_and_aux_utils.py +++ b/src/weathergen/train/target_and_aux_utils.py @@ -4,8 +4,8 @@ from weathergen.common.config import Config, merge_configs from weathergen.model.ema import EMAModel from weathergen.model.model_interface import init_model_and_shard -from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux from weathergen.train.target_and_aux_diffusion import DiffusionLatentTargetEncoder +from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher from weathergen.train.teacher_utils import load_encoder_from_checkpoint, prepare_encoder_teacher @@ -48,8 +48,14 @@ def get_target_aux_calculator( overrides=target_and_aux_calc_params.get("model_param_overrides", {}), ) # Free components not needed by DiffusionLatentTargetEncoder (only uses the encoder) - for attr in ("forecast_engine", "pred_heads", "target_token_engines", - "embed_target_coords", "latent_heads", "latent_pre_norm"): + for attr in ( + "forecast_engine", + "pred_heads", + "target_token_engines", + "embed_target_coords", + "latent_heads", + "latent_pre_norm", + ): if hasattr(model, attr) and getattr(model, attr) is not None: delattr(model, attr) setattr(model, attr, None) @@ -102,7 +108,7 @@ def get_target_aux_calculator( elif target_and_aux_calc == "FrozenTeacher": target_aux = FrozenTeacher.from_pretrained(cf, dataset, device, target_and_aux_calc_params) - + else: raise NotImplementedError(f"{target_and_aux_calc} is not implemented") diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index dee993a7c..cb4a1e1ab 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -615,7 +615,10 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): + for ( + loss_name, + target_aux, + ) in self.target_and_aux_calculators_val.items(): target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) targets_and_auxs[loss_name] = target_aux.compute( self.cf.general.istep, @@ -649,8 +652,10 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch, preds, targets_and_auxs, - noise_level=noise_level if is_diffusion and len(noise_levels) > 1 else None, - write_zarr=False, #(noise_idx == 0), + noise_level=noise_level + if is_diffusion and len(noise_levels) > 1 + else None, + write_zarr=False, # (noise_idx == 0), ) pbar.update(batch_size) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 30e7d219f..e83ab12d4 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,68 +8,28 @@ # nor does it submit to any jurisdiction. import logging -from math import exp -import re -import matplotlib.pyplot as plt -import cartopy.crs as ccrs import numpy as np import torch -import xarray as xr import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange, zarrio_writer from weathergen.datasets.data_reader_base import TimeWindowHandler -from weathergen.evaluate.plotting.plotter import Plotter _logger = logging.getLogger(__name__) -# TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. -i = 0 - - -# TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. -def _normalize_channel_name(name: str) -> str: - return str(name).lower().replace("_", "").replace(" ", "") - - -# TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. -def _resolve_channel_names(stream_info, raw_channels): - if not raw_channels: - return raw_channels - if isinstance(raw_channels[0], str): - return list(raw_channels) - - channel_names = None - if hasattr(stream_info, "val_target_channels") and stream_info.val_target_channels: - if isinstance(stream_info.val_target_channels[0], str): - channel_names = list(stream_info.val_target_channels) - - if channel_names is None: - target_weights = getattr(stream_info, "target_channel_weights", None) - if isinstance(target_weights, dict): - channel_names = list(target_weights.keys()) - - if channel_names is None: - channel_weights = getattr(stream_info, "channel_weights", None) - if isinstance(channel_weights, dict): - channel_names = list(channel_weights.keys()) - - if channel_names is None: - return [f"ch{idx}" for idx in raw_channels] - - resolved = [] - for idx in raw_channels: - if 0 <= int(idx) < len(channel_names): - resolved.append(channel_names[int(idx)]) - else: - resolved.append(f"ch{idx}") - return resolved - def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out, + cf, + val_cfg, + batch_size, + mini_epoch, + batch_idx, + dn_data, + batch, + model_output, + target_aux_out, noise_level=None, write_zarr=True, ): @@ -85,8 +45,6 @@ def write_output( Whether to write zarr output. Default True. Set to False to only generate plots without writing zarr data. """ - # TODO: REMOVE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - global i # TODO: how to handle multiple physical loss terms outputs_physical = [ @@ -100,7 +58,6 @@ def write_output( # collect all target / prediction-related information fp32 = torch.float32 preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], [] - noised_preds_all = [] # decoded noised tokens (diffusion models only) timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() forecast_offset = timestep_idxs[0] @@ -112,9 +69,8 @@ def write_output( targets_all += [[]] targets_coords_all += [[]] targets_times_all += [[]] - noised_preds_all += [[]] targets_lens += [[]] - for stream_idx, stream_info in enumerate(cf.streams): + for stream_info in cf.streams: sname = stream_info["name"] # handle spoof data: do not write since it might corrupt validation (spoofing invisible @@ -132,7 +88,7 @@ def write_output( else: preds = model_output.get_physical_prediction(t_idx, sname) targets = target_aux_out.physical[t_idx][sname]["target"] - + preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] # handle forcing streams or if sample is empty @@ -160,7 +116,6 @@ def write_output( # extract original target coords and times from target data t_coords_s += [t_coords.cpu().numpy()] t_times_s += [t_times.astype("datetime64[ns]")] - targets_lens[-1] += [[]] targets_lens[-1][-1] += [t.shape[0] for t in targets_s] @@ -170,19 +125,6 @@ def write_output( targets_coords_all[-1] += [np.concatenate(t_coords_s)] targets_times_all[-1] += [np.concatenate(t_times_s)] - # collect decoded noised tokens (diffusion models only) - noised_preds = model_output.get_noised_physical_prediction(t_idx, sname) - if noised_preds is not None: - noised_s = [] - for i_batch, npred in enumerate(noised_preds): - idxs_inv = target_aux_out.physical[t_idx][sname]["idxs_inv"][i_batch] - if idxs_inv is not None: - npred = npred[:, idxs_inv] - noised_s += [dn_data(sname, npred).detach().to(fp32).cpu().numpy()] - noised_preds_all[-1] += [np.concatenate(noised_s, axis=1)] - else: - noised_preds_all[-1] += [np.array([])] - if len(preds_all) == 0 or np.array([p.shape[1] for pp in preds_all for p in pp]).sum() == 0: _logger.warning("Writing no data since predictions are empty.") return @@ -254,157 +196,7 @@ def write_output( for subset in data.items(): zio.write_zarr(subset) - # Free arrays no longer needed after zarr writing del targets_all, targets_lens, sources, data - # TODO: REMOVE EVERYTHING BELOW THIS LINE LATER. ONLY FOR SINGLE-SAMPLE OVERFITTING EXPERIMENTS. - - # Prepare prediction data for Plotter (scatter plot expects lat/lon coords on ipoint). - base_plot_dir = config.get_path_run(cf) / "plots" / "validation" - base_plot_dir.mkdir(parents=True, exist_ok=True) - plotter = Plotter({"image_format": "png", "dpi_val": 150}, base_plot_dir) - # headline_channels = {"2t", "z500", "q850", "10u", "10v"} - # headline_channels = {"2t", "q850"} - # headline_channels = {"z500"} - headline_channels = {"2t", "z500"} - - t_idx = 0 - for stream_idx, stream_info in enumerate(cf.streams): - stream_name = stream_info["name"] - preds_stream = preds_all[t_idx][stream_idx] - noised_stream = noised_preds_all[t_idx][stream_idx] - coords_stream = targets_coords_all[t_idx][stream_idx] - times_stream = targets_times_all[t_idx][stream_idx] - - if preds_stream.size == 0 or coords_stream.size == 0: - _logger.warning(f"No prediction data to plot for stream {stream_name}.") - continue - - # Expected shape is (ens, ipoint, channel). Select first ensemble if present. - if preds_stream.ndim == 3: - preds_stream = preds_stream[0] - elif preds_stream.ndim != 2: - _logger.warning( - f"Unsupported prediction shape {preds_stream.shape} for stream {stream_name}." - ) - continue - - has_noised = ( - noised_stream.size > 0 and noised_stream.ndim >= 2 - ) - if has_noised and noised_stream.ndim == 3: - noised_stream = noised_stream[0] - - channels = _resolve_channel_names(stream_info, target_channels[stream_idx]) - selected_channels = [ - ch for ch in channels if _normalize_channel_name(ch) in headline_channels - ] - if not selected_channels: - _logger.warning(f"No headline channels available for plotting stream {stream_name}.") - continue - - ch_to_col = {ch: idx for idx, ch in enumerate(channels)} - - lat = coords_stream[:, 0] - lon = coords_stream[:, 1] - - run_id = config.get_run_id_from_config(cf) - num_samples = len(preds) - len_per_sample = preds_stream.shape[0] // num_samples - - for sample in range(num_samples): - s_start = sample * len_per_sample - s_end = (sample + 1) * len_per_sample - - # Extract sample date from target times - sample_times = times_stream[s_start:s_end] - sample_date = np.unique(sample_times) - if len(sample_date) > 0 and not np.isnat(sample_date[0]): - date_str = str(sample_date[0].astype("datetime64[h]")) - else: - date_str = "unknown date" - - for varname in selected_channels: - col = ch_to_col[varname] - pred_vals = preds_stream[s_start:s_end, col] - sample_lat = lat[s_start:s_end] - sample_lon = lon[s_start:s_end] - - # Drop NaN points (use pred mask for both panels) - valid = ~np.isnan(pred_vals) - pred_vals = pred_vals[valid] - plot_lat = sample_lat[valid] - plot_lon = sample_lon[valid] - - channel_dir = base_plot_dir / varname - channel_dir.mkdir(parents=True, exist_ok=True) - - eta_str = str(noise_level) if noise_level is not None else None - eta_tag = f"_eta{eta_str}" if eta_str is not None else "" - epoch_tag = f"epoch_{mini_epoch:03d}_{i % 3}{eta_tag}" - - # Determine number of panels - ncols = 2 if has_noised else 1 - proj = ccrs.Robinson() - fig, axes = plt.subplots( - 1, ncols, figsize=(8 * ncols, 5), - subplot_kw={"projection": proj}, dpi=150, - ) - if ncols == 1: - axes = [axes] - - # Shared color limits across panels - vmin, vmax = np.nanmin(pred_vals), np.nanmax(pred_vals) - - # Panel 1: noised (if available) - if has_noised: - noised_vals = noised_stream[s_start:s_end, col][valid] - vmin = min(vmin, np.nanmin(noised_vals)) - vmax = max(vmax, np.nanmax(noised_vals)) - ax_noised = axes[0] - ax_noised.coastlines() - ax_noised.set_global() - sc_n = ax_noised.scatter( - plot_lon, plot_lat, c=noised_vals, - vmin=vmin, vmax=vmax, cmap="coolwarm", - s=4.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, - ) - ax_noised.set_title("Noised", fontsize=10) - ax_denoised = axes[1] - else: - ax_denoised = axes[0] - - # Panel 2 (or only panel): denoised prediction - ax_denoised.coastlines() - ax_denoised.set_global() - sc_d = ax_denoised.scatter( - plot_lon, plot_lat, c=pred_vals, - vmin=vmin, vmax=vmax, cmap="coolwarm", - s=4.0, marker="o", transform=ccrs.PlateCarree(), linewidths=0.0, - ) - ax_denoised.set_title("Denoised", fontsize=10) - - # Shared colorbar - fig.colorbar(sc_d, ax=axes, orientation="horizontal", - label=varname, shrink=0.6, pad=0.05) - - # Suptitle with date - eta_info = f" | noise_level={eta_str}" if eta_str else "" - fig.suptitle( - f"{stream_name} - {varname} (fstep {forecast_offset})" - f" | sample {sample + 1} | {date_str}{eta_info}", - fontsize=11, - ) - - fname = channel_dir / f"{epoch_tag}_{sample}.{plotter.image_format}" - fig.savefig(fname, bbox_inches="tight") - plt.close(fig) - - del pred_vals, plot_lat, plot_lon, valid - - del preds_stream, coords_stream - del targets_times_all - - i += 1 \ No newline at end of file From 0ec723af97cdb64e2a8863c1fa84e5f00204d646 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 29 Apr 2026 13:17:50 +0200 Subject: [PATCH 299/344] only write first noise leveld during validation --- src/weathergen/train/trainer.py | 43 +++++++++++++-------------- src/weathergen/utils/validation_io.py | 18 ++--------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index cb4a1e1ab..0c7b404b8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -634,29 +634,26 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) # log output - if bidx < num_samples_write: - # denormalization function for data - denormalize_data_fct = ( - (lambda x0, x1: x1) - if mode_cfg.get("output", {}).get("normalized_samples", False) - else self.dataset_val.denormalize_target_channels - ) - # write output (zarr only for first noise level, plots for all) - write_output( - self.cf, - mode_cfg, - batch_size, - mini_epoch, - bidx, - denormalize_data_fct, - batch, - preds, - targets_and_auxs, - noise_level=noise_level - if is_diffusion and len(noise_levels) > 1 - else None, - write_zarr=False, # (noise_idx == 0), - ) + if noise_idx == 0: + if bidx < num_samples_write: + # denormalization function for data + denormalize_data_fct = ( + (lambda x0, x1: x1) + if mode_cfg.get("output", {}).get("normalized_samples", False) + else self.dataset_val.denormalize_target_channels + ) + # write output (zarr only for first noise level, plots for all) + write_output( + self.cf, + mode_cfg, + batch_size, + mini_epoch, + bidx, + denormalize_data_fct, + batch, + preds, + targets_and_auxs, + ) pbar.update(batch_size) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e83ab12d4..ee86ca2ce 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -30,20 +30,9 @@ def write_output( batch, model_output, target_aux_out, - noise_level=None, - write_zarr=True, ): """ Interface for writing model output - - Parameters - ---------- - noise_level : float | None - Fixed diffusion noise level (eta) used for this validation pass. - When not None the value is embedded in plot filenames and titles. - write_zarr : bool - Whether to write zarr output. Default True. Set to False to only - generate plots without writing zarr data. """ # TODO: how to handle multiple physical loss terms @@ -191,10 +180,9 @@ def write_output( sample_start, forecast_offset, ) - if write_zarr: - with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: - for subset in data.items(): - zio.write_zarr(subset) + with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: + for subset in data.items(): + zio.write_zarr(subset) # Free arrays no longer needed after zarr writing del targets_all, targets_lens, sources, data From f6df91099c0d5cae5034854b39fa4fadc28cf847 Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 29 Apr 2026 17:58:47 +0200 Subject: [PATCH 300/344] Configs for forecast model candidate --- config/config_eval.yml | 86 +++++++++++++++++++ config/config_forecasting.yml | 7 +- config/config_forecasting_finetuning.yml | 8 +- config/config_pipeline_forecasting.yml | 42 +++++++++ config/streams/era5_1deg_forecasting/era5.yml | 14 +-- 5 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 config/config_eval.yml create mode 100644 config/config_pipeline_forecasting.yml diff --git a/config/config_eval.yml b/config/config_eval.yml new file mode 100644 index 000000000..fda60d039 --- /dev/null +++ b/config/config_eval.yml @@ -0,0 +1,86 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +global_plotting_options: +# regions: ["europe", "global"] + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + fps: 2 + ERA5: + marker_size: 3 + scale_marker_size: 1 + marker: "o" + alpha: 0.5 + add_healpix_grid: false + healpix_nside: 4 + 2t: + vmin: 205 + vmax: 315 + 10u: + vmin: -23 + vmax: 23 + 10v: + vmin: -23 + vmax: 23 + z_500: + vmin: 4700 + vmax: 6000 + t_850: + vmin: 230 + vmax: 310 + u_850: + vmin: -40 + vmax: 40 + v_850: + vmin: -40 + vmax: 40 + q_850: + vmin: 0 + vmax: 0.02 + + +evaluation: + metrics: ["rmse", "froct", "acc"] + regions: ["global", "nhem"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: true + score_cards: false + bar_plots: false + + +default_streams: + ERA5: + climatology_path: "/iopsstor/scratch/cscs/thunter/shared_work/assets/climatology/aifs-ea-an-oper-0001-mars-o96-1980-2020-6h-v6_climatology.zarr" + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850"] + # channels: ["2t", "z_500", "q_850"] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [0] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + plot_maps: true + plot_histograms: false + plot_animations: true + # CERRA: + # channels: ["z_500", "t_850", "u_850"] #, "blah"] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: "all" + # plot_maps: true + # plot_bias: false + # plot_target: false + # plot_histograms: true + # plot_animations: true + +run_ids : \ No newline at end of file diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 56ab2f4a0..ed612f5e3 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -69,6 +69,8 @@ healpix_level: 5 rope_2D: False mlp_type: swiglu use_xsa: True +# mlp_type: mlp +# use_xsa: False with_mixed_precision: True with_flash_attention: True @@ -89,6 +91,7 @@ freeze_modules: "" load_chkpt: {} norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type ##################################### @@ -138,7 +141,7 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] - num_mini_epochs: 64 + num_mini_epochs: 196 samples_per_mini_epoch: 4096 shuffle: True @@ -185,7 +188,7 @@ training_config: } forecast : - time_step: 06:00:00 + # time_step: 01:00:00 offset: 1 num_steps: 3 policy: "fixed" diff --git a/config/config_forecasting_finetuning.yml b/config/config_forecasting_finetuning.yml index d3b4c2441..761b111c4 100644 --- a/config/config_forecasting_finetuning.yml +++ b/config/config_forecasting_finetuning.yml @@ -35,17 +35,17 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] - num_mini_epochs: 16 + num_mini_epochs: 32 start_date: 2014-01-01T00:00 end_date: 2022-12-31T00:00 learning_rate_scheduling : lr_start: 1e-6 - lr_max: 1e-4 + lr_max: 7e-5 lr_final_decay: 2e-6 lr_final: 0.0 - num_steps_warmup: 256 + num_steps_warmup: 1024 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "cosine" @@ -53,7 +53,7 @@ training_config: parallel_scaling_policy: "sqrt" forecast : - time_step: 06:00:00 + # time_step: 06:00:00 offset: 1 num_steps: 8 policy: "fixed" diff --git a/config/config_pipeline_forecasting.yml b/config/config_pipeline_forecasting.yml new file mode 100644 index 000000000..8e3d0dd95 --- /dev/null +++ b/config/config_pipeline_forecasting.yml @@ -0,0 +1,42 @@ +# Pipeline for pre-training, fine-tuning, inference, and evaluation of a forecasting model. Each stage can have its own +# configuration files, options, and resource requirements. The stages are executed sequentially, with the ability to +# reference previous stages' run IDs for chaining jobs together. + +stages: + - name: pretrain + stage: train + config_files: + - config/config_forecasting.yml + options: + - training.max_epochs=196 + chain_jobs: 6 + nodes: 2 + # slurm_args: + # - "--time=12:00:00" + + - name: finetune + stage: train + from_run_id: STAGE.pretrain.run_id # optional; use STAGE..run_id to reference a previous stage's run_id + config_files: + - config/config_forecasting_finetuning.yml + chain_jobs: 4 + nodes: 2 + # slurm_args: + # - "--time=12:00:00" + # # - "--partition=booster" + + - name: inference + stage: inference + from_run_id: STAGE.finetune.run_id + options: + - test_config.streams_directory="./config/streams/era5_1deg_forecasting_inf/" + - test_config.start_date=202301010000 + - test_config.end_date=202301310000 + - test_config.samples=16 + nodes: 1 + + - name: evaluation + stage: evaluation + eval_config: config/config_eval.yml + run_ids: [STAGE.inference.run_id] # list of run_ids to evaluate; can reference previous stages or be hardcoded + nodes: 1 diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 0bd70ae01..72ee284c2 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -9,7 +9,8 @@ ERA5 : type : anemoi - filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] stream_id : 0 source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] @@ -17,18 +18,19 @@ ERA5 : location_weight : cosine_latitude masking_rate : 0.6 masking_rate_none : 0.05 - token_size : 8 - tokenize_spacetime : True - max_num_targets: 20000 + token_size : 64 + tokenize_spacetime : False + max_num_targets: 120000 + frequency: 01:00:00 embed : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : num_layers : 2 num_heads : 4 From e095f8bbb27025edfe5c7a40cb702943855cb868 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Thu, 30 Apr 2026 10:53:51 +0200 Subject: [PATCH 301/344] Store denoising steps during diffusion inference (#2284) * writing inference denoising steps as forecast steps * write inference denoising into forecast steps dimension --- src/weathergen/model/diffusion.py | 10 ++++- src/weathergen/model/model.py | 61 +++++++++++++++++++++++++-- src/weathergen/train/trainer.py | 45 ++++++++++++++++++++ src/weathergen/utils/validation_io.py | 9 ++++ 4 files changed, 121 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 38bc6210f..fc565c639 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -293,6 +293,10 @@ def inference_forward( "x": [x.cpu()], } + # Per-step intermediate denoised states (one per ODE step). + # Returned to the caller so they can be treated as a forecast-step dimension. + intermediate_x: list[torch.Tensor] = [] + # Main sampling loop. x_next = x * t_steps[0] for i, (t_cur, t_next) in enumerate( @@ -332,9 +336,13 @@ def inference_forward( if self.cur_token is not None: track["l2_to_target"].append((x_next - self.cur_token).norm().item()) track["x"].append(self.cur_token.cpu()) + + # Record intermediate denoised state for this ODE step. + intermediate_x.append(x_next) + self._plot_sampling_diagnostics(track, num_steps) - return x_next + return intermediate_x def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: """Save a diagnostic plot of the sampling trajectory.""" diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a03d5109a..f29721735 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -346,6 +346,9 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.register_token_idxs = list(range(cf.num_register_tokens)) self.aux_token_idxs = list(range(cf.num_register_tokens + cf.num_class_tokens)) self.num_aux_tokens = cf.num_register_tokens + cf.num_class_tokens + # One-shot flag to avoid log spam when warning about an unsupported + # diffusion-inference + multi-step-rollout combination. + self._warned_diffusion_multi_step = False def _create_latent_pred_head( self, global_cfg, name, loss_cfg, use_class_token, use_patch_token @@ -730,6 +733,38 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: coords=model_params.rope_coords, ) + # Diffusion inference returns the per-ODE-step intermediate denoised tokens as a + # list. Treat each intermediate state as its own forecast step in the output so the + # full denoising trajectory can be inspected downstream. The original `step` is + # still used to look up target coordinates (they share the same physical timestamp). + if isinstance(tokens, list): + # Diffusion inference currently only supports a single physical forecast + # step (forecast.num_steps=1); the per-ODE-step trajectory consumes the + # ModelOutput fstep dimension. Multi-step autoregressive rollouts on top of + # diffusion are not implemented yet. + if ( + len(batch.get_output_idxs()) > 1 + and not self._warned_diffusion_multi_step + ): + logger.warning( + "Diffusion inference is being run with forecast.num_steps=%d (>1). " + "Only a single forecast step is supported in this mode; the " + "per-ODE-step denoising trajectory will overwrite later forecast " + "steps in the model output.", + len(batch.get_output_idxs()), + ) + self._warned_diffusion_multi_step = True + # Resize output to fit the diffusion trajectory. + output = self._reindex_output_for_trajectory(output, len(tokens)) + for i, toks in enumerate(tokens): + output = self.predict_decoders( + model_params, step, toks, batch, output, out_step=i + ) + output = self.predict_latent( + model_params, step, toks, batch, output, out_step=i + ) + continue + # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) @@ -738,6 +773,18 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: return output + @staticmethod + def _reindex_output_for_trajectory(output: ModelOutput, n_steps: int) -> ModelOutput: + """ + Resize a ModelOutput to hold ``n_steps`` forecast steps, preserving any latent entries + that were already attached to fstep 0 (e.g. encoder posteriors). + """ + new_output = ModelOutput(n_steps) + if len(output.latent) > 0: + for k, v in output.latent[0].items(): + new_output.add_latent_prediction(0, k, v) + return new_output + def predict_latent( self, model_params: ModelParams, @@ -745,19 +792,23 @@ def predict_latent( tokens: torch.Tensor, batch: ModelBatch, output: ModelOutput, + out_step: int | None = None, ) -> ModelOutput: """ Compute latent predictions """ + if out_step is None: + out_step = step + # safe latent prediction tokens_post_norm = self.latent_pre_norm(tokens) if step == 0 else None latent_state = self.tokens_to_latent_state(tokens_post_norm, tokens) - output.add_latent_prediction(step, "latent_state", latent_state) + output.add_latent_prediction(out_step, "latent_state", latent_state) # latent predictions for SSL training for name, head in self.latent_heads.items(): - output.add_latent_prediction(step, name, head(latent_state)) + output.add_latent_prediction(out_step, name, head(latent_state)) return output @@ -768,6 +819,7 @@ def predict_decoders( tokens: torch.Tensor, batch: ModelBatch, output: ModelOutput, + out_step: int | None = None, ) -> ModelOutput: """ Compute decoder-based predictions @@ -790,6 +842,9 @@ def predict_decoders( if not self.pred_heads: return output + if out_step is None: + out_step = step + # remove register and class tokens tokens = tokens[:, self.num_aux_tokens :] @@ -869,6 +924,6 @@ def predict_decoders( # recover batch dimension (ragged, so as list) pred = torch.split(pred, t_coords_lens, dim=1) - output.add_physical_prediction(step, stream_name, pred) + output.add_physical_prediction(out_step, stream_name, pred) return output diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0c7b404b8..59683f085 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -56,6 +56,43 @@ # cfg_keys_to_filter = ["losses", "model_input", "target_input"] +def _expand_targets_to_match_preds(preds, targets_and_auxs: dict) -> None: + """ + Replicate per-fstep entries in each TargetAuxOutput so its ``physical`` and ``latent`` + lists match the number of forecast steps in ``preds``. + + Diffusion inference produces one ``preds`` fstep per ODE denoising step, but the + physical target is identical across the trajectory. Without this expansion the loss + calculator (which zips preds and targets with ``strict=True``) raises a length + mismatch. + + The expansion replicates references — no tensor copies are made — and is a no-op when + the lengths already agree. + """ + n_pred = len(preds.physical) + for t_aux in targets_and_auxs.values(): + n_tgt = len(t_aux.physical) + if n_tgt == n_pred or n_tgt == 0: + continue + if n_pred % n_tgt != 0: + logger.warning( + "Cannot expand target/aux from %d to %d fsteps (not a multiple); " + "leaving unchanged.", + n_tgt, + n_pred, + ) + continue + repeat = n_pred // n_tgt + t_aux.physical = [t_aux.physical[i // repeat] for i in range(n_pred)] + t_aux.latent = [t_aux.latent[i // repeat] for i in range(n_pred)] + # output_idxs is consumed by validation IO via batch.get_output_idxs(), but we + # keep the dataclass internally consistent in case other consumers read it. + if t_aux.output_idxs is not None and len(t_aux.output_idxs) == n_tgt: + t_aux.output_idxs = [ + t_aux.output_idxs[i // repeat] for i in range(n_pred) + ] + + class Trainer(TrainerBase): def __init__(self, train_logging: Config): TrainerBase.__init__(self) @@ -627,6 +664,14 @@ def validate(self, mini_epoch, mode_cfg, batch_size): self.model, ) + # Diffusion inference inflates the model output's fstep + # dimension to one entry per ODE step (the denoising + # trajectory). The physical target is identical for every + # such step, so replicate target/aux entries to keep the + # downstream loss calculator and validation IO aligned. + if is_diffusion: + _expand_targets_to_match_preds(preds, targets_and_auxs) + _ = self.loss_calculator_val.compute_loss( preds=preds, targets_and_aux=targets_and_auxs, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index ee86ca2ce..e6bbebe8d 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -50,6 +50,15 @@ def write_output( timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() forecast_offset = timestep_idxs[0] + + # Diffusion inference inflates the model output's fstep dimension to one entry per + # ODE denoising step (the trajectory). The batch only has the original physical + # forecast indices, so synthesize a contiguous run of indices starting at the + # original first index to cover every entry in model_output / target_aux_out. + n_pred_steps = len(model_output.physical) + if n_pred_steps > len(timestep_idxs): + timestep_idxs = list(range(forecast_offset, forecast_offset + n_pred_steps)) + targets_lens = [] # TODO Maybe stopping at forecast_steps explained #1657 From a717fedc503cb0ee01936d0598eab230767a6fe3 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 2 May 2026 09:49:41 +0200 Subject: [PATCH 302/344] Fix max_num_targets=-1 for inference --- src/weathergen/train/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index caae65647..3c018b33d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -183,6 +183,7 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): device_type = torch.accelerator.current_accelerator() self.device = torch.device(f"{device_type}:{cf.local_rank}") self.ema_model = None + [stream.update({"max_num_targets": -1}) for stream in cf.streams] # create data loader # only one needed since we only run the validation code path From 18fcc58fc243bfc444f81874065ac27c7b788369 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 5 May 2026 10:25:13 +0200 Subject: [PATCH 303/344] Config minor adjustment: add forecast.time_step: 06:00:00 back in --- config/config_forecasting.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index ed612f5e3..30faddf75 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -188,7 +188,7 @@ training_config: } forecast : - # time_step: 01:00:00 + time_step: 06:00:00 offset: 1 num_steps: 3 policy: "fixed" From c6d1d75918fbc5cca2ebb76cde7c9424a8048d26 Mon Sep 17 00:00:00 2001 From: kctezcan Date: Wed, 6 May 2026 18:26:36 +0300 Subject: [PATCH 304/344] Bugfix + remove assertion FC_offset=0 (#2323) * remove fc offset 0 assertion * encoder layer norm correct dimension --- packages/common/src/weathergen/common/config.py | 9 +++++---- src/weathergen/model/encoder.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 4df4a67d9..75a6f3ba1 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -753,10 +753,11 @@ def validate_forecast_policy_and_steps(forecast_cfg: OmegaConf, mode: str): output_offset = forecast_cfg.get("offset", 0) assert isinstance(output_offset, int), TypeError(valid_forecast_offset) if output_offset == 0: - if isinstance(forecast_cfg.num_steps, int): - assert forecast_cfg.num_steps in [0, 1], valid_forecast_steps_offset0 - else: - raise TypeError(valid_forecast_steps_offset0) + # if isinstance(forecast_cfg.num_steps, int): + # assert forecast_cfg.num_steps in [0, 1], valid_forecast_steps_offset0 + # else: + # raise TypeError(valid_forecast_steps_offset0) + pass elif output_offset == 1: assert forecast_cfg.policy, (provide_forecast_policy, valid_forecast_policies) if isinstance(forecast_cfg.num_steps, int): diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 3ed17e3ae..582e9b57f 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -117,7 +117,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) - self.ln = torch.nn.LayerNorm(cf.ae_local_dim_embed, elementwise_affine=False) + self.ln = torch.nn.LayerNorm(cf.ae_global_dim_embed, elementwise_affine=False) def forward(self, model_params, batch): """ From 646fa7e0bf3f01c5cf57e8307e766dafe1f38833 Mon Sep 17 00:00:00 2001 From: Matthias Date: Thu, 7 May 2026 14:40:26 +0200 Subject: [PATCH 305/344] Change config to 1-step pre-training --- config/config_forecasting.yml | 2 +- config/config_pipeline_forecasting.yml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 30faddf75..9e1110d3d 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -190,7 +190,7 @@ training_config: forecast : time_step: 06:00:00 offset: 1 - num_steps: 3 + num_steps: 1 policy: "fixed" diff --git a/config/config_pipeline_forecasting.yml b/config/config_pipeline_forecasting.yml index 8e3d0dd95..b6dffe6b5 100644 --- a/config/config_pipeline_forecasting.yml +++ b/config/config_pipeline_forecasting.yml @@ -8,8 +8,8 @@ stages: config_files: - config/config_forecasting.yml options: - - training.max_epochs=196 - chain_jobs: 6 + - training.max_epochs=256 + chain_jobs: 7 nodes: 2 # slurm_args: # - "--time=12:00:00" @@ -29,10 +29,10 @@ stages: stage: inference from_run_id: STAGE.finetune.run_id options: - - test_config.streams_directory="./config/streams/era5_1deg_forecasting_inf/" + # - test_config.streams_directory="./config/streams/era5_1deg_forecasting_inf/" - test_config.start_date=202301010000 - test_config.end_date=202301310000 - - test_config.samples=16 + # - test_config.samples=16 nodes: 1 - name: evaluation From 44cd8b9b510c9eebd29d72fee0d8326d72eafc66 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 7 May 2026 20:02:46 +0200 Subject: [PATCH 306/344] plot histograms --- config/evaluate/eval_config.yml | 3 + .../evaluate/plotting/plot_orchestration.py | 289 +++++++++++------- .../evaluate/plotting/plot_utils.py | 7 + .../weathergen/evaluate/plotting/plotter.py | 264 ++++++++++++---- .../evaluate/scores/score_orchestration.py | 7 +- 5 files changed, 401 insertions(+), 169 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index fc2d49c2a..1f0ae14b8 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -7,6 +7,9 @@ # animation_format: "gif" #options: "mp4", "gif" # dpi_val : 300 # fps: 2 +# n_bins: 50 #number of bins for histograms. +# log_x: true #use log scale for x axis in histograms. +# log_y: true #use log scale for y axis in histograms. # ERA5: # use_datashader: false # marker_size: 2 diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 79850df42..d34d1f9c4 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -9,7 +9,6 @@ """Plotting orchestration: parallel dispatch of per-sample maps, score maps, and summary plots.""" -import glob import logging from pathlib import Path @@ -45,7 +44,6 @@ _logger = logging.getLogger(__name__) - # --------------------------------------------------------------------------- # Score maps # --------------------------------------------------------------------------- @@ -252,17 +250,18 @@ def _scatter_plot_single( def _build_single_animation( - map_output_dir: Path, + output_dir: Path, run_id: str, tag: str, stream: str, - region: str, + region: str | None, var: str, sa: object, fsteps: list, image_format: str, animation_format: str, duration_ms: int, + prefix: str = "map", ) -> list[str]: """Build one GIF for a single (region, sample, variable) combination. @@ -271,32 +270,38 @@ def _build_single_animation( Returns the list of source frame paths that were assembled into the GIF (empty list if no frames were found). """ - image_paths: list[str] = [] - for fstep in fsteps: - parts = [ - "map", - run_id, - tag, - str(sa), - "*", - stream, - region, - var, - "fstep", - str(fstep).zfill(3), - ] - name = "_".join(filter(None, parts)) - fname = f"{map_output_dir.joinpath(name)}.{image_format}" - image_paths += glob.glob(fname) + # Both map and histogram filenames follow the same pattern: + # {prefix}_{run_id}_{tag}_{sample}_{valid_time}_{stream}_{region}_{var}_{fstep:03d} + # For all_samples histograms, valid_time is omitted. + # We match files by checking a fixed prefix and suffix, allowing any + # valid_time (or none) in between — no glob wildcards needed. + region_part = region if region else "" + head = "_".join(filter(None, [prefix, run_id, tag, str(sa)])) + tail = "_".join(filter(None, [stream, region_part, var])) + suffix = f".{image_format}" + fstep_strs = {str(f).zfill(3) for f in fsteps} + + if not output_dir.is_dir(): + return [] + + image_paths = sorted( + str(f) + for f in output_dir.iterdir() + if f.name.startswith(head + "_") + and f.name.endswith(suffix) + and f"_{tail}_" in f.name + and f.stem.rsplit("_", 1)[-1] in fstep_strs + ) if not image_paths: - _logger.debug(f"No images found for animation {var} sample {sa} region {region}") return [] - image_paths = sorted(image_paths) - out_path = ( - f"{map_output_dir}/animation_{run_id}_{tag}_{sa}_{stream}_{region}_{var}.{animation_format}" - ) + anim_parts = ["animation", run_id, tag, str(sa), stream] + if region: + anim_parts.append(region) + anim_parts.append(var) + out_path = f"{output_dir / '_'.join(filter(None, anim_parts))}.{animation_format}" + if animation_format.lower() == "mp4": frames = [imageio.imread(p) for p in image_paths] fps = 1000 / duration_ms if duration_ms > 0 else 2 @@ -327,6 +332,9 @@ def _dispatch_animations( ) -> list[str]: """Build GIF animations in parallel for all (region, sample, variable) combinations. + Animations are built for both maps and histograms — whichever image files + exist on disk will be picked up automatically. + Parameters ---------- plotter : Plotter @@ -340,13 +348,17 @@ def _dispatch_animations( Paths of all source frames that were assembled into GIFs. """ plotter.update_data_selection(select) - map_output_dir = plotter.get_map_output_dir(tag) duration_ms = int(1000 / plotter.fps) if plotter.fps > 0 else 400 + prefixes = [ + ("map", plotter.get_map_output_dir(tag)), + ("histogram", plotter.get_hist_output_dir()), + ] + tasks = [ { - "map_output_dir": map_output_dir, + "output_dir": output_dir, "run_id": plotter.run_id, "tag": tag, "stream": plotter.stream, @@ -357,7 +369,9 @@ def _dispatch_animations( "image_format": plotter.image_format, "animation_format": plotter.animation_format, "duration_ms": duration_ms, + "prefix": prefix, } + for prefix, output_dir in prefixes for region in plotter.regions for sa in samples for var in variables @@ -370,7 +384,7 @@ def _dispatch_animations( results = dispatch_parallel( calls, n_workers=get_num_workers(max_workers=max_workers), - backend="threading", + backend="loky", desc="Animations", ) return [p for r in results if r for p in r] @@ -395,7 +409,7 @@ def _plot_single_sample( plot_maps: bool, plot_bias: bool, plot_target: bool, - plot_histograms: bool, + plot_histograms: bool | str, maps_config: dict, bias_config: dict, ) -> None: @@ -418,11 +432,13 @@ def _plot_single_sample( if plot_bias and bias_data is not None and not bias_has_ens: plotter.create_maps_per_sample(bias_data, plot_chs, data_selection, "bias", bias_cfg) - for ens in ensemble: - has_ens = "ens" in preds.dims and ens != "mean" - preds_ens = preds.sel(ens=ens) if has_ens else preds - preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" - preds_name = "_".join(filter(None, ["preds", preds_tag])) + for ens in ensemble: + has_ens = "ens" in preds.dims and ens != "mean" + preds_ens = preds.sel(ens=ens) if has_ens else preds + preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" + preds_name = "_".join(filter(None, ["preds", preds_tag])) + + if plot_maps: plotter.create_maps_per_sample( preds_ens, plot_chs, data_selection, preds_name, maps_cfg ) @@ -434,10 +450,60 @@ def _plot_single_sample( bias_ens, plot_chs, data_selection, bias_tag, bias_cfg ) - if plot_histograms: - plotter.create_histograms_per_sample( - tars, preds_ens, plot_chs, data_selection, preds_tag - ) + if plot_histograms is True or plot_histograms == "per-sample": + plotter.create_histograms( + tars, + preds_ens, + plot_chs, + data_selection, + preds_name, + ranges=maps_config, + ) + + plotter.clean_data_selection() + + +def _plot_all_samples( + plotter_cfg: dict, + output_basedir: str, + tars: xr.DataArray, + preds: xr.DataArray, + bias_data: xr.DataArray | None, + fstep: int | str, + stream: str, + plot_chs: list[str], + ensemble: list, + plot_histograms: bool | str, + maps_config: dict, + bias_config: dict, +) -> None: + """Plot histograms across all samples for a single fstep. + + Unlike per-sample histograms, these aggregate all samples together. + The output filename uses 'global' instead of a sample id and omits the timestep. + """ + if not (plot_histograms is True or plot_histograms == "across-samples"): + return + + matplotlib.use("Agg") + plotter = Plotter(plotter_cfg, Path(output_basedir)) + + data_selection = {"sample": "all_samples", "stream": stream, "forecast_step": fstep} + + for ens in ensemble: + has_ens = "ens" in preds.dims and ens != "mean" + preds_ens = preds.sel(ens=ens) if has_ens else preds + preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" + preds_name = "_".join(filter(None, ["preds", preds_tag])) + + plotter.create_histograms( + tars, + preds_ens, + plot_chs, + data_selection, + preds_name, + ranges=maps_config, + ) plotter.clean_data_selection() @@ -465,14 +531,8 @@ def plot_data( stream_cfg = reader.get_stream(stream) plot_settings = stream_cfg.get("plotting", {}) - if not ( - plot_settings - and ( - plot_settings.get("plot_maps", False) - or plot_settings.get("plot_histograms", False) - or plot_settings.get("plot_animations", False) - ) - ): + plot_keys = ("plot_maps", "plot_histograms", "plot_animations") + if not plot_settings or not any(plot_settings.get(k, False) for k in plot_keys): return plotter_cfg = { @@ -482,9 +542,13 @@ def plot_data( "fig_size": global_plotting_opts.get("fig_size"), "fps": global_plotting_opts.get("fps", 2), "regions": global_plotting_opts.get("regions", ["global"]), + "log_x": global_plotting_opts.get("log_x", False), + "log_y": global_plotting_opts.get("log_y", False), + "n_bins": global_plotting_opts.get("n_bins", 50), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False) | plot_settings.get("plot_subtimesteps", False), } + plotter = Plotter(plotter_cfg, reader.runplot_dir) available_data = reader.check_availability(stream, mode="plotting") @@ -502,12 +566,16 @@ def plot_data( if not isinstance(plot_target, bool): raise TypeError("plot_target must be a boolean.") plot_histograms = plot_settings.get("plot_histograms", False) - if not isinstance(plot_histograms, bool): - raise TypeError("plot_histograms must be a boolean.") + if not isinstance(plot_histograms, bool) and plot_histograms not in { + "across-samples", + "per-sample", + }: + raise TypeError("plot_histograms must be true, false, 'across-samples', or 'per-sample'. ") plot_animations = plot_settings.get("plot_animations", False) if not isinstance(plot_animations, bool): raise TypeError("plot_animations must be a boolean.") + model_output = output_data if output_data is None: model_output = reader.get_data( stream, @@ -516,8 +584,6 @@ def plot_data( channels=available_data.channels, ensemble=available_data.ensemble, ) - else: - model_output = output_data da_tars = model_output.target da_preds = model_output.prediction @@ -530,7 +596,9 @@ def plot_data( plot_sample_set = set(available_data.samples) if available_data.samples is not None else None plot_channel_set = set(available_data.channels) if available_data.channels is not None else None + output_dir = str(reader.runplot_dir) output_fstep_keys = set(da_tars.keys()) + if plot_fstep_set is not None and output_fstep_keys - plot_fstep_set: zarr_fsteps = set(int(f) for f in reader.get_forecast_steps()) if plot_fstep_set == zarr_fsteps: @@ -551,16 +619,9 @@ def plot_data( if not isinstance(global_plotting_opts.get(stream), oc.DictConfig): global_plotting_opts[stream] = oc.DictConfig({}) - maps_config = common_ranges( - da_tars, da_preds, available_data.channels, global_plotting_opts[stream] - ) - bias_config = bias_ranges( - da_tars, da_preds, available_data.channels, global_plotting_opts[stream] - ) - - maps_config_dict = oc.OmegaConf.to_container(maps_config, resolve=True) - bias_config_dict = oc.OmegaConf.to_container(bias_config, resolve=True) - output_basedir = str(reader.runplot_dir) + _range_args = (da_tars, da_preds, available_data.channels, global_plotting_opts[stream]) + maps_config_dict = oc.OmegaConf.to_container(common_ranges(*_range_args), resolve=True) + bias_config_dict = oc.OmegaConf.to_container(bias_ranges(*_range_args), resolve=True) num_plot_workers = get_num_workers( check_process_headroom=True, @@ -568,6 +629,7 @@ def plot_data( ) tasks: list[dict] = [] + all_samples_tasks: list[dict] = [] for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): all_chs = list(np.atleast_1d(tars.channel.values)) plot_chs = ( @@ -589,11 +651,28 @@ def plot_data( bias_data = (preds - tars) if plot_bias else None + all_samples_tasks.append( + { + "plotter_cfg": plotter_cfg, + "output_basedir": output_dir, + "tars": tars, + "preds": preds, + "bias_data": bias_data, + "fstep": fstep, + "stream": stream, + "plot_chs": plot_chs, + "ensemble": list(available_data.ensemble), + "plot_histograms": plot_histograms, + "maps_config": maps_config_dict, + "bias_config": bias_config_dict, + } + ) + for sample in plot_samples: tasks.append( { "plotter_cfg": plotter_cfg, - "output_basedir": output_basedir, + "output_basedir": output_dir, "tars": tars, "preds": preds, "bias_data": bias_data, @@ -620,63 +699,51 @@ def plot_data( calls, n_workers=num_plot_workers, backend="loky", desc=f"Plotting {run_id} - {stream}" ) + if all_samples_tasks: + _logger.info( + f"Parallel plotting: dispatching {len(all_samples_tasks)} across-samples " + f"tasks using up to {num_plot_workers} loky workers." + ) + as_calls = [delayed(_plot_all_samples)(**t) for t in all_samples_tasks] + dispatch_parallel( + as_calls, + n_workers=num_plot_workers, + backend="loky", + desc=f"Across-samples plots {run_id} - {stream}", + ) + if plot_animations: - plotter = Plotter(plotter_cfg, reader.runplot_dir) last_fstep = list(da_tars.keys())[-1] - last_tars = da_tars[last_fstep] last_preds = da_preds[last_fstep] - all_chs = list(np.atleast_1d(last_tars.channel.values)) - plot_chs = ( - [ch for ch in all_chs if ch in plot_channel_set] - if plot_channel_set is not None - else all_chs - ) - all_samples = list(np.unique(last_tars.sample.values)) - plot_samples = ( - [s for s in all_samples if s in plot_sample_set] - if plot_sample_set is not None - else all_samples - ) - plot_fsteps = da_tars.keys() - data_selection = { - "sample": plot_samples[-1], - "stream": stream, - "forecast_step": last_fstep, - } + last_tars = da_tars[last_fstep] + has_ens = "ens" in last_preds.dims + + _sel = lambda items, allowed: [x for x in items if x in allowed] if allowed else items + plot_chs = _sel(list(np.atleast_1d(last_tars.channel.values)), plot_channel_set) + plot_samples = _sel(list(np.unique(last_tars.sample.values)), plot_sample_set) + max_wk = reader.eval_cfg.get("max_workers", None) + anim_samples = plot_samples + (["all_samples"] if plot_histograms else []) + anim_kw = dict( + plotter=plotter, + samples=anim_samples, + fsteps=da_tars.keys(), + variables=plot_chs, + max_workers=max_wk, + select={"sample": plot_samples[-1], "stream": stream, "forecast_step": last_fstep}, + ) + + tags: list[str] = [] for ens in available_data.ensemble: - preds_name = "preds" if "ens" not in last_preds.dims else f"preds_ens_{ens}" - _dispatch_animations( - plotter, - plot_samples, - plot_fsteps, - plot_chs, - data_selection, - preds_name, - max_workers=max_wk, - ) + tags.append("preds" if not has_ens else f"preds_ens_{ens}") if plot_target: - _dispatch_animations( - plotter, - plot_samples, - plot_fsteps, - plot_chs, - data_selection, - "targets", - max_workers=max_wk, - ) + tags.append("targets") if plot_bias: for ens in available_data.ensemble: - bias_tag = "bias" if "ens" not in last_preds.dims else f"bias_ens_{ens}" - _dispatch_animations( - plotter, - plot_samples, - plot_fsteps, - plot_chs, - data_selection, - bias_tag, - max_workers=max_wk, - ) + tags.append("bias" if not has_ens else f"bias_ens_{ens}") + + for tag in tags: + _dispatch_animations(**anim_kw, tag=tag) # --------------------------------------------------------------------------- diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index b61b7813f..dfb2b3e8f 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -136,6 +136,13 @@ def clean_label(s: str) -> str: return re.sub(r"[_\-]+", " ", s).strip() +def filter_set(items: list, allowed: set | None) -> list: + """Return *items* filtered to *allowed*, or all items if *allowed* is ``None``.""" + if allowed is None: + return items + return [x for x in items if x in allowed] + + class DefaultMarkerSize: """ Utility class for managing default configuration values, such as marker sizes diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 961de7c10..97e0840a6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -11,6 +11,7 @@ import logging import os import warnings +from dataclasses import dataclass from pathlib import Path import cartopy @@ -23,6 +24,8 @@ from astropy_healpix import HEALPix as HEALPixGrid from cartopy.io import DownloadWarning from matplotlib.collections import LineCollection +from scipy.stats import skew +from scipy.stats import wasserstein_distance as wd try: import datashader as ds @@ -54,7 +57,7 @@ def _download_cartopy_off(enabled: bool) -> None: """Enable/disable blocking Cartopy downloads by elevating DownloadWarning to error.""" if enabled: warnings.filterwarnings("error", category=DownloadWarning) - _logger.info( + _logger.debug( "Auto-downloads are blocked for cartopy; only local cartopy data will be used." ) else: @@ -68,6 +71,39 @@ def _download_cartopy_off(enabled: bool) -> None: _logger.debug(f"Taking cartopy paths from {work_dir}") +@dataclass +class DistStats: + """Summary statistics for a 1-D distribution.""" + + count: int + min: float + max: float + mean: float + median: float + std: float + skewness: float + + @classmethod + def from_array(cls, v: np.typing.NDArray) -> "DistStats": + v = np.asarray(v).ravel() + return cls( + count=len(v), + min=float(np.min(v)), + max=float(np.max(v)), + mean=float(np.mean(v)), + median=float(np.median(v)), + std=float(np.std(v)), + skewness=float(skew(v, nan_policy="omit")), + ) + + def summary(self, label: str) -> str: + return ( + f"{label:8s} N={self.count} min={self.min:.3g} max={self.max:.3g} " + f"mean={self.mean:.3g} med={self.median:.3g} " + f"std={self.std:.3g} skew={self.skewness:.3g}" + ) + + class Plotter: """ Contains all basic plotting functions. @@ -94,7 +130,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | It can also be set later via update_data_selection. """ - _logger.info(f"Taking cartopy paths from {work_dir}") + _logger.debug(f"Taking cartopy paths from {work_dir}") self.image_format = plotter_cfg.get("image_format") self.animation_format = plotter_cfg.get("animation_format") @@ -102,6 +138,9 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.fig_size = plotter_cfg.get("fig_size") self.fps = plotter_cfg.get("fps") self.regions = plotter_cfg.get("regions") + self.log_x = plotter_cfg.get("log_x", False) + self.log_y = plotter_cfg.get("log_y", False) + self.n_bins = plotter_cfg.get("n_bins", 50) _download_cartopy_off(enabled=True) self.plot_subtimesteps = plotter_cfg.get( "plot_subtimesteps", False @@ -137,6 +176,10 @@ def update_data_selection(self, select: dict): _logger.warning("No sample in the selection. Might lead to unexpected results.") else: self.sample = select["sample"] + # "all_samples" is a proxy for across-samples aggregation; + # remove it from self.select so it won't be used in .sel() + if select["sample"] == "all_samples": + self.select.pop("sample") if "stream" not in select: _logger.warning("No stream in the selection. Might lead to unexpected results.") @@ -190,13 +233,14 @@ def select_from_da(self, da: xr.DataArray, selection: dict) -> xr.DataArray: da = da.sel({key: value}) return da - def create_histograms_per_sample( + def create_histograms( self, target: xr.DataArray, preds: xr.DataArray, variables: list, select: dict, tag: str = "", + ranges: dict | None = None, ) -> list[str]: """ Plot histogram of target vs predictions for each variable and valid time in the DataArray. @@ -222,44 +266,64 @@ def create_histograms_per_sample( self.update_data_selection(select) - # Basic map output directory for this stream - hist_output_dir = self.out_plot_basedir / self.stream / "histograms" + # Basic histogram output directory for this stream + hist_output_dir = self.get_hist_output_dir() if not os.path.exists(hist_output_dir): _logger.info(f"Creating dir {hist_output_dir}") os.makedirs(hist_output_dir, exist_ok=True) - for var in variables: - select_var = self.select | {"channel": var} - - targ, prd = ( - self.select_from_da(target, select_var), - self.select_from_da(preds, select_var), - ) + for region in self.regions: + if region != "global": + bbox = RegionBoundingBox.from_region_name(region) + reg_target = bbox.apply_mask(target) + reg_preds = bbox.apply_mask(preds) + else: + reg_target = target + reg_preds = preds - # Remove NaNs - targ = targ.dropna(dim="ipoint") - prd = prd.dropna(dim="ipoint") - assert targ.size > 0, "Data array must not be empty or contain only NAs" - assert prd.size > 0, "Data array must not be empty or contain only NAs" + for var in variables: + select_var = self.select | {"channel": var} - if self.plot_subtimesteps: - ntimes_unique = len(np.unique(targ.valid_time)) - _logger.info( - f"Creating histograms for {ntimes_unique} valid times of variable {var}." + targ, prd = ( + self.select_from_da(reg_target, select_var), + self.select_from_da(reg_preds, select_var), ) - groups = zip(targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False) - else: - _logger.info(f"Plotting histogram for all valid times of {var}") + # Remove NaNs + targ = targ.dropna(dim="ipoint") + prd = prd.dropna(dim="ipoint") + assert targ.size > 0, "Data array must not be empty or contain only NAs" + assert prd.size > 0, "Data array must not be empty or contain only NAs" + + if self.plot_subtimesteps and str(self.sample) != "all_samples": + ntimes_unique = len(np.unique(targ.valid_time)) + _logger.debug( + f"Creating histograms for {ntimes_unique} valid times of variable {var}." + ) - groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time + groups = zip( + targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False + ) + else: + _logger.debug(f"Plotting histogram for all valid times of {var}") - for (valid_time, targ_t), (_, prd_t) in groups: - if valid_time is not None: - _logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") - name = self.plot_histogram(targ_t, prd_t, hist_output_dir, var, tag=tag) - plot_names.append(name) + groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time + + for (valid_time, targ_t), (_, prd_t) in groups: + if valid_time is not None: + _logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") + var_range = ranges.get(var, {}) if ranges else {} + name = self.plot_histogram( + targ_t, + prd_t, + hist_output_dir, + var, + tag=tag, + region=region, + xlim=(var_range.get("vmin"), var_range.get("vmax")), + ) + plot_names.append(name) self.clean_data_selection() @@ -272,6 +336,8 @@ def plot_histogram( hist_output_dir: Path, varname: str, tag: str = "", + region: str = "", + xlim: tuple | None = None, ) -> str: """ Plot a histogram comparing target and prediction data for a specific variable. @@ -294,47 +360,121 @@ def plot_histogram( Name of the saved plot file. """ - # Get common bin edges - vals = np.concatenate([target_data, pred_data]) - bins = np.histogram_bin_edges(vals, bins=50) - - # Plot histograms - plt.hist(target_data, bins=bins, alpha=0.7, label="Target") - plt.hist(pred_data, bins=bins, alpha=0.7, label="Prediction") + tar_vals = np.asarray(target_data).ravel() + prd_vals = np.asarray(pred_data).ravel() + + # Get common bin edges — use fixed xlim range if provided for consistency + xmin, xmax = xlim if xlim else (None, None) + # Fall back to data-derived bounds if either limit is missing + if xmin is None or xmax is None: + vals = np.concatenate([tar_vals, prd_vals]) + if xmin is None: + xmin = float(np.nanmin(vals)) + if xmax is None: + xmax = float(np.nanmax(vals)) + # Add 5% margin on each side so tails are clearly visible + margin = (xmax - xmin) * 0.05 + xmin -= margin + xmax += margin + bins = np.linspace(xmin, xmax, self.n_bins + 1) + + # Compute histograms + target_counts, _ = np.histogram(tar_vals, bins=bins) + pred_counts, _ = np.histogram(prd_vals, bins=bins) + bin_centers = (bins[:-1] + bins[1:]) / 2 + + color_tar = "black" + color_pred = "#00897B" # teal / green-blue + + # Create figure with two subplots: histogram + ratio + fig, (ax_hist, ax_ratio) = plt.subplots( + 2, + 1, + sharex=True, + figsize=self.fig_size or (8, 6), + gridspec_kw={"height_ratios": [3, 1], "hspace": 0.05}, + ) - # set labels and title - plt.xlabel(f"Variable: {varname}") - plt.ylabel("Frequency") - plt.title( - f"Histogram of Target and Prediction: {self.stream}, {varname} : " - f"fstep = {self.fstep:03}" + # Upper panel: histogram curves + ax_hist.plot( + bin_centers, target_counts, alpha=0.7, label="Target", linewidth=1.5, color=color_tar + ) + ax_hist.plot( + bin_centers, pred_counts, alpha=0.7, label="Prediction", linewidth=1.5, color=color_pred + ) + ax_hist.set_ylabel("Frequency") + ax_hist.set_title(f"{self.stream}, {varname} : fstep = {self.fstep:03}") + ax_hist.legend(frameon=False) + if self.log_y: + ax_hist.set_yscale("log") + ax_hist.grid(True, linestyle="--", alpha=0.5) + + # Lower panel: ratio (prediction / target) + with np.errstate(divide="ignore", invalid="ignore"): + ratio = np.where(target_counts > 0, pred_counts / target_counts, np.nan) + ax_ratio.plot(bin_centers, ratio, linewidth=1.2, color=color_pred) + ax_ratio.axhline(1.0, linestyle="--", color="gray", linewidth=0.8) + ax_ratio.set_ylabel("Pred / Target") + ax_ratio.set_xlabel(f"Variable: {varname}") + ax_ratio.set_ylim(0, 2) + ax_ratio.grid(True, linestyle="--", alpha=0.5) + + if self.log_x: + ax_hist.set_xscale("log") + ax_ratio.set_xscale("log") + ax_ratio.set_xlim(xmin, xmax) + + t_s = DistStats.from_array(tar_vals) + p_s = DistStats.from_array(prd_vals) + + # Wasserstein distance + w_dist = wd(tar_vals, prd_vals) + + stat_text = ( + f"Wasserstein distance: {w_dist:.4g}\n{t_s.summary('Target:')}\n{p_s.summary('Pred:')}" ) - plt.legend(frameon=False) - valid_time = ( - target_data["valid_time"][0] - .values.astype("datetime64[m]") - .astype(datetime.datetime) - .strftime("%Y-%m-%dT%H%M") + fig.text( + 0.5, + -0.02, + stat_text, + ha="center", + va="top", + fontsize=7, + family="monospace", + bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5), ) - # TODO: make this nicer + # For "all_samples" (across-samples) histograms, omit the valid_time from the name + is_global = str(self.sample) == "all_samples" + + if is_global: + valid_time = None + else: + valid_time = ( + target_data["valid_time"][0] + .values.astype("datetime64[m]") + .astype(datetime.datetime) + .strftime("%Y-%m-%dT%H%M") + ) + parts = [ "histogram", - self.run_id, - tag, + str(self.run_id), + str(tag) if tag else "", str(self.sample), valid_time, - self.stream, + str(self.stream), + region if region else "", varname, - str(self.fstep).zfill(3), + f"{self.fstep:03d}", ] name = "_".join(filter(None, parts)) fname = hist_output_dir / f"{name}.{self.image_format}" _logger.debug(f"Saving histogram to {fname}") - plt.savefig(fname, bbox_inches="tight") - plt.close() + fig.savefig(fname, bbox_inches="tight") + plt.close(fig) return name @@ -695,7 +835,7 @@ def _build_map_filename(self, varname: str, regionname: str, tag: str, data: xr. parts.append(varname) if self.fstep is not None: - parts.extend(["fstep", f"{self.fstep:03d}"]) + parts.append(f"{self.fstep:03d}") return "_".join(filter(None, parts)) @@ -902,6 +1042,16 @@ def get_map_output_dir(self, tag): """ return self.out_plot_basedir / self.stream / "maps" / tag + def get_hist_output_dir(self): + """Return the output directory path for histogram plots. + + Returns + ------- + Path + Resolved directory path: ``//histograms``. + """ + return self.out_plot_basedir / self.stream / "histograms" + def get_map_title(self, var, valid_time, data): """Build the title string for a map plot. diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py index aff40b9d9..82bc8a026 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py @@ -49,6 +49,7 @@ def _score_single_fstep( bbox: "RegionBoundingBox", metrics: dict, group_by_coord: str | None, + agg_dims: str | list[str] = "ipoint", ) -> tuple[int, xr.DataArray, dict[tuple[int, str], dict]] | None: """Score all metrics for one fstep in one region. Stateless, thread-safe. @@ -89,7 +90,7 @@ def _score_single_fstep( score = get_score( score_data, metric, - agg_dims="ipoint", + agg_dims=agg_dims, group_by_coord=group_by_coord, parameters=parameters, ) @@ -176,6 +177,7 @@ def calc_scores_per_stream( aligned_clim_data = get_climatology(reader, da_tars, stream) max_workers = reader.eval_cfg.get("max_workers", None) + agg_dims = reader.eval_cfg.get("agg_dims", "ipoint") for region in regions: bbox = RegionBoundingBox.from_region_name(region) @@ -194,6 +196,7 @@ def calc_scores_per_stream( bbox, metrics, max_workers, + agg_dims, ) store_metrics_for_region( @@ -227,6 +230,7 @@ def compute_scores_for_region( bbox: "RegionBoundingBox", metrics: dict, max_workers: int | None, + agg_dims: str | list[str] = "ipoint", ) -> tuple[list, dict]: """Dispatch parallel scoring for all fsteps in one region. @@ -289,6 +293,7 @@ def compute_scores_for_region( bbox, metrics, group_by_coord, + agg_dims, ) for fstep, tars_fs, preds_fs, preds_next, tars_next, climatology in fstep_tasks ] From 045d311f67bdb744807662c3b8be6fb9e8f10f07 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 7 May 2026 20:09:17 +0200 Subject: [PATCH 307/344] add agg_dim --- config/evaluate/eval_config_default.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/evaluate/eval_config_default.yml b/config/evaluate/eval_config_default.yml index 92921edd6..7c7dfca4d 100644 --- a/config/evaluate/eval_config_default.yml +++ b/config/evaluate/eval_config_default.yml @@ -37,6 +37,7 @@ evaluation: add_grid: false score_cards: false bar_plots: false + #agg_dims: ["ipoint"] #----> NOTE: advanced! Handle with care. This will average the scores across the specified list of dimensions. Supported dimensions: "ipoint", "sample", "forecast_step", "ensemble". Use with caution, as it can hide important details about the model performance. default_streams: From 518c9094c4405be8c9c724dfd26d37638a340687 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 8 May 2026 11:09:52 +0200 Subject: [PATCH 308/344] Adjusted configs for forecast backbone model --- config/config_forecasting.yml | 11 +- config/streams/era5_1deg_forecasting/era5.yml | 132 +++++++++--------- 2 files changed, 72 insertions(+), 71 deletions(-) diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 4f1ff1499..a3f32110f 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 2048 +ae_local_dim_embed: 512 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -25,7 +25,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 +ae_global_dim_embed: 512 ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 @@ -63,10 +63,11 @@ fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm fe_impute_latent_noise_std: 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 +fe_diffusion_model: False healpix_level: 5 -rope_2D: False +rope_2D: True with_mixed_precision: True with_flash_attention: True @@ -184,8 +185,8 @@ training_config: forecast : time_step: 06:00:00 - offset: 1 - num_steps: 3 + offset: 0 + num_steps: 2 policy: "fixed" diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 569922fb9..56cbae702 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -37,69 +37,69 @@ ERA5 : pred_head : ens_size : 1 num_layers : 1 - # channel_weights : - # q_50: 0.2 - # q_100: 0.23 - # q_150: 0.26 - # q_200: 0.29 - # q_250: 0.33 - # q_300: 0.36 - # q_400: 0.42 - # q_500: 0.48 - # q_600: 0.55 - # q_700: 0.61 - # q_850: 0.71 - # q_925: 0.75 - # q_1000: 0.8 - # t_50: 0.2 - # t_100: 0.23 - # t_150: 0.26 - # t_200: 0.29 - # t_250: 0.33 - # t_300: 0.36 - # t_400: 0.42 - # t_500: 0.48 - # t_600: 0.55 - # t_700: 0.61 - # t_850: 0.71 - # t_925: 0.75 - # t_1000: 0.8 - # u_50: 0.2 - # u_100: 0.23 - # u_150: 0.26 - # u_200: 0.29 - # u_250: 0.33 - # u_300: 0.36 - # u_400: 0.42 - # u_500: 0.48 - # u_600: 0.55 - # u_700: 0.61 - # u_850: 0.71 - # u_925: 0.75 - # u_1000: 0.8 - # v_50: 0.2 - # v_100: 0.23 - # v_150: 0.26 - # v_200: 0.29 - # v_250: 0.33 - # v_300: 0.36 - # v_400: 0.42 - # v_500: 0.48 - # v_600: 0.55 - # v_700: 0.61 - # v_850: 0.71 - # v_925: 0.75 - # v_1000: 0.8 - # z_50: 0.2 - # z_100: 0.23 - # z_150: 0.26 - # z_200: 0.29 - # z_250: 0.33 - # z_300: 0.36 - # z_400: 0.42 - # z_500: 0.48 - # z_600: 0.55 - # z_700: 0.61 - # z_850: 0.71 - # z_925: 0.75 - # z_1000: 0.8 \ No newline at end of file + channel_weights : + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 \ No newline at end of file From 2a29edfdb5cb76845f2a37968054ebde63fd4168 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 12 May 2026 10:23:20 +0200 Subject: [PATCH 309/344] Working era5 distribution config --- config/config_diffusion.yml | 7 ++++--- src/weathergen/model/diffusion.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 651f619f2..5cfd9c810 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 80 sigma_data: 0.63 rho: 7 -p_mean: -1.2 +p_mean: 0.5 p_std: 1.2 healpix_level: 5 @@ -110,9 +110,10 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 # load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 # load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 -load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 # load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 # load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 @@ -262,7 +263,7 @@ validation_config: # parameters for validation samples that are written to disk output : { # number of samples that are written - num_samples: 1, + num_samples: 0, # write samples in normalized model space normalized_samples: False, # output streams to write; default all diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index fc565c639..449685259 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -127,7 +127,7 @@ def forward( elif self.cf.stage == "inference": if fstep is None: raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") - + self.cur_token = tokens.detach() if tokens is not None else None return self.inference_forward( fstep=fstep, num_steps=num_steps, From 12a1a9a38cf5b54beaed23e116d8fa2ef001f0c5 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 12 May 2026 10:26:20 +0200 Subject: [PATCH 310/344] Adjust sigma_data and lr_max for diffusion training --- config/config_diffusion.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 5cfd9c810..b4ecb0ea3 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -70,7 +70,7 @@ frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 sigma_max: 80 -sigma_data: 0.63 +sigma_data: 1.0 rho: 7 p_mean: 0.5 p_std: 1.2 @@ -187,7 +187,7 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 #5e-5 - lr_max: 7e-6 #1e-4 + lr_max: 1e-5 #1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 64 From 67cb91c9ec81460b5c9e728d2cb9c971b1a93152 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 15 May 2026 15:48:01 +0200 Subject: [PATCH 311/344] Add geoinfo to stream config and ensure identical target samples during inference --- config/streams/era5_1deg_forecasting/era5.yml | 2 +- .../evaluate/plotting/plot_orchestration.py | 11 ++++++++--- src/weathergen/run_train.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index fce350a59..ae8e2da53 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -13,7 +13,7 @@ ERA5 : stream_id : 0 source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] - # geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index d34d1f9c4..b28ad4adc 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -669,13 +669,18 @@ def plot_data( ) for sample in plot_samples: + # Pre-slice to this sample before serializing for the worker, to avoid + # sending the full per-fstep DataArray (all samples) to each loky process. + tars_s = tars.sel(sample=sample) + preds_s = preds.sel(sample=sample) + bias_s = bias_data.sel(sample=sample) if bias_data is not None else None tasks.append( { "plotter_cfg": plotter_cfg, "output_basedir": output_dir, - "tars": tars, - "preds": preds, - "bias_data": bias_data, + "tars": tars_s, + "preds": preds_s, + "bias_data": bias_s, "sample": sample, "fstep": fstep, "stream": stream, diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index e7703925f..b4e36bc0f 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -96,6 +96,7 @@ def run_inference(args): cli_overwrite, ) cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + cf.data_loading.rng_seed = 42 cf.stage = args.stage devices = Trainer.init_torch() @@ -169,7 +170,6 @@ def run_train(args): args.private_config, None, None, args.base_config, *args.config, cli_overwrite ) cf = config.set_run_id(cf, args.run_id, False) - cf.data_loading.rng_seed = int(time.time()) cf.stage = args.stage mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) From 4e68c022e215b0920d511ab3319a5b61ef22d0e6 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 15 May 2026 17:24:41 +0200 Subject: [PATCH 312/344] Add missing cf.data_laoding.rng_seed back in --- src/weathergen/run_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index b4e36bc0f..a1ad4d1c3 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -170,6 +170,7 @@ def run_train(args): args.private_config, None, None, args.base_config, *args.config, cli_overwrite ) cf = config.set_run_id(cf, args.run_id, False) + cf.data_loading.rng_seed = int(time.time()) cf.stage = args.stage mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) From 1abbff173f0f28439fbf648571db4fbcd6633db0 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 18 May 2026 12:44:54 +0200 Subject: [PATCH 313/344] Update diffusion config to swiglu chkpt and p_mean=1.5 --- config/config_diffusion.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 7123024b9..5567d4df1 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -72,7 +72,7 @@ sigma_min: 0.002 sigma_max: 80 sigma_data: 1.0 rho: 7 -p_mean: 0.5 +p_mean: 1.5 p_std: 1.2 healpix_level: 5 @@ -119,8 +119,8 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 # load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline # load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline -load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline -# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, swiglu xsa, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa, diffusion-full-pipeline # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 From 366ae77c164349994d8c6768a5c987914daef2e9 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 18:33:35 +0200 Subject: [PATCH 314/344] bug fix --- src/weathergen/model/layers.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 182863b39..dc5d7d026 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -134,9 +134,22 @@ def forward(self, *args): elif len(args) > 2: ada_ln_aux = args[-1] else: - assert len(args) == 3, "DIT gets 3 args (no conditioning implemented yet)" - noise_emb = args[-1] - x = self.lnorm(x) + if self.dit_is_cond: + assert len(args) == 4, "DIT with cond gets 4 args" + ada_ln_aux = args[-1] + noise_emb = args[-2] + else: + assert len(args) == 3, "DIT with cond gets 3 args" + ada_ln_aux = args[-1] + + + if self.is_dit: + if self.dit_is_cond: + assert ada_ln_aux is not None, "Need auxiliary input for conditional DIT" + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + x = self.lnorm(x) + cond_gate = 1 assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" x, gate = self.noise_conditioning(x, noise_emb) From dc6d82d3f4a0633d9cd5fb1cabbca255495e38f9 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 19:14:49 +0200 Subject: [PATCH 315/344] bug fix --- src/weathergen/model/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index dc5d7d026..b447982d0 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -139,8 +139,8 @@ def forward(self, *args): ada_ln_aux = args[-1] noise_emb = args[-2] else: - assert len(args) == 3, "DIT with cond gets 3 args" - ada_ln_aux = args[-1] + assert len(args) == 3, "DIT without cond gets 3 args" + noise_emb = args[-1] if self.is_dit: From aaeb0733205b624265251e0055bbe40c684eaa60 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 22 Apr 2026 19:14:58 +0200 Subject: [PATCH 316/344] config change --- config/config_diffusion.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 5567d4df1..1ff104830 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,6 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True +fe_diffusion_model_conditioning: "date_time" # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) From bce064de308983d52ff760fcf692a2e55cf58d1e Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 23 Apr 2026 12:27:55 +0200 Subject: [PATCH 317/344] plot config --- config/runs_plot_train.yml | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index c14727adf..fd052357e 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,18 +1,36 @@ train: plot: + nol9pfdg: + slurm_id: 387094 + description: "bug-fix non-cond, with sigma_data=0.63" + imqzsbte: + slurm_id: 387095 + description: "bug-fix non-cond" + f8nd1c60: + slurm_id: 387097 + description: "bug-fix cond-branch w/ cond" + ux8yjktb: + slurm_id: 387095 + description: "bug-fix cond-branch w/ non-cond" + xxkmgsne: + slurm_id: 0 + description: "bug-fix cond-branch w/ non-cond, lr_max=5e-6" + jwexz9y4: + slurm_id: 0 + description: "bug-fix cond-branch w/ non-cond, lr_max=2.5e-6" u7etjsm0: slurm_id: 385058 - description: "ERA5, lr_start=1e-6, lr_max=1e-5" + description: "old ERA5, lr_start=1e-6, lr_max=1e-5" mot8sfay: slurm_id: 385060 - description: "ERA5, lr_start=1e-6, lr_max=7e-6" - zhon45xy: - slurm_id: 385064 - description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=1e-5" - yimje7g3: - slurm_id: 385062 - description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=7e-6" + description: "old ERA5, lr_start=1e-6, lr_max=7e-6" + # zhon45xy: + # slurm_id: 385064 + # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=1e-5" + # yimje7g3: + # slurm_id: 385062 + # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=7e-6" # bpeh160r: # slurm_id: 381190 # description: "single samples, lr_start=1e-6, lr_max=1e-6" From 77f6f4398eb27120d19153ec0af93f21aaf7eb45 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 4 May 2026 12:43:18 +0200 Subject: [PATCH 318/344] plot config --- config/runs_plot_train.yml | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index fd052357e..b5092f8a1 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,24 +1,27 @@ train: plot: + r8vzykrm: + slurm_id: 387093 + description: "bug-fix non-cond-branch, with sigma_data=0.63" nol9pfdg: slurm_id: 387094 - description: "bug-fix non-cond, with sigma_data=0.63" - imqzsbte: - slurm_id: 387095 - description: "bug-fix non-cond" + description: "bug-fix cond-branch w/o cond, with sigma_data=0.63" + # imqzsbte: + # slurm_id: 387095 + # description: "bug-fix non-cond" f8nd1c60: slurm_id: 387097 description: "bug-fix cond-branch w/ cond" - ux8yjktb: - slurm_id: 387095 - description: "bug-fix cond-branch w/ non-cond" - xxkmgsne: - slurm_id: 0 - description: "bug-fix cond-branch w/ non-cond, lr_max=5e-6" - jwexz9y4: - slurm_id: 0 - description: "bug-fix cond-branch w/ non-cond, lr_max=2.5e-6" + # ux8yjktb: + # slurm_id: 387095 + # description: "bug-fix cond-branch w/ non-cond" + # xxkmgsne: + # slurm_id: 0 + # description: "bug-fix cond-branch w/ non-cond, lr_max=5e-6" + # jwexz9y4: + # slurm_id: 0 + # description: "bug-fix cond-branch w/ non-cond, lr_max=2.5e-6" u7etjsm0: slurm_id: 385058 description: "old ERA5, lr_start=1e-6, lr_max=1e-5" From 140d1ca23b350a3943e2ff885d69802eaa221ceb Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 4 May 2026 13:07:29 +0200 Subject: [PATCH 319/344] update stage handling in diffusion --- src/weathergen/model/diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 449685259..db4ea75a9 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -100,6 +100,7 @@ def forward( ValueError: If required arguments are missing for current mode """ # called during training in training mode + # called during training in training mode if self.training: if tokens is None or fstep is None or meta_info is None: raise ValueError( From 825f841fff19818e9e9955262b06748d807fdc8e Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 7 May 2026 20:48:44 +0200 Subject: [PATCH 320/344] re-implement conditioning, update adalayernorm and embedding function --- config/config_diffusion.yml | 1 + .../datasets/multi_stream_data_sampler.py | 25 +++++ src/weathergen/model/attention.py | 33 +++++-- src/weathergen/model/diffusion.py | 93 +++++++++++++------ src/weathergen/model/engines.py | 14 ++- src/weathergen/model/layers.py | 13 ++- src/weathergen/model/model.py | 8 +- src/weathergen/model/norms.py | 52 +++++------ 8 files changed, 165 insertions(+), 74 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 1ff104830..971e25edc 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -67,6 +67,7 @@ fe_impute_latent_noise_std: 0.0 # 1e-4 forecast_att_dense_rate: 1.0 with_step_conditioning: True # False # Diffusion related parameters +diffusion_conditioning_embed_dim: 32 frequency_embedding_dim: 256 embedding_dim: 512 sigma_min: 0.002 diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index cae76d8d2..5b888a9e5 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -97,6 +97,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.streams = cf.streams self.rank = cf.rank self.world_size = cf.world_size + self.diffusion_model_conditioning = cf.fe_diffusion_model_conditioning self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False) # initialise healpic @@ -722,6 +723,13 @@ def _get_batch(self, idx: int, num_forecast_steps: int): ) target_metadata = target_masks.metadata[tidx] + # Get first target step's times (using self.output_offset as the first output step index) + if self.diffusion_model_conditioning == "date_time": + target_times_array = sdata.target_times_raw[self.output_offset] + target_metadata.add_params({'timestamp': ( + target_times_array[0] if len(target_times_array) > 0 else None + )}) + # also want to add the mask to the metadata target_metadata.mask = target_mask # Map target to all source students @@ -735,6 +743,23 @@ def _get_batch(self, idx: int, num_forecast_steps: int): target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) + #add target times in source for diffusion model date/time conditioning + if self.diffusion_model_conditioning == "date_time": + #TODO: Might need upgrading fro num_samples > 1 + + # Assert singular source and target samples + assert len(batch.source_samples.samples) == 1, "Only single source sample supported for diffusion model conditioning." + assert len(batch.target_samples.samples) == 1, "Only single target sample supported for diffusion model conditioning." + + source_sample = batch.source_samples.samples[0] + target_sample = batch.target_samples.samples[0] + + # Copy target timestamps to source metadata for all streams + for stream_name in [s["name"] for s in self.streams]: + if stream_name in target_sample.meta_info and stream_name in source_sample.meta_info: + target_timestamp = target_sample.meta_info[stream_name].params.get('timestamp') + source_sample.meta_info[stream_name].add_params({'timestamp': target_timestamp}) + return batch def __iter__(self) -> ModelBatch: diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index ad5f37f09..f431f51d8 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.layers import LinearNormConditioning -from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm from weathergen.model.positional_encoding import rotary_pos_emb_2d """ @@ -248,6 +248,7 @@ def __init__( attention_dtype=torch.bfloat16, with_2d_rope=False, is_dit=False, + dit_is_cond=False, use_xsa=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -269,10 +270,13 @@ def __init__( norm = RMSNorm self.is_dit = is_dit + self.dit_is_cond = dit_is_cond if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" assert dim_aux is None, "conditioning not yet implemented for DIT attention" assert with_residual, "DIT attention should always have residual connection" - self.lnorm = norm(dim_embed, eps=norm_eps) + self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) self.noise_conditioning = LinearNormConditioning( latent_space_dim=dim_embed, dtype=attention_dtype ) @@ -317,8 +321,13 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): # Handle ada_ln_aux conditioning if self.is_dit: - x = self.lnorm(x) - x, gate = self.noise_conditioning(x, emb) + if self.dit_is_cond: + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + x = self.lnorm(x) + cond_gate = 1 + x, noise_gate = self.noise_conditioning(x, emb) + gate = cond_gate * noise_gate else: x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) @@ -586,6 +595,7 @@ def __init__( attention_dtype=torch.bfloat16, with_2d_rope=False, is_dit=False, # should only be True for diffusion model + dit_is_cond = False, # whether the attention is used for conditioning in the diffusion model (as opposed to denoising). Should only be True for cross attention layers in the diffusion model, and will control whether ada_ln_aux is applied to the input or output of the attention layer use_xsa=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -607,11 +617,13 @@ def __init__( norm = RMSNorm self.is_dit = is_dit + self.dit_is_cond = dit_is_cond if is_dit: - assert dim_aux is None, "conditioning not yet implemented for DIT attention" + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" assert with_residual, "DIT attention should always have residual connection" - self.lnorm = norm(dim_embed, eps=norm_eps) + self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) self.noise_conditioning = LinearNormConditioning( latent_space_dim=dim_embed, dtype=attention_dtype ) # TODO: Do I need to pass dtype? @@ -650,8 +662,13 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None): # Handle ada_ln_aux conditioning if self.is_dit: - x = self.lnorm(x) - x, gate = self.noise_conditioning(x, emb) + if self.dit_is_cond: + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + x = self.lnorm(x) + cond_gate = 1 + x, noise_gate = self.noise_conditioning(x, emb) + gate = cond_gate * noise_gate else: x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index db4ea75a9..fe3e411df 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -157,7 +157,10 @@ def training_forward( self.cur_token = tokens.detach() - c = None + if self.cf.fe_diffusion_model_conditioning == "date_time": + c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) + else: + c = None y = tokens @@ -200,6 +203,8 @@ def denoise( # Precondition input and feed through network x = self.preconditioner.precondition(x, c) # currently does nothing + if self.cf.fe_diffusion_model_conditioning == "date_time": + c = self.datetime_embedder(c).to(x.device) return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=c @@ -231,6 +236,10 @@ def inference_forward( # Extract conditioning from meta_info (same as training_forward) c = None + if self.cf.fe_diffusion_model_conditioning == "date_time": + c = meta_info["ERA5"].params["timestamp"] + + # Sample pure noise (assuming single batch element for now) # torch.manual_seed(42) x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") @@ -350,6 +359,7 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: import matplotlib matplotlib.use("Agg") + import os import matplotlib.pyplot as plt steps = list(range(len(track["sigma"]))) @@ -465,18 +475,18 @@ class DateTimeEncoder(torch.nn.Module): Input shape: scalar or any tensor shape (...) Output shape: (..., 32) — 8 frequencies × 4 components (cos/sin per signal) - Output structure for k=1..8: - [cos(2πk·doy/365.25), sin(2πk·doy/365.25), cos(k·t), sin(k·t)] + Output structure for k=1..num_frequencies: + [cos(2πk·doy_frac), sin(2πk·doy_frac), cos(2πk·tod_frac), sin(2πk·tod_frac)] where: - - doy = day of year (0-365.25) - - t = 2π·seconds_of_day/86400 (time of day in radians, UTC) + - doy_frac = day_of_year / days_in_year + - tod_frac = seconds_of_day / 86400.0 """ def __init__(self): super().__init__() self.num_frequencies = 8 - def forward(self, timestamp: np.ndarray) -> torch.Tensor: + def forward(self, timestamp: np.ndarray | np.datetime64) -> torch.Tensor: """ Encode numpy datetime64 timestamps into 32D multi-frequency calendar embeddings. @@ -489,6 +499,7 @@ def forward(self, timestamp: np.ndarray) -> torch.Tensor: # TODO: Consider adding local time encoding (e.g., using longitude) + timestamp = np.asarray(timestamp) orig_shape = timestamp.shape timestamp_flat = timestamp.reshape(-1) @@ -497,7 +508,7 @@ def forward(self, timestamp: np.ndarray) -> torch.Tensor: # --- Extract time components --- ts_int64 = timestamp_flat.astype("int64") # seconds since Unix epoch seconds_in_day = 86400.0 - seconds_of_day = (ts_int64 % int(seconds_in_day)) / seconds_in_day # [0, 1) + tod_frac = (ts_int64 % int(seconds_in_day)) / seconds_in_day # [0, 1) # --- Extract day of year --- day_np = timestamp_flat.astype("datetime64[D]") @@ -510,29 +521,51 @@ def forward(self, timestamp: np.ndarray) -> torch.Tensor: days_in_year = (next_year_start - year_start).astype(np.int64) # 365 or 366 doy_frac = day_of_year_0.astype(np.float32) / days_in_year.astype(np.float32) # [0, 1) - # --- Multi-frequency sinusoidal embeddings --- - # Build output for all 8 frequency scales - embeddings = [] - for k in range(1, self.num_frequencies + 1): - k_float = float(k) - - # Day-of-year components: cos(2π·k·doy/365.25), sin(2π·k·doy/365.25) - doy_phase = two_pi * k_float * doy_frac - doy_cos = np.cos(doy_phase).astype(np.float32) - doy_sin = np.sin(doy_phase).astype(np.float32) - - # Time-of-day components: cos(k·t), sin(k·t) where t = 2π·seconds_of_day - tot_phase = k_float * two_pi * seconds_of_day - tot_cos = np.cos(tot_phase).astype(np.float32) - tot_sin = np.sin(tot_phase).astype(np.float32) - - embeddings.append(doy_cos) - embeddings.append(doy_sin) - embeddings.append(tot_cos) - embeddings.append(tot_sin) - - # Stack all components: (N, 32) - out = np.stack(embeddings, axis=-1) + # --- Multi-frequency sinusoidal embeddings (vectorized over k) --- + k = np.arange(1, self.num_frequencies + 1, dtype=np.float32)[None, :] + doy_phase = two_pi * doy_frac[:, None] * k + tod_phase = two_pi * tod_frac[:, None] * k + + doy_cos = np.cos(doy_phase).astype(np.float32) + doy_sin = np.sin(doy_phase).astype(np.float32) + tod_cos = np.cos(tod_phase).astype(np.float32) + tod_sin = np.sin(tod_phase).astype(np.float32) + + # Stack all components: (N, K, 4) -> (N, K*4) + out = np.stack([doy_cos, doy_sin, tod_cos, tod_sin], axis=-1) + out = out.reshape(out.shape[0], self.num_frequencies * 4) out = torch.from_numpy(out).float() + breakpoint() + self.plot_embedding_heatmap(out) + breakpoint() + return out.reshape(*orig_shape, self.num_frequencies * 4) + + def plot_embedding_heatmap(self, emb: torch.Tensor) -> torch.Tensor: + """ + Compute and plot a heatmap of the date/time embedding for debugging. + + Args: + emb: torch.Tensor of shape (..., 32) containing multi-frequency embeddings + + Returns: + torch.Tensor of shape (..., 32) containing multi-frequency embeddings + """ + import matplotlib.pyplot as plt + import os + + emb_2d = emb.reshape(emb.shape[0], -1) if emb.ndim > 1 else emb.view(1, -1) + + plt.figure(figsize=(8, 4)) + plt.imshow(emb_2d.detach().cpu().numpy(), aspect="auto", cmap="viridis") + plt.colorbar(label="embedding value") + plt.xlabel("embedding dimension") + plt.ylabel("sample index") + plt.title("Date/time embedding heatmap") + plt.tight_layout() + os.makedirs("plots", exist_ok=True) + plt.savefig("plots/datetime_embedding_heatmap.png", dpi=150) + plt.close() + + return emb diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 41437f55f..066183050 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -586,6 +586,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) else: @@ -606,6 +607,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) # Add MLP block @@ -621,6 +623,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, is_dit=self.cf.fe_diffusion_model, + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], ) ) # Optionally, add LayerNorm after i-th layer @@ -668,10 +671,13 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - assert ada_ln_aux is None, ( - "ada_ln_aux should not be provided when diffusion model conditioning is disabled" - ) - tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) + if self.cf.fe_diffusion_model_conditioning in ["date_time"]: + # Assuming ada_ln_aux contains the date_time embedding in this case + assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) + else: + assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" + tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index b447982d0..54dcefca2 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from weathergen.model.norms import AdaLayerNorm, RMSNorm, SwiGLU +from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm, SwiGLU class NamedLinear(torch.nn.Module): @@ -62,6 +62,7 @@ def __init__( mlp_type="mlp", name: str | None = None, is_dit=False, + dit_is_cond=False, ): """Constructor""" @@ -75,6 +76,7 @@ def __init__( self.with_residual = with_residual self.with_aux = dim_aux is not None self.is_dit = is_dit + self.dit_is_cond = dit_is_cond self.mlp_type = mlp_type.lower() dim_hidden = int(dim_in * hidden_factor) @@ -90,9 +92,11 @@ def __init__( norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm if is_dit: - assert dim_aux is None, "conditioning not yet implemented for DIT attention" + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" assert with_residual, "DIT attention should always have residual connection" - self.lnorm = norm(dim_in, eps=norm_eps) + self.lnorm = AdaLNZero(dim_in, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_in, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning(dim_in) self.noise_conditioning = LinearNormConditioning(dim_in) elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) @@ -151,7 +155,8 @@ def forward(self, *args): x = self.lnorm(x) cond_gate = 1 assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" - x, gate = self.noise_conditioning(x, noise_emb) + x, noise_gate = self.noise_conditioning(x, noise_emb) + gate = cond_gate * noise_gate for layer in self.layers: if isinstance(layer, AdaLayerNorm): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 3e4676824..79021d7af 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -387,7 +387,13 @@ def create(self) -> "Model": # Initialize forecasting engine: standard or diffusion-wrapped mode_cfg = cf.training_config if cf.fe_num_blocks > 0: - self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) + if cf.get("fe_diffusion_model_conditioning", None) in ["date_time"]: + assert cf.diffusion_conditioning_embed_dim is not None, ( + "Diffusion conditioning embedding dimension must be specified when using diffusion model conditioning" + ) + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells, dim_aux=self.cf.diffusion_conditioning_embed_dim) + else: + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) if cf.get("fe_diffusion_model", False): self.forecast_engine = DiffusionForecastEngine( cf, self.num_healpix_cells, forecast_engine=self.forecast_engine diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 1d394057c..0526c6f90 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -60,41 +60,39 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight - class AdaLNZero(torch.nn.Module): - """DiT-style adaptive layer norm with zero initialization for diffusion models.""" + """ + AdaLayerNorm with zero initialization and with additional gate parameter + """ - def __init__(self, dim_embed: int, dim_aux: int, norm_eps: float = 1e-5): + def __init__( + self, dim_embed_x, dim_aux, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5 + ): super().__init__() - self.norm = torch.nn.LayerNorm(dim_embed, elementwise_affine=False, eps=norm_eps) - self.gate_proj = torch.nn.Linear(dim_aux, dim_embed) - self.shift_proj = torch.nn.Linear(dim_aux, dim_embed) - - with torch.no_grad(): - self.gate_proj.weight.zero_() - self.gate_proj.bias.zero_() - self.shift_proj.weight.zero_() - self.shift_proj.bias.zero_() - - def forward(self, x: torch.Tensor, aux: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Returns (x_normalized_and_scaled, gate_signal) for residual modulation.""" - x_norm = self.norm(x) - if aux.dim() == 0: - aux = aux.unsqueeze(0) + # simple 2-layer MLP for embedding auxiliary information + self.embed_aux = torch.nn.ModuleList() + self.embed_aux.append(torch.nn.Linear(dim_aux, 6 * dim_aux)) + self.embed_aux.append(torch.nn.SiLU()) + self.embed_aux.append(torch.nn.Linear(6 * dim_aux, 3 * dim_embed_x)) - gate_params = self.gate_proj(aux) - shift_params = self.shift_proj(aux) + self.norm = torch.nn.LayerNorm( + dim_embed_x, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + # Zero-initialize the final modulation layer. + nn.init.zeros_(self.embed_aux[-1].weight) + nn.init.zeros_(self.embed_aux[-1].bias) - while gate_params.dim() < x_norm.dim(): - gate_params = gate_params.unsqueeze(-2) - shift_params = shift_params.unsqueeze(-2) + def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: + for block in self.embed_aux: + aux = block(aux) + scale, shift, gate = aux.chunk(3, dim=-1) - gate = 1 + gate_params - x_out = gate * x_norm + shift_params - gate_signal = gate.mean(dim=-1, keepdim=True) + x = self.norm(x) * (1 + scale) + shift - return x_out, gate_signal + return x, gate class AdaLayerNorm(torch.nn.Module): From dd97830ea40321b33d5d49c1dd56368b680de675 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 7 May 2026 20:51:31 +0200 Subject: [PATCH 321/344] remove debugging tool --- src/weathergen/model/diffusion.py | 32 ------------------------------- 1 file changed, 32 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index fe3e411df..933dd6b1f 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -536,36 +536,4 @@ def forward(self, timestamp: np.ndarray | np.datetime64) -> torch.Tensor: out = out.reshape(out.shape[0], self.num_frequencies * 4) out = torch.from_numpy(out).float() - breakpoint() - self.plot_embedding_heatmap(out) - breakpoint() - return out.reshape(*orig_shape, self.num_frequencies * 4) - - def plot_embedding_heatmap(self, emb: torch.Tensor) -> torch.Tensor: - """ - Compute and plot a heatmap of the date/time embedding for debugging. - - Args: - emb: torch.Tensor of shape (..., 32) containing multi-frequency embeddings - - Returns: - torch.Tensor of shape (..., 32) containing multi-frequency embeddings - """ - import matplotlib.pyplot as plt - import os - - emb_2d = emb.reshape(emb.shape[0], -1) if emb.ndim > 1 else emb.view(1, -1) - - plt.figure(figsize=(8, 4)) - plt.imshow(emb_2d.detach().cpu().numpy(), aspect="auto", cmap="viridis") - plt.colorbar(label="embedding value") - plt.xlabel("embedding dimension") - plt.ylabel("sample index") - plt.title("Date/time embedding heatmap") - plt.tight_layout() - os.makedirs("plots", exist_ok=True) - plt.savefig("plots/datetime_embedding_heatmap.png", dpi=150) - plt.close() - - return emb From 77d5e0afa518d5ea721e4d4ff708793395899d18 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 7 May 2026 23:00:16 +0200 Subject: [PATCH 322/344] implement time only / day only conditioning --- config/config_diffusion.yml | 2 +- .../datasets/multi_stream_data_sampler.py | 4 ++-- src/weathergen/model/diffusion.py | 24 ++++++++++++------- src/weathergen/model/engines.py | 8 +++---- src/weathergen/model/model.py | 2 +- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 971e25edc..9bb418fee 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,7 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: "date_time" # options: "date_time" +fe_diffusion_model_conditioning: "time" # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 5b888a9e5..cb5c5749b 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -724,7 +724,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): target_metadata = target_masks.metadata[tidx] # Get first target step's times (using self.output_offset as the first output step index) - if self.diffusion_model_conditioning == "date_time": + if self.diffusion_model_conditioning in ["date_time", "date", "time"]: target_times_array = sdata.target_times_raw[self.output_offset] target_metadata.add_params({'timestamp': ( target_times_array[0] if len(target_times_array) > 0 else None @@ -744,7 +744,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) #add target times in source for diffusion model date/time conditioning - if self.diffusion_model_conditioning == "date_time": + if self.diffusion_model_conditioning in ["date_time", "date", "time"]: #TODO: Might need upgrading fro num_samples > 1 # Assert singular source and target samples diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 933dd6b1f..cd464934f 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -50,7 +50,9 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.noise_embedder = NoiseEmbedder( embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim ) - self.datetime_embedder = DateTimeEncoder() + self.conditioning = self.cf.fe_diffusion_model_conditioning + if "date" in self.conditioning or "time" in self.conditioning: + self.datetime_embedder = DateTimeEncoder(self.conditioning) # Parameters self.sigma_min = self.cf.sigma_min @@ -157,7 +159,7 @@ def training_forward( self.cur_token = tokens.detach() - if self.cf.fe_diffusion_model_conditioning == "date_time": + if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) else: c = None @@ -203,7 +205,7 @@ def denoise( # Precondition input and feed through network x = self.preconditioner.precondition(x, c) # currently does nothing - if self.cf.fe_diffusion_model_conditioning == "date_time": + if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) return c_skip * x + c_out * self.net( @@ -236,7 +238,7 @@ def inference_forward( # Extract conditioning from meta_info (same as training_forward) c = None - if self.cf.fe_diffusion_model_conditioning == "date_time": + if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] @@ -482,9 +484,13 @@ class DateTimeEncoder(torch.nn.Module): - tod_frac = seconds_of_day / 86400.0 """ - def __init__(self): + def __init__(self, conditioning: str): super().__init__() self.num_frequencies = 8 + assert conditioning in ["date_time", "date", "time"], f"Unsupported conditioning: {conditioning}" + self.date_only = conditioning == "date" + self.time_only = conditioning == "time" + def forward(self, timestamp: np.ndarray | np.datetime64) -> torch.Tensor: """ @@ -526,10 +532,10 @@ def forward(self, timestamp: np.ndarray | np.datetime64) -> torch.Tensor: doy_phase = two_pi * doy_frac[:, None] * k tod_phase = two_pi * tod_frac[:, None] * k - doy_cos = np.cos(doy_phase).astype(np.float32) - doy_sin = np.sin(doy_phase).astype(np.float32) - tod_cos = np.cos(tod_phase).astype(np.float32) - tod_sin = np.sin(tod_phase).astype(np.float32) + doy_cos = np.cos(doy_phase).astype(np.float32) if not self.time_only else np.zeros_like(doy_phase).astype(np.float32) + doy_sin = np.sin(doy_phase).astype(np.float32) if not self.time_only else np.zeros_like(doy_phase).astype(np.float32) + tod_cos = np.cos(tod_phase).astype(np.float32) if not self.date_only else np.zeros_like(tod_phase).astype(np.float32) + tod_sin = np.sin(tod_phase).astype(np.float32) if not self.date_only else np.zeros_like(tod_phase).astype(np.float32) # Stack all components: (N, K, 4) -> (N, K*4) out = np.stack([doy_cos, doy_sin, tod_cos, tod_sin], axis=-1) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 066183050..9d0ee2daf 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -586,7 +586,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"], ) ) else: @@ -607,7 +607,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"], ) ) # Add MLP block @@ -623,7 +623,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time"], + dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"], ) ) # Optionally, add LayerNorm after i-th layer @@ -671,7 +671,7 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - if self.cf.fe_diffusion_model_conditioning in ["date_time"]: + if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: # Assuming ada_ln_aux contains the date_time embedding in this case assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 79021d7af..3a408d99c 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -387,7 +387,7 @@ def create(self) -> "Model": # Initialize forecasting engine: standard or diffusion-wrapped mode_cfg = cf.training_config if cf.fe_num_blocks > 0: - if cf.get("fe_diffusion_model_conditioning", None) in ["date_time"]: + if cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: assert cf.diffusion_conditioning_embed_dim is not None, ( "Diffusion conditioning embedding dimension must be specified when using diffusion model conditioning" ) From 11c9e6ae2985e0623bf3a62db623b9d065c551cf Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 13 May 2026 20:16:49 +0200 Subject: [PATCH 323/344] date_time conditioning --- config/config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 9bb418fee..971e25edc 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,7 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: "time" # options: "date_time" +fe_diffusion_model_conditioning: "date_time" # options: "date_time" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) From b6ace25f414c79a53a355303f678e2745fe7eb99 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 19 May 2026 12:59:33 +0200 Subject: [PATCH 324/344] activate swiglu, xsa --- config/config_diffusion.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 971e25edc..5344c0784 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -83,10 +83,10 @@ healpix_level: 5 # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding rope_2D: True -# mlp_type: swiglu -# use_xsa: True -mlp_type: mlp -use_xsa: False +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False with_mixed_precision: True with_flash_attention: True From 2b1ff151dc6934f36e66d5e0839f7302cca7de22 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 19 May 2026 15:41:43 +0200 Subject: [PATCH 325/344] Update diffusion config with more pre-trained models --- config/config_diffusion.yml | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 5567d4df1..183a3d182 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -81,10 +81,10 @@ healpix_level: 5 # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding rope_2D: True -# mlp_type: swiglu -# use_xsa: True -mlp_type: mlp -use_xsa: False +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False with_mixed_precision: True with_flash_attention: True @@ -102,7 +102,10 @@ latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True -freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +freeze_modules: "" # load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 # load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 # load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 @@ -120,7 +123,9 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_to # load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline # load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline # load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline -load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa, diffusion-full-pipeline +load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 From 4dca300dc785df737382961373cf517d09f8f7a8 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 19 May 2026 19:04:34 +0200 Subject: [PATCH 326/344] initial commit with data flow for the forecast conditioning (conditioning not implemented) --- CLAUDE.md | 109 ++++++++++++++++++ config/config_diffusion.yml | 4 +- .../datasets/multi_stream_data_sampler.py | 67 ++++++++--- src/weathergen/model/diffusion.py | 18 ++- .../loss_module_latent_diffusion.py | 4 +- .../train/target_and_aux_diffusion.py | 32 ++++- .../train/target_and_aux_module_base.py | 3 + src/weathergen/train/trainer.py | 13 ++- 8 files changed, 220 insertions(+), 30 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..23ccb7dbf --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,109 @@ +# WeatherGenerator — Claude Code Guide + +## Running training + +```bash +uv run train --base-config config/config_diffusion.yml +``` + +Request an interactive GPU node first with `agpu`, then run the command above. + +## Debugging training failures + +### General approach + +1. Run training and capture the full traceback — the first error is usually the root cause; later errors are often cascades. +2. Read the file at the crashing line before editing anything. +3. Verify the fix in isolation if possible (small unit test or `python -c "..."`) before re-running the full job. + +### Common error patterns + +#### `KeyError` / `AttributeError` in loss modules + +`preds.latent` is a list indexed by output step. The shape depends on `output_offset` and `num_steps` in the forecast config: + +- Index 0 may hold `{"posteriors": ...}` (encoder posteriors) rather than a latent state. +- Always guard with `.get("key") is not None` instead of `if pl:` — a non-empty posteriors dict is truthy but does not contain `"latent_state"`. + +```python +# safe +pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl.get("latent_state") is not None] +``` + +#### `TypeError: unexpected keyword argument` in data sampler + +`MultiStreamDataSampler._build_stream_data` has several call sites. If you add a parameter, check every call site — the method signature and each caller must agree. + +#### `AssertionError: ada_ln_aux should not be provided when diffusion model conditioning is disabled` + +`ForecastingEngine.forward()` in `fe_diffusion_model=True` mode asserts `ada_ln_aux is None`. Pass `ada_ln_aux=None` in `DiffusionForecastEngine.denoise()` until conditioning is wired into the network blocks. + +#### `bdb.BdbQuit` — process exits silently + +A `breakpoint()` call was left in the code. Running non-interactively (batch job, subprocess) causes Python's debugger to immediately quit the process. Search for and remove stray `breakpoint()` calls before submitting jobs: + +```bash +grep -rn "breakpoint()" src/ +``` + +### Key data flow: DiffusionForecastEngine + +``` +MultiStreamDataSampler._get_batch() + → source_batch (X_t in source_tokens_cells) + → target_batch (X_{t+1} in source_tokens_cells, when mode=="diffusion_forecast") + +trainer.train(): + 1. target_aux.pre_compute(source_batch, target_batch) ← runs BEFORE model.forward + encodes X_{t+1} via frozen encoder + writes tokens into source_batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] + 2. preds = model(source_batch) + DiffusionForecastEngine.training_forward(): + y = meta_info["ERA5"].params["diffusion_target_tokens"] # X_{t+1} + c = tokens # X_t (conditioning) + adds EDM noise → calls denoise(x=y+n, c=c, sigma) + 3. targets_and_auxs = target_aux.compute(target_batch) + reuses _pending_tokens set by pre_compute (no second encoder pass) + returns diffusion_latent = encoded X_{t+1} + 4. loss(preds, targets_and_auxs) + compares denoised prediction against diffusion_latent +``` + +### Channel/normalization mismatch trap + +`source_exclude` and `target_exclude` differ for ERA5 streams (`skt` vs `slor`/`sdor`). Never reuse `output_data` (target-normalized, target-channels) as a drop-in for source input. When building X_{t+1} as a source-side input, collect it explicitly as `"source"` type: + +```python +future_input_data = [collect_datasources(stream_ds, idx + step_delta, "source", self.rng)] +``` + +### Config levers for the diffusion training mode + +In `config/config_diffusion.yml`: + +```yaml +training_mode: ["masking", "diffusion_forecast"] # enables X_{t+1} target conditioning +num_steps: 1 +offset: 1 # target batch is one step ahead of source batch +``` + +With `offset: 1, num_steps: 1`: `output_steps=2`, `output_idxs=[1]`. The posteriors slot lives at index 0 of `preds.latent`; the diffusion latent state is at index 1. + +## Project layout (relevant to training) + +``` +src/weathergen/ + model/ + diffusion.py # DiffusionForecastEngine — training_forward, denoise, inference_forward + engines.py # ForecastingEngine (the underlying transformer) + train/ + trainer.py # main training loop; pre_compute hook lives here + target_and_aux_diffusion.py # DiffusionLatentTargetEncoder (frozen encoder + pre_compute) + target_and_aux_module_base.py # base class with no-op pre_compute + loss_modules/ + loss_module_latent_diffusion.py # latent-space EDM loss + datasets/ + multi_stream_data_sampler.py # batch construction; diffusion_forecast mode here +config/ + config_diffusion.yml # main config for diffusion experiments +``` diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 5344c0784..0537edba6 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -182,7 +182,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + training_mode: ["masking","diffusion_forecast"] num_mini_epochs: 128 samples_per_mini_epoch: 4096 @@ -280,7 +280,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: True + validate_before_training: False # test config; full test config is merge of validation and test config diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index cb5c5749b..dad0e052a 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -484,6 +484,7 @@ def _build_stream_data( output_tokens: list, output_mask, input_mask, + input_base_idx: TIndex = None, ) -> StreamData: """ Return one batch of data @@ -499,7 +500,10 @@ def _build_stream_data( output_mask : mask for output/prediction/target input_mask : mask for network input (can be source or target) - + input_base_idx: Override time index for the input section's time-window + lookup. When None, falls back to base_idx. Pass a shifted index + (e.g. base_idx + output_offset) together with future source data + to make source_tokens_cells encode a future timestep. Returns: StreamData with source and targets masked according to view_meta @@ -516,7 +520,7 @@ def _build_stream_data( stream_data = self._build_stream_data_input( modes, stream_data, - base_idx, + input_base_idx if input_base_idx is not None else base_idx, stream_info, num_steps_input, input_data, @@ -646,6 +650,12 @@ def _get_batch(self, idx: int, num_forecast_steps: int): if "student_teacher" in mode or "latent_loss" in mode: source_select += ["network_input"] target_select += ["network_input"] + if "diffusion_forecast" in mode: + # Like student_teacher but target samples encode X_{t+1} instead of X_t. + # _build_stream_data is called with output_data as input so source_tokens_cells + # holds the future state; DiffusionLatentTargetEncoder then encodes X_{t+1}. + source_select += ["network_input", "target_coords"] + target_select += ["network_input"] # remove duplicates source_select, target_select = list(set(source_select)), list(set(target_select)) if len(source_select) == 0 or len(target_select) == 0: @@ -708,19 +718,46 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for tidx, target_mask in enumerate(target_masks.masks): # depending on the mode, the the streamdata obj to have the target mask applied to # the inputs. Hence the target mask is also the source mask here. - sdata = self._build_stream_data( - target_select, - idx, - num_forecast_steps, - stream_info, - target_masks.metadata[tidx].params.get("num_steps_input", 1), - input_data, - output_data, - input_tokens, - output_tokens, - output_mask=target_mask, - input_mask=target_mask, - ) + if "diffusion_forecast" in mode: + # Shift the input window forward by output_offset steps so that + # source_tokens_cells encodes X_{t+1} rather than X_t. + # Collect X_{t+1} as SOURCE (not reusing output_data which uses target + # channels/normalization — source and target channel sets differ). + step_delta = (self.time_step * self.output_offset) // self.step_timedelta + future_input_data = [ + collect_datasources(stream_ds, idx + step_delta, "source", self.rng) + ] + future_input_tokens = self.tokenizer.get_tokens_windows( + stream_info, future_input_data, True + ) + sdata = self._build_stream_data( + target_select, + idx, + num_forecast_steps, + stream_info, + 1, + future_input_data, + output_data, + future_input_tokens, + output_tokens, + output_mask=target_mask, + input_mask=target_mask, + input_base_idx=idx + step_delta, + ) + else: + sdata = self._build_stream_data( + target_select, + idx, + num_forecast_steps, + stream_info, + target_masks.metadata[tidx].params.get("num_steps_input", 1), + input_data, + output_data, + input_tokens, + output_tokens, + output_mask=target_mask, + input_mask=target_mask, + ) target_metadata = target_masks.metadata[tidx] # Get first target step's times (using self.output_offset as the first output step index) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index cd464934f..ba6891a6f 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -161,10 +161,15 @@ def training_forward( if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) - else: - c = None - - y = tokens + y = meta_info["ERA5"].params.get("diffusion_target_tokens") + elif self.cf.fe_diffusion_model_conditioning in ["forecast"]: + # In diffusion_forecast mode, meta_info carries X_{t+1} tokens written by + # DiffusionLatentTargetEncoder.pre_compute(). Noise X_{t+1} as the target and + # use X_t (the incoming rollout state) as the conditioning signal c. + # Falls back to denoising-autoencoder behaviour when no target is cached + # (e.g. validation or unconditional pre-training). + y = meta_info["ERA5"].params.get("diffusion_target_tokens") + c = tokens # X_t as conditioning if self.training: eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) @@ -208,8 +213,11 @@ def denoise( if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) + # c (X_t conditioning) is not yet wired into the network blocks — ada_ln_aux + # is unsupported in fe_diffusion_model mode. c is passed through Preconditioner + # and available for future architectural integration (e.g. concat or cross-attn). return c_skip * x + c_out * self.net( - c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=c + c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=None ) # Eq. (7) in EDM paper def inference_forward( diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index a89c9d6b8..35285ae85 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -94,8 +94,8 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] - target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl.get("latent_state") is not None] + target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent.get("diffusion_latent") is not None] eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index eb0f905fb..08b67e3ad 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -21,6 +21,7 @@ def __init__(self, encoder, is_model_sharded=True): self.is_model_sharded = is_model_sharded self._fixed_noise_level: float | None = None + self._pending_tokens: torch.Tensor | None = None # Build a name → param map once self.src_params = dict(self.encoder.named_parameters()) @@ -47,6 +48,23 @@ def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: if self.is_model_sharded: self.encoder.reshard() + def pre_compute(self, istep, source_batch, target_batch, model_params, model, **kwargs) -> None: + """ + Encode the target batch (whose source_tokens_cells holds X_{t+1} in + diffusion_forecast mode) before model.forward() so that training_forward + can noise X_{t+1} and condition on X_t. + + Stores encoded tokens in: + - self._pending_tokens → reused by compute() to skip a second encoder pass + - source_batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] + → flows to DiffusionForecastEngine.training_forward via model.forward's meta_info + """ + with torch.no_grad(): + self.encoder.encoder.eval() + tokens, _ = self.encoder.encoder(model_params=model_params, batch=target_batch) + self._pending_tokens = tokens + source_batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] = tokens + def compute( self, istep: int, @@ -66,11 +84,15 @@ def compute( else: noise_level_rn = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 - # TODO: check if there are scenarios where the encoder needs to be set to eval - with torch.no_grad(): - self.encoder.encoder.eval() # NOTE: might be redundant - tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) - # NOTE: must not set to train afterwards unless it was already in train + # Reuse tokens from pre_compute when available (avoids a second encoder pass). + # Falls back to encoding the batch directly (e.g. during validation). + if self._pending_tokens is not None: + tokens = self._pending_tokens + self._pending_tokens = None + else: + with torch.no_grad(): + self.encoder.encoder.eval() # NOTE: might be redundant + tokens, _ = self.encoder.encoder(model_params=model_params, batch=batch) output_idxs = batch.get_output_idxs() assert len(output_idxs) > 0 diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index 88f9dc93d..dcdabeef2 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -69,6 +69,9 @@ def __init__(self, cf, model, **kwargs): def reset(self): pass + def pre_compute(self, istep, source_batch, target_batch, model_params, model, **kwargs) -> None: + pass + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: pass diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 25f62309f..58eb6fd4b 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -486,9 +486,20 @@ def train(self, mini_epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): + source_samples = batch.get_source_samples() + for loss_name, target_aux in self.target_and_aux_calculators.items(): + target_idxs = get_target_idxs_from_cfg(self.training_cfg, loss_name) + target_aux.pre_compute( + self.cf.general.istep, + source_samples, + batch.get_target_samples(target_idxs), + self.model_params, + self.model, + ) + preds = self.model( self.model_params, - batch.get_source_samples(), + source_samples, ) targets_and_auxs = {} From d2e7b5d12d90f3654ce2518edca617df0e76cbf1 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 19 May 2026 19:20:09 +0200 Subject: [PATCH 327/344] offset 1 --- config/config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 0537edba6..0f323cbaf 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -245,7 +245,7 @@ training_config: forecast : time_step: 06:00:00 num_steps: 1 - offset: 0 + offset: 1 policy: "fixed" From fde12303f34e5b4bb2342a368a4eb47443845494 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Tue, 19 May 2026 19:43:29 +0200 Subject: [PATCH 328/344] bug fix from merge --- config/config_diffusion.yml | 2 +- src/weathergen/model/diffusion.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 0f323cbaf..85d2fbd7c 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -60,7 +60,7 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True -fe_diffusion_model_conditioning: "date_time" # options: "date_time" +fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index ba6891a6f..f28d6449f 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -216,8 +216,9 @@ def denoise( # c (X_t conditioning) is not yet wired into the network blocks — ada_ln_aux # is unsupported in fe_diffusion_model mode. c is passed through Preconditioner # and available for future architectural integration (e.g. concat or cross-attn). + ada_ln_aux = c if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"] else None return c_skip * x + c_out * self.net( - c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=None + c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux ) # Eq. (7) in EDM paper def inference_forward( From 3f897699a388044f3bcd4e06dc85527bf5124226 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 09:19:23 +0200 Subject: [PATCH 329/344] change ada_ln argument passing --- config/config_diffusion.yml | 1 + src/weathergen/model/diffusion.py | 14 ++++++++++---- src/weathergen/model/engines.py | 8 ++++---- src/weathergen/model/model.py | 2 +- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 85d2fbd7c..de9d60f28 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -61,6 +61,7 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" #options: "ada_ln", "cross_attn", "add" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index f28d6449f..adce59e17 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -51,6 +51,15 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim ) self.conditioning = self.cf.fe_diffusion_model_conditioning + self.conditioning_type = self.cf.get("fe_diffusion_model_conditioning_type", None) + + _date_time_modes = {"date_time", "date", "time"} + assert self.conditioning not in _date_time_modes or self.conditioning_type == "ada_ln", ( + f"fe_diffusion_model_conditioning_type must be 'ada_ln' when " + f"fe_diffusion_model_conditioning is '{self.conditioning}' " + f"(got '{self.conditioning_type}')" + ) + if "date" in self.conditioning or "time" in self.conditioning: self.datetime_embedder = DateTimeEncoder(self.conditioning) @@ -213,10 +222,7 @@ def denoise( if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) - # c (X_t conditioning) is not yet wired into the network blocks — ada_ln_aux - # is unsupported in fe_diffusion_model mode. c is passed through Preconditioner - # and available for future architectural integration (e.g. concat or cross-attn). - ada_ln_aux = c if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"] else None + ada_ln_aux = c if self.conditioning_type == "ada_ln" else None return c_skip * x + c_out * self.net( c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux ) # Eq. (7) in EDM paper diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 9d0ee2daf..3d5112ff1 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -586,7 +586,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"], + dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) else: @@ -607,7 +607,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"], + dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) # Add MLP block @@ -623,7 +623,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, is_dit=self.cf.fe_diffusion_model, - dit_is_cond=self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"], + dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) # Optionally, add LayerNorm after i-th layer @@ -671,7 +671,7 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: + if self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln": # Assuming ada_ln_aux contains the date_time embedding in this case assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 3a408d99c..1442e9cf4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -387,7 +387,7 @@ def create(self) -> "Model": # Initialize forecasting engine: standard or diffusion-wrapped mode_cfg = cf.training_config if cf.fe_num_blocks > 0: - if cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln": assert cf.diffusion_conditioning_embed_dim is not None, ( "Diffusion conditioning embedding dimension must be specified when using diffusion model conditioning" ) From d74029d0e6e385408d63938de1a68898ca619fbc Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 09:45:58 +0200 Subject: [PATCH 330/344] naive implementation of conditioning via concatenation --- config/config_diffusion.yml | 2 +- src/weathergen/model/diffusion.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index de9d60f28..ddea20d98 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -61,7 +61,7 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" -fe_diffusion_model_conditioning_type: "ada_ln" #options: "ada_ln", "cross_attn", "add" +fe_diffusion_model_conditioning_type: "concat" #options: "ada_ln", "concat" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index adce59e17..bd2b60a3d 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,6 +59,20 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) + assert self.conditioning != "forecast" or self.conditioning_type == "concat", ( + f"fe_diffusion_model_conditioning_type must be 'concat' when " + f"fe_diffusion_model_conditioning is 'forecast' " + f"(got '{self.conditioning_type}')" + ) + + if self.conditioning_type == "concat": + D = self.cf.ae_global_dim_embed + self.cond_proj = torch.nn.Linear(2 * D, D, bias=False) + # Warm-start: pass the noisy input through unchanged, ignore conditioning initially. + # The network can then gradually learn to use the conditioning signal. + with torch.no_grad(): + self.cond_proj.weight[:, :D] = torch.eye(D) + self.cond_proj.weight[:, D:] = 0.0 if "date" in self.conditioning or "time" in self.conditioning: self.datetime_embedder = DateTimeEncoder(self.conditioning) @@ -222,9 +236,17 @@ def denoise( if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) + # Build network input depending on conditioning type. + # "concat": channel-concatenate enc(X_t) with the preconditioned noisy input and + # project back to D via a learned linear layer (GenCast-style). + # "ada_ln": pass conditioning through ada_ln_aux into DiT AdaLN blocks. + net_input = c_in * x + if self.conditioning_type == "concat" and c is not None: + net_input = self.cond_proj(torch.cat([net_input, c], dim=-1)) + ada_ln_aux = c if self.conditioning_type == "ada_ln" else None return c_skip * x + c_out * self.net( - c_in * x, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux + net_input, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux ) # Eq. (7) in EDM paper def inference_forward( @@ -250,11 +272,13 @@ def inference_forward( torch.Tensor: Generated sample of shape (1, num_healpix_cells, ae_global_dim_embed) """ - # Extract conditioning from meta_info (same as training_forward) + # Extract conditioning (mirrors training_forward). c = None - if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] + elif self.cf.fe_diffusion_model_conditioning == "forecast": + # cur_token = enc(X_t) stored in forward() before routing to inference_forward + c = self.cur_token # Sample pure noise (assuming single batch element for now) From 8a1b698a6dabfa5b96ea03ef3517d251334a01f1 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 10:54:27 +0200 Subject: [PATCH 331/344] remove CLAUDE.md --- CLAUDE.md | 109 ------------------------------------------------------ 1 file changed, 109 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 23ccb7dbf..000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,109 +0,0 @@ -# WeatherGenerator — Claude Code Guide - -## Running training - -```bash -uv run train --base-config config/config_diffusion.yml -``` - -Request an interactive GPU node first with `agpu`, then run the command above. - -## Debugging training failures - -### General approach - -1. Run training and capture the full traceback — the first error is usually the root cause; later errors are often cascades. -2. Read the file at the crashing line before editing anything. -3. Verify the fix in isolation if possible (small unit test or `python -c "..."`) before re-running the full job. - -### Common error patterns - -#### `KeyError` / `AttributeError` in loss modules - -`preds.latent` is a list indexed by output step. The shape depends on `output_offset` and `num_steps` in the forecast config: - -- Index 0 may hold `{"posteriors": ...}` (encoder posteriors) rather than a latent state. -- Always guard with `.get("key") is not None` instead of `if pl:` — a non-empty posteriors dict is truthy but does not contain `"latent_state"`. - -```python -# safe -pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl.get("latent_state") is not None] -``` - -#### `TypeError: unexpected keyword argument` in data sampler - -`MultiStreamDataSampler._build_stream_data` has several call sites. If you add a parameter, check every call site — the method signature and each caller must agree. - -#### `AssertionError: ada_ln_aux should not be provided when diffusion model conditioning is disabled` - -`ForecastingEngine.forward()` in `fe_diffusion_model=True` mode asserts `ada_ln_aux is None`. Pass `ada_ln_aux=None` in `DiffusionForecastEngine.denoise()` until conditioning is wired into the network blocks. - -#### `bdb.BdbQuit` — process exits silently - -A `breakpoint()` call was left in the code. Running non-interactively (batch job, subprocess) causes Python's debugger to immediately quit the process. Search for and remove stray `breakpoint()` calls before submitting jobs: - -```bash -grep -rn "breakpoint()" src/ -``` - -### Key data flow: DiffusionForecastEngine - -``` -MultiStreamDataSampler._get_batch() - → source_batch (X_t in source_tokens_cells) - → target_batch (X_{t+1} in source_tokens_cells, when mode=="diffusion_forecast") - -trainer.train(): - 1. target_aux.pre_compute(source_batch, target_batch) ← runs BEFORE model.forward - encodes X_{t+1} via frozen encoder - writes tokens into source_batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] - 2. preds = model(source_batch) - DiffusionForecastEngine.training_forward(): - y = meta_info["ERA5"].params["diffusion_target_tokens"] # X_{t+1} - c = tokens # X_t (conditioning) - adds EDM noise → calls denoise(x=y+n, c=c, sigma) - 3. targets_and_auxs = target_aux.compute(target_batch) - reuses _pending_tokens set by pre_compute (no second encoder pass) - returns diffusion_latent = encoded X_{t+1} - 4. loss(preds, targets_and_auxs) - compares denoised prediction against diffusion_latent -``` - -### Channel/normalization mismatch trap - -`source_exclude` and `target_exclude` differ for ERA5 streams (`skt` vs `slor`/`sdor`). Never reuse `output_data` (target-normalized, target-channels) as a drop-in for source input. When building X_{t+1} as a source-side input, collect it explicitly as `"source"` type: - -```python -future_input_data = [collect_datasources(stream_ds, idx + step_delta, "source", self.rng)] -``` - -### Config levers for the diffusion training mode - -In `config/config_diffusion.yml`: - -```yaml -training_mode: ["masking", "diffusion_forecast"] # enables X_{t+1} target conditioning -num_steps: 1 -offset: 1 # target batch is one step ahead of source batch -``` - -With `offset: 1, num_steps: 1`: `output_steps=2`, `output_idxs=[1]`. The posteriors slot lives at index 0 of `preds.latent`; the diffusion latent state is at index 1. - -## Project layout (relevant to training) - -``` -src/weathergen/ - model/ - diffusion.py # DiffusionForecastEngine — training_forward, denoise, inference_forward - engines.py # ForecastingEngine (the underlying transformer) - train/ - trainer.py # main training loop; pre_compute hook lives here - target_and_aux_diffusion.py # DiffusionLatentTargetEncoder (frozen encoder + pre_compute) - target_and_aux_module_base.py # base class with no-op pre_compute - loss_modules/ - loss_module_latent_diffusion.py # latent-space EDM loss - datasets/ - multi_stream_data_sampler.py # batch construction; diffusion_forecast mode here -config/ - config_diffusion.yml # main config for diffusion experiments -``` From 1844cbf32ae0d2e0ef730133989d9549f009e118 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 17:29:39 +0200 Subject: [PATCH 332/344] implemented cross-attn in fe engine --- config/config_diffusion.yml | 6 +++--- src/weathergen/model/attention.py | 23 +++++++++++++++++++---- src/weathergen/model/diffusion.py | 7 ++++--- src/weathergen/model/engines.py | 24 +++++++++++++++++++++++- src/weathergen/train/trainer.py | 15 +++++++++++++-- 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index ddea20d98..3b9db7f4e 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -61,7 +61,7 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" -fe_diffusion_model_conditioning_type: "concat" #options: "ada_ln", "concat" +fe_diffusion_model_conditioning_type: "concat" #options: "ada_ln", "concat", "cross_attn" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -186,7 +186,7 @@ training_config: training_mode: ["masking","diffusion_forecast"] num_mini_epochs: 128 - samples_per_mini_epoch: 4096 + samples_per_mini_epoch: 128 shuffle: True start_date: 1979-01-01T00:00 @@ -281,7 +281,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: True # test config; full test config is merge of validation and test config diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index f431f51d8..e9e679aac 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -716,12 +716,14 @@ def __init__( qk_norm_type=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + is_dit=False, ): super(MultiCrossAttentionHead, self).__init__() self.num_heads = num_heads self.with_residual = with_residual self.with_flash = with_flash + self.is_dit = is_dit if norm_type == "LayerNorm": norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps) @@ -731,7 +733,14 @@ def __init__( assert dim_embed_q % num_heads == 0 self.dim_head_proj = dim_embed_q // num_heads if dim_head_proj is None else dim_head_proj - self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps) + if is_dit: + assert with_residual + self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed_q, dtype=attention_dtype + ) + else: + self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps) self.lnorm_in_kv = norm(dim_embed_kv, eps=norm_eps) self.proj_heads_q = torch.nn.Linear(dim_embed_q, num_heads * self.dim_head_proj, bias=False) @@ -761,10 +770,16 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) ######################################### - def forward(self, x_q, x_kv): + def forward(self, x_q, x_kv, emb=None): if self.with_residual: x_q_in = x_q - x_q, x_kv = self.lnorm_in_q(x_q), self.lnorm_in_kv(x_kv) + + if self.is_dit: + x_q = self.lnorm_in_q(x_q) + x_q, gate = self.noise_conditioning(x_q, emb) + else: + x_q = self.lnorm_in_q(x_q) + x_kv = self.lnorm_in_kv(x_kv) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention @@ -780,6 +795,6 @@ def forward(self, x_q, x_kv): outs = self.dropout(self.proj_out(outs.flatten(-2, -1))) if self.with_residual: - outs = x_q_in + outs + outs = x_q_in + outs * gate if self.is_dit else x_q_in + outs return outs diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index bd2b60a3d..882d369b1 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,8 +59,8 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) - assert self.conditioning != "forecast" or self.conditioning_type == "concat", ( - f"fe_diffusion_model_conditioning_type must be 'concat' when " + assert self.conditioning != "forecast" or self.conditioning_type in {"concat", "cross_attn"}, ( + f"fe_diffusion_model_conditioning_type must be 'concat' or 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " f"(got '{self.conditioning_type}')" ) @@ -245,8 +245,9 @@ def denoise( net_input = self.cond_proj(torch.cat([net_input, c], dim=-1)) ada_ln_aux = c if self.conditioning_type == "ada_ln" else None + x_kv = c if self.conditioning_type == "cross_attn" else None return c_skip * x + c_out * self.net( - net_input, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux + net_input, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux, x_kv=x_kv ) # Eq. (7) in EDM paper def inference_forward( diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 3d5112ff1..511502b5c 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -17,6 +17,7 @@ from weathergen.common.config import Config from weathergen.datasets.batch import SampleMetaData from weathergen.model.attention import ( + MultiCrossAttentionHead, MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, MultiSelfAttentionHead, @@ -610,6 +611,24 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) + # Add cross-attention block (Q=noised tokens, KV=enc(X_t)) for cross_attn conditioning + if self.cf.get("fe_diffusion_model_conditioning_type") == "cross_attn": + self.fe_blocks.append( + MultiCrossAttentionHead( + dim_embed_q=self.cf.ae_global_dim_embed, + dim_embed_kv=self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + dropout_rate=self.cf.fe_dropout_rate, + with_residual=True, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + is_dit=self.cf.fe_diffusion_model, + ) + ) # Add MLP block self.fe_blocks.append( MLP( @@ -649,6 +668,7 @@ def forward( noise_emb: torch.Tensor = None, ada_ln_aux: torch.Tensor = None, coords: torch.Tensor = None, + x_kv: torch.Tensor = None, ) -> torch.Tensor: # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning # aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") @@ -670,9 +690,11 @@ def forward( for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) + elif isinstance(block, MultiCrossAttentionHead): + assert x_kv is not None, "x_kv (e.g. enc(X_t)) must be provided for cross_attn conditioning" + tokens = checkpoint(block, tokens, x_kv, noise_emb, use_reentrant=False) else: if self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln": - # Assuming ada_ln_aux contains the date_time embedding in this case assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) else: diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 58eb6fd4b..b45ae73f1 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -667,15 +667,26 @@ def validate(self, mini_epoch, mode_cfg, batch_size): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): + source_samples = batch.get_source_samples() + for loss_name, target_aux in self.target_and_aux_calculators_val.items(): + target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) + target_aux.pre_compute( + self.cf.general.istep, + source_samples, + batch.get_target_samples(target_idxs), + self.model_params, + self.model, + ) + if self.ema_model is None: preds = self.model( self.model_params, - batch.get_source_samples(), + source_samples, ) else: preds = self.ema_model.forward_eval( self.model_params, - batch.get_source_samples(), + source_samples, ) targets_and_auxs = {} From 372bab461485537e8bc383266ad74e5678fb7ef3 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 18:18:00 +0200 Subject: [PATCH 333/344] removed concatenation option --- config/config_diffusion.yml | 10 +++++----- src/weathergen/model/diffusion.py | 20 +++----------------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 3b9db7f4e..775f32fc4 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -61,7 +61,7 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" -fe_diffusion_model_conditioning_type: "concat" #options: "ada_ln", "concat", "cross_attn" +fe_diffusion_model_conditioning_type: "cross_attn" # options: "ada_ln", "cross_attn" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -190,7 +190,7 @@ training_config: shuffle: True start_date: 1979-01-01T00:00 - end_date: 2022-12-31T18:00 + end_date: 1979-02-01T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -261,8 +261,8 @@ validation_config: samples_per_mini_epoch: 256 shuffle: True - start_date: 2023-10-01T00:00 - end_date: 2023-12-31T18:00 + start_date: 1979-01-01T00:00 + end_date: 1979-02-01T00:00 # whether to track the exponential moving average of weights for validation validate_with_ema: @@ -281,7 +281,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: True + validate_before_training: False # test config; full test config is merge of validation and test config diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 882d369b1..2f995e7a5 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,21 +59,12 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) - assert self.conditioning != "forecast" or self.conditioning_type in {"concat", "cross_attn"}, ( - f"fe_diffusion_model_conditioning_type must be 'concat' or 'cross_attn' when " + assert self.conditioning != "forecast" or self.conditioning_type == "cross_attn", ( + f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " f"(got '{self.conditioning_type}')" ) - if self.conditioning_type == "concat": - D = self.cf.ae_global_dim_embed - self.cond_proj = torch.nn.Linear(2 * D, D, bias=False) - # Warm-start: pass the noisy input through unchanged, ignore conditioning initially. - # The network can then gradually learn to use the conditioning signal. - with torch.no_grad(): - self.cond_proj.weight[:, :D] = torch.eye(D) - self.cond_proj.weight[:, D:] = 0.0 - if "date" in self.conditioning or "time" in self.conditioning: self.datetime_embedder = DateTimeEncoder(self.conditioning) @@ -236,14 +227,9 @@ def denoise( if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) - # Build network input depending on conditioning type. - # "concat": channel-concatenate enc(X_t) with the preconditioned noisy input and - # project back to D via a learned linear layer (GenCast-style). # "ada_ln": pass conditioning through ada_ln_aux into DiT AdaLN blocks. + # "cross_attn": pass conditioning as KV into cross-attention blocks in ForecastingEngine. net_input = c_in * x - if self.conditioning_type == "concat" and c is not None: - net_input = self.cond_proj(torch.cat([net_input, c], dim=-1)) - ada_ln_aux = c if self.conditioning_type == "ada_ln" else None x_kv = c if self.conditioning_type == "cross_attn" else None return c_skip * x + c_out * self.net( From 13560fc1711ba5e143513fb117174f59d406e098 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 18:20:38 +0200 Subject: [PATCH 334/344] date in config --- config/config_diffusion.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 775f32fc4..cc41e4524 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -190,7 +190,7 @@ training_config: shuffle: True start_date: 1979-01-01T00:00 - end_date: 1979-02-01T00:00 + end_date: 2022-12-31T18:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -261,8 +261,8 @@ validation_config: samples_per_mini_epoch: 256 shuffle: True - start_date: 1979-01-01T00:00 - end_date: 1979-02-01T00:00 + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 # whether to track the exponential moving average of weights for validation validate_with_ema: From 66ff754bea5606c0b2ff507d89fd7fb829242025 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 18:21:18 +0200 Subject: [PATCH 335/344] comment in config --- config/config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index cc41e4524..ed8d830e7 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -61,7 +61,7 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" -fe_diffusion_model_conditioning_type: "cross_attn" # options: "ada_ln", "cross_attn" +fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", (maybe later: "ada_ln", "concat") fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) From 73694397ed61b96de0f2a7567c51c49c7ddc951a Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 19:03:15 +0200 Subject: [PATCH 336/344] minor improvements --- src/weathergen/model/diffusion.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 2f995e7a5..ff8c74889 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -65,7 +65,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"(got '{self.conditioning_type}')" ) - if "date" in self.conditioning or "time" in self.conditioning: + if self.conditioning and (self.conditioning in ["date_time", "date", "time"]): self.datetime_embedder = DateTimeEncoder(self.conditioning) # Parameters @@ -173,17 +173,17 @@ def training_forward( self.cur_token = tokens.detach() + c = None if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: - c = meta_info["ERA5"].params["timestamp"] # TODO: add correct preconditioning (e.g., sample/s in previous time step, datetime encoding, etc.) + c = meta_info["ERA5"].params["timestamp"] y = meta_info["ERA5"].params.get("diffusion_target_tokens") - elif self.cf.fe_diffusion_model_conditioning in ["forecast"]: - # In diffusion_forecast mode, meta_info carries X_{t+1} tokens written by - # DiffusionLatentTargetEncoder.pre_compute(). Noise X_{t+1} as the target and - # use X_t (the incoming rollout state) as the conditioning signal c. - # Falls back to denoising-autoencoder behaviour when no target is cached - # (e.g. validation or unconditional pre-training). + elif self.cf.fe_diffusion_model_conditioning == "forecast": y = meta_info["ERA5"].params.get("diffusion_target_tokens") c = tokens # X_t as conditioning + else: + # Unconditional: denoise the current tokens as an autoencoder + y = tokens + c = None if self.training: eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) @@ -224,7 +224,7 @@ def denoise( # Precondition input and feed through network x = self.preconditioner.precondition(x, c) # currently does nothing - if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) # "ada_ln": pass conditioning through ada_ln_aux into DiT AdaLN blocks. From 810987ab132517365754306fd1ad2cd226ce71d6 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 20 May 2026 19:08:19 +0200 Subject: [PATCH 337/344] assert offset zero --- src/weathergen/model/diffusion.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index ff8c74889..028716be3 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,6 +59,11 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) + _offset = self.cf.get("training_config", {}).get("forecast", {}).get("offset", 0) + assert self.conditioning not in _date_time_modes or _offset == 0, ( + f"forecast.offset must be 0 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got offset={_offset})" + ) assert self.conditioning != "forecast" or self.conditioning_type == "cross_attn", ( f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " From 8acbb01e6f0ce33cd99067d36d7de3be070b7f3c Mon Sep 17 00:00:00 2001 From: Matthias Date: Thu, 21 May 2026 13:45:52 +0200 Subject: [PATCH 338/344] Config for 2048-dim model --- config/config_diffusion_d2048.yml | 319 ++++++++++++++++++ config/config_forecasting_d2048.yml | 256 ++++++++++++++ .../era5_1deg_forecasting_d2048/era5.yml | 111 ++++++ 3 files changed, 686 insertions(+) create mode 100644 config/config_diffusion_d2048.yml create mode 100644 config/config_forecasting_d2048.yml create mode 100644 config/streams/era5_1deg_forecasting_d2048/era5.yml diff --git a/config/config_diffusion_d2048.yml b/config/config_diffusion_d2048.yml new file mode 100644 index 000000000..8af6b57a9 --- /dev/null +++ b/config/config_diffusion_d2048.yml @@ -0,0 +1,319 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: "" +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_samples: 1 + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_forecasting_d2048.yml b/config/config_forecasting_d2048.yml new file mode 100644 index 000000000..ad538f893 --- /dev/null +++ b/config/config_forecasting_d2048.yml @@ -0,0 +1,256 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +fe_diffusion_model: False + +healpix_level: 5 + +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/streams/era5_1deg_forecasting_d2048/era5.yml b/config/streams/era5_1deg_forecasting_d2048/era5.yml new file mode 100644 index 000000000..bff0375f4 --- /dev/null +++ b/config/streams/era5_1deg_forecasting_d2048/era5.yml @@ -0,0 +1,111 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] + stream_id : 0 + source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1. + location_weight : cosine_latitude + token_size : 8 + tokenize_spacetime : True + max_num_targets: 20000 + # max_num_targets: -1 + frequency : 06:00:00 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 + channel_weights : + q_10: 0.2 + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_10: 0.2 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_10: 0.2 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_10: 0.2 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_10: 0.2 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 + \ No newline at end of file From 58bf2d4af6c789c7ffcf2324542294f6aa779562 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 21 May 2026 14:43:52 +0200 Subject: [PATCH 339/344] roll back data flow (not working) --- config/config_diffusion.yml | 9 +-- .../datasets/multi_stream_data_sampler.py | 59 ++++--------------- src/weathergen/model/diffusion.py | 22 ++++--- src/weathergen/model/model.py | 12 +++- .../train/target_and_aux_diffusion.py | 35 +++-------- 5 files changed, 48 insertions(+), 89 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index ed8d830e7..fbc7d9c3b 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -61,7 +61,7 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" -fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", (maybe later: "ada_ln", "concat") +fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", "ada_ln" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -183,7 +183,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking","diffusion_forecast"] + training_mode: ["masking","student_teacher"] num_mini_epochs: 128 samples_per_mini_epoch: 128 @@ -239,14 +239,15 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, - num_samples: 1 + num_samples: 1, + num_steps_input: 2 } } forecast : time_step: 06:00:00 num_steps: 1 - offset: 1 + offset: 0 policy: "fixed" diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index dad0e052a..f1bce67ca 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -650,12 +650,6 @@ def _get_batch(self, idx: int, num_forecast_steps: int): if "student_teacher" in mode or "latent_loss" in mode: source_select += ["network_input"] target_select += ["network_input"] - if "diffusion_forecast" in mode: - # Like student_teacher but target samples encode X_{t+1} instead of X_t. - # _build_stream_data is called with output_data as input so source_tokens_cells - # holds the future state; DiffusionLatentTargetEncoder then encodes X_{t+1}. - source_select += ["network_input", "target_coords"] - target_select += ["network_input"] # remove duplicates source_select, target_select = list(set(source_select)), list(set(target_select)) if len(source_select) == 0 or len(target_select) == 0: @@ -718,46 +712,19 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for tidx, target_mask in enumerate(target_masks.masks): # depending on the mode, the the streamdata obj to have the target mask applied to # the inputs. Hence the target mask is also the source mask here. - if "diffusion_forecast" in mode: - # Shift the input window forward by output_offset steps so that - # source_tokens_cells encodes X_{t+1} rather than X_t. - # Collect X_{t+1} as SOURCE (not reusing output_data which uses target - # channels/normalization — source and target channel sets differ). - step_delta = (self.time_step * self.output_offset) // self.step_timedelta - future_input_data = [ - collect_datasources(stream_ds, idx + step_delta, "source", self.rng) - ] - future_input_tokens = self.tokenizer.get_tokens_windows( - stream_info, future_input_data, True - ) - sdata = self._build_stream_data( - target_select, - idx, - num_forecast_steps, - stream_info, - 1, - future_input_data, - output_data, - future_input_tokens, - output_tokens, - output_mask=target_mask, - input_mask=target_mask, - input_base_idx=idx + step_delta, - ) - else: - sdata = self._build_stream_data( - target_select, - idx, - num_forecast_steps, - stream_info, - target_masks.metadata[tidx].params.get("num_steps_input", 1), - input_data, - output_data, - input_tokens, - output_tokens, - output_mask=target_mask, - input_mask=target_mask, - ) + sdata = self._build_stream_data( + target_select, + idx, + num_forecast_steps, + stream_info, + target_masks.metadata[tidx].params.get("num_steps_input", 1), + input_data, + output_data, + input_tokens, + output_tokens, + output_mask=target_mask, + input_mask=target_mask, + ) target_metadata = target_masks.metadata[tidx] # Get first target step's times (using self.output_offset as the first output step index) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 028716be3..a8fad480d 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -64,7 +64,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"forecast.offset must be 0 when fe_diffusion_model_conditioning is " f"'{self.conditioning}' (got offset={_offset})" ) - assert self.conditioning != "forecast" or self.conditioning_type == "cross_attn", ( + assert self.conditioning != "forecast" or self.conditioning_type in {"cross_attn"}, ( f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " f"(got '{self.conditioning_type}')" @@ -83,6 +83,8 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.cur_token = None # TODO: re move after single sample experiments self._noised_tokens: torch.Tensor | None = None self._fixed_noise_level: float | None = None + self._pending_target_tokens: torch.Tensor | None = None + self._last_noise_level_rn: float | None = None self._noise = None @@ -178,20 +180,22 @@ def training_forward( self.cur_token = tokens.detach() + # y is always the target to denoise (set by DiffusionLatentTargetEncoder.pre_compute) + y = meta_info["ERA5"].params.get("diffusion_target_tokens") + assert y is not None, ( + "diffusion_target_tokens not found in meta_info — " + "DiffusionLatentTargetEncoder.pre_compute must be called before training_forward" + ) + c = None if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - y = meta_info["ERA5"].params.get("diffusion_target_tokens") elif self.cf.fe_diffusion_model_conditioning == "forecast": - y = meta_info["ERA5"].params.get("diffusion_target_tokens") - c = tokens # X_t as conditioning - else: - # Unconditional: denoise the current tokens as an autoencoder - y = tokens - c = None + c = tokens # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) if self.training: - eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) + self._last_noise_level_rn = meta_info["ERA5"].params["noise_level_rn"] + eta = torch.tensor([self._last_noise_level_rn], device=tokens.device) else: # During validation, use fixed noise level (default: 0.0 = mean of noise distribution) noise_level = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 1442e9cf4..3245dfc14 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -728,8 +728,16 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # recover batch dimension and separate input_steps shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) - # collapse along input step dimension - tokens = tokens.reshape(shape).sum(axis=1) + tokens_multi = tokens.reshape(shape) + + if self.cf.get("fe_diffusion_model", False): + # X_t (last step) is the diffusion denoising target; X_{t-1} is the conditioning context. + tokens_xt = tokens_multi[:, -1] + batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] = tokens_xt + self.forecast_engine._pending_target_tokens = tokens_xt + tokens = tokens_multi[:, -2] + else: + tokens = tokens_multi.sum(axis=1) # Allow for pushforward trick p_fwd = self.cf.training_config.get("forecast", {}).get("pushforward", False) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index 08b67e3ad..a2f4cb1cc 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -49,21 +49,7 @@ def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: self.encoder.reshard() def pre_compute(self, istep, source_batch, target_batch, model_params, model, **kwargs) -> None: - """ - Encode the target batch (whose source_tokens_cells holds X_{t+1} in - diffusion_forecast mode) before model.forward() so that training_forward - can noise X_{t+1} and condition on X_t. - - Stores encoded tokens in: - - self._pending_tokens → reused by compute() to skip a second encoder pass - - source_batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] - → flows to DiffusionForecastEngine.training_forward via model.forward's meta_info - """ - with torch.no_grad(): - self.encoder.encoder.eval() - tokens, _ = self.encoder.encoder(model_params=model_params, batch=target_batch) - self._pending_tokens = tokens - source_batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] = tokens + pass def compute( self, @@ -76,23 +62,16 @@ def compute( ) -> tuple[Any, Any]: # During validation (model in eval mode), use fixed noise level # so that sigma = exp(eta * p_std + p_mean) is deterministic - if model.training: - noise_level_rn = ( - batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] - ) # TODO: adjust for multiple streams + noise_level_rn = model.forecast_engine._last_noise_level_rn else: noise_level_rn = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 - # Reuse tokens from pre_compute when available (avoids a second encoder pass). - # Falls back to encoding the batch directly (e.g. during validation). - if self._pending_tokens is not None: - tokens = self._pending_tokens - self._pending_tokens = None - else: - with torch.no_grad(): - self.encoder.encoder.eval() # NOTE: might be redundant - tokens, _ = self.encoder.encoder(model_params=model_params, batch=batch) + # Encode X_t (the diffusion target) directly with the frozen encoder. + # batch here is the target batch, which contains X_t for offset=0. + with torch.no_grad(): + self.encoder.encoder.eval() + tokens, _ = self.encoder.encoder(model_params=model_params, batch=batch) output_idxs = batch.get_output_idxs() assert len(output_idxs) > 0 From 9b652f44795423931ccbaf322f73707c65a5fe06 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 21 May 2026 14:57:24 +0200 Subject: [PATCH 340/344] cleanup rollback --- .../datasets/multi_stream_data_sampler.py | 7 +------ src/weathergen/model/diffusion.py | 5 +---- .../loss_modules/loss_module_latent_diffusion.py | 1 + src/weathergen/train/target_and_aux_diffusion.py | 16 +++++++--------- .../train/target_and_aux_module_base.py | 3 --- src/weathergen/train/trainer.py | 13 +------------ 6 files changed, 11 insertions(+), 34 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index f1bce67ca..82102df6f 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -484,7 +484,6 @@ def _build_stream_data( output_tokens: list, output_mask, input_mask, - input_base_idx: TIndex = None, ) -> StreamData: """ Return one batch of data @@ -500,10 +499,6 @@ def _build_stream_data( output_mask : mask for output/prediction/target input_mask : mask for network input (can be source or target) - input_base_idx: Override time index for the input section's time-window - lookup. When None, falls back to base_idx. Pass a shifted index - (e.g. base_idx + output_offset) together with future source data - to make source_tokens_cells encode a future timestep. Returns: StreamData with source and targets masked according to view_meta @@ -520,7 +515,7 @@ def _build_stream_data( stream_data = self._build_stream_data_input( modes, stream_data, - input_base_idx if input_base_idx is not None else base_idx, + base_idx, stream_info, num_steps_input, input_data, diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index a8fad480d..082cfb015 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -83,8 +83,6 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.cur_token = None # TODO: re move after single sample experiments self._noised_tokens: torch.Tensor | None = None self._fixed_noise_level: float | None = None - self._pending_target_tokens: torch.Tensor | None = None - self._last_noise_level_rn: float | None = None self._noise = None @@ -194,8 +192,7 @@ def training_forward( c = tokens # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) if self.training: - self._last_noise_level_rn = meta_info["ERA5"].params["noise_level_rn"] - eta = torch.tensor([self._last_noise_level_rn], device=tokens.device) + eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) else: # During validation, use fixed noise level (default: 0.0 = mean of noise distribution) noise_level = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 35285ae85..0ef99e894 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -96,6 +96,7 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl.get("latent_state") is not None] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent.get("diffusion_latent") is not None] + assert len(pred_tokens_all) == len(target_tokens_all), "Mismatch in number of forecast steps between predictions and targets." eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index a2f4cb1cc..a857bce0b 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -21,7 +21,6 @@ def __init__(self, encoder, is_model_sharded=True): self.is_model_sharded = is_model_sharded self._fixed_noise_level: float | None = None - self._pending_tokens: torch.Tensor | None = None # Build a name → param map once self.src_params = dict(self.encoder.named_parameters()) @@ -48,9 +47,6 @@ def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: if self.is_model_sharded: self.encoder.reshard() - def pre_compute(self, istep, source_batch, target_batch, model_params, model, **kwargs) -> None: - pass - def compute( self, istep: int, @@ -63,15 +59,17 @@ def compute( # During validation (model in eval mode), use fixed noise level # so that sigma = exp(eta * p_std + p_mean) is deterministic if model.training: - noise_level_rn = model.forecast_engine._last_noise_level_rn + noise_level_rn = ( + batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] + ) # TODO: adjust for multiple streams else: noise_level_rn = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 - # Encode X_t (the diffusion target) directly with the frozen encoder. - # batch here is the target batch, which contains X_t for offset=0. + # TODO: check if there are scenarios where the encoder needs to be set to eval with torch.no_grad(): - self.encoder.encoder.eval() - tokens, _ = self.encoder.encoder(model_params=model_params, batch=batch) + self.encoder.encoder.eval() # NOTE: might be redundant + tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) + # NOTE: must not set to train afterwards unless it was already in train output_idxs = batch.get_output_idxs() assert len(output_idxs) > 0 diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index dcdabeef2..88f9dc93d 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -69,9 +69,6 @@ def __init__(self, cf, model, **kwargs): def reset(self): pass - def pre_compute(self, istep, source_batch, target_batch, model_params, model, **kwargs) -> None: - pass - def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: pass diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b45ae73f1..fe335d1c2 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -486,20 +486,9 @@ def train(self, mini_epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - source_samples = batch.get_source_samples() - for loss_name, target_aux in self.target_and_aux_calculators.items(): - target_idxs = get_target_idxs_from_cfg(self.training_cfg, loss_name) - target_aux.pre_compute( - self.cf.general.istep, - source_samples, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) - preds = self.model( self.model_params, - source_samples, + batch.get_source_samples(), ) targets_and_auxs = {} From 27a3b1b4887ca77a9ed1f464c4ccfc8c1769e2ab Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 21 May 2026 16:30:21 +0200 Subject: [PATCH 341/344] inter commit --- config/config_diffusion.yml | 2 +- src/weathergen/model/diffusion.py | 19 ++++++++++--------- src/weathergen/model/engines.py | 18 ++++++++++-------- src/weathergen/model/model.py | 14 ++++++++------ .../train/target_and_aux_diffusion.py | 3 +++ src/weathergen/train/trainer.py | 14 ++------------ 6 files changed, 34 insertions(+), 36 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index fbc7d9c3b..6b7911d72 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -239,8 +239,8 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 2, num_samples: 1, - num_steps_input: 2 } } diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 082cfb015..56e21bf99 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,11 +59,16 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) - _offset = self.cf.get("training_config", {}).get("forecast", {}).get("offset", 0) + _offset = self.cf.get("training_config", {}).get("forecast", {}).get("offset", 0) assert self.conditioning not in _date_time_modes or _offset == 0, ( f"forecast.offset must be 0 when fe_diffusion_model_conditioning is " f"'{self.conditioning}' (got offset={_offset})" ) + _input_num_steps = self.cf.get("training_config", {}).get("model_input", {}).get("forecasting", {}).get("num_steps_input", 0) + assert self.conditioning != "forecast" or _input_num_steps == 2, ( + f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) assert self.conditioning != "forecast" or self.conditioning_type in {"cross_attn"}, ( f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " @@ -179,7 +184,7 @@ def training_forward( self.cur_token = tokens.detach() # y is always the target to denoise (set by DiffusionLatentTargetEncoder.pre_compute) - y = meta_info["ERA5"].params.get("diffusion_target_tokens") + y = tokens assert y is not None, ( "diffusion_target_tokens not found in meta_info — " "DiffusionLatentTargetEncoder.pre_compute must be called before training_forward" @@ -189,7 +194,7 @@ def training_forward( if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] elif self.cf.fe_diffusion_model_conditioning == "forecast": - c = tokens # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) + c = meta_info["ERA5"].params["conditioning_tokens"] # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) if self.training: eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) @@ -229,17 +234,13 @@ def denoise( noise_emb = self.noise_embedder(c_noise) # Precondition input and feed through network - x = self.preconditioner.precondition(x, c) # currently does nothing if self.conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) - # "ada_ln": pass conditioning through ada_ln_aux into DiT AdaLN blocks. - # "cross_attn": pass conditioning as KV into cross-attention blocks in ForecastingEngine. net_input = c_in * x - ada_ln_aux = c if self.conditioning_type == "ada_ln" else None - x_kv = c if self.conditioning_type == "cross_attn" else None + return c_skip * x + c_out * self.net( - net_input, fstep=fstep, coords=coords, noise_emb=noise_emb, ada_ln_aux=ada_ln_aux, x_kv=x_kv + net_input, fstep=fstep, coords=coords, noise_emb=noise_emb, conditioning=c ) # Eq. (7) in EDM paper def inference_forward( diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 511502b5c..40bec1a75 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -666,9 +666,8 @@ def forward( fstep: int, meta_info: SampleMetaData = None, noise_emb: torch.Tensor = None, - ada_ln_aux: torch.Tensor = None, + conditioning: torch.Tensor = None, coords: torch.Tensor = None, - x_kv: torch.Tensor = None, ) -> torch.Tensor: # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning # aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") @@ -691,21 +690,24 @@ def forward( if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) elif isinstance(block, MultiCrossAttentionHead): - assert x_kv is not None, "x_kv (e.g. enc(X_t)) must be provided for cross_attn conditioning" - tokens = checkpoint(block, tokens, x_kv, noise_emb, use_reentrant=False) + assert conditioning is not None, "conditioning (e.g. enc(X_t)) must be provided for cross_attn conditioning" + tokens = checkpoint(block, tokens, conditioning, noise_emb, use_reentrant=False) else: if self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln": - assert ada_ln_aux is not None, "ada_ln_aux must be provided for diffusion model conditioning" - tokens = checkpoint(block, tokens, coords, noise_emb, ada_ln_aux, use_reentrant=False) + assert conditioning is not None, "conditioning must be provided for diffusion model conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, conditioning, use_reentrant=False) + elif self.cf.get("fe_diffusion_model_conditioning_type", None) == "cross_attn": + assert conditioning is not None, "conditioning (e.g. enc(X_t)) must be provided for cross_attn conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: - assert ada_ln_aux is None, "ada_ln_aux should not be provided when diffusion model conditioning is disabled" + assert conditioning is None, "conditioning should not be provided when diffusion model conditioning is disabled" tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) else: for block in self.fe_blocks: if isinstance(block, torch.nn.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - tokens = checkpoint(block, tokens, coords, ada_ln_aux, use_reentrant=False) + tokens = checkpoint(block, tokens, coords, conditioning, use_reentrant=False) return tokens if not forecast_residual else (tokens_in + tokens) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 3245dfc14..9eb20ab49 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -728,16 +728,18 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # recover batch dimension and separate input_steps shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) - tokens_multi = tokens.reshape(shape) + # Reshape tokens to [B, T, ...] + tokens = tokens.reshape(shape) if self.cf.get("fe_diffusion_model", False): + tokens = tokens.reshape(shape) + conditioning_tokens = tokens[:, -2] # TODO: enable longer history for conditioning # X_t (last step) is the diffusion denoising target; X_{t-1} is the conditioning context. - tokens_xt = tokens_multi[:, -1] - batch.samples[0].meta_info["ERA5"].params["diffusion_target_tokens"] = tokens_xt - self.forecast_engine._pending_target_tokens = tokens_xt - tokens = tokens_multi[:, -2] + batch.samples[0].meta_info["ERA5"].params["conditioning_tokens"] = conditioning_tokens + # self.forecast_engine._pending_target_tokens = diffusion_target_tokens + tokens = tokens[:, -1] else: - tokens = tokens_multi.sum(axis=1) + tokens = tokens.sum(axis=1) # Allow for pushforward trick p_fwd = self.cf.training_config.get("forecast", {}).get("pushforward", False) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index a857bce0b..e3f582009 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -69,6 +69,9 @@ def compute( with torch.no_grad(): self.encoder.encoder.eval() # NOTE: might be redundant tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) + shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) + tokens_multi = tokens.reshape(shape) + tokens = tokens_multi[:, -1] # NOTE: must not set to train afterwards unless it was already in train output_idxs = batch.get_output_idxs() diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index fe335d1c2..f8a6b0792 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -656,26 +656,16 @@ def validate(self, mini_epoch, mode_cfg, batch_size): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - source_samples = batch.get_source_samples() - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): - target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) - target_aux.pre_compute( - self.cf.general.istep, - source_samples, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) if self.ema_model is None: preds = self.model( self.model_params, - source_samples, + batch.get_source_samples(), ) else: preds = self.ema_model.forward_eval( self.model_params, - source_samples, + batch.get_source_samples(), ) targets_and_auxs = {} From 8534dd2e8fddbda529c08ac95928ebb637979210 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 21 May 2026 17:58:28 +0200 Subject: [PATCH 342/344] =?UTF-8?q?fixes=20=E2=80=93=20forecast=20+=20cros?= =?UTF-8?q?s=5Fattn=20should=20run=20now?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_diffusion.yml | 13 +++++++++++-- src/weathergen/model/engines.py | 13 ++++++++----- .../loss_modules/loss_module_latent_diffusion.py | 5 ++--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 6b7911d72..4d49f1836 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -186,7 +186,7 @@ training_config: training_mode: ["masking","student_teacher"] num_mini_epochs: 128 - samples_per_mini_epoch: 128 + samples_per_mini_epoch: 4096 shuffle: True start_date: 1979-01-01T00:00 @@ -244,6 +244,15 @@ training_config: } } + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + forecast : time_step: 06:00:00 num_steps: 1 @@ -282,7 +291,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: True # test config; full test config is merge of validation and test config diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 40bec1a75..dcf2c0992 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -114,6 +114,13 @@ def forward(self, batch, pe_embed): # switch from stream to cell-based ordering and apply per cell positional encoding + # if the assert is hit, max_number_tokens_local_per_cell in config needs to be increased + max_tokens = self.cf.get("ae_local_max_tokens_per_cell", 64) + assert batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens, ( + "max number of tokens per cell for positional encoding exceeded." + ) + " Increase ae_local_max_tokens_per_cell in config." + if batch.tokens_lens.shape[2] == 1: # trivial with one stream tokens_all = torch.cat(x_embeds) @@ -122,10 +129,6 @@ def forward(self, batch, pe_embed): scatter_idxs = self.get_scatter_idxs_vectorized(batch) scatter_idxs = scatter_idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - # if the assert is hit, MAX_NUMBER_TOKENS_LOCAL_PER_CELL needs to be increased - assert ( - batch.tokens_lens.flatten(0, 2).sum(0).max() < MAX_NUMBER_TOKENS_LOCAL_PER_CELL - ), "max number of tokens per cell for positional encoding exceeded" # actual scatter operation and apply per cell positional encoding tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds)) @@ -133,7 +136,7 @@ def forward(self, batch, pe_embed): tokens_all = tokens_all + pe_embed[pe_idxs] return tokens_all - + def get_pe_idxs_vectorized(self, batch): """ Compute per cell indices into positional encoding diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index 0ef99e894..a89c9d6b8 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -94,9 +94,8 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl.get("latent_state") is not None] - target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent.get("diffusion_latent") is not None] - assert len(pred_tokens_all) == len(target_tokens_all), "Mismatch in number of forecast steps between predictions and targets." + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 From 17c11fa4156fcb0e834f62cdd52783c3afbfe21e Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 26 May 2026 13:01:49 +0200 Subject: [PATCH 343/344] Add 2048-dim diffusion configs and inference fixes --- config/config_diffusion.yml | 14 ++++++-- config/config_diffusion_d2048.yml | 34 ++++++++++++++++--- .../era5_1deg_forecasting_d2048/era5.yml | 3 ++ .../datasets/multi_stream_data_sampler.py | 2 +- src/weathergen/model/diffusion.py | 12 +++---- src/weathergen/model/model.py | 7 +++- .../train/target_and_aux_diffusion.py | 12 ++++--- src/weathergen/train/trainer.py | 4 +++ 8 files changed, 68 insertions(+), 20 deletions(-) diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml index 5a318ce0b..8cc954266 100644 --- a/config/config_diffusion.yml +++ b/config/config_diffusion.yml @@ -105,10 +105,10 @@ latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True -# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" # freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" # freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -freeze_modules: "" +# freeze_modules: "" # load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 # load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 # load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 @@ -126,11 +126,19 @@ freeze_modules: "" # load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline # load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline # load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline -load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline # load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone # load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'j74tn8le', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=-1.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'j7lr0jws', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=-1.2, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'cbras2el', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=-0.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'kn3124hp', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=0.0, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'qqbu9852', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=0.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'vqsh3yrl', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=1.0, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'xl8h7vbt', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=1.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'p9m2jwvc', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=2.0, based on m6fs8wvj backbone norm_type: "LayerNorm" diff --git a/config/config_diffusion_d2048.yml b/config/config_diffusion_d2048.yml index 8af6b57a9..275f08761 100644 --- a/config/config_diffusion_d2048.yml +++ b/config/config_diffusion_d2048.yml @@ -60,6 +60,8 @@ fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True fe_diffusion_model: True +fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", "ada_ln" fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -124,11 +126,25 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedd # load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline # load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline # load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline -# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone -# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 -load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting norm_type: "LayerNorm" @@ -242,7 +258,17 @@ training_config: # masking strategy: "random", "healpix", "forecast" masking_strategy: "forecast", masking_strategy_config: {diffusion_rn: True}, - num_samples: 1 + num_steps_input: 2, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, } } diff --git a/config/streams/era5_1deg_forecasting_d2048/era5.yml b/config/streams/era5_1deg_forecasting_d2048/era5.yml index bff0375f4..ed00da42c 100644 --- a/config/streams/era5_1deg_forecasting_d2048/era5.yml +++ b/config/streams/era5_1deg_forecasting_d2048/era5.yml @@ -14,6 +14,9 @@ ERA5 : source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + # source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + # target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + # geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight : 1. location_weight : cosine_latitude token_size : 8 diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 82102df6f..e248555a8 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -97,7 +97,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.streams = cf.streams self.rank = cf.rank self.world_size = cf.world_size - self.diffusion_model_conditioning = cf.fe_diffusion_model_conditioning + self.diffusion_model_conditioning = cf.get("fe_diffusion_model_conditioning", None) self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False) # initialise healpic diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 56e21bf99..0f75910b1 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -50,7 +50,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast self.noise_embedder = NoiseEmbedder( embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim ) - self.conditioning = self.cf.fe_diffusion_model_conditioning + self.conditioning = self.cf.get("fe_diffusion_model_conditioning", None) self.conditioning_type = self.cf.get("fe_diffusion_model_conditioning_type", None) _date_time_modes = {"date_time", "date", "time"} @@ -191,9 +191,9 @@ def training_forward( ) c = None - if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: + if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - elif self.cf.fe_diffusion_model_conditioning == "forecast": + elif self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": c = meta_info["ERA5"].params["conditioning_tokens"] # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) if self.training: @@ -234,7 +234,7 @@ def denoise( noise_emb = self.noise_embedder(c_noise) # Precondition input and feed through network - if self.conditioning in ["date_time", "date", "time"]: + if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) net_input = c_in * x @@ -268,9 +268,9 @@ def inference_forward( # Extract conditioning (mirrors training_forward). c = None - if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]: + if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - elif self.cf.fe_diffusion_model_conditioning == "forecast": + elif self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": # cur_token = enc(X_t) stored in forward() before routing to inference_forward c = self.cur_token diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 9eb20ab49..602c377c7 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -731,7 +731,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Reshape tokens to [B, T, ...] tokens = tokens.reshape(shape) - if self.cf.get("fe_diffusion_model", False): + if self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": tokens = tokens.reshape(shape) conditioning_tokens = tokens[:, -2] # TODO: enable longer history for conditioning # X_t (last step) is the diffusion denoising target; X_{t-1} is the conditioning context. @@ -792,6 +792,11 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: model_params, step, toks, batch, output, out_step=i ) continue + + # # In diffusion inference mode, the final denoised tokens are returned. + # tokens = tokens[-1] + # # Feed the denoised output back as conditioning for the next autoregressive step. + # batch.samples[0].meta_info["ERA5"].params["conditioning_tokens"] = tokens # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py index e3f582009..4079dcc65 100644 --- a/src/weathergen/train/target_and_aux_diffusion.py +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -71,16 +71,18 @@ def compute( tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) tokens_multi = tokens.reshape(shape) - tokens = tokens_multi[:, -1] # NOTE: must not set to train afterwards unless it was already in train output_idxs = batch.get_output_idxs() assert len(output_idxs) > 0 - target_aux_output = TargetAuxOutput(batch.get_output_len(), output_idxs) - - # TODO: currently hard-coding 0 - target_aux_output.add_latent_target(0, "diffusion_latent", tokens) + # The encoder produces a single target latent (tokens_multi[:, -1]) regardless of + # how many forecast steps are requested. Initialise with a single slot so that + # _expand_targets_to_match_preds (in trainer.py) replicates the target across all + # forecast steps automatically — both for T-step autoregressive rollouts and for the + # single-step ODE-trajectory case. + target_aux_output = TargetAuxOutput(1, [0]) + target_aux_output.add_latent_target(0, "diffusion_latent", tokens_multi[:, -1]) # TODO: write function in TargetAuxOutput class target_aux_output.aux_outputs = {"noise_level_rn": noise_level_rn} diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f8a6b0792..0fdd8d895 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -233,6 +233,10 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): # create data loader # only one needed since we only run the validation code path + # Force full maps during inference by disabling target subsampling + for stream_info in cf.streams: + stream_info["max_num_targets"] = -1 + self.dataset = MultiStreamDataSampler( cf, self.test_cfg, From 58ef4defdb4c79a24a57ea39cc18e5e59e6d57e9 Mon Sep 17 00:00:00 2001 From: Moritz Hauschulz <60788263+moritzhauschulz@users.noreply.github.com> Date: Thu, 28 May 2026 10:55:31 +0100 Subject: [PATCH 344/344] [DRAFT] Mh/jk/diffusion full pipeline forecast (#2396) * additional check for num_input_steps * changed configs and fixed error in inference_forward * review PR comments * comma * uncommented roll-out conditioning * update parameters to freeze by default * included l3rxe29i as default loaded model * rm breakpoint * review PR * config changes --- config/config_diffusion_d2048_ERA5.yml | 348 +++++++++++++++++ config/config_diffusion_d2048_date_time.yml | 348 +++++++++++++++++ ...ml => config_diffusion_d2048_forecast.yml} | 14 +- config/config_diffusion_d2048_time.yml | 349 ++++++++++++++++++ config/config_diffusion_d2048_time_aug.yml | 349 ++++++++++++++++++ src/weathergen/model/diffusion.py | 22 +- src/weathergen/model/engines.py | 12 +- src/weathergen/model/model.py | 11 +- src/weathergen/model/utils.py | 3 + 9 files changed, 1432 insertions(+), 24 deletions(-) create mode 100644 config/config_diffusion_d2048_ERA5.yml create mode 100644 config/config_diffusion_d2048_date_time.yml rename config/{config_diffusion_d2048.yml => config_diffusion_d2048_forecast.yml} (95%) create mode 100644 config/config_diffusion_d2048_time.yml create mode 100644 config/config_diffusion_d2048_time_aug.yml diff --git a/config/config_diffusion_d2048_ERA5.yml b/config/config_diffusion_d2048_ERA5.yml new file mode 100644 index 000000000..9e3094e64 --- /dev/null +++ b/config/config_diffusion_d2048_ERA5.yml @@ -0,0 +1,348 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: None # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: None # options: "cross_attn", "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_date_time.yml b/config/config_diffusion_d2048_date_time.yml new file mode 100644 index 000000000..95bcc1f31 --- /dev/null +++ b/config/config_diffusion_d2048_date_time.yml @@ -0,0 +1,348 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "date_time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048.yml b/config/config_diffusion_d2048_forecast.yml similarity index 95% rename from config/config_diffusion_d2048.yml rename to config/config_diffusion_d2048_forecast.yml index 275f08761..fbace5174 100644 --- a/config/config_diffusion_d2048.yml +++ b/config/config_diffusion_d2048_forecast.yml @@ -104,10 +104,13 @@ latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True -# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only # freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" -# freeze_modules: "" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + # load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 # load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 # load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 @@ -130,7 +133,7 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedd # load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone # load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 # load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 -# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline # load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone # load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone # load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone @@ -142,7 +145,8 @@ freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedd # load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting # load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting # load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting -load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting (deterministic, precursor to ug7huxi2) # load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting # load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting diff --git a/config/config_diffusion_d2048_time.yml b/config/config_diffusion_d2048_time.yml new file mode 100644 index 000000000..8e100832f --- /dev/null +++ b/config/config_diffusion_d2048_time.yml @@ -0,0 +1,349 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_time_aug.yml b/config/config_diffusion_d2048_time_aug.yml new file mode 100644 index 000000000..f0011735f --- /dev/null +++ b/config/config_diffusion_d2048_time_aug.yml @@ -0,0 +1,349 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2022-08-01T00:00 + end_date: 2022-08-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-08-01T00:00 + end_date: 2023-08-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 0f75910b1..7909534d1 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,6 +59,11 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) + _ada_ln = self.conditioning_type == "ada_ln" + assert self.cf.get("diffusion_conditioning_embed_dim", None) is not None or not _ada_ln, ( + f"diffusion_conditioning_embed_dim must be set when " + f"fe_diffusion_model_conditioning_type is 'ada_ln'" + ) _offset = self.cf.get("training_config", {}).get("forecast", {}).get("offset", 0) assert self.conditioning not in _date_time_modes or _offset == 0, ( f"forecast.offset must be 0 when fe_diffusion_model_conditioning is " @@ -69,6 +74,10 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" ) + assert self.conditioning not in ["date_time", "date", "time"] or _input_num_steps == 1, ( + f"forecast.input_num_steps must be 1 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) assert self.conditioning != "forecast" or self.conditioning_type in {"cross_attn"}, ( f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " @@ -191,9 +200,9 @@ def training_forward( ) c = None - if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - elif self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": + elif self.conditioning == "forecast": c = meta_info["ERA5"].params["conditioning_tokens"] # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) if self.training: @@ -234,7 +243,7 @@ def denoise( noise_emb = self.noise_embedder(c_noise) # Precondition input and feed through network - if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) net_input = c_in * x @@ -268,11 +277,10 @@ def inference_forward( # Extract conditioning (mirrors training_forward). c = None - if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - elif self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": - # cur_token = enc(X_t) stored in forward() before routing to inference_forward - c = self.cur_token + elif self.conditioning == "forecast": + c = meta_info["ERA5"].params["conditioning_tokens"] # Sample pure noise (assuming single batch element for now) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index dcf2c0992..6c8e847db 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -589,7 +589,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) @@ -610,12 +610,12 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) # Add cross-attention block (Q=noised tokens, KV=enc(X_t)) for cross_attn conditioning - if self.cf.get("fe_diffusion_model_conditioning_type") == "cross_attn": + if self.cf.get("fe_diffusion_model_conditioning_type", None) == "cross_attn": self.fe_blocks.append( MultiCrossAttentionHead( dim_embed_q=self.cf.ae_global_dim_embed, @@ -629,7 +629,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), ) ) # Add MLP block @@ -644,7 +644,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) @@ -685,7 +685,7 @@ def forward( if forecast_residual: tokens_in = tokens - if self.cf.fe_diffusion_model: + if self.cf.get("fe_diffusion_model", False): assert noise_emb is not None, ( "noise_emb must be provided for diffusion model conditioning" ) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 602c377c7..20590e5bf 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -746,7 +746,6 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): - without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): @@ -791,12 +790,12 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: output = self.predict_latent( model_params, step, toks, batch, output, out_step=i ) + # Feed the final denoised state back as conditioning for the next step. + # Pass tokens[-1] forward so inference diagnostics have a reference point; + # inference_forward always starts from pure noise regardless. + batch.samples[0].meta_info["ERA5"].params["conditioning_tokens"] = tokens[-1] + tokens = None #NOTE: This is precautionary, might need to be handled differently. It should not be the same as conditioning tokens. continue - - # # In diffusion inference mode, the final denoised tokens are returned. - # tokens = tokens[-1] - # # Feed the denoised output back as conditioning for the next autoregressive step. - # batch.samples[0].meta_info["ERA5"].params["conditioning_tokens"] = tokens # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/model/utils.py b/src/weathergen/model/utils.py index 7dd2060bb..865d826d3 100644 --- a/src/weathergen/model/utils.py +++ b/src/weathergen/model/utils.py @@ -49,6 +49,9 @@ def apply_fct_to_blocks(model, blocks, fct): # avoid the whole model element which has name '' if (re.fullmatch(blocks, name) is not None) and (name != ""): fct(module) + logger.info(f"Applied function {fct.__name__} to block {name}") + else: + logger.info(f"Did not apply function {fct.__name__} to block {name}") class ActivationFactory: