diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 786b147..05df298 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -53,6 +53,7 @@ def __init__(self, config: ModelConfig): self._only_master_rank = False self._peft_target_modules = set() self._peft_modules_to_save = set() + self._fp8_skip_modules = set() self._peft_format = False self._adapter_name = 'default' self.model_type = config.hf_model_type @@ -1747,6 +1748,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) + self._fp8_skip_modules.update({'eh_proj'}) self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -1862,6 +1864,7 @@ def export_weights( self._disable_tqdm = disable_tqdm self._peft_target_modules = set() self._peft_modules_to_save = set() + self._fp8_skip_modules = set() hf_prefix = 'base_model.model.' if peft_format else '' mg_models = unwrap_model(mg_models) for i, mg_model in enumerate(mg_models): diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index bc8aee8..9ec681d 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -532,6 +532,7 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'], ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) + self._fp8_skip_modules.update({'mtp.fc'}) self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) if not to_mcore: origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 0cae72b..6fd1ac7 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention +from .mtp_layer import MultiTokenPredictionLayer diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py new file mode 100644 index 0000000..26a68b8 --- /dev/null +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -0,0 +1,195 @@ +import torch +import transformer_engine +from contextlib import nullcontext +from functools import partial +from megatron.core import InferenceParams +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + scatter_to_sequence_parallel_region) +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer as _MultiTokenPredictionLayer +from megatron.core.transformer.spec_utils import build_module +from megatron.core.utils import make_viewless_tensor +from typing import Callable, Optional + +from mcore_bridge.config import ModelConfig + + +class MultiTokenPredictionLayer(_MultiTokenPredictionLayer): + + def __init__(self, config: ModelConfig, submodules, *args, **kwargs): + if config.fp8_param: + eh_proj = submodules.eh_proj + submodules.eh_proj = IdentityOp + super().__init__(config, submodules, *args, **kwargs) + self.tp_group = getattr(self, 'tp_group', None) + if not config.fp8_param: + return + submodules.eh_proj = eh_proj + fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) + with fp8_context: + self.eh_proj = build_module( + self.submodules.eh_proj, + self.config.hidden_size * 2, + self.config.hidden_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='mtp_eh_proj', + tp_group=self.tp_group, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor = None, + context_mask: torch.Tensor = None, + rotary_pos_emb: torch.Tensor = None, + rotary_pos_cos: torch.Tensor = None, + rotary_pos_sin: torch.Tensor = None, + attention_bias: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + sequence_len_offset: torch.Tensor = None, + embedding=None, + ): + assert context is None, 'multi token prediction + cross attention is not yet supported.' + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + packed_seq_params=packed_seq_params, + hidden_states=hidden_states, + ) + assert not self.transformer_layer.self_attention.config.apply_rope_fusion + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.config.position_embedding_type == 'rope' and packed_seq: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + rotary_pos_emb = rotary_pos_emb[position_ids[0]] + else: + # mrope or not packed_seq + rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + partial( + self._proj_and_transformer_layer, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ), + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + ) + else: + hidden_states = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + return hidden_states, input_ids, position_ids + + def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor): + """ + Concatenate the tokens before sending to transformer layer. + """ + try: + from megatron.core.typed_torch import apply_module + decoder_input = apply_module(self.enorm)(decoder_input) + hidden_states = apply_module(self.hnorm)(hidden_states) + except ImportError: + decoder_input = self.enorm(decoder_input) + hidden_states = self.hnorm(hidden_states) + decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states + # and the (i + K)-th token's embedding, and combine them with linear projection. + hidden_states = torch.cat((decoder_input, hidden_states), -1) + if self.config.fp8_param: + fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) + else: + fp8_context = nullcontext() + with fp8_context: + hidden_states, _ = self.eh_proj(hidden_states) + # For tensor parallel we need to gather the tensor across the model-parallel + # ranks after the linear projection. This used to call + # `all_gather_last_dim_from_tensor_parallel_region`, but that utility reduces + # the gradient in backward pass and was therefore incorrect in this context. + # It has been replaced with the correct `gather_from_tensor_model_parallel_region`. + hidden_states = gather_from_tensor_model_parallel_region(hidden_states, group=self.tp_group) + # For sequence parallel, scatter after linear_fc and before transformer layer. + if self.sequence_parallel: + hidden_states = scatter_to_sequence_parallel_region(hidden_states, group=self.tp_group) + return hidden_states + + def _get_embeddings( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + ): + from megatron.core.transformer.multi_token_prediction import roll_tensor + + # Calc logits for the current Multi-Token Prediction (MTP) layers. + input_ids, _ = roll_tensor( + input_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + position_ids, _ = roll_tensor( + position_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + # embedding + if isinstance(embedding, tuple): + embedding, decoder_input = embedding + else: + decoder_input = None + if decoder_input is None: + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + else: + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + if enable_sp: + decoder_input = gather_from_sequence_parallel_region(decoder_input) + decoder_input, _ = roll_tensor( + decoder_input.transpose(0, 2), + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + decoder_input = decoder_input.transpose(0, 2).contiguous() + if enable_sp: + decoder_input = scatter_to_sequence_parallel_region(decoder_input) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return input_ids, position_ids, decoder_input, hidden_states diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index e4d5d36..9610a6d 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -15,6 +15,8 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger +from .modules import MultiTokenPredictionLayer + if TYPE_CHECKING: from .gpt_model import GPTModel from .mm_gpt_model import MultimodalGPTModel @@ -121,8 +123,11 @@ def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = N else: transformer_layer_spec_for_mtp = transformer_layer_spec kwargs = {'vp_stage': vp_stage} if self.mcore_013 else {} - return get_gpt_mtp_block_spec( + mtp_block_spec = get_gpt_mtp_block_spec( self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs) + for layer_spec in mtp_block_spec.layer_specs: + layer_spec.module = MultiTokenPredictionLayer + return mtp_block_spec def _set_shared_expert_gate(self, transformer_layer_spec): mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index a4dd6d6..0b50d21 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -377,154 +377,6 @@ def sharded_state_dict( TEGroupedLinear.sharded_state_dict = sharded_state_dict -def _patch_mtp(): - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - context: torch.Tensor = None, - context_mask: torch.Tensor = None, - rotary_pos_emb: torch.Tensor = None, - rotary_pos_cos: torch.Tensor = None, - rotary_pos_sin: torch.Tensor = None, - attention_bias: torch.Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - sequence_len_offset: torch.Tensor = None, - embedding=None, - ): - """ - Execute the forward pass through the Multi-Token Prediction (MTP) layer. - - Args: - input_ids (Tensor): Input token IDs . - position_ids (Tensor): Positional IDs of the input tokens. - hidden_states (Tensor): Hidden states tensor of shape [s, b, h] where s is the - sequence length, b is the batch size, and h is the hidden size. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - context (Tensor, optional): Context tensor for cross-attention, if applicable. - context_mask (Tensor, optional): Mask for cross-attention context, if applicable. - rotary_pos_emb (Tensor, optional): Rotary positional embeddings. - rotary_pos_cos (Tensor, optional): Cosine component of rotary positional embeddings. - rotary_pos_sin (Tensor, optional): Sine component of rotary positional embeddings. - sequence_len_offset (Tensor, optional): Offset for sequence length, if applicable. - embedding (Callable): The embedding module from gpt model to compute the decoder input. - - Returns: - Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ - assert context is None, 'multi token prediction + cross attention is not yet supported.' - input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( - input_ids=input_ids, - position_ids=position_ids, - embedding=embedding, - packed_seq_params=packed_seq_params, - hidden_states=hidden_states, - ) - assert not self.transformer_layer.self_attention.config.apply_rope_fusion - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.config.position_embedding_type == 'rope' and packed_seq: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - rotary_pos_emb = rotary_pos_emb[position_ids[0]] - else: - # mrope or not packed_seq - rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) - if self.config.recompute_granularity == 'full' and self.training: - hidden_states = self._checkpointed_forward( - partial( - self._proj_and_transformer_layer, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ), - hidden_states=hidden_states, - decoder_input=decoder_input, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_params=inference_params, - ) - else: - hidden_states = self._proj_and_transformer_layer( - hidden_states=hidden_states, - decoder_input=decoder_input, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) - return hidden_states, input_ids, position_ids - - MultiTokenPredictionLayer.forward = forward - - def _get_embeddings( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - embedding: Callable, - hidden_states: torch.Tensor, - packed_seq_params: Optional[PackedSeqParams] = None, - ): - from megatron.core.transformer.multi_token_prediction import roll_tensor - from megatron.core.utils import make_viewless_tensor - - # Calc logits for the current Multi-Token Prediction (MTP) layers. - input_ids, _ = roll_tensor( - input_ids, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - position_ids, _ = roll_tensor( - position_ids, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - # embedding - if isinstance(embedding, tuple): - embedding, decoder_input = embedding - else: - decoder_input = None - if decoder_input is None: - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) - else: - enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 - if enable_sp: - decoder_input = gather_from_sequence_parallel_region(decoder_input) - decoder_input, _ = roll_tensor( - decoder_input.transpose(0, 2), - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - decoder_input = decoder_input.transpose(0, 2).contiguous() - if enable_sp: - decoder_input = scatter_to_sequence_parallel_region(decoder_input) - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - return input_ids, position_ids, decoder_input, hidden_states - - MultiTokenPredictionLayer._get_embeddings = _get_embeddings - - def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -886,7 +738,6 @@ def apply_patch(): # patch module _patch_mla_attention() _patch_TEGroupedLinear() - _patch_mtp() _patch_TransformerLayer() _patch_TELinear() _patch_mrope()