-
Notifications
You must be signed in to change notification settings - Fork 896
[coding-agent-rl] Refactor coding-agent RL: turn-node TrajectoryManager + pluggable harness layer #2005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jingshenghang
wants to merge
66
commits into
THUDM:main
Choose a base branch
from
jingshenghang:refactor_trajectory_manager
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[coding-agent-rl] Refactor coding-agent RL: turn-node TrajectoryManager + pluggable harness layer #2005
Changes from 47 commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
8646a8d
refactor(agent): replace segment-based trajectory with turn-node Traj…
f10552f
refactor(agent): drop dead code left by trajectory-manager refactor
ba557ae
refactor(agent): port anthropic + trajectory_manager from trajectory-…
b8a5cd4
refactor(agent): rewrite openai.py for Codex CLI + TrajectoryManager
44e7999
refactor(agent): centralize snapshot-threshold default + filter acces…
5e59c25
feat(agent): add fork-merge rescue for short assistant rewrites
1d459bd
fix(agent): replace sib.messages on fork-merge rescue
8d6fdc9
refactor(agent): drop billing-header scrub now cc emits no header
2b4efc4
refactor(agent): migrate TrajectoryManager and adapters (v4)
6f95f18
refactor(agent): drop drift fork/merge params, strict exact-prefix li…
4fcbb24
feat(agent): assistant-rewrite merge to de-dilute reward
0f1c5aa
feat(agent): TrajectoryManager re-accepts fork_merge_max_response_tokens
4732702
docs(test): spec for TrajectoryManager e2e test script
4ea5d5a
test(agent): end-to-end TrajectoryManager test matrix (append_turn/ge…
cbee0de
test(agent): dump raw append_turn inputs in e2e readable output
a11c9b4
test(agent): 1.7 now shows token drift's effect on the linearized sample
f1b1792
test(agent): every case prints [samples] + mask info
0be5966
test(agent): make reward-split explicit in dump + conservation assert
fe7692a
test(agent): set every case's input reward to 1.0
2ee5ee8
test(agent): render whitespace in token labels as visible ␣
8eed4dd
test(agent): add Group 4 (boundary/defensive/feature) -> 98% coverage
cfad29d
test(agent): assert full output via golden token+loss strings
624b927
refactor(agent): drift-tolerant trajectory linearization
054f89a
chore(agent): untrack e2e test design doc and trajectory_manager tests
bc3d304
chore(agent): drop comments/docstrings in generate.py and TurnRecord
ece9007
docs(agent): tighten trajectory_manager comments to why-not-what
6565fe1
refactor(agent): slim adapters and trajectory_manager, add e2e test
aceb162
refactor(agent): replace segment-based trajectory with turn-node Traj…
0193103
refactor(agent): drop dead code left by trajectory-manager refactor
d84fe47
refactor(agent): port anthropic + trajectory_manager from trajectory-…
b7550cc
refactor(agent): rewrite openai.py for Codex CLI + TrajectoryManager
0b2576d
refactor(agent): centralize snapshot-threshold default + filter acces…
4ea8bb9
feat(agent): add fork-merge rescue for short assistant rewrites
8c5fe1a
fix(agent): replace sib.messages on fork-merge rescue
18ea895
refactor(agent): drop billing-header scrub now cc emits no header
f733bef
fix(agent): use rollout_id after upstream group_id revert
ad122f4
refactor(agent): assert base_sample in get_trajectory instead of defa…
b05dfdb
Merge branch 'trajectory-manager-migration-v4' into refactor_trajecto…
18772b0
fix(agent): mask entire drifted response span in B1 replace
d61b419
refactor(agent): extract shared adapter pipeline; move TurnRecord + a…
71f5511
refactor(agent): symmetric adapters + de-scaffold common
d8a6b78
refactor(agent): drop tools param; dict== mount-point matching
d73bb4a
refactor(agent): collapse Node turn_* into turn; rename Node->Message…
fc72355
perf(agent): chunked common-prefix; rename _lcp_len -> _common_prefix…
54b0846
refactor(agent): trajectory manager cleanup + adapter tweaks
804bcd1
refactor(agent): fold classify_drift into _SampleBuilder as classify_…
46d09de
chore(agent): untrack test_trajectory_manager.py
36fa60e
refactor(agent): gate fork on full output_ids length
76cf2cb
Refactor coding-agent RL harness: pluggable agent harness layer + SWE…
0da1b3c
Merge pull request #4 from jingshenghang/refactor_harness
jingshenghang 7e1af92
refactor(agent): unify env vars under SLIME_AGENT_*/ADAPTER_*, tidy h…
76c9b0a
refactor(agent): make harnesses SingletonMeta-backed, drop module-lev…
d2897af
refactor(agent): rename debug hook to debug_callback, pass TurnRecord…
336985c
refactor(agent): tidy adapters, trajectory_manager, harness and SWE e…
36608fb
refactor(agent): drop optional sandbox metadata file/json, keep image…
da0a4fd
refactor(agent): reorganize agent tests under tests/test_agent/, tidy…
414a2ae
refactor(agent): simplify adapters, harness and sandbox
3cd0e30
docs(swe): move fan-out/sandbox notes from launcher into README
b40db1d
Merge remote-tracking branch 'origin/main' into refactor_trajectory_m…
94beeda
fix(agent): lazy-import load_tokenizer so CPU agent test needs no tra…
4405c1c
Revert "fix(agent): lazy-import load_tokenizer so CPU agent test need…
942358d
test(agent): stub transformers before importing generate in CPU rollo…
3772db3
test(agent): shim asyncio.timeout for py3.10 CI in CPU rollout test
f0f40b4
ci: rename agent-adapter-test job to agent-test (covers more than ada…
b2d3b38
refactor(agent): rename trajectory_manager.py back to trajectory.py
35c9222
refactor(agent): drop duplicate trajectory_manager.py left by rename
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,8 +9,8 @@ | |
| 1. ``sandbox.run_claude_code`` prepares the agent sandbox and runs claude-code. | ||
| 2. ``sandbox.git_diff`` captures the model-produced patch. | ||
| 3. ``sandbox.evaluate`` scores that patch in a second clean sandbox. | ||
| 4. ``_merge_samples`` combines reward + adapter ``TokenSegment``s, | ||
| delegating segment-to-``Sample`` fan-out to ``slime.agent.trajectory``. | ||
| 4. ``adapter.finish_session`` drains the session tree into reward-weighted | ||
| ``Sample`` objects with ``.response`` already decoded; ``generate`` logs. | ||
|
|
||
| All sandbox-side details live in ``sandbox.py``; the LLM plumbing | ||
| (Anthropic <-> SGLang /generate, token capture, 3-kind segment split) uses | ||
|
|
@@ -49,17 +49,15 @@ | |
| import secrets | ||
| import time | ||
| import traceback | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| from slime.agent.adapters import AnthropicAdapter | ||
| from slime.agent.trajectory import TokenSegment, fan_out_sample_segments | ||
| from slime.agent.aiohttp_threaded import FilteredAccessLogger, run_app_in_thread | ||
| from slime.utils.misc import SingletonMeta | ||
| from slime.utils.processing_utils import load_tokenizer | ||
| from slime.utils.types import Sample | ||
|
|
||
| from . import sandbox | ||
| from .aiohttp_threaded import run_app_in_thread | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -97,11 +95,13 @@ def __init__(self, args) -> None: | |
| "Without it the sandbox cannot dial back and the rollout will " | ||
| "silently abort." | ||
| ) | ||
| fork_merge_threshold = int(v) if (v := os.environ.get("SLIME_FORK_MERGE_MAX_RESPONSE_TOKENS")) else None | ||
| self.adapter = AnthropicAdapter( | ||
| tokenizer=self.tokenizer, | ||
| sglang_url=sglang_url, | ||
| tool_parser=self.tool_parser, | ||
| reasoning_parser=self.reasoning_parser, | ||
| fork_threshold_tokens=fork_merge_threshold, | ||
| ) | ||
| # handler_cancellation=True so a client disconnect cancels the handler | ||
| # coroutine, arming the fire-and-forget /abort_request inside the | ||
|
|
@@ -113,7 +113,10 @@ def __init__(self, args) -> None: | |
| host=SHIM_BIND_HOST, | ||
| port=SHIM_PORT, | ||
| thread_name="anthropic-adapter", | ||
| runner_kwargs={"handler_cancellation": True}, | ||
| runner_kwargs={ | ||
| "handler_cancellation": True, | ||
| "access_log_class": FilteredAccessLogger, | ||
| }, | ||
| ) | ||
| self.adapter_url = f"http://{public_host}:{self.app_handle.port}" | ||
| logger.info( | ||
|
|
@@ -127,18 +130,8 @@ def __init__(self, args) -> None: | |
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Trajectory -> Sample conversion | ||
| # adapter.finish_session() returns TokenSegments. One trajectory yields >=1 | ||
| # segments because the agent may compact + reset mid-run; trajectory.py handles | ||
| # the mechanical segment -> Sample fan-out. | ||
| # Session setup | ||
| # --------------------------------------------------------------------------- | ||
| @dataclass(frozen=True) | ||
| class RewardResult: | ||
| reward: float | ||
| is_solved: bool | ||
| applied_cleanly: bool | ||
|
|
||
|
|
||
| def _start_session( | ||
| state: _State, | ||
| sample: Sample, | ||
|
|
@@ -164,55 +157,11 @@ def _start_session( | |
| return session_id | ||
|
|
||
|
|
||
| def _merge_samples( | ||
| *, | ||
| sample: Sample, | ||
| state: _State, | ||
| segments: list[TokenSegment], | ||
| reward_result: RewardResult, | ||
| elapsed_sec: float, | ||
| instance_id: str, | ||
| ): | ||
| if not segments: | ||
| return _abort_result(sample, "adapter_session_empty") | ||
|
|
||
| trajectory_metadata = { | ||
| **(sample.metadata or {}), | ||
| "instance_id": instance_id, | ||
| "is_solved": reward_result.is_solved, | ||
| "applied_cleanly": reward_result.applied_cleanly, | ||
| "elapsed_sec": elapsed_sec, | ||
| } | ||
|
|
||
| # All K samples share rollout_id so the loss reducer counts this | ||
| # trajectory once. | ||
| fanned = fan_out_sample_segments( | ||
| sample, | ||
| segments, | ||
| reward_result.reward, | ||
| state.tokenizer, | ||
| metadata=trajectory_metadata, | ||
| ) | ||
| if not fanned: | ||
| raise ValueError("fan-out produced no samples") | ||
|
|
||
| logger.info( | ||
| "[coding_agent_rl] %s: reward=%.2f solved=%s applied=%s elapsed=%.1fs segments=%d", | ||
| instance_id, | ||
| reward_result.reward, | ||
| reward_result.is_solved, | ||
| reward_result.applied_cleanly, | ||
| elapsed_sec, | ||
| len(fanned), | ||
| ) | ||
| return fanned | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Main per-sample agent function | ||
| # | ||
| # The four calls inside the timeout are the high-level rollout recipe: | ||
| # run_claude_code -> git_diff -> sandbox.evaluate -> merge_samples. | ||
| # run_claude_code -> git_diff -> sandbox.evaluate -> finish_session. | ||
| # --------------------------------------------------------------------------- | ||
| async def generate(args, sample: Sample, sampling_params: dict[str, Any]): | ||
| """Per-sample agent function with wall-clock guard. See | ||
|
|
@@ -249,20 +198,26 @@ async def generate(args, sample: Sample, sampling_params: dict[str, Any]): | |
| pre_commands=md["pre_commands"], | ||
| timeout_sec=SWE_EVAL_TIMEOUT_SEC, | ||
| ) | ||
| reward_result = RewardResult( | ||
| samples = await state.adapter.finish_session( | ||
| session_id, | ||
| base_sample=sample, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 或者我们统一都存成 base_sample 也行
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已统一修改为 base_sample |
||
| reward=float(reward), | ||
| is_solved=bool(is_solved), | ||
| applied_cleanly=bool(applied_cleanly), | ||
| ) | ||
| segments = await state.adapter.finish_session(session_id) | ||
| return _merge_samples( | ||
| sample=sample, | ||
| state=state, | ||
| segments=segments, | ||
| reward_result=reward_result, | ||
| elapsed_sec=time.time() - t0, | ||
| instance_id=instance_id, | ||
| if not samples: | ||
| return _abort_result(sample, "adapter_session_empty") | ||
|
|
||
| # finish_session already linearized, reward-weighted and decoded | ||
| # each segment's .response; here we only log a summary. | ||
| logger.info( | ||
| "[coding_agent_rl] %s: reward=%.2f solved=%s applied=%s elapsed=%.1fs segments=%d", | ||
| instance_id, | ||
| float(reward), | ||
| bool(is_solved), | ||
| bool(applied_cleanly), | ||
| time.time() - t0, | ||
| len(samples), | ||
| ) | ||
| return samples | ||
|
|
||
| except asyncio.TimeoutError: | ||
| _log_timeout_diagnostic(t0) | ||
|
|
@@ -347,7 +302,9 @@ def _coerce_prompt(prompt) -> str: | |
| return "" | ||
|
|
||
|
|
||
| def _abort(sample: Sample, reason: str) -> Sample: | ||
| def _abort_result(sample: Sample, reason: str) -> list[Sample]: | ||
| """Mark ``sample`` aborted in place and return it in the list shape this | ||
| fan-out generate function always yields.""" | ||
| sample.tokens = [0, 0] | ||
| sample.response = "" | ||
| sample.response_length = 1 | ||
|
|
@@ -356,9 +313,4 @@ def _abort(sample: Sample, reason: str) -> Sample: | |
| sample.status = Sample.Status.ABORTED | ||
| sample.metadata = {**(sample.metadata or {}), "abort_reason": reason} | ||
| logger.warning("[coding_agent_rl] aborted: %s", reason) | ||
| return sample | ||
|
|
||
|
|
||
| def _abort_result(sample: Sample, reason: str): | ||
| """Return a uniform list shape for this fan-out generate function.""" | ||
| return [_abort(sample, reason)] | ||
| return [sample] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
貌似没有别的地方用到
access_log_class了?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"access_log_class": FilteredAccessLogger这个对应的FilteredAccessLogger在 aiohttp_threaded.py 里面有定义,是让 adaptor 只打印异常请求(回复不是 200,或者请求超过 120s),避免正常请求日志刷屏