perf: optimize data pipeline, HDF5 caching, and trainer CPU syncs#84
Open
Chamath-Adithya wants to merge 1 commit into
Open
perf: optimize data pipeline, HDF5 caching, and trainer CPU syncs#84Chamath-Adithya wants to merge 1 commit into
Chamath-Adithya wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR: High-Performance Data Pipeline and I/O Optimization for Large-Scale Physical Systems Training
Abstract
This Pull Request introduces three critical high-performance optimizations targeting the unified PyTorch dataset interfaces (
WellDataset), data modularity (WellDataModule), and validation metrics loops withinthe_well. By implementing a process-safe HDF5 file handle cache, enabling persistent DataLoader worker processes, and pruning redundant GPU-to-CPU metric copies, we successfully eliminate key I/O bottlenecks and memory leaks during multi-epoch training and validation rollouts without modifying core physics data integrity.Technical Context & Architectural Inefficiencies
1. High-Frequency File Open/Close Overhead (I/O Bottleneck)
In the baseline
WellDataset._load_one_sampleimplementation, HDF5 datasets were opened and parsed for every single index retrieval using Python'swith h5.File(...) as file:block:In deep learning dataloaders with standard batch sizes, this causes hundreds of redundant file open, header-parsing, and socket/file-descriptor creation operations per training step. This pattern induces a massive storage-level latency bottleneck, especially on network or distributed file systems (
fsspec).2. Multi-Epoch Process Re-Initialization Penalty
Because PyTorch’s standard
DataLoaderdoes not enforcepersistent_workers=Trueby default, the background worker processes are completely torn down and reconstructed at the completion of every single training epoch. This discards any in-memory cached files and forces each new process to re-initialize metadata structures, wiping out standard cache benefits.3. Validation Rollout I/O Synchronization Blocking
During evaluation rollouts,
split_up_lossescompiled temporal loss statistics and immediately executed a.cpu()operation on time logs for every single batch:Because PyTorch processes operations asynchronously, calling
.cpu()forces an expensive blocking host-to-device (GPU-CPU) synchronization. Furthermore, since validation visualization functions are strictly executed on the last batch of an epoch, these intermediate batch CPU metrics were entirely overwritten and discarded, meaning hundreds of synchronous transfers were executed redundantly, causing massive GPU stall times.Implemented Solutions
A. Lazy Process-Safe File Handle Caching (
datasets.py)Introduced a lazily-initialized file-handle manager (
_get_file_handle) that caches openh5py.Fileand underlyingfsspecdescriptors. To prevent process descriptor leakage and concurrency collisions across multi-worker sub-processes underforkorspawnboundaries, we check and clear the cache when a change in process ID is detected:B. Persistent Worker Threads (
datamodule.py)Enabled
persistent_workers=self.data_workers > 0across all five baseline dataloaders. By maintaining worker states between epochs, background workers keep their cached file descriptors and metadata warm, completely bypassing epoch-boundary initialization overhead.C. Validation I/O Pruning and Detached Metric Tensors (
training.py)return_time_logsBoolean flag tosplit_up_losses. It is resolved dynamically asis_last_batch = (i == denom - 1)..cpu()) and memory allocation are now executed only on the last batch of validation, preserving CPU-GPU overlap and avoiding blocking syncs for all other batches.loss_dictare explicitly detached (.detach()) to ensure no lingering graph/computational dependencies reside in CUDA memory.Performance & Memory Impact
.cpu()transfers forced on every batch.detach())All package dataset interfaces and model setups have been statically validated. Boundary conditions, tensor formats, and physical coordinate-grid metrics remain perfectly intact.