From 2f5d1230a66d77f90669f9cf435dbc5da0c87dd9 Mon Sep 17 00:00:00 2001 From: Wael Almikaeel Date: Mon, 18 May 2026 15:54:00 +0200 Subject: [PATCH 1/2] updating mse_ens and linting --- src/weathergen/model/engines.py | 6 +-- .../train/loss_modules/loss_functions.py | 44 +++++++++++++++++-- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 72486da2f..5369e6c6d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -111,9 +111,9 @@ def forward(self, batch, pe_embed): # if the assert is hit, max_number_tokens_local_per_cell in config needs to be increased max_tokens = self.cf.get("ae_local_max_tokens_per_cell", 64) - assert ( - batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens - ), "max number of tokens per cell for positional encoding exceeded." + assert batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens, ( + "max number of tokens per cell for positional encoding exceeded." + ) " Increase ae_local_max_tokens_per_cell in config." if batch.tokens_lens.shape[2] == 1: diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index f9f173fcf..66d8925f4 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -62,9 +62,47 @@ def stats_normalized_erf(target, ens, mu, stddev): return torch.mean(d * d) # + torch.mean( torch.sqrt( stddev) ) -def mse_ens(target, ens, mu, stddev): - mse_loss = torch.nn.functional.mse_loss - return torch.stack([mse_loss(target, mem) for mem in ens], 0).mean() +def mse_ens( + target: torch.Tensor, + pred: torch.Tensor, + weights_channels: torch.Tensor | None, + weights_points: torch.Tensor | None, + use_ensemble_mean: bool = False, +): + """ + MSE loss for ensemble predictions, with two modes: + + use_ensemble_mean=False (default): + Mean of per-member MSE — equivalent to mean(mse(target, mem) for mem in ens). + Penalises every member independently; each member is pushed toward the target. + + use_ensemble_mean=True: + MSE of the ensemble mean against the target. + Collapses the ensemble to a single prediction before comparing, which + ignores spread and rewards a well-calibrated ensemble mean. + + target : shape (num_data_points, num_channels) + pred : shape (ens_dim, num_data_points, num_channels) + weights_channels : shape (num_channels,) or None + weights_points : shape (num_data_points,) or None + """ + mask_nan = ~torch.isnan(target) + t = torch.where(mask_nan, target, torch.zeros_like(target)) + p = torch.where(mask_nan.unsqueeze(0), pred, torch.zeros_like(pred)) + + if use_ensemble_mean: + # MSE( target, mean_over_members(pred) ) + diff_sq = (t - p.mean(0)).pow(2) # [num_data_points, num_channels] + else: + # mean_over_members( MSE(target, member) ) + diff_sq = (t.unsqueeze(0) - p).pow(2).mean(0) # [num_data_points, num_channels] + + if weights_points is not None: + diff_sq = (diff_sq.transpose(1, 0) * weights_points).transpose(1, 0) + + loss_chs = diff_sq.mean(0) # [num_channels] + loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) + return loss, loss_chs def kernel_crps( From 8d4de160a8ef76dd517cb02661f362c39a861e2b Mon Sep 17 00:00:00 2001 From: Wael Almikaeel Date: Mon, 18 May 2026 20:20:11 +0200 Subject: [PATCH 2/2] updating mse_ens to use mse functionality --- .../train/loss_modules/loss_functions.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index 66d8925f4..0d741bc28 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -86,23 +86,15 @@ def mse_ens( weights_channels : shape (num_channels,) or None weights_points : shape (num_data_points,) or None """ - mask_nan = ~torch.isnan(target) - t = torch.where(mask_nan, target, torch.zeros_like(target)) - p = torch.where(mask_nan.unsqueeze(0), pred, torch.zeros_like(pred)) - if use_ensemble_mean: - # MSE( target, mean_over_members(pred) ) - diff_sq = (t - p.mean(0)).pow(2) # [num_data_points, num_channels] - else: - # mean_over_members( MSE(target, member) ) - diff_sq = (t.unsqueeze(0) - p).pow(2).mean(0) # [num_data_points, num_channels] - - if weights_points is not None: - diff_sq = (diff_sq.transpose(1, 0) * weights_points).transpose(1, 0) + # lp_loss collapses the ensemble via .mean(0) before computing MSE + return mse(target, pred, weights_channels, weights_points) - loss_chs = diff_sq.mean(0) # [num_channels] - loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) - return loss, loss_chs + losses, losses_chs = zip( + *[mse(target, member.unsqueeze(0), weights_channels, weights_points) for member in pred], + strict=False, + ) + return torch.stack(list(losses)).mean(), torch.stack(list(losses_chs)).mean(0) def kernel_crps(