Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ __pycache__/

*.log
*.out
/megatron_output/
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ modelscope
peft>=0.11,<0.19
safetensors
tqdm
transformers>=4.33,<5.4.0
transformers>=4.33,<5.6.0
175 changes: 168 additions & 7 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from transformers.utils import ContextManagers
from typing import Callable, List, Optional, Union

from mcore_bridge.config import ModelConfig
from mcore_bridge.tuners import LoraParallelLinear
from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr,
gc_collect, get_logger, is_master, unwrap_model)
Expand Down Expand Up @@ -45,7 +46,7 @@ class GPTBridge:
hf_shared_expert_key = None
hf_expert_bias_key = 'gate.e_score_correction_bias'

def __init__(self, config):
def __init__(self, config: ModelConfig):
self.config = config
self._disable_tqdm = False
self._target_device = None
Expand Down Expand Up @@ -115,6 +116,8 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
'word_embeddings',
'linear_qkv',
'in_proj',
'in_proj_qkvz',
'in_proj_ba',
'conv1d',
# mla
'linear_q_proj',
Expand Down Expand Up @@ -705,9 +708,9 @@ def _set_moe_state(

def _get_hf_experts_attr(self, is_mtp: bool = False):
# return hf_grouped, is_gate_up
if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe'} or self.llm_model_type in {
if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_5_moe'} or self.llm_model_type in {
'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe',
'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'qwen3_5_moe', 'glm_moe_dsa', 'deepseek_v32'
'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32'
}:
return False, False
elif self.model_type in {'qwen3_vl_moe', 'llama4'} or self.llm_model_type in {'gpt_oss'}:
Expand Down Expand Up @@ -1231,11 +1234,153 @@ def _set_indexer(self, mg_indexer, hf_state_dict, hf_prefix: str, to_mcore: bool
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict

def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
def _set_linear_decoupled_in_proj(self, mg_attn, hf_state_dict, to_mcore: bool):
config = self.config
num_key_heads = config.linear_num_key_heads
key_dim = config.linear_key_head_dim
value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads
hidden_size_block = config.hidden_size // self.fp8_block_size
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
if isinstance(mg_attn.in_proj_qkvz, LoraParallelLinear):
lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load()
assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all(), \
'Need to ensure QKVZ\'s lora_A are consistent'
qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load()
q_lora_B, k_lora_B, v_lora_B = torch.split(
qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0)
lora_B = torch.cat([
*(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]),
hf_state_dict['in_proj_z.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1])
],
dim=1).reshape(-1, qkv_lora_B.shape[-1])
self._set_weight(mg_attn.in_proj_qkvz.lora_A[self._adapter_name].weight, lora_A,
'in_proj_qkvz.lora_A.weight')
self._set_weight(mg_attn.in_proj_qkvz.lora_B[self._adapter_name].weight, lora_B,
'in_proj_qkvz.lora_B.weight')
elif not self._peft_format:
qkv = hf_state_dict['in_proj_qkv.weight'].load()
q, k, v = torch.split(
qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0)
in_proj_weight = torch.cat([
*(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]),
hf_state_dict['in_proj_z.weight'].load().reshape(num_key_heads, -1, config.hidden_size)
],
dim=1).reshape((-1, config.hidden_size))
in_scale_inv = None
if 'in_proj_qkv.weight_scale_inv' in hf_state_dict:
qkv_scale_inv = hf_state_dict['in_proj_qkv.weight_scale_inv'].load()
q_si, k_si, v_si = torch.split(
qkv_scale_inv,
[x * num_key_heads // self.fp8_block_size for x in [key_dim, key_dim, value_dim]],
dim=0)
in_scale_inv = torch.cat([
*(x.reshape(num_key_heads, -1, hidden_size_block) for x in [q_si, k_si, v_si]),
hf_state_dict['in_proj_z.weight_scale_inv'].load().reshape(num_key_heads, -1,
hidden_size_block),
],
dim=1).reshape((-1, hidden_size_block))
self._set_weight(
mg_attn.in_proj_qkvz.weight, in_proj_weight, 'in_proj_qkvz.weight', hf_scale_inv=in_scale_inv)
else:
hf_state_dict = {}
qkv_dim = key_dim * 2 + value_dim
is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj_qkvz,
LoraParallelLinear) and self._peft_format
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
Comment thread
Jintao-Huang marked this conversation as resolved.
if self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
if is_lora:
lora_A, _ = self._get_weight(
None if mg_attn is None else mg_attn.in_proj_qkvz.lora_A[self._adapter_name].weight.data,
f'in_proj_qkvz.lora_A.{self._adapter_name}.weight')
lora_B, _ = self._get_weight(
None if mg_attn is None else mg_attn.in_proj_qkvz.lora_B[self._adapter_name].weight.data,
f'in_proj_qkvz.lora_B.{self._adapter_name}.weight')
if lora_A is not None:
lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1])
self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z'})
for key in ['in_proj_qkv', 'in_proj_z']:
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1])
k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1])
v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1])
hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0)
hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:].reshape(-1, lora_B.shape[-1]).clone()
elif not self._peft_format:
in_proj_weight, scale_inv = self._get_weight(
None if mg_attn is None else mg_attn.in_proj_qkvz.weight.data, 'in_proj_qkvz.weight')
if in_proj_weight is not None:
in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size)
q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size)
k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size)
v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size)
hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0)
hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:].reshape(-1,
config.hidden_size).clone()
if scale_inv is not None:
key_block = key_dim // self.fp8_block_size
qkv_block = qkv_dim // self.fp8_block_size
scale_inv = scale_inv.reshape(num_key_heads, -1, hidden_size_block)
q = scale_inv[:, :key_block].reshape(-1, hidden_size_block)
k = scale_inv[:, key_block:2 * key_block].reshape(-1, hidden_size_block)
v = scale_inv[:, 2 * key_block:qkv_block].reshape(-1, hidden_size_block)
hf_state_dict['in_proj_qkv.weight_scale_inv'] = torch.concat([q, k, v], dim=0)
hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[:, qkv_block:].reshape(
-1, hidden_size_block).clone()
if to_mcore:
if isinstance(mg_attn.in_proj_ba, LoraParallelLinear):
lora_A = hf_state_dict['in_proj_b.lora_A.weight'].load()
assert (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \
'Need to ensure BA\'s lora_A are consistent'
b_lora_B = hf_state_dict['in_proj_b.lora_B.weight'].load()
lora_B = torch.cat([
b_lora_B.reshape(num_key_heads, -1, b_lora_B.shape[-1]),
hf_state_dict['in_proj_a.lora_B.weight'].load().reshape(num_key_heads, -1, b_lora_B.shape[-1]),
],
dim=1).reshape(-1, b_lora_B.shape[-1])
self._set_weight(mg_attn.in_proj_ba.lora_A[self._adapter_name].weight, lora_A,
'in_proj_ba.lora_A.weight')
self._set_weight(mg_attn.in_proj_ba.lora_B[self._adapter_name].weight, lora_B,
'in_proj_ba.lora_B.weight')
elif not self._peft_format:
in_proj_weight = torch.cat([
hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size)
for key in ['in_proj_b', 'in_proj_a']
],
dim=1).reshape((-1, config.hidden_size))
self._set_weight(mg_attn.in_proj_ba.weight, in_proj_weight, 'in_proj_ba.weight')
else:
a_dim = config.linear_num_value_heads // num_key_heads
is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj_ba,
LoraParallelLinear) and self._peft_format
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
Comment thread
Jintao-Huang marked this conversation as resolved.
if self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
if is_lora:
lora_A, _ = self._get_weight(
None if mg_attn is None else mg_attn.in_proj_ba.lora_A[self._adapter_name].weight.data,
f'in_proj_ba.lora_A.{self._adapter_name}.weight')
lora_B, _ = self._get_weight(
None if mg_attn is None else mg_attn.in_proj_ba.lora_B[self._adapter_name].weight.data,
f'in_proj_ba.lora_B.{self._adapter_name}.weight')
if lora_A is not None:
lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1])
self._peft_target_modules.update({'in_proj_b', 'in_proj_a'})
for key in ['in_proj_b', 'in_proj_a']:
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, :-a_dim].reshape(-1, lora_B.shape[-1]).clone()
hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone()
elif not self._peft_format:
in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj_ba.weight.data,
'in_proj_ba.weight')
if in_proj_weight is not None:
in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size)
hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, :-a_dim].reshape(-1,
config.hidden_size).clone()
hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1,
config.hidden_size).clone()
return hf_state_dict

