-
Notifications
You must be signed in to change notification settings - Fork 13
Support MTP weight reuse with unrolled steps #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| scatter_to_sequence_parallel_region) | ||
| from megatron.core.transformer import TransformerLayer | ||
| from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention | ||
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer | ||
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer, MultiTokenPredictionBlock, get_mtp_layer_offset | ||
| from megatron.core.utils import deprecate_inference_params | ||
| from packaging import version | ||
| from peft.tuners.tuners_utils import BaseTuner | ||
|
|
@@ -394,6 +394,7 @@ def forward( | |
| packed_seq_params: PackedSeqParams = None, | ||
| sequence_len_offset: torch.Tensor = None, | ||
| embedding=None, | ||
| depth_idx: int = None, | ||
| ): | ||
| """ | ||
| Execute the forward pass through the Multi-Token Prediction (MTP) layer. | ||
|
|
@@ -417,7 +418,9 @@ def forward( | |
| Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape | ||
| [s, b, h], and optionally the updated context tensor if cross-attention is used. | ||
| """ | ||
| # TODO: Multimodal compatible | ||
| # current unroll depth | ||
| effective_depth = self.layer_number if depth_idx is None else depth_idx | ||
|
|
||
| assert context is None, 'multi token prediction + cross attention is not yet supported.' | ||
| input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( | ||
| input_ids=input_ids, | ||
|
|
@@ -433,7 +436,7 @@ def forward( | |
| rotary_pos_emb = rotary_pos_emb[position_ids[0]] | ||
| else: | ||
| # mrope or not packed_seq | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0) | ||
| if self.config.recompute_granularity == 'full' and self.training: | ||
| hidden_states = self._checkpointed_forward( | ||
| partial( | ||
|
|
@@ -471,6 +474,60 @@ def forward( | |
|
|
||
| MultiTokenPredictionLayer.forward = forward | ||
|
|
||
| def block_forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| context: torch.Tensor = None, | ||
| context_mask: torch.Tensor = None, | ||
| rotary_pos_emb: torch.Tensor = None, | ||
| rotary_pos_cos: torch.Tensor = None, | ||
| rotary_pos_sin: torch.Tensor = None, | ||
| attention_bias: torch.Tensor = None, | ||
| inference_params: InferenceParams = None, | ||
| packed_seq_params: PackedSeqParams = None, | ||
| sequence_len_offset: torch.Tensor = None, | ||
| extra_block_kwargs: Optional[dict] = None, | ||
| embedding=None, | ||
| ): | ||
| """Perform the forward pass through all MTP modules with optional layer reuse.""" | ||
| offset = get_mtp_layer_offset(self.config, self.vp_stage) | ||
| hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) | ||
| hidden_states = hidden_states_list[offset] | ||
|
|
||
| physical_num_layers = len(self.layers) | ||
| unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers | ||
|
|
||
| for step in range(unroll_steps): | ||
| layer = self.layers[step % physical_num_layers] | ||
| global_depth = offset + step + 1 | ||
| hidden_states, input_ids, position_ids = layer( | ||
| input_ids=input_ids, | ||
| position_ids=position_ids, | ||
| hidden_states=hidden_states, | ||
| attention_mask=attention_mask, | ||
| inference_params=inference_params, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| embedding=embedding, | ||
| depth_idx=global_depth, | ||
| **(extra_block_kwargs or {}), | ||
| ) | ||
| hidden_states_list.append(hidden_states) | ||
|
|
||
| hidden_states = torch.cat(hidden_states_list, dim=0) | ||
| return hidden_states | ||
|
Comment on lines
+498
to
+526
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
If this feature is primarily intended for the |
||
|
|
||
| MultiTokenPredictionBlock.forward = block_forward | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def _get_embeddings( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import mcore_bridge # noqa: F401 | ||
| import mcore_bridge.model.gpt_model as gpt_model_mod | ||
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock | ||
| from mcore_bridge.model.gpt_model import GPTModel | ||
|
|
||
|
|
||
| class RecordingLayer: | ||
|
|
||
| def __init__(self): | ||
| self.depth_history = [] | ||
|
|
||
| def __call__( | ||
| self, | ||
| *, | ||
| input_ids, | ||
| position_ids, | ||
| hidden_states, | ||
| attention_mask, | ||
| depth_idx=None, | ||
| **kwargs, | ||
| ): | ||
| self.depth_history.append(depth_idx) | ||
| return hidden_states + depth_idx, input_ids, position_ids | ||
|
|
||
|
|
||
| class RecordingOutputLayer: | ||
|
|
||
| def __init__(self): | ||
| self.calls = [] | ||
| self.sequence_parallel = False | ||
|
|
||
| def __call__(self, hidden_states, weight=None, runtime_gather_output=None): | ||
| self.calls.append(hidden_states.clone()) | ||
| return hidden_states.squeeze(-1).transpose(0, 1), None | ||
|
|
||
|
|
||
| def test_mtp_block_reuses_single_physical_layer_across_unroll_steps(): | ||
| layer = RecordingLayer() | ||
| block = SimpleNamespace( | ||
| config=SimpleNamespace( | ||
| mtp_num_layers=1, | ||
| mtp_unroll_steps=3, | ||
| pipeline_model_parallel_size=1, | ||
| pipeline_model_parallel_layout=None, | ||
| ), | ||
| vp_stage=None, | ||
| layers=[layer], | ||
| ) | ||
|
|
||
| hidden_states = torch.zeros(2, 1, 1) | ||
| input_ids = torch.zeros(1, 2, dtype=torch.long) | ||
| position_ids = torch.zeros(1, 2, dtype=torch.long) | ||
|
|
||
| output = MultiTokenPredictionBlock.forward( | ||
| block, | ||
| input_ids=input_ids, | ||
| position_ids=position_ids, | ||
| hidden_states=hidden_states, | ||
| attention_mask=None, | ||
| ) | ||
|
|
||
| assert layer.depth_history == [1, 2, 3] | ||
|
|
||
| chunks = torch.chunk(output, 4, dim=0) | ||
| assert [chunk[0, 0, 0].item() for chunk in chunks] == [0.0, 1.0, 3.0, 6.0] | ||
|
|
||
|
|
||
| def test_postprocess_uses_unroll_steps_for_mtp_loss(monkeypatch): | ||
| saved_losses = [] | ||
| monkeypatch.setattr( | ||
| gpt_model_mod, | ||
| 'roll_tensor', | ||
| lambda tensor, shifts, dims, cp_group=None, packed_seq_params=None: (tensor, tensor.numel()), | ||
| ) | ||
| monkeypatch.setattr( | ||
| gpt_model_mod.MTPLossAutoScaler, | ||
| 'apply', | ||
| lambda hidden_states, scaled_loss: hidden_states, | ||
| ) | ||
| monkeypatch.setattr( | ||
| gpt_model_mod.MTPLossLoggingHelper, | ||
| 'save_loss_to_tracker', | ||
| lambda loss, layer_number, total_layers, avg_group=None: saved_losses.append( | ||
| (layer_number, total_layers) | ||
| ), | ||
| ) | ||
| monkeypatch.setattr( | ||
| gpt_model_mod.parallel_state, | ||
| 'get_data_parallel_group', | ||
| lambda with_context_parallel=True: None, | ||
| ) | ||
| monkeypatch.setattr(gpt_model_mod, 'has_config_logger_enabled', lambda config: False) | ||
|
|
||
| output_layer = RecordingOutputLayer() | ||
|
|
||
| def mtp_forward(**kwargs): | ||
| hidden_states = kwargs['hidden_states'] | ||
| return torch.cat([hidden_states, hidden_states + 1, hidden_states + 2, hidden_states + 3], dim=0) | ||
|
|
||
| model = SimpleNamespace( | ||
| post_process=True, | ||
| mtp_process=True, | ||
| training=True, | ||
| share_embeddings_and_output_weights=False, | ||
| cp_group=None, | ||
| embedding=lambda *args, **kwargs: None, | ||
| output_layer=output_layer, | ||
| mtp=mtp_forward, | ||
| config=SimpleNamespace( | ||
| task_type='causal_lm', | ||
| is_multimodal=False, | ||
| context_parallel_size=1, | ||
| mtp_num_layers=1, | ||
| mtp_unroll_steps=3, | ||
| decoder_input_detach=True, | ||
| calculate_per_token_loss=False, | ||
| mtp_loss_scaling_factor=0.3, | ||
| sequence_parallel=False, | ||
| tensor_model_parallel_size=1, | ||
| ), | ||
| compute_language_model_loss=lambda labels, logits: logits.float(), | ||
| ) | ||
|
|
||
| loss = GPTModel._postprocess( | ||
| model, | ||
| hidden_states=torch.zeros(1, 1, 1), | ||
| input_ids=None, | ||
| position_ids=None, | ||
| labels=torch.ones(1, 1, dtype=torch.long), | ||
| rotary_pos_emb=None, | ||
| rotary_pos_cos=None, | ||
| rotary_pos_sin=None, | ||
| loss_mask=torch.ones(1, 1, dtype=torch.bool), | ||
| decoder_input=None, | ||
| attention_mask=None, | ||
| inference_params=None, | ||
| packed_seq_params=None, | ||
| sequence_len_offset=None, | ||
| runtime_gather_output=False, | ||
| extra_block_kwargs=None, | ||
| inference_context=None, | ||
| ) | ||
|
|
||
| assert loss.shape == (1, 1) | ||
| assert [call[0, 0, 0].item() for call in output_layer.calls] == [1.0, 2.0, 3.0, 0.0] | ||
| assert saved_losses == [(0, 3), (1, 3), (2, 3)] | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| raise SystemExit(pytest.main([__file__, '-q'])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
torch.rolloperation will fail ifrotary_pos_embisNone. This occurs in models that do not use RoPE or MRoPE (e.g., models using absolute position embeddings).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This
rotary_pos_emb is Noneassumption is pre-existing in the base branch; this PR only changes the shift depth fromself.layer_numbertoeffective_depthfor the unrolled MTP case, and does not introduce a new dereference path here. If needed, I can address theNoneguard separately in a follow-up cleanup.