Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ class ModelConfig(TransformerConfig):
dsa_indexer_use_sparse_loss: bool = False
dsa_indexer_rotary_interleaved: bool = False

# mtp
mtp_unroll_steps: Optional[int] = None
decoder_input_detach: bool = True

# visual
hf_config: Optional[PretrainedConfig] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'linear_key_head_dim': ['linear_key_head_dim'],
'linear_value_head_dim': ['linear_value_head_dim'],
'linear_conv_kernel_dim': ['linear_conv_kernel_dim'],
'mtp_unroll_steps': ['mtp_unroll_steps'],
# dsa
'dsa_indexer_n_heads': ['index_n_heads'],
'dsa_indexer_head_dim': ['index_head_dim'],
Expand Down
12 changes: 7 additions & 5 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,10 @@ def _postprocess(
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)

if self.mtp_process and labels is not None:
mtp_depth = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers
if self.config.is_multimodal:
embedding_ = (self.embedding, decoder_input)
_decoder_input = decoder_input.detach() if self.config.decoder_input_detach else decoder_input
embedding_ = (self.embedding, _decoder_input)
else:
embedding_ = self.embedding
hidden_states = self.mtp(
Expand All @@ -423,12 +425,12 @@ def _postprocess(
**(extra_block_kwargs or {}),
)
mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states_list = torch.chunk(hidden_states, 1 + mtp_depth, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
for mtp_layer_number in range(mtp_depth):
# output
mtp_logits, _ = self.output_layer(
hidden_states_list[mtp_layer_number + 1],
Expand Down Expand Up @@ -460,10 +462,10 @@ def _postprocess(
MTPLossLoggingHelper.save_loss_to_tracker(
mtp_loss_for_log,
mtp_layer_number,
self.config.mtp_num_layers,
mtp_depth,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
mtp_loss_scale = self.config.mtp_loss_scaling_factor / mtp_depth
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
Expand Down
63 changes: 60 additions & 3 deletions src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
scatter_to_sequence_parallel_region)
from megatron.core.transformer import TransformerLayer
from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer, MultiTokenPredictionBlock, get_mtp_layer_offset
from megatron.core.utils import deprecate_inference_params
from packaging import version
from peft.tuners.tuners_utils import BaseTuner
Expand Down Expand Up @@ -394,6 +394,7 @@ def forward(
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
embedding=None,
depth_idx: int = None,
):
"""
Execute the forward pass through the Multi-Token Prediction (MTP) layer.
Expand All @@ -417,7 +418,9 @@ def forward(
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
# TODO: Multimodal compatible
# current unroll depth
effective_depth = self.layer_number if depth_idx is None else depth_idx

assert context is None, 'multi token prediction + cross attention is not yet supported.'
input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
Expand All @@ -433,7 +436,7 @@ def forward(
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
else:
# mrope or not packed_seq
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0)
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.roll operation will fail if rotary_pos_emb is None. This occurs in models that do not use RoPE or MRoPE (e.g., models using absolute position embeddings).

            if rotary_pos_emb is not None:
                rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This rotary_pos_emb is None assumption is pre-existing in the base branch; this PR only changes the shift depth from self.layer_number to effective_depth for the unrolled MTP case, and does not introduce a new dereference path here. If needed, I can address the None guard separately in a follow-up cleanup.

if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
Expand Down Expand Up @@ -471,6 +474,60 @@ def forward(

MultiTokenPredictionLayer.forward = forward

def block_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor = None,
context_mask: torch.Tensor = None,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_cos: torch.Tensor = None,
rotary_pos_sin: torch.Tensor = None,
attention_bias: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
extra_block_kwargs: Optional[dict] = None,
embedding=None,
):
"""Perform the forward pass through all MTP modules with optional layer reuse."""
offset = get_mtp_layer_offset(self.config, self.vp_stage)
hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0))
hidden_states = hidden_states_list[offset]

physical_num_layers = len(self.layers)
unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers

for step in range(unroll_steps):
layer = self.layers[step % physical_num_layers]
global_depth = offset + step + 1
hidden_states, input_ids, position_ids = layer(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=embedding,
depth_idx=global_depth,
**(extra_block_kwargs or {}),
)
hidden_states_list.append(hidden_states)

hidden_states = torch.cat(hidden_states_list, dim=0)
return hidden_states
Comment on lines +498 to +526
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block_forward implementation is not fully compatible with Pipeline Parallelism (PP) when mtp_unroll_steps is used:

  1. ZeroDivisionError: If a PP stage contains no MTP layers (len(self.layers) == 0), line 504 will crash due to step % physical_num_layers. A guard is needed to return early or skip the loop in such stages.
  2. Incorrect Chunking and Offset: torch.chunk(hidden_states, 1 + offset, dim=0) and global_depth = offset + step + 1 rely on the physical layer offset. If previous stages have already performed logical unrolling, the number of chunks in hidden_states and the starting logical depth will be higher than what the physical offset indicates.
  3. Redundant Execution: Every stage with MTP layers will attempt to execute the full unroll_steps. In a PP setup where MTP layers are distributed, this leads to an incorrect total number of logical steps and mismatched chunk counts in _postprocess.

If this feature is primarily intended for the mtp_num_layers=1 case (single shared layer), please add a check for physical_num_layers > 0 and consider documenting the PP limitations.


MultiTokenPredictionBlock.forward = block_forward




def _get_embeddings(
self,
input_ids: torch.Tensor,
Expand Down
155 changes: 155 additions & 0 deletions tests/test_mtp_reuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from types import SimpleNamespace

import pytest
import torch

import mcore_bridge # noqa: F401
import mcore_bridge.model.gpt_model as gpt_model_mod
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock
from mcore_bridge.model.gpt_model import GPTModel


class RecordingLayer:

def __init__(self):
self.depth_history = []

def __call__(
self,
*,
input_ids,
position_ids,
hidden_states,
attention_mask,
depth_idx=None,
**kwargs,
):
self.depth_history.append(depth_idx)
return hidden_states + depth_idx, input_ids, position_ids


class RecordingOutputLayer:

def __init__(self):
self.calls = []
self.sequence_parallel = False

def __call__(self, hidden_states, weight=None, runtime_gather_output=None):
self.calls.append(hidden_states.clone())
return hidden_states.squeeze(-1).transpose(0, 1), None


def test_mtp_block_reuses_single_physical_layer_across_unroll_steps():
layer = RecordingLayer()
block = SimpleNamespace(
config=SimpleNamespace(
mtp_num_layers=1,
mtp_unroll_steps=3,
pipeline_model_parallel_size=1,
pipeline_model_parallel_layout=None,
),
vp_stage=None,
layers=[layer],
)

hidden_states = torch.zeros(2, 1, 1)
input_ids = torch.zeros(1, 2, dtype=torch.long)
position_ids = torch.zeros(1, 2, dtype=torch.long)

output = MultiTokenPredictionBlock.forward(
block,
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=None,
)

assert layer.depth_history == [1, 2, 3]

chunks = torch.chunk(output, 4, dim=0)
assert [chunk[0, 0, 0].item() for chunk in chunks] == [0.0, 1.0, 3.0, 6.0]


def test_postprocess_uses_unroll_steps_for_mtp_loss(monkeypatch):
saved_losses = []
monkeypatch.setattr(
gpt_model_mod,
'roll_tensor',
lambda tensor, shifts, dims, cp_group=None, packed_seq_params=None: (tensor, tensor.numel()),
)
monkeypatch.setattr(
gpt_model_mod.MTPLossAutoScaler,
'apply',
lambda hidden_states, scaled_loss: hidden_states,
)
monkeypatch.setattr(
gpt_model_mod.MTPLossLoggingHelper,
'save_loss_to_tracker',
lambda loss, layer_number, total_layers, avg_group=None: saved_losses.append(
(layer_number, total_layers)
),
)
monkeypatch.setattr(
gpt_model_mod.parallel_state,
'get_data_parallel_group',
lambda with_context_parallel=True: None,
)
monkeypatch.setattr(gpt_model_mod, 'has_config_logger_enabled', lambda config: False)

output_layer = RecordingOutputLayer()

def mtp_forward(**kwargs):
hidden_states = kwargs['hidden_states']
return torch.cat([hidden_states, hidden_states + 1, hidden_states + 2, hidden_states + 3], dim=0)

model = SimpleNamespace(
post_process=True,
mtp_process=True,
training=True,
share_embeddings_and_output_weights=False,
cp_group=None,
embedding=lambda *args, **kwargs: None,
output_layer=output_layer,
mtp=mtp_forward,
config=SimpleNamespace(
task_type='causal_lm',
is_multimodal=False,
context_parallel_size=1,
mtp_num_layers=1,
mtp_unroll_steps=3,
decoder_input_detach=True,
calculate_per_token_loss=False,
mtp_loss_scaling_factor=0.3,
sequence_parallel=False,
tensor_model_parallel_size=1,
),
compute_language_model_loss=lambda labels, logits: logits.float(),
)

loss = GPTModel._postprocess(
model,
hidden_states=torch.zeros(1, 1, 1),
input_ids=None,
position_ids=None,
labels=torch.ones(1, 1, dtype=torch.long),
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
loss_mask=torch.ones(1, 1, dtype=torch.bool),
decoder_input=None,
attention_mask=None,
inference_params=None,
packed_seq_params=None,
sequence_len_offset=None,
runtime_gather_output=False,
extra_block_kwargs=None,
inference_context=None,
)

assert loss.shape == (1, 1)
assert [call[0, 0, 0].item() for call in output_layer.calls] == [1.0, 2.0, 3.0, 0.0]
assert saved_losses == [(0, 3), (1, 3), (2, 3)]


if __name__ == '__main__':
raise SystemExit(pytest.main([__file__, '-q']))
Loading