From 578e9fbf390cfa211a024b4480142b5b332420ab Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Apr 2026 00:25:26 +0800 Subject: [PATCH 1/4] support mtp_shared_weights --- src/mcore_bridge/config/model_config.py | 11 +++++++++-- src/mcore_bridge/model/gpt_model.py | 8 ++++---- src/mcore_bridge/model/modules/mtp_layer.py | 5 ++++- src/mcore_bridge/patcher.py | 5 +++-- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index ba55169..f952ea5 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -204,6 +204,10 @@ class ModelConfig(TransformerConfig): dsa_indexer_use_sparse_loss: bool = False dsa_indexer_rotary_interleaved: bool = False + # mtp + mtp_decoder_input_detach: bool = False + mtp_shared_weights: bool = False + # visual hf_config: Optional[PretrainedConfig] = None vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2' @@ -225,7 +229,6 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False - mtp_decoder_input_detach: bool = False _mindspeed_defaults_cache = None @@ -305,7 +308,11 @@ def __post_init__(self): self.apply_query_key_layer_scaling = self.fp16 if self.apply_query_key_layer_scaling: os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1' - # patch rotary_interleaved + if self.mtp_shared_weights: + self.mtp_unroll_steps = self.mtp_num_layers + self.mtp_num_layers = 1 + else: + self.mtp_unroll_steps = self.mtp_num_layers super().__post_init__() self._check_npu() diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 2486725..5fb714c 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -420,12 +420,12 @@ def _postprocess( **(extra_block_kwargs or {}), ) mtp_labels = labels.clone() - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_unroll_steps, dim=0) hidden_states = hidden_states_list[0] if loss_mask is None: # if loss_mask is not provided, use all ones as loss_mask loss_mask = torch.ones_like(mtp_labels) - for mtp_layer_number in range(self.config.mtp_num_layers): + for mtp_layer_number in range(self.config.mtp_unroll_steps): # output mtp_logits, _ = self.output_layer( hidden_states_list[mtp_layer_number + 1], @@ -457,10 +457,10 @@ def _postprocess( MTPLossLoggingHelper.save_loss_to_tracker( mtp_loss_for_log, mtp_layer_number, - self.config.mtp_num_layers, + self.config.mtp_unroll_steps, avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_unroll_steps if self.config.calculate_per_token_loss: hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) else: diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 0909610..c2d069b 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -65,8 +65,11 @@ def forward( sequence_len_offset: torch.Tensor = None, embedding=None, decoder_input=None, + layer_number: Optional[int] = None, ): assert context is None, 'multi token prediction + cross attention is not yet supported.' + if layer_number is None: + layer_number = self.layer_number input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, position_ids=position_ids, @@ -82,7 +85,7 @@ def forward( rotary_pos_emb = rotary_pos_emb[position_ids[0]] else: # mrope or not packed_seq - rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) + rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0) if self.config.recompute_granularity == 'full' and self.training: hidden_states = self._checkpointed_forward( partial( diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 7d0d115..2c039cd 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -733,13 +733,14 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_st hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) hidden_states = hidden_states_list[offset] mtp_decoder_input = decoder_input = kwargs.pop('decoder_input', None) - for layer_number in range(len(self.layers)): - (hidden_states, input_ids, position_ids, decoder_input) = self.layers[layer_number]( + for layer_number in range(self.config.mtp_unroll_steps): + (hidden_states, input_ids, position_ids, decoder_input) = self.layers[layer_number % len(self.layers)]( input_ids=input_ids, position_ids=position_ids, hidden_states=hidden_states, attention_mask=attention_mask, decoder_input=decoder_input, + layer_number=layer_number + 1, **kwargs, ) if mtp_decoder_input is None: From 8627bd938704cfbc753d3f3b58c033b5f39a8c09 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Apr 2026 10:41:15 +0800 Subject: [PATCH 2/4] fix --- src/mcore_bridge/config/model_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index f952ea5..8b147e2 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -309,6 +309,7 @@ def __post_init__(self): if self.apply_query_key_layer_scaling: os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1' if self.mtp_shared_weights: + assert self.mtp_num_layers is not None self.mtp_unroll_steps = self.mtp_num_layers self.mtp_num_layers = 1 else: From 059293fce3aa072030982e9ca7aa0345c3c9f701 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Apr 2026 11:10:20 +0800 Subject: [PATCH 3/4] fix --- src/mcore_bridge/patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 2c039cd..c97cd5b 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -730,6 +730,7 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_st attention_mask: torch.Tensor, **kwargs) -> torch.Tensor: # get hidden states from previous mtp stages offset = get_mtp_layer_offset(self.config, self.vp_stage) + assert offset == 0, f'not supported' hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) hidden_states = hidden_states_list[offset] mtp_decoder_input = decoder_input = kwargs.pop('decoder_input', None) From 7b8da5756d5940bf01df2af4c9d2ae1861146c56 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Apr 2026 11:15:36 +0800 Subject: [PATCH 4/4] fix --- src/mcore_bridge/patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index c97cd5b..f2009a1 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -730,7 +730,7 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_st attention_mask: torch.Tensor, **kwargs) -> torch.Tensor: # get hidden states from previous mtp stages offset = get_mtp_layer_offset(self.config, self.vp_stage) - assert offset == 0, f'not supported' + assert offset == 0, 'not support offset' hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) hidden_states = hidden_states_list[offset] mtp_decoder_input = decoder_input = kwargs.pop('decoder_input', None)