-
Notifications
You must be signed in to change notification settings - Fork 28
support qwen3.6 grpo & in-place add lora #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
cccc369
add base_layer suffix for expert weights
hjh0119 457f941
qwen3.6 grpo
hjh0119 8f232d2
Merge branch 'main' into expert-lora
hjh0119 6968438
adjust gpu_memory_utilization to avoid oom
hjh0119 44969e8
reuse ipc buffer to a avoid oom
hjh0119 14b514a
fix test and base sync
hjh0119 ad53b3b
lint
hjh0119 71124df
revert base sync
hjh0119 633e9d4
fix gemini
hjh0119 ecf2229
clean comment
hjh0119 e5c87dd
merge main
hjh0119 7cf4a78
enable thinking false
hjh0119 756c7de
fix except error
hjh0119 c2dcd00
gemini
hjh0119 74547e1
wip
hjh0119 dab0901
fix qwen3.5 moe lora patch
hjh0119 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,271 @@ | ||
| """GRPO training script for GSM8K dataset. | ||
|
|
||
| Converted from the Tinker client version to Ray-based training. | ||
| Uses short reasoning format: shorter thinking gets higher format reward. | ||
| Answer extracted from \\boxed{} or #### format. | ||
| """ | ||
| import os | ||
| import re | ||
| from typing import List, Tuple, Dict, Any | ||
|
|
||
| from peft import LoraConfig | ||
|
|
||
| import twinkle | ||
| from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger | ||
| from twinkle.advantage import GRPOAdvantage | ||
| from twinkle.checkpoint_engine import CheckpointEngineManager | ||
| from twinkle.data_format import SamplingParams | ||
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import Dataset, DatasetMeta | ||
| from twinkle.metric import CompletionRewardMetric | ||
| from twinkle.model import TransformersModel | ||
| from twinkle.processor import InputProcessor | ||
| from twinkle.reward import GSM8KAccuracyReward | ||
| from twinkle.reward.base import Reward | ||
| from twinkle.sampler import vLLMSampler | ||
| from twinkle.preprocessor.llm import GSM8KProcessor | ||
|
|
||
| logger = get_logger() | ||
|
|
||
| # ========== Configuration ========== | ||
| MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') | ||
| USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) | ||
|
|
||
| MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) | ||
| MODEL_EP = int(os.environ.get('MODEL_EP', 2)) | ||
| MODEL_TP = int(os.environ.get('MODEL_TP', 2)) | ||
| MODEL_PP = int(os.environ.get('MODEL_PP', 2)) | ||
|
|
||
| SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) | ||
| SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2)) | ||
| NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS | ||
|
|
||
| NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) | ||
| MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) | ||
| LEARNING_RATE = float(os.environ.get('LR', 1e-5)) | ||
| MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) | ||
| BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) | ||
| MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) | ||
| MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) | ||
| GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) | ||
| ADAPTER_NAME = 'default' | ||
| SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) | ||
| LORA_RANK = int(os.environ.get('LORA_RANK', 16)) | ||
|
|
||
| SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' | ||
| 'and put your final answer within \\boxed{}.') | ||
|
|
||
| # ========== Reward Functions ========== | ||
| class GSM8KBrevityReward(Reward): | ||
| """Brevity reward: rewards shorter completions that contain a valid answer. | ||
|
|
||
| Returns 0.0 if no valid answer format (\\boxed{} or ####). | ||
| Otherwise returns higher score for shorter completions (1.0 at <=200 chars). | ||
| """ | ||
|
|
||
| def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: | ||
| rewards = [] | ||
| for traj in trajectories: | ||
| messages = traj.get('messages', []) | ||
| completion = '' | ||
| for msg in reversed(messages): | ||
| if msg.get('role') == 'assistant': | ||
| completion = msg.get('content', '') | ||
| break | ||
|
|
||
| has_answer = bool( | ||
| re.search(r'\\boxed\{[^}]+\}', completion) | ||
| or re.search(r'####\s*[\-\d,\.]+', completion) | ||
| ) | ||
|
|
||
| if not has_answer: | ||
| rewards.append(0.0) | ||
| else: | ||
| length = len(completion) | ||
| if length <= 200: | ||
| rewards.append(1.0) | ||
| else: | ||
| rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) | ||
| return rewards | ||
|
|
||
|
|
||
| # ========== Dataset ========== | ||
| def create_gsm8k_dataset(): | ||
| dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) | ||
| dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True) | ||
| dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) | ||
| dataset.encode(add_generation_prompt=True) | ||
| return dataset | ||
|
|
||
|
|
||
| def compute_rewards( | ||
| trajectories: List[Dict[str, Any]], | ||
| ) -> Tuple[List[float], List[float], List[float]]: | ||
| accuracy_reward_fn = GSM8KAccuracyReward() | ||
| brevity_reward_fn = GSM8KBrevityReward() | ||
|
|
||
| accuracy_rewards = accuracy_reward_fn(trajectories) | ||
| brevity_rewards = brevity_reward_fn(trajectories) | ||
| total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] | ||
| return total_rewards, brevity_rewards, accuracy_rewards | ||
|
|
||
|
|
||
| # ========== Main ========== | ||
| def main(): | ||
| device_groups = [ | ||
| DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), | ||
| DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP), | ||
| ] | ||
| dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP) | ||
| model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True) | ||
| sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP) | ||
| sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP) | ||
| twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) | ||
|
|
||
| lora_config = LoraConfig( | ||
| target_modules=['all-linear'], | ||
| r=LORA_RANK, | ||
| lora_alpha=LORA_RANK * 2, | ||
| lora_dropout=0.05, | ||
| ) | ||
|
|
||
| if USE_MEGATRON: | ||
| from twinkle.model.megatron import MegatronModel | ||
| model = MegatronModel( | ||
| model_id=MODEL_ID, | ||
| device_mesh=model_mesh, | ||
| remote_group='model', | ||
| mixed_precision='bf16', | ||
| ) | ||
| else: | ||
| model = TransformersModel( | ||
| model_id=MODEL_ID, | ||
| device_mesh=model_mesh, | ||
| remote_group='model', | ||
| ) | ||
|
|
||
| model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) | ||
| if USE_MEGATRON: | ||
| model.set_optimizer('default', lr=LEARNING_RATE) | ||
| model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) | ||
| else: | ||
| model.set_optimizer('AdamW', lr=LEARNING_RATE) | ||
| model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) | ||
|
|
||
| model.set_loss('GRPOLoss', epsilon=0.2) | ||
| model.set_processor(InputProcessor) | ||
| model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True) | ||
|
|
||
| sampler = vLLMSampler( | ||
| model_id=MODEL_ID, | ||
| engine_args={ | ||
| 'tensor_parallel_size': SAMPLER_TP, | ||
| 'gpu_memory_utilization': 0.7, | ||
| 'max_model_len': 8192, | ||
| 'max_lora_rank': LORA_RANK, # save as lora_config | ||
| # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 | ||
| # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) | ||
| # meaning only sync lora weights, if merge_and_sync=True, | ||
| # lora will be merged into the base model and sync all weights to vLLM | ||
| 'enable_lora': True, | ||
| 'enable_tower_connector_lora': True, | ||
| }, | ||
| device_mesh=sampler_mesh, | ||
| remote_group='sampler', | ||
| ) | ||
| sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True) | ||
|
|
||
| ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) | ||
|
|
||
| GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS | ||
| dataloader = DataLoader( | ||
| dataset=create_gsm8k_dataset, | ||
| batch_size=GLOBAL_BATCH_SIZE, | ||
| min_batch_size=GLOBAL_BATCH_SIZE, | ||
| device_mesh=model_mesh, | ||
| remote_group='model', | ||
| ) | ||
|
|
||
| advantage_fn = GRPOAdvantage() | ||
| metrics = CompletionRewardMetric() | ||
| sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95) | ||
|
|
||
| optim_step = 0 | ||
| logger.info('Starting GSM8K GRPO training (short reasoning)') | ||
| logger.info(get_device_placement()) | ||
|
|
||
| for batch in dataloader: | ||
| if optim_step >= MAX_STEPS: | ||
| break | ||
|
|
||
| metrics.reset() | ||
| expand_prompts = [] | ||
| for prompt in batch: | ||
| expand_prompts.extend([prompt] * NUM_GENERATIONS) | ||
|
|
||
| # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) | ||
| # meaning only sync lora weights, if merge_and_sync=True, | ||
| # lora will be merged into the base model and sync all weights to vLLM | ||
| ckpt_manager.sync_weights(merge_and_sync=False) | ||
| sampler.reset_prefix_cache() | ||
|
|
||
| sample_responses = sampler.sample( | ||
| expand_prompts, | ||
| sampling_params, | ||
| ) | ||
|
|
||
| all_input_data: List[Dict[str, Any]] = [] | ||
| all_old_logps: List[List[float]] = [] | ||
| all_completion_lengths: List[int] = [] | ||
|
|
||
| for sample_response in sample_responses: | ||
| for sequence in sample_response.sequences: | ||
| all_input_data.append(sequence.new_input_feature) | ||
| all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) | ||
| all_completion_lengths.append(len(sequence.tokens)) | ||
|
|
||
| total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data) | ||
|
|
||
| metrics.accumulate( | ||
| completion_lengths=all_completion_lengths, | ||
| rewards={ | ||
| 'total': total_rewards, | ||
| 'brevity': brevity_rewards, | ||
| 'accuracy': accuracy_rewards, | ||
| }, | ||
| ) | ||
|
|
||
| advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() | ||
|
|
||
| total_completions = len(all_input_data) | ||
| for mb_start in range(0, total_completions, MINI_BATCH_SIZE): | ||
| mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) | ||
| mb_inputs = all_input_data[mb_start:mb_end] | ||
| mb_old_logps = all_old_logps[mb_start:mb_end] | ||
| mb_advantages = advantages[mb_start:mb_end] | ||
|
|
||
| model.forward_backward( | ||
| inputs=mb_inputs, | ||
| old_logps=mb_old_logps, | ||
| advantages=mb_advantages, | ||
| micro_batch_size=MICRO_BATCH_SIZE, | ||
| ) | ||
| model.clip_grad_and_step() | ||
| optim_step += 1 | ||
|
|
||
| if optim_step >= MAX_STEPS: | ||
| break | ||
| if optim_step % SAVE_STEPS == 0: | ||
| model.save(f'math-grpo-checkpoint-{optim_step}') | ||
|
|
||
| log_dict = metrics.calculate() | ||
| log_dict.update(model.calculate_metric(is_training=True)) | ||
| metrics.reset() | ||
| logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') | ||
|
|
||
| logger.info(f'Training completed. optim_steps={optim_step}') | ||
| model.save('math-grpo-final') | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.