Optimize balancer and setup debug logger.#308
Optimize balancer and setup debug logger.#308JacoCheung wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Greptile OverviewGreptile SummaryThis PR optimizes the balancer overhead and batch allgather communication in the training pipeline by implementing two key improvements: 1. Background Thread for Balancer: Offloaded H2D transfer and batch shuffling (including Karmarkar-Karp load balancing) to a background thread using 2. Fused KJT AllGather: Replaced per-KJT gathering with Additional improvements:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Main as Main Thread
participant BG as Background Thread
participant GPU as GPU Stream
participant NCCL as NCCL (DP Group)
Note over Main,NCCL: Batch i+2 Processing (Optimized)
Main->>Main: _next_batch(dataloader)
Main->>BG: submit(_h2d_and_shuffle, batch)
activate BG
BG->>GPU: _to_device (H2D on memcpy_stream)
BG->>NCCL: AllGather workloads
BG->>NCCL: Karmarkar-Karp partitioning
BG->>NCCL: AllGather KJTs (fused, 2 calls)
BG->>NCCL: AllGather dense tensors
deactivate BG
Note over Main: Batch i Processing (overlapped)
Main->>GPU: Forward pass (default stream)
Main->>BG: future.result() [WAIT HERE]
activate BG
BG-->>Main: shuffled_batch
deactivate BG
Note over Main,NCCL: Safe to use DP communicator
Main->>NCCL: AllReduce loss (DP group)
Main->>GPU: Backward pass
Main->>NCCL: AllReduce gradients (DP group)
|
8b68ea4 to
2e3a262
Compare
examples/commons/utils/logger.py
Outdated
| def print_rank_all(message, level=logging.INFO): | ||
| """If distributed is initialized, print only on rank 0.""" |
There was a problem hiding this comment.
Docstring is incorrect - says "print only on rank 0" but print_rank_all actually prints on all ranks.
| def print_rank_all(message, level=logging.INFO): | |
| """If distributed is initialized, print only on rank 0.""" | |
| """If distributed is initialized, print on all ranks.""" |
examples/commons/utils/logger.py
Outdated
| def info_rank_all(message): | ||
| """If distributed is initialized, print only on rank 0.""" |
There was a problem hiding this comment.
Docstring is incorrect - says "print only on rank 0" but info_rank_all actually prints on all ranks.
| def info_rank_all(message): | |
| """If distributed is initialized, print only on rank 0.""" | |
| """If distributed is initialized, print on all ranks.""" |
examples/commons/utils/logger.py
Outdated
| def debug_rank_all(message): | ||
| """If distributed is initialized, print only on rank 0.""" |
There was a problem hiding this comment.
Docstring is incorrect - says "print only on rank 0" but debug_rank_all actually prints on all ranks.
| def debug_rank_all(message): | |
| """If distributed is initialized, print only on rank 0.""" | |
| """If distributed is initialized, print on all ranks.""" |
| self._shuffle_executor = concurrent.futures.ThreadPoolExecutor( | ||
| max_workers=1, thread_name_prefix="shuffle" | ||
| ) | ||
| self._shuffle_future: Optional["concurrent.futures.Future[Optional[In]]"] = None |
There was a problem hiding this comment.
Consider adding cleanup for _shuffle_executor to ensure graceful thread termination:
def __del__(self):
if hasattr(self, '_shuffle_executor'):
self._shuffle_executor.shutdown(wait=True)
Description
This PR aims to fully hide the balancer overhead (KK algorithm on host) and optimize the comms of batch allgather.
Besides, a set of helper utils is added.
- The peak sol is about fwd 60% , bwd %50.