diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 7d9e0ad..96f56fe 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -292,7 +292,7 @@ def temporal_split_losses( new_losses[f"{dset_name}/{fname}_{loss_name}_T={time_str}"] = loss_subset return new_losses - def split_up_losses(self, loss_values, loss_name, dset_name, field_names): + def split_up_losses(self, loss_values, loss_name, dset_name, field_names, return_time_logs=False): new_losses = {} time_logs = {} time_steps = loss_values.shape[0] # we already average over batch @@ -303,9 +303,10 @@ def split_up_losses(self, loss_values, loss_name, dset_name, field_names): ] # Split up losses by field for i, fname in enumerate(field_names): - time_logs[f"{dset_name}/{fname}_{loss_name}_rollout"] = loss_values[ - :, i - ].cpu() + if return_time_logs: + time_logs[f"{dset_name}/{fname}_{loss_name}_rollout"] = loss_values[ + :, i + ].cpu().detach() new_losses |= self.temporal_split_losses( loss_values[:, i], temporal_loss_intervals, loss_name, dset_name, fname ) @@ -313,7 +314,8 @@ def split_up_losses(self, loss_values, loss_name, dset_name, field_names): new_losses |= self.temporal_split_losses( loss_values.mean(1), temporal_loss_intervals, loss_name, dset_name, "full" ) - time_logs[f"{dset_name}/full_{loss_name}_rollout"] = loss_values.mean(1).cpu() + if return_time_logs: + time_logs[f"{dset_name}/full_{loss_name}_rollout"] = loss_values.mean(1).cpu().detach() return new_losses, time_logs @torch.inference_mode() @@ -356,17 +358,18 @@ def validation_loop( if not isinstance(loss, dict): loss = {loss_fn.__class__.__name__: loss} # Split the losses and update the logging dictionary + is_last_batch = (i == denom - 1) for k, v in loss.items(): sub_loss = v.mean(0) new_losses, new_time_logs = self.split_up_losses( - sub_loss, k, dset_name, field_names + sub_loss, k, dset_name, field_names, return_time_logs=is_last_batch ) # TODO get better way to include spectral error. - if k in long_time_metrics or "spectral_error" in k: + if is_last_batch and (k in long_time_metrics or "spectral_error" in k): time_logs |= new_time_logs for loss_name, loss_value in new_losses.items(): loss_dict[loss_name] = ( - loss_dict.get(loss_name, 0.0) + loss_value / denom + loss_dict.get(loss_name, 0.0) + loss_value.detach() / denom ) count += 1 if not full and count >= self.short_validation_length: diff --git a/the_well/data/datamodule.py b/the_well/data/datamodule.py index f3660a5..ece6a72 100755 --- a/the_well/data/datamodule.py +++ b/the_well/data/datamodule.py @@ -280,6 +280,7 @@ def train_dataloader(self) -> DataLoader: shuffle=shuffle, drop_last=True, sampler=sampler, + persistent_workers=self.data_workers > 0, ) def val_dataloader(self) -> DataLoader: @@ -309,6 +310,7 @@ def val_dataloader(self) -> DataLoader: shuffle=shuffle, drop_last=True, sampler=sampler, + persistent_workers=self.data_workers > 0, ) def rollout_val_dataloader(self) -> DataLoader: @@ -338,6 +340,7 @@ def rollout_val_dataloader(self) -> DataLoader: shuffle=shuffle, # Shuffling because most batches we take a small subsample drop_last=True, sampler=sampler, + persistent_workers=self.data_workers > 0, ) def test_dataloader(self) -> DataLoader: @@ -366,6 +369,7 @@ def test_dataloader(self) -> DataLoader: shuffle=False, drop_last=True, sampler=sampler, + persistent_workers=self.data_workers > 0, ) def rollout_test_dataloader(self) -> DataLoader: @@ -394,6 +398,7 @@ def rollout_test_dataloader(self) -> DataLoader: shuffle=False, drop_last=True, sampler=sampler, + persistent_workers=self.data_workers > 0, ) def __repr__(self) -> str: diff --git a/the_well/data/datasets.py b/the_well/data/datasets.py index 83c5115..ae13e0b 100755 --- a/the_well/data/datasets.py +++ b/the_well/data/datasets.py @@ -322,6 +322,8 @@ def __init__( self.files_paths = sub_files self.files_paths.sort() self.caches = [{} for _ in self.files_paths] + self._opened_files = {} + self._opened_files_pid = os.getpid() # Build multi-index self.metadata = self._build_metadata() # Override name if necessary for logging @@ -777,6 +779,21 @@ def _reconstruct_bcs(self, file: h5.File, cache, sample_idx, time_idx, n_steps, else: raise NotImplementedError() + def _get_file_handle(self, file_idx): + current_pid = os.getpid() + if current_pid != self._opened_files_pid: + self._opened_files = {} + self._opened_files_pid = current_pid + + if file_idx not in self._opened_files: + f = self.fs.open( + self.files_paths[file_idx], "rb", **IO_PARAMS["fsspec_params"] + ) + h5_file = h5.File(f, "r", **IO_PARAMS["h5py_params"]) + self._opened_files[file_idx] = (f, h5_file) + + return self._opened_files[file_idx][1] + def _load_one_sample(self, index): # Find specific file and local index if self.restriction_set is not None: @@ -791,35 +808,38 @@ def _load_one_sample(self, index): sample_idx = local_idx // windows_per_trajectory time_idx = local_idx % windows_per_trajectory # open hdf5 file (and cache the open object) - with h5.File( - self.fs.open( - self.files_paths[file_idx], "rb", **IO_PARAMS["fsspec_params"] - ), - "r", - **IO_PARAMS["h5py_params"], - ) as file: - # If we gave a stride range, decide the largest size we can use given the sample location - dt = self.min_dt_stride - if self.max_dt_stride > self.min_dt_stride: - effective_max_dt = maximum_stride_for_initial_index( - time_idx, - self.n_steps_per_trajectory[file_idx], - self.n_steps_input, - self.n_steps_output, - ) - effective_max_dt = min(effective_max_dt, self.max_dt_stride) - if effective_max_dt > self.min_dt_stride: - # Randint is non-inclusive on the upper bound - dt = np.random.randint(self.min_dt_stride, effective_max_dt + 1) - # Fetch the data - data = {} - - output_steps = min(self.n_steps_output, self.max_rollout_steps) - # If start_output_steps_at_t set, then work backwards for initial time index - if self.full_trajectory_mode and self.start_output_steps_at_t >= 0: - time_idx = self.start_output_steps_at_t - (self.n_steps_input) * dt - - data["variable_fields"], data["constant_fields"] = self._reconstruct_fields( + file = self._get_file_handle(file_idx) + # If we gave a stride range, decide the largest size we can use given the sample location + dt = self.min_dt_stride + if self.max_dt_stride > self.min_dt_stride: + effective_max_dt = maximum_stride_for_initial_index( + time_idx, + self.n_steps_per_trajectory[file_idx], + self.n_steps_input, + self.n_steps_output, + ) + effective_max_dt = min(effective_max_dt, self.max_dt_stride) + if effective_max_dt > self.min_dt_stride: + # Randint is non-inclusive on the upper bound + dt = np.random.randint(self.min_dt_stride, effective_max_dt + 1) + # Fetch the data + data = {} + + output_steps = min(self.n_steps_output, self.max_rollout_steps) + # If start_output_steps_at_t set, then work backwards for initial time index + if self.full_trajectory_mode and self.start_output_steps_at_t >= 0: + time_idx = self.start_output_steps_at_t - (self.n_steps_input) * dt + + data["variable_fields"], data["constant_fields"] = self._reconstruct_fields( + file, + self.caches[file_idx], + sample_idx, + time_idx, + self.n_steps_input + output_steps, + dt, + ) + data["variable_scalars"], data["constant_scalars"] = ( + self._reconstruct_scalars( file, self.caches[file_idx], sample_idx, @@ -827,36 +847,27 @@ def _load_one_sample(self, index): self.n_steps_input + output_steps, dt, ) - data["variable_scalars"], data["constant_scalars"] = ( - self._reconstruct_scalars( - file, - self.caches[file_idx], - sample_idx, - time_idx, - self.n_steps_input + output_steps, - dt, - ) - ) + ) - if self.boundary_return_type is not None: - data["boundary_conditions"] = self._reconstruct_bcs( - file, - self.caches[file_idx], - sample_idx, - time_idx, - self.n_steps_input + output_steps, - dt, - ) + if self.boundary_return_type is not None: + data["boundary_conditions"] = self._reconstruct_bcs( + file, + self.caches[file_idx], + sample_idx, + time_idx, + self.n_steps_input + output_steps, + dt, + ) - if self.return_grid: - data["space_grid"], data["time_grid"] = self._reconstruct_grids( - file, - self.caches[file_idx], - sample_idx, - time_idx, - self.n_steps_input + output_steps, - dt, - ) + if self.return_grid: + data["space_grid"], data["time_grid"] = self._reconstruct_grids( + file, + self.caches[file_idx], + sample_idx, + time_idx, + self.n_steps_input + output_steps, + dt, + ) return data, file_idx, sample_idx, time_idx, dt def _preprocess_data(