diff --git a/config/pipeline.yml b/config/pipeline.yml new file mode 100644 index 000000000..ad25e58e3 --- /dev/null +++ b/config/pipeline.yml @@ -0,0 +1,37 @@ +stages: + - name: pretrain + stage: train + config_files: + - config/config_forecasting.yml + options: + - training_config.num_mini_epochs=32 + chain_jobs: 4 + nodes: 2 + slurm_args: + - "--time=10:00:00" + + - name: finetune + stage: train + from_run_id: STAGE.pretrain + config_files: + - config/config_forecasting_finetuning.yml + chain_jobs: 2 + nodes: 2 + slurm_args: + - "--time=10:00:00" + + - name: inference + stage: inference + from_run_id: STAGE.pretrain + options: + - training_config.forecast.num_steps=120 + - test_config.output.num_samples=10 + - test_config.start_date=202310010000 + - test_config.end_date=202312300000 + - test_config.samples_per_mini_epoch=128 + chain_jobs: 2 + nodes: 1 + slurm_args: + - "--time=10:00:00" + + diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index b18791aa5..fb4250d98 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -102,9 +102,7 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 1 - ) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1) if self.rope_post_mod_qk_lnorm: qs = self.post_rope_lnorm_q(qs).to(self.dtype) ks = self.post_rope_lnorm_k(ks).to(self.dtype) @@ -302,9 +300,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 1 - ) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1) if self.rope_post_mod_qk_lnorm: qs = self.post_rope_lnorm_q(qs).to(self.dtype) ks = self.post_rope_lnorm_k(ks).to(self.dtype) @@ -621,9 +617,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 2 - ) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 2) if self.rope_post_mod_qk_lnorm: qs = self.post_rope_lnorm_q(qs).to(self.dtype) ks = self.post_rope_lnorm_k(ks).to(self.dtype) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 72340d142..e6b976ad1 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -6,13 +6,17 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import logging +import math +import numpy as np import torch from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.engines import ( EmbeddingEngine, GlobalAssimilationEngine, @@ -24,7 +28,15 @@ # from weathergen.model.model import ModelParams from weathergen.model.parametrised_prob_dist import LatentInterpolator -from weathergen.model.positional_encoding import positional_encoding_harmonic +from weathergen.model.positional_encoding import ( + build_spherical_rope_coeff_tensors, + get_rope_mode, + get_rope_spherical_band, + positional_encoding_harmonic, +) +from weathergen.utils.utils import get_dtype + +logger = logging.getLogger(__name__) class EncoderModule(torch.nn.Module): @@ -44,7 +56,76 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**self.healpix_level - self.cf = cf + self.dtype = get_dtype(cf.attention_dtype) + + # Positional embeddings + self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) + self.register_buffer( + "pe_embed", + torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), + ) + + self.register_buffer( + "q_cells_lens", torch.ones(self.num_healpix_cells + 1, dtype=torch.int32) + ) + self.q_cells_lens[0] = 0 + + self.register_buffer( + "pe_global", + torch.zeros(self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed, dtype=self.dtype), + ) + + + # RoPE coordinates + self.rope_mode = get_rope_mode(cf, logger) + if self.rope_mode != "none": + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + 1, + total_tokens, + 2, + dtype=self.dtype, + ), + ) + self.register_buffer( + "rope_cell_coords", + torch.zeros( + self.num_healpix_cells, + 2, + dtype=self.dtype, + ), + ) + if self.rope_mode == "spherical": + rope_spherical_band = get_rope_spherical_band(cf) + num_modes = 2 * int(rope_spherical_band) + 1 + self.register_buffer( + "rope_spherical_coeffs", + torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_cell_coeffs", + torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_extra_coeffs", + torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype), + ) + else: + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None + else: + self.rope_coords = None + self.rope_cell_coords = None + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None + self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size @@ -117,33 +198,131 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) - def forward(self, model_params, batch): + def reset_parameters(self) -> None: + """Creates positional embedding for each grid point for each stream used after stream + embedding, positional embedding for all stream assimilated cell-level local embedding, + initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter + initializing for target prediction. + + Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for + both per stream after stream embedding and per cell level for local assimilation. + + Query len based parameter creation: Calculate parameters for the calculated token length at + each cell after local assimilation.""" + + cf = self.cf + + dim_embed = cf.ae_local_dim_embed + token_idx_bias = 16 + freq_bias = 8 + self.pe_embed.data.fill_(0.0) + position = torch.arange( + token_idx_bias, + token_idx_bias + self.max_tokens_local_per_cell, + device=self.pe_embed.device, + ).unsqueeze(1) + div = torch.exp( + torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device) + * -(math.log(self.max_tokens_local_per_cell) / dim_embed), + ) + self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) + self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) + + dim_embed = cf.ae_global_dim_embed + + if self.rope_mode != "none": + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + self.rope_cell_coords.data.copy_(coords) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0) + offset = self.num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + if self.rope_mode == "spherical": + band = int(get_rope_spherical_band(cf)) + ( + (cell_real, cell_imag), + (extra_real, extra_imag), + (packed_extra_real, packed_extra_imag), + (packed_real, packed_imag), + ) = build_spherical_rope_coeff_tensors( + nside=2**self.healpix_level, + band=band, + num_local_queries=cf.ae_local_num_queries, + num_extra_tokens=self.num_extra_tokens, + device=self.rope_spherical_coeffs.device, + dtype=self.rope_spherical_coeffs.dtype, + ) + self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real) + self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag) + self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real) + self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag) + + self.rope_spherical_coeffs.data.fill_(0.0) + self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real) + self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_real.shape[1], :, 0 + ].copy_(packed_real) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_imag.shape[1], :, 1 + ].copy_(packed_imag) + + self.pe_global.data.fill_(0.0) + xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed + self.pe_global.data[..., 0::2] = 0.5 * torch.sin( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 0::2] += ( + torch.sin( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + self.pe_global.data[..., 1::2] = 0.5 * torch.cos( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 1::2] += ( + torch.cos( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + + self.q_cells_lens.data.fill_(1) + self.q_cells_lens.data[0] = 0 + + def forward(self, batch): """ Encoder forward """ stream_cell_tokens = checkpoint( - self.embed_engine, batch, model_params.pe_embed, use_reentrant=False + self.embed_engine, batch, self.pe_embed, use_reentrant=False ) tokens_global, posteriors = checkpoint( - self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False + self.assimilate_local, stream_cell_tokens, batch, use_reentrant=False ) tokens_global = checkpoint( self.ae_global_engine, tokens_global, coords=( - model_params.rope_spherical_coeffs.unbind(dim=-1) - if model_params.rope_spherical_coeffs is not None - else model_params.rope_coords + self.rope_spherical_coeffs.unbind(dim=-1) + if self.rope_spherical_coeffs is not None + else self.rope_coords ), use_reentrant=False, ) return tokens_global, posteriors - def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor): + def interpolate_latents(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ " TODO """ @@ -289,9 +468,7 @@ def aggregation_engine_unmasked( return tokens_global_unmasked - def assimilate_local( - self, model_params, tokens: torch.Tensor, batch: ModelBatch - ) -> torch.Tensor: + def assimilate_local(self, tokens: torch.Tensor, batch: ModelBatch) -> torch.Tensor: """ Processes embedded tokens locally and prepares them for the global assimilation @@ -316,15 +493,15 @@ def assimilate_local( # TODO: re-enable or remove ae_local_queries_per_cell if self.cf.ae_local_queries_per_cell: - tokens_global = (self.q_cells + model_params.pe_global).repeat(rs, 1, 1) + tokens_global = (self.q_cells + self.pe_global).repeat(rs, 1, 1) else: num_tokens = self.num_healpix_cells - tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + model_params.pe_global + tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + self.pe_global tokens_global = tokens_global.repeat(rs, 1, 1) # apply local assimilation engine and project onto global latent vectors tokens_global_unmasked, posteriors = self.assimilate_local_project_chunked( - tokens, tokens_global, cell_lens, model_params.q_cells_lens + tokens, tokens_global, cell_lens, self.q_cells_lens ) # apply aggregation engine on unmasked tokens @@ -332,9 +509,9 @@ def assimilate_local( tokens_global_unmasked, tokens_global_register_class, batch.tokens_lens, - rope_cell_coords=model_params.rope_cell_coords, - rope_cell_coeffs=model_params.rope_spherical_cell_coeffs, - rope_extra_coeffs=model_params.rope_spherical_extra_coeffs, + rope_cell_coords=self.rope_cell_coords, + rope_cell_coeffs=self.rope_spherical_cell_coeffs, + rope_extra_coeffs=self.rope_spherical_extra_coeffs, ) # final processing diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index e5af71d2e..c25fc34e5 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -16,6 +16,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -29,7 +30,11 @@ StreamEmbedTransformer, ) from weathergen.model.layers import MLP -from weathergen.model.positional_encoding import get_rope_mode +from weathergen.model.positional_encoding import ( + build_spherical_rope_coeff_tensors, + get_rope_mode, + get_rope_spherical_band, +) from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -559,6 +564,59 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells rope_mode = get_rope_mode(self.cf) self.fe_blocks = torch.nn.ModuleList() + self.rope_2D = cf.get("rope_2D", False) + self.healpix_level = cf.get("fe_healpix_level") if cf.get("fe_healpix_level") is not None else cf.get("healpix_level") + self.dtype = get_dtype(cf.attention_dtype) + + # RoPE coordinates + self.rope_mode = get_rope_mode(cf) + if self.rope_mode != "none": + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + 1, + total_tokens, + 2, + dtype=self.dtype, + ), + ) + self.register_buffer( + "rope_cell_coords", + torch.zeros( + self.num_healpix_cells, + 2, + dtype=self.dtype, + ), + ) + if self.rope_mode == "spherical": + rope_spherical_band = get_rope_spherical_band(cf) + num_modes = 2 * int(rope_spherical_band) + 1 + self.register_buffer( + "rope_spherical_coeffs", + torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_cell_coeffs", + torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_extra_coeffs", + torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype), + ) + else: + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None + else: + self.rope_coords = None + self.rope_cell_coords = None + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: @@ -625,7 +683,52 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep, coords=None): + def reset_parameters(self) -> None: + """HEALPix neighbourhood based parameter initializing for target prediction.""" + + cf = self.cf + + if self.rope_mode != "none": + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + self.rope_cell_coords.data.copy_(coords) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0) + offset = self.num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + if self.rope_mode == "spherical": + band = int(get_rope_spherical_band(cf)) + ( + (cell_real, cell_imag), + (extra_real, extra_imag), + (packed_extra_real, packed_extra_imag), + (packed_real, packed_imag), + ) = build_spherical_rope_coeff_tensors( + nside=2**self.healpix_level, + band=band, + num_local_queries=cf.ae_local_num_queries, + num_extra_tokens=self.num_extra_tokens, + device=self.rope_spherical_coeffs.device, + dtype=self.rope_spherical_coeffs.dtype, + ) + self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real) + self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag) + self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real) + self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag) + + self.rope_spherical_coeffs.data.fill_(0.0) + self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real) + self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_real.shape[1], :, 0 + ].copy_(packed_real) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_imag.shape[1], :, 1 + ].copy_(packed_imag) + + def forward(self, tokens, fstep): if self.training: # Impute noise to the latent state noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) @@ -637,7 +740,15 @@ def forward(self, tokens, fstep, coords=None): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) + tokens = checkpoint( + block, + tokens, + self.rope_spherical_coeffs.unbind(dim=-1) + if self.rope_spherical_coeffs is not None + else self.rope_coords, + aux_info, + use_reentrant=False, + ) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 244eac8f9..569f8e5c1 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -10,11 +10,9 @@ # nor does it submit to any jurisdiction. import logging -import math import warnings import astropy_healpix as hp -import astropy_healpix.healpy import numpy as np import torch import torch.nn as nn @@ -22,7 +20,6 @@ from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch -from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, @@ -37,11 +34,6 @@ TargetPredictionEngineClassic, ) from weathergen.model.layers import MLP, NamedLinear -from weathergen.model.positional_encoding import ( - build_spherical_rope_coeff_tensors, - get_rope_mode, - get_rope_spherical_band, -) from weathergen.model.utils import get_num_parameters from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype, is_stream_forcing @@ -96,72 +88,6 @@ def __init__(self, cf) -> None: self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**cf.healpix_level - self.dtype = get_dtype(cf.attention_dtype) - - # Positional embeddings - self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) - self.pe_embed = torch.nn.Parameter( - torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), - requires_grad=False, - ) - - pe = torch.zeros( - self.num_healpix_cells, - cf.ae_local_num_queries, - cf.ae_global_dim_embed, - dtype=self.dtype, - ) - self.pe_global = torch.nn.Parameter(pe, requires_grad=False) - - # RoPE coordinates - self.rope_mode = get_rope_mode(cf, logger) - if self.rope_mode != "none": - self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens - total_tokens = ( - self.num_healpix_cells + self.num_extra_tokens - ) * cf.ae_local_num_queries - self.register_buffer( - "rope_coords", - torch.zeros( - 1, - total_tokens, - 2, - dtype=self.dtype, - ), - ) - self.register_buffer( - "rope_cell_coords", - torch.zeros( - self.num_healpix_cells, - 2, - dtype=self.dtype, - ), - ) - if self.rope_mode == "spherical": - rope_spherical_band = get_rope_spherical_band(cf) - num_modes = 2 * int(rope_spherical_band) + 1 - self.register_buffer( - "rope_spherical_coeffs", - torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype), - ) - self.register_buffer( - "rope_spherical_cell_coeffs", - torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype), - ) - self.register_buffer( - "rope_spherical_extra_coeffs", - torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype), - ) - else: - self.rope_spherical_coeffs = None - self.rope_spherical_cell_coeffs = None - self.rope_spherical_extra_coeffs = None - else: - self.rope_coords = None - self.rope_cell_coords = None - self.rope_spherical_coeffs = None - self.rope_spherical_cell_coeffs = None - self.rope_spherical_extra_coeffs = None # HEALPix neighbours hlc = self.healpix_level @@ -177,124 +103,15 @@ def __init__(self, cf) -> None: requires_grad=False, ) - self.q_cells_lens = torch.nn.Parameter( - torch.ones(self.num_healpix_cells + 1, dtype=torch.int32), requires_grad=False - ) - self.q_cells_lens.data[0] = 0 - def create(self, cf: Config) -> "ModelParams": - self.reset_parameters(cf) + self.reset_parameters() return self - def reset_parameters(self, cf: Config) -> "ModelParams": - """Creates positional embedding for each grid point for each stream used after stream - embedding, positional embedding for all stream assimilated cell-level local embedding, - initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter - initializing for target prediction. - - Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for - both per stream after stream embedding and per cell level for local assimilation. - - HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize each - with its own cell number as well as the cell numbers of its neighbors. If a cell has - fewer than eight neighbors, use its own cell number to fill the remaining slots. - - Query len based parameter creation: Calculate parameters for the calculated token length at - each cell after local assimilation. - - Args: - cf : Configuration + def reset_parameters(self) -> "ModelParams": + """HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize + each with its own cell number as well as the cell numbers of its neighbors. If a cell has + fewer than eight neighbors, use its own cell number to fill the remaining slots. """ - - # positional encodings - - dim_embed = cf.ae_local_dim_embed - token_idx_bias = 16 - freq_bias = 8 - self.pe_embed.data.fill_(0.0) - position = torch.arange( - token_idx_bias, - token_idx_bias + self.max_tokens_local_per_cell, - device=self.pe_embed.device, - ).unsqueeze(1) - div = torch.exp( - torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device) - * -(math.log(self.max_tokens_local_per_cell) / dim_embed), - ) - self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) - self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) - - dim_embed = cf.ae_global_dim_embed - - if self.rope_mode != "none": - verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) - coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) - self.rope_cell_coords.data.copy_(coords) - coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) - coords_flat = coords.flatten(0, 1).unsqueeze(0) - offset = self.num_extra_tokens * cf.ae_local_num_queries - self.rope_coords.data.fill_(0.0) - self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) - - if self.rope_mode == "spherical": - band = int(get_rope_spherical_band(cf)) - ( - (cell_real, cell_imag), - (extra_real, extra_imag), - (packed_extra_real, packed_extra_imag), - (packed_real, packed_imag), - ) = build_spherical_rope_coeff_tensors( - nside=2**self.healpix_level, - band=band, - num_local_queries=cf.ae_local_num_queries, - num_extra_tokens=self.num_extra_tokens, - device=self.rope_spherical_coeffs.device, - dtype=self.rope_spherical_coeffs.dtype, - ) - self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real) - self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag) - self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real) - self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag) - - self.rope_spherical_coeffs.data.fill_(0.0) - self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real) - self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag) - self.rope_spherical_coeffs.data[ - :, offset : offset + packed_real.shape[1], :, 0 - ].copy_(packed_real) - self.rope_spherical_coeffs.data[ - :, offset : offset + packed_imag.shape[1], :, 1 - ].copy_(packed_imag) - - # pe_global: always initialized. RoPE handles relative position in Q/K, but pe_global - # provides per-cell token identity which is critical for masked cells that have no - # content from local assimilation. Without it, masked cells are identical and the - # teacher representation (evaluated without dropout) collapses to low rank. - self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed - self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 0::2] += ( - torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 1::2] += ( - torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - - # healpix neighborhood structure - hlc = self.healpix_level num_healpix_cells = self.num_healpix_cells with warnings.catch_warnings(action="ignore"): @@ -306,12 +123,6 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.hp_nbours.data[:, 0] = torch.arange(temp.shape[0], device=self.hp_nbours.device) self.hp_nbours.data[:, 1:] = torch.from_numpy(temp).to(self.hp_nbours.device) - # precompute for varlen attention - self.q_cells_lens.data.fill_(1) - self.q_cells_lens.data[0] = 0 - - # ensure all params have grad set to False - return @@ -653,6 +464,10 @@ def _reset_params(module): pass self.apply(_reset_params) + if self.encoder is not None: + self.encoder.reset_parameters() + if self.forecast_engine is not None: + self.forecast_engine.reset_parameters() def print_num_parameters(self) -> None: """Print number of parameters for entire model and each module used to build the model""" @@ -749,7 +564,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: output = ModelOutput(batch.get_output_len()) - tokens, posteriors = self.encoder(model_params, batch) + tokens, posteriors = self.encoder(batch) output.add_latent_prediction(0, "posteriors", posteriors) # recover batch dimension and separate input_steps @@ -757,12 +572,6 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) - rope_data = ( - model_params.rope_spherical_coeffs.unbind(dim=-1) - if model_params.rope_spherical_coeffs is not None - else model_params.rope_coords - ) - # Allow for pushforward trick p_fwd = self.cf.training_config.get("forecast", {}).get("pushforward", False) # roll-out in latent space, iterate and generate output over requested output steps @@ -770,20 +579,19 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): - tokens = self.forecast_engine(tokens, step, coords=rope_data) + tokens = self.forecast_engine(tokens, step) continue - tokens = self.forecast_engine(tokens, step, coords=rope_data) + tokens = self.forecast_engine(tokens, step) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) - output = self.predict_latent(model_params, step, tokens, batch, output) + output = self.predict_latent(step, tokens, batch, output) return output def predict_latent( self, - model_params: ModelParams, step: int, tokens: torch.Tensor, batch: ModelBatch, diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 1423652f7..ce469f2a4 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -17,7 +17,7 @@ MixedPrecisionPolicy, fully_shard, ) -from torch.distributed.tensor import distribute_tensor +from torch.distributed.tensor import DTensor, distribute_tensor from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( @@ -164,7 +164,7 @@ def init_model_and_shard( # model params model_params = ModelParams(cf).create(cf) - model_params.reset_parameters(cf) + model_params.reset_parameters() model_params = model_params.to(f"cuda:{cf.local_rank}") return model, model_params @@ -196,13 +196,15 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): if sharded_meta_param is None: logger.warning(f"Parameter {param_name} from checkpoint not found in model.") continue - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) - maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + if isinstance(sharded_meta_param, DTensor): + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + else: + maybe_sharded_sd[param_name] = full_tensor.to(device) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index ad1c54bee..6aa364e4c 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -181,6 +181,7 @@ def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) + # Spherical RoPE def _max_supported_spherical_band(dim_embed: int, num_heads: int) -> int: head_dim = dim_embed // num_heads @@ -324,7 +325,6 @@ def build_spherical_rope_coeff_tensors( ) - @lru_cache(maxsize=32) def _healpy_band_maps( nside: int, band: int diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f8cb9cafe..b222b92dc 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -652,7 +652,7 @@ def _get_full_model_state_dict(self): if self.cf.with_ddp and self.cf.with_fsdp: cpu_state_dict = {} for param_name, sharded_param in maybe_sharded_sd.items(): - full_param = sharded_param.full_tensor() + full_param = sharded_param.full_tensor() if isinstance(sharded_param, DTensor) else sharded_param if is_root(): cpu_state_dict[param_name] = full_param.cpu() else: