From c8a4148452a4099d25fef336ffa585355f7ccc18 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Wed, 20 May 2026 03:01:36 -0700 Subject: [PATCH 1/9] refactor a2a op to reuse existing code from mapping.py --- megatron/core/ssm/mamba_context_parallel.py | 36 ++++----------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/megatron/core/ssm/mamba_context_parallel.py b/megatron/core/ssm/mamba_context_parallel.py index 3297728d5fe..5c040716069 100644 --- a/megatron/core/ssm/mamba_context_parallel.py +++ b/megatron/core/ssm/mamba_context_parallel.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel import all_to_all +from megatron.core.tensor_parallel.mappings import all_to_all_hp2sp, all_to_all_sp2hp from megatron.core.utils import is_te_min_version try: @@ -302,7 +302,6 @@ def _slice_vector_param(self, param: torch.Tensor, has_hdim: bool = False) -> to return param[start:end] -# TODO(duncan): Consider combining with all_to_all_sp2hp in mappings.py and using einops.rearrange def _all_to_all_cp2hp( input_: torch.Tensor, cp_group: torch.distributed.ProcessGroup ) -> torch.Tensor: @@ -324,24 +323,12 @@ def _all_to_all_cp2hp( """ assert input_.dim() == 3, "all_to_all_cp2hp assumes 3-d input shape." s_in, b_in, h_in = input_.shape - # Squash the first two dimensions -> [s*b, h] - input_ = input_.reshape(-1, h_in) - # Split into world_size chunks along the h dimension - world_size = cp_group.size() - h_out = h_in // world_size - split_tensors = torch.split(input_, split_size_or_sections=h_out, dim=1) - # Concat the chunks along the s*b dimension - concat_tensor = torch.cat(split_tensors, dim=0) - # TODO(duncan): Can the following be optimized by using the non-single (tensor list) version of - # all-to-all? - # Swap chunks of dim0 across the cp ranks - output = all_to_all(cp_group, concat_tensor) - # Recover the s and b dimensions - output = output.reshape(s_in * world_size, b_in, h_out) + s_out, h_out = s_in * cp_group.size(), h_in // cp_group.size() + output = all_to_all_sp2hp(input_, group=cp_group) + output = output.reshape(s_out, b_in, h_out) return output -# TODO(duncan): Consider combining with all_to_all_hp2sp in mappings.py and using einops.rearrange def _all_to_all_hp2cp( input_: torch.Tensor, cp_group: torch.distributed.ProcessGroup ) -> torch.Tensor: @@ -363,18 +350,9 @@ def _all_to_all_hp2cp( """ assert input_.dim() == 3, "all_to_all_hp2cp assumes 3-d input shape." s_in, b_in, h_in = input_.shape - # Squash the first two dimensions -> [s*b, h] - input_ = input_.reshape(-1, h_in) - # Swap chunks of dim0 across the cp ranks - input_exchanged = all_to_all(cp_group, input_) - # Split into world_size chunks along the s*b dimension - world_size = cp_group.size() - s_out = s_in // world_size - split_tensors = torch.split(input_exchanged, split_size_or_sections=s_out * b_in, dim=0) - # Concat the chunks along the h dimension - output = torch.cat(split_tensors, dim=-1) - # Recover the s and b dimensions - output = output.reshape(s_out, b_in, h_in * world_size) + s_out, h_out = s_in // cp_group.size(), h_in * cp_group.size() + output = all_to_all_hp2sp(input_, group=cp_group) + output = output.reshape(s_out, b_in, h_out) return output From 330ac30b92995b107b19cc900ac8079ca20cf344 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Wed, 20 May 2026 23:05:19 -0700 Subject: [PATCH 2/9] fuse per-seq a2a into a unified one --- megatron/core/ssm/gated_delta_net.py | 104 +++++++--- tests/unit_tests/ssm/test_gated_delta_net.py | 194 ++++++++++++++++++- 2 files changed, 269 insertions(+), 29 deletions(-) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 1dbaf161e73..79cc7aff227 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -343,25 +343,31 @@ def forward( # CP All to All: CP to HP if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0) - outputs = [] - for qkvzba_i in unpacked_qkvzba: - qkvzba_i = tensor_a2a_cp2hp( - qkvzba_i, - seq_dim=0, - head_dim=-1, - cp_group=self.pg_collection.cp, - split_sections=[ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], + # Batched: one a2a on the full local THD tensor, then one local + # permutation that reorders rank-grouped output into per-sequence + # natural order. The permutation also folds in the per-sequence + # `_undo_attention_load_balancing`, so it's disabled inside the + # a2a call. + qkvzba = tensor_a2a_cp2hp( + qkvzba, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + split_sections=[ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + undo_attention_load_balancing=False, + ) + if self.cp_size > 1: + thd_cp_a2a_idx, thd_cp_a2a_inv = _build_thd_cp_a2a_perm( + cu_seqlens_q, self.cp_size, seq_len ) - outputs.append(qkvzba_i) - qkvzba = torch.cat(outputs, dim=0) + qkvzba = qkvzba.index_select(0, thd_cp_a2a_idx) else: qkvzba = tensor_a2a_cp2hp( qkvzba, @@ -489,14 +495,15 @@ def forward( # CP all to all: HP to CP if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0) - outputs = [] - for norm_out_i in unpacked_norm_out: - norm_out_i = tensor_a2a_hp2cp( - norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp - ) - outputs.append(norm_out_i) - norm_out = torch.cat(outputs, dim=0) + if self.cp_size > 1: + norm_out = norm_out.index_select(0, thd_cp_a2a_inv) + norm_out = tensor_a2a_hp2cp( + norm_out, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + redo_attention_load_balancing=False, + ) else: norm_out = tensor_a2a_hp2cp( norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp @@ -574,7 +581,7 @@ def _compute_g_and_beta(self, A_log_local_cp, dt_bias_local_cp, alpha, beta): def _resolve_cu_seqlens( self, cu_seqlens_padded, cu_seqlens_actual, total_seq_len, name, cp_size: int = 1 - ): + ) -> torch.Tensor: """Resolve cu_seqlens for packed sequence all-to-all, handling alignment padding.""" if cu_seqlens_padded is not None: cu_seqlens = cu_seqlens_padded @@ -696,7 +703,8 @@ def _backward_out_proj(self): self.out_proj.backward_dw() -def _unpack_sequence(x, cu_seqlens, dim=1): +# Used by tests/unit_tests/ssm/test_gated_delta_net.py +def _unpack_sequence(x, cu_seqlens, dim=1) -> list[torch.Tensor]: unpacked_x = [] cu_seqlens_list = cu_seqlens.tolist() num_seqs = len(cu_seqlens_list) - 1 @@ -708,6 +716,46 @@ def _unpack_sequence(x, cu_seqlens, dim=1): return unpacked_x +def _build_thd_cp_a2a_perm( + cu_seqlens: torch.Tensor, cp_size: int, t_global: int +) -> Tuple[torch.Tensor, torch.Tensor]: + cu = cu_seqlens.to(dtype=torch.long) + t_local = t_global // cp_size + + positions = torch.arange(t_global, device=cu.device) + seq_idx = torch.bucketize(positions, cu[1:], right=True) + seq_lens = torch.diff(cu) + halves = seq_lens // (2 * cp_size) # per-sequence half-chunk size + local_starts = cu[:-1] // cp_size + global_starts = cu[:-1] + + half_i = halves[seq_idx] + pos_in_seq = positions - global_starts[seq_idx] + + natural_chunk = pos_in_seq // half_i # in [0, 2*cp) + offset = pos_in_seq - natural_chunk * half_i + + # Invert the ordering produced by `_undo_attention_load_balancing`: + # natural_chunk < cp: load_balanced = 2 * natural_chunk + # natural_chunk >= cp: load_balanced = 4*cp - 2*natural_chunk - 1 + lb_chunk = torch.where( + natural_chunk < cp_size, 2 * natural_chunk, 4 * cp_size - 2 * natural_chunk - 1 + ) + + # In the per-sequence load-balanced layout each rank owns load-balanced + # chunks (2r) and (2r+1), in that order, of every sequence. + rank = lb_chunk // 2 + half_within_rank = lb_chunk - 2 * rank + k = half_within_rank * half_i + offset + + idx = rank * t_local + local_starts[seq_idx] + k + + inv = torch.empty_like(idx) + inv[idx] = positions + + return idx, inv + + #################### # Sharded state dict utilities #################### diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 6bd7e3ea841..53ff467db47 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -16,7 +16,13 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.ssm.gated_delta_net import GatedDeltaNet +from megatron.core.ssm.gated_delta_net import ( + GatedDeltaNet, + _build_thd_cp_a2a_perm, + _unpack_sequence, + tensor_a2a_cp2hp, + tensor_a2a_hp2cp, +) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig from megatron.core.utils import unwrap_model @@ -406,3 +412,189 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packi micro_batch_size=4, sequence_packing=sequence_packing, ) + + +@pytest.mark.parametrize("cp_size", [2, 4]) +@pytest.mark.internal +class TestBatchedThdAllToAll: + """Verify batched-a2a + permute matches the per-sequence loop in GDN.""" + + @pytest.fixture(scope='function', autouse=True) + def setup_method(self, cp_size): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + ) + model_parallel_cuda_manual_seed(123) + self.cp_size = cp_size + self.cp_group = parallel_state.get_context_parallel_group() + + def teardown_method(self): + Utils.destroy_model_parallel() + + @staticmethod + def _per_seq_a2a_cp2hp(local_t, cu_seqlens, cp_group, split_sections=None): + cp_size = cp_group.size() + unpacked = _unpack_sequence(local_t, cu_seqlens // cp_size, dim=0) + outputs = [] + for x in unpacked: + outputs.append( + tensor_a2a_cp2hp( + x, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + undo_attention_load_balancing=True, + ) + ) + return torch.cat(outputs, dim=0) + + @staticmethod + def _per_seq_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): + unpacked = _unpack_sequence(global_t, cu_seqlens, dim=0) + outputs = [] + for x in unpacked: + outputs.append( + tensor_a2a_hp2cp( + x, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + redo_attention_load_balancing=True, + ) + ) + return torch.cat(outputs, dim=0) + + # ---- Optimized: single a2a + production permutation helper ---- + + @staticmethod + def _batched_a2a_cp2hp(local_t, cu_seqlens, cp_group, split_sections=None): + cp_size = cp_group.size() + t_global = int(cu_seqlens[-1].item()) + naive = tensor_a2a_cp2hp( + local_t, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + undo_attention_load_balancing=False, + ) + idx, _ = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, t_global) + return naive.index_select(0, idx) + + @staticmethod + def _batched_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): + cp_size = cp_group.size() + t_global = int(cu_seqlens[-1].item()) + _, inv = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, t_global) + permuted = global_t.index_select(0, inv) + return tensor_a2a_hp2cp( + permuted, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + redo_attention_load_balancing=False, + ) + + # ---- Tests ---- + + @pytest.mark.parametrize( + "cu_seqlens", + [ + (0, 32, 64), # 2 equal sequences + (0, 32, 64, 96, 128), # 4 equal sequences (matches existing THD test) + (0, 16, 48, 80), # 3 unequal sequences + ], + ) + def test_cp2hp_batched_matches_per_seq(self, cu_seqlens): + cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) + if ((cu[1:] - cu[:-1]) % self.cp_size != 0).any(): + pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") + + T_global = cu_seqlens[-1] + T_local = T_global // self.cp_size + hidden = 32 + torch.manual_seed(42 + self.cp_size) + local_t = ( + torch.rand(T_local, 1, hidden, device=torch.cuda.current_device()) + .bfloat16() + .contiguous() + ) + + out_ref = self._per_seq_a2a_cp2hp(local_t, cu, self.cp_group) + out_opt = self._batched_a2a_cp2hp(local_t, cu, self.cp_group) + + rank = torch.distributed.get_rank() + assert out_opt.shape == out_ref.shape, (out_opt.shape, out_ref.shape) + # Both paths apply the same a2a kernel; only the surrounding pack/cat + # differs. Equality should be bitwise. + torch.testing.assert_close( + out_opt, + out_ref, + atol=0.0, + rtol=0.0, + msg=lambda m: f"Batched CP->HP mismatch on rank={rank}: {m}", + ) + + @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64), (0, 32, 64, 96, 128), (0, 16, 48, 80)]) + def test_hp2cp_batched_matches_per_seq(self, cu_seqlens): + cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) + if ((cu[1:] - cu[:-1]) % self.cp_size != 0).any(): + pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") + + T_global = cu_seqlens[-1] + hidden = 32 + # Hidden must be divisible by cp_size for the HP-sharded input layout. + assert hidden % self.cp_size == 0 + h_local = hidden // self.cp_size + torch.manual_seed(42 + self.cp_size) + global_t = ( + torch.rand(T_global, 1, h_local, device=torch.cuda.current_device()) + .bfloat16() + .contiguous() + ) + + out_ref = self._per_seq_a2a_hp2cp(global_t, cu, self.cp_group) + out_opt = self._batched_a2a_hp2cp(global_t, cu, self.cp_group) + + rank = torch.distributed.get_rank() + assert out_opt.shape == out_ref.shape, (out_opt.shape, out_ref.shape) + torch.testing.assert_close( + out_opt, + out_ref, + atol=0.0, + rtol=0.0, + msg=lambda m: f"Batched HP->CP mismatch on rank={rank}: {m}", + ) + + @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64, 96, 128)]) + def test_cp2hp_hp2cp_round_trip(self, cu_seqlens): + """cp2hp followed by hp2cp on the batched path should be the identity.""" + cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) + if ((cu[1:] - cu[:-1]) % self.cp_size != 0).any(): + pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") + + T_global = cu_seqlens[-1] + T_local = T_global // self.cp_size + hidden = 32 + torch.manual_seed(7) + local_t = ( + torch.rand(T_local, 1, hidden, device=torch.cuda.current_device()) + .bfloat16() + .contiguous() + ) + + mid = self._batched_a2a_cp2hp(local_t, cu, self.cp_group) + back = self._batched_a2a_hp2cp(mid, cu, self.cp_group) + + torch.testing.assert_close( + back, + local_t, + atol=0.0, + rtol=0.0, + msg=lambda m: f"Batched cp2hp -> hp2cp not identity: {m}", + ) From 31946595a3436c8ee6af907df240ca1b6015ec45 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Thu, 21 May 2026 02:10:07 -0700 Subject: [PATCH 3/9] use qwen3 model config for testing --- tests/unit_tests/ssm/test_gated_delta_net.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 53ff467db47..f5ef0f6d5ca 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -78,17 +78,18 @@ def setup_method(self, tp_size, sp, cp_size): # Initialize model self.transformer_config = TransformerConfig( - hidden_size=256, - linear_conv_kernel_dim=2, - linear_key_head_dim=64, - linear_value_head_dim=64, - linear_num_key_heads=4, - linear_num_value_heads=8, + hidden_size=2048, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, num_layers=1, normalization="RMSNorm", use_cpu_initialization=True, layernorm_zero_centered_gamma=True, - num_attention_heads=8, + num_attention_heads=16, + num_query_groups=2, activation_func=F.silu, bf16=True, tensor_model_parallel_size=tp_size, From 670bae83cebf30b33d1e18dc8ac2950b45f5d87a Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Fri, 22 May 2026 01:55:08 -0700 Subject: [PATCH 4/9] add head perm --- megatron/core/ssm/gated_delta_net.py | 68 +++++++++++++------- tests/unit_tests/ssm/test_gated_delta_net.py | 67 +++++++++---------- 2 files changed, 73 insertions(+), 62 deletions(-) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 79cc7aff227..1ac3053f819 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -138,6 +138,24 @@ def __init__( self.qk_dim_local_tp = self.qk_dim // self.tp_size self.v_dim_local_tp = self.v_dim // self.tp_size + if self.cp_size > 1: + head_perm = _build_head_perm_for_split_sections( + [ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + self.cp_size, + torch.cuda.current_device(), + ) + else: + head_perm = None + # Registered as a non-persistent buffer to exclude it from state_dict + self.register_buffer("_thd_head_perm", head_perm, persistent=False) + # Input projection (hidden_states -> q, k, v, gate, beta, alpha) # TODO: for now, output gate is forced for GDN. # We may remove this restriction in the future. @@ -342,25 +360,19 @@ def forward( nvtx_range_pop(suffix="in_proj") # CP All to All: CP to HP + if self.cp_size > 1: + qkvzba = qkvzba.index_select(-1, self._thd_head_perm) if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - # Batched: one a2a on the full local THD tensor, then one local - # permutation that reorders rank-grouped output into per-sequence - # natural order. The permutation also folds in the per-sequence - # `_undo_attention_load_balancing`, so it's disabled inside the - # a2a call. + # Batched: one a2a on the full local THD tensor, then two local + # permutations -- one on the head dim (so a single fused no-split + # a2a still produces the per-channel scatter layout) and one on + # the seq dim (rank-grouped -> per-seq natural order, also folds + # in `_undo_attention_load_balancing`). qkvzba = tensor_a2a_cp2hp( qkvzba, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp, - split_sections=[ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], undo_attention_load_balancing=False, ) if self.cp_size > 1: @@ -370,18 +382,7 @@ def forward( qkvzba = qkvzba.index_select(0, thd_cp_a2a_idx) else: qkvzba = tensor_a2a_cp2hp( - qkvzba, - seq_dim=0, - head_dim=-1, - cp_group=self.pg_collection.cp, - split_sections=[ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], + qkvzba, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp ) # Transpose: s b x --> b s x @@ -756,6 +757,23 @@ def _build_thd_cp_a2a_perm( return idx, inv +def _build_head_perm_for_split_sections( + split_sections: List[int], cp_size: int, device: torch.device +) -> torch.Tensor: + assert all( + s % cp_size == 0 for s in split_sections + ), f"split_sections {split_sections} must be divisible by cp_size {cp_size} for GDN" + offset = 0 + parts = [] + for s in split_sections: + parts.append( + torch.arange(offset, offset + s, device=device, dtype=torch.long).view(cp_size, -1) + ) + offset += s + + return torch.cat(parts, dim=-1).view(-1) + + #################### # Sharded state dict utilities #################### diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index f5ef0f6d5ca..92fba405c88 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -18,6 +18,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.gated_delta_net import ( GatedDeltaNet, + _build_head_perm_for_split_sections, _build_thd_cp_a2a_perm, _unpack_sequence, tensor_a2a_cp2hp, @@ -76,7 +77,7 @@ def setup_method(self, tp_size, sp, cp_size): cp_group = parallel_state.get_context_parallel_group() pg_collection = ProcessGroupCollection(tp=tp_group, cp=cp_group) - # Initialize model + # Initialize model, with the same config as Qwen Next except `num_layers` self.transformer_config = TransformerConfig( hidden_size=2048, linear_conv_kernel_dim=4, @@ -417,8 +418,8 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packi @pytest.mark.parametrize("cp_size", [2, 4]) @pytest.mark.internal -class TestBatchedThdAllToAll: - """Verify batched-a2a + permute matches the per-sequence loop in GDN.""" +class TestFusedThdAllToAll: + """Verify fused 1 AllToAll + permute matches the per-sequence, per-channel loop in GDN.""" @pytest.fixture(scope='function', autouse=True) def setup_method(self, cp_size): @@ -475,12 +476,17 @@ def _per_seq_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): def _batched_a2a_cp2hp(local_t, cu_seqlens, cp_group, split_sections=None): cp_size = cp_group.size() t_global = int(cu_seqlens[-1].item()) + if split_sections is not None and cp_size > 1: + head_perm = _build_head_perm_for_split_sections( + list(split_sections), cp_size, local_t.device + ) + local_t = local_t.index_select(-1, head_perm) naive = tensor_a2a_cp2hp( local_t, seq_dim=0, head_dim=-1, cp_group=cp_group, - split_sections=split_sections, + split_sections=None, # always single fused a2a undo_attention_load_balancing=False, ) idx, _ = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, t_global) @@ -501,8 +507,6 @@ def _batched_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): redo_attention_load_balancing=False, ) - # ---- Tests ---- - @pytest.mark.parametrize( "cu_seqlens", [ @@ -511,34 +515,35 @@ def _batched_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): (0, 16, 48, 80), # 3 unequal sequences ], ) - def test_cp2hp_batched_matches_per_seq(self, cu_seqlens): + @pytest.mark.parametrize("split_sections", [(8, 8, 4, 4, 4, 4)]) + @pytest.mark.skip + def test_cp2hp_batched_matches_per_seq(self, cu_seqlens, split_sections): cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) - if ((cu[1:] - cu[:-1]) % self.cp_size != 0).any(): + if (torch.diff(cu) % self.cp_size != 0).any(): pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") T_global = cu_seqlens[-1] T_local = T_global // self.cp_size hidden = 32 - torch.manual_seed(42 + self.cp_size) + if split_sections is not None: + assert sum(split_sections) == hidden, (split_sections, hidden) + torch.manual_seed(42) local_t = ( torch.rand(T_local, 1, hidden, device=torch.cuda.current_device()) .bfloat16() .contiguous() ) - out_ref = self._per_seq_a2a_cp2hp(local_t, cu, self.cp_group) - out_opt = self._batched_a2a_cp2hp(local_t, cu, self.cp_group) + out_ref = self._per_seq_a2a_cp2hp( + local_t, cu, self.cp_group, split_sections=list(split_sections) + ) + out_fused = self._batched_a2a_cp2hp( + local_t, cu, self.cp_group, split_sections=list(split_sections) + ) rank = torch.distributed.get_rank() - assert out_opt.shape == out_ref.shape, (out_opt.shape, out_ref.shape) - # Both paths apply the same a2a kernel; only the surrounding pack/cat - # differs. Equality should be bitwise. - torch.testing.assert_close( - out_opt, - out_ref, - atol=0.0, - rtol=0.0, - msg=lambda m: f"Batched CP->HP mismatch on rank={rank}: {m}", + assert torch.equal(out_fused, out_ref), ( + f"Batched CP->HP mismatch on rank={rank} " f"(split_sections={split_sections})" ) @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64), (0, 32, 64, 96, 128), (0, 16, 48, 80)]) @@ -552,7 +557,7 @@ def test_hp2cp_batched_matches_per_seq(self, cu_seqlens): # Hidden must be divisible by cp_size for the HP-sharded input layout. assert hidden % self.cp_size == 0 h_local = hidden // self.cp_size - torch.manual_seed(42 + self.cp_size) + torch.manual_seed(42) global_t = ( torch.rand(T_global, 1, h_local, device=torch.cuda.current_device()) .bfloat16() @@ -560,19 +565,13 @@ def test_hp2cp_batched_matches_per_seq(self, cu_seqlens): ) out_ref = self._per_seq_a2a_hp2cp(global_t, cu, self.cp_group) - out_opt = self._batched_a2a_hp2cp(global_t, cu, self.cp_group) + out_fused = self._batched_a2a_hp2cp(global_t, cu, self.cp_group) rank = torch.distributed.get_rank() - assert out_opt.shape == out_ref.shape, (out_opt.shape, out_ref.shape) - torch.testing.assert_close( - out_opt, - out_ref, - atol=0.0, - rtol=0.0, - msg=lambda m: f"Batched HP->CP mismatch on rank={rank}: {m}", - ) + assert torch.equal(out_fused, out_ref), f"Batched HP->CP mismatch on rank={rank}" @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64, 96, 128)]) + @pytest.mark.skip def test_cp2hp_hp2cp_round_trip(self, cu_seqlens): """cp2hp followed by hp2cp on the batched path should be the identity.""" cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) @@ -592,10 +591,4 @@ def test_cp2hp_hp2cp_round_trip(self, cu_seqlens): mid = self._batched_a2a_cp2hp(local_t, cu, self.cp_group) back = self._batched_a2a_hp2cp(mid, cu, self.cp_group) - torch.testing.assert_close( - back, - local_t, - atol=0.0, - rtol=0.0, - msg=lambda m: f"Batched cp2hp -> hp2cp not identity: {m}", - ) + assert torch.equal(back, local_t), "Batched cp2hp -> hp2cp not identity" From da59a4594665185b92d6cbaea777cf3675df9939 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Fri, 22 May 2026 20:53:52 -0700 Subject: [PATCH 5/9] fix test --- tests/unit_tests/ssm/test_gated_delta_net.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 92fba405c88..9d3a021d142 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -416,23 +416,23 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packi ) -@pytest.mark.parametrize("cp_size", [2, 4]) +@pytest.mark.parametrize("cp_size", [2, 4], scope="class") @pytest.mark.internal class TestFusedThdAllToAll: """Verify fused 1 AllToAll + permute matches the per-sequence, per-channel loop in GDN.""" - @pytest.fixture(scope='function', autouse=True) - def setup_method(self, cp_size): + @pytest.fixture(scope='class', autouse=True) + def setup_method(self, request, cp_size): Utils.initialize_model_parallel( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=cp_size, ) model_parallel_cuda_manual_seed(123) - self.cp_size = cp_size - self.cp_group = parallel_state.get_context_parallel_group() - - def teardown_method(self): + # Attach on the class so every test method can read self.cp_*. + request.cls.cp_size = cp_size + request.cls.cp_group = parallel_state.get_context_parallel_group() + yield Utils.destroy_model_parallel() @staticmethod @@ -516,7 +516,6 @@ def _batched_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): ], ) @pytest.mark.parametrize("split_sections", [(8, 8, 4, 4, 4, 4)]) - @pytest.mark.skip def test_cp2hp_batched_matches_per_seq(self, cu_seqlens, split_sections): cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) if (torch.diff(cu) % self.cp_size != 0).any(): @@ -571,7 +570,6 @@ def test_hp2cp_batched_matches_per_seq(self, cu_seqlens): assert torch.equal(out_fused, out_ref), f"Batched HP->CP mismatch on rank={rank}" @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64, 96, 128)]) - @pytest.mark.skip def test_cp2hp_hp2cp_round_trip(self, cu_seqlens): """cp2hp followed by hp2cp on the batched path should be the identity.""" cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) From 2c77843875ec8b39b347bba823d080c506001ff4 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Tue, 26 May 2026 21:02:56 -0700 Subject: [PATCH 6/9] move unused function to test file --- megatron/core/ssm/gated_delta_net.py | 13 ------------- tests/unit_tests/ssm/test_gated_delta_net.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 1ac3053f819..3c05856bf47 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -704,19 +704,6 @@ def _backward_out_proj(self): self.out_proj.backward_dw() -# Used by tests/unit_tests/ssm/test_gated_delta_net.py -def _unpack_sequence(x, cu_seqlens, dim=1) -> list[torch.Tensor]: - unpacked_x = [] - cu_seqlens_list = cu_seqlens.tolist() - num_seqs = len(cu_seqlens_list) - 1 - for i in range(num_seqs): - idx_start = cu_seqlens_list[i] - idx_end = cu_seqlens_list[i + 1] - chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] - unpacked_x.append(x[tuple(chunked_index)]) - return unpacked_x - - def _build_thd_cp_a2a_perm( cu_seqlens: torch.Tensor, cp_size: int, t_global: int ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 9d3a021d142..67749b3f1c2 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -20,7 +20,6 @@ GatedDeltaNet, _build_head_perm_for_split_sections, _build_thd_cp_a2a_perm, - _unpack_sequence, tensor_a2a_cp2hp, tensor_a2a_hp2cp, ) @@ -51,6 +50,18 @@ HAVE_FLA = False +def _unpack_sequence(x: torch.Tensor, cu_seqlens: torch.Tensor, dim=1) -> list[torch.Tensor]: + unpacked_x = [] + cu_seqlens_list = cu_seqlens.tolist() + num_seqs = len(cu_seqlens_list) - 1 + for i in range(num_seqs): + idx_start = cu_seqlens_list[i] + idx_end = cu_seqlens_list[i + 1] + chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] + unpacked_x.append(x[tuple(chunked_index)]) + return unpacked_x + + @pytest.mark.parametrize( ("tp_size", "sp", "cp_size"), [(1, False, 1), (2, False, 1), (2, True, 1), (1, False, 2), (2, False, 2), (2, True, 2)], From 9e83836a93c9c10c4e56472bba0e8bf410cd6128 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Tue, 26 May 2026 21:43:18 -0700 Subject: [PATCH 7/9] update comments --- megatron/core/ssm/gated_delta_net.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 3c05856bf47..b895b2f58f1 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -361,13 +361,9 @@ def forward( # CP All to All: CP to HP if self.cp_size > 1: + # # Pre-permute head dim so a single unsectioned a2a is equivalent to per-section a2a. qkvzba = qkvzba.index_select(-1, self._thd_head_perm) if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - # Batched: one a2a on the full local THD tensor, then two local - # permutations -- one on the head dim (so a single fused no-split - # a2a still produces the per-channel scatter layout) and one on - # the seq dim (rank-grouped -> per-seq natural order, also folds - # in `_undo_attention_load_balancing`). qkvzba = tensor_a2a_cp2hp( qkvzba, seq_dim=0, @@ -376,6 +372,9 @@ def forward( undo_attention_load_balancing=False, ) if self.cp_size > 1: + # Permute at the seq dim so that a single unsectioned a2a + # is equivalent to per-sequence a2a. + # This also folds the ``_undo_attention_load_balancing`` step. thd_cp_a2a_idx, thd_cp_a2a_inv = _build_thd_cp_a2a_perm( cu_seqlens_q, self.cp_size, seq_len ) From 155ee8e9b78777376f8315d51361f6c92e102d5b Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Tue, 2 Jun 2026 17:33:00 -0700 Subject: [PATCH 8/9] move head_perm to forward --- megatron/core/ssm/gated_delta_net.py | 32 +++++++++++----------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index b895b2f58f1..08c3b504c80 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -138,24 +138,6 @@ def __init__( self.qk_dim_local_tp = self.qk_dim // self.tp_size self.v_dim_local_tp = self.v_dim // self.tp_size - if self.cp_size > 1: - head_perm = _build_head_perm_for_split_sections( - [ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], - self.cp_size, - torch.cuda.current_device(), - ) - else: - head_perm = None - # Registered as a non-persistent buffer to exclude it from state_dict - self.register_buffer("_thd_head_perm", head_perm, persistent=False) - # Input projection (hidden_states -> q, k, v, gate, beta, alpha) # TODO: for now, output gate is forced for GDN. # We may remove this restriction in the future. @@ -362,7 +344,19 @@ def forward( # CP All to All: CP to HP if self.cp_size > 1: # # Pre-permute head dim so a single unsectioned a2a is equivalent to per-section a2a. - qkvzba = qkvzba.index_select(-1, self._thd_head_perm) + head_perm = _build_head_perm_for_split_sections( + [ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + self.pg_collection.cp.size(), + torch.cuda.current_device(), + ) + qkvzba = qkvzba.index_select(-1, head_perm) if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': qkvzba = tensor_a2a_cp2hp( qkvzba, From a0044a562290d25af571ea66354929aadb0f2752 Mon Sep 17 00:00:00 2001 From: Xuanteng Huang Date: Tue, 2 Jun 2026 22:50:17 -0700 Subject: [PATCH 9/9] add lru cache --- megatron/core/ssm/gated_delta_net.py | 8 +++++--- tests/unit_tests/ssm/test_gated_delta_net.py | 16 +++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 08c3b504c80..e3dd9ab2713 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -7,6 +7,7 @@ import logging from dataclasses import dataclass, replace +from functools import lru_cache from typing import List, Optional, Tuple, Union import torch @@ -345,14 +346,14 @@ def forward( if self.cp_size > 1: # # Pre-permute head dim so a single unsectioned a2a is equivalent to per-section a2a. head_perm = _build_head_perm_for_split_sections( - [ + ( self.qk_dim_local_tp, self.qk_dim_local_tp, self.v_dim_local_tp, self.v_dim_local_tp, self.num_value_heads // self.tp_size, self.num_value_heads // self.tp_size, - ], + ), self.pg_collection.cp.size(), torch.cuda.current_device(), ) @@ -737,8 +738,9 @@ def _build_thd_cp_a2a_perm( return idx, inv +@lru_cache(maxsize=8) def _build_head_perm_for_split_sections( - split_sections: List[int], cp_size: int, device: torch.device + split_sections: Tuple[int], cp_size: int, device: torch.device ) -> torch.Tensor: assert all( s % cp_size == 0 for s in split_sections diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 67749b3f1c2..9d54b455436 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -488,9 +488,7 @@ def _batched_a2a_cp2hp(local_t, cu_seqlens, cp_group, split_sections=None): cp_size = cp_group.size() t_global = int(cu_seqlens[-1].item()) if split_sections is not None and cp_size > 1: - head_perm = _build_head_perm_for_split_sections( - list(split_sections), cp_size, local_t.device - ) + head_perm = _build_head_perm_for_split_sections(split_sections, cp_size, local_t.device) local_t = local_t.index_select(-1, head_perm) naive = tensor_a2a_cp2hp( local_t, @@ -526,7 +524,7 @@ def _batched_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): (0, 16, 48, 80), # 3 unequal sequences ], ) - @pytest.mark.parametrize("split_sections", [(8, 8, 4, 4, 4, 4)]) + @pytest.mark.parametrize("split_sections", [(8, 8, 4, 16, 32, 4)]) def test_cp2hp_batched_matches_per_seq(self, cu_seqlens, split_sections): cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) if (torch.diff(cu) % self.cp_size != 0).any(): @@ -534,9 +532,7 @@ def test_cp2hp_batched_matches_per_seq(self, cu_seqlens, split_sections): T_global = cu_seqlens[-1] T_local = T_global // self.cp_size - hidden = 32 - if split_sections is not None: - assert sum(split_sections) == hidden, (split_sections, hidden) + hidden = sum(split_sections) torch.manual_seed(42) local_t = ( torch.rand(T_local, 1, hidden, device=torch.cuda.current_device()) @@ -544,11 +540,9 @@ def test_cp2hp_batched_matches_per_seq(self, cu_seqlens, split_sections): .contiguous() ) - out_ref = self._per_seq_a2a_cp2hp( - local_t, cu, self.cp_group, split_sections=list(split_sections) - ) + out_ref = self._per_seq_a2a_cp2hp(local_t, cu, self.cp_group, split_sections=split_sections) out_fused = self._batched_a2a_cp2hp( - local_t, cu, self.cp_group, split_sections=list(split_sections) + local_t, cu, self.cp_group, split_sections=split_sections ) rank = torch.distributed.get_rank()