Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions the_well/benchmark/trainer/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def temporal_split_losses(
new_losses[f"{dset_name}/{fname}_{loss_name}_T={time_str}"] = loss_subset
return new_losses

def split_up_losses(self, loss_values, loss_name, dset_name, field_names):
def split_up_losses(self, loss_values, loss_name, dset_name, field_names, return_time_logs=False):
new_losses = {}
time_logs = {}
time_steps = loss_values.shape[0] # we already average over batch
Expand All @@ -303,17 +303,19 @@ def split_up_losses(self, loss_values, loss_name, dset_name, field_names):
]
# Split up losses by field
for i, fname in enumerate(field_names):
time_logs[f"{dset_name}/{fname}_{loss_name}_rollout"] = loss_values[
:, i
].cpu()
if return_time_logs:
time_logs[f"{dset_name}/{fname}_{loss_name}_rollout"] = loss_values[
:, i
].cpu().detach()
new_losses |= self.temporal_split_losses(
loss_values[:, i], temporal_loss_intervals, loss_name, dset_name, fname
)
# Compute average over all fields
new_losses |= self.temporal_split_losses(
loss_values.mean(1), temporal_loss_intervals, loss_name, dset_name, "full"
)
time_logs[f"{dset_name}/full_{loss_name}_rollout"] = loss_values.mean(1).cpu()
if return_time_logs:
time_logs[f"{dset_name}/full_{loss_name}_rollout"] = loss_values.mean(1).cpu().detach()
return new_losses, time_logs

@torch.inference_mode()
Expand Down Expand Up @@ -356,17 +358,18 @@ def validation_loop(
if not isinstance(loss, dict):
loss = {loss_fn.__class__.__name__: loss}
# Split the losses and update the logging dictionary
is_last_batch = (i == denom - 1)
for k, v in loss.items():
sub_loss = v.mean(0)
new_losses, new_time_logs = self.split_up_losses(
sub_loss, k, dset_name, field_names
sub_loss, k, dset_name, field_names, return_time_logs=is_last_batch
)
# TODO get better way to include spectral error.
if k in long_time_metrics or "spectral_error" in k:
if is_last_batch and (k in long_time_metrics or "spectral_error" in k):
time_logs |= new_time_logs
for loss_name, loss_value in new_losses.items():
loss_dict[loss_name] = (
loss_dict.get(loss_name, 0.0) + loss_value / denom
loss_dict.get(loss_name, 0.0) + loss_value.detach() / denom
)
count += 1
if not full and count >= self.short_validation_length:
Expand Down
5 changes: 5 additions & 0 deletions the_well/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def train_dataloader(self) -> DataLoader:
shuffle=shuffle,
drop_last=True,
sampler=sampler,
persistent_workers=self.data_workers > 0,
)

def val_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -309,6 +310,7 @@ def val_dataloader(self) -> DataLoader:
shuffle=shuffle,
drop_last=True,
sampler=sampler,
persistent_workers=self.data_workers > 0,
)

def rollout_val_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -338,6 +340,7 @@ def rollout_val_dataloader(self) -> DataLoader:
shuffle=shuffle, # Shuffling because most batches we take a small subsample
drop_last=True,
sampler=sampler,
persistent_workers=self.data_workers > 0,
)

def test_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -366,6 +369,7 @@ def test_dataloader(self) -> DataLoader:
shuffle=False,
drop_last=True,
sampler=sampler,
persistent_workers=self.data_workers > 0,
)

def rollout_test_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -394,6 +398,7 @@ def rollout_test_dataloader(self) -> DataLoader:
shuffle=False,
drop_last=True,
sampler=sampler,
persistent_workers=self.data_workers > 0,
)

def __repr__(self) -> str:
Expand Down
125 changes: 68 additions & 57 deletions the_well/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ def __init__(
self.files_paths = sub_files
self.files_paths.sort()
self.caches = [{} for _ in self.files_paths]
self._opened_files = {}
self._opened_files_pid = os.getpid()
# Build multi-index
self.metadata = self._build_metadata()
# Override name if necessary for logging
Expand Down Expand Up @@ -777,6 +779,21 @@ def _reconstruct_bcs(self, file: h5.File, cache, sample_idx, time_idx, n_steps,
else:
raise NotImplementedError()

def _get_file_handle(self, file_idx):
current_pid = os.getpid()
if current_pid != self._opened_files_pid:
self._opened_files = {}
self._opened_files_pid = current_pid

if file_idx not in self._opened_files:
f = self.fs.open(
self.files_paths[file_idx], "rb", **IO_PARAMS["fsspec_params"]
)
h5_file = h5.File(f, "r", **IO_PARAMS["h5py_params"])
self._opened_files[file_idx] = (f, h5_file)

return self._opened_files[file_idx][1]

def _load_one_sample(self, index):
# Find specific file and local index
if self.restriction_set is not None:
Expand All @@ -791,72 +808,66 @@ def _load_one_sample(self, index):
sample_idx = local_idx // windows_per_trajectory
time_idx = local_idx % windows_per_trajectory
# open hdf5 file (and cache the open object)
with h5.File(
self.fs.open(
self.files_paths[file_idx], "rb", **IO_PARAMS["fsspec_params"]
),
"r",
**IO_PARAMS["h5py_params"],
) as file:
# If we gave a stride range, decide the largest size we can use given the sample location
dt = self.min_dt_stride
if self.max_dt_stride > self.min_dt_stride:
effective_max_dt = maximum_stride_for_initial_index(
time_idx,
self.n_steps_per_trajectory[file_idx],
self.n_steps_input,
self.n_steps_output,
)
effective_max_dt = min(effective_max_dt, self.max_dt_stride)
if effective_max_dt > self.min_dt_stride:
# Randint is non-inclusive on the upper bound
dt = np.random.randint(self.min_dt_stride, effective_max_dt + 1)
# Fetch the data
data = {}

output_steps = min(self.n_steps_output, self.max_rollout_steps)
# If start_output_steps_at_t set, then work backwards for initial time index
if self.full_trajectory_mode and self.start_output_steps_at_t >= 0:
time_idx = self.start_output_steps_at_t - (self.n_steps_input) * dt

data["variable_fields"], data["constant_fields"] = self._reconstruct_fields(
file = self._get_file_handle(file_idx)
# If we gave a stride range, decide the largest size we can use given the sample location
dt = self.min_dt_stride
if self.max_dt_stride > self.min_dt_stride:
effective_max_dt = maximum_stride_for_initial_index(
time_idx,
self.n_steps_per_trajectory[file_idx],
self.n_steps_input,
self.n_steps_output,
)
effective_max_dt = min(effective_max_dt, self.max_dt_stride)
if effective_max_dt > self.min_dt_stride:
# Randint is non-inclusive on the upper bound
dt = np.random.randint(self.min_dt_stride, effective_max_dt + 1)
# Fetch the data
data = {}

output_steps = min(self.n_steps_output, self.max_rollout_steps)
# If start_output_steps_at_t set, then work backwards for initial time index
if self.full_trajectory_mode and self.start_output_steps_at_t >= 0:
time_idx = self.start_output_steps_at_t - (self.n_steps_input) * dt

data["variable_fields"], data["constant_fields"] = self._reconstruct_fields(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)
data["variable_scalars"], data["constant_scalars"] = (
self._reconstruct_scalars(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)
data["variable_scalars"], data["constant_scalars"] = (
self._reconstruct_scalars(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)
)
)

if self.boundary_return_type is not None:
data["boundary_conditions"] = self._reconstruct_bcs(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)
if self.boundary_return_type is not None:
data["boundary_conditions"] = self._reconstruct_bcs(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)

if self.return_grid:
data["space_grid"], data["time_grid"] = self._reconstruct_grids(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)
if self.return_grid:
data["space_grid"], data["time_grid"] = self._reconstruct_grids(
file,
self.caches[file_idx],
sample_idx,
time_idx,
self.n_steps_input + output_steps,
dt,
)
return data, file_idx, sample_idx, time_idx, dt

def _preprocess_data(
Expand Down