-
Notifications
You must be signed in to change notification settings - Fork 0
feat: chunked streaming for GPU memory reduction #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Collaborator
vijk777
commented
Jan 22, 2026
- Implemented chunked streaming: load 64K chunks on-demand instead of full dataset
- Background thread prefetches chunks while GPU trains
- Prefetch buffer size: 6 chunks for optimal overlap
- Unit test to check async logic
- Script to run on youtube dataset with a mock computation
- implements RandomChunkLoader for streaming random 64K chunks - uses background thread for disk I/O to CPU pinned memory - uses CUDA streams for async CPU→GPU transfer - enables overlap of disk I/O, transfer, and training - comprehensive unit tests with mock training (sleeps) design: - loads random overlapping windows (no chunk boundaries) - queue-based producer-consumer pattern - batches_per_chunk = chunk_size / batch_size - chunks_per_epoch = dataset_size * passes_per_epoch / chunk_size 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
- use fixed 200ms load time, 100ms train time - increase prefetch buffer to 10 to avoid blocking background thread - expected: 1.5x speedup (1500ms sequential -> 1000ms with overlap) - makes test more predictable and easier to debug 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
numpy random generation was taking ~200ms per chunk, dominating the sleep delay and making overlap test unpredictable. now using np.zeros (instant) so load time is purely the sleep delay. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
- both 100ms now (was 200ms load, 100ms train) - sequential: 5 * (100+100) = 1000ms - with overlap: ~500ms (2x speedup expected) - test passes if total < 650ms (0.65 threshold) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
- amortizes cold start overhead over more iterations - sequential: 20 * (100+100) = 4000ms - with overlap: ~2000ms (2x speedup expected) - should show clearer overlap benefit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
implements: - ChunkLatencyStats: tracks chunk get time, batch forward/backward/step times - create_zarr_loader: wraps zarr loading for RandomChunkLoader - sample_batch_within_chunk: samples batches within current chunk - calculate_chunk_params: computes chunks_per_epoch, batches_per_chunk documentation: - CHUNKED_STREAMING_INTEGRATION.md: step-by-step guide - shows memory savings: 60 GB -> 8 GB (87% reduction) - includes expected latencies and tuning parameters ready to integrate into latent.py training loop 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
changes: - modified load_dataset to return chunk_loader instead of full training data - updated reconstruction warmup to use chunked iteration - replaced batch iterator with chunked iteration in main training loop - added ChunkLatencyStats tracking for chunk_get, forward, backward, step times - log latency stats to tensorboard every epoch - cleanup chunk_loader at end of training - removed train_loss_constant_model baseline (requires full dataset) memory savings: - before: ~60 GB (full training data on GPU) - after: ~8 GB (2 chunk buffers + validation data) - 87% reduction in GPU memory usage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
fixes: - add safety check in LossComponents.mean() for count == 0 - add validation in calculate_chunk_params() to ensure: - batches_per_chunk > 0 (chunk_size >= batch_size) - batches_per_epoch > 0 (dataset not too small) - chunks_per_epoch > 0 (enough batches for at least one chunk) bug occurred when warmup completed without processing any batches, causing count == 0 and division by zero in mean(). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
changes: - remove validation error when chunk_size > total_timesteps - set max_start_idx = 0 when dataset < chunk_size (always loads full dataset) - in calculate_chunk_params: if dataset < chunk_size: - chunks_per_epoch = data_passes_per_epoch (load once per pass) - batches_per_chunk = total_timesteps // batch_size now works with small test datasets where total_timesteps < 65536 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
- test_multiple_epochs_back_to_back: verifies consecutive epochs work - test_early_break_from_epoch: simulates warmup breaking early both tests currently fail, demonstrating thread cleanup issues that need to be fixed in chunk_loader.py
fixes three issues: 1. handle early break from epoch - stop thread gracefully instead of raising error 2. clear queue between epochs to remove None sentinel 3. update test to reflect new behavior (chunk_size > dataset is allowed) changes: - start_epoch() now stops alive threads by setting stop_flag and draining queue - clears queue before starting new epoch - renamed test_chunk_size_validation -> test_chunk_size_larger_than_dataset - test now verifies chunk_size > dataset returns full dataset
when chunk_size > dataset, end_idx was exceeding total_timesteps. now clamped: end_idx = min(start_idx + chunk_size, total_timesteps) fixes test_chunk_size_larger_than_dataset
tests chunk prefetching with real zarr data and 3.7s simulated training. measures: - queue size before each get_next_chunk() - chunk get times - queue empty/full counts tests prefetch values: 2, 4, 6 to diagnose why background loading isn't staying ahead during training
FlyVisSim is in LatentEvolution.load_flyvis, not NeuralGraph.FlyVis
changes: - added simulate_gpu_training() that does matmul on loaded chunks - uses actual chunk data for computation (realistic GPU memory pressure) - changed from 3.7s sleep to 1s GPU computation - loader now transfers chunks to cuda device - prints iterations per chunk to verify computation is happening - requires CUDA to run this better simulates real training workflow: disk → cpu → gpu → compute
based on diagnostics showing: - prefetch=2: queue empty 100% of time, 411ms mean get time - prefetch=6: queue has items 90% of time, 162ms mean get time expected improvements: - reduce chunk overhead from 11.5s to 4.5s per epoch (~7s savings) - 6% speedup overall (~12 min saved per 100 epochs) - cpu memory: +14 GB (21 GB total for prefetch buffer)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.