diff --git a/NOTICE b/NOTICE index 34bf42bfa..ea8e8f7c4 100644 --- a/NOTICE +++ b/NOTICE @@ -1,3 +1,43 @@ +======================================================================= +NVLABS/EDM (Elucidating the Design of Diffusion Models) + +This software incorporates code from the 'edm' repository. + +Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +The source code is available at: +https://github.com/NVlabs/edm + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 + +======================================================================= +google-deepmind/graphcast (several associated papers) + +This software incorporates code from the 'google-deepmind/graphcast' repository, with adaptations. + +Original Copyright 2024 DeepMind Technologies Limited. + +The source code is available at: +https://github.com/google-deepmind/graphcast + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 + +======================================================================= +facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) + +This software incorporates code from the 'facebookresearch/DiT' repository, with adaptations. + +The source code is available at: +https://github.com/facebookresearch/DiT + +The code and model weights are licensed under CC-BY-NC. +See https://raw.githubusercontent.com/facebookresearch/DiT/refs/heads/main/LICENSE.txt for details. This project includes code derived from project "DINOv2: Learning Robust Visual Features without Supervision", originally developed by Meta Platforms, Inc. and affiliates, licensed under the Apache License, Version 2.0. diff --git a/config/config_diffusion.yml b/config/config_diffusion.yml new file mode 100644 index 000000000..8cc954266 --- /dev/null +++ b/config/config_diffusion.yml @@ -0,0 +1,339 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 512 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 512 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +diffusion_conditioning_embed_dim: 32 +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: "" +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'j74tn8le', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=-1.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'j7lr0jws', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=-1.2, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'cbras2el', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=-0.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'kn3124hp', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=0.0, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'qqbu9852', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=0.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'vqsh3yrl', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=1.0, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'xl8h7vbt', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=1.5, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'p9m2jwvc', 'epoch': -1} # forecasting d512 hl5, diffusion-full-pipeline, p_mean=2.0, based on m6fs8wvj backbone + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 2, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_ERA5.yml b/config/config_diffusion_d2048_ERA5.yml new file mode 100644 index 000000000..9e3094e64 --- /dev/null +++ b/config/config_diffusion_d2048_ERA5.yml @@ -0,0 +1,348 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: None # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: None # options: "cross_attn", "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_date_time.yml b/config/config_diffusion_d2048_date_time.yml new file mode 100644 index 000000000..95bcc1f31 --- /dev/null +++ b/config/config_diffusion_d2048_date_time.yml @@ -0,0 +1,348 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "date_time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_forecast.yml b/config/config_diffusion_d2048_forecast.yml new file mode 100644 index 000000000..fbace5174 --- /dev/null +++ b/config/config_diffusion_d2048_forecast.yml @@ -0,0 +1,349 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting (deterministic, precursor to ug7huxi2) +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 2, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_time.yml b/config/config_diffusion_d2048_time.yml new file mode 100644 index 000000000..8e100832f --- /dev/null +++ b/config/config_diffusion_d2048_time.yml @@ -0,0 +1,349 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_time_aug.yml b/config/config_diffusion_d2048_time_aug.yml new file mode 100644 index 000000000..f0011735f --- /dev/null +++ b/config/config_diffusion_d2048_time_aug.yml @@ -0,0 +1,349 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +#below for FE training only +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for DECODER training only +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +#below for FE and DECODER training +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" + +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +# load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2022-08-01T00:00 + end_date: 2022-08-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-08-01T00:00 + end_date: 2023-08-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_tiny.yml b/config/config_diffusion_tiny.yml new file mode 100644 index 000000000..ba306d483 --- /dev/null +++ b/config/config_diffusion_tiny.yml @@ -0,0 +1,343 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# # z500 small +# embed_orientation: "channels" +# embed_unembed_mode: "block" +# embed_dropout_rate: 0.1 + +# ae_local_dim_embed: 128 +# ae_local_num_blocks: 0 +# ae_local_num_heads: 8 +# ae_local_dropout_rate: 0.0 +# ae_local_with_qk_lnorm: True + +# ae_local_num_queries: 1 +# ae_local_queries_per_cell: False +# ae_adapter_num_heads: 8 +# ae_adapter_embed: 128 +# ae_adapter_with_qk_lnorm: True +# ae_adapter_with_residual: True +# ae_adapter_dropout_rate: 0.0 + +# ae_global_dim_embed: 128 +# ae_global_num_blocks: 4 +# ae_global_num_heads: 8 +# ae_global_dropout_rate: 0.0 +# ae_global_with_qk_lnorm: True +# # TODO: switching to < 1 triggers triton-related issues. +# # See https://github.com/ecmwf/WeatherGenerator/issues/1050 +# ae_global_att_dense_rate: 1.0 +# ae_global_block_factor: 64 +# ae_global_mlp_hidden_factor: 2 +# ae_global_trailing_layer_norm: False + +# ae_aggregation_num_blocks: 0 +# ae_aggregation_num_heads: 4 +# ae_aggregation_dropout_rate: 0.0 +# ae_aggregation_with_qk_lnorm: True +# ae_aggregation_att_dense_rate: 1.0 +# ae_aggregation_block_factor: 64 +# ae_aggregation_mlp_hidden_factor: 2 + +# decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +# pred_adapter_kv: False +# pred_self_attention: True +# pred_dyadic_dims: False +# pred_mlp_adaln: True +# num_class_tokens: 0 +# num_register_tokens: 0 + +# # multi-var small +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 128 +ae_local_num_blocks: 0 +ae_local_num_heads: 4 +ae_local_dropout_rate: 0.0 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.0 + +ae_global_dim_embed: 128 +ae_global_num_blocks: 4 +ae_global_num_heads: 4 +ae_global_dropout_rate: 0.0 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 8 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 4 +ae_aggregation_dropout_rate: 0.0 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 8 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 4 +fe_num_heads: 4 +fe_dropout_rate: 0.0 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 # 170 +sigma_data: 0.7855 # 0.5 # 1.7 +rho: 7 +p_mean: -1.2 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# load_chkpt: {'run_id': 'lgrasnq6', 'epoch': -1} # z500 small +# load_chkpt: {'run_id': 'hhz27wy0', 'epoch': -1} # multi-var small, sigma_data=0.7855 +# load_chkpt: {'run_id': 'xpwjhaf4', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ml89b5r6', 'epoch': -1} # multi-var d128 hl5, sigma_data=0.2415 +load_chkpt: {'run_id': 'a3n1pdkl', 'epoch': -1} # multi-var d128 hl5, nopos, sigma_data=0.2507 + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_diffusion_tiny/" # z500 small +# streams_directory: "./config/streams/era5_1deg_forecasting/" # multi-var small +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 256 + samples_per_mini_epoch: 1024 + shuffle: True + + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 5e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 10.0 # 1.0 + weight_decay: 0.0 # 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_samples: 1 + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + # validation_noise_levels: [0.3, 0.5, 0.75, 1.0, 1.5] + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + # validation_noise_levels: [1.0, 2.0, 4.0, 8.0, 16.0] + + samples_per_mini_epoch: 8 + shuffle: False + + start_date: 2012-06-01T00:00 + end_date: 2012-06-01T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 1, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 4f1ff1499..bc74a96c1 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 2048 +ae_local_dim_embed: 512 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -25,7 +25,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 +ae_global_dim_embed: 512 ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 @@ -63,10 +63,15 @@ fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm fe_impute_latent_noise_std: 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 +fe_diffusion_model: False healpix_level: 5 -rope_2D: False +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False with_mixed_precision: True with_flash_attention: True @@ -87,6 +92,7 @@ freeze_modules: "" load_chkpt: {} norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type ##################################### diff --git a/config/config_forecasting_d2048.yml b/config/config_forecasting_d2048.yml new file mode 100644 index 000000000..ad538f893 --- /dev/null +++ b/config/config_forecasting_d2048.yml @@ -0,0 +1,256 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +fe_diffusion_model: False + +healpix_level: 5 + +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_forecasting_z500.yml b/config/config_forecasting_z500.yml new file mode 100644 index 000000000..5dc0dbe43 --- /dev/null +++ b/config/config_forecasting_z500.yml @@ -0,0 +1,251 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 3 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams_directory: "./config/streams/era5_1deg_forecasting/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/default_config.yml b/config/default_config.yml index 39abe739b..afdbcde13 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -57,6 +57,7 @@ fe_num_blocks: 6 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True +fe_diffusion_model: False fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) @@ -258,4 +259,4 @@ wgtags: # *** Experiment-specific tags *** # All extra tags (including lists, dictionaries, etc.) are treated # as strings by mlflow, so treat all extra tags as simple string key: value pairs. - grid: null + grid: null \ No newline at end of file diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index fc2d49c2a..1f0ae14b8 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -7,6 +7,9 @@ # animation_format: "gif" #options: "mp4", "gif" # dpi_val : 300 # fps: 2 +# n_bins: 50 #number of bins for histograms. +# log_x: true #use log scale for x axis in histograms. +# log_y: true #use log scale for y axis in histograms. # ERA5: # use_datashader: false # marker_size: 2 diff --git a/config/evaluate/eval_config_default.yml b/config/evaluate/eval_config_default.yml index 92921edd6..7c7dfca4d 100644 --- a/config/evaluate/eval_config_default.yml +++ b/config/evaluate/eval_config_default.yml @@ -37,6 +37,7 @@ evaluation: add_grid: false score_cards: false bar_plots: false + #agg_dims: ["ipoint"] #----> NOTE: advanced! Handle with care. This will average the scores across the specified list of dimensions. Supported dimensions: "ipoint", "sample", "forecast_step", "ensemble". Use with caution, as it can hide important details about the model performance. default_streams: diff --git a/config/evaluate/eval_config_diffusion.yml b/config/evaluate/eval_config_diffusion.yml new file mode 100644 index 000000000..15a567af2 --- /dev/null +++ b/config/evaluate/eval_config_diffusion.yml @@ -0,0 +1,62 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +#global_plotting_options: +# region: ["belgium", "global"] +# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +# dpi_val : 300 +# fps: 2 +# ERA5: +# marker_size: 2 +# scale_marker_size: 1 +# marker: "o" +# # alpha: 0.5 +# 2t: +# vmin: 250 +# vmax: 300 +# 10u: +# vmin: -40 +# vmax: 40 + +evaluation: + metrics : ["rmse", "mae"] + regions: ["global", "nhem"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + num_processes: 0 #options: int, "auto", 0 means no parallelism (default) + # baseline: "ar40mckx" + + +default_streams: + ERA5: + channels: ["2t", "10u"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [1] + forecast_step: [1] #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + ensemble: [0] #supported: "all", "mean", [0,1,2] + plot_maps: true + plot_target: true + plot_histograms: true + plot_animations: true + + +run_ids : + kuia5xr0: + label: "debugging model g0vdqua7" + results_base_dir : "../results/" + #NEW: if "streams" is not specified, the default streams are used + + + \ No newline at end of file diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..b5092f8a1 --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,100 @@ + +train: + plot: + r8vzykrm: + slurm_id: 387093 + description: "bug-fix non-cond-branch, with sigma_data=0.63" + nol9pfdg: + slurm_id: 387094 + description: "bug-fix cond-branch w/o cond, with sigma_data=0.63" + # imqzsbte: + # slurm_id: 387095 + # description: "bug-fix non-cond" + f8nd1c60: + slurm_id: 387097 + description: "bug-fix cond-branch w/ cond" + # ux8yjktb: + # slurm_id: 387095 + # description: "bug-fix cond-branch w/ non-cond" + # xxkmgsne: + # slurm_id: 0 + # description: "bug-fix cond-branch w/ non-cond, lr_max=5e-6" + # jwexz9y4: + # slurm_id: 0 + # description: "bug-fix cond-branch w/ non-cond, lr_max=2.5e-6" + u7etjsm0: + slurm_id: 385058 + description: "old ERA5, lr_start=1e-6, lr_max=1e-5" + mot8sfay: + slurm_id: 385060 + description: "old ERA5, lr_start=1e-6, lr_max=7e-6" + # zhon45xy: + # slurm_id: 385064 + # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=1e-5" + # yimje7g3: + # slurm_id: 385062 + # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=7e-6" + # bpeh160r: + # slurm_id: 381190 + # description: "single samples, lr_start=1e-6, lr_max=1e-6" + # cigywmh2: + # slurm_id: 380678 + # description: "single samples, lr_start=1e-6, lr_max=1e-5" + # kxe5zfla: + # slurm_id: 380680 + # description: "single samples, lr_start=1e-6, lr_max=5e-5" + # fuz6l32i: + # slurm_id: 382174 + # description: "conditioning w/ single samples, lr_start=1e-6, lr_max=5e-5" + # vujmw4g2: + # slurm_id: 382207 + # description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + # t8cm7bn9: + # slurm_id: 382239 + # description: "MERGE no conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + # b9oyntjg: + # slurm_id: 382235 + # description: "MERGE conditioning w/ ERA5, lr_start=1e-6, lr_max=5e-5" + # it6wj130: + # slurm_id: 376473 + # description: "single samples, lr_start=1e-6, lr_max=1e-5" + # p54uvzl2: + # slurm_id: 376808 + # description: "single samples, lr_start=1e-5, lr_max=1e-5" + # lq5djr4m: + # slurm_id: 376811 + # description: "single samples, lr_start=1e-6, lr_max=1e-6" + # wcruesg4: + # slurm_id: 376816 + # description: "single samples, lr_start=1e-7, lr_max=1e-6" + # k3qh6elp: + # slurm_id: 377059 + # description: "single samples, lr_start=1e-6, lr_max=5e-6" + # w8hp1c2g: + # slurm_id: 376855 + # description: "20y distribution, lr_start=1e-5, lr_max=1e-5" + # ss9z2rqi: + # slurm_id: 376858 + # description: "20y distribution, lr_start=1e-6, lr_max=1e-6" + # g9iqgz0d: + # slurm_id: 376860 + # description: "20y distribution, lr_start=1e-7, lr_max=1e-6" + # q5u9p8xo: + # slurm_id: 377063 + # description: "20y distribution, lr_start=1e-6, lr_max=5e-6" + # f8e97mqx: + # slurm_id: 376862 + # description: "20y distribution, lr_start=1e-6, lr_max=1e-5" + # xqgy519d: + # slurm_id: 0 + # description: "Old Matze Baseline 1e-5 (single samples)" + # rj2xksg0: + # slurm_id: 0 + # description: "Old Matze Baseline 1e-5 (single samples)" + bbosl5wy: + slurm_id: 0 + description: "Matze Baseline (ERA5)" + # y0l8egdr: + # slurm_id: 0 + # description: "New Matze Baseline (single samples)" + diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 26455bf4e..a779ba68a 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -24,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : num_layers : 2 num_heads : 4 diff --git a/config/streams/era5_1deg_diffusion_tiny/era5.yml b/config/streams/era5_1deg_diffusion_tiny/era5.yml new file mode 100644 index 000000000..40da38c09 --- /dev/null +++ b/config/streams/era5_1deg_diffusion_tiny/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + # source : ["z_500"] + # target : ["z_500"] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 32 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 32 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 0bd70ae01..ae8e2da53 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -13,17 +13,19 @@ ERA5 : stream_id : 0 source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 masking_rate_none : 0.05 token_size : 8 tokenize_spacetime : True - max_num_targets: 20000 + # max_num_targets: 20000 + max_num_targets: -1 embed : net : transformer num_tokens : 1 - num_heads : 8 + num_heads : 4 dim_embed : 256 num_blocks : 2 embed_target_coords : diff --git a/config/streams/era5_1deg_forecasting_d2048/era5.yml b/config/streams/era5_1deg_forecasting_d2048/era5.yml new file mode 100644 index 000000000..ed00da42c --- /dev/null +++ b/config/streams/era5_1deg_forecasting_d2048/era5.yml @@ -0,0 +1,114 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] + stream_id : 0 + source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + # source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + # target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + # geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1. + location_weight : cosine_latitude + token_size : 8 + tokenize_spacetime : True + max_num_targets: 20000 + # max_num_targets: -1 + frequency : 06:00:00 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 + channel_weights : + q_10: 0.2 + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_10: 0.2 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_10: 0.2 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_10: 0.2 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_10: 0.2 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 + \ No newline at end of file diff --git a/config/streams/era5_1deg_forecasting_z500/era5.yml b/config/streams/era5_1deg_forecasting_z500/era5.yml new file mode 100644 index 000000000..f1659fb21 --- /dev/null +++ b/config/streams/era5_1deg_forecasting_z500/era5.yml @@ -0,0 +1,38 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source : ['z_500'] + target : ['z_500'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 32 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 32 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 4df4a67d9..75a6f3ba1 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -753,10 +753,11 @@ def validate_forecast_policy_and_steps(forecast_cfg: OmegaConf, mode: str): output_offset = forecast_cfg.get("offset", 0) assert isinstance(output_offset, int), TypeError(valid_forecast_offset) if output_offset == 0: - if isinstance(forecast_cfg.num_steps, int): - assert forecast_cfg.num_steps in [0, 1], valid_forecast_steps_offset0 - else: - raise TypeError(valid_forecast_steps_offset0) + # if isinstance(forecast_cfg.num_steps, int): + # assert forecast_cfg.num_steps in [0, 1], valid_forecast_steps_offset0 + # else: + # raise TypeError(valid_forecast_steps_offset0) + pass elif output_offset == 1: assert forecast_cfg.policy, (provide_forecast_policy, valid_forecast_policies) if isinstance(forecast_cfg.num_steps, int): diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 79850df42..b28ad4adc 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -9,7 +9,6 @@ """Plotting orchestration: parallel dispatch of per-sample maps, score maps, and summary plots.""" -import glob import logging from pathlib import Path @@ -45,7 +44,6 @@ _logger = logging.getLogger(__name__) - # --------------------------------------------------------------------------- # Score maps # --------------------------------------------------------------------------- @@ -252,17 +250,18 @@ def _scatter_plot_single( def _build_single_animation( - map_output_dir: Path, + output_dir: Path, run_id: str, tag: str, stream: str, - region: str, + region: str | None, var: str, sa: object, fsteps: list, image_format: str, animation_format: str, duration_ms: int, + prefix: str = "map", ) -> list[str]: """Build one GIF for a single (region, sample, variable) combination. @@ -271,32 +270,38 @@ def _build_single_animation( Returns the list of source frame paths that were assembled into the GIF (empty list if no frames were found). """ - image_paths: list[str] = [] - for fstep in fsteps: - parts = [ - "map", - run_id, - tag, - str(sa), - "*", - stream, - region, - var, - "fstep", - str(fstep).zfill(3), - ] - name = "_".join(filter(None, parts)) - fname = f"{map_output_dir.joinpath(name)}.{image_format}" - image_paths += glob.glob(fname) + # Both map and histogram filenames follow the same pattern: + # {prefix}_{run_id}_{tag}_{sample}_{valid_time}_{stream}_{region}_{var}_{fstep:03d} + # For all_samples histograms, valid_time is omitted. + # We match files by checking a fixed prefix and suffix, allowing any + # valid_time (or none) in between — no glob wildcards needed. + region_part = region if region else "" + head = "_".join(filter(None, [prefix, run_id, tag, str(sa)])) + tail = "_".join(filter(None, [stream, region_part, var])) + suffix = f".{image_format}" + fstep_strs = {str(f).zfill(3) for f in fsteps} + + if not output_dir.is_dir(): + return [] + + image_paths = sorted( + str(f) + for f in output_dir.iterdir() + if f.name.startswith(head + "_") + and f.name.endswith(suffix) + and f"_{tail}_" in f.name + and f.stem.rsplit("_", 1)[-1] in fstep_strs + ) if not image_paths: - _logger.debug(f"No images found for animation {var} sample {sa} region {region}") return [] - image_paths = sorted(image_paths) - out_path = ( - f"{map_output_dir}/animation_{run_id}_{tag}_{sa}_{stream}_{region}_{var}.{animation_format}" - ) + anim_parts = ["animation", run_id, tag, str(sa), stream] + if region: + anim_parts.append(region) + anim_parts.append(var) + out_path = f"{output_dir / '_'.join(filter(None, anim_parts))}.{animation_format}" + if animation_format.lower() == "mp4": frames = [imageio.imread(p) for p in image_paths] fps = 1000 / duration_ms if duration_ms > 0 else 2 @@ -327,6 +332,9 @@ def _dispatch_animations( ) -> list[str]: """Build GIF animations in parallel for all (region, sample, variable) combinations. + Animations are built for both maps and histograms — whichever image files + exist on disk will be picked up automatically. + Parameters ---------- plotter : Plotter @@ -340,13 +348,17 @@ def _dispatch_animations( Paths of all source frames that were assembled into GIFs. """ plotter.update_data_selection(select) - map_output_dir = plotter.get_map_output_dir(tag) duration_ms = int(1000 / plotter.fps) if plotter.fps > 0 else 400 + prefixes = [ + ("map", plotter.get_map_output_dir(tag)), + ("histogram", plotter.get_hist_output_dir()), + ] + tasks = [ { - "map_output_dir": map_output_dir, + "output_dir": output_dir, "run_id": plotter.run_id, "tag": tag, "stream": plotter.stream, @@ -357,7 +369,9 @@ def _dispatch_animations( "image_format": plotter.image_format, "animation_format": plotter.animation_format, "duration_ms": duration_ms, + "prefix": prefix, } + for prefix, output_dir in prefixes for region in plotter.regions for sa in samples for var in variables @@ -370,7 +384,7 @@ def _dispatch_animations( results = dispatch_parallel( calls, n_workers=get_num_workers(max_workers=max_workers), - backend="threading", + backend="loky", desc="Animations", ) return [p for r in results if r for p in r] @@ -395,7 +409,7 @@ def _plot_single_sample( plot_maps: bool, plot_bias: bool, plot_target: bool, - plot_histograms: bool, + plot_histograms: bool | str, maps_config: dict, bias_config: dict, ) -> None: @@ -418,11 +432,13 @@ def _plot_single_sample( if plot_bias and bias_data is not None and not bias_has_ens: plotter.create_maps_per_sample(bias_data, plot_chs, data_selection, "bias", bias_cfg) - for ens in ensemble: - has_ens = "ens" in preds.dims and ens != "mean" - preds_ens = preds.sel(ens=ens) if has_ens else preds - preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" - preds_name = "_".join(filter(None, ["preds", preds_tag])) + for ens in ensemble: + has_ens = "ens" in preds.dims and ens != "mean" + preds_ens = preds.sel(ens=ens) if has_ens else preds + preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" + preds_name = "_".join(filter(None, ["preds", preds_tag])) + + if plot_maps: plotter.create_maps_per_sample( preds_ens, plot_chs, data_selection, preds_name, maps_cfg ) @@ -434,10 +450,60 @@ def _plot_single_sample( bias_ens, plot_chs, data_selection, bias_tag, bias_cfg ) - if plot_histograms: - plotter.create_histograms_per_sample( - tars, preds_ens, plot_chs, data_selection, preds_tag - ) + if plot_histograms is True or plot_histograms == "per-sample": + plotter.create_histograms( + tars, + preds_ens, + plot_chs, + data_selection, + preds_name, + ranges=maps_config, + ) + + plotter.clean_data_selection() + + +def _plot_all_samples( + plotter_cfg: dict, + output_basedir: str, + tars: xr.DataArray, + preds: xr.DataArray, + bias_data: xr.DataArray | None, + fstep: int | str, + stream: str, + plot_chs: list[str], + ensemble: list, + plot_histograms: bool | str, + maps_config: dict, + bias_config: dict, +) -> None: + """Plot histograms across all samples for a single fstep. + + Unlike per-sample histograms, these aggregate all samples together. + The output filename uses 'global' instead of a sample id and omits the timestep. + """ + if not (plot_histograms is True or plot_histograms == "across-samples"): + return + + matplotlib.use("Agg") + plotter = Plotter(plotter_cfg, Path(output_basedir)) + + data_selection = {"sample": "all_samples", "stream": stream, "forecast_step": fstep} + + for ens in ensemble: + has_ens = "ens" in preds.dims and ens != "mean" + preds_ens = preds.sel(ens=ens) if has_ens else preds + preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" + preds_name = "_".join(filter(None, ["preds", preds_tag])) + + plotter.create_histograms( + tars, + preds_ens, + plot_chs, + data_selection, + preds_name, + ranges=maps_config, + ) plotter.clean_data_selection() @@ -465,14 +531,8 @@ def plot_data( stream_cfg = reader.get_stream(stream) plot_settings = stream_cfg.get("plotting", {}) - if not ( - plot_settings - and ( - plot_settings.get("plot_maps", False) - or plot_settings.get("plot_histograms", False) - or plot_settings.get("plot_animations", False) - ) - ): + plot_keys = ("plot_maps", "plot_histograms", "plot_animations") + if not plot_settings or not any(plot_settings.get(k, False) for k in plot_keys): return plotter_cfg = { @@ -482,9 +542,13 @@ def plot_data( "fig_size": global_plotting_opts.get("fig_size"), "fps": global_plotting_opts.get("fps", 2), "regions": global_plotting_opts.get("regions", ["global"]), + "log_x": global_plotting_opts.get("log_x", False), + "log_y": global_plotting_opts.get("log_y", False), + "n_bins": global_plotting_opts.get("n_bins", 50), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False) | plot_settings.get("plot_subtimesteps", False), } + plotter = Plotter(plotter_cfg, reader.runplot_dir) available_data = reader.check_availability(stream, mode="plotting") @@ -502,12 +566,16 @@ def plot_data( if not isinstance(plot_target, bool): raise TypeError("plot_target must be a boolean.") plot_histograms = plot_settings.get("plot_histograms", False) - if not isinstance(plot_histograms, bool): - raise TypeError("plot_histograms must be a boolean.") + if not isinstance(plot_histograms, bool) and plot_histograms not in { + "across-samples", + "per-sample", + }: + raise TypeError("plot_histograms must be true, false, 'across-samples', or 'per-sample'. ") plot_animations = plot_settings.get("plot_animations", False) if not isinstance(plot_animations, bool): raise TypeError("plot_animations must be a boolean.") + model_output = output_data if output_data is None: model_output = reader.get_data( stream, @@ -516,8 +584,6 @@ def plot_data( channels=available_data.channels, ensemble=available_data.ensemble, ) - else: - model_output = output_data da_tars = model_output.target da_preds = model_output.prediction @@ -530,7 +596,9 @@ def plot_data( plot_sample_set = set(available_data.samples) if available_data.samples is not None else None plot_channel_set = set(available_data.channels) if available_data.channels is not None else None + output_dir = str(reader.runplot_dir) output_fstep_keys = set(da_tars.keys()) + if plot_fstep_set is not None and output_fstep_keys - plot_fstep_set: zarr_fsteps = set(int(f) for f in reader.get_forecast_steps()) if plot_fstep_set == zarr_fsteps: @@ -551,16 +619,9 @@ def plot_data( if not isinstance(global_plotting_opts.get(stream), oc.DictConfig): global_plotting_opts[stream] = oc.DictConfig({}) - maps_config = common_ranges( - da_tars, da_preds, available_data.channels, global_plotting_opts[stream] - ) - bias_config = bias_ranges( - da_tars, da_preds, available_data.channels, global_plotting_opts[stream] - ) - - maps_config_dict = oc.OmegaConf.to_container(maps_config, resolve=True) - bias_config_dict = oc.OmegaConf.to_container(bias_config, resolve=True) - output_basedir = str(reader.runplot_dir) + _range_args = (da_tars, da_preds, available_data.channels, global_plotting_opts[stream]) + maps_config_dict = oc.OmegaConf.to_container(common_ranges(*_range_args), resolve=True) + bias_config_dict = oc.OmegaConf.to_container(bias_ranges(*_range_args), resolve=True) num_plot_workers = get_num_workers( check_process_headroom=True, @@ -568,6 +629,7 @@ def plot_data( ) tasks: list[dict] = [] + all_samples_tasks: list[dict] = [] for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): all_chs = list(np.atleast_1d(tars.channel.values)) plot_chs = ( @@ -589,14 +651,36 @@ def plot_data( bias_data = (preds - tars) if plot_bias else None + all_samples_tasks.append( + { + "plotter_cfg": plotter_cfg, + "output_basedir": output_dir, + "tars": tars, + "preds": preds, + "bias_data": bias_data, + "fstep": fstep, + "stream": stream, + "plot_chs": plot_chs, + "ensemble": list(available_data.ensemble), + "plot_histograms": plot_histograms, + "maps_config": maps_config_dict, + "bias_config": bias_config_dict, + } + ) + for sample in plot_samples: + # Pre-slice to this sample before serializing for the worker, to avoid + # sending the full per-fstep DataArray (all samples) to each loky process. + tars_s = tars.sel(sample=sample) + preds_s = preds.sel(sample=sample) + bias_s = bias_data.sel(sample=sample) if bias_data is not None else None tasks.append( { "plotter_cfg": plotter_cfg, - "output_basedir": output_basedir, - "tars": tars, - "preds": preds, - "bias_data": bias_data, + "output_basedir": output_dir, + "tars": tars_s, + "preds": preds_s, + "bias_data": bias_s, "sample": sample, "fstep": fstep, "stream": stream, @@ -620,63 +704,51 @@ def plot_data( calls, n_workers=num_plot_workers, backend="loky", desc=f"Plotting {run_id} - {stream}" ) + if all_samples_tasks: + _logger.info( + f"Parallel plotting: dispatching {len(all_samples_tasks)} across-samples " + f"tasks using up to {num_plot_workers} loky workers." + ) + as_calls = [delayed(_plot_all_samples)(**t) for t in all_samples_tasks] + dispatch_parallel( + as_calls, + n_workers=num_plot_workers, + backend="loky", + desc=f"Across-samples plots {run_id} - {stream}", + ) + if plot_animations: - plotter = Plotter(plotter_cfg, reader.runplot_dir) last_fstep = list(da_tars.keys())[-1] - last_tars = da_tars[last_fstep] last_preds = da_preds[last_fstep] - all_chs = list(np.atleast_1d(last_tars.channel.values)) - plot_chs = ( - [ch for ch in all_chs if ch in plot_channel_set] - if plot_channel_set is not None - else all_chs - ) - all_samples = list(np.unique(last_tars.sample.values)) - plot_samples = ( - [s for s in all_samples if s in plot_sample_set] - if plot_sample_set is not None - else all_samples - ) - plot_fsteps = da_tars.keys() - data_selection = { - "sample": plot_samples[-1], - "stream": stream, - "forecast_step": last_fstep, - } + last_tars = da_tars[last_fstep] + has_ens = "ens" in last_preds.dims + + _sel = lambda items, allowed: [x for x in items if x in allowed] if allowed else items + plot_chs = _sel(list(np.atleast_1d(last_tars.channel.values)), plot_channel_set) + plot_samples = _sel(list(np.unique(last_tars.sample.values)), plot_sample_set) + max_wk = reader.eval_cfg.get("max_workers", None) + anim_samples = plot_samples + (["all_samples"] if plot_histograms else []) + anim_kw = dict( + plotter=plotter, + samples=anim_samples, + fsteps=da_tars.keys(), + variables=plot_chs, + max_workers=max_wk, + select={"sample": plot_samples[-1], "stream": stream, "forecast_step": last_fstep}, + ) + + tags: list[str] = [] for ens in available_data.ensemble: - preds_name = "preds" if "ens" not in last_preds.dims else f"preds_ens_{ens}" - _dispatch_animations( - plotter, - plot_samples, - plot_fsteps, - plot_chs, - data_selection, - preds_name, - max_workers=max_wk, - ) + tags.append("preds" if not has_ens else f"preds_ens_{ens}") if plot_target: - _dispatch_animations( - plotter, - plot_samples, - plot_fsteps, - plot_chs, - data_selection, - "targets", - max_workers=max_wk, - ) + tags.append("targets") if plot_bias: for ens in available_data.ensemble: - bias_tag = "bias" if "ens" not in last_preds.dims else f"bias_ens_{ens}" - _dispatch_animations( - plotter, - plot_samples, - plot_fsteps, - plot_chs, - data_selection, - bias_tag, - max_workers=max_wk, - ) + tags.append("bias" if not has_ens else f"bias_ens_{ens}") + + for tag in tags: + _dispatch_animations(**anim_kw, tag=tag) # --------------------------------------------------------------------------- diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index b61b7813f..dfb2b3e8f 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -136,6 +136,13 @@ def clean_label(s: str) -> str: return re.sub(r"[_\-]+", " ", s).strip() +def filter_set(items: list, allowed: set | None) -> list: + """Return *items* filtered to *allowed*, or all items if *allowed* is ``None``.""" + if allowed is None: + return items + return [x for x in items if x in allowed] + + class DefaultMarkerSize: """ Utility class for managing default configuration values, such as marker sizes diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 961de7c10..97e0840a6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -11,6 +11,7 @@ import logging import os import warnings +from dataclasses import dataclass from pathlib import Path import cartopy @@ -23,6 +24,8 @@ from astropy_healpix import HEALPix as HEALPixGrid from cartopy.io import DownloadWarning from matplotlib.collections import LineCollection +from scipy.stats import skew +from scipy.stats import wasserstein_distance as wd try: import datashader as ds @@ -54,7 +57,7 @@ def _download_cartopy_off(enabled: bool) -> None: """Enable/disable blocking Cartopy downloads by elevating DownloadWarning to error.""" if enabled: warnings.filterwarnings("error", category=DownloadWarning) - _logger.info( + _logger.debug( "Auto-downloads are blocked for cartopy; only local cartopy data will be used." ) else: @@ -68,6 +71,39 @@ def _download_cartopy_off(enabled: bool) -> None: _logger.debug(f"Taking cartopy paths from {work_dir}") +@dataclass +class DistStats: + """Summary statistics for a 1-D distribution.""" + + count: int + min: float + max: float + mean: float + median: float + std: float + skewness: float + + @classmethod + def from_array(cls, v: np.typing.NDArray) -> "DistStats": + v = np.asarray(v).ravel() + return cls( + count=len(v), + min=float(np.min(v)), + max=float(np.max(v)), + mean=float(np.mean(v)), + median=float(np.median(v)), + std=float(np.std(v)), + skewness=float(skew(v, nan_policy="omit")), + ) + + def summary(self, label: str) -> str: + return ( + f"{label:8s} N={self.count} min={self.min:.3g} max={self.max:.3g} " + f"mean={self.mean:.3g} med={self.median:.3g} " + f"std={self.std:.3g} skew={self.skewness:.3g}" + ) + + class Plotter: """ Contains all basic plotting functions. @@ -94,7 +130,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | It can also be set later via update_data_selection. """ - _logger.info(f"Taking cartopy paths from {work_dir}") + _logger.debug(f"Taking cartopy paths from {work_dir}") self.image_format = plotter_cfg.get("image_format") self.animation_format = plotter_cfg.get("animation_format") @@ -102,6 +138,9 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.fig_size = plotter_cfg.get("fig_size") self.fps = plotter_cfg.get("fps") self.regions = plotter_cfg.get("regions") + self.log_x = plotter_cfg.get("log_x", False) + self.log_y = plotter_cfg.get("log_y", False) + self.n_bins = plotter_cfg.get("n_bins", 50) _download_cartopy_off(enabled=True) self.plot_subtimesteps = plotter_cfg.get( "plot_subtimesteps", False @@ -137,6 +176,10 @@ def update_data_selection(self, select: dict): _logger.warning("No sample in the selection. Might lead to unexpected results.") else: self.sample = select["sample"] + # "all_samples" is a proxy for across-samples aggregation; + # remove it from self.select so it won't be used in .sel() + if select["sample"] == "all_samples": + self.select.pop("sample") if "stream" not in select: _logger.warning("No stream in the selection. Might lead to unexpected results.") @@ -190,13 +233,14 @@ def select_from_da(self, da: xr.DataArray, selection: dict) -> xr.DataArray: da = da.sel({key: value}) return da - def create_histograms_per_sample( + def create_histograms( self, target: xr.DataArray, preds: xr.DataArray, variables: list, select: dict, tag: str = "", + ranges: dict | None = None, ) -> list[str]: """ Plot histogram of target vs predictions for each variable and valid time in the DataArray. @@ -222,44 +266,64 @@ def create_histograms_per_sample( self.update_data_selection(select) - # Basic map output directory for this stream - hist_output_dir = self.out_plot_basedir / self.stream / "histograms" + # Basic histogram output directory for this stream + hist_output_dir = self.get_hist_output_dir() if not os.path.exists(hist_output_dir): _logger.info(f"Creating dir {hist_output_dir}") os.makedirs(hist_output_dir, exist_ok=True) - for var in variables: - select_var = self.select | {"channel": var} - - targ, prd = ( - self.select_from_da(target, select_var), - self.select_from_da(preds, select_var), - ) + for region in self.regions: + if region != "global": + bbox = RegionBoundingBox.from_region_name(region) + reg_target = bbox.apply_mask(target) + reg_preds = bbox.apply_mask(preds) + else: + reg_target = target + reg_preds = preds - # Remove NaNs - targ = targ.dropna(dim="ipoint") - prd = prd.dropna(dim="ipoint") - assert targ.size > 0, "Data array must not be empty or contain only NAs" - assert prd.size > 0, "Data array must not be empty or contain only NAs" + for var in variables: + select_var = self.select | {"channel": var} - if self.plot_subtimesteps: - ntimes_unique = len(np.unique(targ.valid_time)) - _logger.info( - f"Creating histograms for {ntimes_unique} valid times of variable {var}." + targ, prd = ( + self.select_from_da(reg_target, select_var), + self.select_from_da(reg_preds, select_var), ) - groups = zip(targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False) - else: - _logger.info(f"Plotting histogram for all valid times of {var}") + # Remove NaNs + targ = targ.dropna(dim="ipoint") + prd = prd.dropna(dim="ipoint") + assert targ.size > 0, "Data array must not be empty or contain only NAs" + assert prd.size > 0, "Data array must not be empty or contain only NAs" + + if self.plot_subtimesteps and str(self.sample) != "all_samples": + ntimes_unique = len(np.unique(targ.valid_time)) + _logger.debug( + f"Creating histograms for {ntimes_unique} valid times of variable {var}." + ) - groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time + groups = zip( + targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False + ) + else: + _logger.debug(f"Plotting histogram for all valid times of {var}") - for (valid_time, targ_t), (_, prd_t) in groups: - if valid_time is not None: - _logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") - name = self.plot_histogram(targ_t, prd_t, hist_output_dir, var, tag=tag) - plot_names.append(name) + groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time + + for (valid_time, targ_t), (_, prd_t) in groups: + if valid_time is not None: + _logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") + var_range = ranges.get(var, {}) if ranges else {} + name = self.plot_histogram( + targ_t, + prd_t, + hist_output_dir, + var, + tag=tag, + region=region, + xlim=(var_range.get("vmin"), var_range.get("vmax")), + ) + plot_names.append(name) self.clean_data_selection() @@ -272,6 +336,8 @@ def plot_histogram( hist_output_dir: Path, varname: str, tag: str = "", + region: str = "", + xlim: tuple | None = None, ) -> str: """ Plot a histogram comparing target and prediction data for a specific variable. @@ -294,47 +360,121 @@ def plot_histogram( Name of the saved plot file. """ - # Get common bin edges - vals = np.concatenate([target_data, pred_data]) - bins = np.histogram_bin_edges(vals, bins=50) - - # Plot histograms - plt.hist(target_data, bins=bins, alpha=0.7, label="Target") - plt.hist(pred_data, bins=bins, alpha=0.7, label="Prediction") + tar_vals = np.asarray(target_data).ravel() + prd_vals = np.asarray(pred_data).ravel() + + # Get common bin edges — use fixed xlim range if provided for consistency + xmin, xmax = xlim if xlim else (None, None) + # Fall back to data-derived bounds if either limit is missing + if xmin is None or xmax is None: + vals = np.concatenate([tar_vals, prd_vals]) + if xmin is None: + xmin = float(np.nanmin(vals)) + if xmax is None: + xmax = float(np.nanmax(vals)) + # Add 5% margin on each side so tails are clearly visible + margin = (xmax - xmin) * 0.05 + xmin -= margin + xmax += margin + bins = np.linspace(xmin, xmax, self.n_bins + 1) + + # Compute histograms + target_counts, _ = np.histogram(tar_vals, bins=bins) + pred_counts, _ = np.histogram(prd_vals, bins=bins) + bin_centers = (bins[:-1] + bins[1:]) / 2 + + color_tar = "black" + color_pred = "#00897B" # teal / green-blue + + # Create figure with two subplots: histogram + ratio + fig, (ax_hist, ax_ratio) = plt.subplots( + 2, + 1, + sharex=True, + figsize=self.fig_size or (8, 6), + gridspec_kw={"height_ratios": [3, 1], "hspace": 0.05}, + ) - # set labels and title - plt.xlabel(f"Variable: {varname}") - plt.ylabel("Frequency") - plt.title( - f"Histogram of Target and Prediction: {self.stream}, {varname} : " - f"fstep = {self.fstep:03}" + # Upper panel: histogram curves + ax_hist.plot( + bin_centers, target_counts, alpha=0.7, label="Target", linewidth=1.5, color=color_tar + ) + ax_hist.plot( + bin_centers, pred_counts, alpha=0.7, label="Prediction", linewidth=1.5, color=color_pred + ) + ax_hist.set_ylabel("Frequency") + ax_hist.set_title(f"{self.stream}, {varname} : fstep = {self.fstep:03}") + ax_hist.legend(frameon=False) + if self.log_y: + ax_hist.set_yscale("log") + ax_hist.grid(True, linestyle="--", alpha=0.5) + + # Lower panel: ratio (prediction / target) + with np.errstate(divide="ignore", invalid="ignore"): + ratio = np.where(target_counts > 0, pred_counts / target_counts, np.nan) + ax_ratio.plot(bin_centers, ratio, linewidth=1.2, color=color_pred) + ax_ratio.axhline(1.0, linestyle="--", color="gray", linewidth=0.8) + ax_ratio.set_ylabel("Pred / Target") + ax_ratio.set_xlabel(f"Variable: {varname}") + ax_ratio.set_ylim(0, 2) + ax_ratio.grid(True, linestyle="--", alpha=0.5) + + if self.log_x: + ax_hist.set_xscale("log") + ax_ratio.set_xscale("log") + ax_ratio.set_xlim(xmin, xmax) + + t_s = DistStats.from_array(tar_vals) + p_s = DistStats.from_array(prd_vals) + + # Wasserstein distance + w_dist = wd(tar_vals, prd_vals) + + stat_text = ( + f"Wasserstein distance: {w_dist:.4g}\n{t_s.summary('Target:')}\n{p_s.summary('Pred:')}" ) - plt.legend(frameon=False) - valid_time = ( - target_data["valid_time"][0] - .values.astype("datetime64[m]") - .astype(datetime.datetime) - .strftime("%Y-%m-%dT%H%M") + fig.text( + 0.5, + -0.02, + stat_text, + ha="center", + va="top", + fontsize=7, + family="monospace", + bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5), ) - # TODO: make this nicer + # For "all_samples" (across-samples) histograms, omit the valid_time from the name + is_global = str(self.sample) == "all_samples" + + if is_global: + valid_time = None + else: + valid_time = ( + target_data["valid_time"][0] + .values.astype("datetime64[m]") + .astype(datetime.datetime) + .strftime("%Y-%m-%dT%H%M") + ) + parts = [ "histogram", - self.run_id, - tag, + str(self.run_id), + str(tag) if tag else "", str(self.sample), valid_time, - self.stream, + str(self.stream), + region if region else "", varname, - str(self.fstep).zfill(3), + f"{self.fstep:03d}", ] name = "_".join(filter(None, parts)) fname = hist_output_dir / f"{name}.{self.image_format}" _logger.debug(f"Saving histogram to {fname}") - plt.savefig(fname, bbox_inches="tight") - plt.close() + fig.savefig(fname, bbox_inches="tight") + plt.close(fig) return name @@ -695,7 +835,7 @@ def _build_map_filename(self, varname: str, regionname: str, tag: str, data: xr. parts.append(varname) if self.fstep is not None: - parts.extend(["fstep", f"{self.fstep:03d}"]) + parts.append(f"{self.fstep:03d}") return "_".join(filter(None, parts)) @@ -902,6 +1042,16 @@ def get_map_output_dir(self, tag): """ return self.out_plot_basedir / self.stream / "maps" / tag + def get_hist_output_dir(self): + """Return the output directory path for histogram plots. + + Returns + ------- + Path + Resolved directory path: ``//histograms``. + """ + return self.out_plot_basedir / self.stream / "histograms" + def get_map_title(self, var, valid_time, data): """Build the title string for a map plot. diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py index aff40b9d9..82bc8a026 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py @@ -49,6 +49,7 @@ def _score_single_fstep( bbox: "RegionBoundingBox", metrics: dict, group_by_coord: str | None, + agg_dims: str | list[str] = "ipoint", ) -> tuple[int, xr.DataArray, dict[tuple[int, str], dict]] | None: """Score all metrics for one fstep in one region. Stateless, thread-safe. @@ -89,7 +90,7 @@ def _score_single_fstep( score = get_score( score_data, metric, - agg_dims="ipoint", + agg_dims=agg_dims, group_by_coord=group_by_coord, parameters=parameters, ) @@ -176,6 +177,7 @@ def calc_scores_per_stream( aligned_clim_data = get_climatology(reader, da_tars, stream) max_workers = reader.eval_cfg.get("max_workers", None) + agg_dims = reader.eval_cfg.get("agg_dims", "ipoint") for region in regions: bbox = RegionBoundingBox.from_region_name(region) @@ -194,6 +196,7 @@ def calc_scores_per_stream( bbox, metrics, max_workers, + agg_dims, ) store_metrics_for_region( @@ -227,6 +230,7 @@ def compute_scores_for_region( bbox: "RegionBoundingBox", metrics: dict, max_workers: int | None, + agg_dims: str | list[str] = "ipoint", ) -> tuple[list, dict]: """Dispatch parallel scoring for all fsteps in one region. @@ -289,6 +293,7 @@ def compute_scores_for_region( bbox, metrics, group_by_coord, + agg_dims, ) for fstep, tars_fs, preds_fs, preds_next, tars_next, climatology in fstep_tasks ] diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 6c4c0f913..22547ee92 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -25,6 +25,16 @@ class SampleMetaData: global_params: dict | None = None + def add_global_params(self, params: dict) -> None: + if self.global_params is None: + self.global_params = {} + self.global_params.update(params) + + def add_params(self, params: dict) -> None: + if self.params is None: + self.params = {} + self.params.update(params) + class Sample: # keys: stream name, values: SampleMetaData @@ -125,6 +135,7 @@ def add_meta_info(self, stream_name: str, meta_info: SampleMetaData) -> None: """ Add metadata for stream @stream_name to sample """ + self.meta_info[stream_name] = meta_info def get_stream_data(self, stream_name: str) -> StreamData: @@ -332,6 +343,7 @@ def add_source_stream( """ Add data for one stream to sample @source_sample_idx """ + self.source_samples.samples[source_sample_idx].add_stream_data(stream_name, stream_data) # add the meta_info diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 7dfac40d2..9c20d010b 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -28,17 +28,18 @@ def __len__(self): return len(self.masks) def add_mask(self, mask, params, cfg, losses, idx, correspondence, relationship): + global_params = { + "idx": idx, + "correspondence": correspondence, + "loss": losses, + "relationship": relationship, + } self.masks += [mask] self.metadata += [ SampleMetaData( params={**cfg, **params}, mask=mask, - global_params={ - "idx": idx, - "correspondence": correspondence, - "loss": losses, - "relationship": relationship, - }, + global_params=global_params, ) ] diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 000c447db..e248555a8 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -97,6 +97,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.streams = cf.streams self.rank = cf.rank self.world_size = cf.world_size + self.diffusion_model_conditioning = cf.get("fe_diffusion_model_conditioning", None) self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False) # initialise healpic @@ -499,7 +500,6 @@ def _build_stream_data( output_mask : mask for output/prediction/target input_mask : mask for network input (can be source or target) - Returns: StreamData with source and targets masked according to view_meta """ @@ -600,6 +600,7 @@ def _get_source_target_masks(self, training_mode): self.num_healpix_cells, stream_info, ) + # identical for all streams num_target_samples = len(masks[stream_info["name"]][0]) num_source_samples = len(masks[stream_info["name"]][1]) @@ -720,6 +721,14 @@ def _get_batch(self, idx: int, num_forecast_steps: int): input_mask=target_mask, ) target_metadata = target_masks.metadata[tidx] + + # Get first target step's times (using self.output_offset as the first output step index) + if self.diffusion_model_conditioning in ["date_time", "date", "time"]: + target_times_array = sdata.target_times_raw[self.output_offset] + target_metadata.add_params({'timestamp': ( + target_times_array[0] if len(target_times_array) > 0 else None + )}) + # also want to add the mask to the metadata target_metadata.mask = target_mask # Map target to all source students @@ -733,6 +742,23 @@ def _get_batch(self, idx: int, num_forecast_steps: int): target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) + #add target times in source for diffusion model date/time conditioning + if self.diffusion_model_conditioning in ["date_time", "date", "time"]: + #TODO: Might need upgrading fro num_samples > 1 + + # Assert singular source and target samples + assert len(batch.source_samples.samples) == 1, "Only single source sample supported for diffusion model conditioning." + assert len(batch.target_samples.samples) == 1, "Only single target sample supported for diffusion model conditioning." + + source_sample = batch.source_samples.samples[0] + target_sample = batch.target_samples.samples[0] + + # Copy target timestamps to source metadata for all streams + for stream_name in [s["name"] for s in self.streams]: + if stream_name in target_sample.meta_info and stream_name in source_sample.meta_info: + target_timestamp = target_sample.meta_info[stream_name].params.get('timestamp') + source_sample.meta_info[stream_name].add_params({'timestamp': target_timestamp}) + return batch def __iter__(self) -> ModelBatch: diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index bf97479e6..e9e679aac 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -13,7 +13,8 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.layers import LinearNormConditioning +from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm from weathergen.model.positional_encoding import rotary_pos_emb_2d """ @@ -24,6 +25,16 @@ """ +def _apply_xsa(attn_out: torch.Tensor, self_values: torch.Tensor) -> torch.Tensor: + attn_out_float = attn_out.float() + self_values_float = self_values.float() + denom = self_values_float.pow(2).sum(dim=-1, keepdim=True).clamp_min( + torch.finfo(self_values_float.dtype).eps + ) + proj = (attn_out_float * self_values_float).sum(dim=-1, keepdim=True) / denom + return (attn_out_float - (proj * self_values_float)).to(attn_out.dtype) + + class MultiSelfAttentionHeadVarlen(torch.nn.Module): def __init__( self, @@ -41,6 +52,7 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, + use_xsa=False, ): super(MultiSelfAttentionHeadVarlen, self).__init__() @@ -50,6 +62,7 @@ def __init__( self.softcap = softcap self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -118,6 +131,9 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): dropout_p=dropout_rate, ) + if self.use_xsa: + outs = _apply_xsa(outs, vs) + out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: @@ -141,6 +157,7 @@ def __init__( softcap=0.0, norm_eps=1e-5, attention_dtype=torch.bfloat16, + use_xsa=False, ): super(MultiSelfAttentionHeadVarlenFlex, self).__init__() @@ -148,6 +165,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -200,6 +218,9 @@ def forward(self, x, x_lens=None): outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze() + if self.use_xsa: + outs = _apply_xsa(outs, vs.transpose(1, 2).squeeze()) + out = self.dropout(self.proj_out(outs.flatten(-2, -1))) if self.with_residual: out = out + x_in @@ -226,6 +247,9 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, + is_dit=False, + dit_is_cond=False, + use_xsa=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -234,6 +258,8 @@ def __init__( self.softcap = softcap self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.dtype = attention_dtype + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -243,10 +269,22 @@ def __init__( else: norm = RMSNorm - if dim_aux is not None: + self.is_dit = is_dit + self.dit_is_cond = dit_is_cond + if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert dim_aux is None, "conditioning not yet implemented for DIT attention" + assert with_residual, "DIT attention should always have residual connection" + self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed, dtype=attention_dtype + ) + elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) + self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -277,10 +315,21 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - def forward(self, x, coords=None, ada_ln_aux=None): + def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x - x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + + # Handle ada_ln_aux conditioning + if self.is_dit: + if self.dit_is_cond: + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + x = self.lnorm(x) + cond_gate = 1 + x, noise_gate = self.noise_conditioning(x, emb) + gate = cond_gate * noise_gate + else: + x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) # project onto heads s = [x.shape[0], x.shape[1], self.num_heads, -1] @@ -295,9 +344,13 @@ def forward(self, x, coords=None, ada_ln_aux=None): outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) + if self.use_xsa: + outs = _apply_xsa(outs, vs.transpose(1, 2)) + out = self.proj_out(self.dropout(outs.flatten(-2, -1))) + if self.with_residual: - out = x_in + out + out = x_in + out * gate if self.is_dit else x_in + out return out @@ -541,6 +594,9 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, + is_dit=False, # should only be True for diffusion model + dit_is_cond = False, # whether the attention is used for conditioning in the diffusion model (as opposed to denoising). Should only be True for cross attention layers in the diffusion model, and will control whether ada_ln_aux is applied to the input or output of the attention layer + use_xsa=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -550,6 +606,7 @@ def __init__( self.dropout_rate = dropout_rate self.with_residual = with_residual self.with_2d_rope = with_2d_rope + self.use_xsa = use_xsa assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -559,10 +616,22 @@ def __init__( else: norm = RMSNorm - if dim_aux is not None: + self.is_dit = is_dit + self.dit_is_cond = dit_is_cond + + if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert with_residual, "DIT attention should always have residual connection" + self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed, dtype=attention_dtype + ) # TODO: Do I need to pass dtype? + elif dim_aux is not None: self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps) else: self.lnorm = norm(dim_embed, eps=norm_eps) + self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False) @@ -587,10 +656,21 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) - def forward(self, x, coords=None, ada_ln_aux=None): + def forward(self, x, coords=None, emb=None, ada_ln_aux=None): if self.with_residual: x_in = x - x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + + # Handle ada_ln_aux conditioning + if self.is_dit: + if self.dit_is_cond: + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + x = self.lnorm(x) + cond_gate = 1 + x, noise_gate = self.noise_conditioning(x, emb) + gate = cond_gate * noise_gate + else: + x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention @@ -610,9 +690,13 @@ def forward(self, x, coords=None, ada_ln_aux=None): # ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt) outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) + if self.use_xsa: + outs = _apply_xsa(outs, vs) + out = self.proj_out(outs.flatten(-2, -1)) + if self.with_residual: - out = out + x_in + out = x_in + out * gate if self.is_dit else out + x_in return out @@ -632,12 +716,14 @@ def __init__( qk_norm_type=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + is_dit=False, ): super(MultiCrossAttentionHead, self).__init__() self.num_heads = num_heads self.with_residual = with_residual self.with_flash = with_flash + self.is_dit = is_dit if norm_type == "LayerNorm": norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps) @@ -647,7 +733,14 @@ def __init__( assert dim_embed_q % num_heads == 0 self.dim_head_proj = dim_embed_q // num_heads if dim_head_proj is None else dim_head_proj - self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps) + if is_dit: + assert with_residual + self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning( + latent_space_dim=dim_embed_q, dtype=attention_dtype + ) + else: + self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps) self.lnorm_in_kv = norm(dim_embed_kv, eps=norm_eps) self.proj_heads_q = torch.nn.Linear(dim_embed_q, num_heads * self.dim_head_proj, bias=False) @@ -677,10 +770,16 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) ######################################### - def forward(self, x_q, x_kv): + def forward(self, x_q, x_kv, emb=None): if self.with_residual: x_q_in = x_q - x_q, x_kv = self.lnorm_in_q(x_q), self.lnorm_in_kv(x_kv) + + if self.is_dit: + x_q = self.lnorm_in_q(x_q) + x_q, gate = self.noise_conditioning(x_q, emb) + else: + x_q = self.lnorm_in_q(x_q) + x_kv = self.lnorm_in_kv(x_kv) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention @@ -696,6 +795,6 @@ def forward(self, x_q, x_kv): outs = self.dropout(self.proj_out(outs.flatten(-2, -1))) if self.with_residual: - outs = x_q_in + outs + outs = x_q_in + outs * gate if self.is_dit else x_q_in + outs return outs diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index a05e25ca9..3d1996930 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -25,7 +25,17 @@ class SelfAttentionBlock(nn.Module): layer norm with a FFN. """ - def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs): + def __init__( + self, + dim, + dim_aux, + with_adanorm, + num_heads, + dropout_rate, + mlp_type="mlp", + use_xsa=False, + **kwargs, + ): super().__init__() self.with_adanorm = with_adanorm @@ -34,6 +44,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs dim_embed=dim, num_heads=num_heads, with_residual=False, + use_xsa=use_xsa, **kwargs["attention_kwargs"], ) if self.with_adanorm: @@ -48,6 +59,7 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs dim_out=dim, hidden_factor=4, dropout_rate=0.1, + mlp_type=mlp_type, nonlin=approx_gelu, with_residual=False, ) @@ -98,6 +110,8 @@ def __init__( with_mlp, num_heads, dropout_rate, + mlp_type="mlp", + use_xsa=False, **kwargs, ): super().__init__() @@ -111,6 +125,7 @@ def __init__( dim_embed=dim_q, num_heads=num_heads, with_residual=False, + use_xsa=use_xsa, **kwargs["attention_kwargs"], ) if self.with_adanorm: @@ -140,6 +155,7 @@ def __init__( dim_in=dim_q, dim_out=dim_q, hidden_factor=4, + mlp_type=mlp_type, nonlin=approx_gelu, with_residual=False, ) @@ -189,6 +205,7 @@ def __init__( attention_kwargs, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, tro_type, mlp_norm_eps=1e-6, ): @@ -198,6 +215,7 @@ def __init__( self.tro_type = tro_type self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor + self.tr_mlp_type = tr_mlp_type self.block = nn.ModuleList() @@ -235,6 +253,7 @@ def __init__( dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + use_xsa=self.cf.get("use_xsa", False), ) ) @@ -246,6 +265,7 @@ def __init__( with_residual=True, hidden_factor=self.tr_mlp_hidden_factor, dropout_rate=0.1, # Assuming dropout_rate is 0.1 + mlp_type=self.tr_mlp_type, norm_type=self.cf.norm_type, dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), norm_eps=self.cf.mlp_norm_eps, diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py new file mode 100644 index 000000000..7909534d1 --- /dev/null +++ b/src/weathergen/model/diffusion.py @@ -0,0 +1,586 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ---------------------------------------------------------------------------- +# Third-Party Attribution: NVLABS/EDM (Elucidating the Design of Diffusion Models) +# This file incorporates code originally from the 'NVlabs/edm' repository. +# +# Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Third-Party Attribution: facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) +# This file incorporates code originally from the 'facebookresearch/DiT' repository, +# with adaptations. +# +# The original code is licensed under CC-BY-NC. +# ---------------------------------------------------------------------------- + + +import logging +import math + +import numpy as np +import torch + +from weathergen.common.config import Config, get_path_run +from weathergen.datasets.batch import SampleMetaData +from weathergen.model.engines import ForecastingEngine + +logger = logging.getLogger(__name__) + + +class DiffusionForecastEngine(torch.nn.Module): + # Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72 + + def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: ForecastingEngine): + super().__init__() + self.cf = cf + self.num_healpix_cells = num_healpix_cells + self.net = forecast_engine + self.preconditioner = Preconditioner() + self.frequency_embedding_dim = self.cf.frequency_embedding_dim + self.embedding_dim = self.cf.embedding_dim + self.noise_embedder = NoiseEmbedder( + embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim + ) + self.conditioning = self.cf.get("fe_diffusion_model_conditioning", None) + self.conditioning_type = self.cf.get("fe_diffusion_model_conditioning_type", None) + + _date_time_modes = {"date_time", "date", "time"} + assert self.conditioning not in _date_time_modes or self.conditioning_type == "ada_ln", ( + f"fe_diffusion_model_conditioning_type must be 'ada_ln' when " + f"fe_diffusion_model_conditioning is '{self.conditioning}' " + f"(got '{self.conditioning_type}')" + ) + _ada_ln = self.conditioning_type == "ada_ln" + assert self.cf.get("diffusion_conditioning_embed_dim", None) is not None or not _ada_ln, ( + f"diffusion_conditioning_embed_dim must be set when " + f"fe_diffusion_model_conditioning_type is 'ada_ln'" + ) + _offset = self.cf.get("training_config", {}).get("forecast", {}).get("offset", 0) + assert self.conditioning not in _date_time_modes or _offset == 0, ( + f"forecast.offset must be 0 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got offset={_offset})" + ) + _input_num_steps = self.cf.get("training_config", {}).get("model_input", {}).get("forecasting", {}).get("num_steps_input", 0) + assert self.conditioning != "forecast" or _input_num_steps == 2, ( + f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) + assert self.conditioning not in ["date_time", "date", "time"] or _input_num_steps == 1, ( + f"forecast.input_num_steps must be 1 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) + assert self.conditioning != "forecast" or self.conditioning_type in {"cross_attn"}, ( + f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " + f"fe_diffusion_model_conditioning is 'forecast' " + f"(got '{self.conditioning_type}')" + ) + + if self.conditioning and (self.conditioning in ["date_time", "date", "time"]): + self.datetime_embedder = DateTimeEncoder(self.conditioning) + + # Parameters + self.sigma_min = self.cf.sigma_min + self.sigma_max = self.cf.sigma_max + self.sigma_data = self.cf.sigma_data + self.rho = self.cf.rho + self.p_mean = self.cf.p_mean + self.p_std = self.cf.p_std + self.cur_token = None # TODO: re move after single sample experiments + self._noised_tokens: torch.Tensor | None = None + self._fixed_noise_level: float | None = None + + self._noise = None + + def forward( + self, + tokens: torch.Tensor = None, + fstep: int = None, + meta_info: dict[str, SampleMetaData] = None, + coords: torch.Tensor = None, + num_steps: int = 10, + ) -> torch.Tensor: + """ + Forward pass that routes to training_forward or inference_forward based on model status. + + During training: + - calls training_forward with tokens, fstep, meta_info, coords + - extracts datetime conditioning from meta_info and passes through datetime embedder + - adds noise to target and returns denoised prediction + + During inference: + - calls inference_forward with fstep, num_steps, and meta_info + - generates samples via iterative diffusion steps with conditional temporal modulation + + Args: + tokens: Training tensor of shape (B, H, D) - required during training + fstep: Forecast step index - required for both modes + meta_info: Sample metadata dict containing timestamps - required for both modes + coords: Optional coordinate tensor + num_steps: Number of diffusion steps for inference (default: 30) + + Returns: + torch.Tensor: Model output (denoised prediction during training, + or generated sample during inference) + + Raises: + ValueError: If required arguments are missing for current mode + """ + # called during training in training mode + # called during training in training mode + if self.training: + if tokens is None or fstep is None or meta_info is None: + raise ValueError( + f"During training, tokens, fstep, and meta_info are required. " + f"Got tokens={tokens is not None}, fstep={fstep}, meta_info={meta_info is not None}" + ) + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) + else: + # called in evaluation mode : + # decide btw pure noise generation (inference) vs denoising a sample for + # evaluation (train) using the stage variable + if self.cf.stage == "train" or self.cf.stage == "train_continue": + # NOTE: temporary for analysing denoising + return self.training_forward( + tokens=tokens, + fstep=fstep, + meta_info=meta_info, + coords=coords, + ) + elif self.cf.stage == "inference": + if fstep is None: + raise ValueError(f"During inference, fstep is required. Got fstep={fstep}") + self.cur_token = tokens.detach() if tokens is not None else None + return self.inference_forward( + fstep=fstep, + num_steps=num_steps, + meta_info=meta_info, + coords=coords, + ) + + def training_forward( + self, + tokens: torch.Tensor, + fstep: int, + meta_info: dict[str, SampleMetaData], + coords: torch.Tensor = None, + ) -> torch.Tensor: + """ + Model forward call during training. Unpacks the conditioning c = [x_{t-k}, ..., x_{t}], the + target y = x_{t+1}, and the random noise eta from the data, computes the diffusion noise + level sigma, and feeds the noisy target along with the conditioning and sigma through the + model to return a denoised prediction. + """ + # Retrieve conditionings [0:-1], target [-1], and noise from data object. + # TOOD: The data retrieval ignores batch and stream dimension for now (has to be adapted). + # c = [data.get_input_data(t) for t in range(data.get_sample_len() - 1)] + # y = data.get_input_data(-1) + # eta = data.get_input_metadata(-1) + + self.cur_token = tokens.detach() + + # y is always the target to denoise (set by DiffusionLatentTargetEncoder.pre_compute) + y = tokens + assert y is not None, ( + "diffusion_target_tokens not found in meta_info — " + "DiffusionLatentTargetEncoder.pre_compute must be called before training_forward" + ) + + c = None + if self.conditioning in ["date_time", "date", "time"]: + c = meta_info["ERA5"].params["timestamp"] + elif self.conditioning == "forecast": + c = meta_info["ERA5"].params["conditioning_tokens"] # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) + + if self.training: + eta = torch.tensor([meta_info["ERA5"].params["noise_level_rn"]], device=tokens.device) + else: + # During validation, use fixed noise level (default: 0.0 = mean of noise distribution) + noise_level = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 + eta = torch.tensor([noise_level], device=tokens.device) + + # Compute sigma (noise level) from eta and create noise tensor + sigma = (eta * self.p_std + self.p_mean).exp() + n = torch.randn_like(y) * sigma + + self._noised_tokens = (y + n).detach() + + return self.denoise(x=y + n, c=c, sigma=sigma, fstep=fstep, coords=coords) + + def denoise( + self, + x: torch.Tensor, + c: torch.Tensor, + sigma: float, + fstep: int, + coords: torch.Tensor = None, + ) -> torch.Tensor: + """ + The actual diffusion step, where the model removes noise from the input x under + consideration of a conditioning c (e.g., previous time steps) and the current diffusion + noise level sigma. + """ + # Compute scaling conditionings (EDM Eq. 7 — disabled for direct prediction) + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (sigma**2 + self.sigma_data**2).sqrt() + c_noise = sigma.log() / 4 + + # Embed noise level + noise_emb = self.noise_embedder(c_noise) + + # Precondition input and feed through network + if self.conditioning in ["date_time", "date", "time"]: + c = self.datetime_embedder(c).to(x.device) + + net_input = c_in * x + + return c_skip * x + c_out * self.net( + net_input, fstep=fstep, coords=coords, noise_emb=noise_emb, conditioning=c + ) # Eq. (7) in EDM paper + + def inference_forward( + self, + fstep: int, + num_steps: int = 50, + meta_info: dict[str, SampleMetaData] = None, + coords: torch.Tensor = None, + ) -> torch.Tensor: + """ + Forward pass of the diffusion model during inference. + + Iteratively denoises a random sample using the learned score function, + with optional temporal conditioning extracted from meta_info. + https://github.com/NVlabs/edm/blob/main/generate.py + + Args: + fstep: Forecast step index for the network + num_steps: Number of diffusion denoising steps (default: 30) + meta_info: Optional sample metadata dict containing timestamps for temporal conditioning + coords: Optional coordinate tensor for spatial conditioning + Returns: + torch.Tensor: Generated sample of shape (1, num_healpix_cells, ae_global_dim_embed) + """ + + # Extract conditioning (mirrors training_forward). + c = None + if self.conditioning in ["date_time", "date", "time"]: + c = meta_info["ERA5"].params["timestamp"] + elif self.conditioning == "forecast": + c = meta_info["ERA5"].params["conditioning_tokens"] + + + # Sample pure noise (assuming single batch element for now) + # torch.manual_seed(42) + x = torch.randn(1, self.num_healpix_cells, self.cf.ae_global_dim_embed).to(device="cuda") + + ### OLD WAY OF COMPUTING SIGMA SCHEDULE + # # Time step discretization. + # step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") + # t_steps = ( + # self.sigma_max ** (1 / self.rho) + # + step_indices + # / (num_steps - 1) + # * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) + # ) ** self.rho + # t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + ### NEW WAY OF COMPUTING SIGMA SCHEDULE WITH TRAINING-ALIGNED BOUNDS AND DIAGNOSTICS + # --- Training-aligned sigma bounds --- + # Training noise: sigma = exp(eta * p_std + p_mean), eta ~ N(0,1). + # The network only learns to denoise reliably within the training distribution. + # - sigma_max_eff: cap at 99.7th percentile = exp(p_mean + 3*p_std) + # Beyond this, the denoiser is in untrained territory → garbage predictions + # that poison the entire ODE trajectory. + # - sigma_min_eff: floor at a level where the network still contributes. + # With EDM preconditioning, c_skip = sigma_data^2/(sigma^2+sigma_data^2). + # At sigma << sigma_data, c_skip → 1, meaning the output ≈ input (skip + # connection dominates) and the network can no longer correct errors. + # We stop at sigma_min = max(config value, sigma_data * 0.01), which gives + # c_skip ≈ 0.9999 — still some network contribution, and avoids the + # numerical instability of dividing by near-zero sigma in the ODE. + sigma_max_train = math.exp(self.p_mean + 3.0 * self.p_std) + sigma_max_eff = min(self.sigma_max, sigma_max_train) + sigma_min_eff = max(self.sigma_min, self.sigma_data * 0.01) + logger.info( + f"Inference sigma schedule: " + f"sigma_max_eff={sigma_max_eff:.4f} (config={self.sigma_max}, train 3σ={sigma_max_train:.4f}), " + f"sigma_min_eff={sigma_min_eff:.4f} (config={self.sigma_min}), " + f"sigma_data={self.sigma_data}, rho={self.rho}, num_steps={num_steps}" + ) + + # --- Time step discretization (EDM Eq. 5) with training-aligned bounds --- + step_indices = torch.arange(num_steps, dtype=torch.float64, device="cuda") + t_steps = ( + sigma_max_eff ** (1 / self.rho) + + step_indices + / (num_steps - 1) + * (sigma_min_eff ** (1 / self.rho) - sigma_max_eff ** (1 / self.rho)) + ) ** self.rho + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + # t_steps = torch.cat( + # [self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + # ) # t_N = 0 + + # --- Per-step tracking for diagnostics --- + track = { + "sigma": [], + "x_std": [], + "denoised_std": [], + "l2_to_target": [], + "cosine_to_target": [], + "c_skip": [], + "x": [x.cpu()], + } + + # Per-step intermediate denoised states (one per ODE step). + # Returned to the caller so they can be treated as a forecast-step dimension. + intermediate_x: list[torch.Tensor] = [] + + # Main sampling loop. + x_next = x * t_steps[0] + for i, (t_cur, t_next) in enumerate( + zip(t_steps[:-1], t_steps[1:], strict=False) + ): # 0, ..., N-1 + t_cur = torch.tensor([t_cur], device="cuda").float() + t_next = torch.tensor([t_next], device="cuda").float() + + x_cur = x_next + + # Increase noise temporarily. (Stochastic sampling; not used for now) + # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + # t_hat = self.net.round_sigma(t_cur + gamma * t_cur) + # x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * s_noise * torch.randn_like(x_cur) + x_hat = x_cur + t_hat = t_cur + + # Euler step. + denoised = self.denoise(x=x_hat, c=c, sigma=t_hat, fstep=fstep, coords=coords) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + # --- Record diagnostics --- + with torch.no_grad(): + s = t_cur.item() + track["sigma"].append(s) + track["c_skip"].append(self.sigma_data**2 / (s**2 + self.sigma_data**2)) + track["x_std"].append(x_next.std().item()) + track["denoised_std"].append(denoised.std().item()) + track["x"].append(x_next.cpu()) + if self.cur_token is not None: + track["l2_to_target"].append((x_next - self.cur_token).norm().item()) + track["x"].append(self.cur_token.cpu()) + + # Record intermediate denoised state for this ODE step. + intermediate_x.append(x_next) + + self._plot_sampling_diagnostics(track, num_steps) + + return intermediate_x + + def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: + """Save a diagnostic plot of the sampling trajectory.""" + import matplotlib + + matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt + + steps = list(range(len(track["sigma"]))) + has_target = len(track["l2_to_target"]) > 0 + n_plots = 3 + + fig, axes = plt.subplots(n_plots, 1, figsize=(10, 3 * n_plots), sharex=True) + + # 1) Sigma schedule + axes[0].semilogy(steps, track["sigma"], "o-", markersize=3) + axes[0].set_ylabel("sigma (noise level)") + axes[0].set_title( + f"Sampling diagnostics | sigma_max_eff={track['sigma'][0]:.2f}, " + f"sigma_data={self.sigma_data}, steps={num_steps}" + ) + axes[0].axhline( + self.sigma_data, color="grey", ls="--", lw=0.8, label=f"sigma_data={self.sigma_data}" + ) + axes[0].legend(fontsize=8) + axes[0].grid(True, alpha=0.3) + + # 2) Std of x_next and denoised estimate + axes[1].plot(steps, track["x_std"], "o-", markersize=3, label="x (noisy state)") + axes[1].plot(steps, track["denoised_std"], "s-", markersize=3, label="denoised estimate") + if self.cur_token is not None: + target_std = self.cur_token.std().item() + axes[1].axhline( + target_std, color="grey", ls="--", lw=0.8, label=f"target std={target_std:.3f}" + ) + axes[1].set_ylabel("std") + axes[1].legend(fontsize=8) + axes[1].grid(True, alpha=0.3) + + if has_target: + # 3) L2 error to target + axes[2].plot(steps, track["l2_to_target"], "o-", markersize=3, color="tab:red") + axes[2].set_ylabel("L2 error to target") + axes[2].grid(True, alpha=0.3) + + axes[-1].set_xlabel("sampling step") + fig.tight_layout() + + out_dir = get_path_run(self.cf) + out_dir.mkdir(exist_ok=True, parents=True) + out_path_base = out_dir / "plots" / "validation" / "plots" + out_path_base.mkdir(exist_ok=True, parents=True) + fig.savefig(out_path_base / "sampling_diagnostics.png", dpi=150) + plt.close(fig) + logger.info(f"Saved sampling diagnostics to {out_path_base / 'sampling_diagnostics.png'}") + + +class Preconditioner: + # Preconditioner, e.g., to concatenate previous frames to the input + def __init__(self): + pass + + def precondition(self, x, c): + return x + + +# NOTE: Adapted from DiT codebase: +class NoiseEmbedder(torch.nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, embedding_dim: int, frequency_embedding_dim: int, dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + self.mlp = torch.nn.Sequential( + torch.nn.Linear(frequency_embedding_dim, embedding_dim, bias=True), + torch.nn.SiLU(), + torch.nn.Linear(embedding_dim, embedding_dim, bias=True), + ) + self.frequency_embedding_dim = frequency_embedding_dim + + def timestep_embedding(self, t: float, max_period: int = 10000): + """ + Create sinusoidal timestep embeddings. + :param t: a scalar or 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # Ensure t is 1D + if t.ndim == 0: + t = t.view(1) + + half = self.frequency_embedding_dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=self.dtype) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.frequency_embedding_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t: float): + t_freq = self.timestep_embedding(t) + t_emb = self.mlp(t_freq) + return t_emb + + +class DateTimeEncoder(torch.nn.Module): + """ + Encodes timestamp(s) into multi-frequency sinusoidal calendar embeddings. + + Inspired by cBottle (Climate in a Bottle) with k=1..8 frequency scales. + Captures seasonal (day-of-year) and diurnal (time-of-day) cycles at multiple timescales. + + Input shape: scalar or any tensor shape (...) + Output shape: (..., 32) — 8 frequencies × 4 components (cos/sin per signal) + + Output structure for k=1..num_frequencies: + [cos(2πk·doy_frac), sin(2πk·doy_frac), cos(2πk·tod_frac), sin(2πk·tod_frac)] + where: + - doy_frac = day_of_year / days_in_year + - tod_frac = seconds_of_day / 86400.0 + """ + + def __init__(self, conditioning: str): + super().__init__() + self.num_frequencies = 8 + assert conditioning in ["date_time", "date", "time"], f"Unsupported conditioning: {conditioning}" + self.date_only = conditioning == "date" + self.time_only = conditioning == "time" + + + def forward(self, timestamp: np.ndarray | np.datetime64) -> torch.Tensor: + """ + Encode numpy datetime64 timestamps into 32D multi-frequency calendar embeddings. + + Args: + timestamp: np.datetime64 scalar or array of timestamps + + Returns: + torch.Tensor of shape (..., 32) containing multi-frequency embeddings + """ + + # TODO: Consider adding local time encoding (e.g., using longitude) + + timestamp = np.asarray(timestamp) + orig_shape = timestamp.shape + timestamp_flat = timestamp.reshape(-1) + + two_pi = 2.0 * np.pi + + # --- Extract time components --- + ts_int64 = timestamp_flat.astype("int64") # seconds since Unix epoch + seconds_in_day = 86400.0 + tod_frac = (ts_int64 % int(seconds_in_day)) / seconds_in_day # [0, 1) + + # --- Extract day of year --- + day_np = timestamp_flat.astype("datetime64[D]") + year_start = day_np.astype("datetime64[Y]").astype("datetime64[D]") + next_year_start = (day_np.astype("datetime64[Y]") + np.timedelta64(1, "Y")).astype( + "datetime64[D]" + ) + + day_of_year_0 = (day_np - year_start).astype(np.int64) # [0, 365] or [0, 366] + days_in_year = (next_year_start - year_start).astype(np.int64) # 365 or 366 + doy_frac = day_of_year_0.astype(np.float32) / days_in_year.astype(np.float32) # [0, 1) + + # --- Multi-frequency sinusoidal embeddings (vectorized over k) --- + k = np.arange(1, self.num_frequencies + 1, dtype=np.float32)[None, :] + doy_phase = two_pi * doy_frac[:, None] * k + tod_phase = two_pi * tod_frac[:, None] * k + + doy_cos = np.cos(doy_phase).astype(np.float32) if not self.time_only else np.zeros_like(doy_phase).astype(np.float32) + doy_sin = np.sin(doy_phase).astype(np.float32) if not self.time_only else np.zeros_like(doy_phase).astype(np.float32) + tod_cos = np.cos(tod_phase).astype(np.float32) if not self.date_only else np.zeros_like(tod_phase).astype(np.float32) + tod_sin = np.sin(tod_phase).astype(np.float32) if not self.date_only else np.zeros_like(tod_phase).astype(np.float32) + + # Stack all components: (N, K, 4) -> (N, K*4) + out = np.stack([doy_cos, doy_sin, tod_cos, tod_sin], axis=-1) + out = out.reshape(out.shape[0], self.num_frequencies * 4) + out = torch.from_numpy(out).float() + + return out.reshape(*orig_shape, self.num_frequencies * 4) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index 90fbcf714..7b9137cff 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -32,6 +32,8 @@ def __init__( num_heads, dropout_rate=0.0, norm_type="LayerNorm", + mlp_type="mlp", + use_xsa=False, unembed_mode="full", stream_name="stream_embed", ): @@ -56,6 +58,8 @@ def __init__( self.dim_out = dim_out self.num_blocks = num_blocks self.num_heads = num_heads + self.mlp_type = mlp_type + self.use_xsa = use_xsa self.unembed_mode = unembed_mode norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -69,6 +73,7 @@ def __init__( dropout_rate=dropout_rate, with_qk_lnorm=True, with_flash=True, + use_xsa=self.use_xsa, ) ) self.layers.append( @@ -77,6 +82,7 @@ def __init__( self.dim_embed, hidden_factor=2, dropout_rate=dropout_rate, + mlp_type=self.mlp_type, with_residual=True, ) ) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 54409e297..582e9b57f 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -117,6 +117,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) + self.ln = torch.nn.LayerNorm(cf.ae_global_dim_embed, elementwise_affine=False) + def forward(self, model_params, batch): """ Encoder forward @@ -137,6 +139,8 @@ def forward(self, model_params, batch): use_reentrant=False, ) + tokens_global = self.ln(tokens_global) + return tokens_global, posteriors def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor): @@ -353,3 +357,6 @@ def assimilate_local( ).flatten(1, 2) return tokens_global, posteriors + + def reset_parameters(self): + return diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index d2d5a00c8..6c8e847db 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -6,7 +6,6 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import dataclasses import math @@ -16,7 +15,9 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.batch import SampleMetaData from weathergen.model.attention import ( + MultiCrossAttentionHead, MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, MultiSelfAttentionHead, @@ -69,6 +70,8 @@ def __init__(self, cf: Config, sources_size) -> None: num_heads=si["embed"]["num_heads"], dropout_rate=self.cf.embed_dropout_rate, norm_type=self.cf.norm_type, + mlp_type=self.cf.get("mlp_type", "mlp"), + use_xsa=self.cf.get("use_xsa", False), unembed_mode=self.cf.embed_unembed_mode, stream_name=stream_name, ) @@ -111,6 +114,13 @@ def forward(self, batch, pe_embed): # switch from stream to cell-based ordering and apply per cell positional encoding + # if the assert is hit, max_number_tokens_local_per_cell in config needs to be increased + max_tokens = self.cf.get("ae_local_max_tokens_per_cell", 64) + assert batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens, ( + "max number of tokens per cell for positional encoding exceeded." + ) + " Increase ae_local_max_tokens_per_cell in config." + if batch.tokens_lens.shape[2] == 1: # trivial with one stream tokens_all = torch.cat(x_embeds) @@ -119,10 +129,6 @@ def forward(self, batch, pe_embed): scatter_idxs = self.get_scatter_idxs_vectorized(batch) scatter_idxs = scatter_idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - # if the assert is hit, MAX_NUMBER_TOKENS_LOCAL_PER_CELL needs to be increased - assert ( - batch.tokens_lens.flatten(0, 2).sum(0).max() < MAX_NUMBER_TOKENS_LOCAL_PER_CELL - ), "max number of tokens per cell for positional encoding exceeded" # actual scatter operation and apply per cell positional encoding tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds)) @@ -130,7 +136,7 @@ def forward(self, batch, pe_embed): tokens_all = tokens_all + pe_embed[pe_idxs] return tokens_all - + def get_pe_idxs_vectorized(self, batch): """ Compute per cell indices into positional encoding @@ -221,6 +227,7 @@ def __init__(self, cf: Config) -> None: dropout_rate=self.cf.ae_local_dropout_rate, with_qk_lnorm=self.cf.ae_local_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, @@ -233,6 +240,7 @@ def __init__(self, cf: Config) -> None: self.cf.ae_local_dim_embed, with_residual=True, dropout_rate=self.cf.ae_local_dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -283,6 +291,7 @@ def __init__(self, cf: Config) -> None: self.cf.ae_global_dim_embed, with_residual=True, dropout_rate=self.cf.ae_adapter_dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -405,6 +414,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_aggregation_dropout_rate, with_qk_lnorm=self.cf.ae_aggregation_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, @@ -423,6 +433,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_aggregation_dropout_rate, with_qk_lnorm=self.cf.ae_aggregation_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, @@ -437,6 +448,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.ae_aggregation_dropout_rate, hidden_factor=self.cf.ae_aggregation_mlp_hidden_factor, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -481,6 +493,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_global_dropout_rate, with_qk_lnorm=self.cf.ae_global_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, @@ -498,6 +511,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: dropout_rate=self.cf.ae_global_dropout_rate, with_qk_lnorm=self.cf.ae_global_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, @@ -513,6 +527,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.ae_global_dropout_rate, hidden_factor=self.cf.ae_global_mlp_hidden_factor, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, ) @@ -567,12 +582,15 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dropout_rate=self.cf.fe_dropout_rate, with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), + is_dit=self.cf.get("fe_diffusion_model", False), + dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) else: @@ -585,12 +603,33 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dropout_rate=self.cf.fe_dropout_rate, with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), + is_dit=self.cf.get("fe_diffusion_model", False), + dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", + ) + ) + # Add cross-attention block (Q=noised tokens, KV=enc(X_t)) for cross_attn conditioning + if self.cf.get("fe_diffusion_model_conditioning_type", None) == "cross_attn": + self.fe_blocks.append( + MultiCrossAttentionHead( + dim_embed_q=self.cf.ae_global_dim_embed, + dim_embed_kv=self.cf.ae_global_dim_embed, + num_heads=self.cf.fe_num_heads, + dropout_rate=self.cf.fe_dropout_rate, + with_residual=True, + with_qk_lnorm=self.cf.fe_with_qk_lnorm, + with_flash=self.cf.with_flash_attention, + norm_type=self.cf.norm_type, + qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), + norm_eps=self.cf.norm_eps, + attention_dtype=get_dtype(self.cf.attention_dtype), + is_dit=self.cf.get("fe_diffusion_model", False), ) ) # Add MLP block @@ -598,11 +637,15 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = MLP( self.cf.ae_global_dim_embed, self.cf.ae_global_dim_embed, + num_layers=2, with_residual=True, dropout_rate=self.cf.fe_dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, + is_dit=self.cf.get("fe_diffusion_model", False), + dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) # Optionally, add LayerNorm after i-th layer @@ -620,20 +663,56 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep, coords=None): + def forward( + self, + tokens: torch.Tensor, + fstep: int, + meta_info: SampleMetaData = None, + noise_emb: torch.Tensor = None, + conditioning: torch.Tensor = None, + coords: torch.Tensor = None, + ) -> torch.Tensor: + # aux_info is forecast step, if not disabled with cf.forecast_with_step_conditioning + # aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") if self.training: # Impute noise to the latent state noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) if noise_std > 0.0: tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std - aux_info = None - for _b_idx, block in enumerate(self.fe_blocks): - if isinstance(block, torch.nn.modules.normalization.LayerNorm): - tokens = checkpoint(block, tokens, use_reentrant=False) - else: - tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) - return tokens + # predict residual to last time step if requested + forecast_residual = self.cf.get("forecast_residual", False) + if forecast_residual: + tokens_in = tokens + + if self.cf.get("fe_diffusion_model", False): + assert noise_emb is not None, ( + "noise_emb must be provided for diffusion model conditioning" + ) + for block in self.fe_blocks: + if isinstance(block, torch.nn.LayerNorm): + tokens = checkpoint(block, tokens, use_reentrant=False) + elif isinstance(block, MultiCrossAttentionHead): + assert conditioning is not None, "conditioning (e.g. enc(X_t)) must be provided for cross_attn conditioning" + tokens = checkpoint(block, tokens, conditioning, noise_emb, use_reentrant=False) + else: + if self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln": + assert conditioning is not None, "conditioning must be provided for diffusion model conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, conditioning, use_reentrant=False) + elif self.cf.get("fe_diffusion_model_conditioning_type", None) == "cross_attn": + assert conditioning is not None, "conditioning (e.g. enc(X_t)) must be provided for cross_attn conditioning" + tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) + else: + assert conditioning is None, "conditioning should not be provided when diffusion model conditioning is disabled" + tokens = checkpoint(block, tokens, coords, noise_emb, use_reentrant=False) + else: + for block in self.fe_blocks: + if isinstance(block, torch.nn.LayerNorm): + tokens = checkpoint(block, tokens, use_reentrant=False) + else: + tokens = checkpoint(block, tokens, coords, conditioning, use_reentrant=False) + + return tokens if not forecast_residual else (tokens_in + tokens) class EnsPredictionHead(torch.nn.Module): @@ -699,6 +778,7 @@ def __init__( dim_coord_in, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, softcap, stream_config: dict, ): @@ -720,6 +800,7 @@ def __init__( self.dim_coord_in = dim_coord_in self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor + self.tr_mlp_type = tr_mlp_type self.softcap = softcap self.tte = torch.nn.ModuleList() @@ -753,6 +834,7 @@ def __init__( dropout_rate=0.1, # Assuming dropout_rate is 0.1 with_qk_lnorm=True, with_flash=self.cf.with_flash_attention, + use_xsa=self.cf.get("use_xsa", False), norm_type=self.cf.norm_type, qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), dim_aux=self.dim_coord_in, @@ -769,6 +851,7 @@ def __init__( with_residual=True, hidden_factor=self.tr_mlp_hidden_factor, dropout_rate=0.1, # Assuming dropout_rate is 0.1 + mlp_type=self.tr_mlp_type, norm_type=self.cf.norm_type, dim_aux=(self.dim_coord_in if self.cf.pred_mlp_adaln else None), norm_eps=self.cf.mlp_norm_eps, @@ -806,8 +889,9 @@ def __init__( dim_coord_in, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, softcap, - stream_name: str, + stream_config: dict, ): """ Initialize the TargetPredictionEngine with the configuration. @@ -830,13 +914,14 @@ def __init__( LayerNorm that does not scale after the layer is applied """ super(TargetPredictionEngine, self).__init__() - self.name = f"TargetPredictionEngine_{stream_name}" + self.name = f"TargetPredictionEngine_{stream_config['name']}" self.cf = cf self.dims_embed = dims_embed self.dim_coord_in = dim_coord_in self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor + self.tr_mlp_type = tr_mlp_type self.softcap = softcap # For backwards compatibility @@ -876,6 +961,7 @@ def __init__( with_self_attn=False, with_adanorm=False, with_mlp=False, + mlp_type=self.tr_mlp_type, attention_kwargs=attention_kwargs, ) ) @@ -888,6 +974,8 @@ def __init__( attention_kwargs=attention_kwargs, with_adanorm=True, dropout_rate=0.1, + mlp_type=self.tr_mlp_type, + use_xsa=self.cf.get("use_xsa", False), ) ) elif self.cf.decoder_type == "CrossAttentionConditioning": @@ -901,6 +989,8 @@ def __init__( with_adanorm=False, with_mlp=True, dropout_rate=0.1, + mlp_type=self.tr_mlp_type, + use_xsa=self.cf.get("use_xsa", False), attention_kwargs=attention_kwargs, ) ) @@ -915,6 +1005,8 @@ def __init__( with_adanorm=True, with_mlp=True, dropout_rate=0.1, + mlp_type=self.tr_mlp_type, + use_xsa=self.cf.get("use_xsa", False), attention_kwargs=attention_kwargs, ) ) @@ -930,6 +1022,7 @@ def __init__( attention_kwargs=attention_kwargs, tr_dim_head_proj=tr_dim_head_proj, tr_mlp_hidden_factor=tr_mlp_hidden_factor, + tr_mlp_type=tr_mlp_type, mlp_norm_eps=self.cf.mlp_norm_eps, ) ) @@ -1033,6 +1126,7 @@ def __init__( dropout_rate=dropout_rate, with_qk_lnorm=with_qk_lnorm, with_flash=self.global_cf.with_flash_attention, + use_xsa=self.global_cf.get("use_xsa", False), norm_type=self.global_cf.norm_type, qk_norm_type=self.global_cf.qk_norm_type, # dim_aux=dim_aux, @@ -1048,6 +1142,7 @@ def __init__( hidden_factor=4, with_residual=True, dropout_rate=dropout_rate, + mlp_type=loss_conf.get("mlp_type", self.global_cf.get("mlp_type", "mlp")), norm_type=self.global_cf.norm_type, # dim_aux=dim_aux, norm_eps=self.global_cf.mlp_norm_eps, @@ -1087,7 +1182,15 @@ def forward(self, x: LatentState): class LatentPredictionHeadMLP(nn.Module): - def __init__(self, name, in_dim: int, loss_conf, use_class_token: bool, use_patch_token: bool): + def __init__( + self, + name, + in_dim: int, + loss_conf, + use_class_token: bool, + use_patch_token: bool, + default_mlp_type: str = "mlp", + ): super().__init__() self.name = name @@ -1102,7 +1205,13 @@ def __init__(self, name, in_dim: int, loss_conf, use_class_token: bool, use_patc self.use_patch_token = use_patch_token # Create an MLP block - self.blocks = MLP(in_dim, out_dim, num_layers, hidden_factor) + self.blocks = MLP( + in_dim, + out_dim, + num_layers, + hidden_factor, + mlp_type=loss_conf.get("mlp_type", default_mlp_type), + ) def forward(self, x: LatentState): outputs = [] diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..54dcefca2 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -7,11 +7,28 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +# ---------------------------------------------------------------------------- +# Third-Party Attribution: facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT)) +# This file incorporates code originally from the 'facebookresearch/DiT' repository, +# with adaptations. +# +# The original code is licensed under CC-BY-NC. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Third-Party Attribution: google-deepmind/graphcast (several associated papers) +# This file incorporates code originally from the 'google-deepmind/graphcast' repository, +# with adaptations. +# +# The original code is licensed under Apache 2.0. +# Original Copyright 2024 DeepMind Technologies Limited. +# ---------------------------------------------------------------------------- + import torch import torch.nn as nn -from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm, SwiGLU class NamedLinear(torch.nn.Module): @@ -42,7 +59,10 @@ def __init__( norm_type="LayerNorm", dim_aux=None, norm_eps=1e-5, + mlp_type="mlp", name: str | None = None, + is_dit=False, + dit_is_cond=False, ): """Constructor""" @@ -55,37 +75,98 @@ def __init__( self.with_residual = with_residual self.with_aux = dim_aux is not None + self.is_dit = is_dit + self.dit_is_cond = dit_is_cond + self.mlp_type = mlp_type.lower() dim_hidden = int(dim_in * hidden_factor) - self.layers = torch.nn.ModuleList() + if self.mlp_type not in {"mlp", "swiglu"}: + raise ValueError(f"Unsupported mlp_type: {mlp_type}") - norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm + if self.mlp_type == "swiglu": + # Align with the standard LLaMA-style SwiGLU hidden-width rule. + dim_hidden = max(1, int(2 * dim_hidden / 3)) - if pre_layer_norm: - self.layers.append( - norm(dim_in, eps=norm_eps) - if dim_aux is None - else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) - ) + self.layers = torch.nn.ModuleList() - self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) - self.layers.append(nonlin()) - self.layers.append(torch.nn.Dropout(p=dropout_rate)) + norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm - for _ in range(num_layers - 2): - self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + if is_dit: + if dit_is_cond: + assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm" + assert with_residual, "DIT attention should always have residual connection" + self.lnorm = AdaLNZero(dim_in, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_in, eps=norm_eps) + self.noise_conditioning = LinearNormConditioning(dim_in) + self.noise_conditioning = LinearNormConditioning(dim_in) + elif dim_aux is not None: + self.lnorm = AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + else: + self.lnorm = norm(dim_in, eps=norm_eps) + + # TODO: The below should be consolidated – implementing in layer list for backward compatibility + if not is_dit: + self.layers.append(self.lnorm) + + if self.mlp_type == "swiglu": + self.layers.append(torch.nn.Linear(dim_in, 2 * dim_hidden)) + self.layers.append(SwiGLU()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, 2 * dim_hidden)) + self.layers.append(SwiGLU()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + else: + self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) self.layers.append(nonlin()) self.layers.append(torch.nn.Dropout(p=dropout_rate)) + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + # TODO: expanded args, must check dependencies (previously aux = args[-1]) def forward(self, *args): - x, x_in, aux = args[0], args[0], args[-1] + x, x_in = args[0], args[0] + if not self.is_dit: + if len(args) < 2 and self.with_aux: + raise ValueError("Auxiliary input required but not provided") + if len(args) == 2: + ada_ln_aux = args[1] + elif len(args) > 2: + ada_ln_aux = args[-1] + else: + if self.dit_is_cond: + assert len(args) == 4, "DIT with cond gets 4 args" + ada_ln_aux = args[-1] + noise_emb = args[-2] + else: + assert len(args) == 3, "DIT without cond gets 3 args" + noise_emb = args[-1] + - for i, layer in enumerate(self.layers): - x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + if self.is_dit: + if self.dit_is_cond: + assert ada_ln_aux is not None, "Need auxiliary input for conditional DIT" + x, cond_gate = self.lnorm(x, ada_ln_aux) + else: + x = self.lnorm(x) + cond_gate = 1 + assert noise_emb is not None, "Need noise embedding for noise conditioning in DIT" + x, noise_gate = self.noise_conditioning(x, noise_emb) + gate = cond_gate * noise_gate + + for layer in self.layers: + if isinstance(layer, AdaLayerNorm): + x = layer(x, ada_ln_aux) + else: + x = layer(x) if self.with_residual: + if self.is_dit: + x = x * gate if x.shape[-1] == x_in.shape[-1]: x = x_in + x else: @@ -93,3 +174,37 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + + +# NOTE: Inspired by GenCast/DiT. +class LinearNormConditioning(torch.nn.Module): + """Module for norm conditioning, adapted from GenCast with additional gate parameter from DiT. + + Conditions the normalization of `inputs` by applying a linear layer to the + `norm_conditioning` which produces the scale and offset for each channel. + """ + + def __init__(self, latent_space_dim: int, noise_emb_dim: int = 512, dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + + self.conditional_linear_layer = torch.nn.Linear( + in_features=noise_emb_dim, + out_features=3 * latent_space_dim, + ) + # Optional: initialize weights similar to TruncatedNormal(stddev=1e-8) + torch.nn.init.normal_(self.conditional_linear_layer.weight, std=1e-8) + torch.nn.init.zeros_(self.conditional_linear_layer.bias) + + def forward(self, inputs, noise_emb): + conditional_scale_offset = self.conditional_linear_layer(noise_emb.to(self.dtype)) + scale_minus_one, offset, gate = torch.chunk(conditional_scale_offset, 3, dim=-1) + scale = scale_minus_one + 1.0 + + # Reshape scale and offset for broadcasting if needed + while scale.dim() < inputs.dim(): + scale = scale.unsqueeze(1) + offset = offset.unsqueeze(1) + return (inputs * scale + offset).to( + self.dtype + ), gate # TODO: check if to(self.dtype) needed here diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a569feb15..20590e5bf 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -23,6 +23,7 @@ from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch from weathergen.datasets.utils import healpix_verts_rots, r3tos2 +from weathergen.model.diffusion import DiffusionForecastEngine from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( MAX_NUMBER_TOKENS_LOCAL_PER_CELL, @@ -313,6 +314,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord """ super(Model, self).__init__() + self._noise = None + self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**self.healpix_level @@ -343,6 +346,9 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.register_token_idxs = list(range(cf.num_register_tokens)) self.aux_token_idxs = list(range(cf.num_register_tokens + cf.num_class_tokens)) self.num_aux_tokens = cf.num_register_tokens + cf.num_class_tokens + # One-shot flag to avoid log spam when warning about an unsupported + # diffusion-inference + multi-step-rollout combination. + self._warned_diffusion_multi_step = False def _create_latent_pred_head( self, global_cfg, name, loss_cfg, use_class_token, use_patch_token @@ -354,6 +360,7 @@ def _create_latent_pred_head( loss_cfg, use_class_token=use_class_token, use_patch_token=use_patch_token, + default_mlp_type=global_cfg.get("mlp_type", "mlp"), ) elif loss_cfg["head"].lower() == "transformer": return LatentPredictionHeadTransformer( @@ -377,9 +384,20 @@ def create(self) -> "Model": cf, self.sources_size, self.targets_num_channels, self.targets_coords_size ) + # Initialize forecasting engine: standard or diffusion-wrapped mode_cfg = cf.training_config if cf.fe_num_blocks > 0: - self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) + if cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln": + assert cf.diffusion_conditioning_embed_dim is not None, ( + "Diffusion conditioning embedding dimension must be specified when using diffusion model conditioning" + ) + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells, dim_aux=self.cf.diffusion_conditioning_embed_dim) + else: + self.forecast_engine = ForecastingEngine(cf, mode_cfg, self.num_healpix_cells) + if cf.get("fe_diffusion_model", False): + self.forecast_engine = DiffusionForecastEngine( + cf, self.num_healpix_cells, forecast_engine=self.forecast_engine + ) else: self.forecast_engine = IdentityEngine() @@ -421,6 +439,7 @@ def create(self) -> "Model": tr_mlp_hidden_factor = ( tr["mlp_hidden_factor"] if "mlp_hidden_factor" in tr else 2 ) + tr_mlp_type = tr.get("mlp_type", cf.get("mlp_type", "mlp")) tr_dim_head_proj = tr["dim_head_proj"] if "dim_head_proj" in tr else None softcap = tr["softcap"] if "softcap" in tr else 0.0 @@ -448,6 +467,7 @@ def create(self) -> "Model": hidden_factor=8, with_residual=False, dropout_rate=dropout_rate, + mlp_type=self.cf.get("mlp_type", "mlp"), norm_eps=self.cf.mlp_norm_eps, name=f"embed_target_coords_{stream_name}", ) @@ -474,6 +494,7 @@ def create(self) -> "Model": dim_coord_in, tr_dim_head_proj, tr_mlp_hidden_factor, + tr_mlp_type, softcap, stream_config=si, ) @@ -624,7 +645,13 @@ def print_num_parameters(self) -> None: num_params_latent_heads = get_num_parameters(self.latent_heads) num_params_latent_heads += get_num_parameters(self.latent_pre_norm) - num_params_fe = get_num_parameters(self.forecast_engine.fe_blocks) + num_params_fe = ( + get_num_parameters( + self.forecast_engine.net.fe_blocks + if cf.fe_diffusion_model + else self.forecast_engine.fe_blocks + ) + ) mdict = self.embed_target_coords num_params_embed_tcs = [ @@ -701,11 +728,22 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # recover batch dimension and separate input_steps shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) - # collapse along input step dimension - tokens = tokens.reshape(shape).sum(axis=1) + # Reshape tokens to [B, T, ...] + tokens = tokens.reshape(shape) + + if self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": + tokens = tokens.reshape(shape) + conditioning_tokens = tokens[:, -2] # TODO: enable longer history for conditioning + # X_t (last step) is the diffusion denoising target; X_{t-1} is the conditioning context. + batch.samples[0].meta_info["ERA5"].params["conditioning_tokens"] = conditioning_tokens + # self.forecast_engine._pending_target_tokens = diffusion_target_tokens + tokens = tokens[:, -1] + else: + tokens = tokens.sum(axis=1) # Allow for pushforward trick p_fwd = self.cf.training_config.get("forecast", {}).get("pushforward", False) + # roll-out in latent space, iterate and generate output over requested output steps for step in batch.get_output_idxs(): without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) @@ -714,14 +752,71 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: tokens = self.forecast_engine(tokens, step, model_params.rope_coords) continue - tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + # apply forecasting engine + tokens = self.forecast_engine( + tokens, + step, + meta_info=batch.samples[0].meta_info, + coords=model_params.rope_coords, + ) + + # Diffusion inference returns the per-ODE-step intermediate denoised tokens as a + # list. Treat each intermediate state as its own forecast step in the output so the + # full denoising trajectory can be inspected downstream. The original `step` is + # still used to look up target coordinates (they share the same physical timestamp). + if isinstance(tokens, list): + # Diffusion inference currently only supports a single physical forecast + # step (forecast.num_steps=1); the per-ODE-step trajectory consumes the + # ModelOutput fstep dimension. Multi-step autoregressive rollouts on top of + # diffusion are not implemented yet. + if ( + len(batch.get_output_idxs()) > 1 + and not self._warned_diffusion_multi_step + ): + logger.warning( + "Diffusion inference is being run with forecast.num_steps=%d (>1). " + "Only a single forecast step is supported in this mode; the " + "per-ODE-step denoising trajectory will overwrite later forecast " + "steps in the model output.", + len(batch.get_output_idxs()), + ) + self._warned_diffusion_multi_step = True + # Resize output to fit the diffusion trajectory. + output = self._reindex_output_for_trajectory(output, len(tokens)) + for i, toks in enumerate(tokens): + output = self.predict_decoders( + model_params, step, toks, batch, output, out_step=i + ) + output = self.predict_latent( + model_params, step, toks, batch, output, out_step=i + ) + # Feed the final denoised state back as conditioning for the next step. + # Pass tokens[-1] forward so inference diagnostics have a reference point; + # inference_forward always starts from pure noise regardless. + batch.samples[0].meta_info["ERA5"].params["conditioning_tokens"] = tokens[-1] + tokens = None #NOTE: This is precautionary, might need to be handled differently. It should not be the same as conditioning tokens. + continue + # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) + # latent predictions (raw and with SSL heads) output = self.predict_latent(model_params, step, tokens, batch, output) return output + @staticmethod + def _reindex_output_for_trajectory(output: ModelOutput, n_steps: int) -> ModelOutput: + """ + Resize a ModelOutput to hold ``n_steps`` forecast steps, preserving any latent entries + that were already attached to fstep 0 (e.g. encoder posteriors). + """ + new_output = ModelOutput(n_steps) + if len(output.latent) > 0: + for k, v in output.latent[0].items(): + new_output.add_latent_prediction(0, k, v) + return new_output + def predict_latent( self, model_params: ModelParams, @@ -729,19 +824,23 @@ def predict_latent( tokens: torch.Tensor, batch: ModelBatch, output: ModelOutput, + out_step: int | None = None, ) -> ModelOutput: """ Compute latent predictions """ + if out_step is None: + out_step = step + # safe latent prediction tokens_post_norm = self.latent_pre_norm(tokens) if step == 0 else None latent_state = self.tokens_to_latent_state(tokens_post_norm, tokens) - output.add_latent_prediction(step, "latent_state", latent_state) + output.add_latent_prediction(out_step, "latent_state", latent_state) # latent predictions for SSL training for name, head in self.latent_heads.items(): - output.add_latent_prediction(step, name, head(latent_state)) + output.add_latent_prediction(out_step, name, head(latent_state)) return output @@ -752,6 +851,7 @@ def predict_decoders( tokens: torch.Tensor, batch: ModelBatch, output: ModelOutput, + out_step: int | None = None, ) -> ModelOutput: """ Compute decoder-based predictions @@ -770,9 +870,13 @@ def predict_decoders( Prediction output tokens in physical representation for each target_coords. """ # Empty dicts evaluate to False in python + # breakpoint() if not self.pred_heads: return output + if out_step is None: + out_step = step + # remove register and class tokens tokens = tokens[:, self.num_aux_tokens :] @@ -787,6 +891,8 @@ def predict_decoders( ) tokens_nbors_lens[0] = 0 + # breakpoint() + # pair with tokens from assimilation engine to obtain target tokens for stream_name in self.stream_names: # extract target coords for current stream and fstep and convert to one tensor @@ -797,6 +903,7 @@ def predict_decoders( t_coords_lens = [len(t) for t in t_coords] t_coords = torch.cat(t_coords) + # breakpoint() if len(t_coords) == 0: continue @@ -848,6 +955,7 @@ def predict_decoders( # recover batch dimension (ragged, so as list) pred = torch.split(pred, t_coords_lens, dim=1) - output.add_physical_prediction(step, stream_name, pred) + + output.add_physical_prediction(out_step, stream_name, pred) return output diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index e200d674b..b27f8bc0e 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -70,6 +70,7 @@ def init_model_and_shard( find_unused_parameters=cf.get("ddp_find_unused_parameters", True), gradient_as_bucket_view=True, bucket_cap_mb=512, + static_graph=cf.get("ddp_static_graph", True), ) elif with_ddp and with_fsdp: @@ -105,7 +106,15 @@ def init_model_and_shard( if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.forecast_engine.fe_blocks.modules(): + if cf.fe_diffusion_model: + model_fe_blocks = model.forecast_engine.net.fe_blocks + else: + model_fe_blocks = model.forecast_engine.fe_blocks + for module in model_fe_blocks.modules(): + if isinstance(module, modules_to_shard): + fully_shard(module, **fsdp_kwargs) + + for module in model.latent_heads.modules(): if isinstance(module, modules_to_shard): # reshard_after_forward=False keeps FE parameters unsharded # during the multi-step rollout loop. @@ -189,10 +198,30 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): is_model_sharded = cf.with_ddp and cf.with_fsdp if is_model_sharded: + # model_has_prefix_module = list(model.state_dict().keys())[0].split(".")[0] == "module" + # params_has_prefix_module = list(params.keys())[0].split(".")[0] == "module" + # if model_has_prefix_module and not params_has_prefix_module: + # # add "module." prefix + # params_temp = {} + # for k in params.keys(): + # params_temp["module." + k] = params[k] + # params = params_temp + # elif not model_has_prefix_module and params_has_prefix_module: + # # remove "module." prefix + # params_temp = {} + # for k in params.keys(): + # params_temp[k.replace("module.", "")] = params[k] + # params = params_temp + meta_sharded_sd = model.state_dict() maybe_sharded_sd = {} for param_name, full_tensor in params.items(): sharded_meta_param = meta_sharded_sd.get(param_name) + if ( + sharded_meta_param is None + or type(sharded_meta_param) is not torch.distributed.tensor.DTensor + ): + continue sharded_tensor = distribute_tensor( full_tensor, sharded_meta_param.device_mesh, diff --git a/src/weathergen/model/norms.py b/src/weathergen/model/norms.py index 4ecbfa80a..0526c6f90 100644 --- a/src/weathergen/model/norms.py +++ b/src/weathergen/model/norms.py @@ -60,6 +60,40 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight +class AdaLNZero(torch.nn.Module): + """ + AdaLayerNorm with zero initialization and with additional gate parameter + """ + + def __init__( + self, dim_embed_x, dim_aux, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5 + ): + super().__init__() + + # simple 2-layer MLP for embedding auxiliary information + self.embed_aux = torch.nn.ModuleList() + self.embed_aux.append(torch.nn.Linear(dim_aux, 6 * dim_aux)) + self.embed_aux.append(torch.nn.SiLU()) + self.embed_aux.append(torch.nn.Linear(6 * dim_aux, 3 * dim_embed_x)) + + self.norm = torch.nn.LayerNorm( + dim_embed_x, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + # Zero-initialize the final modulation layer. + nn.init.zeros_(self.embed_aux[-1].weight) + nn.init.zeros_(self.embed_aux[-1].bias) + + def forward(self, x: torch.Tensor, aux: torch.Tensor | None = None) -> torch.Tensor: + for block in self.embed_aux: + aux = block(aux) + scale, shift, gate = aux.chunk(3, dim=-1) + + x = self.norm(x) * (1 + scale) + shift + + return x, gate + class AdaLayerNorm(torch.nn.Module): """ diff --git a/src/weathergen/model/utils.py b/src/weathergen/model/utils.py index 7dd2060bb..865d826d3 100644 --- a/src/weathergen/model/utils.py +++ b/src/weathergen/model/utils.py @@ -49,6 +49,9 @@ def apply_fct_to_blocks(model, blocks, fct): # avoid the whole model element which has name '' if (re.fullmatch(blocks, name) is not None) and (name != ""): fct(module) + logger.info(f"Applied function {fct.__name__} to block {name}") + else: + logger.info(f"Did not apply function {fct.__name__} to block {name}") class ActivationFactory: diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 7995b5864..a1ad4d1c3 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -96,6 +96,8 @@ def run_inference(args): cli_overwrite, ) cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + cf.data_loading.rng_seed = 42 + cf.stage = args.stage devices = Trainer.init_torch() cf = Trainer.init_ddp(cf) @@ -134,6 +136,7 @@ def run_continue(args): cli_overwrite, ) cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + cf.stage = args.stage mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) @@ -163,13 +166,12 @@ def run_train(args): """ cli_overwrite = config.from_cli_arglist(args.options) - cf = config.load_merge_configs( args.private_config, None, None, args.base_config, *args.config, cli_overwrite ) cf = config.set_run_id(cf, args.run_id, False) - cf.data_loading.rng_seed = int(time.time()) + cf.stage = args.stage mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) cf = Trainer.init_ddp(cf) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 90907cf8b..bf84881a7 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -86,18 +86,18 @@ def compute_loss( ): losses_all = defaultdict(dict) stddev_all = defaultdict(dict) - loss = torch.tensor(0.0, requires_grad=True) + loss = torch.tensor(0.0, device=self.device, requires_grad=True) for loss_term_name, calc_term in self.loss_calculators.items(): target = targets_and_aux[loss_term_name] for weight, calculator in calc_term: + loss_values = calculator.compute_loss( + preds=preds, targets=target, metadata=metadata + ) if weight > 0.0: - loss_values = calculator.compute_loss( - preds=preds, targets=target, metadata=metadata - ) loss = loss + weight * loss_values.loss - losses_all[calculator.name] = loss_values.losses_all - losses_all[calculator.name]["loss_avg"] = loss_values.loss - stddev_all[calculator.name] = loss_values.stddev_all + losses_all[calculator.name] = loss_values.losses_all + losses_all[calculator.name]["loss_avg"] = loss_values.loss + stddev_all[calculator.name] = loss_values.stddev_all # Keep histories for logging self.loss_hist += [loss.detach()] diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 00a8b7b31..3aded1edd 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -7,7 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from .loss_module_latent_diffusion import LossLatentDiffusion from .loss_module_physical import LossPhysical from .loss_module_ssl import LossLatentSSLStudentTeacher -__all__ = [LossPhysical, LossLatentSSLStudentTeacher] +__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatentDiffusion] diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index 192101278..40f8d2ce8 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -206,8 +206,8 @@ def lp_loss( def mse( target: torch.Tensor, pred: torch.Tensor, - weights_channels: torch.Tensor | None, - weights_points: torch.Tensor | None, + weights_channels: torch.Tensor | None = None, + weights_points: torch.Tensor | None = None, ): """ Computes the mean squared error (mse). diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py new file mode 100644 index 000000000..a89c9d6b8 --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -0,0 +1,149 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss_modules.loss_functions as loss_fns +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossLatentDiffusion(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + mode_cfg: DictConfig, + stage: Stage, + device: str, + **loss_fcts: dict, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatentDiff" + + self.sigma_data = self.cf.sigma_data + self.rho = self.cf.rho + self.p_mean = self.cf.p_mean + self.p_std = self.cf.p_std + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [ + getattr(loss_fns, name), + params.get("weight", 1.0), + name, + ] + for name, params in loss_fcts.items() + ] + + self.random_target = None + + def _get_noise_weight(self, eta): + sigma = (eta * self.p_std + self.p_mean).exp() + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(loss_fns, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + noise_weight: torch.Tensor = 1.0, + ): + """ + Compute loss for given loss function + """ + + loss, loss_chs = loss_fct(target=target, pred=pred) + loss = noise_weight * loss + + return loss + + def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: + losses_all: dict[str, Tensor] = { + f"{self.name}.{loss_fct_name}": torch.zeros( + 1, + device=self.device, + ) + for _, _, loss_fct_name in self.loss_fcts + } + + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] + + eta = torch.tensor( + [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 + ) + fsteps = len(target_tokens_all) + + # During validation, use unweighted loss (no noise-level scaling) + noise_weight = 1.0 if self.stage == "val" else self._get_noise_weight(eta) + fstep_loss_weights = self._get_fstep_weights(fsteps) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + for target_tokens, pred_tokens, fstep_loss_weight in zip( + target_tokens_all, pred_tokens_all, fstep_loss_weights, strict=True + ): + # the first entry in tokens_all is the source itself, so skip it + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + # if forecast_offset==0, then the timepoints correspond. + # Otherwise targets don't encode the source timestep, so we don't need to skip + for loss_fct, loss_fct_weight, loss_fct_name in self.loss_fcts: + loss_lfct = self._loss_per_loss_function( + loss_fct, + target=target_tokens, + pred=pred_tokens, + noise_weight=noise_weight, + ) + + losses_all[f"{self.name}.{loss_fct_name}"] += loss_lfct # TODO: break into fsteps + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + ( + loss_fstep * fstep_loss_weight / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + ) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss_fsteps / (ctr_fsteps if ctr_fsteps > 0 else 1.0) + + for _, loss_values in losses_all.items(): + loss_values /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + loss_values[loss_values == 0.0] = torch.nan + + return LossValues( + loss=loss, + losses_all=losses_all, + stddev_all={"latent": torch.tensor(torch.nan).to(self.device)}, + ) diff --git a/src/weathergen/train/target_and_aux_diffusion.py b/src/weathergen/train/target_and_aux_diffusion.py new file mode 100644 index 000000000..4079dcc65 --- /dev/null +++ b/src/weathergen/train/target_and_aux_diffusion.py @@ -0,0 +1,90 @@ +from typing import Any + +import torch + +from weathergen.datasets.batch import ModelBatch +from weathergen.model.model import ModelParams +from weathergen.model.utils import apply_fct_to_blocks, freeze_weights, set_to_eval +from weathergen.train.target_and_aux_module_base import ( + TargetAndAuxModuleBase, + TargetAuxOutput, +) + + +class DiffusionLatentTargetEncoder(TargetAndAuxModuleBase): + def __init__(self, encoder, is_model_sharded=True): + # Todo: make sure this is a frozen clone or forward without gradients in compute() + self.encoder = encoder + + apply_fct_to_blocks(self.encoder, ".*", freeze_weights) + apply_fct_to_blocks(self.encoder, ".*", set_to_eval) + + self.is_model_sharded = is_model_sharded + self._fixed_noise_level: float | None = None + # Build a name → param map once + self.src_params = dict(self.encoder.named_parameters()) + + # self.reset() + + @torch.no_grad() + def reset(self): + """ + This function resets the EMAModel to be the same as the Model. + + It operates via the state_dict to be able to deal with sharded tensors in case + FSDP2 is used. + """ + # TODO: This needs fixing, might need to use apply_fct_to_blocks as in init() + + self.encoder.to_empty(device="cuda") + for p in self.encoder.parameters(): + p.requires_grad = False + maybe_sharded_sd = self.encoder.state_dict() + mkeys, ukeys = self.encoder.load_state_dict(maybe_sharded_sd, strict=False, assign=False) + self.encoder.eval() + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + if self.is_model_sharded: + self.encoder.reshard() + + def compute( + self, + istep: int, + batch: ModelBatch, + model_params: ModelParams, + model: torch.nn.Module, + *args, + **kwargs, + ) -> tuple[Any, Any]: + # During validation (model in eval mode), use fixed noise level + # so that sigma = exp(eta * p_std + p_mean) is deterministic + if model.training: + noise_level_rn = ( + batch.samples[0].meta_info["ERA5"].params["noise_level_rn"] + ) # TODO: adjust for multiple streams + else: + noise_level_rn = self._fixed_noise_level if self._fixed_noise_level is not None else 0.0 + + # TODO: check if there are scenarios where the encoder needs to be set to eval + with torch.no_grad(): + self.encoder.encoder.eval() # NOTE: might be redundant + tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch) + shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) + tokens_multi = tokens.reshape(shape) + # NOTE: must not set to train afterwards unless it was already in train + + output_idxs = batch.get_output_idxs() + assert len(output_idxs) > 0 + + # The encoder produces a single target latent (tokens_multi[:, -1]) regardless of + # how many forecast steps are requested. Initialise with a single slot so that + # _expand_targets_to_match_preds (in trainer.py) replicates the target across all + # forecast steps automatically — both for T-step autoregressive rollouts and for the + # single-step ODE-trajectory case. + target_aux_output = TargetAuxOutput(1, [0]) + target_aux_output.add_latent_target(0, "diffusion_latent", tokens_multi[:, -1]) + + # TODO: write function in TargetAuxOutput class + target_aux_output.aux_outputs = {"noise_level_rn": noise_level_rn} + + return target_aux_output diff --git a/src/weathergen/train/target_and_aux_utils.py b/src/weathergen/train/target_and_aux_utils.py index efaff18bc..47f9551db 100644 --- a/src/weathergen/train/target_and_aux_utils.py +++ b/src/weathergen/train/target_and_aux_utils.py @@ -1,8 +1,10 @@ import omegaconf +import torch from weathergen.common.config import Config, merge_configs from weathergen.model.ema import EMAModel from weathergen.model.model_interface import init_model_and_shard +from weathergen.train.target_and_aux_diffusion import DiffusionLatentTargetEncoder from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher from weathergen.train.teacher_utils import load_encoder_from_checkpoint, prepare_encoder_teacher @@ -33,6 +35,36 @@ def get_target_aux_calculator( if target_and_aux_calc == "Physical": target_aux = PhysicalTargetAndAux(loss_cfg, model) + elif target_and_aux_calc == "DiffusionLatentTargetEncoder": + model, _ = init_model_and_shard( + cf, + dataset, + cf.get("load_chkpt", {}).get("run_id", None), + cf.get("load_chkpt", {}).get("epoch", -1), + "student", + device, + with_ddp=False, + with_fsdp=False, + overrides=target_and_aux_calc_params.get("model_param_overrides", {}), + ) + # Free components not needed by DiffusionLatentTargetEncoder (only uses the encoder) + for attr in ( + "forecast_engine", + "pred_heads", + "target_token_engines", + "embed_target_coords", + "latent_heads", + "latent_pre_norm", + ): + if hasattr(model, attr) and getattr(model, attr) is not None: + delattr(model, attr) + setattr(model, attr, None) + torch.cuda.empty_cache() + + target_aux = DiffusionLatentTargetEncoder( + model, is_model_sharded=(cf.with_ddp and cf.with_fsdp) + ) + elif target_and_aux_calc == "EMATeacher": # work around for problems with FSDP2 assert not cf.with_fsdp, "EMATeacher not supported with FSDP(2) at the moment" diff --git a/src/weathergen/train/teacher_utils.py b/src/weathergen/train/teacher_utils.py index c026960b5..224505f5d 100644 --- a/src/weathergen/train/teacher_utils.py +++ b/src/weathergen/train/teacher_utils.py @@ -43,7 +43,12 @@ def _create_teacher_heads( if head_type == "mlp": return LatentPredictionHeadMLP( - f"{name}-head", dim_embed, loss_conf, use_class_token, use_patch_token + f"{name}-head", + dim_embed, + loss_conf, + use_class_token, + use_patch_token, + default_mlp_type=(cf.get("mlp_type", "mlp") if cf is not None else "mlp"), ) elif head_type == "transformer": if cf is None: @@ -88,7 +93,7 @@ def prepare_encoder_teacher(model: nn.Module, training_cfg, override_cfg) -> Non elif name in ("iBOT", "DINO"): head_type = conf.get("head", "mlp").lower() model.latent_heads[name] = _create_teacher_heads( - name, head_type, teacher_dim_embed, conf + name, head_type, teacher_dim_embed, conf, cf=override_cfg ) else: logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.") diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 210fc16a4..0fdd8d895 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -11,6 +11,7 @@ import copy import logging import time +from decimal import Decimal from math import sqrt import numpy as np @@ -56,6 +57,43 @@ # cfg_keys_to_filter = ["losses", "model_input", "target_input"] +def _expand_targets_to_match_preds(preds, targets_and_auxs: dict) -> None: + """ + Replicate per-fstep entries in each TargetAuxOutput so its ``physical`` and ``latent`` + lists match the number of forecast steps in ``preds``. + + Diffusion inference produces one ``preds`` fstep per ODE denoising step, but the + physical target is identical across the trajectory. Without this expansion the loss + calculator (which zips preds and targets with ``strict=True``) raises a length + mismatch. + + The expansion replicates references — no tensor copies are made — and is a no-op when + the lengths already agree. + """ + n_pred = len(preds.physical) + for t_aux in targets_and_auxs.values(): + n_tgt = len(t_aux.physical) + if n_tgt == n_pred or n_tgt == 0: + continue + if n_pred % n_tgt != 0: + logger.warning( + "Cannot expand target/aux from %d to %d fsteps (not a multiple); " + "leaving unchanged.", + n_tgt, + n_pred, + ) + continue + repeat = n_pred // n_tgt + t_aux.physical = [t_aux.physical[i // repeat] for i in range(n_pred)] + t_aux.latent = [t_aux.latent[i // repeat] for i in range(n_pred)] + # output_idxs is consumed by validation IO via batch.get_output_idxs(), but we + # keep the dataclass internally consistent in case other consumers read it. + if t_aux.output_idxs is not None and len(t_aux.output_idxs) == n_tgt: + t_aux.output_idxs = [ + t_aux.output_idxs[i // repeat] for i in range(n_pred) + ] + + class Trainer(TrainerBase): def __init__(self, train_logging: Config): TrainerBase.__init__(self) @@ -195,6 +233,10 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): # create data loader # only one needed since we only run the validation code path + # Force full maps during inference by disabling target subsampling + for stream_info in cf.streams: + stream_info["max_num_targets"] = -1 + self.dataset = MultiStreamDataSampler( cf, self.test_cfg, @@ -560,93 +602,185 @@ def train(self, mini_epoch): def validate(self, mini_epoch, mode_cfg, batch_size): """ - Perform validation / test computation as specified by mode_cfg + Perform validation / test computation as specified by mode_cfg. + + For diffusion models, runs separate validation passes for each noise level + specified in ``validation_noise_levels`` (defaults to ``[0.0]``). + Losses are logged with a per-noise-level suffix so they can be compared. """ cf = self.cf self.model.eval() - dataset_val_iter = iter(self.data_loader_validation) - - num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size - - with torch.no_grad(): - # print progress bar but only in interactive mode, i.e. when without ddp - with tqdm.tqdm( - total=len(self.data_loader_validation), disable=self.cf.with_ddp - ) as pbar: - for bidx, batch in enumerate(dataset_val_iter): - if cf.data_loading.get("memory_pinning", False): - # pin memory for faster CPU-GPU transfer - batch = batch.pin_memory() - - batch.to_device(self.device) - - # evaluate model - with torch.autocast( - device_type=f"cuda:{cf.local_rank}", - dtype=self.mixed_precision_dtype, - enabled=cf.with_mixed_precision, - ): - if self.ema_model is None: - preds = self.model( - self.model_params, - batch.get_source_samples(), - ) - else: - preds = self.ema_model.forward_eval( - self.model_params, - batch.get_source_samples(), - ) - - targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): - target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) - targets_and_auxs[loss_name] = target_aux.compute( - self.cf.general.istep, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) - - _ = self.loss_calculator_val.compute_loss( - preds=preds, - targets_and_aux=targets_and_auxs, - metadata=extract_batch_metadata(batch), - ) + is_diffusion = cf.get("fe_diffusion_model", False) + noise_levels = list(mode_cfg.get("validation_noise_levels", [0.0])) + if not is_diffusion: + noise_levels = [0.0] + else: + # Always include a pass without fixed noise level (random sampling) + noise_levels = [None] + noise_levels - # log output - if bidx < num_samples_write: - # denormalization function for data - denormalize_data_fct = ( - (lambda x0, x1: x1) - if mode_cfg.get("output", {}).get("normalized_samples", False) - else self.dataset_val.denormalize_target_channels - ) - # write output - write_output( - self.cf, - mode_cfg, - batch_size, - mini_epoch, - bidx, - denormalize_data_fct, - batch, - preds, - targets_and_auxs, - ) + # Accumulate losses across noise levels with suffixed keys so they are + # logged as a single "val" entry (e.g. LossLatentDiff.LossLatentDiff.mse.eta0.03) + all_losses: dict[str, list] = {} + all_stddev: dict[str, list] = {} + + for noise_idx, noise_level in enumerate(noise_levels): + if is_diffusion: + self._set_validation_noise_level(noise_level) - pbar.update(batch_size) + if noise_level is None: + loss_suffix = "" + stage_suffix = "" + else: + _d = Decimal(str(noise_level)).normalize() + _sign, _digits, _exp = _d.as_tuple() + eta_str = f"{'-' if _sign else ''}{''.join(map(str, _digits))}e{_exp}" + loss_suffix = f".eta{eta_str}" if len(noise_levels) > 1 else "" + stage_suffix = f"_eta{eta_str}" if len(noise_levels) > 1 else "" + + dataset_val_iter = iter(self.data_loader_validation) + num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size + + with torch.no_grad(): + # print progress bar but only in interactive mode, i.e. when without ddp + with tqdm.tqdm( + total=len(self.data_loader_validation), disable=self.cf.with_ddp + ) as pbar: + for bidx, batch in enumerate(dataset_val_iter): + if cf.data_loading.get("memory_pinning", False): + # pin memory for faster CPU-GPU transfer + batch = batch.pin_memory() + + batch.to_device(self.device) + + # evaluate model + with torch.autocast( + device_type=f"cuda:{cf.local_rank}", + dtype=self.mixed_precision_dtype, + enabled=cf.with_mixed_precision, + ): + + if self.ema_model is None: + preds = self.model( + self.model_params, + batch.get_source_samples(), + ) + else: + preds = self.ema_model.forward_eval( + self.model_params, + batch.get_source_samples(), + ) + + targets_and_auxs = {} + for ( + loss_name, + target_aux, + ) in self.target_and_aux_calculators_val.items(): + target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) + targets_and_auxs[loss_name] = target_aux.compute( + self.cf.general.istep, + batch.get_target_samples(target_idxs), + self.model_params, + self.model, + ) + + # Diffusion inference inflates the model output's fstep + # dimension to one entry per ODE step (the denoising + # trajectory). The physical target is identical for every + # such step, so replicate target/aux entries to keep the + # downstream loss calculator and validation IO aligned. + if is_diffusion: + _expand_targets_to_match_preds(preds, targets_and_auxs) + + _ = self.loss_calculator_val.compute_loss( + preds=preds, + targets_and_aux=targets_and_auxs, + metadata=extract_batch_metadata(batch), + ) - if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: - break + # log output + if noise_idx == 0: + if bidx < num_samples_write: + # denormalization function for data + denormalize_data_fct = ( + (lambda x0, x1: x1) + if mode_cfg.get("output", {}).get("normalized_samples", False) + else self.dataset_val.denormalize_target_channels + ) + # write output (zarr only for first noise level, plots for all) + write_output( + self.cf, + mode_cfg, + batch_size, + mini_epoch, + bidx, + denormalize_data_fct, + batch, + preds, + targets_and_auxs, + ) + + pbar.update(batch_size) + + if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: + break + + # Terminal logging per noise level for progress visibility + self._log_terminal(0, mini_epoch, VAL, stage_suffix=stage_suffix) + + # Extract losses for this noise level, suffix keys, and accumulate + loss_calc = self.loss_calculator_val + _, losses_level, stddev_level = prepare_losses_for_logging( + loss_calc.loss_hist, + loss_calc.losses_unweighted_hist, + loss_calc.stddev_unweighted_hist, + ) + for key, value in losses_level.items(): + all_losses[f"{key}{loss_suffix}"] = value + for key, value in stddev_level.items(): + all_stddev[f"{key}{loss_suffix}"] = value + loss_calc.loss_hist = [] + loss_calc.losses_unweighted_hist = [] + loss_calc.stddev_unweighted_hist = [] + + # Log all noise levels as a single "val" entry with suffixed loss keys + samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + if is_root(): + self.train_logger.add_logs(VAL, samples, all_losses, all_stddev) - self._log_terminal(0, mini_epoch, VAL) - self._log(VAL) + # reset fixed noise level + if is_diffusion: + self._set_validation_noise_level(None) # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() + def _set_validation_noise_level(self, noise_level: float | None): + """Set fixed noise level on diffusion components for validation. + + Args: + noise_level: The eta value (standard normal space) to fix for validation. + sigma = exp(eta * p_std + p_mean). None resets to default (0.0). + """ + # Unwrap DDP/FSDP to access the underlying model + base_model = getattr(self.model, "module", self.model) + # Set on the base model + if hasattr(base_model, "forecast_engine") and hasattr( + base_model.forecast_engine, "_fixed_noise_level" + ): + base_model.forecast_engine._fixed_noise_level = noise_level + # Also set on the EMA model (separate model copy used during validation) + if self.ema_model is not None: + ema_net = getattr(self.ema_model.ema_model, "module", self.ema_model.ema_model) + if hasattr(ema_net, "forecast_engine") and hasattr( + ema_net.forecast_engine, "_fixed_noise_level" + ): + ema_net.forecast_engine._fixed_noise_level = noise_level + for calc in self.target_and_aux_calculators_val.values(): + if hasattr(calc, "_fixed_noise_level"): + calc._fixed_noise_level = noise_level + def _get_full_model_state_dict(self): maybe_sharded_sd = ( self.model.state_dict() if self.ema_model is None else self.ema_model.state_dict() @@ -721,13 +855,15 @@ def save_model(self, mini_epoch: int, name=None): # save config config.save(self.cf, mini_epoch) - def _log(self, stage: Stage): + def _log(self, stage: Stage, stage_suffix: str = ""): """ Logs training or validation metrics. Args: stage: Stage Is it's VAL, logs are treated as validation logs. If TRAIN, logs are treated as training logs + stage_suffix: Optional suffix appended to the logged stage name + (e.g. "_eta0.00" for per-noise-level validation). Notes: - This method only executes logging on the main process (rank 0). @@ -741,15 +877,16 @@ def _log(self, stage: Stage): ) samples = self.cf.general.istep * self.get_batch_size_total(self.batch_size_per_gpu) + log_stage = f"{stage}{stage_suffix}" if stage_suffix else stage if is_root(): # plain logger if stage == VAL: - self.train_logger.add_logs(stage, samples, losses_all, stddev_all) + self.train_logger.add_logs(log_stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: self.train_logger.add_logs( - stage, + log_stage, samples, losses_all, stddev_all, @@ -783,7 +920,7 @@ def _log_instant_grad_norms(self, stage: Stage): if is_root(): self.train_logger.log_metrics(stage, grad_norms) - def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): + def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage, stage_suffix: str = ""): print_freq = self.train_logging.terminal if bidx % print_freq == 0 and bidx > 0 or stage == VAL: # compute from last iteration @@ -797,7 +934,7 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): if is_root(): if stage == VAL: logger.info( - f"""validation ({self.cf.general.run_id}) : {mini_epoch:03d} : + f"""validation{stage_suffix} ({self.cf.general.run_id}) : {mini_epoch:03d} : {np.nanmean(avg_loss)}""" ) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index b4a5f1279..7d734af18 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -430,10 +430,12 @@ def plot_loss_per_stream( data_cols = [] for col in run_data_mode.columns: col_split = col.split(".") - if len(col_split) < 4: + if col == stream_name: + data_cols += [col] + elif len(col_split) < 4: if stream_name in col: data_cols += [col] - elif len(col_split) == 4: + elif col_split[3] == "avg": if ( col_split[1].lower() == stream_name.lower() and col_split[2].lower() == err.lower() diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index d21938dd5..e6bbebe8d 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -21,7 +21,15 @@ def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out + cf, + val_cfg, + batch_size, + mini_epoch, + batch_idx, + dn_data, + batch, + model_output, + target_aux_out, ): """ Interface for writing model output @@ -42,6 +50,15 @@ def write_output( timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() forecast_offset = timestep_idxs[0] + + # Diffusion inference inflates the model output's fstep dimension to one entry per + # ODE denoising step (the trajectory). The batch only has the original physical + # forecast indices, so synthesize a contiguous run of indices starting at the + # original first index to cover every entry in model_output / target_aux_out. + n_pred_steps = len(model_output.physical) + if n_pred_steps > len(timestep_idxs): + timestep_idxs = list(range(forecast_offset, forecast_offset + n_pred_steps)) + targets_lens = [] # TODO Maybe stopping at forecast_steps explained #1657 @@ -56,6 +73,7 @@ def write_output( # handle spoof data: do not write since it might corrupt validation (spoofing invisible # there) + if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: preds = model_output.get_physical_prediction(t_idx, sname) preds_shape = preds[0].shape @@ -174,3 +192,8 @@ def write_output( with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: for subset in data.items(): zio.write_zarr(subset) + + # Free arrays no longer needed after zarr writing + del targets_all, targets_lens, sources, data + + del targets_times_all