Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ class ModelConfig(TransformerConfig):
dsa_indexer_use_sparse_loss: bool = False
dsa_indexer_rotary_interleaved: bool = False

# mtp
mtp_decoder_input_detach: bool = False
mtp_shared_weights: bool = False
Comment thread
Jintao-Huang marked this conversation as resolved.
Comment thread
Jintao-Huang marked this conversation as resolved.

# visual
hf_config: Optional[PretrainedConfig] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'
Expand All @@ -225,7 +229,6 @@ class ModelConfig(TransformerConfig):
task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm'
num_labels: Optional[int] = None
mlp_padding_free: bool = False
mtp_decoder_input_detach: bool = False

_mindspeed_defaults_cache = None

Expand Down Expand Up @@ -305,7 +308,12 @@ def __post_init__(self):
self.apply_query_key_layer_scaling = self.fp16
if self.apply_query_key_layer_scaling:
os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1'
# patch rotary_interleaved
if self.mtp_shared_weights:
assert self.mtp_num_layers is not None
self.mtp_unroll_steps = self.mtp_num_layers
self.mtp_num_layers = 1
else:
self.mtp_unroll_steps = self.mtp_num_layers
Comment thread
Jintao-Huang marked this conversation as resolved.
Comment on lines +311 to +316
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization logic for mtp_unroll_steps should be more robust. If mtp_num_layers is None and mtp_shared_weights is False, mtp_unroll_steps will be assigned None, which will cause a TypeError in gpt_model.py (line 423) when calculating 1 + self.config.mtp_unroll_steps. Additionally, it is safer to ensure mtp_num_layers is positive when weight sharing is enabled.

Suggested change
if self.mtp_shared_weights:
assert self.mtp_num_layers is not None
self.mtp_unroll_steps = self.mtp_num_layers
self.mtp_num_layers = 1
else:
self.mtp_unroll_steps = self.mtp_num_layers
if self.mtp_shared_weights:
assert self.mtp_num_layers is not None and self.mtp_num_layers > 0, \
"mtp_num_layers must be > 0 when mtp_shared_weights is True"
self.mtp_unroll_steps = self.mtp_num_layers
self.mtp_num_layers = 1
else:
self.mtp_unroll_steps = self.mtp_num_layers or 0

super().__post_init__()

self._check_npu()
Expand Down
8 changes: 4 additions & 4 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,12 @@ def _postprocess(
**(extra_block_kwargs or {}),
)
mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_unroll_steps, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
for mtp_layer_number in range(self.config.mtp_unroll_steps):
# output
mtp_logits, _ = self.output_layer(
hidden_states_list[mtp_layer_number + 1],
Expand Down Expand Up @@ -457,10 +457,10 @@ def _postprocess(
MTPLossLoggingHelper.save_loss_to_tracker(
mtp_loss_for_log,
mtp_layer_number,
self.config.mtp_num_layers,
self.config.mtp_unroll_steps,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_unroll_steps
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
Expand Down
5 changes: 4 additions & 1 deletion src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def forward(
sequence_len_offset: torch.Tensor = None,
embedding=None,
decoder_input=None,
layer_number: Optional[int] = None,
):
assert context is None, 'multi token prediction + cross attention is not yet supported.'
if layer_number is None:
layer_number = self.layer_number
input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
position_ids=position_ids,
Expand All @@ -82,7 +85,7 @@ def forward(
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
else:
# mrope or not packed_seq
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0)
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0)
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
Expand Down
6 changes: 4 additions & 2 deletions src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,16 +730,18 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_st
attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
# get hidden states from previous mtp stages
offset = get_mtp_layer_offset(self.config, self.vp_stage)
assert offset == 0, 'not support offset'
hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0))
hidden_states = hidden_states_list[offset]
mtp_decoder_input = decoder_input = kwargs.pop('decoder_input', None)
for layer_number in range(len(self.layers)):
(hidden_states, input_ids, position_ids, decoder_input) = self.layers[layer_number](
for layer_number in range(self.config.mtp_unroll_steps):
(hidden_states, input_ids, position_ids, decoder_input) = self.layers[layer_number % len(self.layers)](
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=attention_mask,
decoder_input=decoder_input,
layer_number=layer_number + 1,
**kwargs,
)
if mtp_decoder_input is None:
Expand Down
Loading