Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/zh_cn/rl/advanced_tutorial/agent_loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ AgentLoop 输入和输出都是 `RolloutState`。如果后续使用预置 `RLTra
- `prompt_ids`:tokenized prompt,通常由 RL tokenize function 写入。
- `reward_model`:标签信息,例如 `{"ground_truth": ...}`,供 Judger 使用。
- `sample_params`:会在 `generate_group()` 中被 AgentLoop 的默认采样参数覆盖。
- `task_name`、`uid`、`session_uid` 等调度字段。
- `task_name`、`rollout_id`、`session_id` 等调度字段。

生成前,AgentLoop 需要确保:

Expand Down
8 changes: 4 additions & 4 deletions recipe/verl_agent/common/agent_loop_verl_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ async def generate(

# session_id is set in the VerlToolAgentLoop.generate_sample
# and ignore request_id generated by verl.ToolAgentLoop.run
session_uid = sampling_params.get("session_uid", -1)
session_id = sampling_params.get("session_id", -1)

rollout_state = RolloutState(
message=[],
tokens=prompt_ids,
session_uid=session_uid,
session_id=session_id,
sample_params=sample_params,
)

Expand Down Expand Up @@ -115,7 +115,7 @@ async def generate_sample(self, rollout_state: RolloutState) -> RolloutState:
repetition_penalty=sp.repetition_penalty,
logprobs=sp.return_logprob,
# session_id is used to identify the session in the server manager
session_uid=rollout_state.session_uid,
session_id=rollout_state.session_id,
)

input_kwargs = {
Expand All @@ -129,7 +129,7 @@ async def generate_sample(self, rollout_state: RolloutState) -> RolloutState:
except Exception as e:
rollout_state.status = Status.FAILED
rollout_state.error_msg = str(e)
self.logger.error(f"[VerlToolAgentLoop][{rollout_state.session_uid}] generate_sample failed: {e}")
self.logger.error(f"[VerlToolAgentLoop][{rollout_state.session_id}] generate_sample failed: {e}")
return rollout_state
# TODO: handle samples with corrupted tool tokens ?

Expand Down
6 changes: 3 additions & 3 deletions tests/rl/test_agent_loop_manager_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _generate(self, rollout_state: RolloutState) -> RolloutState:
rollout_state.status = Status.COMPLETED
rollout_state.response = "ok"
rollout_state.response_ids = [100, 101]
rollout_state.reward = {"score": 1.0 if int(rollout_state.uid or 0) % 2 == 0 else 0.5}
rollout_state.reward = {"score": 1.0 if int(rollout_state.rollout_id or 0) % 2 == 0 else 0.5}
return rollout_state


Expand Down Expand Up @@ -306,8 +306,8 @@ async def test_resume_keeps_unconsumed_completed_groups_available_to_get_batch(s
manager = self._build_disagg_async_manager()
buffered_group = [
RolloutState(
uid=9000 + idx,
message_uid=90,
rollout_id=9000 + idx,
group_id=90,
message=[{"role": "user", "content": "buffered"}],
prompt_ids=[1, 2, 3],
response="ok",
Expand Down
4 changes: 2 additions & 2 deletions tests/rl/test_multi_task_agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ async def pause_produce(self, ctx) -> float:


class _FakeRolloutState:
def __init__(self, uid: str, group_generate_time_s: float):
self.uid = uid
def __init__(self, rollout_id: str, group_generate_time_s: float):
self.rollout_id = rollout_id
self.extra_fields = {GROUP_GENERATE_TIME_KEY: group_generate_time_s}


Expand Down
8 changes: 4 additions & 4 deletions tests/rl/test_prepare_train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _state(
self,
*,
uid: int = 1,
message_uid: int = 1,
group_id: int = 1,
prompt_ids: list[int] | None = None,
response_ids: list[int] | torch.Tensor | None = None,
logprobs: list[float] | None = None,
Expand All @@ -58,9 +58,9 @@ def _state(
extra_fields: dict | None = None,
) -> RolloutState:
return RolloutState(
uid=uid,
message_uid=message_uid,
message=[{"role": "user", "content": f"prompt {message_uid}"}],
rollout_id=uid,
group_id=group_id,
message=[{"role": "user", "content": f"prompt {group_id}"}],
prompt_ids=prompt_ids if prompt_ids is not None else [10, 11, 12],
response=response,
response_ids=response_ids if response_ids is not None else [20, 21, 22],
Expand Down
32 changes: 16 additions & 16 deletions tests/rl/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def make_rollout_state(
reward_score: float | None = None,
) -> RolloutState:
return RolloutState(
uid=uid,
message_uid=uid,
rollout_id=uid,
group_id=uid,
message=[{"role": "user", "content": f"prompt {uid}"}],
prompt_ids=[uid],
tokens=[uid],
Expand Down Expand Up @@ -109,7 +109,7 @@ async def mock_pause():
sleep_by_id = sleep_by_id or {}

async def mock_gen(rs, **kwargs):
await asyncio.sleep(sleep_by_id.get(rs[0].message_uid, 0.0))
await asyncio.sleep(sleep_by_id.get(rs[0].group_id, 0.0))
for r in rs:
r.seq_staleness = kwargs.get("model_step", kwargs.get("train_step", 0))
r.status = Status.COMPLETED
Expand Down Expand Up @@ -246,7 +246,7 @@ async def test_sampler_with_replay_buffer(self):

# 场景 A: ReplayBuffer 为空,从 Dataloader 拿
data = await sampler.sample(task_name)
self.assertEqual(data[0].message_uid, 0)
self.assertEqual(data[0].group_id, 0)

# 场景 B: ReplayBuffer 有多个候选状态,按列表顺序优先拿
aborted_item = make_rollout_state(999, status=Status.ABORTED)
Expand All @@ -255,14 +255,14 @@ async def test_sampler_with_replay_buffer(self):
await self.replay_buffer.put([expired_item], task_name)

data = await sampler.sample(task_name, group_status=[Status.EXPIRED, Status.ABORTED])
self.assertEqual(data[0].message_uid, 1000)
self.assertEqual(data[0].group_id, 1000)

data = await sampler.sample(task_name, group_status=[Status.EXPIRED, Status.ABORTED])
self.assertEqual(data[0].message_uid, 999)
self.assertEqual(data[0].group_id, 999)

# 场景 C: ReplayBuffer 对应状态都为空,回退到 Dataloader
data = await sampler.sample(task_name, group_status=[Status.EXPIRED, Status.ABORTED])
self.assertEqual(data[0].message_uid, 1)
self.assertEqual(data[0].group_id, 1)

async def test_put_generated_group_only_validates_completed_group(self):
# 验证 ProduceContext 只对 completed group 执行业务过滤,aborted group 保持可重试状态。
Expand Down Expand Up @@ -345,19 +345,19 @@ async def test_sync_produce_strategy(self):
# 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据
final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED)
self.assertEqual(len(final_data), 2)
self.assertEqual(final_data[0][0].message_uid, 0)
self.assertEqual(final_data[1][0].message_uid, 1)
self.assertEqual(final_data[0][0].group_id, 0)
self.assertEqual(final_data[1][0].group_id, 1)

async def test_sync_produce_strategy_refills_after_filtered_and_aborted_groups(self):
# 验证 filtered / aborted group 不占用 completed quota,sync producer 会继续补齐训练 batch。
task_name = "test_sync_refill"

def is_valid_sample_fn(samples):
return samples[0].message_uid != 0
return samples[0].group_id != 0

async def mock_gen(rs, **kwargs):
for r in rs:
if r.message_uid == 1:
if r.group_id == 1:
r.status = Status.ABORTED
r.response = ""
r.response_ids = []
Expand Down Expand Up @@ -386,7 +386,7 @@ async def mock_gen(rs, **kwargs):
await strategy.produce_batch(ctx)
completed = await self.replay_buffer.get(10, task_name, Status.COMPLETED)
self.assertEqual(len(completed), 2)
self.assertEqual(sorted(group[0].message_uid for group in completed), [2, 3])
self.assertEqual(sorted(group[0].group_id for group in completed), [2, 3])
self.assertEqual(await self.replay_buffer.count(task_name, Status.FILTERED), 1)
self.assertEqual(await self.replay_buffer.count(task_name, Status.ABORTED), 1)

Expand All @@ -403,7 +403,7 @@ async def mock_gen(rs, **kwargs):
nonlocal call_count
call_count += 1
for r in rs:
if r.message_uid == 999:
if r.group_id == 999:
r.seq_staleness = 5
else:
r.seq_staleness = call_count
Expand Down Expand Up @@ -434,7 +434,7 @@ async def mock_gen(rs, **kwargs):
# 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据。
final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED)
self.assertEqual(len(final_data), 4)
self.assertEqual(sorted(group[0].message_uid for group in final_data), [0, 1, 2, 999])
self.assertEqual(sorted(group[0].group_id for group in final_data), [0, 1, 2, 999])

async def test_async_produce_strategy_accepts_context_entrypoint(self):
# 验证 AsyncProduceStrategy 通过 ProduceContext public 入口完成一次最小生产。
Expand Down Expand Up @@ -564,7 +564,7 @@ async def instrumented_sample(task_name, group_status=None):
# tail-batch 模式在本轮优先走 EXPIRED pool,并且不使用 over-sample 额外发射。
self.assertEqual(sampled_statuses, [[Status.EXPIRED, Status.ABORTED], [Status.EXPIRED, Status.ABORTED]])
completed = await self.replay_buffer.get(10, task_name, Status.COMPLETED)
self.assertEqual(sorted(group[0].message_uid for group in completed), [900, 901])
self.assertEqual(sorted(group[0].group_id for group in completed), [900, 901])
self.assertTrue(all(group[0].seq_staleness == 0 for group in completed))

async def test_async_produce_strategy_fails_fast_on_invalid_progress(self):
Expand Down Expand Up @@ -721,7 +721,7 @@ async def test_async_produce_strategy_does_not_reclaim_previous_call_pending(sel

final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED)
self.assertEqual(len(final_data), 1)
self.assertEqual(final_data[0][0].message_uid, 0)
self.assertEqual(final_data[0][0].group_id, 0)
for group in final_data:
self.assertIn("group_generate_time_s", group[0].extra_fields)
self.assertGreater(group[0].extra_fields["group_generate_time_s"], 0.0)
Expand Down
14 changes: 7 additions & 7 deletions tests/rl/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def make_rollout_state(
response_ids = list(response_ids) if response_ids is not None else [uid + 10]
logprobs = list(logprobs) if logprobs is not None else [0.1 for _ in response_ids]
return RolloutState(
uid=uid,
message_uid=uid,
rollout_id=uid,
group_id=uid,
message=[{"role": "user", "content": f"prompt {uid}"}],
prompt_ids=prompt_ids,
tokens=list(tokens) if tokens is not None else list(prompt_ids),
Expand All @@ -77,7 +77,7 @@ def make_rollout_state(


def group_uids(groups: list[list[RolloutState]]) -> list[list[int]]:
return [[state.uid for state in group] for group in groups]
return [[state.rollout_id for state in group] for group in groups]


async def save_and_resume(
Expand Down Expand Up @@ -236,7 +236,7 @@ async def test_common_refresh_staleness_contract(self):
assert await replay_buffer.count("task", Status.ABORTED) == 0
assert await replay_buffer.count("task", Status.EXPIRED) == 2
expired = await replay_buffer.get(2, "task", Status.EXPIRED)
assert {state.uid for group in expired for state in group} == {1, 2}
assert {state.rollout_id for group in expired for state in group} == {1, 2}

filtered_buffer = replay_buffer_config_cls().build()
await filtered_buffer.put([make_rollout_state(3, response_model_steps=[1])], "task")
Expand All @@ -256,8 +256,8 @@ async def test_common_refresh_staleness_contract(self):
assert await filtered_buffer.count("task", Status.EXPIRED) == 1
completed = await filtered_buffer.get(1, "task", Status.COMPLETED)
expired = await filtered_buffer.get(1, "task", Status.EXPIRED)
assert completed[0][0].uid == 3
assert expired[0][0].uid == 4
assert completed[0][0].rollout_id == 3
assert expired[0][0].rollout_id == 4

async def test_sync_get_returns_fifo_order(self):
# Sync replay 用于共卡按需生产,策略契约是同 task/status 下严格按入库顺序消费。
Expand Down Expand Up @@ -331,7 +331,7 @@ async def test_save_resume_preserves_rollout_state_fields(self):
# save/resume 应保留真实 RolloutState 字段,不再用 MockState.input_ids 代表训练样本内容。
def state_signature(state: RolloutState) -> tuple:
return (
state.uid,
state.rollout_id,
tuple(state.prompt_ids or []),
tuple(state.response_ids or []),
tuple(state.response_model_steps or []),
Expand Down
14 changes: 7 additions & 7 deletions tests/rl/test_rl_colocate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
class _FakeRolloutState:
def __init__(self, uid: int):
self.id = uid
self.uid = str(uid)
self.message_uid = uid
self.session_uid = uid
self.rollout_id = str(uid)
self.group_id = uid
self.session_id = uid
self.status = Status.INIT
self.finish_reason = None
self.seq_staleness = 0
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_fit_uses_sync_interval_and_passes_rollout_model_step(self):
async def _produce_batch(batch_size, train_step, *, model_step):
produce_calls.append((batch_size, train_step, model_step))
return ProduceBatchResult(
rollout_states=[[SimpleNamespace(message_uid=train_step, uid=train_step)]]
rollout_states=[[SimpleNamespace(group_id=train_step, rollout_id=train_step)]]
)

trainer = self._make_trainer(
Expand Down Expand Up @@ -302,8 +302,8 @@ def test_debug_train_loads_batches_and_skips_weight_sync(self):
# 验证 debug_train 通过 fit() 读取落盘 batch,并只推进训练流程。
debug_dir = Path(self.temp_dir.name) / "debug_train"
debug_dir.mkdir()
torch.save([[SimpleNamespace(uid=1, message_uid=1)]], debug_dir / "debug_rollout_1.pt")
torch.save([[SimpleNamespace(uid=2, message_uid=2)]], debug_dir / "debug_rollout_2.pt")
torch.save([[SimpleNamespace(rollout_id=1, group_id=1)]], debug_dir / "debug_rollout_1.pt")
torch.save([[SimpleNamespace(rollout_id=2, group_id=2)]], debug_dir / "debug_rollout_2.pt")

trainer = self._make_trainer(MagicMock(), total_train_steps=2)
trainer._debug_train = True
Expand Down Expand Up @@ -333,7 +333,7 @@ def train_one_batch(train_batch, train_step, step_timer_dict, **kwargs):

trainer.fit()

self.assertEqual([(step, batch[0][0].uid) for step, batch in captured_batches], [(1, 1), (2, 2)])
self.assertEqual([(step, batch[0][0].rollout_id) for step, batch in captured_batches], [(1, 1), (2, 2)])
self.assertEqual(trainer._cur_step, 2)

def test_debug_rollout_fit_serializes_object_refs_and_debug_train_fit_restores_them(self):
Expand Down
16 changes: 8 additions & 8 deletions tests/rl/test_rl_disaggregated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _minimal_train_info(self, *, training_samples: int, training_tokens: int, be

def test_fit_persists_checkpoint_for_completed_model_step(self):
# 验证 checkpoint 以 fit 完成的 model_step 为准,并通过 async manager.save 落盘。
train_sample = SimpleNamespace(message_uid=1, uid=1)
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
manager = _FakeManager([ProduceBatchResult(rollout_states=[[train_sample]])])
manager.save = AsyncMock()
trainer = self._make_trainer(manager)
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_fit_persists_checkpoint_for_completed_model_step(self):

def test_fit_retries_same_step_after_empty_expired_skip(self):
# 验证空 expired batch 只同步上一版模型,不推进 train_step,并重试同一步。
train_sample = SimpleNamespace(message_uid=1, uid=1)
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
manager = _FakeManager(
[
ProduceBatchResult(rollout_states=[], status=ProduceBatchStatus.EXPIRED_BATCH),
Expand All @@ -238,7 +238,7 @@ def test_fit_retries_same_step_after_empty_expired_skip(self):

def test_fit_trains_non_empty_expired_batch_then_syncs_current_step(self):
# 验证非空 expired batch 仍会训练,并用当前完成的 model_step 恢复 producer。
train_sample = SimpleNamespace(message_uid=1, uid=1)
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
manager = _FakeManager(
[ProduceBatchResult(rollout_states=[[train_sample]], status=ProduceBatchStatus.EXPIRED_BATCH)]
)
Expand All @@ -253,7 +253,7 @@ def test_fit_trains_non_empty_expired_batch_then_syncs_current_step(self):

def test_fit_keeps_background_producer_running_while_training_blocks(self):
# 验证非共卡训练阻塞在同步训练 batch 时,后台 producer 仍能继续调度。
train_sample = SimpleNamespace(message_uid=1, uid=1)
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
training_started = threading.Event()
producer_ticked = threading.Event()
manager = _TickingManager(
Expand All @@ -280,7 +280,7 @@ def blocking_train_one_batch(*args, **kwargs):

def test_fit_observes_background_producer_failure_before_training_waited_batch(self):
# 后台 producer 异常是终止性失败;前台 get_batch 还在等待时必须立刻暴露,不能先训练随后才失败。
train_sample = SimpleNamespace(message_uid=1, uid=1)
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
manager = _FailingProducerManager([ProduceBatchResult(rollout_states=[[train_sample]])])
trainer = self._make_trainer(manager)

Expand All @@ -292,9 +292,9 @@ def test_fit_observes_background_producer_failure_before_training_waited_batch(s

def test_fit_runs_eval_before_reset_and_stops_producer(self):
# 验证 eval 在 producer 恢复前执行,避免生产侧提前抢占 rollout 资源。
# 确定性排序依赖 RolloutState 的 message_uiduid,测试用轻量对象模拟即可。
train_sample = SimpleNamespace(message_uid=1, uid=1)
eval_sample = SimpleNamespace(message_uid=2, uid=2)
# 确定性排序依赖 RolloutState 的 group_idrollout_id,测试用轻量对象模拟即可。
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
eval_sample = SimpleNamespace(group_id=2, rollout_id=2)
manager = _FakeManager(
[ProduceBatchResult(rollout_states=[[train_sample]], status=ProduceBatchStatus.NORMAL)]
)
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/test_rl_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _generate(self, rollout_state):
rollout_state.status = Status.COMPLETED
rollout_state.response = "ok"
rollout_state.response_ids = [100, 101]
reward_score = 1.0 if int(rollout_state.uid) % 2 == 0 else 0.5
reward_score = 1.0 if int(rollout_state.rollout_id) % 2 == 0 else 0.5
rollout_state.reward = {"score": reward_score}
return rollout_state

Expand Down
Loading
Loading