Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions src/weathergen/datasets/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,29 @@ def get_stream_data(self, stream_name: str) -> StreamData:
assert self.streams_data.get(stream_name, -1) != -1, "stream name does not exist"
return self.streams_data[stream_name]

def get_num_source_steps(self) -> int:
"""
Get number of source steps from smallest of all available streams
"""
lens = [
stream.get_num_source_steps()
for _, stream in self.streams_data.items()
if stream is not None
]
return min(lens) if lens else 0

def get_num_target_steps(self) -> int:
"""
Get number of target steps from smallest of all available streams
"""
lens = [
stream.get_num_target_steps()
for _, stream in self.streams_data.items()
if stream is not None
]
return min(lens) if lens else 0



class BatchSamples:
"""
Expand Down Expand Up @@ -183,16 +206,18 @@ def get_subset(self, subset: list | None = None):
bs.tokens_lens = torch.index_select(bs.tokens_lens, 1, torch_idxs)
return bs

def get_num_steps(self) -> int:
def get_num_source_steps(self) -> int:
"""
Get number of input/source steps from smallest of all available streams
"""
# TODO: define explicitly
lens = [
len(stream.source_tokens_cells) for _, stream in self.samples[0].streams_data.items()
]
return self.samples[0].get_num_source_steps()

def get_num_target_steps(self) -> int:
"""
Get number of target steps from smallest of all available streams
"""
return self.samples[0].get_num_target_steps()

return min(lens)

def get_output_idxs(self) -> int:
"""
Expand Down Expand Up @@ -472,23 +497,11 @@ def get_num_source_steps(self) -> int:
"""
Get number of input/source steps from smallest of all available streams
"""
# TODO: define explicitly
lens = [
len(stream.source_tokens_cells)
for _, stream in self.target_samples.samples[0].streams_data.items()
]

return min(lens)
return self.source_samples.get_num_source_steps()

def get_num_target_steps(self) -> int:
"""
Get number of input/source steps from smallest of all available streams
Get number of target steps from smallest of all available streams
"""
# TODO: define explicitly
# TODO: ensure that num_input_steps is constant across batch with different strategies
lens = [
len(stream.target_tokens)
for _, stream in self.target_samples.samples[0].streams_data.items()
]
return self.target_samples.get_num_target_steps()

return min(lens)
13 changes: 13 additions & 0 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,19 @@ def is_spoof(self, step: int) -> bool:
"""
return any(self.source_is_spoof) or self.target_is_spoof[step]

def get_num_source_steps(self) -> int:
"""
Get number of input/source steps
"""
return len(self.source_tokens_cells)

def get_num_target_steps(self) -> int:
"""
Get number of target steps
"""
return len(self.target_tokens)



def spoof(healpix_level: int, datetime, geoinfo_size, num_channels) -> IOReaderData:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def assimilate_local(

cell_lens = torch.sum(batch.tokens_lens, 2).flatten()

num_steps_input = batch.get_num_steps()
num_steps_input = batch.get_num_source_steps()
rs = num_steps_input * len(batch)

# create register and latent tokens and prepend to latent spatial tokens
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, cf: Config, sources_size) -> None:
raise ValueError("Unsupported embedding network type")

def forward(self, batch, pe_embed):
num_steps_input = batch.get_num_steps()
num_steps_input = batch.get_num_source_steps()

num_tokens = torch.sum(batch.tokens_lens, 2).flatten().sum().item()
tokens_all = torch.empty(
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
output.add_latent_prediction(0, "posteriors", posteriors)

# recover batch dimension and separate input_steps
shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:])
shape = (len(batch), batch.get_num_source_steps(), *tokens.shape[1:])
# collapse along input step dimension
tokens = tokens.reshape(shape).sum(axis=1)

Expand Down
Loading