def _set_linear_in_proj(self, mg_attn, hf_state_dict, to_mcore: bool):
config = self.config
num_key_heads = config.linear_num_key_heads
key_dim = config.linear_key_head_dim
Expand Down Expand Up @@ -1314,6 +1459,21 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
-1, config.hidden_size).clone()
hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1,
config.hidden_size).clone()
return hf_state_dict

def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
config = self.config
num_key_heads = config.linear_num_key_heads
key_dim = config.linear_key_head_dim
value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads
if config.linear_decoupled_in_proj:
hf_state_dict.update(self._set_linear_decoupled_in_proj(mg_attn, hf_state_dict, to_mcore))
else:
hf_state_dict.update(self._set_linear_in_proj(mg_attn, hf_state_dict, to_mcore))
if not self._peft_format:
if to_mcore:
conv1d = hf_state_dict['conv1d.weight'].load()
Expand Down Expand Up @@ -1597,7 +1757,8 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
hf_state_dict = {}
self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict)
transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer
if not to_mcore and not self.llm_model_type == 'qwen3_next':
# TODO: check
if not to_mcore and self.llm_model_type in {'deepseek_v3', 'deepseek_v32', 'glm4_moe', 'glm4_moe_lite'}:
Comment thread
Jintao-Huang marked this conversation as resolved.
self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight',
to_mcore)
if self.config.untie_embeddings_and_output_weights:
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class ModelConfig(TransformerConfig):
linear_conv_kernel_dim: Optional[int] = None
layernorm_zero_centered_gamma: bool = False
attention_output_gate: bool = False
linear_decoupled_in_proj: bool = False

# dsa
experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None
Expand Down
13 changes: 11 additions & 2 deletions src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm
from megatron.core.transformer.attention import SelfAttention
from typing import Optional

Expand All @@ -20,8 +21,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo
if is_linear_attention:
hf_state_dict.update(
self._set_linear_attn_state(mg_attn, hf_state_dict, 'linear_attn.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'self_attention.in_proj.layer_norm_weight', hf_state_dict,
'input_layernorm.weight', to_mcore)

if self.config.linear_decoupled_in_proj:
self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight',
to_mcore)
else:
self._set_state_dict(mg_layer, 'self_attention.in_proj.layer_norm_weight', hf_state_dict,
'input_layernorm.weight', to_mcore)
else:
hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict,
Expand All @@ -44,6 +50,9 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
layer_spec.submodules.self_attention.module = GatedSelfAttention
else:
layer_spec.submodules.self_attention.module = GatedDeltaNet
if self.config.linear_decoupled_in_proj:
layer_spec.submodules.input_layernorm = TENorm
layer_spec.submodules.self_attention.submodules.in_proj = TEColumnParallelLinear
return layer_specs

def build_model(
Expand Down
Loading
Loading