From a5fbd448c398a6b3f0b8f1ad7c3f07e8d7f7c93c Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sat, 25 Apr 2026 11:57:11 +0200 Subject: [PATCH 01/10] add inference_only flag to default config --- config/default_config.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config/default_config.yml b/config/default_config.yml index 39abe739b..a27a2225c 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -90,6 +90,9 @@ latent_noise_deterministic_latents: True freeze_modules: "" load_chkpt: {} +# When True, targets are not loaded and loss/metric logging is skipped entirely. +inference_only: True + norm_type: "LayerNorm" qk_norm_type: null # if null, defaults to norm_type From 13974f0a7715b922e9b9eeebc8a690826625d1ba Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sat, 25 Apr 2026 11:57:59 +0200 Subject: [PATCH 02/10] add inference_only mode: skip target loading, batch validation, and tokens_lens for targets; improve empty-batch error handling with max_attempts and skip reason logging --- .../datasets/multi_stream_data_sampler.py | 96 ++++++++++++------- 1 file changed, 60 insertions(+), 36 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 7b276607e..96ad2625d 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -91,6 +91,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.mode_cfg = mode_cfg self._stage = stage + self.inference_only = cf.get("inference_only", False) self.mini_epoch = 0 self.mask_value = 0.0 @@ -189,7 +190,6 @@ def _calc_baseperms(self, fsm: int) -> np.typing.NDArray: depends on fsm so must be repeated for __init__ and reset""" perms_len = int(self.index_range.end - self.index_range.start) perms_len -= (fsm + self.output_offset) * (self.time_step // self.step_timedelta) - return np.arange(perms_len) def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: @@ -441,14 +441,17 @@ def _build_stream_data_output( token_data = output_tokens[step] if "target_coords" in mode: - (tc, tc_l) = self.tokenizer.get_target_coords( + (tc, tc_l, tc_raw, tc_times) = self.tokenizer.get_target_coords( stream_info, rdata, token_data, (time_win_target.start, time_win_target.end), target_mask, ) - stream_data.add_target_coords(timestep_idx, tc, tc_l, rdata.is_spoof) + stream_data.add_target_coords( + timestep_idx, tc, tc_l, rdata.is_spoof, + target_coords_raw=tc_raw, times_raw=tc_times, + ) if "target_values" in mode: (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( @@ -613,9 +616,11 @@ def _preprocess_model_batch( batch.source_samples.tokens_lens = get_tokens_lens( self.streams, batch.source_samples, source_input_steps ) - batch.target_samples.tokens_lens = get_tokens_lens( - self.streams, batch.target_samples, target_input_steps - ) + # In inference_only mode targets are not loaded, so skip tokens_lens for targets + if not self.inference_only: + batch.target_samples.tokens_lens = get_tokens_lens( + self.streams, batch.target_samples, target_input_steps + ) return batch @@ -697,30 +702,33 @@ def _get_batch(self, idx: int, num_forecast_steps: int): batch.add_source_stream(sidx, tidx, stream_name, sdata, source_masks.metadata[sidx]) # for t_idx, mask in enumerate(source_masks): - for tidx, target_mask in enumerate(target_masks.masks): - # depending on the mode, the the streamdata obj to have the target mask applied to - # the inputs. Hence the target mask is also the source mask here. - sdata = self._build_stream_data( - target_select, - idx, - num_forecast_steps, - stream_info, - target_masks.metadata[tidx].params.get("num_steps_input", 1), - input_data, - output_data, - input_tokens, - output_tokens, - output_mask=target_mask, - input_mask=target_mask, - ) - target_metadata = target_masks.metadata[tidx] - # also want to add the mask to the metadata - target_metadata.mask = target_mask - # Map target to all source students - student_indices = [ - s_idx for s_idx, tid in enumerate(source_to_target) if tid == tidx - ] - batch.add_target_stream(tidx, student_indices, stream_name, sdata, target_metadata) + if not self.inference_only: + for tidx, target_mask in enumerate(target_masks.masks): + # depending on the mode, the the streamdata obj to have the target mask applied + # to the inputs. Hence the target mask is also the source mask here. + sdata = self._build_stream_data( + target_select, + idx, + num_forecast_steps, + stream_info, + target_masks.metadata[tidx].params.get("num_steps_input", 1), + input_data, + output_data, + input_tokens, + output_tokens, + output_mask=target_mask, + input_mask=target_mask, + ) + target_metadata = target_masks.metadata[tidx] + # also want to add the mask to the metadata + target_metadata.mask = target_mask + # Map target to all source students + student_indices = [ + s_idx for s_idx, tid in enumerate(source_to_target) if tid == tidx + ] + batch.add_target_stream( + tidx, student_indices, stream_name, sdata, target_metadata + ) source_in_steps = input_steps.max().item() target_in_steps = np.array([tc.get("num_steps_input", 1) for _, tc in target_cfgs.items()]) @@ -753,21 +761,37 @@ def __iter__(self) -> ModelBatch: # use while loop due to the scattered nature of the data in time and to # ensure batches are not empty + num_attempts = 0 + max_attempts = perms.shape[0] while True: idx: TIndex = perms[idx_raw % perms.shape[0]] idx_raw += 1 + num_attempts += 1 batch = self._get_batch(idx, num_forecast_steps) - # ensure the batch is valid, i.e. not completely empty and no NaN values - # student teacher has no classical targets + # Check for invalid batches: empty sources, NaN values, or empty targets (if applicable). mode = self.mode_cfg.get("training_mode") - not_valid = batch.sources_empty() or batch.is_nan() - not_valid = not_valid or (batch.targets_empty() if "masking" in mode else False) + sources_empty = batch.sources_empty() + sources_nan = batch.is_nan() + not_valid = sources_empty or sources_nan + if not self.inference_only: + not_valid = not_valid or ( + batch.targets_empty() if "masking" in mode else False + ) - # skip completely empty batch item or when all targets are empty -> no grad + # Skip invalid batches or raise an error if no valid batch is found after max_attempts. if not_valid: - logger.warning(f"Skipping empty batch with idx={idx}.") + if num_attempts > max_attempts: + raise RuntimeError( + f"Could not find a valid non-empty batch after {num_attempts} attempts. " + "All data may be missing or targets unavailable for this epoch." + ) + reason = "sources_empty" if sources_empty else ("sources_nan" if sources_nan else "targets_empty") + if self.inference_only: + logger.debug(f"Skipping empty batch with idx={idx} (reason={reason}).") + else: + logger.warning(f"Skipping empty batch with idx={idx} (reason={reason}).") else: break From 524a1322363c7f4575602210ac260c2b8eb7c99b Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sat, 25 Apr 2026 11:58:47 +0200 Subject: [PATCH 03/10] add optional target_coords_raw and times_raw fields to add_target_coords --- src/weathergen/datasets/stream_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 6564a0622..ffb127f19 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -291,6 +291,8 @@ def add_target_coords( target_coords: torch.Tensor, target_coords_per_cell: torch.Tensor, is_spoof: bool, + target_coords_raw=None, + times_raw=None, ) -> None: """ Add data for target for one input. @@ -320,6 +322,10 @@ def add_target_coords( self.target_coords[fstep] = target_coords self.target_coords_lens[fstep] = target_coords_per_cell + if target_coords_raw is not None: + self.target_coords_raw[fstep] = target_coords_raw + if times_raw is not None: + self.target_times_raw[fstep] = times_raw self.target_is_spoof[fstep] = is_spoof From 52c2212ad8a94aadd30139be006a622edfcbc2d2 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sat, 25 Apr 2026 11:59:29 +0200 Subject: [PATCH 04/10] return coords_raw and datetimes from get_target_coords --- src/weathergen/datasets/tokenizer_masking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 205bd5c8a..79e2e44ca 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -165,7 +165,7 @@ def get_target_coords( ) # TODO: split up - _, _, _, coords_local, coords_per_cell = tokenize_apply_mask_target( + _, datetimes, coords_raw, coords_local, coords_per_cell = tokenize_apply_mask_target( stream_info["stream_id"], self.hl_target, idxs_cells, @@ -180,7 +180,7 @@ def get_target_coords( encode_times_target, ) - return (coords_local, coords_per_cell) + return (coords_local, coords_per_cell, coords_raw, datetimes) def get_target_values( self, From 5aad8d14ba7cc4cf962da92efc3a0cc2eab6f836 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sat, 25 Apr 2026 11:59:57 +0200 Subject: [PATCH 05/10] skip loss/metric computation and logging in inference_only mode; fall back to source samples for coordinate/time metadata and override is_spoof for output writing --- src/weathergen/train/trainer.py | 63 ++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index caae65647..f822e9332 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -566,6 +566,8 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch.to_device(self.device) + inference_only = self.cf.get("inference_only", False) + # evaluate model with torch.autocast( device_type=f"cuda:{cf.local_rank}", @@ -584,20 +586,44 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): - target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) - targets_and_auxs[loss_name] = target_aux.compute( - self.cf.general.istep, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) - - _ = self.loss_calculator_val.compute_loss( - preds=preds, - targets_and_aux=targets_and_auxs, - metadata=extract_batch_metadata(batch), - ) + if not inference_only: + for loss_name, target_aux in self.target_and_aux_calculators_val.items(): + target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) + targets_and_auxs[loss_name] = target_aux.compute( + self.cf.general.istep, + batch.get_target_samples(target_idxs), + self.model_params, + self.model, + ) + else: + # In inference_only mode, targets are absent but we still need + # coordinate/time metadata for output writing. Compute targets_and_auxs + # from source samples, which carry target_coords_raw and + # target_times_raw (populated by get_target_coords in the sampler). + for loss_name, target_aux in self.target_and_aux_calculators_val.items(): + tao = target_aux.compute( + self.cf.general.istep, + batch.get_source_samples(), + self.model_params, + self.model, + ) + # Source stream data marks output steps as spoof because target + # files are absent by design in inference_only mode. Override + # is_spoof to False so write_output writes actual predictions + # rather than discarding them as corrupted validation data. + for step_dict in tao.physical: + for sname, step_data in step_dict.items(): + step_data["is_spoof"] = [False] * len( + step_data["is_spoof"] + ) + targets_and_auxs[loss_name] = tao + + if not inference_only: + _ = self.loss_calculator_val.compute_loss( + preds=preds, + targets_and_aux=targets_and_auxs, + metadata=extract_batch_metadata(batch), + ) # log output if bidx < num_samples_write: @@ -625,8 +651,13 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: break - self._log_terminal(0, mini_epoch, VAL) - self._log(VAL) + if not inference_only: + self._log_terminal(0, mini_epoch, VAL) + self._log(VAL) + else: + logger.info( + f"inference_only=True: skipping loss/metric logging for epoch {mini_epoch}." + ) # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() From a6aaa56a3b555c14f09a6235ab58867d2b31e1cc Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sat, 25 Apr 2026 12:00:24 +0200 Subject: [PATCH 06/10] fix output writing in inference_only mode: guard empty idxs_inv tensor, reshape empty target array to (0, n_channels), and use coord count for targets_lens --- src/weathergen/utils/validation_io.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index d21938dd5..7e9a2ab98 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -83,7 +83,9 @@ def write_output( t_times = target_data["target_times"][i_batch] idxs_inv = target_aux_out.physical[t_idx][sname]["idxs_inv"][i_batch] - if idxs_inv is not None: + if idxs_inv is not None and ( + not isinstance(idxs_inv, torch.Tensor) or idxs_inv.numel() > 0 + ): pred = pred[:, idxs_inv] target = target[idxs_inv] t_coords = t_coords[idxs_inv] @@ -91,6 +93,11 @@ def write_output( # denormalize data if requested and map to storage format preds_s += [dn_data(sname, pred.to(fp32)).detach().cpu().numpy()] + # In inference_only mode, target tokens are empty (no channel dim); create a + # properly shaped empty array so downstream code handles it gracefully. + n_channels = pred.shape[-1] + if target.ndim == 1 and target.numel() == 0: + target = target.reshape(0, n_channels) targets_s += [dn_data(sname, target.to(fp32)).detach().cpu().numpy()] # extract original target coords and times from target data @@ -98,7 +105,11 @@ def write_output( t_times_s += [t_times.astype("datetime64[ns]")] targets_lens[-1] += [[]] - targets_lens[-1][-1] += [t.shape[0] for t in targets_s] + # Use coordinate count rather than target-value count so that in + # inference_only mode (where target tokens are empty but coords and + # predictions are non-empty) the correct number of output datapoints + # is indexed when writing predictions. + targets_lens[-1][-1] += [t.shape[0] for t in t_coords_s] preds_all[-1] += [np.concatenate(preds_s, axis=1)] targets_all[-1] += [np.concatenate(targets_s)] From d36afe60a640de6be94c8e3b70e155d41b61139d Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 28 Apr 2026 14:29:18 +0200 Subject: [PATCH 07/10] default inference only to false --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index a27a2225c..aca2e4606 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -91,7 +91,7 @@ freeze_modules: "" load_chkpt: {} # When True, targets are not loaded and loss/metric logging is skipped entirely. -inference_only: True +inference_only: False norm_type: "LayerNorm" qk_norm_type: null # if null, defaults to norm_type From eb27fad830f471666fcc2d0ddbb9eefd3236a4d9 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 28 Apr 2026 14:30:06 +0200 Subject: [PATCH 08/10] improved logging for non valid data --- .../datasets/multi_stream_data_sampler.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 96ad2625d..a96314f6c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -774,12 +774,20 @@ def __iter__(self) -> ModelBatch: mode = self.mode_cfg.get("training_mode") sources_empty = batch.sources_empty() sources_nan = batch.is_nan() - not_valid = sources_empty or sources_nan - if not self.inference_only: - not_valid = not_valid or ( - batch.targets_empty() if "masking" in mode else False - ) - + targets_empty = ( + not self.inference_only and "masking" in mode and batch.targets_empty() + ) + not_valid = sources_empty or sources_nan or targets_empty + if not_valid: + if sources_empty: + logger.info(f"Skipping batch at idx={idx}: sources are empty.") + if sources_nan: + logger.info(f"Skipping batch at idx={idx}: sources contain NaN values.") + if targets_empty: + logger.info( + f"Skipping batch at idx={idx}: targets are empty " + "(inference_only=False, training_mode includes masking)." + ) # Skip invalid batches or raise an error if no valid batch is found after max_attempts. if not_valid: if num_attempts > max_attempts: @@ -787,11 +795,6 @@ def __iter__(self) -> ModelBatch: f"Could not find a valid non-empty batch after {num_attempts} attempts. " "All data may be missing or targets unavailable for this epoch." ) - reason = "sources_empty" if sources_empty else ("sources_nan" if sources_nan else "targets_empty") - if self.inference_only: - logger.debug(f"Skipping empty batch with idx={idx} (reason={reason}).") - else: - logger.warning(f"Skipping empty batch with idx={idx} (reason={reason}).") else: break From 08ac2e097a936f4bb88b3fb7cbceb39fb149e3a1 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 28 Apr 2026 14:30:34 +0200 Subject: [PATCH 09/10] added ruff requested change --- src/weathergen/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f822e9332..f172bdaf9 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -612,7 +612,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): # is_spoof to False so write_output writes actual predictions # rather than discarding them as corrupted validation data. for step_dict in tao.physical: - for sname, step_data in step_dict.items(): + for _sname, step_data in step_dict.items(): step_data["is_spoof"] = [False] * len( step_data["is_spoof"] ) From 9c674d572cacb8dd634dd97c472895e86200fe63 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 28 Apr 2026 14:33:27 +0200 Subject: [PATCH 10/10] added ruff requested change --- src/weathergen/datasets/multi_stream_data_sampler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index a96314f6c..700b22c08 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -770,7 +770,8 @@ def __iter__(self) -> ModelBatch: batch = self._get_batch(idx, num_forecast_steps) - # Check for invalid batches: empty sources, NaN values, or empty targets (if applicable). + # Check for invalid batches: empty sources, NaN values, or empty targets + # (if applicable). mode = self.mode_cfg.get("training_mode") sources_empty = batch.sources_empty() sources_nan = batch.is_nan() @@ -788,12 +789,14 @@ def __iter__(self) -> ModelBatch: f"Skipping batch at idx={idx}: targets are empty " "(inference_only=False, training_mode includes masking)." ) - # Skip invalid batches or raise an error if no valid batch is found after max_attempts. + # Skip invalid batches or raise an error if no valid batch is found + # after max_attempts. if not_valid: if num_attempts > max_attempts: raise RuntimeError( - f"Could not find a valid non-empty batch after {num_attempts} attempts. " - "All data may be missing or targets unavailable for this epoch." + f"Could not find a valid non-empty batch after {num_attempts} " + "attempts. All data may be missing or targets unavailable" + " for this epoch." ) else: break