support mtp_shared_weights#41
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for shared weights in Multi-Token Prediction (MTP) by decoupling the number of physical layers from the number of unroll steps. Key changes include updating ModelConfig to handle weight sharing logic, modifying gpt_model.py to use unroll steps for loss calculation and state chunking, and updating the MTP layer and patcher to support dynamic layer indexing for rotary embeddings. Feedback suggests refining the initialization logic to prevent accidental activation of MTP when layers are zero and explicitly declaring mtp_unroll_steps in the configuration dataclass.
|
Any related update to MS-SWIFT? I suggest that the mtp_unroll_steps should be reserved in metadata (i.e. config.json) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Multi-Token Prediction (MTP) with shared weights by adding mtp_shared_weights and mtp_unroll_steps to the configuration. The logic in gpt_model.py and patcher.py has been updated to use mtp_unroll_steps for loop iterations and loss scaling, allowing the model to unroll more steps than the number of physical layers when weights are shared. Feedback highlights the need to explicitly declare mtp_unroll_steps as a field in the ModelConfig dataclass and to improve the robustness of its initialization to avoid potential TypeError when mtp_num_layers is missing.
| if self.mtp_shared_weights: | ||
| assert self.mtp_num_layers is not None | ||
| self.mtp_unroll_steps = self.mtp_num_layers | ||
| self.mtp_num_layers = 1 | ||
| else: | ||
| self.mtp_unroll_steps = self.mtp_num_layers |
There was a problem hiding this comment.
The initialization logic for mtp_unroll_steps should be more robust. If mtp_num_layers is None and mtp_shared_weights is False, mtp_unroll_steps will be assigned None, which will cause a TypeError in gpt_model.py (line 423) when calculating 1 + self.config.mtp_unroll_steps. Additionally, it is safer to ensure mtp_num_layers is positive when weight sharing is enabled.
| if self.mtp_shared_weights: | |
| assert self.mtp_num_layers is not None | |
| self.mtp_unroll_steps = self.mtp_num_layers | |
| self.mtp_num_layers = 1 | |
| else: | |
| self.mtp_unroll_steps = self.mtp_num_layers | |
| if self.mtp_shared_weights: | |
| assert self.mtp_num_layers is not None and self.mtp_num_layers > 0, \ | |
| "mtp_num_layers must be > 0 when mtp_shared_weights is True" | |
| self.mtp_unroll_steps = self.mtp_num_layers | |
| self.mtp_num_layers = 1 | |
| else: | |
| self.mtp_unroll_steps = self.mtp_num_layers or 0 |
#29