From 483a1369f7508b3eae6289e67dfe79151e802d0d Mon Sep 17 00:00:00 2001 From: TillHae Date: Mon, 11 May 2026 15:53:27 +0200 Subject: [PATCH 1/2] call BatchSample method instead of repeating content --- src/weathergen/datasets/batch.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 6c4c0f913..78456bfba 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -472,13 +472,7 @@ def get_num_source_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams """ - # TODO: define explicitly - lens = [ - len(stream.source_tokens_cells) - for _, stream in self.target_samples.samples[0].streams_data.items() - ] - - return min(lens) + return self.target_samples.get_num_steps() def get_num_target_steps(self) -> int: """ From d79f737b3b2b269a2347551f4d19c9de4bb7e7e0 Mon Sep 17 00:00:00 2001 From: TillHae Date: Mon, 18 May 2026 08:59:15 +0200 Subject: [PATCH 2/2] improve object-oriented design of batch architecture --- src/weathergen/datasets/batch.py | 49 ++++++++++++++++++-------- src/weathergen/datasets/stream_data.py | 13 +++++++ src/weathergen/model/encoder.py | 2 +- src/weathergen/model/engines.py | 2 +- src/weathergen/model/model.py | 2 +- 5 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 78456bfba..9d17f0555 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -134,6 +134,29 @@ def get_stream_data(self, stream_name: str) -> StreamData: assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist" return self.streams_data[stream_name] + def get_num_source_steps(self) -> int: + """ + Get number of source steps from smallest of all available streams + """ + lens = [ + stream.get_num_source_steps() + for _, stream in self.streams_data.items() + if stream is not None + ] + return min(lens) if lens else 0 + + def get_num_target_steps(self) -> int: + """ + Get number of target steps from smallest of all available streams + """ + lens = [ + stream.get_num_target_steps() + for _, stream in self.streams_data.items() + if stream is not None + ] + return min(lens) if lens else 0 + + class BatchSamples: """ @@ -183,16 +206,18 @@ def get_subset(self, subset: list | None = None): bs.tokens_lens = torch.index_select(bs.tokens_lens, 1, torch_idxs) return bs - def get_num_steps(self) -> int: + def get_num_source_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams """ - # TODO: define explicitly - lens = [ - len(stream.source_tokens_cells) for _, stream in self.samples[0].streams_data.items() - ] + return self.samples[0].get_num_source_steps() + + def get_num_target_steps(self) -> int: + """ + Get number of target steps from smallest of all available streams + """ + return self.samples[0].get_num_target_steps() - return min(lens) def get_output_idxs(self) -> int: """ @@ -472,17 +497,11 @@ def get_num_source_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams """ - return self.target_samples.get_num_steps() + return self.source_samples.get_num_source_steps() def get_num_target_steps(self) -> int: """ - Get number of input/source steps from smallest of all available streams + Get number of target steps from smallest of all available streams """ - # TODO: define explicitly - # TODO: ensure that num_input_steps is constant across batch with different strategies - lens = [ - len(stream.target_tokens) - for _, stream in self.target_samples.samples[0].streams_data.items() - ] + return self.target_samples.get_num_target_steps() - return min(lens) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index e993a0f03..16c71b64c 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -438,6 +438,19 @@ def is_spoof(self, step: int) -> bool: """ return any(self.source_is_spoof) or self.target_is_spoof[step] + def get_num_source_steps(self) -> int: + """ + Get number of input/source steps + """ + return len(self.source_tokens_cells) + + def get_num_target_steps(self) -> int: + """ + Get number of target steps + """ + return len(self.target_tokens) + + def spoof(healpix_level: int, datetime, geoinfo_size, num_channels) -> IOReaderData: """ diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 54409e297..00fd1627a 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -289,7 +289,7 @@ def assimilate_local( cell_lens = torch.sum(batch.tokens_lens, 2).flatten() - num_steps_input = batch.get_num_steps() + num_steps_input = batch.get_num_source_steps() rs = num_steps_input * len(batch) # create register and latent tokens and prepend to latent spatial tokens diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 72486da2f..c57bef513 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -80,7 +80,7 @@ def __init__(self, cf: Config, sources_size) -> None: raise ValueError("Unsupported embedding network type") def forward(self, batch, pe_embed): - num_steps_input = batch.get_num_steps() + num_steps_input = batch.get_num_source_steps() num_tokens = torch.sum(batch.tokens_lens, 2).flatten().sum().item() tokens_all = torch.empty( diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ce046d3b3..e4c9f3d32 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -699,7 +699,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: output.add_latent_prediction(0, "posteriors", posteriors) # recover batch dimension and separate input_steps - shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) + shape = (len(batch), batch.get_num_source_steps(), *tokens.shape[1:]) # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1)