diff --git a/.claude/skills/profile-run/SKILL.md b/.claude/skills/profile-run/SKILL.md new file mode 100644 index 0000000000..44eb1c3d8c --- /dev/null +++ b/.claude/skills/profile-run/SKILL.md @@ -0,0 +1,133 @@ +--- +name: profile-run +description: Profile/debug a Modal training or eval run for this async-RL project — access Modal app logs, fetch and analyze rollout_*.pt dumps from the slime-checkpoints volume, read the rollout dashboard, and find the W&B run. Use whenever the user reports a bug, shares a log/dashboard link, or asks why a run behaved a certain way. +--- + +# Profile a training/eval run + +How to ground any bug report for this project in **actual observed data** instead of speculation. + +## Grounding rule (most important) + +When the user reports a bug, shares a log line, or asks "why did X happen": +1. **Pull the real artifact first** — Modal logs, the `rollout_*.pt` dump, and/or W&B — before forming a root-cause claim. +2. **Quote what you actually observe** (counts, token spans, decoded text), not what the code "should" do. This project has burned multiple wrong hypotheses (reasoning-content loss, openai-SDK dropping fields) that the data disproved. +3. If a hypothesis contradicts the data, **say so and revise** — don't defend it. +4. If the needed artifact isn't accessible (e.g. historical logs aged out, W&B entity unknown), **ask the user a specific question** ("what's the app id / W&B entity?") rather than guessing. +5. Verify fixes against the real data path (decode with the real tokenizer / run the actual merge logic), not a synthetic stub. + +## Where to look first (efficiency) + +Cheapest → most expensive. Don't download a 70–120 MB dump to answer a question W&B already shows. +1. **W&B** (no download): trends over steps — reward/solve (`rollout/raw_reward`), collapse, grad_norm, KL, step_time, cache-hit, eval. Start here for "is it learning / did it collapse / how's throughput." +2. **Modal logs** (stream): live failures and per-sample summaries on the *current* step (tail-only, can't see old steps). +3. **Rollout dump** (download once, reuse): per-sample / token-level forensics — loss-mask spans, turn structure, drift/reset detection, decoded conversations. Only when you need what W&B can't show. + +Reuse a single analysis venv across the session (don't reinstall torch/transformers each time); download each dump once. + +## Environment facts + +- `modal` is **not** on PATH — use **`uvx modal …`** with args **unquoted** (`uvx modal app list --env junlin-dev`, not a single quoted string). Almost everything needs **`--env junlin-dev`**. So-called "modal failures" are nearly always one of: missing `uvx`, missing `--env` (→ empty output or "No such file"), the `app logs` stream auto-disconnecting (~15 min, expected), or arg-quoting — not flaky auth. When correctly invoked it's reliable (verified). First `uvx` call may be slow (downloads modal once, then cached). +- Runs launch from `multinode-training-guide/` via `EXPERIMENT_CONFIG= uv run --no-dev modal run -d slime/modal_train.py::train`. Configs live in `multinode-training-guide/slime/configs/`. +- Run tag / W&B group / dump subdir all equal `_RUN_TAG` (default e.g. `qwen3.6-35b-a3b-swe-gym-lite-colocate-1n`). Strip ANSI from CLI output with `sed -E 's/\x1b\[[0-9;]*m//g'`. + +## 1. Find the run's app + +```bash +cd /Users/junlin/Documents/Research/async-rl/multinode-training-guide +uvx modal app list --env junlin-dev 2>&1 | sed -E 's/\x1b\[[0-9;]*m//g' | grep -iE "w_qwen|ephemeral|running" +``` +The training run is the `ephemeral` app named after the config (e.g. `w_qwen3_6_…`). Note its `ap-…` id. + +## 2. Modal logs + +```bash +uvx modal app logs ap-XXXX --env junlin-dev 2>err | sed -E 's/\x1b\[[0-9;]*m//g' > /tmp/logs.txt +``` +Caveats (learned the hard way): +- It **streams from ~now**, tail-only — you **cannot scroll back** to an old step's startup. For step-0 history you usually need the dump instead. +- The stream **auto-disconnects ~every 15 min**. For a durable watch, use a Monitor with a self-reconnecting `while true; do uvx modal app logs …; sleep 3; done` loop and a tight `grep` filter (e.g. `adapter_session_empty|aborted:|wall_clock_timeout|\[mini-swe\].*tail:|Traceback`). Don't filter to only the happy path — include failure signatures or a crash looks identical to silence. + +Useful greps: `[async_rl] … reward=` (per-sample summaries), `[mini-swe] … exit=N … tail:` (in-sandbox agent failures, nonzero exit only), `[harbor]` (env/grading), `[trajectory] merge prompt base changed` (trajectory drift), `agent budget exhausted before step` (boot/budget). + +## 3. Rollout dumps — the main triage tool + +Dumps are on the `slime-checkpoints` volume, written per step (config `save_debug_rollout_data`). Train = `rollout_.pt`, eval = `rollout_eval_.pt`. Same `_RUN_TAG` relaunch **overwrites** them. + +```bash +uvx modal volume ls slime-checkpoints /swe_rollout_dumps/ --env junlin-dev # list + mtimes +uvx modal volume get slime-checkpoints /swe_rollout_dumps//rollout_0.pt /tmp/r0.pt --env junlin-dev --force +``` +Load (plain dicts — no slime import needed): +```python +import torch +dump = torch.load("/tmp/r0.pt", map_location="cpu", weights_only=False) # {"rollout_id", "samples":[...]} +s = dump["samples"][0] # dict: tokens, loss_mask, response_length, response (decoded str), + # prompt, rollout_log_probs, metadata{instance_id,is_solved,abort_reason,...}, + # status, reward, weight_versions, trace +``` +Key invariants: `tokens = prompt_ids + response_ids`; `loss_mask`/`rollout_log_probs` align with the **response** portion (last `response_length` tokens). `mask=1` = trained (assistant output), `mask=0` = context (tool results / re-rendered history). + +### Analysis recipes (verified) +```python +# trained fraction & turn count +trained = sum(s["loss_mask"]); frac = trained / s["response_length"] +turns = s["response"].count("<|im_start|>assistant") + 1 # +1 for the head turn + +# trajectory-merge RESET detector: base task prompt is ~1.0-1.5k tokens; a much larger +# prompt means early turns were dropped into the UNTRAINED prompt (line-107 reset). +prompt_len = len(s["tokens"]) - s["response_length"] +is_reset = prompt_len > 4000 + +# GRPO signal check: groups with zero reward variance give no advantage +# decode token spans / mask runs with the real tokenizer (see venv below) +``` + +### Throwaway analysis venv (tokenizer/torch without polluting anything) +```bash +uv venv --python 3.11 /tmp/dbg && \ +uv pip install --python /tmp/dbg/bin/python -q transformers jinja2 wandb && \ +uv pip install --python /tmp/dbg/bin/python -q torch --index-url https://download.pytorch.org/whl/cpu +# tokenizer-only (Qwen3.6 is public; don't download weights): +/tmp/dbg/bin/python -c "from huggingface_hub import snapshot_download as d; from transformers import AutoTokenizer; \ +AutoTokenizer.from_pretrained(d('Qwen/Qwen3.6-35B-A3B', allow_patterns=['tokenizer*','*.json']))" +``` +`apply_chat_template(..., tokenize=False)` then `tok.encode(...)` to get clean id lists (tokenize=True returns an Encoding in transformers 5.x). The chat template + `reasoning_parser=qwen3` / `tool_call_parser=qwen3_coder` are the source of multi-turn render quirks — see [[trajectory-drift-formaterror]]. + +## 4. Rollout dashboard (web) + +`async_rl_research/dashboard/` serves the same volume. URL pattern: +`https://modal-labs-junlin-dev--swe-rollout-dashboard-dashboard.modal.run/#/.pt/` +`convert.py` reconstructs turns by splitting the decoded `response` on `<|im_start|>` and parsing ``/``/``. The first (head) turn renders without an opening `` because the prompt prefilled it — that's expected, not a bug. + +## 5. W&B (verified — use this FIRST for trends; no big download) + +- **Entity `junlinwang`, project `Modal`** (= `WANDB_PROJECT`), run name = group = `_RUN_TAG` (suffix disabled, so one run per relaunch; relaunches make a *new* run with the same name). +- Auth: `wandb` reads `~/.netrc` automatically — no key needed in code (it's already logged in). On a fresh machine: `wandb login` or set `WANDB_API_KEY`. `wandb` isn't installed by default; `uv pip install wandb` into the analysis venv. + +```python +import wandb +api = wandb.Api() # loads creds from ~/.netrc +print(api.default_entity) # -> junlinwang +runs = list(api.runs("junlinwang/Modal", filters={"group": ""})) +runs.sort(key=lambda r: r.created_at) # latest = newest relaunch +run = runs[-1] +print(run.id, run.state) # e.g. crashed/running/finished +val = run.summary.get("rollout/raw_reward") # last-logged scalar +h = run.history(keys=["rollout/step","rollout/raw_reward","train/grad_norm","rollout/kl"], + samples=5000, pandas=True) # time series (use history, NOT scan_history) +``` + +**Metric semantics that bite (verified):** +- **`rollout/raw_reward` = the actual mean reward / solve signal** (step 0 ≈ 0.48 matched 122/252 solved in the dump). **`rollout/rewards` is the GRPO-centered advantage ≈ 0 by construction — do NOT use it to judge solve rate.** +- `train/grad_norm`, `train/kl_loss`, `train/ppo_kl`, `perf/step_time`, `sgl_engine/sglang_cache_hit_rate`, `eval//...`. ~193 keys; engine gauges are per-DP-rank means (×8) — see [[swe-rl-perf-profile]]. +- `run.history(...)` returns a sampled DataFrame; `_step` is the W&B logging step, use the `rollout/step`/`train/step` columns for the real step. `scan_history(keys=...)` gave all-None here — prefer `history(keys=..., samples=N, pandas=True)`. + +W&B alone shows trends without any download: e.g. `rollout/raw_reward` 0.48 → ~0.0 after step 0 is the **policy collapse**, visible instantly. + +## Quick interpretation map + +- `adapter_session_empty` (0 turns) → in-sandbox agent never completed a turn; check `[mini-swe] … tail:` and `agent budget exhausted before step` (see [[adapter-session-empty-budget-bug]]). +- `exit=-2` → `EXIT_BUDGET_EXCEEDED` (agent ran full `AGENT_TIME_BUDGET_SEC`). +- Large prompt / dropped turns / `merge prompt base changed` → trajectory drift ([[trajectory-drift-formaterror]]). +- `mean turns/sample` collapsing across steps (e.g. 34→1 after one update) → policy collapse; suspect rollout-logprob/EAGLE mismatch, recompute train-side logprobs vs `rollout_log_probs`. diff --git a/.gitignore b/.gitignore index 91d9db40fd..e9a268d45d 100644 --- a/.gitignore +++ b/.gitignore @@ -193,4 +193,3 @@ glm/ _examples_synced/ .env .DS_Store -scripts_agenticRL/ diff --git a/async_rl_research/README.md b/async_rl_research/README.md new file mode 100644 index 0000000000..b5b0790a78 --- /dev/null +++ b/async_rl_research/README.md @@ -0,0 +1,70 @@ +# async_rl_research + +Agentic-RL rollout package for slime. It runs an in-sandbox coding agent +(default: mini-swe-agent) against tasks on Modal, records exact SGLang tokens +via an HTTP adapter, and grades the result into a reward. Task families are +pluggable **envs**: SWE-Gym (git-diff grading in a clean sandbox) and harbor +datasets like USACO (in-place `test.sh` verification, multi-step aware). + +| Module | Role | +| --- | --- | +| `generate.py` | Per-sample rollout entrypoint (`--custom-generate-function-path async_rl_research.generate.generate`); orchestrates `runtime × env` | +| `agent/base.py` | `AgentRuntime` contract + shared launch/provision machinery + runtime registry | +| `agent/mini_swe_agent.py` | Default runtime (`mini-swe`): adapter choice, venv provisioning, headless runner | +| `env/base.py` | `RolloutEnv` contract (row schema, sandbox lifecycle, grading) + env registry; rows pick their env via `metadata.task_type` | +| `env/swe_gym.py` | SWE-Gym env: prebuilt image boot / pre_commands / git diff / clean-sandbox eval | +| `env/harbor.py` | Harbor env: Dockerfile boot, step loop, in-place verify (+ oracle-check CLI) | +| `env/convert2slime/` | Dataset converters, paired with their env by filename (see `data/README.md`) | +| `evalset.py` | Eval-set builder: spec YAML → subsampled per-subset jsonl + manifest + ready `--eval-config` (see `data/README.md`) | +| `modal_sandbox.py` | Modal backend (boot concurrency, create retry; registry refs + Dockerfile builds) | +| `dashboard/` | Modal web app (Bun/TS) for browsing the rollout debug dumps as agent conversations (see `dashboard/README.md`) | +| `profiles/PERF.md` | Measured rollout-time attribution, ranked fixes, and a step-by-step profiling guide | +| `profiles/profiling.py` | In-rollout instrumentation: env phase timers + adapter middleware (per-session turn count / gen time) → `sample.metadata["timing"]` → dumps | +| `profiles/profile.py` | Offline analyzer: W&B run + rollout dump → one attribution row in `profiles/runs.jsonl` + regenerated `profiles/ATTRIBUTION.md` | + +## Setup + +Harbor datasets need two things at rollout time: `ASYNC_RL_TASK_ROOT` pointing at +the converter's out dir (on the slime-data volume), and ideally an oracle pass +first (`python -m async_rl_research.environment.harbor --limit 3`, expect +reward=1.0) -- see `data/README.md` for the full flow. + +The rollout boot honors these env vars: + +| Env var | Purpose | +| --- | --- | +| `MODAL_REGISTRY_SECRET` | Modal secret for authenticated Docker Hub pulls (`dockerhub-creds`) | +| `MODAL_ENVIRONMENT` | Modal environment the images are cached in | +| `SLIME_AGENT_SANDBOX_ADD_PYTHON` | Add python to the image (must match rollout) | + +## Eval + +Eval reuses the exact same `generate()` → `runtime × env` stack as training: +slime's eval path (`slime/rollout/sglang_rollout.py::eval_rollout`) iterates +`--eval-config` datasets and calls the custom generate function with +`evaluation=True` per sample. Mean reward per dataset lands in W&B as +`eval/{name}` (plus `-truncated_ratio`, response-len stats). + +Three pieces, in order: + +1. **Build an eval set** (subsampled, versioned, pinned by manifest): + `python -m async_rl_research.evalset spec.yaml --out-dir /data/evalsets/v0` + — see `data/README.md`. Oracle-check harbor subsets before burning GPU time. +2. **Wire it into the training config** as an inline `eval_config` dict (the + launcher materializes it to a temp YAML → `--eval-config`); set + `eval_interval`. train_async.py evals every `eval_interval` rollouts + (first at rollout `eval_interval` — no step-0 baseline in async mode; get + the base-model baseline from an eval-only run instead). Eval blocks the + train loop and shares the sglang engines — size subsets accordingly. +3. **Eval-only runs**: `num_rollout = 0` with `eval_interval` set routes + through `train.py`'s stock eval-only branch — one eval pass, then exit + (set `load` to a saved checkpoint to eval a trained model; use + `async_mode = False` in the experiment config so train.py is the + entrypoint). + +Per-dataset eval sampling overrides (`temperature`, `top_p`, `top_k`) flow +through `generate.py::_sampling_params` into the adapter's session defaults, +along with `max_new_tokens` (the per-turn generation cap). slime sets it to +`rollout_max_response_len` for train and `eval_max_response_len` for eval, so +that value bounds a single model turn, then the adapter further clamps it to the +remaining `rollout_max_context_len` budget. diff --git a/async_rl_research/agent/adapters/__init__.py b/async_rl_research/agent/adapters/__init__.py new file mode 100644 index 0000000000..597051eb20 --- /dev/null +++ b/async_rl_research/agent/adapters/__init__.py @@ -0,0 +1,6 @@ +"""Repo-owned slime adapter variants handling model-specific rendering quirks +without patching slime core. See ``qwen.py``.""" + +from .qwen import QwenOpenAIAdapter + +__all__ = ["QwenOpenAIAdapter"] diff --git a/async_rl_research/agent/adapters/qwen.py b/async_rl_research/agent/adapters/qwen.py new file mode 100644 index 0000000000..08269037eb --- /dev/null +++ b/async_rl_research/agent/adapters/qwen.py @@ -0,0 +1,167 @@ +"""Qwen-family OpenAI adapter: render tool-call arguments as a dict. + +slime's ``OpenAIAdapter`` stringifies ``function.arguments`` before +``apply_chat_template``, but the Qwen3-Coder family (Qwen3.6-35B-A3B) template +iterates ``arguments | items`` and needs a mapping -- a string raises on turn +2+, capping every episode at one turn. This adapter renders ``arguments`` as a +dict on the inbound path only (the outbound OpenAI response stays string-form). + +It also splices each turn's raw ``output_ids`` into the next prompt rather than +re-rendering the parsed assistant message (see ``_build_prompt``): the +qwen3_coder parser strips trailing whitespace from tool-call arguments, so a +re-render is not token-identical to what the model generated, which makes +``merge_turns`` log "prefix drift" and mask whole turns out of training. + +slime renders through free functions with no method seam, so we register our +own ``/v1/chat/completions`` handler. ``_run_turn`` / ``_handle_chat_completions`` +below are faithful mirrors of slime's -- keep them in sync. +""" + +from __future__ import annotations + +import asyncio +import json + +from aiohttp import web + +from slime.agent.adapters import openai as _slime_openai +from slime.agent.adapters.common import ADAPTER_KEY, TOKENIZER_KEY, BaseAdapter, render_token_ids +from slime.agent.adapters.openai import OpenAIAdapter + + +def _dictify_tool_arguments(messages: list[dict]) -> None: + """In place: parse each tool call's JSON-string ``function.arguments`` into a + dict so ``apply_chat_template`` can iterate it. Idempotent; non-JSON left as-is.""" + for msg in messages: + for call in msg.get("tool_calls") or []: + fn = call.get("function") + if not isinstance(fn, dict): + continue + args = fn.get("arguments") + if isinstance(args, str): + s = args.strip() + if not s: + fn["arguments"] = {} + continue + try: + fn["arguments"] = json.loads(s) + except (json.JSONDecodeError, ValueError): + pass + + +def _template_ids(tok, messages: list[dict], *, add_generation_prompt: bool) -> list[int]: + """``apply_chat_template`` -> flat token-id list (tolerating the 1-element + batch that some transformers versions return for ``tokenize=True``).""" + enc = tok.apply_chat_template(messages, tools=None, tokenize=True, add_generation_prompt=add_generation_prompt) + ids = enc["input_ids"] if hasattr(enc, "__getitem__") and "input_ids" in enc else enc + ids = list(ids) + if ids and isinstance(ids[0], list): # transformers>=5 may return [[...ids...]] + ids = ids[0] + return ids + + +def _tool_continuation_ids(new_messages: list[dict], tok) -> list[int]: + """Token delta to append after the previous turn's raw ``output_ids``: the + tool-result/user message(s) mini-swe added this turn, plus the next + generation prompt. + + The model's raw ``output_ids`` stop at ``<|im_end|>``; the chat template + emits ``<|im_end|>\\n`` for a finished assistant turn, so the delta must + restore that single inter-turn newline. We anchor on a sentinel user message + and slice from its trailing newline onward. + """ + continuation = _slime_openai._translate_chat_messages( + [m for m in new_messages if isinstance(m, dict) and m.get("role") != "assistant"] + ) + if not continuation: + return [] + sentinel = [{"role": "user", "content": ""}] + base = _template_ids(tok, sentinel, add_generation_prompt=False) + full = _template_ids(tok, sentinel + continuation, add_generation_prompt=True) + if len(base) < 1 or full[: len(base)] != base: + return [] # unexpected render shape -> caller falls back to a full re-render + return full[len(base) - 1 :] # start at the sentinel's trailing "\n" (the inter-turn separator) + + +def _build_prompt(target, messages: list[dict], tools_schema: list[dict] | None, kind: str, tok) -> list[int]: + """Build the next prompt. + + On the ``append`` path, splice the previous turn's **raw** ``output_ids`` + into the prompt instead of re-rendering the parsed assistant message. The + qwen3_coder parser strips trailing whitespace from tool-call arguments, so a + re-render is not token-identical to what the model generated; that mismatch + makes ``merge_turns`` log "prefix drift" and mask whole turns out of training + (and makes the rollout subtly off-policy). Splicing the raw tokens keeps the + prompt == the training target by construction. The parsed message is still + returned to mini-swe for tool execution -- only prompt reconstruction here + changes. ``new``/``wipe`` (and the first turn) fall back to a full re-render. + """ + new_messages = messages[target.seen_msgs :] if kind == "append" else [] + (_slime_openai._extend_chat_messages if kind == "append" else _slime_openai._replace_chat_messages)( + target, messages, tools_schema + ) + if kind == "append" and target.turns: + continuation = _tool_continuation_ids(new_messages, tok) + if continuation: + last = target.turns[-1] + return list(last.prompt_ids) + list(last.output_ids) + continuation + _dictify_tool_arguments(target.chat_messages) + return render_token_ids(target, tok) + + +async def _run_turn(request: web.Request, body: dict, messages: list[dict]): + """Mirror of ``openai._run_turn`` calling the dict-args ``_build_prompt``.""" + sid = _slime_openai._request_session_id(request, body) + adapter = request.app[ADAPTER_KEY] + if sid in adapter.closed: + raise web.HTTPServiceUnavailable(text="session closed") + app = request.app + s = adapter.store.setdefault(sid, _slime_openai.Session()) + task = asyncio.current_task() + adapter.inflight.setdefault(sid, set()).add(task) + try: + async with s.lock: + target = s.main + tools_schema = _slime_openai._normalize_tools(body.get("tools")) + kind = _slime_openai._select_kind(s, messages) + prompt_ids = _build_prompt(target, messages, tools_schema, kind, app[TOKENIZER_KEY]) + turn = await _slime_openai._generate(prompt_ids, s, body, app, session_id=sid) + parsed = _slime_openai._parse_turn(target, turn, app) + target.turns.append(turn) + return turn, parsed, len(prompt_ids), len(turn.output_ids) + finally: + adapter.inflight.get(sid, set()).discard(task) + + +async def _handle_chat_completions(request: web.Request) -> web.StreamResponse: + """Mirror of ``openai._handle_chat_completions`` via the dict-args ``_run_turn``.""" + body = await request.json() + messages = body.get("messages") or [] + if not isinstance(messages, list): + raise web.HTTPBadRequest(text="messages must be a list") + turn, parsed, in_tok, out_tok = await _run_turn(request, body, messages) + if body.get("stream"): + return await _slime_openai._stream_chat_completion(request, body, parsed, turn.finish_reason, in_tok, out_tok) + return web.json_response( + _slime_openai._chat_completion_response(body, parsed, turn.finish_reason, in_tok, out_tok) + ) + + +class QwenOpenAIAdapter(OpenAIAdapter): + """``OpenAIAdapter`` rendering tool-call arguments as a dict (see module + docstring); only the ``/v1/chat/completions`` handler differs.""" + + def __init__(self, *, tokenizer, sglang_url, tool_parser=None, reasoning_parser=None) -> None: + # Skip OpenAIAdapter.__init__: it binds slime's string-args handler and + # aiohttp can't re-bind a route. Do BaseAdapter setup, then our routes. + BaseAdapter.__init__( + self, + tokenizer=tokenizer, + sglang_url=sglang_url, + tool_parser=tool_parser, + reasoning_parser=reasoning_parser, + ) + self.app.router.add_post("/v1/chat/completions", _handle_chat_completions) + self.app.router.add_post("/v1/responses", _slime_openai._handle_responses) # mini-swe unused + self.app.router.add_get("/healthz", _slime_openai._ok) + self.app.router.add_get("/v1/models", _slime_openai._ok) diff --git a/async_rl_research/agent/base.py b/async_rl_research/agent/base.py new file mode 100644 index 0000000000..d1a7a8edfc --- /dev/null +++ b/async_rl_research/agent/base.py @@ -0,0 +1,224 @@ +"""AgentRuntime: the contract between ``generate.py`` and one agent framework. + +A *runtime* packages everything specific to one in-sandbox agent framework +(mini-swe-agent, opencode, ...): its slime adapter, provisioning, and launch. +Subclass, declare the class attributes, implement ``run_agent`` by composing +``_ensure_provisioned`` + ``_detached_run``, and register in ``RUNTIMES`` below. + +On-policy rule: the adapter applies the request body OVER its per-session +sampling defaults, so a runtime must strip the agent's own temperature/top_p +(a client-sent temperature silently turns rollouts greedy). Runtimes are +instantiated once per worker and must be stateless across samples. +""" + +from __future__ import annotations + +import asyncio +import importlib +import logging +import shlex +import time +from abc import ABC, abstractmethod +from typing import ClassVar, NamedTuple + +from slime.agent.sandbox import Sandbox + +logger = logging.getLogger(__name__) + + +# run_agent return value when the wallclock budget elapsed before the done +# marker appeared (otherwise run_agent returns the agent's exit code). +EXIT_BUDGET_EXCEEDED = -2 + + +class AgentRunResult(NamedTuple): + """Outcome of one agent leg: the process ``exit_code`` (or + ``EXIT_BUDGET_EXCEEDED``) plus, on a nonzero exit, the last few KB of the + agent's stdout/stderr (``tail``; empty on a clean exit). + + Persisted into sample metadata so a zero-turn ``adapter_session_empty`` + self-explains in the dump (e.g. exit=137 -> OOM-killed) instead of relying + on tail-only Modal logs that age out once the run finishes. + """ + + exit_code: int + tail: str = "" + + +class AgentRuntime(ABC): + """One agent framework's integration: wire adapter + provision + launch. + + Required attributes (validated at class-definition time): ``name`` (registry + key / log prefix) and ``adapter_cls`` (slime adapter for the wire protocol). + Optional: ``model_name`` (advertised to the agent), ``scratch_prefix`` + (launch scratch prefix under workdir), ``diff_exclude`` (extra scratch to + drop from the diff; launch scratch is excluded automatically). + """ + + name: ClassVar[str] + adapter_cls: ClassVar[type] + model_name: ClassVar[str] = "slime-actor" + scratch_prefix: ClassVar[str] = ".agent" + diff_exclude: ClassVar[tuple[str, ...]] = () + + def __init_subclass__(cls, **kwargs) -> None: + # Fail at import time, not mid-rollout, on missing declarations. + super().__init_subclass__(**kwargs) + missing = [a for a in ("name", "adapter_cls") if getattr(cls, a, None) is None] + if missing: + raise TypeError(f"{cls.__name__} must define class attribute(s) {missing!r} (see AgentRuntime)") + + @property + def diff_exclude_all(self) -> tuple[str, ...]: + """Everything to drop from the captured diff: launch scratch + extras.""" + return (*self._launch_scratch_files(), *self.diff_exclude) + + @abstractmethod + async def run_agent( + self, + sb: Sandbox, + *, + md: dict, + session_id: str, + adapter_url: str, + time_budget_sec: int, + ) -> AgentRunResult: + """Provision + launch the agent in the booted, task-prepped sandbox. + + The agent must call ``adapter_url`` with ``session_id`` as its bearer so + the adapter groups its turns. May be called multiple times per sample in + the SAME sandbox (multi-step). Returns an ``AgentRunResult`` (exit code, + or ``EXIT_BUDGET_EXCEEDED``, plus a failure-only log tail). + """ + + # ------------------------------------------------------------------ + # Shared machinery + # ------------------------------------------------------------------ + def _launch_scratch_files(self) -> tuple[str, str, str]: + """(launcher script, done marker, log) names under workdir.""" + p = self.scratch_prefix + return (f"{p}_run.sh", f"{p}_done", f"{p}_log") + + async def _detached_run( + self, + sb: Sandbox, + *, + workdir: str, + command: str, + env: dict[str, str] | None = None, + time_budget_sec: int, + poll_interval_sec: float = 5.0, + log_tag: str = "", + ) -> AgentRunResult: + """Launch ``command`` detached in ``workdir`` (``setsid ... &``) and poll + a done-marker, so a foreground exec stream reset can't re-launch the + whole agent. ``command`` is spliced in verbatim (caller quotes paths). + Returns an ``AgentRunResult``: the exit code (or ``EXIT_BUDGET_EXCEEDED`` + if ``time_budget_sec`` elapses first) plus, on a nonzero exit, the last + 4 KB of the agent log (surfaced host-side AND returned for the dump). + """ + q = shlex.quote + launch, done, log = self._launch_scratch_files() + exports = "".join(f"export {k}={q(str(v))}\n" for k, v in (env or {}).items()) + launcher_body = ( + "#!/bin/bash\n" f"cd {q(workdir)}\n" f"{exports}" f"{command} > {q(log)} 2>&1\n" f"echo $? > {q(done)}\n" + ) + await sb.write_file(f"{workdir}/{launch}", launcher_body) + # rm the done marker BEFORE launching: a stale marker from a prior leg + # would satisfy the first poll while the new agent still runs. + await sb.exec(f"cd {q(workdir)} && rm -f {q(done)} && chmod +x {q(launch)}", check=False, timeout=30) + await sb.exec( + f"cd {q(workdir)} && setsid bash {q(launch)} < /dev/null > /dev/null 2>&1 &", + check=False, + timeout=30, + ) + + done_path = f"{workdir}/{done}" + deadline = time.time() + time_budget_sec + exit_code = EXIT_BUDGET_EXCEEDED + while time.time() < deadline: + await asyncio.sleep(poll_interval_sec) + ec, out, _ = await sb.exec(f"test -f {q(done_path)} && cat {q(done_path)}", check=False, timeout=15) + if ec == 0: + try: + exit_code = int((out or "").strip() or "-1") + except ValueError: + exit_code = -1 + break + tail = "" + if exit_code != 0: + _, raw_tail, _ = await sb.exec( + f"tail -c 4000 {q(f'{workdir}/{log}')} 2>/dev/null", check=False, timeout=15 + ) + tail = (raw_tail or "").strip() + if tail: + logger.warning("[%s] %s exit=%s %s tail:\n%s", self.name, log_tag, exit_code, log, tail) + logger.info("[%s] %s exit=%s elapsed<=%ds", self.name, log_tag, exit_code, time_budget_sec) + return AgentRunResult(exit_code, tail) + + async def _ensure_provisioned( + self, + sb: Sandbox, + *, + spec: str, + marker_path: str, + setup_script: str, + check_cmd: str | None = None, + timeout: int = 900, + ) -> bool: + """Idempotent toolchain install keyed on a spec marker. + + Skips when ``marker_path`` already holds exactly ``spec`` (and + ``check_cmd`` exits 0), so a pre-baked image short-circuits while a + changed pin rebuilds. The marker is written LAST so a half-finished + install is never mistaken for complete. Returns True if it ran. + """ + q = shlex.quote + probe = f"cat {q(marker_path)} 2>/dev/null" + if check_cmd: + probe = f"({check_cmd}) >/dev/null 2>&1 && " + probe + _, out, _ = await sb.exec(probe, check=False, timeout=60) + if (out or "").strip() == spec: + return False + + logger.info("[%s] provisioning spec=%s in sandbox %s", self.name, spec, sb.sandbox_id[:8]) + await sb.exec(setup_script, check=True, timeout=timeout) + if check_cmd: + await sb.exec(check_cmd, check=True, timeout=60) + await sb.exec(f"printf '%s' {q(spec)} > {q(marker_path)}", check=True, timeout=30) + return True + + +# --------------------------------------------------------------------------- +# Registry + loader +# --------------------------------------------------------------------------- +DEFAULT_RUNTIME = "mini-swe" + +# Short name -> "module:Class" (strings so importing base.py imports no runtime). +RUNTIMES: dict[str, str] = { + "mini-swe": "async_rl_research.agent.mini_swe_agent:MiniSweAgentRuntime", +} + + +def load_runtime(spec: str | None = None) -> AgentRuntime: + """Resolve ``spec`` to an AgentRuntime instance, validating eagerly. + + Accepts a registry short name ("mini-swe"), "pkg.module:ClassName", or a + module path exposing ``RUNTIME`` (legacy driver form). + """ + spec = spec or DEFAULT_RUNTIME + target = RUNTIMES.get(spec, spec) + if ":" in target: + module_path, _, attr = target.partition(":") + else: + module_path, attr = target, "RUNTIME" + module = importlib.import_module(module_path) + cls = getattr(module, attr, None) + if cls is None: + raise ValueError( + f"agent runtime {spec!r}: module {module_path!r} does not expose {attr!r}; " + f"known short names: {sorted(RUNTIMES)}" + ) + if not (isinstance(cls, type) and issubclass(cls, AgentRuntime)): + raise TypeError(f"agent runtime {spec!r} resolved to {cls!r}, which is not an AgentRuntime subclass") + return cls() diff --git a/async_rl_research/agent/config/universal.yaml b/async_rl_research/agent/config/universal.yaml new file mode 100644 index 0000000000..a06f734e24 --- /dev/null +++ b/async_rl_research/agent/config/universal.yaml @@ -0,0 +1,151 @@ +# The universal mini-swe-agent config: ONE prompt scaffold for ALL task +# families. Owned by this repo (not the pip-installed package) so the prompt +# distribution is version-controlled here and changing it never requires a +# package bump. +# +# Based on mini-swe-agent v2.3.1's default.yaml, with the action rules and +# format_error_template rewritten for NATIVE tool-calls (the runner uses +# LitellmModel, and the slime adapter parses tool calls via sglang's parser; +# default.yaml's ```mswea_bash_command``` text format would be rejected every +# turn). +# +# Scope rule: this file carries only scaffold-universal content (response +# format, subshell semantics, how to finish). The task's DELIVERABLE (patch +# vs. artifacts vs. stdout) belongs in the instruction text the env writes to +# PROBLEM_STATEMENT.md, which is rendered here as {{task}} -- harbor +# instruction.md files already follow this; swe_gym appends its suffix in +# env/swe_gym.py. +# +# step_limit / cost_limit / sampling knobs are overridden host-side by the +# runner regardless of what this file says. +agent: + system_template: | + You are a helpful assistant that can interact with a computer. + instance_template: | + Please solve the following task: + + + {{task}} + + + You can execute bash commands and edit files to accomplish it. The task + description above defines what you must produce and any task-specific + submission steps; follow it exactly. + + ## Command Execution Rules + + You are operating in an environment where + + 1. You issue at least one command + 2. The system executes the command(s) in a subshell + 3. You see the result(s) + 4. You write your next command(s) + + Each response should include: + + 1. **Reasoning text** where you explain your analysis and plan + 2. At least one call to the `bash` tool with the shell command to run + + **CRITICAL REQUIREMENTS:** + + - Your response SHOULD include reasoning text explaining what you're doing + - Your response MUST include AT LEAST ONE call to the `bash` tool + - Directory or environment variable changes are not persistent. Every action is executed in a new subshell. + - However, you can prefix any action with `MY_ENV_VAR=MY_VALUE cd /path/to/working/dir && ...` or write/load environment variables from files + + Example of a CORRECT response: a short paragraph of reasoning ("I'll look at + the repo layout first."), followed by a call to the `bash` tool whose command + is `ls -la`. + + ## Environment Details + + - You have a full Linux shell environment + - Always use non-interactive flags (-y, -f) for commands + - Avoid interactive tools like vi, nano, or any that require user input + - You can create new tools or scripts to help you with the task; if a tool isn't available, you can install it + + + {{system}} {{release}} {{version}} {{machine}} + + + ## Useful shell idioms + + The following are commands you pass to the `bash` tool (its `command` + argument) -- they are NOT a response format. Adapt as needed. + + - Create a file with a heredoc: + cat <<'EOF' > newfile.py + import numpy as np + print("hello") + EOF + - Edit in place with sed: `sed -i 's/old/new/g' filename.py` (drop the `g` + for first-occurrence-only; prefix a range like `1,10` to scope to lines). + - View specific lines with numbers: `nl -ba filename.py | sed -n '10,20p'`. + + ## Finishing + + Once you have verified your work and completed any submission steps the + task description requires, finish by issuing the following command: + `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT`. + Do not combine it with any other command. After this command, + you cannot continue working on this task. + step_limit: 0 + cost_limit: 0. +environment: + timeout: 60 + env: + PAGER: cat + MANPAGER: cat + LESS: -R + PIP_PROGRESS_BAR: 'off' + TQDM_DISABLE: '1' +model: + observation_template: | + {% if output.exception_info -%} + {{output.exception_info}} + {% endif -%} + {{output.returncode}} + {% if output.output | length < 10000 -%} + + {{ output.output -}} + + {%- else -%} + + The output of your last command was too long. + Please try a different command that produces less output. + If you're looking at a file you can try use head, tail or sed to view a smaller number of lines selectively. + If you're using grep or find and it produced too much output, you can use a more selective search pattern. + If you really need to see something from the full command's output, you can redirect output to a file and then search in that file. + + {%- set elided_chars = output.output | length - 10000 -%} + + {{ output.output[:5000] }} + + + {{ elided_chars }} characters elided + + + {{ output.output[-5000:] }} + + {%- endif -%} + format_error_template: |- + + Tool call error: + + + {{error}} + + + Here is general guidance on how to submit correct toolcalls: + + Every response needs to use the 'bash' tool at least once to execute commands. + + Call the bash tool with your command as the argument: + - Tool: bash + - Arguments: {"command": "your_command_here"} + + If you want to end the task, please issue the following command: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT` + without any other command. + + model_kwargs: + drop_params: true diff --git a/async_rl_research/agent/mini_swe_agent.py b/async_rl_research/agent/mini_swe_agent.py index 60f9b43f86..5d1e7054db 100644 --- a/async_rl_research/agent/mini_swe_agent.py +++ b/async_rl_research/agent/mini_swe_agent.py @@ -1,110 +1,157 @@ -"""mini-swe-agent driver. - -This module is the agent-specific half of the rollout. The generic recipe -(adapter/HTTP lifecycle, trajectory merge, abort isolation, dataset -normalization, sandbox boot / git_diff / evaluate) lives in -``async_rl_research.generate`` and ``async_rl_research.sandbox``. Here we -own only what is unique to mini-swe-agent: - - * ADAPTER_CLS / MODEL_NAME -- which slime adapter speaks this agent's wire - protocol (mini-swe-agent talks to litellm's OpenAI-compatible API, so we - intercept with OpenAIAdapter). - * the in-sandbox **headless runner** (``MINI_RUNNER_PY``) -- stock - mini-swe-agent (LitellmModel + LocalEnvironment + DefaultAgent) wired so - every model call dials back to the slime adapter. - * ``run_agent`` -- provision the package, upload the runner + task, launch - it detached, and poll a done-marker (sandbox gateways reset long-lived - HTTP/2 connections, so we cannot hold a multi-minute foreground exec). - -Token capture + loss masking happen entirely host-side in the adapter; this -runner is "dumb" and never sees token ids. mini-swe-agent runs UNMODIFIED at -its public OpenAI boundary -- the only requirement on the sandbox image is -python + the ``mini-swe-agent`` package (prefer baking it in; the best-effort -pip install below is a fallback for dev). - -Design A wire flow per turn:: - - in-sandbox: litellm.completion(messages, tools=[BASH_TOOL]) - -> POST {adapter_url}/v1/chat/completions Bearer - host adapter: render messages -> input_ids -> SGLang /generate - (return_logprob) -> record TurnRecord -> OpenAI JSON back - in-sandbox: run bash tool-call locally -> append observation -> loop - -The served model must support tool-call bash; set the matching SGLang parsers -on the launcher, e.g. ``--sglang-tool-call-parser qwen3_coder`` and -``--sglang-reasoning-parser qwen3``. +"""mini-swe-agent runtime (the default AgentRuntime). + +Runs stock mini-swe-agent (v2) headless inside the sandbox in an isolated +uv-venv, with every model call dialing back to the slime adapter over litellm's +OpenAI-compatible API. Token capture + loss masking happen host-side; the +runner never sees token ids. The adapter MUST use the served model's sglang +tool-call parser (v2 drives bash via native tool-calls), e.g. +``--sglang-tool-call-parser qwen25 --sglang-reasoning-parser qwen3``. """ from __future__ import annotations -import asyncio -import logging import os import shlex -import time +from pathlib import Path -from slime.agent.adapters import OpenAIAdapter from slime.agent.sandbox import Sandbox -logger = logging.getLogger(__name__) +from ..environment.base import PROBLEM_FILE +# Renders tool-call arguments as a dict so Qwen3.6's qwen3_coder chat template +# doesn't crash on turn 2+ (safe for hermes-style Qwen3 too). +from .adapters import QwenOpenAIAdapter +from .base import AgentRunResult, AgentRuntime -# --- driver declaration (read by async_rl_research.generate._State) --------- -ADAPTER_CLS = OpenAIAdapter -# Advertised to litellm as "openai/". The adapter ignores the name -# (it routes to the SGLang-served actor); litellm only needs the provider -# prefix so it speaks the OpenAI dialect at our adapter_url. -MODEL_NAME = "slime-actor" - - -# --- mini-swe-agent-specific knobs ------------------------------------------ MSWE_STEP_LIMIT = int(os.environ.get("MSWE_STEP_LIMIT", "50")) -# Prefer baking `pip install mini-swe-agent==` into the sandbox image. If -# MSWE_PIP_INSTALL=1, run_agent will best-effort install it at boot (needs the -# sandbox to have outbound PyPI access). -MSWE_PIP_INSTALL = os.environ.get("MSWE_PIP_INSTALL", "0") == "1" -MSWE_PIP_SPEC = os.environ.get("MSWE_PIP_SPEC", "mini-swe-agent") +# Consecutive no-tool-call model turns before the runner ends the episode: a +# stuck model that never reaches the context wall would otherwise format-error +# its way to MSWE_STEP_LIMIT. See _StopAwareModel in MINI_RUNNER_PY. +MSWE_MAX_EMPTY_TURNS = int(os.environ.get("MSWE_MAX_EMPTY_TURNS", "3")) +# Which YAML config (prompts) the runner loads. Override ladder: MSWE_CONFIG +# env (global) > metadata.agent_config (per-row) > universal config below. +MSWE_CONFIG = os.environ.get("MSWE_CONFIG", "") +# Read at import: the scaffold must be identical for every rollout in a run. +UNIVERSAL_CONFIG_YAML = (Path(__file__).parent / "config" / "universal.yaml").read_text(encoding="utf-8") +# Exact-pinned: prompts + wire protocol are part of the RL task distribution, +# and MINI_RUNNER_PY below is written against the v2 API. +MSWE_PIP_SPEC = os.environ.get("MSWE_PIP_SPEC", "mini-swe-agent==2.3.1") +# Prepended to PATH for the agent's bash commands: LocalEnvironment runs via +# /bin/sh so `conda activate testbed` never fires; this is how its python wins. +MSWE_PATH_PREPEND = os.environ.get("MSWE_PATH_PREPEND", "/opt/miniconda3/envs/testbed/bin:/opt/miniconda3/bin") +# Isolated venv so the testbed conda env is never used or clobbered. Provisioned +# at boot with uv; can be pre-baked into a derived image. +MSWE_AGENT_VENV = os.environ.get("MSWE_AGENT_VENV", "/opt/mswe-agent") +MSWE_AGENT_PYTHON_VERSION = os.environ.get("MSWE_AGENT_PYTHON_VERSION", "3.11") + +_VENV_PY = f"{MSWE_AGENT_VENV}/bin/python" -# Sandbox paths (kept under workdir; excluded from the captured diff). _RUNNER = ".mswe_runner.py" -_LAUNCH = ".mswe_run.sh" -_DONE = ".mswe_done" -_LOG = ".mswe_log" -_PROBLEM = "PROBLEM_STATEMENT.md" - - -# --------------------------------------------------------------------------- -# Headless in-sandbox runner. -# -# NOTE: PIN + VERIFY the imports / kwargs against the mini-swe-agent version -# baked into the image -- class names and config kwargs have shifted across -# releases. The wiring that must hold regardless: litellm points at the slime -# adapter (OPENAI_API_BASE/OPENAI_API_KEY), and we DO NOT set temperature here -# (the adapter's per-session sampling defaults must win, keeping RL on-policy). -# --------------------------------------------------------------------------- -MINI_RUNNER_PY = r'''"""Headless mini-swe-agent runner -- runs INSIDE the sandbox (design A).""" +_CONFIG_FILE = ".mswe_config.yaml" + + +# Headless in-sandbox runner (mini-swe-agent v2, exact-pinned). NO sampling +# knobs reach the request body -- the adapter applies it OVER its per-session +# defaults, so a client-sent temperature would silently turn rollouts greedy. +MINI_RUNNER_PY = r'''"""Headless mini-swe-agent (v2) runner -- runs INSIDE the sandbox (design A).""" import os import sys import traceback +from pathlib import Path WORKDIR = os.environ["MSWE_WORKDIR"] MODEL = os.environ.get("MSWE_MODEL", "slime-actor") STEP_LIMIT = int(os.environ.get("MSWE_STEP_LIMIT", "50")) +MAX_EMPTY_TURNS = int(os.environ.get("MSWE_MAX_EMPTY_TURNS", "3")) +PATH_PREPEND = os.environ.get("MSWE_PATH_PREPEND", "") with open(os.environ["MSWE_PROBLEM_FILE"], encoding="utf-8") as fh: TASK = fh.read() try: + import yaml from minisweagent.agents.default import DefaultAgent + from minisweagent.config import builtin_config_dir from minisweagent.environments.local import LocalEnvironment from minisweagent.models.litellm_model import LitellmModel + from minisweagent.exceptions import FormatError, LimitsExceeded + + class _StopAwareModel(LitellmModel): + """End the episode instead of looping when the model can't progress. + + A no-tool-call response surfaces as FormatError from super().query() + (LitellmModel stashes the raw response, incl. finish_reason, on it). We + stop on finish_reason='length' (adapter signalled context/output budget + exhausted) or after MAX_EMPTY_TURNS consecutive no-tool-call turns, + raising LimitsExceeded -- mini-swe's own graceful 'exit' path, the same + one step_limit uses. Without this mini-swe retries the format error every + turn and burns the whole context to step_limit (49 dead turns seen on + eval); finish_reason is otherwise never inspected. + """ - # api_base / api_key come from OPENAI_API_BASE / OPENAI_API_KEY in the env - # (litellm's openai provider reads them). No temperature/top_p here. - model = LitellmModel(model_name="openai/" + MODEL) - env = LocalEnvironment(cwd=WORKDIR) - # cost_limit=0 disables cost tracking (meaningless against a local actor). - agent = DefaultAgent(model, env, step_limit=STEP_LIMIT, cost_limit=0.0) - agent.run(TASK) + _empty = 0 + + def query(self, messages, **kwargs): + try: + msg = super().query(messages, **kwargs) + except FormatError as e: + resp = (e.messages[0].get("extra") or {}).get("response") or {} + fr = ((resp.get("choices") or [{}])[0] or {}).get("finish_reason") + self._empty += 1 + if fr == "length" or self._empty >= MAX_EMPTY_TURNS: + status = "ContextLengthExceeded" if fr == "length" else "NoProgress" + raise LimitsExceeded( + { + "role": "exit", + "content": f"ending session: finish_reason={fr}, no-tool-call streak={self._empty}", + "extra": {"exit_status": status, "submission": ""}, + } + ) + raise + self._empty = 0 + return msg + + # Default to the uploaded universal config; MSWE_CONFIG (if set) names a + # BUILTIN packaged config. Read the builtin path directly -- the spec helper + # would also try cwd-relative candidates a repo file could shadow. + cfg_path = Path(os.environ["MSWE_CONFIG_FILE"]) + builtin = os.environ.get("MSWE_CONFIG", "") + if builtin: + candidate = builtin_config_dir / builtin + if candidate.is_file(): + cfg_path = candidate + else: + print("[runner] builtin config %s not found; using the universal config" % candidate) + cfg = yaml.safe_load(cfg_path.read_text()) + agent_cfg = dict(cfg.get("agent") or {}) + model_cfg = dict(cfg.get("model") or {}) + env_cfg = dict(cfg.get("environment") or {}) + + # Strip all sampling knobs (the config pins temperature=0.0 for + # benchmarking) so the adapter's per-session defaults stay in force. + model_kwargs = dict(model_cfg.get("model_kwargs") or {}) + model_kwargs.pop("temperature", None) + model_kwargs.pop("top_p", None) + model_cfg.update( + model_name="openai/" + MODEL, + model_kwargs=model_kwargs, + # "openai/slime-actor" has no litellm price entry; the default mode + # would raise on the first successful completion. + cost_tracking="ignore_errors", + ) + agent_cfg.update(step_limit=STEP_LIMIT, cost_limit=0.0) + + # Prepend the testbed env's bin dirs onto PATH (config.env wins over + # os.environ); conda activation never fires under /bin/sh. + env_overrides = dict(env_cfg.get("env") or {}) + prepend = [p for p in PATH_PREPEND.split(":") if p and os.path.isdir(p)] + if prepend: + env_overrides["PATH"] = ":".join(prepend) + ":" + os.environ.get("PATH", "") + + model = _StopAwareModel(**model_cfg) + env = LocalEnvironment(cwd=WORKDIR, env=env_overrides, timeout=int(env_cfg.get("timeout") or 60)) + agent = DefaultAgent(model, env, **agent_cfg) + info = agent.run(TASK) + print("[runner] exit_status=%s" % info.get("exit_status")) sys.exit(0) except SystemExit: raise @@ -114,125 +161,124 @@ ''' -# --------------------------------------------------------------------------- -# run_agent: provision + launch + poll (the only entrypoint generate.py calls) -# --------------------------------------------------------------------------- -async def run_agent( - sb: Sandbox, - *, - md: dict, - session_id: str, - adapter_url: str, - time_budget_sec: int, -) -> int: - """Provision mini-swe-agent in ``sb``, run it on the task, poll to done. - - Returns the runner's exit code, or ``-2`` if the wallclock budget elapses - first. The agent dials back to ``adapter_url`` for every model call and - authenticates with ``session_id`` so the adapter groups its turns. - """ - workdir = md["workdir"] - await _prepare_workspace(sb, workdir, md) - await _ensure_installed(sb) - return await _launch_and_poll( - sb, - workdir=workdir, - session_id=session_id, - adapter_url=adapter_url, - time_budget_sec=time_budget_sec, - ) - - -async def _prepare_workspace(sb: Sandbox, workdir: str, md: dict) -> None: - # git operations inside the sandbox need the repo marked safe; the diff is - # captured by sandbox.git_diff later. - await sb.exec("git config --system --add safe.directory '*'", check=False, timeout=60) - if md.get("pre_commands"): - await _apply_pre_commands(sb, workdir, md["pre_commands"]) - await sb.write_file(f"{workdir}/{_PROBLEM}", md.get("problem_statement") or "") - await sb.write_file(f"{workdir}/{_RUNNER}", MINI_RUNNER_PY) - - -async def _apply_pre_commands(sb: Sandbox, workdir: str, pre) -> None: - # Keep the work sandbox baseline aligned with eval (sweb-style pre_commands - # are typically `git checkout -f`); skipping them makes the - # model's diff context mismatch the eval base -> apply failures. - body = pre.replace("\\n", "\n") if isinstance(pre, str) else "\n".join(c for c in (pre or []) if c) - await sb.write_file(f"{workdir}/.mswe_pre.sh", "set -e\n" + body) - await sb.exec(f"cd {shlex.quote(workdir)} && bash .mswe_pre.sh", check=False, timeout=600) - - -async def _ensure_installed(sb: Sandbox) -> None: - ec, out, _ = await sb.exec( - 'python -c "import minisweagent" 2>/dev/null && echo MSWE_OK', check=False, timeout=60 - ) - if "MSWE_OK" in (out or ""): - return - if not MSWE_PIP_INSTALL: - raise RuntimeError( - "mini-swe-agent is not installed in the sandbox image. Bake " - f"`pip install {MSWE_PIP_SPEC}` into the image, or set " - "MSWE_PIP_INSTALL=1 to install at boot (needs outbound PyPI)." +# Provision mini-swe-agent into an isolated py3.11 uv venv. uv is our package +# manager; we NEVER fall back to the image's pip, because many task images +# (SWE-bench-Pro) ship a poisoned PIP_INDEX_URL pointing at a dead build-time +# mirror -> `pip install uv` hits "Connection refused .../uv/". Instead, mirror +# harbor's bootstrap: ensure curl via the OS package manager, install uv from +# the pinned astral script (falling back to pip ONLY with an explicit PyPI +# index), and retry network steps to absorb transient GitHub release resets. +# Measured across all 731 SWE-bench-Pro images at conc 50: 99.7% success (vs the +# original pip-fallback's ~52% on no-curl images). See +# profiles/PROVISIONING_VENV_VS_VOLUME.md. +_VENV_SETUP = ( + "set -e\n" + 'export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"\n' + 'retry() { for i in 1 2 3; do bash -c "$1" && return 0; sleep $((i*4)); done; return 1; }\n' + "if ! command -v uv >/dev/null 2>&1; then\n" + # Ensure curl via whatever package manager the image ships (best-effort: + # if none works we still try the pip path below). + ' command -v curl >/dev/null 2>&1 || retry "apt-get update && apt-get install -y curl' + ' || apk add --no-cache curl || yum install -y curl || dnf install -y curl" || true\n' + # uv via the pinned astral script (bypasses the image's pip entirely); + # pip fallback forces a clean PyPI index so a poisoned image config can't win, + # then retries with --break-system-packages so PEP 668 ("externally-managed") + # images don't hard-fail. Try plain first: older pip (<23) lacks the flag but + # also doesn't enforce PEP 668, so the plain attempt already succeeds there. + ' retry "curl -LsSf https://astral.sh/uv/0.7.13/install.sh | sh"' + ' || retry "python3 -m pip install --index-url https://pypi.org/simple --root-user-action=ignore uv' + ' || python3 -m pip install --index-url https://pypi.org/simple --root-user-action=ignore --break-system-packages uv"\n' + ' export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"\n' + "fi\n" + f"rm -rf {shlex.quote(MSWE_AGENT_VENV)}\n" + f'retry "uv venv --python {MSWE_AGENT_PYTHON_VERSION} {shlex.quote(MSWE_AGENT_VENV)}"\n' + f'retry "uv pip install --python {shlex.quote(_VENV_PY)} {shlex.quote(MSWE_PIP_SPEC)}"\n' + # Verify the agent's REAL import (the top package doesn't pull in pydantic). + # Two distinct failure modes this guards against: + # * `-P` (MANDATORY): the exec runs with cwd = the image WORKDIR (e.g. + # /testbed), so a task repo whose root *is* an agent dependency shadows + # the venv. The pydantic SWE-gym tasks ship /testbed/pydantic/, which the + # bare `python -c` imports instead of the venv's pydantic -> repo-pydantic + # vs venv-pydantic_core skew crashes ("no attribute 'dict_not_none'"). + # -P drops cwd from sys.path -- the SAME guard the runner launch uses + # (see run_agent) -- so the venv wins. Deterministic: was hitting 100% of + # pydantic tasks as a loud `exception:RuntimeError` from _ensure_provisioned. + # * reinstall: a partial wheel (corrupt uv cache / index contention at high + # conc) can leave a native ext unimportable; --reinstall --no-cache + # repairs it. A still-broken venv then fails LOUDLY here (set -e) instead + # of as a silent zero-turn adapter_session_empty. + f"for i in 1 2 3; do MSWEA_SILENT_STARTUP=1 {shlex.quote(_VENV_PY)} -P -c 'import minisweagent.agents.default' 2>/dev/null && break;" + f' retry "uv pip install --python {shlex.quote(_VENV_PY)} --reinstall --no-cache {shlex.quote(MSWE_PIP_SPEC)}" || true; done\n' + f"MSWEA_SILENT_STARTUP=1 {shlex.quote(_VENV_PY)} -P -c 'import minisweagent.agents.default'\n" +) + +# MSWEA_SILENT_STARTUP suppresses the import-time banner that would otherwise +# corrupt the provisioning probe's marker comparison. Import the agent's real +# entrypoint (not just the top package) so the probe also rejects a pre-baked +# venv whose pydantic_core native module is missing -> re-provision instead of +# launching a doomed agent. `-P` keeps the image WORKDIR (e.g. /testbed) off +# sys.path so a repo named like an agent dep (pydantic tasks) can't shadow the +# venv and fail the probe; see _VENV_SETUP and the runner launch. +_VENV_CHECK = f"MSWEA_SILENT_STARTUP=1 {shlex.quote(_VENV_PY)} -P -c 'import minisweagent.agents.default'" + + +class MiniSweAgentRuntime(AgentRuntime): + name = "mini-swe" + adapter_cls = QwenOpenAIAdapter + model_name = "slime-actor" + scratch_prefix = ".mswe" + # "patch.txt": submission artifact the builtin swebench prompt tells the + # agent to create, which `git add -N .` would otherwise sweep into the diff. + diff_exclude = (_RUNNER, _CONFIG_FILE, "patch.txt") + + async def run_agent( + self, + sb: Sandbox, + *, + md: dict, + session_id: str, + adapter_url: str, + time_budget_sec: int, + ) -> AgentRunResult: + """Provision mini-swe-agent in ``sb``, run it on the task, poll to done.""" + workdir = md["workdir"] + await sb.write_file(f"{workdir}/{_RUNNER}", MINI_RUNNER_PY) + await sb.write_file(f"{workdir}/{_CONFIG_FILE}", UNIVERSAL_CONFIG_YAML) + await self._ensure_provisioned( + sb, + spec=MSWE_PIP_SPEC, + marker_path=f"{MSWE_AGENT_VENV}/.mswe_spec", + setup_script=_VENV_SETUP, + check_cmd=_VENV_CHECK, ) - logger.info("[mini_swe_agent] installing %s in sandbox %s", MSWE_PIP_SPEC, sb.sandbox_id[:8]) - await sb.exec(f"pip install --no-input {shlex.quote(MSWE_PIP_SPEC)}", check=True, timeout=600) - - -async def _launch_and_poll( - sb: Sandbox, - *, - workdir: str, - session_id: str, - adapter_url: str, - time_budget_sec: int, -) -> int: - """Launch the runner detached + poll a done-marker file. - - Sandbox gateways reset HTTP/2 around ~6.5 min, so we cannot keep a - long-lived foreground exec. The launcher writes the exit code into a marker - file; we poll it every 5s via short RPCs (which also keeps the sandbox - alive against idle GC). - """ - q = shlex.quote - base = q(f"{adapter_url}/v1") - launcher_body = ( - "#!/bin/bash\n" - f"cd {q(workdir)}\n" - # litellm's openai provider reads these for base URL + bearer auth. - f"export OPENAI_API_BASE={base}\n" - f"export OPENAI_BASE_URL={base}\n" - f"export OPENAI_API_KEY={q(session_id)}\n" - f"export MSWE_MODEL={q(MODEL_NAME)}\n" - f"export MSWE_WORKDIR={q(workdir)}\n" - f"export MSWE_PROBLEM_FILE={q(f'{workdir}/{_PROBLEM}')}\n" - f"export MSWE_STEP_LIMIT={q(str(MSWE_STEP_LIMIT))}\n" - f"python {q(_RUNNER)} > {q(_LOG)} 2>&1\n" - f"echo $? > {q(_DONE)}\n" - ) - await sb.write_file(f"{workdir}/{_LAUNCH}", launcher_body) - await sb.exec(f"cd {q(workdir)} && chmod +x {q(_LAUNCH)}", check=False, timeout=30) - # Detach so the exec RPC returns immediately; the marker file is the signal. - await sb.exec( - f"cd {q(workdir)} && setsid bash {q(_LAUNCH)} < /dev/null > /dev/null 2>&1 &", - check=False, - timeout=30, - ) - done_path = f"{workdir}/{_DONE}" - deadline = time.time() + time_budget_sec - exit_code = -2 # convention: -2 = budget exceeded - while time.time() < deadline: - await asyncio.sleep(5) - ec, out, _ = await sb.exec(f"test -f {q(done_path)} && cat {q(done_path)}", check=False, timeout=15) - if ec == 0: - try: - exit_code = int((out or "").strip() or "-1") - except ValueError: - exit_code = -1 - break - logger.info("[mini_swe_agent] session=%s exit=%s elapsed<=%ds", session_id, exit_code, time_budget_sec) - return exit_code + base = f"{adapter_url}/v1" + env = { + "OPENAI_API_BASE": base, + "OPENAI_BASE_URL": base, + "OPENAI_API_KEY": session_id, + "MSWE_MODEL": self.model_name, + "MSWE_WORKDIR": workdir, + "MSWE_PROBLEM_FILE": f"{workdir}/{PROBLEM_FILE}", + # Override ladder: global env > per-row builtin > universal config. + "MSWE_CONFIG": MSWE_CONFIG or md.get("agent_config") or "", + "MSWE_CONFIG_FILE": f"{workdir}/{_CONFIG_FILE}", + "MSWE_STEP_LIMIT": str(MSWE_STEP_LIMIT), + "MSWE_MAX_EMPTY_TURNS": str(MSWE_MAX_EMPTY_TURNS), + "MSWE_PATH_PREPEND": MSWE_PATH_PREPEND, + "MSWEA_SILENT_STARTUP": "1", + } + # Run with the ISOLATED venv interpreter; -P keeps the workdir off + # sys.path so a repo sharing a name with an agent dep can't shadow it. + return await self._detached_run( + sb, + workdir=workdir, + command=f"{shlex.quote(_VENV_PY)} -P {shlex.quote(_RUNNER)}", + env=env, + time_budget_sec=time_budget_sec, + log_tag=f"session={session_id}", + ) -# sandbox.git_diff should exclude these scratch files from the captured diff: -DIFF_EXCLUDE = (_PROBLEM, _RUNNER, _LAUNCH, _DONE, _LOG, ".mswe_pre.sh") +# Module export for dotted-module-path loading (see load_runtime). +RUNTIME = MiniSweAgentRuntime diff --git a/async_rl_research/aiohttp_threaded.py b/async_rl_research/aiohttp_threaded.py index 2b9d12c87f..fc9f737080 100644 --- a/async_rl_research/aiohttp_threaded.py +++ b/async_rl_research/aiohttp_threaded.py @@ -1,11 +1,6 @@ -"""Run an ``aiohttp.web.Application`` in a background daemon thread. - -This is the "http" piece: ``generate.py`` builds the slime adapter (an -``aiohttp`` app that speaks the agent's wire API on the front and SGLang -``/generate`` on the back) and serves it here, on a daemon thread, so the -synchronous slime rollout loop and the in-sandbox agent's HTTP callbacks can -run concurrently. Verbatim copy of -``examples/coding_agent_rl/aiohttp_threaded.py`` (generic, no SWE specifics). +"""Run an ``aiohttp.web.Application`` in a background daemon thread, so the +synchronous slime rollout loop and the agent's HTTP callbacks run concurrently. +Verbatim copy of ``examples/coding_agent_rl/aiohttp_threaded.py``. """ from __future__ import annotations @@ -54,9 +49,7 @@ def run_app_in_thread( ) -> AppHandle: """Spin up ``app`` on a daemon thread; block until it is listening. - ``runner_kwargs`` is forwarded to ``web.AppRunner`` (e.g. pass - ``{"handler_cancellation": True}`` to make a client disconnect cancel - the in-flight handler coroutine). + ``runner_kwargs`` is forwarded to ``web.AppRunner``. """ started = threading.Event() err_box: list[BaseException] = [] diff --git a/async_rl_research/architecture.png b/async_rl_research/architecture.png new file mode 100644 index 0000000000..9c20b2a036 Binary files /dev/null and b/async_rl_research/architecture.png differ diff --git a/async_rl_research/environment/__init__.py b/async_rl_research/environment/__init__.py new file mode 100644 index 0000000000..6595716ac5 --- /dev/null +++ b/async_rl_research/environment/__init__.py @@ -0,0 +1,5 @@ +"""Task environments (env) + their dataset converters (env/convert2slime).""" + +from .base import PROBLEM_FILE, EnvMetadataError, RewardResult, RolloutEnv, load_env + +__all__ = ["PROBLEM_FILE", "EnvMetadataError", "RewardResult", "RolloutEnv", "load_env"] diff --git a/async_rl_research/environment/base.py b/async_rl_research/environment/base.py new file mode 100644 index 0000000000..7db242e1a5 --- /dev/null +++ b/async_rl_research/environment/base.py @@ -0,0 +1,183 @@ +"""RolloutEnv: the contract between ``generate.py`` and one task family. + +An *env* packages everything task-family-specific (SWE-Gym, harbor, ...): row +validation, sandbox boot/prep, driving the agent across step(s), grading into a +reward. Mirrors ``agent/base.py``'s AgentRuntime: one ``generate()`` +orchestrates ``runtime x env``. + +Schema-pair convention: ``env/.py`` and ``env/convert2slime/.py`` +are paired -- the converter is the only writer of the ``metadata`` dict and +``normalize_metadata`` the only reader. Rows select their env via +``metadata.task_type`` (absent -> ``swe_gym``). + +Writing a new env: subclass, declare ``name``, implement ``normalize_metadata`` ++ ``rollout``, register in ``ENVS``. ``rollout`` owns the whole sandbox +lifecycle. Envs are cached once per worker and must be stateless across samples. +""" + +from __future__ import annotations + +import gzip +import importlib +import io +import logging +import shlex +import tarfile +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar + +logger = logging.getLogger(__name__) + + +# Problem statement written to the workdir; always excluded from captured diffs. +PROBLEM_FILE = "PROBLEM_STATEMENT.md" + + +class EnvMetadataError(ValueError): + """A dataset row is unusable for this env; str(err) becomes the abort reason.""" + + +@dataclass(frozen=True) +class RewardResult: + """What an env's ``rollout`` hands back to the trajectory merge: a scalar + ``reward``, an ``is_solved`` flag for solve-rate logging, and ``extra`` + env-specific diagnostics merged into trajectory metadata. + """ + + reward: float + is_solved: bool + extra: dict[str, Any] = field(default_factory=dict) + + +class RolloutEnv(ABC): + """One task family's integration: row schema + sandbox + grading. + + ``generate.py`` only ever touches: ``name``, ``normalize_metadata``, + ``rollout``. + """ + + name: ClassVar[str] + + def __init_subclass__(cls, **kwargs) -> None: + # Fail at import time, not mid-rollout. + super().__init_subclass__(**kwargs) + if getattr(cls, "name", None) is None: + raise TypeError(f"{cls.__name__} must define class attribute 'name' (see RolloutEnv)") + + @abstractmethod + def normalize_metadata(self, sample) -> dict[str, Any]: + """Normalize one dataset row (slime ``Sample``) into the env's md dict. + + Must include ``instance_id``, plus ``workdir``/``agent_config`` if the + runtime reads them. Raises ``EnvMetadataError`` for unrunnable rows. + """ + + @abstractmethod + async def rollout( + self, + md: dict[str, Any], + *, + runtime, + session_id: str, + adapter_url: str, + agent_time_budget_sec: int, + eval_timeout_sec: int, + ) -> RewardResult: + """Run the full task episode: boot, prep, agent run(s), grading. + + Call ``runtime.run_agent`` per agent leg; all legs share one adapter + session so a multi-step episode stays one trajectory. + ``agent_time_budget_sec`` bounds TOTAL agent wallclock across legs; + ``eval_timeout_sec`` caps each grading command. + """ + + def effective_budgets( + self, md: dict[str, Any], *, agent_time_budget_sec: int, eval_timeout_sec: int + ) -> dict[str, int]: + """Wall-clock budgets actually enforced this rollout (for the dump/dashboard).""" + from ..modal_sandbox import ModalSandbox + + return { + "boot_sec": ModalSandbox._boot_timeout_from_env(), + "agent_sec": agent_time_budget_sec, + "eval_sec": eval_timeout_sec, + } + + # ------------------------------------------------------------------ + # Shared sandbox helpers + # ------------------------------------------------------------------ + @staticmethod + async def write_problem_file(sb, workdir: str, text: str | None) -> None: + await sb.write_file(f"{workdir}/{PROBLEM_FILE}", text or "") + + @staticmethod + async def upload_dir(sb, host_dir: str | Path, sandbox_dir: str) -> None: + """Copy a host dir's CONTENTS into a fresh ``sandbox_dir`` via one + gzipped tar (task tests/solution dirs are small).""" + host_dir = Path(host_dir) + buf = io.BytesIO() + # mtime=0 so re-uploads of identical content are byte-identical. + with gzip.GzipFile(fileobj=buf, mode="wb", mtime=0) as gz: + with tarfile.open(fileobj=gz, mode="w") as tar: + tar.add(host_dir, arcname=".") + archive = f"/tmp/.upload_{abs(hash(str(host_dir))) % 10**8}.tgz" + await sb.write_file(archive, buf.getvalue()) + q = shlex.quote + await sb.exec( + f"rm -rf {q(sandbox_dir)} && mkdir -p {q(sandbox_dir)} && tar -xzf {q(archive)} -C {q(sandbox_dir)}", + check=True, + timeout=120, + ) + + +def coerce_prompt(prompt) -> str: + """Best-effort extraction of plain text from a slime prompt field.""" + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list): + for m in prompt: + if isinstance(m, dict) and m.get("role") == "user": + c = m.get("content") + if isinstance(c, str): + return c + if isinstance(c, list): + return "\n".join(p.get("text", "") for p in c if isinstance(p, dict) and p.get("type") == "text") + return "" + + +# --------------------------------------------------------------------------- +# Registry + loader (mirrors agent.base.RUNTIMES / load_runtime) +# --------------------------------------------------------------------------- +DEFAULT_ENV = "swe_gym" + +# task_type -> "module:Class" (strings so importing base.py imports no env module). +ENVS: dict[str, str] = { + "swe_gym": "async_rl_research.environment.swe_gym:SweGymEnv", + "harbor": "async_rl_research.environment.harbor:HarborEnv", +} + +_ENV_CACHE: dict[str, RolloutEnv] = {} + + +def load_env(spec: str | None = None) -> RolloutEnv: + """Resolve ``spec`` (a row's ``metadata.task_type``) to a cached env. + + Accepts a registry short name ("harbor") or "pkg.module:Class"; absent -> + ``DEFAULT_ENV``. + """ + spec = spec or DEFAULT_ENV + cached = _ENV_CACHE.get(spec) + if cached is not None: + return cached + target = ENVS.get(spec, spec) + if ":" not in target: + raise ValueError(f"unknown task_type {spec!r}; known: {sorted(ENVS)} (or pass 'pkg.module:Class')") + module_path, _, attr = target.partition(":") + cls = getattr(importlib.import_module(module_path), attr, None) + if not (isinstance(cls, type) and issubclass(cls, RolloutEnv)): + raise TypeError(f"task_type {spec!r} resolved to {cls!r}, which is not a RolloutEnv subclass") + env = cls() + _ENV_CACHE[spec] = env + return env diff --git a/async_rl_research/environment/convert2slime/__init__.py b/async_rl_research/environment/convert2slime/__init__.py new file mode 100644 index 0000000000..b556da7301 --- /dev/null +++ b/async_rl_research/environment/convert2slime/__init__.py @@ -0,0 +1,5 @@ +"""Dataset converters: one per env, paired by filename (see env/base.py). + +``env/convert2slime/.py`` is the only writer of the ``metadata`` schema +``env/.py`` reads. Run offline; may carry heavy deps the rollout never imports. +""" diff --git a/async_rl_research/environment/convert2slime/harbor.py b/async_rl_research/environment/convert2slime/harbor.py new file mode 100644 index 0000000000..05a91c942c --- /dev/null +++ b/async_rl_research/environment/convert2slime/harbor.py @@ -0,0 +1,304 @@ +"""Materialize a harbor dataset as slime prompt data + local task dirs. + +Schema pair of ``env/harbor.py`` and the ONLY writer of that schema: parses +each ``task.toml`` offline and bakes everything into ``metadata`` so the rollout +never reads harbor config. Output goes under ``--out-dir`` (on the slime-data +volume): ``.jsonl`` + ``tasks//``, referenced via +``metadata.task_path`` relative to it (export ``ASYNC_RL_TASK_ROOT=``). + +Sources:: + + python -m async_rl_research.environment.convert2slime.harbor \ + --tasks-dir ~/harbor-datasets/datasets/usaco --out-dir data/usaco + python -m async_rl_research.environment.convert2slime.harbor \ + --registry ~/harbor/registry.json --dataset usaco --out-dir data/usaco + +v1 scope (else skipped + logged): linux, single-container Dockerfile or prebuilt +docker_image, shared-environment verification, no GPU/TPU, no MCP. Multi-step +tasks are supported. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import re +import shutil +from pathlib import Path +from typing import Any + +import tomllib + +logger = logging.getLogger(__name__) + +_COPY_IGNORE = shutil.ignore_patterns(".git", "__pycache__", ".DS_Store", ".venv", "node_modules") +_WORKDIR_RE = re.compile(r"^\s*WORKDIR\s+(.+?)\s*$", re.IGNORECASE | re.MULTILINE) + + +class SkipTask(Exception): + """Task is outside v1 scope; str(err) is the logged reason.""" + + +def _parse_dockerfile_workdir(dockerfile: Path) -> str | None: + matches = _WORKDIR_RE.findall(dockerfile.read_text(encoding="utf-8", errors="replace")) + if not matches: + return None + last = matches[-1].strip().strip('"').strip("'") + # Variable/relative WORKDIRs can't be resolved statically; let the rollout + # detect cwd from the sandbox. + return last if last.startswith("/") and "$" not in last else None + + +def _check_verifier_shared(verifier_cfg: dict[str, Any], where: str) -> None: + if verifier_cfg.get("environment_mode") == "separate" or verifier_cfg.get("environment"): + raise SkipTask(f"separate verifier environment ({where})") + + +def _verifier_md(verifier_cfg: dict[str, Any]) -> dict[str, Any]: + md: dict[str, Any] = {} + if verifier_cfg.get("timeout_sec"): + md["timeout_sec"] = float(verifier_cfg["timeout_sec"]) + if verifier_cfg.get("env"): + md["env"] = dict(verifier_cfg["env"]) + return md + + +def _instruction(path: Path, where: str) -> str: + if not path.is_file(): + raise SkipTask(f"missing {where}") + return path.read_text(encoding="utf-8") + + +def _steps_md(cfg: dict[str, Any], task_dir: Path) -> list[dict[str, Any]]: + steps_md = [] + for step in cfg.get("steps") or []: + name = step.get("name") + if not name: + raise SkipTask("unnamed step") + step_dir = task_dir / "steps" / name + # Fall back to the shared top-level tests/ when a step ships none. + tests_path = f"steps/{name}/tests" if (step_dir / "tests" / "test.sh").is_file() else "tests" + if not (task_dir / tests_path / "test.sh").is_file(): + raise SkipTask(f"step {name!r} has no tests/test.sh (step or shared)") + _check_verifier_shared(step.get("verifier") or {}, f"step {name!r}") + steps_md.append( + { + "name": name, + "instruction": _instruction(step_dir / "instruction.md", f"steps/{name}/instruction.md"), + "tests_path": tests_path, + "verifier": _verifier_md(step.get("verifier") or {}), + "min_reward": step.get("min_reward"), + "agent_timeout_sec": (step.get("agent") or {}).get("timeout_sec"), + } + ) + return steps_md + + +def translate_task(task_dir: Path, *, dataset: str | None = None) -> dict[str, Any]: + """One harbor task dir -> one slime row (task_path filled by caller). + + ``dataset`` qualifies tasks whose task.toml has no ``[task].name``. Raises + ``SkipTask`` for tasks outside v1 scope. + """ + config_path = task_dir / "task.toml" + if not config_path.is_file(): + raise SkipTask("no task.toml") + cfg = tomllib.loads(config_path.read_text(encoding="utf-8")) + + env_cfg = cfg.get("environment") or {} + if (env_cfg.get("os") or "linux") != "linux": + raise SkipTask(f"os={env_cfg.get('os')}") + if env_cfg.get("gpus") or env_cfg.get("gpu_types") or env_cfg.get("tpu"): + raise SkipTask("requires GPU/TPU") + if env_cfg.get("mcp_servers"): + raise SkipTask("requires MCP servers") + if env_cfg.get("network_mode") in ("no-network", "allowlist"): + # Modal can't enforce per-phase network policies yet, so running these + # would grant more network than the task allows. + raise SkipTask(f"network_mode={env_cfg['network_mode']}") + + docker_image = env_cfg.get("docker_image") + dockerfile = None + workdir = env_cfg.get("workdir") + if not docker_image: + if (task_dir / "environment" / "docker-compose.yaml").is_file() or ( + task_dir / "environment" / "docker-compose.yml" + ).is_file(): + raise SkipTask("docker-compose environment") + if not (task_dir / "environment" / "Dockerfile").is_file(): + raise SkipTask("no environment/Dockerfile or docker_image") + dockerfile = "environment/Dockerfile" + if not workdir: + workdir = _parse_dockerfile_workdir(task_dir / "environment" / "Dockerfile") + + verifier_cfg = cfg.get("verifier") or {} + _check_verifier_shared(verifier_cfg, "task") + + steps_md = _steps_md(cfg, task_dir) + if steps_md: + instruction_path = task_dir / "instruction.md" + instruction = ( + instruction_path.read_text(encoding="utf-8") if instruction_path.is_file() else steps_md[0]["instruction"] + ) + else: + instruction = _instruction(task_dir / "instruction.md", "instruction.md") + if not (task_dir / "tests" / "test.sh").is_file(): + raise SkipTask("no tests/test.sh") + + fallback = f"{dataset}/{task_dir.name}" if dataset else task_dir.name + task_name = ((cfg.get("task") or {}).get("name") or fallback).strip() + instance_id = re.sub(r"[^A-Za-z0-9_.-]+", "__", task_name) or task_dir.name + + metadata: dict[str, Any] = { + "task_type": "harbor", + "instance_id": instance_id, + "docker_image": docker_image, + "dockerfile": dockerfile, + "workdir": workdir, + "problem_statement": instruction, + "agent_timeout_sec": (cfg.get("agent") or {}).get("timeout_sec"), + "build_timeout_sec": env_cfg.get("build_timeout_sec"), + "verifier": _verifier_md(verifier_cfg), + "steps": steps_md or None, + "reward_strategy": cfg.get("multi_step_reward_strategy"), + "cpus": env_cfg.get("cpus"), + "memory_mb": env_cfg.get("memory_mb"), + } + metadata = {k: v for k, v in metadata.items() if v is not None} + + # prompt as a single-message list, NOT a raw string: slime's Dataset asserts + # a list when a HF processor loads for hf_checkpoint. Harmless: the agent + # reads metadata["problem_statement"], never sample.prompt. + return {"prompt": [{"role": "user", "content": instruction}], "label": task_name, "metadata": metadata} + + +def convert( + task_dirs: list[Path], out_dir: Path, *, name: str, dataset: str | None = None, limit: int | None = None +) -> tuple[int, int]: + """Copy tasks + write the JSONL. Returns (converted, skipped).""" + out_dir.mkdir(parents=True, exist_ok=True) + tasks_out = out_dir / "tasks" + rows: list[dict[str, Any]] = [] + skipped = 0 + + for task_dir in task_dirs: + if limit is not None and len(rows) >= limit: + break + try: + row = translate_task(task_dir, dataset=dataset) + except SkipTask as e: + skipped += 1 + logger.warning("[harbor2slime] skip %s: %s", task_dir.name, e) + continue + instance_id = row["metadata"]["instance_id"] + dest = tasks_out / instance_id + if dest.exists(): + shutil.rmtree(dest) + shutil.copytree(task_dir, dest, ignore=_COPY_IGNORE) + row["metadata"]["task_path"] = f"tasks/{instance_id}" + rows.append(row) + + jsonl_path = out_dir / f"{name}.jsonl" + with open(jsonl_path, "w", encoding="utf-8") as fh: + for row in rows: + fh.write(json.dumps(row, ensure_ascii=False) + "\n") + logger.info("[harbor2slime] wrote %d rows -> %s (skipped %d)", len(rows), jsonl_path, skipped) + return len(rows), skipped + + +def _discover_task_dirs(tasks_dir: Path) -> list[Path]: + """Direct subdirectories holding a task.toml, sorted for determinism. + + Falls back to ``tasks_dir`` as a single task only when no subdir tasks exist + (some datasets ship a stray template task.toml at the root). + """ + subdirs = sorted(p for p in tasks_dir.iterdir() if (p / "task.toml").is_file()) + if subdirs: + if (tasks_dir / "task.toml").is_file(): + logger.warning( + "[harbor2slime] %s has both a root task.toml and %d task subdirs; using the subdirs", + tasks_dir, + len(subdirs), + ) + return subdirs + if (tasks_dir / "task.toml").is_file(): + return [tasks_dir] + return [] + + +def _download_from_registry(registry_spec: str, dataset: str, version: str | None, download_dir: Path) -> list[Path]: + """Fetch a dataset's tasks via the harbor package (optional dep).""" + try: + from harbor.models.registry import Registry + from harbor.tasks.client import TaskClient + except ImportError as exc: + raise SystemExit("--registry mode needs the `harbor` package: pip install harbor") from exc + + import asyncio + + if registry_spec.startswith(("http://", "https://")): + registry = Registry.from_url(registry_spec) + else: + registry = Registry.from_path(Path(registry_spec)) + matches = [d for d in registry.datasets if d.name == dataset and (version is None or d.version == version)] + if not matches: + known = sorted({d.name for d in registry.datasets}) + raise SystemExit(f"dataset {dataset!r} not in registry; known: {known}") + spec = matches[-1] + task_ids = [t.to_source_task_id() for t in spec.tasks] + logger.info("[harbor2slime] downloading %d tasks for %s==%s", len(task_ids), spec.name, spec.version) + result = asyncio.run(TaskClient().download_tasks(task_ids, output_dir=download_dir)) + + paths: list[Path] = [] + items = result.values() if hasattr(result, "values") else result + for item in items: + path = getattr(item, "path", None) or getattr(item, "downloaded_path", None) + if path: + paths.append(Path(path)) + if not paths: + raise SystemExit(f"registry download produced no task paths (result type {type(result).__name__})") + return sorted(paths) + + +def main(argv: list[str] | None = None) -> int: + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") + parser = argparse.ArgumentParser(description="Materialize a harbor dataset as slime prompt JSONL + task dirs.") + source = parser.add_mutually_exclusive_group(required=True) + source.add_argument("--tasks-dir", type=Path, help="directory of harbor task dirs (or a single task dir)") + source.add_argument("--registry", help="harbor registry.json path or URL (needs `pip install harbor`)") + parser.add_argument("--dataset", help="dataset name in the registry (with --registry)") + parser.add_argument("--dataset-version", help="dataset version in the registry (default: last match)") + parser.add_argument( + "--out-dir", type=Path, required=True, help="output dir (JSONL + tasks/); use the slime-data volume" + ) + parser.add_argument("--name", help="JSONL filename stem (default: dataset or tasks-dir name)") + parser.add_argument("--limit", type=int, help="maximum tasks to convert") + args = parser.parse_args(argv) + + if args.registry: + if not args.dataset: + parser.error("--registry requires --dataset") + task_dirs = _download_from_registry( + args.registry, args.dataset, args.dataset_version, args.out_dir / "downloads" + ) + name = args.name or args.dataset + dataset = args.dataset + else: + task_dirs = _discover_task_dirs(args.tasks_dir) + name = args.name or args.tasks_dir.name + dataset = args.tasks_dir.name + if not task_dirs: + raise SystemExit("no task dirs found") + + converted, skipped = convert(task_dirs, args.out_dir, name=name, dataset=dataset, limit=args.limit) + out_dir = args.out_dir.resolve() + print(f"converted {converted} tasks ({skipped} skipped) -> {out_dir / (name + '.jsonl')}") + print("next steps:") + print(f" export ASYNC_RL_TASK_ROOT={out_dir}") + print(f" python -m async_rl_research.environment.harbor {out_dir / (name + '.jsonl')} --limit 3 # oracle check") + return 0 if converted else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/async_rl_research/environment/convert2slime/openthoughts_agent.py b/async_rl_research/environment/convert2slime/openthoughts_agent.py new file mode 100644 index 0000000000..b02fb8694b --- /dev/null +++ b/async_rl_research/environment/convert2slime/openthoughts_agent.py @@ -0,0 +1,160 @@ +"""Materialize ``open-thoughts/OpenThoughts-Agent-v1-RL`` as slime prompt data. + +Each HF row packs a harbor task dir as a gzipped tar (``task_binary``); this +unpacks each into a staging tree and hands off to the shared ``harbor.convert``. +Output matches ``harbor.convert`` (``.jsonl`` + ``tasks//``). +These tasks carry no ``[task].name``, so ``instance_id`` falls back to +``__``. + +These tasks deposit their deliverable in ``/output`` (e.g. ``cp ... +/output/command_capture.txt``), but harbor only provisions the workdir + +``/logs/{agent,verifier,artifacts}``. Rather than special-case ``/output`` in +the generic env, we remap it onto harbor's ``/logs/artifacts`` here so the data +is fully ``/logs``-native (the tasks already write their reward to +``/logs/verifier``); see ``conform_output_paths``. + + python -m async_rl_research.environment.convert2slime.openthoughts_agent \ + --out-dir data/openthoughts_agent +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import logging +import re +import tarfile +from collections.abc import Iterable, Iterator +from pathlib import Path +from typing import Any + +from async_rl_research.environment.convert2slime import harbor + +logger = logging.getLogger(__name__) + +HF_DATASET = "open-thoughts/OpenThoughts-Agent-v1-RL" +DATASET_NAME = "openthoughts_agent" # instance_id prefix + JSONL stem + +# These tasks hardcode a /output capture dir; harbor only provisions the workdir +# + /logs/{agent,verifier,artifacts}. Remap /output onto harbor's artifacts dir. +_HARBOR_ARTIFACTS_DIR = "/logs/artifacts" +# Match the /output dir prefix only -- not lookalikes like /outputs or /output_dir. +_OUTPUT_DIR_RE = re.compile(r"/output(?![A-Za-z0-9_])") +# Runtime + agent-facing files only; never touch test data (e.g. the .txt holding +# tests/expected_output.txt, which could legitimately contain the literal /output). +_REMAP_SUFFIXES = (".sh", ".md", ".py") + + +def _archive_bytes(task_binary: bytes | str) -> bytes: + """The ``binary`` HF feature loads as bytes; JSON transports base64.""" + if isinstance(task_binary, str): + return base64.b64decode(task_binary) + return bytes(task_binary) + + +def unpack_tasks(rows: Iterable[dict[str, Any]], staging_dir: Path) -> list[Path]: + """Unpack each row's ``task_binary`` into ``staging_dir//``; returns + the sorted task dirs for ``harbor.convert``.""" + staging_dir.mkdir(parents=True, exist_ok=True) + task_dirs: list[Path] = [] + for row in rows: + path = row.get("path") + if not path: + logger.warning("[openthoughts2slime] skip row with no path") + continue + dest = staging_dir / str(path) + with tarfile.open(fileobj=io.BytesIO(_archive_bytes(row["task_binary"])), mode="r:gz") as tf: + # filter="data" blocks path traversal/unsafe members (py3.12+). + tf.extractall(dest, filter="data") + task_dirs.append(dest) + return sorted(task_dirs) + + +def conform_output_paths(task_dirs: Iterable[Path]) -> int: + """Remap the legacy ``/output`` capture dir onto harbor's ``/logs/artifacts`` + in each task's scripts + instructions (solve.sh, test.sh, instruction.md). + + Rewrites in-place under the staging tree before ``harbor.convert`` copies it. + Returns the count of files changed. Only ``_REMAP_SUFFIXES`` are touched, so + test data (``tests/expected_output.txt``) is never rewritten. + """ + rewritten = 0 + for task_dir in task_dirs: + for path in task_dir.rglob("*"): + if not path.is_file() or path.suffix not in _REMAP_SUFFIXES: + continue + try: + text = path.read_text(encoding="utf-8") + except UnicodeDecodeError: + continue + new = _OUTPUT_DIR_RE.sub(_HARBOR_ARTIFACTS_DIR, text) + if new != text: + path.write_text(new, encoding="utf-8") + rewritten += 1 + return rewritten + + +def load_hf_rows(repo: str, split: str, limit: int | None) -> Iterator[dict[str, Any]]: + try: + from datasets import load_dataset + except ImportError as exc: + raise SystemExit("Install `datasets` to pull OpenThoughts-Agent from HuggingFace.") from exc + + for index, row in enumerate(load_dataset(repo, split=split)): + if limit is not None and index >= limit: + break + yield dict(row) + + +def materialize( + out_dir: Path, + *, + name: str = DATASET_NAME, + repo: str = HF_DATASET, + split: str = "train", + limit: int | None = None, +) -> tuple[int, int]: + """Download the HF dataset, unpack tasks, and run the harbor converter. + + Returns ``(converted, skipped)``. Staging lives under ``out_dir`` (one + filesystem for the converter's copytree) and is removed after. + """ + import shutil + + staging = out_dir / "_staging" + rows = load_hf_rows(repo, split, limit) + task_dirs = unpack_tasks(rows, staging) + logger.info("[openthoughts2slime] unpacked %d tasks -> %s", len(task_dirs), staging) + remapped = conform_output_paths(task_dirs) + logger.info("[openthoughts2slime] remapped /output -> %s in %d files", _HARBOR_ARTIFACTS_DIR, remapped) + try: + return harbor.convert(task_dirs, out_dir, name=name, dataset=DATASET_NAME) + finally: + shutil.rmtree(staging, ignore_errors=True) + + +def main(argv: list[str] | None = None) -> int: + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--out-dir", type=Path, required=True, help="output dir (JSONL + tasks/); use the slime-data volume" + ) + parser.add_argument("--name", default=DATASET_NAME, help="JSONL filename stem") + parser.add_argument("--repo", default=HF_DATASET, help="HuggingFace dataset repo id") + parser.add_argument("--split", default="train", help="HuggingFace split") + parser.add_argument("--limit", type=int, help="maximum tasks to convert") + args = parser.parse_args(argv) + + converted, skipped = materialize(args.out_dir, name=args.name, repo=args.repo, split=args.split, limit=args.limit) + out_dir = args.out_dir.resolve() + jsonl = out_dir / f"{args.name}.jsonl" + print(f"converted {converted} tasks ({skipped} skipped) -> {jsonl}") + print("next steps:") + print(f" export ASYNC_RL_TASK_ROOT={out_dir}") + print(f" python -m async_rl_research.environment.harbor {jsonl} --task-root {out_dir} --limit 3 # oracle check") + return 0 if converted else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/async_rl_research/environment/convert2slime/swe_gym.py b/async_rl_research/environment/convert2slime/swe_gym.py new file mode 100644 index 0000000000..a4707e7f1e --- /dev/null +++ b/async_rl_research/environment/convert2slime/swe_gym.py @@ -0,0 +1,174 @@ +"""Translate SWE-Gym / SWE-Gym-Lite rows into slime prompt data. + +Schema pair of ``env/swe_gym.py`` (SWE rows carry no ``metadata.task_type``: +the default env). Output is one JSON object per line: + + { + "prompt": "...", + "label": "owner__repo-123", + "metadata": { + "instance_id": "...", + "image": "...", + "workdir": "/testbed", + "problem_statement": "...", + "eval_cmd": "echo ... | base64 -d > /tmp/swegym_eval.py && python /tmp/swegym_eval.py", + "pre_commands": ["git checkout -f"] + } + } + +SWE-Gym-specific: derive the prebuilt image name and build a pytest reward +command from ``test_patch`` + F2P/P2P tests. +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import os +from collections.abc import Iterable, Iterator +from itertools import islice +from pathlib import Path +from typing import Any + +HF_DATASET = "SWE-Gym/SWE-Gym" +HF_DATASET_LITE = "SWE-Gym/SWE-Gym-Lite" +IMAGE_PREFIX = os.environ.get("SWE_GYM_IMAGE_PREFIX", "docker.io/xingyaoww") +IMAGE_TAG = os.environ.get("SWE_GYM_IMAGE_TAG", "latest") +WORKDIR = "/testbed" + + +def image_for(instance_id: str) -> str: + name = "sweb.eval.x86_64." + instance_id.replace("__", "_s_") + return f"{IMAGE_PREFIX.rstrip('/')}/{name}:{IMAGE_TAG}" + + +def as_list(value: Any) -> list[str]: + if isinstance(value, str): + value = json.loads(value) if value.strip() else [] + return [str(item) for item in (value or [])] + + +def build_eval_cmd(test_patch: str, tests: list[str]) -> str: + patch_b64 = base64.b64encode((test_patch or "").encode()).decode("ascii") + script = "\n".join( + [ + "import base64", + "import pathlib", + "import subprocess", + "import sys", + "", + "PATCH_B64 = " + repr(patch_b64), + "TESTS = " + json.dumps(tests), + "patch_path = pathlib.Path('/tmp/swegym_test.patch')", + "patch_path.write_bytes(base64.b64decode(PATCH_B64))", + "if patch_path.stat().st_size:", + " commands = [", + " ['git', 'apply', '-v', str(patch_path)],", + " ['git', 'apply', '--3way', str(patch_path)],", + " ['patch', '-p1', '--no-backup-if-mismatch', '-i', str(patch_path)],", + " ]", + " for command in commands:", + " result = subprocess.run(command)", + " if result.returncode == 0:", + " break", + " else:", + " sys.exit(result.returncode)", + "", + "import pytest", + "sys.exit(pytest.main(['--no-header', '-rN', '-p', 'no:cacheprovider', *TESTS]))", + "", + ] + ) + script_b64 = base64.b64encode(script.encode()).decode("ascii") + return f"echo {script_b64} | base64 -d > /tmp/swegym_eval.py && python /tmp/swegym_eval.py" + + +def translate(raw: dict[str, Any]) -> dict[str, Any] | None: + instance_id = raw.get("instance_id") + if not instance_id: + return None + + tests = as_list(raw.get("FAIL_TO_PASS")) + as_list(raw.get("PASS_TO_PASS")) + if not tests: + return None + + problem = raw.get("problem_statement") or "" + metadata: dict[str, Any] = { + "instance_id": instance_id, + "image": image_for(instance_id), + "workdir": WORKDIR, + "problem_statement": problem, + "eval_cmd": build_eval_cmd(raw.get("test_patch") or "", tests), + "repo": raw.get("repo"), + "base_commit": raw.get("base_commit"), + "version": raw.get("version"), + } + if base_commit := raw.get("base_commit"): + metadata["pre_commands"] = [f"git checkout {base_commit} -f"] + + # prompt as a single-message list, NOT a raw string: slime's Dataset asserts + # a list when a HF processor loads for hf_checkpoint. Harmless: the agent + # reads problem_statement from metadata, never sample.prompt. + return { + "prompt": [{"role": "user", "content": problem}], + "label": instance_id, + "metadata": metadata, + } + + +def iter_jsonl(path: str | Path) -> Iterator[dict[str, Any]]: + with open(path, encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if line: + yield json.loads(line) + + +def load_hf(split: str, *, lite: bool, limit: int | None) -> Iterator[dict[str, Any]]: + try: + from datasets import load_dataset + except ImportError as exc: + raise SystemExit("Install `datasets` to pull SWE-Gym from HuggingFace.") from exc + + dataset = HF_DATASET_LITE if lite else HF_DATASET + for index, row in enumerate(load_dataset(dataset, split=split)): + if limit is not None and index >= limit: + break + yield dict(row) + + +def write_jsonl(rows: Iterable[dict[str, Any]], out_path: str | Path) -> int: + count = 0 + with open(out_path, "w", encoding="utf-8") as fh: + for raw in rows: + row = translate(raw) + if row is None: + continue + fh.write(json.dumps(row, ensure_ascii=False) + "\n") + count += 1 + return count + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Translate SWE-Gym data to slime prompt JSONL.") + parser.add_argument("--out", required=True, help="output JSONL path") + parser.add_argument("--input", help="downloaded SWE-Gym JSONL to convert") + parser.add_argument("--split", default="train", help="HuggingFace split when --input is omitted") + parser.add_argument("--lite", action="store_true", help="use SWE-Gym-Lite when --input is omitted") + parser.add_argument("--limit", type=int, help="maximum rows to read") + args = parser.parse_args(argv) + + if args.input: + raw_rows: Iterable[dict[str, Any]] = iter_jsonl(args.input) + if args.limit is not None: + raw_rows = islice(raw_rows, args.limit) + else: + raw_rows = load_hf(args.split, lite=args.lite, limit=args.limit) + count = write_jsonl(raw_rows, args.out) + print(f"wrote {count} rows -> {args.out}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/async_rl_research/environment/harbor.py b/async_rl_research/environment/harbor.py new file mode 100644 index 0000000000..5092d98ea1 --- /dev/null +++ b/async_rl_research/environment/harbor.py @@ -0,0 +1,494 @@ +"""Harbor env: run harbor-format tasks (USACO, ...) as RL episodes on Modal. + +Schema pair of ``env/convert2slime/harbor.py`` (see ``base.py``); the converter +bakes everything into ``metadata`` so this never reads ``task.toml`` at rollout. + +Episode (harbor "shared" verifier semantics): boot sandbox, then per step write +the instruction, run the agent leg against the shared session, verify IN-PLACE +(upload tests/, run test.sh, parse /logs/verifier/reward.{json,txt}), and gate +on min_reward. Per-step rewards aggregate (mean | final) to a scalar. Tests are +uploaded only AFTER the agent leg so the agent can't read them. + +Rollout needs ``ASYNC_RL_TASK_ROOT`` (dir relative ``task_path``s resolve +against). Oracle check (no model): + + python -m async_rl_research.environment.harbor out/usaco.jsonl --task-root out --limit 3 +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import shlex +import time +from pathlib import Path +from typing import Any + +from ..modal_sandbox import DockerfileImage, ModalSandbox +from ..profiles.profiling import PhaseTimer +from .base import EnvMetadataError, RewardResult, RolloutEnv, coerce_prompt + +logger = logging.getLogger(__name__) + +TASK_ROOT_ENV = "ASYNC_RL_TASK_ROOT" + +# Per-task task.toml timeouts override the env defaults only when this is set; +# default (off) keeps the env vars (boot/agent/eval) the single source of truth. +TASK_TIMEOUT_OVERRIDE = os.environ.get("ASYNC_RL_TASK_TIMEOUT_OVERRIDE", "0").strip().lower() in ("1", "true", "yes") + +# ${VAR} / ${VAR:-default} templates in verifier env values, resolved against +# the HEAD's os.environ (unresolvable -> skipped with a warning). +_ENV_TEMPLATE = re.compile(r"^\$\{([A-Za-z_][A-Za-z0-9_]*)(?::-(.*))?\}$") + + +def _effective(env_val: int, task_val: Any) -> int: + """Task.toml timeout wins over the env default only under TASK_TIMEOUT_OVERRIDE.""" + return int(task_val) if (TASK_TIMEOUT_OVERRIDE and task_val) else int(env_val) + + +def _resolve_env_templates(env: dict[str, str] | None) -> dict[str, str]: + resolved: dict[str, str] = {} + for key, value in (env or {}).items(): + m = _ENV_TEMPLATE.fullmatch(str(value)) + if not m: + resolved[key] = str(value) + continue + name, default = m.group(1), m.group(2) + actual = os.environ.get(name, default) + if actual is None: + logger.warning("[harbor] env %s=${%s} unresolvable on this host; skipping", key, name) + continue + resolved[key] = actual + return resolved + + +def _scalar_reward(rewards: dict[str, Any] | None) -> float: + """Harbor 1D convention: the 'reward' key; a single-entry dict counts too.""" + if not rewards: + return 0.0 + if "reward" in rewards: + return float(rewards["reward"]) + if len(rewards) == 1: + return float(next(iter(rewards.values()))) + logger.warning("[harbor] multi-key rewards %s without 'reward' key; scalar=0", sorted(rewards)) + return 0.0 + + +def _meets_min_reward(rewards: dict[str, Any] | None, min_reward: float | dict[str, float] | None) -> bool: + """Harbor's step gate: missing rewards/keys are treated as -inf.""" + if min_reward is None: + return True + if isinstance(min_reward, dict): + return all( + rewards is not None and key in rewards and float(rewards[key]) >= float(v) for key, v in min_reward.items() + ) + return rewards is not None and "reward" in rewards and float(rewards["reward"]) >= float(min_reward) + + +class HarborEnv(RolloutEnv): + name = "harbor" + # No agent_config default: harbor instruction.md files carry their own + # deliverable contract. Override per-row via metadata.agent_config. + + # Row schema (written by env/convert2slime/harbor.py -- keep in sync) + def normalize_metadata(self, sample) -> dict[str, Any]: + m = sample.metadata or {} + task_path = m.get("task_path") + if not task_path: + raise EnvMetadataError("missing_task_path") + task_dir = self._resolve_task_dir(task_path) + + docker_image = m.get("docker_image") + dockerfile = m.get("dockerfile") + if not docker_image and not dockerfile: + raise EnvMetadataError("missing_docker_image_or_dockerfile") + if dockerfile and not (task_dir / dockerfile).is_file(): + raise EnvMetadataError(f"dockerfile_missing:{task_dir / dockerfile}") + + steps = m.get("steps") or None + if steps: + for step in steps: + if not (task_dir / step["tests_path"] / "test.sh").is_file(): + raise EnvMetadataError(f"tests_missing:{step['tests_path']}") + elif not (task_dir / "tests" / "test.sh").is_file(): + raise EnvMetadataError("tests_missing:tests") + + return { + "instance_id": m.get("instance_id") or sample.label or task_dir.name, + "task_dir": str(task_dir), + "docker_image": docker_image, + "dockerfile": dockerfile, + "workdir": m.get("workdir"), # None -> detected from the sandbox + "problem_statement": m.get("problem_statement") or coerce_prompt(sample.prompt), + "agent_timeout_sec": m.get("agent_timeout_sec"), + "build_timeout_sec": m.get("build_timeout_sec"), + "verifier": m.get("verifier") or {}, + "steps": steps, + "reward_strategy": m.get("reward_strategy"), + "cpus": m.get("cpus"), + "memory_mb": m.get("memory_mb"), + "agent_config": m.get("agent_config"), + } + + def effective_budgets( + self, md: dict[str, Any], *, agent_time_budget_sec: int, eval_timeout_sec: int + ) -> dict[str, int]: + return { + "boot_sec": _effective(ModalSandbox._boot_timeout_from_env(), md.get("build_timeout_sec")), + "agent_sec": _effective(agent_time_budget_sec, md.get("agent_timeout_sec")), + "eval_sec": _effective(eval_timeout_sec, (md.get("verifier") or {}).get("timeout_sec")), + } + + @staticmethod + def _resolve_task_dir(task_path: str) -> Path: + p = Path(task_path) + if not p.is_absolute(): + root = os.environ.get(TASK_ROOT_ENV) + if not root: + raise EnvMetadataError(f"{TASK_ROOT_ENV}_unset") + p = Path(root) / p + if not p.is_dir(): + raise EnvMetadataError(f"task_dir_missing:{p}") + return p + + # Episode + def _image(self, md: dict[str, Any]) -> str | DockerfileImage: + if md["docker_image"]: + return md["docker_image"] + path = Path(md["task_dir"]) / md["dockerfile"] + return DockerfileImage(path=str(path), context_dir=str(path.parent)) + + def _sandbox(self, md: dict[str, Any]) -> ModalSandbox: + kwargs: dict[str, Any] = {} + if md["cpus"]: + kwargs["cpu"] = float(md["cpus"]) + if md["memory_mb"]: + kwargs["memory_mb"] = int(md["memory_mb"]) + if md["workdir"]: + kwargs["workdir"] = md["workdir"] + if TASK_TIMEOUT_OVERRIDE and md.get("build_timeout_sec"): + kwargs["boot_timeout"] = int(md["build_timeout_sec"]) + return ModalSandbox(self._image(md), **kwargs) + + def _step_specs(self, md: dict[str, Any]) -> list[dict[str, Any]]: + """Uniform step list; a single-step task becomes one pseudo-step.""" + if md["steps"]: + return md["steps"] + return [ + { + "name": None, + "instruction": md["problem_statement"], + "tests_path": "tests", + "verifier": md["verifier"], + "min_reward": None, + "agent_timeout_sec": None, + } + ] + + async def rollout( + self, + md: dict[str, Any], + *, + runtime, + session_id: str, + adapter_url: str, + agent_time_budget_sec: int, + eval_timeout_sec: int, + ) -> RewardResult: + async def agent_leg(sb, leg_md: dict[str, Any], budget_sec: int): + return await runtime.run_agent( + sb, + md=leg_md, + session_id=session_id, + adapter_url=adapter_url, + time_budget_sec=budget_sec, + ) + + return await self._episode( + md, + run_leg=agent_leg, + agent_time_budget_sec=agent_time_budget_sec, + eval_timeout_sec=eval_timeout_sec, + ) + + async def _episode( + self, + md: dict[str, Any], + *, + run_leg, + agent_time_budget_sec: int, + eval_timeout_sec: int, + ) -> RewardResult: + """Shared by the RL rollout and the oracle check: only the leg differs.""" + task_dir = Path(md["task_dir"]) + steps = self._step_specs(md) + step_results: list[dict[str, Any]] = [] + # Last agent leg's exit code + failure tail (stays None for the oracle + # leg, which runs solve.sh rather than the agent). Surfaced in extra so + # a zero-turn adapter_session_empty self-explains in the rollout dump. + last_agent = None + timer = PhaseTimer() + + t0 = time.monotonic() + async with self._sandbox(md) as sb: + timer.record("work_boot", time.monotonic() - t0) + workdir = md["workdir"] or await self._detect_workdir(sb) + q = shlex.quote + # Test scripts assume /logs/{agent,verifier,artifacts} exist. + with timer.phase("prep"): + await sb.exec( + f"mkdir -p {q(workdir)} /logs/agent /logs/verifier /logs/artifacts", + check=True, + timeout=60, + ) + + # Start the agent clock only once the sandbox is booted and prepped: + # a cold per-instance image pull can take many minutes, and charging + # it against the agent budget would exhaust the window before any step + # runs (-> zero agent turns -> adapter_session_empty). Provisioning on + # the first leg likewise gets its own clock inside _detached_run. + deadline = time.monotonic() + _effective(agent_time_budget_sec, md.get("agent_timeout_sec")) + + for step in steps: + remaining = int(deadline - time.monotonic()) + if remaining <= 0: + logger.warning( + "[harbor] %s: agent budget exhausted before step %r", md["instance_id"], step["name"] + ) + break + budget = remaining + if TASK_TIMEOUT_OVERRIDE and step.get("agent_timeout_sec"): + budget = min(budget, int(step["agent_timeout_sec"])) + + with timer.phase("prep"): + await self.write_problem_file(sb, workdir, step["instruction"]) + leg_md = {**md, "workdir": workdir} + with timer.phase("agent"): + leg_result = await run_leg(sb, leg_md, budget) + if leg_result is not None: + last_agent = leg_result + + with timer.phase("verifier"): + rewards = await self._verify( + sb, + tests_dir=task_dir / step["tests_path"], + workdir=workdir, + verifier={**md["verifier"], **(step.get("verifier") or {})}, + eval_timeout_sec=eval_timeout_sec, + instance_id=md["instance_id"], + ) + step_results.append({"name": step["name"], "rewards": rewards, "reward": _scalar_reward(rewards)}) + if not _meets_min_reward(rewards, step.get("min_reward")): + logger.info( + "[harbor] %s: step %r below min_reward; aborting remaining steps", + md["instance_id"], + step["name"], + ) + break + + reward = self._aggregate(steps, step_results, md["reward_strategy"]) + extra: dict[str, Any] = { + "harbor_step_results": step_results, + "harbor_steps_completed": len(step_results), + "harbor_steps_total": len(steps), + "timing": timer.as_dict(), + } + if last_agent is not None: + extra["agent_exit_code"] = last_agent.exit_code + extra["agent_tail"] = last_agent.tail + return RewardResult( + reward=reward, + # epsilon: weighted pytest fractions sum to 0.999... for a fully- + # passing task (seen on openthoughts-tblite bash-log-processor-fix) + is_solved=reward >= 1.0 - 1e-6, + extra=extra, + ) + + @staticmethod + def _aggregate(steps: list[dict], results: list[dict], strategy: str | None) -> float: + """Scalar episode reward from per-step scalars. + + 'mean' divides by ALL declared steps (gated steps count 0; stricter than + harbor's job-level mean, but the conservative signal is what RL wants). + 'final' is the last declared step's reward (0 if never reached). + """ + if not results: + return 0.0 + if len(steps) == 1: + return results[0]["reward"] + if (strategy or "mean") == "final": + return results[-1]["reward"] if len(results) == len(steps) else 0.0 + return sum(r["reward"] for r in results) / len(steps) + + @staticmethod + async def _detect_workdir(sb) -> str: + # Prebuilt docker_image rows may not know their WORKDIR; ask the sandbox. + ec, out, _ = await sb.exec("pwd", check=False, timeout=30) + detected = (out or "").strip().splitlines()[-1] if ec == 0 and (out or "").strip() else "" + return detected or "/app" + + # In-place verification (harbor's shared-environment Verifier semantics) + async def _verify( + self, + sb, + *, + tests_dir: Path, + workdir: str, + verifier: dict[str, Any], + eval_timeout_sec: int, + instance_id: str, + ) -> dict[str, Any] | None: + """Upload tests, run test.sh, parse the reward files. None = no verdict.""" + timeout = _effective(eval_timeout_sec, verifier.get("timeout_sec")) + + await self.upload_dir(sb, tests_dir, "/tests") + q = shlex.quote + await sb.exec( + "chmod +x /tests/test.sh && rm -f /logs/verifier/reward.json /logs/verifier/reward.txt", + check=False, + timeout=60, + ) + env = _resolve_env_templates(verifier.get("env")) + ec, out, err = await sb.exec( + f"cd {q(workdir)} && bash /tests/test.sh", + env=env or None, + timeout=timeout, + check=False, + ) + if os.environ.get("HARBOR_VERIFY_DEBUG"): + logger.info( + "[harbor-verify-debug] %s: test.sh exit=%s\n--- stdout tail ---\n%s\n--- stderr tail ---\n%s", + instance_id, + ec, + (out or "")[-3000:], + (err or "")[-3000:], + ) + + raw_json = await sb.read_file("/logs/verifier/reward.json") + if raw_json.strip(): + try: + return dict(json.loads(raw_json)) + except (ValueError, TypeError): + logger.warning("[harbor] %s: unparseable reward.json: %.200s", instance_id, raw_json) + return None + raw_txt = await sb.read_file("/logs/verifier/reward.txt") + if raw_txt.strip(): + try: + return {"reward": float(raw_txt.strip())} + except ValueError: + logger.warning("[harbor] %s: unparseable reward.txt: %.200s", instance_id, raw_txt) + return None + # No reward file: terminal-bench-style tasks end test.sh with a bare + # pytest run. Grade all-or-nothing on its exit code (0 pass, 1 fail); + # anything else stays "no verdict" so infra breakage isn't scored. + if ec in (0, 1): + logger.info("[harbor] %s: no reward file; graded from test.sh exit=%d", instance_id, ec) + return {"reward": 1.0 if ec == 0 else 0.0, "graded_from": "exit_code"} + logger.warning( + "[harbor] %s: test.sh exit=%s wrote no reward file; stderr tail: %s", + instance_id, + ec, + (err or "")[-400:], + ) + return None + + # Oracle check (reference solution through the exact rollout path) + async def oracle_episode( + self, md: dict[str, Any], *, solve_timeout_sec: int, eval_timeout_sec: int + ) -> RewardResult: + """Replace the agent leg with the task's solution/solve.sh (a counter + maps each sequential leg to its step).""" + task_dir = Path(md["task_dir"]) + steps = self._step_specs(md) + state = {"i": 0} + + async def leg(sb, leg_md: dict[str, Any], budget_sec: int) -> None: + step = steps[state["i"]] + state["i"] += 1 + solution_dir = task_dir / "steps" / step["name"] / "solution" if step["name"] else task_dir / "solution" + if not (solution_dir / "solve.sh").is_file(): + solution_dir = task_dir / "solution" + if not (solution_dir / "solve.sh").is_file(): + raise FileNotFoundError(f"no solution/solve.sh under {task_dir}") + await self.upload_dir(sb, solution_dir, "/solution") + q = shlex.quote + await sb.exec("chmod +x /solution/solve.sh", check=False, timeout=30) + ec, out, err = await sb.exec( + f"cd {q(leg_md['workdir'])} && bash /solution/solve.sh", + timeout=min(budget_sec, solve_timeout_sec), + check=False, + ) + if ec != 0: + logger.warning("[harbor-oracle] solve.sh exit=%d stderr tail: %s", ec, (err or "")[-400:]) + if os.environ.get("HARBOR_VERIFY_DEBUG"): + logger.info( + "[harbor-oracle-debug] %s: solve.sh exit=%s\n--- stdout tail ---\n%s\n--- stderr tail ---\n%s", + md["instance_id"], + ec, + (out or "")[-3000:], + (err or "")[-3000:], + ) + + return await self._episode( + md, + run_leg=leg, + agent_time_budget_sec=solve_timeout_sec * max(1, len(steps)), + eval_timeout_sec=eval_timeout_sec, + ) + + +def _oracle_main() -> int: + import argparse + import asyncio + from types import SimpleNamespace + + parser = argparse.ArgumentParser( + description="Run harbor reference solutions through the rollout path (reward should be 1.0)." + ) + parser.add_argument("jsonl", help="converted slime prompt JSONL (env/convert2slime/harbor.py output)") + parser.add_argument("--task-root", help=f"task root (default: ${TASK_ROOT_ENV} or the JSONL's directory)") + parser.add_argument("--limit", type=int, default=1, help="how many rows to check (default 1)") + parser.add_argument("--index", type=int, help="check exactly this row") + parser.add_argument("--solve-timeout", type=int, default=600) + parser.add_argument("--eval-timeout", type=int, default=int(os.environ.get("AGENT_EVAL_TIMEOUT_SEC", "600"))) + parser.add_argument( + "--vm-runtime", + action="store_true", + help="boot VM sandboxes (experimental_options={'vm_runtime': True}) instead of gVisor", + ) + args = parser.parse_args() + if args.vm_runtime: + os.environ["SLIME_AGENT_SANDBOX_VM_RUNTIME"] = "1" + + root = args.task_root or os.environ.get(TASK_ROOT_ENV) or str(Path(args.jsonl).resolve().parent) + os.environ[TASK_ROOT_ENV] = root + + rows = [] + with open(args.jsonl, encoding="utf-8") as fh: + for line in fh: + if line.strip(): + rows.append(json.loads(line)) + picked = [rows[args.index]] if args.index is not None else rows[: args.limit] + + env = HarborEnv() + failures = 0 + for row in picked: + sample = SimpleNamespace(metadata=row.get("metadata"), prompt=row.get("prompt"), label=row.get("label")) + md = env.normalize_metadata(sample) + t0 = time.monotonic() + result = asyncio.run( + env.oracle_episode(md, solve_timeout_sec=args.solve_timeout, eval_timeout_sec=args.eval_timeout) + ) + status = "OK " if result.is_solved else "FAIL" + print( + f"[{status}] {md['instance_id']}: reward={result.reward:.2f} t={time.monotonic() - t0:.0f}s {result.extra}" + ) + failures += 0 if result.is_solved else 1 + return 1 if failures else 0 + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") + raise SystemExit(_oracle_main()) diff --git a/async_rl_research/environment/swe_gym.py b/async_rl_research/environment/swe_gym.py new file mode 100644 index 0000000000..9816350b04 --- /dev/null +++ b/async_rl_research/environment/swe_gym.py @@ -0,0 +1,231 @@ +"""SWE-Gym env: git-diff capture + clean-sandbox grading on Modal. + +Schema pair of ``env/convert2slime/swe_gym.py`` (see ``base.py``). Grading is +diff-transplant: boot work sandbox -> pre_commands + problem file -> agent runs +-> capture git diff -> CLOSE work sandbox -> boot CLEAN sandbox -> re-apply +pre_commands -> apply diff -> run eval_cmd. The evaluator never sees the agent's +filesystem (only the captured diff affects reward), so tests can't be cheated. +""" + +from __future__ import annotations + +import json +import logging +import shlex +import time +from pathlib import Path +from typing import Any + +from slime.agent.sandbox import Sandbox + +from ..modal_sandbox import ModalSandbox +from ..profiles.profiling import PhaseTimer +from .base import PROBLEM_FILE, EnvMetadataError, RewardResult, RolloutEnv, coerce_prompt + +logger = logging.getLogger(__name__) + + +# Patch/pre scripts live under /tmp, outside the diff's reach. +_PATCH = "/tmp/__swe_patch__.diff" +_PRE = "/tmp/__swe_pre__.sh" + + +# Appended to the problem statement: the universal scaffold has no submission +# protocol, so spell out the deliverable (reward is the working-tree `git diff`). +_DELIVERABLE_SUFFIX = """ + +## Deliverable + +Fix the issue by editing the repository's source files in place. + +- Your work is collected as the uncommitted working-tree changes (`git diff`) of this repository when you finish: leave your edits uncommitted. +- Do NOT commit your changes and do NOT create patch files. +- Do NOT modify tests or configuration files (pyproject.toml, setup.cfg, etc.). +- Delete any reproduction scripts or scratch files you created before finishing. +""" + + +class SweGymEnv(RolloutEnv): + name = "swe_gym" + # No agent_config default: the universal scaffold + _DELIVERABLE_SUFFIX + # apply. Override per-row via metadata.agent_config; globally via MSWE_CONFIG. + + def normalize_metadata(self, sample) -> dict[str, Any]: + m = sample.metadata or {} + label = sample.label if (isinstance(sample.label, str) and len(sample.label) < 256) else None + md = { + "instance_id": m.get("instance_id") or label or "unknown", + "image": m.get("image"), + "workdir": m.get("workdir"), + "problem_statement": m.get("problem_statement") or coerce_prompt(sample.prompt), + "swepro": m.get("swepro"), + "eval_cmd": m.get("eval_cmd"), + "pre_commands": m.get("pre_commands"), + "agent_config": m.get("agent_config"), + } + if not md["image"] or not md["workdir"]: + raise EnvMetadataError("missing_image_or_workdir") + return md + + async def rollout( + self, + md: dict[str, Any], + *, + runtime, + session_id: str, + adapter_url: str, + agent_time_budget_sec: int, + eval_timeout_sec: int, + ) -> RewardResult: + workdir = md["workdir"] + timer = PhaseTimer() + t0 = time.monotonic() + async with ModalSandbox(md["image"]) as sb: + timer.record("work_boot", time.monotonic() - t0) + with timer.phase("prep"): + await self._prepare_workspace(sb, md) + with timer.phase("agent"): + agent_run = await runtime.run_agent( + sb, + md=md, + session_id=session_id, + adapter_url=adapter_url, + time_budget_sec=agent_time_budget_sec, + ) + with timer.phase("diff"): + diff_text = await self._git_diff(sb, workdir, exclude=runtime.diff_exclude_all) + + # Work sandbox is closed; grade the diff in a clean one. + with timer.phase("eval"): + reward, is_solved, applied = await self._evaluate(md, diff_text, timeout_sec=eval_timeout_sec, timer=timer) + return RewardResult( + reward=float(reward), + is_solved=bool(is_solved), + # diff_bytes/diff_files are SIZE metrics; patch text never stored. + extra={ + "applied_cleanly": bool(applied), + "diff_bytes": len(diff_text), + "diff_files": diff_text.count("diff --git"), + "timing": timer.as_dict(), + "agent_exit_code": agent_run.exit_code, + "agent_tail": agent_run.tail, + }, + ) + + # ------------------------------------------------------------------ + # Workspace prep (work sandbox; task-side, agent-agnostic) + # ------------------------------------------------------------------ + async def _prepare_workspace(self, sb: Sandbox, md: dict[str, Any]) -> None: + """Bring a freshly booted work sandbox to the task's start state. + + ``pre_commands`` (typically ``git checkout -f``) run in BOTH + work and eval sandboxes, else the diff context mismatches the eval base. + """ + # In-sandbox git ops need the repo marked safe for root. + await sb.exec("git config --system --add safe.directory '*'", check=False, timeout=60) + if md["pre_commands"]: + await _apply_pre_commands(sb, md["workdir"], md["pre_commands"]) + await self.write_problem_file(sb, md["workdir"], (md["problem_statement"] or "") + _DELIVERABLE_SUFFIX) + + # ------------------------------------------------------------------ + # Diff capture + # ------------------------------------------------------------------ + async def _git_diff(self, sb: Sandbox, workdir: str, *, exclude: tuple[str, ...] = ()) -> str: + """Capture the model's edits as a patch (``git add -N .`` so new files + appear), excluding ``PROBLEM_FILE`` + the runtime's ``diff_exclude_all``. + """ + excludes = " ".join(f"':(exclude){f}'" for f in (PROBLEM_FILE, *exclude)) + cmd = f"cd {shlex.quote(workdir)} && git add -N . && git diff -- . {excludes}" + _, out, _ = await sb.exec(cmd, user="root", timeout=120, check=False) + return out + + # ------------------------------------------------------------------ + # Eval (fresh clean sandbox, apply diff, run dataset tests) + # ------------------------------------------------------------------ + async def _evaluate( + self, md: dict[str, Any], diff_text: str, *, timeout_sec: int, timer: PhaseTimer | None = None + ) -> tuple[float, bool, bool]: + """Grade ``diff_text`` in a CLEAN sandbox; returns (reward, solved, applied).""" + if not (md["swepro"] or md["eval_cmd"]): + logger.warning("[swe_gym.evaluate] no swepro/eval_cmd; reward=0") + return 0.0, False, True + + workdir = md["workdir"] + t0 = time.monotonic() + async with ModalSandbox(md["image"]) as ev: + if timer is not None: + timer.record("eval_boot", time.monotonic() - t0) + if md["pre_commands"]: + await _apply_pre_commands(ev, workdir, md["pre_commands"]) + + applied = await _apply_diff(ev, workdir, diff_text) + if not applied: + return 0.0, False, False + + if md["swepro"]: + reward, solved = await _run_swepro(ev, workdir, md["swepro"], timeout_sec) + return reward, solved, True + reward, solved = await _run_eval_cmd(ev, workdir, md["eval_cmd"], timeout_sec) + return reward, solved, True + + +async def _apply_pre_commands(sb: Sandbox, workdir: str, pre: list[str] | str) -> None: + body = pre.replace("\\n", "\n") if isinstance(pre, str) else "\n".join(c for c in (pre or []) if c) + await sb.write_file(_PRE, "set -e\n" + body) + await sb.exec(f"cd {shlex.quote(workdir)} && bash {shlex.quote(_PRE)}", check=False, timeout=600) + + +async def _apply_diff(sb: Sandbox, workdir: str, diff_text: str) -> bool: + if not diff_text.strip(): + return True + await sb.write_file(_PATCH, diff_text) + wq = shlex.quote(workdir) + pq = shlex.quote(_PATCH) + for cmd in ( + f"cd {wq} && git apply --3way --whitespace=nowarn {pq}", + f"cd {wq} && git apply --whitespace=nowarn {pq}", + f"cd {wq} && patch -p1 --no-backup-if-mismatch < {pq}", + ): + ec, _, _ = await sb.exec(cmd, check=False, timeout=120) + if ec == 0: + return True + return False + + +async def _run_eval_cmd(sb: Sandbox, workdir: str, cmd: str, timeout: int) -> tuple[float, bool]: + # SWE-Gym-Lite's self-contained command whose exit code is the verdict. + ec, _, _ = await sb.exec(f"cd {shlex.quote(workdir)} && {cmd}", check=False, timeout=timeout) + return (1.0 if ec == 0 else 0.0), ec == 0 + + +async def _run_swepro(sb: Sandbox, workdir: str, swepro: dict, timeout: int) -> tuple[float, bool]: + # Forward-compat pass-through for swepro-style run/parse grading. + swepro_dir = "/tmp/swepro_eval" + await sb.exec(f"mkdir -p {swepro_dir} && chmod 777 {swepro_dir}", check=True, timeout=30) + for key, dst in (("run_script_path", "run_script.sh"), ("parser_script_path", "parser.py")): + host_path = swepro.get(key) + if host_path: + await sb.write_file(f"{swepro_dir}/{dst}", Path(host_path).read_text()) + await sb.exec(f"chmod -R 755 {swepro_dir}", check=False, timeout=30) + + test_arg = ",".join(swepro.get("selected_test_files") or []) + stdout_f = f"{swepro_dir}/stdout.log" + stderr_f = f"{swepro_dir}/stderr.log" + result_f = f"{swepro_dir}/result.json" + await sb.exec( + f"cd {shlex.quote(workdir)} && bash {swepro_dir}/run_script.sh " + f"{shlex.quote(test_arg)} > {stdout_f} 2> {stderr_f} || true", + check=False, + timeout=timeout, + ) + await sb.exec( + f"python3 {swepro_dir}/parser.py {stdout_f} {stderr_f} {result_f}", + check=False, + timeout=120, + ) + raw = await sb.read_file(result_f) + parsed = json.loads(raw) if raw else {"tests": []} + passed = {t["name"] for t in parsed.get("tests", []) if t.get("status") == "PASSED"} + required = set(swepro.get("fail_to_pass") or []) | set(swepro.get("pass_to_pass") or []) + solved = bool(required) and required.issubset(passed) + return (1.0 if solved else 0.0), solved diff --git a/async_rl_research/evalset.py b/async_rl_research/evalset.py new file mode 100644 index 0000000000..90864394ad --- /dev/null +++ b/async_rl_research/evalset.py @@ -0,0 +1,169 @@ +"""Build a versioned eval set by subsampling converted slime datasets. + +Writes per-subset JSONL files drawn from already-converted datasets plus a +manifest pinning the chosen instances; point the training config's inline +``eval_config`` at the subset files. + +Spec YAML:: + + task_root: /data # optional: harbor metadata.task_path values are + # rewritten relative to this dir, so ONE + # ASYNC_RL_TASK_ROOT covers train + eval rows. + # Omit to inline absolute task paths instead. + subsets: + - name: swebench_verified_50 + source: /data/swebench_verified/swebench_verified.jsonl + n: 50 # omit -> keep all rows + seed: 0 # deterministic subsample (default 0) + - name: usaco_hard + source: /data/usaco/usaco.jsonl + ids: [usaco_829, ...] # optional instance_id allowlist, applied before n + +Usage:: + + python -m async_rl_research.evalset spec.yaml --out-dir /data/evalsets/v0 + +Outputs per-subset ``.jsonl``, ``manifest.json``, and +``eval_config.yaml``, and prints the inline ``eval_config`` dict. Spec paths +must be the *runtime* paths (e.g. ``/data/...``); run the builder where they +resolve so harbor task-dir checks mean something. +""" + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any + + +def _load_spec(path: Path) -> dict[str, Any]: + import yaml + + spec = yaml.safe_load(path.read_text()) + if not isinstance(spec, dict) or not isinstance(spec.get("subsets"), list) or not spec["subsets"]: + raise SystemExit(f"{path}: spec must be a mapping with a non-empty `subsets` list") + names = [s.get("name") for s in spec["subsets"]] + if any(not n for n in names) or len(set(names)) != len(names): + raise SystemExit(f"{path}: every subset needs a unique `name` (got {names})") + for s in spec["subsets"]: + if not s.get("source"): + raise SystemExit(f"{path}: subset {s.get('name')!r} is missing `source`") + unknown = set(s) - {"name", "source", "n", "seed", "ids"} + if unknown: + raise SystemExit(f"{path}: subset {s['name']!r} has unknown keys {sorted(unknown)}") + return spec + + +def _instance_id(row: dict[str, Any], index: int) -> str: + return (row.get("metadata") or {}).get("instance_id") or row.get("label") or f"row-{index}" + + +def _rewrite_task_path(row: dict[str, Any], source_dir: Path, task_root: Path | None, problems: list[str]) -> None: + """Re-root a harbor row's relative task_path so it stays resolvable: pin it + relative to ``task_root`` when given, absolute otherwise. + """ + md = row.get("metadata") or {} + if md.get("task_type") != "harbor" or not md.get("task_path"): + return + task_dir = Path(md["task_path"]) + if not task_dir.is_absolute(): + task_dir = source_dir / task_dir + if task_root is not None: + try: + md["task_path"] = str(task_dir.relative_to(task_root)) + except ValueError: + problems.append(f"task dir {task_dir} is outside task_root {task_root}; kept absolute") + md["task_path"] = str(task_dir) + else: + md["task_path"] = str(task_dir) + if not task_dir.is_dir(): + problems.append(f"task dir not found: {task_dir}") + + +def _build_subset(subset: dict[str, Any], out_dir: Path, task_root: Path | None, strict: bool) -> dict[str, Any]: + source = Path(subset["source"]) + rows = [json.loads(line) for line in source.read_text().splitlines() if line.strip()] + ids = [_instance_id(row, i) for i, row in enumerate(rows)] + + if allow := subset.get("ids"): + missing = set(allow) - set(ids) + if missing: + raise SystemExit(f"subset {subset['name']!r}: ids not in {source}: {sorted(missing)}") + keep = [i for i, iid in enumerate(ids) if iid in set(allow)] + else: + keep = list(range(len(rows))) + + n = subset.get("n") + if n is not None and n < len(keep): + keep = sorted(random.Random(subset.get("seed", 0)).sample(keep, n)) + + problems: list[str] = [] + chosen = [] + for i in keep: + row = json.loads(json.dumps(rows[i])) # deep copy + _rewrite_task_path(row, source.parent, task_root, problems) + chosen.append(row) + + if problems: + for p in problems: + print(f" [{subset['name']}] WARNING: {p}", file=sys.stderr) + if strict: + raise SystemExit(f"subset {subset['name']!r}: {len(problems)} problem(s) with --strict") + + out_path = out_dir / f"{subset['name']}.jsonl" + with out_path.open("w") as f: + for row in chosen: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + print(f" {subset['name']}: {len(chosen)}/{len(rows)} rows from {source} -> {out_path}") + return { + "name": subset["name"], + "source": str(source), + "n_source_rows": len(rows), + "n_rows": len(chosen), + "seed": subset.get("seed", 0), + "path": str(out_path), + "instance_ids": [ids[i] for i in keep], + } + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("spec", type=Path, help="eval-set spec YAML (see module docstring)") + parser.add_argument("--out-dir", type=Path, required=True, help="eval-set output dir (one dir per version)") + parser.add_argument("--strict", action="store_true", help="fail on missing task dirs instead of warning") + args = parser.parse_args() + + spec = _load_spec(args.spec) + task_root = Path(spec["task_root"]) if spec.get("task_root") else None + # Resolve so manifest / eval_config paths are absolute. + args.out_dir = args.out_dir.resolve() + args.out_dir.mkdir(parents=True, exist_ok=True) + + built = [_build_subset(s, args.out_dir, task_root, args.strict) for s in spec["subsets"]] + + manifest = {"spec": spec, "subsets": built} + (args.out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2) + "\n") + + datasets = [{"name": b["name"], "path": b["path"]} for b in built] + import yaml + + (args.out_dir / "eval_config.yaml").write_text(yaml.dump({"eval": {"datasets": datasets}}, sort_keys=False)) + + print(f"\nwrote {args.out_dir}/manifest.json and {args.out_dir}/eval_config.yaml") + if task_root is not None: + print(f"run with ASYNC_RL_TASK_ROOT={task_root} (harbor task_paths are relative to it)") + print("\ninline eval_config for the training config:\n") + print(" eval_config = {") + print(' "defaults": {"n_samples_per_eval_prompt": 1},') + print(' "datasets": [') + for d in datasets: + print(f' {{"name": "{d["name"]}", "path": "{d["path"]}"}},') + print(" ],") + print(" }") + + +if __name__ == "__main__": + main() diff --git a/async_rl_research/generate.py b/async_rl_research/generate.py index 7c21afba70..a986b15158 100644 --- a/async_rl_research/generate.py +++ b/async_rl_research/generate.py @@ -1,90 +1,40 @@ """Generic agentic-RL rollout entrypoint for slime (design A: HTTP adapter). -Wire-up:: - - --custom-generate-function-path async_rl_research.generate.generate - -This is the **agent-agnostic** per-sample orchestrator. It owns the parts that -are identical for any in-sandbox agent and delegates the agent-specific and -sandbox-specific work to two collaborators: - - generate.py (this file) the rollout recipe + adapter/HTTP lifecycle + - trajectory merge + abort/timeout isolation. - agent/.py everything specific to one agent (which adapter, - how to launch it in the sandbox, its prompt / - env wiring). Default: agent/mini_swe_agent.py. - sandbox.py sandbox backend + SWE eval (boot / git_diff / - evaluate). NOT built yet -- contract below. - -Topology (design A -- "in-sandbox subprocess + HTTP adapter"): - - host generate(): - 1. _State (once/worker): build the driver's adapter (an aiohttp app that - speaks the agent's wire API and records exact SGLang tokens) and serve - it on a bg thread; expose adapter_url = http://$SLIME_HEAD_HOST:$PORT. - 2. open an adapter session keyed by session_id. - 3. boot a sandbox; the driver launches the agent inside it as a - subprocess. The agent dials BACK to adapter_url for every model call; - the adapter renders messages -> input_ids, calls SGLang /generate - (return_logprob), and records (prompt_ids, output_ids, logprobs). - 4. capture git diff; score it in a CLEAN sandbox (no test-cheating). - 5. finish_session() drains the recorded token segments; merge -> Sample. - -Reward is computed inline (sandbox.evaluate) and written onto the sample, so -slime's default reward-model step is skipped (generate_and_rm only calls -async_rm when sample.reward is None). - ----------------------------------------------------------------------------- -Driver contract (a driver is a *module*; default async_rl_research.agent.mini_swe_agent) ----------------------------------------------------------------------------- - ADAPTER_CLS : type[BaseAdapter] - The slime adapter class for this agent's wire protocol - (OpenAIAdapter for mini-swe-agent / litellm, AnthropicAdapter for - claude-code). Constructed as - ADAPTER_CLS(tokenizer=, sglang_url=, tool_parser=, reasoning_parser=). - - async def run_agent(sb, *, md, session_id, adapter_url, time_budget_sec) -> int - Provision + launch the agent inside the already-booted sandbox `sb`, - wait for it to finish, return an exit code. The agent must send - `session_id` as its auth/bearer so the adapter groups its turns, and - must target `adapter_url` for model calls. `md` is the normalized - dataset row (see _metadata). - ----------------------------------------------------------------------------- -sandbox.py contract (async_rl_research.sandbox -- NOT built yet) ----------------------------------------------------------------------------- - @asynccontextmanager - async def boot_agent_sandbox(image: str) -> AsyncIterator[Sandbox]: ... - - async def git_diff(sb, workdir: str) -> str: ... - - async def evaluate(*, image, workdir, diff_text, swepro=None, eval_cmd=None, - pre_commands=None, timeout_sec=600) -> tuple[float, bool, bool]: - # (reward, solved, applied_cleanly); applies diff in a CLEAN sandbox. +Wire-up: ``--custom-generate-function-path async_rl_research.generate.generate``. + +Agent- and task-agnostic per-sample orchestrator. Owns the parts identical for +any in-sandbox agent on any task family (adapter/HTTP lifecycle, session +management, trajectory merge, abort/timeout isolation) and delegates the rest to +a runtime (``agent/base.py``) and an env (``env/base.py``, picked per row by +metadata.task_type). Per sample: ``_State`` serves the runtime's adapter on a bg +thread, a session is opened keyed by session_id, ``env.rollout`` runs the agent +(which dials back to the adapter per model call), and the recorded token +segments merge into Sample(s). Reward is computed inline so slime's reward-model +step is skipped. Env knobs --------- - SLIME_HEAD_HOST public IP sandboxes use to reach the adapter (REQUIRED) + SLIME_HEAD_HOST public IP sandboxes use to reach the adapter + (REQUIRED unless MODAL_EXPOSE_ADAPTER=1) + MODAL_EXPOSE_ADAPTER 1 to publish the adapter through a modal.forward + tunnel (required on a Modal cluster) SHIM_BIND_HOST 0.0.0.0 adapter bind host on the head node SHIM_PORT 18002 adapter bind port - ASYNC_RL_AGENT_DRIVER dotted module path of the driver - (default async_rl_research.agent.mini_swe_agent) - AGENT_TIME_BUDGET_SEC 1800 wallclock budget for one agent run - AGENT_EVAL_TIMEOUT_SEC 600 wallclock cap on the evaluator sandbox - AGENT_GENERATE_GUARD_SEC full generate() guard; default budget+eval+180 + ASYNC_RL_AGENT_RUNTIME agent runtime spec (default "mini-swe") + ASYNC_RL_AGENT_DRIVER legacy alias for ASYNC_RL_AGENT_RUNTIME + ASYNC_RL_TASK_ROOT root dir relative metadata.task_path resolve against + AGENT_TIME_BUDGET_SEC 1800 total agent wallclock budget per sample + AGENT_EVAL_TIMEOUT_SEC 600 wallclock cap per grading command """ from __future__ import annotations import asyncio -import base64 -import importlib import logging import os import secrets import time import traceback -from dataclasses import dataclass from typing import Any from slime.agent.trajectory import TokenSegment, fan_out_sample_segments @@ -92,72 +42,71 @@ async def evaluate(*, image, workdir, diff_text, swepro=None, eval_cmd=None, from slime.utils.processing_utils import load_tokenizer from slime.utils.types import Sample +from .agent.base import AgentRuntime, load_runtime from .aiohttp_threaded import run_app_in_thread +from .environment.base import EnvMetadataError, RewardResult, load_env +from .modal_sandbox import SandboxBootTimeout +from .profiles import profiling logger = logging.getLogger(__name__) -DEFAULT_DRIVER = "async_rl_research.agent.mini_swe_agent" - AGENT_TIME_BUDGET_SEC = int(os.environ.get("AGENT_TIME_BUDGET_SEC", "1800")) AGENT_EVAL_TIMEOUT_SEC = int(os.environ.get("AGENT_EVAL_TIMEOUT_SEC", "600")) -# Wall-clock guard for the entire generate() call. When exceeded, the in-flight -# sample is aborted (`wall_clock_timeout`) and the rest of the rollout -# continues -- isolates one hung trajectory from the whole training step. -AGENT_GENERATE_GUARD_SEC = int(os.environ.get("AGENT_GENERATE_GUARD_SEC", "0") or 0) or ( - AGENT_TIME_BUDGET_SEC + AGENT_EVAL_TIMEOUT_SEC + 180 -) SHIM_BIND_HOST = os.environ.get("SHIM_BIND_HOST", "0.0.0.0") SHIM_PORT = int(os.environ.get("SHIM_PORT", "18002")) +# On a Modal cluster sandboxes are network-isolated and reach the adapter only +# via a public modal.forward tunnel. +MODAL_EXPOSE_ADAPTER = os.environ.get("MODAL_EXPOSE_ADAPTER", "0").strip().lower() in ("1", "true", "yes") + +def _load_runtime(args) -> AgentRuntime: + """Resolve + instantiate the agent runtime (env > arg > registry default). -def _load_driver(args): - """Resolve the agent driver *module* (env > arg > default).""" - path = ( - os.environ.get("ASYNC_RL_AGENT_DRIVER") + Validation is eager so a misdeclared runtime fails the worker boot loudly. + """ + spec = ( + os.environ.get("ASYNC_RL_AGENT_RUNTIME") + or os.environ.get("ASYNC_RL_AGENT_DRIVER") # legacy alias + or getattr(args, "agent_runtime", None) or getattr(args, "agent_driver", None) - or DEFAULT_DRIVER ) - return importlib.import_module(path) + return load_runtime(spec) -# --------------------------------------------------------------------------- -# Singleton: tokenizer + driver-selected adapter + background HTTP server. -# SingletonMeta keys per class, so there is exactly one adapter + server per -# rollout worker process; trajectories stay isolated by session_id. -# --------------------------------------------------------------------------- +# Singleton per worker process: tokenizer + adapter + bg HTTP server. class _State(metaclass=SingletonMeta): def __init__(self, args) -> None: self.args = args - self.driver = _load_driver(args) + self.runtime = _load_runtime(args) self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) self.max_context_len = int(getattr(args, "rollout_max_context_len", 0) or 0) - # Adapter reuses the SGLang parsers configured for the served model so - # tool-call bash / reasoning are parsed correctly (e.g. - # --sglang-tool-call-parser qwen3_coder, --sglang-reasoning-parser qwen3). + # Reuse the served model's SGLang parsers so tool-call / reasoning parse. self.tool_parser = getattr(args, "sglang_tool_call_parser", None) or None self.reasoning_parser = getattr(args, "sglang_reasoning_parser", None) or None sglang_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" public_host = os.environ.get("SLIME_HEAD_HOST") - if not public_host: + if not public_host and not MODAL_EXPOSE_ADAPTER: raise RuntimeError( "SLIME_HEAD_HOST is not set. Export it to the host IP that " - "sandboxes can reach for the reverse-connection to the adapter. " - "Without it the in-sandbox agent cannot dial back and the " - "rollout will silently abort." + "sandboxes can reach for the reverse-connection to the adapter, " + "or set MODAL_EXPOSE_ADAPTER=1 to publish the adapter through a " + "modal.forward tunnel (required on a Modal cluster). Without " + "either the in-sandbox agent cannot dial back and the rollout " + "will silently abort." ) - self.adapter = self.driver.ADAPTER_CLS( + self.adapter = self.runtime.adapter_cls( tokenizer=self.tokenizer, sglang_url=sglang_url, tool_parser=self.tool_parser, reasoning_parser=self.reasoning_parser, ) - # handler_cancellation=True so a client disconnect cancels the handler - # coroutine, arming the adapter's fire-and-forget /abort_request. Without - # it a cancelled client leaves an inflight sglang /generate that races - # the next release_memory_occupation and trips sglang's idle assertion. + # Per-turn timing by session; install before the app starts. + profiling.install(self.adapter.app) + # handler_cancellation=True: a client disconnect cancels the handler and + # arms /abort_request, else an inflight /generate trips sglang's idle assert. self.app_handle = run_app_in_thread( self.adapter.app, host=SHIM_BIND_HOST, @@ -165,34 +114,48 @@ def __init__(self, args) -> None: thread_name="agent-adapter", runner_kwargs={"handler_cancellation": True}, ) - # Base URL (no /v1). The driver appends whatever its wire API needs. - self.adapter_url = f"http://{public_host}:{self.app_handle.port}" + # Work past the bind can still fail (e.g. modal.forward); tear down so + # the orphaned daemon thread doesn't hold SHIM_PORT against retries. + try: + self._tunnel_cm = None + self.adapter_url = self._resolve_adapter_url(public_host) + except BaseException: + self.app_handle.stop() + raise logger.info( - "[async_rl] driver=%s adapter=%s tokenizer=%s tool_parser=%s reasoning_parser=%s", - self.driver.__name__, + "[async_rl] runtime=%s adapter=%s tokenizer=%s tool_parser=%s reasoning_parser=%s", + self.runtime.name, self.adapter_url, args.hf_checkpoint, self.tool_parser, self.reasoning_parser, ) + def _resolve_adapter_url(self, public_host: str | None) -> str: + """Pick the URL the in-sandbox agent dials back on. -# --------------------------------------------------------------------------- -# Trajectory -> Sample -# --------------------------------------------------------------------------- -@dataclass(frozen=True) -class RewardResult: - reward: float - is_solved: bool - applied_cleanly: bool + On a Modal cluster: a per-process ``modal.forward`` tunnel (one static + env can't cover multiple data-parallel workers), held on ``self`` for + the process lifetime. + """ + if not MODAL_EXPOSE_ADAPTER: + return f"http://{public_host}:{self.app_handle.port}" + import modal -def _start_session(state: _State, sample: Sample, md: dict[str, Any]) -> str: - """Register the adapter session BEFORE the agent starts. + # Blocking CM, never exited -- the process owns the tunnel until death. + self._tunnel_cm = modal.forward(self.app_handle.port) + tunnel = self._tunnel_cm.__enter__() + logger.info("[async_rl] modal.forward tunnel for adapter port %d -> %s", self.app_handle.port, tunnel.url) + return tunnel.url - The in-sandbox agent sends ``session_id`` as its auth/bearer token so the - adapter groups all of its turns under one chain. - """ + +# --------------------------------------------------------------------------- +# Trajectory -> Sample +# --------------------------------------------------------------------------- +def _start_session(state: _State, sample: Sample, md: dict[str, Any], sampling_params: dict[str, Any]) -> str: + """Register the adapter session BEFORE the agent starts (it sends + ``session_id`` as its bearer token to group its turns).""" if sample.session_id: session_id = sample.session_id elif sample.index is not None and sample.group_index is not None: @@ -200,30 +163,43 @@ def _start_session(state: _State, sample: Sample, md: dict[str, Any]) -> str: else: session_id = f"agent-{md['instance_id']}-{secrets.token_hex(8)}" sample.session_id = session_id - # sampling_defaults win over anything the agent sends, keeping the rollout - # on-policy (the adapter merges request body OVER these defaults). + # Adapter applies the request body OVER sampling_defaults; runtimes must + # strip the agent's own sampling knobs to stay on-policy. state.adapter.open_session( session_id, - sampling_defaults=_sampling_params(state.args), + sampling_defaults=_sampling_params(state.args, sampling_params), max_context_tokens=state.max_context_len, ) return session_id -def _sampling_params(args) -> dict[str, Any]: - # Kept tiny on purpose: the adapter fills the rest of its defaults. We only - # pin the knobs that must match training. Extend as needed. - if args is None: - return {} - return { - k: v - for k, v in ( - ("temperature", getattr(args, "rollout_temperature", None)), - ("top_p", getattr(args, "rollout_top_p", None)), - ("top_k", getattr(args, "rollout_top_k", None)), - ) - if v is not None - } +def _sampling_params(args, overrides: dict[str, Any] | None = None) -> dict[str, Any]: + # Pin the knobs that must match training; the adapter fills the rest. These + # become the session defaults the adapter applies UNDER each request. + # ``overrides`` (slime's sampling_params) carries the per-call values: + # the eval temperature/top_p AND ``max_new_tokens`` -- which slime sets to + # rollout_max_response_len for train and eval_max_response_len for eval. + # Forwarding it makes that the adapter's per-turn generation cap (still + # further clamped to the remaining rollout_max_context_len budget); dropping + # it would silently fall back to the adapter's hardcoded per-turn default. + params = ( + {} + if args is None + else { + k: v + for k, v in ( + ("temperature", getattr(args, "rollout_temperature", None)), + ("top_p", getattr(args, "rollout_top_p", None)), + ("top_k", getattr(args, "rollout_top_k", None)), + ("max_new_tokens", getattr(args, "rollout_max_response_len", None)), + ) + if v is not None + } + ) + for k in ("temperature", "top_p", "top_k", "max_new_tokens"): + if overrides and overrides.get(k) is not None: + params[k] = overrides[k] + return params def _merge_samples( @@ -237,21 +213,24 @@ def _merge_samples( ): """Fan TokenSegments + reward out into Sample(s). - A single linear agent chain yields one ("final") segment -> K=1 -> one - Sample. Routing through ``fan_out_sample_segments`` (which handles K==1) - keeps it correct if an agent later adds context compaction ("wipe" - segments): reward is split reward/K and siblings share ``rollout_id`` so - the per-rollout loss reducer counts the trajectory once. + A linear chain yields one Sample; routing through ``fan_out_sample_segments`` + stays correct if an agent later adds context-compaction "wipe" segments + (reward split reward/K, siblings share ``rollout_id``). """ if not segments: - return _abort_result(sample, "adapter_session_empty") + # Carry the agent's exit code + failure tail (set by the env in + # reward_result.extra) onto the abort, so a zero-turn rollout self- + # explains in the dump instead of needing tail-only Modal logs. + diag = {k: reward_result.extra[k] for k in ("agent_exit_code", "agent_tail") if k in reward_result.extra} + return _abort_result(sample, "adapter_session_empty", extra=diag) trajectory_metadata = { **(sample.metadata or {}), "instance_id": instance_id, "is_solved": reward_result.is_solved, - "applied_cleanly": reward_result.applied_cleanly, "elapsed_sec": elapsed_sec, + # Env-specific diagnostics (swe_gym: applied_cleanly; harbor: per-step). + **reward_result.extra, } fanned = fan_out_sample_segments( sample, segments, reward_result.reward, state.tokenizer, metadata=trajectory_metadata @@ -259,13 +238,13 @@ def _merge_samples( if not fanned: raise ValueError("fan-out produced no samples") logger.info( - "[async_rl] %s: reward=%.2f solved=%s applied=%s elapsed=%.1fs segments=%d", + "[async_rl] %s: reward=%.2f solved=%s elapsed=%.1fs segments=%d extra=%s", instance_id, reward_result.reward, reward_result.is_solved, - reward_result.applied_cleanly, elapsed_sec, len(fanned), + {k: v for k, v in reward_result.extra.items() if k != "agent_tail" and not isinstance(v, (list, dict))}, ) return fanned @@ -276,141 +255,101 @@ def _merge_samples( async def generate(args, sample: Sample, sampling_params: dict[str, Any], evaluation: bool = False): """Per-sample agent rollout with a wall-clock guard. - Accepts ``evaluation`` (slime passes it when present in the signature) but - treats train and eval identically -- running the agent + grading its diff - is what eval wants too. + Treats train and eval identically (run the agent + grade). """ - # `sandbox` is intentionally lazy-imported: it is not built yet (its - # contract is documented above). Everything else in this module imports and - # runs without it. - from . import sandbox - state = _State(args) - md = _metadata(sample) - if not md["image"] or not md["workdir"]: - return _abort_result(sample, "missing_image_or_workdir") + # Row -> env dispatch (lazy import). A bad row aborts THAT sample; an env + # module that won't import still raises loudly. + try: + env = load_env((sample.metadata or {}).get("task_type")) + md = env.normalize_metadata(sample) + except EnvMetadataError as e: + return _abort_result(sample, str(e)) + except (ValueError, TypeError) as e: + return _abort_result(sample, f"env_dispatch_failed:{type(e).__name__}:{e}") instance_id = md["instance_id"] - session_id = _start_session(state, sample, md) + # Enforced budgets (env defaults, or task.toml under override): recorded up + # front so the dump self-reports them even when the sample aborts. + sample.metadata = { + **(sample.metadata or {}), + "budgets": env.effective_budgets( + md, agent_time_budget_sec=AGENT_TIME_BUDGET_SEC, eval_timeout_sec=AGENT_EVAL_TIMEOUT_SEC + ), + } + session_id = _start_session(state, sample, md, sampling_params) t0 = time.time() try: - async with asyncio.timeout(AGENT_GENERATE_GUARD_SEC): - async with sandbox.boot_agent_sandbox(md["image"]) as sb: - await state.driver.run_agent( - sb, - md=md, - session_id=session_id, - adapter_url=state.adapter_url, - time_budget_sec=AGENT_TIME_BUDGET_SEC, - ) - diff_text = await sandbox.git_diff(sb, md["workdir"]) - - reward, is_solved, applied_cleanly = await sandbox.evaluate( - image=md["image"], - workdir=md["workdir"], - diff_text=diff_text, - swepro=md["swepro"], - eval_cmd=md["eval_cmd"], - pre_commands=md["pre_commands"], - timeout_sec=AGENT_EVAL_TIMEOUT_SEC, - ) - reward_result = RewardResult( - reward=float(reward), is_solved=bool(is_solved), applied_cleanly=bool(applied_cleanly) - ) - segments = await state.adapter.finish_session(session_id) - return _merge_samples( - sample=sample, - state=state, - segments=segments, - reward_result=reward_result, - elapsed_sec=time.time() - t0, - instance_id=instance_id, - ) + reward_result: RewardResult = await env.rollout( + md, + runtime=state.runtime, + session_id=session_id, + adapter_url=state.adapter_url, + agent_time_budget_sec=AGENT_TIME_BUDGET_SEC, + eval_timeout_sec=AGENT_EVAL_TIMEOUT_SEC, + ) + # Fold adapter per-turn stats into the env's phase timing. + turn_stats = profiling.pop_session_stats(session_id) + if turn_stats: + reward_result.extra.setdefault("timing", {}).update(turn_stats) + segments = await state.adapter.finish_session(session_id) + return _merge_samples( + sample=sample, + state=state, + segments=segments, + reward_result=reward_result, + elapsed_sec=time.time() - t0, + instance_id=instance_id, + ) - except asyncio.TimeoutError: - _log_timeout_diagnostic(t0) - return _abort_result(sample, "wall_clock_timeout") + except SandboxBootTimeout as e: + _attach_partial_timing(sample, session_id, t0) + return _abort_result(sample, f"boot_timeout:{e.timeout_sec}s") + except asyncio.CancelledError: + # A stray CancelledError from inside the rollout (e.g. Modal's + # synchronicity bridge) would crash the whole training step. Only a + # genuine external cancel leaves the task cancelling -- re-raise then. + if asyncio.current_task().cancelling(): + raise + logger.error("[async_rl] %s: stray CancelledError; aborting sample", instance_id) + _attach_partial_timing(sample, session_id, t0) + return _abort_result(sample, "exception:CancelledError") except Exception as e: logger.error("[async_rl] %s: rollout failed: %s\n%s", instance_id, e, traceback.format_exc()) + _attach_partial_timing(sample, session_id, t0) return _abort_result(sample, f"exception:{type(e).__name__}") finally: - # Close the sid before the next train step's release_memory_occupation; - # stragglers from this trajectory would otherwise race its idle assert. + # Close the sid before the next step's release_memory_occupation, else + # stragglers race its idle assert. await state.adapter.finish_session(session_id) # idempotent -def _log_timeout_diagnostic(t0: float) -> None: - """Dump pending-task names when the wall-clock guard fires. Never crashes.""" - try: - elapsed = time.time() - t0 - pending = [t for t in asyncio.all_tasks() if not t.done()] - stuck = [] - for t in pending[:5]: - coro = getattr(t, "_coro", None) - stuck.append(getattr(coro, "__qualname__", repr(coro))) - logger.warning( - "[async_rl] generate() wall_clock_timeout after %.1fs (guard=%ds); %d tasks pending; stuck: %s", - elapsed, - AGENT_GENERATE_GUARD_SEC, - len(pending), - stuck, - ) - except Exception: # pragma: no cover - pass - - -# --------------------------------------------------------------------------- -# Dataset-row normalization (agent-agnostic; SWE schema shared with the example) -# --------------------------------------------------------------------------- -def _wrap_f2p_script(script: str | None) -> str | None: - if not script: - return None - b64 = base64.b64encode(script.encode("utf-8")).decode("ascii") - return f"echo {b64} | base64 -d > /tmp/slime_f2p.py && python /tmp/slime_f2p.py" - - -def _metadata(sample: Sample) -> dict[str, Any]: - """Normalize the two dataset schemas (flat vs ``remote_env_info``).""" - m = sample.metadata or {} - rem = m.get("remote_env_info") or {} - label = sample.label if (isinstance(sample.label, str) and len(sample.label) < 256) else None - return { - "instance_id": m.get("instance_id") or rem.get("instance_id") or label or "unknown", - "image": m.get("image") or rem.get("image_url"), - "workdir": m.get("workdir") or rem.get("workdir"), - "problem_statement": m.get("problem_statement") or _coerce_prompt(sample.prompt), - "swepro": m.get("swepro"), - "eval_cmd": m.get("eval_cmd") or _wrap_f2p_script(rem.get("f2p_script")), - "pre_commands": m.get("pre_commands") or rem.get("pre_commands"), - } - - -def _coerce_prompt(prompt) -> str: - if isinstance(prompt, str): - return prompt - if isinstance(prompt, list): - for m in prompt: - if isinstance(m, dict) and m.get("role") == "user": - c = m.get("content") - if isinstance(c, str): - return c - if isinstance(c, list): - return "\n".join(p.get("text", "") for p in c if isinstance(p, dict) and p.get("type") == "text") - return "" +def _attach_partial_timing(sample: Sample, session_id: str, t0: float) -> None: + """On abort, keep accrued turn stats (distinguishes 'alive but slow' from + 'never dialed in').""" + stats = profiling.pop_session_stats(session_id) or {} + stats["elapsed_at_abort"] = round(time.time() - t0, 1) + sample.metadata = {**(sample.metadata or {}), "timing": stats} -def _abort(sample: Sample, reason: str) -> Sample: +def _abort(sample: Sample, reason: str, extra: dict[str, Any] | None = None) -> Sample: sample.tokens = [0, 0] sample.response = "" sample.response_length = 1 sample.loss_mask = [0] + # Shape-consistent with response_length: the train actor slices + # rollout_log_probs for every sample, so a None here crashes the step. + sample.rollout_log_probs = [0.0] sample.reward = 0.0 + # Mirror fan_out_sample_segments: build_dp_schedule groups by rollout_id, so + # a None collapses all aborts into one group and drops below global_batch_size. + sample.rollout_id = sample.index sample.status = Sample.Status.ABORTED - sample.metadata = {**(sample.metadata or {}), "abort_reason": reason} + sample.metadata = {**(sample.metadata or {}), "abort_reason": reason, **(extra or {})} logger.warning("[async_rl] aborted: %s", reason) return sample -def _abort_result(sample: Sample, reason: str): +def _abort_result(sample: Sample, reason: str, extra: dict[str, Any] | None = None): """Uniform list shape for this (potentially fan-out) generate function.""" - return [_abort(sample, reason)] + return [_abort(sample, reason, extra)] diff --git a/async_rl_research/modal_sandbox.py b/async_rl_research/modal_sandbox.py new file mode 100644 index 0000000000..459f60650a --- /dev/null +++ b/async_rl_research/modal_sandbox.py @@ -0,0 +1,482 @@ +"""Modal sandbox backend for agent rollouts. + +``ModalSandbox`` is a drop-in ``Sandbox`` protocol impl backed by +``modal.Sandbox`` (the local analog of ``E2BSandbox``). Pure infrastructure, so +the env glue and agent runtimes build on it. Image is a registry ref or a +host-side Dockerfile build (``DockerfileImage``). ``modal`` is imported lazily +so this stays importable without Modal installed. Boot concurrency and +create-retry live here so every sandbox creation is gated/retried uniformly. + +Env knobs +--------- + MODAL_BOOT_CONCURRENCY max concurrent sandbox creates (default 8) + MODAL_BOOT_RETRIES transient-create retries (default 2) + MODAL_BOOT_TIMEOUT_SEC cap on sandbox boot/image-pull (default 600) + MODAL_RPC_RETRIES transient-exec retries (default 2) + SLIME_AGENT_SANDBOX_LIFETIME_SEC sandbox max lifetime (default 3600) + (legacy alias: MODAL_SANDBOX_LIFETIME_SEC) + SLIME_AGENT_SANDBOX_MODAL_APP Modal app name (default slime-agent-sandboxes) + (legacy alias: MODAL_SANDBOX_APP_NAME) + SLIME_AGENT_SANDBOX_BLOCK_NETWORK 1 to cut sandbox outbound network + (legacy alias: MODAL_SANDBOX_BLOCK_NETWORK) + SLIME_AGENT_SANDBOX_CPU fractional cpu cores (optional) + SLIME_AGENT_SANDBOX_MEMORY_MB memory in MB (optional) + SLIME_AGENT_SANDBOX_GPU gpu spec, e.g. "a10g" (optional) + SLIME_AGENT_SANDBOX_VM_RUNTIME 1 to boot a VM sandbox instead of gVisor + (real kernel, allowlisted workspaces only; VM memory is static, floored + at MODAL_VM_MEMORY_FLOOR_MB, default 2048) + SLIME_AGENT_SANDBOX_ADD_PYTHON add a python to the image (optional) + MODAL_REGISTRY_SECRET modal.Secret name for a private registry/ECR + MODAL_ENVIRONMENT modal environment name (optional) +""" + +from __future__ import annotations + +import asyncio +import codecs +import logging +import os +import shlex +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +ExecResult = tuple[int, str, str] +FileContent = str | bytes | Path + + +class SandboxBootTimeout(Exception): + """Sandbox create/image-pull exceeded its boot budget (MODAL_BOOT_TIMEOUT_SEC).""" + + def __init__(self, timeout_sec: int, image: str = "") -> None: + self.timeout_sec = timeout_sec + super().__init__(f"sandbox boot exceeded {timeout_sec}s: {image}") + + +@dataclass(frozen=True) +class DockerfileImage: + """Build-from-Dockerfile image spec (harbor-style task environments). + + ``context_dir`` defaults to the Dockerfile's dir. Modal content-hashes the + Dockerfile + context so identical task files cache-hit across boots. FROM + pull uses Modal's default builder auth (no private base images). + """ + + path: str + context_dir: str | None = None + + @property + def description(self) -> str: + return f"dockerfile:{self.path}" + + +# Modal validates exec argv against ARG_MAX (64 KiB) client-side; larger +# commands are staged as a script file. Under the limit to leave room for the +# bash/runuser wrapper argv. +_EXEC_ARGV_LIMIT_BYTES = 32_768 + + +def _normalize_image_ref(ref: str) -> str: + """Lowercase a registry ref's repository name; preserve tag/digest. + + OCI repo names must be lowercase, but some SWE-bench dataset images carry + mixed-case orgs/repos. Tag/digest are case-sensitive and left untouched. + """ + if not ref: + return ref + name, sep, suffix = ref, "", "" + if "@" in ref: # digest pin: name@sha256:... + name, _, digest = ref.partition("@") + sep, suffix = "@", digest + else: + slash = ref.rfind("/") + colon = ref.rfind(":") + if colon > slash: # a tag colon, not a registry :port + name, sep, suffix = ref[:colon], ":", ref[colon + 1 :] + return name.lower() + sep + suffix + + +def _getenv(*names: str, default: str = "") -> str: + for name in names: + value = os.environ.get(name) + if value is not None and value.strip(): + return value + return default + + +def _getenv_int(*names: str, default: int) -> int: + raw = _getenv(*names) + return int(raw) if raw else default + + +# Process-wide create gate + cached App, lazily created on the running loop. +_BOOT_SEM: asyncio.Semaphore | None = None +_APP_CACHE: dict[str, Any] = {} +_APP_LOCK: asyncio.Lock | None = None + + +def _boot_sem() -> asyncio.Semaphore: + global _BOOT_SEM + if _BOOT_SEM is None: + _BOOT_SEM = asyncio.Semaphore(_getenv_int("MODAL_BOOT_CONCURRENCY", default=8)) + return _BOOT_SEM + + +def _app_lock() -> asyncio.Lock: + global _APP_LOCK + if _APP_LOCK is None: + _APP_LOCK = asyncio.Lock() + return _APP_LOCK + + +class ModalSandbox: + """Async context manager around ``modal.Sandbox`` (the ``Sandbox`` protocol). + + Command failures surface as exit codes; transient Modal transport errors are + retried so infra problems are never scored as a failed test. + """ + + default_lifetime_sec = 3600 + default_boot_timeout_sec = 600 + default_app_name = "slime-agent-sandboxes" + default_create_retries = 2 + default_rpc_retries = 2 + # Main process that keeps the sandbox alive (exec runs separate processes). + # Required for images that blank their entrypoint (see __aenter__). + keepalive_command = ("sleep", "infinity") + rpc_backoff_base_sec = 1.0 + # Per-stream output cap so a runaway command can't balloon host memory. + output_cap_chars = _getenv_int("MODAL_EXEC_OUTPUT_CAP", default=1_000_000) + + def __init__( + self, + image: str | DockerfileImage, + *, + timeout: int | None = None, + block_network: bool | None = None, + cpu: float | None = None, + memory_mb: int | None = None, + gpu: str | None = None, + registry_secret: str | None = None, + rpc_retries: int | None = None, + create_retries: int | None = None, + app_name: str | None = None, + add_python: str | None = None, + workdir: str | None = None, + vm_runtime: bool | None = None, + boot_timeout: int | None = None, + ) -> None: + if isinstance(image, DockerfileImage): + self.image_spec: DockerfileImage | None = image + self.image = image.description # label only + else: + self.image_spec = None + self.image = _normalize_image_ref(image) + self.timeout = timeout if timeout is not None else self._lifetime_from_env() + self.boot_timeout = boot_timeout if boot_timeout is not None else self._boot_timeout_from_env() + self.block_network = block_network if block_network is not None else self._block_network_from_env() + self.cpu = cpu if cpu is not None else self._float_from_env("SLIME_AGENT_SANDBOX_CPU", "MODAL_SANDBOX_CPU") + self.memory_mb = ( + memory_mb + if memory_mb is not None + else self._int_from_env("SLIME_AGENT_SANDBOX_MEMORY_MB", "MODAL_SANDBOX_MEMORY_MB") + ) + self.gpu = gpu or (_getenv("SLIME_AGENT_SANDBOX_GPU", "MODAL_SANDBOX_GPU") or None) + self.vm_runtime = vm_runtime if vm_runtime is not None else self._vm_runtime_from_env() + if self.vm_runtime and self.memory_mb is None: + # VM memory is static; Modal's 128MB default OOMs a VM. + self.memory_mb = _getenv_int("MODAL_VM_MEMORY_FLOOR_MB", default=2048) + self.registry_secret = registry_secret or (_getenv("MODAL_REGISTRY_SECRET") or None) + self.rpc_retries = ( + rpc_retries + if rpc_retries is not None + else _getenv_int("MODAL_RPC_RETRIES", "SLIME_AGENT_SANDBOX_RPC_RETRIES", default=self.default_rpc_retries) + ) + self.create_retries = ( + create_retries + if create_retries is not None + else _getenv_int("MODAL_BOOT_RETRIES", default=self.default_create_retries) + ) + self.app_name = app_name or _getenv( + "SLIME_AGENT_SANDBOX_MODAL_APP", "MODAL_SANDBOX_APP_NAME", default=self.default_app_name + ) + self.add_python = add_python or (_getenv("SLIME_AGENT_SANDBOX_ADD_PYTHON", "MODAL_SANDBOX_ADD_PYTHON") or None) + self.workdir = workdir or (_getenv("SLIME_AGENT_SANDBOX_WORKDIR", "MODAL_SANDBOX_WORKDIR") or None) + self._modal = None + self._sb = None + self.sandbox_id = "" + + # -- env helpers -------------------------------------------------------- + @classmethod + def _lifetime_from_env(cls) -> int: + return _getenv_int( + "SLIME_AGENT_SANDBOX_LIFETIME_SEC", "MODAL_SANDBOX_LIFETIME_SEC", default=cls.default_lifetime_sec + ) + + @classmethod + def _boot_timeout_from_env(cls) -> int: + return _getenv_int("MODAL_BOOT_TIMEOUT_SEC", default=cls.default_boot_timeout_sec) + + @staticmethod + def _vm_runtime_from_env() -> bool: + return _getenv("SLIME_AGENT_SANDBOX_VM_RUNTIME", "MODAL_SANDBOX_VM_RUNTIME").strip().lower() in ( + "1", + "true", + "yes", + ) + + @staticmethod + def _block_network_from_env() -> bool: + return _getenv("SLIME_AGENT_SANDBOX_BLOCK_NETWORK", "MODAL_SANDBOX_BLOCK_NETWORK").strip().lower() in ( + "1", + "true", + "yes", + ) + + @staticmethod + def _float_from_env(*names: str) -> float | None: + raw = _getenv(*names) + return float(raw) if raw else None + + @staticmethod + def _int_from_env(*names: str) -> int | None: + raw = _getenv(*names) + return int(raw) if raw else None + + # -- transient-error classification ------------------------------------ + @staticmethod + def _is_transient(e: BaseException) -> bool: + """True if ``e`` is a Modal transport flap safe to retry (command-level + timeouts are NOT transient).""" + name = type(e).__name__ + if "SandboxTimeout" in name or name == "TimeoutError": + return False + if name in { + "ConnectionError", + "ConnectionResetError", + "ConnectionAbortedError", + "GRPCError", + "StreamTerminatedError", + "InternalError", + "ServerError", + "RemoteError", + }: + return True + msg = str(e).lower() + return any(s in msg for s in ("connection", "unavailable", "stream terminated", "goaway", "reset")) + + async def _retry(self, op_name: str, attempts: int, coro_factory): + last_err: BaseException | None = None + for attempt in range(max(1, attempts)): + try: + return await coro_factory() + except Exception as e: + if not self._is_transient(e) or attempt + 1 >= max(1, attempts): + raise + last_err = e + backoff = self.rpc_backoff_base_sec * (2**attempt) + logger.debug( + "[modal_sandbox] %s transient %s, retry %d/%d in %.1fs: %s", + op_name, + type(e).__name__, + attempt + 1, + attempts, + backoff, + str(e)[:160], + ) + await asyncio.sleep(backoff) + assert last_err is not None + raise last_err + + # -- lifecycle ---------------------------------------------------------- + async def _get_app(self): + environment_name = _getenv("MODAL_ENVIRONMENT") or None + key = f"{self.app_name}\0{environment_name or ''}" + async with _app_lock(): + app = _APP_CACHE.get(key) + if app is None: + kwargs: dict[str, Any] = {"create_if_missing": True} + if environment_name: + kwargs["environment_name"] = environment_name + app = await self._modal.App.lookup.aio(self.app_name, **kwargs) + _APP_CACHE[key] = app + return app + + def _build_image(self): + modal = self._modal + kwargs: dict[str, Any] = {} + if self.add_python: + kwargs["add_python"] = self.add_python + if self.image_spec is not None: + spec = self.image_spec + context_dir = spec.context_dir or str(Path(spec.path).parent) + return modal.Image.from_dockerfile(spec.path, context_dir=context_dir, **kwargs) + secret = None + if self.registry_secret: + secret = modal.Secret.from_name(self.registry_secret) + if ".dkr.ecr." in self.image and secret is not None: + return modal.Image.from_aws_ecr(self.image, secret=secret, **kwargs) + return modal.Image.from_registry(self.image, secret=secret, **kwargs) + + async def __aenter__(self) -> ModalSandbox: + import modal # lazy + + self._modal = modal + app = await self._get_app() + image = self._build_image() + + create_kwargs: dict[str, Any] = { + "app": app, + "image": image, + "timeout": self.timeout, + "block_network": self.block_network, + } + if self.cpu is not None: + create_kwargs["cpu"] = self.cpu + if self.memory_mb is not None: + create_kwargs["memory"] = self.memory_mb + if self.gpu: + create_kwargs["gpu"] = self.gpu + if self.workdir: + create_kwargs["workdir"] = self.workdir + if self.vm_runtime: + create_kwargs["experimental_options"] = {"vm_runtime": True} + + async def _create(): + async with _boot_sem(): + # Explicit keepalive command. Some task images (SWE-bench-Pro) + # blank the entrypoint (`ENTRYPOINT []`) expecting an external + # `sleep infinity` from docker-compose; with no command Modal runs + # the now-empty entrypoint and the container exits immediately + # (rc 128) -> every later exec hits "sandbox already shut down". + # `exec` spawns its own processes, so this is inert for images + # that already stay up (SWE-bench verified). + return await modal.Sandbox.create.aio(*self.keepalive_command, **create_kwargs) + + try: + async with asyncio.timeout(self.boot_timeout): + self._sb = await self._retry(f"create({self.image[:48]!r})", self.create_retries, _create) + except TimeoutError: + raise SandboxBootTimeout(self.boot_timeout, self.image[:48]) from None + self.sandbox_id = str(getattr(self._sb, "object_id", "") or "") + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + sb = self._sb + if sb is None: + return + try: + await sb.terminate.aio() + except Exception as e: + logger.warning("[modal_sandbox] terminate %s failed: %s", self.sandbox_id[:8], e) + try: + await sb.wait.aio(raise_on_termination=False) + except Exception: + pass + + # -- protocol surface --------------------------------------------------- + def _require_sandbox(self): + if self._sb is None: + raise RuntimeError("ModalSandbox has not been entered") + return self._sb + + async def exec( + self, + cmd: str, + *, + user: str = "root", + env: dict[str, str] | None = None, + timeout: int = 120, + check: bool = False, + ) -> ExecResult: + sb = self._require_sandbox() + # Honor user= for agents that drop privileges. + if user and user != "root": + inner = f"runuser -u {shlex.quote(user)} -- bash -lc {shlex.quote(cmd)}" + else: + inner = cmd + if len(inner.encode("utf-8", errors="ignore")) > _EXEC_ARGV_LIMIT_BYTES: + # Too big for exec argv: stage as a script. Left behind so _retry can + # re-run; sandboxes are ephemeral. + script = f"/tmp/.modal_exec_{uuid.uuid4().hex}.sh" + await self.write_file(script, inner) + inner = f"bash {shlex.quote(script)}" + secrets = [self._modal.Secret.from_dict({str(k): str(v) for k, v in env.items()})] if env else [] + + async def _run() -> ExecResult: + # text=False: Modal's text mode decodes strictly, so one non-UTF8 + # byte kills the rollout. Take bytes, decode host-side with replace. + proc = await sb.exec.aio("bash", "-lc", inner, timeout=timeout, secrets=secrets, text=False) + # Drain both streams BEFORE wait(): a full pipe buffer deadlocks wait(). + out_task = asyncio.create_task(_read_stream_capped(proc.stdout, self.output_cap_chars)) + err_task = asyncio.create_task(_read_stream_capped(proc.stderr, self.output_cap_chars)) + exit_code = int(await proc.wait.aio()) + stdout, stderr = await asyncio.gather(out_task, err_task) + return exit_code, stdout, stderr + + exit_code, stdout, stderr = await self._retry(f"exec({cmd[:48]!r})", self.rpc_retries, _run) + if check and exit_code != 0: + raise RuntimeError(f"modal exec failed (exit={exit_code}): {cmd[:120]}\n{stderr[-1000:]}") + return exit_code, stdout, stderr + + async def write_file(self, sandbox_path: str, content: FileContent, *, user: str = "root") -> None: + sb = self._require_sandbox() + fs = sb.filesystem + + async def _write(): + if isinstance(content, Path): + await fs.copy_from_local.aio(str(content), sandbox_path) + elif isinstance(content, bytes): + await fs.write_bytes.aio(content, sandbox_path) + else: + await fs.write_text.aio(str(content), sandbox_path) + + await self._retry(f"write_file({sandbox_path})", self.rpc_retries, _write) + if user and user != "root": + quoted = shlex.quote(user) + await self.exec(f"chown {quoted}:{quoted} {shlex.quote(sandbox_path)}", timeout=30, check=False) + + async def read_file(self, sandbox_path: str, *, user: str = "root") -> str: + del user # advisory: Modal reads as the sandbox owner + sb = self._require_sandbox() + try: + return await self._retry( + f"read_file({sandbox_path})", + self.rpc_retries, + lambda: sb.filesystem.read_text.aio(sandbox_path), + ) + except Exception: + return "" + + +async def _read_stream_capped(stream: Any, cap: int) -> str: + """Drain ``stream`` fully but keep only the first ``cap`` chars (tail dropped + with a marker). Decodes byte chunks incrementally with ``errors="replace"`` + so a split multibyte char doesn't mojibake and non-UTF8 never raises. + """ + if stream is None: + return "" + decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") + parts: list[str] = [] + total = 0 + truncated = False + async for chunk in stream: + if chunk is None: + continue + text = decoder.decode(chunk) if isinstance(chunk, (bytes, bytearray)) else chunk + if total < cap: + take = text[: cap - total] + parts.append(take) + total += len(take) + if total >= cap: + truncated = True + else: + truncated = True + out = "".join(parts) + if truncated: + out += "\n...[truncated]" + return out diff --git a/async_rl_research/sandbox.py b/async_rl_research/sandbox.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/async_rl_research/scripts_agenticRL/.gitignore b/async_rl_research/scripts_agenticRL/.gitignore new file mode 100644 index 0000000000..3493937931 --- /dev/null +++ b/async_rl_research/scripts_agenticRL/.gitignore @@ -0,0 +1,6 @@ +# local launch scripts — machine-specific, never tracked +* +!.gitignore +!qwen3_dapo_og.sh +!qwen3_6_swe_eval.sh +!qwen3_6_colocate.sh diff --git a/async_rl_research/scripts_agenticRL/qwen3_6_swe_eval.sh b/async_rl_research/scripts_agenticRL/qwen3_6_swe_eval.sh new file mode 100755 index 0000000000..49fc68ca5c --- /dev/null +++ b/async_rl_research/scripts_agenticRL/qwen3_6_swe_eval.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -euo pipefail +cd "${GUIDE:-$HOME/Documents/Research/async-rl/multinode-training-guide}" +export EXPERIMENT_CONFIG=w_qwen3_6_swe_eval_2n +export MODAL_ENVIRONMENT=${MODAL_ENVIRONMENT:-junlin-dev} WANDB_PROJECT=${WANDB_PROJECT:-Modal} + +# uv run --no-dev modal run slime/modal_train.py::download_data +uv run --no-dev modal run -d slime/modal_train.py::train diff --git a/async_rl_research/scripts_agenticRL/qwen3_dapo_og.sh b/async_rl_research/scripts_agenticRL/qwen3_dapo_og.sh new file mode 100644 index 0000000000..94aa4a9eb7 --- /dev/null +++ b/async_rl_research/scripts_agenticRL/qwen3_dapo_og.sh @@ -0,0 +1,15 @@ +export EXPERIMENT_CONFIG=qwen3_dapo +export MODAL_ENVIRONMENT=junlin-dev +export WANDB_PROJECT=Modal +export WANDB_GROUP=qwen3-30b-a3b-dapo-math-1n + + + +cd /Users/junlin/Documents/Research/async-rl/multinode-training-guide + + + +# uv run --no-dev modal run slime/modal_train.py::download_model +# uv run --no-dev modal run slime/modal_train.py::download_data +# uv run --no-dev modal run slime/modal_train.py::convert_hf_to_megatron_checkpoint +uv run --no-dev modal run -d slime/modal_train.py::train diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 74680e2ada..e378c27528 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -591,6 +591,26 @@ def save_model(self, rollout_id: int, force_sync: bool = False) -> None: if self.args.offload_train: self.sleep() + def save_hf(self, rollout_id: int = 0) -> None: + """DEBUG: dump HF via the real resync converter only (no megatron ckpt). + + Used by the qwen3.6 resync probe to test the live TP/EP gather+convert in + isolation. Resolves hf_checkpoint to a local dir (raw save_hf requires it). + """ + import os + + if self.args.offload_train: + self.wake_up() + if not os.path.isdir(self.args.hf_checkpoint): + from huggingface_hub import snapshot_download + + self.args.hf_checkpoint = snapshot_download(self.args.hf_checkpoint, local_files_only=True) + from slime.backends.megatron_utils.model import save_hf_model + + save_hf_model(self.args, rollout_id, self.model) + if self.args.offload_train: + self.sleep() + @timer def update_weights(self) -> None: if self.args.debug_train_only or self.args.debug_rollout_only: diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index 27ad610ad9..4186581b10 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -141,6 +141,10 @@ def save_model(self, rollout_id, force_sync=False): """Save actor model""" return ray.get([actor.save_model.remote(rollout_id, force_sync=force_sync) for actor in self._actor_handlers]) + def save_hf(self, rollout_id=0): + """DEBUG (qwen3.6 resync probe): dump HF via the real resync converter.""" + return ray.get([actor.save_hf.remote(rollout_id) for actor in self._actor_handlers]) + def update_weights(self): """Broadcast weights from rank 0 to all other ranks.""" return ray.get([actor.update_weights.remote() for actor in self._actor_handlers]) diff --git a/slime/utils/logging_utils.py b/slime/utils/logging_utils.py index 11348a4074..3eff1514b4 100644 --- a/slime/utils/logging_utils.py +++ b/slime/utils/logging_utils.py @@ -38,6 +38,14 @@ def update_tracking_open_metrics(args, router_addr): def finish_tracking(args): if not args.use_wandb: return + try: + logger_actor = wandb_utils.get_logger_actor() + if logger_actor is not None: + import ray + + ray.get(logger_actor.finish.remote(), timeout=120) + except Exception: + logging.getLogger(__name__).exception("Failed to finish wandb logger actor") try: if wandb.run is not None: wandb.finish() @@ -48,7 +56,21 @@ def finish_tracking(args): # TODO further refactor, e.g. put TensorBoard init to the "init" part def log(args, metrics, step_key: str): if args.use_wandb: - wandb.log(metrics) + # All history must go through the single primary writer; metrics + # logged from shared-mode secondary processes are ingested hours + # late (or dropped) by the W&B backend. See wandb_utils. The call is + # synchronous (it is cheap and infrequent) so that no metric can be + # lost in a shutdown race and actor failures are surfaced here. + logger_actor = wandb_utils.get_logger_actor() + if logger_actor is not None: + try: + import ray + + ray.get(logger_actor.log.remote(metrics), timeout=60) + except Exception: + logging.getLogger(__name__).exception("Failed to log metrics via wandb logger actor") + elif wandb.run is not None: + wandb.log(metrics) if args.use_tensorboard: metrics_except_step = {k: v for k, v in metrics.items() if k != step_key} diff --git a/slime/utils/wandb_utils.py b/slime/utils/wandb_utils.py index 81dbe9f124..4ba52f0878 100644 --- a/slime/utils/wandb_utils.py +++ b/slime/utils/wandb_utils.py @@ -1,11 +1,30 @@ import logging +import math import os +import socket +import threading from copy import deepcopy import wandb logger = logging.getLogger(__name__) +# Name of the Ray actor that owns the W&B run and performs ALL history writes. +# +# Why a single writer: on the current W&B backend, history logged by +# ``mode="shared"`` secondary writers (``x_primary=False``) is not ingested in +# real time — it lands hours late via a backfill path, or is dropped entirely +# when the writer process dies before flushing. The same delayed path swallows +# everything logged after a run is finished and re-initialized with +# ``resume="allow"``. So the run looks completely empty in the UI during (and +# long after) training. Funneling every ``wandb.log`` through the one primary +# writer that created the run keeps metrics on the real-time path. Secondary +# processes still attach in shared mode, but only for console logs and +# per-node system metrics. +LOGGER_ACTOR_NAME = "slime_wandb_logger" + +_logger_actor = None + def _is_offline_mode(args) -> bool: """Detect whether W&B should run in offline mode. @@ -19,28 +38,8 @@ def _is_offline_mode(args) -> bool: return os.environ.get("WANDB_MODE") == "offline" -def init_wandb_primary(args): - if not args.use_wandb: - args.wandb_run_id = None - return - - # Set W&B mode if specified (overrides WANDB_MODE env var) - if args.wandb_mode: - os.environ["WANDB_MODE"] = args.wandb_mode - if args.wandb_mode == "offline": - logger.info("W&B offline mode enabled. Data will be saved locally.") - elif args.wandb_mode == "disabled": - logger.info("W&B disabled mode enabled. No data will be logged.") - elif args.wandb_mode == "online": - logger.info("W&B online mode enabled. Data will be uploaded to cloud.") - - offline = _is_offline_mode(args) - - # Only perform explicit login when NOT offline - if (not offline) and args.wandb_key is not None: - wandb.login(key=args.wandb_key, host=args.wandb_host) - - # Prepare wandb init parameters +def _primary_init_kwargs(args): + """Build the wandb.init kwargs shared by the offline path and the logger actor.""" # add random 6 length string with characters if args.wandb_random_suffix: group = args.wandb_group + "_" + wandb.util.generate_id() @@ -49,7 +48,6 @@ def init_wandb_primary(args): group = args.wandb_group run_name = args.wandb_group - # Prepare wandb init parameters init_kwargs = { "entity": args.wandb_team, "project": args.wandb_project, @@ -58,34 +56,174 @@ def init_wandb_primary(args): "config": _compute_config_for_logging(args), } - # Configure settings based on offline/online mode - if offline: - init_kwargs["settings"] = wandb.Settings(mode="offline") - else: - init_kwargs["settings"] = wandb.Settings(mode="shared", x_primary=True) - - # Add custom directory if specified if args.wandb_dir: # Ensure directory exists to avoid backend crashes os.makedirs(args.wandb_dir, exist_ok=True) init_kwargs["dir"] = args.wandb_dir logger.info(f"W&B logs will be stored in: {args.wandb_dir}") - wandb.init(**init_kwargs) + return init_kwargs - _init_wandb_common() +class WandbLoggerActor: + """Ray actor that owns the W&B run and is its only history writer. + + See the comment on ``LOGGER_ACTOR_NAME`` for why all metrics must flow + through this single process. + """ + + def __init__(self, args): + self.args = args + self._scraper_thread = None + self._stop = threading.Event() + + if args.wandb_mode: + os.environ["WANDB_MODE"] = args.wandb_mode + if args.wandb_key is not None: + wandb.login(key=args.wandb_key, host=args.wandb_host) + + init_kwargs = _primary_init_kwargs(args) + init_kwargs["settings"] = wandb.Settings(mode="shared", x_primary=True, x_label="primary-logger") + wandb.init(**init_kwargs) + _init_wandb_common() + + def get_run_id(self): + return wandb.run.id + + def log(self, metrics: dict): + if wandb.run is not None: + wandb.log(metrics) + + def start_open_metrics(self, router_addr: str): + """Poll the sglang router's metrics endpoint and log it as history. + + Replaces the previous finish + re-init with + ``x_stats_open_metrics_endpoints``: resuming a finished shared-mode + run sends all subsequent metric streams down the W&B backfill path + (hours of delay), which made runs look empty. + """ + if self._scraper_thread is not None: + return + url = f"{router_addr}/engine_metrics" + self._scraper_thread = threading.Thread(target=self._scrape_loop, args=(url,), daemon=True) + self._scraper_thread.start() + logger.info(f"Scraping SGLang engine metrics from {url}.") + + def _scrape_loop(self, url): + import urllib.request + + while not self._stop.wait(30): + try: + with urllib.request.urlopen(url, timeout=10) as resp: + text = resp.read().decode("utf-8", errors="replace") + except Exception: + continue + metrics = _parse_prometheus_text(text) + if metrics and wandb.run is not None: + wandb.log({f"sgl_engine/{name}": value for name, value in metrics.items()}) + + def finish(self): + # Idempotent: called from RolloutManager.dispose() and again from the + # driver's finish_tracking(). + self._stop.set() + if wandb.run is not None: + wandb.finish() + + +def _parse_prometheus_text(text): + """Parse Prometheus text exposition into {metric_name: mean across series}.""" + sums: dict[str, float] = {} + counts: dict[str, int] = {} + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "{" in line: + name = line[: line.index("{")] + rest = line[line.index("}") + 1 :].split() + else: + parts = line.split() + name, rest = parts[0], parts[1:] + if not rest: + continue + try: + value = float(rest[0]) + except ValueError: + continue + if math.isnan(value) or math.isinf(value): + continue + sums[name] = sums.get(name, 0.0) + value + counts[name] = counts.get(name, 0) + 1 + return {name: sums[name] / counts[name] for name in sums} + + +def get_logger_actor(): + """Return the primary logger actor handle, or None (e.g. offline mode).""" + global _logger_actor + if _logger_actor is None: + try: + import ray + + if not ray.is_initialized(): + return None + _logger_actor = ray.get_actor(LOGGER_ACTOR_NAME) + except Exception: + return None + return _logger_actor + + +def init_wandb_primary(args): + if not args.use_wandb: + args.wandb_run_id = None + return + + # Set W&B mode if specified (overrides WANDB_MODE env var) + if args.wandb_mode: + os.environ["WANDB_MODE"] = args.wandb_mode + if args.wandb_mode == "offline": + logger.info("W&B offline mode enabled. Data will be saved locally.") + elif args.wandb_mode == "disabled": + logger.info("W&B disabled mode enabled. No data will be logged.") + elif args.wandb_mode == "online": + logger.info("W&B online mode enabled. Data will be uploaded to cloud.") + + if _is_offline_mode(args) or args.wandb_mode == "disabled": + # Offline/disabled: every process writes locally (or not at all); no + # actor needed. For "disabled", the WANDB_MODE env var set above must + # stay in charge — an explicit Settings(mode=...) would override it. + init_kwargs = _primary_init_kwargs(args) + if _is_offline_mode(args): + init_kwargs["settings"] = wandb.Settings(mode="offline") + wandb.init(**init_kwargs) + _init_wandb_common() + args.wandb_run_id = wandb.run.id + return + + if args.wandb_key is not None: + wandb.login(key=args.wandb_key, host=args.wandb_host) + + import ray + + global _logger_actor + _logger_actor = ray.remote(num_cpus=0)(WandbLoggerActor).options(name=LOGGER_ACTOR_NAME).remote(args) # Set wandb_run_id in args for easy access throughout the training process - args.wandb_run_id = wandb.run.id + args.wandb_run_id = ray.get(_logger_actor.get_run_id.remote()) + + # Attach the driver as a shared-mode secondary so its console output and + # head-node system metrics still reach the run. + init_wandb_secondary(args, role="driver") def reinit_wandb_primary_with_open_metrics(args, router_addr): - """Re-initialize the primary W&B run with open metrics endpoints. + """Start uploading SGLang engine metrics now that the router is up. The primary wandb init happens before rollout servers start (to obtain - ``wandb_run_id`` for secondary processes). This function is called - *after* servers are up so the router address is available for scraping - SGLang Prometheus metrics via the primary process's stats monitor. + ``wandb_run_id`` for secondary processes). This function is called + *after* servers are up so the router address is available. The logger + actor scrapes the router's Prometheus endpoint itself — the previous + finish + re-init with ``x_stats_open_metrics_endpoints`` made the W&B + backend route all subsequent metrics through its hours-delayed backfill + path, so runs looked empty. """ if not args.use_wandb or _is_offline_mode(args): return @@ -93,8 +231,7 @@ def reinit_wandb_primary_with_open_metrics(args, router_addr): return if router_addr is None: return - wandb_run_id = getattr(args, "wandb_run_id", None) - if wandb_run_id is None: + if getattr(args, "wandb_run_id", None) is None: return import sglang_router @@ -105,34 +242,10 @@ def reinit_wandb_primary_with_open_metrics(args, router_addr): ) return - logger.info(f"Re-initializing primary W&B with SGLang metrics at {router_addr}.") - - wandb.finish() - - init_kwargs = { - "id": wandb_run_id, - "entity": args.wandb_team, - "project": args.wandb_project, - "resume": "allow", - "reinit": True, - "settings": wandb.Settings( - mode="shared", - x_primary=True, - x_stats_open_metrics_endpoints={ - "sgl_engine": f"{router_addr}/engine_metrics", - }, - x_stats_open_metrics_filters={ - "sgl_engine.*": {}, - }, - ), - } - - if args.wandb_dir: - os.makedirs(args.wandb_dir, exist_ok=True) - init_kwargs["dir"] = args.wandb_dir - - wandb.init(**init_kwargs) - _init_wandb_common() + actor = get_logger_actor() + if actor is None: + return + actor.start_open_metrics.remote(router_addr) def _compute_config_for_logging(args): @@ -198,6 +311,9 @@ def init_wandb_secondary(args, role=None): mode="shared", x_primary=False, x_update_finish_state=False, + # Distinct labels keep per-process system metrics and console + # logs from clobbering each other on the W&B backend. + x_label=f"{role or 'worker'}-{socket.gethostname()}-{os.getpid()}", ) init_kwargs = { diff --git a/tests/test_adapter_session_empty_diag.py b/tests/test_adapter_session_empty_diag.py new file mode 100644 index 0000000000..36b9b98a6b --- /dev/null +++ b/tests/test_adapter_session_empty_diag.py @@ -0,0 +1,168 @@ +"""Regression tests for persisting the agent exit code + log tail. + +A zero-turn rollout aborts as ``adapter_session_empty`` on the graceful success +path (the agent process launched but made no adapter calls). Previously the dump +could not say *why* turns=0 — ``_detached_run`` only logged the exit code/tail to +tail-only Modal logs that age out. These tests pin the wiring that now carries +``agent_exit_code`` (+ a failure-only ``agent_tail``) into the abort sample's +metadata so the dump self-explains (e.g. exit=137 -> OOM-killed). + +- ``test_detached_run_returns_exit_and_tail`` (unit): fake sandbox, asserts + ``_detached_run`` returns the parsed exit code with the log tail on a nonzero + exit and an empty tail on a clean exit. +- ``test_empty_session_carries_agent_diag`` (unit): ``_merge_samples`` empty + path copies ``agent_exit_code``/``agent_tail`` from ``reward_result.extra`` + onto the aborted sample's metadata, and stays clean when they're absent. +- ``test_abort_result_merges_extra`` (unit): ``_abort_result`` merges an extra + dict without clobbering ``abort_reason``. +""" +import asyncio +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from async_rl_research.agent.base import EXIT_BUDGET_EXCEEDED, AgentRunResult, AgentRuntime +from async_rl_research.environment.base import RewardResult +from async_rl_research.generate import _abort_result, _merge_samples +from slime.utils.types import Sample + +NUM_GPUS = 0 + + +def test_agent_run_result_defaults(): + assert AgentRunResult(0).tail == "" + assert AgentRunResult(137, "boom").exit_code == 137 + assert EXIT_BUDGET_EXCEEDED == -2 + + +# -------------------------------------------------------------------------- +# _detached_run: the new exit-code + failure-tail capture +# -------------------------------------------------------------------------- +class _FakeRuntime(AgentRuntime): + name = "fake" + adapter_cls = object # non-None is all __init_subclass__ requires + + async def run_agent(self, *a, **k): # unused; abstract method must exist + raise NotImplementedError + + +class _FakeSandbox: + """Minimal sandbox: the done-marker poll yields ``exit_code``; the tail + command yields ``tail_text``; everything else (rm/chmod/setsid) is a no-op.""" + + def __init__(self, exit_code, tail_text): + self._exit = exit_code + self._tail = tail_text + + async def write_file(self, path, body): + return None + + async def exec(self, cmd, check=False, timeout=None): + if "cat" in cmd and "_done" in cmd: # poll: `test -f .. && cat ..` + return (0, str(self._exit), "") + if "tail -c 4000" in cmd: + return (0, self._tail, "") + return (0, "", "") # rm / chmod / setsid launch + + +def _run_detached(exit_code, tail_text): + rt = _FakeRuntime() + sb = _FakeSandbox(exit_code, tail_text) + return asyncio.run( + rt._detached_run( + sb, + workdir="/app", + command="true", + time_budget_sec=5, + poll_interval_sec=0.01, # don't sleep 5s in a test + ) + ) + + +def test_detached_run_returns_exit_and_tail(): + # nonzero exit -> tail captured and returned (137 == 128 + SIGKILL == OOM) + res = _run_detached(137, "fatal: Out of memory\nKilled") + assert isinstance(res, AgentRunResult) + assert res.exit_code == 137 + assert "Out of memory" in res.tail + + # clean exit -> no tail read (empty), per "tail only on failure" + ok = _run_detached(0, "should-not-be-read") + assert ok.exit_code == 0 + assert ok.tail == "" + + +# -------------------------------------------------------------------------- +# _merge_samples empty path carries the diag onto the abort +# -------------------------------------------------------------------------- +def _merge_empty(extra): + # On the empty path _merge_samples returns before touching `state`, so a + # dummy is safe. + return _merge_samples( + sample=Sample(index=0, prompt="x"), + state=None, + segments=[], + reward_result=RewardResult(reward=0.0, is_solved=False, extra=extra), + elapsed_sec=1.0, + instance_id="gravitational__teleport-deadbeef", + ) + + +def test_empty_session_carries_agent_diag(): + out = _merge_empty({"agent_exit_code": 137, "agent_tail": "Killed (OOM)", "harbor_steps_total": 1}) + assert len(out) == 1 + md = out[0].metadata + assert md["abort_reason"] == "adapter_session_empty" + assert md["agent_exit_code"] == 137 + assert md["agent_tail"] == "Killed (OOM)" + # only the agent diag is copied, not unrelated extra keys + assert "harbor_steps_total" not in md + + +def test_empty_session_without_diag_is_clean(): + # e.g. budget exhausted before any leg ran -> extra has no agent_* keys + out = _merge_empty({}) + md = out[0].metadata + assert md["abort_reason"] == "adapter_session_empty" + assert "agent_exit_code" not in md + assert "agent_tail" not in md + + +def test_venv_setup_hardens_pydantic_core_import(): + # Provisioning verifies the agent's deep import `import minisweagent.agents.default` + # and force-reinstalls (--reinstall --no-cache) to repair a partial wheel. + # `-P` is load-bearing: the check runs with cwd = the image WORKDIR (e.g. + # /testbed), so without it a task repo named like an agent dep (the pydantic + # SWE-gym tasks ship /testbed/pydantic/) shadows the venv and crashes the + # import -> a deterministic `exception:RuntimeError`. Verified fixed on + # pydantic-6104/6043/8511 (rc 1 -> 0). See profiles/provisioning_repro_pydantic.py. + from async_rl_research.agent.mini_swe_agent import _VENV_CHECK, _VENV_SETUP + + assert "minisweagent.agents.default" in _VENV_SETUP + assert "--reinstall" in _VENV_SETUP and "--no-cache" in _VENV_SETUP + assert "minisweagent.agents.default" in _VENV_CHECK + # the import checks MUST use `-P` (keep cwd/workdir off sys.path) + assert "-P -c 'import minisweagent.agents.default'" in _VENV_CHECK + assert _VENV_SETUP.count("-P -c 'import minisweagent.agents.default'") == 2 + # the repair script must be valid bash (it's assembled as a Python string) + import shutil + import subprocess + + bash = shutil.which("bash") + if bash: + r = subprocess.run([bash, "-n"], input=_VENV_SETUP, text=True, capture_output=True) + assert r.returncode == 0, f"_VENV_SETUP is not valid bash:\n{r.stderr}" + + +def test_abort_result_merges_extra(): + out = _abort_result(Sample(index=1, prompt="x"), "adapter_session_empty", extra={"agent_exit_code": 1}) + md = out[0].metadata + assert md["abort_reason"] == "adapter_session_empty" + assert md["agent_exit_code"] == 1 + # extra is optional: other abort reasons still work + out2 = _abort_result(Sample(index=2, prompt="x"), "boot_timeout:600s") + assert out2[0].metadata["abort_reason"] == "boot_timeout:600s" + assert "agent_exit_code" not in out2[0].metadata diff --git a/tests/test_policy_loss_modes.py b/tests/test_policy_loss_modes.py index 367e5de6c5..60c57a589d 100644 --- a/tests/test_policy_loss_modes.py +++ b/tests/test_policy_loss_modes.py @@ -5,7 +5,8 @@ from slime.utils.misc import load_function from slime.utils.ppo_utils import compute_policy_loss -from slime_plugins.losses.cispo import cispo_policy_loss_function, compute_policy_loss as compute_cispo_policy_loss +from slime_plugins.losses.cispo import cispo_policy_loss_function +from slime_plugins.losses.cispo import compute_policy_loss as compute_cispo_policy_loss @pytest.mark.unit diff --git a/tests/test_qwen_adapter_splice.py b/tests/test_qwen_adapter_splice.py new file mode 100644 index 0000000000..7d876f1df8 --- /dev/null +++ b/tests/test_qwen_adapter_splice.py @@ -0,0 +1,248 @@ +"""Regression tests for the QwenOpenAIAdapter prompt-splice fix. + +The qwen3_coder tool-call parser strips trailing whitespace from tool-call +arguments, so re-rendering the parsed assistant message is not token-identical to +the model's raw ``output_ids``. When ``_build_prompt`` re-rendered, that mismatch +made ``merge_turns`` log "prefix drift" and mask whole assistant turns out of +training. The fix splices the previous turn's raw ``output_ids`` into the next +prompt so prompt == training target by construction. + +- ``test_splice_invariant_no_drift`` (unit): fake tokenizer, asserts every + appended prompt starts with ``prompt_{i-1} + output_{i-1}`` and ``merge_turns`` + trains 100% of output tokens with zero drift warnings. +- ``test_real_template_trailing_whitespace_*`` (integration): real Qwen3.6 + tokenizer, reproduces the trailing-whitespace drift and shows the splice path + eliminates it while the old re-render path does not. +""" +import json +import logging +import sys +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from slime.agent.adapters import openai as O +from slime.agent.adapters.common import AdapterChain, render_token_ids +from slime.agent.trajectory import TurnRecord, merge_turns +from async_rl_research.agent.adapters.qwen import _build_prompt as splice_build_prompt +from async_rl_research.agent.adapters.qwen import _dictify_tool_arguments + +NUM_GPUS = 0 + +TOOLS = [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {"command": {"type": "string"}}, "required": ["command"]}, + }, + } +] + + +class _Session: + """Minimal stand-in for the adapter Session that ``_select_kind`` needs.""" + + def __init__(self): + self.main = AdapterChain() + self.segments = [] + + +class _DriftCapture: + """Capture warnings emitted by ``merge_turns``.""" + + def __enter__(self): + self.msgs = [] + self._handler = logging.Handler() + self._handler.emit = lambda record: self.msgs.append(record.getMessage()) + self._logger = logging.getLogger("slime.agent.trajectory") + self._old_level = self._logger.level + self._logger.addHandler(self._handler) + self._logger.setLevel(logging.WARNING) + return self + + def __exit__(self, *exc): + self._logger.removeHandler(self._handler) + self._logger.setLevel(self._old_level) + + +def _drive_episode(tok, raw_outputs, build_prompt, echo_for): + """Run the adapter's per-turn loop. ``raw_outputs[i]`` is (output_ids, raw_text); + ``echo_for(raw_text)`` returns the OpenAI assistant message handed back to the + harness (models the parser). Returns (chain, merged_segment, drift_warnings).""" + s = _Session() + target = s.main + messages = [ + {"role": "system", "content": "You are a helpful assistant that can interact with a computer."}, + {"role": "user", "content": "Please solve the task. Use the bash tool."}, + ] + for i, (output_ids, raw_text) in enumerate(raw_outputs): + kind = O._select_kind(s, messages) + prompt_ids = build_prompt(target, messages, TOOLS, kind, tok) + target.turns.append( + TurnRecord( + prompt_ids=list(prompt_ids), + output_ids=list(output_ids), + finish_reason="tool_calls", + output_log_probs=[-0.01] * len(output_ids), + ) + ) + messages = messages + [ + echo_for(raw_text), + {"role": "tool", "tool_call_id": "call_0", "content": f"0\n{i}"}, + ] + with _DriftCapture() as cap: + seg = merge_turns(target.turns) + return target, seg, list(cap.msgs) + + +# --------------------------------------------------------------------------- # +# unit: fake tokenizer, splice invariant + clean merge # +# --------------------------------------------------------------------------- # + +_IM_START, _IM_END, _NL = 800, 801, 802 +_ROLE = {"system": 810, "user": 811, "assistant": 812, "tool": 813} + + +class _FakeTokenizer: + """Prefix-consistent chat template: each message renders to a fixed block + ending in ``<|im_end|>`` + ``\\n``; the generation prompt appends its own block. + Models enough structure for ``_tool_continuation_ids``' sentinel-and-slice.""" + + def _block(self, m): + content = m.get("content") or "" + ctoks = [(ord(c) % 50) + 100 for c in str(content)[:5]] + calls = m.get("tool_calls") or [] + ctoks += [777] * len(calls) # a fixed per-tool-call marker + return [_IM_START, _ROLE.get(m.get("role"), 811)] + ctoks + [_IM_END, _NL] + + def apply_chat_template(self, messages, tools=None, tokenize=True, add_generation_prompt=True): + ids = [] + for m in messages: + ids += self._block(m) + if add_generation_prompt: + ids += [_IM_START, _ROLE["assistant"], 900] # "\n" + return ids + + def decode(self, ids, skip_special_tokens=False): + return "" + + +@pytest.mark.unit +def test_splice_invariant_no_drift(): + tok = _FakeTokenizer() + # three turns; output_ids are arbitrary but each ends in <|im_end|> + raw_outputs = [([700 + i, 701 + i, 702 + i, _IM_END], f"raw{i}") for i in range(3)] + + def echo_for(_raw): + return { + "role": "assistant", + "content": "ok", + "tool_calls": [{"id": "call_0", "type": "function", "function": {"name": "bash", "arguments": '{"command": "ls"}'}}], + } + + target, seg, warns = _drive_episode(tok, raw_outputs, splice_build_prompt, echo_for) + + assert warns == [], f"splice path should not drift, got: {warns}" + # every appended prompt must start with prompt_{i-1} + output_{i-1} + for i in range(1, len(target.turns)): + prev = list(target.turns[i - 1].prompt_ids) + list(target.turns[i - 1].output_ids) + assert target.turns[i].prompt_ids[: len(prev)] == prev, f"turn {i} prompt does not extend prior turn" + # 100% of generated output tokens are trained (mask=1); context is mask=0 + total_output = sum(len(o) for o, _ in raw_outputs) + assert sum(seg.loss_mask) == total_output + assert len(seg.loss_mask) == seg.response_ids.__len__() + + +# --------------------------------------------------------------------------- # +# integration: real Qwen3.6 template, trailing-whitespace drift # +# --------------------------------------------------------------------------- # + + +def _real_tokenizer(): + transformers = pytest.importorskip("transformers") + try: + return transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3.6-35B-A3B") + except Exception as exc: # offline + not cached, gated, etc. + pytest.skip(f"Qwen3.6 tokenizer unavailable: {exc}") + + +def _raw_output_text(reasoning, cmd): + # cmd carries a trailing '\n' -> the parser strips it -> re-render drifts + return ( + f"{reasoning}\n\n\n\n\n\n" + f"{cmd}\n\n\n<|im_end|>" + ) + + +_TRAJECTORY = [ + ("Check the python version.", 'python3 -c "\nimport sys\nprint(sys.version)\n"\n'), + ("Run the failing case.", 'cd /testbed && python3 -c "\nimport numpy as np\nprint(np.zeros((2,2)))\n"\n'), + ("Apply the fix and re-run.", 'cd /testbed && python3 -c "\nprint(1 + 1)\n"\n'), +] + + +def _old_build_prompt(target, messages, tools_schema, kind, tok): + """Pre-fix behavior: extend/replace, dict-ify, full re-render.""" + (O._extend_chat_messages if kind == "append" else O._replace_chat_messages)(target, messages, tools_schema) + _dictify_tool_arguments(target.chat_messages) + return render_token_ids(target, tok) + + +def _build_real_episode_inputs(tok): + raw_outputs = [] + echo_map = {} + for reasoning, cmd in _TRAJECTORY: + raw_text = _raw_output_text(reasoning, cmd) + raw_outputs.append((tok.encode(raw_text, add_special_tokens=False), raw_text)) + # model the qwen3_coder parser: tool value has its trailing whitespace + # stripped, and arguments are returned to the harness as a JSON string + # (exactly what _chat_message/_json_arguments produce). + echo_map[raw_text] = { + "role": "assistant", + "content": None, + "reasoning_content": reasoning, + "tool_calls": [ + { + "id": "call_0", + "type": "function", + "function": {"name": "bash", "arguments": json.dumps({"command": cmd.rstrip()})}, + } + ], + } + return raw_outputs, (lambda raw: echo_map[raw]) + + +@pytest.mark.integration +def test_real_template_trailing_whitespace_old_path_drifts(): + tok = _real_tokenizer() + raw_outputs, echo_for = _build_real_episode_inputs(tok) + _, seg, warns = _drive_episode(tok, raw_outputs, _old_build_prompt, echo_for) + total_output = sum(len(o) for o, _ in raw_outputs) + # baseline: the old re-render path drifts and masks turns out of training + assert any("prefix drift" in w for w in warns), "expected the pre-fix path to drift" + assert sum(seg.loss_mask) < total_output, "expected the pre-fix path to mask some output" + + +@pytest.mark.integration +def test_real_template_trailing_whitespace_splice_is_clean(): + tok = _real_tokenizer() + raw_outputs, echo_for = _build_real_episode_inputs(tok) + target, seg, warns = _drive_episode(tok, raw_outputs, splice_build_prompt, echo_for) + total_output = sum(len(o) for o, _ in raw_outputs) + # the fix: no drift, every output token trained + assert warns == [], f"splice path must not drift, got: {warns}" + assert sum(seg.loss_mask) == total_output, "splice path must train 100% of output tokens" + # and every appended prompt contains the prior turn's raw output verbatim + for i in range(1, len(target.turns)): + prev = list(target.turns[i - 1].prompt_ids) + list(target.turns[i - 1].output_ids) + assert target.turns[i].prompt_ids[: len(prev)] == prev + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v"])) diff --git a/train.py b/train.py index 2404a0bbd1..e3c7f50dcd 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,5 @@ +import os + import ray from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models @@ -16,13 +18,34 @@ def train(args): # need to initialize rollout manager first to calculate num_rollout rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"]) - # Update primary W&B with SGLang metrics endpoint now that servers are up. - router_addr = ray.get(rollout_manager.get_metrics_router_addr.remote()) - update_tracking_open_metrics(args, router_addr) + # DEBUG (qwen3.6 resync probe): skip sglang-metrics wiring for the HF-dump probes. + _hf_probe = os.environ.get("SLIME_SAVE_HF_AND_EXIT") or os.environ.get("SLIME_SAVE_HF_AFTER_TRAIN") + if not _hf_probe: + # Update primary W&B with SGLang metrics endpoint now that servers are up. + router_addr = ray.get(rollout_manager.get_metrics_router_addr.remote()) + update_tracking_open_metrics(args, router_addr) # create the actor and critic models actor_model, critic_model = create_training_models(args, pgs, rollout_manager) + # DEBUG (qwen3.6 resync probe): the model is loaded at the real TP/EP layout; + # dump HF via the live resync converter and exit BEFORE any train — tests the + # gather+convert on the clean model only. + if os.environ.get("SLIME_SAVE_HF_AND_EXIT"): + actor_model.save_hf(0) + finish_tracking(args) + return + + # DEBUG (qwen3.6 resync probe): train EXACTLY ONE step on dumped rollout data + # (load_debug_rollout_data → no sglang), then dump HF — tests whether the + # optimizer/backward step corrupts the Megatron weights. + if os.environ.get("SLIME_SAVE_HF_AFTER_TRAIN"): + rollout_data_ref = ray.get(rollout_manager.generate.remote(0)) + ray.get(actor_model.async_train(0, rollout_data_ref)) + actor_model.save_hf(0) + finish_tracking(args) + return + if args.offload_rollout: ray.get(rollout_manager.onload_weights.remote()) diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000..7518fc90bf --- /dev/null +++ b/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.12"