[bugfix] Fix mtp fp8#35
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Multi-Token Prediction (MTP) layer by moving the patched logic from patcher.py into a dedicated MultiTokenPredictionLayer class in src/mcore_bridge/model/modules/mtp_layer.py. It also introduces _fp8_skip_modules tracking in the bridge and model conversion logic. Feedback highlights potential AttributeError due to undefined parallel state attributes, dead code regarding the unused _fp8_skip_modules set, and performance concerns related to local imports within the forward pass.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the Multi-Token Prediction (MTP) implementation by replacing monkey-patching with a dedicated MultiTokenPredictionLayer subclass that includes FP8 support. It also updates model registration and bridge logic to accommodate these changes. Feedback highlights a potential logic error in sequence parallel handling that could lead to incorrect tensor sharding, identifies dead code related to the unused _fp8_skip_modules variable, and suggests correcting the forward method's docstring to match its actual return type.
No description provided.