Skip to content
271 changes: 271 additions & 0 deletions cookbook/rl/short_math_grpo_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""GRPO training script for GSM8K dataset.

Converted from the Tinker client version to Ray-based training.
Uses short reasoning format: shorter thinking gets higher format reward.
Answer extracted from \\boxed{} or #### format.
"""
import os
import re
from typing import List, Tuple, Dict, Any

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor

logger = get_logger()

# ========== Configuration ==========
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
MODEL_EP = int(os.environ.get('MODEL_EP', 2))
MODEL_TP = int(os.environ.get('MODEL_TP', 2))
MODEL_PP = int(os.environ.get('MODEL_PP', 2))

SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
LORA_RANK = int(os.environ.get('LORA_RANK', 16))

SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')

# ========== Reward Functions ==========
class GSM8KBrevityReward(Reward):
"""Brevity reward: rewards shorter completions that contain a valid answer.

Returns 0.0 if no valid answer format (\\boxed{} or ####).
Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""

def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
for traj in trajectories:
messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break

has_answer = bool(
re.search(r'\\boxed\{[^}]+\}', completion)
or re.search(r'####\s*[\-\d,\.]+', completion)
)

if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
if length <= 200:
rewards.append(1.0)
else:
rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards


# ========== Dataset ==========
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset


def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = GSM8KAccuracyReward()
brevity_reward_fn = GSM8KBrevityReward()

accuracy_rewards = accuracy_reward_fn(trajectories)
brevity_rewards = brevity_reward_fn(trajectories)
total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
return total_rewards, brevity_rewards, accuracy_rewards


# ========== Main ==========
def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP),
]
dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP)
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True)
sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

lora_config = LoraConfig(
target_modules=['all-linear'],
r=LORA_RANK,
lora_alpha=LORA_RANK * 2,
lora_dropout=0.05,
)

if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
mixed_precision='bf16',
)
else:
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
Comment thread
hjh0119 marked this conversation as resolved.
Outdated
if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
else:
model.set_optimizer('AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'tensor_parallel_size': SAMPLER_TP,
'gpu_memory_utilization': 0.7,
'max_model_len': 8192,
'max_lora_rank': LORA_RANK, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
'enable_lora': True,
'enable_tower_connector_lora': True,
},
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
dataloader = DataLoader(
dataset=create_gsm8k_dataset,
batch_size=GLOBAL_BATCH_SIZE,
min_batch_size=GLOBAL_BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

advantage_fn = GRPOAdvantage()
metrics = CompletionRewardMetric()
sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95)

optim_step = 0
logger.info('Starting GSM8K GRPO training (short reasoning)')
logger.info(get_device_placement())

for batch in dataloader:
if optim_step >= MAX_STEPS:
break

metrics.reset()
expand_prompts = []
for prompt in batch:
expand_prompts.extend([prompt] * NUM_GENERATIONS)

# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

sample_responses = sampler.sample(
expand_prompts,
sampling_params,
)

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_completion_lengths.append(len(sequence.tokens))

total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data)

metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
},
)

advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()

total_completions = len(all_input_data)
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
mb_inputs = all_input_data[mb_start:mb_end]
mb_old_logps = all_old_logps[mb_start:mb_end]
mb_advantages = advantages[mb_start:mb_end]

model.forward_backward(
inputs=mb_inputs,
old_logps=mb_old_logps,
advantages=mb_advantages,
micro_batch_size=MICRO_BATCH_SIZE,
)
model.clip_grad_and_step()
optim_step += 1

if optim_step >= MAX_STEPS:
break
if optim_step % SAVE_STEPS == 0:
model.save(f'math-grpo-checkpoint-{optim_step}')

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

logger.info(f'Training completed. optim_steps={optim_step}')
model.save('math-grpo-final')


if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions src/twinkle/patch/vllm_moe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __call__(self, model, **kwargs):
# (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader

# Early return if no MOE models are supported
# expected_lora_modules : up_proj -> experts.0.up_proj
from vllm.model_executor.models.qwen3_5 import Qwen3_5MoeForConditionalGeneration
Qwen3_5MoeForConditionalGeneration.is_3d_moe_weight = False
Comment thread
hjh0119 marked this conversation as resolved.
Outdated
Comment thread
hjh0119 marked this conversation as resolved.
Outdated

if not SUPPORTED_MOE_MODELS:
return

Expand Down
25 changes: 24 additions & 1 deletion src/twinkle/sampler/vllm_sampler/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import inspect
import os
import re
import torch
import uuid
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -512,7 +513,14 @@ async def _sync_iter():
sync_id = uuid.uuid4().hex
zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}-{os.getpid()}-{sync_id}.sock'

env_bucket_mb = os.environ.get('TWINKLE_VLLM_BUCKET_SIZE_MB')
if env_bucket_mb is not None:
bucket_size_mb = int(env_bucket_mb)
if bucket_size_mb <= 0:
raise ValueError(f'bucket_size_mb must be > 0, got {bucket_size_mb}')

bucket_size = bucket_size_mb << 20
lora_mode = bool(base_sync_done and peft_config)

# Create transfer buffer
buffer = None
Expand Down Expand Up @@ -575,9 +583,14 @@ async def _chain_first():
offset = 0
bucket_meta: list[dict] = []
n_weights = 0
current_expert_layer: Optional[str] = None

def _get_expert_layer_prefix(weight_name: str) -> Optional[str]:
m = re.match(r'^(.*\.mlp\.experts)\.\d+\.', weight_name)
return m.group(1) if m else None

async def _flush_bucket(is_last: bool) -> None:
nonlocal offset, bucket_meta
nonlocal offset, bucket_meta, current_expert_layer
if not bucket_meta and not is_last:
return
if buffer.device.type != 'cpu':
Expand All @@ -593,6 +606,7 @@ async def _flush_bucket(is_last: bool) -> None:
)
offset = 0
bucket_meta = []
current_expert_layer = None

async for name, weight in _chain_first():
if use_shm and weight.device.type != 'cpu':
Expand All @@ -602,6 +616,15 @@ async def _flush_bucket(is_last: bool) -> None:

weight_u8 = weight.view(-1).view(torch.uint8)
total_nbytes = int(weight_u8.numel())
expert_layer_prefix = _get_expert_layer_prefix(name) if lora_mode else None
if lora_mode and offset > 0:
# Keep each expert layer in an isolated bucket to avoid sending
# partial expert-layer weights.
if current_expert_layer != expert_layer_prefix:
await _flush_bucket(is_last=False)
if lora_mode:
current_expert_layer = expert_layer_prefix

chunk_offset = 0
while chunk_offset < total_nbytes:
if offset >= bucket_size:
Expand Down
Loading
Loading