-
Notifications
You must be signed in to change notification settings - Fork 13
[bugfix] Fix mtp fp8 #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
9055478
fix mtp fp8
Jintao-Huang 8193452
update
Jintao-Huang 4e0cafa
update
Jintao-Huang 2d75dc5
update
Jintao-Huang c355297
fix
Jintao-Huang d48a0d2
update
Jintao-Huang fddc291
Merge remote-tracking branch 'refs/remotes/origin/fix_mtp_fp8' into f…
Jintao-Huang ffe4dc7
fix
Jintao-Huang 1f7fdc2
update
Jintao-Huang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||
| from .gated_delta_net import GatedDeltaNet | ||
| from .gated_self_attention import GatedSelfAttention | ||
| from .mtp_layer import MultiTokenPredictionLayer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,218 @@ | ||
| import torch | ||
| import transformer_engine | ||
| from contextlib import nullcontext | ||
| from functools import partial | ||
| from megatron.core import InferenceParams | ||
| from megatron.core.packed_seq_params import PackedSeqParams | ||
| from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, | ||
| gather_from_tensor_model_parallel_region, | ||
| scatter_to_sequence_parallel_region) | ||
| from megatron.core.transformer.identity_op import IdentityOp | ||
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer as _MultiTokenPredictionLayer | ||
| from megatron.core.transformer.spec_utils import build_module | ||
| from megatron.core.utils import make_viewless_tensor | ||
| from typing import Callable, Optional | ||
|
|
||
| from mcore_bridge.config import ModelConfig | ||
|
|
||
|
|
||
| class MultiTokenPredictionLayer(_MultiTokenPredictionLayer): | ||
|
|
||
| def __init__(self, config: ModelConfig, submodules, *args, **kwargs): | ||
| if config.fp8_param: | ||
| eh_proj = submodules.eh_proj | ||
| submodules.eh_proj = IdentityOp | ||
| super().__init__(config, submodules, *args, **kwargs) | ||
| if not config.fp8_param: | ||
| return | ||
| del self.eh_proj | ||
| submodules.eh_proj = eh_proj | ||
| fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) | ||
| with fp8_context: | ||
| self.eh_proj = build_module( | ||
| self.submodules.eh_proj, | ||
| self.config.hidden_size * 2, | ||
| self.config.hidden_size, | ||
| config=self.config, | ||
| init_method=self.config.init_method, | ||
| gather_output=False, | ||
| bias=False, | ||
| skip_bias_add=False, | ||
| is_expert=False, | ||
| tp_comm_buffer_name='mtp_eh_proj', | ||
| tp_group=self.tp_group, | ||
|
Jintao-Huang marked this conversation as resolved.
|
||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| context: torch.Tensor = None, | ||
| context_mask: torch.Tensor = None, | ||
| rotary_pos_emb: torch.Tensor = None, | ||
| rotary_pos_cos: torch.Tensor = None, | ||
| rotary_pos_sin: torch.Tensor = None, | ||
| attention_bias: torch.Tensor = None, | ||
| inference_params: InferenceParams = None, | ||
| packed_seq_params: PackedSeqParams = None, | ||
| sequence_len_offset: torch.Tensor = None, | ||
| embedding=None, | ||
| ): | ||
| """ | ||
| Execute the forward pass through the Multi-Token Prediction (MTP) layer. | ||
|
|
||
| Args: | ||
| input_ids (Tensor): Input token IDs . | ||
| position_ids (Tensor): Positional IDs of the input tokens. | ||
| hidden_states (Tensor): Hidden states tensor of shape [s, b, h] where s is the | ||
| sequence length, b is the batch size, and h is the hidden size. | ||
| attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking | ||
| self-attention. | ||
| context (Tensor, optional): Context tensor for cross-attention, if applicable. | ||
| context_mask (Tensor, optional): Mask for cross-attention context, if applicable. | ||
| rotary_pos_emb (Tensor, optional): Rotary positional embeddings. | ||
| rotary_pos_cos (Tensor, optional): Cosine component of rotary positional embeddings. | ||
| rotary_pos_sin (Tensor, optional): Sine component of rotary positional embeddings. | ||
| sequence_len_offset (Tensor, optional): Offset for sequence length, if applicable. | ||
| embedding (Callable): The embedding module from gpt model to compute the decoder input. | ||
|
|
||
| Returns: | ||
| Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape | ||
| [s, b, h], and optionally the updated context tensor if cross-attention is used. | ||
| """ | ||
|
Jintao-Huang marked this conversation as resolved.
Outdated
|
||
| assert context is None, 'multi token prediction + cross attention is not yet supported.' | ||
| input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( | ||
| input_ids=input_ids, | ||
| position_ids=position_ids, | ||
| embedding=embedding, | ||
| packed_seq_params=packed_seq_params, | ||
| hidden_states=hidden_states, | ||
| ) | ||
| assert not self.transformer_layer.self_attention.config.apply_rope_fusion | ||
| packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' | ||
| if self.config.position_embedding_type == 'rope' and packed_seq: | ||
| assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' | ||
| rotary_pos_emb = rotary_pos_emb[position_ids[0]] | ||
| else: | ||
| # mrope or not packed_seq | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) | ||
| if self.config.recompute_granularity == 'full' and self.training: | ||
| hidden_states = self._checkpointed_forward( | ||
| partial( | ||
| self._proj_and_transformer_layer, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| ), | ||
| hidden_states=hidden_states, | ||
| decoder_input=decoder_input, | ||
| attention_mask=attention_mask, | ||
| context=context, | ||
| context_mask=context_mask, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| attention_bias=attention_bias, | ||
| inference_params=inference_params, | ||
| ) | ||
| else: | ||
| hidden_states = self._proj_and_transformer_layer( | ||
| hidden_states=hidden_states, | ||
| decoder_input=decoder_input, | ||
| attention_mask=attention_mask, | ||
| context=context, | ||
| context_mask=context_mask, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| attention_bias=attention_bias, | ||
| inference_params=inference_params, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| ) | ||
| return hidden_states, input_ids, position_ids | ||
|
|
||
| def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor): | ||
| """ | ||
| Concatenate the tokens before sending to transformer layer. | ||
| """ | ||
| try: | ||
| from megatron.core.typed_torch import apply_module | ||
|
Jintao-Huang marked this conversation as resolved.
|
||
| decoder_input = apply_module(self.enorm)(decoder_input) | ||
| hidden_states = apply_module(self.hnorm)(hidden_states) | ||
| except ImportError: | ||
| decoder_input = self.enorm(decoder_input) | ||
| hidden_states = self.hnorm(hidden_states) | ||
| decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) | ||
| hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) | ||
| # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states | ||
| # and the (i + K)-th token's embedding, and combine them with linear projection. | ||
| hidden_states = torch.cat((decoder_input, hidden_states), -1) | ||
| if self.config.fp8_param: | ||
| fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) | ||
| else: | ||
| fp8_context = nullcontext() | ||
| with fp8_context: | ||
| hidden_states, _ = self.eh_proj(hidden_states) | ||
| # For tensor parallel we need to gather the tensor across the model-parallel | ||
| # ranks after the linear projection. This used to call | ||
| # `all_gather_last_dim_from_tensor_parallel_region`, but that utility reduces | ||
| # the gradient in backward pass and was therefore incorrect in this context. | ||
| # It has been replaced with the correct `gather_from_tensor_model_parallel_region`. | ||
| hidden_states = gather_from_tensor_model_parallel_region(hidden_states, group=self.tp_group) | ||
| # For sequence parallel, scatter after linear_fc and before transformer layer. | ||
| if self.sequence_parallel: | ||
| hidden_states = scatter_to_sequence_parallel_region(hidden_states, group=self.tp_group) | ||
|
Jintao-Huang marked this conversation as resolved.
|
||
| return hidden_states | ||
|
|
||
| def _get_embeddings( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| embedding: Callable, | ||
| hidden_states: torch.Tensor, | ||
| packed_seq_params: Optional[PackedSeqParams] = None, | ||
| ): | ||
| from megatron.core.transformer.multi_token_prediction import roll_tensor | ||
| from megatron.core.utils import make_viewless_tensor | ||
|
Jintao-Huang marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Calc logits for the current Multi-Token Prediction (MTP) layers. | ||
| input_ids, _ = roll_tensor( | ||
| input_ids, | ||
| shifts=-1, | ||
| dims=-1, | ||
| cp_group=self.cp_group, | ||
| packed_seq_params=packed_seq_params, | ||
| ) | ||
| position_ids, _ = roll_tensor( | ||
| position_ids, | ||
| shifts=-1, | ||
| dims=-1, | ||
| cp_group=self.cp_group, | ||
| packed_seq_params=packed_seq_params, | ||
| ) | ||
| # embedding | ||
| if isinstance(embedding, tuple): | ||
| embedding, decoder_input = embedding | ||
| else: | ||
| decoder_input = None | ||
| if decoder_input is None: | ||
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) | ||
| else: | ||
| enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 | ||
| if enable_sp: | ||
| decoder_input = gather_from_sequence_parallel_region(decoder_input) | ||
| decoder_input, _ = roll_tensor( | ||
| decoder_input.transpose(0, 2), | ||
| shifts=-1, | ||
| dims=-1, | ||
| cp_group=self.cp_group, | ||
| packed_seq_params=packed_seq_params, | ||
| ) | ||
| decoder_input = decoder_input.transpose(0, 2).contiguous() | ||
| if enable_sp: | ||
| decoder_input = scatter_to_sequence_parallel_region(decoder_input) | ||
| hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) | ||
|
|
||
| return input_ids, position_ids, decoder_input, hidden_states | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.