Skip to content
Open
3 changes: 3 additions & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ latent_noise_deterministic_latents: True
freeze_modules: ""
load_chkpt: {}

# When True, targets are not loaded and loss/metric logging is skipped entirely.
inference_only: False

norm_type: "LayerNorm"
qk_norm_type: null # if null, defaults to norm_type

Expand Down
104 changes: 67 additions & 37 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage):

self.mode_cfg = mode_cfg
self._stage = stage
self.inference_only = cf.get("inference_only", False)

self.mini_epoch = 0
self.mask_value = 0.0
Expand Down Expand Up @@ -189,7 +190,6 @@ def _calc_baseperms(self, fsm: int) -> np.typing.NDArray:
depends on fsm so must be repeated for __init__ and reset"""
perms_len = int(self.index_range.end - self.index_range.start)
perms_len -= (fsm + self.output_offset) * (self.time_step // self.step_timedelta)

return np.arange(perms_len)

def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]:
Expand Down Expand Up @@ -441,14 +441,17 @@ def _build_stream_data_output(
token_data = output_tokens[step]

if "target_coords" in mode:
(tc, tc_l) = self.tokenizer.get_target_coords(
(tc, tc_l, tc_raw, tc_times) = self.tokenizer.get_target_coords(
stream_info,
rdata,
token_data,
(time_win_target.start, time_win_target.end),
target_mask,
)
stream_data.add_target_coords(timestep_idx, tc, tc_l, rdata.is_spoof)
stream_data.add_target_coords(
timestep_idx, tc, tc_l, rdata.is_spoof,
target_coords_raw=tc_raw, times_raw=tc_times,
)

if "target_values" in mode:
(tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values(
Expand Down Expand Up @@ -613,9 +616,11 @@ def _preprocess_model_batch(
batch.source_samples.tokens_lens = get_tokens_lens(
self.streams, batch.source_samples, source_input_steps
)
batch.target_samples.tokens_lens = get_tokens_lens(
self.streams, batch.target_samples, target_input_steps
)
# In inference_only mode targets are not loaded, so skip tokens_lens for targets
if not self.inference_only:
batch.target_samples.tokens_lens = get_tokens_lens(
self.streams, batch.target_samples, target_input_steps
)

return batch

Expand Down Expand Up @@ -697,30 +702,33 @@ def _get_batch(self, idx: int, num_forecast_steps: int):
batch.add_source_stream(sidx, tidx, stream_name, sdata, source_masks.metadata[sidx])

# for t_idx, mask in enumerate(source_masks):
for tidx, target_mask in enumerate(target_masks.masks):
# depending on the mode, the the streamdata obj to have the target mask applied to
# the inputs. Hence the target mask is also the source mask here.
sdata = self._build_stream_data(
target_select,
idx,
num_forecast_steps,
stream_info,
target_masks.metadata[tidx].params.get("num_steps_input", 1),
input_data,
output_data,
input_tokens,
output_tokens,
output_mask=target_mask,
input_mask=target_mask,
)
target_metadata = target_masks.metadata[tidx]
# also want to add the mask to the metadata
target_metadata.mask = target_mask
# Map target to all source students
student_indices = [
s_idx for s_idx, tid in enumerate(source_to_target) if tid == tidx
]
batch.add_target_stream(tidx, student_indices, stream_name, sdata, target_metadata)
if not self.inference_only:
for tidx, target_mask in enumerate(target_masks.masks):
# depending on the mode, the the streamdata obj to have the target mask applied
# to the inputs. Hence the target mask is also the source mask here.
sdata = self._build_stream_data(
target_select,
idx,
num_forecast_steps,
stream_info,
target_masks.metadata[tidx].params.get("num_steps_input", 1),
input_data,
output_data,
input_tokens,
output_tokens,
output_mask=target_mask,
input_mask=target_mask,
)
target_metadata = target_masks.metadata[tidx]
# also want to add the mask to the metadata
target_metadata.mask = target_mask
# Map target to all source students
student_indices = [
s_idx for s_idx, tid in enumerate(source_to_target) if tid == tidx
]
batch.add_target_stream(
tidx, student_indices, stream_name, sdata, target_metadata
)

source_in_steps = input_steps.max().item()
target_in_steps = np.array([tc.get("num_steps_input", 1) for _, tc in target_cfgs.items()])
Expand Down Expand Up @@ -753,21 +761,43 @@ def __iter__(self) -> ModelBatch:

# use while loop due to the scattered nature of the data in time and to
# ensure batches are not empty
num_attempts = 0
max_attempts = perms.shape[0]
while True:
idx: TIndex = perms[idx_raw % perms.shape[0]]
idx_raw += 1
num_attempts += 1

batch = self._get_batch(idx, num_forecast_steps)

# ensure the batch is valid, i.e. not completely empty and no NaN values
# student teacher has no classical targets
# Check for invalid batches: empty sources, NaN values, or empty targets
# (if applicable).
mode = self.mode_cfg.get("training_mode")
not_valid = batch.sources_empty() or batch.is_nan()
not_valid = not_valid or (batch.targets_empty() if "masking" in mode else False)

# skip completely empty batch item or when all targets are empty -> no grad
sources_empty = batch.sources_empty()
sources_nan = batch.is_nan()
targets_empty = (
not self.inference_only and "masking" in mode and batch.targets_empty()
)
not_valid = sources_empty or sources_nan or targets_empty
if not_valid:
logger.warning(f"Skipping empty batch with idx={idx}.")
if sources_empty:
logger.info(f"Skipping batch at idx={idx}: sources are empty.")
if sources_nan:
logger.info(f"Skipping batch at idx={idx}: sources contain NaN values.")
if targets_empty:
logger.info(
f"Skipping batch at idx={idx}: targets are empty "
"(inference_only=False, training_mode includes masking)."
)
# Skip invalid batches or raise an error if no valid batch is found
# after max_attempts.
if not_valid:
if num_attempts > max_attempts:
raise RuntimeError(
f"Could not find a valid non-empty batch after {num_attempts} "
"attempts. All data may be missing or targets unavailable"
" for this epoch."
)
else:
break

Expand Down
6 changes: 6 additions & 0 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def add_target_coords(
target_coords: torch.Tensor,
target_coords_per_cell: torch.Tensor,
is_spoof: bool,
target_coords_raw=None,
times_raw=None,
) -> None:
"""
Add data for target for one input.
Expand Down Expand Up @@ -320,6 +322,10 @@ def add_target_coords(

self.target_coords[fstep] = target_coords
self.target_coords_lens[fstep] = target_coords_per_cell
if target_coords_raw is not None:
self.target_coords_raw[fstep] = target_coords_raw
if times_raw is not None:
self.target_times_raw[fstep] = times_raw

self.target_is_spoof[fstep] = is_spoof

Expand Down
4 changes: 2 additions & 2 deletions src/weathergen/datasets/tokenizer_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_target_coords(
)

# TODO: split up
_, _, _, coords_local, coords_per_cell = tokenize_apply_mask_target(
_, datetimes, coords_raw, coords_local, coords_per_cell = tokenize_apply_mask_target(
stream_info["stream_id"],
self.hl_target,
idxs_cells,
Expand All @@ -180,7 +180,7 @@ def get_target_coords(
encode_times_target,
)

return (coords_local, coords_per_cell)
return (coords_local, coords_per_cell, coords_raw, datetimes)

def get_target_values(
self,
Expand Down
63 changes: 47 additions & 16 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,8 @@ def validate(self, mini_epoch, mode_cfg, batch_size):

batch.to_device(self.device)

inference_only = self.cf.get("inference_only", False)

# evaluate model
with torch.autocast(
device_type=f"cuda:{cf.local_rank}",
Expand All @@ -584,20 +586,44 @@ def validate(self, mini_epoch, mode_cfg, batch_size):
)

targets_and_auxs = {}
for loss_name, target_aux in self.target_and_aux_calculators_val.items():
target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name)
targets_and_auxs[loss_name] = target_aux.compute(
self.cf.general.istep,
batch.get_target_samples(target_idxs),
self.model_params,
self.model,
)

_ = self.loss_calculator_val.compute_loss(
preds=preds,
targets_and_aux=targets_and_auxs,
metadata=extract_batch_metadata(batch),
)
if not inference_only:
for loss_name, target_aux in self.target_and_aux_calculators_val.items():
target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name)
targets_and_auxs[loss_name] = target_aux.compute(
self.cf.general.istep,
batch.get_target_samples(target_idxs),
self.model_params,
self.model,
)
else:
# In inference_only mode, targets are absent but we still need
# coordinate/time metadata for output writing. Compute targets_and_auxs
# from source samples, which carry target_coords_raw and
# target_times_raw (populated by get_target_coords in the sampler).
for loss_name, target_aux in self.target_and_aux_calculators_val.items():
tao = target_aux.compute(
self.cf.general.istep,
batch.get_source_samples(),
self.model_params,
self.model,
)
# Source stream data marks output steps as spoof because target
# files are absent by design in inference_only mode. Override
# is_spoof to False so write_output writes actual predictions
# rather than discarding them as corrupted validation data.
for step_dict in tao.physical:
for _sname, step_data in step_dict.items():
step_data["is_spoof"] = [False] * len(
step_data["is_spoof"]
)
targets_and_auxs[loss_name] = tao

if not inference_only:
_ = self.loss_calculator_val.compute_loss(
preds=preds,
targets_and_aux=targets_and_auxs,
metadata=extract_batch_metadata(batch),
)

# log output
if bidx < num_samples_write:
Expand Down Expand Up @@ -625,8 +651,13 @@ def validate(self, mini_epoch, mode_cfg, batch_size):
if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch:
break

self._log_terminal(0, mini_epoch, VAL)
self._log(VAL)
if not inference_only:
self._log_terminal(0, mini_epoch, VAL)
self._log(VAL)
else:
logger.info(
f"inference_only=True: skipping loss/metric logging for epoch {mini_epoch}."
)

# avoid that there is a systematic bias in the validation subset
self.dataset_val.advance()
Expand Down
15 changes: 13 additions & 2 deletions src/weathergen/utils/validation_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,33 @@ def write_output(
t_times = target_data["target_times"][i_batch]

idxs_inv = target_aux_out.physical[t_idx][sname]["idxs_inv"][i_batch]
if idxs_inv is not None:
if idxs_inv is not None and (
not isinstance(idxs_inv, torch.Tensor) or idxs_inv.numel() > 0
):
pred = pred[:, idxs_inv]
target = target[idxs_inv]
t_coords = t_coords[idxs_inv]
t_times = t_times[idxs_inv]

# denormalize data if requested and map to storage format
preds_s += [dn_data(sname, pred.to(fp32)).detach().cpu().numpy()]
# In inference_only mode, target tokens are empty (no channel dim); create a
# properly shaped empty array so downstream code handles it gracefully.
n_channels = pred.shape[-1]
if target.ndim == 1 and target.numel() == 0:
target = target.reshape(0, n_channels)
targets_s += [dn_data(sname, target.to(fp32)).detach().cpu().numpy()]

# extract original target coords and times from target data
t_coords_s += [t_coords.cpu().numpy()]
t_times_s += [t_times.astype("datetime64[ns]")]

targets_lens[-1] += [[]]
targets_lens[-1][-1] += [t.shape[0] for t in targets_s]
# Use coordinate count rather than target-value count so that in
# inference_only mode (where target tokens are empty but coords and
# predictions are non-empty) the correct number of output datapoints
# is indexed when writing predictions.
targets_lens[-1][-1] += [t.shape[0] for t in t_coords_s]

preds_all[-1] += [np.concatenate(preds_s, axis=1)]
targets_all[-1] += [np.concatenate(targets_s)]
Expand Down
Loading