Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, config: ModelConfig):
self._only_master_rank = False
self._peft_target_modules = set()
self._peft_modules_to_save = set()
self._fp8_skip_modules = set()
Comment thread
Jintao-Huang marked this conversation as resolved.
Comment thread
Jintao-Huang marked this conversation as resolved.
self._peft_format = False
self._adapter_name = 'default'
self.model_type = config.hf_model_type
Expand Down Expand Up @@ -1747,6 +1748,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd
def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict):
for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']:
self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore)
self._fp8_skip_modules.update({'eh_proj'})
self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore)

def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
Expand Down Expand Up @@ -1862,6 +1864,7 @@ def export_weights(
self._disable_tqdm = disable_tqdm
self._peft_target_modules = set()
self._peft_modules_to_save = set()
self._fp8_skip_modules = set()
hf_prefix = 'base_model.model.' if peft_format else ''
mg_models = unwrap_model(mg_models)
for i, mg_model in enumerate(mg_models):
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/gpts/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state
for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'],
['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']):
self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore)
self._fp8_skip_modules.update({'mtp.fc'})
Comment thread
Jintao-Huang marked this conversation as resolved.
self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore)
if not to_mcore:
origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.'))
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .gated_delta_net import GatedDeltaNet
from .gated_self_attention import GatedSelfAttention
from .mtp_layer import MultiTokenPredictionLayer
218 changes: 218 additions & 0 deletions src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import torch
import transformer_engine
from contextlib import nullcontext
from functools import partial
from megatron.core import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
scatter_to_sequence_parallel_region)
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer as _MultiTokenPredictionLayer
from megatron.core.transformer.spec_utils import build_module
from megatron.core.utils import make_viewless_tensor
from typing import Callable, Optional

from mcore_bridge.config import ModelConfig


class MultiTokenPredictionLayer(_MultiTokenPredictionLayer):

def __init__(self, config: ModelConfig, submodules, *args, **kwargs):
if config.fp8_param:
eh_proj = submodules.eh_proj
submodules.eh_proj = IdentityOp
super().__init__(config, submodules, *args, **kwargs)
if not config.fp8_param:
return
del self.eh_proj
submodules.eh_proj = eh_proj
fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False)
with fp8_context:
self.eh_proj = build_module(
self.submodules.eh_proj,
self.config.hidden_size * 2,
self.config.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='mtp_eh_proj',
tp_group=self.tp_group,
Comment thread
Jintao-Huang marked this conversation as resolved.
)

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor = None,
context_mask: torch.Tensor = None,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_cos: torch.Tensor = None,
rotary_pos_sin: torch.Tensor = None,
attention_bias: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
embedding=None,
):
"""
Execute the forward pass through the Multi-Token Prediction (MTP) layer.

Args:
input_ids (Tensor): Input token IDs .
position_ids (Tensor): Positional IDs of the input tokens.
hidden_states (Tensor): Hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention, if applicable.
context_mask (Tensor, optional): Mask for cross-attention context, if applicable.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
rotary_pos_cos (Tensor, optional): Cosine component of rotary positional embeddings.
rotary_pos_sin (Tensor, optional): Sine component of rotary positional embeddings.
sequence_len_offset (Tensor, optional): Offset for sequence length, if applicable.
embedding (Callable): The embedding module from gpt model to compute the decoder input.

Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated
assert context is None, 'multi token prediction + cross attention is not yet supported.'
input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding=embedding,
packed_seq_params=packed_seq_params,
hidden_states=hidden_states,
)
assert not self.transformer_layer.self_attention.config.apply_rope_fusion
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.config.position_embedding_type == 'rope' and packed_seq:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
else:
# mrope or not packed_seq
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0)
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
self._proj_and_transformer_layer,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
),
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
)
else:
hidden_states = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
return hidden_states, input_ids, position_ids

def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor):
"""
Concatenate the tokens before sending to transformer layer.
"""
try:
from megatron.core.typed_torch import apply_module
Comment thread
Jintao-Huang marked this conversation as resolved.
decoder_input = apply_module(self.enorm)(decoder_input)
hidden_states = apply_module(self.hnorm)(hidden_states)
except ImportError:
decoder_input = self.enorm(decoder_input)
hidden_states = self.hnorm(hidden_states)
decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True)
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
# At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states
# and the (i + K)-th token's embedding, and combine them with linear projection.
hidden_states = torch.cat((decoder_input, hidden_states), -1)
if self.config.fp8_param:
fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False)
else:
fp8_context = nullcontext()
with fp8_context:
hidden_states, _ = self.eh_proj(hidden_states)
# For tensor parallel we need to gather the tensor across the model-parallel
# ranks after the linear projection. This used to call
# `all_gather_last_dim_from_tensor_parallel_region`, but that utility reduces
# the gradient in backward pass and was therefore incorrect in this context.
# It has been replaced with the correct `gather_from_tensor_model_parallel_region`.
hidden_states = gather_from_tensor_model_parallel_region(hidden_states, group=self.tp_group)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if self.sequence_parallel:
hidden_states = scatter_to_sequence_parallel_region(hidden_states, group=self.tp_group)
Comment thread
Jintao-Huang marked this conversation as resolved.
return hidden_states

def _get_embeddings(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
packed_seq_params: Optional[PackedSeqParams] = None,
):
from megatron.core.transformer.multi_token_prediction import roll_tensor
from megatron.core.utils import make_viewless_tensor
Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated

# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(
input_ids,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
position_ids, _ = roll_tensor(
position_ids,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
# embedding
if isinstance(embedding, tuple):
embedding, decoder_input = embedding
else:
decoder_input = None
if decoder_input is None:
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
else:
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
if enable_sp:
decoder_input = gather_from_sequence_parallel_region(decoder_input)
decoder_input, _ = roll_tensor(
decoder_input.transpose(0, 2),
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
decoder_input = decoder_input.transpose(0, 2).contiguous()
if enable_sp:
decoder_input = scatter_to_sequence_parallel_region(decoder_input)
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)

return input_ids, position_ids, decoder_input, hidden_states
7 changes: 6 additions & 1 deletion src/mcore_bridge/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mcore_bridge.config import ModelConfig
from mcore_bridge.utils import get_logger

from .modules import MultiTokenPredictionLayer

if TYPE_CHECKING:
from .gpt_model import GPTModel
from .mm_gpt_model import MultimodalGPTModel
Expand Down Expand Up @@ -121,8 +123,11 @@ def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = N
else:
transformer_layer_spec_for_mtp = transformer_layer_spec
kwargs = {'vp_stage': vp_stage} if self.mcore_013 else {}
return get_gpt_mtp_block_spec(
mtp_block_spec = get_gpt_mtp_block_spec(
self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs)
for layer_spec in mtp_block_spec.layer_specs:
layer_spec.module = MultiTokenPredictionLayer
return mtp_block_spec

def _set_shared_expert_gate(self, transformer_layer_spec):
mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0')
Expand Down
Loading
Loading