Skip to content

Optimize balancer and setup debug logger.#308

Open
JacoCheung wants to merge 10 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/opt_balancer
Open

Optimize balancer and setup debug logger.#308
JacoCheung wants to merge 10 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/opt_balancer

Conversation

@JacoCheung
Copy link
Collaborator

@JacoCheung JacoCheung commented Feb 11, 2026

Description

This PR aims to fully hide the balancer overhead (KK algorithm on host) and optimize the comms of batch allgather.

Besides, a set of helper utils is added.

  1. HSTU kernel SoL benchmark. (dense shape)
    - The peak sol is about fwd 60% , bwd %50.
  2. HSTU kernel balancer optimization benchmark. (the hstu speedup when input is shuffled evenly)
  3. In-flight hstu mfu logging.
  4. Balancer logger.
  5. Dataset seqlen distribution specs.

@greptile-apps
Copy link

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR optimizes the balancer overhead and batch allgather communication in the training pipeline by implementing two key improvements:

1. Background Thread for Balancer: Offloaded H2D transfer and batch shuffling (including Karmarkar-Karp load balancing) to a background thread using ThreadPoolExecutor. The main thread continues with forward pass while the background thread handles data movement on a separate CUDA stream (_memcpy_stream). Critical synchronization point added before the main thread touches the DP communicator to prevent NCCL deadlock.

2. Fused KJT AllGather: Replaced per-KJT gathering with keyed_jagged_tensor_list_allgather that concatenates all KJT lengths/values and performs only 2 NCCL calls total (1 for lengths, 1 for values), regardless of the number of KJTs. Uses keyed_jagged_index_select_dim1 for efficient rank-major to key-major layout transpose.

Additional improvements:

  • Removed unnecessary barrier() in _gatherv_along_first_dim (NCCL already synchronizes streams)
  • Added debug logging for load balance statistics (controlled by PRINT_LOAD_BALANCE env var)
  • Added HSTU attention performance tracking with MFU calculation
  • Refactored shuffle into two phases (compute_partition_indices and shuffle_batch_by_global_indices) for better modularity
  • Added utility classes for random distributions in benchmark data generation

Confidence Score: 4/5

  • Safe to merge with careful testing of concurrency and NCCL ordering
  • The PR implements sophisticated concurrency optimizations with proper NCCL synchronization points. The background threading and fused communication changes are well-documented and follow correct ordering constraints. However, the complexity of concurrent NCCL calls and stream synchronization requires thorough multi-GPU testing to ensure no deadlocks or race conditions occur in production workloads.
  • Pay close attention to examples/commons/pipeline/train_pipeline.py (concurrent NCCL usage) and examples/commons/ops/collective_ops.py (fused KJT allgather logic)

Important Files Changed

Filename Overview
examples/commons/distributed/batch_shuffler.py Refactored shuffle into two phases (compute_partition_indices and shuffle_batch_by_global_indices) for better separation; added load balance logging with env var controls
examples/commons/distributed/batch_allgather.py Optimized to fuse all KJT fields into single AllGather call pair using keyed_jagged_tensor_list_allgather instead of per-KJT gathering
examples/commons/ops/collective_ops.py Added keyed_jagged_tensor_list_allgather for fused KJT gathering (2 NCCL calls total); removed unnecessary barrier in _gatherv_along_first_dim
examples/commons/pipeline/train_pipeline.py Added ThreadPoolExecutor to offload H2D + batch shuffle to background thread; wait for shuffle completion before main thread touches DP communicator to avoid NCCL deadlock
examples/commons/utils/logger.py Added debug_rank_0, print_rank_all, info_rank_all, debug_rank_all functions; added LOG_LEVEL env var support; enhanced print_rank_0 with log level parameter
examples/commons/utils/attn_perf_tracker.py New file for HSTU attention performance tracking with lazy printing, auto-detection of training vs inference mode, MFU calculation

Sequence Diagram

sequenceDiagram
    participant Main as Main Thread
    participant BG as Background Thread
    participant GPU as GPU Stream
    participant NCCL as NCCL (DP Group)
    
    Note over Main,NCCL: Batch i+2 Processing (Optimized)
    Main->>Main: _next_batch(dataloader)
    Main->>BG: submit(_h2d_and_shuffle, batch)
    activate BG
    BG->>GPU: _to_device (H2D on memcpy_stream)
    BG->>NCCL: AllGather workloads
    BG->>NCCL: Karmarkar-Karp partitioning
    BG->>NCCL: AllGather KJTs (fused, 2 calls)
    BG->>NCCL: AllGather dense tensors
    deactivate BG
    
    Note over Main: Batch i Processing (overlapped)
    Main->>GPU: Forward pass (default stream)
    Main->>BG: future.result() [WAIT HERE]
    activate BG
    BG-->>Main: shuffled_batch
    deactivate BG
    
    Note over Main,NCCL: Safe to use DP communicator
    Main->>NCCL: AllReduce loss (DP group)
    Main->>GPU: Backward pass
    Main->>NCCL: AllReduce gradients (DP group)
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

21 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@JacoCheung JacoCheung force-pushed the junzhang/opt_balancer branch from 8b68ea4 to 2e3a262 Compare February 12, 2026 12:02
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 64 to 65
def print_rank_all(message, level=logging.INFO):
"""If distributed is initialized, print only on rank 0."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring is incorrect - says "print only on rank 0" but print_rank_all actually prints on all ranks.

Suggested change
def print_rank_all(message, level=logging.INFO):
"""If distributed is initialized, print only on rank 0."""
"""If distributed is initialized, print on all ranks."""

Comment on lines 72 to 73
def info_rank_all(message):
"""If distributed is initialized, print only on rank 0."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring is incorrect - says "print only on rank 0" but info_rank_all actually prints on all ranks.

Suggested change
def info_rank_all(message):
"""If distributed is initialized, print only on rank 0."""
"""If distributed is initialized, print on all ranks."""

Comment on lines 80 to 81
def debug_rank_all(message):
"""If distributed is initialized, print only on rank 0."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring is incorrect - says "print only on rank 0" but debug_rank_all actually prints on all ranks.

Suggested change
def debug_rank_all(message):
"""If distributed is initialized, print only on rank 0."""
"""If distributed is initialized, print on all ranks."""

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +719 to +722
self._shuffle_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="shuffle"
)
self._shuffle_future: Optional["concurrent.futures.Future[Optional[In]]"] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding cleanup for _shuffle_executor to ensure graceful thread termination:

    def __del__(self):
        if hasattr(self, '_shuffle_executor'):
            self._shuffle_executor.shutdown(wait=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant