From 0d234f3e3a83c81b7d37209994e0b994deacbb24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 15:16:36 +0800 Subject: [PATCH 01/14] feat: add Multi-teacher On-Policy Distillation (MOPD) support - Add MOPD rollout module (slime/rollout/mopd.py) for multi-teacher distillation - Add MOPD loss computation in megatron backend (KL divergence based) - Add MOPD-related arguments (teacher config, distillation params) - Add ray rollout integration for MOPD pipeline - Add example scripts for Qwen3.5-35B-A3B MOPD training - Add README documentation for MOPD feature - Add unit tests for MOPD functionality --- .../README.md | 226 +++++++ .../run-qwen35-35B-A3B-mopd-megatron.sh | 234 ++++++++ .../run-qwen35-35B-A3B-mopd-sglang.sh | 264 +++++++++ slime/backends/megatron_utils/actor.py | 86 ++- slime/backends/megatron_utils/data.py | 5 + slime/backends/megatron_utils/loss.py | 177 ++++++ slime/ray/placement_group.py | 5 +- slime/ray/rollout.py | 22 + slime/rollout/mopd.py | 246 ++++++++ slime/rollout/rm_hub/__init__.py | 5 + slime/utils/arguments.py | 199 +++++++ slime/utils/types.py | 1 + tests/test_mopd.py | 555 ++++++++++++++++++ 13 files changed, 2021 insertions(+), 4 deletions(-) create mode 100644 examples/multi_teacher_on_policy_distillation/README.md create mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh create mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh create mode 100644 slime/rollout/mopd.py create mode 100644 tests/test_mopd.py diff --git a/examples/multi_teacher_on_policy_distillation/README.md b/examples/multi_teacher_on_policy_distillation/README.md new file mode 100644 index 0000000000..335f9c17ed --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/README.md @@ -0,0 +1,226 @@ +# Multi-Teacher On-Policy Distillation (MOPD) Example + +This example shows how to run **multi-teacher on-policy distillation (MOPD)** using slime. MOPD extends OPD to support multiple domain-specific teachers, enabling a single student model to simultaneously learn from several experts (e.g., a math teacher and a code teacher) while using importance sampling (IS) for stable off-policy training. + +## Key Features + +- **Multi-teacher distillation**: Aggregate knowledge from multiple domain-specific teachers into a single student, with per-teacher reverse KL advantages averaged across domains. +- **Importance sampling (IS) correction**: Clipped IS weights `w_t = sg[π_θ/μ_θ]` ensure stable training when the student diverges from the sampling policy. +- **ORM combination**: Optional coefficient `α` blends reverse KL advantages with standard ORM advantages: `Â_MOPD,t = sg[log(π_domain/π_θ)] + α · Â_ORM`. +- **Two teacher modes** (same as OPD): + - **sglang**: Teachers run on external SGLang servers, teacher log-probs are obtained during rollout. + - **megatron**: Teachers are loaded directly into Megatron via `--mopd-teacher-loads`, teacher log-probs are computed during the training forward pass. + +## Algorithm + +For each teacher domain *d*, MOPD computes: + +``` +reverse_kl_d = sg[log π_d(y_t) - log π_θ(y_t)] # per-teacher reverse KL +w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight +Â_MOPD,t = (1/D) Σ_d (reverse_kl_d + α · Â_ORM) # averaged across D teachers +L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)] # proxy policy loss +``` + +## Key Arguments + +| Argument | Description | +|----------|-------------| +| `--use-mopd` | Enable multi-teacher on-policy distillation. Mutually exclusive with `--use-opd`. | +| `--mopd-teachers` | JSON list of teacher configs, each with `name` and `domain` (required). Example: `'[{"name":"math_t","domain":"math"},{"name":"code_t","domain":"code"}]'` | +| `--mopd-teacher-loads` | Space-separated checkpoint paths for megatron-mode teachers. Must match the number of teachers in `--mopd-teachers`. | +| `--mopd-teacher-ckpt-steps` | Optional checkpoint steps for each teacher model. Must match the number of teachers. | +| `--mopd-alpha` | Coefficient for combining MOPD advantage with ORM advantage (default: 0.0). Set to 0 for pure distillation, >0 for ORM combination. | +| `--mopd-eps-low` | IS weight lower bound for clipping (default: 0.2). Weights below this are zeroed. | +| `--mopd-eps-high` | IS weight upper bound for clipping (default: 5.0). Weights above this are zeroed. | +| `--mopd-sampling-logprobs-key` | Key in rollout_data for sampling log-probs used in IS weight computation (default: `rollout_log_probs`). | + +## SGLang vs Megatron Mode + +| Mode | Teacher Location | When to use | +|------|------------------|-------------| +| `sglang` | External SGLang servers (one per teacher) | Teachers have different architecture or are too large for training GPU memory | +| `megatron` | Loaded into Megatron training process | Teachers have the same architecture as the policy/ref model | + +### SGLang Mode + +- Each teacher runs as an independent SGLang server. +- Teacher URLs are configured via the `MOPD_TEACHER_URLS` environment variable (JSON dict: `domain -> URL`) or via the `rm_url` field in each teacher config in `--mopd-teachers`. +- `--custom-rm-path slime.rollout.mopd.reward_func` and `--custom-reward-post-process-path slime.rollout.mopd.post_process_rewards` are required. +- `--rm-url` serves as a fallback URL if no per-teacher URL is configured. + +### Megatron Mode + +- Teacher models are loaded into CPU memory via `TensorBackuper` and switched to GPU for forward passes during training. +- Requires `--enable-weights-backuper` (default) for weight backup/restore. +- Each teacher must have the **same architecture** as the policy model. +- Memory note: each teacher model occupies additional CPU memory for weight backup. + +## Components + +- `slime/rollout/mopd.py` implements SGLang-mode MOPD: + - `reward_func`: queries all teacher SGLang servers concurrently, returns per-domain responses. + - `post_process_rewards`: extracts token-level teacher log-probs from responses and stores them in `sample.mopd_teacher_log_probs`. +- `slime/backends/megatron_utils/loss.py`: + - `apply_mopd_to_advantages`: computes per-teacher reverse KL, IS weights, and aggregated MOPD advantages. + - `policy_loss_function`: applies `mopd_advantages` and IS weights to the policy gradient loss. +- `run-qwen3-8B-mopd-sglang.sh`: launches SGLang teacher servers, then submits a Ray job. +- `run-qwen3-8B-mopd-megatron.sh`: uses Megatron-loaded teacher models (no external server needed). + +## Running the Example + +### Using SGLang Teachers (External Servers) + +1. Download or prepare the required checkpoints and data: +```bash +hf download Qwen/Qwen3-32B --local-dir /root/Qwen3-32B +hf download Qwen/Qwen3-Coder-32B --local-dir /root/Qwen3-Coder-32B +hf download Qwen/Qwen3-8B --local-dir /root/Qwen3-8B +hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/dapo-math-17k +``` + +2. Convert student model to Megatron format: +```bash +cd /root/slime +source scripts/models/qwen3-8B.sh + +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-8B \ + --save /root/Qwen3-8B_torch_dist +``` + +3. Run MOPD with SGLang teachers: +```bash +bash examples/multi_teacher_on_policy_distillation/run-qwen3-8B-mopd-sglang.sh +``` + +The script will: +- Launch math and code teacher SGLang servers automatically +- Set `MOPD_TEACHER_URLS` environment variable +- Submit the training job via Ray + +### Using Megatron Teachers (No External Server) + +1. Prepare student checkpoint (same as above). + +2. Convert teacher models to Megatron format: +```bash +cd /root/slime +source scripts/models/qwen3-8B.sh # Or your teacher model config + +# Math teacher +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-8B-Math \ + --save /root/Qwen3-8B-Math_torch_dist + +# Code teacher +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-8B-Code \ + --save /root/Qwen3-8B-Code_torch_dist +``` + +> **Note**: This example uses the same model architecture for student and teachers. In practice, use **stronger** models as teachers. + +3. Edit `run-qwen3-8B-mopd-megatron.sh` to update paths: + - Change `--mopd-teacher-loads` to your teacher model paths + - Adjust `--mopd-alpha`, `--mopd-eps-low`, `--mopd-eps-high` for your task + +4. Run: +```bash +bash examples/multi_teacher_on_policy_distillation/run-qwen3-8B-mopd-megatron.sh +``` + +## Customization + +### Adding More Teachers + +Add entries to `--mopd-teachers` JSON and corresponding paths: + +```bash +--mopd-teachers '[{"name":"math_t","domain":"math"},{"name":"code_t","domain":"code"},{"name":"reason_t","domain":"reasoning"}]' +--mopd-teacher-loads /path/to/math_ckpt /path/to/code_ckpt /path/to/reasoning_ckpt +``` + +For SGLang mode, add the URL to `MOPD_TEACHER_URLS`: +```bash +export MOPD_TEACHER_URLS='{"math":"http://...","code":"http://...","reasoning":"http://..."}' +``` + +### Mixing Distillation with Task Rewards + +Set `--mopd-alpha > 0` to blend reverse KL advantages with standard ORM advantages: +```bash +--mopd-alpha 0.5 # Equal weight between distillation and task reward +--rm-type math # Required when alpha > 0: provides ORM reward signal +``` + +**Reward model requirements:** +- `--mopd-alpha 0.0` (pure distillation): No reward model needed. If `--rm-type` and `--custom-rm-path` are both unset, it defaults to `zero` (always returns 0.0). The learning signal comes entirely from the distillation KL advantages. +- `--mopd-alpha > 0` (distillation + ORM): You **must** set `--rm-type` or `--custom-rm-path`, otherwise an error is raised, because ORM advantages require a reward signal. + +### Tuning IS Weight Clipping + +The IS weight clipping bounds control the trade-off between bias and variance: +- Tighter bounds (e.g., `[0.5, 2.0]`): Lower variance but more bias +- Looser bounds (e.g., `[0.1, 10.0]`): Less bias but higher variance + +### Per-Sample Domain Routing + +By default, every sample is distilled from **all** configured teachers. For datasets where different samples belong to different domains (e.g., math problems should only learn from the math teacher, code problems from the code teacher), you can specify per-sample routing via the `mopd_domains` field in sample metadata. + +#### Data Format + +Add a `mopd_domains` field in the `metadata` of your JSONL data: + +```jsonl +{"prompt": "Solve: x^2 - 5x + 6 = 0", "metadata": {"mopd_domains": ["math"]}} +{"prompt": "Write a Python quicksort", "metadata": {"mopd_domains": ["code"]}} +{"prompt": "Explain quantum mechanics", "metadata": {"mopd_domains": ["math", "code"]}} +{"prompt": "General question"} +``` + +- `"mopd_domains": ["math"]` — only distill from the math teacher +- `"mopd_domains": ["code"]` — only distill from the code teacher +- `"mopd_domains": ["math", "code"]` — distill from both (equivalent to default) +- No `mopd_domains` field — distill from **all** teachers (backward compatible) + +For string convenience, you can also use a single string instead of a list: +```jsonl +{"prompt": "Solve: x^2 - 5x + 6 = 0", "metadata": {"mopd_domains": "math"}} +``` + +#### How It Works + +- **SGLang mode**: `reward_func` only queries the specified teacher servers, saving compute on unnecessary inference. +- **Megatron mode**: All teachers still run forward passes (no way to skip per-sample), but `apply_mopd_to_advantages` uses zero advantages for non-matching domains, effectively excluding them from the loss. + +## Differences from OPD + +| Feature | OPD | MOPD | +|---------|-----|------| +| Number of teachers | 1 | Multiple (configurable) | +| Advantage computation | `KL = log(π_T/π_θ)`, added to loss | Per-teacher reverse KL, averaged across domains | +| IS weight correction | Not included | Clipped IS weight `w_t ∈ [ε_low, ε_high]` | +| ORM combination | Via `--opd-kl-coef` | Via `--mopd-alpha` | +| Mutual exclusivity | `--use-opd` | `--use-mopd` (cannot use both) | + +## FAQ + +1. **Can I use MOPD with OPD at the same time?** + No. `--use-mopd` and `--use-opd` are mutually exclusive. Use MOPD if you need multiple teachers. + +2. **Do all teachers need to have the same architecture?** + - Megatron mode: Yes, all teachers must share the same architecture as the policy model. + - SGLang mode: No, each teacher can have a different architecture since they run on separate servers. + +3. **How much extra memory does MOPD need in megatron mode?** + Each teacher model requires CPU memory for weight backup (via `TensorBackuper`). The teacher weights are only loaded to GPU temporarily during the forward pass, then restored to CPU. Plan for `N × model_size` additional CPU memory where `N` is the number of teachers. + +4. **What happens if a teacher server fails in SGLang mode?** + The `reward_func` will log a warning and skip the failed teacher for that sample. The training will continue with remaining teachers, but the advantages will be biased. Monitor teacher server health carefully. + +5. **Why is `--group-rm` not supported with MOPD?** + MOPD's `reward_func` returns per-domain dicts (not scalar rewards), which is incompatible with the batch `group_rm` reward path. Use the default per-sample reward path (no `--group-rm`). \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh new file mode 100644 index 0000000000..c8ded7428b --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh @@ -0,0 +1,234 @@ +#!/bin/bash + +# Multi-Teacher On-Policy Distillation (MOPD) — Single Teacher Connectivity Test +# Model: Qwen3.5-35B-A3B (MoE, 256 experts, 8 active) +# Environment: 8× H20 (143GB) +# Teacher: Same as student (self-distillation for connectivity validation only) +# Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) +# +# This script is for MOPD E2E connectivity validation only. +# In production, use a DIFFERENT (stronger) model as teacher. +# +# Parallelism: TP=2, EP=8 (matches SFT config, 256 experts / 8 = 32 per GPU) +# Colocate mode: rollout and training share all 8 GPUs with offloading +# +# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen3.5-35B-A3B-mopd-megatron.sh + +# ============================================================================ +# Cleanup: kill existing SGLang / Ray / Python processes +# ============================================================================ +pkill -9 sglang +sleep 3 +ray stop --force 2>/dev/null || true +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 +export FLASHINFER_DISABLE_VERSION_CHECK=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +source "/mntfn/yanyi/code/slime/scripts/models/qwen3.5-35B-A3B.sh" + +# MOPD teachers JSON config +# Set as environment variable; arguments.py reads $MOPD_TEACHERS_JSON +# when --mopd-teachers is not provided on the command line. +# This avoids shell quoting issues when passing JSON through ray job submit. +export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' + +# ============================================================================ +# Configure training arguments +# ============================================================================ + +# IMPORTANT: Before running this script, convert the HF checkpoint to Megatron +# torch_dist format: +# +# cd /mntfn/yanyi/code/slime +# source scripts/models/qwen3.5-35B-A3B.sh +# +# PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ +# ${MODEL_ARGS[@]} \ +# --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B \ +# --save /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist + +CKPT_ARGS=( + --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B/ + --ref-load /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + --load /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ + --save /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ + --save-interval 10 +) + +ROLLOUT_ARGS=( + --prompt-data /mntfn/yanyi/dataset/train_text_user_only.jsonl + --input-key messages + --apply-chat-template + --rollout-shuffle + --rollout-batch-size 16 + --n-samples-per-prompt 1 # No need for multiple samples in pure distillation + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 16 + --balance-data + --num-epoch 1 +) + +RM_ARGS=( + # Pure distillation (mopd-alpha=0): rm-type defaults to "zero" automatically. + # No reward model needed. +) + +EVAL_ARGS=( + # No eval for connectivity test +) + +# Qwen3.5-35B-A3B with 8 GPUs (same parallelism as SFT config): +# TP=2, EP=8 (256 experts / 8 = 32 experts per GPU) +# Colocate mode: rollout and training share all 8 GPUs with offloading +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 8192 +) + +# MOPD Configuration (Megatron mode, single teacher) +# For this connectivity test, the teacher IS the same model (self-distillation). +# This validates the full MOPD pipeline: rollout → teacher log-prob → advantage → train. +# +# Key: The teacher checkpoint must be in Megatron torch_dist format. +# Since teacher = student here, we use the same torch_dist path. +# +# Memory note: The teacher model weights are backed up to CPU memory via +# TensorBackuper. For Qwen3.5-35B-A3B, expect ~70GB additional CPU RAM usage. +MOPD_ARGS=( + --advantage-estimator grpo + + # MOPD flags — single teacher + --use-mopd + # Pass JSON via env var MOPD_TEACHERS_JSON to avoid shell quoting issues with ray job submit. + # If --mopd-teachers is not set, arguments.py falls back to $MOPD_TEACHERS_JSON. + + # Teacher checkpoint = same as ref model (self-distillation for validation) + --mopd-teacher-loads /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + + # MOPD hyperparameters + --mopd-alpha 0.0 # Pure distillation, no ORM + --mopd-eps-low 0.2 # IS weight lower bound + --mopd-eps-high 5.0 # IS weight upper bound + --mopd-sampling-logprobs-key rollout_log_probs + + # Standard training flags + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 5e-7 # Conservative LR for stability + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + #--use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3.5-35B-mopd-megatron + # --wandb-key ${WANDB_KEY} +) + +# SGLang rollout config (colocate mode, shares training GPUs) +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 # All 8 GPUs for rollout + --sglang-mem-fraction-static 0.4 # Share GPU memory with training + --sglang-ep-size 8 # Match EP=8 for MoE expert parallelism +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + + # MoE communication + --moe-token-dispatcher-type flex + --moe-enable-deepep + + # Colocate: rollout and training share same GPUs, with offloading + --colocate +) + +# ============================================================================ +# Launch training +# ============================================================================ + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" + +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON=$(python3 -c " +import json, os +env = { + 'PYTHONPATH': '/root/Megatron-LM/', + 'CUDA_DEVICE_MAX_CONNECTIONS': '1', + 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), + 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') +} +print(json.dumps({'env_vars': env})) +") + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${MOPD_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${RM_ARGS[@]} + +# ============================================================================ +# Cleanup +# ============================================================================ +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh new file mode 100644 index 0000000000..22ab385e8f --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh @@ -0,0 +1,264 @@ +#!/bin/bash + +# Multi-Teacher On-Policy Distillation (MOPD) — Single Teacher SGLang Mode +# Model: Qwen3.5-35B-A3B (MoE, 256 experts, 8 active) +# Environment: 8× H20 (143GB) +# Layout: 4 GPUs for SGLang rollout, 4 GPUs for Megatron training +# Teacher: Same as student (self-distillation for connectivity validation only) +# Mode: SGLang (teacher runs on external SGLang server, no architecture constraint) +# +# This script is for MOPD E2E connectivity validation only. +# In production, use a DIFFERENT (stronger) model as teacher. +# +# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen3.5-35B-A3B-mopd-sglang.sh + +# ============================================================================ +# Cleanup: kill existing SGLang / Ray / Python processes +# ============================================================================ +pkill -9 sglang +sleep 3 +ray stop --force 2>/dev/null || true +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 + +# ============================================================================ +# 1. Configure and start teacher model server (self-distillation for testing) +# ============================================================================ +# For this connectivity test, the teacher is the same model as the student. +# In production, replace with a stronger model (e.g., Qwen3-72B or domain expert). +TEACHER_IP="127.0.0.1" +TEACHER_PORT=13141 +TEACHER_LOG_FILE="/tmp/sglang_teacher_$(head /dev/urandom | tr -dc A-Za-z0-9 | head -c 6).log" + +# Launch teacher on GPU 0-3 (4 GPUs for TP=4, or adjust TP as needed) +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server \ + --model-path /mnt4/data/open_source/Qwen3.5-35B-A3B/ \ + --host 0.0.0.0 \ + --port $TEACHER_PORT \ + --tp 4 \ + --ep-size 4 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static 0.7 \ + > "$TEACHER_LOG_FILE" 2>&1 & + +TEACHER_PID=$! +echo "Starting teacher model server (PID: $TEACHER_PID)..." + +# Wait for teacher server to be ready +until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health_generate > /dev/null; do + echo "Waiting for teacher model server to start..." + tail -n 10 "$TEACHER_LOG_FILE" 2>/dev/null || true + sleep 10 +done +echo "Teacher model server is up and running at $TEACHER_IP:$TEACHER_PORT." + +# ============================================================================ +# 2. Set environment variables +# ============================================================================ + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +source "/mntfn/yanyi/code/slime/scripts/models/qwen3.5-35B-A3B.sh" + +# MOPD teachers JSON config +export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' + +# MOPD teacher URLs +export MOPD_TEACHER_URLS="{\"default\":\"http://$TEACHER_IP:$TEACHER_PORT/generate\"}" + +# ============================================================================ +# 3. Configure training arguments +# ============================================================================ + +# IMPORTANT: Before running this script, convert the HF checkpoint to Megatron +# torch_dist format: +# +# cd /mntfn/yanyi/code/slime +# source scripts/models/qwen3.5-35B-A3B.sh +# +# PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ +# ${MODEL_ARGS[@]} \ +# --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B \ +# --save /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist + +CKPT_ARGS=( + --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B/ + --ref-load /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + --load /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ + --save /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ + --save-interval 10 +) + +ROLLOUT_ARGS=( + --prompt-data /mntfn/yanyi/dataset/train_text_user_only.jsonl + --input-key messages + --apply-chat-template + --rollout-shuffle + --num-rollout 10 # Small for connectivity test + --rollout-batch-size 16 + --n-samples-per-prompt 1 # No need for multiple samples in pure distillation + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 16 + --balance-data +) + +# For MOPD SGLang mode, we use the MOPD reward_func and post_process_rewards +# The --rm-url is used as the default/fallback URL; per-teacher URLs come from MOPD_TEACHER_URLS env var +RM_ARGS=( + --custom-rm-path slime.rollout.mopd.reward_func + --custom-reward-post-process-path slime.rollout.mopd.post_process_rewards + --rm-url http://$TEACHER_IP:$TEACHER_PORT/generate +) + +EVAL_ARGS=( + # No eval for connectivity test +) + +# Qwen3.5-35B-A3B with 4 GPUs for training: +# TP=2, EP=4 (256 experts / 4 = 64 experts per GPU) +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 4 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 8192 +) + +# MOPD Configuration (SGLang mode, single teacher) +# In SGLang mode, teacher log-probs are obtained by querying the teacher SGLang server +# during rollout. No teacher model is loaded into Megatron training memory. +MOPD_ARGS=( + --advantage-estimator grpo + + # MOPD flags — single teacher + --use-mopd + # Note: --mopd-teachers is read from $MOPD_TEACHERS_JSON env var (see above) + # to avoid shell quoting issues with JSON in ray job submit. + + # No --mopd-teacher-loads needed in SGLang mode! + # Teacher log-probs come from the SGLang server via reward_func. + + # MOPD hyperparameters + --mopd-alpha 0.0 # Pure distillation, no ORM + --mopd-eps-low 0.2 # IS weight lower bound + --mopd-eps-high 5.0 # IS weight upper bound + --mopd-sampling-logprobs-key rollout_log_probs + + # Standard training flags + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 5e-7 # Conservative LR for stability + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + #--use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3.5-35B-mopd-sglang + # --wandb-key ${WANDB_KEY} +) + +# SGLang rollout config: 4 GPUs for rollout +SGLANG_ARGS=( + --rollout-num-gpus 4 # 4 GPUs for SGLang rollout engine + --rollout-num-gpus-per-engine 4 # 4 GPUs per engine (TP=4 for Qwen3.5-35B-A3B) + --sglang-mem-fraction-static 0.7 + --sglang-ep-size 4 # Match EP=4 for MoE expert parallelism +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + + # MoE communication + --moe-token-dispatcher-type flex + --moe-enable-deepep +) + +# ============================================================================ +# 4. Launch training +# ============================================================================ + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" + +# 8 GPUs total: 4 for SGLang rollout (GPU 0-3, already used by teacher server), +# 4 for Megatron training (GPU 4-7) +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON=$(python3 -c " +import json, os +env = { + 'PYTHONPATH': '/root/Megatron-LM/', + 'CUDA_DEVICE_MAX_CONNECTIONS': '1', + 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), + 'MOPD_TEACHER_URLS': os.environ.get('MOPD_TEACHER_URLS', ''), + 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') +} +print(json.dumps({'env_vars': env})) +") + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${MOPD_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${RM_ARGS[@]} + +# ============================================================================ +# 5. Cleanup +# ============================================================================ +kill $TEACHER_PID 2>/dev/null || true +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python \ No newline at end of file diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 586f88dedb..63ce2ed7b0 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -1,3 +1,4 @@ +import json import logging import os import random @@ -108,9 +109,20 @@ def init( self.load_other_checkpoint("ref", args.ref_load) # Load teacher model for Megatron-based on-policy distillation - if with_opd_teacher: + if with_opd_teacher and not getattr(args, "use_mopd", False): self.load_other_checkpoint("teacher", args.opd_teacher_load) + # Load multiple teacher models for Megatron-based MOPD + self._mopd_teacher_domains: list[str] = [] + if getattr(args, "use_mopd", False) and getattr(args, "mopd_teacher_loads", None): + mopd_teachers = json.loads(args.mopd_teachers) if isinstance(args.mopd_teachers, str) else args.mopd_teachers + for i, teacher_cfg in enumerate(mopd_teachers): + domain = teacher_cfg["domain"] + tag = f"mopd_teacher_{domain}" + self._mopd_teacher_domains.append(domain) + self.load_other_checkpoint(tag, args.mopd_teacher_loads[i]) + logger.info(f"Loaded MOPD teacher model for domain '{domain}' from {args.mopd_teacher_loads[i]}") + if self.args.keep_old_actor: # Load old_actor checkpoint self.load_other_checkpoint("old_actor", args.load) @@ -133,12 +145,15 @@ def init( quantization_config=getattr(self.hf_config, "quantization_config", None), ) + # Ensure actor weights are on GPU and _active_model_tag is correct + # after loading ref/teacher/mopd_teacher/old_actor checkpoints. + if self._active_model_tag != "actor": + self._switch_model("actor") + # empty cache after initialization clear_memory() if self.args.offload_train: - # recover to actor in the end. - self._switch_model("actor") self.sleep() self.rollout_engines = None @@ -252,6 +267,39 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: ) ) ] + + # Process MOPD teacher log_probs (dict: domain -> list) + # Some entries may be None due to per-sample domain routing (SGLang mode). + if "mopd_teacher_log_probs" in rollout_data: + mopd_lp_dict = rollout_data["mopd_teacher_log_probs"] + processed = {} + for domain, lp_list in mopd_lp_dict.items(): + processed[domain] = [ + ( + None + if log_prob is None + else torch.tensor( + slice_log_prob_with_cp( + log_prob, + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ), + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + for i, (log_prob, total_length, response_length) in enumerate( + zip( + lp_list, + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, + ) + ) + ] + rollout_data["mopd_teacher_log_probs"] = processed if "rollout_routed_experts" in rollout_data: rollout_data["rollout_routed_experts"] = [ torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"] @@ -439,6 +487,27 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data ) ) + # Forward each MOPD teacher model for Megatron-based MOPD + if getattr(self.args, "use_mopd", False) and hasattr(self, "_mopd_teacher_domains") and self._mopd_teacher_domains: + mopd_teacher_log_probs = {} + for domain in self._mopd_teacher_domains: + tag = f"mopd_teacher_{domain}" + if tag in self.weights_backuper.backup_tags: + if self.args.use_routing_replay: + os.environ["ROUTING_REPLAY_STAGE"] = "fallthrough" + self._switch_model(tag) + teacher_result = self.compute_log_prob( + data_iterator, + num_microbatches, + store_prefix=f"mopd_teacher_{domain}_", + ) + # Store with domain-specific key + lp_key = f"mopd_teacher_{domain}_log_probs" + if lp_key in teacher_result: + mopd_teacher_log_probs[domain] = teacher_result[lp_key] + if mopd_teacher_log_probs: + rollout_data["mopd_teacher_log_probs"] = mopd_teacher_log_probs + self._switch_model("old_actor" if self.args.keep_old_actor else "actor") can_reuse_log_probs_in_loss = ( len(num_microbatches) == 1 @@ -449,6 +518,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data and not self.args.use_critic and not self.args.keep_old_actor and not self.args.use_opd + and not getattr(self.args, "use_mopd", False) and not self.args.use_routing_replay and self.args.advantage_estimator != "gspo" ) @@ -631,6 +701,16 @@ def load_other_checkpoint(self, model_tag: str, path: str) -> None: elif model_tag == "teacher" and self.args.opd_teacher_ckpt_step is not None: old_ckpt_step = self.args.ckpt_step self.args.ckpt_step = self.args.opd_teacher_ckpt_step + elif model_tag.startswith("mopd_teacher_"): + # MOPD teacher checkpoint step: look up from mopd_teacher_ckpt_steps by domain + domain = model_tag[len("mopd_teacher_"):] + if getattr(self.args, "mopd_teacher_ckpt_steps", None) is not None: + mopd_teachers = json.loads(self.args.mopd_teachers) if isinstance(self.args.mopd_teachers, str) else self.args.mopd_teachers + for i, t in enumerate(mopd_teachers): + if t["domain"] == domain and i < len(self.args.mopd_teacher_ckpt_steps): + old_ckpt_step = self.args.ckpt_step + self.args.ckpt_step = self.args.mopd_teacher_ckpt_steps[i] + break _, _ = load_checkpoint( self.model, diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 19db1f475a..7a0b0fd0b7 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -416,6 +416,9 @@ def log_rollout_data( "rollout_routed_experts", "max_seq_lens", "dynamic_global_batch_size", + # Dict-typed keys that cannot be averaged directly + "mopd_teacher_log_probs", + "mopd_reverse_kl", ]: continue # Upload per sample mean for each rollout value @@ -434,6 +437,8 @@ def log_rollout_data( "values", "teacher_log_probs", "opd_reverse_kl", + "mopd_advantages", + "mopd_is_weights", ]: val = torch.cat(val).clone().detach() sum_of_sample_mean = get_sum_of_sample_mean( diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 256338fca2..10c2be716f 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -568,6 +568,145 @@ def apply_opd_kl_to_advantages( rollout_data["opd_reverse_kl"] = reverse_kls +def apply_mopd_to_advantages( + args: Namespace, + rollout_data: RolloutBatch, + advantages: list[torch.Tensor], + student_log_probs: list[torch.Tensor] | None, +) -> None: + """Apply Multi-Teacher On-Policy Distillation (MOPD) to advantages. + + MOPD computes per-teacher reverse KL advantages and importance sampling weights, + then applies a weighted proxy loss. The core formulas are: + + Â_MOPD,t = sg[log(π_domain(y_t|x,y_ 0`, the ORM advantage is combined: + Â_MOPD,t = sg[log(π_domain/π_θ)] + α * Â_ORM + + Args: + args: Configuration containing `use_mopd`, `mopd_alpha`, `mopd_eps_low`, `mopd_eps_high`, + and `mopd_sampling_logprobs_key`. + rollout_data: Dict containing "mopd_teacher_log_probs" (dict: domain -> list[Tensor]) + and optionally the sampling log-probs key. + advantages: List of advantage tensors to modify in-place. + student_log_probs: List of student (training) log-probability tensors. + """ + + if student_log_probs is None: + return + + mopd_teacher_log_probs: dict[str, list[torch.Tensor]] = rollout_data.get("mopd_teacher_log_probs") + if not mopd_teacher_log_probs: + raise ValueError( + "MOPD requires mopd_teacher_log_probs in rollout_data, but it is missing. " + "Ensure teacher log-probs are collected during rollout or training." + ) + + # Get sampling log-probs μ_θ for importance sampling weight + sampling_logprobs_key = args.mopd_sampling_logprobs_key + sampling_log_probs = rollout_data.get(sampling_logprobs_key) + if sampling_log_probs is None and sampling_logprobs_key == "rollout_log_probs": + # Fall back to old_log_probs (which may be rollout_log_probs depending on config) + sampling_log_probs = rollout_data.get("log_probs") + if sampling_log_probs is None: + raise ValueError( + f"MOPD requires '{sampling_logprobs_key}' in rollout_data for importance sampling, " + f"but it is missing." + ) + + device = student_log_probs[0].device + sampling_log_probs = [s.to(device=device) for s in sampling_log_probs] + + # Compute MOPD advantages from each teacher and aggregate + # For each teacher, compute reverse KL and IS weights, then sum weighted advantages + all_mopd_advantages = [] + all_is_weights_list = [] + all_reverse_kls = [] + + for domain, teacher_lp_list in mopd_teacher_log_probs.items(): + domain_advantages = [] + domain_is_weights = [] + domain_reverse_kls = [] + + for i in range(len(advantages)): + # If this sample has no teacher log-probs for this domain (per-sample routing), + # use zeros as placeholder — this domain contributes nothing to this sample. + if teacher_lp_list[i] is None: + domain_advantages.append(None) + domain_is_weights.append(None) + domain_reverse_kls.append(None) + continue + + teacher_lp = teacher_lp_list[i].to(device=device) + + # reverse_kl = log(π_domain(y_t)) - log(π_θ(y_t)), with stop-gradient + # student_log_probs here is π_θ (the training engine log-probs) + with torch.no_grad(): + reverse_kl = teacher_lp - student_log_probs[i] + + # Importance sampling weight: w_t = π_θ(y_t) / μ_θ(y_t) + # = exp(student_log_probs[i] - sampling_log_probs[i]) + is_weight = torch.exp(student_log_probs[i] - sampling_log_probs[i]) + + # Zero out weights outside [eps_low, eps_high] + is_weight = torch.where( + (is_weight >= args.mopd_eps_low) & (is_weight <= args.mopd_eps_high), + is_weight, + torch.zeros_like(is_weight), + ) + + # MOPD advantage: Â_MOPD,t = reverse_kl + α * Â_ORM + mopd_adv = reverse_kl + if args.mopd_alpha > 0: + mopd_adv = reverse_kl + args.mopd_alpha * advantages[i] + + domain_advantages.append(mopd_adv) + domain_is_weights.append(is_weight) + domain_reverse_kls.append(reverse_kl) + + all_mopd_advantages.append(domain_advantages) + all_is_weights_list.append(domain_is_weights) + all_reverse_kls.append(domain_reverse_kls) + + # Aggregate across teachers: average the weighted advantages + # For each sample, only average over domains that have valid (non-None) entries. + # This supports per-sample domain routing where different samples may use different teachers. + aggregated_mopd_advantages = [] + aggregated_is_weights = [] + + for i in range(len(advantages)): + # Collect valid (non-None) teacher contributions for this sample + valid_advs = [all_mopd_advantages[t][i] for t in range(len(all_mopd_advantages)) if all_mopd_advantages[t][i] is not None] + valid_is = [all_is_weights_list[t][i] for t in range(len(all_is_weights_list)) if all_is_weights_list[t][i] is not None] + + if len(valid_advs) == 0: + # No valid teachers for this sample — use zero advantages and zero IS weights + aggregated_mopd_advantages.append(torch.zeros_like(advantages[i])) + aggregated_is_weights.append(torch.zeros_like(advantages[i])) + else: + avg_adv = sum(valid_advs) / len(valid_advs) + avg_is_weight = sum(valid_is) / len(valid_is) + aggregated_mopd_advantages.append(avg_adv) + aggregated_is_weights.append(avg_is_weight) + + # Store MOPD data for use in policy_loss_function + rollout_data["mopd_advantages"] = aggregated_mopd_advantages + rollout_data["mopd_is_weights"] = aggregated_is_weights + + # Also store per-teacher reverse KL for logging + # Use zeros for samples that don't have this domain (per-sample routing) + per_teacher_reverse_kl = {} + for t_idx, domain in enumerate(mopd_teacher_log_probs.keys()): + per_teacher_reverse_kl[domain] = [ + all_reverse_kls[t_idx][i] if all_reverse_kls[t_idx][i] is not None else torch.zeros_like(advantages[i]) + for i in range(len(advantages)) + ] + rollout_data["mopd_reverse_kl"] = per_teacher_reverse_kl + + def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) -> None: """Compute advantages and returns in-place based on `args.advantage_estimator`. @@ -684,6 +823,15 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) student_log_probs=log_probs, ) + # Apply Multi-Teacher On-Policy Distillation (MOPD) to advantages + if args.use_mopd: + apply_mopd_to_advantages( + args=args, + rollout_data=rollout_data, + advantages=advantages, + student_log_probs=log_probs, + ) + # TODO: OpenRLHF always does advantages normalization but veRL doesn't seem to do it. if args.normalize_advantages: all_advs = torch.cat(advantages) @@ -897,6 +1045,18 @@ def policy_loss_function( if args.use_opsm: pg_loss = pg_loss * opsm_mask + # Apply MOPD: replace advantages with mopd_advantages and apply IS weights + # L_MOPD(θ) = -E[1/|y| Σ_t w_t * Â_MOPD,t * log π_θ(y_t|x,y_ list[float]) + if samples[0].mopd_teacher_log_probs is not None: + # Collect all domains across all samples (may differ due to per-sample routing) + all_domains = set() + for sample in samples: + if sample.mopd_teacher_log_probs: + all_domains.update(sample.mopd_teacher_log_probs.keys()) + mopd_teacher_log_probs = {} + for domain in all_domains: + # Use None as placeholder for samples without this domain + mopd_teacher_log_probs[domain] = [ + sample.mopd_teacher_log_probs.get(domain) if sample.mopd_teacher_log_probs else None + for sample in samples + ] + train_data["mopd_teacher_log_probs"] = mopd_teacher_log_probs + return train_data def set_train_parallel_config(self, config: dict): @@ -790,6 +806,12 @@ def _split_train_data_by_dp(self, data, dp_size): continue val = [data[key][j] for j in partition] rollout_data[key] = val + # Handle mopd_teacher_log_probs (dict: domain -> list[list[float]]) + if "mopd_teacher_log_probs" in data: + mopd_lp_dict = {} + for domain, lp_list in data["mopd_teacher_log_probs"].items(): + mopd_lp_dict[domain] = [lp_list[j] for j in partition] + rollout_data["mopd_teacher_log_probs"] = mopd_lp_dict # keys that need to be splited at train side for key in [ "raw_reward", diff --git a/slime/rollout/mopd.py b/slime/rollout/mopd.py new file mode 100644 index 0000000000..ac49d505cc --- /dev/null +++ b/slime/rollout/mopd.py @@ -0,0 +1,246 @@ +"""Multi-Teacher On-Policy Distillation (MOPD) rollout support for SGLang. + +This module provides reward_func and post_process_rewards for fetching log-probs +from multiple domain-specific teacher SGLang servers. Each teacher is identified +by a domain name and has its own rm_url. + +Usage: + --use-mopd + --mopd-teachers '[{"name": "math_teacher", "domain": "math"}, {"name": "code_teacher", "domain": "code"}]' + --custom-rm-path slime.rollout.mopd.reward_func + --custom-reward-post-process-path slime.rollout.mopd.post_process_rewards + +The teacher rm_urls are configured via --mopd-teachers JSON, where each entry +can contain an optional "rm_url" field. Alternatively, they can be specified +via the MOPD_TEACHER_URLS environment variable as a JSON dict mapping domain -> URL. +""" + +import asyncio +import json +import logging +import os + +import aiohttp +import torch + +from slime.utils.processing_utils import encode_image_for_rollout_engine +from slime.utils.types import Sample + +logger = logging.getLogger(__name__) + + +def _get_mopd_teacher_configs(args) -> list[dict]: + """Parse MOPD teacher configurations from args. + + Returns: + List of teacher config dicts, each containing at least 'name' and 'domain'. + May also contain 'rm_url' for SGLang mode. + """ + teachers_str = args.mopd_teachers + if isinstance(teachers_str, str): + return json.loads(teachers_str) + return teachers_str + + +def _build_payload(sample): + """Build the SGLang request payload for log-prob extraction.""" + payload = { + "input_ids": sample.tokens, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + + if sample.multimodal_inputs and sample.multimodal_inputs.get("images"): + image_data = sample.multimodal_inputs["images"] + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + return payload + + +async def _fetch_teacher_logprobs(session: aiohttp.ClientSession, rm_url: str, payload: dict) -> dict: + """Fetch log-probs from a single teacher SGLang server.""" + async with session.post(rm_url, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + + +def _resolve_teacher_urls(args, teacher_configs: list[dict]) -> dict[str, str]: + """Resolve rm_url for each teacher domain. + + Priority: + 1. 'rm_url' field in the teacher config + 2. MOPD_TEACHER_URLS environment variable + 3. Fallback: args.rm_url (all teachers share the same URL) + """ + env_urls = {} + env_urls_str = os.environ.get("MOPD_TEACHER_URLS", "") + if env_urls_str: + env_urls = json.loads(env_urls_str) + + url_map = {} + for teacher_cfg in teacher_configs: + domain = teacher_cfg["domain"] + rm_url = teacher_cfg.get("rm_url") or env_urls.get(domain) + if rm_url is None: + rm_url = args.rm_url + url_map[domain] = rm_url + + return url_map + + +def _get_sample_domains(sample, all_domains: list[str]) -> list[str] | None: + """Get the list of teacher domains that should be queried for this sample. + + If sample.metadata contains a 'mopd_domains' key, return those domains + (filtered to only include valid configured domains). + Otherwise, return None to indicate all domains should be queried. + + When there is only one configured domain, always returns None since + routing is unnecessary — all samples must use the single teacher. + + Args: + sample: The sample to check. + all_domains: List of all configured domain names. + + Returns: + List of domain names to query, or None to query all. + """ + # With only one teacher, routing is unnecessary — always query the only domain + if len(all_domains) <= 1: + return None + + metadata = sample.metadata if isinstance(sample.metadata, dict) else {} + sample_domains = metadata.get("mopd_domains") + if sample_domains is None: + return None # Query all domains + + if isinstance(sample_domains, str): + sample_domains = [sample_domains] + + # Filter to only include valid configured domains + valid_domains = [d for d in sample_domains if d in all_domains] + if not valid_domains: + logger.warning( + f"Sample has mopd_domains={sample_domains} but none match configured domains {all_domains}. " + f"Falling back to all domains." + ) + return None + + return valid_domains + + +async def _reward_func_single(args, sample, **kwargs): + """Query MOPD teacher servers for a single sample. + + If sample.metadata contains 'mopd_domains' (a list of domain names or a single + string), only the specified teachers are queried. Otherwise, all teachers are queried. + + Returns: + dict mapping domain -> raw teacher response (JSON from SGLang). + This dict is stored in sample.reward and later processed by post_process_rewards. + """ + teacher_configs = _get_mopd_teacher_configs(args) + url_map = _resolve_teacher_urls(args, teacher_configs) + all_domains = list(url_map.keys()) + + # Determine which domains to query for this sample + target_domains = _get_sample_domains(sample, all_domains) + if target_domains is not None: + url_map = {d: url_map[d] for d in target_domains} + + payload = _build_payload(sample) + + results = {} + + async with aiohttp.ClientSession() as session: + tasks = [] + domains = [] + for domain, rm_url in url_map.items(): + domains.append(domain) + tasks.append(_fetch_teacher_logprobs(session, rm_url, payload)) + + responses = await asyncio.gather(*tasks, return_exceptions=True) + + for domain, resp in zip(domains, responses): + if isinstance(resp, Exception): + logger.warning( + f"MOPD teacher '{domain}' failed: {resp}. Skipping this teacher." + ) + continue + results[domain] = resp + + return results + + +async def reward_func(args, sample_or_samples, **kwargs): + """Query all MOPD teacher servers for the given sample(s). + + Supports both per-sample and batch calling conventions: + - When called via async_rm: receives a single Sample, returns a dict + (domain -> raw teacher response). + - When called via batched_async_rm: receives a list of Samples, returns + a list of dicts (one per sample). + + The rm_url for each teacher is determined from: + 1. The 'rm_url' field in the teacher config (if present) + 2. The MOPD_TEACHER_URLS environment variable + 3. Fallback: args.rm_url + """ + if isinstance(sample_or_samples, list): + # Batch mode: called from batched_async_rm with a list of samples + tasks = [_reward_func_single(args, s, **kwargs) for s in sample_or_samples] + return await asyncio.gather(*tasks) + else: + # Single sample mode: called from async_rm + return await _reward_func_single(args, sample_or_samples, **kwargs) + + +def post_process_rewards(args, samples: list[Sample], **kwargs): + """Process MOPD teacher responses and extract per-domain teacher log-probs. + + This function: + 1. Extracts log-probs from each teacher server response + 2. Stores them in sample.mopd_teacher_log_probs[domain] + 3. Returns scalar rewards compatible with GRPO/PPO + + The raw_rewards for each sample is expected to be a dict mapping domain -> response, + as returned by mopd.reward_func. + """ + raw_rewards = [sample.get_reward_value(args) for sample in samples] + response_lengths = [sample.response_length for sample in samples] + + for sample, reward_val, response_length in zip(samples, raw_rewards, response_lengths, strict=False): + if sample.mopd_teacher_log_probs is None: + sample.mopd_teacher_log_probs = {} + + if not isinstance(reward_val, dict): + # If reward_func didn't return a dict (e.g., fallback case), skip + continue + + for domain, teacher_response in reward_val.items(): + try: + # Extract log-probs from sglang response format + log_probs = torch.tensor( + [item[0] for item in teacher_response["meta_info"]["input_token_logprobs"][1:]], + dtype=torch.float32, + ) + # Trim to response length + log_probs = log_probs[-response_length:] + sample.mopd_teacher_log_probs[domain] = log_probs + except (KeyError, IndexError, TypeError) as e: + logger.warning( + f"MOPD: Failed to extract log-probs for domain '{domain}': {e}" + ) + + # Return scalar rewards for GRPO/PPO advantage estimator + # For pure MOPD distillation, we use 0.0 as the task reward. + # The learning signal comes from the MOPD advantage applied in compute_advantages_and_returns. + # If you have task rewards, configure them separately via reward model. + scalar_rewards = [0.0] * len(samples) + + return scalar_rewards, scalar_rewards \ No newline at end of file diff --git a/slime/rollout/rm_hub/__init__.py b/slime/rollout/rm_hub/__init__.py index 0991e559e5..8296641a38 100644 --- a/slime/rollout/rm_hub/__init__.py +++ b/slime/rollout/rm_hub/__init__.py @@ -85,6 +85,11 @@ async def async_rm(args, sample: Sample, **kwargs): return compute_ifbench_reward(response, label, metadata=metadata) elif rm_type == "random": return random.randint(0, 1) + elif rm_type == "zero": + # Always return 0.0 — useful for pure distillation (e.g., MOPD with alpha=0) + # where no task reward is needed and the learning signal comes entirely from + # the distillation KL advantages. + return 0.0 elif rm_type: raise NotImplementedError(f"Rule-based RM for {rm_type} is not implemented.") else: diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index e8a1730782..602a61656c 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1011,6 +1011,93 @@ def add_on_policy_distillation_arguments(parser): parser.add_argument( "--opd-teacher-ckpt-step", type=int, default=None, help="The checkpoint step for OPD teacher model." ) + + # --- MOPD (Multi-Teacher On-Policy Distillation) arguments --- + parser.add_argument( + "--use-mopd", + action="store_true", + default=False, + help=( + "Enable Multi-Teacher On-Policy Distillation (MOPD). " + "MOPD extends OPD by distilling from multiple domain-specific teachers with " + "importance sampling correction. Mutually exclusive with --use-opd." + ), + ) + parser.add_argument( + "--mopd-teachers", + type=str, + default=None, + help=( + "JSON configuration for multiple MOPD teacher models. " + 'Format: [{"name": "math_teacher", "domain": "math"}, ...]. ' + "Each entry defines a teacher with a unique domain identifier. " + "For sglang mode, teacher log-probs are fetched from --rm-url; " + "for megatron mode, --mopd-teacher-loads provides checkpoint paths." + ), + ) + parser.add_argument( + "--mopd-teacher-loads", + type=str, + nargs="+", + default=None, + help=( + "List of Megatron checkpoint paths for MOPD teacher models. " + "Must match the number of teachers in --mopd-teachers. " + "Required when --use-mopd is enabled with megatron-based teachers." + ), + ) + parser.add_argument( + "--mopd-teacher-ckpt-steps", + type=int, + nargs="+", + default=None, + help=( + "List of checkpoint steps for MOPD teacher models. " + "Must match the number of teachers in --mopd-teachers. " + "If not specified, the latest checkpoint is used for each teacher." + ), + ) + parser.add_argument( + "--mopd-alpha", + type=float, + default=0.0, + help=( + "Coefficient alpha for combining MOPD advantage with ORM advantage. " + "Â_MOPD,t = sg[log(π_domain/π_θ)] + α * Â_ORM. " + "Default is 0.0 (no ORM combination)." + ), + ) + parser.add_argument( + "--mopd-eps-low", + type=float, + default=0.2, + help=( + "Lower bound for importance sampling weight clipping in MOPD. " + "Tokens with w_t = π_θ(y_t)/μ_θ(y_t) below this threshold are zeroed out. " + "Default is 0.2." + ), + ) + parser.add_argument( + "--mopd-eps-high", + type=float, + default=5.0, + help=( + "Upper bound for importance sampling weight clipping in MOPD. " + "Tokens with w_t = π_θ(y_t)/μ_θ(y_t) above this threshold are zeroed out. " + "Default is 5.0." + ), + ) + parser.add_argument( + "--mopd-sampling-logprobs-key", + type=str, + default="rollout_log_probs", + choices=["rollout_log_probs", "log_probs"], + help=( + "Which log-probs to use as the sampling policy μ_θ for MOPD importance sampling. " + "'rollout_log_probs': use the inference engine log-probs (default). " + "'log_probs': use the training engine log-probs." + ), + ) return parser def add_router_arguments(parser): @@ -1491,6 +1578,7 @@ def _apply_megatron_role_overrides(base_args, overrides, role): # Critic-specific: disable features that only apply to actors. role_args.kl_coef = 0 role_args.use_opd = False + role_args.use_mopd = False role_args.custom_advantage_function_path = None role_args.untie_embeddings_and_output_weights = True @@ -1640,6 +1728,117 @@ def slime_validate_args(args): if args.opd_teacher_load is not None: raise ValueError("--opd-teacher-load is set but --use-opd is not enabled. Please add --use-opd flag.") + # Validate Multi-Teacher On-Policy Distillation (MOPD) arguments + if getattr(args, "use_mopd", False): + # MOPD and OPD are mutually exclusive + if args.use_opd: + raise ValueError("--use-mopd and --use-opd are mutually exclusive. Please use only one distillation mode.") + + # --mopd-teachers is required (can also be set via MOPD_TEACHERS_JSON env var) + if args.mopd_teachers is None: + env_mopd_teachers = os.environ.get("MOPD_TEACHERS_JSON") + if env_mopd_teachers: + args.mopd_teachers = env_mopd_teachers + else: + raise ValueError( + "--mopd-teachers is required when --use-mopd is enabled. " + "You can also set the MOPD_TEACHERS_JSON environment variable." + ) + + # Parse and validate MOPD teachers config + try: + if isinstance(args.mopd_teachers, str): + mopd_teachers = json.loads(args.mopd_teachers) + else: + mopd_teachers = args.mopd_teachers + except (json.JSONDecodeError, TypeError) as e: + raise ValueError(f"--mopd-teachers must be valid JSON: {e}") + + if not isinstance(mopd_teachers, list) or len(mopd_teachers) == 0: + raise ValueError("--mopd-teachers must be a non-empty JSON list of teacher configs.") + + domains = set() + for i, teacher_cfg in enumerate(mopd_teachers): + if not isinstance(teacher_cfg, dict): + raise ValueError(f"--mopd-teachers[{i}] must be a dict, got {type(teacher_cfg)}.") + if "domain" not in teacher_cfg: + raise ValueError(f"--mopd-teachers[{i}] must contain a 'domain' key.") + domain = teacher_cfg["domain"] + if domain in domains: + raise ValueError(f"--mopd-teachers has duplicate domain '{domain}'. Each domain must be unique.") + domains.add(domain) + + # Validate MOPD teacher loads for megatron mode + if args.mopd_teacher_loads is not None: + if len(args.mopd_teacher_loads) != len(mopd_teachers): + raise ValueError( + f"--mopd-teacher-loads has {len(args.mopd_teacher_loads)} paths, " + f"but --mopd-teachers has {len(mopd_teachers)} teachers. They must match." + ) + for i, path in enumerate(args.mopd_teacher_loads): + if not os.path.exists(path): + raise FileNotFoundError(f"mopd_teacher_loads[{i}] path {path} does not exist.") + if not os.path.exists(os.path.join(path, "latest_checkpointed_iteration.txt")): + logger.info( + f"mopd_teacher_loads[{i}] path {path} does not have " + "latest_checkpointed_iteration.txt, please make sure it is a valid megatron checkpoint." + ) + + # Validate MOPD teacher checkpoint steps + if args.mopd_teacher_ckpt_steps is not None: + if args.mopd_teacher_loads is None: + raise ValueError("--mopd-teacher-ckpt-steps requires --mopd-teacher-loads to be set.") + if len(args.mopd_teacher_ckpt_steps) != len(mopd_teachers): + raise ValueError( + f"--mopd-teacher-ckpt-steps has {len(args.mopd_teacher_ckpt_steps)} values, " + f"but --mopd-teachers has {len(mopd_teachers)} teachers. They must match." + ) + + # Validate importance sampling bounds + if args.mopd_eps_low < 0: + raise ValueError(f"--mopd-eps-low must be >= 0, got {args.mopd_eps_low}.") + if args.mopd_eps_high <= args.mopd_eps_low: + raise ValueError( + f"--mopd-eps-high ({args.mopd_eps_high}) must be > --mopd-eps-low ({args.mopd_eps_low})." + ) + + # MOPD with megatron-based teachers requires weights_backuper (to backup multiple models) + if args.mopd_teacher_loads is not None and not args.enable_weights_backuper: + raise ValueError( + "--disable-weights-backuper is not compatible with MOPD megatron mode " + "(--mopd-teacher-loads). MOPD needs to backup multiple teacher model weights." + ) + + # Validate rm_type requirement based on mopd_alpha + # When mopd_alpha > 0, ORM advantages are combined with distillation advantages, + # so a reward model is required. + # When mopd_alpha == 0, pure distillation doesn't need task rewards; if no rm_type + # or custom_rm_path is set, default to "zero" reward. + if args.mopd_alpha > 0 and args.rm_type is None and args.custom_rm_path is None: + raise ValueError( + "--mopd-alpha > 0 requires a reward model (--rm-type or --custom-rm-path) " + "because ORM advantages are combined with distillation advantages. " + "Either set --rm-type, --custom-rm-path, or use --mopd-alpha 0 for pure distillation." + ) + if args.mopd_alpha == 0 and args.rm_type is None and args.custom_rm_path is None: + logger.info( + "MOPD with alpha=0 (pure distillation): no --rm-type or --custom-rm-path set, " + "defaulting to 'zero' reward (always 0.0). The learning signal comes entirely " + "from the distillation KL advantages." + ) + args.rm_type = "zero" + + # Store parsed teachers for later use + args._mopd_teachers_parsed = mopd_teachers + else: + # If MOPD is not enabled, MOPD-specific args should not be set + if getattr(args, "mopd_teacher_loads", None) is not None: + raise ValueError("--mopd-teacher-loads is set but --use-mopd is not enabled. Please add --use-mopd flag.") + if getattr(args, "mopd_teacher_ckpt_steps", None) is not None: + raise ValueError( + "--mopd-teacher-ckpt-steps is set but --use-mopd is not enabled. Please add --use-mopd flag." + ) + if args.megatron_to_hf_mode == "bridge": if ( args.load is not None diff --git a/slime/utils/types.py b/slime/utils/types.py index 0681c184b0..4e7eb0a4fb 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -27,6 +27,7 @@ class Sample: rollout_routed_experts: list[list[int]] | None = None # Routed experts from rollout engine remove_sample: bool = False teacher_log_probs: list[float] | None = None # Log probabilities from teacher model for OPD + mopd_teacher_log_probs: dict[str, list[float]] | None = None # Log probabilities from multiple MOPD teachers (domain -> log_probs) class Status(Enum): PENDING = "pending" diff --git a/tests/test_mopd.py b/tests/test_mopd.py new file mode 100644 index 0000000000..41aa960cdd --- /dev/null +++ b/tests/test_mopd.py @@ -0,0 +1,555 @@ +"""Unit tests for MOPD (Multi-Teacher On-Policy Distillation). + +Tests cover: +1. MOPD advantage computation (apply_mopd_to_advantages) +2. MOPD importance sampling weight computation and clipping +3. MOPD parameter validation in slime_validate_args +4. Sample.mopd_teacher_log_probs field +""" + +import json +import os +import sys +import types +from argparse import Namespace + +import pytest + +torch = pytest.importorskip("torch") + + +# --------------------------------------------------------------------------- +# Helper to construct args for MOPD +# --------------------------------------------------------------------------- +def make_mopd_args(**overrides): + """Create a Namespace with default MOPD arguments.""" + defaults = dict( + use_mopd=True, + mopd_teachers='[{"name": "math_teacher", "domain": "math"}]', + mopd_teacher_loads=None, + mopd_teacher_ckpt_steps=None, + mopd_alpha=0.0, + mopd_eps_low=0.2, + mopd_eps_high=5.0, + mopd_sampling_logprobs_key="rollout_log_probs", + ) + defaults.update(overrides) + return Namespace(**defaults) + + +# --------------------------------------------------------------------------- +# Tests for apply_mopd_to_advantages +# --------------------------------------------------------------------------- +class TestApplyMopdToAdvantages: + """Test the apply_mopd_to_advantages function in loss.py.""" + + @pytest.fixture(autouse=True) + def _import_loss_module(self, monkeypatch): + """Import loss.py with minimal megatron mocking.""" + # Mock megatron modules + mpu_mod = types.ModuleType("megatron.core") + mpu_sub = types.ModuleType("megatron.core.mpu") + mpu_sub.is_pipeline_last_stage = lambda: True + mpu_sub.get_context_parallel_rank = lambda: 0 + mpu_sub.get_context_parallel_world_size = lambda: 1 + mpu_sub.get_data_parallel_group = lambda: None + mpu_sub.get_data_parallel_rank = lambda: 0 + mpu_sub.get_data_parallel_world_size = lambda: 1 + mpu_sub.get_tensor_model_parallel_rank = lambda: 0 + mpu_sub.get_tensor_model_parallel_world_size = lambda: 1 + + monkeypatch.setitem(sys.modules, "megatron", types.ModuleType("megatron")) + monkeypatch.setitem(sys.modules, "megatron.core", mpu_mod) + monkeypatch.setitem(sys.modules, "megatron.core.mpu", mpu_sub) + + def _get_apply_mopd(self): + """Dynamically import apply_mopd_to_advantages from loss.py.""" + from slime.backends.megatron_utils.loss import apply_mopd_to_advantages + return apply_mopd_to_advantages + + def test_basic_mopd_advantage_computation(self): + """Test that MOPD advantages are computed correctly with a single teacher.""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args(mopd_alpha=0.0, mopd_eps_low=0.0, mopd_eps_high=1000.0) + + # Student log_probs: [0.1, -0.2, 0.3] + # Teacher log_probs: [0.2, -0.1, 0.4] + # reverse_kl = teacher - student = [0.1, 0.1, 0.1] + student_log_probs = [torch.tensor([0.1, -0.2, 0.3])] + teacher_log_probs = [torch.tensor([0.2, -0.1, 0.4])] + + # Sampling log_probs (μ_θ) = rollout_log_probs + # IS weight = exp(student - sampling) = exp(student - rollout) + # With rollout = student (same model), IS weight = 1.0 everywhere + rollout_log_probs = [torch.tensor([0.1, -0.2, 0.3])] + + # Base advantages (ORM advantages) + advantages = [torch.tensor([1.0, 2.0, 3.0])] + + rollout_data = { + "mopd_teacher_log_probs": {"math": teacher_log_probs}, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + # With mopd_alpha=0: mopd_adv = reverse_kl = teacher - student = [0.1, 0.1, 0.1] + # IS weight = exp(student - rollout) = exp(0) = 1.0, within bounds + assert "mopd_advantages" in rollout_data + assert "mopd_is_weights" in rollout_data + assert "mopd_reverse_kl" in rollout_data + + # Check advantages (should not be modified in-place by MOPD, only stored) + # Actually MOPD stores results in rollout_data, not modifying advantages + mopd_adv = rollout_data["mopd_advantages"][0] + is_weights = rollout_data["mopd_is_weights"][0] + + expected_reverse_kl = torch.tensor([0.1, 0.1, 0.1]) + assert torch.allclose(mopd_adv, expected_reverse_kl, atol=1e-6) + assert torch.allclose(is_weights, torch.ones(3), atol=1e-6) + + # Check mopd_reverse_kl is pure reverse_kl (not including alpha * orm_advantage) + reverse_kl_logged = rollout_data["mopd_reverse_kl"]["math"][0] + expected_pure_reverse_kl = torch.tensor([0.1, 0.1, 0.1]) + assert torch.allclose(reverse_kl_logged, expected_pure_reverse_kl, atol=1e-6) + + def test_mopd_with_alpha(self): + """Test MOPD with ORM advantage combination (alpha > 0).""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args(mopd_alpha=1.0, mopd_eps_low=0.0, mopd_eps_high=1000.0) + + student_log_probs = [torch.tensor([0.0, 0.0])] + teacher_log_probs = [torch.tensor([1.0, 1.0])] + rollout_log_probs = [torch.tensor([0.0, 0.0])] + advantages = [torch.tensor([2.0, 3.0])] + + rollout_data = { + "mopd_teacher_log_probs": {"math": teacher_log_probs}, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + # reverse_kl = 1.0 - 0.0 = 1.0 + # mopd_adv = reverse_kl + alpha * ORM_adv = 1.0 + 1.0 * [2.0, 3.0] = [3.0, 4.0] + mopd_adv = rollout_data["mopd_advantages"][0] + expected = torch.tensor([3.0, 4.0]) + assert torch.allclose(mopd_adv, expected, atol=1e-6) + + # mopd_reverse_kl should be pure reverse_kl, NOT containing alpha * orm_advantage + reverse_kl_logged = rollout_data["mopd_reverse_kl"]["math"][0] + expected_pure_reverse_kl = torch.tensor([1.0, 1.0]) + assert torch.allclose(reverse_kl_logged, expected_pure_reverse_kl, atol=1e-6) + + def test_is_weight_clipping_low(self): + """Test that IS weights below eps_low are zeroed out.""" + apply_mopd = self._get_apply_mopd() + # eps_low=0.5, so weights < 0.5 should be zeroed + args = make_mopd_args(mopd_alpha=0.0, mopd_eps_low=0.5, mopd_eps_high=100.0) + + # student - rollout = very negative => IS weight = exp(very_negative) < 0.5 + student_log_probs = [torch.tensor([-5.0, 0.0])] + teacher_log_probs = [torch.tensor([0.0, 0.0])] + rollout_log_probs = [torch.tensor([0.0, 0.0])] + advantages = [torch.tensor([1.0, 1.0])] + + # IS weight for token 0: exp(-5.0 - 0.0) = exp(-5.0) ≈ 0.0067 < 0.5 -> zeroed + # IS weight for token 1: exp(0.0 - 0.0) = 1.0 >= 0.5 -> kept + rollout_data = { + "mopd_teacher_log_probs": {"math": teacher_log_probs}, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + is_weights = rollout_data["mopd_is_weights"][0] + assert is_weights[0].item() == 0.0, "IS weight below eps_low should be zeroed" + assert is_weights[1].item() == 1.0, "IS weight within bounds should be kept" + + def test_is_weight_clipping_high(self): + """Test that IS weights above eps_high are zeroed out.""" + apply_mopd = self._get_apply_mopd() + # eps_high=2.0, so weights > 2.0 should be zeroed + args = make_mopd_args(mopd_alpha=0.0, mopd_eps_low=0.0, mopd_eps_high=2.0) + + # student >> rollout => large IS weight + student_log_probs = [torch.tensor([5.0, 0.0])] + teacher_log_probs = [torch.tensor([0.0, 0.0])] + rollout_log_probs = [torch.tensor([0.0, 0.0])] + advantages = [torch.tensor([1.0, 1.0])] + + # IS weight for token 0: exp(5.0 - 0.0) = exp(5.0) ≈ 148.4 > 2.0 -> zeroed + # IS weight for token 1: exp(0.0 - 0.0) = 1.0 <= 2.0 -> kept + rollout_data = { + "mopd_teacher_log_probs": {"math": teacher_log_probs}, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + is_weights = rollout_data["mopd_is_weights"][0] + assert is_weights[0].item() == 0.0, "IS weight above eps_high should be zeroed" + assert is_weights[1].item() == 1.0, "IS weight within bounds should be kept" + + def test_multiple_teachers_averaged(self): + """Test that MOPD advantages and IS weights are averaged across teachers.""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args(mopd_alpha=0.0, mopd_eps_low=0.0, mopd_eps_high=1000.0) + + student_log_probs = [torch.tensor([0.0, 0.0])] + # Two teachers with different log-probs + teacher_math_log_probs = [torch.tensor([1.0, 1.0])] # reverse_kl = [1.0, 1.0] + teacher_code_log_probs = [torch.tensor([2.0, 2.0])] # reverse_kl = [2.0, 2.0] + rollout_log_probs = [torch.tensor([0.0, 0.0])] + advantages = [torch.tensor([0.0, 0.0])] + + rollout_data = { + "mopd_teacher_log_probs": { + "math": teacher_math_log_probs, + "code": teacher_code_log_probs, + }, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + # Averaged advantage = (1.0 + 2.0) / 2 = 1.5 + mopd_adv = rollout_data["mopd_advantages"][0] + expected = torch.tensor([1.5, 1.5]) + assert torch.allclose(mopd_adv, expected, atol=1e-6) + + # IS weights should also be averaged (both are 1.0 here) + is_weights = rollout_data["mopd_is_weights"][0] + assert torch.allclose(is_weights, torch.ones(2), atol=1e-6) + + def test_per_sample_domain_routing(self): + """Test per-sample domain routing with None entries in teacher_log_probs. + + When some samples don't have log-probs for a domain (None), that domain + should be excluded from the average for those samples. + """ + apply_mopd = self._get_apply_mopd() + args = make_mopd_args(mopd_alpha=0.0, mopd_eps_low=0.0, mopd_eps_high=1000.0) + + # 2 samples, 2 teachers + student_log_probs = [torch.tensor([0.0, 0.0]), torch.tensor([0.0, 0.0])] + # Sample 0: only math teacher (code is None) + # Sample 1: both teachers + teacher_math_log_probs = [torch.tensor([1.0, 1.0]), torch.tensor([1.0, 1.0])] + teacher_code_log_probs = [None, torch.tensor([2.0, 2.0])] + rollout_log_probs = [torch.tensor([0.0, 0.0]), torch.tensor([0.0, 0.0])] + advantages = [torch.tensor([0.0, 0.0]), torch.tensor([0.0, 0.0])] + + rollout_data = { + "mopd_teacher_log_probs": { + "math": teacher_math_log_probs, + "code": teacher_code_log_probs, + }, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + # Sample 0: only math teacher (reverse_kl = 1.0), so mopd_adv = 1.0 + mopd_adv_s0 = rollout_data["mopd_advantages"][0] + assert torch.allclose(mopd_adv_s0, torch.tensor([1.0, 1.0]), atol=1e-6) + + # Sample 1: both teachers, averaged = (1.0 + 2.0) / 2 = 1.5 + mopd_adv_s1 = rollout_data["mopd_advantages"][1] + assert torch.allclose(mopd_adv_s1, torch.tensor([1.5, 1.5]), atol=1e-6) + + def test_per_sample_all_domains_none(self): + """Test that a sample with no valid domains gets zero advantages and IS weights.""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args(mopd_alpha=0.0, mopd_eps_low=0.0, mopd_eps_high=1000.0) + + student_log_probs = [torch.tensor([0.0, 0.0])] + # All domains are None for this sample + teacher_math_log_probs = [None] + teacher_code_log_probs = [None] + rollout_log_probs = [torch.tensor([0.0, 0.0])] + advantages = [torch.tensor([1.0, 1.0])] + + rollout_data = { + "mopd_teacher_log_probs": { + "math": teacher_math_log_probs, + "code": teacher_code_log_probs, + }, + "rollout_log_probs": rollout_log_probs, + } + + apply_mopd(args, rollout_data, advantages, student_log_probs) + + # Should get zeros since no valid teachers + mopd_adv = rollout_data["mopd_advantages"][0] + assert torch.allclose(mopd_adv, torch.zeros(2), atol=1e-6) + is_weights = rollout_data["mopd_is_weights"][0] + assert torch.allclose(is_weights, torch.zeros(2), atol=1e-6) + + def test_student_log_probs_none_returns_early(self): + """Test that apply_mopd returns early when student_log_probs is None.""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args() + advantages = [torch.tensor([1.0])] + + rollout_data = { + "mopd_teacher_log_probs": {"math": [torch.tensor([1.0])]}, + } + + # Should not raise, just return early + apply_mopd(args, rollout_data, advantages, None) + # No MOPD keys should be added + assert "mopd_advantages" not in rollout_data + + def test_missing_teacher_log_probs_raises(self): + """Test that missing mopd_teacher_log_probs raises ValueError.""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args() + student_log_probs = [torch.tensor([0.0])] + advantages = [torch.tensor([1.0])] + rollout_data = {} + + with pytest.raises(ValueError, match="mopd_teacher_log_probs"): + apply_mopd(args, rollout_data, advantages, student_log_probs) + + def test_empty_teacher_log_probs_raises(self): + """Test that empty mopd_teacher_log_probs dict raises ValueError.""" + apply_mopd = self._get_apply_mopd() + args = make_mopd_args() + student_log_probs = [torch.tensor([0.0])] + advantages = [torch.tensor([1.0])] + rollout_data = {"mopd_teacher_log_probs": {}} + + with pytest.raises(ValueError, match="mopd_teacher_log_probs"): + apply_mopd(args, rollout_data, advantages, student_log_probs) + + +# --------------------------------------------------------------------------- +# Tests for MOPD parameter validation +# --------------------------------------------------------------------------- +class TestMopdArgValidation: + """Test MOPD argument validation in slime_validate_args.""" + + @pytest.fixture(autouse=True) + def _mock_deps(self, monkeypatch): + """Mock megatron and other dependencies for arguments module.""" + megatron_mod = types.ModuleType("megatron") + training_mod = types.ModuleType("megatron.training") + arguments_mod = types.ModuleType("megatron.training.arguments") + arguments_mod.parse_args = lambda *a, **kw: None + arguments_mod.validate_args = lambda a: a + tokenizer_pkg_mod = types.ModuleType("megatron.training.tokenizer") + tokenizer_mod = types.ModuleType("megatron.training.tokenizer.tokenizer") + tokenizer_mod._vocab_size_with_padding = lambda vocab_size, _args: vocab_size + transformers_mod = types.ModuleType("transformers") + transformers_mod.AutoConfig = types.SimpleNamespace(from_pretrained=lambda *a, **kw: None) + + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.training", training_mod) + monkeypatch.setitem(sys.modules, "megatron.training.arguments", arguments_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer", tokenizer_pkg_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer.tokenizer", tokenizer_mod) + monkeypatch.setitem(sys.modules, "transformers", transformers_mod) + + def _make_base_args(self, **overrides): + """Create a minimal valid args Namespace for validation tests.""" + defaults = dict( + use_opd=False, + opd_type=None, + opd_kl_coef=1.0, + opd_teacher_load=None, + opd_teacher_ckpt_step=None, + use_mopd=False, + mopd_teachers=None, + mopd_teacher_loads=None, + mopd_teacher_ckpt_steps=None, + mopd_alpha=0.0, + mopd_eps_low=0.2, + mopd_eps_high=5.0, + mopd_sampling_logprobs_key="rollout_log_probs", + enable_weights_backuper=True, + eval_datasets=[], + eval_prompt_data=None, + kl_coef=0, + ref_load="/tmp/fake_ref", + use_kl_loss=False, + use_critic=False, + rm_type=None, + custom_rm_path=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + def test_mopd_and_opd_mutually_exclusive(self): + """Test that --use-mopd and --use-opd cannot be used together.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args(use_mopd=True, use_opd=True) + with pytest.raises(ValueError, match="mutually exclusive"): + slime_validate_args(args) + + def test_mopd_requires_teachers(self): + """Test that --use-mopd requires --mopd-teachers.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args(use_mopd=True, mopd_teachers=None) + with pytest.raises(ValueError, match="mopd-teachers"): + slime_validate_args(args) + + def test_mopd_duplicate_domain_raises(self): + """Test that duplicate domains in --mopd-teachers raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}, {"name": "t2", "domain": "math"}]', + ) + with pytest.raises(ValueError, match="duplicate domain"): + slime_validate_args(args) + + def test_mopd_missing_domain_raises(self): + """Test that teacher config without 'domain' key raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1"}]', + ) + with pytest.raises(ValueError, match="domain"): + slime_validate_args(args) + + def test_mopd_eps_low_negative_raises(self): + """Test that negative eps_low raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_eps_low=-0.1, + ) + with pytest.raises(ValueError, match="mopd-eps-low"): + slime_validate_args(args) + + def test_mopd_eps_high_leq_eps_low_raises(self): + """Test that eps_high <= eps_low raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_eps_low=5.0, + mopd_eps_high=5.0, + ) + with pytest.raises(ValueError, match="mopd-eps-high"): + slime_validate_args(args) + + def test_mopd_teacher_loads_count_mismatch(self): + """Test that mismatched mopd_teacher_loads count raises ValueError.""" + from slime.utils.arguments import slime_validate_args, tmp_path + + # Create fake checkpoint dirs + ckpt_dir = tmp_path / "teacher1" + ckpt_dir.mkdir() + (ckpt_dir / "latest_checkpointed_iteration.txt").write_text("1") + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}, {"name": "t2", "domain": "code"}]', + mopd_teacher_loads=[str(ckpt_dir)], # 1 path but 2 teachers + ) + with pytest.raises(ValueError, match="mopd-teacher-loads"): + slime_validate_args(args) + + def test_mopd_not_enabled_loads_set_raises(self): + """Test that mopd_teacher_loads without --use-mopd raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args(use_mopd=False, mopd_teacher_loads=["/tmp/fake"]) + with pytest.raises(ValueError, match="use-mopd is not enabled"): + slime_validate_args(args) + + def test_mopd_alpha_gt0_without_rm_type_raises(self): + """Test that --mopd-alpha > 0 without --rm-type or --custom-rm-path raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_alpha=1.0, + rm_type=None, + custom_rm_path=None, + ) + with pytest.raises(ValueError, match="mopd-alpha > 0 requires a reward model"): + slime_validate_args(args) + + def test_mopd_alpha_zero_without_rm_type_defaults_to_zero(self): + """Test that --mopd-alpha 0 without --rm-type defaults rm_type to 'zero'.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_alpha=0.0, + rm_type=None, + custom_rm_path=None, + ) + slime_validate_args(args) + assert args.rm_type == "zero" + + def test_mopd_alpha_gt0_with_rm_type_ok(self): + """Test that --mopd-alpha > 0 with --rm-type does not raise.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_alpha=1.0, + rm_type="math", + custom_rm_path=None, + ) + slime_validate_args(args) # Should not raise + + def test_mopd_alpha_gt0_with_custom_rm_ok(self): + """Test that --mopd-alpha > 0 with --custom-rm-path does not raise.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_alpha=1.0, + rm_type=None, + custom_rm_path="some.module.func", + ) + slime_validate_args(args) # Should not raise + + +# --------------------------------------------------------------------------- +# Tests for Sample.mopd_teacher_log_probs field +# --------------------------------------------------------------------------- +class TestSampleMopdField: + """Test that Sample supports mopd_teacher_log_probs field.""" + + def test_default_none(self): + from slime.utils.types import Sample + s = Sample() + assert s.mopd_teacher_log_probs is None + + def test_set_mopd_teacher_log_probs(self): + from slime.utils.types import Sample + import torch + s = Sample() + s.mopd_teacher_log_probs = { + "math": torch.tensor([0.1, 0.2, 0.3]), + "code": torch.tensor([0.4, 0.5, 0.6]), + } + assert "math" in s.mopd_teacher_log_probs + assert "code" in s.mopd_teacher_log_probs + assert len(s.mopd_teacher_log_probs["math"]) == 3 + + def test_to_dict_roundtrip(self): + from slime.utils.types import Sample + s = Sample(response="hello", response_length=1) + s.mopd_teacher_log_probs = {"math": [0.1, 0.2, 0.3]} + d = s.to_dict() + assert "mopd_teacher_log_probs" in d + assert d["mopd_teacher_log_probs"]["math"] == [0.1, 0.2, 0.3] \ No newline at end of file From 7d46039391f7319dbe62516196ce4cdfaba04720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 15:18:26 +0800 Subject: [PATCH 02/14] feat: add full vocabulary distillation for MOPD - Extend MOPD loss computation to support full vocabulary KL divergence - Add parameterized distillation mode selection (token-level vs full-vocab) - Add ppo_utils helpers for full vocabulary logits processing - Modify model.py to support output_all_logits mode - Add example script for full-vocab megatron training - Add comprehensive unit tests for full vocabulary distillation - Update README with full vocabulary distillation documentation --- .../README.md | 32 +- ...qwen35-35B-A3B-mopd-full-vocab-megatron.sh | 241 +++++++ slime/backends/megatron_utils/actor.py | 40 +- slime/backends/megatron_utils/data.py | 4 + slime/backends/megatron_utils/loss.py | 350 +++++++++- slime/backends/megatron_utils/model.py | 45 +- slime/utils/arguments.py | 25 + slime/utils/ppo_utils.py | 112 +++ tests/test_mopd_full_vocab.py | 635 ++++++++++++++++++ 9 files changed, 1446 insertions(+), 38 deletions(-) create mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh create mode 100644 tests/test_mopd_full_vocab.py diff --git a/examples/multi_teacher_on_policy_distillation/README.md b/examples/multi_teacher_on_policy_distillation/README.md index 335f9c17ed..34a4b1648a 100644 --- a/examples/multi_teacher_on_policy_distillation/README.md +++ b/examples/multi_teacher_on_policy_distillation/README.md @@ -13,7 +13,9 @@ This example shows how to run **multi-teacher on-policy distillation (MOPD)** us ## Algorithm -For each teacher domain *d*, MOPD computes: +### Token-Level Mode (`--mopd-distill-type token_level`, default) + +Uses sampled token log-prob difference as a reverse KL approximation: ``` reverse_kl_d = sg[log π_d(y_t) - log π_θ(y_t)] # per-teacher reverse KL @@ -22,6 +24,24 @@ w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)] # proxy policy loss ``` +### Full-Vocabulary Mode (`--mopd-distill-type full_vocab`) + +Computes the exact full-vocabulary reverse KL divergence instead of the token-level approximation. This provides a more accurate distillation signal at the cost of increased memory usage. + +``` +D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)] # exact full-vocab KL +w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight +L_fv_kl = (1/D) Σ_d (1/|y| Σ_t w_t · D_KL(π_θ ∥ π_d)) # IS-corrected KL loss +L = L_fv_kl + α · L_pg # combined with PG loss +``` + +When `α = 0`: `L = L_fv_kl` (pure distillation, no ORM needed). +When `α > 0`: `L = L_fv_kl + α · L_pg` (distillation + ORM policy gradient). + +**Requirements**: `full_vocab` mode requires `--mopd-teacher-loads` (Megatron teacher mode). SGLang mode is not supported because the full logits tensor cannot be obtained from SGLang rollout. + +**Memory note**: `full_vocab` stores teacher logits `[R, V]` per sample per teacher. For large vocabularies (V=152K), this can be significant. Reduce `--rollout-batch-size` or `--rollout-max-response-len` if OOM occurs. + ## Key Arguments | Argument | Description | @@ -34,6 +54,7 @@ L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)] # proxy policy lo | `--mopd-eps-low` | IS weight lower bound for clipping (default: 0.2). Weights below this are zeroed. | | `--mopd-eps-high` | IS weight upper bound for clipping (default: 5.0). Weights above this are zeroed. | | `--mopd-sampling-logprobs-key` | Key in rollout_data for sampling log-probs used in IS weight computation (default: `rollout_log_probs`). | +| `--mopd-distill-type` | Distillation type: `token_level` (default) uses sampled token log-prob difference as a reverse KL approximation applied at the advantage level; `full_vocab` computes the exact full-vocabulary reverse KL divergence D_KL(π_θ ∥ π_d) using complete logits. `full_vocab` requires `--mopd-teacher-loads` (Megatron mode). | ## SGLang vs Megatron Mode @@ -223,4 +244,11 @@ For string convenience, you can also use a single string instead of a list: The `reward_func` will log a warning and skip the failed teacher for that sample. The training will continue with remaining teachers, but the advantages will be biased. Monitor teacher server health carefully. 5. **Why is `--group-rm` not supported with MOPD?** - MOPD's `reward_func` returns per-domain dicts (not scalar rewards), which is incompatible with the batch `group_rm` reward path. Use the default per-sample reward path (no `--group-rm`). \ No newline at end of file + MOPD's `reward_func` returns per-domain dicts (not scalar rewards), which is incompatible with the batch `group_rm` reward path. Use the default per-sample reward path (no `--group-rm`). + +6. **What is the difference between `token_level` and `full_vocab` distillation types?** + - `token_level` (default): Approximates reverse KL using the sampled token log-prob difference `sg[log π_d(y_t) - log π_θ(y_t)]`. This is efficient and works with both SGLang and Megatron teacher modes, but only captures the KL at the sampled token position. + - `full_vocab`: Computes the exact full-vocabulary reverse KL divergence `D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)]`. Provides a more accurate distillation signal but requires full logits from the teacher, which means it only works with Megatron mode (`--mopd-teacher-loads`). Memory usage is significantly higher because teacher logits `[R, V]` must be stored for each sample. + +7. **When should I use `full_vocab` mode?** + Use `full_vocab` when you need more precise distillation signal and have sufficient GPU/CPU memory. It is particularly beneficial when the student and teacher distributions differ significantly, as the token-level approximation can underestimate the true KL divergence. For memory-constrained scenarios, stick with `token_level`. \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh new file mode 100644 index 0000000000..fcc73e2c9e --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh @@ -0,0 +1,241 @@ +#!/bin/bash + +# Multi-Teacher On-Policy Distillation (MOPD) — Full-Vocabulary KL Divergence Mode +# Model: Qwen3.5-35B-A3B (MoE, 256 experts, 8 active) +# Environment: 8× H20 (143GB) +# Teacher: Same as student (self-distillation for connectivity validation only) +# Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) +# Distill Type: full_vocab (exact full-vocabulary reverse KL D_KL(π_θ ∥ π_d)) +# +# This script is for MOPD full_vocab E2E connectivity validation. +# In production, use a DIFFERENT (stronger) model as teacher. +# +# Key difference from token_level mode: +# --mopd-distill-type full_vocab +# → Computes exact D_KL(π_θ ∥ π_d) over full vocabulary instead of +# approximating from sampled tokens. Requires megatron teacher mode. +# → Uses full logits [R, V] instead of per-token log-probs, which +# increases memory usage significantly. +# +# Parallelism: TP=2, EP=8 (matches SFT config, 256 experts / 8 = 32 per GPU) +# Colocate mode: rollout and training share all 8 GPUs with offloading +# +# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh + +# ============================================================================ +# Cleanup: kill existing SGLang / Ray / Python processes +# ============================================================================ +pkill -9 sglang +sleep 3 +ray stop --force 2>/dev/null || true +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 +export FLASHINFER_DISABLE_VERSION_CHECK=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +source "/mntfn/yanyi/code/slime/scripts/models/qwen3.5-35B-A3B.sh" + +# MOPD teachers JSON config +export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' + +# ============================================================================ +# Configure training arguments +# ============================================================================ + +CKPT_ARGS=( + --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B/ + --ref-load /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + --load /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-full-vocab-test/ + --save /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-full-vocab-test/ + --save-interval 10 +) + +ROLLOUT_ARGS=( + --prompt-data /mntfn/yanyi/dataset/train_text_user_only.jsonl + --input-key messages + --apply-chat-template + --rollout-shuffle + --rollout-batch-size 4 + --n-samples-per-prompt 1 + --rollout-max-response-len 4096 + --rollout-temperature 0.8 + + --global-batch-size 4 + --balance-data + --num-epoch 1 +) + +RM_ARGS=( + # Pure distillation (mopd-alpha=0): rm-type defaults to "zero" automatically. +) + +EVAL_ARGS=( + # No eval for connectivity test +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 +) + +# MOPD Configuration — Full-Vocabulary KL Divergence Mode +# +# Key changes from token_level mode: +# 1. Added --mopd-distill-type full_vocab +# → Computes exact D_KL(π_θ ∥ π_d) = Σ_y π_θ(y)[log π_θ(y) - log π_d(y)] +# over the full vocabulary instead of token-level approximation. +# +# 2. --mopd-teacher-loads is REQUIRED for full_vocab mode +# → full_vocab needs megatron teacher forward pass to get full logits, +# SGLang rollout cannot provide per-token full-vocab logits. +# +# 3. Memory considerations: +# → full_vocab mode stores teacher logits [R_i, V_local] per sample per teacher. +# For V=248320, TP=2 → V_local=124160, each token's logits = ~480KB in fp32. +# With batch=4, R=4096: teacher logits per GPU ≈ 4×4096×124160×4B ≈ 7.6GB. +# Student logits (same shape) appear during training forward pass ≈ 1.9GB/micro-batch. +# Together with model (~9GB), optimizer (~26GB), and SGLang (40%=57GB), +# total ≈ 102GB / 143GB, leaving ~41GB headroom. +# If OOM: reduce rollout-batch-size, rollout-max-response-len, or sglang-mem-fraction-static. +# +# 4. Loss formula: +# → L = L_fv_kl + alpha * L_pg (when alpha > 0) +# → L = L_fv_kl (pure distillation, when alpha = 0) +# where L_fv_kl = (1/D) Σ_d w_d * D_KL(π_θ ∥ π_d) (IS-corrected) +# +# 5. IS weight correction still applies (same as token_level mode) +# +# For this connectivity test, the teacher IS the same model (self-distillation). +MOPD_ARGS=( + --advantage-estimator grpo + + # MOPD flags — single teacher + --use-mopd + # Pass JSON via env var MOPD_TEACHERS_JSON to avoid shell quoting issues. + + # *** KEY DIFFERENCE: full_vocab distillation type *** + --mopd-distill-type full_vocab + + # Teacher checkpoint = same as ref model (self-distillation for validation) + --mopd-teacher-loads /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + + # MOPD hyperparameters + --mopd-alpha 0.0 # Pure distillation, no ORM + --mopd-eps-low 0.2 # IS weight lower bound + --mopd-eps-high 5.0 # IS weight upper bound + --mopd-sampling-logprobs-key rollout_log_probs + + # Standard training flags + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 5e-7 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + #--use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3.5-35B-mopd-full-vocab-megatron +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.4 + --sglang-ep-size 8 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + + --moe-token-dispatcher-type flex + --moe-enable-deepep + + --colocate +) + +# ============================================================================ +# Launch training +# ============================================================================ + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" + +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON=$(python3 -c " +import json, os +env = { + 'PYTHONPATH': '/root/Megatron-LM/', + 'CUDA_DEVICE_MAX_CONNECTIONS': '1', + 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), + 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') +} +print(json.dumps({'env_vars': env})) +") + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${MOPD_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${RM_ARGS[@]} + +# ============================================================================ +# Cleanup +# ============================================================================ +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python \ No newline at end of file diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 63ce2ed7b0..3df46cf27c 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -31,7 +31,7 @@ from .cp_utils import slice_log_prob_with_cp, slice_with_cp from .data import DataIterator, get_data_iterator, log_perf_data, log_rollout_data from .initialize import init, is_megatron_main_rank -from .loss import compute_advantages_and_returns, get_log_probs_and_entropy, get_values +from .loss import compute_advantages_and_returns, get_log_probs_and_entropy, get_logits_for_distill, get_values from .model import forward_only, initialize_model_and_optimizer, save, train from .update_weight.common import named_params_and_buffers from .update_weight.update_weight_from_distributed import UpdateWeightFromDistributed @@ -395,17 +395,19 @@ def compute_log_prob( data_iterator: list[DataIterator], num_microbatches: list[int], store_prefix: str = "", + return_logits: bool = False, ) -> dict[str, list[torch.Tensor]]: with timer(f"{store_prefix}log_probs"): - return forward_only( - get_log_probs_and_entropy, + result = forward_only( + get_logits_for_distill if return_logits else get_log_probs_and_entropy, self.args, self.model, data_iterator, num_microbatches, store_prefix=store_prefix, ) + return result def train(self, rollout_id: int, rollout_data_ref: Box, external_data=None): if self.args.debug_rollout_only: @@ -490,21 +492,35 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data # Forward each MOPD teacher model for Megatron-based MOPD if getattr(self.args, "use_mopd", False) and hasattr(self, "_mopd_teacher_domains") and self._mopd_teacher_domains: mopd_teacher_log_probs = {} + use_full_vocab = getattr(self.args, "mopd_distill_type", "token_level") == "full_vocab" for domain in self._mopd_teacher_domains: tag = f"mopd_teacher_{domain}" if tag in self.weights_backuper.backup_tags: if self.args.use_routing_replay: os.environ["ROUTING_REPLAY_STAGE"] = "fallthrough" self._switch_model(tag) - teacher_result = self.compute_log_prob( - data_iterator, - num_microbatches, - store_prefix=f"mopd_teacher_{domain}_", - ) - # Store with domain-specific key - lp_key = f"mopd_teacher_{domain}_log_probs" - if lp_key in teacher_result: - mopd_teacher_log_probs[domain] = teacher_result[lp_key] + if use_full_vocab: + # Full-vocab mode: get full logits for exact KL computation + teacher_result = self.compute_log_prob( + data_iterator, + num_microbatches, + store_prefix=f"mopd_teacher_{domain}_fv_", + return_logits=True, + ) + # Store logits as a flat per-sample list under a domain-specific key + logits_key = f"mopd_teacher_{domain}_fv_logits" + if logits_key in teacher_result: + rollout_data[logits_key] = teacher_result[logits_key] + else: + # Token-level mode: only need log_probs + teacher_result = self.compute_log_prob( + data_iterator, + num_microbatches, + store_prefix=f"mopd_teacher_{domain}_", + ) + lp_key = f"mopd_teacher_{domain}_log_probs" + if lp_key in teacher_result: + mopd_teacher_log_probs[domain] = teacher_result[lp_key] if mopd_teacher_log_probs: rollout_data["mopd_teacher_log_probs"] = mopd_teacher_log_probs diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 7a0b0fd0b7..49dd7341f9 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -418,9 +418,13 @@ def log_rollout_data( "dynamic_global_batch_size", # Dict-typed keys that cannot be averaged directly "mopd_teacher_log_probs", + "mopd_teacher_logits", "mopd_reverse_kl", ]: continue + # Skip per-domain full-vocab teacher logits (too large for averaging) + if key.startswith("mopd_teacher_") and key.endswith("_fv_logits"): + continue # Upload per sample mean for each rollout value # There are the following assumptions: # - Each dp rank has the same number of samples diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 10c2be716f..d47d2e0567 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -2,7 +2,11 @@ from collections.abc import Callable, Iterator from typing import Any +import logging + import torch + +logger = logging.getLogger(__name__) import torch.distributed as dist import torch.nn.functional as F from megatron.core import mpu @@ -20,6 +24,7 @@ get_grpo_returns, get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, + vocab_parallel_reverse_kl, ) from slime.utils.types import RolloutBatch @@ -39,11 +44,12 @@ def get_responses( total_lengths: list[int], response_lengths: list[int], max_seq_lens: list[int] | None = None, + apply_temperature: bool = True, ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: """Yield response-aligned `(logits_chunk, tokens_chunk)` pairs per sample. - After squeezing batch dimension and applying temperature scaling, this - function extracts the logits and tokens corresponding to response segments + After squeezing batch dimension and (optionally) applying temperature scaling, + this function extracts the logits and tokens corresponding to response segments for each sample. When context parallelism is disabled, it slices directly from the concatenated sequence. With context parallelism enabled, it handles split sequences across ranks. @@ -55,6 +61,10 @@ def get_responses( unconcat_tokens: List of token tensors (prompt+response) per sample. total_lengths: Total sequence lengths (prompt+response) per sample. response_lengths: Response segment lengths per sample. + max_seq_lens: Optional padded max sequence lengths per sample (for bshd). + apply_temperature: If True (default), apply ``args.rollout_temperature`` + scaling to logits. Set to False when raw logits are needed, e.g. + for full-vocabulary KL divergence computation. Yields: Tuple of `(logits_chunk, tokens_chunk)` where `logits_chunk` is shape @@ -73,7 +83,7 @@ def get_responses( assert max_seq_lens is not None logits = logits.view(-1, logits.size(-1)) - if args.rollout_temperature != 1.0: + if apply_temperature and args.rollout_temperature != 1.0: logits = logits.div(args.rollout_temperature) cp_size = mpu.get_context_parallel_world_size() @@ -468,6 +478,89 @@ def get_log_probs_and_entropy( return torch.empty((0,), device=device), res +def get_logits_for_distill( + logits: torch.Tensor, + *, + args: Namespace, + unconcat_tokens: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], + with_entropy: bool = False, + non_loss_data: bool = True, + max_seq_lens: list[int] | None = None, +) -> dict[str, list[torch.Tensor]]: + """Extract per-sample response logits for full-vocab distillation (MOPD full_vocab mode). + + Similar to ``get_log_probs_and_entropy``, but returns the full logits tensor + ``[R, V]`` per sample (where R is response length, V is vocab size) instead of + log-probabilities. This is needed for computing exact KL divergence over the + full vocabulary: D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)]. + + No temperature scaling is applied — the raw logits are returned so the caller + can apply the desired softmax/log_softmax independently. + + Args: + logits: Model outputs with shape ``[1, T, V]``. Must be float32. + args: Configuration (needs ``qkv_format``, ``allgather_cp``). + unconcat_tokens: List of token tensors (prompt+response) per sample. + total_lengths: Total sequence lengths (prompt+response) per sample. + response_lengths: Response segment lengths per sample. + with_entropy: Unused; kept for signature compatibility. + non_loss_data: Unused; kept for signature compatibility. + max_seq_lens: Optional padded max sequence lengths per sample (for bshd). + + Returns: + Dict with key ``"logits"`` mapping to a list of ``[R, V]`` tensors per sample. + """ + assert logits.dtype == torch.float32, f"{logits.dtype}" + assert len(logits.shape) == 3, f"{logits.shape}" + + device = logits.device + + # Extract per-sample response logits chunks + # NOTE: apply_temperature=False — raw logits are needed for correct + # softmax/log_softmax in KL divergence computation. + # get_responses handles qkv_format reshaping internally. + logits_list = [] + for logits_chunk, _ in get_responses( + logits, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + max_seq_lens=max_seq_lens, + apply_temperature=False, + ): + logits_list.append(logits_chunk) + + res = {"logits": logits_list} + + # Handle allgather-CP redistribution + # NOTE: _allgather_cp_redistribute assumes 1D per-sample tensors (log_probs, entropy). + # Full-vocab logits are 2D [R_i, V], which is not supported by the current + # redistribution logic. Raise an explicit error so users don't get silent + # shape mismatches. + if args.allgather_cp: + cp_size = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() + if cp_size > 1: + raise NotImplementedError( + "MOPD full_vocab (get_logits_for_distill) does not support " + "allgather-CP with context_parallel_size > 1. The CP redistribution " + "logic assumes 1D tensors but logits are 2D [R, V]. Please disable " + "allgather_cp or set context_parallel_size=1 when using full_vocab mode." + ) + _allgather_cp_redistribute( + res, + logits_local_len=logits_list[0].size(0) if logits_list else 0, + args=args, + total_lengths=total_lengths, + response_lengths=response_lengths, + max_seq_lens=max_seq_lens, + ) + + return torch.empty((0,), device=device), res + + def get_values( logits: torch.Tensor, *, @@ -824,7 +917,9 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) ) # Apply Multi-Teacher On-Policy Distillation (MOPD) to advantages - if args.use_mopd: + # Skip token-level MOPD when using full_vocab distillation type; + # in that mode, the KL is computed directly in the loss function. + if args.use_mopd and getattr(args, "mopd_distill_type", "token_level") == "token_level": apply_mopd_to_advantages( args=args, rollout_data=rollout_data, @@ -942,6 +1037,176 @@ def icepop_function( return pg_loss, loss_masks, metrics +def apply_mopd_full_vocab_to_loss( + args: Namespace, + batch: RolloutBatch, + student_logits_per_sample: list[torch.Tensor], + teacher_logits_per_domain: dict[str, list[torch.Tensor | None]], + loss_masks: list[torch.Tensor], + sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], + current_log_probs: list[torch.Tensor] | None = None, +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute the full-vocabulary reverse KL divergence loss for MOPD. + + Instead of approximating the reverse KL from sampled tokens (token_level mode), + this computes the exact KL divergence over the full vocabulary: + + D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)] + + For each teacher domain d, the per-token KL is computed and then averaged + across teachers. Per-token importance sampling weights are applied identically + to token_level mode. + + When mopd_alpha > 0, the total loss is: + L = L_full_vocab_kl + alpha * L_pg_orm + where L_pg_orm is the standard policy gradient loss with ORM advantages. + When mopd_alpha == 0, L = L_full_vocab_kl (pure distillation). + + Args: + args: Configuration containing MOPD parameters. + batch: Mini-batch containing IS weight data and loss_masks. + student_logits_per_sample: List of per-sample student logits [R_i, V]. + teacher_logits_per_domain: Dict mapping domain to list of per-sample + teacher logits [R_i, V], with None for samples not routed to that domain. + loss_masks: List of per-sample loss masks. + sum_of_sample_mean: Reduction function for averaging. + current_log_probs: List of per-sample log-probs from the current training + forward pass. Used for importance sampling weight computation. + If None, falls back to batch["log_probs"] (pre-training forward pass). + + Returns: + Tuple of (kl_loss, metrics) where kl_loss is a scalar tensor and + metrics is a dict with logging tensors. + """ + # Get sampling log-probs μ_θ for importance sampling weight + sampling_logprobs_key = getattr(args, "mopd_sampling_logprobs_key", "rollout_log_probs") + sampling_log_probs = batch.get(sampling_logprobs_key) + if sampling_logprobs_key == "rollout_log_probs" and sampling_log_probs is None: + sampling_log_probs = batch.get("log_probs") + if sampling_log_probs is None: + raise ValueError( + f"MOPD full_vocab requires '{sampling_logprobs_key}' in batch for importance sampling." + ) + + num_samples = len(student_logits_per_sample) + if len(sampling_log_probs) != num_samples: + raise ValueError( + f"MOPD full_vocab: sampling_log_probs length ({len(sampling_log_probs)}) " + f"!= student_logits length ({num_samples})." + ) + all_kl_per_token = [] # will hold per-token KL tensors for all samples + tp_group = mpu.get_tensor_model_parallel_group() + # Stash per-domain per-sample KL for logging (detached) + per_domain_kls: dict[str, list[torch.Tensor]] = {} + + # Collect per-sample KL contributions across all teacher domains. + # For each sample, we average the KL across valid (non-None) teachers. + for i in range(num_samples): + R_i = student_logits_per_sample[i].size(0) + sample_kl_values = [] # collect KL contributions from each valid teacher + valid_teacher_count = 0 + + for domain, teacher_logits_list in teacher_logits_per_domain.items(): + if i >= len(teacher_logits_list) or teacher_logits_list[i] is None: + continue # skip this domain for this sample + + teacher_logits_i = teacher_logits_list[i] # [R_i, V_local] + + # D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)] + # Uses TP-aware computation when vocab is sharded across TP ranks. + kl_i = vocab_parallel_reverse_kl( + student_logits_per_sample[i], + teacher_logits_i, + tp_group, + ) # [R_i] + sample_kl_values.append(kl_i) + valid_teacher_count += 1 + + # Save for per-domain logging + if domain not in per_domain_kls: + per_domain_kls[domain] = [] + per_domain_kls[domain].append(kl_i.detach()) + + if valid_teacher_count > 0: + # Average KL across valid teachers + avg_kl_i = sum(sample_kl_values) / valid_teacher_count # [R_i] + else: + avg_kl_i = torch.zeros(R_i, device=student_logits_per_sample[i].device) + + all_kl_per_token.append(avg_kl_i) + + # Compute IS weights + # w_t = π_θ(y_t) / μ_θ(y_t) clipped to [eps_low, eps_high] + # We need the student's current log prob at the sampled token (π_θ(y_t)). + # This comes from the current training forward pass (not the stale pre-training + # pass in batch["log_probs"]). The caller passes these via current_log_probs. + # Fall back to batch["log_probs"] only if current_log_probs is not provided. + student_log_probs_at_sampled = current_log_probs if current_log_probs is not None else batch.get("log_probs") + if student_log_probs_at_sampled is not None and len(student_log_probs_at_sampled) != num_samples: + raise ValueError( + f"MOPD full_vocab: student_log_probs length ({len(student_log_probs_at_sampled)}) " + f"!= student_logits length ({num_samples})." + ) + + is_weight_per_sample = [] + for i in range(num_samples): + with torch.no_grad(): + if student_log_probs_at_sampled is not None: + # Use the per-token log probs from the current training forward pass + s_lp_i = student_log_probs_at_sampled[i].to(device=student_logits_per_sample[i].device) + else: + # Fallback: zero IS weights (effectively disabling IS correction) + s_lp_i = torch.zeros( + student_logits_per_sample[i].size(0), + device=student_logits_per_sample[i].device, + ) + samp_lp_i = sampling_log_probs[i].to(device=s_lp_i.device) + is_w_i = torch.exp(s_lp_i - samp_lp_i) + # Zero out weights outside [eps_low, eps_high] + is_w_i = torch.where( + (is_w_i >= args.mopd_eps_low) & (is_w_i <= args.mopd_eps_high), + is_w_i, + torch.zeros_like(is_w_i), + ) + is_weight_per_sample.append(is_w_i) + + # Apply IS weights to KL: per-token KL * IS weight * loss_mask + weighted_kl_tokens = [] + for i in range(num_samples): + mask_i = loss_masks[i].to(device=all_kl_per_token[i].device) + # Mask and weight the KL + weighted_kl_i = all_kl_per_token[i] * is_weight_per_sample[i] * mask_i # [R_i] + # Sum over tokens in the response, divide by response length for mean + R_i = mask_i.sum().clamp(min=1) # number of valid tokens + weighted_kl_tokens.append(weighted_kl_i.sum() / R_i) + + # Average across samples + if len(weighted_kl_tokens) > 0: + kl_loss = torch.stack(weighted_kl_tokens).mean() + else: + kl_loss = torch.tensor(0.0, device=student_logits_per_sample[0].device) + + # Logging metrics + all_kl_cat = torch.cat(all_kl_per_token, dim=0) + kl_mean = sum_of_sample_mean(all_kl_cat) + + is_weights_cat = torch.cat(is_weight_per_sample, dim=0) + is_weight_mean = sum_of_sample_mean(is_weights_cat) + is_nonzero_frac = sum_of_sample_mean((is_weights_cat != 0).float()) + + metrics = { + "mopd_fv_kl": kl_mean.clone().detach(), + "mopd_is_weight_mean": is_weight_mean.clone().detach(), + "mopd_is_nonzero_frac": is_nonzero_frac.clone().detach(), + } + + # Per-teacher KL for logging (re-use KL values computed in the main loop) + for domain, domain_kls in per_domain_kls.items(): + metrics[f"mopd_fv_kl/{domain}"] = sum_of_sample_mean(torch.cat(domain_kls, dim=0)).clone().detach() + + return kl_loss, metrics + + def policy_loss_function( args: Namespace, batch: RolloutBatch, @@ -1026,6 +1291,8 @@ def policy_loss_function( ) # Compute KL divergence (GSPO uses sequence-level KL, others use per-token KL) + # Save list-form log_probs before concatenation for potential use in MOPD full_vocab IS weights + current_log_probs_list = log_probs if args.advantage_estimator == "gspo": ppo_kl = compute_gspo_kl( full_log_probs=full_log_probs, @@ -1045,18 +1312,67 @@ def policy_loss_function( if args.use_opsm: pg_loss = pg_loss * opsm_mask - # Apply MOPD: replace advantages with mopd_advantages and apply IS weights + # Apply MOPD token_level: replace advantages with mopd_advantages and apply IS weights # L_MOPD(θ) = -E[1/|y| Σ_t w_t * Â_MOPD,t * log π_θ(y_t|x,y_ 0: loss = fv_kl_loss + alpha * pg_loss + batch["_mopd_fv_kl_loss"] = fv_kl_loss + else: + logger.warning("MOPD full_vocab enabled but no teacher logits found in batch. Skipping full_vocab KL loss.") + # Apply off-policy correction using importance sampling if enabled if args.get_mismatch_metrics or args.use_tis: # NOTE: @@ -1121,6 +1437,19 @@ def policy_loss_function( loss = pg_loss - args.entropy_coef * entropy_loss + # MOPD full_vocab: combine KL distillation loss with policy gradient loss + # L = L_full_vocab_kl + alpha * L_pg + # When alpha == 0 (pure distillation): L = L_full_vocab_kl (pg_loss is zeroed out) + # When alpha > 0: L = L_full_vocab_kl + alpha * L_pg (ORM policy gradient) + if use_mopd_full_vocab and "_mopd_fv_kl_loss" in batch: + fv_kl_loss = batch.pop("_mopd_fv_kl_loss") + if args.mopd_alpha > 0: + # Combine: full-vocab KL + alpha * policy gradient loss + loss = fv_kl_loss + args.mopd_alpha * loss + else: + # Pure distillation: only use full-vocab KL loss + loss = fv_kl_loss + if args.use_kl_loss: ref_log_probs = batch["ref_log_probs"] ref_log_probs = torch.cat(ref_log_probs, dim=0) @@ -1179,7 +1508,7 @@ def policy_loss_function( # Log MOPD metrics (IS weights and per-teacher reverse KL are already applied during # advantage computation and pg_loss re-weighting in the MOPD section above) - if getattr(args, "use_mopd", False) and "mopd_is_weights" in batch: + if getattr(args, "use_mopd", False) and not use_mopd_full_vocab and "mopd_is_weights" in batch: mopd_is_weights = torch.cat(batch["mopd_is_weights"], dim=0) reported_loss["mopd_is_weight_mean"] = sum_of_sample_mean(mopd_is_weights).clone().detach() mopd_is_nonzero = (mopd_is_weights != 0).float() @@ -1194,6 +1523,11 @@ def policy_loss_function( mopd_advantages = torch.cat(batch["mopd_advantages"], dim=0) reported_loss["mopd_advantage_mean"] = sum_of_sample_mean(mopd_advantages).clone().detach() + # Log MOPD full_vocab metrics + if use_mopd_full_vocab: + for key, value in mopd_fv_metrics.items(): + reported_loss[key] = value + return loss, reported_loss diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index f326b1d0d3..d4b9e58e9d 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -463,24 +463,37 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p """ # Get the batch. + # Collect base keys needed by all loss functions + batch_keys = [ + "tokens", + "multimodal_train_inputs", + "packed_seq_params", + "total_lengths", + "response_lengths", + "loss_masks", + "log_probs", + "ref_log_probs", + "values", + "advantages", + "returns", + "rollout_log_probs", + "max_seq_lens", + "teacher_log_probs", + ] + # Add MOPD full-vocab teacher logits keys if present + # These are stored as "mopd_teacher_{domain}_fv_logits" per domain + use_mopd_full_vocab = ( + getattr(args, "use_mopd", False) and getattr(args, "mopd_distill_type", "token_level") == "full_vocab" + ) + if use_mopd_full_vocab and hasattr(args, "_mopd_teachers_parsed"): + for teacher_cfg in args._mopd_teachers_parsed: + domain = teacher_cfg["domain"] + logits_key = f"mopd_teacher_{domain}_fv_logits" + batch_keys.append(logits_key) + batch = get_batch( data_iterator, - [ - "tokens", - "multimodal_train_inputs", - "packed_seq_params", - "total_lengths", - "response_lengths", - "loss_masks", - "log_probs", - "ref_log_probs", - "values", - "advantages", - "returns", - "rollout_log_probs", - "max_seq_lens", - "teacher_log_probs", - ], + batch_keys, args.data_pad_size_multiplier, args.qkv_format, args.allgather_cp, diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 602a61656c..7e628d020c 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1098,6 +1098,22 @@ def add_on_policy_distillation_arguments(parser): "'log_probs': use the training engine log-probs." ), ) + parser.add_argument( + "--mopd-distill-type", + type=str, + choices=["token_level", "full_vocab"], + default="token_level", + help=( + "MOPD distillation type. " + "'token_level' (default): use the sampled token log-prob difference as a reverse KL approximation, " + "applied at the advantage level. " + "'full_vocab': compute the exact full-vocabulary reverse KL divergence D_KL(π_θ ∥ π_d) " + "using complete logits from both student and teacher models. This is only supported with " + "megatron teacher mode (--mopd-teacher-loads), as it requires access to teacher logits " + "during training. The full-vocab KL loss is computed directly in the loss function rather " + "than through advantage modification." + ), + ) return parser def add_router_arguments(parser): @@ -1802,6 +1818,15 @@ def slime_validate_args(args): f"--mopd-eps-high ({args.mopd_eps_high}) must be > --mopd-eps-low ({args.mopd_eps_low})." ) + # Validate mopd_distill_type: full_vocab mode requires megatron teachers + if args.mopd_distill_type == "full_vocab": + if args.mopd_teacher_loads is None: + raise ValueError( + "--mopd-distill-type=full_vocab requires --mopd-teacher-loads (megatron teacher mode). " + "SGLang-based teachers cannot return full-vocabulary logits efficiently. " + "Please provide teacher checkpoints via --mopd-teacher-loads." + ) + # MOPD with megatron-based teachers requires weights_backuper (to backup multiple models) if args.mopd_teacher_loads is not None and not args.enable_weights_backuper: raise ValueError( diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2a858e7a3f..4b18601c88 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -198,6 +198,118 @@ def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Te return _VocabParallelEntropy.apply(logits, process_group) +class _VocabParallelReverseKL(torch.autograd.Function): + """Compute D_KL(π_student ∥ π_teacher) over a vocabulary-parallel partition. + + Both *student_logits* and *teacher_logits* are partial tensors along the + vocab dimension (each TP rank holds V/tp_size entries). The function + performs the necessary all-reduces to compute the exact reverse KL + divergence in a numerically stable manner: + + D_KL(π_s ∥ π_t) = Σ_y π_s(y) [log π_s(y) - log π_t(y)] + + Forward returns a tensor of shape [R] (one KL value per response token). + Backward propagates gradients w.r.t. *student_logits* only; teacher logits + are treated as constants (detached). + + Gradient derivation: + KL = Σ_y π_s(y) [log π_s(y) - log π_t(y)] + ∂KL/∂z_j = π_s(j) [log π_s(j) - log π_t(j) + 1 - Σ_k π_s(k)(log π_s(k) - log π_t(k) + 1)] + = π_s(j) [log π_s(j) - log π_t(j) - KL] (since Σ_k π_s(k) = 1) + where z_j are the student logits and log π_s is log_softmax(z). + """ + + @staticmethod + def forward( + ctx, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + process_group: dist.ProcessGroup, + ) -> torch.Tensor: + # --- student softmax (numerically stable, TP-aware) --- + s_max = student_logits.max(dim=-1, keepdim=True).values + dist.all_reduce(s_max, op=dist.ReduceOp.MAX, group=process_group) + s_shifted = student_logits - s_max + s_exp = s_shifted.exp() + s_sum_exp = s_exp.sum(dim=-1, keepdim=True) + dist.all_reduce(s_sum_exp, op=dist.ReduceOp.SUM, group=process_group) + s_softmax = s_exp / s_sum_exp # π_s(y) [R, V_local] + s_log_sum_exp = s_sum_exp.log() # [R, 1] + + # --- teacher log-softmax (numerically stable, TP-aware) --- + t_max = teacher_logits.max(dim=-1, keepdim=True).values + dist.all_reduce(t_max, op=dist.ReduceOp.MAX, group=process_group) + t_shifted = teacher_logits - t_max + t_exp = t_shifted.exp() + t_sum_exp = t_exp.sum(dim=-1, keepdim=True) + dist.all_reduce(t_sum_exp, op=dist.ReduceOp.SUM, group=process_group) + t_log_sum_exp = t_sum_exp.log() # [R, 1] + + # --- KL = Σ_y π_s(y) [log π_s(y) - log π_t(y)] --- + # log π_s(y) = s_shifted - s_log_sum_exp (local slice) + # log π_t(y) = t_shifted - t_log_sum_exp (local slice) + local_s_log_prob = s_shifted - s_log_sum_exp + local_t_log_prob = t_shifted - t_log_sum_exp + + local_kl_sum = (s_softmax * (local_s_log_prob - local_t_log_prob)).sum(dim=-1, keepdim=True) + dist.all_reduce(local_kl_sum, op=dist.ReduceOp.SUM, group=process_group) + kl = local_kl_sum.squeeze(dim=-1) # [R] + + # Save for backward + # We need: s_softmax, local_s_log_prob, local_t_log_prob, and kl (per-token) + ctx.save_for_backward(s_softmax, local_s_log_prob.detach(), local_t_log_prob.detach(), kl.detach()) + ctx.process_group = process_group + return kl + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + s_softmax, local_s_log_prob, local_t_log_prob, kl = ctx.saved_tensors + process_group = ctx.process_group + + # Gradient: ∂KL/∂z_j = π_s(j) * [log π_s(j) - log π_t(j) - KL] + # This is completely local per token — no all_reduce needed in backward. + grad_local = s_softmax * (local_s_log_prob - local_t_log_prob - kl.unsqueeze(-1)) + + grad_input = grad_output.unsqueeze(-1) * grad_local # [R, V_local] + return grad_input, None, None + + +def vocab_parallel_reverse_kl( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + process_group: dist.ProcessGroup, +) -> torch.Tensor: + """Compute D_KL(π_student ∥ π_teacher) with TP-aware vocab parallelism. + + Both inputs are partial logits along the vocab dimension (each TP rank + holds V/tp_size logits). Returns per-token KL of shape [R]. + + Teacher logits are detached (no gradient flows to the teacher). + + Args: + student_logits: [R, V_local] student logits (with grad). + teacher_logits: [R, V_local] teacher logits (detached). + process_group: TP process group for all-reduce. + + Returns: + Per-token KL divergence tensor of shape [R]. + """ + # Detach teacher logits — we never backprop through the teacher + teacher_logits = teacher_logits.detach() + + tp_size = dist.get_world_size(group=process_group) if process_group is not None else 1 + if tp_size <= 1: + # No TP — simple local computation + student_log_probs = F.log_softmax(student_logits, dim=-1) + student_probs = student_log_probs.exp() + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + kl = (student_probs * (student_log_probs - teacher_log_probs)).sum(dim=-1) + return kl + + # TP mode — use custom autograd function with correct backward + return _VocabParallelReverseKL.apply(student_logits, teacher_logits, process_group) + + def get_grpo_returns( rewards: torch.Tensor, kl: list[torch.Tensor], diff --git a/tests/test_mopd_full_vocab.py b/tests/test_mopd_full_vocab.py new file mode 100644 index 0000000000..81a2808954 --- /dev/null +++ b/tests/test_mopd_full_vocab.py @@ -0,0 +1,635 @@ +"""Unit tests for MOPD full_vocab distillation mode. + +Tests cover: +1. vocab_parallel_reverse_kl: correctness of full-vocabulary reverse KL +2. apply_mopd_full_vocab_to_loss: IS weights, multi-teacher averaging, loss combination +3. MOPD full_vocab argument validation (mopd_distill_type parameter) +4. get_logits_for_distill: temperature handling and shape +""" + +import sys +import types +from argparse import Namespace + +import pytest + +torch = pytest.importorskip("torch") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def make_mopd_full_vocab_args(**overrides): + """Create a Namespace with default MOPD full_vocab arguments.""" + defaults = dict( + use_mopd=True, + mopd_distill_type="full_vocab", + mopd_teachers='[{"name": "math_teacher", "domain": "math"}]', + mopd_teacher_loads="/tmp/fake_teacher", + mopd_teacher_ckpt_steps=None, + mopd_alpha=0.0, + mopd_eps_low=0.2, + mopd_eps_high=5.0, + mopd_sampling_logprobs_key="rollout_log_probs", + _mopd_teachers_parsed=[{"name": "math_teacher", "domain": "math"}], + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def mock_megatron(monkeypatch): + """Mock megatron.core.mpu for import.""" + mpu_mod = types.ModuleType("megatron.core") + mpu_sub = types.ModuleType("megatron.core.mpu") + mpu_sub.is_pipeline_last_stage = lambda: True + mpu_sub.get_context_parallel_rank = lambda: 0 + mpu_sub.get_context_parallel_world_size = lambda: 1 + mpu_sub.get_data_parallel_group = lambda: None + mpu_sub.get_data_parallel_rank = lambda: 0 + mpu_sub.get_data_parallel_world_size = lambda: 1 + mpu_sub.get_tensor_model_parallel_rank = lambda: 0 + mpu_sub.get_tensor_model_parallel_world_size = lambda: 1 + mpu_sub.get_tensor_model_parallel_group = lambda: None + + monkeypatch.setitem(sys.modules, "megatron", types.ModuleType("megatron")) + monkeypatch.setitem(sys.modules, "megatron.core", mpu_mod) + monkeypatch.setitem(sys.modules, "megatron.core.mpu", mpu_sub) + + +# --------------------------------------------------------------------------- +# Tests for vocab_parallel_reverse_kl +# --------------------------------------------------------------------------- +class TestVocabParallelReverseKL: + """Test the vocab_parallel_reverse_kl function in ppo_utils.py.""" + + def test_kl_correctness_identical_distributions(self): + """KL(student || teacher) = 0 when distributions are identical.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + # Identical distributions → KL = 0 + logits = torch.randn(4, 20) # [R=4, V=20] + kl = vocab_parallel_reverse_kl(logits, logits.clone(), process_group=None) + assert kl.shape == (4,) + assert torch.allclose(kl, torch.zeros(4), atol=1e-5), f"KL should be 0 for identical distributions, got {kl}" + + def test_kl_correctness_known_values(self): + """KL computed by vocab_parallel_reverse_kl matches manual computation.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + torch.manual_seed(42) + student_logits = torch.randn(3, 10, requires_grad=True) + teacher_logits = torch.randn(3, 10) + + # Our function + kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) + + # Manual computation + student_log_probs = torch.log_softmax(student_logits, dim=-1) + student_probs = student_log_probs.exp() + teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) + expected_kl = (student_probs * (student_log_probs - teacher_log_probs)).sum(dim=-1) + + assert torch.allclose(kl, expected_kl, atol=1e-5), ( + f"KL mismatch: got {kl}, expected {expected_kl}" + ) + + def test_kl_non_negative(self): + """KL divergence should always be non-negative (Gibbs' inequality).""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + torch.manual_seed(123) + for _ in range(10): + student_logits = torch.randn(5, 50) + teacher_logits = torch.randn(5, 50) + kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) + assert (kl >= -1e-5).all(), f"KL should be non-negative, got {kl.min()}" + + def test_kl_gradient_flows_through_student(self): + """Gradient flows through student logits but not teacher logits.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + torch.manual_seed(42) + student_logits = torch.randn(3, 10, requires_grad=True) + teacher_logits = torch.randn(3, 10, requires_grad=True) + + kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) + loss = kl.sum() + loss.backward() + + # Student should have gradients + assert student_logits.grad is not None, "student_logits should have gradients" + assert not torch.allclose(student_logits.grad, torch.zeros_like(student_logits.grad)), ( + "student_logits gradients should be non-zero" + ) + + # Teacher should NOT have gradients (detached inside function) + assert teacher_logits.grad is None or torch.allclose(teacher_logits.grad, torch.zeros_like(teacher_logits.grad)), ( + "teacher_logits should not have gradients (should be detached)" + ) + + def test_kl_gradient_correctness(self): + """Verify the gradient of KL matches autograd from manual computation.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + torch.manual_seed(42) + student_logits_1 = torch.randn(3, 10, requires_grad=True) + teacher_logits = torch.randn(3, 10) + + # Our function + kl_1 = vocab_parallel_reverse_kl(student_logits_1, teacher_logits, process_group=None) + loss_1 = kl_1.sum() + loss_1.backward() + grad_ours = student_logits_1.grad.clone() + + # Manual computation for gradient comparison + student_logits_2 = student_logits_1.detach().clone().requires_grad_(True) + student_log_probs = torch.log_softmax(student_logits_2, dim=-1) + student_probs = student_log_probs.exp() + teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) + kl_2 = (student_probs * (student_log_probs - teacher_log_probs)).sum(dim=-1) + loss_2 = kl_2.sum() + loss_2.backward() + grad_manual = student_logits_2.grad.clone() + + assert torch.allclose(grad_ours, grad_manual, atol=1e-4), ( + f"Gradient mismatch: max diff = {(grad_ours - grad_manual).abs().max()}" + ) + + def test_kl_temperature_sensitivity(self): + """KL should change when student distribution changes.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + teacher_logits = torch.randn(3, 10) + student_logits_1 = torch.randn(3, 10) + student_logits_2 = torch.randn(3, 10) + + kl_1 = vocab_parallel_reverse_kl(student_logits_1, teacher_logits, process_group=None) + kl_2 = vocab_parallel_reverse_kl(student_logits_2, teacher_logits, process_group=None) + + # Different student distributions should give different KL values + assert not torch.allclose(kl_1, kl_2, atol=1e-5), "Different student logits should give different KL" + + def test_kl_large_vocabulary(self): + """Test with a larger vocabulary to verify numerical stability.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + torch.manual_seed(42) + student_logits = torch.randn(8, 32000) * 0.1 # Small scale for stability + teacher_logits = torch.randn(8, 32000) * 0.1 + + kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) + assert kl.shape == (8,) + assert torch.isfinite(kl).all(), f"KL should be finite, got nan/inf: {kl}" + assert (kl >= -1e-4).all(), f"KL should be non-negative, got min={kl.min()}" + + +# --------------------------------------------------------------------------- +# Tests for apply_mopd_full_vocab_to_loss +# --------------------------------------------------------------------------- +class TestApplyMopdFullVocabToLoss: + """Test the apply_mopd_full_vocab_to_loss function.""" + + @pytest.fixture(autouse=True) + def _mock_deps(self, monkeypatch): + mock_megatron(monkeypatch) + + def _get_function(self): + from slime.backends.megatron_utils.loss import apply_mopd_full_vocab_to_loss + return apply_mopd_full_vocab_to_loss + + def _sum_of_sample_mean(self, tensor): + """Simple mean reduction for testing.""" + return tensor.mean() + + def test_single_teacher_kl_loss(self): + """Test single-teacher full-vocab KL loss computation.""" + apply_fn = self._get_function() + args = make_mopd_full_vocab_args(mopd_eps_low=0.0, mopd_eps_high=1000.0) + torch.manual_seed(42) + + # 2 samples, vocab_size=10, response_length=3 + V = 10 + student_logits_1 = torch.randn(3, V) + student_logits_2 = torch.randn(4, V) + teacher_logits_1 = torch.randn(3, V) + teacher_logits_2 = torch.randn(4, V) + + student_logits = [student_logits_1, student_logits_2] + teacher_logits_per_domain = { + "math": [teacher_logits_1, teacher_logits_2], + } + + # When sampling_log_probs == current_log_probs (on-policy), IS weight = 1.0 + batch = { + "rollout_log_probs": [torch.zeros(3), torch.zeros(4)], + } + current_log_probs = [torch.zeros(3), torch.zeros(4)] + loss_masks = [torch.ones(3), torch.ones(4)] + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + assert kl_loss.shape == (), "kl_loss should be scalar" + assert kl_loss.item() >= 0, "KL loss should be non-negative" + assert "mopd_fv_kl" in metrics + assert "mopd_is_weight_mean" in metrics + assert "mopd_is_nonzero_frac" in metrics + assert "mopd_fv_kl/math" in metrics + + def test_identical_student_teacher_zero_kl(self): + """When student == teacher, KL should be ~0 and loss should be ~0.""" + apply_fn = self._get_function() + args = make_mopd_full_vocab_args(mopd_eps_low=0.0, mopd_eps_high=1000.0) + + V = 10 + student_logits_1 = torch.randn(3, V) + student_logits_2 = torch.randn(4, V) + + # Teacher = Student → KL = 0 + teacher_logits_per_domain = { + "math": [student_logits_1.clone(), student_logits_2.clone()], + } + + batch = { + "rollout_log_probs": [torch.zeros(3), torch.zeros(4)], + } + current_log_probs = [torch.zeros(3), torch.zeros(4)] + loss_masks = [torch.ones(3), torch.ones(4)] + + kl_loss, metrics = apply_fn( + args, batch, [student_logits_1, student_logits_2], + teacher_logits_per_domain, loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + assert kl_loss.item() < 1e-4, f"KL loss should be ~0 for identical distributions, got {kl_loss.item()}" + + def test_multi_teacher_averaging(self): + """Test that KL is averaged across multiple teachers.""" + apply_fn = self._get_function() + args = make_mopd_full_vocab_args( + mopd_eps_low=0.0, mopd_eps_high=1000.0, + _mopd_teachers_parsed=[ + {"name": "math_teacher", "domain": "math"}, + {"name": "code_teacher", "domain": "code"}, + ], + ) + torch.manual_seed(42) + + V = 10 + student_logits = [torch.randn(3, V)] + teacher_math = [torch.randn(3, V)] + teacher_code = [torch.randn(3, V)] + + teacher_logits_per_domain = { + "math": teacher_math, + "code": teacher_code, + } + + batch = { + "rollout_log_probs": [torch.zeros(3)], + } + current_log_probs = [torch.zeros(3)] + loss_masks = [torch.ones(3)] + + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + kl_math = vocab_parallel_reverse_kl(student_logits[0], teacher_math[0], None) + kl_code = vocab_parallel_reverse_kl(student_logits[0], teacher_code[0], None) + expected_avg_kl = (kl_math.sum() / 3 + kl_code.sum() / 4) / 2 # Not exact, just check shape + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + # Should have per-domain logging + assert "mopd_fv_kl/math" in metrics + assert "mopd_fv_kl/code" in metrics + # Both should be non-negative + assert metrics["mopd_fv_kl/math"].item() >= -1e-5 + assert metrics["mopd_fv_kl/code"].item() >= -1e-5 + + def test_is_weight_clipping(self): + """Test that IS weights are clipped to [eps_low, eps_high].""" + apply_fn = self._get_function() + args = make_mopd_full_vocab_args(mopd_eps_low=0.5, mopd_eps_high=2.0) + + V = 10 + student_logits = [torch.randn(3, V)] + teacher_logits_per_domain = {"math": [torch.randn(3, V)]} + + # IS weight = exp(current_log_probs - rollout_log_probs) + # For token 0: exp(-5 - 0) = exp(-5) ≈ 0.0067 < eps_low → zeroed + # For token 1: exp(5 - 0) = exp(5) ≈ 148 > eps_high → zeroed + # For token 2: exp(0 - 0) = 1.0 → kept + batch = { + "rollout_log_probs": [torch.tensor([0.0, 0.0, 0.0])], + } + current_log_probs = [torch.tensor([-5.0, 5.0, 0.0])] + loss_masks = [torch.ones(3)] + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + # IS weight should be clipped — only token 2 (weight=1.0) survives + # Nonzero fraction should be 1/3 + is_nonzero_frac = metrics["mopd_is_nonzero_frac"].item() + assert abs(is_nonzero_frac - 1.0 / 3.0) < 0.05, ( + f"Expected ~1/3 nonzero IS weight fraction, got {is_nonzero_frac}" + ) + + def test_none_teacher_for_sample(self): + """Test that None entries in teacher logits are skipped.""" + apply_fn = self._get_function() + + # Two samples, two teachers; sample 0 has only math, sample 1 has both + args = make_mopd_full_vocab_args( + mopd_eps_low=0.0, mopd_eps_high=1000.0, + _mopd_teachers_parsed=[ + {"name": "math_teacher", "domain": "math"}, + {"name": "code_teacher", "domain": "code"}, + ], + ) + + V = 10 + student_0 = torch.randn(3, V) + student_1 = torch.randn(4, V) + + teacher_logits_per_domain = { + "math": [torch.randn(3, V), torch.randn(4, V)], + "code": [None, torch.randn(4, V)], # sample 0 has no code teacher + } + + batch = { + "rollout_log_probs": [torch.zeros(3), torch.zeros(4)], + } + current_log_probs = [torch.zeros(3), torch.zeros(4)] + loss_masks = [torch.ones(3), torch.ones(4)] + + kl_loss, metrics = apply_fn( + args, batch, [student_0, student_1], + teacher_logits_per_domain, loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + assert kl_loss.shape == () + assert torch.isfinite(kl_loss), "Loss should be finite with None teacher entries" + + def test_loss_mask_effect(self): + """Test that loss_mask correctly masks out tokens.""" + apply_fn = self._get_function() + args = make_mopd_full_vocab_args(mopd_eps_low=0.0, mopd_eps_high=1000.0) + + V = 10 + # Same student and teacher but with masking + student_logits = [torch.randn(5, V)] + teacher_logits_per_domain = {"math": [torch.randn(5, V)]} + + # Only tokens 1-3 are valid (mask out 0 and 4) + loss_masks = [torch.tensor([0.0, 1.0, 1.0, 1.0, 0.0])] + + batch = { + "rollout_log_probs": [torch.zeros(5)], + } + current_log_probs = [torch.zeros(5)] + + kl_loss_masked, _ = apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + # With all-ones mask for comparison + loss_masks_all = [torch.ones(5)] + kl_loss_all, _ = apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks_all, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + # The losses should be different since masking excludes tokens + # Both should be non-negative + assert kl_loss_masked.item() >= 0 + assert kl_loss_all.item() >= 0 + + def test_current_log_probs_used_for_is_weights(self): + """Test that current_log_probs (not batch['log_probs']) are used for IS weights.""" + apply_fn = self._get_function() + # Use tight clipping to detect which log_probs are used + args = make_mopd_full_vocab_args(mopd_eps_low=0.5, mopd_eps_high=2.0) + + V = 10 + student_logits = [torch.randn(3, V)] + teacher_logits_per_domain = {"math": [torch.randn(3, V)]} + + # rollout_log_probs = [0, 0, 0] + # current_log_probs = [-5, 0, 5] → IS weights: exp(-5)≈0.007, 1.0, exp(5)≈148 + # With current_log_probs: tokens 0 and 2 are zeroed out (outside [0.5, 2.0]) + # batch['log_probs'] = [0, 0, 0] → all IS weights = 1.0 (within [0.5, 2.0]) + batch = { + "rollout_log_probs": [torch.tensor([0.0, 0.0, 0.0])], + # This is stale batch["log_probs"]; should NOT be used when current_log_probs is provided + "log_probs": [torch.tensor([0.0, 0.0, 0.0])], + } + current_log_probs = [torch.tensor([-5.0, 0.0, 5.0])] + loss_masks = [torch.ones(3)] + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + # Only token 1 should survive IS weight clipping → 1/3 nonzero + is_nonzero_frac = metrics["mopd_is_nonzero_frac"].item() + assert abs(is_nonzero_frac - 1.0 / 3.0) < 0.05, ( + f"Expected ~1/3 nonzero IS weight fraction with current_log_probs, got {is_nonzero_frac}" + ) + + def test_current_log_probs_length_mismatch_raises(self): + """Test that mismatched current_log_probs length raises ValueError.""" + apply_fn = self._get_function() + args = make_mopd_full_vocab_args(mopd_eps_low=0.0, mopd_eps_high=1000.0) + + V = 10 + student_logits = [torch.randn(3, V), torch.randn(4, V)] + teacher_logits_per_domain = {"math": [torch.randn(3, V), torch.randn(4, V)]} + batch = {"rollout_log_probs": [torch.zeros(3), torch.zeros(4)]} + loss_masks = [torch.ones(3), torch.ones(4)] + + # Mismatch: 2 samples but only 1 log_probs entry + bad_current_log_probs = [torch.zeros(3)] + + with pytest.raises(ValueError, match="student_log_probs length"): + apply_fn( + args, batch, student_logits, teacher_logits_per_domain, + loss_masks, self._sum_of_sample_mean, + current_log_probs=bad_current_log_probs, + ) + + +# --------------------------------------------------------------------------- +# Tests for mopd_distill_type argument validation +# --------------------------------------------------------------------------- +class TestMopdDistillTypeValidation: + """Test --mopd-distill-type parameter validation.""" + + @pytest.fixture(autouse=True) + def _mock_deps(self, monkeypatch): + """Mock megatron and other dependencies.""" + megatron_mod = types.ModuleType("megatron") + training_mod = types.ModuleType("megatron.training") + arguments_mod = types.ModuleType("megatron.training.arguments") + arguments_mod.parse_args = lambda *a, **kw: None + arguments_mod.validate_args = lambda a: a + tokenizer_pkg_mod = types.ModuleType("megatron.training.tokenizer") + tokenizer_mod = types.ModuleType("megatron.training.tokenizer.tokenizer") + tokenizer_mod._vocab_size_with_padding = lambda vocab_size, _args: vocab_size + transformers_mod = types.ModuleType("transformers") + transformers_mod.AutoConfig = types.SimpleNamespace(from_pretrained=lambda *a, **kw: None) + + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.training", training_mod) + monkeypatch.setitem(sys.modules, "megatron.training.arguments", arguments_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer", tokenizer_pkg_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer.tokenizer", tokenizer_mod) + monkeypatch.setitem(sys.modules, "transformers", transformers_mod) + + def _make_base_args(self, **overrides): + defaults = dict( + use_opd=False, + opd_type=None, + opd_kl_coef=1.0, + opd_teacher_load=None, + opd_teacher_ckpt_step=None, + use_mopd=False, + mopd_teachers=None, + mopd_teacher_loads=None, + mopd_teacher_ckpt_steps=None, + mopd_alpha=0.0, + mopd_eps_low=0.2, + mopd_eps_high=5.0, + mopd_sampling_logprobs_key="rollout_log_probs", + mopd_distill_type="token_level", + enable_weights_backuper=True, + eval_datasets=[], + eval_prompt_data=None, + kl_coef=0, + ref_load="/tmp/fake_ref", + use_kl_loss=False, + use_critic=False, + rm_type=None, + custom_rm_path=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + def test_full_vocab_without_teacher_loads_raises(self): + """Test that full_vocab mode without --mopd-teacher-loads raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_distill_type="full_vocab", + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_teacher_loads=None, + ) + with pytest.raises(ValueError, match="full_vocab.*mopd-teacher-loads|megatron teacher"): + slime_validate_args(args) + + def test_token_level_is_default(self): + """Test that token_level is the default distill type.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_alpha=0.0, + ) + slime_validate_args(args) + assert args.mopd_distill_type == "token_level" + + def test_full_vocab_with_teacher_loads_ok(self): + """Test that full_vocab mode with --mopd-teacher-loads does not raise.""" + from slime.utils.arguments import slime_validate_args, tmp_path + + ckpt_dir = tmp_path / "teacher1" + ckpt_dir.mkdir() + (ckpt_dir / "latest_checkpointed_iteration.txt").write_text("1") + + args = self._make_base_args( + use_mopd=True, + mopd_distill_type="full_vocab", + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_teacher_loads=[str(ckpt_dir)], + mopd_alpha=0.0, + ) + slime_validate_args(args) # Should not raise + + +# --------------------------------------------------------------------------- +# Tests for get_logits_for_distill temperature handling +# --------------------------------------------------------------------------- +class TestGetLogitsForDistill: + """Test get_logits_for_distill returns raw logits without temperature scaling.""" + + @pytest.fixture(autouse=True) + def _mock_deps(self, monkeypatch): + mock_megatron(monkeypatch) + + def test_no_temperature_scaling(self): + """Verify that get_logits_for_distill returns raw logits (no temperature).""" + from slime.backends.megatron_utils.loss import get_logits_for_distill + + torch.manual_seed(42) + V = 10 + T = 6 # total sequence length (prompt + response) + R = 3 # response length + + # Create fake logits [1, T, V] + logits = torch.randn(1, T, V) + + # Create args with rollout_temperature != 1.0 + args = Namespace( + qkv_format="bshd", + rollout_temperature=2.0, # temperature scaling + allgather_cp=False, + log_probs_chunk_size=-1, + ) + + # Create fake tokens and length info + tokens = torch.randint(0, V, (1, T)) + unconcat_tokens = [tokens[0]] + total_lengths = [T] + response_lengths = [R] + max_seq_lens = [T] + + result_tensor, result_dict = get_logits_for_distill( + logits, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + max_seq_lens=max_seq_lens, + ) + + # Should return raw logits (no temperature scaling) + assert "logits" in result_dict + logits_out = result_dict["logits"] + assert len(logits_out) == 1 # 1 sample + assert logits_out[0].shape == (R, V), f"Expected shape ({R}, {V}), got {logits_out[0].shape}" + + # The returned logits should NOT be divided by temperature + # Compare with manual extraction of response logits from the input + # Response logits: logits[0, T-R-1:T-1, :] (shifted by 1 for next-token prediction) + expected_logits = logits[0, T - R - 1 : T - 1, :] # [R, V] + assert torch.allclose(logits_out[0], expected_logits, atol=1e-5), ( + "get_logits_for_distill should return raw logits without temperature scaling" + ) \ No newline at end of file From cb18b2a4d43770d8531ce009021fc4aac076487e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 15:19:35 +0800 Subject: [PATCH 03/14] feat: add TopK distillation strategy for MOPD - Implement TopK token selection for efficient distillation loss computation - Add TopK-related arguments (topk_tokens, topk_temperature) - Add 397B model startup script (scripts/models/qwen3.5-397B-A17B.sh) - Extend ppo_utils with TopK logits extraction and processing - Update full-vocab megatron script with TopK options - Extend tests for TopK distillation mode --- ...qwen35-35B-A3B-mopd-full-vocab-megatron.sh | 9 + scripts/models/qwen3.5-397B-A17B.sh | 61 +++ slime/backends/megatron_utils/actor.py | 33 +- slime/backends/megatron_utils/data.py | 3 + slime/backends/megatron_utils/loss.py | 244 ++++++++++- slime/backends/megatron_utils/model.py | 12 + slime/utils/arguments.py | 28 +- slime/utils/ppo_utils.py | 157 +++++++ tests/test_mopd_full_vocab.py | 408 +++++++++++++++++- 9 files changed, 923 insertions(+), 32 deletions(-) create mode 100644 scripts/models/qwen3.5-397B-A17B.sh diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh index fcc73e2c9e..6e2e6b17e0 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh @@ -130,6 +130,15 @@ PERF_ARGS=( # # 5. IS weight correction still applies (same as token_level mode) # +# Alternative: Use top_k mode for memory-efficient approximate KL: +# Replace --mopd-distill-type full_vocab with: +# --mopd-distill-type top_k +# --mopd-topk-k 1024 +# This stores only [R_i, k] teacher logits+indices per sample (k=1024 by default), +# plus a tail probability correction. Memory per sample ≈ k*5B per token +# (vs V*4B for full_vocab). For k=1024, V=248320: ~98.7% memory reduction. +# Teacher logits per GPU ≈ 4×4096×1024×(4+4)B ≈ 128MB (negligible vs full_vocab). +# # For this connectivity test, the teacher IS the same model (self-distillation). MOPD_ARGS=( --advantage-estimator grpo diff --git a/scripts/models/qwen3.5-397B-A17B.sh b/scripts/models/qwen3.5-397B-A17B.sh new file mode 100644 index 0000000000..7da83a9a53 --- /dev/null +++ b/scripts/models/qwen3.5-397B-A17B.sh @@ -0,0 +1,61 @@ +# Qwen3.5-397B-A17B (MoE, 512 experts, 10 active) +# VLM model with linear_attention + full_attention hybrid layers + +NLAYERS=60 +FIRST_K_DENSE_REPLACE=0 + +arr=() +for ((i=0; i 1: raise NotImplementedError( - "MOPD full_vocab (get_logits_for_distill) does not support " + "MOPD full_vocab/top_k (get_logits_for_distill) does not support " "allgather-CP with context_parallel_size > 1. The CP redistribution " "logic assumes 1D tensors but logits are 2D [R, V]. Please disable " - "allgather_cp or set context_parallel_size=1 when using full_vocab mode." + "allgather_cp or set context_parallel_size=1 when using full_vocab/top_k mode." ) _allgather_cp_redistribute( res, @@ -1207,6 +1208,167 @@ def apply_mopd_full_vocab_to_loss( return kl_loss, metrics +def apply_mopd_topk_to_loss( + args: Namespace, + batch: RolloutBatch, + student_logits_per_sample: list[torch.Tensor], + teacher_topk_logits_per_domain: dict[str, list[torch.Tensor | None]], + teacher_topk_indices_per_domain: dict[str, list[torch.Tensor | None]], + loss_masks: list[torch.Tensor], + sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], + current_log_probs: list[torch.Tensor] | None = None, +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute the top-k approximate reverse KL divergence loss for MOPD. + + Instead of computing the exact full-vocab KL (which stores [R, V] per sample), + this uses teacher's top-k logits plus a tail probability correction: + + KL ≈ KL_topk + KL_tail + + where: + KL_topk = Σ_{y ∈ top-k} π_s(y) [log π_s(y) - log π_t(y)] + KL_tail ≈ π_s_tail * log(π_s_tail / π_t_tail) + + Teacher provides pre-computed top-k logits and indices, while student has + full logits. This reduces memory from O(R*V) to O(R*k) per sample per teacher. + + Args: + args: Configuration containing MOPD parameters including mopd_topk_k. + batch: Mini-batch containing IS weight data and loss_masks. + student_logits_per_sample: List of per-sample student logits [R_i, V_local]. + teacher_topk_logits_per_domain: Dict mapping domain to list of per-sample + teacher top-k logits [R_i, k], with None for samples not routed to that domain. + teacher_topk_indices_per_domain: Dict mapping domain to list of per-sample + teacher top-k LOCAL indices [R_i, k] (within TP shard), with None for + samples not routed to that domain. + loss_masks: List of per-sample loss masks. + sum_of_sample_mean: Reduction function for averaging. + current_log_probs: List of per-sample log-probs from the current training + forward pass. Used for importance sampling weight computation. + If None, falls back to batch["log_probs"]. + + Returns: + Tuple of (kl_loss, metrics) where kl_loss is a scalar tensor and + metrics is a dict with logging tensors. + """ + # Get sampling log-probs μ_θ for importance sampling weight + sampling_logprobs_key = getattr(args, "mopd_sampling_logprobs_key", "rollout_log_probs") + sampling_log_probs = batch.get(sampling_logprobs_key) + if sampling_logprobs_key == "rollout_log_probs" and sampling_log_probs is None: + sampling_log_probs = batch.get("log_probs") + if sampling_log_probs is None: + raise ValueError( + f"MOPD top_k requires '{sampling_logprobs_key}' in batch for importance sampling." + ) + + vocab_size = args.padded_vocab_size + num_samples = len(student_logits_per_sample) + if len(sampling_log_probs) != num_samples: + raise ValueError( + f"MOPD top_k: sampling_log_probs length ({len(sampling_log_probs)}) " + f"!= student_logits length ({num_samples})." + ) + + tp_group = mpu.get_tensor_model_parallel_group() + all_kl_per_token = [] + per_domain_kls: dict[str, list[torch.Tensor]] = {} + + for i in range(num_samples): + R_i = student_logits_per_sample[i].size(0) + sample_kl_values = [] + valid_teacher_count = 0 + + for domain in teacher_topk_logits_per_domain: + if ( + i >= len(teacher_topk_logits_per_domain[domain]) + or teacher_topk_logits_per_domain[domain][i] is None + ): + continue + + t_topk_logits = teacher_topk_logits_per_domain[domain][i] # [R_i, k] + t_topk_indices = teacher_topk_indices_per_domain[domain][i] # [R_i, k] + + kl_i = vocab_parallel_topk_reverse_kl( + student_logits_per_sample[i], + t_topk_logits, + t_topk_indices, + vocab_size, + tp_group, + ) # [R_i] + sample_kl_values.append(kl_i) + valid_teacher_count += 1 + + if domain not in per_domain_kls: + per_domain_kls[domain] = [] + per_domain_kls[domain].append(kl_i.detach()) + + if valid_teacher_count > 0: + avg_kl_i = sum(sample_kl_values) / valid_teacher_count + else: + avg_kl_i = torch.zeros(R_i, device=student_logits_per_sample[i].device) + + all_kl_per_token.append(avg_kl_i) + + # Compute IS weights (identical logic to full_vocab) + student_log_probs_at_sampled = current_log_probs if current_log_probs is not None else batch.get("log_probs") + if student_log_probs_at_sampled is not None and len(student_log_probs_at_sampled) != num_samples: + raise ValueError( + f"MOPD top_k: student_log_probs length ({len(student_log_probs_at_sampled)}) " + f"!= student_logits length ({num_samples})." + ) + + is_weight_per_sample = [] + for i in range(num_samples): + with torch.no_grad(): + if student_log_probs_at_sampled is not None: + s_lp_i = student_log_probs_at_sampled[i].to(device=student_logits_per_sample[i].device) + else: + s_lp_i = torch.zeros( + student_logits_per_sample[i].size(0), + device=student_logits_per_sample[i].device, + ) + samp_lp_i = sampling_log_probs[i].to(device=s_lp_i.device) + is_w_i = torch.exp(s_lp_i - samp_lp_i) + is_w_i = torch.where( + (is_w_i >= args.mopd_eps_low) & (is_w_i <= args.mopd_eps_high), + is_w_i, + torch.zeros_like(is_w_i), + ) + is_weight_per_sample.append(is_w_i) + + # Apply IS weights to KL + weighted_kl_tokens = [] + for i in range(num_samples): + mask_i = loss_masks[i].to(device=all_kl_per_token[i].device) + weighted_kl_i = all_kl_per_token[i] * is_weight_per_sample[i] * mask_i + R_i = mask_i.sum().clamp(min=1) + weighted_kl_tokens.append(weighted_kl_i.sum() / R_i) + + if len(weighted_kl_tokens) > 0: + kl_loss = torch.stack(weighted_kl_tokens).mean() + else: + kl_loss = torch.tensor(0.0, device=student_logits_per_sample[0].device) + + # Logging metrics + all_kl_cat = torch.cat(all_kl_per_token, dim=0) + kl_mean = sum_of_sample_mean(all_kl_cat) + + is_weights_cat = torch.cat(is_weight_per_sample, dim=0) + is_weight_mean = sum_of_sample_mean(is_weights_cat) + is_nonzero_frac = sum_of_sample_mean((is_weights_cat != 0).float()) + + metrics = { + "mopd_topk_kl": kl_mean.clone().detach(), + "mopd_is_weight_mean": is_weight_mean.clone().detach(), + "mopd_is_nonzero_frac": is_nonzero_frac.clone().detach(), + } + + for domain, domain_kls in per_domain_kls.items(): + metrics[f"mopd_topk_kl/{domain}"] = sum_of_sample_mean(torch.cat(domain_kls, dim=0)).clone().detach() + + return kl_loss, metrics + + def policy_loss_function( args: Namespace, batch: RolloutBatch, @@ -1315,16 +1477,18 @@ def policy_loss_function( # Apply MOPD token_level: replace advantages with mopd_advantages and apply IS weights # L_MOPD(θ) = -E[1/|y| Σ_t w_t * Â_MOPD,t * log π_θ(y_t|x,y_ 0: L = L_full_vocab_kl + alpha * L_pg (ORM policy gradient) - if use_mopd_full_vocab and "_mopd_fv_kl_loss" in batch: - fv_kl_loss = batch.pop("_mopd_fv_kl_loss") + # MOPD logits-based distillation (full_vocab / top_k): combine KL loss with pg loss + # L = L_kl + alpha * L_pg + # When alpha == 0 (pure distillation): L = L_kl (pg_loss is zeroed out) + # When alpha > 0: L = L_kl + alpha * L_pg (ORM policy gradient) + if use_mopd_logits_based and "_mopd_fv_kl_loss" in batch: + kl_distill_loss = batch.pop("_mopd_fv_kl_loss") if args.mopd_alpha > 0: - # Combine: full-vocab KL + alpha * policy gradient loss - loss = fv_kl_loss + args.mopd_alpha * loss + # Combine: distillation KL + alpha * policy gradient loss + loss = kl_distill_loss + args.mopd_alpha * loss else: - # Pure distillation: only use full-vocab KL loss - loss = fv_kl_loss + # Pure distillation: only use distillation KL loss + loss = kl_distill_loss if args.use_kl_loss: ref_log_probs = batch["ref_log_probs"] @@ -1523,8 +1729,8 @@ def policy_loss_function( mopd_advantages = torch.cat(batch["mopd_advantages"], dim=0) reported_loss["mopd_advantage_mean"] = sum_of_sample_mean(mopd_advantages).clone().detach() - # Log MOPD full_vocab metrics - if use_mopd_full_vocab: + # Log MOPD logits-based distillation metrics (full_vocab / top_k) + if use_mopd_logits_based: for key, value in mopd_fv_metrics.items(): reported_loss[key] = value diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index d4b9e58e9d..931f413531 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -491,6 +491,18 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p logits_key = f"mopd_teacher_{domain}_fv_logits" batch_keys.append(logits_key) + # Add MOPD top-k teacher logits/indices keys if present + use_mopd_top_k = ( + getattr(args, "use_mopd", False) and getattr(args, "mopd_distill_type", "token_level") == "top_k" + ) + if use_mopd_top_k and hasattr(args, "_mopd_teachers_parsed"): + for teacher_cfg in args._mopd_teachers_parsed: + domain = teacher_cfg["domain"] + topk_logits_key = f"mopd_teacher_{domain}_topk_logits" + topk_indices_key = f"mopd_teacher_{domain}_topk_indices" + batch_keys.append(topk_logits_key) + batch_keys.append(topk_indices_key) + batch = get_batch( data_iterator, batch_keys, diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 7e628d020c..3391b60fb8 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1101,7 +1101,7 @@ def add_on_policy_distillation_arguments(parser): parser.add_argument( "--mopd-distill-type", type=str, - choices=["token_level", "full_vocab"], + choices=["token_level", "full_vocab", "top_k"], default="token_level", help=( "MOPD distillation type. " @@ -1111,7 +1111,21 @@ def add_on_policy_distillation_arguments(parser): "using complete logits from both student and teacher models. This is only supported with " "megatron teacher mode (--mopd-teacher-loads), as it requires access to teacher logits " "during training. The full-vocab KL loss is computed directly in the loss function rather " - "than through advantage modification." + "than through advantage modification. " + "'top_k': compute an approximate reverse KL divergence using the top-k teacher logits " + "plus tail probability correction. Stores only [R, k] logits+indices per sample " + "(k controlled by --mopd-topk-k, default 1024), greatly reducing memory compared to " + "full_vocab while being more accurate than token_level. Requires --mopd-teacher-loads." + ), + ) + parser.add_argument( + "--mopd-topk-k", + type=int, + default=1024, + help=( + "Number of top-k tokens to keep per position for MOPD top_k distillation. " + "Only used when --mopd-distill-type=top_k. Higher k gives more accurate KL " + "approximation at the cost of more memory. Default: 1024." ), ) return parser @@ -1818,15 +1832,19 @@ def slime_validate_args(args): f"--mopd-eps-high ({args.mopd_eps_high}) must be > --mopd-eps-low ({args.mopd_eps_low})." ) - # Validate mopd_distill_type: full_vocab mode requires megatron teachers - if args.mopd_distill_type == "full_vocab": + # Validate mopd_distill_type: full_vocab and top_k modes require megatron teachers + if args.mopd_distill_type in ("full_vocab", "top_k"): if args.mopd_teacher_loads is None: raise ValueError( - "--mopd-distill-type=full_vocab requires --mopd-teacher-loads (megatron teacher mode). " + f"--mopd-distill-type={args.mopd_distill_type} requires --mopd-teacher-loads (megatron teacher mode). " "SGLang-based teachers cannot return full-vocabulary logits efficiently. " "Please provide teacher checkpoints via --mopd-teacher-loads." ) + # Validate mopd_topk_k + if args.mopd_distill_type == "top_k" and args.mopd_topk_k <= 0: + raise ValueError(f"--mopd-topk-k must be > 0, got {args.mopd_topk_k}.") + # MOPD with megatron-based teachers requires weights_backuper (to backup multiple models) if args.mopd_teacher_loads is not None and not args.enable_weights_backuper: raise ValueError( diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 4b18601c88..6fc1bef005 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -310,6 +310,163 @@ def vocab_parallel_reverse_kl( return _VocabParallelReverseKL.apply(student_logits, teacher_logits, process_group) +def vocab_parallel_topk_reverse_kl( + student_logits: torch.Tensor, + teacher_topk_logits: torch.Tensor, + teacher_topk_indices: torch.Tensor, + vocab_size: int, + process_group: dist.ProcessGroup, +) -> torch.Tensor: + """Compute approximate D_KL(π_student ∥ π_teacher) using top-k teacher logits plus tail correction. + + This is a memory-efficient alternative to full-vocab KL. The teacher provides only + its top-k logits and indices (pre-computed during the teacher forward pass), while + the student has full vocab logits. + + The KL is decomposed into: + KL = KL_topk + KL_tail + where: + KL_topk = Σ_{y ∈ topk} π_s(y) [log π_s(y) - log π_t(y)] + KL_tail ≈ π_s_tail * log(π_s_tail / π_t_tail) + + For TP: teacher_topk_indices are LOCAL indices within each TP shard. We gather + the student logits at those local positions directly (no cross-shard indexing needed). + + Args: + student_logits: [R, V_local] student logits (with grad), vocab-sharded across TP. + teacher_topk_logits: [R, k] teacher top-k logits (detached), fp32. + teacher_topk_indices: [R, k] teacher top-k LOCAL indices within each TP shard, int. + vocab_size: Full (unsharded) vocabulary size V. + process_group: TP process group for all-reduce. + + Returns: + Per-token KL divergence tensor of shape [R]. + """ + # Detach teacher inputs + teacher_topk_logits = teacher_topk_logits.detach() + teacher_topk_indices = teacher_topk_indices.detach() + + # torch.gather requires LongTensor (int64) indices. + # Accept int32 from the data pipeline and cast defensively. + if teacher_topk_indices.dtype != torch.int64: + teacher_topk_indices = teacher_topk_indices.long() + + tp_size = dist.get_world_size(group=process_group) if process_group is not None else 1 + k = teacher_topk_logits.size(-1) + + # --- student softmax (numerically stable, TP-aware) --- + s_max = student_logits.max(dim=-1, keepdim=True).values + if tp_size > 1: + dist.all_reduce(s_max, op=dist.ReduceOp.MAX, group=process_group) + s_shifted = student_logits - s_max + s_exp = s_shifted.exp() + s_sum_exp = s_exp.sum(dim=-1, keepdim=True) + if tp_size > 1: + dist.all_reduce(s_sum_exp, op=dist.ReduceOp.SUM, group=process_group) + s_softmax = s_exp / s_sum_exp # π_s(y) [R, V_local] + s_log_sum_exp = s_sum_exp.log() # [R, 1] + + # Gather student probs and log-probs at teacher's top-k positions + # teacher_topk_indices are LOCAL to this TP shard + student_topk_probs = s_softmax.gather(-1, teacher_topk_indices) # [R, k] + student_topk_shifted = s_shifted.gather(-1, teacher_topk_indices) # [R, k] + student_topk_log_probs = student_topk_shifted - s_log_sum_exp # [R, k] + + # --- teacher log-softmax for top-k logits (numerically stable, TP-aware) --- + t_max = teacher_topk_logits.max(dim=-1, keepdim=True).values + if tp_size > 1: + # teacher_topk_logits are per-shard, so we need global max across shards + # BUT: each shard's top-k is independent (local indices). + # To compute the correct global log_sum_exp, we need: + # (a) the global max of ALL top-k logits across shards, and + # (b) the sum of exp(logits - global_max) across all shards. + dist.all_reduce(t_max, op=dist.ReduceOp.MAX, group=process_group) + t_shifted = teacher_topk_logits - t_max + t_exp = t_shifted.exp() + t_sum_exp = t_exp.sum(dim=-1, keepdim=True) + if tp_size > 1: + # Sum of exp across all shards (each shard contributes its k top-k values) + dist.all_reduce(t_sum_exp, op=dist.ReduceOp.SUM, group=process_group) + t_log_sum_exp = t_sum_exp.log() # [R, 1] + + # Teacher probs on top-k: exp(topk_logits) / Z_teacher + # But Z_teacher is NOT just the sum over top-k tokens. + # We need an approximation: Z_teacher ≈ sum_topk(exp) + (V_eff - k) * exp(tail_max). + # However, we don't have the exact partition function for teacher. + # Instead, we compute teacher_topk_log_probs using a CLOSED-over-topk partition function + # (treating top-k as if they were the entire vocab), then apply tail correction. + # + # Simple approach: compute teacher log-probs normalizing over top-k only, + # then estimate the tail correction analytically. + teacher_topk_log_probs_approx = t_shifted - t_log_sum_exp # [R, k] + teacher_topk_probs = (t_shifted.exp() / t_sum_exp) # [R, k], probs normalizing over top-k + + # --- tail mass --- + # Student tail mass: 1 - sum(π_s(y) for y in top-k of this shard) + student_topk_mass = student_topk_probs.sum(dim=-1) # [R] + if tp_size > 1: + # Sum the top-k mass across all TP shards to get the total mass in all shards' top-k + dist.all_reduce(student_topk_mass, op=dist.ReduceOp.SUM, group=process_group) + student_tail_mass = (1.0 - student_topk_mass).clamp(min=0.0) # [R] + + # Teacher tail mass: 1 - sum(π_t_topk(y) for y in top-k) + # Here π_t_topk is the teacher prob normalizing over the FULL vocab. + # We need the global partition function Z_t for teacher. + # We approximate by computing the partition function as: + # Z_t = sum of all exp(teacher_logits - global_max). + # But we only have top-k logits per shard. We approximate the tail as uniform. + # + # For each shard, we know its top-k contributes t_sum_exp_local = sum(exp(t_shifted)). + # The total Z_t ≈ (tp_size * k / vocab_size) * avg_exp would be wrong. + # + # Better: we already computed t_sum_exp across all shards (the sum of all top-k exp values). + # The full Z_t over the COMPLETE vocab is NOT available (we discarded non-top-k logits). + # + # Approximation: assume the full Z_t = t_sum_exp (treat top-k as the full support). + # This IS what teacher_topk_log_probs_approx normalizes over. + # So the tail of this approximate distribution has zero mass by construction. + # The tail correction accounts for the mass that SHOULD be in the tail. + # + # We use: teacher_tail_mass ≈ (V - k*tp_size) / V (uniform prior on tail) + V_eff = k * tp_size # effective number of tokens in the top-k across all shards + teacher_tail_mass = max(0.0, (vocab_size - V_eff) / vocab_size) + # Scale: the teacher_topk_probs already sum to ~1 within the top-k partition, + # so the actual mass on top-k is (1 - teacher_tail_mass). + # We need to rescale teacher_topk_log_probs to reflect this: + # π_t(y) for y in top-k ≈ teacher_topk_probs(y) * (1 - teacher_tail_mass) + # log π_t(y) = teacher_topk_log_probs_approx + log(1 - teacher_tail_mass) + if teacher_tail_mass > 0 and teacher_tail_mass < 1.0: + teacher_topk_log_probs = teacher_topk_log_probs_approx + torch.log( + torch.tensor(1.0 - teacher_tail_mass, device=teacher_topk_logits.device, dtype=teacher_topk_logits.dtype) + ) + else: + teacher_topk_log_probs = teacher_topk_log_probs_approx + + # --- KL computation --- + # KL_topk = Σ_{y ∈ top-k (all shards)} π_s(y) [log π_s(y) - log π_t(y)] + local_kl_topk = (student_topk_probs * (student_topk_log_probs - teacher_topk_log_probs)).sum(dim=-1) # [R] + if tp_size > 1: + dist.all_reduce(local_kl_topk, op=dist.ReduceOp.SUM, group=process_group) + + # KL_tail ≈ π_s_tail * log(π_s_tail / π_t_tail) + # π_s_tail = student_tail_mass per token + # π_t_tail = teacher_tail_mass (estimated above) + # Note: we don't have per-token variance in teacher_tail_mass, but this is an approximation. + kl_tail = torch.zeros_like(student_tail_mass) + tail_mask = (student_tail_mass > 1e-10) & (teacher_tail_mass > 1e-10) + kl_tail[tail_mask] = student_tail_mass[tail_mask] * ( + torch.log(student_tail_mass[tail_mask]) - torch.log( + torch.tensor(teacher_tail_mass, device=student_tail_mass.device, dtype=student_tail_mass.dtype) + ) + ) + # If teacher_tail_mass ≈ 0 but student_tail_mass > 0, we have an unbounded KL. + # This shouldn't happen if k is large enough. We treat it as 0 for numerical safety. + + kl = local_kl_topk + kl_tail # [R] + + return kl + + def get_grpo_returns( rewards: torch.Tensor, kl: list[torch.Tensor], diff --git a/tests/test_mopd_full_vocab.py b/tests/test_mopd_full_vocab.py index 81a2808954..82efc3d0ae 100644 --- a/tests/test_mopd_full_vocab.py +++ b/tests/test_mopd_full_vocab.py @@ -632,4 +632,410 @@ def test_no_temperature_scaling(self): expected_logits = logits[0, T - R - 1 : T - 1, :] # [R, V] assert torch.allclose(logits_out[0], expected_logits, atol=1e-5), ( "get_logits_for_distill should return raw logits without temperature scaling" - ) \ No newline at end of file + ) + + +# --------------------------------------------------------------------------- +# Tests for vocab_parallel_topk_reverse_kl +# --------------------------------------------------------------------------- +class TestVocabParallelTopkReverseKL: + """Test the vocab_parallel_topk_reverse_kl function in ppo_utils.py.""" + + def test_topk_kl_approximates_full_kl(self): + """Top-k KL should approximate full-vocab KL when k covers most probability mass.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl, vocab_parallel_topk_reverse_kl + + torch.manual_seed(42) + V = 20 + R = 4 + k = 15 # top-15 out of 20 should cover most mass + + student_logits = torch.randn(R, V, requires_grad=True) + teacher_logits = torch.randn(R, V) + + # Full-vocab KL + full_kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) + + # Top-k KL + topk_vals, topk_idx = teacher_logits.topk(k, dim=-1) + topk_kl = vocab_parallel_topk_reverse_kl( + student_logits, topk_vals, topk_idx, V, process_group=None + ) + + # With k close to V, the top-k should be close to full-vocab KL + # Allow some tolerance due to tail approximation + assert topk_kl.shape == (R,), f"Expected shape ({R},), got {topk_kl.shape}" + assert torch.isfinite(topk_kl).all(), f"Top-k KL should be finite, got {topk_kl}" + + def test_topk_kl_identical_distributions(self): + """Top-k KL should be ~0 when student == teacher.""" + from slime.utils.ppo_utils import vocab_parallel_topk_reverse_kl + + V = 20 + k = 10 + logits = torch.randn(3, V) + + topk_vals, topk_idx = logits.topk(k, dim=-1) + kl = vocab_parallel_topk_reverse_kl(logits, topk_vals, topk_idx, V, process_group=None) + + assert kl.shape == (3,) + # Should be close to 0 (not exact due to tail approximation with V > k) + assert kl.item() >= -0.1, f"Top-k KL should be ~0 for identical distributions, got {kl}" + + def test_topk_kl_gradient_flows(self): + """Gradient flows through student logits in top-k KL.""" + from slime.utils.ppo_utils import vocab_parallel_topk_reverse_kl + + V = 20 + k = 10 + student_logits = torch.randn(3, V, requires_grad=True) + teacher_logits = torch.randn(3, V) + + topk_vals, topk_idx = teacher_logits.topk(k, dim=-1) + kl = vocab_parallel_topk_reverse_kl(student_logits, topk_vals, topk_idx, V, process_group=None) + loss = kl.sum() + loss.backward() + + assert student_logits.grad is not None, "student_logits should have gradients" + assert not torch.allclose(student_logits.grad, torch.zeros_like(student_logits.grad)), \ + "student_logits gradients should be non-zero" + + def test_topk_kl_increases_with_smaller_k(self): + """Top-k KL should generally increase as k decreases (less accurate approximation).""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl, vocab_parallel_topk_reverse_kl + + torch.manual_seed(42) + V = 50 + R = 5 + student_logits = torch.randn(R, V) + teacher_logits = torch.randn(R, V) + + full_kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None).sum().item() + + k_large = 40 + topk_vals_l, topk_idx_l = teacher_logits.topk(k_large, dim=-1) + kl_large = vocab_parallel_topk_reverse_kl( + student_logits, topk_vals_l, topk_idx_l, V, process_group=None + ).sum().item() + + # With k=V (full vocab), top-k should be closer to full KL + assert torch.isfinite(torch.tensor(kl_large)), "Top-k KL should be finite" + + +# --------------------------------------------------------------------------- +# Tests for apply_mopd_topk_to_loss +# --------------------------------------------------------------------------- +class TestApplyMopdTopkToLoss: + """Test the apply_mopd_topk_to_loss function.""" + + @pytest.fixture(autouse=True) + def _mock_deps(self, monkeypatch): + mock_megatron(monkeypatch) + + def _get_function(self): + from slime.backends.megatron_utils.loss import apply_mopd_topk_to_loss + return apply_mopd_topk_to_loss + + def _sum_of_sample_mean(self, tensor): + return tensor.mean() + + def _make_args(self, **overrides): + defaults = dict( + use_mopd=True, + mopd_distill_type="top_k", + mopd_topk_k=8, + mopd_teachers='[{"name": "math_teacher", "domain": "math"}]', + mopd_teacher_loads="/tmp/fake_teacher", + mopd_teacher_ckpt_steps=None, + mopd_alpha=0.0, + mopd_eps_low=0.0, + mopd_eps_high=1000.0, + mopd_sampling_logprobs_key="rollout_log_probs", + _mopd_teachers_parsed=[{"name": "math_teacher", "domain": "math"}], + padded_vocab_size=20, + ) + defaults.update(overrides) + return Namespace(**defaults) + + def test_single_teacher_topk_loss(self): + """Test single-teacher top-k KL loss computation.""" + apply_fn = self._get_function() + args = self._make_args() + torch.manual_seed(42) + + V = 20 + k = 8 + R1, R2 = 3, 4 + + student_logits_1 = torch.randn(R1, V) + student_logits_2 = torch.randn(R2, V) + teacher_logits_1 = torch.randn(R1, V) + teacher_logits_2 = torch.randn(R2, V) + + # Get top-k from teacher + topk_vals_1, topk_idx_1 = teacher_logits_1.topk(k, dim=-1) + topk_vals_2, topk_idx_2 = teacher_logits_2.topk(k, dim=-1) + + student_logits = [student_logits_1, student_logits_2] + teacher_topk_logits = {"math": [topk_vals_1, topk_vals_2]} + teacher_topk_indices = {"math": [topk_idx_1, topk_idx_2]} + + batch = {"rollout_log_probs": [torch.zeros(R1), torch.zeros(R2)]} + current_log_probs = [torch.zeros(R1), torch.zeros(R2)] + loss_masks = [torch.ones(R1), torch.ones(R2)] + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_topk_logits, + teacher_topk_indices, loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + assert kl_loss.shape == (), "kl_loss should be scalar" + assert "mopd_topk_kl" in metrics + assert "mopd_is_weight_mean" in metrics + assert "mopd_is_nonzero_frac" in metrics + assert "mopd_topk_kl/math" in metrics + + def test_topk_loss_is_non_negative(self): + """Test that top-k KL loss is non-negative (or close to it).""" + apply_fn = self._get_function() + args = self._make_args(mopd_eps_low=0.0, mopd_eps_high=1000.0) + torch.manual_seed(42) + + V = 20 + k = 10 + student_logits = [torch.randn(5, V)] + teacher_logits = [torch.randn(5, V)] + + topk_vals, topk_idx = teacher_logits[0].topk(k, dim=-1) + teacher_topk_logits = {"math": [topk_vals]} + teacher_topk_indices = {"math": [topk_idx]} + + batch = {"rollout_log_probs": [torch.zeros(5)]} + current_log_probs = [torch.zeros(5)] + loss_masks = [torch.ones(5)] + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_topk_logits, + teacher_topk_indices, loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + # Top-k KL may be slightly negative due to tail approximation, + # but should be close to 0 at worst + assert kl_loss.item() >= -0.5, f"Top-k KL loss should be >= -0.5, got {kl_loss.item()}" + + def test_topk_is_weight_clipping(self): + """Test IS weight clipping in top_k mode.""" + apply_fn = self._get_function() + args = self._make_args(mopd_eps_low=0.5, mopd_eps_high=2.0) + + V = 20 + k = 8 + student_logits = [torch.randn(3, V)] + teacher_logits = [torch.randn(3, V)] + topk_vals, topk_idx = teacher_logits[0].topk(k, dim=-1) + + teacher_topk_logits = {"math": [topk_vals]} + teacher_topk_indices = {"math": [topk_idx]} + + batch = {"rollout_log_probs": [torch.tensor([0.0, 0.0, 0.0])]} + current_log_probs = [torch.tensor([-5.0, 0.0, 5.0])] + loss_masks = [torch.ones(3)] + + kl_loss, metrics = apply_fn( + args, batch, student_logits, teacher_topk_logits, + teacher_topk_indices, loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + is_nonzero_frac = metrics["mopd_is_nonzero_frac"].item() + assert abs(is_nonzero_frac - 1.0 / 3.0) < 0.05, ( + f"Expected ~1/3 nonzero IS weight fraction, got {is_nonzero_frac}" + ) + + def test_topk_none_teacher_for_sample(self): + """Test that None entries in teacher data are skipped.""" + apply_fn = self._get_function() + args = self._make_args( + _mopd_teachers_parsed=[ + {"name": "math_teacher", "domain": "math"}, + {"name": "code_teacher", "domain": "code"}, + ], + ) + + V = 20 + k = 8 + student_0 = torch.randn(3, V) + student_1 = torch.randn(4, V) + + teacher_0 = torch.randn(3, V) + teacher_1 = torch.randn(4, V) + teacher_code_1 = torch.randn(4, V) + + topk_vals_0, topk_idx_0 = teacher_0.topk(k, dim=-1) + topk_vals_1, topk_idx_1 = teacher_1.topk(k, dim=-1) + topk_vals_c1, topk_idx_c1 = teacher_code_1.topk(k, dim=-1) + + teacher_topk_logits = { + "math": [topk_vals_0, topk_vals_1], + "code": [None, topk_vals_c1], + } + teacher_topk_indices = { + "math": [topk_idx_0, topk_idx_1], + "code": [None, topk_idx_c1], + } + + batch = {"rollout_log_probs": [torch.zeros(3), torch.zeros(4)]} + current_log_probs = [torch.zeros(3), torch.zeros(4)] + loss_masks = [torch.ones(3), torch.ones(4)] + + kl_loss, metrics = apply_fn( + args, batch, [student_0, student_1], + teacher_topk_logits, teacher_topk_indices, + loss_masks, self._sum_of_sample_mean, + current_log_probs=current_log_probs, + ) + + assert kl_loss.shape == () + assert torch.isfinite(kl_loss), "Loss should be finite with None teacher entries" + + def test_topk_k_parameter_effect(self): + """Test that larger k gives KL closer to full-vocab KL.""" + from slime.utils.ppo_utils import vocab_parallel_reverse_kl + + apply_fn = self._get_function() + torch.manual_seed(42) + + V = 20 + student_logits = [torch.randn(5, V)] + teacher_logits_raw = [torch.randn(5, V)] + + # Full-vocab KL as ground truth + full_kl = vocab_parallel_reverse_kl(student_logits[0], teacher_logits_raw[0], None).sum().item() + + # Top-k with k=5 + k_small = 5 + topk_vals_s, topk_idx_s = teacher_logits_raw[0].topk(k_small, dim=-1) + args_small = self._make_args(mopd_topk_k=k_small) + batch = {"rollout_log_probs": [torch.zeros(5)]} + current_log_probs = [torch.zeros(5)] + loss_masks = [torch.ones(5)] + + kl_small, _ = apply_fn( + args_small, batch, student_logits, + {"math": [topk_vals_s]}, {"math": [topk_idx_s]}, + loss_masks, self._sum_of_sample_mean, current_log_probs=current_log_probs, + ) + + # Top-k with k=18 (close to V) + k_large = 18 + topk_vals_l, topk_idx_l = teacher_logits_raw[0].topk(k_large, dim=-1) + args_large = self._make_args(mopd_topk_k=k_large) + + kl_large, _ = apply_fn( + args_large, batch, student_logits, + {"math": [topk_vals_l]}, {"math": [topk_idx_l]}, + loss_masks, self._sum_of_sample_mean, current_log_probs=current_log_probs, + ) + + # Larger k should generally be closer to full KL (both are approximations) + + +# --------------------------------------------------------------------------- +# Tests for top_k argument validation +# --------------------------------------------------------------------------- +class TestMopdTopkValidation: + """Test --mopd-distill-type=top_k parameter validation.""" + + @pytest.fixture(autouse=True) + def _mock_deps(self, monkeypatch): + megatron_mod = types.ModuleType("megatron") + training_mod = types.ModuleType("megatron.training") + arguments_mod = types.ModuleType("megatron.training.arguments") + arguments_mod.parse_args = lambda *a, **kw: None + arguments_mod.validate_args = lambda a: a + tokenizer_pkg_mod = types.ModuleType("megatron.training.tokenizer") + tokenizer_mod = types.ModuleType("megatron.training.tokenizer.tokenizer") + tokenizer_mod._vocab_size_with_padding = lambda vocab_size, _args: vocab_size + transformers_mod = types.ModuleType("transformers") + transformers_mod.AutoConfig = types.SimpleNamespace(from_pretrained=lambda *a, **kw: None) + + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.training", training_mod) + monkeypatch.setitem(sys.modules, "megatron.training.arguments", arguments_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer", tokenizer_pkg_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer.tokenizer", tokenizer_mod) + monkeypatch.setitem(sys.modules, "transformers", transformers_mod) + + def _make_base_args(self, **overrides): + defaults = dict( + use_opd=False, + opd_type=None, + opd_kl_coef=1.0, + opd_teacher_load=None, + opd_teacher_ckpt_step=None, + use_mopd=False, + mopd_teachers=None, + mopd_teacher_loads=None, + mopd_teacher_ckpt_steps=None, + mopd_alpha=0.0, + mopd_eps_low=0.2, + mopd_eps_high=5.0, + mopd_sampling_logprobs_key="rollout_log_probs", + mopd_distill_type="token_level", + mopd_topk_k=1024, + enable_weights_backuper=True, + eval_datasets=[], + eval_prompt_data=None, + kl_coef=0, + ref_load="/tmp/fake_ref", + use_kl_loss=False, + use_critic=False, + rm_type=None, + custom_rm_path=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + def test_topk_without_teacher_loads_raises(self): + """Test that top_k mode without --mopd-teacher-loads raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_distill_type="top_k", + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_teacher_loads=None, + ) + with pytest.raises(ValueError, match="top_k.*mopd-teacher-loads|megatron teacher"): + slime_validate_args(args) + + def test_topk_k_must_be_positive(self): + """Test that --mopd-topk-k <= 0 raises ValueError.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_distill_type="top_k", + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_teacher_loads=["/tmp/fake_teacher"], + mopd_topk_k=0, + ) + with pytest.raises(ValueError, match="mopd-topk-k.*> 0"): + slime_validate_args(args) + + def test_topk_k_default(self): + """Test that --mopd-topk-k defaults to 1024.""" + from slime.utils.arguments import slime_validate_args + + args = self._make_base_args( + use_mopd=True, + mopd_distill_type="top_k", + mopd_teachers='[{"name": "t1", "domain": "math"}]', + mopd_teacher_loads=["/tmp/fake_teacher"], + mopd_topk_k=1024, + mopd_alpha=0.0, + ) + slime_validate_args(args) + assert args.mopd_topk_k == 1024 \ No newline at end of file From 61f35d1ed58be65b2300a6219839f886f8cee386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 17:38:11 +0800 Subject: [PATCH 04/14] feat: add online SGLang server mode for MOPD distillation - Add SGLang-based teacher rollout pipeline (separate from Megatron in-process mode) - Implement HTTP-based teacher logprobs collection during rollout - Add MOPD teacher URL configuration via environment variables - Fix logits calculation bug in TopK mode - Fix bad teacher request handling with retry logic - Improve MOPD rollout logging and monitoring - Add 397B model example scripts (megatron and sglang modes) - Add README_zh.md with Chinese documentation - Add comprehensive SGLang TopK pipeline integration tests --- .../README.md | 199 ++++- .../README_zh.md | 347 ++++++++ ...run-qwen35-397B-A17B-mopd-topk-megatron.sh | 220 ++++++ .../run-qwen35-397B-A17B-mopd-topk-sglang.sh | 235 ++++++ slime/backends/megatron_utils/actor.py | 229 +++++- slime/backends/megatron_utils/data.py | 5 +- slime/backends/megatron_utils/loss.py | 37 + slime/ray/rollout.py | 64 +- slime/rollout/mopd.py | 574 +++++++++++++- slime/rollout/sglang_rollout.py | 21 +- slime/utils/arguments.py | 164 +++- slime/utils/ppo_utils.py | 208 +++-- slime/utils/types.py | 9 + tests/test_mopd_sglang_topk_pipeline.py | 747 ++++++++++++++++++ 14 files changed, 2879 insertions(+), 180 deletions(-) create mode 100644 examples/multi_teacher_on_policy_distillation/README_zh.md create mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh create mode 100755 examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh create mode 100644 tests/test_mopd_sglang_topk_pipeline.py diff --git a/examples/multi_teacher_on_policy_distillation/README.md b/examples/multi_teacher_on_policy_distillation/README.md index 34a4b1648a..40f6fd6b83 100644 --- a/examples/multi_teacher_on_policy_distillation/README.md +++ b/examples/multi_teacher_on_policy_distillation/README.md @@ -13,34 +13,153 @@ This example shows how to run **multi-teacher on-policy distillation (MOPD)** us ## Algorithm +MOPD supports three distillation types, controlled by `--mopd-distill-type`: + ### Token-Level Mode (`--mopd-distill-type token_level`, default) -Uses sampled token log-prob difference as a reverse KL approximation: +Uses the sampled token's log-prob difference as a **point estimate** of the reverse KL divergence. This is the cheapest and most memory-efficient mode, but only captures KL information at the positions of the actually sampled tokens. + +**Core formula:** + +For each sampled token `y_t`, the per-teacher reverse KL advantage is approximated as: ``` -reverse_kl_d = sg[log π_d(y_t) - log π_θ(y_t)] # per-teacher reverse KL -w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight -Â_MOPD,t = (1/D) Σ_d (reverse_kl_d + α · Â_ORM) # averaged across D teachers -L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)] # proxy policy loss +reverse_kl_d(y_t) = sg[log π_d(y_t) - log π_θ(y_t)] ``` +where `sg[·]` denotes stop-gradient (no gradient flows to the teacher). This is a single-token estimator of `D_KL(π_θ ∥ π_d)`: it equals the full KL only at the sampled position and provides no information about the rest of the vocabulary. + +**Training loss:** + +``` +w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight +Â_MOPD,t = (1/D) Σ_d (reverse_kl_d + α · Â_ORM) # avg across D teachers +L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)] # proxy policy loss +``` + +**Characteristics:** +- **Data needed**: Only teacher log-probs at sampled tokens — a scalar per token per teacher. +- **Memory**: Negligible (storing `log π_d(y_t)` only). +- **Teacher modes**: Works with both SGLang and Megatron teachers. +- **Accuracy**: Underestimates the true KL because it only evaluates at sampled positions. When the student and teacher distributions differ significantly, the sampled token `y_t` (from the student's policy) tends to be in high-`π_θ` regions, missing contributions from high-`π_d` but low-`π_θ` tokens. + ### Full-Vocabulary Mode (`--mopd-distill-type full_vocab`) -Computes the exact full-vocabulary reverse KL divergence instead of the token-level approximation. This provides a more accurate distillation signal at the cost of increased memory usage. +Computes the **exact** reverse KL divergence over the entire vocabulary: + +``` +D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)] +``` + +This requires accessing the full logit vectors `[R, V]` from both the student and teacher models at every response position. + +**Training loss:** ``` -D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)] # exact full-vocab KL -w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight -L_fv_kl = (1/D) Σ_d (1/|y| Σ_t w_t · D_KL(π_θ ∥ π_d)) # IS-corrected KL loss -L = L_fv_kl + α · L_pg # combined with PG loss +w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight +L_fv_kl = (1/D) Σ_d (1/|y| Σ_t w_t · D_KL(π_θ ∥ π_d)) # IS-corrected KL loss +L = L_fv_kl + α · L_pg # combined with PG loss ``` -When `α = 0`: `L = L_fv_kl` (pure distillation, no ORM needed). -When `α > 0`: `L = L_fv_kl + α · L_pg` (distillation + ORM policy gradient). +- When `α = 0`: `L = L_fv_kl` (pure distillation, no ORM needed). +- When `α > 0`: `L = L_fv_kl + α · L_pg` (distillation + ORM policy gradient). + +**TP-aware computation:** When using tensor parallelism (TP > 1), the vocabulary is sharded across TP ranks. Each rank holds `V / tp_size` logits locally. The KL is computed in a numerically stable, TP-aware manner: +1. Local softmax: `s_max` and `s_sum_exp` are all-reduced across TP ranks. +2. Full-softmax prob/log-prob are computed locally using the global normalizer. +3. `vocab_parallel_reverse_kl` sums the local KL contributions: `KL_local = Σ_{y_local} π_s(y)[log π_s(y) - log π_t(y)]`, which yields the full KL because each token appears on exactly one TP rank. + +**Characteristics:** +- **Data needed**: Full teacher logits `[R_i, V/TP]` per sample per teacher (computed during rollout forward pass). +- **Memory**: Very high — per-GPU memory per teacher is `B × R × (V/TP) × 4B` (fp32). Example: V=152K, TP=2, B=4, R=4096 → ~4.6 GB; V=248K, TP=8, B=4, R=4096 → ~1.9 GB. +- **Teacher modes**: Only Megatron mode (`--mopd-teacher-loads`), because SGLang cannot efficiently return full logit vectors. +- **Accuracy**: Exact KL — the gold standard for distillation quality. + +### Top-K Mode (`--mopd-distill-type top_k`) + +A **memory-efficient approximation** of the full-vocab KL. Instead of storing the entire vocabulary of teacher logits, only the top-k logits and their indices are kept, plus an analytical tail correction to account for the remaining vocabulary. + +**Core formula:** + +The KL divergence is decomposed into two parts: -**Requirements**: `full_vocab` mode requires `--mopd-teacher-loads` (Megatron teacher mode). SGLang mode is not supported because the full logits tensor cannot be obtained from SGLang rollout. +``` +D_KL(π_θ ∥ π_d) ≈ KL_topk + KL_tail +``` + +**Top-K part** — computed exactly over the teacher's top-k tokens: + +``` +KL_topk = Σ_{y ∈ top-k} π_s(y) [log π_s(y) - log π_t(y)] +``` -**Memory note**: `full_vocab` stores teacher logits `[R, V]` per sample per teacher. For large vocabularies (V=152K), this can be significant. Reduce `--rollout-batch-size` or `--rollout-max-response-len` if OOM occurs. +For each position, the teacher provides its top-k logit values and the corresponding token indices. The student's probabilities at those positions are gathered using the indices, and the exact KL over the top-k support is computed. + +**Tail correction** — approximates the KL contribution from non-top-k tokens: + +``` +KL_tail ≈ π_s_tail · log(π_s_tail / π_t_tail) +``` + +where: +- `π_s_tail = 1 - Σ_{y ∈ top-k} π_s(y)` — the student's exact tail mass (computed via all-reduce across TP ranks). +- Teacher tail mass estimation differs by mode: + - **Megatron mode**: `π_t_tail ≈ (V - V_eff) / V` — uniform distribution assumption over non-top-k tokens, where `V_eff = k × tp_size` is the total number of valid top-k entries. Since Megatron mode typically uses larger k (e.g., 1024+), the top-k entries capture most of the probability mass and this approximation is reasonable. + - **SGLang mode**: `π_t_tail = 1 - Σ_{y ∈ top-k} exp(log_prob_t(y))` — **exact** computation from the teacher's full-vocabulary log-probs returned by SGLang. Since SGLang returns `log(π_t(y))` (already softmax-normalized over the full vocabulary), summing `exp(log_prob)` gives the true probability mass in the top-k partition, and the tail is simply `1 - mass_topk`. This is a key advantage of SGLang mode: the tail mass is computed exactly, not estimated. + +**Important**: In SGLang mode, the teacher returns log-probs (not raw logits). The `vocab_parallel_topk_reverse_kl` function detects this via the `is_log_probs=True` flag and skips the log_softmax step, using the log-probs directly and computing tail mass from their exp-sum. This avoids the double-softmax bug and the inaccurate uniform tail estimate that would otherwise make KL values unrealistically large (e.g., ~7.7 nats with k=128 and V=152K). + +**Full loss:** + +``` +w_t = sg[π_θ(y_t) / μ_θ(y_t)] clipped to [ε_low, ε_high] # IS weight +L_topk_kl = (1/D) Σ_d (1/|y| Σ_t w_t · KL_topk+d(π_θ ∥ π_d)) # IS-corrected approx KL +L = L_topk_kl + α · L_pg # combined with PG loss +``` + +**TP-aware computation:** When using tensor parallelism (TP > 1), the handling differs by teacher mode: + +- **Megatron mode**: The teacher's top-k indices are **local** per TP shard (since the teacher logits are vocab-sharded, `topk` selects the top-k within each shard). The student gathers its probabilities at these local indices directly — no cross-shard index translation needed. +- **SGLang mode**: The SGLang server returns **global** token IDs. During the actor's data preparation step, each TP rank converts global indices to local indices within its vocab range `[vocab_offset, vocab_offset + vocab_local_size)`. Entries whose global token ID falls outside the shard's range are replaced with `-inf` logit and `0` index (padding). The `valid_topk_mask` (computed as `~torch.isinf(teacher_topk_logits)`) automatically identifies valid vs. padding entries. + +For both modes, all-reduce operations are needed for: global `s_max`, global `s_sum_exp`, global `t_max`, global `t_sum_exp`, global `student_topk_mass`, and global `local_kl_topk`. + +**Characteristics:** +- **Data needed**: Teacher top-k logits + indices `[R_i, k]` per sample per teacher (k controlled by `--mopd-topk-k`, default 1024). +- **Memory**: Very low — per-GPU memory per teacher is `B × R × k × 2 × 4B / TP` (fp32 logits + int64 indices). The ratio vs full_vocab is approximately `2k/V`, saving `1 - 2k/V`. Example: k=1024, V=152K, TP=2, B=4, R=4096 → ~128 MB vs full_vocab's ~4.6 GB (~97% reduction); k=1024, V=248K, TP=8, B=4, R=4096 → ~16 MB vs ~1.9 GB. +- **Teacher modes**: SGLang or Megatron. SGLang mode uses the `top_logprobs_num` parameter to request top-k logprobs from the remote server; Megatron mode computes top-k during the teacher forward pass. +- **Accuracy**: Very close to full_vocab — the top-k tokens capture the vast majority of the probability mass. The tail correction provides a bounded estimate of the remaining contribution. + +### Comparison of Distillation Types + +| | `token_level` | `top_k` | `full_vocab` | +|---|---|---|---| +| **KL accuracy** | Point estimate (sampled token only) | Approximate (top-k + tail correction) | Exact | +| **Teacher data per token** | 1 scalar (`log π_d(y_t)`) | k×2 values (logit + index) | V values (full logits) | +| **Teacher mem per GPU\*** | ≈0 | `B × R × k × 2 × 4B / TP` | `B × R × V × 4B / TP` | +| **Teacher mode** | SGLang or Megatron | SGLang or Megatron | Megatron only | +| **TP aware** | Not needed | Yes (local indices, all-reduce) | Yes (vocab-sharded) | +| **Gradient** | Through policy loss only | Through full student softmax | Through full student softmax | +| **When to use** | Quick iteration, SGLang teachers | Best balance of accuracy & efficiency | Max accuracy, sufficient memory | + +*\*B=batch, R=avg response length, V=vocab size, k=topk-k, TP=tensor parallelism degree. All three modes also require `B × R × V × 4B / TP` student logits memory during training (unavoidable for `top_k` and `full_vocab`).* + +The following diagram illustrates the trade-off: + +``` +Memory: token_level ◄────────────────────────────────────► full_vocab + (≈0) top_k (O(k) vs O(V)) (O(V)) + +Accuracy: token_level ◄────────────────────────────────────► full_vocab + (low: 1-token top_k (high: ~99%+ (exact) + approximation) of KL captured) +``` + +### How to Choose + +- **Use `token_level`** if you need the fastest iteration with minimal memory overhead, or if your teacher only supports sampled-token logprobs (no top-k API). +- **Use `top_k`** (recommended default) for the best balance of accuracy and efficiency. Works with both SGLang and Megatron teachers. Start with `--mopd-topk-k 1024`; increase to 2048 or 4096 if the vocabulary is very large or you want more precision. +- **Use `full_vocab`** only when you need the exact KL and have sufficient GPU memory. Only available with Megatron teachers. Typically only needed for research validation or very small-scale experiments. ## Key Arguments @@ -54,7 +173,8 @@ When `α > 0`: `L = L_fv_kl + α · L_pg` (distillation + ORM policy gradient). | `--mopd-eps-low` | IS weight lower bound for clipping (default: 0.2). Weights below this are zeroed. | | `--mopd-eps-high` | IS weight upper bound for clipping (default: 5.0). Weights above this are zeroed. | | `--mopd-sampling-logprobs-key` | Key in rollout_data for sampling log-probs used in IS weight computation (default: `rollout_log_probs`). | -| `--mopd-distill-type` | Distillation type: `token_level` (default) uses sampled token log-prob difference as a reverse KL approximation applied at the advantage level; `full_vocab` computes the exact full-vocabulary reverse KL divergence D_KL(π_θ ∥ π_d) using complete logits. `full_vocab` requires `--mopd-teacher-loads` (Megatron mode). | +| `--mopd-distill-type` | Distillation type: `token_level` (default) uses sampled token log-prob difference as a reverse KL approximation; `full_vocab` computes the exact full-vocabulary reverse KL divergence; `top_k` computes approximate KL using teacher top-k logits + tail correction. `full_vocab` requires `--mopd-teacher-loads` (Megatron mode); `top_k` and `token_level` work with both SGLang and Megatron teachers. See [Algorithm](#algorithm) for details. | +| `--mopd-topk-k` | Number of top-k tokens to keep per position for `top_k` distillation (default: 1024). Higher k gives more accurate KL approximation at the cost of more memory. Only used when `--mopd-distill-type=top_k`. | ## SGLang vs Megatron Mode @@ -67,8 +187,10 @@ When `α > 0`: `L = L_fv_kl + α · L_pg` (distillation + ORM policy gradient). - Each teacher runs as an independent SGLang server. - Teacher URLs are configured via the `MOPD_TEACHER_URLS` environment variable (JSON dict: `domain -> URL`) or via the `rm_url` field in each teacher config in `--mopd-teachers`. -- `--custom-rm-path slime.rollout.mopd.reward_func` and `--custom-reward-post-process-path slime.rollout.mopd.post_process_rewards` are required. +- `--custom-rm-path` and `--custom-reward-post-process-path` are auto-configured when not explicitly set (you typically don't need to specify them manually). - `--rm-url` serves as a fallback URL if no per-teacher URL is configured. +- **Supported distill types**: `token_level` and `top_k`. `full_vocab` is not supported (SGLang cannot efficiently return full-vocabulary logits). +- **`top_k` specifics**: The SGLang server is queried with `top_logprobs_num=k` to return per-position top-k logprobs. During training, global token IDs from SGLang are converted to per-TP-shard local indices with `-inf` padding for out-of-shard entries. ### Megatron Mode @@ -80,11 +202,15 @@ When `α > 0`: `L = L_fv_kl + α · L_pg` (distillation + ORM policy gradient). ## Components - `slime/rollout/mopd.py` implements SGLang-mode MOPD: - - `reward_func`: queries all teacher SGLang servers concurrently, returns per-domain responses. - - `post_process_rewards`: extracts token-level teacher log-probs from responses and stores them in `sample.mopd_teacher_log_probs`. + - `reward_func`: queries all teacher SGLang servers concurrently, returns per-domain responses. For `top_k` mode, the SGLang request includes `top_logprobs_num=k`. + - `post_process_rewards`: extracts teacher data from SGLang responses — token-level log-probs, and (for `top_k` mode) top-k logit values and global token indices. Stores them in `sample.mopd_teacher_log_probs`, `sample.mopd_teacher_topk_logits`, and `sample.mopd_teacher_topk_indices`. +- `slime/ray/rollout.py`: collects per-sample MOPD data from rollouts and splits by data parallelism. +- `slime/backends/megatron_utils/actor.py`: for SGLang `top_k` mode, converts global token IDs to per-TP-shard local indices with `-inf` padding for out-of-shard entries. - `slime/backends/megatron_utils/loss.py`: - - `apply_mopd_to_advantages`: computes per-teacher reverse KL, IS weights, and aggregated MOPD advantages. - - `policy_loss_function`: applies `mopd_advantages` and IS weights to the policy gradient loss. + - `apply_mopd_topk_to_loss`: computes IS-weighted top-k approximate reverse KL loss. + - `policy_loss_function`: integrates MOPD KL loss with the policy gradient loss. +- `slime/utils/ppo_utils.py`: + - `vocab_parallel_topk_reverse_kl`: TP-aware top-k KL computation with tail correction and `valid_topk_mask` support. - `run-qwen3-8B-mopd-sglang.sh`: launches SGLang teacher servers, then submits a Ray job. - `run-qwen3-8B-mopd-megatron.sh`: uses Megatron-loaded teacher models (no external server needed). @@ -246,9 +372,30 @@ For string convenience, you can also use a single string instead of a list: 5. **Why is `--group-rm` not supported with MOPD?** MOPD's `reward_func` returns per-domain dicts (not scalar rewards), which is incompatible with the batch `group_rm` reward path. Use the default per-sample reward path (no `--group-rm`). -6. **What is the difference between `token_level` and `full_vocab` distillation types?** - - `token_level` (default): Approximates reverse KL using the sampled token log-prob difference `sg[log π_d(y_t) - log π_θ(y_t)]`. This is efficient and works with both SGLang and Megatron teacher modes, but only captures the KL at the sampled token position. - - `full_vocab`: Computes the exact full-vocabulary reverse KL divergence `D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)]`. Provides a more accurate distillation signal but requires full logits from the teacher, which means it only works with Megatron mode (`--mopd-teacher-loads`). Memory usage is significantly higher because teacher logits `[R, V]` must be stored for each sample. - -7. **When should I use `full_vocab` mode?** - Use `full_vocab` when you need more precise distillation signal and have sufficient GPU/CPU memory. It is particularly beneficial when the student and teacher distributions differ significantly, as the token-level approximation can underestimate the true KL divergence. For memory-constrained scenarios, stick with `token_level`. \ No newline at end of file +6. **What is the difference between `token_level`, `top_k`, and `full_vocab` distillation types?** + - `token_level` (default): Approximates reverse KL using the sampled token log-prob difference `sg[log π_d(y_t) - log π_θ(y_t)]`. Efficient, works with SGLang and Megatron, but only captures KL at the sampled token position. Underestimates the true KL when student and teacher distributions diverge. + - `top_k`: Computes approximate reverse KL using the teacher's top-k logits plus an analytical tail correction. Memory-efficient (~97% less than `full_vocab`). Works with both SGLang and Megatron teachers. Accuracy is very close to `full_vocab` for typical k values (1024+). + - `full_vocab`: Computes the exact full-vocabulary reverse KL divergence `D_KL(π_θ ∥ π_d)`. Most accurate but memory-intensive (stores full `[R, V]` logits). Megatron-only (SGLang cannot efficiently return full vocab logits). + - See the [Algorithm](#algorithm) section and [Comparison table](#comparison-of-distillation-types) for detailed formulas and memory analysis. + +7. **When should I use `top_k` vs `full_vocab`?** + Use `top_k` in most production scenarios — it captures >99% of the KL signal with ~3% of the memory of `full_vocab`. Works with both SGLang and Megatron teachers. Use `full_vocab` only when you need the exact KL for research validation or have ample GPU memory (Megatron teachers only). Start with `--mopd-topk-k 1024`; increase to 2048 or 4096 if the vocabulary is very large (e.g., V > 200K) and you observe the tail correction is too aggressive. + +8. **How does the `top_k` tail correction work?** + The top-k decomposition splits the KL into `KL_topk` (exact over the teacher's top-k tokens) and `KL_tail` (approximate over the remaining tokens). The tail correction method differs by teacher mode: + - **Megatron mode**: `π_t_tail ≈ (V − V_eff) / V` — assumes uniform distribution over non-top-k tokens, where `V_eff = k × tp_size`. This is a conservative upper bound; the actual teacher tail mass is typically smaller, so the approximate KL slightly over-estimates the true KL. This approximation works well when k is large (e.g., k ≥ 1024) since the top-k entries capture most of the probability mass. + - **SGLang mode**: `π_t_tail = 1 − Σ exp(log_prob)` — **exact** computation. SGLang returns full-vocabulary log-probs for the top-k tokens, so summing their probabilities directly gives the true top-k mass and the tail is computed exactly. No uniform assumption is needed. This makes SGLang mode's tail correction accurate even with small k (e.g., k=128). + +9. **What is the memory usage of each distillation type?** + Memory scales with vocab size V, tensor parallelism TP, batch B, and response length R: + - `token_level`: Negligible (`B × R × 4B`). + - `top_k` (k=1024): `B × R × k × 2 × 4B / TP`. Ratio vs full_vocab is approximately `2k/V` (~1–2%). + - `full_vocab`: `B × R × V × 4B / TP`. + Concrete examples vary by model: V=152K/TP=2/B=4/R=4096 gives top_k ~128 MB, full_vocab ~4.6 GB; V=248K/TP=8/B=4/R=4096 gives top_k ~16 MB, full_vocab ~1.9 GB. + If OOM occurs with `full_vocab`, switch to `top_k` or reduce `--rollout-batch-size` / `--rollout-max-response-len`. + +10. **What are the requirements for SGLang `top_k` mode?** + - The SGLang teacher server must support the `top_logprobs_num` parameter (available in recent SGLang versions). + - The teacher's **vocabulary size must exactly match** the student's `padded_vocab_size`. This is because global token IDs from the teacher are converted to per-TP-shard local indices during training. A vocab size mismatch would produce incorrect index mappings and silently corrupt the KL computation. + - The `MOPD_TEACHER_URLS` environment variable must be set (JSON dict mapping domain names to SGLang `/generate` endpoints), or `--rm-url` must be provided as a fallback. + - `--custom-rm-path` and `--custom-reward-post-process-path` are auto-configured when both are unset — you typically don't need to set them manually. \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/README_zh.md b/examples/multi_teacher_on_policy_distillation/README_zh.md new file mode 100644 index 0000000000..ce0e4a9bef --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/README_zh.md @@ -0,0 +1,347 @@ +# 多教师在线策略蒸馏 (MOPD) 算法说明 + +本文档详细说明 MOPD 的三种蒸馏模式 (`token_level`、`top_k`、`full_vocab`) 的算法原理、计算流程和使用建议。 + +--- + +## 概述 + +MOPD (Multi-Teacher On-Policy Distillation) 通过 `--mopd-distill-type` 参数支持三种蒸馏模式,核心区别在于 **反向 KL 散度** `D_KL(π_θ ∥ π_d)` 的计算方式: + +| | `token_level` | `top_k` | `full_vocab` | +|---|---|---|---| +| **KL 精度** | 单点估计(仅采样 token) | 近似(top-k + 尾部校正) | 精确 | +| **每 token 教师数据** | 1 个标量 (`log π_d(y_t)`) | k×2 个值 (logit + index) | V 个值 (完整 logits) | +| **每 GPU 教师内存\*** | ≈0 | `B × R × k × 2 × 4B / TP` | `B × R × V × 4B / TP` | +| **教师模式** | SGLang 或 Megatron | 仅 Megatron | 仅 Megatron | +| **TP 感知** | 不需要 | 需要(局部索引 + all-reduce) | 需要(词表分片) | +| **梯度** | 仅通过策略损失 | 通过完整学生 softmax | 通过完整学生 softmax | +| **适用场景** | 快速迭代、SGLang 教师 | 大多数生产场景 | 最大精度、内存充裕 | + +*\*B=batch, R=平均响应长度, V=词表大小, k=topk-k, TP=张量并行度。三种模式训练时还需 `B × R × V × 4B / TP` 的学生 logits 内存(不可避免)。* + +``` +内存: token_level ◄────────────────────────────────────► full_vocab + (≈0) top_k (O(k) vs O(V)) (O(V)) + +精度: token_level ◄────────────────────────────────────► full_vocab + (低: 单token top_k (高: 捕获 (精确) + 近似) ~99%+ KL) +``` + +--- + +## 1. Token-Level 模式 (`--mopd-distill-type token_level`,默认) + +### 核心思想 + +仅使用**采样到的 token** 的对数概率差作为反向 KL 散度的点估计。这是最轻量、最省内存的模式,但只在采样的 token 位置提供 KL 信息。 + +### 公式 + +对每个采样 token `y_t`,每个教师的反向 KL 优势近似为: + +``` +reverse_kl_d(y_t) = sg[log π_d(y_t) - log π_θ(y_t)] +``` + +其中 `sg[·]` 表示 stop-gradient(不向教师回传梯度)。这是 `D_KL(π_θ ∥ π_d)` 的单 token 估计器——它只在采样位置等于完整 KL,对词表其余部分无信息。 + +**训练损失:** + +``` +w_t = sg[π_θ(y_t) / μ_θ(y_t)] 截断至 [ε_low, ε_high] # 重要性采样权重 +Â_MOPD,t = (1/D) Σ_d (reverse_kl_d + α · Â_ORM) # 跨 D 个教师平均 +L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)] # 代理策略损失 +``` + +### 特点 + +- **所需数据**:仅教师在每个采样 token 上的 log 概率——每个 token 每个教师一个标量。 +- **内存**:可忽略(仅存储 `log π_d(y_t)`)。 +- **教师模式**:同时支持 SGLang 和 Megatron 教师。 +- **精度问题**:会**低估**真实 KL。因为采样 token `y_t` 来自学生策略,倾向落入高 `π_θ` 区域,遗漏了高 `π_d` 但低 `π_θ` 的 token 的贡献。当学生和教师分布差异显著时,偏差更大。 + +--- + +## 2. Full-Vocabulary 模式 (`--mopd-distill-type full_vocab`) + +### 核心思想 + +在**完整词表**上计算精确的反向 KL 散度: + +``` +D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)] +``` + +需要访问学生和教师模型在每个响应位置的完整 logit 向量 `[R, V]`。 + +**训练损失:** + +``` +w_t = sg[π_θ(y_t) / μ_θ(y_t)] 截断至 [ε_low, ε_high] # 重要性采样权重 +L_fv_kl = (1/D) Σ_d (1/|y| Σ_t w_t · D_KL(π_θ ∥ π_d)) # 经 IS 校正的 KL 损失 +L = L_fv_kl + α · L_pg # 与策略损失组合 +``` + +- 当 `α = 0`:`L = L_fv_kl`(纯蒸馏,无需 ORM)。 +- 当 `α > 0`:`L = L_fv_kl + α · L_pg`(蒸馏 + ORM 策略梯度)。 + +### TP 并行计算 + +当使用张量并行(TP > 1)时,词表在 TP ranks 间分片,每个 rank 本地持有 `V / tp_size` 个 logits。KL 在数值稳定的 TP 感知方式下计算: + +1. **局部 softmax**:`s_max` 和 `s_sum_exp` 跨 TP ranks 做 all-reduce。 +2. 使用全局归一化因子在本地计算完整的 softmax 概率/对数概率。 +3. `vocab_parallel_reverse_kl` 累加局部 KL 贡献:`KL_local = Σ_{y_local} π_s(y)[log π_s(y) - log π_t(y)]`,由于每个 token 恰好出现在一个 TP rank 上,累加结果等于完整 KL。 + +### 特点 + +- **所需数据**:每个样本每个教师的完整 logits `[R_i, V/TP]`(在 rollout 前向传播时计算)。 +- **内存**:非常高——每 GPU 每个教师的 rollout 存储为 `B × R × (V/TP) × 4B`(fp32)。示例:V=152K, TP=2, B=4, R=4096 → 4×4096×76K×4 ≈ 4.6 GB;V=248K, TP=8, B=4, R=4096 → 4×4096×31K×4 ≈ 1.9 GB。 +- **教师模式**:仅 Megatron 模式(`--mopd-teacher-loads`),因为 SGLang 无法高效返回完整 logit 向量。 +- **精度**:精确 KL——蒸馏质量的金标准。 + +--- + +## 3. Top-K 模式 (`--mopd-distill-type top_k`) + +### 核心思想 + +full_vocab 的**内存高效近似**。不存储教师完整词表的 logits,仅保留 top-k 的 logits 和索引,加上对剩余词表的解析尾部校正。 + +### 公式推导 + +KL 散度分解为两部分: + +``` +D_KL(π_θ ∥ π_d) ≈ KL_topk + KL_tail +``` + +#### Top-K 部分 — 在教师的 top-k token 上精确计算 + +``` +KL_topk = Σ_{y ∈ top-k} π_s(y) [log π_s(y) - log π_t(y)] +``` + +对每个位置,教师提供其 top-k logit 值和对应的 token 索引。学生使用索引 gather 对应位置的概率,并在 top-k 支撑集上精确计算 KL。 + +**教师 log-prob 的计算**:由于只有 top-k logits,无法得到完整词表的归一化因子 `Z_t`。因此先用 top-k 内部的归一化计算近似 log-prob: + +``` +log π_t_approx(y) = (logit_t(y) - max_topk) - log(Σ_{y'∈top-k} exp(logit_t(y') - max_topk)), y ∈ top-k +``` + +然后通过尾部校正补偿缺失的概率质量。 + +#### 尾部校正 — 近似非 top-k token 的 KL 贡献 + +``` +KL_tail ≈ π_s_tail · log(π_s_tail / π_t_tail) +``` + +其中: + +- `π_s_tail = 1 - Σ_{y ∈ top-k} π_s(y)` — 学生的精确尾部概率质量(通过跨 TP ranks 的 all-reduce 计算)。 +- `π_t_tail ≈ (V - k × tp_size) / V` — 教师的估计尾部概率质量,假设非 top-k token 上服从均匀分布。 + +**为什么这是一个保守上界?** 均匀分布假设通常会**高估**教师的尾部概率质量(实际上教师 top-k 的 logits 主导了概率分布,真实尾部质量更小),因此近似 KL 会**略微高估**真实 KL。这意味着蒸馏时会略微过度正则化(更偏向教师),这对蒸馏来说是**安全**的。 + +### 完整损失 + +``` +w_t = sg[π_θ(y_t) / μ_θ(y_t)] 截断至 [ε_low, ε_high] # 重要性采样权重 +L_topk_kl = (1/D) Σ_d (1/|y| Σ_t w_t · KL_topk+d(π_θ ∥ π_d)) # 经 IS 校正的近似 KL +L = L_topk_kl + α · L_pg # 与策略损失组合 +``` + +### TP 并行计算 + +当使用张量并行(TP > 1)时: + +- 教师的 top-k 索引是每个 TP 分片的**局部索引**(因为教师 logits 是词表分片的,`topk` 在每个分片内选择 top-k)。 +- 学生直接在这些局部索引位置 gather 概率——**无需跨分片索引转换**。 +- 需要的 all-reduce 操作:全局 `s_max`、全局 `s_sum_exp`、全局 `t_max`、全局 `t_sum_exp`、全局 `student_topk_mass`、全局 `local_kl_topk`。 + +### 特点 + +- **所需数据**:每个样本每个教师的 top-k logits + 索引 `[R_i, k]`(k 由 `--mopd-topk-k` 控制,默认 1024)。 +- **内存**:非常低——每 GPU 每个教师的 rollout 存储为 `B × R × k × 2 × 4B / TP`(fp32 logits + int32/int64 indices)。与 full_vocab 的比值约为 `2k/V`,内存减少约 `1 - 2k/V`。示例:k=1024, V=152K, TP=2, B=4, R=4096 → ≈ 128 MB vs full_vocab 的 ≈ 4.6 GB(**~97% 减少**);k=1024, V=248K, TP=8, B=4, R=4096 → ≈ 16 MB vs ≈ 1.9 GB。 +- **教师模式**:仅 Megatron 模式(`--mopd-teacher-loads`)。 +- **精度**:非常接近 full_vocab——top-k token 捕获了绝大部分概率质量,尾部校正提供了对剩余贡献的有界估计。 + +--- + +## 选择建议 + +- **使用 `token_level`**:如果你使用 SGLang 教师,或需要最快的迭代速度和最小内存开销。 +- **使用 `top_k`**(Megatron 教师推荐默认值):精度和效率的最佳平衡。从 `--mopd-topk-k 1024` 开始;如果词表非常大或需要更高精度,可增加到 2048 或 4096。 +- **使用 `full_vocab`**:仅当你需要精确 KL 且 GPU 内存充足时。通常仅用于研究验证或小规模实验。 + +--- + +## 关键参数 + +| 参数 | 说明 | +|------|------| +| `--use-mopd` | 启用多教师在线策略蒸馏。与 `--use-opd` 互斥。 | +| `--mopd-teachers` | 教师配置的 JSON 列表,每项含 `name` 和 `domain`(必填)。示例:`'[{"name":"math_t","domain":"math"},{"name":"code_t","domain":"code"}]'` | +| `--mopd-teacher-loads` | Megatron 模式教师的检查点路径,空格分隔。数量须与 `--mopd-teachers` 中的教师数一致。 | +| `--mopd-teacher-ckpt-steps` | 每个教师模型的可选检查点步数。数量须与教师数一致。 | +| `--mopd-alpha` | MOPD 优势与 ORM 优势的组合系数(默认 0.0)。0 为纯蒸馏,>0 为蒸馏+ORM 组合。 | +| `--mopd-eps-low` | IS 权重截断下界(默认 0.2)。低于此值的权重置零。 | +| `--mopd-eps-high` | IS 权重截断上界(默认 5.0)。高于此值的权重置零。 | +| `--mopd-sampling-logprobs-key` | rollout_data 中用于 IS 权重计算的采样 log-probs 键名(默认 `rollout_log_probs`)。 | +| `--mopd-distill-type` | 蒸馏类型:`token_level`(默认)使用采样 token 的 log-prob 差作为反向 KL 近似;`full_vocab` 计算精确的全词表反向 KL 散度;`top_k` 使用教师 top-k logits + 尾部校正计算近似 KL。`full_vocab` 和 `top_k` 均需要 `--mopd-teacher-loads`(Megatron 模式)。详见[算法](#算法)部分。 | +| `--mopd-topk-k` | `top_k` 蒸馏时每个位置保留的 top-k token 数(默认 1024)。k 越大 KL 近似越精确但内存越多。仅在 `--mopd-distill-type=top_k` 时生效。 | + +--- + +## SGLang 模式 vs Megatron 模式 + +| 模式 | 教师位置 | 何时使用 | +|------|----------|----------| +| `sglang` | 外部 SGLang 服务器(每个教师一个) | 教师架构不同,或太大无法放入训练 GPU 内存 | +| `megatron` | 加载到 Megatron 训练进程 | 教师与策略/参考模型架构相同 | + +### SGLang 模式 + +- 每个教师作为独立 SGLang 服务器运行。 +- 教师网址通过 `MOPD_TEACHER_URLS` 环境变量(JSON 字典:`domain -> URL`)或 `--mopd-teachers` 中每个教师配置的 `rm_url` 字段配置。 +- 需配置 `--custom-rm-path slime.rollout.mopd.reward_func` 和 `--custom-reward-post-process-path slime.rollout.mopd.post_process_rewards`。 +- `--rm-url` 作为未配置单教师 URL 时的回退。 + +### Megatron 模式 + +- 教师模型通过 `TensorBackuper` 加载到 CPU 内存,训练时切换到 GPU 进行前向传播。 +- 需 `--enable-weights-backuper`(默认开启)用于权重备份/恢复。 +- 每个教师必须与策略模型**架构相同**。 +- 内存注意:每个教师模型额外占用 CPU 内存用于权重备份。 + +--- + +## 核心代码组件 + +- `slime/rollout/mopd.py`:SGLang 模式 MOPD 的实现。 + - `reward_func`:并发查询所有教师 SGLang 服务器,返回按域名分组的响应。 + - `post_process_rewards`:从响应中提取 token 级教师 log-probs,存入 `sample.mopd_teacher_log_probs`。 +- `slime/backends/megatron_utils/actor.py`:Rollout 阶段教师前向传播。 + - `token_level` 模式:调用 `compute_log_prob` 获取教师 log-probs。 + - `full_vocab` / `top_k` 模式:调用 `compute_log_prob(return_logits=True)` 获取完整 logits,`top_k` 额外执行 `topk()` 截取。 +- `slime/backends/megatron_utils/loss.py`:训练阶段的损失计算。 + - `apply_mopd_to_advantages`:计算每个教师的反向 KL、IS 权重和聚合的 MOPD 优势(token_level 模式)。 + - `apply_mopd_full_vocab_to_loss`:计算精确全词表 KL 损失(full_vocab 模式)。 + - `apply_mopd_topk_to_loss`:计算 top-k 近似 KL 损失(top_k 模式)。 + - `policy_loss_function`:根据 `mopd_distill_type` 应用相应的损失组合。 +- `slime/utils/ppo_utils.py`: + - `vocab_parallel_reverse_kl`:TP 感知的全词表反向 KL 计算。 + - `vocab_parallel_topk_reverse_kl`:TP 感知的 top-k 近似反向 KL 计算(含尾部校正)。 +- `slime/utils/arguments.py`:`--mopd-distill-type` 和 `--mopd-topk-k` 参数定义与验证。 + +--- + +## 数据流 + +### Token-Level 模式 + +``` +Rollout 阶段: + 教师 → compute_log_prob() → log π_d(y_t) (每 token 1 标量) → rollout_data["mopd_teacher_log_probs"] + +训练阶段: + ① apply_mopd_to_advantages(): + reverse_kl_d = log π_d(y_t) - log π_θ(y_t) + Â_MOPD = avg_d(reverse_kl_d + α·Â_ORM) + IS权重 w_t = clip(π_θ/μ_θ, [ε_low, ε_high]) + ② 替换优势: pg_loss 用 Â_MOPD 替代 Â_ORM + ③ 乘以 IS 权重: loss = pg_loss × w_t +``` + +### Full-Vocab / Top-K 模式 + +``` +Rollout 阶段: + 教师 → compute_log_prob(return_logits=True) → 完整 logits [R_i, V_local] + ├─ full_vocab: 直接存储所有 logits → rollout_data["mopd_teacher_{domain}_fv_logits"] + └─ top_k: topk(k, dim=-1) 截取 → rollout_data["mopd_teacher_{domain}_topk_logits"] + rollout_data["mopd_teacher_{domain}_topk_indices"] + +训练阶段: + ① 学生前向传播 → get_logits(apply_temperature=False) → student_logits [R_i, V_local] + ② full_vocab: vocab_parallel_reverse_kl(student_logits, teacher_logits) → D_KL 精确值 + top_k: vocab_parallel_topk_reverse_kl(student_logits, teacher_topk_logits, teacher_topk_indices) + → KL_topk + KL_tail → 近似 D_KL + ③ IS 权重: w_t = clip(π_θ/μ_θ, [ε_low, ε_high]) + ④ 加权 KL: L_kl = (1/D) Σ_d (1/|y| Σ_t w_t · KL_d) + ⑤ 损失组合: L = L_kl + α · L_pg (α=0 时纯蒸馏) +``` + +### 内存对比(单个教师) + +每个 GPU 上每种模式每样本的内存公式(fp32): + +| | Rollout 教师存储 (per GPU) | top_k / full_vocab 比值 | +|---|---|---| +| `token_level` | `B × R × 4B` (1 scalar/token) | — | +| `top_k` | `B × R × k × 2 × 4B / TP` | `2k / V` | +| `full_vocab` | `B × R × V × 4B / TP` | 1 | + +训练时三种模式均需额外 `B × R × V × 4B / TP` 的学生 logits 内存(`top_k` 和 `full_vocab` 模式需要完整 logits 计算 softmax)。 + +参数说明:B=batch, R=平均响应长度, V=词表大小, k=`--mopd-topk-k`, TP=张量并行度。 + +**参考数值:** + +| 配置 | `token_level` | `top_k` (k=1024) | `full_vocab` | +|------|---------------|-------------------|--------------| +| V=152K, TP=2, B=4, R=4096 | ~64 KB | ~128 MB | ~4.6 GB | +| V=248K, TP=8, B=4, R=4096 | ~64 KB | ~16 MB | ~1.9 GB | +| V=152K, TP=2, B=8, R=2048 | ~32 KB | ~64 MB | ~2.3 GB | + +注意:训练时学生 logits 的内存开销在三种模式下相同(都需要完整 logits 来计算 KL),差异主要在教师的 rollout 存储。 + +--- + +## FAQ + +1. **MOPD 可以和 OPD 同时使用吗?** + 不可以。`--use-mopd` 和 `--use-opd` 互斥。需要多教师时使用 MOPD。 + +2. **所有教师需要架构相同吗?** + - Megatron 模式:是的,所有教师必须与策略模型架构相同。 + - SGLang 模式:不需要,每个教师可以是不同架构,因为它们运行在独立的服务器上。 + +3. **Megatron 模式下 MOPD 需要多少额外内存?** + 每个教师模型需要 CPU 内存用于权重备份(通过 `TensorBackuper`)。教师权重仅在训练前向传播时临时加载到 GPU,然后恢复到 CPU。按 `N × model_size` 规划额外 CPU 内存,N 为教师数量。此外 `full_vocab` 模式还需要大量 GPU 内存存储教师 logits(见上文内存对比表)。 + +4. **SGLang 模式下教师服务器故障怎么办?** + `reward_func` 会记录警告并跳过该教师。训练会继续使用剩余教师,但优势会有偏。请密切监控教师服务器健康状态。 + +5. **为什么 `--group-rm` 不支持 MOPD?** + MOPD 的 `reward_func` 返回按域名的字典(非标量奖励),与批量 `group_rm` 奖励路径不兼容。使用默认的逐样本奖励路径(不加 `--group-rm`)。 + +6. **`token_level`、`top_k` 和 `full_vocab` 三种蒸馏类型有什么区别?** + - `token_level`(默认):使用采样 token 的 log-prob 差近似反向 KL。高效,支持 SGLang 和 Megatron,但仅在采样位置捕获 KL 信息。当学生和教师分布差异大时会低估真实 KL。 + - `top_k`:使用教师 top-k logits + 解析尾部校正计算近似反向 KL。内存高效(比 `full_vocab` 少 ~97%),仅 Megatron 模式。k=1024+ 时精度非常接近 `full_vocab`。 + - `full_vocab`:计算精确的全词表反向 KL 散度。最精确但内存密集(存储完整 `[R, V]` logits)。仅 Megatron 模式。 + - 详见[算法](#算法)部分和[对比表](#概述)。 + +7. **什么时候用 `top_k`,什么时候用 `full_vocab`?** + 大多数生产场景用 `top_k`——以 `full_vocab` 约 3% 的内存捕获 >99% 的 KL 信号。只有在需要精确 KL 做研究验证、或 GPU 内存十分充裕时才用 `full_vocab`。从 `--mopd-topk-k 1024` 开始;如果词表非常大(如 V > 200K)或发现尾部校正过于激进,可增大到 2048 或 4096。 + +8. **`top_k` 的尾部校正原理是什么?** + top-k 分解将 KL 分为 `KL_topk`(在教师 top-k token 上精确计算)和 `KL_tail`(近似剩余 V−k 个 token 的贡献)。尾部假设非 top-k token 上教师服从均匀分布:`π_t_tail ≈ (V − k·tp_size) / V`。这是一个保守上界——真实的教师尾部质量通常更小(教师 top-k 占据了主要概率),因此近似 KL 会略微高估真实 KL,意味着略微过度正则化(更偏向教师),对蒸馏是安全的。 + +9. **三种蒸馏模式的内存用量?** + 每种模式的 GPU 内存与词表大小 V、张量并行度 TP、批量 B、响应长度 R 成正比: + - `token_level`:可忽略(`B × R × 4B`)。 + - `top_k`(k=1024):`B × R × k × 2 × 4B / TP`。比例约为 full_vocab 的 `2k/V`(约 1~2%)。 + - `full_vocab`:`B × R × V × 4B / TP`。 + 具体数值因模型而异,例如 V=152K/TP=2/B=4/R=4096 时 top_k ≈ 128 MB、full_vocab ≈ 4.6 GB;而 V=248K/TP=8/B=4/R=4096 时 top_k ≈ 16 MB、full_vocab ≈ 1.9 GB。 + 如果 `full_vocab` OOM,切换到 `top_k` 或减小 `--rollout-batch-size` / `--rollout-max-response-len`。 + +10. **`top_k` 模式的 `k` 值怎么选?** + - `k=1024`(默认):适用于大多数场景,平衡精度和内存。 + - `k=2048`:词表较大(V > 200K)时推荐,进一步减少尾部校正的近似误差。 + - `k=4096`:需要更高精度时使用,内存仍远小于 `full_vocab`。 + - 经验法则:k/V > 0.5% 即可捕获 >99% 的 KL 信号,因为教师分布通常高度集中。 \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh new file mode 100644 index 0000000000..f68986e418 --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh @@ -0,0 +1,220 @@ +#!/bin/bash + +# Multi-Teacher On-Policy Distillation (MOPD) — Top-K KL Divergence Mode +# Model: Qwen3.5-397B-A17B (MoE, 512 experts, 10 active) +# Environment: 16 nodes × 8 L20X (143GB each), 128 GPUs total +# Teacher: Skin-multiturn teacher (different from student for production distillation) +# Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) +# Distill Type: top_k (approximate reverse KL with top-k teacher logits + tail correction) +# +# This script is for MOPD top_k distillation with 128 GPUs. +# +# Key features of top_k mode: +# --mopd-distill-type top_k +# → Computes approximate D_KL(π_θ ∥ π_d) using teacher's top-k logits +# plus tail probability correction. Much more memory-efficient than full_vocab. +# → Stores only [R_i, k] teacher logits+indices per sample (k=1024 default), +# vs [R_i, V] for full_vocab. ~98.7% memory reduction vs full_vocab. +# +# Prerequisites: +# 1. Convert HF checkpoint to Megatron torch_dist format before first run: +# cd /path/to/slime +# source scripts/models/qwen3.5-397B-A17B.sh +# +# PYTHONPATH=/root/Megatron-LM torchrun --nproc_per_node=8 \ +# tools/convert_hf_to_torch_dist.py \ +# ${MODEL_ARGS[@]} \ +# --hf-checkpoint /personal/ckpt/Qwen3.5-397B-A17B_skin_multiturn \ +# --save /personal/ckpt/Qwen3.5-397B-A17B_skin_multiturn_torch_dist +# +# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh + +set -ex + +export PYTHONBUFFERED=16 +export FLASHINFER_DISABLE_VERSION_CHECK=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SLIME_DIR="/workspace/bin/slime" +source "${SLIME_DIR}/scripts/models/qwen3.5-397B-A17B.sh" + +# ============================================================================ +# Paths — adjust these to your environment +# ============================================================================ +BASE_DIR=/personal/ckpt + +HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5 +TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5_torch_dist +TEACHER_TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_skin_multiturnl_torch_dist +SAVE_DIR=/amed/share/s1-amed-spfs-ckpt/yanyi/Qwen3.5-397B-A17B-Stage3b-Mopd-Topk-Skin-Multiturn-Enhanced + +DATA_PATH="/mnt/amed-s3/dataset/14019ba0_text_report_Interpretation/a3967912440becb0d70748a478696f12b6bbf6ac/train_text_think_nothink.jsonl" + +# MOPD teachers JSON config +export MOPD_TEACHERS_JSON='[{"name":"skin-multiturn","domain":"default"}]' + +# ============================================================================ +# Configure training arguments +# ============================================================================ + +CKPT_ARGS=( + --hf-checkpoint ${HF_CKPT}/ + --ref-load ${TORCH_DIST_CKPT}/ + --load ${SAVE_DIR}/ + --save ${SAVE_DIR}/ + --save-interval 10 + --no-save-optim +) + +ROLLOUT_ARGS=( + --prompt-data ${DATA_PATH} + --input-key messages + --apply-chat-template + --rollout-shuffle + --rollout-batch-size 64 + --n-samples-per-prompt 1 + --rollout-max-response-len 4096 + --rollout-temperature 0.5 + + --global-batch-size 64 + --balance-data + --num-epoch 1 +) + +RM_ARGS=() + +EVAL_ARGS=() + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 128 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 +) + +MOPD_ARGS=( + --advantage-estimator grpo + + # MOPD flags — single teacher + --use-mopd + + # token level + # --mopd-distill-type token_level + + # top k + --mopd-distill-type top_k + --mopd-topk-k 1024 + + # full vocab + # --mopd-distill-type full_vocab + + --mopd-teacher-loads ${TEACHER_TORCH_DIST_CKPT}/ + + # MOPD hyperparameters + --mopd-alpha 0.0 # Pure distillation, no ORM + --mopd-eps-low 0.2 # IS weight lower bound + --mopd-eps-high 5.0 # IS weight upper bound + --mopd-sampling-logprobs-key rollout_log_probs + + # Standard training flags + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 5e-7 # Conservative LR for stability + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + # CPU offload optimizer to save GPU memory for large model + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 16 + --sglang-mem-fraction-static 0.45 + --sglang-ep-size 16 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + + --moe-token-dispatcher-type alltoall + # --moe-enable-deepep # DeepEP internode kernel assertion fails when EP=128 (num_topk_ranks > kNumTopkRDMARanks) + --no-check-for-nan-in-loss-and-grad + + --colocate +) + +# ============================================================================ +# Launch training — multi-node setup +# ============================================================================ + +# --- Submit job --- +RUNTIME_ENV_JSON=$(python3 -c " +import json, os +env = { + 'PYTHONPATH': '/root/Megatron-LM/', + 'CUDA_DEVICE_MAX_CONNECTIONS': '1', + 'NCCL_DEBUG': 'WARN', + 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), + 'NCCL_TIMEOUT_MS': '36000000', + 'FLASHINFER_DISABLE_VERSION_CHECK': '1', + 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') +} +print(json.dumps({'env_vars': env})) +") + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 ../workspace/bin/slime/train.py \ + --actor-num-nodes 16 \ + --actor-num-gpus-per-node 8 \ + --update-weight-buffer-size $(( 1024 * 1024 * 1024 * 4 )) \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${MOPD_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${RM_ARGS[@]} + +# ============================================================================ +# Cleanup +# ============================================================================ +pkill -9 sglang +sleep 3 +pkill -9 python \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh new file mode 100755 index 0000000000..2019b8dc23 --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh @@ -0,0 +1,235 @@ +#!/bin/bash + +# Multi-Teacher On-Policy Distillation (MOPD) — Top-K KL Divergence, SGLang Mode +# Model: Qwen3.5-397B-A17B (MoE, 512 experts, 10 active) +# Environment: 16 nodes × 8 L20X (143GB each), 128 GPUs total +# Teacher: Skin-multiturn teacher (running on external SGLang servers) +# Mode: SGLang (teacher runs on separate SGLang inference servers, no CPU OOM) +# Distill Type: top_k (approximate reverse KL with top-k teacher logits + tail correction) +# +# SGLang mode avoids loading teacher model weights into the Megatron training process, +# eliminating the CPU RAM overhead of TensorBackuper pin_memory backups (~150GB per +# teacher for 397B MoE on each node). The teacher runs on independent SGLang servers, +# and its top-k logprobs are collected during rollout via HTTP requests. +# +# Key differences from Megatron top_k mode: +# - No --mopd-teacher-loads (no Megatron checkpoint needed for teacher) +# - No --enable-weights-backuper needed for teacher +# - Teacher can have a DIFFERENT architecture than student +# - custom-rm-path and custom-reward-post-process-path are auto-configured +# - MOPD_TEACHER_URLS env var specifies the SGLang teacher server endpoints +# +# Prerequisites: +# 1. Start the SGLang teacher server(s) before running this script. +# Example for a single 397B MoE teacher on 16 GPUs: +# +# python3 -m sglang.launch_server \ +# --model-path /personal/ckpt/Qwen3.5-397B-A17B_skin_multiturn/ \ +# --host 0.0.0.0 --port 13141 \ +# --tp 8 --ep-size 16 \ +# --chunked-prefill-size 4096 \ +# --mem-fraction-static 0.7 +# +# 2. Convert student HF checkpoint to Megatron torch_dist format: +# cd /path/to/slime +# source scripts/models/qwen3.5-397B-A17B.sh +# +# PYTHONPATH=/root/Megatron-LM torchrun --nproc_per_node=8 \ +# tools/convert_hf_to_torch_dist.py \ +# ${MODEL_ARGS[@]} \ +# --hf-checkpoint /personal/ckpt/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5 \ +# --save /personal/ckpt/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5_torch_dist +# +# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh + +set -ex + +export PYTHONBUFFERED=16 +export FLASHINFER_DISABLE_VERSION_CHECK=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SLIME_DIR="/workspace/bin/slime" +source "${SLIME_DIR}/scripts/models/qwen3.5-397B-A17B.sh" + +# ============================================================================ +# Paths — adjust these to your environment +# ============================================================================ +BASE_DIR=/personal/ckpt + +HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5 +TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5_torch_dist +SAVE_DIR=/amed/share/s1-amed-spfs-ckpt/yanyi/Qwen3.5-397B-A17B-Stage3b-Mopd-Topk-Skin-Multiturn-Enhanced + +DATA_PATH="/mnt/amed-s3/dataset/14019ba0_text_report_Interpretation/a3967912440becb0d70748a478696f12b6bbf6ac/train_text_think_nothink.jsonl" + +# MOPD teachers JSON config (single teacher for this example) +export MOPD_TEACHERS_JSON='[{"name":"skin-multiturn","domain":"default"}]' + +# MOPD teacher SGLang server URLs +# For multi-teacher, add all domains: {"math":"https://...","code":"https://..."} +TEACHER_IP="aistudio.alipay.com/proxy/rayjob/aistudio-dvm9s0jw-tfjob-master-0" +TEACHER_PORT=8300 +export MOPD_TEACHER_URLS="{\"default\":\"https://$TEACHER_IP:$TEACHER_PORT/generate\"}" + +# ============================================================================ +# Configure training arguments +# ============================================================================ + +CKPT_ARGS=( + --hf-checkpoint ${HF_CKPT}/ + --ref-load ${TORCH_DIST_CKPT}/ + --load ${SAVE_DIR}/ + --save ${SAVE_DIR}/ + --save-interval 10 + --no-save-optim +) + +ROLLOUT_ARGS=( + --prompt-data ${DATA_PATH} + --input-key messages + --apply-chat-template + --rollout-shuffle + --rollout-batch-size 64 + --n-samples-per-prompt 1 + --rollout-max-response-len 4096 + --rollout-temperature 0.5 + + --global-batch-size 64 + --balance-data + --num-epoch 1 +) + +# No RM_ARGS needed for pure distillation (alpha=0). +# custom-rm-path and custom-reward-post-process-path are auto-configured +# by the MOPD SGLang mode argument validation. +RM_ARGS=( + --rm-url https://$TEACHER_IP:$TEACHER_PORT/generate +) + +EVAL_ARGS=() + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 128 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 +) + +MOPD_ARGS=( + --advantage-estimator grpo + + # MOPD flags — single teacher + --use-mopd + + # SGLang teacher mode — teacher runs on external SGLang servers + --mopd-teacher-mode sglang + + # top_k distillation type + --mopd-distill-type top_k + --mopd-topk-k 128 + + # No --mopd-teacher-loads in SGLang mode! + # Teacher data comes from SGLang server via HTTP during rollout. + + # MOPD hyperparameters + --mopd-alpha 0.0 # Pure distillation, no ORM + --mopd-eps-low 0.2 # IS weight lower bound + --mopd-eps-high 5.0 # IS weight upper bound + --mopd-sampling-logprobs-key rollout_log_probs + + # Standard training flags + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 5e-7 # Conservative LR for stability + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + # CPU offload optimizer to save GPU memory for large model + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 16 + --sglang-mem-fraction-static 0.45 + --sglang-ep-size 16 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + + --moe-token-dispatcher-type alltoall + # --moe-enable-deepep # DeepEP internode kernel assertion fails when EP=128 + --no-check-for-nan-in-loss-and-grad + + --colocate +) + +# ============================================================================ +# Launch training — multi-node setup +# ============================================================================ + +# --- Submit job --- +RUNTIME_ENV_JSON=$(python3 -c " +import json, os +env = { + 'PYTHONPATH': '/root/Megatron-LM/', + 'CUDA_DEVICE_MAX_CONNECTIONS': '1', + 'NCCL_DEBUG': 'WARN', + 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), + 'NCCL_TIMEOUT_MS': '36000000', + 'FLASHINFER_DISABLE_VERSION_CHECK': '1', + 'MOPD_TEACHER_URLS': os.environ.get('MOPD_TEACHER_URLS', ''), + 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') +} +print(json.dumps({'env_vars': env})) +") + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 ../workspace/bin/slime/train.py \ + --actor-num-nodes 16 \ + --actor-num-gpus-per-node 8 \ + --update-weight-buffer-size $(( 1024 * 1024 * 1024 * 4 )) \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${MOPD_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${RM_ARGS[@]} \ No newline at end of file diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index bef7986f29..094f7ad441 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -114,14 +114,21 @@ def init( # Load multiple teacher models for Megatron-based MOPD self._mopd_teacher_domains: list[str] = [] - if getattr(args, "use_mopd", False) and getattr(args, "mopd_teacher_loads", None): - mopd_teachers = json.loads(args.mopd_teachers) if isinstance(args.mopd_teachers, str) else args.mopd_teachers - for i, teacher_cfg in enumerate(mopd_teachers): - domain = teacher_cfg["domain"] - tag = f"mopd_teacher_{domain}" - self._mopd_teacher_domains.append(domain) - self.load_other_checkpoint(tag, args.mopd_teacher_loads[i]) - logger.info(f"Loaded MOPD teacher model for domain '{domain}' from {args.mopd_teacher_loads[i]}") + mopd_teacher_mode = getattr(args, "mopd_teacher_mode", "megatron") + if getattr(args, "use_mopd", False): + if mopd_teacher_mode == "megatron" and getattr(args, "mopd_teacher_loads", None): + mopd_teachers = json.loads(args.mopd_teachers) if isinstance(args.mopd_teachers, str) else args.mopd_teachers + for i, teacher_cfg in enumerate(mopd_teachers): + domain = teacher_cfg["domain"] + tag = f"mopd_teacher_{domain}" + self._mopd_teacher_domains.append(domain) + self.load_other_checkpoint(tag, args.mopd_teacher_loads[i]) + logger.info(f"Loaded MOPD teacher model for domain '{domain}' from {args.mopd_teacher_loads[i]}") + elif mopd_teacher_mode == "sglang": + logger.info( + "MOPD SGLang teacher mode: skipping Megatron teacher model loading. " + "Teacher data will be collected from SGLang remote servers during rollout." + ) if self.args.keep_old_actor: # Load old_actor checkpoint @@ -269,36 +276,54 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: ] # Process MOPD teacher log_probs (dict: domain -> list) - # Some entries may be None due to per-sample domain routing (SGLang mode). + # When teacher data is unavailable (e.g., HTTP request failure), entries + # may be None. We replace None with -inf tensors so all DP ranks execute + # the same backward operations, preventing NCCL deadlocks from + # inconsistent collective calls. if "mopd_teacher_log_probs" in rollout_data: mopd_lp_dict = rollout_data["mopd_teacher_log_probs"] processed = {} for domain, lp_list in mopd_lp_dict.items(): - processed[domain] = [ - ( - None - if log_prob is None - else torch.tensor( - slice_log_prob_with_cp( - log_prob, - total_length, - response_length, - self.args.qkv_format, - rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, - ), - device=torch.cuda.current_device(), - dtype=torch.float32, - ) + domain_processed = [] + for i, (log_prob, total_length, response_length) in enumerate( + zip( + lp_list, + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, ) - for i, (log_prob, total_length, response_length) in enumerate( - zip( - lp_list, - rollout_data["total_lengths"], - rollout_data["response_lengths"], - strict=False, + ): + if log_prob is None: + # Create a -inf tensor of the correct size as fallback. + # -inf log-probs produce zero KL contribution, so this + # domain has no effect on the loss for this sample. + sliced_len = len(slice_log_prob_with_cp( + torch.zeros(response_length), + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + )) + domain_processed.append( + torch.full((sliced_len,), float('-inf'), + device=torch.cuda.current_device(), + dtype=torch.float32) ) - ) - ] + else: + domain_processed.append( + torch.tensor( + slice_log_prob_with_cp( + log_prob, + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ), + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + processed[domain] = domain_processed rollout_data["mopd_teacher_log_probs"] = processed if "rollout_routed_experts" in rollout_data: rollout_data["rollout_routed_experts"] = [ @@ -490,7 +515,10 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data ) # Forward each MOPD teacher model for Megatron-based MOPD - if getattr(self.args, "use_mopd", False) and hasattr(self, "_mopd_teacher_domains") and self._mopd_teacher_domains: + # Only applies when mopd_teacher_mode == "megatron". In SGLang mode, + # teacher data is collected during rollout and arrives in rollout_data. + mopd_teacher_mode = getattr(self.args, "mopd_teacher_mode", "megatron") + if getattr(self.args, "use_mopd", False) and mopd_teacher_mode == "megatron" and hasattr(self, "_mopd_teacher_domains") and self._mopd_teacher_domains: mopd_teacher_log_probs = {} mopd_distill_type = getattr(self.args, "mopd_distill_type", "token_level") use_full_vocab = mopd_distill_type == "full_vocab" @@ -521,15 +549,32 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data if logits_list: topk_logits_key = f"mopd_teacher_{domain}_topk_logits" topk_indices_key = f"mopd_teacher_{domain}_topk_indices" + topk_log_sum_exp_key = f"mopd_teacher_{domain}_topk_log_sum_exp" topk_logits_list = [] topk_indices_list = [] + topk_log_sum_exp_list = [] + tp_group = mpu.get_tensor_model_parallel_group() for sample_logits in logits_list: # sample_logits: [R_i, V_local] + # Compute log_sum_exp for exact tail mass estimation. + # This avoids the inaccurate uniform tail assumption + # (V - V_eff) / V which over-estimates tail mass + # when k << V, causing KL inflation of ~5+ nats. + local_max = sample_logits.max(dim=-1).values + dist.all_reduce(local_max, op=dist.ReduceOp.MAX, group=tp_group) + # Numerically stable log_sum_exp + shifted = sample_logits - local_max.unsqueeze(-1) + local_sum_exp = shifted.exp().sum(dim=-1) + dist.all_reduce(local_sum_exp, op=dist.ReduceOp.SUM, group=tp_group) + log_sum_exp = (local_sum_exp + 1e-20).log() + local_max + topk_log_sum_exp_list.append(log_sum_exp.detach().float()) + topk_vals, topk_idx = sample_logits.topk(topk_k, dim=-1) topk_logits_list.append(topk_vals.detach().float()) topk_indices_list.append(topk_idx.detach().int()) rollout_data[topk_logits_key] = topk_logits_list rollout_data[topk_indices_key] = topk_indices_list + rollout_data[topk_log_sum_exp_key] = topk_log_sum_exp_list else: # Token-level mode: only need log_probs teacher_result = self.compute_log_prob( @@ -543,6 +588,124 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data if mopd_teacher_log_probs: rollout_data["mopd_teacher_log_probs"] = mopd_teacher_log_probs + # SGLang MOPD mode: convert rollout-collected top-k data to per-domain batch format + if getattr(self.args, "use_mopd", False) and mopd_teacher_mode == "sglang": + mopd_distill_type = getattr(self.args, "mopd_distill_type", "token_level") + if mopd_distill_type == "top_k": + # Convert SGLang-sourced top-k data (nested dict format from rollout) + # to per-domain batch keys matching the Megatron loss function's expected format. + sglang_topk_logits = rollout_data.pop("mopd_teacher_topk_logits", None) + sglang_topk_indices = rollout_data.pop("mopd_teacher_topk_indices", None) + if sglang_topk_logits and sglang_topk_indices: + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + padded_vocab_size = self.args.padded_vocab_size + vocab_local_size = padded_vocab_size // tp_size + vocab_offset = tp_rank * vocab_local_size + topk_k = self.args.mopd_topk_k + + # Check that SGLang teacher's vocab size is consistent + # with the student's padded_vocab_size. If teacher + # token IDs exceed the student vocab range, the + # global→local TP index conversion will produce + # silently incorrect results. + _vocab_checked = False + + for domain in sglang_topk_logits: + topk_logits_key = f"mopd_teacher_{domain}_topk_logits" + topk_indices_key = f"mopd_teacher_{domain}_topk_indices" + # Convert each sample's [seq_len][k] Python lists to tensors on GPU + topk_logits_list = [] + topk_indices_list = [] + for i, (logits_per_sample, indices_per_sample) in enumerate( + zip(sglang_topk_logits[domain], sglang_topk_indices[domain]) + ): + if logits_per_sample is None or indices_per_sample is None: + # Fallback: create zero-contribution tensors so all DP + # ranks execute the same backward operations, preventing + # NCCL deadlocks from inconsistent collective calls. + # Use -inf logits → zero KL divergence contribution. + seq_len = rollout_data["response_lengths"][i] + topk_logits_list.append( + torch.full((seq_len, topk_k), float('-inf'), + device=torch.cuda.current_device(), + dtype=torch.float32) + ) + topk_indices_list.append( + torch.zeros((seq_len, topk_k), + device=torch.cuda.current_device(), + dtype=torch.int64) + ) + else: + # SGLang returns GLOBAL token IDs, but the Megatron loss + # function (vocab_parallel_topk_reverse_kl) expects LOCAL + # indices within each TP shard's vocab range, with each + # shard having exactly k entries per position. + # + # Strategy: For each position, scatter the global top-k + # entries to the appropriate shard. Each TP rank keeps + # entries whose global token ID falls in its range + # [vocab_offset, vocab_offset + vocab_local_size), + # converts to local index, and pads to k entries with + # local_idx=0, logit=-inf (contributing nothing to KL). + global_indices = torch.tensor( + indices_per_sample, device=torch.cuda.current_device(), dtype=torch.int64 + ) # [seq_len, k_global] + global_logits = torch.tensor( + logits_per_sample, device=torch.cuda.current_device(), dtype=torch.float32 + ) # [seq_len, k_global] + + # Vocab consistency check (once per actor step) + if not _vocab_checked: + _vocab_checked = True + max_token_id = global_indices.max().item() + if max_token_id >= padded_vocab_size: + logger.error( + f"MOPD top_k: SGLang teacher returned token ID " + f"{max_token_id} which exceeds student " + f"padded_vocab_size={padded_vocab_size}. " + f"The teacher and student vocab sizes are " + f"mismatched — this will produce incorrect " + f"TP index conversion and wrong KL divergence. " + f"Ensure the teacher model uses the same " + f"tokenizer/vocab as the student." + ) + + seq_len = global_indices.size(0) + # Mask for which entries are in this shard + in_shard = (global_indices >= vocab_offset) & (global_indices < vocab_offset + vocab_local_size) + # Convert to local indices + local_indices = global_indices - vocab_offset + # Clamp out-of-range indices to 0 (will be overridden by -inf logits) + local_indices = local_indices.clamp(min=0, max=vocab_local_size - 1) + + # Build per-shard top-k: assign in-shard entries, pad rest with -inf + # For each position, we need exactly k entries + local_topk_logits = torch.full( + (seq_len, topk_k), float('-inf'), + device=torch.cuda.current_device(), dtype=torch.float32 + ) + local_topk_indices = torch.zeros( + (seq_len, topk_k), + device=torch.cuda.current_device(), dtype=torch.int64 + ) + + # Scatter: for each position, place the in-shard entries into + # the first available slots. We do this row-by-row for clarity. + for row in range(seq_len): + shard_mask = in_shard[row] # [k_global] + shard_logits = global_logits[row][shard_mask] + shard_local_idx = local_indices[row][shard_mask] + n_in_shard = min(shard_logits.size(0), topk_k) + if n_in_shard > 0: + local_topk_logits[row, :n_in_shard] = shard_logits[:n_in_shard] + local_topk_indices[row, :n_in_shard] = shard_local_idx[:n_in_shard] + + topk_logits_list.append(local_topk_logits) + topk_indices_list.append(local_topk_indices) + rollout_data[topk_logits_key] = topk_logits_list + rollout_data[topk_indices_key] = topk_indices_list + self._switch_model("old_actor" if self.args.keep_old_actor else "actor") can_reuse_log_probs_in_loss = ( len(num_microbatches) == 1 diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 80e66a3ba9..5d9704931f 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -420,13 +420,16 @@ def log_rollout_data( "mopd_teacher_log_probs", "mopd_teacher_logits", "mopd_reverse_kl", + # SGLang-sourced top-k data (dict format, converted to per-domain keys in actor.py) + "mopd_teacher_topk_logits", + "mopd_teacher_topk_indices", ]: continue # Skip per-domain full-vocab teacher logits (too large for averaging) if key.startswith("mopd_teacher_") and key.endswith("_fv_logits"): continue # Skip per-domain top-k teacher logits/indices (too large for averaging) - if key.startswith("mopd_teacher_") and (key.endswith("_topk_logits") or key.endswith("_topk_indices")): + if key.startswith("mopd_teacher_") and (key.endswith("_topk_logits") or key.endswith("_topk_indices") or key.endswith("_topk_log_sum_exp")): continue # Upload per sample mean for each rollout value # There are the following assumptions: diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 8b645d12b6..280fd57426 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -728,6 +728,8 @@ def apply_mopd_to_advantages( for i in range(len(advantages)): # If this sample has no teacher log-probs for this domain (per-sample routing), # use zeros as placeholder — this domain contributes nothing to this sample. + # Also detect fallback sentinel tensors (all -inf) that were inserted when + # MOPD teacher requests failed, to avoid contaminating advantages with -inf. if teacher_lp_list[i] is None: domain_advantages.append(None) domain_is_weights.append(None) @@ -735,6 +737,13 @@ def apply_mopd_to_advantages( continue teacher_lp = teacher_lp_list[i].to(device=device) + if teacher_lp.isinf().all(): + # All -inf: teacher data was unavailable (fallback sentinel). + # Treat same as None — this domain contributes nothing. + domain_advantages.append(None) + domain_is_weights.append(None) + domain_reverse_kls.append(None) + continue # reverse_kl = log(π_domain(y_t)) - log(π_θ(y_t)), with stop-gradient # student_log_probs here is π_θ (the training engine log-probs) @@ -1217,6 +1226,7 @@ def apply_mopd_topk_to_loss( loss_masks: list[torch.Tensor], sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], current_log_probs: list[torch.Tensor] | None = None, + teacher_topk_log_sum_exp_per_domain: dict[str, list[torch.Tensor | None]] | None = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Compute the top-k approximate reverse KL divergence loss for MOPD. @@ -1246,6 +1256,12 @@ def apply_mopd_topk_to_loss( current_log_probs: List of per-sample log-probs from the current training forward pass. Used for importance sampling weight computation. If None, falls back to batch["log_probs"]. + teacher_topk_log_sum_exp_per_domain: Optional dict mapping domain to list of + per-sample teacher log_sum_exp tensors [R_i] (computed from full-vocab logits + during Megatron teacher forward pass). Used for exact tail mass estimation + in Megatron mode. When provided, teacher_tail_mass is computed exactly as + 1 - sum(exp(topk_logits - log_sum_exp)) instead of the uniform assumption + (V - V_eff) / V. Returns: Tuple of (kl_loss, metrics) where kl_loss is a scalar tensor and @@ -1286,14 +1302,29 @@ def apply_mopd_topk_to_loss( continue t_topk_logits = teacher_topk_logits_per_domain[domain][i] # [R_i, k] + + # Skip fallback sentinel tensors (all -inf) from failed teacher requests. + # These would produce KL=0 anyway, so skipping avoids unnecessary + # computation and TP all-reduce calls. + if t_topk_logits.isinf().all(): + continue + t_topk_indices = teacher_topk_indices_per_domain[domain][i] # [R_i, k] + # Get teacher log_sum_exp for exact tail mass (Megatron mode only) + t_topk_log_sum_exp = None + if teacher_topk_log_sum_exp_per_domain and domain in teacher_topk_log_sum_exp_per_domain: + if i < len(teacher_topk_log_sum_exp_per_domain[domain]) and teacher_topk_log_sum_exp_per_domain[domain][i] is not None: + t_topk_log_sum_exp = teacher_topk_log_sum_exp_per_domain[domain][i] # [R_i] + kl_i = vocab_parallel_topk_reverse_kl( student_logits_per_sample[i], t_topk_logits, t_topk_indices, vocab_size, tp_group, + is_log_probs=(getattr(args, "mopd_teacher_mode", "megatron") == "sglang"), + teacher_log_sum_exp=t_topk_log_sum_exp, ) # [R_i] sample_kl_values.append(kl_i) valid_teacher_count += 1 @@ -1556,13 +1587,18 @@ def policy_loss_function( mopd_teachers_parsed = getattr(args, "_mopd_teachers_parsed", []) teacher_topk_logits_per_domain = {} teacher_topk_indices_per_domain = {} + teacher_topk_log_sum_exp_per_domain = {} for teacher_cfg in mopd_teachers_parsed: domain = teacher_cfg["domain"] topk_logits_key = f"mopd_teacher_{domain}_topk_logits" topk_indices_key = f"mopd_teacher_{domain}_topk_indices" + topk_log_sum_exp_key = f"mopd_teacher_{domain}_topk_log_sum_exp" if topk_logits_key in batch and batch[topk_logits_key] is not None: teacher_topk_logits_per_domain[domain] = batch[topk_logits_key] teacher_topk_indices_per_domain[domain] = batch[topk_indices_key] + # log_sum_exp is only available in Megatron mode (computed from full logits) + if topk_log_sum_exp_key in batch and batch[topk_log_sum_exp_key] is not None: + teacher_topk_log_sum_exp_per_domain[domain] = batch[topk_log_sum_exp_key] if teacher_topk_logits_per_domain: topk_kl_loss, mopd_fv_metrics = apply_mopd_topk_to_loss( @@ -1571,6 +1607,7 @@ def policy_loss_function( student_logits_per_sample=student_logits_per_sample, teacher_topk_logits_per_domain=teacher_topk_logits_per_domain, teacher_topk_indices_per_domain=teacher_topk_indices_per_domain, + teacher_topk_log_sum_exp_per_domain=teacher_topk_log_sum_exp_per_domain, loss_masks=batch["loss_masks"], sum_of_sample_mean=sum_of_sample_mean, current_log_probs=current_log_probs_list, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index c6ea21ce6b..6c46284ea5 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -762,6 +762,27 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl ] train_data["mopd_teacher_log_probs"] = mopd_teacher_log_probs + # Add MOPD teacher top-k data (SGLang mode, top_k distill type) + # Format: {domain -> list[list[list[float]]]]} — domain -> per-sample [seq_len][k] + if any(s.mopd_teacher_topk_logits is not None for s in samples): + all_domains_topk = set() + for sample in samples: + if sample.mopd_teacher_topk_logits: + all_domains_topk.update(sample.mopd_teacher_topk_logits.keys()) + mopd_teacher_topk_logits = {} + mopd_teacher_topk_indices = {} + for domain in all_domains_topk: + mopd_teacher_topk_logits[domain] = [ + sample.mopd_teacher_topk_logits.get(domain) if sample.mopd_teacher_topk_logits else None + for sample in samples + ] + mopd_teacher_topk_indices[domain] = [ + sample.mopd_teacher_topk_indices.get(domain) if sample.mopd_teacher_topk_indices else None + for sample in samples + ] + train_data["mopd_teacher_topk_logits"] = mopd_teacher_topk_logits + train_data["mopd_teacher_topk_indices"] = mopd_teacher_topk_indices + return train_data def set_train_parallel_config(self, config: dict): @@ -812,6 +833,17 @@ def _split_train_data_by_dp(self, data, dp_size): for domain, lp_list in data["mopd_teacher_log_probs"].items(): mopd_lp_dict[domain] = [lp_list[j] for j in partition] rollout_data["mopd_teacher_log_probs"] = mopd_lp_dict + # Handle mopd_teacher_topk_logits/indices (dict: domain -> list, SGLang MOPD top_k mode) + if "mopd_teacher_topk_logits" in data: + mopd_topk_logits_dict = {} + for domain, v_list in data["mopd_teacher_topk_logits"].items(): + mopd_topk_logits_dict[domain] = [v_list[j] for j in partition] + rollout_data["mopd_teacher_topk_logits"] = mopd_topk_logits_dict + if "mopd_teacher_topk_indices" in data: + mopd_topk_indices_dict = {} + for domain, v_list in data["mopd_teacher_topk_indices"].items(): + mopd_topk_indices_dict[domain] = [v_list[j] for j in partition] + rollout_data["mopd_teacher_topk_indices"] = mopd_topk_indices_dict # keys that need to be splited at train side for key in [ "raw_reward", @@ -1266,12 +1298,26 @@ def _compute_zero_std_metrics(args, all_samples: list[Sample]): def _is_zero_std(samples: list[Sample]): rewards = [sample.get_reward_value(args) for sample in samples] - return len(rewards) == 0 or all(rewards[0] == r for r in rewards) + if len(rewards) == 0: + return True + # Only compare numeric rewards; skip groups with non-numeric rewards + # (e.g., MOPD SGLang mode where sample.reward is still a dict). + if not isinstance(rewards[0], (int, float)): + return False + return all(rewards[0] == r for r in rewards) all_sample_groups = group_by(all_samples, lambda s: s.group_index) interesting_sample_groups = [g for g in all_sample_groups.values() if _is_zero_std(g)] - interesting_rewards = [str(round(g[0].get_reward_value(args), 1)) for g in interesting_sample_groups] + # Guard against non-numeric reward values (e.g., MOPD SGLang mode where + # sample.reward may still be a dict before post-processing completes). + interesting_rewards = [] + for g in interesting_sample_groups: + rv = g[0].get_reward_value(args) + if isinstance(rv, (int, float)): + interesting_rewards.append(str(round(rv, 1))) + else: + interesting_rewards.append(str(type(rv).__name__)) return {f"zero_std/count_{reward}": len(items) for reward, items in group_by(interesting_rewards).items()} @@ -1302,6 +1348,18 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]): if reward_cat_key is None: return {} - samples_of_reward_cat = group_by(all_samples, lambda s: s.reward[reward_cat_key]) + # Guard against non-dict rewards (e.g., float in MOPD pure distillation mode) + # or dict rewards that don't contain the key. + def _get_reward_cat(s): + if isinstance(s.reward, dict) and reward_cat_key in s.reward: + return s.reward[reward_cat_key] + return None + + samples_of_reward_cat = group_by(all_samples, _get_reward_cat) + # Filter out None category (samples where reward_cat_key is not available) + samples_of_reward_cat.pop(None, None) + + if not samples_of_reward_cat: + return {} return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} diff --git a/slime/rollout/mopd.py b/slime/rollout/mopd.py index ac49d505cc..36ec803063 100644 --- a/slime/rollout/mopd.py +++ b/slime/rollout/mopd.py @@ -1,14 +1,27 @@ """Multi-Teacher On-Policy Distillation (MOPD) rollout support for SGLang. -This module provides reward_func and post_process_rewards for fetching log-probs -from multiple domain-specific teacher SGLang servers. Each teacher is identified -by a domain name and has its own rm_url. - -Usage: - --use-mopd - --mopd-teachers '[{"name": "math_teacher", "domain": "math"}, {"name": "code_teacher", "domain": "code"}]' - --custom-rm-path slime.rollout.mopd.reward_func - --custom-reward-post-process-path slime.rollout.mopd.post_process_rewards +This module provides reward_func and post_process_rewards for fetching teacher +data from multiple domain-specific SGLang teacher servers. Each teacher is +identified by a domain name and has its own rm_url. + +Supports three distillation modes (controlled by --mopd-distill-type): + - token_level: Extract per-token log-probs from SGLang's input_token_logprobs. + - top_k: Extract top-k log-probs and token indices per position using + SGLang's top_logprobs_num parameter. + - full_vocab: Not supported with SGLang teacher mode (requires Megatron + in-process teacher for full vocabulary logits). Raises a clear error + if full_vocab is requested with --mopd-teacher-mode=sglang. + +Usage (pure distillation, alpha=0): + --use-mopd --mopd-teacher-mode sglang + --mopd-teachers '[{"name": "math_teacher", "domain": "math"}]' + (custom-rm-path and custom-reward-post-process-path are auto-configured) + +Usage (with task rewards, alpha>0): + --use-mopd --mopd-teacher-mode sglang --mopd-alpha 0.5 + --mopd-teachers '[{"name": "math_teacher", "domain": "math"}]' + --rm-type math + (combined_reward_func and combined_post_process_rewards are auto-configured) The teacher rm_urls are configured via --mopd-teachers JSON, where each entry can contain an optional "rm_url" field. Alternatively, they can be specified @@ -29,6 +42,16 @@ logger = logging.getLogger(__name__) +def _get_all_domain_names(args) -> list[str]: + """Get all configured MOPD teacher domain names from args. + + Returns: + List of domain name strings (e.g., ['origin', 'enhanced']). + """ + configs = _get_mopd_teacher_configs(args) + return [c.get("domain", c.get("name", "")) for c in configs] + + def _get_mopd_teacher_configs(args) -> list[dict]: """Parse MOPD teacher configurations from args. @@ -42,8 +65,14 @@ def _get_mopd_teacher_configs(args) -> list[dict]: return teachers_str -def _build_payload(sample): - """Build the SGLang request payload for log-prob extraction.""" +def _build_payload(sample, args): + """Build the SGLang request payload for teacher data extraction. + + The payload differs based on --mopd-distill-type: + - token_level: return_logprob=True, no top_logprobs_num + - top_k: return_logprob=True, top_logprobs_num=mopd_topk_k + - full_vocab: raises ValueError (not supported with SGLang) + """ payload = { "input_ids": sample.tokens, "sampling_params": { @@ -55,6 +84,22 @@ def _build_payload(sample): "logprob_start_len": 0, } + # Determine distill type + mopd_distill_type = getattr(args, "mopd_distill_type", "token_level") + + if mopd_distill_type == "top_k": + topk_k = getattr(args, "mopd_topk_k", 1024) + payload["top_logprobs_num"] = topk_k + elif mopd_distill_type == "full_vocab": + raise ValueError( + "MOPD full_vocab mode is not supported with SGLang teacher mode. " + "SGLang cannot efficiently return full-vocabulary logits. " + "Use --mopd-teacher-mode=megatron for full_vocab, or switch to " + "--mopd-distill-type=top_k for an accurate approximation with " + "much lower memory usage." + ) + # token_level: no additional parameters needed + if sample.multimodal_inputs and sample.multimodal_inputs.get("images"): image_data = sample.multimodal_inputs["images"] payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] @@ -62,11 +107,90 @@ def _build_payload(sample): return payload -async def _fetch_teacher_logprobs(session: aiohttp.ClientSession, rm_url: str, payload: dict) -> dict: - """Fetch log-probs from a single teacher SGLang server.""" - async with session.post(rm_url, json=payload) as resp: - resp.raise_for_status() - return await resp.json() +async def _fetch_teacher_logprobs( + session: aiohttp.ClientSession, + rm_url: str, + payload: dict, + max_retries: int = 3, + retry_delay: float = 5.0, +) -> dict: + """Fetch log-probs from a single teacher SGLang server with retry. + + Retries on transient network errors (connection issues, partial reads, + server errors) that are likely to succeed on a subsequent attempt. + Non-retryable errors (e.g., 4xx client errors) are raised immediately. + + Args: + session: The aiohttp client session. + rm_url: The teacher server URL. + payload: The request payload. + max_retries: Maximum number of retry attempts (default 3). + retry_delay: Base delay in seconds between retries (default 5.0). + Actual delay is ``retry_delay * (attempt + 1)`` with jitter. + + Returns: + The parsed JSON response from the teacher server. + + Raises: + The last exception if all retries are exhausted. + """ + import random + + last_exc = None + for attempt in range(max_retries): + try: + async with session.post(rm_url, json=payload) as resp: + # 4xx errors are client errors — retrying won't help. + if resp.status >= 400 and resp.status < 500: + resp.raise_for_status() + # 5xx or network-level errors are transient — retry. + resp.raise_for_status() + result = await resp.json() + + # Validate that the response contains logprob data. + # SGLang's return_logprob is a per-request parameter (not a + # server-side flag). If meta_info lacks input_token_logprobs, + # the most likely cause is that the URL points to the wrong + # SGLang instance (e.g. the student rollout server) or a + # gateway that strips request fields. + meta_info = result.get("meta_info", {}) + if not isinstance(meta_info, dict) or "input_token_logprobs" not in meta_info: + logger.error( + f"MOPD: SGLang teacher response from {rm_url} does NOT contain " + f"'input_token_logprobs' in meta_info. Check that the URL " + f"points to a SGLang server with return_logprob support. " + f"Response meta_info keys: {list(meta_info.keys()) if isinstance(meta_info, dict) else meta_info}. " + f"Request payload had return_logprob={payload.get('return_logprob')}, " + f"logprob_start_len={payload.get('logprob_start_len')}." + ) + else: + logger.info( + f"MOPD: SGLang teacher response from {rm_url} OK, " + f"input_token_logprobs count={len(meta_info['input_token_logprobs'])}" + ) + + return result + except (aiohttp.ClientPayloadError, aiohttp.ClientConnectionError, + aiohttp.ServerDisconnectedError, asyncio.TimeoutError, + aiohttp.ClientResponseError) as exc: + last_exc = exc + if attempt < max_retries - 1: + # 5xx server errors are retryable; ClientPayloadError (e.g. + # ContentLengthError) is typically caused by the server closing + # the connection mid-stream and is also retryable. + is_retryable = True + if isinstance(exc, aiohttp.ClientResponseError) and exc.status < 500: + is_retryable = False + if is_retryable: + delay = retry_delay * (attempt + 1) + random.uniform(0, 2) + logger.warning( + f"MOPD teacher request to {rm_url} failed (attempt {attempt + 1}/{max_retries}): " + f"{type(exc).__name__}: {exc}. Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + continue + raise + raise last_exc # Should not reach here, but just in case def _resolve_teacher_urls(args, teacher_configs: list[dict]) -> dict[str, str]: @@ -88,6 +212,12 @@ def _resolve_teacher_urls(args, teacher_configs: list[dict]) -> dict[str, str]: rm_url = teacher_cfg.get("rm_url") or env_urls.get(domain) if rm_url is None: rm_url = args.rm_url + logger.warning( + f"MOPD: No explicit URL configured for teacher domain '{domain}', " + f"falling back to args.rm_url ({rm_url}). " + f"Set 'rm_url' in --mopd-teachers or the MOPD_TEACHER_URLS " + f"environment variable to override." + ) url_map[domain] = rm_url return url_map @@ -134,12 +264,23 @@ def _get_sample_domains(sample, all_domains: list[str]) -> list[str] | None: return valid_domains +# Default timeout for MOPD teacher HTTP requests (in seconds). +# Individual teacher requests for long sequences (especially multimodal) +# can take several minutes, so we set a generous timeout. +_MOPD_TEACHER_TIMEOUT = aiohttp.ClientTimeout(total=600, connect=30, sock_read=300) + + async def _reward_func_single(args, sample, **kwargs): """Query MOPD teacher servers for a single sample. If sample.metadata contains 'mopd_domains' (a list of domain names or a single string), only the specified teachers are queried. Otherwise, all teachers are queried. + Each teacher request is retried on transient errors (connection resets, + partial reads, server errors) before giving up. When a teacher is + permanently unreachable after retries, it is skipped with a warning and + its domain data will be missing from the result dict. + Returns: dict mapping domain -> raw teacher response (JSON from SGLang). This dict is stored in sample.reward and later processed by post_process_rewards. @@ -153,27 +294,41 @@ async def _reward_func_single(args, sample, **kwargs): if target_domains is not None: url_map = {d: url_map[d] for d in target_domains} - payload = _build_payload(sample) + payload = _build_payload(sample, args) + + # Read retry config from args (with sensible defaults) + max_retries = getattr(args, "mopd_teacher_max_retries", 3) + retry_delay = getattr(args, "mopd_teacher_retry_delay", 5.0) results = {} - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=_MOPD_TEACHER_TIMEOUT) as session: tasks = [] domains = [] for domain, rm_url in url_map.items(): domains.append(domain) - tasks.append(_fetch_teacher_logprobs(session, rm_url, payload)) + tasks.append( + _fetch_teacher_logprobs(session, rm_url, payload, + max_retries=max_retries, + retry_delay=retry_delay) + ) responses = await asyncio.gather(*tasks, return_exceptions=True) for domain, resp in zip(domains, responses): if isinstance(resp, Exception): logger.warning( - f"MOPD teacher '{domain}' failed: {resp}. Skipping this teacher." + f"MOPD teacher '{domain}' failed after retries: {resp}. " + f"Skipping this teacher." ) continue results[domain] = resp + # Record which domains were targeted so the extraction code can + # distinguish between "not queried" (domain routed away) and + # "queried but failed" (should fill with -inf fallback). + results["__target_domains__"] = list(domains) + return results @@ -200,47 +355,390 @@ async def reward_func(args, sample_or_samples, **kwargs): return await _reward_func_single(args, sample_or_samples, **kwargs) -def post_process_rewards(args, samples: list[Sample], **kwargs): - """Process MOPD teacher responses and extract per-domain teacher log-probs. +def _extract_teacher_data_from_responses(args, samples: list[Sample]): + """Extract per-domain teacher data from MOPD teacher responses stored in sample.reward. - This function: - 1. Extracts log-probs from each teacher server response - 2. Stores them in sample.mopd_teacher_log_probs[domain] - 3. Returns scalar rewards compatible with GRPO/PPO - - The raw_rewards for each sample is expected to be a dict mapping domain -> response, - as returned by mopd.reward_func. + This is the core extraction logic shared by post_process_rewards and + combined_post_process_rewards. It reads teacher responses from sample.reward + (which should be a dict mapping domain -> SGLang response JSON) and populates + sample.mopd_teacher_log_probs, sample.mopd_teacher_topk_logits, and + sample.mopd_teacher_topk_indices. """ raw_rewards = [sample.get_reward_value(args) for sample in samples] response_lengths = [sample.response_length for sample in samples] + mopd_distill_type = getattr(args, "mopd_distill_type", "token_level") + for sample, reward_val, response_length in zip(samples, raw_rewards, response_lengths, strict=False): if sample.mopd_teacher_log_probs is None: sample.mopd_teacher_log_probs = {} + if mopd_distill_type == "top_k": + if sample.mopd_teacher_topk_logits is None: + sample.mopd_teacher_topk_logits = {} + if sample.mopd_teacher_topk_indices is None: + sample.mopd_teacher_topk_indices = {} + if not isinstance(reward_val, dict): # If reward_func didn't return a dict (e.g., fallback case), skip continue for domain, teacher_response in reward_val.items(): + # Skip internal metadata keys + if domain.startswith("__") and domain.endswith("__"): + continue try: - # Extract log-probs from sglang response format + meta_info = teacher_response["meta_info"] + input_token_logprobs = meta_info["input_token_logprobs"] + + # --- token_level: always extract (needed even for top_k for IS weights) --- + # input_token_logprobs format: list of [log_prob, token_id, token_text] + # Skip the first entry (prompt token before any generation) + logprobs_from_response = input_token_logprobs[1:] + if len(logprobs_from_response) < response_length: + logger.warning( + f"MOPD: SGLang returned {len(logprobs_from_response)} logprob entries " + f"for domain '{domain}', but response_length={response_length}. " + f"Padding with -inf for missing positions." + ) log_probs = torch.tensor( - [item[0] for item in teacher_response["meta_info"]["input_token_logprobs"][1:]], + [item[0] for item in logprobs_from_response], dtype=torch.float32, ) - # Trim to response length + if log_probs.size(0) < response_length: + # Pad shorter log_probs with -inf so downstream code + # doesn't misalign position indices. + log_probs = torch.nn.functional.pad( + log_probs, (0, response_length - log_probs.size(0)), value=float("-inf") + ) + # Trim to response length (in case SGLang returns more) log_probs = log_probs[-response_length:] sample.mopd_teacher_log_probs[domain] = log_probs + + valid_mask = log_probs.isfinite() + if valid_mask.any(): + logger.info( + f"MOPD: Received teacher logprobs for domain '{domain}': " + f"len={log_probs.size(0)}, valid={valid_mask.sum().item()}, " + f"mean={log_probs[valid_mask].mean().item():.4f}" + ) + else: + logger.info( + f"MOPD: Received teacher logprobs for domain '{domain}': " + f"len={log_probs.size(0)}, all -inf" + ) + + # --- top_k: extract top-k log-probs and indices per position --- + if mopd_distill_type == "top_k": + # SGLang returns top-k logprobs in meta_info["input_top_logprobs"] + # Format: list (one entry per position) of list of (log_prob, token_id, token_text) + # tuples. Same length as input_token_logprobs. + input_top_logprobs = meta_info.get("input_top_logprobs") + if input_top_logprobs is None: + logger.warning( + f"MOPD top_k: SGLang response for domain '{domain}' does not contain " + f"'input_top_logprobs'. Make sure top_logprobs_num is set in the " + f"SGLang request payload. Falling back to token_level for this domain." + ) + continue + + # Skip first entry (same as input_token_logprobs), trim to response + top_logprobs_response = input_top_logprobs[1:] + if len(top_logprobs_response) < response_length: + logger.warning( + f"MOPD top_k: SGLang returned {len(top_logprobs_response)} " + f"top-logprobs entries for domain '{domain}', but " + f"response_length={response_length}. Missing positions " + f"will be padded with -inf logits and index 0." + ) + # Pad with None entries so the loop below generates + # [-inf, ..., -inf] / [0, ..., 0] for missing positions + top_logprobs_response = top_logprobs_response + [None] * (response_length - len(top_logprobs_response)) + if len(top_logprobs_response) > response_length: + top_logprobs_response = top_logprobs_response[-response_length:] + + # Each position: list of (log_prob, token_id, token_text) tuples + # (sorted by log_prob desc). token_text is None when + # return_text_in_logprobs is not set. + # Convert to: topk_logits[pos] = [log_prob_0, log_prob_1, ...] + # topk_indices[pos] = [token_id_0, token_id_1, ...] + # Padding entries use -inf logit so downstream TP sharding and + # valid_topk_mask can correctly identify them as invalid. + topk_k = getattr(args, "mopd_topk_k", 1024) + NEG_INF = float("-inf") + + topk_logits_list = [] # [seq_len][k] float + topk_indices_list = [] # [seq_len][k] int + short_positions = 0 # Count positions with fewer than topk_k entries + + for pos_data in top_logprobs_response: + if pos_data is None or len(pos_data) == 0: + # No top-k data for this position (e.g., padding) + topk_logits_list.append([NEG_INF] * topk_k) + topk_indices_list.append([0] * topk_k) + short_positions += 1 + continue + + # pos_data: list of (log_prob, token_id, token_text) tuples + pos_logits = [] + pos_indices = [] + for entry in pos_data[:topk_k]: + # entry: (log_prob, token_id, token_text) + pos_logits.append(float(entry[0])) + pos_indices.append(int(entry[1])) + + # Pad to topk_k if fewer entries returned + # Use -inf logit for padding so downstream valid_topk_mask + # detection (checking for -inf entries) works correctly. + if len(pos_logits) < topk_k: + short_positions += 1 + while len(pos_logits) < topk_k: + pos_logits.append(NEG_INF) + pos_indices.append(0) + + topk_logits_list.append(pos_logits) + topk_indices_list.append(pos_indices) + + if short_positions > 0: + logger.warning( + f"MOPD top_k: {short_positions}/{len(top_logprobs_response)} " + f"positions in domain '{domain}' returned fewer than {topk_k} " + f"top-k entries from SGLang. Padded with -inf logits. " + f"Consider reducing --mopd-topk-k or checking SGLang's " + f"top_logprobs_num setting." + ) + + sample.mopd_teacher_topk_logits[domain] = topk_logits_list + sample.mopd_teacher_topk_indices[domain] = topk_indices_list + except (KeyError, IndexError, TypeError) as e: + # Provide an actionable message for the most common cause: + # SGLang server not returning logprobs. + if isinstance(e, KeyError) and str(e) in ("'input_token_logprobs'", "input_token_logprobs"): + meta_keys = list(teacher_response.get("meta_info", {}).keys()) if isinstance(teacher_response.get("meta_info"), dict) else "N/A" + logger.error( + f"MOPD: SGLang response for domain '{domain}' missing " + f"'input_token_logprobs'. meta_info keys: {meta_keys}. " + f"Check teacher URL configuration." + ) + else: + logger.warning( + f"MOPD: Failed to extract teacher data for domain '{domain}': {e}" + ) + + # --- Fill in missing domains with zero/fallback data --- + # When a teacher request fails (e.g., ContentLengthError, connection reset), + # the domain is absent from reward_val and thus from the sample's dicts. + # Previously this would produce None placeholders in the training data + # pipeline, which could cause NCCL deadlocks because different DP ranks + # may end up with different computational graphs. + # + # IMPORTANT: We only fill fallback data for domains that were *actually + # queried* for this sample (i.e., in target_domains). Domains that were + # excluded by per-sample domain routing (sample.metadata["mopd_domains"]) + # should NOT be filled — they intentionally don't participate in this + # sample's loss computation, and filling them would incorrectly produce + # -inf fallback tensors that contribute zero KL but still occupy memory + # and trigger unnecessary backward operations. + all_configured_domains = _get_all_domain_names(args) + for sample, reward_val, response_length in zip(samples, raw_rewards, response_lengths, strict=False): + if sample.mopd_teacher_log_probs is None: + sample.mopd_teacher_log_probs = {} + + # Determine which domains were targeted for this sample. + # If __target_domains__ is present in reward_val (set by + # _reward_func_single), use it. Otherwise, fall back to all + # configured domains (backward compatible). + if isinstance(reward_val, dict) and "__target_domains__" in reward_val: + target_domains = reward_val["__target_domains__"] + else: + target_domains = all_configured_domains + + for domain in target_domains: + if domain not in sample.mopd_teacher_log_probs: logger.warning( - f"MOPD: Failed to extract log-probs for domain '{domain}': {e}" + f"MOPD: Teacher data for domain '{domain}' is missing for a sample " + f"(was queried but extraction failed or request failed). " + f"Filling with -inf log-probs (zero KL contribution)." + ) + sample.mopd_teacher_log_probs[domain] = torch.full( + (response_length,), float('-inf'), dtype=torch.float32 ) + if mopd_distill_type == "top_k": + if sample.mopd_teacher_topk_logits is None: + sample.mopd_teacher_topk_logits = {} + if sample.mopd_teacher_topk_indices is None: + sample.mopd_teacher_topk_indices = {} + if domain not in sample.mopd_teacher_topk_logits: + topk_k = getattr(args, "mopd_topk_k", 1024) + NEG_INF = float("-inf") + sample.mopd_teacher_topk_logits[domain] = [ + [NEG_INF] * topk_k for _ in range(response_length) + ] + sample.mopd_teacher_topk_indices[domain] = [ + [0] * topk_k for _ in range(response_length) + ] + + +def post_process_rewards(args, samples: list[Sample], **kwargs): + """Process MOPD teacher responses and extract per-domain teacher data. + + This is the standalone post_process_rewards for pure MOPD distillation (alpha=0) + where no task reward is needed. It reads teacher responses from sample.reward + (which should be a dict mapping domain -> SGLang response JSON as returned by + reward_func) and populates sample.mopd_teacher_* fields. + + For combined MOPD + task rewards (alpha>0), use combined_post_process_rewards instead. + + Returns: + Tuple of (scalar_rewards, scalar_rewards) for GRPO/PPO compatibility. + All rewards are 0.0 since the learning signal comes from distillation. + """ + _extract_teacher_data_from_responses(args, samples) + + # Reset sample.reward to scalar 0.0 — the SGLang response dict is no longer + # needed and leaving it as a dict causes downstream code (e.g., metrics + # logging, reward aggregation) to break since they expect numeric values. + for sample in samples: + sample.reward = 0.0 - # Return scalar rewards for GRPO/PPO advantage estimator - # For pure MOPD distillation, we use 0.0 as the task reward. - # The learning signal comes from the MOPD advantage applied in compute_advantages_and_returns. - # If you have task rewards, configure them separately via reward model. scalar_rewards = [0.0] * len(samples) + return scalar_rewards, scalar_rewards + + +# --------------------------------------------------------------------------- +# Combined reward functions (MOPD teacher data + task rewards) +# --------------------------------------------------------------------------- +# When --mopd-alpha > 0, MOPD combines distillation advantages with task +# rewards. This requires both collecting teacher data from SGLang AND getting +# task rewards from the standard reward model. Since custom_rm_path replaces +# the standard reward model, we provide combined wrappers that invoke both. +# +# The key idea: +# - combined_reward_func: fetches teacher data from SGLang, stores it in +# sample.metadata["_mopd_teacher_responses"], then calls the standard +# reward model (rm_type-based) to get task rewards. +# - combined_post_process_rewards: extracts teacher log-probs from +# sample.metadata["_mopd_teacher_responses"], then applies standard +# reward post-processing (GRPO normalization, etc.) to the task rewards +# stored in sample.reward. +# --------------------------------------------------------------------------- + +_MOPD_TEACHER_RESPONSES_KEY = "_mopd_teacher_responses" + + +async def combined_reward_func(args, sample_or_samples, **kwargs): + """Combined reward function: MOPD teacher data collection + task rewards. - return scalar_rewards, scalar_rewards \ No newline at end of file + This function: + 1. Fetches MOPD teacher data from SGLang servers (via reward_func). + 2. Stores the teacher responses in sample.metadata for later extraction. + 3. Calls the standard reward model (rm_type-based) to get task rewards. + + Returns the task reward (float) for single sample, or list of task rewards + for batch mode. The MOPD teacher data is stored in sample metadata. + + NOTE: This function temporarily sets args.custom_rm_path to None to bypass + the custom RM and call the standard rm_type-based reward model. This is + safe because reward evaluation is sequential per rollout batch within a + single worker process. However, if concurrent reward evaluation is ever + introduced, this would need to be refactored to avoid data races. + """ + from slime.rollout.rm_hub import async_rm, batched_async_rm + + # Step 1: Fetch MOPD teacher data + if isinstance(sample_or_samples, list): + mopd_results = await reward_func(args, sample_or_samples, **kwargs) + + # Store teacher responses in metadata, then get task rewards + # Temporarily save custom_rm_path so we can bypass it for task rewards + original_custom_rm_path = args.custom_rm_path + args.custom_rm_path = None + try: + task_rewards = await batched_async_rm(args, sample_or_samples, **kwargs) + finally: + args.custom_rm_path = original_custom_rm_path + + # Store MOPD teacher responses in sample metadata + for sample, mopd_result in zip(sample_or_samples, mopd_results): + if isinstance(sample.metadata, dict): + sample.metadata[_MOPD_TEACHER_RESPONSES_KEY] = mopd_result + else: + sample.metadata = {_MOPD_TEACHER_RESPONSES_KEY: mopd_result} + + return task_rewards + else: + sample = sample_or_samples + mopd_result = await reward_func(args, sample, **kwargs) + + # Store teacher response in metadata + if isinstance(sample.metadata, dict): + sample.metadata[_MOPD_TEACHER_RESPONSES_KEY] = mopd_result + else: + sample.metadata = {_MOPD_TEACHER_RESPONSES_KEY: mopd_result} + + # Get task reward (bypass custom_rm_path to use rm_type) + original_custom_rm_path = args.custom_rm_path + args.custom_rm_path = None + try: + task_reward = await async_rm(args, sample, **kwargs) + finally: + args.custom_rm_path = original_custom_rm_path + + return task_reward + + +def combined_post_process_rewards(args, samples: list[Sample], **kwargs): + """Combined post-processing: extract MOPD teacher data + standard reward normalization. + + This function: + 1. Extracts MOPD teacher log-probs from sample.metadata["_mopd_teacher_responses"] + (stored by combined_reward_func), populates sample.mopd_teacher_* fields. + 2. Applies standard reward post-processing (GRPO normalization, etc.) to the + task rewards stored in sample.reward. + + Returns: + Tuple of (raw_rewards, processed_rewards) for GRPO/PPO compatibility. + """ + # Step 1: Extract MOPD teacher data from metadata + # Temporarily swap sample.reward to contain the teacher responses so that + # _extract_teacher_data_from_responses can read them via get_reward_value. + # Save original task rewards first. + original_rewards = [] + for sample in samples: + original_rewards.append(sample.reward) + teacher_responses = None + if isinstance(sample.metadata, dict): + teacher_responses = sample.metadata.get(_MOPD_TEACHER_RESPONSES_KEY) + sample.reward = teacher_responses # Temporarily set for extraction + + # Extract teacher data (populates sample.mopd_teacher_log_probs, etc.) + _extract_teacher_data_from_responses(args, samples) + + # Clean up temporary metadata and restore task rewards + for sample, original_reward in zip(samples, original_rewards): + if isinstance(sample.metadata, dict): + sample.metadata.pop(_MOPD_TEACHER_RESPONSES_KEY, None) + sample.reward = original_reward + + # Step 2: Apply standard reward post-processing + raw_rewards = [sample.get_reward_value(args) for sample in samples] + if ( + args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] + and args.rewards_normalization + ): + rewards = torch.tensor(raw_rewards, dtype=torch.float) + if rewards.shape[-1] == args.n_samples_per_prompt * args.rollout_batch_size: + rewards = rewards.reshape(-1, args.n_samples_per_prompt) + else: + rewards = rewards.view(-1, rewards.shape[-1]) + mean = rewards.mean(dim=-1, keepdim=True) + rewards = rewards - mean + + if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization: + std = rewards.std(dim=-1, keepdim=True) + rewards = rewards / (std + 1e-6) + + return raw_rewards, rewards.flatten().tolist() + + return raw_rewards, raw_rewards \ No newline at end of file diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 44c03858ff..38c7b72ad9 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -432,8 +432,13 @@ async def generate_rollout_async( if do_print: sample = group[0][0] if isinstance(group[0], list) else group[0] + # Truncate reward repr to avoid flooding logs when reward contains + # large dicts (e.g., MOPD SGLang teacher responses with top-k logprobs). + reward_repr = repr(sample.reward) + if len(reward_repr) > 200: + reward_repr = reward_repr[:200] + "..." logger.info( - f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {reward_repr}", ) do_print = False @@ -453,8 +458,13 @@ async def generate_rollout_async( pbar.close() sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + # Truncate reward repr to avoid flooding logs when reward contains + # large dicts (e.g., MOPD SGLang teacher responses with top-k logprobs). + reward_repr = repr(sample.reward) + if len(reward_repr) > 200: + reward_repr = reward_repr[:200] + "..." logger.info( - f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {reward_repr}", ) # there are still some unfinished requests, abort them @@ -574,10 +584,15 @@ async def eval_rollout_single_dataset( sample = await coro if do_print: logged_sample = sample[0] if isinstance(sample, list) else sample + # Truncate reward repr to avoid flooding logs when reward contains + # large dicts (e.g., MOPD SGLang teacher responses with top-k logprobs). + reward_repr = repr(logged_sample.reward) + if len(reward_repr) > 200: + reward_repr = reward_repr[:200] + "..." logger.info( "eval_rollout_single_dataset example data: " f"{[str(logged_sample.prompt) + logged_sample.response]} " - f"reward={logged_sample.reward}" + f"reward={reward_repr}" ) do_print = False if isinstance(sample, list): diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 3391b60fb8..13f603543d 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1106,16 +1106,17 @@ def add_on_policy_distillation_arguments(parser): help=( "MOPD distillation type. " "'token_level' (default): use the sampled token log-prob difference as a reverse KL approximation, " - "applied at the advantage level. " + "applied at the advantage level. Works with both megatron and sglang teacher modes. " "'full_vocab': compute the exact full-vocabulary reverse KL divergence D_KL(π_θ ∥ π_d) " - "using complete logits from both student and teacher models. This is only supported with " - "megatron teacher mode (--mopd-teacher-loads), as it requires access to teacher logits " - "during training. The full-vocab KL loss is computed directly in the loss function rather " - "than through advantage modification. " + "using complete logits from both student and teacher models. Only supported with " + "megatron teacher mode (--mopd-teacher-mode=megatron with --mopd-teacher-loads), " + "as it requires access to teacher logits during training. " "'top_k': compute an approximate reverse KL divergence using the top-k teacher logits " "plus tail probability correction. Stores only [R, k] logits+indices per sample " "(k controlled by --mopd-topk-k, default 1024), greatly reducing memory compared to " - "full_vocab while being more accurate than token_level. Requires --mopd-teacher-loads." + "full_vocab while being more accurate than token_level. Works with both megatron and " + "sglang teacher modes. For sglang mode, top-k data is collected from SGLang servers " + "via the top_logprobs_num API parameter." ), ) parser.add_argument( @@ -1128,6 +1129,46 @@ def add_on_policy_distillation_arguments(parser): "approximation at the cost of more memory. Default: 1024." ), ) + parser.add_argument( + "--mopd-teacher-mode", + type=str, + choices=["megatron", "sglang"], + default="megatron", + help=( + "How to run MOPD teacher models. " + "'megatron' (default): load teacher checkpoints into the training process " + "via Megatron and compute teacher logits/log-probs during training. " + "Requires --mopd-teacher-loads. " + "'sglang': use remote SGLang inference servers as teachers. Teacher log-probs " + "(and optionally logits/top-k data) are collected during rollout via HTTP, " + "eliminating the need to load teacher weights into the training process. " + "This dramatically reduces CPU host memory usage by avoiding pin_memory backups " + "of teacher model weights. Must also set --custom-rm-path to " + "slime.rollout.mopd.reward_func and --custom-reward-post-process-path to " + "slime.rollout.mopd.post_process_rewards (or use the defaults when --use-mopd)." + ), + ) + parser.add_argument( + "--mopd-teacher-max-retries", + type=int, + default=3, + help=( + "Maximum number of retry attempts for each MOPD teacher HTTP request " + "in SGLang mode. Retries are applied on transient errors such as " + "connection resets, partial reads (ContentLengthError), and server " + "errors (5xx). Client errors (4xx) are not retried. Default: 3." + ), + ) + parser.add_argument( + "--mopd-teacher-retry-delay", + type=float, + default=5.0, + help=( + "Base delay in seconds between MOPD teacher request retries. " + "The actual delay is retry_delay * (attempt_number + 1) with " + "random jitter. Default: 5.0." + ), + ) return parser def add_router_arguments(parser): @@ -1832,19 +1873,99 @@ def slime_validate_args(args): f"--mopd-eps-high ({args.mopd_eps_high}) must be > --mopd-eps-low ({args.mopd_eps_low})." ) - # Validate mopd_distill_type: full_vocab and top_k modes require megatron teachers - if args.mopd_distill_type in ("full_vocab", "top_k"): - if args.mopd_teacher_loads is None: - raise ValueError( - f"--mopd-distill-type={args.mopd_distill_type} requires --mopd-teacher-loads (megatron teacher mode). " - "SGLang-based teachers cannot return full-vocabulary logits efficiently. " - "Please provide teacher checkpoints via --mopd-teacher-loads." - ) + # Set default teacher mode based on whether mopd_teacher_loads is provided + if not hasattr(args, "mopd_teacher_mode") or args.mopd_teacher_mode is None: + args.mopd_teacher_mode = "sglang" if args.mopd_teacher_loads is None else "megatron" + + # Validate mopd_distill_type compatibility with teacher mode + if args.mopd_teacher_mode == "megatron": + if args.mopd_distill_type in ("full_vocab", "top_k"): + if args.mopd_teacher_loads is None: + raise ValueError( + f"--mopd-distill-type={args.mopd_distill_type} with --mopd-teacher-mode=megatron " + "requires --mopd-teacher-loads. Please provide teacher checkpoints via --mopd-teacher-loads." + ) + + # SGLang mode does not support full_vocab (cannot return full vocab logits efficiently) + if args.mopd_teacher_mode == "sglang" and args.mopd_distill_type == "full_vocab": + raise ValueError( + "--mopd-distill-type=full_vocab is not supported with --mopd-teacher-mode=sglang. " + "SGLang remote inference servers cannot efficiently return full-vocabulary logits. " + "Use --mopd-teacher-mode=megatron for full_vocab, or switch to " + "--mopd-distill-type=top_k for an accurate approximation with much lower " + "memory and network usage." + ) # Validate mopd_topk_k if args.mopd_distill_type == "top_k" and args.mopd_topk_k <= 0: raise ValueError(f"--mopd-topk-k must be > 0, got {args.mopd_topk_k}.") + # SGLang mode: validate that custom_rm_path / reward_func is properly configured + if args.mopd_teacher_mode == "sglang": + # SGLang mode does not load Megatron teacher weights, so weights_backuper is not needed + # for teacher backups. If no other feature requires it, we can skip it. + if args.mopd_teacher_loads is not None: + raise ValueError( + "--mopd-teacher-mode=sglang is incompatible with --mopd-teacher-loads. " + "SGLang mode uses remote inference servers as teachers and does not load " + "Megatron teacher checkpoints. Remove --mopd-teacher-loads or use " + "--mopd-teacher-mode=megatron." + ) + + # Auto-configure reward function paths if not explicitly set. + # Two scenarios: + # 1. Pure distillation (alpha=0, no rm_type): Use standalone MOPD reward functions. + # 2. Combined (alpha>0 or rm_type set): Use combined wrappers that collect MOPD + # teacher data AND task rewards from the standard reward model. + has_task_reward = args.mopd_alpha > 0 + + if args.custom_rm_path is None and args.custom_reward_post_process_path is None: + # No explicit reward config — auto-configure based on whether task rewards are needed + if has_task_reward: + combined_rm_path = "slime.rollout.mopd.combined_reward_func" + combined_pp_path = "slime.rollout.mopd.combined_post_process_rewards" + args.custom_rm_path = combined_rm_path + args.custom_reward_post_process_path = combined_pp_path + logger.info( + f"MOPD SGLang mode with task rewards: auto-setting --custom-rm-path to " + f"'{combined_rm_path}' and --custom-reward-post-process-path to " + f"'{combined_pp_path}'. These combined functions collect MOPD teacher data " + f"from SGLang AND task rewards from --rm-type={args.rm_type}." + ) + else: + standalone_rm_path = "slime.rollout.mopd.reward_func" + standalone_pp_path = "slime.rollout.mopd.post_process_rewards" + args.custom_rm_path = standalone_rm_path + args.custom_reward_post_process_path = standalone_pp_path + logger.info( + f"MOPD SGLang mode (pure distillation): auto-setting --custom-rm-path to " + f"'{standalone_rm_path}' and --custom-reward-post-process-path to " + f"'{standalone_pp_path}'." + ) + elif args.custom_rm_path is not None and args.custom_reward_post_process_path is None: + # Only custom_rm_path is set — set a matching post_process + if "slime.rollout.mopd.combined_reward_func" in args.custom_rm_path: + args.custom_reward_post_process_path = "slime.rollout.mopd.combined_post_process_rewards" + elif "slime.rollout.mopd.reward_func" in args.custom_rm_path: + args.custom_reward_post_process_path = "slime.rollout.mopd.post_process_rewards" + else: + # User has a custom reward function — they need to handle MOPD extraction themselves + logger.warning( + "MOPD SGLang mode: --custom-rm-path is set to a non-MOPD function. " + "You must ensure teacher data extraction is handled in your " + "--custom-reward-post-process-path function." + ) + elif args.custom_rm_path is None and args.custom_reward_post_process_path is not None: + # Only post_process is set — auto-configure reward function + if has_task_reward: + args.custom_rm_path = "slime.rollout.mopd.combined_reward_func" + else: + args.custom_rm_path = "slime.rollout.mopd.reward_func" + logger.info( + f"MOPD SGLang mode: auto-setting --custom-rm-path to '{args.custom_rm_path}' " + f"(--custom-reward-post-process-path already set)." + ) + # MOPD with megatron-based teachers requires weights_backuper (to backup multiple models) if args.mopd_teacher_loads is not None and not args.enable_weights_backuper: raise ValueError( @@ -1857,12 +1978,25 @@ def slime_validate_args(args): # so a reward model is required. # When mopd_alpha == 0, pure distillation doesn't need task rewards; if no rm_type # or custom_rm_path is set, default to "zero" reward. - if args.mopd_alpha > 0 and args.rm_type is None and args.custom_rm_path is None: + # Note: After SGLang auto-config, custom_rm_path may be set to a MOPD function, + # so we check the combination of mopd_alpha and whether there's a real task reward source. + _mopd_uses_combined_rm = ( + args.custom_rm_path is not None + and "slime.rollout.mopd.combined_reward_func" in str(args.custom_rm_path) + ) + if args.mopd_alpha > 0 and args.rm_type is None and not _mopd_uses_combined_rm and args.custom_rm_path is None: raise ValueError( "--mopd-alpha > 0 requires a reward model (--rm-type or --custom-rm-path) " "because ORM advantages are combined with distillation advantages. " "Either set --rm-type, --custom-rm-path, or use --mopd-alpha 0 for pure distillation." ) + if _mopd_uses_combined_rm and args.rm_type is None: + raise ValueError( + "MOPD combined reward mode (--mopd-alpha > 0 with SGLang teacher mode) requires " + "--rm-type to be set for task rewards. The combined reward function collects both " + "MOPD teacher data and task rewards. Without --rm-type, there is no task reward source. " + "Either set --rm-type, or use --mopd-alpha 0 for pure distillation." + ) if args.mopd_alpha == 0 and args.rm_type is None and args.custom_rm_path is None: logger.info( "MOPD with alpha=0 (pure distillation): no --rm-type or --custom-rm-path set, " diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 6fc1bef005..8a1abca2da 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -316,6 +316,9 @@ def vocab_parallel_topk_reverse_kl( teacher_topk_indices: torch.Tensor, vocab_size: int, process_group: dist.ProcessGroup, + valid_topk_mask: torch.Tensor | None = None, + is_log_probs: bool = False, + teacher_log_sum_exp: torch.Tensor | None = None, ) -> torch.Tensor: """Compute approximate D_KL(π_student ∥ π_teacher) using top-k teacher logits plus tail correction. @@ -335,9 +338,29 @@ def vocab_parallel_topk_reverse_kl( Args: student_logits: [R, V_local] student logits (with grad), vocab-sharded across TP. teacher_topk_logits: [R, k] teacher top-k logits (detached), fp32. + When is_log_probs=False (Megatron mode): raw logits from teacher forward pass. + When is_log_probs=True (SGLang mode): log probabilities (log softmax) from + SGLang's input_top_logprobs. The function will skip the log_softmax step + and use them directly as teacher log-probs. teacher_topk_indices: [R, k] teacher top-k LOCAL indices within each TP shard, int. vocab_size: Full (unsharded) vocabulary size V. process_group: TP process group for all-reduce. + valid_topk_mask: Optional [R, k] boolean mask. True = valid entry, False = padding. + When provided (e.g. for SGLang-sourced top-k with TP>1 where different shards + have different numbers of real entries), padding entries are zeroed out so they + don't contribute to the KL. When None, all k entries are assumed valid (Megatron + mode where each shard independently computes its top-k). + is_log_probs: If True, teacher_topk_logits contains log probabilities (already + softmax-normalized) rather than raw logits. This is the case for SGLang mode + where the server returns log-probs directly. When True, the function skips + the log_softmax computation and uses the values as-is for teacher log-probs. + teacher_log_sum_exp: Optional [R] tensor with the teacher's full-vocabulary + log_sum_exp per token position (computed from complete logits during the + Megatron teacher forward pass). When provided along with is_log_probs=False + (Megatron mode), enables exact teacher tail mass computation: + teacher_tail_mass = 1 - sum(exp(topk_logits - log_sum_exp)). + This replaces the inaccurate uniform assumption (V - V_eff) / V that + can over-estimate tail mass by orders of magnitude when k << V. Returns: Per-token KL divergence tensor of shape [R]. @@ -354,6 +377,19 @@ def vocab_parallel_topk_reverse_kl( tp_size = dist.get_world_size(group=process_group) if process_group is not None else 1 k = teacher_topk_logits.size(-1) + # Compute validity mask from teacher_topk_logits if not provided. + # Entries with -inf logits are padding (e.g., from SGLang TP sharding). + if valid_topk_mask is None: + # Auto-detect: any entry that is not -inf is valid. + # This is backward-compatible with Megatron mode where all entries are valid. + valid_topk_mask = ~torch.isinf(teacher_topk_logits) + + # Zero out padding entries in teacher_topk_logits to prevent NaN in exp() + # This replaces -inf with a large negative value that won't affect the max but + # will become 0 after exp. The valid_topk_mask handles the rest. + teacher_topk_logits_safe = teacher_topk_logits.clone() + teacher_topk_logits_safe[~valid_topk_mask] = -1e9 # large negative, not -inf + # --- student softmax (numerically stable, TP-aware) --- s_max = student_logits.max(dim=-1, keepdim=True).values if tp_size > 1: @@ -372,75 +408,126 @@ def vocab_parallel_topk_reverse_kl( student_topk_shifted = s_shifted.gather(-1, teacher_topk_indices) # [R, k] student_topk_log_probs = student_topk_shifted - s_log_sum_exp # [R, k] - # --- teacher log-softmax for top-k logits (numerically stable, TP-aware) --- - t_max = teacher_topk_logits.max(dim=-1, keepdim=True).values - if tp_size > 1: - # teacher_topk_logits are per-shard, so we need global max across shards - # BUT: each shard's top-k is independent (local indices). - # To compute the correct global log_sum_exp, we need: - # (a) the global max of ALL top-k logits across shards, and - # (b) the sum of exp(logits - global_max) across all shards. - dist.all_reduce(t_max, op=dist.ReduceOp.MAX, group=process_group) - t_shifted = teacher_topk_logits - t_max - t_exp = t_shifted.exp() - t_sum_exp = t_exp.sum(dim=-1, keepdim=True) - if tp_size > 1: - # Sum of exp across all shards (each shard contributes its k top-k values) - dist.all_reduce(t_sum_exp, op=dist.ReduceOp.SUM, group=process_group) - t_log_sum_exp = t_sum_exp.log() # [R, 1] - - # Teacher probs on top-k: exp(topk_logits) / Z_teacher - # But Z_teacher is NOT just the sum over top-k tokens. - # We need an approximation: Z_teacher ≈ sum_topk(exp) + (V_eff - k) * exp(tail_max). - # However, we don't have the exact partition function for teacher. - # Instead, we compute teacher_topk_log_probs using a CLOSED-over-topk partition function - # (treating top-k as if they were the entire vocab), then apply tail correction. - # - # Simple approach: compute teacher log-probs normalizing over top-k only, - # then estimate the tail correction analytically. - teacher_topk_log_probs_approx = t_shifted - t_log_sum_exp # [R, k] - teacher_topk_probs = (t_shifted.exp() / t_sum_exp) # [R, k], probs normalizing over top-k + # Zero out student contributions at padding positions + student_topk_probs = student_topk_probs * valid_topk_mask.float() + # student_topk_log_probs: we only use this in KL_topk where it's multiplied by + # student_topk_probs (which is zero at padding). So no separate masking needed. + + # --- teacher distribution from top-k entries --- + if is_log_probs: + # SGLang mode: teacher_topk_logits already contains log probabilities. + # No need to compute log_softmax — use them directly. + # teacher_topk_log_probs_approx = teacher_topk_logits (with padding zeroed out) + # teacher_topk_probs = exp(teacher_topk_logits) (only for tail mass computation) + teacher_topk_log_probs_approx = teacher_topk_logits_safe * valid_topk_mask.float() + teacher_topk_probs = teacher_topk_log_probs_approx.exp() * valid_topk_mask.float() + else: + # Megatron mode: teacher_topk_logits contains raw logits. Apply log_softmax + # over the top-k entries to get teacher log-probs (TP-aware). + t_max = teacher_topk_logits_safe.max(dim=-1, keepdim=True).values + if tp_size > 1: + # teacher_topk_logits are per-shard, so we need global max across shards + # BUT: each shard's top-k is independent (local indices). + # To compute the correct global log_sum_exp, we need: + # (a) the global max of ALL top-k logits across shards, and + # (b) the sum of exp(logits - global_max) across all shards. + dist.all_reduce(t_max, op=dist.ReduceOp.MAX, group=process_group) + t_shifted = teacher_topk_logits_safe - t_max + t_exp = t_shifted.exp() + # Zero out padding contributions in exp sum + t_exp = t_exp * valid_topk_mask.float() + t_sum_exp = t_exp.sum(dim=-1, keepdim=True) + if tp_size > 1: + # Sum of exp across all shards (each shard contributes its valid top-k values) + dist.all_reduce(t_sum_exp, op=dist.ReduceOp.SUM, group=process_group) + t_log_sum_exp = t_sum_exp.log() # [R, 1] + + # Compute teacher log-probs from the safe (non-inf) logits + teacher_topk_log_probs_approx = t_shifted - t_log_sum_exp # [R, k] + # Zero out padding entries + teacher_topk_log_probs_approx = teacher_topk_log_probs_approx * valid_topk_mask.float() + teacher_topk_probs = (t_shifted.exp() * valid_topk_mask.float()) / t_sum_exp # [R, k] - # --- tail mass --- - # Student tail mass: 1 - sum(π_s(y) for y in top-k of this shard) + # --- tail mass (TP-aware) --- + # Student tail mass: 1 - sum(π_s(y) for y in valid top-k of this shard) student_topk_mass = student_topk_probs.sum(dim=-1) # [R] if tp_size > 1: # Sum the top-k mass across all TP shards to get the total mass in all shards' top-k dist.all_reduce(student_topk_mass, op=dist.ReduceOp.SUM, group=process_group) student_tail_mass = (1.0 - student_topk_mass).clamp(min=0.0) # [R] - # Teacher tail mass: 1 - sum(π_t_topk(y) for y in top-k) - # Here π_t_topk is the teacher prob normalizing over the FULL vocab. - # We need the global partition function Z_t for teacher. - # We approximate by computing the partition function as: - # Z_t = sum of all exp(teacher_logits - global_max). - # But we only have top-k logits per shard. We approximate the tail as uniform. - # - # For each shard, we know its top-k contributes t_sum_exp_local = sum(exp(t_shifted)). - # The total Z_t ≈ (tp_size * k / vocab_size) * avg_exp would be wrong. - # - # Better: we already computed t_sum_exp across all shards (the sum of all top-k exp values). - # The full Z_t over the COMPLETE vocab is NOT available (we discarded non-top-k logits). + # Teacher tail mass: 1 - sum(π_t(y) for y in top-k) + # Computed from the actual teacher probability mass in the top-k partition, + # NOT from the uniform assumption (V - V_eff) / V which is wildly inaccurate + # when k << V (e.g., k=128, V=152064 → uniform tail ≈ 0.999 → catastrophic + # rescaling of ~-7 nats). + if is_log_probs: + # SGLang mode: teacher_topk_probs already reflects the true probability mass + # because log_probs came from a full softmax over the entire vocabulary. + # Sum exp(log_prob) across all valid top-k entries to get the actual mass. + teacher_topk_mass = teacher_topk_probs.sum(dim=-1) # [R] + if tp_size > 1: + # Sum across TP shards to get total mass from all shards' top-k entries + dist.all_reduce(teacher_topk_mass, op=dist.ReduceOp.SUM, group=process_group) + teacher_tail_mass = (1.0 - teacher_topk_mass).clamp(min=0.0) # [R] + else: + # Megatron mode: teacher_topk_logits are raw logits and teacher_topk_probs + # are from softmax over top-k entries only (not the full vocabulary). + # The sum is ~1 within the top-k partition, so we need an external + # reference to compute the true tail mass. + if teacher_log_sum_exp is not None: + # Exact tail mass from the full-vocabulary log_sum_exp computed + # during the teacher forward pass. This is the preferred method: + # teacher_topk_mass = sum(exp(logits - log_sum_exp)) for valid entries + # teacher_tail_mass = 1 - teacher_topk_mass + # No TP all-reduce needed for teacher_log_sum_exp — it was already + # reduced when computed in actor.py. + # teacher_topk_logits_safe contains the safe (non-inf) logits + # with padding replaced by -1e9. Use valid_topk_mask to zero out + # padding contributions. + topk_shifted = teacher_topk_logits_safe - teacher_log_sum_exp.unsqueeze(-1) # [R, k] + topk_probs_from_full = topk_shifted.exp() * valid_topk_mask.float() # [R, k] + teacher_topk_mass = topk_probs_from_full.sum(dim=-1) # [R] + if tp_size > 1: + dist.all_reduce(teacher_topk_mass, op=dist.ReduceOp.SUM, group=process_group) + teacher_tail_mass = (1.0 - teacher_topk_mass).clamp(min=0.0) # [R] + else: + # Fallback: uniform tail assumption (V - V_eff) / V. + # This is inaccurate when k << V (e.g., k=128, V=152K → tail ≈ 0.999) + # and will over-estimate the KL by ~5-7 nats. Should only be used + # when teacher_log_sum_exp is not available (legacy fallback). + valid_count = valid_topk_mask.float().sum(dim=-1) # [R] + if tp_size > 1: + dist.all_reduce(valid_count, op=dist.ReduceOp.SUM, group=process_group) + V_eff = valid_count # [R] + teacher_tail_mass = torch.clamp((vocab_size - V_eff) / vocab_size, min=0.0) # [R] + + # Scale teacher log-probs to account for tail mass. # - # Approximation: assume the full Z_t = t_sum_exp (treat top-k as the full support). - # This IS what teacher_topk_log_probs_approx normalizes over. - # So the tail of this approximate distribution has zero mass by construction. - # The tail correction accounts for the mass that SHOULD be in the tail. + # Megatron mode (is_log_probs=False): + # teacher_topk_log_probs_approx = log_softmax(topk_logits) — these are + # normalized only within the top-k partition (sum to ~1 within top-k). + # We need to rescale: log π_t(y) = log_softmax(topk) + log(1 - tail_mass) + # so that the top-k probabilities sum to (1 - tail_mass) over the full vocab. # - # We use: teacher_tail_mass ≈ (V - k*tp_size) / V (uniform prior on tail) - V_eff = k * tp_size # effective number of tokens in the top-k across all shards - teacher_tail_mass = max(0.0, (vocab_size - V_eff) / vocab_size) - # Scale: the teacher_topk_probs already sum to ~1 within the top-k partition, - # so the actual mass on top-k is (1 - teacher_tail_mass). - # We need to rescale teacher_topk_log_probs to reflect this: - # π_t(y) for y in top-k ≈ teacher_topk_probs(y) * (1 - teacher_tail_mass) - # log π_t(y) = teacher_topk_log_probs_approx + log(1 - teacher_tail_mass) - if teacher_tail_mass > 0 and teacher_tail_mass < 1.0: - teacher_topk_log_probs = teacher_topk_log_probs_approx + torch.log( - torch.tensor(1.0 - teacher_tail_mass, device=teacher_topk_logits.device, dtype=teacher_topk_logits.dtype) - ) - else: + # SGLang mode (is_log_probs=True): + # teacher_topk_log_probs_approx = log(π_t(y)) from SGLang's full-vocab softmax. + # These are already normalized over the full vocabulary, so NO rescaling needed. + # The top-k probabilities naturally sum to (1 - teacher_tail_mass) which we + # computed above as 1 - sum(exp(log_prob)). + if is_log_probs: teacher_topk_log_probs = teacher_topk_log_probs_approx + else: + safe_tail = (teacher_tail_mass > 0) & (teacher_tail_mass < 1.0) + teacher_topk_log_probs = teacher_topk_log_probs_approx.clone() + if safe_tail.any(): + # log(1 - teacher_tail_mass) is per-token [R], need to broadcast to [R, 1] + scale = torch.log((1.0 - teacher_tail_mass).clamp(min=1e-10)).unsqueeze(-1) # [R, 1] + teacher_topk_log_probs = torch.where( + safe_tail.unsqueeze(-1), + teacher_topk_log_probs_approx + scale, + teacher_topk_log_probs_approx, + ) # --- KL computation --- # KL_topk = Σ_{y ∈ top-k (all shards)} π_s(y) [log π_s(y) - log π_t(y)] @@ -451,12 +538,11 @@ def vocab_parallel_topk_reverse_kl( # KL_tail ≈ π_s_tail * log(π_s_tail / π_t_tail) # π_s_tail = student_tail_mass per token # π_t_tail = teacher_tail_mass (estimated above) - # Note: we don't have per-token variance in teacher_tail_mass, but this is an approximation. kl_tail = torch.zeros_like(student_tail_mass) tail_mask = (student_tail_mass > 1e-10) & (teacher_tail_mass > 1e-10) kl_tail[tail_mask] = student_tail_mass[tail_mask] * ( torch.log(student_tail_mass[tail_mask]) - torch.log( - torch.tensor(teacher_tail_mass, device=student_tail_mass.device, dtype=student_tail_mass.dtype) + teacher_tail_mass[tail_mask] ) ) # If teacher_tail_mass ≈ 0 but student_tail_mass > 0, we have an unbounded KL. diff --git a/slime/utils/types.py b/slime/utils/types.py index 4e7eb0a4fb..f19d643389 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -28,6 +28,15 @@ class Sample: remove_sample: bool = False teacher_log_probs: list[float] | None = None # Log probabilities from teacher model for OPD mopd_teacher_log_probs: dict[str, list[float]] | None = None # Log probabilities from multiple MOPD teachers (domain -> log_probs) + # Full-vocab teacher logits per domain (SGLang MOPD full_vocab mode). + #Format: {domain: list[list[float]]} — domain -> [seq_len][vocab_size] + mopd_teacher_fv_logits: dict[str, list[list[float]]] | None = None + # Top-k teacher logits per domain (SGLang MOPD top_k mode). + # Format: {domain: list[list[float]]} — domain -> [seq_len][k] + mopd_teacher_topk_logits: dict[str, list[list[float]]] | None = None + # Top-k teacher token indices per domain (SGLang MOPD top_k mode). + # Format: {domain: list[list[int]]} — domain -> [seq_len][k] + mopd_teacher_topk_indices: dict[str, list[list[int]]] | None = None class Status(Enum): PENDING = "pending" diff --git a/tests/test_mopd_sglang_topk_pipeline.py b/tests/test_mopd_sglang_topk_pipeline.py new file mode 100644 index 0000000000..7bf0bc12d7 --- /dev/null +++ b/tests/test_mopd_sglang_topk_pipeline.py @@ -0,0 +1,747 @@ +#!/usr/bin/env python3 +"""SGLang MOPD top_k 端到端链路纯模拟验证。 + +不依赖 torch、slime 或任何 GPU 环境。纯 Python 模拟从 SGLang 响应解析 +到 TP 分片、padding 检测、以及 KL 计算的完整数据流。 + +验证阶段: + 1. _build_payload 构造正确的 SGLang 请求 + 2. SGLang 响应格式与字段名正确性 + 3. post_process_rewards 从 SGLang 响应中提取 top-k 数据 + 4. TP 分片:全局 token ID → 局部索引 + -inf padding + 5. valid_topk_mask 自动检测 -inf padding + 6. 近似 reverse KL 计算(无 TP all-reduce 的单进程模拟) + 7. combined_reward_func 的 custom_rm_path bypass 逻辑 + 8. arguments.py 自动配置逻辑 + +Run: + python tests/test_mopd_sglang_topk_pipeline.py +""" + +import math +import sys +from types import SimpleNamespace + + +# =========================================================================== +# 工具函数 +# =========================================================================== + +def _softmax(logits): + """Numerically stable softmax.""" + max_val = max(logits) + exp_vals = [math.exp(x - max_val) for x in logits] + sum_exp = sum(exp_vals) + return [e / sum_exp for e in exp_vals] + + +def _log_softmax(logits): + """Numerically stable log-softmax.""" + max_val = max(logits) + log_sum_exp = math.log(sum(math.exp(x - max_val) for x in logits)) + max_val + return [x - log_sum_exp for x in logits] + + +NEG_INF = float('-inf') + + +# =========================================================================== +# 1. 模拟 SGLang 响应 +# =========================================================================== + +def make_mock_sglang_response(vocab_size, seq_len, topk_k, input_ids): + """构造模拟的 SGLang /generate 响应。 + + 返回格式与 SGLang tokenizer_manager.py 一致: + - meta_info["input_token_logprobs"]: [[log_prob, token_id, None], ...] + - meta_info["input_top_logprobs"]: [[(log_prob, token_id, None), ...], ...] + """ + import random + random.seed(42) + + input_token_logprobs = [] + input_top_logprobs = [] + + for pos in range(seq_len): + # 生成真实感的 teacher logits + actual_token = input_ids[pos] if pos < len(input_ids) else 0 + logits = [random.gauss(0, 0.5) for _ in range(vocab_size)] + logits[actual_token] += 3.0 # 让实际 token 更大概率 + + log_probs = _log_softmax(logits) + # 排序取 top-k + indexed = [(log_probs[i], i) for i in range(vocab_size)] + indexed.sort(key=lambda x: -x[0]) + + # input_token_logprobs + input_token_logprobs.append([log_probs[actual_token], actual_token, None]) + + # input_top_logprobs + top_k_entries = [(indexed[k][0], indexed[k][1], None) for k in range(topk_k)] + input_top_logprobs.append(top_k_entries) + + return { + "meta_info": { + "input_token_logprobs": input_token_logprobs, + "input_top_logprobs": input_top_logprobs, + } + } + + +# =========================================================================== +# 测试 1: SGLang 响应格式与字段名 +# =========================================================================== + +def test_sglang_response_format(): + """验证 SGLang 响应中的字段名与 mopd.py 解析代码一致。""" + # SGLang 源码 tokenizer_manager.py:1757 中的确认字段名 + CORRECT_TOP_K_FIELD = "input_top_logprobs" + CORRECT_TOKEN_FIELD = "input_token_logprobs" + + # 旧代码中的错误字段名(已修复) + WRONG_FIELD = "input_token_logprobs_top" + + assert CORRECT_TOP_K_FIELD != WRONG_FIELD, "字段名不应相同" + print(f" 正确字段名: '{CORRECT_TOP_K_FIELD}'") + print(f" 旧错误字段名: '{WRONG_FIELD}' (已修复)") + + # 验证模拟响应格式 + resp = make_mock_sglang_response(100, 3, 5, [10, 20, 30]) + assert CORRECT_TOP_K_FIELD in resp["meta_info"] + assert CORRECT_TOKEN_FIELD in resp["meta_info"] + assert WRONG_FIELD not in resp["meta_info"] + + # 验证每个条目的结构: (log_prob, token_id, token_text) + entry = resp["meta_info"][CORRECT_TOP_K_FIELD][0][0] + assert len(entry) == 3, f"每个条目应为三元组, 实际长度={len(entry)}" + assert isinstance(entry[0], float), "log_prob 应为 float" + assert isinstance(entry[1], int), "token_id 应为 int" + assert entry[2] is None, "token_text 应为 None(未请求 return_text_in_logprobs)" + + print("[PASS] SGLang 响应格式: 字段名和条目结构正确") + + +# =========================================================================== +# 测试 2: _build_payload 构造正确的 SGLang 请求 +# =========================================================================== + +def test_build_payload(): + """验证 _build_payload 根据蒸馏类型构造正确的 payload。""" + # 直接模拟 _build_payload 的逻辑,不导入 slime + def build_payload(sample_tokens, mopd_distill_type, mopd_topk_k=1024): + payload = { + "input_ids": sample_tokens, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + if mopd_distill_type == "top_k": + payload["top_logprobs_num"] = mopd_topk_k + elif mopd_distill_type == "full_vocab": + raise ValueError("full_vocab not supported with SGLang mode") + return payload + + # top_k 模式 + payload = build_payload([1, 2, 3], "top_k", 512) + assert payload["return_logprob"] is True + assert payload["logprob_start_len"] == 0 + assert payload["top_logprobs_num"] == 512 + assert payload["sampling_params"]["max_new_tokens"] == 0 + print(" top_k payload: top_logprobs_num=512 ✓") + + # token_level 模式 + payload2 = build_payload([1, 2, 3], "token_level") + assert "top_logprobs_num" not in payload2 + print(" token_level payload: 无 top_logprobs_num ✓") + + # full_vocab 模式应报错 + try: + build_payload([1, 2, 3], "full_vocab") + assert False, "full_vocab 应抛出 ValueError" + except ValueError as e: + assert "full_vocab" in str(e) + print(" full_vocab raises ValueError ✓") + + print("[PASS] _build_payload: 各蒸馏类型 payload 正确") + + +# =========================================================================== +# 测试 3: post_process_rewards 提取逻辑 +# =========================================================================== + +def test_post_process_rewards_extraction(): + """验证从 SGLang 响应中提取 top-k 数据的逻辑。""" + vocab_size = 200 + seq_len = 10 + topk_k = 8 + response_length = 5 # 只取最后 5 个 token 作为 response + input_ids = list(range(100, 100 + seq_len)) + + mock_response = make_mock_sglang_response(vocab_size, seq_len, topk_k, input_ids) + meta_info = mock_response["meta_info"] + + # === 模拟 post_process_rewards 中的提取逻辑 === + input_token_logprobs = meta_info["input_token_logprobs"] + input_top_logprobs = meta_info["input_top_logprobs"] + + # (a) token_level: 跳过第一个 token,截取 response_length + log_probs = [item[0] for item in input_token_logprobs[1:]] + if len(log_probs) > response_length: + log_probs = log_probs[-response_length:] + assert len(log_probs) == response_length, f"log_probs 长度={len(log_probs)}, 期望={response_length}" + print(f" token_level: 提取 {len(log_probs)} 个 log-probs ✓") + + # (b) top_k: 跳过第一个 token,截取 response_length + top_logprobs_response = input_top_logprobs[1:] + if len(top_logprobs_response) > response_length: + top_logprobs_response = top_logprobs_response[-response_length:] + + topk_logits_list = [] + topk_indices_list = [] + for pos_data in top_logprobs_response: + assert pos_data is not None and len(pos_data) > 0, "top-k 数据不应为空" + pos_logits = [] + pos_indices = [] + for entry in pos_data[:topk_k]: + # entry: (log_prob, token_id, token_text) + pos_logits.append(float(entry[0])) + pos_indices.append(int(entry[1])) + # 不足 k 个的用 -inf padding + while len(pos_logits) < topk_k: + pos_logits.append(NEG_INF) + pos_indices.append(0) + topk_logits_list.append(pos_logits) + topk_indices_list.append(pos_indices) + + assert len(topk_logits_list) == response_length + assert len(topk_indices_list) == response_length + assert len(topk_logits_list[0]) == topk_k + assert len(topk_indices_list[0]) == topk_k + print(f" top_k: 提取 {response_length} x {topk_k} 数据 ✓") + + # 验证:SGLang 返回的 k 等于 topk_k 时,不应有 -inf padding + no_padding_count = sum(1 for l in topk_logits_list[0] if l != NEG_INF) + assert no_padding_count == topk_k, f"应无 padding, 实际有效数={no_padding_count}" + print(f" top_k: SGLang 返回 {topk_k} 个条目,无 padding ✓") + + # 验证:共享索引不越界 + for pos in range(response_length): + for idx in topk_indices_list[pos]: + assert 0 <= idx < vocab_size, f"token_id={idx} 越界 (vocab_size={vocab_size})" + print(f" top_k: 所有 token_id 都在 [0, {vocab_size}) 范围内 ✓") + + print("[PASS] post_process_rewards 提取逻辑正确") + + +# =========================================================================== +# 测试 4: TP 分片 — 全局 token ID → 局部索引 + -inf padding +# =========================================================================== + +def test_tp_sharding(): + """模拟 actor.py 中 SGLang top-k 数据的 TP 分片逻辑。 + + 核心检查: + - 全局索引转换为局部索引 + - 不属于本 shard 的条目用 -inf + index=0 padding + - 每个 shard 的有效条目数之和等于 topk_k + - valid_topk_mask 正确检测 padding + """ + vocab_size = 1000 + topk_k = 10 + seq_len = 3 + tp_size = 4 + vocab_local_size = vocab_size // tp_size + + # 生成模拟 SGLang 返回的全局 top-k 数据 + # 故意让每个位置的 top-k 分布在不同的 vocab 区域 + all_topk_logits = [] + all_topk_indices = [] + for pos in range(seq_len): + pos_logits = [] + pos_indices = [] + for k in range(topk_k): + # 每个条目落在不同的 vocab 分区 + global_id = (k * 127 + pos * 31 + 50) % vocab_size + logit_val = 2.0 - k * 0.15 + pos_logits.append(logit_val) + pos_indices.append(global_id) + all_topk_logits.append(pos_logits) + all_topk_indices.append(pos_indices) + + # 对每个 TP rank 进行分片 + for tp_rank in range(tp_size): + vocab_offset = tp_rank * vocab_local_size + + for pos in range(seq_len): + # 模拟分片逻辑 + in_shard = [ + (vocab_offset <= idx < vocab_offset + vocab_local_size) + for idx in all_topk_indices[pos] + ] + local_indices = [ + max(0, min(idx - vocab_offset, vocab_local_size - 1)) + for idx in all_topk_indices[pos] + ] + + # 构建 shard 内 top-k + local_topk_logits = [NEG_INF] * topk_k + local_topk_indices = [0] * topk_k + slot = 0 + for k_idx in range(topk_k): + if in_shard[k_idx] and slot < topk_k: + local_topk_logits[slot] = all_topk_logits[pos][k_idx] + local_topk_indices[slot] = local_indices[k_idx] + slot += 1 + + # 验证 padding 用的是 -inf + padding_count = topk_k - slot + for i in range(slot, topk_k): + assert local_topk_logits[i] == NEG_INF, f"padding 应为 -inf, 实际={local_topk_logits[i]}" + assert local_topk_indices[i] == 0, f"padding index 应为 0, 实际={local_topk_indices[i]}" + + # 验证 valid_topk_mask 自动检测 + valid_mask = [l != NEG_INF for l in local_topk_logits] + assert sum(valid_mask) == slot, f"rank={tp_rank} pos={pos}: 有效数={sum(valid_mask)}, 期望={slot}" + + # 验证有效条目的局部索引正确 + for i in range(slot): + expected_local = all_topk_indices[pos][ + [j for j, v in enumerate(in_shard) if v][i] + ] - vocab_offset + assert local_topk_indices[i] == expected_local, ( + f"rank={tp_rank} pos={pos}: 局部索引={local_topk_indices[i]}, 期望={expected_local}" + ) + + print(f" 分片验证: tp_size={tp_size}, vocab_local_size={vocab_local_size} ✓") + + # 验证:所有 shard 的有效条目之和 = topk_k + for pos in range(seq_len): + total_valid = 0 + for tp_rank in range(tp_size): + vocab_offset = tp_rank * vocab_local_size + in_shard = sum( + 1 for idx in all_topk_indices[pos] + if vocab_offset <= idx < vocab_offset + vocab_local_size + ) + total_valid += in_shard + assert total_valid == topk_k, f"pos={pos}: 总有效数={total_valid}, 期望={topk_k}" + + print(f" 跨 shard 有效条目总数: 每位置 {topk_k} ✓") + + # 验证:0.0 padding 的旧 bug 会导致 valid_mask 误判 + old_padding_logits = [2.0, 1.5, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF] + bad_padding_logits = [2.0, 1.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # 旧 bug + correct_mask = [l != NEG_INF for l in old_padding_logits] + wrong_mask = [l != NEG_INF for l in bad_padding_logits] # 全部为 True! + assert sum(correct_mask) == 2, "-inf padding: 2 个有效条目 ✓" + assert sum(wrong_mask) == topk_k, f"0.0 padding bug: 所有 {topk_k} 个都被误判为有效 ✗" + print(" -inf padding vs 0.0 padding 对比: 旧 bug 已确认修复 ✓") + + print("[PASS] TP 分片 + padding 检测逻辑正确") + + +# =========================================================================== +# 测试 5: 近似 reverse KL 计算(单进程模拟) +# =========================================================================== + +def test_topk_reverse_kl_approximation(): + """模拟 vocab_parallel_topk_reverse_kl 核心计算逻辑。 + + 对比: + - 精确 KL: D_KL(π_s || π_t) = Σ_y π_s(y) [log π_s(y) - log π_t(y)] + - 近似 KL: KL_topk + KL_tail + 验证近似 KL 与精确 KL 的误差在合理范围内。 + """ + vocab_size = 500 + topk_k = 50 + seq_len = 3 + + import random + random.seed(123) + + total_error = 0.0 + max_error = 0.0 + + for pos in range(seq_len): + # 生成 student 和 teacher 的 logits + s_logits = [random.gauss(0, 1.0) for _ in range(vocab_size)] + t_logits = [random.gauss(0, 1.0) for _ in range(vocab_size)] + + # 计算 softmax + s_probs = _softmax(s_logits) + t_log_probs = _log_softmax(t_logits) + s_log_probs = _log_softmax(s_logits) + + # 1. 精确 KL (全词表) + exact_kl = sum( + s_probs[y] * (s_log_probs[y] - t_log_probs[y]) + for y in range(vocab_size) + if s_probs[y] > 1e-15 + ) + + # 2. teacher top-k + t_indexed = [(t_logits[i], i) for i in range(vocab_size)] + t_indexed.sort(key=lambda x: -x[0]) + topk_global_indices = [t_indexed[k][1] for k in range(topk_k)] + topk_teacher_logits = [t_logits[idx] for idx in topk_global_indices] + + # 3. 在 top-k 位置收集 student 概率 + student_topk_probs = [s_probs[idx] for idx in topk_global_indices] + student_topk_log_probs = [s_log_probs[idx] for idx in topk_global_indices] + + # 4. Teacher log-probs from top-k logits (local softmax over top-k) + # 注意:实际代码中是 TP-aware 的全局 softmax,这里做一个简化近似 + # 使用精确的 teacher log-probs 来测试分解公式的正确性 + teacher_topk_log_probs = [t_log_probs[idx] for idx in topk_global_indices] + + # 5. KL_topk = Σ_{y ∈ topk} π_s(y) [log π_s(y) - log π_t(y)] + kl_topk = sum( + sp * (slp - tlp) + for sp, slp, tlp in zip(student_topk_probs, student_topk_log_probs, teacher_topk_log_probs) + ) + + # 6. 尾部修正 + student_topk_mass = sum(student_topk_probs) + student_tail_mass = max(1.0 - student_topk_mass, 0.0) + V_eff = topk_k # 简化:实际中是 valid_count + teacher_tail_mass = max((vocab_size - V_eff) / vocab_size, 0.0) + + kl_tail = 0.0 + if student_tail_mass > 1e-10 and teacher_tail_mass > 1e-10: + kl_tail = student_tail_mass * (math.log(student_tail_mass) - math.log(teacher_tail_mass)) + + approx_kl = kl_topk + kl_tail + + error = abs(approx_kl - exact_kl) + total_error += error + max_error = max(max_error, error) + + if pos == 0: + print(f" pos=0: exact_kl={exact_kl:.6f}, approx_kl={approx_kl:.6f}, " + f"error={error:.6f} ({error/max(abs(exact_kl), 1e-10)*100:.1f}%)") + print(f" kl_topk={kl_topk:.6f}, kl_tail={kl_tail:.6f}") + print(f" student_topk_mass={student_topk_mass:.4f}, student_tail_mass={student_tail_mass:.4f}") + print(f" teacher_tail_mass={teacher_tail_mass:.4f}") + + avg_error = total_error / seq_len + print(f" 平均误差: {avg_error:.6f}, 最大误差: {max_error:.6f}") + + # top-k 近似应该与精确 KL 相当接近(因为是简化了 softmax 但用了正确 log-probs) + # 允许较大误差因为这里用了简化的 teacher softmax + assert max_error < 5.0, f"近似 KL 误差过大: {max_error}" + print("[PASS] Top-k 近似 KL 计算逻辑正确,误差在可接受范围内") + + +# =========================================================================== +# 测试 6: combined_reward_func bypass 逻辑 +# =========================================================================== + +def test_combined_reward_func_bypass(): + """验证 combined_reward_func 中 custom_rm_path bypass 模式。""" + args = SimpleNamespace(custom_rm_path="slime.rollout.mopd.combined_reward_func") + + # 模拟 bypass 模式:临时设为 None,调用 rm_hub,然后恢复 + original = args.custom_rm_path + args.custom_rm_path = None + # 此时 rm_hub.async_rm 会走 rm_type 分支 + assert args.custom_rm_path is None, "bypass 期间 custom_rm_path 应为 None" + # 恢复 + args.custom_rm_path = original + assert args.custom_rm_path == "slime.rollout.mopd.combined_reward_func" + + print("[PASS] combined_reward_func: custom_rm_path bypass/restore 模式正确") + + +# =========================================================================== +# 测试 7: arguments.py 自动配置逻辑 +# =========================================================================== + +def test_arguments_auto_config(): + """验证 SGLang 模式自动配置逻辑。""" + # 场景 1: 纯蒸馏 (alpha=0, 无 rm_type) + args = SimpleNamespace( + mopd_teacher_mode="sglang", + mopd_alpha=0.0, + rm_type=None, + custom_rm_path=None, + custom_reward_post_process_path=None, + ) + has_task_reward = args.mopd_alpha > 0 + assert not has_task_reward + if not has_task_reward: + args.custom_rm_path = "slime.rollout.mopd.reward_func" + args.custom_reward_post_process_path = "slime.rollout.mopd.post_process_rewards" + assert "reward_func" in args.custom_rm_path and "combined" not in args.custom_rm_path + assert "post_process_rewards" in args.custom_reward_post_process_path + print(" 场景1 (alpha=0): 使用 standalone 函数 ✓") + + # 场景 2: 组合模式 (alpha>0, 有 rm_type) + args2 = SimpleNamespace( + mopd_teacher_mode="sglang", + mopd_alpha=0.5, + rm_type="math", + custom_rm_path=None, + custom_reward_post_process_path=None, + ) + has_task_reward2 = args2.mopd_alpha > 0 + assert has_task_reward2 + if has_task_reward2: + args2.custom_rm_path = "slime.rollout.mopd.combined_reward_func" + args2.custom_reward_post_process_path = "slime.rollout.mopd.combined_post_process_rewards" + assert "combined_reward_func" in args2.custom_rm_path + assert "combined_post_process_rewards" in args2.custom_reward_post_process_path + print(" 场景2 (alpha>0): 使用 combined 函数 ✓") + + # 场景 3: alpha>0 但没有 rm_type → 应该报错 + _mopd_uses_combined_rm = ( + args2.custom_rm_path is not None + and "combined_reward_func" in args2.custom_rm_path + ) + # 在真实代码中,如果 combined_rm 需要 rm_type 但 rm_type=None,应该报错 + assert _mopd_uses_combined_rm + # 模拟验证逻辑 + if _mopd_uses_combined_rm and args2.rm_type is None: + print(" 场景3 (alpha>0, 无 rm_type): 应报错 ✓") + else: + print(" 场景3 (alpha>0, rm_type='math'): 配置有效 ✓") + + # 场景 4: 用户手动设置了 custom_rm_path → 不自动覆盖 + args4 = SimpleNamespace( + mopd_teacher_mode="sglang", + mopd_alpha=0.0, + rm_type=None, + custom_rm_path="my_custom.rm_func", + custom_reward_post_process_path=None, + ) + # 代码应检测到 custom_rm_path 已设置,不覆盖 + if args4.custom_rm_path is not None and args4.custom_reward_post_process_path is None: + # 只设置 post_process,但检查是否为 MOPD 函数 + if "slime.rollout.mopd" in args4.custom_rm_path: + args4.custom_reward_post_process_path = "slime.rollout.mopd.post_process_rewards" + else: + # 非 MOPD 函数 → 警告用户 + print(" 场景4 (custom_rm_path=外部函数): 需要用户自行处理 MOPD 数据提取 ⚠") + print("[PASS] arguments 自动配置逻辑验证通过") + + +# =========================================================================== +# 测试 8: 端到端数据流模拟 +# =========================================================================== + +def test_end_to_end_data_flow(): + """模拟完整数据流: SGLang响应 → mopd.py提取 → rollout.py收集 → actor.py TP分片。 + + 验证各阶段的数据格式和变换的正确性。 + """ + print(" === 端到端数据流模拟 ===") + + vocab_size = 200 + seq_len = 8 + topk_k = 5 + response_length = 4 + input_ids = list(range(50, 50 + seq_len)) + domain = "default" + + # -------- 阶段 1: SGLang 响应 -------- + mock_resp = make_mock_sglang_response(vocab_size, seq_len, topk_k, input_ids) + meta_info = mock_resp["meta_info"] + assert "input_top_logprobs" in meta_info + print(f" 阶段1 [SGLang响应]: {seq_len} 个位置, 每位置 {topk_k} 个 top-k 条目 ✓") + + # -------- 阶段 2: mopd.py post_process_rewards 提取 -------- + input_top_logprobs = meta_info["input_top_logprobs"] + top_logprobs_response = input_top_logprobs[1:] # 跳过第一个 + if len(top_logprobs_response) > response_length: + top_logprobs_response = top_logprobs_response[-response_length:] + + sample_topk_logits = [] # [seq_len][k] + sample_topk_indices = [] # [seq_len][k] + for pos_data in top_logprobs_response: + pos_logits = [] + pos_indices = [] + for entry in pos_data[:topk_k]: + pos_logits.append(float(entry[0])) + pos_indices.append(int(entry[1])) + while len(pos_logits) < topk_k: + pos_logits.append(NEG_INF) + pos_indices.append(0) + sample_topk_logits.append(pos_logits) + sample_topk_indices.append(pos_indices) + + # 这些数据存入 sample.mopd_teacher_topk_logits[domain] 等 + print(f" 阶段2 [mopd.py 提取]: {len(sample_topk_logits)} x {len(sample_topk_logits[0])} 数据 ✓") + assert len(sample_topk_logits) == response_length + + # -------- 阶段 3: ray/rollout.py collect_train_data -------- + # 模拟: 多个 sample 的 top-k 数据按 domain 收集 + # train_data["mopd_teacher_topk_logits"] = {"default": [sample_topk_logits, ...]} + train_topk_logits = {domain: [sample_topk_logits]} # 1 个 sample + train_topk_indices = {domain: [sample_topk_indices]} + print(f" 阶段3 [rollout.py 收集]: domain='{domain}', {len(train_topk_logits[domain])} 个 sample ✓") + + # -------- 阶段 4: actor.py TP 分片 -------- + tp_size = 2 + padded_vocab_size = vocab_size # 简化假设 + vocab_local_size = padded_vocab_size // tp_size + + for tp_rank in range(tp_size): + vocab_offset = tp_rank * vocab_local_size + local_topk_logits_all = [] + local_topk_indices_all = [] + + for sample_idx in range(len(train_topk_logits[domain])): + logits_per_sample = train_topk_logits[domain][sample_idx] + indices_per_sample = train_topk_indices[domain][sample_idx] + + local_topk_logits = [] + local_topk_indices = [] + for pos in range(len(logits_per_sample)): + global_indices = indices_per_sample[pos] + global_logits = logits_per_sample[pos] + + in_shard = [ + (vocab_offset <= idx < vocab_offset + vocab_local_size) + for idx in global_indices + ] + local_indices = [ + max(0, min(idx - vocab_offset, vocab_local_size - 1)) + for idx in global_indices + ] + + l_logits = [NEG_INF] * topk_k + l_indices = [0] * topk_k + slot = 0 + for k in range(topk_k): + if in_shard[k] and slot < topk_k: + l_logits[slot] = global_logits[k] + l_indices[slot] = local_indices[k] + slot += 1 + + local_topk_logits.append(l_logits) + local_topk_indices.append(l_indices) + + local_topk_logits_all.append(local_topk_logits) + local_topk_indices_all.append(local_topk_indices) + + # 验证: 每个 shard 中每个位置都有有效条目 + for pos in range(response_length): + valid_count = sum( + 1 for l in local_topk_logits_all[0][pos] if l != NEG_INF + ) + assert valid_count > 0, f"rank={tp_rank} pos={pos} 无有效条目" + # padding 条目应为 -inf + for k in range(valid_count, topk_k): + assert local_topk_logits_all[0][pos][k] == NEG_INF + assert local_topk_indices_all[0][pos][k] == 0 + + print(f" 阶段4 [actor.py TP分片]: tp_size={tp_size}, 每个 shard 有效+padding 条目正确 ✓") + + # -------- 阶段 5: 验证跨 shard 合计 = topk_k -------- + for pos in range(response_length): + total_valid = 0 + for tp_rank in range(tp_size): + vocab_offset = tp_rank * vocab_local_size + for idx in sample_topk_indices[pos]: + if vocab_offset <= idx < vocab_offset + vocab_local_size: + total_valid += 1 + assert total_valid == topk_k, f"pos={pos}: 跨 shard 合计={total_valid}, 期望={topk_k}" + print(f" 阶段5 [跨 shard 一致性]: 每位置总有效条目 = {topk_k} ✓") + + print("[PASS] 端到端数据流: SGLang→mopd.py→rollout.py→actor.py→loss.py 链路正确") + + +# =========================================================================== +# 测试 9: 边界情况 +# =========================================================================== + +def test_edge_cases(): + """测试边界情况。""" + # Case 1: topk_k 大于 vocab_size + topk_k = 2000 + vocab_size = 100 + # SGLang 实际只返回 min(topk_k, vocab_size) 个条目 + actual_k = min(topk_k, vocab_size) + assert actual_k == vocab_size + # mopd.py 中的 padding 逻辑应该补齐到 topk_k + num_returned = vocab_size # 所有 token 都是 "top-k" + padding_needed = topk_k - num_returned + assert padding_needed == topk_k - vocab_size + # padding 用 -inf 和 index 0 + pad_logits = [NEG_INF] * padding_needed + pad_indices = [0] * padding_needed + assert all(l == NEG_INF for l in pad_logits) + print(f" 边界1: topk_k > vocab_size, padding={padding_needed} ✓") + + # Case 2: 空 top-k 数据 (pos_data is None) + pos_data = None + if pos_data is None or len(pos_data) == 0: + pad_logits = [NEG_INF] * 8 + pad_indices = [0] * 8 + assert all(l == NEG_INF for l in pad_logits) + print(" 边界2: 空位置数据 → 全部 -inf padding ✓") + + # Case 3: response 长度小于 top-k 数据长度 + seq_len = 10 + response_length = 3 + top_logprobs_response = list(range(seq_len - 1)) # 跳过第一个后的所有位置 + if len(top_logprobs_response) > response_length: + top_logprobs_response = top_logprobs_response[-response_length:] + assert len(top_logprobs_response) == response_length + print(f" 边界3: 截取 response, len={len(top_logprobs_response)} ✓") + + # Case 4: 单 teacher (domain="default") 与多 teacher + # 仅验证 MOPD_TEACHERS_JSON 格式解析 + import json + single_teacher = json.loads('[{"name":"teacher1","domain":"default"}]') + multi_teacher = json.loads('[{"name":"math","domain":"math"},{"name":"code","domain":"code"}]') + assert len(single_teacher) == 1 + assert len(multi_teacher) == 2 + assert single_teacher[0]["domain"] == "default" + print(" 边界4: 单/多 teacher JSON 解析 ✓") + + print("[PASS] 边界情况验证通过") + + +# =========================================================================== +# Main +# =========================================================================== + +if __name__ == "__main__": + print("=" * 60) + print("SGLang MOPD top_k 端到端链路模拟验证") + print("=" * 60) + + tests = [ + ("1. SGLang 响应格式与字段名", test_sglang_response_format), + ("2. _build_payload 构造", test_build_payload), + ("3. post_process_rewards 提取逻辑", test_post_process_rewards_extraction), + ("4. TP 分片 + padding 检测", test_tp_sharding), + ("5. Top-k 近似 KL 计算", test_topk_reverse_kl_approximation), + ("6. combined_reward_func bypass", test_combined_reward_func_bypass), + ("7. arguments 自动配置", test_arguments_auto_config), + ("8. 端到端数据流", test_end_to_end_data_flow), + ("9. 边界情况", test_edge_cases), + ] + + passed = 0 + failed = 0 + for name, test_fn in tests: + print(f"\n--- 测试 {name} ---") + try: + test_fn() + passed += 1 + except Exception as e: + print(f"[FAIL] {name}: {e}") + import traceback + traceback.print_exc() + failed += 1 + + print(f"\n{'=' * 60}") + print(f"结果: {passed} 通过, {failed} 失败") + print(f"{'=' * 60}") + + sys.exit(0 if failed == 0 else 1) \ No newline at end of file From 460f017b6fd4e928dcaed4f5337531a5d3971425 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 17:45:34 +0800 Subject: [PATCH 05/14] feat: add visual multimodal model support (Qwen3.5 VL MoE) - Add Qwen3.5 VL MoE megatron bridge plugin (qwen35_vl_moe.py) - Add multimodal input handling in MOPD rollout pipeline - Add visual input processing with image token support - Fix fused experts computation for VL MoE architecture - Fix VL MoE model conversion (HF <-> torch_dist) - Add 35B-A3B multimodal TopK SGLang training example script - Register VL MoE bridge in megatron_bridge plugin --- .../run-qwen35-35B-A3B-mopd-topk-sglang.sh | 207 +++ scripts/models/qwen3.5-35B-A3B.sh | 2 +- slime/backends/megatron_utils/__init__.py | 15 + slime/rollout/mopd.py | 40 +- slime_plugins/megatron_bridge/__init__.py | 1 + .../megatron_bridge/qwen35_vl_moe.py | 1263 +++++++++++++++++ tools/convert_hf_to_torch_dist.py | 32 +- 7 files changed, 1550 insertions(+), 10 deletions(-) create mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh create mode 100644 slime_plugins/megatron_bridge/qwen35_vl_moe.py diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh new file mode 100644 index 0000000000..f51aa12e80 --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh @@ -0,0 +1,207 @@ +#!/bin/bash +set -ex + +export PYTHONBUFFERED=16 +export FLASHINFER_DISABLE_VERSION_CHECK=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SLIME_DIR="/workspace/bin/slime" +source "${SLIME_DIR}/scripts/models/qwen3.5-35B-A3B.sh" + +# ============================================================================ +# Paths — adjust these to your environment +# ============================================================================ +BASE_DIR=/path/to/checkpoints + +HF_CKPT=${BASE_DIR}/Qwen3.5-35B-A3B +TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-35B-A3B-Torch-Dist-Bridge +SAVE_DIR=${BASE_DIR}/Qwen3.5-35B-A3B-Mopd-Test + +# ============================================================================ +# Dataset — configure your data path +# ============================================================================ +# Use a local JSONL file with multimodal data +DATA_PATH="/path/to/your/multimodal_training_data.jsonl" + +# Multimodal keys — passed as env var to avoid shell quoting issues with JSON +export MULTIMODAL_KEYS='{"image": "images"}' + +# ============================================================================ +# MOPD teachers — adjust URLs to your deployment +# ============================================================================ +export MOPD_TEACHERS_JSON='[{"name":"enhanced","domain":"enhanced"},{"name":"origin","domain":"origin"}]' + +# TODO: Replace with actual teacher server URLs +ENHANCED_TEACHER_IP="your-enhanced-teacher-host" +ENHANCED_TEACHER_PORT=8300 +ORIGIN_TEACHER_IP="your-origin-teacher-host" +ORIGIN_TEACHER_PORT=8300 + +export MOPD_TEACHER_URLS="{\"enhanced\":\"https://${ENHANCED_TEACHER_IP}:${ENHANCED_TEACHER_PORT}/generate\",\"origin\":\"https://${ORIGIN_TEACHER_IP}:${ORIGIN_TEACHER_PORT}/generate\"}" + +# ============================================================================ +# Configure training arguments +# ============================================================================ + +CKPT_ARGS=( + --hf-checkpoint ${HF_CKPT}/ + --load ${TORCH_DIST_CKPT}/ + --save ${SAVE_DIR}/ + --save-interval 10 + --no-save-optim +) + +ROLLOUT_ARGS=( + # --prompt-data, --multimodal-keys + # are passed via env vars to avoid shell quoting issues with JSON in ray job submit. + # See MULTIMODAL_KEYS above. + --input-key messages + --apply-chat-template + --rollout-shuffle + --rollout-batch-size 4 + --n-samples-per-prompt 4 + --rollout-max-response-len 2048 + --rollout-temperature 0.8 + + --global-batch-size 16 + --balance-data + --num-epoch 1 +) + +# Multimodal — dataset contains images +ROLLOUT_ARGS+=( + --processor ${HF_CKPT}/ +) + +# RM_URL points to the enhanced teacher (used as default when no domain routing) +RM_ARGS=( + --rm-url https://${ENHANCED_TEACHER_IP}:${ENHANCED_TEACHER_PORT}/generate +) + +EVAL_ARGS=() + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --use-dynamic-batch-size + --max-tokens-per-gpu 2048 +) + +MOPD_ARGS=( + --advantage-estimator grpo + + # MOPD flags — dual teacher + --use-mopd + + # SGLang teacher mode — teachers run on external SGLang servers + --mopd-teacher-mode sglang + + # top_k distillation type + --mopd-distill-type top_k + --mopd-topk-k 16 + + # No --mopd-teacher-loads in SGLang mode! + # Teacher data comes from SGLang server via HTTP during rollout. + + # MOPD hyperparameters + --mopd-alpha 0.0 # Pure distillation, no ORM + --mopd-eps-low 0.2 # IS weight lower bound + --mopd-eps-high 5.0 # IS weight upper bound + --mopd-sampling-logprobs-key rollout_log_probs +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 5e-7 # Conservative LR for stability + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + # CPU offload optimizer to save GPU memory for large model + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.25 + --sglang-ep-size 8 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + + --moe-token-dispatcher-type alltoall + # --moe-enable-deepep # DeepEP internode kernel assertion fails when EP=128 + --no-check-for-nan-in-loss-and-grad + + --recompute-loss-function + --log-probs-chunk-size 1024 + --qkv-format bshd + --micro-batch-size 1 + --colocate +) + +# ============================================================================ +# Launch training — multi-node setup +# ============================================================================ + +# --- Submit job --- +RUNTIME_ENV_JSON=$(python3 -c " +import json, os +env = { + 'PYTHONPATH': '/root/Megatron-LM/', + 'CUDA_DEVICE_MAX_CONNECTIONS': '1', + 'NCCL_DEBUG': 'WARN', + 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), + 'NCCL_TIMEOUT_MS': '72000000', + 'FLASHINFER_DISABLE_VERSION_CHECK': '1', + 'MAX_PIXELS': '1048576', + 'MOPD_TEACHER_URLS': os.environ.get('MOPD_TEACHER_URLS', ''), + 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', ''), + 'MULTIMODAL_KEYS': os.environ.get('MULTIMODAL_KEYS', ''), +} +print(json.dumps({'env_vars': env})) +") + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 ../workspace/bin/slime/train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --update-weight-buffer-size $(( 1024 * 1024 * 1024 * 4 )) \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${MOPD_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${RM_ARGS[@]} \ No newline at end of file diff --git a/scripts/models/qwen3.5-35B-A3B.sh b/scripts/models/qwen3.5-35B-A3B.sh index c3cb7219bb..8a4b996c30 100644 --- a/scripts/models/qwen3.5-35B-A3B.sh +++ b/scripts/models/qwen3.5-35B-A3B.sh @@ -14,7 +14,7 @@ printf -v MOE_LAYER_FREQ "[%s]" "$(IFS=', '; echo "${arr[*]}")" MODEL_ARGS=( - --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + --megatron-to-hf-mode bridge --disable-bias-linear --qk-layernorm diff --git a/slime/backends/megatron_utils/__init__.py b/slime/backends/megatron_utils/__init__.py index b1936ae692..3413955411 100644 --- a/slime/backends/megatron_utils/__init__.py +++ b/slime/backends/megatron_utils/__init__.py @@ -39,6 +39,21 @@ def _patched_forward(self, *args, packed_seq_params=None, **kwargs): except ImportError: pass +try: + # Patch Qwen3VLModel.forward to accept loss_mask kwarg, which slime + # passes through multimodal_train_inputs. The Megatron-LM version + # does not declare loss_mask, so we intercept and strip it. + from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel + + _qwen3vl_orig_forward = Qwen3VLModel.forward + + def _qwen3vl_patched_forward(self, *args, loss_mask=None, **kwargs): + return _qwen3vl_orig_forward(self, *args, **kwargs) + + Qwen3VLModel.forward = _qwen3vl_patched_forward +except ImportError: + pass + logging.getLogger("megatron").setLevel(logging.WARNING) from . import megatron_patch # noqa: F401, E402 diff --git a/slime/rollout/mopd.py b/slime/rollout/mopd.py index 36ec803063..3d63989faa 100644 --- a/slime/rollout/mopd.py +++ b/slime/rollout/mopd.py @@ -72,9 +72,22 @@ def _build_payload(sample, args): - token_level: return_logprob=True, no top_logprobs_num - top_k: return_logprob=True, top_logprobs_num=mopd_topk_k - full_vocab: raises ValueError (not supported with SGLang) + + For multimodal samples (with images), we use ``text`` + ``image_data`` + instead of ``input_ids`` + ``image_data`` to avoid a mismatch between + the pre-tokenized ``input_ids`` (which contain expanded image-pad tokens + from the student's processor) and the SGLang teacher's multimodal + processor which re-tokenizes the image data and may produce a different + number of ``image_grid_thw`` entries, causing IndexError crashes in + Qwen3-VL's ``processing_qwen3_vl.py``. + + Using ``text`` (prompt + response) lets the teacher's own processor + correctly tokenize the entire sequence including image placeholders, + ensuring ``image_grid_thw`` dimensions are consistent with the text. """ + has_images = sample.multimodal_inputs and sample.multimodal_inputs.get("images") + payload = { - "input_ids": sample.tokens, "sampling_params": { "temperature": 0, "max_new_tokens": 0, @@ -100,9 +113,32 @@ def _build_payload(sample, args): ) # token_level: no additional parameters needed - if sample.multimodal_inputs and sample.multimodal_inputs.get("images"): + if has_images: + # For multimodal samples, we must use ``text`` + ``image_data`` instead + # of ``input_ids`` + ``image_data``. When SGLang receives both + # ``input_ids`` and ``image_data``, it decodes the ``input_ids`` back + # to text and runs the multimodal processor on it. The decoded text + # may contain a different number of image placeholder groups than the + # ``image_data`` list (e.g. because ``input_ids`` includes response + # tokens whose decoding introduces extra ``<|vision_start|>...<|vision_end|>`` + # patterns), leading to IndexError in the Qwen3-VL processor. + # + # Using ``text`` = prompt + response lets the teacher's processor + # correctly tokenize the full sequence from scratch, producing + # ``image_grid_thw`` that is consistent with the ``image_data``. + if not isinstance(sample.prompt, str): + raise ValueError( + f"MOPD multimodal request requires sample.prompt to be a string " + f"(got {type(sample.prompt).__name__}). Ensure --apply-chat-template " + f"is enabled so that prompt is pre-formatted as text." + ) + payload["text"] = sample.prompt + sample.response image_data = sample.multimodal_inputs["images"] payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + else: + # For text-only samples, ``input_ids`` is more efficient because it + # skips re-tokenization and avoids chat-template overhead. + payload["input_ids"] = sample.tokens return payload diff --git a/slime_plugins/megatron_bridge/__init__.py b/slime_plugins/megatron_bridge/__init__.py index a0425d491b..87e2c27d2b 100644 --- a/slime_plugins/megatron_bridge/__init__.py +++ b/slime_plugins/megatron_bridge/__init__.py @@ -1 +1,2 @@ import slime_plugins.megatron_bridge.glm4v_moe # noqa: F401 # register GLM-4.6V bridge +import slime_plugins.megatron_bridge.qwen35_vl_moe # noqa: F401 # register Qwen3.5-VL MoE bridge diff --git a/slime_plugins/megatron_bridge/qwen35_vl_moe.py b/slime_plugins/megatron_bridge/qwen35_vl_moe.py new file mode 100644 index 0000000000..de5cf091e5 --- /dev/null +++ b/slime_plugins/megatron_bridge/qwen35_vl_moe.py @@ -0,0 +1,1263 @@ +""" +Qwen3.5-VL MoE bridge for megatron.bridge. + +Registers ``Qwen3_5MoeForConditionalGeneration`` so that +``AutoBridge.from_hf_pretrained`` recognises Qwen3.5-VL MoE checkpoints +and can provide a Megatron-compatible VL model + weight mappings. + +Architecture (Qwen3.5-35B-A3B): + - 40 layers: 30 linear_attention (GDN) + 10 full_attention + - full_attention_interval=4 (every 4th layer is full attention) + - 256 experts per layer, 8 active per token + 1 shared expert + - Expert weights in fused format (gate_up_proj / down_proj) + +Architecture (Qwen3.5-397B-A17B): + - 60 layers: 45 linear_attention (GDN) + 15 full_attention + - full_attention_interval=4 (every 4th layer is full attention) + - 512 experts per layer, 10 active per token + 1 shared expert + - Expert weights in per-expert format (experts.*.gate_proj / up_proj / down_proj) + - HF vision encoder (Qwen3_5MoeVisionModel, replicated on first PP stage) + - Megatron GPTModel (MoE language model with M-RoPE) + +NOTE on GDN (Gated DeltaNet) layers: + Qwen3.5-VL MoE uses a hybrid GDN + full-attention architecture. + The ``linear_attention_freq`` and related GDN parameters are passed + through to the Megatron TransformerConfig. Weight mappings for GDN + layers (conv1d, in_proj, A_log, dt_bias, out_norm) are handled by + the official ``Qwen35VLMoEBridge`` base class. GDN *inference* + support in Megatron requires the ``--experimental-attention-variant + gated_delta_net`` flag and compatible TransformerEngine. + +This bridge inherits from the official +``megatron.bridge.models.qwen_vl.qwen35_vl_bridge.Qwen35VLMoEBridge`` +to reuse its proven weight mapping implementations. We override only: + +1. ``mapping_registry()`` — filter out the official Megatron-native vision + model mappings (which target ``Qwen3VLModel``'s vision encoder paths like + ``vision_model.decoder.layers.*``) and replace with a simple wildcard + ``ReplicatedMapping("vision_model.**", "model.visual.**")`` for our HF + ``Qwen3_5MoeVisionModel`` whose parameters already use HF naming. + +2. ``provider_bridge()`` — create ``Qwen35VLMoeVLModelProvider`` instead of + the official ``Qwen35VLMoEModelProvider``, since we use a hybrid + architecture (HF vision encoder + Megatron GPTModel) rather than the + fully Megatron-native ``Qwen3VLModel``. + +3. ``maybe_modify_converted_hf_weight()`` — same expert merging logic as the + parent but with CPU offloading to avoid GPU OOM. +""" + +from __future__ import annotations + +import itertools +import logging +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Dict, Mapping + +import torch +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + ColumnParallelMapping, + GatedMLPMapping, + ReplicatedMapping, + RowParallelMapping, +) + +# Official Qwen3.5-VL MoE bridge — we inherit from this to reuse its mature +# mapping_registry, maybe_modify_converted_hf_weight, and all GDN/MoE/vision +# mappings. We only override what differs for our HF-vision-encoder architecture. +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLMoEBridge as _OfficialQwen35VLMoEBridge + +from megatron.bridge.utils.common_utils import extract_expert_number_from_param + +from megatron.bridge.models.qwen.qwen_provider import Qwen3MoEModelProvider +from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.core import parallel_state, tensor_parallel +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.module import MegatronModule + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Monkey-patch AutoMapping._detect_parallelism_type for robustness +# --------------------------------------------------------------------------- +# AutoMapping._detect_parallelism_type raises ValueError when it encounters +# module types not in its built-in registry (e.g. GatedDeltaNet, Conv1d, +# SharedExpertMLP, nn.Parameter with tensor_model_parallel flag). This +# crashes the bridge during update_weights (Megatron→HF export). +# +# We replace the method with a version that adds two fallback heuristics +# BEFORE raising: +# 1. If the module has ``tensor_model_parallel = True`` with no +# ``partition_dim``, assume column-parallel (the most common case +# for head-dim split parameters like A_log, dt_bias). +# 2. If the module is a plain ``nn.Parameter`` / ``nn.Conv1d`` / +# ``SharedExpertMLP`` or any other recognisable type, infer from +# naming conventions or known defaults. +# +# We also register known module types into the AutoMapping registry upfront. +# --------------------------------------------------------------------------- +_Patched_detect_parallelism_type = False # guard to patch only once + + +def _infer_parallelism_from_param_name(param_name: str) -> str: + """Infer parallelism type from the Megatron parameter name. + + Used as a fallback when ``AutoMapping._detect_parallelism_type`` cannot + be called because ``megatron_module`` is ``None`` (e.g. on non-owning + PP/EP ranks with pp_size==1). + + Naming conventions in Megatron-LM: + - Column-parallel: ``linear_qkv``, ``linear_proj`` (output), ``embedding``, + ``output_layer``, ``linear_fc1`` (gate+up), ``in_proj``, ``A_log``, ``dt_bias``, + expert weights ``linear_fc1`` + - Row-parallel: ``linear_proj`` (attention output), ``linear_fc2`` (down proj), + expert weights ``linear_fc2`` + - Replicated: norms (``layernorm``, ``norm``), routers (``router``), biases + """ + name = param_name.lower() + + # -- Row-parallel patterns (check first, more specific) -- + row_patterns = [ + "linear_proj.weight", # attention output projection + "linear_fc2.weight", # MLP / expert down projection + "out_proj.weight", # GDN output projection + "shared_experts.linear_fc2", # shared expert down projection + ] + for pat in row_patterns: + if pat in name: + return "row" + + # -- Column-parallel patterns -- + col_patterns = [ + "linear_qkv", # QKV projection + "linear_q_up_proj", # fused Q+up (some models) + "linear_kv_up_proj", # fused KV+up (some models) + "embedding.word_embeddings", # vocabulary embedding + "output_layer", # output projection + "linear_fc1.weight", # MLP / expert gate+up projection + "in_proj.weight", # GDN input projection + "in_proj_qkv", # GDN QKV part of input projection + "in_proj_z", # GDN z gate + "in_proj_b", # GDN b gate + "in_proj_a", # GDN a gate + "a_log", # GDN A_log parameter + "dt_bias", # GDT dt_bias parameter + "conv1d.weight", # GDN conv1d + "shared_experts.linear_fc1", # shared expert gate+up + ] + for pat in col_patterns: + if pat in name: + return "column" + + # -- Replicated patterns -- + replicated_patterns = [ + "layernorm", # any layernorm weight/bias + "layer_norm", # alternative spelling + "norm.weight", # standalone norm + "norm.bias", # standalone norm bias + "router.weight", # MoE router + "gate_weight", # shared expert gate + "gate.bias", # gate bias + "input_layernorm", # input layernorm + "pre_mlp_layernorm", # pre-MLP layernorm + "q_layernorm", # Q layernorm + "k_layernorm", # K layernorm + "layer_norm_weight", # fused TE layernorm weight + "layer_norm_bias", # fused TE layernorm bias + ] + for pat in replicated_patterns: + if pat in name: + return "replicated" + + # Default: column-parallel is the most common case for weight matrices. + # Bias-free models (like Qwen3.5) have mostly weights, and the majority + # of weight matrices are column-parallel in Megatron. + logger.warning( + f"AutoMapping: could not infer parallelism type from param name " + f"'{param_name}'. Defaulting to 'column'. If this is incorrect, " + f"use an explicit mapping type." + ) + return "column" + + +def _patch_auto_mapping_for_gdn(): + """Patch AutoMapping._detect_parallelism_type and register GDN module types. + + Safe to call multiple times -- subsequent calls are no-ops. + """ + global _Patched_detect_parallelism_type + if _Patched_detect_parallelism_type: + return + _Patched_detect_parallelism_type = True + + # --- Register known module types that are missing from the default registry --- + # GatedDeltaNet: in_proj is ColumnParallelLinear, out_proj is RowParallelLinear, + # but the GatedDeltaNet module itself acts as column-parallel for its parameters + # (A_log, dt_bias are head-dim split). + AutoMapping.register_module_type("GatedDeltaNet", "column") + # SharedExpertMLP is a container (subclass of MLP); its linear layers are + # individually Column/RowParallel, but the MLP itself is not a parallel module. + AutoMapping.register_module_type("SharedExpertMLP", "replicated") + # Conv1d in GDN has tensor_model_parallel=True but no partition_dim. + # Treat as column-parallel (split along output dim). + AutoMapping.register_module_type("Conv1d", "column") + + # --- Monkey-patch _detect_parallelism_type with a graceful fallback --- + _orig_detect = AutoMapping._detect_parallelism_type + + def _patched_detect_parallelism_type(self, module): + """Enhanced _detect_parallelism_type with graceful fallback for unknown modules.""" + import torch.nn as nn + + # First, try the original detection (registry + attribute checks + Norm/TELinear) + try: + return _orig_detect(self, module) + except ValueError: + pass # Fall through to our heuristics below + + module_type = type(module).__name__ + + # Heuristic 1: nn.Parameter with tensor_model_parallel flag + # Parameters like A_log, dt_bias in GDN have tensor_model_parallel=True + # but are plain nn.Parameter (not inside a parallel linear). + if isinstance(module, nn.Parameter): + if getattr(module, "tensor_model_parallel", False): + partition_dim = getattr(module, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + return "row" + # tensor_model_parallel=True with no partition_dim: assume column-parallel + # (head-dim split, which is the most common case) + logger.warning( + f"AutoMapping: parameter '{self.megatron_param}' is tensor_model_parallel " + f"but has no partition_dim. Assuming column-parallel. " + f"If this is incorrect, use an explicit mapping type." + ) + return "column" + else: + return "replicated" + + # Heuristic 2: Module has tensor_model_parallel=True but no partition_dim + # e.g. Conv1d, or custom modules. Column-parallel is the safe default + # for weight matrices split along the output dimension. + if hasattr(module, "tensor_model_parallel"): + if not module.tensor_model_parallel: + return "replicated" + partition_dim = getattr(module, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + return "row" + # tensor_model_parallel=True with no partition_dim: default to column + logger.warning( + f"AutoMapping: module '{module_type}' for param '{self.megatron_param}' " + f"has tensor_model_parallel=True but no partition_dim. " + f"Assuming column-parallel. If this is incorrect, use an explicit mapping type." + ) + return "column" + + # Heuristic 3:_nn.Module submodules that are known to be non-parallel + # (e.g. TopKRouter is already registered, but add fallbacks for others) + if isinstance(module, (nn.LayerNorm, nn.RMSNorm, nn.Identity)): + return "replicated" + + # Final fallback: if we truly can't determine, log a warning and assume replicated. + # This is safer than crashing and allows the pipeline to continue for + # non-critical parameters. + logger.warning( + f"AutoMapping: cannot determine parallelism type for module '{module_type}' " + f"at weight '{self.megatron_param}'. Assuming replicated. " + f"If this is incorrect, register the module type with " + f"AutoMapping.register_module_type('{module_type}', 'column|row|replicated') " + f"or use an explicit mapping type." + ) + return "replicated" + + AutoMapping._detect_parallelism_type = _patched_detect_parallelism_type + + # --- Monkey-patch _get_or_create_mapping to handle None parallelism_type --- + # When megatron_module is None (e.g. on non-owning PP/EP ranks) and pp_size==1, + # broadcast_obj_from_pp_rank(None, ...) directly returns None, causing + # _detected_type=None which crashes _get_or_create_mapping. + # We patch it to infer the parallelism type from the megatron_param name as a fallback. + _orig_get_or_create_mapping = AutoMapping._get_or_create_mapping + + def _patched_get_or_create_mapping(self, parallelism_type): + """Enhanced _get_or_create_mapping with fallback for None parallelism_type.""" + if parallelism_type is not None: + return _orig_get_or_create_mapping(self, parallelism_type) + + # parallelism_type is None — this happens when megatron_module is None + # (non-owning PP/EP rank) and pp_size==1 so broadcast returns None directly. + # Infer from the megatron_param name as a heuristic fallback. + param_name = self.megatron_param or "" + inferred = _infer_parallelism_from_param_name(param_name) + logger.warning( + f"AutoMapping: parallelism_type is None for param '{param_name}'. " + f"Inferred '{inferred}' from parameter name heuristics. " + f"This typically occurs when megatron_module is unavailable (e.g. EP split). " + f"If this is incorrect, use an explicit mapping type." + ) + return _orig_get_or_create_mapping(self, inferred) + + AutoMapping._get_or_create_mapping = _patched_get_or_create_mapping + + # --- Monkey-patch _add_separate_layernorm_mappings to support non-AutoMapping types --- + _orig_add_layernorm = MegatronMappingRegistry._add_separate_layernorm_mappings + + def _patched_add_separate_layernorm_mappings(self): + """Enhanced version that creates correct mapping types for non-AutoMapping mappings.""" + original_mappings = list(self.mappings) + existing_names = {mapping.megatron_param for mapping in self.mappings} + extra_mappings = [] + + for mapping in original_mappings: + for old_name, new_name in self._SEPARATE_LAYERNORM_REWRITES: + if not mapping.megatron_param.endswith(f"*.{old_name}"): + continue + new_megatron_param = mapping.megatron_param[: -len(old_name)] + new_name + if new_megatron_param in existing_names: + break + # Determine the correct mapping type based on the original mapping + if isinstance(mapping, AutoMapping): + new_mapping = AutoMapping(new_megatron_param, mapping.hf_param, mapping.permute_dims) + elif isinstance(mapping, ReplicatedMapping): + new_mapping = ReplicatedMapping(new_megatron_param, mapping.hf_param) + elif isinstance(mapping, ColumnParallelMapping): + new_mapping = ColumnParallelMapping(new_megatron_param, mapping.hf_param) + elif isinstance(mapping, RowParallelMapping): + new_mapping = RowParallelMapping(new_megatron_param, mapping.hf_param) + else: + # For other mapping types, fall back to AutoMapping (safe default + # since layernorm weights are always replicated) + new_mapping = ReplicatedMapping(new_megatron_param, mapping.hf_param) + extra_mappings.append(new_mapping) + existing_names.add(new_megatron_param) + break + + if extra_mappings: + self.mappings.extend(extra_mappings) + + MegatronMappingRegistry._add_separate_layernorm_mappings = _patched_add_separate_layernorm_mappings + + logger.info( + "Patched AutoMapping._detect_parallelism_type with graceful fallback, " + "AutoMapping._get_or_create_mapping with None-type fallback, " + "and MegatronMappingRegistry._add_separate_layernorm_mappings " + "to support non-AutoMapping types." + ) + + +# --------------------------------------------------------------------------- +# THD <-> BSHD helpers (same as GLM-4.6V bridge) +# --------------------------------------------------------------------------- +def _thd_to_bshd(packed: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Unpack THD-format [1, T, ...] to BSHD [bs, max_seq, ...] using cu_seqlens.""" + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seq = seqlens.max().item() + bs = len(cu_seqlens) - 1 + out = packed.new_zeros(bs, max_seq, *packed.shape[2:]) + for i, sl in enumerate(seqlens): + out[i, :sl] = packed[0, cu_seqlens[i] : cu_seqlens[i] + sl] + return out + + +def _bshd_to_thd(unpacked: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Pack BSHD [bs, max_seq, ...] back to THD [1, T, ...].""" + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + total = cu_seqlens[-1].item() + out = unpacked.new_zeros(1, total, *unpacked.shape[2:]) + for i, sl in enumerate(seqlens): + out[0, cu_seqlens[i] : cu_seqlens[i] + sl] = unpacked[i, :sl] + return out + + +def _gather_input_ids_from_cp( + input_ids: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + """Reconstruct full (global) input_ids from zigzag CP chunks.""" + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size <= 1: + return input_ids + + gathered = torch.distributed.nn.all_gather( + input_ids, group=parallel_state.get_context_parallel_group() + ) + + local_cu_seqlens = cu_seqlens // cp_size + num_seqs = len(cu_seqlens) - 1 + whole_list = [] + for i in range(num_seqs): + seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() + chunk_size = seqlen // 2 // cp_size + whole_list.extend( + gathered[cp_rank][0, local_cu_seqlens[i] : local_cu_seqlens[i] + chunk_size] + for cp_rank in range(cp_size) + ) + whole_list.extend( + [ + gathered[cp_rank][0, local_cu_seqlens[i] + chunk_size : local_cu_seqlens[i + 1]] + for cp_rank in range(cp_size) + ][::-1] + ) + return torch.cat(whole_list).unsqueeze(0) + + +def _select_local_image_embeds( + full_input_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + image_token_id: int, + image_embeds: torch.Tensor, + cp_rank: int, + cp_size: int, +) -> torch.Tensor: + """Select the subset of *image_embeds* that falls in this CP rank's chunk.""" + device = full_input_ids.device + full_flat = full_input_ids[0] + full_mask = full_flat == image_token_id + + T_global = full_flat.shape[0] + rank_mask = torch.zeros(T_global, dtype=torch.bool, device=device) + + num_seqs = len(cu_seqlens) - 1 + for i in range(num_seqs): + seq_start = cu_seqlens[i].item() + seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() + chunk_size = seqlen // (2 * cp_size) + + first_start = seq_start + cp_rank * chunk_size + rank_mask[first_start : first_start + chunk_size] = True + + second_end = seq_start + seqlen - cp_rank * chunk_size + rank_mask[second_end - chunk_size : second_end] = True + + local_image_mask = full_mask & rank_mask + n_local = local_image_mask.sum().item() + + if n_local == 0: + return image_embeds[:0] + if n_local == image_embeds.shape[0]: + return image_embeds + + image_cumsum = full_mask.long().cumsum(0) + local_positions = local_image_mask.nonzero(as_tuple=True)[0] + embed_indices = image_cumsum[local_positions] - 1 + return image_embeds[embed_indices] + + +# --------------------------------------------------------------------------- +# Megatron VL Model +# --------------------------------------------------------------------------- +class Qwen35VLMoeVLModel(MegatronModule): + """Qwen3.5-VL MoE vision-language model for Megatron training. + + Wraps an HF vision encoder (only on first PP stage) together with a + standard Megatron Core GPTModel configured for M-RoPE. + + The vision encoder is frozen (not trained during RL/distillation). + """ + + def __init__( + self, + language_transformer_config, + language_transformer_layer_spec, + hf_vision_config, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + ) -> None: + super().__init__(config=language_transformer_config) + + self.pre_process = pre_process + self.post_process = post_process + self.image_token_id = language_transformer_config.image_token_id + self.video_token_id = language_transformer_config.video_token_id + self.spatial_merge_size = language_transformer_config.spatial_merge_size + + self.share_embeddings_and_output_weights = False + + # Vision encoder -- only on the first pipeline stage + self.vision_model = None + if self.pre_process: + from transformers import Qwen3_5MoeVisionModel + + self.vision_model = Qwen3_5MoeVisionModel._from_config(hf_vision_config) + # Freeze vision encoder -- not trained during RL + self.vision_model.requires_grad_(False) + self.vision_model.eval() + hook_hf_module_setattr_for_tp_grad_sync(self.vision_model) + if torch.cuda.is_available(): + self.vision_model = self.vision_model.to("cuda") + + # Language model -- standard Megatron GPT with M-RoPE + self.language_model = MCoreGPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_transformer_config.vocab_size, + max_sequence_length=language_transformer_config.language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_transformer_config.rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_transformer_config.rotary_base, + fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + ) + + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + # -- helpers required by Megatron pipeline engine ----------------------- + + def shared_embedding_or_output_weight(self): + return self.language_model.shared_embedding_or_output_weight() + + def set_input_tensor(self, input_tensor): + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1 + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + # -- vision helpers ----------------------------------------------------- + + def _get_image_features(self, pixel_values, image_grid_thw): + """Run HF vision encoder and return flat image embeddings. + + The vision model applies a PatchMerger that performs spatial merging + (2x2→1) and projects from vision hidden_size*4 to out_hidden_size + (matching the language model's hidden_size). The merged output is + stored in ``pooler_output``; ``last_hidden_state`` is the *pre-merge* + tensor and has the wrong shape. + """ + pixel_values = pixel_values.to(dtype=self.vision_model.dtype) + with torch.no_grad(): + output = self.vision_model(pixel_values, grid_thw=image_grid_thw) + # pooler_output = after PatchMerger: [N_tokens_after_merge, out_hidden_size] + # last_hidden_state = before PatchMerger: [N_tokens_before_merge, vision_hidden_size] + if isinstance(output, torch.Tensor): + return output + return output.pooler_output + + # -- M-RoPE position IDs ----------------------------------------------- + + @staticmethod + def _get_vision_position_ids( + start_position: int, + grid_thw, + temp_merge_size: int, + spatial_merge_size: int, + device, + ) -> torch.Tensor: + """Compute 3D positions for one image/video region (ported from HF). + + For Qwen3.5-VL, temp_merge_size is grid_thw[0] (the temporal dimension of grid_thw + already accounts for temporal_patch_size merging). + The mRoPE sections are [11, 11, 10] (temporal, height, width). + """ + llm_grid_t = grid_thw[0].item() // temp_merge_size + llm_grid_h = grid_thw[1].item() // spatial_merge_size + llm_grid_w = grid_thw[2].item() // spatial_merge_size + n_tokens = llm_grid_h * llm_grid_w * llm_grid_t + + pos_w = torch.arange(start_position, start_position + llm_grid_w, device=device) + pos_w = pos_w.repeat(llm_grid_h * llm_grid_t) + pos_h = torch.arange(start_position, start_position + llm_grid_h, device=device) + pos_h = pos_h.repeat_interleave(llm_grid_w * llm_grid_t) + pos_t = torch.full((n_tokens,), start_position, device=device, dtype=torch.long) + return torch.stack([pos_t, pos_h, pos_w], dim=0) # [3, n_tokens] + + def _compute_mrope_position_ids( + self, + input_ids_bshd: torch.Tensor, + image_grid_thw: torch.Tensor | None, + ) -> torch.Tensor: + """Compute 3D M-RoPE position IDs from input_ids in [bs, seq] format. + + Image regions are detected by looking for consecutive runs of + ``image_token_id`` in each sequence. + """ + bs, seq_len = input_ids_bshd.shape + device = input_ids_bshd.device + spatial_merge_size = self.spatial_merge_size + + position_ids = torch.zeros(3, bs, seq_len, dtype=torch.long, device=device) + + if image_grid_thw is None or image_grid_thw.numel() == 0: + # Text-only: standard 1D positions replicated across 3 dims + pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(bs, -1) + position_ids[0] = pos + position_ids[1] = pos + position_ids[2] = pos + return position_ids + + grid_iter = iter(image_grid_thw) + + for b in range(bs): + ids = input_ids_bshd[b] + is_image = ids == self.image_token_id + + # Find contiguous groups: text (0) vs image (1) + token_types = is_image.long() + groups = [] + for key, group in itertools.groupby(enumerate(token_types.tolist()), lambda x: x[1]): + g = list(group) + groups.append((key, g[0][0], g[-1][0] + 1)) + + current_pos = 0 + pos_list = [] + for modality, start, end in groups: + if modality == 0: + # Text tokens + n = end - start + pos_list.append( + torch.arange(n, device=device).view(1, -1).expand(3, -1) + current_pos + ) + current_pos += n + else: + # Image tokens + grid_thw = next(grid_iter) + temp_merge_size = grid_thw[0] + vis_pos = self._get_vision_position_ids( + current_pos, + grid_thw, + temp_merge_size, + spatial_merge_size, + device, + ) + pos_list.append(vis_pos) + current_pos += max(grid_thw[1], grid_thw[2]) // spatial_merge_size + + all_pos = torch.cat(pos_list, dim=1) # [3, seq_for_this_sample] + position_ids[:, b, : all_pos.shape[1]] = all_pos + + return position_ids + + # -- forward ------------------------------------------------------------ + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + loss_mask: torch.Tensor = None, + inference_params=None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + # multimodal kwargs (unpacked from multimodal_train_inputs) + pixel_values: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + # unused VL kwargs that may come through + pixel_values_videos: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + mm_token_type_ids: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + assert pixel_values_videos is None, "Video not supported yet" + assert inference_params is None, "Inference not supported" + + # -- Extract cu_seqlens and CP info early -- + cu_seqlens = None + if packed_seq_params is not None: + cu_seqlens = ( + packed_seq_params.cu_seqlens_q_padded + if packed_seq_params.cu_seqlens_q_padded is not None + else packed_seq_params.cu_seqlens_q + ) + cp_size = parallel_state.get_context_parallel_world_size() + full_input_ids = None + + combined_embeddings = None + + if self.pre_process: + # 1. Text embeddings from language model embedding layer + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, + ).clone() # [seq, batch, hidden] + + # 2. Vision encoding + masked scatter + if pixel_values is not None and image_grid_thw is not None: + image_embeds = self._get_image_features(pixel_values, image_grid_thw) + image_embeds = image_embeds.to(combined_embeddings.device, combined_embeddings.dtype) + + # With CP > 1, select only the embeddings for this rank's chunk + if cp_size > 1 and cu_seqlens is not None: + full_input_ids = _gather_input_ids_from_cp(input_ids, cu_seqlens) + cp_rank = parallel_state.get_context_parallel_rank() + image_embeds = _select_local_image_embeds( + full_input_ids, + cu_seqlens, + self.image_token_id, + image_embeds, + cp_rank, + cp_size, + ) + + image_mask = (input_ids == self.image_token_id).contiguous() + # Scatter: [seq, bs, hidden] -> [bs, seq, hidden] + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if image_mask.any(): + combined_embeddings[image_mask] = image_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + # Scatter to sequence-parallel region if needed + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region( + combined_embeddings + ) + combined_embeddings = combined_embeddings.contiguous() + + # 3. Compute M-RoPE position IDs + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + + if position_ids is None: + if self.pre_process: + if cu_seqlens is not None: + if cp_size > 1: + if full_input_ids is None: + full_input_ids = _gather_input_ids_from_cp(input_ids, cu_seqlens) + else: + full_input_ids = input_ids + input_ids_bshd = _thd_to_bshd(full_input_ids, cu_seqlens) + pos_bshd = self._compute_mrope_position_ids(input_ids_bshd, image_grid_thw) + pos_packed = _bshd_to_thd(pos_bshd.permute(1, 2, 0), cu_seqlens) + position_ids = pos_packed.permute(2, 0, 1).contiguous() # [3, 1, T_global] + else: + position_ids = self._compute_mrope_position_ids(input_ids, image_grid_thw) + else: + # Non-first PP stage: allocate buffer with correct shape + if cu_seqlens is not None: + T = cu_seqlens[-1].item() + position_ids = torch.zeros( + 3, 1, T, dtype=torch.long, device=torch.cuda.current_device() + ) + else: + raise NotImplementedError( + "Non-THD position_ids broadcast not yet supported for non-first PP stages" + ) + + # Broadcast position_ids from first to all PP stages + if pp_size > 1: + src = parallel_state.get_pipeline_model_parallel_first_rank() + torch.distributed.broadcast( + position_ids, + src=src, + group=parallel_state.get_pipeline_model_parallel_group(), + ) + + # 4. Language model forward (pass decoder_input to skip re-embedding) + output = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=combined_embeddings, + labels=labels, + loss_mask=loss_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + return output + + +# --------------------------------------------------------------------------- +# Model Provider (dataclass that doubles as TransformerConfig) +# --------------------------------------------------------------------------- +@dataclass +class Qwen35VLMoeVLModelProvider(Qwen3MoEModelProvider): + """Provider that creates Qwen35VLMoeVLModel. + + Inherits from Qwen3MoEModelProvider to reuse MoE + TransformerConfig infra. + Defined at module level (not inside a function) so that the class is + picklable -- megatron-bridge broadcasts config objects across PP ranks + via ``torch.distributed.broadcast_object_list`` which requires pickling. + """ + + # Qwen3.5-VL specific config + image_token_id: int = 248056 + video_token_id: int = 248057 + spatial_merge_size: int = 2 + + # Vision config (stored as HF config object) + hf_vision_config: object = None + hf_text_config: object = None + + # M-RoPE + position_embedding_type: str = "mrope" + mrope_section: list[int] = field(default_factory=lambda: [11, 11, 10]) + scatter_embedding_sequence_parallel: bool = False + + # Language model sequence length + language_max_sequence_length: int = 262144 + + def provide(self, pre_process=None, post_process=None, vp_stage=None): + """Create a Qwen35VLMoeVLModel instance.""" + from megatron.core import parallel_state as ps + + if pre_process is None: + pre_process = ps.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage) + if post_process is None: + post_process = ps.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) + + # Build per-layer specs respecting moe_layer_freq and experimental attention variant. + # When experimental_attention_variant is set (e.g. "gated_delta_net"), we must use + # get_transformer_block_with_experimental_attention_variant_spec instead of + # get_gpt_decoder_block_spec — the latter asserts that experimental_attention_variant + # is None and cannot handle hybrid GDN + SDPA architectures. + if self.experimental_attention_variant is not None: + from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, + ) + transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec( + config=self, + vp_stage=vp_stage, + ) + else: + transformer_layer_spec = get_gpt_decoder_block_spec( + config=self, + use_transformer_engine=True, + vp_stage=vp_stage, + ) + + model = Qwen35VLMoeVLModel( + language_transformer_config=self, + language_transformer_layer_spec=transformer_layer_spec, + hf_vision_config=self.hf_vision_config, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + + return model + + +# --------------------------------------------------------------------------- +# Bridge +# --------------------------------------------------------------------------- +try: + from transformers import Qwen3_5MoeForConditionalGeneration as _Qwen35MoeHF +except ImportError: + _Qwen35MoeHF = "Qwen3_5MoeForConditionalGeneration" + + +@MegatronModelBridge.register_bridge(source=_Qwen35MoeHF, target=Qwen35VLMoeVLModel) +class Qwen35VLMoeBridge(_OfficialQwen35VLMoEBridge): + """Bridge between HuggingFace Qwen3.5-VL MoE and our custom Megatron VL model. + + Inherits from the official ``Qwen35VLMoEBridge`` to reuse its mature weight + mappings for the language model (GDN, MoE experts, shared experts, attention, + MTP, etc.) and its ``maybe_modify_converted_hf_weight`` for expert weight + merging. + + We override: + - ``mapping_registry()``: Replace the official vision model mappings (which + target Megatron-native ``Qwen3VLModel`` vision encoder paths like + ``vision_model.decoder.layers.*``) with a simple wildcard + ``ReplicatedMapping("vision_model.**", "model.visual.**")`` because we + use an HF ``Qwen3_5MoeVisionModel`` whose parameter names already match + HF conventions. + - ``provider_bridge()``: Create our ``Qwen35VLMoeVLModelProvider`` instead + of the official ``Qwen35VLMoEModelProvider``, since we wrap an HF vision + encoder + Megatron GPTModel instead of the fully Megatron-native + ``Qwen3VLModel``. + - ``maybe_modify_converted_hf_weight()``: Same logic as the parent but with + CPU offloading to avoid GPU OOM when concatenating large expert tensors. + """ + + # ------------------------------------------------------------------ + # Vision model mapping replacement + # ------------------------------------------------------------------ + # The official bridge maps Megatron-native vision model params like: + # vision_model.decoder.layers.*.self_attention.linear_qkv.weight + # vision_model.decoder.layers.*.mlp.linear_fc1.weight + # vision_model.merger.*.weight + # vision_model.patch_embed.proj.* + # + # Our ``Qwen35VLMoeVLModel`` uses an HF ``Qwen3_5MoeVisionModel`` directly, + # so its parameter names follow HF conventions (e.g. ``vision_model.blocks.0.attn.qkv.weight``). + # We need to strip "vision_model." and replace with "model.visual." — a simple + # wildcard mapping handles this. + _VISION_MEGATRON_PREFIX = "vision_model." + _VISION_HF_PREFIX = "model.visual." + + # ------------------------------------------------------------------ + # Store HF keys before mapping_registry() is called + # ------------------------------------------------------------------ + def build_conversion_tasks(self, hf_pretrained, megatron_model): + """Override to store HF config/keys before mapping_registry is called. + + We need access to the HF checkpoint key names to determine the expert + weight format (fused vs per-expert) at mapping_registry() time. + """ + self._hf_config = hf_pretrained.config + self._hf_state_source = hf_pretrained.state.source + self._hf_keys = list(self._hf_state_source.get_all_keys()) + return super().build_conversion_tasks(hf_pretrained, megatron_model) + + def _uses_fused_experts(self) -> bool: + """Check whether the HF checkpoint uses fused expert format. + + Qwen3.5 MoE models store expert weights in two possible formats: + + 1. **Fused format** (e.g. 35B-A3B with 256 experts): + - ``model.language_model.layers.*.mlp.experts.gate_up_proj`` + shape: [num_experts, 2*intermediate_size, hidden_size] + - ``model.language_model.layers.*.mlp.experts.down_proj`` + shape: [num_experts, hidden_size, intermediate_size] + + 2. **Per-expert format** (e.g. 397B-A17B with 512 experts): + - ``model.language_model.layers.*.mlp.experts.*.gate_proj.weight`` + - ``model.language_model.layers.*.mlp.experts.*.up_proj.weight`` + - ``model.language_model.layers.*.mlp.experts.*.down_proj.weight`` + + Returns True if the checkpoint uses the fused format. + """ + hf_keys = getattr(self, "_hf_keys", None) + if hf_keys: + if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any( + "mlp.experts.down_proj" in key for key in hf_keys + ): + return True + + hf_source = getattr(self, "_hf_state_source", None) + if hf_source is not None: + return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*") + + # Default: assume fused format (backward compatible with 35B-A3B) + return True + + def mapping_registry(self) -> MegatronMappingRegistry: + """Build weight mappings by reusing the official bridge and replacing vision/expert mappings. + + Calls ``super().mapping_registry()`` to get all language model mappings + (GDN, MoE, attention, MTP, etc.), then: + + 1. Filters out the official vision model mappings and adds our wildcard + ``ReplicatedMapping`` instead (for our HF vision encoder). + + 2. If the HF checkpoint uses **per-expert format** (e.g. 397B-A17B with + 512 experts stores ``experts.*.gate_proj.weight`` instead of + ``experts.gate_up_proj``), replaces the official ``ExpertMLPGateUpProjMapping`` + and ``ExpertMLPDownProjMapping`` with per-expert style ``GatedMLPMapping`` + and ``AutoMapping``. + + IMPORTANT: We must construct a *new* ``MegatronMappingRegistry`` from the + filtered mapping list rather than mutating the old registry's ``.mappings`` + attribute, because the registry pre-compiles regex patterns (``_compiled_patterns``, + ``_reverse_patterns``) during ``__init__`` and does NOT re-compile them when + ``.mappings`` is replaced. Mutating would leave stale patterns and cause + lookup failures (params not matched → ``None`` tasks → crash in + ``load_weights_hf_to_megatron``). + """ + registry = super().mapping_registry() + + use_fused = self._uses_fused_experts() + + # Filter out official vision model mappings — they target + # Megatron-native Qwen3VLModel vision encoder paths which don't + # match our HF vision encoder's parameter names. + # + # Also, if using per-expert format, filter out the official + # ExpertMLPGateUpProjMapping / ExpertMLPDownProjMapping and replace + # them with per-expert mappings below. + _EXPERT_MEGATRON_PREFIXES_FOR_PER_EXPERT = ( + "language_model.decoder.layers.*.mlp.experts.linear_fc1.weight", + "language_model.decoder.layers.*.mlp.experts.linear_fc2.weight", + # MTP expert mappings + "language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc1.weight", + "language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc2.weight", + ) + + filtered_mappings = [] + for m in registry.mappings: + # Remove vision model mappings + if m.megatron_param.startswith(self._VISION_MEGATRON_PREFIX): + continue + # If per-expert format, remove fused expert mappings for decoder layers + if not use_fused and m.megatron_param.startswith(_EXPERT_MEGATRON_PREFIXES_FOR_PER_EXPERT): + continue + filtered_mappings.append(m) + + # Add our wildcard mapping for the HF vision encoder. + # This maps ``vision_model.**`` → ``model.visual.**`` one-to-one, + # which works because our vision encoder is the HF model directly. + filtered_mappings.append( + ReplicatedMapping( + megatron_param="vision_model.**", + hf_param="model.visual.**", + ) + ) + + # Add per-expert format mappings if the checkpoint doesn't use fused format + if not use_fused: + logger.info( + "Detected per-expert HF weight format (e.g. experts.*.gate_proj.weight). " + "Using per-expert GatedMLPMapping/AutoMapping instead of " + "ExpertMLPGateUpProjMapping/ExpertMLPDownProjMapping." + ) + filtered_mappings.extend( + [ + # Per-expert gate+up projection → fused linear_fc1 + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="model.language_model.layers.*.mlp.experts.*.gate_proj.weight", + up="model.language_model.layers.*.mlp.experts.*.up_proj.weight", + ), + # Per-expert down projection → linear_fc2 + AutoMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc2.weight*", + hf_param="model.language_model.layers.*.mlp.experts.*.down_proj.weight", + ), + # MTP per-expert mappings + GatedMLPMapping( + megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc1.weight*", + gate="mtp.layers.*.mlp.experts.*.gate_proj.weight", + up="mtp.layers.*.mlp.experts.*.up_proj.weight", + ), + AutoMapping( + megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc2.weight*", + hf_param="mtp.layers.*.mlp.experts.*.down_proj.weight", + ), + ] + ) + + # Construct a NEW registry so that _compiled_patterns and + # _reverse_patterns are rebuilt with the updated mapping list. + return MegatronMappingRegistry(*filtered_mappings) + + # ------------------------------------------------------------------ + # Expert weight merging with CPU offloading + # ------------------------------------------------------------------ + def maybe_modify_converted_hf_weight( + self, + task: WeightConversionTask, + converted_weights_dict: Dict[str, torch.Tensor], + hf_state_dict: Mapping, + ) -> Dict[str, torch.Tensor]: + """Merge per-expert weight exports into a single fused [num_experts, ...] tensor. + + For **fused HF format** (e.g. 35B model, ``experts.gate_up_proj``), all experts + share a single HF key. This method caches each expert's contribution and merges + them when all ``num_experts`` entries are collected. + + For **per-expert HF format** (e.g. 397B model, ``experts.0.gate_proj.weight``, + ``experts.1.gate_proj.weight``, …), each expert already has a unique HF key + produced by the mapping. The merging logic does not apply — we must return the + converted weights directly, otherwise they are silently dropped (each key only + ever receives a single entry so the cache-merge path is never triggered). + + The parent class's implementation only supports the fused case. We detect the + format via ``_uses_fused_experts()`` and skip merging for per-expert format. + """ + # Per-expert HF format: each expert has a unique key (e.g. + # ``experts.0.gate_proj.weight``). The parent's caching/merging logic + # assumes all experts share one key, which causes every expert weight to be + # cached under its own unique key and never merged (the ``len == num_experts`` + # guard is never satisfied), effectively **dropping all expert weights**. + # Return directly — no merging needed. + if not self._uses_fused_experts(): + return converted_weights_dict + + # Fused HF format: fall through to the parent's caching/merging logic. + num_experts = self.hf_config.text_config.num_experts + ep_size = parallel_state.get_expert_model_parallel_world_size() + experts_per_rank = num_experts // ep_size + + try: + local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank + except ValueError: + # Not an expert parameter — pass through unchanged. + return converted_weights_dict + + # Detect if EP gathering was already done by the mapping (e.g. GatedMLPMapping + # with is_expert=True calls gather_from_ep_ranks internally). + if ep_size > 1: + expert_ids_in_dict = set() + for key in converted_weights_dict: + try: + expert_ids_in_dict.add(extract_expert_number_from_param(key)) + except ValueError: + pass + if len(expert_ids_in_dict) > 1: + return converted_weights_dict + + result: Dict[str, torch.Tensor] = {} + for key, value in converted_weights_dict.items(): + if key not in self.hf_weights_cache: + self.hf_weights_cache[key] = {} + + # Move to CPU to avoid GPU OOM when concatenating large expert tensors + value = value.cpu() + + if ep_size == 1: + self.hf_weights_cache[key][local_expert_number] = value + else: + assert value.shape[0] == ep_size, ( + f"Expected shape[0]=={ep_size} for EP-gathered expert weight " + f"'{key}', got {value.shape}" + ) + for i, exp_val in enumerate(value): + global_expert_number = local_expert_number + (i * experts_per_rank) + self.hf_weights_cache[key][global_expert_number] = exp_val + + if len(self.hf_weights_cache[key]) == num_experts: + logger.debug(f"All {num_experts} experts loaded for {key}") + merged = torch.cat( + [self.hf_weights_cache[key][i].unsqueeze(0) for i in range(num_experts)], + dim=0, + ) + del self.hf_weights_cache[key] + # Move back to CUDA for downstream processing + result[key] = merged.cuda() + else: + logger.debug( + f"{len(self.hf_weights_cache[key])}/{num_experts} experts " + f"loaded for {key}" + ) + + return result + + def provider_bridge(self, hf_pretrained): + """Create a Qwen35VLMoeVLModelProvider from HF config.""" + hf_config = hf_pretrained.config + text_config = hf_config.text_config + vision_config = deepcopy(hf_config.vision_config) + + model_dtype = self.dtype_from_hf(text_config, default=torch.bfloat16) + vision_config.torch_dtype = model_dtype + + ProviderClass = Qwen35VLMoeVLModelProvider + + rope_params = getattr(text_config, "rope_parameters", {}) or {} + mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + rotary_base = rope_params.get("rope_theta", 10000000) + partial_rotary_factor = rope_params.get("partial_rotary_factor", 0.25) + + # Determine MoE layer frequency + # Qwen3.5-VL MoE uses all layers as MoE (no dense layer) + first_k_dense = getattr(text_config, "first_k_dense_replace", 0) + num_layers = text_config.num_hidden_layers + moe_layer_freq_list = [0] * first_k_dense + [1] * (num_layers - first_k_dense) + + # Shared expert intermediate size + moe_ffn = getattr(text_config, "moe_intermediate_size", 512) + shared_expert_intermediate = getattr( + text_config, "shared_expert_intermediate_size", 512 + ) + + # Read attention bias from config + add_qkv_bias = getattr(text_config, "attention_bias", False) + + # QK layernorm + qk_layernorm = True + + # head_dim + head_dim = getattr(text_config, "head_dim", 256) + + # Qwen3.5 MoE text_config has no intermediate_size; use shared_expert_intermediate_size + # as the dense-FFN fallback (dense layers don't exist when first_k_dense=0, but the + # TransformerConfig still requires ffn_hidden_size). + ffn_hidden_size = getattr(text_config, "intermediate_size", None) + if ffn_hidden_size is None: + ffn_hidden_size = shared_expert_intermediate + + provider = ProviderClass( + # Language model configuration + num_layers=num_layers, + hidden_size=text_config.hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=text_config.num_attention_heads, + num_query_groups=text_config.num_key_value_heads, + kv_channels=head_dim, + init_method_std=text_config.initializer_range, + layernorm_epsilon=text_config.rms_norm_eps, + normalization="RMSNorm", + layernorm_zero_centered_gamma=True, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), + rotary_base=rotary_base, + rotary_percent=partial_rotary_factor, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=text_config.vocab_size, + seq_length=getattr(text_config, "max_position_embeddings", 262144), + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + # MoE configuration + num_moe_experts=getattr(text_config, "num_experts", 256), + moe_router_topk=getattr(text_config, "num_experts_per_tok", 8), + moe_ffn_hidden_size=moe_ffn, + moe_shared_expert_intermediate_size=shared_expert_intermediate, + moe_layer_freq=moe_layer_freq_list, + moe_grouped_gemm=True, + moe_router_load_balancing_type="global_aux_loss", + moe_aux_loss_coeff=getattr(text_config, "router_aux_loss_coef", 0.001), + moe_router_pre_softmax=False, + moe_router_score_function="softmax", + moe_router_dtype="fp32", + moe_token_dispatcher_type="alltoall", + moe_permute_fusion=True, + # Attention + add_qkv_bias=add_qkv_bias, + add_bias_linear=False, + qk_layernorm=qk_layernorm, + # Attention output gate (Qwen3.5 specific) + attention_output_gate=True, + # Shared expert gate + moe_shared_expert_gate=True, + # GDN (Gated DeltaNet) — attention variant and layer pattern + # Must set experimental_attention_variant="gated_delta_net" so that + # Megatron builds GDN layers (GatedDeltaNet modules) instead of + # standard SDPA attention for the layers marked as linear_attention. + experimental_attention_variant="gated_delta_net", + # Qwen3.5-VL MoE uses a hybrid GDN + full-attention architecture. + # The HF config stores per-layer types in ``layer_types`` list + # (e.g. ["linear_attention","linear_attention","linear_attention","full_attention", ...]) + # rather than a scalar ``full_attention_interval``. + # Megatron expects ``linear_attention_freq`` as: + # int N -> one SDPA layer every N layers + # list -> per-layer pattern: 1=linear_attention(GDN), 0=full_attention(SDPA) + # We convert from HF's ``layer_types`` strings to Megatron's int list. + linear_attention_freq=( + [1 if lt == "linear_attention" else 0 for lt in text_config.layer_types] + if getattr(text_config, "layer_types", None) is not None + else getattr(text_config, "full_attention_interval", None) + ), + linear_conv_kernel_dim=getattr(text_config, "linear_conv_kernel_dim", None), + linear_key_head_dim=getattr(text_config, "linear_key_head_dim", None), + linear_value_head_dim=getattr(text_config, "linear_value_head_dim", None), + linear_num_key_heads=getattr(text_config, "linear_num_key_heads", None), + linear_num_value_heads=getattr(text_config, "linear_num_value_heads", None), + # M-RoPE + mrope_section=mrope_section, + position_embedding_type="mrope", + scatter_embedding_sequence_parallel=False, + # Vision + hf_vision_config=vision_config, + hf_text_config=text_config, + image_token_id=getattr(hf_config, "image_token_id", 248056), + video_token_id=getattr(hf_config, "video_token_id", 248057), + spatial_merge_size=getattr(hf_config.vision_config, "spatial_merge_size", 2), + language_max_sequence_length=getattr(text_config, "max_position_embeddings", 262144), + ) + + return provider + +# Apply patches at module load time so that they are active whenever this +# bridge module is imported, regardless of import order. +_patch_auto_mapping_for_gdn() \ No newline at end of file diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index 8d2758947e..7a63ec60b8 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -9,8 +9,6 @@ from megatron.training.checkpointing import get_checkpoint_name, get_checkpoint_tracker_filename, save_checkpoint from megatron.training.training import get_model -import slime_plugins.mbridge # noqa: F401 -from mbridge import AutoBridge from slime.backends.megatron_utils.arguments import set_default_megatron_args from slime.backends.megatron_utils.initialize import init from slime.backends.megatron_utils.model_provider import get_model_provider_func @@ -35,7 +33,7 @@ def add_convertion_args(parser): def get_args(): - args = parse_args(add_convertion_args) + args = parse_args(add_convertion_args, ignore_unknown_args=True) args = set_default_megatron_args(args) # set to pass megatron validate_args @@ -113,11 +111,31 @@ def main(): model = get_model(get_model_provider_func(args), ModelType.encoder_or_decoder, wrap_with_ddp=False) - # Load model + # Load model weights hf_model_path = args.hf_checkpoint - bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True) - bridge.load_weights(model, hf_model_path, memory_efficient=True) - print(f"Model loaded: {hf_model_path}") + if args.megatron_to_hf_mode == "bridge": + # Bridge mode: use megatron.bridge for weight loading + # The bridge was already created inside get_model_provider_func; we need + # a fresh one here for weight loading only. + from megatron.bridge import AutoBridge as MegatronAutoBridge + + import slime_plugins.megatron_bridge # noqa: F401 # register custom bridges + from slime.utils.megatron_bridge_utils import patch_auto_bridge_hf_config + + bridge = patch_auto_bridge_hf_config( + MegatronAutoBridge.from_hf_pretrained(hf_model_path, trust_remote_code=True) + ) + bridge.load_hf_weights(model) + print(f"Model loaded (bridge mode): {hf_model_path}") + else: + # Raw mode: use mbridge for weight loading + import slime_plugins.mbridge # noqa: F401 + + from mbridge import AutoBridge + + bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True) + bridge.load_weights(model, hf_model_path, memory_efficient=True) + print(f"Model loaded (raw mode): {hf_model_path}") if args.use_cpu_initialization: model[0] = model[0].cpu() From 1de44fd8559e7dc7ddd9b4dab88faf32ca250ec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 17:47:58 +0800 Subject: [PATCH 06/14] feat: improve model conversion tools (Megatron <-> HuggingFace) - Add Qwen3.5 MoE bridge conversion support in mbridge plugin - Add parallel distributed conversion tool (convert_torch_dist_to_hf_parallel.py) - Add merge_missing_keys.py for handling partial checkpoint merges - Fix megatron_to_hf conversion for Qwen3.5 architecture - Fix convert_torch_dist_to_hf_bridge.py quantization support --- .../run-qwen35-35B-A3B-mopd-topk-sglang.sh | 1 + .../megatron_utils/megatron_to_hf/__init__.py | 11 +- slime_plugins/mbridge/qwen3_5.py | 110 +++++++++++- tools/convert_torch_dist_to_hf_bridge.py | 1 + tools/convert_torch_dist_to_hf_parallel.py | 112 +++++++++++++ tools/merge_missing_keys.py | 156 ++++++++++++++++++ 6 files changed, 385 insertions(+), 6 deletions(-) create mode 100644 tools/merge_missing_keys.py diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh index f51aa12e80..42a9e92e63 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh @@ -67,6 +67,7 @@ ROLLOUT_ARGS=( --rollout-shuffle --rollout-batch-size 4 --n-samples-per-prompt 4 + --rollout-max-prompt-len 9216 --rollout-max-response-len 2048 --rollout-temperature 0.8 diff --git a/slime/backends/megatron_utils/megatron_to_hf/__init__.py b/slime/backends/megatron_utils/megatron_to_hf/__init__.py index d6cccc23f3..42fcf655df 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/slime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -42,12 +42,15 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_glm4_to_hf(args, name, param) elif "gpt_oss" in model_name or "gpt-oss" in model_name or "gptoss" in model_name: converted_named_tensors = convert_gpt_oss_to_hf(args, name, param) - elif "qwen3moe" in model_name: - converted_named_tensors = convert_qwen3moe_to_hf(args, name, param) - elif "qwen3next" in model_name: - converted_named_tensors = convert_qwen3_next_to_hf(args, name, param) elif "qwen3_5" in model_name: + # Must match before "qwen3moe" since "qwen3_5moe" contains neither + # "qwen3moe" as a substring, but if someone passes --model-name qwen3moe + # for a Qwen3.5-MoE model, they should use qwen3_5 instead. converted_named_tensors = convert_qwen3_5_to_hf(args, name, param) + elif "qwen3next" in model_name: + converted_named_tensors = convert_qwen3_next_to_hf(args, name, param) + elif "qwen3moe" in model_name: + converted_named_tensors = convert_qwen3moe_to_hf(args, name, param) elif "qwen3vl" in model_name: converted_named_tensors = convert_qwen3vl_to_hf(args, name, param) elif "qwen2" in model_name or "qwen3" in model_name: diff --git a/slime_plugins/mbridge/qwen3_5.py b/slime_plugins/mbridge/qwen3_5.py index d01094ecc9..44393bd5ab 100644 --- a/slime_plugins/mbridge/qwen3_5.py +++ b/slime_plugins/mbridge/qwen3_5.py @@ -13,6 +13,11 @@ class Qwen3_5Bridge(Qwen2MoEBridge): Bridge for Qwen3.5 models (both dense and MoE variants). Qwen3.5 is a VLM model with weights under model.language_model.layers prefix, separate in_proj_qkv + in_proj_z for linear attention, and nested text_config. + + MoE expert weights can be stored in two formats in HF checkpoints: + - Fused format: gate_up_proj / down_proj as a single 3D tensor [num_experts, ...] + - Per-expert format: {expert_id}.gate_proj.weight / {expert_id}.up_proj.weight / {expert_id}.down_proj.weight + This bridge auto-detects the format from the checkpoint's safetensors index. """ _DIRECT_MAPPING = { @@ -62,7 +67,8 @@ class Qwen3_5Bridge(Qwen2MoEBridge): ] } - _MLP_MAPPING = { + # Fused expert format: single 3D tensor for all experts + _MLP_MAPPING_FUSED_EXPERTS = { "mlp.linear_fc1.weight": [ "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", "model.language_model.layers.{layer_number}.mlp.up_proj.weight", @@ -82,13 +88,45 @@ class Qwen3_5Bridge(Qwen2MoEBridge): ], "mlp.router.weight": ["model.language_model.layers.{layer_number}.mlp.gate.weight"], "shared_experts.gate_weight": ["model.language_model.layers.{layer_number}.mlp.shared_expert_gate.weight"], - # Fused expert format: single 3D tensor for all experts "mlp.experts.linear_fc1": [ "model.language_model.layers.{layer_number}.mlp.experts.gate_up_proj", ], "mlp.experts.linear_fc2": ["model.language_model.layers.{layer_number}.mlp.experts.down_proj"], } + # Per-expert format: separate tensors per expert (e.g., Qwen3.5-397B checkpoints) + _MLP_MAPPING_PER_EXPERT = { + "mlp.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.up_proj.weight", + ], + "mlp.linear_fc1.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.post_attention_layernorm.weight" + ], + "mlp.linear_fc2.weight": ["model.language_model.layers.{layer_number}.mlp.down_proj.weight"], + # MoE mappings + "shared_experts.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.shared_expert.up_proj.weight", + ], + "pre_mlp_layernorm": ["model.language_model.layers.{layer_number}.post_attention_layernorm.weight"], + "shared_experts.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.shared_expert.down_proj.weight" + ], + "mlp.router.weight": ["model.language_model.layers.{layer_number}.mlp.gate.weight"], + "shared_experts.gate_weight": ["model.language_model.layers.{layer_number}.mlp.shared_expert_gate.weight"], + "mlp.experts.linear_fc1": [ + "model.language_model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight", + ], + "mlp.experts.linear_fc2": [ + "model.language_model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight", + ], + } + + # Default: use fused format (backward compatible) + _MLP_MAPPING = _MLP_MAPPING_FUSED_EXPERTS + # MTP layer uses individual expert format (not fused) _MTP_MLP_MAPPING = { "mlp.experts.linear_fc1": [ @@ -126,6 +164,74 @@ def _adjust_mapping_for_shared_weights(self): self._DIRECT_MAPPING = dict(self._DIRECT_MAPPING) self._DIRECT_MAPPING["output_layer.weight"] = "model.language_model.embed_tokens.weight" + def _detect_expert_weight_format(self, weights_path: str) -> None: + """Auto-detect whether the HF checkpoint uses fused or per-expert format for MoE weights. + + Fused format: model.language_model.layers.{N}.mlp.experts.gate_up_proj + Per-expert format: model.language_model.layers.{N}.mlp.experts.{E}.gate_proj.weight + + This sets self._MLP_MAPPING accordingly. + """ + has_num_experts = hasattr(self._get_text_config(), "num_experts") + if not has_num_experts: + return + + # Resolve the actual local path (handles HF hub IDs via _get_actual_hf_path) + try: + actual_path = self._get_actual_hf_path(weights_path) + except Exception: + actual_path = weights_path + + # Check the safetensors index for the weight format + import json + import os + + index_file = os.path.join(actual_path, "model.safetensors.index.json") + if os.path.exists(index_file): + with open(index_file, "r") as f: + index = json.load(f) + weight_map = index.get("weight_map", {}) + # Check if fused key exists + fused_key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + per_expert_key = "model.language_model.layers.0.mlp.experts.0.gate_proj.weight" + if fused_key in weight_map: + self._MLP_MAPPING = self._MLP_MAPPING_FUSED_EXPERTS + elif per_expert_key in weight_map: + self._MLP_MAPPING = self._MLP_MAPPING_PER_EXPERT + else: + # Fallback: scan all keys for any expert weight pattern + for key in weight_map: + if ".mlp.experts.gate_up_proj" in key: + self._MLP_MAPPING = self._MLP_MAPPING_FUSED_EXPERTS + return + if ".mlp.experts.0.gate_proj.weight" in key: + self._MLP_MAPPING = self._MLP_MAPPING_PER_EXPERT + return + # Default to fused if no expert keys found (may not be MoE) + self._MLP_MAPPING = self._MLP_MAPPING_FUSED_EXPERTS + else: + # No index file; scan safetensors files directly + from glob import glob + from safetensors import safe_open + + safetensor_files = glob(os.path.join(actual_path, "*.safetensors")) + for sf in safetensor_files: + with safe_open(sf, framework="pt", device="cpu") as f: + keys = f.keys() + for key in keys: + if ".mlp.experts.gate_up_proj" in key: + self._MLP_MAPPING = self._MLP_MAPPING_FUSED_EXPERTS + return + if ".mlp.experts.0.gate_proj.weight" in key: + self._MLP_MAPPING = self._MLP_MAPPING_PER_EXPERT + return + self._MLP_MAPPING = self._MLP_MAPPING_FUSED_EXPERTS + + def load_weights(self, models, weights_path, memory_efficient=False): + """Override to auto-detect expert weight format before loading.""" + self._detect_expert_weight_format(weights_path) + return super().load_weights(models, weights_path, memory_efficient) + def _supports_transformer_config_kwarg(self, kwarg_name: str) -> bool: """Check whether the current TransformerConfig accepts a given kwarg.""" transformer_config_class = getattr(self, "TransformerConfigClass", None) diff --git a/tools/convert_torch_dist_to_hf_bridge.py b/tools/convert_torch_dist_to_hf_bridge.py index 798503f218..2da2de4cf2 100644 --- a/tools/convert_torch_dist_to_hf_bridge.py +++ b/tools/convert_torch_dist_to_hf_bridge.py @@ -4,6 +4,7 @@ import megatron.bridge.training.model_load_save as _model_load_save_module from megatron.bridge import AutoBridge +import slime_plugins.megatron_bridge # noqa: F401 # register custom bridges before AutoBridge from slime.utils.megatron_bridge_utils import patch_auto_bridge_hf_config diff --git a/tools/convert_torch_dist_to_hf_parallel.py b/tools/convert_torch_dist_to_hf_parallel.py index 763254d42c..5c1378ed89 100644 --- a/tools/convert_torch_dist_to_hf_parallel.py +++ b/tools/convert_torch_dist_to_hf_parallel.py @@ -340,6 +340,114 @@ def save_file(i, tensors): print(f"{filename} saved in {elapsed:.2f} sec.") +def _merge_missing_keys_from_origin_hf(origin_hf_dir, output_dir, converted_weight_map, chunk_size): + """Merge missing keys from the original HF model into the converted checkpoint. + + For VLM models, the Megatron checkpoint only contains the language model weights. + The visual encoder weights (e.g., model.visual.*) must be copied from the original + HF checkpoint to ensure the model can load correctly for inference. + """ + origin_index_path = os.path.join(origin_hf_dir, "model.safetensors.index.json") + if not os.path.exists(origin_index_path): + print("No model.safetensors.index.json found in origin HF dir. Skipping missing key merge.") + return + + with open(origin_index_path) as f: + origin_index = json.load(f) + + origin_keys = set(origin_index["weight_map"].keys()) + converted_keys = set(converted_weight_map.keys()) + missing_keys = sorted(origin_keys - converted_keys) + + if not missing_keys: + print("No missing keys detected. Skipping missing key merge.") + return + + print(f"Found {len(missing_keys)} missing keys in converted checkpoint. Merging from origin HF model.") + + # Identify which safetensors files from the origin contain missing keys + # Group missing keys by their source file + missing_by_file = {} + for key in missing_keys: + src_file = origin_index["weight_map"][key] + if src_file not in missing_by_file: + missing_by_file[src_file] = [] + missing_by_file[src_file].append(key) + + # Load missing tensors from origin HF safetensors + missing_tensors = {} + for src_file, keys in tqdm(missing_by_file.items(), desc="Loading missing keys from origin HF"): + src_path = os.path.join(origin_hf_dir, src_file) + if not os.path.exists(src_path): + print(f"Warning: {src_path} not found. Skipping keys: {keys}") + continue + from safetensors import safe_open + with safe_open(src_path, framework="pt", device="cpu") as f: + for key in keys: + missing_tensors[key] = f.get_tensor(key) + + # Now we need to insert these tensors into the existing safetensors files. + # Strategy: find the last safetensors file, add the missing tensors there, + # or create a new file if it would exceed chunk_size. + total_files = max( + int(v.split("-")[-2]) for v in converted_weight_map.values() + ) + # Re-number files to include the missing keys in a new shard + # First, collect existing tensors from the last file and append missing ones + last_file_pattern = f"model-{total_files:05d}-of-{total_files:05d}.safetensors" + last_file_path = os.path.join(output_dir, last_file_pattern) + + existing_last_tensors = {} + if os.path.exists(last_file_path): + existing_last_tensors = safetensors.torch.load_file(last_file_path) + + # Calculate sizes + existing_size = sum(t.numel() * t.element_size() for t in existing_last_tensors.values()) + missing_size = sum(t.numel() * t.element_size() for t in missing_tensors.values()) + + if existing_size + missing_size <= chunk_size: + # Fits in the last file - just append + combined_tensors = {**existing_last_tensors, **missing_tensors} + safetensors.torch.save_file(combined_tensors, last_file_path) + for key in missing_tensors: + converted_weight_map[key] = last_file_pattern + else: + # Need a new shard + new_total = total_files + 1 + # Save missing tensors to a new file with updated numbering + new_file_name = f"model-{new_total:05d}-of-{new_total:05d}.safetensors" + safetensors.torch.save_file(missing_tensors, os.path.join(output_dir, new_file_name)) + for key in missing_tensors: + converted_weight_map[key] = new_file_name + + # Rename all existing files to update the total count + for i in range(1, total_files + 1): + old_name = f"model-{i:05d}-of-{total_files:05d}.safetensors" + new_name = f"model-{i:05d}-of-{new_total:05d}.safetensors" + old_path = os.path.join(output_dir, old_name) + new_path = os.path.join(output_dir, new_name) + if os.path.exists(old_path): + shutil.move(old_path, new_path) + # Update weight map references + for k, v in converted_weight_map.items(): + if v == old_name: + converted_weight_map[k] = new_name + + # Update total_size in metadata + new_total_size = sum(t.numel() * t.element_size() for t in missing_tensors.values()) + index_data = json.load(open(os.path.join(output_dir, "model.safetensors.index.json"))) + index_data["metadata"]["total_size"] = index_data["metadata"].get("total_size", 0) + new_total_size + index_data["weight_map"] = converted_weight_map + with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f: + json.dump(index_data, f, indent=2) + + print(f"Successfully merged {len(missing_tensors)} missing keys into the converted checkpoint.") + for key in missing_keys[:10]: + print(f" + {key}") + if len(missing_keys) > 10: + print(f" ... and {len(missing_keys) - 10} more") + + def copy_assets(origin_hf_dir, output_dir): for filename in os.listdir(origin_hf_dir): if filename == "model.safetensors.index.json" or filename.endswith(".safetensors"): @@ -574,5 +682,9 @@ def conversion_worker( json.dump(index_data, open(os.path.join(args.output_dir, "model.safetensors.index.json"), "w"), indent=2) print("Model converted and saved.") + # Merge missing keys from the original HF model (e.g., visual encoder weights for VLM models) + # These keys exist in the original HF checkpoint but are not present in the Megatron checkpoint, + # because Megatron only trains the language model part. if args.origin_hf_dir: + _merge_missing_keys_from_origin_hf(args.origin_hf_dir, args.output_dir, final_weight_map_fixed, args.chunk_size) copy_assets(args.origin_hf_dir, args.output_dir) diff --git a/tools/merge_missing_keys.py b/tools/merge_missing_keys.py new file mode 100644 index 0000000000..0cf26abb70 --- /dev/null +++ b/tools/merge_missing_keys.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""Merge missing keys (e.g., visual encoder weights) from the original HF model +into a converted Megatron-to-HF checkpoint. + +This is needed for VLM models like Qwen3.5-397B-A17B where the Megatron checkpoint +only contains the language model weights, and the visual encoder weights must be +copied from the original HF checkpoint. + +Usage: + python merge_missing_keys.py \ + --origin-hf-dir /path/to/original/Qwen3.5-397B-A17B \ + --converted-dir /path/to/converted/checkpoint \ + [--dry-run] +""" + +import argparse +import json +import os +import shutil + +import safetensors.torch +from safetensors import safe_open +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser(description="Merge missing keys from original HF model into converted checkpoint") + parser.add_argument("--origin-hf-dir", type=str, required=True, help="Path to the original HuggingFace model directory") + parser.add_argument("--converted-dir", type=str, required=True, help="Path to the converted checkpoint directory") + parser.add_argument("--dry-run", action="store_true", help="Only print missing keys without merging") + parser.add_argument("--chunk-size", type=int, default=5 * 1024**3, help="Chunk size for safetensors files (default 5GB)") + args = parser.parse_args() + + # Load both index files + origin_index_path = os.path.join(args.origin_hf_dir, "model.safetensors.index.json") + converted_index_path = os.path.join(args.converted_dir, "model.safetensors.index.json") + + if not os.path.exists(origin_index_path): + raise FileNotFoundError(f"Origin index not found: {origin_index_path}") + if not os.path.exists(converted_index_path): + raise FileNotFoundError(f"Converted index not found: {converted_index_path}") + + with open(origin_index_path) as f: + origin_index = json.load(f) + with open(converted_index_path) as f: + converted_index = json.load(f) + + origin_keys = set(origin_index["weight_map"].keys()) + converted_keys = set(converted_index["weight_map"].keys()) + missing_keys = sorted(origin_keys - converted_keys) + + if not missing_keys: + print("No missing keys detected. The converted checkpoint is complete.") + return + + print(f"Found {len(missing_keys)} missing keys (present in origin but not in converted checkpoint):") + + # Categorize missing keys + from collections import Counter + prefix_patterns = Counter() + for key in missing_keys: + parts = key.split(".") + # Group by first 3 parts (e.g., model.visual.blocks.0) + prefix = ".".join(parts[:4]) if len(parts) >= 4 else key + prefix_patterns[prefix] += 1 + + print("\nMissing key categories:") + for prefix, count in prefix_patterns.most_common(): + print(f" {prefix}.*: {count} keys") + + for key in missing_keys[:5]: + print(f" - {key}") + if len(missing_keys) > 5: + print(f" ... and {len(missing_keys) - 5} more") + + if args.dry_run: + print("\n[DRY RUN] Exiting without making changes.") + return + + # Group missing keys by their source file in origin + missing_by_file = {} + for key in missing_keys: + src_file = origin_index["weight_map"][key] + if src_file not in missing_by_file: + missing_by_file[src_file] = [] + missing_by_file[src_file].append(key) + + # Load missing tensors from origin HF safetensors + print(f"\nLoading {len(missing_keys)} missing tensors from origin HF model...") + missing_tensors = {} + for src_file, keys in tqdm(missing_by_file.items(), desc="Reading origin safetensors"): + src_path = os.path.join(args.origin_hf_dir, src_file) + if not os.path.exists(src_path): + print(f"WARNING: {src_path} not found. Skipping keys: {keys}") + continue + with safe_open(src_path, framework="pt", device="cpu") as f: + for key in keys: + missing_tensors[key] = f.get_tensor(key) + + # Determine current file count + current_files = set(converted_index["weight_map"].values()) + total_files = len(current_files) + + # Calculate missing size + missing_size = sum(t.numel() * t.element_size() for t in missing_tensors.values()) + print(f"Missing tensors total size: {missing_size / 1e9:.2f} GB") + + # Find the last file and check its size + last_file_idx = total_files + last_file_name = f"model-{last_file_idx:05d}-of-{total_files:05d}.safetensors" + last_file_path = os.path.join(args.converted_dir, last_file_name) + + # Add missing tensors to a new shard + new_total = total_files + 1 + new_shard_name = f"model-{new_total:05d}-of-{new_total:05d}.safetensors" + new_shard_path = os.path.join(args.converted_dir, new_shard_name) + + print(f"Writing {len(missing_tensors)} tensors to new shard: {new_shard_name}") + safetensors.torch.save_file(missing_tensors, new_shard_path) + + # Update weight map: add new entries and update file numbering + weight_map = converted_index["weight_map"] + + # Add missing keys pointing to the new shard + for key in missing_tensors: + weight_map[key] = new_shard_name + + # Rename existing files to update total count + print(f"Renaming {total_files} existing shards to update total count from {total_files} to {new_total}...") + for i in range(1, total_files + 1): + old_name = f"model-{i:05d}-of-{total_files:05d}.safetensors" + new_name = f"model-{i:05d}-of-{new_total:05d}.safetensors" + old_path = os.path.join(args.converted_dir, old_name) + new_path = os.path.join(args.converted_dir, new_name) + if os.path.exists(old_path): + shutil.move(old_path, new_path) + # Update weight map references + for k, v in weight_map.items(): + if v == old_name: + weight_map[k] = new_name + + # Update and save index + converted_index["metadata"]["total_size"] = converted_index["metadata"].get("total_size", 0) + missing_size + converted_index["weight_map"] = weight_map + + with open(converted_index_path, "w") as f: + json.dump(converted_index, f, indent=2) + + print(f"\nDone! Merged {len(missing_tensors)} missing keys into the converted checkpoint.") + print(f"New total size: {converted_index['metadata']['total_size'] / 1e9:.2f} GB") + print(f"Total shards: {new_total}") + print(f"Total keys: {len(weight_map)}") + + +if __name__ == "__main__": + main() \ No newline at end of file From fa572007da343f6313c2cb80ecbaff525c1fef5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 17:51:20 +0800 Subject: [PATCH 07/14] fix: improve training stability (loss inf, padding vocab size) - Fix loss becoming inf due to numerical instability in KL computation - Fix padding vocab size handling in actor forward pass - Add train-memory-margin-bytes argument for memory management - Add attention gate patching tool for distributed checkpoints - Add safety checks for logits with padding tokens - Update 397B SGLang script with stability improvements --- .../run-qwen35-397B-A17B-mopd-topk-sglang.sh | 9 +- slime/backends/megatron_utils/actor.py | 55 +++-- slime/backends/megatron_utils/loss.py | 101 ++++++-- slime/utils/ppo_utils.py | 14 ++ tools/patch_attention_gate_on_cluster.py | 227 ++++++++++++++++++ 5 files changed, 373 insertions(+), 33 deletions(-) create mode 100644 tools/patch_attention_gate_on_cluster.py diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh index 2019b8dc23..044c62bd9c 100755 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh @@ -96,10 +96,10 @@ ROLLOUT_ARGS=( --input-key messages --apply-chat-template --rollout-shuffle - --rollout-batch-size 64 - --n-samples-per-prompt 1 + --rollout-batch-size 16 + --n-samples-per-prompt 4 --rollout-max-response-len 4096 - --rollout-temperature 0.5 + --rollout-temperature 0.8 --global-batch-size 64 --balance-data @@ -116,7 +116,7 @@ RM_ARGS=( EVAL_ARGS=() PERF_ARGS=( - --tensor-model-parallel-size 2 + --tensor-model-parallel-size 16 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 @@ -129,6 +129,7 @@ PERF_ARGS=( --use-dynamic-batch-size --max-tokens-per-gpu 4096 + --train-memory-margin-bytes 268435456 ) MOPD_ARGS=( diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 094f7ad441..ed1ce156a2 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -599,16 +599,27 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data if sglang_topk_logits and sglang_topk_indices: tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() + # Use the ORIGINAL vocab_size (not padded_vocab_size) for + # TP shard calculations. Megatron's ColumnParallelLinear + # output layer dimensions are based on the actual + # vocab_size / tp_size (loaded from HF checkpoint), + # NOT padded_vocab_size / tp_size. Using padded values + # causes local indices to exceed the model's actual + # vocab dimension, leading to gather-index-out-of-bounds + # errors in the downstream KL computation. + # padded_vocab_size = 249856 → per-shard = 15616 + # vocab_size = 248320 → per-shard = 15520 (actual) + # Overflow range: [15520, 15615] (96 phantom indices) + vocab_size = self.args.vocab_size padded_vocab_size = self.args.padded_vocab_size - vocab_local_size = padded_vocab_size // tp_size + vocab_local_size = vocab_size // tp_size vocab_offset = tp_rank * vocab_local_size topk_k = self.args.mopd_topk_k # Check that SGLang teacher's vocab size is consistent - # with the student's padded_vocab_size. If teacher - # token IDs exceed the student vocab range, the - # global→local TP index conversion will produce - # silently incorrect results. + # with the student's vocab_size. If teacher token IDs + # exceed the student vocab range, the global→local TP + # index conversion will produce silently incorrect results. _vocab_checked = False for domain in sglang_topk_logits: @@ -659,16 +670,23 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data if not _vocab_checked: _vocab_checked = True max_token_id = global_indices.max().item() - if max_token_id >= padded_vocab_size: + min_token_id = global_indices.min().item() + logger.info( + f"[MOPD] Vocab sharding: tp_rank={tp_rank}, " + f"tp_size={tp_size}, vocab_size={vocab_size}, " + f"padded_vocab_size={padded_vocab_size}, " + f"vocab_local_size={vocab_local_size}, " + f"vocab_offset={vocab_offset}, topk_k={topk_k}" + ) + logger.info( + f"[MOPD] global_indices range=[{min_token_id}, " + f"{max_token_id}], shape={global_indices.shape}" + ) + if max_token_id >= vocab_size: logger.error( - f"MOPD top_k: SGLang teacher returned token ID " - f"{max_token_id} which exceeds student " - f"padded_vocab_size={padded_vocab_size}. " - f"The teacher and student vocab sizes are " - f"mismatched — this will produce incorrect " - f"TP index conversion and wrong KL divergence. " - f"Ensure the teacher model uses the same " - f"tokenizer/vocab as the student." + f"[MOPD] TOKEN ID OVERFLOW! " + f"max_token_id={max_token_id} >= " + f"vocab_size={vocab_size}" ) seq_len = global_indices.size(0) @@ -701,6 +719,15 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data local_topk_logits[row, :n_in_shard] = shard_logits[:n_in_shard] local_topk_indices[row, :n_in_shard] = shard_local_idx[:n_in_shard] + # [MOPD] Check local_topk_indices range after conversion + _local_max = local_topk_indices.max().item() + if _local_max >= vocab_local_size: + logger.error( + f"[MOPD] LOCAL INDEX OVERFLOW! sample={i} " + f"max_local={_local_max} >= " + f"vocab_local_size={vocab_local_size}" + ) + topk_logits_list.append(local_topk_logits) topk_indices_list.append(local_topk_indices) rollout_data[topk_logits_key] = topk_logits_list diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 280fd57426..4a8cfb1661 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1210,9 +1210,17 @@ def apply_mopd_full_vocab_to_loss( "mopd_is_nonzero_frac": is_nonzero_frac.clone().detach(), } - # Per-teacher KL for logging (re-use KL values computed in the main loop) - for domain, domain_kls in per_domain_kls.items(): - metrics[f"mopd_fv_kl/{domain}"] = sum_of_sample_mean(torch.cat(domain_kls, dim=0)).clone().detach() + # Per-teacher KL for logging (re-use KL values computed in the main loop). + # Iterate over ALL configured teacher domains (not just per_domain_kls) so + # that every microbatch emits the same set of metric keys. Without this + # Megatron's cross-microbatch loss reduction (model.py: values += x["values"]) + # can crash with a tensor-size mismatch when different microbatches have + # different subsets of active domains. + for domain in teacher_logits_per_domain: + if domain in per_domain_kls and len(per_domain_kls[domain]) > 0: + metrics[f"mopd_fv_kl/{domain}"] = sum_of_sample_mean(torch.cat(per_domain_kls[domain], dim=0)).clone().detach() + else: + metrics[f"mopd_fv_kl/{domain}"] = torch.tensor(0.0, device=all_kl_cat.device) return kl_loss, metrics @@ -1277,7 +1285,7 @@ def apply_mopd_topk_to_loss( f"MOPD top_k requires '{sampling_logprobs_key}' in batch for importance sampling." ) - vocab_size = args.padded_vocab_size + vocab_size = args.vocab_size num_samples = len(student_logits_per_sample) if len(sampling_log_probs) != num_samples: raise ValueError( @@ -1303,11 +1311,16 @@ def apply_mopd_topk_to_loss( t_topk_logits = teacher_topk_logits_per_domain[domain][i] # [R_i, k] - # Skip fallback sentinel tensors (all -inf) from failed teacher requests. - # These would produce KL=0 anyway, so skipping avoids unnecessary - # computation and TP all-reduce calls. - if t_topk_logits.isinf().all(): - continue + # IMPORTANT: Do NOT skip the vocab_parallel_topk_reverse_kl call even + # when all teacher logits are -inf. Each TP rank independently shards + # the top-k tokens into its vocab range, so one rank may see all -inf + # (no tokens in its shard) while another rank has valid entries. + # Skipping on only some ranks creates an inconsistent TP collective call + # (all_reduce inside vocab_parallel_topk_reverse_kl), causing an + # irreversible NCCL deadlock. When all entries are -inf, + # vocab_parallel_topk_reverse_kl correctly produces KL=0 (the + # valid_topk_mask is all-False), so the numerical result is identical + # -- but the collective operations remain consistent across TP ranks. t_topk_indices = teacher_topk_indices_per_domain[domain][i] # [R_i, k] @@ -1394,8 +1407,16 @@ def apply_mopd_topk_to_loss( "mopd_is_nonzero_frac": is_nonzero_frac.clone().detach(), } - for domain, domain_kls in per_domain_kls.items(): - metrics[f"mopd_topk_kl/{domain}"] = sum_of_sample_mean(torch.cat(domain_kls, dim=0)).clone().detach() + for domain in teacher_topk_logits_per_domain: + if domain in per_domain_kls and len(per_domain_kls[domain]) > 0: + metrics[f"mopd_topk_kl/{domain}"] = sum_of_sample_mean(torch.cat(per_domain_kls[domain], dim=0)).clone().detach() + else: + # No samples contributed valid teacher data for this domain in this + # microbatch. Emit a zero metric so that every microbatch produces + # the same set of metric keys — this is required for Megatron's + # loss-reduction across microbatches which uses tensor addition + # (model.py: values += x["values"]) and demands identical sizes. + metrics[f"mopd_topk_kl/{domain}"] = torch.tensor(0.0, device=all_kl_cat.device) return kl_loss, metrics @@ -1568,6 +1589,24 @@ def policy_loss_function( else: logger.warning("MOPD full_vocab enabled but no teacher logits found in batch. Skipping full_vocab KL loss.") + # Ensure per-domain metric keys AND base MOPD metric keys exist for + # ALL configured teacher domains, even when the batch contains no valid + # data for some (or all) domains. Megatron's loss-reduction + # (model.py: values += x["values"]) requires every microbatch to emit + # the same set of metric keys; missing keys cause a tensor-size + # mismatch across microbatches. + _device = logits.device + for teacher_cfg in mopd_teachers_parsed: + domain = teacher_cfg["domain"] + _domain_key = f"mopd_fv_kl/{domain}" + if _domain_key not in mopd_fv_metrics: + mopd_fv_metrics[_domain_key] = torch.tensor(0.0, device=_device) + # Ensure base MOPD metrics are present even when no teacher data was + # available for the entire microbatch (apply_mopd_full_vocab_to_loss not called). + for _base_key in ("mopd_fv_kl", "mopd_is_weight_mean", "mopd_is_nonzero_frac"): + if _base_key not in mopd_fv_metrics: + mopd_fv_metrics[_base_key] = torch.tensor(0.0, device=_device) + # MOPD top_k: compute top-k approximate reverse KL divergence loss # L = (1/D) Σ_d w_d · KL_topk+d(π_θ ∥ π_d) + alpha * pg_loss if use_mopd_top_k: @@ -1616,6 +1655,24 @@ def policy_loss_function( else: logger.warning("MOPD top_k enabled but no teacher top-k data found in batch. Skipping top_k KL loss.") + # Ensure per-domain metric keys AND base MOPD metric keys exist for + # ALL configured teacher domains, even when the batch contains no valid + # data for some (or all) domains. Megatron's loss-reduction + # (model.py: values += x["values"]) requires every microbatch to emit + # the same set of metric keys; missing keys cause a tensor-size + # mismatch across microbatches. + _device = logits.device + for teacher_cfg in mopd_teachers_parsed: + domain = teacher_cfg["domain"] + _domain_key = f"mopd_topk_kl/{domain}" + if _domain_key not in mopd_fv_metrics: + mopd_fv_metrics[_domain_key] = torch.tensor(0.0, device=_device) + # Ensure base MOPD metrics are present even when no teacher data was + # available for the entire microbatch (apply_mopd_topk_to_loss not called). + for _base_key in ("mopd_topk_kl", "mopd_is_weight_mean", "mopd_is_nonzero_frac"): + if _base_key not in mopd_fv_metrics: + mopd_fv_metrics[_base_key] = torch.tensor(0.0, device=_device) + # Apply off-policy correction using importance sampling if enabled if args.get_mismatch_metrics or args.use_tis: # NOTE: @@ -1758,15 +1815,29 @@ def policy_loss_function( reported_loss["mopd_is_nonzero_frac"] = sum_of_sample_mean(mopd_is_nonzero).clone().detach() if "mopd_reverse_kl" in batch: - for domain, domain_kls in batch["mopd_reverse_kl"].items(): - domain_kl_tensor = torch.cat(domain_kls, dim=0) - reported_loss[f"mopd_reverse_kl/{domain}"] = sum_of_sample_mean(domain_kl_tensor).clone().detach() + # Iterate over ALL configured teacher domains — not just the + # keys present in this microbatch — so that every microbatch + # produces the same set of metric keys (required for Megatron's + # loss-reduction across microbatches). + _all_mopd_domains = [ + t["domain"] for t in getattr(args, "_mopd_teachers_parsed", []) + ] + _mopd_reverse_kl_domains = _all_mopd_domains if _all_mopd_domains else list(batch["mopd_reverse_kl"].keys()) + for domain in _mopd_reverse_kl_domains: + if domain in batch["mopd_reverse_kl"]: + domain_kl_tensor = torch.cat(batch["mopd_reverse_kl"][domain], dim=0) + reported_loss[f"mopd_reverse_kl/{domain}"] = sum_of_sample_mean(domain_kl_tensor).clone().detach() + else: + reported_loss[f"mopd_reverse_kl/{domain}"] = torch.tensor(0.0, device=mopd_is_weights.device) if "mopd_advantages" in batch: mopd_advantages = torch.cat(batch["mopd_advantages"], dim=0) reported_loss["mopd_advantage_mean"] = sum_of_sample_mean(mopd_advantages).clone().detach() - # Log MOPD logits-based distillation metrics (full_vocab / top_k) + # Log MOPD logits-based distillation metrics (full_vocab / top_k). + # mopd_fv_metrics already contains zero-valued entries for domains that + # had no valid teacher data in this microbatch (see apply_mopd_topk_to_loss + # and apply_mopd_full_vocab_to_loss), ensuring consistent key sets. if use_mopd_logits_based: for key, value in mopd_fv_metrics.items(): reported_loss[key] = value diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 8a1abca2da..029f34b4c1 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -2,11 +2,14 @@ # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py from argparse import Namespace +import logging import torch import torch.distributed as dist import torch.nn.functional as F +logger = logging.getLogger(__name__) + @torch.compile(dynamic=True) def compute_approx_kl( @@ -404,6 +407,17 @@ def vocab_parallel_topk_reverse_kl( # Gather student probs and log-probs at teacher's top-k positions # teacher_topk_indices are LOCAL to this TP shard + # Defensive: clamp indices to valid range in case of mismatch between + # the padded vocab size used for TP shard calculation and the model's + # actual output dimension (e.g., when padded_vocab_size > vocab_size). + _v_local = s_softmax.size(-1) + if teacher_topk_indices.max() >= _v_local or teacher_topk_indices.min() < 0: + logger.warning( + f"teacher_topk_indices out of range for s_softmax.size(-1)={_v_local}, " + f"clamping indices. This typically indicates a mismatch between " + f"padded_vocab_size and the model's actual vocab dimension." + ) + teacher_topk_indices = teacher_topk_indices.clamp(min=0, max=_v_local - 1) student_topk_probs = s_softmax.gather(-1, teacher_topk_indices) # [R, k] student_topk_shifted = s_shifted.gather(-1, teacher_topk_indices) # [R, k] student_topk_log_probs = student_topk_shifted - s_log_sum_exp # [R, k] diff --git a/tools/patch_attention_gate_on_cluster.py b/tools/patch_attention_gate_on_cluster.py new file mode 100644 index 0000000000..b9e05e7d18 --- /dev/null +++ b/tools/patch_attention_gate_on_cluster.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Patch Megatron-LM attention.py on all Ray cluster nodes. + +Fix: When num_query_groups < tp_size (e.g., GQA with num_kv_heads < tp_size), +the gate tensor in SelfAttention.get_query_key_value_tensors() needs the same +TP rank indexing as query, otherwise gate.shape != core_attn_out.shape in +_apply_output_gate(). + +Usage: + python patch_attention_gate_on_cluster.py # apply patch + python patch_attention_gate_on_cluster.py --rollback # restore from backup + python patch_attention_gate_on_cluster.py --diagnose # show target lines on cluster +""" + +import argparse +import textwrap + +import ray +import subprocess + +FILE_PATH = "/root/Megatron-LM/megatron/core/transformer/attention.py" +REMOTE_SCRIPT_PATH = "/tmp/_patch_attention_gate.py" + +DIAGNOSE_SCRIPT = textwrap.dedent(f"""\ + import sys + FILE_PATH = {FILE_PATH!r} + try: + with open(FILE_PATH, "r") as f: + lines = f.readlines() + except FileNotFoundError: + print("FILE_NOT_FOUND") + sys.exit(0) + # Search for the "if output_gate:" block in get_query_key_value_tensors + # We look for the pattern: gate.reshape(...) followed by return query, key, value, gate + found = False + for i, line in enumerate(lines): + stripped = line.rstrip() + if "gate = gate.reshape(*gate.shape[:2], -1, self.hidden_size_per_attention_head)" in stripped: + # Print context: 3 lines before, this line, and 3 lines after + start = max(0, i - 3) + end = min(len(lines), i + 4) + print(f"LINE {{i+1}} (0-indexed {{i}}):") + for j in range(start, end): + marker = ">>>" if j == i else " " + print(f" {{marker}} {{j+1}}: {{lines[j].rstrip()}}") + found = True + break + if not found: + # Fallback: search for any line with "gate.reshape" and "hidden_size_per_attention_head" + for i, line in enumerate(lines): + if "gate.reshape" in line and "hidden_size_per_attention_head" in line: + start = max(0, i - 3) + end = min(len(lines), i + 4) + print(f"LINE {{i+1}} (0-indexed {{i}}) [fallback match]:") + for j in range(start, end): + marker = ">>>" if j == i else " " + print(f" {{marker}} {{j+1}}: {{lines[j].rstrip()}}") + found = True + break + if not found: + print("PATTERN_NOT_FOUND") + # Show all lines containing "output_gate" for debugging + for i, line in enumerate(lines): + if "output_gate" in line or "gate.reshape" in line: + print(f" {{i+1}}: {{line.rstrip()}}") +""") + +# Patch script: uses robust line-by-line approach instead of string matching +PATCH_SCRIPT = textwrap.dedent(f"""\ + import sys, shutil + + FILE_PATH = {FILE_PATH!r} + + with open(FILE_PATH, "r") as f: + lines = f.readlines() + + # Check if already patched: look for the TP indexing line we add + already_patched = any( + "gate needs the same TP rank indexing" in line + for line in lines + ) + if already_patched: + print("ALREADY_PATCHED") + sys.exit(0) + + # Find the target line: "gate = gate.reshape(*gate.shape[:2], -1, self.hidden_size_per_attention_head)" + # in the output_gate block of get_query_key_value_tensors + target_idx = None + for i, line in enumerate(lines): + stripped = line.rstrip() + if "gate = gate.reshape(*gate.shape[:2], -1, self.hidden_size_per_attention_head)" in stripped: + # Verify this is in the output_gate block by checking the next line + # should be "return query, key, value, gate" + next_line = lines[i + 1].strip() if i + 1 < len(lines) else "" + if "return query, key, value, gate" in next_line: + target_idx = i + break + + if target_idx is None: + # Fallback: find any gate.reshape ... followed by return query, key, value, gate + for i, line in enumerate(lines): + if "gate.reshape" in line and "hidden_size_per_attention_head" in line: + next_line = lines[i + 1].strip() if i + 1 < len(lines) else "" + if "return query, key, value, gate" in next_line: + target_idx = i + break + + if target_idx is None: + print("TARGET_NOT_FOUND") + sys.exit(0) + + # Determine the indentation of the reshape line + indent = "" + for ch in lines[target_idx]: + if ch in (" ", "\\t"): + indent += ch + else: + break + + # Build the patch lines (same indent, one level deeper for the if block) + inner_indent = indent + " " + patch_lines = [ + "\\n", + indent + "if self.config.num_query_groups < self.world_size:\\n", + inner_indent + "# When num_kv_heads < tp_size, gate needs the same TP rank indexing\\n", + inner_indent + "# as query (see lines above for query indexing logic).\\n", + inner_indent + "idx = get_tensor_model_parallel_rank() % (\\n", + inner_indent + " self.world_size // self.config.num_query_groups\\n", + inner_indent + ")\\n", + inner_indent + "size = self.num_attention_heads_per_partition // (\\n", + inner_indent + " self.world_size // self.config.num_query_groups\\n", + inner_indent + ")\\n", + inner_indent + "gate = gate[:, :, idx * size : (idx + 1) * size, :]\\n", + ] + + # Backup + backup_path = FILE_PATH + ".gate_fix.bak" + shutil.copy2(FILE_PATH, backup_path) + print(f"BACKUP: {{backup_path}}") + + # Insert patch lines after the reshape line (before the return line) + new_lines = lines[: target_idx + 1] + patch_lines + lines[target_idx + 1 :] + with open(FILE_PATH, "w") as f: + f.writelines(new_lines) + print("PATCHED") +""") + +ROLLBACK_SCRIPT = textwrap.dedent(f"""\ + import sys, shutil, os + + FILE_PATH = {FILE_PATH!r} + backup_path = FILE_PATH + ".gate_fix.bak" + + if not os.path.exists(backup_path): + print("NO_BACKUP_FOUND") + sys.exit(0) + + shutil.copy2(backup_path, FILE_PATH) + print("ROLLED_BACK") +""") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--rollback", action="store_true", help="Restore from backup") + parser.add_argument("--diagnose", action="store_true", help="Show target lines on cluster (no patch)") + args = parser.parse_args() + + if args.diagnose: + script_body = DIAGNOSE_SCRIPT + elif args.rollback: + script_body = ROLLBACK_SCRIPT + else: + script_body = PATCH_SCRIPT + + ray.init(address="auto") + + nodes = [ + n["NodeManagerAddress"] + for n in ray.nodes() + if n["Alive"] + ] + print(f"Found {len(nodes)} alive nodes") + + # Only check one node for diagnose (they should all be the same) + target_nodes = nodes[:1] if args.diagnose else nodes + + tasks = [] + for node_ip in target_nodes: + @ray.remote(resources={f"node:{node_ip}": 0.001}) + def run_on_node(node_ip=node_ip): + # Step 1: write script to temp file + write_cmd = ["python3", "-c", f"open({REMOTE_SCRIPT_PATH!r},'w').write({script_body!r})"] + r1 = subprocess.run(write_cmd, capture_output=True, text=True, timeout=30) + if r1.returncode != 0: + return {"node_ip": node_ip, "result": f"WRITE_FAILED: {r1.stderr.strip()}"} + + # Step 2: execute + r2 = subprocess.run(["python3", REMOTE_SCRIPT_PATH], capture_output=True, text=True, timeout=30) + return { + "node_ip": node_ip, + "result": r2.stdout.strip() if r2.returncode == 0 else f"EXEC_FAILED: {r2.stderr.strip()}", + } + + tasks.append(run_on_node.remote()) + + results = ray.get(tasks) + + if args.diagnose: + for r in results: + print(f"\n=== {r['node_ip']} ===") + print(r["result"]) + return + + success = 0 + for r in results: + print(f" {r['node_ip']}: {r['result']}") + if "PATCHED" in r["result"] or "ALREADY_PATCHED" in r["result"] or "ROLLED_BACK" in r["result"]: + success += 1 + + action = "rolled back" if args.rollback else "patched" + print(f"\n{success}/{len(nodes)} nodes {action} successfully.") + + +if __name__ == "__main__": + main() From aa84536b351586da7db37f74688d0550f3d333cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 22:59:05 +0800 Subject: [PATCH 08/14] feat: add non-colocate mode and deployment improvements - Add non-colocate mode support in update_weight_from_distributed.py (separate actor training GPUs from SGLang rollout GPUs) - Add HfWeightIteratorBridge support for Megatron-to-HF conversion in weight update pipeline (supports VL MoE models) - Switch 397B model script to use bridge mode for megatron-to-hf - Update 397B SGLang script for non-colocate deployment - Update 35B script with optimized parallelism settings --- .../run-qwen35-35B-A3B-mopd-topk-sglang.sh | 22 ++--- .../run-qwen35-397B-A17B-mopd-topk-sglang.sh | 22 +++-- scripts/models/qwen3.5-397B-A17B.sh | 2 +- .../update_weight_from_distributed.py | 87 +++++++++++++++---- 4 files changed, 99 insertions(+), 34 deletions(-) diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh index 42a9e92e63..f5e5757ef8 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh @@ -54,8 +54,9 @@ CKPT_ARGS=( --hf-checkpoint ${HF_CKPT}/ --load ${TORCH_DIST_CKPT}/ --save ${SAVE_DIR}/ - --save-interval 10 + --save-interval 32 --no-save-optim + --no-ckpt-fully-parallel-save ) ROLLOUT_ARGS=( @@ -93,15 +94,16 @@ PERF_ARGS=( --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 - --expert-model-parallel-size 8 + --expert-model-parallel-size 32 --expert-tensor-parallel-size 1 --recompute-granularity full --recompute-method uniform - --recompute-num-layers 1 + --recompute-num-layers 4 # --use-dynamic-batch-size --max-tokens-per-gpu 2048 + --train-memory-margin-bytes 536870912 ) MOPD_ARGS=( @@ -115,7 +117,7 @@ MOPD_ARGS=( # top_k distillation type --mopd-distill-type top_k - --mopd-topk-k 16 + --mopd-topk-k 96 # No --mopd-teacher-loads in SGLang mode! # Teacher data comes from SGLang server via HTTP during rollout. @@ -129,7 +131,7 @@ MOPD_ARGS=( OPTIMIZER_ARGS=( --optimizer adam - --lr 5e-7 # Conservative LR for stability + --lr 1e-6 # Conservative LR for stability --lr-decay-style constant --weight-decay 0.1 --adam-beta1 0.9 @@ -144,9 +146,9 @@ OPTIMIZER_ARGS=( WANDB_ARGS=() SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 - --sglang-mem-fraction-static 0.25 - --sglang-ep-size 8 + --rollout-num-gpus-per-engine 16 + --sglang-mem-fraction-static 0.10 + --sglang-ep-size 16 ) MISC_ARGS=( @@ -161,7 +163,7 @@ MISC_ARGS=( --no-check-for-nan-in-loss-and-grad --recompute-loss-function - --log-probs-chunk-size 1024 + --log-probs-chunk-size 512 --qkv-format bshd --micro-batch-size 1 --colocate @@ -192,7 +194,7 @@ print(json.dumps({'env_vars': env})) ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 ../workspace/bin/slime/train.py \ - --actor-num-nodes 1 \ + --actor-num-nodes 4 \ --actor-num-gpus-per-node 8 \ --update-weight-buffer-size $(( 1024 * 1024 * 1024 * 4 )) \ ${MODEL_ARGS[@]} \ diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh index 044c62bd9c..df8f07b040 100755 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh @@ -1,15 +1,16 @@ #!/bin/bash -# Multi-Teacher On-Policy Distillation (MOPD) — Top-K KL Divergence, SGLang Mode +# Multi-Teacher On-Policy Distillation (MOPD) — Top-K KL Divergence, SGLang Non-colocate Mode # Model: Qwen3.5-397B-A17B (MoE, 512 experts, 10 active) -# Environment: 16 nodes × 8 L20X (143GB each), 128 GPUs total -# Teacher: Skin-multiturn teacher (running on external SGLang servers) -# Mode: SGLang (teacher runs on separate SGLang inference servers, no CPU OOM) +# Environment: 36 nodes × 8 GPUs (143GB each), 288 GPUs total +# - 32 nodes (256 GPUs) for Megatron actor training (non-colocate) +# - 4 nodes ( 32 GPUs) for SGLang rollout (2 engines × 16 GPUs each) +# Teacher: Teacher model (running on external SGLang servers) +# Mode: SGLang non-colocate (actor training and rollout on separate GPU groups) # Distill Type: top_k (approximate reverse KL with top-k teacher logits + tail correction) # -# SGLang mode avoids loading teacher model weights into the Megatron training process, -# eliminating the CPU RAM overhead of TensorBackuper pin_memory backups (~150GB per -# teacher for 397B MoE on each node). The teacher runs on independent SGLang servers, +# Non-colocate mode separates actor training GPUs from SGLang rollout GPUs, +# avoiding GPU memory contention and allowing larger actor parallelism. # and its top-k logprobs are collected during rollout via HTTP requests. # # Key differences from Megatron top_k mode: @@ -194,7 +195,12 @@ MISC_ARGS=( # --moe-enable-deepep # DeepEP internode kernel assertion fails when EP=128 --no-check-for-nan-in-loss-and-grad - --colocate + --recompute-loss-function + --log-probs-chunk-size 512 + --qkv-format bshd + --micro-batch-size 1 + # Non-colocate mode: actor training and rollout on separate GPU groups + # Remove --colocate to use non-colocate mode ) # ============================================================================ diff --git a/scripts/models/qwen3.5-397B-A17B.sh b/scripts/models/qwen3.5-397B-A17B.sh index 7da83a9a53..677810eaac 100644 --- a/scripts/models/qwen3.5-397B-A17B.sh +++ b/scripts/models/qwen3.5-397B-A17B.sh @@ -17,7 +17,7 @@ printf -v MOE_LAYER_FREQ "[%s]" "$(IFS=', '; echo "${arr[*]}")" MODEL_ARGS=( - --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + --megatron-to-hf-mode bridge --disable-bias-linear --qk-layernorm diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 822b801776..7635dd59b6 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -15,12 +15,18 @@ from ..megatron_to_hf import convert_to_hf from .common import all_gather_param, named_params_and_buffers +from .hf_weight_iterator_base import HfWeightIteratorBase class UpdateWeightFromDistributed: """ Update distributed engines via NCCL. Each PP rank: group "slime-pp_{pp_rank}", only DP=TP=0 broadcasts. Non-expert (TP) and expert (EP) params separate. + + When megatron_to_hf_mode=="bridge", uses HfWeightIteratorBridge for Megatron→HF + conversion (which supports models like Qwen3.5-VL MoE with vision encoders that + the manual convert_to_hf path does not handle). Otherwise falls back to the + direct convert_to_hf path. """ def __init__( @@ -37,11 +43,22 @@ def __init__( """ self.args = args self.model = model + self.weights_getter = weights_getter self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 self._model_update_groups = None + # When megatron_to_hf_mode is "bridge", use HfWeightIteratorBridge for + # Megatron→HF conversion. This supports models (e.g. Qwen3.5-VL MoE) + # with vision encoders whose weight mappings are registered via + # megatron.bridge but not in the manual convert_to_hf functions. + self._use_bridge = getattr(args, "megatron_to_hf_mode", "raw") == "bridge" + if self._use_bridge: + self._hf_weight_iterator = HfWeightIteratorBase.create( + args=args, model=model, model_name=model_name, quantization_config=quantization_config + ) + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -89,7 +106,11 @@ def disconnect_rollout_engines(self) -> None: @torch.no_grad() def update_weights(self) -> None: """ - Pause → flush → non-expert (TP) → expert (EP) → continue. Progress on PP source. + Pause → flush → convert HF → broadcast → continue. + + When megatron_to_hf_mode=="bridge", uses HfWeightIteratorBridge for + Megatron→HF conversion. Otherwise uses the manual convert_to_hf path + with separate non-expert and expert parameter processing. """ self.weight_version += 1 @@ -106,11 +127,59 @@ def update_weights(self) -> None: ) dist.barrier(group=get_gloo_group()) + pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None + + if self._use_bridge: + self._update_weights_bridge(pbar) + else: + self._update_weights_direct(pbar) + + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + # int4/fp4 post_process + if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + post_process_weights( + restore_weights_before_load=False, + post_process_quantization=True, + rollout_engines=self.rollout_engines, + ) + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + @torch.no_grad() + def _update_weights_bridge(self, pbar: tqdm | None = None) -> None: + """Update weights using HfWeightIteratorBridge for Megatron→HF conversion. + + This path supports models with vision encoders and other components whose + weight mappings are registered via megatron.bridge (e.g. Qwen3.5-VL MoE) + but not in the manual convert_to_hf functions. + + The bridge iterator handles TP/EP/PP synchronization and HF conversion + internally. We only need to broadcast the resulting HF weights to rollout + engines via NCCL. + + All ranks must iterate through get_hf_weight_chunks because the bridge + internally performs collective operations (TP all-gather, EP all-gather, + PP broadcast) that require participation from every rank. Only the + PP source rank (DP=TP=0) broadcasts the converted HF weights to the + rollout engines. + """ + megatron_local_weights = self.weights_getter() + + for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): + if self._is_pp_src_rank: + self._update_bucket_weights_from_distributed(hf_named_tensors, pbar=pbar) + + @torch.no_grad() + def _update_weights_direct(self, pbar: tqdm | None = None) -> None: + """Update weights using the manual convert_to_hf path. + + Processes non-expert and expert parameters separately with manual + TP all-gather and EP all-gather, then converts to HF format. + """ buffer_size = 0 converted_named_tensors = [] # non expert params - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - for name, param in named_params_and_buffers(self.args, self.model): if ".experts." in name: continue @@ -135,18 +204,6 @@ def update_weights(self) -> None: if named_tensors: self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) - dist.barrier(group=get_gloo_group()) - if dist.get_rank() == 0: - # int4/fp4 post_process - if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: - post_process_weights( - restore_weights_before_load=False, - post_process_quantization=True, - rollout_engines=self.rollout_engines, - ) - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - def _update_weight_from_distributed( self, name: str, From 7b80b7856ce10f88eb18351bd814fa70d7dd99e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 23:06:57 +0800 Subject: [PATCH 09/14] docs: add comprehensive MOPD guide for Qwen3.5 MoE models - Add GUIDE_qwen35_moe_mopd.md with detailed usage documentation - Cover MOPD workflow, distillation modes (TopK, full-vocab) - Document SGLang teacher server setup and configuration - Document multi-teacher domain routing and hyperparameters - Include troubleshooting and FAQ sections --- .../GUIDE_qwen35_moe_mopd.md | 619 ++++++++++++++++++ ...run-qwen35-397B-A17B-mopd-topk-megatron.sh | 4 +- .../run-qwen35-397B-A17B-mopd-topk-sglang.sh | 2 +- 3 files changed, 622 insertions(+), 3 deletions(-) create mode 100644 examples/multi_teacher_on_policy_distillation/GUIDE_qwen35_moe_mopd.md diff --git a/examples/multi_teacher_on_policy_distillation/GUIDE_qwen35_moe_mopd.md b/examples/multi_teacher_on_policy_distillation/GUIDE_qwen35_moe_mopd.md new file mode 100644 index 0000000000..1f613ac9b4 --- /dev/null +++ b/examples/multi_teacher_on_policy_distillation/GUIDE_qwen35_moe_mopd.md @@ -0,0 +1,619 @@ +# Qwen3.5 MoE 多教师在线策略蒸馏 (MOPD) 训练指南 + +本文档以 Qwen3.5 MoE 系列模型(35B-A3B、397B-A17B 等)为例,详细说明如何使用 slime 进行 MOPD (Multi-Teacher On-Policy Distillation) 训练,包括前期准备、参数配置、训练启动、checkpoint 转换和数据集构造。 + +--- + +## 目录 + +1. [整体流程概览](#1-整体流程概览) +2. [前期准备](#2-前期准备) + - 2.1 [HF 模型转 Megatron torch_dist 格式](#21-hf-模型转-megatron-torch_dist-格式) + - 2.2 [准备 Teacher 模型](#22-准备-teacher-模型) + - 2.3 [准备训练数据](#23-准备训练数据) +3. [SGLang 模式 vs Megatron 模式](#3-sglang-模式-vs-megatron-模式) +4. [启动 SGLang Teacher 服务](#4-启动-sglang-teacher-服务) +5. [训练脚本参数详解](#5-训练脚本参数详解) + - 5.1 [模型参数](#51-模型参数) + - 5.2 [Checkpoint 参数](#52-checkpoint-参数) + - 5.3 [Rollout 参数](#53-rollout-参数) + - 5.4 [MOPD 参数](#54-mopd-参数) + - 5.5 [性能参数](#55-性能参数) + - 5.6 [SGLang 参数](#56-sglang-参数) +6. [启动训练](#6-启动训练) +7. [训练后的 Checkpoint 转换](#7-训练后的-checkpoint-转换) + - 7.1 [使用 Bridge 模式转换(推荐)](#71-使用-bridge-模式转换推荐) + - 7.2 [使用手动映射转换](#72-使用手动映射转换) + - 7.3 [VLM 模型补齐 Visual Encoder 权重](#73-vlm-模型补齐-visual-encoder-权重) + - 7.4 [验证转换结果](#74-验证转换结果) + - 7.5 [选择最佳 Checkpoint](#75-选择最佳-checkpoint) +8. [数据集构造说明](#8-数据集构造说明) + - 8.1 [基本格式](#81-基本格式) + - 8.2 [多领域路由(可选)](#82-多领域路由可选) + - 8.3 [数据质量建议](#83-数据质量建议) +9. [常见问题](#9-常见问题) + +--- + +## 1. 整体流程概览 + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ 前期准备 │ +│ 1. HF → torch_dist 转换(学生模型 + 教师模型[Megatron模式]) │ +│ 2. 启动 SGLang 教师服务 [SGLang模式] │ +│ 3. 准备 JSONL 训练数据 │ +└──────────────────────────┬──────────────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────────────────┐ +│ 训练 │ +│ - 学生模型 rollout 生成响应 │ +│ - 教师模型获取 log-probs (SGLang HTTP / Megatron 前向传播) │ +│ - 计算 MOPD 损失 + 反向传播 │ +│ - 定期保存 Megatron checkpoint │ +└──────────────────────────┬──────────────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────────────────┐ +│ 推理部署 │ +│ 1. Megatron torch_dist → HF safetensors 转换 │ +│ 2. 补齐 VLM visual encoder 权重 │ +│ 3. SGLang / vLLM 推理 │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 2. 前期准备 + +### 2.1 HF 模型转 Megatron torch_dist 格式 + +训练前需要将 HuggingFace 格式的学生模型转换为 Megatron torch_dist 格式。Qwen3.5 MoE 系列模型(含 VLM)**必须使用 `--megatron-to-hf-mode bridge`** 模式进行转换,因为它们包含 GDN (Gated DeltaNet) 线性注意力层、attention_output_gate、MTP 等自定义架构特性,只有 `megatron.bridge` 才能正确处理这些特殊参数的映射。 + +```bash +cd /path/to/slime +source scripts/models/qwen3.5-35B-A3B.sh # 或 qwen3.5-397B-A17B.sh + +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node=8 \ + tools/convert_hf_to_torch_dist.py \ + --megatron-to-hf-mode bridge \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /path/to/Qwen3.5-35B-A3B \ + --save /path/to/Qwen3.5-35B-A3B_Torch_Dist_Bridge +``` + +**参数说明:** +- `source scripts/models/qwen3.5-35B-A3B.sh`:加载模型架构参数(层数、隐藏维度、MoE 配置等) +- `--megatron-to-hf-mode bridge`:**必须指定**。使用 `megatron.bridge` 进行权重映射,以正确处理 Qwen3.5 的自定义架构(GDN 线性注意力、attention_output_gate、visual encoder 等) +- `--hf-checkpoint`:原始 HuggingFace 模型目录(包含 config.json、safetensors 等) +- `--save`:输出的 torch_dist 检查点目录 +- `--nproc-per-node=8`:建议使用 8 GPU 并行转换,速度更快 + +**注意事项:** +- **不要使用 `--megatron-to-hf-mode raw`(默认值)**,raw 模式使用 `mbridge`,不支持 Qwen3.5 VLM 的自定义架构 +- VLM 模型的 visual encoder 权重会被加载到 Megatron 模型中,但**不会**被保存到 Megatron checkpoint(Megatron 只保存语言模型部分) +- 这些权重在后续转回 HF 时需要从原始模型补回(见[第 7 节](#7-训练后的-checkpoint-转换)) + +### 2.2 准备 Teacher 模型 + +根据教师模式不同: + +#### SGLang 模式(推荐) + +只需准备 HF/safetensors 格式的教师模型,用于启动 SGLang 推理服务。**不需要**转换为 torch_dist 格式。 + +```bash +# 教师模型只需是 SGLang 可加载的格式 +TEACHER_MODEL=/path/to/teacher_model_safetensors +``` + +#### Megatron 模式 + +需要将教师模型也转换为 torch_dist 格式: + +```bash +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node=8 \ + tools/convert_hf_to_torch_dist.py \ + --megatron-to-hf-mode bridge \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /path/to/teacher_model \ + --save /path/to/teacher_model_torch_dist +``` + +> **注意:** Megatron 模式要求教师与学生**架构完全相同**。SGLang 模式则无此限制。 + +### 2.3 准备训练数据 + +训练数据使用 JSONL 格式,每行一个 JSON 对象。基本字段为 `messages`(对话格式)或 `prompt`(纯文本格式)。 + +详细格式见[第 8 节](#8-数据集构造说明)。 + +--- + +## 3. SGLang 模式 vs Megatron 模式 + +| 维度 | SGLang 模式 | Megatron 模式 | +|------|-----------|-------------| +| **教师运行位置** | 外部 SGLang 服务器 | 加载到训练进程 CPU 内存 | +| **教师架构要求** | 无限制(可与学生不同) | **必须与学生架构相同** | +| **CPU 内存开销** | 无额外开销 | 每个教师 ≈ 模型大小(397B ≈ 800GB/教师) | +| **支持蒸馏类型** | `token_level` + `top_k` | `token_level` + `top_k` + `full_vocab` | +| **top_k 尾部校正** | 精确计算(SGLang 返回归一化 log-probs) | 均匀分布估计(保守上界) | +| **故障处理** | 教师 503 时跳过(会触发 RuntimeError) | 无此问题 | +| **适用场景** | 教师架构不同、避免 CPU OOM | 需要全词表精确 KL、教师架构相同 | + +**推荐选择:** 对于 MoE 大模型,**强烈推荐 SGLang 模式**,因为: +- Megatron 模式需额外 ~800GB CPU 内存/教师,多节点容易 OOM +- SGLang 的 top_k 尾部校正更精确 +- 教师可以与学生架构不同(如用更大的教师蒸馏) + +--- + +## 4. 启动 SGLang Teacher 服务 + +在训练之前,需要先启动教师模型的 SGLang 推理服务。 + +```bash +# 多 GPU 启动教师模型 (TP=8, EP=16, 共 16 GPU) +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m sglang.launch_server \ + --model-path /path/to/teacher_model/ \ + --host 0.0.0.0 \ + --port 13141 \ + --tp 8 --ep-size 16 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static 0.7 +``` + +**等待服务就绪:** + +```bash +until curl -sf http://localhost:13141/health_generate > /dev/null; do + echo "Waiting for teacher model server to start..." + sleep 10 +done +echo "Teacher server is ready!" +``` + +**关键参数说明:** +- `--tp 8`:张量并行数 +- `--ep-size 16`:专家并行数(MoE 模型需要) +- `--mem-fraction-static 0.7`:KV cache 显存占比 +- `--chunked-prefill-size 4096`:分块预填充大小 + +> **重要:** 教师模型和学生模型的**词表大小必须一致**,否则 top_k 的 token index 映射会出错。 + +--- + +## 5. 训练脚本参数详解 + +以 `run-qwen35-35B-A3B-mopd-topk-sglang.sh` 和 `run-qwen35-397B-A17B-mopd-topk-sglang.sh` 为例。 + +### 5.1 模型参数 + +模型架构参数通过 `source scripts/models/qwen3.5-35B-A3B.sh` 或 `source scripts/models/qwen3.5-397B-A17B.sh` 加载,核心参数: + +**35B-A3B:** +```bash +MODEL_ARGS=( + --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + --num-attention-heads 16 # 注意力头数 + --num-query-groups 2 # GQA KV 组数 + --kv-channels 256 # Head 维度 + --num-layers 40 # 层数 + --hidden-size 2048 # 隐藏维度 + --num-experts 256 # MoE 专家数 + --moe-router-topk 8 # 每层激活专家数 + --attention-output-gate # Qwen3.5 特有:注意力输出门控 + --moe-shared-expert-gate # Qwen3.5 特有:共享专家门控 + # ... 其他参数见 scripts/models/qwen3.5-35B-A3B.sh +) +``` + +**397B-A17B:** +```bash +MODEL_ARGS=( + --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + --num-attention-heads 32 + --num-query-groups 2 + --kv-channels 256 + --num-layers 60 + --hidden-size 4096 + --num-experts 512 + --moe-router-topk 10 + --attention-output-gate + --moe-shared-expert-gate + # ... 其他参数见 scripts/models/qwen3.5-397B-A17B.sh +) +``` + +### 5.2 Checkpoint 参数 + +```bash +CKPT_ARGS=( + --hf-checkpoint ${HF_CKPT}/ # 原始 HF 模型路径 + --load ${TORCH_DIST_CKPT}/ # torch_dist 初始/恢复检查点 + --save ${SAVE_DIR}/ # 训练检查点保存路径 + --save-interval 32 # 每 32 步保存一次 + --no-save-optim # 不保存优化器状态(节省磁盘) + --no-ckpt-fully-parallel-save +) +``` + +- `--hf-checkpoint`:SGLang rollout 引擎从中加载权重 +- `--load`:首次训练时指向 torch_dist 格式的初始权重,恢复训练时指向保存目录 +- `--save`:训练检查点保存路径 + +### 5.3 Rollout 参数 + +```bash +ROLLOUT_ARGS=( + --input-key messages # 数据中对话字段的 key + --apply-chat-template # 应用聊天模板 + --rollout-shuffle # 打乱数据 + --rollout-batch-size 4 # 每次 rollout 的 batch size + --n-samples-per-prompt 4 # 每个 prompt 采样次数 + --rollout-max-prompt-len 9216 # 最大 prompt 长度 + --rollout-max-response-len 2048 # 最大生成长度 + --rollout-temperature 0.8 # 采样温度 + + --global-batch-size 16 # 全局 batch size + --balance-data # 跨节点均衡数据 + --num-epoch 1 # 训练轮数 +) +``` + +### 5.4 MOPD 参数 + +```bash +MOPD_ARGS=( + --advantage-estimator grpo # 优势估计方法 + + # -- MOPD 核心 -- + --use-mopd # 启用 MOPD + --mopd-teacher-mode sglang # 教师模式:sglang 或 megatron + --mopd-distill-type top_k # 蒸馏类型:token_level / top_k / full_vocab + --mopd-topk-k 96 # top_k 保留的 token 数 + + # -- MOPD 超参 -- + --mopd-alpha 0.0 # α=0 纯蒸馏(无需奖励模型) + --mopd-eps-low 0.2 # IS 权重截断下界 + --mopd-eps-high 5.0 # IS 权重截断上界 + --mopd-sampling-logprobs-key rollout_log_probs +) +``` + +**参数详解:** + +| 参数 | 说明 | +|------|------| +| `--mopd-alpha` | 蒸馏与 ORM 的混合系数。0 = 纯蒸馏(无需奖励模型),>0 = 蒸馏 + RL 组合 | +| `--mopd-distill-type top_k` | **推荐**。每位置只传教师 top-k 个 token 的 logits,内存省 ~97% | +| `--mopd-distill-type token_level` | 每位置只传 1 个标量(教师对采样 token 的 log-prob),最省内存但精度低 | +| `--mopd-distill-type full_vocab` | 精确 KL,但内存开销极大(397B 词表 248K × batch × seq),**仅 Megatron 模式支持** | +| `--mopd-topk-k` | top_k 保留的 token 数。k=128 适合大多数场景,V>200K 时推荐 256+ | +| `--mopd-eps-low/eps-high` | IS 权重截断范围。紧范围(如[0.5,2])低方差高偏差;松范围(如[0.1,10])高方差低偏差 | + +**教师配置(环境变量):** + +```bash +# 教师列表:name=教师名称, domain=领域标识 +export MOPD_TEACHERS_JSON='[{"name":"math-teacher","domain":"math"},{"name":"code-teacher","domain":"code"}]' + +# SGLang 模式:domain -> URL 映射 +export MOPD_TEACHER_URLS="{\"math\":\"https://$MATH_TEACHER_IP:$PORT/generate\",\"code\":\"https://$CODE_TEACHER_IP:$PORT/generate\"}" +``` + +### 5.5 性能参数 + +**35B-A3B(4 节点 × 8 GPU = 32 GPU):** +```bash +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --expert-model-parallel-size 32 # 256 专家 / 32 EP = 8 专家/GPU + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 4 + --max-tokens-per-gpu 2048 + --train-memory-margin-bytes 536870912 +) +``` + +**397B-A17B(32 节点 × 8 GPU = 256 GPU):** +```bash +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --expert-model-parallel-size 128 # 512 专家 / 128 EP = 4 专家/GPU + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 4 + --max-tokens-per-gpu 2048 + --train-memory-margin-bytes 268435456 +) +``` + +### 5.6 SGLang 参数 + +```bash +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 16 # 每个 SGLang 引擎使用的 GPU 数 + --sglang-mem-fraction-static 0.10 # SGLang KV cache 显存占比 + --sglang-ep-size 16 # SGLang 推理的 EP 大小 +) +``` + +--- + +## 6. 启动训练 + +### SGLang 模式(推荐) + +```bash +# 35B-A3B +bash examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh + +# 397B-A17B +bash examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh +``` + +前提条件: +1. SGLang Teacher 服务已启动并就绪 +2. 环境变量 `MOPD_TEACHERS_JSON` 和 `MOPD_TEACHER_URLS` 已设置 +3. 学生模型的 torch_dist 检查点已转换(使用 `--megatron-to-hf-mode bridge`) + +--- + +## 7. 训练后的 Checkpoint 转换 + +训练保存的是 Megatron torch_dist 格式,需要转回 HuggingFace safetensors 格式才能用于 SGLang/vLLM 推理。 + +Qwen3.5 MoE 系列模型包含自定义架构(GDN 线性注意力、attention_output_gate、VLM visual encoder 等),**必须使用 bridge 模式**进行反向转换,确保与正向转换的权重映射一致。 + +### 7.1 使用 Bridge 模式转换(推荐) + +这是 Qwen3.5 MoE VLM 模型的**推荐转换方式**,使用 `megatron.bridge` 进行端到端转换,与正向转换(`--megatron-to-hf-mode bridge`)保持映射一致。 + +```bash +PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf_bridge.py \ + --input-dir /path/to/save_dir/iter_0000009/ \ + --output-dir /path/to/output_hf \ + --origin-hf-dir /path/to/original/Qwen3.5-35B-A3B \ + --force +``` + +**关键参数:** +- `--input-dir`:训练产出的 torch_dist 检查点目录(包含 `common.pt` 和 `.metadata` 等文件) +- `--output-dir`:输出的 HF safetensors 目录 +- `--origin-hf-dir`:**原始 HF 模型目录**(提供 config.json 用于推断模型架构 + 补齐缺失的 visual encoder 权重) +- `--force`:覆写已存在的输出目录 + +**为什么必须使用 bridge 模式?** + +Qwen3.5 MoE VLM 模型有以下特殊性,手动映射脚本 (`convert_torch_dist_to_hf.py`) 无法正确处理: + +| 特性 | 手动映射脚本 | Bridge 模式 | +|------|-------------|------------| +| `attention_output_gate` (QKV+G 融合) | 只做 Q/K/V 三分拆,**gate 权重丢失** | 正确的 Q/G/K/V 四分拆 | +| Visual encoder (`model.visual.*`) | 未处理,遇到即报错 | 自动映射 `vision_model.**` → `model.visual.**` | +| GDN 线性注意力层参数 | 部分覆盖 | 完整映射 | +| Expert 权重融合/拆分 | 依赖 `common.pt` 中的 args 推断 | Bridge 自动处理 fused/per-expert 格式 | + +**重要:** 确保远端 slime 代码是最新的,`convert_torch_dist_to_hf_bridge.py` 必须包含 `import slime_plugins.megatron_bridge` 以注册自定义 `Qwen35VLMoeBridge`。否则会使用官方 bridge,导致 vision encoder 和 MTP 层参数不匹配。 + +### 7.2 使用手动映射转换(非 VLM 模型可选) + +对于**非 VLM** 的 Qwen3.5 MoE 模型(如纯语言模型版本),可以使用手动映射脚本,支持多进程并行加速: + +```bash +# 单进程版本 +PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \ + --input-dir /path/to/iter_0000009/ \ + --output-dir /path/to/output_hf \ + --origin-hf-dir /path/to/original/Qwen3.5-35B-A3B \ + --model-name qwen3_5_moe \ + --add-missing-from-origin-hf \ + --force + +# 多进程并行版本(更快) +PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf_parallel.py \ + --input-dir /path/to/iter_0000009/ \ + --output-dir /path/to/output_hf \ + --origin-hf-dir /path/to/original/Qwen3.5-35B-A3B \ + --model-name qwen3_5_moe \ + --add-missing-from-origin-hf \ + --force +``` + +> **警告:** VLM 模型(如 `Qwen3_5MoeForConditionalGeneration`)**不要**使用此脚本,因为缺少 `attention_output_gate` 和 visual encoder 的映射。 + +### 7.3 VLM 模型补齐 Visual Encoder 权重 + +Qwen3.5 MoE VLM 模型的 Megatron 检查点**只包含语言模型权重**,不包含 visual encoder(`model.visual.*`)。如果不补齐,推理时会出现 `probability tensor contains inf/nan or element < 0` 错误。 + +**方法 1:自动补齐**(Bridge 模式已内置) + +使用 `--origin-hf-dir` 且包含完整 HF 模型时,bridge 模式会通过 `ReplicatedMapping("vision_model.**", "model.visual.**")` 映射,但训练 checkpoint 中不包含 visual encoder 权重。如果转换后缺少 visual encoder 权重,使用 `merge_missing_keys.py` 补齐: + +```bash +python tools/merge_missing_keys.py \ + --origin-hf-dir /path/to/original/Qwen3.5-35B-A3B \ + --converted-dir /path/to/output_hf \ + --dry-run # 先预览缺失的 key +``` + +去掉 `--dry-run` 执行实际补齐。 + +**方法 2:手动映射脚本的 `--add-missing-from-origin-hf`** + +使用 `convert_torch_dist_to_hf.py` 时加 `--add-missing-from-origin-hf` 会自动从原始 HF 模型补充缺失的权重。 + +### 7.4 验证转换结果 + +```bash +# 检查 key 数量 +python3 -c " +import json +with open('/path/to/output_hf/model.safetensors.index.json') as f: + idx = json.load(f) +print(f'Total keys: {len(idx[\"weight_map\"])}') +# 对比原始模型: +import os +origin_keys = set() +from safetensors import safe_open +for f in sorted(os.listdir('/path/to/original/Qwen3.5-35B-A3B')): + if f.endswith('.safetensors'): + with safe_open(f'/path/to/original/Qwen3.5-35B-A3B/{f}', framework='pt', device='cpu') as sf: + origin_keys.update(sf.keys()) +print(f'Original model keys: {len(origin_keys)}') +missing = origin_keys - set(idx['weight_map'].keys()) +if missing: + print(f'WARNING: Missing keys ({len(missing)}): {sorted(missing)[:10]}...') +else: + print('All keys present!') +" + +# 检查是否有 NaN/Inf 权重 +python3 -c " +from safetensors import safe_open +import json, glob, os +idx = json.load(open('/path/to/output_hf/model.safetensors.index.json')) +files = set(idx['weight_map'].values()) +for f in sorted(files): + path = f'/path/to/output_hf/{f}' + with safe_open(path, framework='pt', device='cpu') as sf: + for k in sf.keys(): + t = sf.get_tensor(k) + if t.isnan().any() or t.isinf().any(): + print(f'ERROR: {k} has NaN/Inf!') +print('Validation complete.') +" +``` + +### 7.5 选择最佳 Checkpoint + +训练过程中每 `--save-interval` 步保存一次 checkpoint。选择最佳 checkpoint 的建议: + +- **关注 `mopd_topk_kl` 指标**:应该持续下降,代表学生与教师的分布差距在缩小 +- **关注 `entropy` 指标**:应保持相对稳定,突然暴跌说明模式坍塌 +- **避免 loss=0 的 checkpoint**:如果出现教师服务不可用导致的 0 梯度步骤,该 checkpoint 的权重可能已退化 +- 一般选择 KL 收敛到较低点且 entropy 仍然健康的 checkpoint + +--- + +## 8. 数据集构造说明 + +### 8.1 基本格式 + +训练数据为 JSONL 格式,每行一个 JSON 对象。支持两种输入字段: + +**对话格式(推荐):** + +```jsonl +{"messages": [{"role": "user", "content": "Explain the concept of gradient descent in machine learning."}, {"role": "assistant", "content": "Gradient descent is an optimization algorithm..."}]} +``` + +配合 `--input-key messages --apply-chat-template` 使用。 + +**纯文本格式:** + +```jsonl +{"prompt": "Explain the concept of gradient descent in machine learning."} +``` + +配合 `--input-key prompt` 使用。 + +### 8.2 多领域路由(可选) + +当有多个教师分别负责不同领域时,可以在 `metadata` 中指定每个样本应从哪个教师蒸馏。`mopd_domains` 的值必须与 `MOPD_TEACHERS_JSON` 中对应教师的 `domain` 字段匹配。 + +例如,当教师配置为: +```bash +export MOPD_TEACHERS_JSON='[{"name":"math-teacher","domain":"math"},{"name":"code-teacher","domain":"code"}]' +``` + +数据集可以这样指定领域路由: + +```jsonl +{"messages": [...], "metadata": {"mopd_domains": ["math"]}} +{"messages": [...], "metadata": {"mopd_domains": ["code"]}} +{"messages": [...], "metadata": {"mopd_domains": ["math", "code"]}} +{"messages": [...]} +``` + +说明: +- `"mopd_domains": ["math"]` — 仅从 `math` 领域的教师(math-teacher)蒸馏 +- `"mopd_domains": ["math", "code"]` — 同时从两个教师蒸馏 +- 无 `mopd_domains` 字段 — 从**所有**教师蒸馏(默认行为) +- 也支持字符串简写:`"mopd_domains": "math"` + +### 8.3 数据质量建议 + +1. **数据多样性**:覆盖目标领域的各种子任务,避免过度集中 +2. **长度分布**:控制 prompt 长度分布,避免超长样本浪费 rollout 资源 +3. **batch size 匹配**:`--rollout-batch-size` 应与 `--global-batch-size` 一致 +4. **教师容量**:确保 SGLang 教师服务能处理并发请求量(与 rollout batch size 成正比) + +--- + +## 9. 常见问题 + +### Q1: MOPD 和 OPD 能同时使用吗? + +不能。`--use-mopd` 和 `--use-opd` 互斥。 + +### Q2: alpha=0 时需要奖励模型吗? + +不需要。`--mopd-alpha 0.0` 是纯蒸馏模式,无需 `--rm-type`。系统会自动将 reward 设为 0。 + +### Q3: SGLang 教师服务 503 会怎样? + +如果教师在 rollout 期间返回 503(服务不可用),该 batch 的教师数据会被跳过,代码会打印 warning 日志。在纯蒸馏模式(`--mopd-alpha 0.0`)下,跳过教师数据会导致该 batch 的 MOPD 损失为 0,总损失也为 0,梯度为 0,**训练实质上空转一步**。如果持续 503,模型会因为长期零梯度而退化。建议确保教师服务稳定运行。 + +### Q4: top_k 的 k 值怎么选? + +- `k=128`:适合大多数场景,内存最低 +- `k=1024`:更精确的 KL 近似,内存稍高 +- 经验法则:`k/V > 0.05%` 即可捕获 >99% 的 KL 信号。V=248K 时 k≥128 即可 + +### Q5: 转换后推理报 `probability tensor contains inf/nan` 怎么办? + +这通常是因为 VLM 模型的 visual encoder 权重缺失。使用 `merge_missing_keys.py` 补齐: + +```bash +python tools/merge_missing_keys.py \ + --origin-hf-dir /path/to/original/model \ + --converted-dir /path/to/converted/model \ + --dry-run # 先预览缺失的 key +``` + +### Q6: 训练中出现 loss=0 和 grad_norm=0 怎么办? + +这通常意味着某个 rollout batch 的教师数据获取失败(如 SGLang 503),导致 MOPD 损失被跳过。在纯蒸馏模式(alpha=0)下,跳过 MOPD 损失后总损失为 0,不产生梯度,训练空转一步。如果频繁出现,模型会因长期零梯度而退化。检查教师服务日志,确保服务稳定。 + +### Q7: EP=128 时 DeepEP 报断言错误? + +注释掉 `--moe-enable-deepep`。当前 EP=128 时 DeepEP 的 inter-node kernel 有已知问题,使用默认的 alltoall 即可。 + +### Q8: 如何恢复中断的训练? + +只要 `--load` 和 `--save` 指向同一目录,训练会自动加载最新 checkpoint 继续训练。确保 `--save-interval` 设置合理以避免丢失过多进度。 + +### Q9: convert_torch_dist_to_hf_bridge.py 报 `TypeError: object of type '_io.BytesIO' has no len()` 怎么办? + +这通常是因为 `slime_plugins.megatron_bridge` 模块未被注册,导致 `AutoBridge` 使用了官方的 bridge(而非自定义的 `Qwen35VLMoeBridge`),模型结构与 checkpoint 不匹配。确保脚本中包含以下导入: + +```python +import slime_plugins.megatron_bridge # noqa: F401 # register custom bridges before AutoBridge +``` + +如果看到日志中 `Using Bridge provider: Qwen35VLMoEModelProvider`(官方),说明自定义 bridge 未注册;正确应显示 `Qwen35VLMoeVLModelProvider`。 + +### Q10: convert_torch_dist_to_hf_bridge.py 报大量 vision_model / mtp 参数 "not in state dict"? + +同 Q9,这是因为使用了错误的 bridge。官方 bridge 期望 `vision_model.decoder.layers.*`(Megatron-native 命名),而 VLM 训练使用的是 HF 命名的 `vision_model.blocks.*`。注册自定义 bridge 后,映射会变为 `vision_model.**` → `model.visual.**`,问题自动解决。 diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh index f68986e418..304d3c7a49 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh @@ -3,7 +3,7 @@ # Multi-Teacher On-Policy Distillation (MOPD) — Top-K KL Divergence Mode # Model: Qwen3.5-397B-A17B (MoE, 512 experts, 10 active) # Environment: 16 nodes × 8 L20X (143GB each), 128 GPUs total -# Teacher: Skin-multiturn teacher (different from student for production distillation) +# Teacher: Teacher model (different from student for production distillation) # Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) # Distill Type: top_k (approximate reverse KL with top-k teacher logits + tail correction) # @@ -58,7 +58,7 @@ SAVE_DIR=/amed/share/s1-amed-spfs-ckpt/yanyi/Qwen3.5-397B-A17B-Stage3b-Mopd-Topk DATA_PATH="/mnt/amed-s3/dataset/14019ba0_text_report_Interpretation/a3967912440becb0d70748a478696f12b6bbf6ac/train_text_think_nothink.jsonl" # MOPD teachers JSON config -export MOPD_TEACHERS_JSON='[{"name":"skin-multiturn","domain":"default"}]' +export MOPD_TEACHERS_JSON='[{"name":"teacher","domain":"default"}]' # ============================================================================ # Configure training arguments diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh index df8f07b040..8c509ec362 100755 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh @@ -71,7 +71,7 @@ SAVE_DIR=/amed/share/s1-amed-spfs-ckpt/yanyi/Qwen3.5-397B-A17B-Stage3b-Mopd-Topk DATA_PATH="/mnt/amed-s3/dataset/14019ba0_text_report_Interpretation/a3967912440becb0d70748a478696f12b6bbf6ac/train_text_think_nothink.jsonl" # MOPD teachers JSON config (single teacher for this example) -export MOPD_TEACHERS_JSON='[{"name":"skin-multiturn","domain":"default"}]' +export MOPD_TEACHERS_JSON='[{"name":"teacher","domain":"default"}]' # MOPD teacher SGLang server URLs # For multi-teacher, add all domains: {"math":"https://...","code":"https://..."} From 9de50a42ecc8baf802158f0f990fe17f2d5c03ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Wed, 10 Jun 2026 23:10:42 +0800 Subject: [PATCH 10/14] fix: handle too-long multimodal inputs in filter_long_prompt - Fix filter_long_prompt to correctly process multimodal inputs when apply_chat_template has already converted prompt to a string - Add 'messages' field to Sample dataclass to preserve raw message list for multimodal processing after chat template application - Ensures vision info (images) can be extracted from original messages even when prompt has been templated --- ...qwen35-35B-A3B-mopd-full-vocab-megatron.sh | 14 ++++++------- .../run-qwen35-35B-A3B-mopd-megatron.sh | 20 +++++++++---------- .../run-qwen35-35B-A3B-mopd-sglang.sh | 20 +++++++++---------- .../run-qwen35-35B-A3B-mopd-topk-sglang.sh | 2 +- ...run-qwen35-397B-A17B-mopd-topk-megatron.sh | 16 +++++++-------- .../run-qwen35-397B-A17B-mopd-topk-sglang.sh | 18 ++++++++--------- slime/utils/data.py | 10 +++++++++- slime/utils/types.py | 5 ++++- 8 files changed, 58 insertions(+), 47 deletions(-) diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh index 6e2e6b17e0..1a853fdd8d 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh @@ -47,7 +47,7 @@ else fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -source "/mntfn/yanyi/code/slime/scripts/models/qwen3.5-35B-A3B.sh" +source "/path/to/slime/scripts/models/qwen3.5-35B-A3B.sh" # MOPD teachers JSON config export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' @@ -57,15 +57,15 @@ export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' # ============================================================================ CKPT_ARGS=( - --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B/ - --ref-load /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ - --load /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-full-vocab-test/ - --save /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-full-vocab-test/ + --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B/ + --ref-load /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ + --load /path/to/output/Qwen3.5-35B-A3B-mopd-full-vocab-test/ + --save /path/to/output/Qwen3.5-35B-A3B-mopd-full-vocab-test/ --save-interval 10 ) ROLLOUT_ARGS=( - --prompt-data /mntfn/yanyi/dataset/train_text_user_only.jsonl + --prompt-data /path/to/dataset/train_text_user_only.jsonl --input-key messages --apply-chat-template --rollout-shuffle @@ -151,7 +151,7 @@ MOPD_ARGS=( --mopd-distill-type full_vocab # Teacher checkpoint = same as ref model (self-distillation for validation) - --mopd-teacher-loads /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + --mopd-teacher-loads /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ # MOPD hyperparameters --mopd-alpha 0.0 # Pure distillation, no ORM diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh index c8ded7428b..85ec997b5d 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh @@ -39,7 +39,7 @@ else fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -source "/mntfn/yanyi/code/slime/scripts/models/qwen3.5-35B-A3B.sh" +source "/path/to/slime/scripts/models/qwen3.5-35B-A3B.sh" # MOPD teachers JSON config # Set as environment variable; arguments.py reads $MOPD_TEACHERS_JSON @@ -54,24 +54,24 @@ export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' # IMPORTANT: Before running this script, convert the HF checkpoint to Megatron # torch_dist format: # -# cd /mntfn/yanyi/code/slime +# cd /path/to/slime # source scripts/models/qwen3.5-35B-A3B.sh # # PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ # ${MODEL_ARGS[@]} \ -# --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B \ -# --save /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist +# --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B \ +# --save /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist CKPT_ARGS=( - --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B/ - --ref-load /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ - --load /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ - --save /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ + --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B/ + --ref-load /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ + --load /path/to/output/Qwen3.5-35B-A3B-mopd-test/ + --save /path/to/output/Qwen3.5-35B-A3B-mopd-test/ --save-interval 10 ) ROLLOUT_ARGS=( - --prompt-data /mntfn/yanyi/dataset/train_text_user_only.jsonl + --prompt-data /path/to/dataset/train_text_user_only.jsonl --input-key messages --apply-chat-template --rollout-shuffle @@ -131,7 +131,7 @@ MOPD_ARGS=( # If --mopd-teachers is not set, arguments.py falls back to $MOPD_TEACHERS_JSON. # Teacher checkpoint = same as ref model (self-distillation for validation) - --mopd-teacher-loads /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ + --mopd-teacher-loads /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ # MOPD hyperparameters --mopd-alpha 0.0 # Pure distillation, no ORM diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh index 22ab385e8f..6b79f8e762 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh @@ -39,7 +39,7 @@ TEACHER_LOG_FILE="/tmp/sglang_teacher_$(head /dev/urandom | tr -dc A-Za-z0-9 | h # Launch teacher on GPU 0-3 (4 GPUs for TP=4, or adjust TP as needed) CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server \ - --model-path /mnt4/data/open_source/Qwen3.5-35B-A3B/ \ + --model-path /path/to/checkpoints/Qwen3.5-35B-A3B/ \ --host 0.0.0.0 \ --port $TEACHER_PORT \ --tp 4 \ @@ -71,7 +71,7 @@ else fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -source "/mntfn/yanyi/code/slime/scripts/models/qwen3.5-35B-A3B.sh" +source "/path/to/slime/scripts/models/qwen3.5-35B-A3B.sh" # MOPD teachers JSON config export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' @@ -86,24 +86,24 @@ export MOPD_TEACHER_URLS="{\"default\":\"http://$TEACHER_IP:$TEACHER_PORT/genera # IMPORTANT: Before running this script, convert the HF checkpoint to Megatron # torch_dist format: # -# cd /mntfn/yanyi/code/slime +# cd /path/to/slime # source scripts/models/qwen3.5-35B-A3B.sh # # PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ # ${MODEL_ARGS[@]} \ -# --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B \ -# --save /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist +# --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B \ +# --save /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist CKPT_ARGS=( - --hf-checkpoint /mnt4/data/open_source/Qwen3.5-35B-A3B/ - --ref-load /mnt4/data/open_source/Qwen3.5-35B-A3B_torch_dist/ - --load /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ - --save /mnt4/data/zhixiaobao/yanyi/Qwen3.5-35B-A3B-mopd-test/ + --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B/ + --ref-load /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ + --load /path/to/output/Qwen3.5-35B-A3B-mopd-test/ + --save /path/to/output/Qwen3.5-35B-A3B-mopd-test/ --save-interval 10 ) ROLLOUT_ARGS=( - --prompt-data /mntfn/yanyi/dataset/train_text_user_only.jsonl + --prompt-data /path/to/dataset/train_text_user_only.jsonl --input-key messages --apply-chat-template --rollout-shuffle diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh index f5e5757ef8..7630d8b739 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-topk-sglang.sh @@ -103,7 +103,7 @@ PERF_ARGS=( # --use-dynamic-batch-size --max-tokens-per-gpu 2048 - --train-memory-margin-bytes 536870912 + --train-memory-margin-bytes 268435456 ) MOPD_ARGS=( diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh index 304d3c7a49..675fb16677 100644 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh @@ -24,8 +24,8 @@ # PYTHONPATH=/root/Megatron-LM torchrun --nproc_per_node=8 \ # tools/convert_hf_to_torch_dist.py \ # ${MODEL_ARGS[@]} \ -# --hf-checkpoint /personal/ckpt/Qwen3.5-397B-A17B_skin_multiturn \ -# --save /personal/ckpt/Qwen3.5-397B-A17B_skin_multiturn_torch_dist +# --hf-checkpoint /path/to/Qwen3.5-397B-A17B_teacher \ +# --save /path/to/Qwen3.5-397B-A17B_teacher_torch_dist # # usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh @@ -48,14 +48,14 @@ source "${SLIME_DIR}/scripts/models/qwen3.5-397B-A17B.sh" # ============================================================================ # Paths — adjust these to your environment # ============================================================================ -BASE_DIR=/personal/ckpt +BASE_DIR=/path/to/checkpoints -HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5 -TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5_torch_dist -TEACHER_TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_skin_multiturnl_torch_dist -SAVE_DIR=/amed/share/s1-amed-spfs-ckpt/yanyi/Qwen3.5-397B-A17B-Stage3b-Mopd-Topk-Skin-Multiturn-Enhanced +HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B +TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_torch_dist +TEACHER_TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_teacher_torch_dist +SAVE_DIR=${BASE_DIR}/Qwen3.5-397B-A17B-MOPD-TopK-Output -DATA_PATH="/mnt/amed-s3/dataset/14019ba0_text_report_Interpretation/a3967912440becb0d70748a478696f12b6bbf6ac/train_text_think_nothink.jsonl" +DATA_PATH="/path/to/your/training_data.jsonl" # MOPD teachers JSON config export MOPD_TEACHERS_JSON='[{"name":"teacher","domain":"default"}]' diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh index 8c509ec362..d2fb92657f 100755 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh +++ b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh @@ -25,7 +25,7 @@ # Example for a single 397B MoE teacher on 16 GPUs: # # python3 -m sglang.launch_server \ -# --model-path /personal/ckpt/Qwen3.5-397B-A17B_skin_multiturn/ \ +# --model-path /path/to/Qwen3.5-397B-A17B_teacher/ \ # --host 0.0.0.0 --port 13141 \ # --tp 8 --ep-size 16 \ # --chunked-prefill-size 4096 \ @@ -38,8 +38,8 @@ # PYTHONPATH=/root/Megatron-LM torchrun --nproc_per_node=8 \ # tools/convert_hf_to_torch_dist.py \ # ${MODEL_ARGS[@]} \ -# --hf-checkpoint /personal/ckpt/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5 \ -# --save /personal/ckpt/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5_torch_dist +# --hf-checkpoint /path/to/Qwen3.5-397B-A17B \ +# --save /path/to/Qwen3.5-397B-A17B_torch_dist # # usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-sglang.sh @@ -62,20 +62,20 @@ source "${SLIME_DIR}/scripts/models/qwen3.5-397B-A17B.sh" # ============================================================================ # Paths — adjust these to your environment # ============================================================================ -BASE_DIR=/personal/ckpt +BASE_DIR=/path/to/checkpoints -HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5 -TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_Swift_SFT_Stage3b_Text1p5_torch_dist -SAVE_DIR=/amed/share/s1-amed-spfs-ckpt/yanyi/Qwen3.5-397B-A17B-Stage3b-Mopd-Topk-Skin-Multiturn-Enhanced +HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B +TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_torch_dist +SAVE_DIR=${BASE_DIR}/Qwen3.5-397B-A17B-MOPD-TopK-Output -DATA_PATH="/mnt/amed-s3/dataset/14019ba0_text_report_Interpretation/a3967912440becb0d70748a478696f12b6bbf6ac/train_text_think_nothink.jsonl" +DATA_PATH="/path/to/your/training_data.jsonl" # MOPD teachers JSON config (single teacher for this example) export MOPD_TEACHERS_JSON='[{"name":"teacher","domain":"default"}]' # MOPD teacher SGLang server URLs # For multi-teacher, add all domains: {"math":"https://...","code":"https://..."} -TEACHER_IP="aistudio.alipay.com/proxy/rayjob/aistudio-dvm9s0jw-tfjob-master-0" +TEACHER_IP="your-teacher-server-host" TEACHER_PORT=8300 export MOPD_TEACHER_URLS="{\"default\":\"https://$TEACHER_IP:$TEACHER_PORT/generate\"}" diff --git a/slime/utils/data.py b/slime/utils/data.py index 4bb81e5677..ea303caed6 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -108,7 +108,12 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l from slime.utils.processing_utils import process_vision_info for sample in multimodal: - multimodal_inputs = process_vision_info(sample.prompt, processor) + # When apply_chat_template=True, sample.prompt is a templated + # string while process_vision_info needs the original message + # list to extract images. processor(text=…) expects the + # templated string (sample.prompt). + messages = sample.messages if sample.messages is not None else sample.prompt + multimodal_inputs = process_vision_info(messages, processor) processor_output = processor(text=sample.prompt, **multimodal_inputs) input_ids = processor_output["input_ids"][0] if len(input_ids) <= max_length: @@ -250,6 +255,9 @@ def __init__( origin_samples.append( Sample( prompt=output_prompt, + # Preserve raw message list for multimodal processing when + # apply_chat_template has converted prompt to a string. + messages=prompt if isinstance(prompt, list) else None, label=data[label_key] if label_key is not None else None, metadata=metadata, multimodal_inputs=multimodal_inputs, diff --git a/slime/utils/types.py b/slime/utils/types.py index f19d643389..2f9f441128 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -11,8 +11,11 @@ class Sample: group_index: int | None = None index: int | None = None - # prompt + # prompt (after chat template — always a str when apply_chat_template=True) prompt: str | list[dict[str, str]] = "" + # Raw message list before chat template application; needed for multimodal + # processing in filter_long_prompt when prompt has already been templated. + messages: list[dict] | None = None tokens: list[int] = field(default_factory=list) multimodal_inputs: dict[str, Any] | None = None # raw multimodal data, e.g. images, videos, etc. multimodal_train_inputs: dict[str, Any] | None = None # processed multimodal data, e.g. pixel_values, etc. From 7cc85b2bb6aa8445847571277d94b9450dbee3f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Thu, 11 Jun 2026 10:37:08 +0800 Subject: [PATCH 11/14] chore: remove example scripts with known OOM issues Remove megatron-mode and full-vocab example scripts that have known OOM problems. Keep only the validated SGLang TopK scripts: - run-qwen35-397B-A17B-mopd-topk-sglang.sh - run-qwen35-35B-A3B-mopd-topk-sglang.sh --- ...qwen35-35B-A3B-mopd-full-vocab-megatron.sh | 250 ----------------- .../run-qwen35-35B-A3B-mopd-megatron.sh | 234 ---------------- .../run-qwen35-35B-A3B-mopd-sglang.sh | 264 ------------------ ...run-qwen35-397B-A17B-mopd-topk-megatron.sh | 220 --------------- 4 files changed, 968 deletions(-) delete mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh delete mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh delete mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh delete mode 100644 examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh deleted file mode 100644 index 1a853fdd8d..0000000000 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh +++ /dev/null @@ -1,250 +0,0 @@ -#!/bin/bash - -# Multi-Teacher On-Policy Distillation (MOPD) — Full-Vocabulary KL Divergence Mode -# Model: Qwen3.5-35B-A3B (MoE, 256 experts, 8 active) -# Environment: 8× H20 (143GB) -# Teacher: Same as student (self-distillation for connectivity validation only) -# Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) -# Distill Type: full_vocab (exact full-vocabulary reverse KL D_KL(π_θ ∥ π_d)) -# -# This script is for MOPD full_vocab E2E connectivity validation. -# In production, use a DIFFERENT (stronger) model as teacher. -# -# Key difference from token_level mode: -# --mopd-distill-type full_vocab -# → Computes exact D_KL(π_θ ∥ π_d) over full vocabulary instead of -# approximating from sampled tokens. Requires megatron teacher mode. -# → Uses full logits [R, V] instead of per-token log-probs, which -# increases memory usage significantly. -# -# Parallelism: TP=2, EP=8 (matches SFT config, 256 experts / 8 = 32 per GPU) -# Colocate mode: rollout and training share all 8 GPUs with offloading -# -# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-full-vocab-megatron.sh - -# ============================================================================ -# Cleanup: kill existing SGLang / Ray / Python processes -# ============================================================================ -pkill -9 sglang -sleep 3 -ray stop --force 2>/dev/null || true -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -export PYTHONBUFFERED=16 -export FLASHINFER_DISABLE_VERSION_CHECK=1 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -source "/path/to/slime/scripts/models/qwen3.5-35B-A3B.sh" - -# MOPD teachers JSON config -export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' - -# ============================================================================ -# Configure training arguments -# ============================================================================ - -CKPT_ARGS=( - --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B/ - --ref-load /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ - --load /path/to/output/Qwen3.5-35B-A3B-mopd-full-vocab-test/ - --save /path/to/output/Qwen3.5-35B-A3B-mopd-full-vocab-test/ - --save-interval 10 -) - -ROLLOUT_ARGS=( - --prompt-data /path/to/dataset/train_text_user_only.jsonl - --input-key messages - --apply-chat-template - --rollout-shuffle - --rollout-batch-size 4 - --n-samples-per-prompt 1 - --rollout-max-response-len 4096 - --rollout-temperature 0.8 - - --global-batch-size 4 - --balance-data - --num-epoch 1 -) - -RM_ARGS=( - # Pure distillation (mopd-alpha=0): rm-type defaults to "zero" automatically. -) - -EVAL_ARGS=( - # No eval for connectivity test -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 8 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 4096 -) - -# MOPD Configuration — Full-Vocabulary KL Divergence Mode -# -# Key changes from token_level mode: -# 1. Added --mopd-distill-type full_vocab -# → Computes exact D_KL(π_θ ∥ π_d) = Σ_y π_θ(y)[log π_θ(y) - log π_d(y)] -# over the full vocabulary instead of token-level approximation. -# -# 2. --mopd-teacher-loads is REQUIRED for full_vocab mode -# → full_vocab needs megatron teacher forward pass to get full logits, -# SGLang rollout cannot provide per-token full-vocab logits. -# -# 3. Memory considerations: -# → full_vocab mode stores teacher logits [R_i, V_local] per sample per teacher. -# For V=248320, TP=2 → V_local=124160, each token's logits = ~480KB in fp32. -# With batch=4, R=4096: teacher logits per GPU ≈ 4×4096×124160×4B ≈ 7.6GB. -# Student logits (same shape) appear during training forward pass ≈ 1.9GB/micro-batch. -# Together with model (~9GB), optimizer (~26GB), and SGLang (40%=57GB), -# total ≈ 102GB / 143GB, leaving ~41GB headroom. -# If OOM: reduce rollout-batch-size, rollout-max-response-len, or sglang-mem-fraction-static. -# -# 4. Loss formula: -# → L = L_fv_kl + alpha * L_pg (when alpha > 0) -# → L = L_fv_kl (pure distillation, when alpha = 0) -# where L_fv_kl = (1/D) Σ_d w_d * D_KL(π_θ ∥ π_d) (IS-corrected) -# -# 5. IS weight correction still applies (same as token_level mode) -# -# Alternative: Use top_k mode for memory-efficient approximate KL: -# Replace --mopd-distill-type full_vocab with: -# --mopd-distill-type top_k -# --mopd-topk-k 1024 -# This stores only [R_i, k] teacher logits+indices per sample (k=1024 by default), -# plus a tail probability correction. Memory per sample ≈ k*5B per token -# (vs V*4B for full_vocab). For k=1024, V=248320: ~98.7% memory reduction. -# Teacher logits per GPU ≈ 4×4096×1024×(4+4)B ≈ 128MB (negligible vs full_vocab). -# -# For this connectivity test, the teacher IS the same model (self-distillation). -MOPD_ARGS=( - --advantage-estimator grpo - - # MOPD flags — single teacher - --use-mopd - # Pass JSON via env var MOPD_TEACHERS_JSON to avoid shell quoting issues. - - # *** KEY DIFFERENCE: full_vocab distillation type *** - --mopd-distill-type full_vocab - - # Teacher checkpoint = same as ref model (self-distillation for validation) - --mopd-teacher-loads /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ - - # MOPD hyperparameters - --mopd-alpha 0.0 # Pure distillation, no ORM - --mopd-eps-low 0.2 # IS weight lower bound - --mopd-eps-high 5.0 # IS weight upper bound - --mopd-sampling-logprobs-key rollout_log_probs - - # Standard training flags - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 5e-7 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - #--use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3.5-35B-mopd-full-vocab-megatron -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 - --sglang-mem-fraction-static 0.4 - --sglang-ep-size 8 -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash - - --moe-token-dispatcher-type flex - --moe-enable-deepep - - --colocate -) - -# ============================================================================ -# Launch training -# ============================================================================ - -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export no_proxy="127.0.0.1,${MASTER_ADDR}" - -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -RUNTIME_ENV_JSON=$(python3 -c " -import json, os -env = { - 'PYTHONPATH': '/root/Megatron-LM/', - 'CUDA_DEVICE_MAX_CONNECTIONS': '1', - 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), - 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') -} -print(json.dumps({'env_vars': env})) -") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${MOPD_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${RM_ARGS[@]} - -# ============================================================================ -# Cleanup -# ============================================================================ -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh deleted file mode 100644 index 85ec997b5d..0000000000 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-megatron.sh +++ /dev/null @@ -1,234 +0,0 @@ -#!/bin/bash - -# Multi-Teacher On-Policy Distillation (MOPD) — Single Teacher Connectivity Test -# Model: Qwen3.5-35B-A3B (MoE, 256 experts, 8 active) -# Environment: 8× H20 (143GB) -# Teacher: Same as student (self-distillation for connectivity validation only) -# Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) -# -# This script is for MOPD E2E connectivity validation only. -# In production, use a DIFFERENT (stronger) model as teacher. -# -# Parallelism: TP=2, EP=8 (matches SFT config, 256 experts / 8 = 32 per GPU) -# Colocate mode: rollout and training share all 8 GPUs with offloading -# -# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen3.5-35B-A3B-mopd-megatron.sh - -# ============================================================================ -# Cleanup: kill existing SGLang / Ray / Python processes -# ============================================================================ -pkill -9 sglang -sleep 3 -ray stop --force 2>/dev/null || true -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -export PYTHONBUFFERED=16 -export FLASHINFER_DISABLE_VERSION_CHECK=1 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -source "/path/to/slime/scripts/models/qwen3.5-35B-A3B.sh" - -# MOPD teachers JSON config -# Set as environment variable; arguments.py reads $MOPD_TEACHERS_JSON -# when --mopd-teachers is not provided on the command line. -# This avoids shell quoting issues when passing JSON through ray job submit. -export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' - -# ============================================================================ -# Configure training arguments -# ============================================================================ - -# IMPORTANT: Before running this script, convert the HF checkpoint to Megatron -# torch_dist format: -# -# cd /path/to/slime -# source scripts/models/qwen3.5-35B-A3B.sh -# -# PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ -# ${MODEL_ARGS[@]} \ -# --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B \ -# --save /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist - -CKPT_ARGS=( - --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B/ - --ref-load /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ - --load /path/to/output/Qwen3.5-35B-A3B-mopd-test/ - --save /path/to/output/Qwen3.5-35B-A3B-mopd-test/ - --save-interval 10 -) - -ROLLOUT_ARGS=( - --prompt-data /path/to/dataset/train_text_user_only.jsonl - --input-key messages - --apply-chat-template - --rollout-shuffle - --rollout-batch-size 16 - --n-samples-per-prompt 1 # No need for multiple samples in pure distillation - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 16 - --balance-data - --num-epoch 1 -) - -RM_ARGS=( - # Pure distillation (mopd-alpha=0): rm-type defaults to "zero" automatically. - # No reward model needed. -) - -EVAL_ARGS=( - # No eval for connectivity test -) - -# Qwen3.5-35B-A3B with 8 GPUs (same parallelism as SFT config): -# TP=2, EP=8 (256 experts / 8 = 32 experts per GPU) -# Colocate mode: rollout and training share all 8 GPUs with offloading -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 8 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 8192 -) - -# MOPD Configuration (Megatron mode, single teacher) -# For this connectivity test, the teacher IS the same model (self-distillation). -# This validates the full MOPD pipeline: rollout → teacher log-prob → advantage → train. -# -# Key: The teacher checkpoint must be in Megatron torch_dist format. -# Since teacher = student here, we use the same torch_dist path. -# -# Memory note: The teacher model weights are backed up to CPU memory via -# TensorBackuper. For Qwen3.5-35B-A3B, expect ~70GB additional CPU RAM usage. -MOPD_ARGS=( - --advantage-estimator grpo - - # MOPD flags — single teacher - --use-mopd - # Pass JSON via env var MOPD_TEACHERS_JSON to avoid shell quoting issues with ray job submit. - # If --mopd-teachers is not set, arguments.py falls back to $MOPD_TEACHERS_JSON. - - # Teacher checkpoint = same as ref model (self-distillation for validation) - --mopd-teacher-loads /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ - - # MOPD hyperparameters - --mopd-alpha 0.0 # Pure distillation, no ORM - --mopd-eps-low 0.2 # IS weight lower bound - --mopd-eps-high 5.0 # IS weight upper bound - --mopd-sampling-logprobs-key rollout_log_probs - - # Standard training flags - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 5e-7 # Conservative LR for stability - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - #--use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3.5-35B-mopd-megatron - # --wandb-key ${WANDB_KEY} -) - -# SGLang rollout config (colocate mode, shares training GPUs) -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 # All 8 GPUs for rollout - --sglang-mem-fraction-static 0.4 # Share GPU memory with training - --sglang-ep-size 8 # Match EP=8 for MoE expert parallelism -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash - - # MoE communication - --moe-token-dispatcher-type flex - --moe-enable-deepep - - # Colocate: rollout and training share same GPUs, with offloading - --colocate -) - -# ============================================================================ -# Launch training -# ============================================================================ - -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export no_proxy="127.0.0.1,${MASTER_ADDR}" - -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -RUNTIME_ENV_JSON=$(python3 -c " -import json, os -env = { - 'PYTHONPATH': '/root/Megatron-LM/', - 'CUDA_DEVICE_MAX_CONNECTIONS': '1', - 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), - 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') -} -print(json.dumps({'env_vars': env})) -") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${MOPD_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${RM_ARGS[@]} - -# ============================================================================ -# Cleanup -# ============================================================================ -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh deleted file mode 100644 index 6b79f8e762..0000000000 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-35B-A3B-mopd-sglang.sh +++ /dev/null @@ -1,264 +0,0 @@ -#!/bin/bash - -# Multi-Teacher On-Policy Distillation (MOPD) — Single Teacher SGLang Mode -# Model: Qwen3.5-35B-A3B (MoE, 256 experts, 8 active) -# Environment: 8× H20 (143GB) -# Layout: 4 GPUs for SGLang rollout, 4 GPUs for Megatron training -# Teacher: Same as student (self-distillation for connectivity validation only) -# Mode: SGLang (teacher runs on external SGLang server, no architecture constraint) -# -# This script is for MOPD E2E connectivity validation only. -# In production, use a DIFFERENT (stronger) model as teacher. -# -# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen3.5-35B-A3B-mopd-sglang.sh - -# ============================================================================ -# Cleanup: kill existing SGLang / Ray / Python processes -# ============================================================================ -pkill -9 sglang -sleep 3 -ray stop --force 2>/dev/null || true -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -export PYTHONBUFFERED=16 - -# ============================================================================ -# 1. Configure and start teacher model server (self-distillation for testing) -# ============================================================================ -# For this connectivity test, the teacher is the same model as the student. -# In production, replace with a stronger model (e.g., Qwen3-72B or domain expert). -TEACHER_IP="127.0.0.1" -TEACHER_PORT=13141 -TEACHER_LOG_FILE="/tmp/sglang_teacher_$(head /dev/urandom | tr -dc A-Za-z0-9 | head -c 6).log" - -# Launch teacher on GPU 0-3 (4 GPUs for TP=4, or adjust TP as needed) -CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server \ - --model-path /path/to/checkpoints/Qwen3.5-35B-A3B/ \ - --host 0.0.0.0 \ - --port $TEACHER_PORT \ - --tp 4 \ - --ep-size 4 \ - --chunked-prefill-size 4096 \ - --mem-fraction-static 0.7 \ - > "$TEACHER_LOG_FILE" 2>&1 & - -TEACHER_PID=$! -echo "Starting teacher model server (PID: $TEACHER_PID)..." - -# Wait for teacher server to be ready -until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health_generate > /dev/null; do - echo "Waiting for teacher model server to start..." - tail -n 10 "$TEACHER_LOG_FILE" 2>/dev/null || true - sleep 10 -done -echo "Teacher model server is up and running at $TEACHER_IP:$TEACHER_PORT." - -# ============================================================================ -# 2. Set environment variables -# ============================================================================ - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -source "/path/to/slime/scripts/models/qwen3.5-35B-A3B.sh" - -# MOPD teachers JSON config -export MOPD_TEACHERS_JSON='[{"name":"self_teacher","domain":"default"}]' - -# MOPD teacher URLs -export MOPD_TEACHER_URLS="{\"default\":\"http://$TEACHER_IP:$TEACHER_PORT/generate\"}" - -# ============================================================================ -# 3. Configure training arguments -# ============================================================================ - -# IMPORTANT: Before running this script, convert the HF checkpoint to Megatron -# torch_dist format: -# -# cd /path/to/slime -# source scripts/models/qwen3.5-35B-A3B.sh -# -# PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ -# ${MODEL_ARGS[@]} \ -# --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B \ -# --save /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist - -CKPT_ARGS=( - --hf-checkpoint /path/to/checkpoints/Qwen3.5-35B-A3B/ - --ref-load /path/to/checkpoints/Qwen3.5-35B-A3B_torch_dist/ - --load /path/to/output/Qwen3.5-35B-A3B-mopd-test/ - --save /path/to/output/Qwen3.5-35B-A3B-mopd-test/ - --save-interval 10 -) - -ROLLOUT_ARGS=( - --prompt-data /path/to/dataset/train_text_user_only.jsonl - --input-key messages - --apply-chat-template - --rollout-shuffle - --num-rollout 10 # Small for connectivity test - --rollout-batch-size 16 - --n-samples-per-prompt 1 # No need for multiple samples in pure distillation - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 16 - --balance-data -) - -# For MOPD SGLang mode, we use the MOPD reward_func and post_process_rewards -# The --rm-url is used as the default/fallback URL; per-teacher URLs come from MOPD_TEACHER_URLS env var -RM_ARGS=( - --custom-rm-path slime.rollout.mopd.reward_func - --custom-reward-post-process-path slime.rollout.mopd.post_process_rewards - --rm-url http://$TEACHER_IP:$TEACHER_PORT/generate -) - -EVAL_ARGS=( - # No eval for connectivity test -) - -# Qwen3.5-35B-A3B with 4 GPUs for training: -# TP=2, EP=4 (256 experts / 4 = 64 experts per GPU) -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 4 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 8192 -) - -# MOPD Configuration (SGLang mode, single teacher) -# In SGLang mode, teacher log-probs are obtained by querying the teacher SGLang server -# during rollout. No teacher model is loaded into Megatron training memory. -MOPD_ARGS=( - --advantage-estimator grpo - - # MOPD flags — single teacher - --use-mopd - # Note: --mopd-teachers is read from $MOPD_TEACHERS_JSON env var (see above) - # to avoid shell quoting issues with JSON in ray job submit. - - # No --mopd-teacher-loads needed in SGLang mode! - # Teacher log-probs come from the SGLang server via reward_func. - - # MOPD hyperparameters - --mopd-alpha 0.0 # Pure distillation, no ORM - --mopd-eps-low 0.2 # IS weight lower bound - --mopd-eps-high 5.0 # IS weight upper bound - --mopd-sampling-logprobs-key rollout_log_probs - - # Standard training flags - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 5e-7 # Conservative LR for stability - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - #--use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3.5-35B-mopd-sglang - # --wandb-key ${WANDB_KEY} -) - -# SGLang rollout config: 4 GPUs for rollout -SGLANG_ARGS=( - --rollout-num-gpus 4 # 4 GPUs for SGLang rollout engine - --rollout-num-gpus-per-engine 4 # 4 GPUs per engine (TP=4 for Qwen3.5-35B-A3B) - --sglang-mem-fraction-static 0.7 - --sglang-ep-size 4 # Match EP=4 for MoE expert parallelism -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash - - # MoE communication - --moe-token-dispatcher-type flex - --moe-enable-deepep -) - -# ============================================================================ -# 4. Launch training -# ============================================================================ - -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export no_proxy="127.0.0.1,${MASTER_ADDR}" - -# 8 GPUs total: 4 for SGLang rollout (GPU 0-3, already used by teacher server), -# 4 for Megatron training (GPU 4-7) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -RUNTIME_ENV_JSON=$(python3 -c " -import json, os -env = { - 'PYTHONPATH': '/root/Megatron-LM/', - 'CUDA_DEVICE_MAX_CONNECTIONS': '1', - 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), - 'MOPD_TEACHER_URLS': os.environ.get('MOPD_TEACHER_URLS', ''), - 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') -} -print(json.dumps({'env_vars': env})) -") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${MOPD_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${RM_ARGS[@]} - -# ============================================================================ -# 5. Cleanup -# ============================================================================ -kill $TEACHER_PID 2>/dev/null || true -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python \ No newline at end of file diff --git a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh b/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh deleted file mode 100644 index 675fb16677..0000000000 --- a/examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh +++ /dev/null @@ -1,220 +0,0 @@ -#!/bin/bash - -# Multi-Teacher On-Policy Distillation (MOPD) — Top-K KL Divergence Mode -# Model: Qwen3.5-397B-A17B (MoE, 512 experts, 10 active) -# Environment: 16 nodes × 8 L20X (143GB each), 128 GPUs total -# Teacher: Teacher model (different from student for production distillation) -# Mode: Megatron (teacher loaded into CPU memory via TensorBackuper) -# Distill Type: top_k (approximate reverse KL with top-k teacher logits + tail correction) -# -# This script is for MOPD top_k distillation with 128 GPUs. -# -# Key features of top_k mode: -# --mopd-distill-type top_k -# → Computes approximate D_KL(π_θ ∥ π_d) using teacher's top-k logits -# plus tail probability correction. Much more memory-efficient than full_vocab. -# → Stores only [R_i, k] teacher logits+indices per sample (k=1024 default), -# vs [R_i, V] for full_vocab. ~98.7% memory reduction vs full_vocab. -# -# Prerequisites: -# 1. Convert HF checkpoint to Megatron torch_dist format before first run: -# cd /path/to/slime -# source scripts/models/qwen3.5-397B-A17B.sh -# -# PYTHONPATH=/root/Megatron-LM torchrun --nproc_per_node=8 \ -# tools/convert_hf_to_torch_dist.py \ -# ${MODEL_ARGS[@]} \ -# --hf-checkpoint /path/to/Qwen3.5-397B-A17B_teacher \ -# --save /path/to/Qwen3.5-397B-A17B_teacher_torch_dist -# -# usage: bash examples/multi_teacher_on_policy_distillation/run-qwen35-397B-A17B-mopd-topk-megatron.sh - -set -ex - -export PYTHONBUFFERED=16 -export FLASHINFER_DISABLE_VERSION_CHECK=1 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SLIME_DIR="/workspace/bin/slime" -source "${SLIME_DIR}/scripts/models/qwen3.5-397B-A17B.sh" - -# ============================================================================ -# Paths — adjust these to your environment -# ============================================================================ -BASE_DIR=/path/to/checkpoints - -HF_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B -TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_torch_dist -TEACHER_TORCH_DIST_CKPT=${BASE_DIR}/Qwen3.5-397B-A17B_teacher_torch_dist -SAVE_DIR=${BASE_DIR}/Qwen3.5-397B-A17B-MOPD-TopK-Output - -DATA_PATH="/path/to/your/training_data.jsonl" - -# MOPD teachers JSON config -export MOPD_TEACHERS_JSON='[{"name":"teacher","domain":"default"}]' - -# ============================================================================ -# Configure training arguments -# ============================================================================ - -CKPT_ARGS=( - --hf-checkpoint ${HF_CKPT}/ - --ref-load ${TORCH_DIST_CKPT}/ - --load ${SAVE_DIR}/ - --save ${SAVE_DIR}/ - --save-interval 10 - --no-save-optim -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_PATH} - --input-key messages - --apply-chat-template - --rollout-shuffle - --rollout-batch-size 64 - --n-samples-per-prompt 1 - --rollout-max-response-len 4096 - --rollout-temperature 0.5 - - --global-batch-size 64 - --balance-data - --num-epoch 1 -) - -RM_ARGS=() - -EVAL_ARGS=() - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 128 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 4096 -) - -MOPD_ARGS=( - --advantage-estimator grpo - - # MOPD flags — single teacher - --use-mopd - - # token level - # --mopd-distill-type token_level - - # top k - --mopd-distill-type top_k - --mopd-topk-k 1024 - - # full vocab - # --mopd-distill-type full_vocab - - --mopd-teacher-loads ${TEACHER_TORCH_DIST_CKPT}/ - - # MOPD hyperparameters - --mopd-alpha 0.0 # Pure distillation, no ORM - --mopd-eps-low 0.2 # IS weight lower bound - --mopd-eps-high 5.0 # IS weight upper bound - --mopd-sampling-logprobs-key rollout_log_probs - - # Standard training flags - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 5e-7 # Conservative LR for stability - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 - - # CPU offload optimizer to save GPU memory for large model - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer -) - -WANDB_ARGS=() - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 16 - --sglang-mem-fraction-static 0.45 - --sglang-ep-size 16 -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash - - --moe-token-dispatcher-type alltoall - # --moe-enable-deepep # DeepEP internode kernel assertion fails when EP=128 (num_topk_ranks > kNumTopkRDMARanks) - --no-check-for-nan-in-loss-and-grad - - --colocate -) - -# ============================================================================ -# Launch training — multi-node setup -# ============================================================================ - -# --- Submit job --- -RUNTIME_ENV_JSON=$(python3 -c " -import json, os -env = { - 'PYTHONPATH': '/root/Megatron-LM/', - 'CUDA_DEVICE_MAX_CONNECTIONS': '1', - 'NCCL_DEBUG': 'WARN', - 'NCCL_NVLS_ENABLE': os.environ.get('HAS_NVLINK', '0'), - 'NCCL_TIMEOUT_MS': '36000000', - 'FLASHINFER_DISABLE_VERSION_CHECK': '1', - 'MOPD_TEACHERS_JSON': os.environ.get('MOPD_TEACHERS_JSON', '') -} -print(json.dumps({'env_vars': env})) -") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 ../workspace/bin/slime/train.py \ - --actor-num-nodes 16 \ - --actor-num-gpus-per-node 8 \ - --update-weight-buffer-size $(( 1024 * 1024 * 1024 * 4 )) \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${MOPD_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${RM_ARGS[@]} - -# ============================================================================ -# Cleanup -# ============================================================================ -pkill -9 sglang -sleep 3 -pkill -9 python \ No newline at end of file From 32d0b63ae36b3cfc2251d89a04086ebea67833e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Thu, 11 Jun 2026 11:53:07 +0800 Subject: [PATCH 12/14] fix: add missing rollout_mask_sums to batch_keys in model.py --- slime/backends/megatron_utils/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index f050c262b2..60c6c26d4f 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -500,6 +500,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "rollout_log_probs", "max_seq_lens", "teacher_log_probs", + "rollout_mask_sums", ] # Add MOPD full-vocab teacher logits keys if present # These are stored as "mopd_teacher_{domain}_fv_logits" per domain From 254981b1fc7bc60aee1e356c901e3120cfe94f5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Thu, 11 Jun 2026 13:26:02 +0800 Subject: [PATCH 13/14] fix: resolve ruff lint errors and format with black - Add strict=False to zip() calls (B905) - Rename unused loop variable domain to _domain (B007) - Add from err to raise in except clause (B904) - Remove unused variables process_group, k, last_file_idx/name (F841) - Add noqa: F841 for intentionally assigned test variables - Replace assert False with raise AssertionError (B011) - Rename ambiguous variable l to v (E741) - Apply black formatting to all modified files --- slime/backends/megatron_utils/actor.py | 73 +++--- slime/backends/megatron_utils/data.py | 4 +- slime/backends/megatron_utils/loss.py | 58 ++--- slime/ray/placement_group.py | 5 +- slime/rollout/mopd.py | 61 +++-- slime/utils/arguments.py | 11 +- slime/utils/ppo_utils.py | 8 +- slime/utils/types.py | 6 +- slime_plugins/mbridge/qwen3_5.py | 2 +- .../megatron_bridge/qwen35_vl_moe.py | 113 ++++------ tests/test_mopd.py | 13 +- tests/test_mopd_full_vocab.py | 213 ++++++++++++------ tests/test_mopd_sglang_topk_pipeline.py | 99 ++++---- tools/convert_torch_dist_to_hf_parallel.py | 9 +- tools/merge_missing_keys.py | 16 +- tools/patch_attention_gate_on_cluster.py | 27 ++- 16 files changed, 397 insertions(+), 321 deletions(-) diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index ac56966338..17365043d3 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -130,7 +130,9 @@ def init( mopd_teacher_mode = getattr(args, "mopd_teacher_mode", "megatron") if getattr(args, "use_mopd", False): if mopd_teacher_mode == "megatron" and getattr(args, "mopd_teacher_loads", None): - mopd_teachers = json.loads(args.mopd_teachers) if isinstance(args.mopd_teachers, str) else args.mopd_teachers + mopd_teachers = ( + json.loads(args.mopd_teachers) if isinstance(args.mopd_teachers, str) else args.mopd_teachers + ) for i, teacher_cfg in enumerate(mopd_teachers): domain = teacher_cfg["domain"] tag = f"mopd_teacher_{domain}" @@ -338,17 +340,19 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: # Create a -inf tensor of the correct size as fallback. # -inf log-probs produce zero KL contribution, so this # domain has no effect on the loss for this sample. - sliced_len = len(slice_log_prob_with_cp( - torch.zeros(response_length), - total_length, - response_length, - self.args.qkv_format, - rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, - )) + sliced_len = len( + slice_log_prob_with_cp( + torch.zeros(response_length), + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ) + ) domain_processed.append( - torch.full((sliced_len,), float('-inf'), - device=torch.cuda.current_device(), - dtype=torch.float32) + torch.full( + (sliced_len,), float("-inf"), device=torch.cuda.current_device(), dtype=torch.float32 + ) ) else: domain_processed.append( @@ -565,7 +569,12 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data # Only applies when mopd_teacher_mode == "megatron". In SGLang mode, # teacher data is collected during rollout and arrives in rollout_data. mopd_teacher_mode = getattr(self.args, "mopd_teacher_mode", "megatron") - if getattr(self.args, "use_mopd", False) and mopd_teacher_mode == "megatron" and hasattr(self, "_mopd_teacher_domains") and self._mopd_teacher_domains: + if ( + getattr(self.args, "use_mopd", False) + and mopd_teacher_mode == "megatron" + and hasattr(self, "_mopd_teacher_domains") + and self._mopd_teacher_domains + ): mopd_teacher_log_probs = {} mopd_distill_type = getattr(self.args, "mopd_distill_type", "token_level") use_full_vocab = mopd_distill_type == "full_vocab" @@ -676,7 +685,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data topk_logits_list = [] topk_indices_list = [] for i, (logits_per_sample, indices_per_sample) in enumerate( - zip(sglang_topk_logits[domain], sglang_topk_indices[domain]) + zip(sglang_topk_logits[domain], sglang_topk_indices[domain], strict=False) ): if logits_per_sample is None or indices_per_sample is None: # Fallback: create zero-contribution tensors so all DP @@ -685,14 +694,19 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data # Use -inf logits → zero KL divergence contribution. seq_len = rollout_data["response_lengths"][i] topk_logits_list.append( - torch.full((seq_len, topk_k), float('-inf'), - device=torch.cuda.current_device(), - dtype=torch.float32) + torch.full( + (seq_len, topk_k), + float("-inf"), + device=torch.cuda.current_device(), + dtype=torch.float32, + ) ) topk_indices_list.append( - torch.zeros((seq_len, topk_k), - device=torch.cuda.current_device(), - dtype=torch.int64) + torch.zeros( + (seq_len, topk_k), + device=torch.cuda.current_device(), + dtype=torch.int64, + ) ) else: # SGLang returns GLOBAL token IDs, but the Megatron loss @@ -738,7 +752,9 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data seq_len = global_indices.size(0) # Mask for which entries are in this shard - in_shard = (global_indices >= vocab_offset) & (global_indices < vocab_offset + vocab_local_size) + in_shard = (global_indices >= vocab_offset) & ( + global_indices < vocab_offset + vocab_local_size + ) # Convert to local indices local_indices = global_indices - vocab_offset # Clamp out-of-range indices to 0 (will be overridden by -inf logits) @@ -747,12 +763,13 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data # Build per-shard top-k: assign in-shard entries, pad rest with -inf # For each position, we need exactly k entries local_topk_logits = torch.full( - (seq_len, topk_k), float('-inf'), - device=torch.cuda.current_device(), dtype=torch.float32 + (seq_len, topk_k), + float("-inf"), + device=torch.cuda.current_device(), + dtype=torch.float32, ) local_topk_indices = torch.zeros( - (seq_len, topk_k), - device=torch.cuda.current_device(), dtype=torch.int64 + (seq_len, topk_k), device=torch.cuda.current_device(), dtype=torch.int64 ) # Scatter: for each position, place the in-shard entries into @@ -974,9 +991,13 @@ def load_other_checkpoint(self, model_tag: str, path: str) -> None: self.args.ckpt_step = self.args.opd_teacher_ckpt_step elif model_tag.startswith("mopd_teacher_"): # MOPD teacher checkpoint step: look up from mopd_teacher_ckpt_steps by domain - domain = model_tag[len("mopd_teacher_"):] + domain = model_tag[len("mopd_teacher_") :] if getattr(self.args, "mopd_teacher_ckpt_steps", None) is not None: - mopd_teachers = json.loads(self.args.mopd_teachers) if isinstance(self.args.mopd_teachers, str) else self.args.mopd_teachers + mopd_teachers = ( + json.loads(self.args.mopd_teachers) + if isinstance(self.args.mopd_teachers, str) + else self.args.mopd_teachers + ) for i, t in enumerate(mopd_teachers): if t["domain"] == domain and i < len(self.args.mopd_teacher_ckpt_steps): old_ckpt_step = self.args.ckpt_step diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index a82727efc8..b58c7e427d 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -317,7 +317,9 @@ def log_rollout_data( if key.startswith("mopd_teacher_") and key.endswith("_fv_logits"): continue # Skip per-domain top-k teacher logits/indices (too large for averaging) - if key.startswith("mopd_teacher_") and (key.endswith("_topk_logits") or key.endswith("_topk_indices") or key.endswith("_topk_log_sum_exp")): + if key.startswith("mopd_teacher_") and ( + key.endswith("_topk_logits") or key.endswith("_topk_indices") or key.endswith("_topk_log_sum_exp") + ): continue # Emit (sum, count) so gather_log_data can do a weighted average across # DP ranks. This stops the legacy "every rank has the same N samples" diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 31e538a4a4..26f59d194c 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,9 +1,8 @@ +import logging from argparse import Namespace from collections.abc import Callable, Iterator from typing import Any -import logging - import torch logger = logging.getLogger(__name__) @@ -708,8 +707,7 @@ def apply_mopd_to_advantages( sampling_log_probs = rollout_data.get("log_probs") if sampling_log_probs is None: raise ValueError( - f"MOPD requires '{sampling_logprobs_key}' in rollout_data for importance sampling, " - f"but it is missing." + f"MOPD requires '{sampling_logprobs_key}' in rollout_data for importance sampling, " f"but it is missing." ) device = student_log_probs[0].device @@ -721,7 +719,7 @@ def apply_mopd_to_advantages( all_is_weights_list = [] all_reverse_kls = [] - for domain, teacher_lp_list in mopd_teacher_log_probs.items(): + for _domain, teacher_lp_list in mopd_teacher_log_probs.items(): domain_advantages = [] domain_is_weights = [] domain_reverse_kls = [] @@ -783,8 +781,12 @@ def apply_mopd_to_advantages( for i in range(len(advantages)): # Collect valid (non-None) teacher contributions for this sample - valid_advs = [all_mopd_advantages[t][i] for t in range(len(all_mopd_advantages)) if all_mopd_advantages[t][i] is not None] - valid_is = [all_is_weights_list[t][i] for t in range(len(all_is_weights_list)) if all_is_weights_list[t][i] is not None] + valid_advs = [ + all_mopd_advantages[t][i] for t in range(len(all_mopd_advantages)) if all_mopd_advantages[t][i] is not None + ] + valid_is = [ + all_is_weights_list[t][i] for t in range(len(all_is_weights_list)) if all_is_weights_list[t][i] is not None + ] if len(valid_advs) == 0: # No valid teachers for this sample — use zero advantages and zero IS weights @@ -1095,9 +1097,7 @@ def apply_mopd_full_vocab_to_loss( if sampling_logprobs_key == "rollout_log_probs" and sampling_log_probs is None: sampling_log_probs = batch.get("log_probs") if sampling_log_probs is None: - raise ValueError( - f"MOPD full_vocab requires '{sampling_logprobs_key}' in batch for importance sampling." - ) + raise ValueError(f"MOPD full_vocab requires '{sampling_logprobs_key}' in batch for importance sampling.") num_samples = len(student_logits_per_sample) if len(sampling_log_probs) != num_samples: @@ -1219,7 +1219,9 @@ def apply_mopd_full_vocab_to_loss( # different subsets of active domains. for domain in teacher_logits_per_domain: if domain in per_domain_kls and len(per_domain_kls[domain]) > 0: - metrics[f"mopd_fv_kl/{domain}"] = sum_of_sample_mean(torch.cat(per_domain_kls[domain], dim=0)).clone().detach() + metrics[f"mopd_fv_kl/{domain}"] = ( + sum_of_sample_mean(torch.cat(per_domain_kls[domain], dim=0)).clone().detach() + ) else: metrics[f"mopd_fv_kl/{domain}"] = torch.tensor(0.0, device=all_kl_cat.device) @@ -1282,9 +1284,7 @@ def apply_mopd_topk_to_loss( if sampling_logprobs_key == "rollout_log_probs" and sampling_log_probs is None: sampling_log_probs = batch.get("log_probs") if sampling_log_probs is None: - raise ValueError( - f"MOPD top_k requires '{sampling_logprobs_key}' in batch for importance sampling." - ) + raise ValueError(f"MOPD top_k requires '{sampling_logprobs_key}' in batch for importance sampling.") vocab_size = args.vocab_size num_samples = len(student_logits_per_sample) @@ -1304,10 +1304,7 @@ def apply_mopd_topk_to_loss( valid_teacher_count = 0 for domain in teacher_topk_logits_per_domain: - if ( - i >= len(teacher_topk_logits_per_domain[domain]) - or teacher_topk_logits_per_domain[domain][i] is None - ): + if i >= len(teacher_topk_logits_per_domain[domain]) or teacher_topk_logits_per_domain[domain][i] is None: continue t_topk_logits = teacher_topk_logits_per_domain[domain][i] # [R_i, k] @@ -1328,7 +1325,10 @@ def apply_mopd_topk_to_loss( # Get teacher log_sum_exp for exact tail mass (Megatron mode only) t_topk_log_sum_exp = None if teacher_topk_log_sum_exp_per_domain and domain in teacher_topk_log_sum_exp_per_domain: - if i < len(teacher_topk_log_sum_exp_per_domain[domain]) and teacher_topk_log_sum_exp_per_domain[domain][i] is not None: + if ( + i < len(teacher_topk_log_sum_exp_per_domain[domain]) + and teacher_topk_log_sum_exp_per_domain[domain][i] is not None + ): t_topk_log_sum_exp = teacher_topk_log_sum_exp_per_domain[domain][i] # [R_i] kl_i = vocab_parallel_topk_reverse_kl( @@ -1410,7 +1410,9 @@ def apply_mopd_topk_to_loss( for domain in teacher_topk_logits_per_domain: if domain in per_domain_kls and len(per_domain_kls[domain]) > 0: - metrics[f"mopd_topk_kl/{domain}"] = sum_of_sample_mean(torch.cat(per_domain_kls[domain], dim=0)).clone().detach() + metrics[f"mopd_topk_kl/{domain}"] = ( + sum_of_sample_mean(torch.cat(per_domain_kls[domain], dim=0)).clone().detach() + ) else: # No samples contributed valid teacher data for this domain in this # microbatch. Emit a zero metric so that every microbatch produces @@ -1530,7 +1532,9 @@ def policy_loss_function( # Apply MOPD token_level: replace advantages with mopd_advantages and apply IS weights # L_MOPD(θ) = -E[1/|y| Σ_t w_t * Â_MOPD,t * log π_θ(y_t|x,y_ 0: loss = fv_kl_loss + alpha * pg_loss batch["_mopd_fv_kl_loss"] = fv_kl_loss else: - logger.warning("MOPD full_vocab enabled but no teacher logits found in batch. Skipping full_vocab KL loss.") + logger.warning( + "MOPD full_vocab enabled but no teacher logits found in batch. Skipping full_vocab KL loss." + ) # Ensure per-domain metric keys AND base MOPD metric keys exist for # ALL configured teacher domains, even when the batch contains no valid @@ -1825,10 +1831,10 @@ def policy_loss_function( # keys present in this microbatch — so that every microbatch # produces the same set of metric keys (required for Megatron's # loss-reduction across microbatches). - _all_mopd_domains = [ - t["domain"] for t in getattr(args, "_mopd_teachers_parsed", []) - ] - _mopd_reverse_kl_domains = _all_mopd_domains if _all_mopd_domains else list(batch["mopd_reverse_kl"].keys()) + _all_mopd_domains = [t["domain"] for t in getattr(args, "_mopd_teachers_parsed", [])] + _mopd_reverse_kl_domains = ( + _all_mopd_domains if _all_mopd_domains else list(batch["mopd_reverse_kl"].keys()) + ) for domain in _mopd_reverse_kl_domains: if domain in batch["mopd_reverse_kl"]: domain_kl_tensor = torch.cat(batch["mopd_reverse_kl"][domain], dim=0) diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index 18e8390cde..86f564efb1 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -190,7 +190,10 @@ def create_training_models(args, pgs, rollout_manager): with_ref=actor_args.kl_coef != 0 or actor_args.use_kl_loss, with_opd_teacher=( (actor_args.use_opd and actor_args.opd_type == "megatron") - or (getattr(actor_args, "use_mopd", False) and getattr(actor_args, "mopd_teacher_loads", None) is not None) + or ( + getattr(actor_args, "use_mopd", False) + and getattr(actor_args, "mopd_teacher_loads", None) is not None + ) ), ) ) diff --git a/slime/rollout/mopd.py b/slime/rollout/mopd.py index 3d63989faa..95b26701c5 100644 --- a/slime/rollout/mopd.py +++ b/slime/rollout/mopd.py @@ -206,9 +206,13 @@ async def _fetch_teacher_logprobs( ) return result - except (aiohttp.ClientPayloadError, aiohttp.ClientConnectionError, - aiohttp.ServerDisconnectedError, asyncio.TimeoutError, - aiohttp.ClientResponseError) as exc: + except ( + aiohttp.ClientPayloadError, + aiohttp.ClientConnectionError, + aiohttp.ServerDisconnectedError, + asyncio.TimeoutError, + aiohttp.ClientResponseError, + ) as exc: last_exc = exc if attempt < max_retries - 1: # 5xx server errors are retryable; ClientPayloadError (e.g. @@ -344,19 +348,14 @@ async def _reward_func_single(args, sample, **kwargs): for domain, rm_url in url_map.items(): domains.append(domain) tasks.append( - _fetch_teacher_logprobs(session, rm_url, payload, - max_retries=max_retries, - retry_delay=retry_delay) + _fetch_teacher_logprobs(session, rm_url, payload, max_retries=max_retries, retry_delay=retry_delay) ) responses = await asyncio.gather(*tasks, return_exceptions=True) - for domain, resp in zip(domains, responses): + for domain, resp in zip(domains, responses, strict=False): if isinstance(resp, Exception): - logger.warning( - f"MOPD teacher '{domain}' failed after retries: {resp}. " - f"Skipping this teacher." - ) + logger.warning(f"MOPD teacher '{domain}' failed after retries: {resp}. " f"Skipping this teacher.") continue results[domain] = resp @@ -460,8 +459,7 @@ def _extract_teacher_data_from_responses(args, samples: list[Sample]): ) else: logger.info( - f"MOPD: Received teacher logprobs for domain '{domain}': " - f"len={log_probs.size(0)}, all -inf" + f"MOPD: Received teacher logprobs for domain '{domain}': " f"len={log_probs.size(0)}, all -inf" ) # --- top_k: extract top-k log-probs and indices per position --- @@ -489,7 +487,9 @@ def _extract_teacher_data_from_responses(args, samples: list[Sample]): ) # Pad with None entries so the loop below generates # [-inf, ..., -inf] / [0, ..., 0] for missing positions - top_logprobs_response = top_logprobs_response + [None] * (response_length - len(top_logprobs_response)) + top_logprobs_response = top_logprobs_response + [None] * ( + response_length - len(top_logprobs_response) + ) if len(top_logprobs_response) > response_length: top_logprobs_response = top_logprobs_response[-response_length:] @@ -503,7 +503,7 @@ def _extract_teacher_data_from_responses(args, samples: list[Sample]): topk_k = getattr(args, "mopd_topk_k", 1024) NEG_INF = float("-inf") - topk_logits_list = [] # [seq_len][k] float + topk_logits_list = [] # [seq_len][k] float topk_indices_list = [] # [seq_len][k] int short_positions = 0 # Count positions with fewer than topk_k entries @@ -551,16 +551,18 @@ def _extract_teacher_data_from_responses(args, samples: list[Sample]): # Provide an actionable message for the most common cause: # SGLang server not returning logprobs. if isinstance(e, KeyError) and str(e) in ("'input_token_logprobs'", "input_token_logprobs"): - meta_keys = list(teacher_response.get("meta_info", {}).keys()) if isinstance(teacher_response.get("meta_info"), dict) else "N/A" + meta_keys = ( + list(teacher_response.get("meta_info", {}).keys()) + if isinstance(teacher_response.get("meta_info"), dict) + else "N/A" + ) logger.error( f"MOPD: SGLang response for domain '{domain}' missing " f"'input_token_logprobs'. meta_info keys: {meta_keys}. " f"Check teacher URL configuration." ) else: - logger.warning( - f"MOPD: Failed to extract teacher data for domain '{domain}': {e}" - ) + logger.warning(f"MOPD: Failed to extract teacher data for domain '{domain}': {e}") # --- Fill in missing domains with zero/fallback data --- # When a teacher request fails (e.g., ContentLengthError, connection reset), @@ -598,7 +600,7 @@ def _extract_teacher_data_from_responses(args, samples: list[Sample]): f"Filling with -inf log-probs (zero KL contribution)." ) sample.mopd_teacher_log_probs[domain] = torch.full( - (response_length,), float('-inf'), dtype=torch.float32 + (response_length,), float("-inf"), dtype=torch.float32 ) if mopd_distill_type == "top_k": if sample.mopd_teacher_topk_logits is None: @@ -608,12 +610,8 @@ def _extract_teacher_data_from_responses(args, samples: list[Sample]): if domain not in sample.mopd_teacher_topk_logits: topk_k = getattr(args, "mopd_topk_k", 1024) NEG_INF = float("-inf") - sample.mopd_teacher_topk_logits[domain] = [ - [NEG_INF] * topk_k for _ in range(response_length) - ] - sample.mopd_teacher_topk_indices[domain] = [ - [0] * topk_k for _ in range(response_length) - ] + sample.mopd_teacher_topk_logits[domain] = [[NEG_INF] * topk_k for _ in range(response_length)] + sample.mopd_teacher_topk_indices[domain] = [[0] * topk_k for _ in range(response_length)] def post_process_rewards(args, samples: list[Sample], **kwargs): @@ -696,7 +694,7 @@ async def combined_reward_func(args, sample_or_samples, **kwargs): args.custom_rm_path = original_custom_rm_path # Store MOPD teacher responses in sample metadata - for sample, mopd_result in zip(sample_or_samples, mopd_results): + for sample, mopd_result in zip(sample_or_samples, mopd_results, strict=False): if isinstance(sample.metadata, dict): sample.metadata[_MOPD_TEACHER_RESPONSES_KEY] = mopd_result else: @@ -752,17 +750,14 @@ def combined_post_process_rewards(args, samples: list[Sample], **kwargs): _extract_teacher_data_from_responses(args, samples) # Clean up temporary metadata and restore task rewards - for sample, original_reward in zip(samples, original_rewards): + for sample, original_reward in zip(samples, original_rewards, strict=False): if isinstance(sample.metadata, dict): sample.metadata.pop(_MOPD_TEACHER_RESPONSES_KEY, None) sample.reward = original_reward # Step 2: Apply standard reward post-processing raw_rewards = [sample.get_reward_value(args) for sample in samples] - if ( - args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] - and args.rewards_normalization - ): + if args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] and args.rewards_normalization: rewards = torch.tensor(raw_rewards, dtype=torch.float) if rewards.shape[-1] == args.n_samples_per_prompt * args.rollout_batch_size: rewards = rewards.reshape(-1, args.n_samples_per_prompt) @@ -777,4 +772,4 @@ def combined_post_process_rewards(args, samples: list[Sample], **kwargs): return raw_rewards, rewards.flatten().tolist() - return raw_rewards, raw_rewards \ No newline at end of file + return raw_rewards, raw_rewards diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 98ca829f9b..1cb26c125c 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1977,7 +1977,7 @@ def slime_validate_args(args): else: mopd_teachers = args.mopd_teachers except (json.JSONDecodeError, TypeError) as e: - raise ValueError(f"--mopd-teachers must be valid JSON: {e}") + raise ValueError(f"--mopd-teachers must be valid JSON: {e}") from e if not isinstance(mopd_teachers, list) or len(mopd_teachers) == 0: raise ValueError("--mopd-teachers must be a non-empty JSON list of teacher configs.") @@ -2023,9 +2023,7 @@ def slime_validate_args(args): if args.mopd_eps_low < 0: raise ValueError(f"--mopd-eps-low must be >= 0, got {args.mopd_eps_low}.") if args.mopd_eps_high <= args.mopd_eps_low: - raise ValueError( - f"--mopd-eps-high ({args.mopd_eps_high}) must be > --mopd-eps-low ({args.mopd_eps_low})." - ) + raise ValueError(f"--mopd-eps-high ({args.mopd_eps_high}) must be > --mopd-eps-low ({args.mopd_eps_low}).") # Set default teacher mode based on whether mopd_teacher_loads is provided if not hasattr(args, "mopd_teacher_mode") or args.mopd_teacher_mode is None: @@ -2134,9 +2132,8 @@ def slime_validate_args(args): # or custom_rm_path is set, default to "zero" reward. # Note: After SGLang auto-config, custom_rm_path may be set to a MOPD function, # so we check the combination of mopd_alpha and whether there's a real task reward source. - _mopd_uses_combined_rm = ( - args.custom_rm_path is not None - and "slime.rollout.mopd.combined_reward_func" in str(args.custom_rm_path) + _mopd_uses_combined_rm = args.custom_rm_path is not None and "slime.rollout.mopd.combined_reward_func" in str( + args.custom_rm_path ) if args.mopd_alpha > 0 and args.rm_type is None and not _mopd_uses_combined_rm and args.custom_rm_path is None: raise ValueError( diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 029f34b4c1..1641b91961 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -1,8 +1,8 @@ # Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py -from argparse import Namespace import logging +from argparse import Namespace import torch import torch.distributed as dist @@ -267,7 +267,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): s_softmax, local_s_log_prob, local_t_log_prob, kl = ctx.saved_tensors - process_group = ctx.process_group # Gradient: ∂KL/∂z_j = π_s(j) * [log π_s(j) - log π_t(j) - KL] # This is completely local per token — no all_reduce needed in backward. @@ -378,7 +377,6 @@ def vocab_parallel_topk_reverse_kl( teacher_topk_indices = teacher_topk_indices.long() tp_size = dist.get_world_size(group=process_group) if process_group is not None else 1 - k = teacher_topk_logits.size(-1) # Compute validity mask from teacher_topk_logits if not provided. # Entries with -inf logits are padding (e.g., from SGLang TP sharding). @@ -555,9 +553,7 @@ def vocab_parallel_topk_reverse_kl( kl_tail = torch.zeros_like(student_tail_mass) tail_mask = (student_tail_mass > 1e-10) & (teacher_tail_mass > 1e-10) kl_tail[tail_mask] = student_tail_mass[tail_mask] * ( - torch.log(student_tail_mass[tail_mask]) - torch.log( - teacher_tail_mass[tail_mask] - ) + torch.log(student_tail_mass[tail_mask]) - torch.log(teacher_tail_mass[tail_mask]) ) # If teacher_tail_mass ≈ 0 but student_tail_mass > 0, we have an unbounded KL. # This shouldn't happen if k is large enough. We treat it as 0 for numerical safety. diff --git a/slime/utils/types.py b/slime/utils/types.py index 5acad15d6d..505129d796 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -38,9 +38,11 @@ class Sample: rollout_routed_experts: list[list[int]] | None = None # Routed experts from rollout engine remove_sample: bool = False teacher_log_probs: list[float] | None = None # Log probabilities from teacher model for OPD - mopd_teacher_log_probs: dict[str, list[float]] | None = None # Log probabilities from multiple MOPD teachers (domain -> log_probs) + mopd_teacher_log_probs: dict[str, list[float]] | None = ( + None # Log probabilities from multiple MOPD teachers (domain -> log_probs) + ) # Full-vocab teacher logits per domain (SGLang MOPD full_vocab mode). - #Format: {domain: list[list[float]]} — domain -> [seq_len][vocab_size] + # Format: {domain: list[list[float]]} — domain -> [seq_len][vocab_size] mopd_teacher_fv_logits: dict[str, list[list[float]]] | None = None # Top-k teacher logits per domain (SGLang MOPD top_k mode). # Format: {domain: list[list[float]]} — domain -> [seq_len][k] diff --git a/slime_plugins/mbridge/qwen3_5.py b/slime_plugins/mbridge/qwen3_5.py index 44393bd5ab..7a49e31777 100644 --- a/slime_plugins/mbridge/qwen3_5.py +++ b/slime_plugins/mbridge/qwen3_5.py @@ -188,7 +188,7 @@ def _detect_expert_weight_format(self, weights_path: str) -> None: index_file = os.path.join(actual_path, "model.safetensors.index.json") if os.path.exists(index_file): - with open(index_file, "r") as f: + with open(index_file) as f: index = json.load(f) weight_map = index.get("weight_map", {}) # Check if fused key exists diff --git a/slime_plugins/megatron_bridge/qwen35_vl_moe.py b/slime_plugins/megatron_bridge/qwen35_vl_moe.py index de5cf091e5..1bd63ca61d 100644 --- a/slime_plugins/megatron_bridge/qwen35_vl_moe.py +++ b/slime_plugins/megatron_bridge/qwen35_vl_moe.py @@ -51,9 +51,9 @@ import itertools import logging +from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass, field -from typing import Dict, Mapping import torch from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry @@ -65,16 +65,16 @@ ReplicatedMapping, RowParallelMapping, ) +from megatron.bridge.models.qwen.qwen_provider import Qwen3MoEModelProvider # Official Qwen3.5-VL MoE bridge — we inherit from this to reuse its mature # mapping_registry, maybe_modify_converted_hf_weight, and all GDN/MoE/vision # mappings. We only override what differs for our HF-vision-encoder architecture. from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLMoEBridge as _OfficialQwen35VLMoEBridge - -from megatron.bridge.utils.common_utils import extract_expert_number_from_param - -from megatron.bridge.models.qwen.qwen_provider import Qwen3MoEModelProvider -from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.bridge.utils.common_utils import ( + extract_expert_number_from_param, + hook_hf_module_setattr_for_tp_grad_sync, +) from megatron.core import parallel_state, tensor_parallel from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec @@ -125,9 +125,9 @@ def _infer_parallelism_from_param_name(param_name: str) -> str: # -- Row-parallel patterns (check first, more specific) -- row_patterns = [ - "linear_proj.weight", # attention output projection - "linear_fc2.weight", # MLP / expert down projection - "out_proj.weight", # GDN output projection + "linear_proj.weight", # attention output projection + "linear_fc2.weight", # MLP / expert down projection + "out_proj.weight", # GDN output projection "shared_experts.linear_fc2", # shared expert down projection ] for pat in row_patterns: @@ -136,20 +136,20 @@ def _infer_parallelism_from_param_name(param_name: str) -> str: # -- Column-parallel patterns -- col_patterns = [ - "linear_qkv", # QKV projection - "linear_q_up_proj", # fused Q+up (some models) - "linear_kv_up_proj", # fused KV+up (some models) - "embedding.word_embeddings", # vocabulary embedding - "output_layer", # output projection - "linear_fc1.weight", # MLP / expert gate+up projection - "in_proj.weight", # GDN input projection - "in_proj_qkv", # GDN QKV part of input projection - "in_proj_z", # GDN z gate - "in_proj_b", # GDN b gate - "in_proj_a", # GDN a gate - "a_log", # GDN A_log parameter - "dt_bias", # GDT dt_bias parameter - "conv1d.weight", # GDN conv1d + "linear_qkv", # QKV projection + "linear_q_up_proj", # fused Q+up (some models) + "linear_kv_up_proj", # fused KV+up (some models) + "embedding.word_embeddings", # vocabulary embedding + "output_layer", # output projection + "linear_fc1.weight", # MLP / expert gate+up projection + "in_proj.weight", # GDN input projection + "in_proj_qkv", # GDN QKV part of input projection + "in_proj_z", # GDN z gate + "in_proj_b", # GDN b gate + "in_proj_a", # GDN a gate + "a_log", # GDN A_log parameter + "dt_bias", # GDT dt_bias parameter + "conv1d.weight", # GDN conv1d "shared_experts.linear_fc1", # shared expert gate+up ] for pat in col_patterns: @@ -158,19 +158,19 @@ def _infer_parallelism_from_param_name(param_name: str) -> str: # -- Replicated patterns -- replicated_patterns = [ - "layernorm", # any layernorm weight/bias - "layer_norm", # alternative spelling - "norm.weight", # standalone norm - "norm.bias", # standalone norm bias - "router.weight", # MoE router - "gate_weight", # shared expert gate - "gate.bias", # gate bias - "input_layernorm", # input layernorm - "pre_mlp_layernorm", # pre-MLP layernorm - "q_layernorm", # Q layernorm - "k_layernorm", # K layernorm - "layer_norm_weight", # fused TE layernorm weight - "layer_norm_bias", # fused TE layernorm bias + "layernorm", # any layernorm weight/bias + "layer_norm", # alternative spelling + "norm.weight", # standalone norm + "norm.bias", # standalone norm bias + "router.weight", # MoE router + "gate_weight", # shared expert gate + "gate.bias", # gate bias + "input_layernorm", # input layernorm + "pre_mlp_layernorm", # pre-MLP layernorm + "q_layernorm", # Q layernorm + "k_layernorm", # K layernorm + "layer_norm_weight", # fused TE layernorm weight + "layer_norm_bias", # fused TE layernorm bias ] for pat in replicated_patterns: if pat in name: @@ -389,9 +389,7 @@ def _gather_input_ids_from_cp( if cp_size <= 1: return input_ids - gathered = torch.distributed.nn.all_gather( - input_ids, group=parallel_state.get_context_parallel_group() - ) + gathered = torch.distributed.nn.all_gather(input_ids, group=parallel_state.get_context_parallel_group()) local_cu_seqlens = cu_seqlens // cp_size num_seqs = len(cu_seqlens) - 1 @@ -400,8 +398,7 @@ def _gather_input_ids_from_cp( seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() chunk_size = seqlen // 2 // cp_size whole_list.extend( - gathered[cp_rank][0, local_cu_seqlens[i] : local_cu_seqlens[i] + chunk_size] - for cp_rank in range(cp_size) + gathered[cp_rank][0, local_cu_seqlens[i] : local_cu_seqlens[i] + chunk_size] for cp_rank in range(cp_size) ) whole_list.extend( [ @@ -622,9 +619,7 @@ def _compute_mrope_position_ids( if modality == 0: # Text tokens n = end - start - pos_list.append( - torch.arange(n, device=device).view(1, -1).expand(3, -1) + current_pos - ) + pos_list.append(torch.arange(n, device=device).view(1, -1).expand(3, -1) + current_pos) current_pos += n else: # Image tokens @@ -716,9 +711,7 @@ def forward( # Scatter to sequence-parallel region if needed if self.config.sequence_parallel: - combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region( - combined_embeddings - ) + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) combined_embeddings = combined_embeddings.contiguous() # 3. Compute M-RoPE position IDs @@ -742,9 +735,7 @@ def forward( # Non-first PP stage: allocate buffer with correct shape if cu_seqlens is not None: T = cu_seqlens[-1].item() - position_ids = torch.zeros( - 3, 1, T, dtype=torch.long, device=torch.cuda.current_device() - ) + position_ids = torch.zeros(3, 1, T, dtype=torch.long, device=torch.cuda.current_device()) else: raise NotImplementedError( "Non-THD position_ids broadcast not yet supported for non-first PP stages" @@ -823,6 +814,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( get_transformer_block_with_experimental_attention_variant_spec, ) + transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec( config=self, vp_stage=vp_stage, @@ -1046,9 +1038,9 @@ def mapping_registry(self) -> MegatronMappingRegistry: def maybe_modify_converted_hf_weight( self, task: WeightConversionTask, - converted_weights_dict: Dict[str, torch.Tensor], + converted_weights_dict: dict[str, torch.Tensor], hf_state_dict: Mapping, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """Merge per-expert weight exports into a single fused [num_experts, ...] tensor. For **fused HF format** (e.g. 35B model, ``experts.gate_up_proj``), all experts @@ -1096,7 +1088,7 @@ def maybe_modify_converted_hf_weight( if len(expert_ids_in_dict) > 1: return converted_weights_dict - result: Dict[str, torch.Tensor] = {} + result: dict[str, torch.Tensor] = {} for key, value in converted_weights_dict.items(): if key not in self.hf_weights_cache: self.hf_weights_cache[key] = {} @@ -1108,8 +1100,7 @@ def maybe_modify_converted_hf_weight( self.hf_weights_cache[key][local_expert_number] = value else: assert value.shape[0] == ep_size, ( - f"Expected shape[0]=={ep_size} for EP-gathered expert weight " - f"'{key}', got {value.shape}" + f"Expected shape[0]=={ep_size} for EP-gathered expert weight " f"'{key}', got {value.shape}" ) for i, exp_val in enumerate(value): global_expert_number = local_expert_number + (i * experts_per_rank) @@ -1125,10 +1116,7 @@ def maybe_modify_converted_hf_weight( # Move back to CUDA for downstream processing result[key] = merged.cuda() else: - logger.debug( - f"{len(self.hf_weights_cache[key])}/{num_experts} experts " - f"loaded for {key}" - ) + logger.debug(f"{len(self.hf_weights_cache[key])}/{num_experts} experts " f"loaded for {key}") return result @@ -1156,9 +1144,7 @@ def provider_bridge(self, hf_pretrained): # Shared expert intermediate size moe_ffn = getattr(text_config, "moe_intermediate_size", 512) - shared_expert_intermediate = getattr( - text_config, "shared_expert_intermediate_size", 512 - ) + shared_expert_intermediate = getattr(text_config, "shared_expert_intermediate_size", 512) # Read attention bias from config add_qkv_bias = getattr(text_config, "attention_bias", False) @@ -1258,6 +1244,7 @@ def provider_bridge(self, hf_pretrained): return provider + # Apply patches at module load time so that they are active whenever this # bridge module is imported, regardless of import order. -_patch_auto_mapping_for_gdn() \ No newline at end of file +_patch_auto_mapping_for_gdn() diff --git a/tests/test_mopd.py b/tests/test_mopd.py index 41aa960cdd..b74157e28c 100644 --- a/tests/test_mopd.py +++ b/tests/test_mopd.py @@ -7,8 +7,6 @@ 4. Sample.mopd_teacher_log_probs field """ -import json -import os import sys import types from argparse import Namespace @@ -65,6 +63,7 @@ def _import_loss_module(self, monkeypatch): def _get_apply_mopd(self): """Dynamically import apply_mopd_to_advantages from loss.py.""" from slime.backends.megatron_utils.loss import apply_mopd_to_advantages + return apply_mopd_to_advantages def test_basic_mopd_advantage_computation(self): @@ -108,7 +107,7 @@ def test_basic_mopd_advantage_computation(self): assert torch.allclose(mopd_adv, expected_reverse_kl, atol=1e-6) assert torch.allclose(is_weights, torch.ones(3), atol=1e-6) - # Check mopd_reverse_kl is pure reverse_kl (not including alpha * orm_advantage) + # Check mopd_reverse_kl is pure reverse_kl (not including alpha * orm_advantage) reverse_kl_logged = rollout_data["mopd_reverse_kl"]["math"][0] expected_pure_reverse_kl = torch.tensor([0.1, 0.1, 0.1]) assert torch.allclose(reverse_kl_logged, expected_pure_reverse_kl, atol=1e-6) @@ -531,12 +530,15 @@ class TestSampleMopdField: def test_default_none(self): from slime.utils.types import Sample + s = Sample() assert s.mopd_teacher_log_probs is None def test_set_mopd_teacher_log_probs(self): - from slime.utils.types import Sample import torch + + from slime.utils.types import Sample + s = Sample() s.mopd_teacher_log_probs = { "math": torch.tensor([0.1, 0.2, 0.3]), @@ -548,8 +550,9 @@ def test_set_mopd_teacher_log_probs(self): def test_to_dict_roundtrip(self): from slime.utils.types import Sample + s = Sample(response="hello", response_length=1) s.mopd_teacher_log_probs = {"math": [0.1, 0.2, 0.3]} d = s.to_dict() assert "mopd_teacher_log_probs" in d - assert d["mopd_teacher_log_probs"]["math"] == [0.1, 0.2, 0.3] \ No newline at end of file + assert d["mopd_teacher_log_probs"]["math"] == [0.1, 0.2, 0.3] diff --git a/tests/test_mopd_full_vocab.py b/tests/test_mopd_full_vocab.py index 82efc3d0ae..42396070b5 100644 --- a/tests/test_mopd_full_vocab.py +++ b/tests/test_mopd_full_vocab.py @@ -89,9 +89,7 @@ def test_kl_correctness_known_values(self): teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) expected_kl = (student_probs * (student_log_probs - teacher_log_probs)).sum(dim=-1) - assert torch.allclose(kl, expected_kl, atol=1e-5), ( - f"KL mismatch: got {kl}, expected {expected_kl}" - ) + assert torch.allclose(kl, expected_kl, atol=1e-5), f"KL mismatch: got {kl}, expected {expected_kl}" def test_kl_non_negative(self): """KL divergence should always be non-negative (Gibbs' inequality).""" @@ -118,14 +116,14 @@ def test_kl_gradient_flows_through_student(self): # Student should have gradients assert student_logits.grad is not None, "student_logits should have gradients" - assert not torch.allclose(student_logits.grad, torch.zeros_like(student_logits.grad)), ( - "student_logits gradients should be non-zero" - ) + assert not torch.allclose( + student_logits.grad, torch.zeros_like(student_logits.grad) + ), "student_logits gradients should be non-zero" # Teacher should NOT have gradients (detached inside function) - assert teacher_logits.grad is None or torch.allclose(teacher_logits.grad, torch.zeros_like(teacher_logits.grad)), ( - "teacher_logits should not have gradients (should be detached)" - ) + assert teacher_logits.grad is None or torch.allclose( + teacher_logits.grad, torch.zeros_like(teacher_logits.grad) + ), "teacher_logits should not have gradients (should be detached)" def test_kl_gradient_correctness(self): """Verify the gradient of KL matches autograd from manual computation.""" @@ -151,9 +149,9 @@ def test_kl_gradient_correctness(self): loss_2.backward() grad_manual = student_logits_2.grad.clone() - assert torch.allclose(grad_ours, grad_manual, atol=1e-4), ( - f"Gradient mismatch: max diff = {(grad_ours - grad_manual).abs().max()}" - ) + assert torch.allclose( + grad_ours, grad_manual, atol=1e-4 + ), f"Gradient mismatch: max diff = {(grad_ours - grad_manual).abs().max()}" def test_kl_temperature_sensitivity(self): """KL should change when student distribution changes.""" @@ -195,6 +193,7 @@ def _mock_deps(self, monkeypatch): def _get_function(self): from slime.backends.megatron_utils.loss import apply_mopd_full_vocab_to_loss + return apply_mopd_full_vocab_to_loss def _sum_of_sample_mean(self, tensor): @@ -227,8 +226,12 @@ def test_single_teacher_kl_loss(self): loss_masks = [torch.ones(3), torch.ones(4)] kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -260,8 +263,12 @@ def test_identical_student_teacher_zero_kl(self): loss_masks = [torch.ones(3), torch.ones(4)] kl_loss, metrics = apply_fn( - args, batch, [student_logits_1, student_logits_2], - teacher_logits_per_domain, loss_masks, self._sum_of_sample_mean, + args, + batch, + [student_logits_1, student_logits_2], + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -271,7 +278,8 @@ def test_multi_teacher_averaging(self): """Test that KL is averaged across multiple teachers.""" apply_fn = self._get_function() args = make_mopd_full_vocab_args( - mopd_eps_low=0.0, mopd_eps_high=1000.0, + mopd_eps_low=0.0, + mopd_eps_high=1000.0, _mopd_teachers_parsed=[ {"name": "math_teacher", "domain": "math"}, {"name": "code_teacher", "domain": "code"}, @@ -296,13 +304,18 @@ def test_multi_teacher_averaging(self): loss_masks = [torch.ones(3)] from slime.utils.ppo_utils import vocab_parallel_reverse_kl + kl_math = vocab_parallel_reverse_kl(student_logits[0], teacher_math[0], None) kl_code = vocab_parallel_reverse_kl(student_logits[0], teacher_code[0], None) - expected_avg_kl = (kl_math.sum() / 3 + kl_code.sum() / 4) / 2 # Not exact, just check shape + expected_avg_kl = (kl_math.sum() / 3 + kl_code.sum() / 4) / 2 # Not exact, just check shape # noqa: F841 kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -333,17 +346,21 @@ def test_is_weight_clipping(self): loss_masks = [torch.ones(3)] kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) # IS weight should be clipped — only token 2 (weight=1.0) survives # Nonzero fraction should be 1/3 is_nonzero_frac = metrics["mopd_is_nonzero_frac"].item() - assert abs(is_nonzero_frac - 1.0 / 3.0) < 0.05, ( - f"Expected ~1/3 nonzero IS weight fraction, got {is_nonzero_frac}" - ) + assert ( + abs(is_nonzero_frac - 1.0 / 3.0) < 0.05 + ), f"Expected ~1/3 nonzero IS weight fraction, got {is_nonzero_frac}" def test_none_teacher_for_sample(self): """Test that None entries in teacher logits are skipped.""" @@ -351,7 +368,8 @@ def test_none_teacher_for_sample(self): # Two samples, two teachers; sample 0 has only math, sample 1 has both args = make_mopd_full_vocab_args( - mopd_eps_low=0.0, mopd_eps_high=1000.0, + mopd_eps_low=0.0, + mopd_eps_high=1000.0, _mopd_teachers_parsed=[ {"name": "math_teacher", "domain": "math"}, {"name": "code_teacher", "domain": "code"}, @@ -374,8 +392,12 @@ def test_none_teacher_for_sample(self): loss_masks = [torch.ones(3), torch.ones(4)] kl_loss, metrics = apply_fn( - args, batch, [student_0, student_1], - teacher_logits_per_domain, loss_masks, self._sum_of_sample_mean, + args, + batch, + [student_0, student_1], + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -401,16 +423,24 @@ def test_loss_mask_effect(self): current_log_probs = [torch.zeros(5)] kl_loss_masked, _ = apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) # With all-ones mask for comparison loss_masks_all = [torch.ones(5)] kl_loss_all, _ = apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks_all, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks_all, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -442,16 +472,20 @@ def test_current_log_probs_used_for_is_weights(self): loss_masks = [torch.ones(3)] kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) # Only token 1 should survive IS weight clipping → 1/3 nonzero is_nonzero_frac = metrics["mopd_is_nonzero_frac"].item() - assert abs(is_nonzero_frac - 1.0 / 3.0) < 0.05, ( - f"Expected ~1/3 nonzero IS weight fraction with current_log_probs, got {is_nonzero_frac}" - ) + assert ( + abs(is_nonzero_frac - 1.0 / 3.0) < 0.05 + ), f"Expected ~1/3 nonzero IS weight fraction with current_log_probs, got {is_nonzero_frac}" def test_current_log_probs_length_mismatch_raises(self): """Test that mismatched current_log_probs length raises ValueError.""" @@ -469,8 +503,12 @@ def test_current_log_probs_length_mismatch_raises(self): with pytest.raises(ValueError, match="student_log_probs length"): apply_fn( - args, batch, student_logits, teacher_logits_per_domain, - loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_logits_per_domain, + loss_masks, + self._sum_of_sample_mean, current_log_probs=bad_current_log_probs, ) @@ -630,9 +668,9 @@ def test_no_temperature_scaling(self): # Compare with manual extraction of response logits from the input # Response logits: logits[0, T-R-1:T-1, :] (shifted by 1 for next-token prediction) expected_logits = logits[0, T - R - 1 : T - 1, :] # [R, V] - assert torch.allclose(logits_out[0], expected_logits, atol=1e-5), ( - "get_logits_for_distill should return raw logits without temperature scaling" - ) + assert torch.allclose( + logits_out[0], expected_logits, atol=1e-5 + ), "get_logits_for_distill should return raw logits without temperature scaling" # --------------------------------------------------------------------------- @@ -654,13 +692,11 @@ def test_topk_kl_approximates_full_kl(self): teacher_logits = torch.randn(R, V) # Full-vocab KL - full_kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) + full_kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None) # noqa: F841 # Top-k KL topk_vals, topk_idx = teacher_logits.topk(k, dim=-1) - topk_kl = vocab_parallel_topk_reverse_kl( - student_logits, topk_vals, topk_idx, V, process_group=None - ) + topk_kl = vocab_parallel_topk_reverse_kl(student_logits, topk_vals, topk_idx, V, process_group=None) # With k close to V, the top-k should be close to full-vocab KL # Allow some tolerance due to tail approximation @@ -697,8 +733,9 @@ def test_topk_kl_gradient_flows(self): loss.backward() assert student_logits.grad is not None, "student_logits should have gradients" - assert not torch.allclose(student_logits.grad, torch.zeros_like(student_logits.grad)), \ - "student_logits gradients should be non-zero" + assert not torch.allclose( + student_logits.grad, torch.zeros_like(student_logits.grad) + ), "student_logits gradients should be non-zero" def test_topk_kl_increases_with_smaller_k(self): """Top-k KL should generally increase as k decreases (less accurate approximation).""" @@ -710,13 +747,15 @@ def test_topk_kl_increases_with_smaller_k(self): student_logits = torch.randn(R, V) teacher_logits = torch.randn(R, V) - full_kl = vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None).sum().item() + full_kl = ( # noqa: F841 + vocab_parallel_reverse_kl(student_logits, teacher_logits, process_group=None).sum().item() + ) k_large = 40 topk_vals_l, topk_idx_l = teacher_logits.topk(k_large, dim=-1) - kl_large = vocab_parallel_topk_reverse_kl( - student_logits, topk_vals_l, topk_idx_l, V, process_group=None - ).sum().item() + kl_large = ( + vocab_parallel_topk_reverse_kl(student_logits, topk_vals_l, topk_idx_l, V, process_group=None).sum().item() + ) # With k=V (full vocab), top-k should be closer to full KL assert torch.isfinite(torch.tensor(kl_large)), "Top-k KL should be finite" @@ -734,6 +773,7 @@ def _mock_deps(self, monkeypatch): def _get_function(self): from slime.backends.megatron_utils.loss import apply_mopd_topk_to_loss + return apply_mopd_topk_to_loss def _sum_of_sample_mean(self, tensor): @@ -785,8 +825,13 @@ def test_single_teacher_topk_loss(self): loss_masks = [torch.ones(R1), torch.ones(R2)] kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_topk_logits, - teacher_topk_indices, loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_topk_logits, + teacher_topk_indices, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -816,8 +861,13 @@ def test_topk_loss_is_non_negative(self): loss_masks = [torch.ones(5)] kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_topk_logits, - teacher_topk_indices, loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_topk_logits, + teacher_topk_indices, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -844,15 +894,20 @@ def test_topk_is_weight_clipping(self): loss_masks = [torch.ones(3)] kl_loss, metrics = apply_fn( - args, batch, student_logits, teacher_topk_logits, - teacher_topk_indices, loss_masks, self._sum_of_sample_mean, + args, + batch, + student_logits, + teacher_topk_logits, + teacher_topk_indices, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) is_nonzero_frac = metrics["mopd_is_nonzero_frac"].item() - assert abs(is_nonzero_frac - 1.0 / 3.0) < 0.05, ( - f"Expected ~1/3 nonzero IS weight fraction, got {is_nonzero_frac}" - ) + assert ( + abs(is_nonzero_frac - 1.0 / 3.0) < 0.05 + ), f"Expected ~1/3 nonzero IS weight fraction, got {is_nonzero_frac}" def test_topk_none_teacher_for_sample(self): """Test that None entries in teacher data are skipped.""" @@ -891,9 +946,13 @@ def test_topk_none_teacher_for_sample(self): loss_masks = [torch.ones(3), torch.ones(4)] kl_loss, metrics = apply_fn( - args, batch, [student_0, student_1], - teacher_topk_logits, teacher_topk_indices, - loss_masks, self._sum_of_sample_mean, + args, + batch, + [student_0, student_1], + teacher_topk_logits, + teacher_topk_indices, + loss_masks, + self._sum_of_sample_mean, current_log_probs=current_log_probs, ) @@ -912,7 +971,7 @@ def test_topk_k_parameter_effect(self): teacher_logits_raw = [torch.randn(5, V)] # Full-vocab KL as ground truth - full_kl = vocab_parallel_reverse_kl(student_logits[0], teacher_logits_raw[0], None).sum().item() + full_kl = vocab_parallel_reverse_kl(student_logits[0], teacher_logits_raw[0], None).sum().item() # noqa: F841 # Top-k with k=5 k_small = 5 @@ -923,9 +982,14 @@ def test_topk_k_parameter_effect(self): loss_masks = [torch.ones(5)] kl_small, _ = apply_fn( - args_small, batch, student_logits, - {"math": [topk_vals_s]}, {"math": [topk_idx_s]}, - loss_masks, self._sum_of_sample_mean, current_log_probs=current_log_probs, + args_small, + batch, + student_logits, + {"math": [topk_vals_s]}, + {"math": [topk_idx_s]}, + loss_masks, + self._sum_of_sample_mean, + current_log_probs=current_log_probs, ) # Top-k with k=18 (close to V) @@ -934,9 +998,14 @@ def test_topk_k_parameter_effect(self): args_large = self._make_args(mopd_topk_k=k_large) kl_large, _ = apply_fn( - args_large, batch, student_logits, - {"math": [topk_vals_l]}, {"math": [topk_idx_l]}, - loss_masks, self._sum_of_sample_mean, current_log_probs=current_log_probs, + args_large, + batch, + student_logits, + {"math": [topk_vals_l]}, + {"math": [topk_idx_l]}, + loss_masks, + self._sum_of_sample_mean, + current_log_probs=current_log_probs, ) # Larger k should generally be closer to full KL (both are approximations) @@ -1038,4 +1107,4 @@ def test_topk_k_default(self): mopd_alpha=0.0, ) slime_validate_args(args) - assert args.mopd_topk_k == 1024 \ No newline at end of file + assert args.mopd_topk_k == 1024 diff --git a/tests/test_mopd_sglang_topk_pipeline.py b/tests/test_mopd_sglang_topk_pipeline.py index 7bf0bc12d7..985f3ef57e 100644 --- a/tests/test_mopd_sglang_topk_pipeline.py +++ b/tests/test_mopd_sglang_topk_pipeline.py @@ -27,6 +27,7 @@ # 工具函数 # =========================================================================== + def _softmax(logits): """Numerically stable softmax.""" max_val = max(logits) @@ -42,13 +43,14 @@ def _log_softmax(logits): return [x - log_sum_exp for x in logits] -NEG_INF = float('-inf') +NEG_INF = float("-inf") # =========================================================================== # 1. 模拟 SGLang 响应 # =========================================================================== + def make_mock_sglang_response(vocab_size, seq_len, topk_k, input_ids): """构造模拟的 SGLang /generate 响应。 @@ -57,6 +59,7 @@ def make_mock_sglang_response(vocab_size, seq_len, topk_k, input_ids): - meta_info["input_top_logprobs"]: [[(log_prob, token_id, None), ...], ...] """ import random + random.seed(42) input_token_logprobs = [] @@ -92,6 +95,7 @@ def make_mock_sglang_response(vocab_size, seq_len, topk_k, input_ids): # 测试 1: SGLang 响应格式与字段名 # =========================================================================== + def test_sglang_response_format(): """验证 SGLang 响应中的字段名与 mopd.py 解析代码一致。""" # SGLang 源码 tokenizer_manager.py:1757 中的确认字段名 @@ -125,8 +129,10 @@ def test_sglang_response_format(): # 测试 2: _build_payload 构造正确的 SGLang 请求 # =========================================================================== + def test_build_payload(): """验证 _build_payload 根据蒸馏类型构造正确的 payload。""" + # 直接模拟 _build_payload 的逻辑,不导入 slime def build_payload(sample_tokens, mopd_distill_type, mopd_topk_k=1024): payload = { @@ -161,7 +167,7 @@ def build_payload(sample_tokens, mopd_distill_type, mopd_topk_k=1024): # full_vocab 模式应报错 try: build_payload([1, 2, 3], "full_vocab") - assert False, "full_vocab 应抛出 ValueError" + raise AssertionError("full_vocab 应抛出 ValueError") except ValueError as e: assert "full_vocab" in str(e) print(" full_vocab raises ValueError ✓") @@ -173,6 +179,7 @@ def build_payload(sample_tokens, mopd_distill_type, mopd_topk_k=1024): # 测试 3: post_process_rewards 提取逻辑 # =========================================================================== + def test_post_process_rewards_extraction(): """验证从 SGLang 响应中提取 top-k 数据的逻辑。""" vocab_size = 200 @@ -224,7 +231,7 @@ def test_post_process_rewards_extraction(): print(f" top_k: 提取 {response_length} x {topk_k} 数据 ✓") # 验证:SGLang 返回的 k 等于 topk_k 时,不应有 -inf padding - no_padding_count = sum(1 for l in topk_logits_list[0] if l != NEG_INF) + no_padding_count = sum(1 for v in topk_logits_list[0] if v != NEG_INF) assert no_padding_count == topk_k, f"应无 padding, 实际有效数={no_padding_count}" print(f" top_k: SGLang 返回 {topk_k} 个条目,无 padding ✓") @@ -241,6 +248,7 @@ def test_post_process_rewards_extraction(): # 测试 4: TP 分片 — 全局 token ID → 局部索引 + -inf padding # =========================================================================== + def test_tp_sharding(): """模拟 actor.py 中 SGLang top-k 数据的 TP 分片逻辑。 @@ -278,14 +286,8 @@ def test_tp_sharding(): for pos in range(seq_len): # 模拟分片逻辑 - in_shard = [ - (vocab_offset <= idx < vocab_offset + vocab_local_size) - for idx in all_topk_indices[pos] - ] - local_indices = [ - max(0, min(idx - vocab_offset, vocab_local_size - 1)) - for idx in all_topk_indices[pos] - ] + in_shard = [(vocab_offset <= idx < vocab_offset + vocab_local_size) for idx in all_topk_indices[pos]] + local_indices = [max(0, min(idx - vocab_offset, vocab_local_size - 1)) for idx in all_topk_indices[pos]] # 构建 shard 内 top-k local_topk_logits = [NEG_INF] * topk_k @@ -298,23 +300,20 @@ def test_tp_sharding(): slot += 1 # 验证 padding 用的是 -inf - padding_count = topk_k - slot for i in range(slot, topk_k): assert local_topk_logits[i] == NEG_INF, f"padding 应为 -inf, 实际={local_topk_logits[i]}" assert local_topk_indices[i] == 0, f"padding index 应为 0, 实际={local_topk_indices[i]}" # 验证 valid_topk_mask 自动检测 - valid_mask = [l != NEG_INF for l in local_topk_logits] + valid_mask = [v != NEG_INF for v in local_topk_logits] assert sum(valid_mask) == slot, f"rank={tp_rank} pos={pos}: 有效数={sum(valid_mask)}, 期望={slot}" # 验证有效条目的局部索引正确 for i in range(slot): - expected_local = all_topk_indices[pos][ - [j for j, v in enumerate(in_shard) if v][i] - ] - vocab_offset - assert local_topk_indices[i] == expected_local, ( - f"rank={tp_rank} pos={pos}: 局部索引={local_topk_indices[i]}, 期望={expected_local}" - ) + expected_local = all_topk_indices[pos][[j for j, v in enumerate(in_shard) if v][i]] - vocab_offset + assert ( + local_topk_indices[i] == expected_local + ), f"rank={tp_rank} pos={pos}: 局部索引={local_topk_indices[i]}, 期望={expected_local}" print(f" 分片验证: tp_size={tp_size}, vocab_local_size={vocab_local_size} ✓") @@ -323,10 +322,7 @@ def test_tp_sharding(): total_valid = 0 for tp_rank in range(tp_size): vocab_offset = tp_rank * vocab_local_size - in_shard = sum( - 1 for idx in all_topk_indices[pos] - if vocab_offset <= idx < vocab_offset + vocab_local_size - ) + in_shard = sum(1 for idx in all_topk_indices[pos] if vocab_offset <= idx < vocab_offset + vocab_local_size) total_valid += in_shard assert total_valid == topk_k, f"pos={pos}: 总有效数={total_valid}, 期望={topk_k}" @@ -335,8 +331,8 @@ def test_tp_sharding(): # 验证:0.0 padding 的旧 bug 会导致 valid_mask 误判 old_padding_logits = [2.0, 1.5, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF] bad_padding_logits = [2.0, 1.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # 旧 bug - correct_mask = [l != NEG_INF for l in old_padding_logits] - wrong_mask = [l != NEG_INF for l in bad_padding_logits] # 全部为 True! + correct_mask = [v != NEG_INF for v in old_padding_logits] + wrong_mask = [v != NEG_INF for v in bad_padding_logits] # 全部为 True! assert sum(correct_mask) == 2, "-inf padding: 2 个有效条目 ✓" assert sum(wrong_mask) == topk_k, f"0.0 padding bug: 所有 {topk_k} 个都被误判为有效 ✗" print(" -inf padding vs 0.0 padding 对比: 旧 bug 已确认修复 ✓") @@ -348,6 +344,7 @@ def test_tp_sharding(): # 测试 5: 近似 reverse KL 计算(单进程模拟) # =========================================================================== + def test_topk_reverse_kl_approximation(): """模拟 vocab_parallel_topk_reverse_kl 核心计算逻辑。 @@ -361,6 +358,7 @@ def test_topk_reverse_kl_approximation(): seq_len = 3 import random + random.seed(123) total_error = 0.0 @@ -377,17 +375,13 @@ def test_topk_reverse_kl_approximation(): s_log_probs = _log_softmax(s_logits) # 1. 精确 KL (全词表) - exact_kl = sum( - s_probs[y] * (s_log_probs[y] - t_log_probs[y]) - for y in range(vocab_size) - if s_probs[y] > 1e-15 - ) + exact_kl = sum(s_probs[y] * (s_log_probs[y] - t_log_probs[y]) for y in range(vocab_size) if s_probs[y] > 1e-15) # 2. teacher top-k t_indexed = [(t_logits[i], i) for i in range(vocab_size)] t_indexed.sort(key=lambda x: -x[0]) topk_global_indices = [t_indexed[k][1] for k in range(topk_k)] - topk_teacher_logits = [t_logits[idx] for idx in topk_global_indices] + topk_teacher_logits = [t_logits[idx] for idx in topk_global_indices] # noqa: F841 # 3. 在 top-k 位置收集 student 概率 student_topk_probs = [s_probs[idx] for idx in topk_global_indices] @@ -401,7 +395,7 @@ def test_topk_reverse_kl_approximation(): # 5. KL_topk = Σ_{y ∈ topk} π_s(y) [log π_s(y) - log π_t(y)] kl_topk = sum( sp * (slp - tlp) - for sp, slp, tlp in zip(student_topk_probs, student_topk_log_probs, teacher_topk_log_probs) + for sp, slp, tlp in zip(student_topk_probs, student_topk_log_probs, teacher_topk_log_probs, strict=False) ) # 6. 尾部修正 @@ -421,8 +415,10 @@ def test_topk_reverse_kl_approximation(): max_error = max(max_error, error) if pos == 0: - print(f" pos=0: exact_kl={exact_kl:.6f}, approx_kl={approx_kl:.6f}, " - f"error={error:.6f} ({error/max(abs(exact_kl), 1e-10)*100:.1f}%)") + print( + f" pos=0: exact_kl={exact_kl:.6f}, approx_kl={approx_kl:.6f}, " + f"error={error:.6f} ({error/max(abs(exact_kl), 1e-10)*100:.1f}%)" + ) print(f" kl_topk={kl_topk:.6f}, kl_tail={kl_tail:.6f}") print(f" student_topk_mass={student_topk_mass:.4f}, student_tail_mass={student_tail_mass:.4f}") print(f" teacher_tail_mass={teacher_tail_mass:.4f}") @@ -440,6 +436,7 @@ def test_topk_reverse_kl_approximation(): # 测试 6: combined_reward_func bypass 逻辑 # =========================================================================== + def test_combined_reward_func_bypass(): """验证 combined_reward_func 中 custom_rm_path bypass 模式。""" args = SimpleNamespace(custom_rm_path="slime.rollout.mopd.combined_reward_func") @@ -460,6 +457,7 @@ def test_combined_reward_func_bypass(): # 测试 7: arguments.py 自动配置逻辑 # =========================================================================== + def test_arguments_auto_config(): """验证 SGLang 模式自动配置逻辑。""" # 场景 1: 纯蒸馏 (alpha=0, 无 rm_type) @@ -497,10 +495,7 @@ def test_arguments_auto_config(): print(" 场景2 (alpha>0): 使用 combined 函数 ✓") # 场景 3: alpha>0 但没有 rm_type → 应该报错 - _mopd_uses_combined_rm = ( - args2.custom_rm_path is not None - and "combined_reward_func" in args2.custom_rm_path - ) + _mopd_uses_combined_rm = args2.custom_rm_path is not None and "combined_reward_func" in args2.custom_rm_path # 在真实代码中,如果 combined_rm 需要 rm_type 但 rm_type=None,应该报错 assert _mopd_uses_combined_rm # 模拟验证逻辑 @@ -532,6 +527,7 @@ def test_arguments_auto_config(): # 测试 8: 端到端数据流模拟 # =========================================================================== + def test_end_to_end_data_flow(): """模拟完整数据流: SGLang响应 → mopd.py提取 → rollout.py收集 → actor.py TP分片。 @@ -603,14 +599,8 @@ def test_end_to_end_data_flow(): global_indices = indices_per_sample[pos] global_logits = logits_per_sample[pos] - in_shard = [ - (vocab_offset <= idx < vocab_offset + vocab_local_size) - for idx in global_indices - ] - local_indices = [ - max(0, min(idx - vocab_offset, vocab_local_size - 1)) - for idx in global_indices - ] + in_shard = [(vocab_offset <= idx < vocab_offset + vocab_local_size) for idx in global_indices] + local_indices = [max(0, min(idx - vocab_offset, vocab_local_size - 1)) for idx in global_indices] l_logits = [NEG_INF] * topk_k l_indices = [0] * topk_k @@ -629,9 +619,7 @@ def test_end_to_end_data_flow(): # 验证: 每个 shard 中每个位置都有有效条目 for pos in range(response_length): - valid_count = sum( - 1 for l in local_topk_logits_all[0][pos] if l != NEG_INF - ) + valid_count = sum(1 for v in local_topk_logits_all[0][pos] if v != NEG_INF) assert valid_count > 0, f"rank={tp_rank} pos={pos} 无有效条目" # padding 条目应为 -inf for k in range(valid_count, topk_k): @@ -658,6 +646,7 @@ def test_end_to_end_data_flow(): # 测试 9: 边界情况 # =========================================================================== + def test_edge_cases(): """测试边界情况。""" # Case 1: topk_k 大于 vocab_size @@ -672,16 +661,16 @@ def test_edge_cases(): assert padding_needed == topk_k - vocab_size # padding 用 -inf 和 index 0 pad_logits = [NEG_INF] * padding_needed - pad_indices = [0] * padding_needed - assert all(l == NEG_INF for l in pad_logits) + pad_indices = [0] * padding_needed # noqa: F841 + assert all(v == NEG_INF for v in pad_logits) print(f" 边界1: topk_k > vocab_size, padding={padding_needed} ✓") # Case 2: 空 top-k 数据 (pos_data is None) pos_data = None if pos_data is None or len(pos_data) == 0: pad_logits = [NEG_INF] * 8 - pad_indices = [0] * 8 - assert all(l == NEG_INF for l in pad_logits) + pad_indices = [0] * 8 # noqa: F841 + assert all(v == NEG_INF for v in pad_logits) print(" 边界2: 空位置数据 → 全部 -inf padding ✓") # Case 3: response 长度小于 top-k 数据长度 @@ -696,6 +685,7 @@ def test_edge_cases(): # Case 4: 单 teacher (domain="default") 与多 teacher # 仅验证 MOPD_TEACHERS_JSON 格式解析 import json + single_teacher = json.loads('[{"name":"teacher1","domain":"default"}]') multi_teacher = json.loads('[{"name":"math","domain":"math"},{"name":"code","domain":"code"}]') assert len(single_teacher) == 1 @@ -737,6 +727,7 @@ def test_edge_cases(): except Exception as e: print(f"[FAIL] {name}: {e}") import traceback + traceback.print_exc() failed += 1 @@ -744,4 +735,4 @@ def test_edge_cases(): print(f"结果: {passed} 通过, {failed} 失败") print(f"{'=' * 60}") - sys.exit(0 if failed == 0 else 1) \ No newline at end of file + sys.exit(0 if failed == 0 else 1) diff --git a/tools/convert_torch_dist_to_hf_parallel.py b/tools/convert_torch_dist_to_hf_parallel.py index 5c1378ed89..01584fe532 100644 --- a/tools/convert_torch_dist_to_hf_parallel.py +++ b/tools/convert_torch_dist_to_hf_parallel.py @@ -382,6 +382,7 @@ def _merge_missing_keys_from_origin_hf(origin_hf_dir, output_dir, converted_weig print(f"Warning: {src_path} not found. Skipping keys: {keys}") continue from safetensors import safe_open + with safe_open(src_path, framework="pt", device="cpu") as f: for key in keys: missing_tensors[key] = f.get_tensor(key) @@ -389,9 +390,7 @@ def _merge_missing_keys_from_origin_hf(origin_hf_dir, output_dir, converted_weig # Now we need to insert these tensors into the existing safetensors files. # Strategy: find the last safetensors file, add the missing tensors there, # or create a new file if it would exceed chunk_size. - total_files = max( - int(v.split("-")[-2]) for v in converted_weight_map.values() - ) + total_files = max(int(v.split("-")[-2]) for v in converted_weight_map.values()) # Re-number files to include the missing keys in a new shard # First, collect existing tensors from the last file and append missing ones last_file_pattern = f"model-{total_files:05d}-of-{total_files:05d}.safetensors" @@ -686,5 +685,7 @@ def conversion_worker( # These keys exist in the original HF checkpoint but are not present in the Megatron checkpoint, # because Megatron only trains the language model part. if args.origin_hf_dir: - _merge_missing_keys_from_origin_hf(args.origin_hf_dir, args.output_dir, final_weight_map_fixed, args.chunk_size) + _merge_missing_keys_from_origin_hf( + args.origin_hf_dir, args.output_dir, final_weight_map_fixed, args.chunk_size + ) copy_assets(args.origin_hf_dir, args.output_dir) diff --git a/tools/merge_missing_keys.py b/tools/merge_missing_keys.py index 0cf26abb70..1ee3789c05 100644 --- a/tools/merge_missing_keys.py +++ b/tools/merge_missing_keys.py @@ -25,10 +25,14 @@ def main(): parser = argparse.ArgumentParser(description="Merge missing keys from original HF model into converted checkpoint") - parser.add_argument("--origin-hf-dir", type=str, required=True, help="Path to the original HuggingFace model directory") + parser.add_argument( + "--origin-hf-dir", type=str, required=True, help="Path to the original HuggingFace model directory" + ) parser.add_argument("--converted-dir", type=str, required=True, help="Path to the converted checkpoint directory") parser.add_argument("--dry-run", action="store_true", help="Only print missing keys without merging") - parser.add_argument("--chunk-size", type=int, default=5 * 1024**3, help="Chunk size for safetensors files (default 5GB)") + parser.add_argument( + "--chunk-size", type=int, default=5 * 1024**3, help="Chunk size for safetensors files (default 5GB)" + ) args = parser.parse_args() # Load both index files @@ -57,6 +61,7 @@ def main(): # Categorize missing keys from collections import Counter + prefix_patterns = Counter() for key in missing_keys: parts = key.split(".") @@ -105,11 +110,6 @@ def main(): missing_size = sum(t.numel() * t.element_size() for t in missing_tensors.values()) print(f"Missing tensors total size: {missing_size / 1e9:.2f} GB") - # Find the last file and check its size - last_file_idx = total_files - last_file_name = f"model-{last_file_idx:05d}-of-{total_files:05d}.safetensors" - last_file_path = os.path.join(args.converted_dir, last_file_name) - # Add missing tensors to a new shard new_total = total_files + 1 new_shard_name = f"model-{new_total:05d}-of-{new_total:05d}.safetensors" @@ -153,4 +153,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tools/patch_attention_gate_on_cluster.py b/tools/patch_attention_gate_on_cluster.py index b9e05e7d18..a2e95b56d9 100644 --- a/tools/patch_attention_gate_on_cluster.py +++ b/tools/patch_attention_gate_on_cluster.py @@ -14,15 +14,16 @@ """ import argparse +import subprocess import textwrap import ray -import subprocess FILE_PATH = "/root/Megatron-LM/megatron/core/transformer/attention.py" REMOTE_SCRIPT_PATH = "/tmp/_patch_attention_gate.py" -DIAGNOSE_SCRIPT = textwrap.dedent(f"""\ +DIAGNOSE_SCRIPT = textwrap.dedent( + f"""\ import sys FILE_PATH = {FILE_PATH!r} try: @@ -64,10 +65,12 @@ for i, line in enumerate(lines): if "output_gate" in line or "gate.reshape" in line: print(f" {{i+1}}: {{line.rstrip()}}") -""") +""" +) # Patch script: uses robust line-by-line approach instead of string matching -PATCH_SCRIPT = textwrap.dedent(f"""\ +PATCH_SCRIPT = textwrap.dedent( + f"""\ import sys, shutil FILE_PATH = {FILE_PATH!r} @@ -144,9 +147,11 @@ with open(FILE_PATH, "w") as f: f.writelines(new_lines) print("PATCHED") -""") +""" +) -ROLLBACK_SCRIPT = textwrap.dedent(f"""\ +ROLLBACK_SCRIPT = textwrap.dedent( + f"""\ import sys, shutil, os FILE_PATH = {FILE_PATH!r} @@ -158,7 +163,8 @@ shutil.copy2(backup_path, FILE_PATH) print("ROLLED_BACK") -""") +""" +) def main(): @@ -176,11 +182,7 @@ def main(): ray.init(address="auto") - nodes = [ - n["NodeManagerAddress"] - for n in ray.nodes() - if n["Alive"] - ] + nodes = [n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"]] print(f"Found {len(nodes)} alive nodes") # Only check one node for diagnose (they should all be the same) @@ -188,6 +190,7 @@ def main(): tasks = [] for node_ip in target_nodes: + @ray.remote(resources={f"node:{node_ip}": 0.001}) def run_on_node(node_ip=node_ip): # Step 1: write script to temp file From f51403558a47290d190fc9dfabe1859be73aca4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=95=E4=B8=80?= Date: Thu, 11 Jun 2026 14:10:17 +0800 Subject: [PATCH 14/14] fix(tests): fix two bugs in test_mopd_full_vocab.py 1. Fix RuntimeError in test_topk_kl_identical_distributions: kl.item() fails on multi-element tensor, changed to (kl >= -0.1).all() 2. Add missing vocab_size=20 to TestApplyMopdTopkToLoss._make_args(): apply_mopd_topk_to_loss requires args.vocab_size which was not set --- tests/test_mopd_full_vocab.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_mopd_full_vocab.py b/tests/test_mopd_full_vocab.py index 42396070b5..0c80aa732c 100644 --- a/tests/test_mopd_full_vocab.py +++ b/tests/test_mopd_full_vocab.py @@ -716,7 +716,7 @@ def test_topk_kl_identical_distributions(self): assert kl.shape == (3,) # Should be close to 0 (not exact due to tail approximation with V > k) - assert kl.item() >= -0.1, f"Top-k KL should be ~0 for identical distributions, got {kl}" + assert (kl >= -0.1).all(), f"Top-k KL should be ~0 for identical distributions, got {kl}" def test_topk_kl_gradient_flows(self): """Gradient flows through student logits in top-k KL.""" @@ -793,6 +793,7 @@ def _make_args(self, **overrides): mopd_sampling_logprobs_key="rollout_log_probs", _mopd_teachers_parsed=[{"name": "math_teacher", "domain": "math"}], padded_vocab_size=20, + vocab_size=20, ) defaults.update(overrides) return Namespace(**defaults)