Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 95 additions & 47 deletions megatron/core/ssm/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging
from dataclasses import dataclass, replace
from functools import lru_cache
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -342,40 +343,40 @@ def forward(
nvtx_range_pop(suffix="in_proj")

# CP All to All: CP to HP
if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0)
outputs = []
for qkvzba_i in unpacked_qkvzba:
qkvzba_i = tensor_a2a_cp2hp(
qkvzba_i,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
split_sections=[
self.qk_dim_local_tp,
self.qk_dim_local_tp,
self.v_dim_local_tp,
self.v_dim_local_tp,
self.num_value_heads // self.tp_size,
self.num_value_heads // self.tp_size,
],
)
outputs.append(qkvzba_i)
qkvzba = torch.cat(outputs, dim=0)
else:
qkvzba = tensor_a2a_cp2hp(
qkvzba,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
split_sections=[
if self.cp_size > 1:
# # Pre-permute head dim so a single unsectioned a2a is equivalent to per-section a2a.
head_perm = _build_head_perm_for_split_sections(
(
self.qk_dim_local_tp,
self.qk_dim_local_tp,
self.v_dim_local_tp,
self.v_dim_local_tp,
self.num_value_heads // self.tp_size,
self.num_value_heads // self.tp_size,
],
),
self.pg_collection.cp.size(),
torch.cuda.current_device(),
)
qkvzba = qkvzba.index_select(-1, head_perm)
if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
qkvzba = tensor_a2a_cp2hp(
qkvzba,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
undo_attention_load_balancing=False,
)
if self.cp_size > 1:
# Permute at the seq dim so that a single unsectioned a2a
# is equivalent to per-sequence a2a.
# This also folds the ``_undo_attention_load_balancing`` step.
thd_cp_a2a_idx, thd_cp_a2a_inv = _build_thd_cp_a2a_perm(
cu_seqlens_q, self.cp_size, seq_len
)
qkvzba = qkvzba.index_select(0, thd_cp_a2a_idx)
else:
qkvzba = tensor_a2a_cp2hp(
qkvzba, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
)

# Transpose: s b x --> b s x
Expand Down Expand Up @@ -489,14 +490,15 @@ def forward(

# CP all to all: HP to CP
if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0)
outputs = []
for norm_out_i in unpacked_norm_out:
norm_out_i = tensor_a2a_hp2cp(
norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
)
outputs.append(norm_out_i)
norm_out = torch.cat(outputs, dim=0)
if self.cp_size > 1:
norm_out = norm_out.index_select(0, thd_cp_a2a_inv)
norm_out = tensor_a2a_hp2cp(
norm_out,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
redo_attention_load_balancing=False,
)
else:
norm_out = tensor_a2a_hp2cp(
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
Expand Down Expand Up @@ -574,7 +576,7 @@ def _compute_g_and_beta(self, A_log_local_cp, dt_bias_local_cp, alpha, beta):

def _resolve_cu_seqlens(
self, cu_seqlens_padded, cu_seqlens_actual, total_seq_len, name, cp_size: int = 1
):
) -> torch.Tensor:
"""Resolve cu_seqlens for packed sequence all-to-all, handling alignment padding."""
if cu_seqlens_padded is not None:
cu_seqlens = cu_seqlens_padded
Expand Down Expand Up @@ -696,16 +698,62 @@ def _backward_out_proj(self):
self.out_proj.backward_dw()


def _unpack_sequence(x, cu_seqlens, dim=1):
unpacked_x = []
cu_seqlens_list = cu_seqlens.tolist()
num_seqs = len(cu_seqlens_list) - 1
for i in range(num_seqs):
idx_start = cu_seqlens_list[i]
idx_end = cu_seqlens_list[i + 1]
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
unpacked_x.append(x[tuple(chunked_index)])
return unpacked_x
def _build_thd_cp_a2a_perm(
cu_seqlens: torch.Tensor, cp_size: int, t_global: int
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's move toward new-style type specification using built-in tuple instead of typing.Tuple.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I convert all type annotation (Optional, Tuple, etc.) in this file to modern one?

cu = cu_seqlens.to(dtype=torch.long)
t_local = t_global // cp_size

positions = torch.arange(t_global, device=cu.device)
seq_idx = torch.bucketize(positions, cu[1:], right=True)
seq_lens = torch.diff(cu)
halves = seq_lens // (2 * cp_size) # per-sequence half-chunk size
local_starts = cu[:-1] // cp_size
Comment thread
xuantengh marked this conversation as resolved.
global_starts = cu[:-1]

half_i = halves[seq_idx]
pos_in_seq = positions - global_starts[seq_idx]

natural_chunk = pos_in_seq // half_i # in [0, 2*cp)
offset = pos_in_seq - natural_chunk * half_i

# Invert the ordering produced by `_undo_attention_load_balancing`:
# natural_chunk < cp: load_balanced = 2 * natural_chunk
# natural_chunk >= cp: load_balanced = 4*cp - 2*natural_chunk - 1
lb_chunk = torch.where(
natural_chunk < cp_size, 2 * natural_chunk, 4 * cp_size - 2 * natural_chunk - 1
)

# In the per-sequence load-balanced layout each rank owns load-balanced
# chunks (2r) and (2r+1), in that order, of every sequence.
rank = lb_chunk // 2
half_within_rank = lb_chunk - 2 * rank
k = half_within_rank * half_i + offset

idx = rank * t_local + local_starts[seq_idx] + k

inv = torch.empty_like(idx)
inv[idx] = positions

return idx, inv


@lru_cache(maxsize=8)
def _build_head_perm_for_split_sections(
Comment thread
xuantengh marked this conversation as resolved.
split_sections: Tuple[int], cp_size: int, device: torch.device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nit as before, but also, the type should be tuple[int, ...].

) -> torch.Tensor:
assert all(
s % cp_size == 0 for s in split_sections
), f"split_sections {split_sections} must be divisible by cp_size {cp_size} for GDN"
offset = 0
parts = []
for s in split_sections:
parts.append(
torch.arange(offset, offset + s, device=device, dtype=torch.long).view(cp_size, -1)
)
offset += s

return torch.cat(parts, dim=-1).view(-1)


####################
Expand Down
36 changes: 7 additions & 29 deletions megatron/core/ssm/mamba_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F

from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import all_to_all
from megatron.core.tensor_parallel.mappings import all_to_all_hp2sp, all_to_all_sp2hp
from megatron.core.utils import is_te_min_version

try:
Expand Down Expand Up @@ -302,7 +302,6 @@ def _slice_vector_param(self, param: torch.Tensor, has_hdim: bool = False) -> to
return param[start:end]


# TODO(duncan): Consider combining with all_to_all_sp2hp in mappings.py and using einops.rearrange
def _all_to_all_cp2hp(
input_: torch.Tensor, cp_group: torch.distributed.ProcessGroup
) -> torch.Tensor:
Expand All @@ -324,24 +323,12 @@ def _all_to_all_cp2hp(
"""
assert input_.dim() == 3, "all_to_all_cp2hp assumes 3-d input shape."
s_in, b_in, h_in = input_.shape
# Squash the first two dimensions -> [s*b, h]
input_ = input_.reshape(-1, h_in)
# Split into world_size chunks along the h dimension
world_size = cp_group.size()
h_out = h_in // world_size
split_tensors = torch.split(input_, split_size_or_sections=h_out, dim=1)
# Concat the chunks along the s*b dimension
concat_tensor = torch.cat(split_tensors, dim=0)
# TODO(duncan): Can the following be optimized by using the non-single (tensor list) version of
# all-to-all?
# Swap chunks of dim0 across the cp ranks
output = all_to_all(cp_group, concat_tensor)
# Recover the s and b dimensions
output = output.reshape(s_in * world_size, b_in, h_out)
s_out, h_out = s_in * cp_group.size(), h_in // cp_group.size()
output = all_to_all_sp2hp(input_, group=cp_group)
output = output.reshape(s_out, b_in, h_out)
return output

Comment thread
xuantengh marked this conversation as resolved.

# TODO(duncan): Consider combining with all_to_all_hp2sp in mappings.py and using einops.rearrange
def _all_to_all_hp2cp(
input_: torch.Tensor, cp_group: torch.distributed.ProcessGroup
) -> torch.Tensor:
Expand All @@ -363,18 +350,9 @@ def _all_to_all_hp2cp(
"""
assert input_.dim() == 3, "all_to_all_hp2cp assumes 3-d input shape."
s_in, b_in, h_in = input_.shape
# Squash the first two dimensions -> [s*b, h]
input_ = input_.reshape(-1, h_in)
# Swap chunks of dim0 across the cp ranks
input_exchanged = all_to_all(cp_group, input_)
# Split into world_size chunks along the s*b dimension
world_size = cp_group.size()
s_out = s_in // world_size
split_tensors = torch.split(input_exchanged, split_size_or_sections=s_out * b_in, dim=0)
# Concat the chunks along the h dimension
output = torch.cat(split_tensors, dim=-1)
# Recover the s and b dimensions
output = output.reshape(s_out, b_in, h_in * world_size)
s_out, h_out = s_in // cp_group.size(), h_in * cp_group.size()
output = all_to_all_hp2sp(input_, group=cp_group)
output = output.reshape(s_out, b_in, h_out)
return output


Expand Down
Loading