From 0bf0662e334e88a1f701027464aa72f109d1b82b Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Mon, 18 May 2026 13:56:49 +0200 Subject: [PATCH 1/8] Refactor EncoderModule and ModelParams --- src/weathergen/model/encoder.py | 143 +++++++++++++++++++++++++++--- src/weathergen/model/model.py | 151 ++------------------------------ 2 files changed, 142 insertions(+), 152 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 54409e297..3f301235c 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -10,6 +10,8 @@ import torch from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint +import math +import numpy as np from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch @@ -25,6 +27,8 @@ # from weathergen.model.model import ModelParams from weathergen.model.parametrised_prob_dist import LatentInterpolator from weathergen.model.positional_encoding import positional_encoding_harmonic +from weathergen.utils.utils import get_dtype +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 class EncoderModule(torch.nn.Module): @@ -44,6 +48,56 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**self.healpix_level + self.dtype = get_dtype(cf.attention_dtype) + + # Positional embeddings + self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) + self.pe_embed = torch.nn.Parameter( + torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), + requires_grad=False, + ) + + self.q_cells_lens = torch.nn.Parameter( + torch.ones(self.num_healpix_cells + 1, dtype=torch.int32), requires_grad=False + ) + self.q_cells_lens.data[0] = 0 + + pe = torch.zeros( + self.num_healpix_cells, + cf.ae_local_num_queries, + cf.ae_global_dim_embed, + dtype=self.dtype, + ) + self.pe_global = torch.nn.Parameter(pe, requires_grad=False) + + # RoPE coordinates + self.rope_2D = cf.get("rope_2D", False) + if self.rope_2D: + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + 1, + total_tokens, + 2, + dtype=self.dtype, + ), + ) + self.register_buffer( + "rope_cell_coords", + torch.zeros( + self.num_healpix_cells, + 2, + dtype=self.dtype, + ), + ) + else: + self.rope_coords = None + self.rope_cell_coords = None + self.cf = cf self.sources_size = sources_size self.targets_num_channels = targets_num_channels @@ -117,29 +171,98 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) - def forward(self, model_params, batch): + def reset_parameters(self) -> None: + """Creates positional embedding for each grid point for each stream used after stream + embedding, positional embedding for all stream assimilated cell-level local embedding, + initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter + initializing for target prediction. + + Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for + both per stream after stream embedding and per cell level for local assimilation. + + Query len based parameter creation: Calculate parameters for the calculated token length at + each cell after local assimilation.""" + + cf = self.cf + + dim_embed = cf.ae_local_dim_embed + token_idx_bias = 16 + freq_bias = 8 + self.pe_embed.data.fill_(0.0) + position = torch.arange( + token_idx_bias, + token_idx_bias + self.max_tokens_local_per_cell, + device=self.pe_embed.device, + ).unsqueeze(1) + div = torch.exp( + torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device) + * -(math.log(self.max_tokens_local_per_cell) / dim_embed), + ) + self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) + self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) + + dim_embed = cf.ae_global_dim_embed + + if self.rope_2D: + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + self.rope_cell_coords.data.copy_(coords) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0) + num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + offset = num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + self.pe_global.data.fill_(0.0) + xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed + self.pe_global.data[..., 0::2] = 0.5 * torch.sin( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 0::2] += ( + torch.sin( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + self.pe_global.data[..., 1::2] = 0.5 * torch.cos( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 1::2] += ( + torch.cos( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + + self.q_cells_lens.data.fill_(1) + self.q_cells_lens.data[0] = 0 + + def forward(self, batch): """ Encoder forward """ stream_cell_tokens = checkpoint( - self.embed_engine, batch, model_params.pe_embed, use_reentrant=False + self.embed_engine, batch, self.pe_embed, use_reentrant=False ) tokens_global, posteriors = checkpoint( - self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False + self.assimilate_local, stream_cell_tokens, batch, use_reentrant=False ) tokens_global = checkpoint( self.ae_global_engine, tokens_global, - coords=model_params.rope_coords, + coords=self.rope_coords, use_reentrant=False, ) return tokens_global, posteriors - def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor): + def interpolate_latents(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ " TODO """ @@ -273,7 +396,7 @@ def aggregation_engine_unmasked( return tokens_global_unmasked def assimilate_local( - self, model_params, tokens: torch.Tensor, batch: ModelBatch + self, tokens: torch.Tensor, batch: ModelBatch ) -> torch.Tensor: """ Processes embedded tokens locally and prepares them for the global assimilation @@ -299,15 +422,15 @@ def assimilate_local( # TODO: re-enable or remove ae_local_queries_per_cell if self.cf.ae_local_queries_per_cell: - tokens_global = (self.q_cells + model_params.pe_global).repeat(rs, 1, 1) + tokens_global = (self.q_cells + self.pe_global).repeat(rs, 1, 1) else: num_tokens = self.num_healpix_cells - tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + model_params.pe_global + tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + self.pe_global tokens_global = tokens_global.repeat(rs, 1, 1) # apply local assimilation engine and project onto global latent vectors tokens_global_unmasked, posteriors = self.assimilate_local_project_chunked( - tokens, tokens_global, cell_lens, model_params.q_cells_lens + tokens, tokens_global, cell_lens, self.q_cells_lens ) # apply aggregation engine on unmasked tokens @@ -315,7 +438,7 @@ def assimilate_local( tokens_global_unmasked, tokens_global_register_class, batch.tokens_lens, - rope_cell_coords=model_params.rope_cell_coords, + rope_cell_coords=self.rope_cell_coords, ) # final processing diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ce046d3b3..9e6e6d3d7 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -10,11 +10,9 @@ # nor does it submit to any jurisdiction. import logging -import math import warnings import astropy_healpix as hp -import astropy_healpix.healpy import numpy as np import torch import torch.nn as nn @@ -22,7 +20,6 @@ from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch -from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, @@ -91,50 +88,6 @@ def __init__(self, cf) -> None: self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**cf.healpix_level - self.dtype = get_dtype(cf.attention_dtype) - - # Positional embeddings - self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) - self.pe_embed = torch.nn.Parameter( - torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), - requires_grad=False, - ) - - pe = torch.zeros( - self.num_healpix_cells, - cf.ae_local_num_queries, - cf.ae_global_dim_embed, - dtype=self.dtype, - ) - self.pe_global = torch.nn.Parameter(pe, requires_grad=False) - - # RoPE coordinates - self.rope_2D = cf.get("rope_2D", False) - if self.rope_2D: - self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens - total_tokens = ( - self.num_healpix_cells + self.num_extra_tokens - ) * cf.ae_local_num_queries - self.register_buffer( - "rope_coords", - torch.zeros( - 1, - total_tokens, - 2, - dtype=self.dtype, - ), - ) - self.register_buffer( - "rope_cell_coords", - torch.zeros( - self.num_healpix_cells, - 2, - dtype=self.dtype, - ), - ) - else: - self.rope_coords = None - self.rope_cell_coords = None # HEALPix neighbours hlc = self.healpix_level @@ -150,97 +103,16 @@ def __init__(self, cf) -> None: requires_grad=False, ) - self.q_cells_lens = torch.nn.Parameter( - torch.ones(self.num_healpix_cells + 1, dtype=torch.int32), requires_grad=False - ) - self.q_cells_lens.data[0] = 0 - def create(self, cf: Config) -> "ModelParams": - self.reset_parameters(cf) + self.reset_parameters() return self - def reset_parameters(self, cf: Config) -> "ModelParams": - """Creates positional embedding for each grid point for each stream used after stream - embedding, positional embedding for all stream assimilated cell-level local embedding, - initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter - initializing for target prediction. - - Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for - both per stream after stream embedding and per cell level for local assimilation. - - HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize each + def reset_parameters(self) -> "ModelParams": + """HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize each with its own cell number as well as the cell numbers of its neighbors. If a cell has fewer than eight neighbors, use its own cell number to fill the remaining slots. - Query len based parameter creation: Calculate parameters for the calculated token length at - each cell after local assimilation. - - Args: - cf : Configuration """ - - # positional encodings - - dim_embed = cf.ae_local_dim_embed - token_idx_bias = 16 - freq_bias = 8 - self.pe_embed.data.fill_(0.0) - position = torch.arange( - token_idx_bias, - token_idx_bias + self.max_tokens_local_per_cell, - device=self.pe_embed.device, - ).unsqueeze(1) - div = torch.exp( - torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device) - * -(math.log(self.max_tokens_local_per_cell) / dim_embed), - ) - self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) - self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) - - dim_embed = cf.ae_global_dim_embed - - if self.rope_2D: - # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. - # Shape: (num_healpix_cells, ae_local_num_queries, 2) - verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) - coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) - # Per-cell coords for QueryAggregationEngine (no query expansion) - self.rope_cell_coords.data.copy_(coords) - coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) - coords_flat = coords.flatten(0, 1).unsqueeze(0) - offset = self.num_extra_tokens * cf.ae_local_num_queries - self.rope_coords.data.fill_(0.0) - self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) - - # pe_global: always initialized. RoPE handles relative position in Q/K, but pe_global - # provides per-cell token identity which is critical for masked cells that have no - # content from local assimilation. Without it, masked cells are identical and the - # teacher representation (evaluated without dropout) collapses to low rank. - self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed - self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 0::2] += ( - torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 1::2] += ( - torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - - # healpix neighborhood structure - hlc = self.healpix_level num_healpix_cells = self.num_healpix_cells with warnings.catch_warnings(action="ignore"): @@ -252,12 +124,6 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.hp_nbours.data[:, 0] = torch.arange(temp.shape[0], device=self.hp_nbours.device) self.hp_nbours.data[:, 1:] = torch.from_numpy(temp).to(self.hp_nbours.device) - # precompute for varlen attention - self.q_cells_lens.data.fill_(1) - self.q_cells_lens.data[0] = 0 - - # ensure all params have grad set to False - return @@ -599,6 +465,8 @@ def _reset_params(module): pass self.apply(_reset_params) + if self.encoder is not None: + self.encoder.reset_parameters() def print_num_parameters(self) -> None: """Print number of parameters for entire model and each module used to build the model""" @@ -695,7 +563,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: output = ModelOutput(batch.get_output_len()) - tokens, posteriors = self.encoder(model_params, batch) + tokens, posteriors = self.encoder(batch) output.add_latent_prediction(0, "posteriors", posteriors) # recover batch dimension and separate input_steps @@ -710,20 +578,19 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): - tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + tokens = self.forecast_engine(tokens, step, self.encoder.rope_coords) continue - tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + tokens = self.forecast_engine(tokens, step, self.encoder.rope_coords) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) - output = self.predict_latent(model_params, step, tokens, batch, output) + output = self.predict_latent(step, tokens, batch, output) return output def predict_latent( self, - model_params: ModelParams, step: int, tokens: torch.Tensor, batch: ModelBatch, From a95bc9e9fc63f3b8a367fd1f11509c75880a32a8 Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Mon, 18 May 2026 22:40:40 +0200 Subject: [PATCH 2/8] Add ROPE coords to ForecastingEngine and update reset_parameters method --- src/weathergen/model/engines.py | 43 +++++++++++++++++++++++++++++++-- src/weathergen/model/model.py | 6 +++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 72486da2f..59b3d514b 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -31,6 +31,7 @@ from weathergen.model.layers import MLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 class EmbeddingEngine(torch.nn.Module): @@ -555,6 +556,27 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.cf = cf self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() + self.rope_2D = cf.get("rope_2D", False) + self.healpix_level = cf.healpix_level + self.dtype = get_dtype(cf.attention_dtype) + + + if self.rope_2D: + num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + 1, + total_tokens, + 2, + dtype=self.dtype + ), + ) + else: + self.rope_coords = None global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: @@ -621,7 +643,24 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep, coords=None): + + def reset_parameters(self) -> None: + """HEALPix neighbourhood based parameter initializing for target prediction.""" + + cf = self.cf + + if self.rope_2D: + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0) + num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + offset = num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + + def forward(self, tokens, fstep): if self.training: # Impute noise to the latent state noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) @@ -633,7 +672,7 @@ def forward(self, tokens, fstep, coords=None): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) + tokens = checkpoint(block, tokens, self.rope_coords, aux_info, use_reentrant=False) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 9e6e6d3d7..40ff7b675 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -467,6 +467,8 @@ def _reset_params(module): self.apply(_reset_params) if self.encoder is not None: self.encoder.reset_parameters() + if self.forecast_engine is not None: + self.forecast_engine.reset_parameters() def print_num_parameters(self) -> None: """Print number of parameters for entire model and each module used to build the model""" @@ -578,10 +580,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): - tokens = self.forecast_engine(tokens, step, self.encoder.rope_coords) + tokens = self.forecast_engine(tokens, step) continue - tokens = self.forecast_engine(tokens, step, self.encoder.rope_coords) + tokens = self.forecast_engine(tokens, step) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) From 585fe84c509072fc05cb280341c9dc31392c0960 Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Tue, 19 May 2026 11:29:15 +0200 Subject: [PATCH 3/8] clean up code formatting in encoder, engines, and model modules --- src/weathergen/model/encoder.py | 14 ++++++-------- src/weathergen/model/engines.py | 24 +++++++----------------- src/weathergen/model/model.py | 7 +++---- 3 files changed, 16 insertions(+), 29 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 3f301235c..724940c9d 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -7,14 +7,16 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import math + +import numpy as np import torch from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint -import math -import numpy as np from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.engines import ( EmbeddingEngine, GlobalAssimilationEngine, @@ -28,7 +30,6 @@ from weathergen.model.parametrised_prob_dist import LatentInterpolator from weathergen.model.positional_encoding import positional_encoding_harmonic from weathergen.utils.utils import get_dtype -from weathergen.datasets.utils import healpix_verts_rots, r3tos2 class EncoderModule(torch.nn.Module): @@ -50,7 +51,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.dtype = get_dtype(cf.attention_dtype) - # Positional embeddings + # Positional embeddings self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) self.pe_embed = torch.nn.Parameter( torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), @@ -98,7 +99,6 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.rope_coords = None self.rope_cell_coords = None - self.cf = cf self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size @@ -395,9 +395,7 @@ def aggregation_engine_unmasked( return tokens_global_unmasked - def assimilate_local( - self, tokens: torch.Tensor, batch: ModelBatch - ) -> torch.Tensor: + def assimilate_local(self, tokens: torch.Tensor, batch: ModelBatch) -> torch.Tensor: """ Processes embedded tokens locally and prepares them for the global assimilation diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 59b3d514b..00159ffdb 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -16,6 +16,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -31,7 +32,6 @@ from weathergen.model.layers import MLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype -from weathergen.datasets.utils import healpix_verts_rots, r3tos2 class EmbeddingEngine(torch.nn.Module): @@ -112,9 +112,9 @@ def forward(self, batch, pe_embed): # if the assert is hit, max_number_tokens_local_per_cell in config needs to be increased max_tokens = self.cf.get("ae_local_max_tokens_per_cell", 64) - assert ( - batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens - ), "max number of tokens per cell for positional encoding exceeded." + assert batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens, ( + "max number of tokens per cell for positional encoding exceeded." + ) " Increase ae_local_max_tokens_per_cell in config." if batch.tokens_lens.shape[2] == 1: @@ -560,20 +560,12 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.healpix_level = cf.healpix_level self.dtype = get_dtype(cf.attention_dtype) - if self.rope_2D: num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens - total_tokens = ( - self.num_healpix_cells + num_extra_tokens - ) * cf.ae_local_num_queries + total_tokens = (self.num_healpix_cells + num_extra_tokens) * cf.ae_local_num_queries self.register_buffer( - "rope_coords", - torch.zeros( - 1, - total_tokens, - 2, - dtype=self.dtype - ), + "rope_coords", + torch.zeros(1, total_tokens, 2, dtype=self.dtype), ) else: self.rope_coords = None @@ -643,7 +635,6 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def reset_parameters(self) -> None: """HEALPix neighbourhood based parameter initializing for target prediction.""" @@ -659,7 +650,6 @@ def reset_parameters(self) -> None: self.rope_coords.data.fill_(0.0) self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) - def forward(self, tokens, fstep): if self.training: # Impute noise to the latent state diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 40ff7b675..569f8e5c1 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -108,10 +108,9 @@ def create(self, cf: Config) -> "ModelParams": return self def reset_parameters(self) -> "ModelParams": - """HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize each - with its own cell number as well as the cell numbers of its neighbors. If a cell has - fewer than eight neighbors, use its own cell number to fill the remaining slots. - + """HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize + each with its own cell number as well as the cell numbers of its neighbors. If a cell has + fewer than eight neighbors, use its own cell number to fill the remaining slots. """ hlc = self.healpix_level num_healpix_cells = self.num_healpix_cells From 77e733bfc6d994d8a8036c072135afe02bfbfc70 Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Tue, 19 May 2026 15:44:24 +0200 Subject: [PATCH 4/8] Bug fixes --- src/weathergen/model/encoder.py | 4 ++-- src/weathergen/model/engines.py | 12 +++++------- src/weathergen/model/model_interface.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 5b5044b61..f949304f2 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -317,7 +317,7 @@ def forward(self, batch): coords=( self.rope_spherical_coeffs.unbind(dim=-1) if self.rope_spherical_coeffs is not None - else self.self.rope_coords + else self.rope_coords ), use_reentrant=False, ) @@ -512,7 +512,7 @@ def assimilate_local(self, tokens: torch.Tensor, batch: ModelBatch) -> torch.Ten tokens_global_register_class, batch.tokens_lens, rope_cell_coords=self.rope_cell_coords, - rope_cell_coeffs=self.rope_spherical_cells_coeffs, + rope_cell_coeffs=self.rope_spherical_cell_coeffs, rope_extra_coeffs=self.rope_spherical_extra_coeffs, ) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 6469f43d9..c25fc34e5 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -565,7 +565,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = rope_mode = get_rope_mode(self.cf) self.fe_blocks = torch.nn.ModuleList() self.rope_2D = cf.get("rope_2D", False) - self.healpix_level = cf.healpix_level + self.healpix_level = cf.get("fe_healpix_level") if cf.get("fe_healpix_level") is not None else cf.get("healpix_level") self.dtype = get_dtype(cf.attention_dtype) # RoPE coordinates @@ -743,12 +743,10 @@ def forward(self, tokens, fstep): tokens = checkpoint( block, tokens, - coords=( - self.rope_spherical_coeffs.unbind(dim=-1) - if self.rope_spherical_coeffs is not None - else self.rope_coords - ), - aux_info=aux_info, + self.rope_spherical_coeffs.unbind(dim=-1) + if self.rope_spherical_coeffs is not None + else self.rope_coords, + aux_info, use_reentrant=False, ) return tokens diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 1423652f7..420a32abf 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -164,7 +164,7 @@ def init_model_and_shard( # model params model_params = ModelParams(cf).create(cf) - model_params.reset_parameters(cf) + model_params.reset_parameters() model_params = model_params.to(f"cuda:{cf.local_rank}") return model, model_params From 11a6fb99d3344d9113d332755793e846aca690c5 Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Tue, 19 May 2026 16:39:18 +0200 Subject: [PATCH 5/8] Fix FSDP2 sharding bug by converting non-trainable EncoderModule tensors to buffers --- src/weathergen/model/encoder.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index f949304f2..e6b976ad1 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -60,23 +60,21 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # Positional embeddings self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) - self.pe_embed = torch.nn.Parameter( + self.register_buffer( + "pe_embed", torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), - requires_grad=False, ) - self.q_cells_lens = torch.nn.Parameter( - torch.ones(self.num_healpix_cells + 1, dtype=torch.int32), requires_grad=False + self.register_buffer( + "q_cells_lens", torch.ones(self.num_healpix_cells + 1, dtype=torch.int32) ) - self.q_cells_lens.data[0] = 0 + self.q_cells_lens[0] = 0 - pe = torch.zeros( - self.num_healpix_cells, - cf.ae_local_num_queries, - cf.ae_global_dim_embed, - dtype=self.dtype, + self.register_buffer( + "pe_global", + torch.zeros(self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed, dtype=self.dtype), ) - self.pe_global = torch.nn.Parameter(pe, requires_grad=False) + # RoPE coordinates self.rope_mode = get_rope_mode(cf, logger) From 1e911273de3a15aef2d41e869c487625b25c2b2c Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Wed, 20 May 2026 12:30:58 +0200 Subject: [PATCH 6/8] Add pipeline configuration for testing; update Trainer to handle plain vectors/buffers like pe_embed --- config/pipeline.yml | 37 +++++++++++++++++++++++++++++++++ src/weathergen/train/trainer.py | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 config/pipeline.yml diff --git a/config/pipeline.yml b/config/pipeline.yml new file mode 100644 index 000000000..ad25e58e3 --- /dev/null +++ b/config/pipeline.yml @@ -0,0 +1,37 @@ +stages: + - name: pretrain + stage: train + config_files: + - config/config_forecasting.yml + options: + - training_config.num_mini_epochs=32 + chain_jobs: 4 + nodes: 2 + slurm_args: + - "--time=10:00:00" + + - name: finetune + stage: train + from_run_id: STAGE.pretrain + config_files: + - config/config_forecasting_finetuning.yml + chain_jobs: 2 + nodes: 2 + slurm_args: + - "--time=10:00:00" + + - name: inference + stage: inference + from_run_id: STAGE.pretrain + options: + - training_config.forecast.num_steps=120 + - test_config.output.num_samples=10 + - test_config.start_date=202310010000 + - test_config.end_date=202312300000 + - test_config.samples_per_mini_epoch=128 + chain_jobs: 2 + nodes: 1 + slurm_args: + - "--time=10:00:00" + + diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f8cb9cafe..b222b92dc 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -652,7 +652,7 @@ def _get_full_model_state_dict(self): if self.cf.with_ddp and self.cf.with_fsdp: cpu_state_dict = {} for param_name, sharded_param in maybe_sharded_sd.items(): - full_param = sharded_param.full_tensor() + full_param = sharded_param.full_tensor() if isinstance(sharded_param, DTensor) else sharded_param if is_root(): cpu_state_dict[param_name] = full_param.cpu() else: From 987929eca2c9e5f381d5f09c7db4d459efbd8d6a Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Thu, 21 May 2026 11:01:40 +0200 Subject: [PATCH 7/8] Fix load_model when plain tensor --- src/weathergen/model/model_interface.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 420a32abf..d75268169 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -17,7 +17,7 @@ MixedPrecisionPolicy, fully_shard, ) -from torch.distributed.tensor import distribute_tensor +from torch.distributed.tensor import DTensor, distribute_tensor from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( @@ -196,13 +196,15 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): if sharded_meta_param is None: logger.warning(f"Parameter {param_name} from checkpoint not found in model.") continue - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) - maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + if isinstance(sharded_meta_param, DTensor): + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + else: + maybe_sharded_sd[param_name] = full_tensor # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True) From 152cc6dd6c05660f7dd39c5a0d46a5f71fae58d0 Mon Sep 17 00:00:00 2001 From: Tharmeekan Date: Tue, 26 May 2026 11:11:52 +0200 Subject: [PATCH 8/8] Fix load_model to transfer full_tensor --- src/weathergen/model/model_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index d75268169..ce469f2a4 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -204,7 +204,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): ) maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) else: - maybe_sharded_sd[param_name] = full_tensor + maybe_sharded_sd[param_name] = full_tensor.to(device) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True)