-
Notifications
You must be signed in to change notification settings - Fork 308
feat: add RL checkpoint format backward compat integration test #2776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| max_steps = 7 | ||
| seq_len = 2048 | ||
|
|
||
| [ckpt] | ||
| resume_step = 3 | ||
|
|
||
| [model] | ||
| name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" | ||
|
|
||
| [trainer.model.debug] | ||
| num_layers = 1 | ||
| random_init = true | ||
|
|
||
| [trainer.optim] | ||
| lr = 3e-6 | ||
|
|
||
| [orchestrator] | ||
| batch_size = 128 | ||
| group_size = 16 | ||
|
|
||
| [orchestrator.train.sampling] | ||
| max_completion_tokens = 128 | ||
|
|
||
| [[orchestrator.train.env]] | ||
| id = "reverse-text" | ||
|
|
||
| [inference] | ||
|
|
||
| [orchestrator.renderer] | ||
| name = "default" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| max_steps = 5 | ||
| seq_len = 2048 | ||
|
|
||
| [ckpt] | ||
| interval = 3 | ||
|
|
||
| [model] | ||
| name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" | ||
|
|
||
| [trainer.model.debug] | ||
| num_layers = 1 | ||
| random_init = true | ||
|
|
||
| [trainer.optim] | ||
| lr = 3e-6 | ||
|
|
||
| [orchestrator] | ||
| batch_size = 128 | ||
| group_size = 16 | ||
|
|
||
| [orchestrator.train.sampling] | ||
| max_completion_tokens = 128 | ||
|
|
||
| [[orchestrator.train.env]] | ||
| id = "reverse-text" | ||
|
|
||
| [inference] | ||
|
|
||
| [orchestrator.renderer] | ||
| name = "default" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| """Integration test for checkpoint format backward compatibility. | ||
|
|
||
| Generates a checkpoint using a pinned version of prime-rl (via git worktree), | ||
| then verifies the current version can resume RL training from that checkpoint. | ||
| This catches breaking changes to the checkpoint format across versions. | ||
|
|
||
| Set CKPT_FORMAT_REF to a git tag or commit hash to pin the checkpoint | ||
| generator version. Defaults to the current HEAD (no backward compat check). | ||
|
|
||
| When pinning to an older version, make sure the config files in | ||
| configs/ci/integration/ckpt_compat/ are compatible with that version. | ||
| The test copies them into the worktree automatically. | ||
| """ | ||
|
|
||
| import os | ||
| import subprocess | ||
| import tempfile | ||
| from pathlib import Path | ||
| from typing import Generator | ||
|
|
||
| import pytest | ||
|
|
||
| from tests.conftest import ProcessResult | ||
| from tests.utils import check_no_error, check_reward_goes_up, check_reward_in_range, strip_escape_codes | ||
|
|
||
| pytestmark = [pytest.mark.gpu, pytest.mark.slow] | ||
|
|
||
| TIMEOUT = 600 # 10 minutes (includes uv sync for worktree) | ||
|
|
||
| # Pin this to a specific git tag/commit when the checkpoint format stabilizes. | ||
| # When the format intentionally changes, update this ref to the new commit. | ||
| # None = use current HEAD (no cross-version check). | ||
| CKPT_FORMAT_REF = os.environ.get("CKPT_FORMAT_REF", None) | ||
|
|
||
|
|
||
| def _get_head_commit(repo_dir: Path) -> str: | ||
| return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=repo_dir).decode().strip() | ||
|
|
||
|
|
||
| def _resolve_ref(repo_dir: Path, ref: str) -> str: | ||
| return subprocess.check_output(["git", "rev-parse", ref], cwd=repo_dir).decode().strip() | ||
|
|
||
|
|
||
| def _run_rl( | ||
| config_path: str, | ||
| output_dir: Path, | ||
| wandb_project: str, | ||
| wandb_name: str, | ||
| cwd: Path | None = None, | ||
| clean_output_dir: bool = False, | ||
| timeout: int = TIMEOUT, | ||
| ) -> ProcessResult: | ||
| """Run RL training as a subprocess.""" | ||
| cmd = [ | ||
| "uv", | ||
| "run", | ||
| "rl", | ||
| "@", | ||
| config_path, | ||
| "--wandb.project", | ||
| wandb_project, | ||
| "--wandb.name", | ||
| wandb_name, | ||
| "--output-dir", | ||
| output_dir.as_posix(), | ||
| ] | ||
| if clean_output_dir: | ||
| cmd.append("--clean-output-dir") | ||
|
|
||
| process = subprocess.Popen( | ||
| cmd, | ||
| cwd=cwd, | ||
| env={**os.environ, "PYTHONUNBUFFERED": "1"}, | ||
| ) | ||
| try: | ||
| process.wait(timeout=timeout) | ||
| except subprocess.TimeoutExpired: | ||
| process.terminate() | ||
| try: | ||
| process.wait(timeout=30) | ||
| except subprocess.TimeoutExpired: | ||
| process.kill() | ||
| process.wait() | ||
|
|
||
| return ProcessResult(process) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def wandb_name(branch_name: str) -> str: | ||
| return f"test-ckpt-compat:{branch_name}" | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def repo_dir() -> Path: | ||
| """Root directory of the prime-rl repo.""" | ||
| return Path(__file__).resolve().parents[2] | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def rl_cwd(repo_dir: Path, output_dir: Path) -> Generator[Path, None, None]: | ||
| """Working directory for the checkpoint generation RL run. | ||
|
|
||
| If CKPT_FORMAT_REF is set and differs from HEAD, creates a git worktree | ||
| at that ref so the checkpoint is generated by the old version. | ||
| Otherwise, uses the current checkout. | ||
| """ | ||
| ref = CKPT_FORMAT_REF | ||
| if ref is None: | ||
| yield repo_dir | ||
| return | ||
|
|
||
| head = _get_head_commit(repo_dir) | ||
| resolved = _resolve_ref(repo_dir, ref) | ||
|
|
||
| if resolved == head: | ||
| yield repo_dir | ||
| return | ||
|
|
||
| worktree = Path(tempfile.mkdtemp(prefix="prime_rl_ckpt_compat_")) | ||
| subprocess.check_call( | ||
| ["git", "worktree", "add", "--detach", str(worktree), resolved], | ||
| cwd=repo_dir, | ||
| ) | ||
|
|
||
| # Copy config files into the worktree so the old version can find them | ||
| src_configs = repo_dir / "configs" / "ci" / "integration" / "ckpt_compat" | ||
| dst_configs = worktree / "configs" / "ci" / "integration" / "ckpt_compat" | ||
| dst_configs.parent.mkdir(parents=True, exist_ok=True) | ||
| subprocess.check_call(["cp", "-r", str(src_configs), str(dst_configs)]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Config copy nests existing directoryMedium Severity When Reviewed by Cursor Bugbot for commit 55beff1. Configure here. |
||
|
|
||
| yield worktree | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worktree missing submodule checkoutMedium Severity When Reviewed by Cursor Bugbot for commit acf930d. Configure here. |
||
|
|
||
| subprocess.check_call( | ||
| ["git", "worktree", "remove", str(worktree), "--force"], | ||
| cwd=repo_dir, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def rl_process( | ||
| rl_cwd: Path, | ||
| wandb_project: str, | ||
| wandb_name: str, | ||
| output_dir: Path, | ||
| ) -> ProcessResult: | ||
| """Run RL training with the pinned version to produce a checkpoint.""" | ||
| return _run_rl( | ||
| config_path="configs/ci/integration/ckpt_compat/start.toml", | ||
| output_dir=output_dir, | ||
| wandb_project=wandb_project, | ||
| wandb_name=wandb_name, | ||
| cwd=rl_cwd, | ||
| clean_output_dir=True, | ||
| timeout=TIMEOUT, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def rl_resume_process( | ||
| rl_process, | ||
| repo_dir: Path, | ||
| wandb_project: str, | ||
| wandb_name: str, | ||
| output_dir: Path, | ||
| ) -> ProcessResult: | ||
| """Resume RL training with the current version from the checkpoint.""" | ||
| if rl_process.returncode != 0: | ||
| pytest.skip("Checkpoint generation failed, skipping resume test") | ||
|
|
||
| return _run_rl( | ||
| config_path="configs/ci/integration/ckpt_compat/resume.toml", | ||
| output_dir=output_dir, | ||
| wandb_project=wandb_project, | ||
| wandb_name=f"{wandb_name}-resume", | ||
| cwd=repo_dir, | ||
| timeout=TIMEOUT, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def test_no_error(rl_process: ProcessResult, output_dir: Path): | ||
| """Tests that the RL process does not fail.""" | ||
| check_no_error(rl_process, output_dir) | ||
|
|
||
|
|
||
| def test_reward_goes_up(rl_process: ProcessResult, test_no_error, output_dir: Path): | ||
| """Tests that the reward goes up in the RL process.""" | ||
| with open(output_dir / "logs" / "orchestrator.log", "r") as f: | ||
| orchestrator_stdout = strip_escape_codes(f.read()).splitlines() | ||
| check_reward_goes_up(orchestrator_stdout) | ||
|
|
||
|
|
||
| def test_reward_in_range(rl_process: ProcessResult, test_no_error, output_dir: Path): | ||
| """Tests that the reward is in range in the RL process.""" | ||
| with open(output_dir / "logs" / "orchestrator.log", "r") as f: | ||
| orchestrator_stdout = strip_escape_codes(f.read()).splitlines() | ||
| check_reward_in_range(orchestrator_stdout, min_threshold=0.4) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def test_no_error_resume(rl_resume_process: ProcessResult, output_dir: Path): | ||
| """Tests that the RL resume process does not fail.""" | ||
| check_no_error(rl_resume_process, output_dir) | ||
|
|
||
|
|
||
| def test_reward_in_range_resume(rl_resume_process: ProcessResult, test_no_error_resume, output_dir: Path): | ||
| """Tests that the reward is in range after resuming from the checkpoint.""" | ||
| with open(output_dir / "logs" / "orchestrator.log", "r") as f: | ||
| orchestrator_stdout = strip_escape_codes(f.read()).splitlines() | ||
| check_reward_in_range(orchestrator_stdout, min_threshold=0.4) | ||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Timeout skips child process cleanup
Medium Severity
On subprocess timeout,
_run_rlcallsterminate/killon the top-leveluvprocess only. Other integration tests usecleanup_process, which recursively signalstorchrun, inference, and other descendants started underuv run rl.Reviewed by Cursor Bugbot for commit 55beff1. Configure here.