Skip to content
275 changes: 275 additions & 0 deletions cookbook/rl/short_math_grpo_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""GRPO training script for GSM8K dataset.

Converted from the Tinker client version to Ray-based training.
Uses short reasoning format: shorter thinking gets higher format reward.
Answer extracted from \\boxed{} or #### format.
"""
import os
import re
from typing import List, Tuple, Dict, Any

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor

logger = get_logger()

# ========== Configuration ==========
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
MODEL_EP = int(os.environ.get('MODEL_EP', 2))
MODEL_TP = int(os.environ.get('MODEL_TP', 2))
MODEL_PP = int(os.environ.get('MODEL_PP', 2))

SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
LORA_RANK = int(os.environ.get('LORA_RANK', 16))

SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')

# ========== Reward Functions ==========
class GSM8KBrevityReward(Reward):
"""Brevity reward: rewards shorter completions that contain a valid answer.

Returns 0.0 if no valid answer format (\\boxed{} or ####).
Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""

def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
for traj in trajectories:
messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break

has_answer = bool(
re.search(r'\\boxed\{[^}]+\}', completion)
or re.search(r'####\s*[\-\d,\.]+', completion)
)

if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
if length <= 200:
rewards.append(1.0)
else:
rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards


# ========== Dataset ==========
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset


def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = GSM8KAccuracyReward()
brevity_reward_fn = GSM8KBrevityReward()

accuracy_rewards = accuracy_reward_fn(trajectories)
brevity_rewards = brevity_reward_fn(trajectories)
total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
return total_rewards, brevity_rewards, accuracy_rewards


# ========== Main ==========
def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP),
]
dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP)
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True)
sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

lora_config = LoraConfig(
target_modules=['all-linear'],
r=LORA_RANK,
lora_alpha=LORA_RANK * 2,
lora_dropout=0.05,
)

if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
mixed_precision='bf16',
)
else:
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
Comment thread
hjh0119 marked this conversation as resolved.
Outdated
if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
else:
model.set_optimizer('AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'tensor_parallel_size': SAMPLER_TP,
'gpu_memory_utilization': 0.7,
'max_model_len': 10000,
'max_lora_rank': LORA_RANK, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
'enable_lora': True,
'enable_tower_connector_lora': True,
},
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
dataloader = DataLoader(
dataset=create_gsm8k_dataset,
batch_size=GLOBAL_BATCH_SIZE,
min_batch_size=GLOBAL_BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

advantage_fn = GRPOAdvantage()
metrics = CompletionRewardMetric()
sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95)

optim_step = 0
logger.info('Starting GSM8K GRPO training (short reasoning)')
logger.info(get_device_placement())

for batch in dataloader:
if optim_step >= MAX_STEPS:
break

metrics.reset()
expand_prompts = []
for prompt in batch:
expand_prompts.extend([prompt] * NUM_GENERATIONS)

# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

sample_responses = sampler.sample(
expand_prompts,
sampling_params,
)
if sample_responses and sample_responses[0].sequences:
first_decoded = sample_responses[0].sequences[0].decoded
if isinstance(first_decoded, str):
logger.info('[sample_debug] first_generation=%r', first_decoded[:512])

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_completion_lengths.append(len(sequence.tokens))

total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data)

metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
},
)

advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()

total_completions = len(all_input_data)
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
mb_inputs = all_input_data[mb_start:mb_end]
mb_old_logps = all_old_logps[mb_start:mb_end]
mb_advantages = advantages[mb_start:mb_end]

model.forward_backward(
inputs=mb_inputs,
old_logps=mb_old_logps,
advantages=mb_advantages,
micro_batch_size=MICRO_BATCH_SIZE,
)
model.clip_grad_and_step()
optim_step += 1

if optim_step >= MAX_STEPS:
break
if optim_step % SAVE_STEPS == 0:
model.save(f'math-grpo-checkpoint-{optim_step}')

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

logger.info(f'Training completed. optim_steps={optim_step}')
model.save('math-grpo-final')


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion src/twinkle/checkpoint_engine/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ def _expand_keys(keys):

if not self.base_sync_done:
self.base_sync_done = True
logger.info('Base model sync completed, subsequent syncs will be LoRA-only')
if not merge_and_sync:
logger.info('Base model sync completed, subsequent syncs will be LoRA-only')
4 changes: 4 additions & 0 deletions src/twinkle/patch/vllm_moe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __call__(self, model, **kwargs):
# (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader

# Early return if no MOE models are supported
# expected_lora_modules : up_proj -> experts.0.up_proj
from vllm.model_executor.models.qwen3_5 import Qwen3_5MoeForConditionalGeneration
Qwen3_5MoeForConditionalGeneration.is_3d_moe_weight = False
Comment thread
hjh0119 marked this conversation as resolved.
Outdated
Comment thread
hjh0119 marked this conversation as resolved.
Outdated

if not SUPPORTED_MOE_MODELS:
return

Expand Down
Loading
Loading