Skip to content

Conversation

@vijk777
Copy link
Collaborator

@vijk777 vijk777 commented Jan 20, 2026

Summary

  • Remove per-epoch train_step evaluation on validation data (now only rollout diagnostics)
  • Stream training data from CPU to GPU per batch to reduce GPU memory usage
  • Remove deprecated ems_warmup_epochs feature

Details

Training data and stimulus are kept on CPU with pin_memory enabled for fast async transfer. Each batch extracts a window and transfers it to GPU. Validation data stays on GPU since it's small.

This reduces GPU memory usage by not loading the entire training dataset to GPU memory upfront.

Test plan

  • Run training with new streaming approach
  • Verify GPU memory usage is reduced
  • Confirm training metrics are comparable

🤖 Generated with Claude Code

vijk777 and others added 7 commits January 20, 2026 14:55
validation metrics are now only computed via diagnostics (rollout stats).
this removes the redundant per-epoch validation train_step call.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- keep train data on CPU with pin_memory for fast async transfer
- add extract_batch_windows() to extract and transfer batch windows
- modify make_batches_random() to yield windowed data on GPU
- update train_step functions to use windowed data directly
- validation data stays on GPU (small enough to fit)

this reduces GPU memory usage by not loading entire dataset to GPU.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
the ems warmup feature was not being used and added complexity.
removed the config field, validation logic, and training loop handling.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ochs

warmup epochs run before the main training loop and are additional to
the configured 'epochs' count. removed misleading validation that
required warmup < epochs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
replaces synchronous CPU->GPU transfer with DataLoader that prefetches
batches in parallel worker processes. this hides transfer latency by
preparing next batches while GPU trains on current batch.

- add WindowDataset class for contiguous slice access (faster than fancy indexing)
- add create_dataloader() with RandomSampler for infinite random sampling
- use persistent_workers and pin_memory for optimal throughput
- separate augmentation index generation (runs on GPU after batch arrives)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Fixes CUDA multiprocessing issue where fork start method causes silent
crashes during DataLoader worker initialization. Spawn is safer with
CUDA tensors.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Moves data and stim tensors to shared memory before creating the dataset.
This allows spawned worker processes to access the data without copying,
fixing bus errors and semaphore leaks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@vijk777
Copy link
Collaborator Author

vijk777 commented Jan 20, 2026

This is just too slow.

@vijk777 vijk777 closed this Jan 20, 2026
@vijk777 vijk777 deleted the vj/stream-batches branch January 20, 2026 23:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants