diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 1dbaf161e73..e3dd9ab2713 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -7,6 +7,7 @@ import logging from dataclasses import dataclass, replace +from functools import lru_cache from typing import List, Optional, Tuple, Union import torch @@ -342,40 +343,40 @@ def forward( nvtx_range_pop(suffix="in_proj") # CP All to All: CP to HP - if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0) - outputs = [] - for qkvzba_i in unpacked_qkvzba: - qkvzba_i = tensor_a2a_cp2hp( - qkvzba_i, - seq_dim=0, - head_dim=-1, - cp_group=self.pg_collection.cp, - split_sections=[ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], - ) - outputs.append(qkvzba_i) - qkvzba = torch.cat(outputs, dim=0) - else: - qkvzba = tensor_a2a_cp2hp( - qkvzba, - seq_dim=0, - head_dim=-1, - cp_group=self.pg_collection.cp, - split_sections=[ + if self.cp_size > 1: + # # Pre-permute head dim so a single unsectioned a2a is equivalent to per-section a2a. + head_perm = _build_head_perm_for_split_sections( + ( self.qk_dim_local_tp, self.qk_dim_local_tp, self.v_dim_local_tp, self.v_dim_local_tp, self.num_value_heads // self.tp_size, self.num_value_heads // self.tp_size, - ], + ), + self.pg_collection.cp.size(), + torch.cuda.current_device(), + ) + qkvzba = qkvzba.index_select(-1, head_perm) + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + qkvzba = tensor_a2a_cp2hp( + qkvzba, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + undo_attention_load_balancing=False, + ) + if self.cp_size > 1: + # Permute at the seq dim so that a single unsectioned a2a + # is equivalent to per-sequence a2a. + # This also folds the ``_undo_attention_load_balancing`` step. + thd_cp_a2a_idx, thd_cp_a2a_inv = _build_thd_cp_a2a_perm( + cu_seqlens_q, self.cp_size, seq_len + ) + qkvzba = qkvzba.index_select(0, thd_cp_a2a_idx) + else: + qkvzba = tensor_a2a_cp2hp( + qkvzba, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp ) # Transpose: s b x --> b s x @@ -489,14 +490,15 @@ def forward( # CP all to all: HP to CP if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0) - outputs = [] - for norm_out_i in unpacked_norm_out: - norm_out_i = tensor_a2a_hp2cp( - norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp - ) - outputs.append(norm_out_i) - norm_out = torch.cat(outputs, dim=0) + if self.cp_size > 1: + norm_out = norm_out.index_select(0, thd_cp_a2a_inv) + norm_out = tensor_a2a_hp2cp( + norm_out, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + redo_attention_load_balancing=False, + ) else: norm_out = tensor_a2a_hp2cp( norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp @@ -574,7 +576,7 @@ def _compute_g_and_beta(self, A_log_local_cp, dt_bias_local_cp, alpha, beta): def _resolve_cu_seqlens( self, cu_seqlens_padded, cu_seqlens_actual, total_seq_len, name, cp_size: int = 1 - ): + ) -> torch.Tensor: """Resolve cu_seqlens for packed sequence all-to-all, handling alignment padding.""" if cu_seqlens_padded is not None: cu_seqlens = cu_seqlens_padded @@ -696,16 +698,62 @@ def _backward_out_proj(self): self.out_proj.backward_dw() -def _unpack_sequence(x, cu_seqlens, dim=1): - unpacked_x = [] - cu_seqlens_list = cu_seqlens.tolist() - num_seqs = len(cu_seqlens_list) - 1 - for i in range(num_seqs): - idx_start = cu_seqlens_list[i] - idx_end = cu_seqlens_list[i + 1] - chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] - unpacked_x.append(x[tuple(chunked_index)]) - return unpacked_x +def _build_thd_cp_a2a_perm( + cu_seqlens: torch.Tensor, cp_size: int, t_global: int +) -> Tuple[torch.Tensor, torch.Tensor]: + cu = cu_seqlens.to(dtype=torch.long) + t_local = t_global // cp_size + + positions = torch.arange(t_global, device=cu.device) + seq_idx = torch.bucketize(positions, cu[1:], right=True) + seq_lens = torch.diff(cu) + halves = seq_lens // (2 * cp_size) # per-sequence half-chunk size + local_starts = cu[:-1] // cp_size + global_starts = cu[:-1] + + half_i = halves[seq_idx] + pos_in_seq = positions - global_starts[seq_idx] + + natural_chunk = pos_in_seq // half_i # in [0, 2*cp) + offset = pos_in_seq - natural_chunk * half_i + + # Invert the ordering produced by `_undo_attention_load_balancing`: + # natural_chunk < cp: load_balanced = 2 * natural_chunk + # natural_chunk >= cp: load_balanced = 4*cp - 2*natural_chunk - 1 + lb_chunk = torch.where( + natural_chunk < cp_size, 2 * natural_chunk, 4 * cp_size - 2 * natural_chunk - 1 + ) + + # In the per-sequence load-balanced layout each rank owns load-balanced + # chunks (2r) and (2r+1), in that order, of every sequence. + rank = lb_chunk // 2 + half_within_rank = lb_chunk - 2 * rank + k = half_within_rank * half_i + offset + + idx = rank * t_local + local_starts[seq_idx] + k + + inv = torch.empty_like(idx) + inv[idx] = positions + + return idx, inv + + +@lru_cache(maxsize=8) +def _build_head_perm_for_split_sections( + split_sections: Tuple[int], cp_size: int, device: torch.device +) -> torch.Tensor: + assert all( + s % cp_size == 0 for s in split_sections + ), f"split_sections {split_sections} must be divisible by cp_size {cp_size} for GDN" + offset = 0 + parts = [] + for s in split_sections: + parts.append( + torch.arange(offset, offset + s, device=device, dtype=torch.long).view(cp_size, -1) + ) + offset += s + + return torch.cat(parts, dim=-1).view(-1) #################### diff --git a/megatron/core/ssm/mamba_context_parallel.py b/megatron/core/ssm/mamba_context_parallel.py index 3297728d5fe..5c040716069 100644 --- a/megatron/core/ssm/mamba_context_parallel.py +++ b/megatron/core/ssm/mamba_context_parallel.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel import all_to_all +from megatron.core.tensor_parallel.mappings import all_to_all_hp2sp, all_to_all_sp2hp from megatron.core.utils import is_te_min_version try: @@ -302,7 +302,6 @@ def _slice_vector_param(self, param: torch.Tensor, has_hdim: bool = False) -> to return param[start:end] -# TODO(duncan): Consider combining with all_to_all_sp2hp in mappings.py and using einops.rearrange def _all_to_all_cp2hp( input_: torch.Tensor, cp_group: torch.distributed.ProcessGroup ) -> torch.Tensor: @@ -324,24 +323,12 @@ def _all_to_all_cp2hp( """ assert input_.dim() == 3, "all_to_all_cp2hp assumes 3-d input shape." s_in, b_in, h_in = input_.shape - # Squash the first two dimensions -> [s*b, h] - input_ = input_.reshape(-1, h_in) - # Split into world_size chunks along the h dimension - world_size = cp_group.size() - h_out = h_in // world_size - split_tensors = torch.split(input_, split_size_or_sections=h_out, dim=1) - # Concat the chunks along the s*b dimension - concat_tensor = torch.cat(split_tensors, dim=0) - # TODO(duncan): Can the following be optimized by using the non-single (tensor list) version of - # all-to-all? - # Swap chunks of dim0 across the cp ranks - output = all_to_all(cp_group, concat_tensor) - # Recover the s and b dimensions - output = output.reshape(s_in * world_size, b_in, h_out) + s_out, h_out = s_in * cp_group.size(), h_in // cp_group.size() + output = all_to_all_sp2hp(input_, group=cp_group) + output = output.reshape(s_out, b_in, h_out) return output -# TODO(duncan): Consider combining with all_to_all_hp2sp in mappings.py and using einops.rearrange def _all_to_all_hp2cp( input_: torch.Tensor, cp_group: torch.distributed.ProcessGroup ) -> torch.Tensor: @@ -363,18 +350,9 @@ def _all_to_all_hp2cp( """ assert input_.dim() == 3, "all_to_all_hp2cp assumes 3-d input shape." s_in, b_in, h_in = input_.shape - # Squash the first two dimensions -> [s*b, h] - input_ = input_.reshape(-1, h_in) - # Swap chunks of dim0 across the cp ranks - input_exchanged = all_to_all(cp_group, input_) - # Split into world_size chunks along the s*b dimension - world_size = cp_group.size() - s_out = s_in // world_size - split_tensors = torch.split(input_exchanged, split_size_or_sections=s_out * b_in, dim=0) - # Concat the chunks along the h dimension - output = torch.cat(split_tensors, dim=-1) - # Recover the s and b dimensions - output = output.reshape(s_out, b_in, h_in * world_size) + s_out, h_out = s_in // cp_group.size(), h_in * cp_group.size() + output = all_to_all_hp2sp(input_, group=cp_group) + output = output.reshape(s_out, b_in, h_out) return output diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 6bd7e3ea841..9d54b455436 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -16,7 +16,13 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.ssm.gated_delta_net import GatedDeltaNet +from megatron.core.ssm.gated_delta_net import ( + GatedDeltaNet, + _build_head_perm_for_split_sections, + _build_thd_cp_a2a_perm, + tensor_a2a_cp2hp, + tensor_a2a_hp2cp, +) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig from megatron.core.utils import unwrap_model @@ -44,6 +50,18 @@ HAVE_FLA = False +def _unpack_sequence(x: torch.Tensor, cu_seqlens: torch.Tensor, dim=1) -> list[torch.Tensor]: + unpacked_x = [] + cu_seqlens_list = cu_seqlens.tolist() + num_seqs = len(cu_seqlens_list) - 1 + for i in range(num_seqs): + idx_start = cu_seqlens_list[i] + idx_end = cu_seqlens_list[i + 1] + chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] + unpacked_x.append(x[tuple(chunked_index)]) + return unpacked_x + + @pytest.mark.parametrize( ("tp_size", "sp", "cp_size"), [(1, False, 1), (2, False, 1), (2, True, 1), (1, False, 2), (2, False, 2), (2, True, 2)], @@ -70,19 +88,20 @@ def setup_method(self, tp_size, sp, cp_size): cp_group = parallel_state.get_context_parallel_group() pg_collection = ProcessGroupCollection(tp=tp_group, cp=cp_group) - # Initialize model + # Initialize model, with the same config as Qwen Next except `num_layers` self.transformer_config = TransformerConfig( - hidden_size=256, - linear_conv_kernel_dim=2, - linear_key_head_dim=64, - linear_value_head_dim=64, - linear_num_key_heads=4, - linear_num_value_heads=8, + hidden_size=2048, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, num_layers=1, normalization="RMSNorm", use_cpu_initialization=True, layernorm_zero_centered_gamma=True, - num_attention_heads=8, + num_attention_heads=16, + num_query_groups=2, activation_func=F.silu, bf16=True, tensor_model_parallel_size=tp_size, @@ -406,3 +425,173 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packi micro_batch_size=4, sequence_packing=sequence_packing, ) + + +@pytest.mark.parametrize("cp_size", [2, 4], scope="class") +@pytest.mark.internal +class TestFusedThdAllToAll: + """Verify fused 1 AllToAll + permute matches the per-sequence, per-channel loop in GDN.""" + + @pytest.fixture(scope='class', autouse=True) + def setup_method(self, request, cp_size): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + ) + model_parallel_cuda_manual_seed(123) + # Attach on the class so every test method can read self.cp_*. + request.cls.cp_size = cp_size + request.cls.cp_group = parallel_state.get_context_parallel_group() + yield + Utils.destroy_model_parallel() + + @staticmethod + def _per_seq_a2a_cp2hp(local_t, cu_seqlens, cp_group, split_sections=None): + cp_size = cp_group.size() + unpacked = _unpack_sequence(local_t, cu_seqlens // cp_size, dim=0) + outputs = [] + for x in unpacked: + outputs.append( + tensor_a2a_cp2hp( + x, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + undo_attention_load_balancing=True, + ) + ) + return torch.cat(outputs, dim=0) + + @staticmethod + def _per_seq_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): + unpacked = _unpack_sequence(global_t, cu_seqlens, dim=0) + outputs = [] + for x in unpacked: + outputs.append( + tensor_a2a_hp2cp( + x, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + redo_attention_load_balancing=True, + ) + ) + return torch.cat(outputs, dim=0) + + # ---- Optimized: single a2a + production permutation helper ---- + + @staticmethod + def _batched_a2a_cp2hp(local_t, cu_seqlens, cp_group, split_sections=None): + cp_size = cp_group.size() + t_global = int(cu_seqlens[-1].item()) + if split_sections is not None and cp_size > 1: + head_perm = _build_head_perm_for_split_sections(split_sections, cp_size, local_t.device) + local_t = local_t.index_select(-1, head_perm) + naive = tensor_a2a_cp2hp( + local_t, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=None, # always single fused a2a + undo_attention_load_balancing=False, + ) + idx, _ = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, t_global) + return naive.index_select(0, idx) + + @staticmethod + def _batched_a2a_hp2cp(global_t, cu_seqlens, cp_group, split_sections=None): + cp_size = cp_group.size() + t_global = int(cu_seqlens[-1].item()) + _, inv = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, t_global) + permuted = global_t.index_select(0, inv) + return tensor_a2a_hp2cp( + permuted, + seq_dim=0, + head_dim=-1, + cp_group=cp_group, + split_sections=split_sections, + redo_attention_load_balancing=False, + ) + + @pytest.mark.parametrize( + "cu_seqlens", + [ + (0, 32, 64), # 2 equal sequences + (0, 32, 64, 96, 128), # 4 equal sequences (matches existing THD test) + (0, 16, 48, 80), # 3 unequal sequences + ], + ) + @pytest.mark.parametrize("split_sections", [(8, 8, 4, 16, 32, 4)]) + def test_cp2hp_batched_matches_per_seq(self, cu_seqlens, split_sections): + cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) + if (torch.diff(cu) % self.cp_size != 0).any(): + pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") + + T_global = cu_seqlens[-1] + T_local = T_global // self.cp_size + hidden = sum(split_sections) + torch.manual_seed(42) + local_t = ( + torch.rand(T_local, 1, hidden, device=torch.cuda.current_device()) + .bfloat16() + .contiguous() + ) + + out_ref = self._per_seq_a2a_cp2hp(local_t, cu, self.cp_group, split_sections=split_sections) + out_fused = self._batched_a2a_cp2hp( + local_t, cu, self.cp_group, split_sections=split_sections + ) + + rank = torch.distributed.get_rank() + assert torch.equal(out_fused, out_ref), ( + f"Batched CP->HP mismatch on rank={rank} " f"(split_sections={split_sections})" + ) + + @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64), (0, 32, 64, 96, 128), (0, 16, 48, 80)]) + def test_hp2cp_batched_matches_per_seq(self, cu_seqlens): + cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) + if ((cu[1:] - cu[:-1]) % self.cp_size != 0).any(): + pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") + + T_global = cu_seqlens[-1] + hidden = 32 + # Hidden must be divisible by cp_size for the HP-sharded input layout. + assert hidden % self.cp_size == 0 + h_local = hidden // self.cp_size + torch.manual_seed(42) + global_t = ( + torch.rand(T_global, 1, h_local, device=torch.cuda.current_device()) + .bfloat16() + .contiguous() + ) + + out_ref = self._per_seq_a2a_hp2cp(global_t, cu, self.cp_group) + out_fused = self._batched_a2a_hp2cp(global_t, cu, self.cp_group) + + rank = torch.distributed.get_rank() + assert torch.equal(out_fused, out_ref), f"Batched HP->CP mismatch on rank={rank}" + + @pytest.mark.parametrize("cu_seqlens", [(0, 32, 64, 96, 128)]) + def test_cp2hp_hp2cp_round_trip(self, cu_seqlens): + """cp2hp followed by hp2cp on the batched path should be the identity.""" + cu = torch.tensor(cu_seqlens, dtype=torch.long, device=torch.cuda.current_device()) + if ((cu[1:] - cu[:-1]) % self.cp_size != 0).any(): + pytest.skip(f"cu_seqlens {cu_seqlens} not divisible by cp_size {self.cp_size}") + + T_global = cu_seqlens[-1] + T_local = T_global // self.cp_size + hidden = 32 + torch.manual_seed(7) + local_t = ( + torch.rand(T_local, 1, hidden, device=torch.cuda.current_device()) + .bfloat16() + .contiguous() + ) + + mid = self._batched_a2a_cp2hp(local_t, cu, self.cp_group) + back = self._batched_a2a_hp2cp(mid, cu, self.cp_group) + + assert torch.equal(back, local_t), "Batched cp2hp -> hp2cp not identity"