diff --git a/.gitignore b/.gitignore index a98c3bd..9dc57fb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ *.log *.out +/megatron_output/ diff --git a/requirements.txt b/requirements.txt index 9c0d007..56e5eb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ modelscope peft>=0.11,<0.19 safetensors tqdm -transformers>=4.33,<5.4.0 +transformers>=4.33,<5.6.0 diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 01c61cc..8864739 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -15,6 +15,7 @@ from transformers.utils import ContextManagers from typing import Callable, List, Optional, Union +from mcore_bridge.config import ModelConfig from mcore_bridge.tuners import LoraParallelLinear from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, gc_collect, get_logger, is_master, unwrap_model) @@ -45,7 +46,7 @@ class GPTBridge: hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' - def __init__(self, config): + def __init__(self, config: ModelConfig): self.config = config self._disable_tqdm = False self._target_device = None @@ -115,6 +116,8 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: 'word_embeddings', 'linear_qkv', 'in_proj', + 'in_proj_qkvz', + 'in_proj_ba', 'conv1d', # mla 'linear_q_proj', @@ -705,9 +708,9 @@ def _set_moe_state( def _get_hf_experts_attr(self, is_mtp: bool = False): # return hf_grouped, is_gate_up - if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe'} or self.llm_model_type in { + if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_5_moe'} or self.llm_model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe', - 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'qwen3_5_moe', 'glm_moe_dsa', 'deepseek_v32' + 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32' }: return False, False elif self.model_type in {'qwen3_vl_moe', 'llama4'} or self.llm_model_type in {'gpt_oss'}: @@ -1231,11 +1234,153 @@ def _set_indexer(self, mg_indexer, hf_state_dict, hf_prefix: str, to_mcore: bool hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + def _set_linear_decoupled_in_proj(self, mg_attn, hf_state_dict, to_mcore: bool): + config = self.config + num_key_heads = config.linear_num_key_heads + key_dim = config.linear_key_head_dim + value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads + hidden_size_block = config.hidden_size // self.fp8_block_size if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + if isinstance(mg_attn.in_proj_qkvz, LoraParallelLinear): + lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() + assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all(), \ + 'Need to ensure QKVZ\'s lora_A are consistent' + qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() + q_lora_B, k_lora_B, v_lora_B = torch.split( + qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) + lora_B = torch.cat([ + *(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]), + hf_state_dict['in_proj_z.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) + ], + dim=1).reshape(-1, qkv_lora_B.shape[-1]) + self._set_weight(mg_attn.in_proj_qkvz.lora_A[self._adapter_name].weight, lora_A, + 'in_proj_qkvz.lora_A.weight') + self._set_weight(mg_attn.in_proj_qkvz.lora_B[self._adapter_name].weight, lora_B, + 'in_proj_qkvz.lora_B.weight') + elif not self._peft_format: + qkv = hf_state_dict['in_proj_qkv.weight'].load() + q, k, v = torch.split( + qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) + in_proj_weight = torch.cat([ + *(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]), + hf_state_dict['in_proj_z.weight'].load().reshape(num_key_heads, -1, config.hidden_size) + ], + dim=1).reshape((-1, config.hidden_size)) + in_scale_inv = None + if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: + qkv_scale_inv = hf_state_dict['in_proj_qkv.weight_scale_inv'].load() + q_si, k_si, v_si = torch.split( + qkv_scale_inv, + [x * num_key_heads // self.fp8_block_size for x in [key_dim, key_dim, value_dim]], + dim=0) + in_scale_inv = torch.cat([ + *(x.reshape(num_key_heads, -1, hidden_size_block) for x in [q_si, k_si, v_si]), + hf_state_dict['in_proj_z.weight_scale_inv'].load().reshape(num_key_heads, -1, + hidden_size_block), + ], + dim=1).reshape((-1, hidden_size_block)) + self._set_weight( + mg_attn.in_proj_qkvz.weight, in_proj_weight, 'in_proj_qkvz.weight', hf_scale_inv=in_scale_inv) else: - hf_state_dict = {} + qkv_dim = key_dim * 2 + value_dim + is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj_qkvz, + LoraParallelLinear) and self._peft_format + is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + lora_A, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_qkvz.lora_A[self._adapter_name].weight.data, + f'in_proj_qkvz.lora_A.{self._adapter_name}.weight') + lora_B, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_qkvz.lora_B[self._adapter_name].weight.data, + f'in_proj_qkvz.lora_B.{self._adapter_name}.weight') + if lora_A is not None: + lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) + self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z'}) + for key in ['in_proj_qkv', 'in_proj_z']: + hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() + q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1]) + k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1]) + v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1]) + hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0) + hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:].reshape(-1, lora_B.shape[-1]).clone() + elif not self._peft_format: + in_proj_weight, scale_inv = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_qkvz.weight.data, 'in_proj_qkvz.weight') + if in_proj_weight is not None: + in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) + q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) + k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size) + v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size) + hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) + hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:].reshape(-1, + config.hidden_size).clone() + if scale_inv is not None: + key_block = key_dim // self.fp8_block_size + qkv_block = qkv_dim // self.fp8_block_size + scale_inv = scale_inv.reshape(num_key_heads, -1, hidden_size_block) + q = scale_inv[:, :key_block].reshape(-1, hidden_size_block) + k = scale_inv[:, key_block:2 * key_block].reshape(-1, hidden_size_block) + v = scale_inv[:, 2 * key_block:qkv_block].reshape(-1, hidden_size_block) + hf_state_dict['in_proj_qkv.weight_scale_inv'] = torch.concat([q, k, v], dim=0) + hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[:, qkv_block:].reshape( + -1, hidden_size_block).clone() + if to_mcore: + if isinstance(mg_attn.in_proj_ba, LoraParallelLinear): + lora_A = hf_state_dict['in_proj_b.lora_A.weight'].load() + assert (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ + 'Need to ensure BA\'s lora_A are consistent' + b_lora_B = hf_state_dict['in_proj_b.lora_B.weight'].load() + lora_B = torch.cat([ + b_lora_B.reshape(num_key_heads, -1, b_lora_B.shape[-1]), + hf_state_dict['in_proj_a.lora_B.weight'].load().reshape(num_key_heads, -1, b_lora_B.shape[-1]), + ], + dim=1).reshape(-1, b_lora_B.shape[-1]) + self._set_weight(mg_attn.in_proj_ba.lora_A[self._adapter_name].weight, lora_A, + 'in_proj_ba.lora_A.weight') + self._set_weight(mg_attn.in_proj_ba.lora_B[self._adapter_name].weight, lora_B, + 'in_proj_ba.lora_B.weight') + elif not self._peft_format: + in_proj_weight = torch.cat([ + hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size) + for key in ['in_proj_b', 'in_proj_a'] + ], + dim=1).reshape((-1, config.hidden_size)) + self._set_weight(mg_attn.in_proj_ba.weight, in_proj_weight, 'in_proj_ba.weight') + else: + a_dim = config.linear_num_value_heads // num_key_heads + is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj_ba, + LoraParallelLinear) and self._peft_format + is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + lora_A, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_ba.lora_A[self._adapter_name].weight.data, + f'in_proj_ba.lora_A.{self._adapter_name}.weight') + lora_B, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_ba.lora_B[self._adapter_name].weight.data, + f'in_proj_ba.lora_B.{self._adapter_name}.weight') + if lora_A is not None: + lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) + self._peft_target_modules.update({'in_proj_b', 'in_proj_a'}) + for key in ['in_proj_b', 'in_proj_a']: + hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() + hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, :-a_dim].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() + elif not self._peft_format: + in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj_ba.weight.data, + 'in_proj_ba.weight') + if in_proj_weight is not None: + in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) + hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, :-a_dim].reshape(-1, + config.hidden_size).clone() + hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, + config.hidden_size).clone() + return hf_state_dict + + def _set_linear_in_proj(self, mg_attn, hf_state_dict, to_mcore: bool): config = self.config num_key_heads = config.linear_num_key_heads key_dim = config.linear_key_head_dim @@ -1314,6 +1459,21 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i -1, config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() + return hf_state_dict + + def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + config = self.config + num_key_heads = config.linear_num_key_heads + key_dim = config.linear_key_head_dim + value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads + if config.linear_decoupled_in_proj: + hf_state_dict.update(self._set_linear_decoupled_in_proj(mg_attn, hf_state_dict, to_mcore)) + else: + hf_state_dict.update(self._set_linear_in_proj(mg_attn, hf_state_dict, to_mcore)) if not self._peft_format: if to_mcore: conv1d = hf_state_dict['conv1d.weight'].load() @@ -1597,7 +1757,8 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = {} self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer - if not to_mcore and not self.llm_model_type == 'qwen3_next': + # TODO: check + if not to_mcore and self.llm_model_type in {'deepseek_v3', 'deepseek_v32', 'glm4_moe', 'glm4_moe_lite'}: self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) if self.config.untie_embeddings_and_output_weights: diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 253d12b..10db680 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -193,6 +193,7 @@ class ModelConfig(TransformerConfig): linear_conv_kernel_dim: Optional[int] = None layernorm_zero_centered_gamma: bool = False attention_output_gate: bool = False + linear_decoupled_in_proj: bool = False # dsa experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py b/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py index 780e0fd..d3c40a8 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm from megatron.core.transformer.attention import SelfAttention from typing import Optional @@ -20,8 +21,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo if is_linear_attention: hf_state_dict.update( self._set_linear_attn_state(mg_attn, hf_state_dict, 'linear_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'self_attention.in_proj.layer_norm_weight', hf_state_dict, - 'input_layernorm.weight', to_mcore) + + if self.config.linear_decoupled_in_proj: + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', + to_mcore) + else: + self._set_state_dict(mg_layer, 'self_attention.in_proj.layer_norm_weight', hf_state_dict, + 'input_layernorm.weight', to_mcore) else: hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, @@ -44,6 +50,9 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_spec.submodules.self_attention.module = GatedSelfAttention else: layer_spec.submodules.self_attention.module = GatedDeltaNet + if self.config.linear_decoupled_in_proj: + layer_spec.submodules.input_layernorm = TENorm + layer_spec.submodules.self_attention.submodules.in_proj = TEColumnParallelLinear return layer_specs def build_model( diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 87580fd..9d372d0 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -1,11 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch import torch.nn.functional as F +import transformer_engine +from contextlib import nullcontext from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional +from mcore_bridge.config import ModelConfig + try: from fla.modules.convolution import causal_conv1d from fla.modules.l2norm import l2norm @@ -17,7 +23,7 @@ try: from megatron.core.ssm.gated_delta_net import GatedDeltaNet as _GatedDeltaNet - from megatron.core.ssm.gated_delta_net import torch_chunk_gated_delta_rule + from megatron.core.ssm.gated_delta_net import GatedDeltaNetSubmodules, torch_chunk_gated_delta_rule except ImportError: _GatedDeltaNet = object @@ -83,6 +89,49 @@ def get_parameter_local_cp( class GatedDeltaNet(_GatedDeltaNet): + def __init__(self, config: ModelConfig, submodules: 'GatedDeltaNetSubmodules', *args, **kwargs): + if config.linear_decoupled_in_proj: + in_proj = submodules.in_proj + submodules.in_proj = IdentityOp + super().__init__(config, submodules, *args, **kwargs) + if not config.linear_decoupled_in_proj: + return + submodules.in_proj = in_proj + self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2 + self.in_proj_ba_dim = self.num_value_heads * 2 + del self.in_proj + self.in_proj_qkvz = build_module( + submodules.in_proj, + self.hidden_size, + self.in_proj_qkvz_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1', + tp_group=self.pg_collection.tp, + ) + if config.fp8_param: + fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) + else: + fp8_context = nullcontext() + with fp8_context: + self.in_proj_ba = build_module( + submodules.in_proj, + self.hidden_size, + self.in_proj_ba_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1_ba', + tp_group=self.pg_collection.tp, + ) + def forward( self, hidden_states: torch.Tensor, @@ -132,8 +181,21 @@ def forward( cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q # Input projection + num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size nvtx_range_push(suffix='in_proj') - qkvzba, _ = self.in_proj(hidden_states) + if self.config.linear_decoupled_in_proj: + qkvz, _ = self.in_proj_qkvz(hidden_states) + if self.config.fp8_param: + fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) + else: + fp8_context = nullcontext() + with fp8_context: + ba, _ = self.in_proj_ba(hidden_states) + qkvz = qkvz.view(qkvz.shape[:-1] + (num_key_heads_per_device, qkvz.shape[-1] // num_key_heads_per_device)) + ba = ba.view(ba.shape[:-1] + (num_key_heads_per_device, ba.shape[-1] // num_key_heads_per_device)) + qkvzba = torch.concat([qkvz, ba], dim=-1).view(*qkvz.shape[:2], -1) + else: + qkvzba, _ = self.in_proj(hidden_states) nvtx_range_pop(suffix='in_proj') if cp_size > 1: @@ -158,15 +220,11 @@ def forward( head_dim=-1, cp_group=self.pg_collection.cp, ) - # Transpose: s b x --> b s x # From sbhd to bshd format - qkvzba = qkvzba.transpose(0, 1) - - # Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha - num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size qkvzba = qkvzba.view(qkvzba.shape[:-1] + (num_key_heads_per_device, qkvzba.shape[-1] // num_key_heads_per_device)) + qkvzba = qkvzba.transpose(0, 1) qkv, gate, beta, alpha = torch.split( qkvzba, [ diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 4b17274..a4dd6d6 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -418,7 +418,6 @@ def forward( Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used. """ - # TODO: Multimodal compatible assert context is None, 'multi token prediction + cross attention is not yet supported.' input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids,