-
Notifications
You must be signed in to change notification settings - Fork 6
Support MTP weight reuse with unrolled steps #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/1.1
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,8 @@ | |
| scatter_to_sequence_parallel_region) | ||
| from megatron.core.transformer import TransformerLayer | ||
| from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention | ||
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer | ||
| from megatron.core.transformer.multi_token_prediction import (MultiTokenPredictionBlock, MultiTokenPredictionLayer, | ||
| get_mtp_layer_offset) | ||
| from megatron.core.utils import deprecate_inference_params | ||
| from packaging import version | ||
| from peft.tuners.tuners_utils import BaseTuner | ||
|
|
@@ -394,6 +395,7 @@ def forward( | |
| packed_seq_params: PackedSeqParams = None, | ||
| sequence_len_offset: torch.Tensor = None, | ||
| embedding=None, | ||
| depth_idx: int = None, | ||
| ): | ||
| """ | ||
| Execute the forward pass through the Multi-Token Prediction (MTP) layer. | ||
|
|
@@ -417,14 +419,17 @@ def forward( | |
| Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape | ||
| [s, b, h], and optionally the updated context tensor if cross-attention is used. | ||
| """ | ||
| # TODO: Multimodal compatible | ||
| # current unroll depth | ||
| effective_depth = self.layer_number if depth_idx is None else depth_idx | ||
|
|
||
| assert context is None, 'multi token prediction + cross attention is not yet supported.' | ||
| input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( | ||
| input_ids=input_ids, | ||
| position_ids=position_ids, | ||
| embedding=embedding, | ||
| packed_seq_params=packed_seq_params, | ||
| hidden_states=hidden_states, | ||
| depth=effective_depth, | ||
| ) | ||
| assert not self.transformer_layer.self_attention.config.apply_rope_fusion | ||
| packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' | ||
|
|
@@ -433,7 +438,7 @@ def forward( | |
| rotary_pos_emb = rotary_pos_emb[position_ids[0]] | ||
| else: | ||
| # mrope or not packed_seq | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0) | ||
| if self.config.recompute_granularity == 'full' and self.training: | ||
| hidden_states = self._checkpointed_forward( | ||
| partial( | ||
|
|
@@ -471,13 +476,65 @@ def forward( | |
|
|
||
| MultiTokenPredictionLayer.forward = forward | ||
|
|
||
| def block_forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| context: torch.Tensor = None, | ||
| context_mask: torch.Tensor = None, | ||
| rotary_pos_emb: torch.Tensor = None, | ||
| rotary_pos_cos: torch.Tensor = None, | ||
| rotary_pos_sin: torch.Tensor = None, | ||
| attention_bias: torch.Tensor = None, | ||
| inference_params: InferenceParams = None, | ||
| packed_seq_params: PackedSeqParams = None, | ||
| sequence_len_offset: torch.Tensor = None, | ||
| extra_block_kwargs: Optional[dict] = None, | ||
| embedding=None, | ||
| ): | ||
| """Perform the forward pass through all MTP modules with optional layer reuse.""" | ||
| offset = get_mtp_layer_offset(self.config, self.vp_stage) | ||
| hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) | ||
| hidden_states = hidden_states_list[offset] | ||
|
|
||
| physical_num_layers = len(self.layers) | ||
| unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers | ||
|
|
||
| for step in range(unroll_steps): | ||
| layer = self.layers[step % physical_num_layers] | ||
| global_depth = offset + step + 1 | ||
| hidden_states, input_ids, position_ids = layer( | ||
| input_ids=input_ids, | ||
| position_ids=position_ids, | ||
| hidden_states=hidden_states, | ||
| attention_mask=attention_mask, | ||
| inference_params=inference_params, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| embedding=embedding, | ||
| depth_idx=global_depth, | ||
| **(extra_block_kwargs or {}), | ||
| ) | ||
| hidden_states_list.append(hidden_states) | ||
|
|
||
| hidden_states = torch.cat(hidden_states_list, dim=0) | ||
| return hidden_states | ||
|
Comment on lines
+498
to
+526
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
If this feature is primarily intended for the |
||
|
|
||
| MultiTokenPredictionBlock.forward = block_forward | ||
|
|
||
| def _get_embeddings( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| embedding: Callable, | ||
| hidden_states: torch.Tensor, | ||
| packed_seq_params: Optional[PackedSeqParams] = None, | ||
| depth: int = 1, | ||
| ): | ||
| from megatron.core.transformer.multi_token_prediction import roll_tensor | ||
| from megatron.core.utils import make_viewless_tensor | ||
|
|
@@ -508,13 +565,17 @@ def _get_embeddings( | |
| enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 | ||
| if enable_sp: | ||
| decoder_input = gather_from_sequence_parallel_region(decoder_input) | ||
| decoder_input, _ = roll_tensor( | ||
| decoder_input.transpose(0, 2), | ||
| shifts=-1, | ||
| dims=-1, | ||
| cp_group=self.cp_group, | ||
| packed_seq_params=packed_seq_params, | ||
| ) | ||
| decoder_input = decoder_input.transpose(0, 2) | ||
| # Megatron's roll_tensor is implemented around single-token left shifts, especially | ||
| # for packed sequences / CP, so apply depth as repeated -1 rolls instead of -depth. | ||
| for _ in range(depth): | ||
| decoder_input, _ = roll_tensor( | ||
| decoder_input, | ||
| shifts=-1, | ||
| dims=-1, | ||
| cp_group=self.cp_group, | ||
| packed_seq_params=packed_seq_params, | ||
| ) | ||
| decoder_input = decoder_input.transpose(0, 2).contiguous() | ||
| if enable_sp: | ||
| decoder_input = scatter_to_sequence_parallel_region(decoder_input) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
torch.rolloperation will fail ifrotary_pos_embisNone. This occurs in models that do not use RoPE or MRoPE (e.g., models using absolute position embeddings).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This
rotary_pos_emb is Noneassumption is pre-existing in the base branch; this PR only changes the shift depth fromself.layer_numbertoeffective_depthfor the unrolled MTP case, and does not introduce a new dereference path here. If needed, I can address theNoneguard separately in a follow-up cleanup.