Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ By default, scoring runs interleaved with generation. Use `--no-interleave-scori

The `--max-retries` flag enables automatic retry with exponential backoff when rollouts fail due to transient infrastructure errors (e.g., sandbox timeouts, API failures).

When a rollout retries, its saved output includes a `retry` block. It records attempt count, exhaustion status, elapsed retry time, and the retryable errors that triggered each attempt. This keeps transient provider or sandbox instability visible in `results.jsonl`, not only in logs.
Comment thread
cursor[bot] marked this conversation as resolved.

The `--num-workers` flag controls how many worker processes the env server spawns. Each worker owns its own environment instance and runs rollouts independently. The default `auto` scales with concurrency.

### Display
Expand Down
3 changes: 3 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,15 @@ class RolloutOutput(dict):
error: str | None
stop_condition: str | None
token_usage: TokenUsage
retry: RetryData
trajectory: list[TrajectoryStep]
tool_defs: list[Tool] | None
```

Serialized output from a rollout. This is a `dict` subclass that provides typed access to known fields while supporting arbitrary additional fields from `state_columns`. All values must be JSON-serializable. Used in `GenerateOutputs` and for saving results to disk.

The `retry` field is present only when `max_retries` caused at least one retry. It records attempts, retry count, exhaustion status, elapsed retry time, and per-attempt error summaries.

### TrajectoryStep

```python
Expand Down
33 changes: 33 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,12 @@ async def test_retry_after_retryable_error(self, mock_client, make_input):

assert states[0].get("error") is None
assert env.call_counts[0] == 3
assert states[0]["retry"]["attempts"] == 3
assert states[0]["retry"]["retry_count"] == 2
assert states[0]["retry"]["exhausted"] is False
assert [
event["error"] for event in states[0]["retry"]["events"]
] == ["InfraError", "InfraError"]

@pytest.mark.asyncio
async def test_no_retry_after_non_retryable_error(self, mock_client, make_input):
Expand Down Expand Up @@ -696,6 +702,33 @@ async def test_error_in_state_after_max_retries_exhausted(
assert rollout_outputs[0].get("error") is not None
error_data = rollout_outputs[0]["error"]
assert "InfraError" == error_data["error"]
retry = rollout_outputs[0]["retry"]
assert retry["attempts"] == 3
assert retry["retry_count"] == 2
assert retry["exhausted"] is True
assert retry["events"][-1]["error"] == "InfraError"

@pytest.mark.asyncio
async def test_group_retry_metadata_is_attached_to_each_rollout(
self, mock_client, make_input
):
"""Grouped scoring should expose retry history on every returned state."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=1, dataset=dataset, parser=Parser(), rubric=Rubric()
)

inputs = [make_input(example_id=0), make_input(example_id=0)]
outputs = await env.generate(
inputs, client=mock_client, model="test-model", max_retries=2
)

rollout_outputs = outputs["outputs"]
assert len(rollout_outputs) == 2
assert env.call_counts[0] == 4
assert all(output.get("error") is None for output in rollout_outputs)
assert all(output["retry"]["attempts"] == 2 for output in rollout_outputs)
assert all(output["retry"]["retry_count"] == 1 for output in rollout_outputs)

@pytest.mark.asyncio
async def test_retries_serialized_infra_error_subclass(self):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,29 @@ def test_states_to_outputs(self, make_state):
assert result[0].get("foo") == "bar" # custom field from make_state fixture
assert result[0]["reward"] == 1.0

def test_states_to_outputs_includes_retry_metadata(self, make_state):
state = make_state()
state["retry"] = {
"attempts": 3,
"max_retries": 2,
"retry_count": 2,
"exhausted": True,
"elapsed_seconds": 0.42,
"events": [
{
"attempt": 1,
"error": "InfraError",
"message": "sandbox timed out",
"error_chain_str": "InfraError",
"next_sleep_seconds": 1.0,
}
],
}

outputs = states_to_outputs([state], state_columns=[])

assert outputs[0]["retry"] == state["retry"]

def test_states_to_outputs_requires_example_id(self, make_state):
state = make_state()
del state["example_id"]
Expand Down
21 changes: 20 additions & 1 deletion verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,24 @@ class ErrorData(TypedDict):
error_chain_str: str


class RetryEvent(TypedDict):
attempt: int
error: str
message: str
error_chain_str: str
next_sleep_seconds: NotRequired[float]
state_index: NotRequired[str]


class RetryData(TypedDict):
attempts: int
max_retries: int
retry_count: int
exhausted: bool
elapsed_seconds: float
events: list[RetryEvent]


class RolloutOutput(dict):
"""Serialized output from a rollout (mirrors RolloutInput).
Expand All @@ -409,7 +427,7 @@ class RolloutOutput(dict):
Required fields: example_id, prompt, completion, reward, timing,
is_completed, is_truncated, metrics
Optional fields: answer, info, error, stop_condition, trajectory, tool_defs,
token_usage
token_usage, retry
Additional fields: arbitrary serializable state_columns
"""

Expand All @@ -430,6 +448,7 @@ class RolloutOutput(dict):
trajectory: list["TrajectoryStep"]
tool_defs: list[Tool]
token_usage: TokenUsage
retry: RetryData


_MISSING = object()
Expand Down
105 changes: 93 additions & 12 deletions verifiers/utils/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import inspect
import logging
import time
from collections import deque
from collections.abc import Mapping
from collections.abc import Coroutine
Expand All @@ -13,6 +14,7 @@

import verifiers as vf
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.error_utils import error_data
from verifiers.utils.error_utils import error_from_data, is_error_data
from verifiers.utils.logging_utils import print_time

Expand Down Expand Up @@ -161,6 +163,31 @@ def maybe_retry(
if max_retries <= 0:
return func

retry_started_at = 0.0
retry_events: list[dict[str, Any]] = []
last_result = None

def summarize_error(error: BaseException | Mapping[str, Any]) -> dict[str, str]:
"""Build the JSON-safe part of a retry event."""
if isinstance(error, Mapping) and is_error_data(error):
return {
"error": error["error"],
"message": error["message"],
"error_chain_str": error["error_chain_str"],
}
if isinstance(error, BaseException):
data = error_data(error)
return {
"error": data["error"],
"message": data["message"],
"error_chain_str": data["error_chain_str"],
}
return {
"error": type(error).__name__,
"message": str(error),
"error_chain_str": type(error).__name__,
}

def reraise_one(err, error_types: tuple[type[Exception], ...]):
if not err:
return
Expand All @@ -172,6 +199,19 @@ def reraise_one(err, error_types: tuple[type[Exception], ...]):
if isinstance(rebuilt, error_types):
raise rebuilt

def iter_state_errors(result) -> list[tuple[str, Mapping[str, Any] | BaseException]]:
if isinstance(result, dict):
err = result.get("error")
return [("", err)] if err else []
if isinstance(result, list):
errors = []
for index, state in enumerate(result):
err = state.get("error")
if err:
errors.append((str(index), err))
return errors
return []

def reraise_error_from_state(result, error_types: tuple[type[Exception], ...]):
"""Re-raise specified errors from state(s) to trigger tenacity retry."""
if isinstance(result, dict):
Expand All @@ -180,25 +220,51 @@ def reraise_error_from_state(result, error_types: tuple[type[Exception], ...]):
for state in result:
reraise_one(state.get("error"), error_types)

def attach_retry_info(result, *, exhausted: bool = False) -> None:
if not retry_events:
return
attempts = len(retry_events) + (0 if exhausted else 1)
summary = {
"attempts": attempts,
"max_retries": max_retries,
"retry_count": min(max_retries, max(0, attempts - 1)),
"exhausted": exhausted,
"elapsed_seconds": time.time() - retry_started_at,
"events": retry_events,
}
if isinstance(result, dict):
result["retry"] = summary
elif isinstance(result, list):
for state in result:
state["retry"] = summary

def begin_retry_call(retry_state: tc.RetryCallState) -> None:
nonlocal last_result, retry_events, retry_started_at
if retry_state.attempt_number == 1:
last_result = None
retry_events = []
retry_started_at = time.time()

def log_retry(retry_state: tc.RetryCallState) -> None:
"""Log a warning with the exception and the number of attempts."""
caller = retry_state.fn.__name__ if retry_state.fn else "unknown function"
error_chain = (
repr(
ErrorChain(
retry_state.outcome.exception() or Exception("Unknown exception")
)
)
if retry_state.outcome
else None
)
error = retry_state.outcome.exception() if retry_state.outcome else None
error_chain = repr(ErrorChain(error or Exception("Unknown exception")))
next_action = retry_state.next_action.sleep if retry_state.next_action else 0
if retry_events and retry_events[-1]["attempt"] == retry_state.attempt_number:
retry_events[-1]["next_sleep_seconds"] = next_action
else:
retry_events.append(
{
"attempt": retry_state.attempt_number,
"next_sleep_seconds": next_action,
**summarize_error(error or Exception("Unknown exception")),
}
)
logger.warning(
f"Caught {error_chain} in {caller}. Retrying in {print_time(next_action)} (retry {retry_state.attempt_number}/{max_retries})"
)

last_result = None

def return_last_result(retry_state: tc.RetryCallState):
"""Return the last result when retries are exhausted (instead of raising)."""
caller = retry_state.fn.__name__ if retry_state.fn else "unknown function"
Expand All @@ -215,13 +281,27 @@ def return_last_result(retry_state: tc.RetryCallState):
f"Retries exhausted for {caller} after {max_retries} attempts. "
f"Last error: {error_chain}. Continuing with error in state."
)
if last_result is not None:
attach_retry_info(last_result, exhausted=True)
return last_result

async def wrapper(*args, **kwargs):
nonlocal last_result
result = await func(*args, **kwargs)
last_result = result # store result
reraise_error_from_state(result, error_types)
state_errors = iter_state_errors(result)
try:
reraise_error_from_state(result, error_types)
except error_types as error:
retry_events.append(
{
"attempt": len(retry_events) + 1,
"state_index": state_errors[0][0] if state_errors else "",
**summarize_error(error),
}
Comment thread
cursor[bot] marked this conversation as resolved.
)
raise
attach_retry_info(result)
return result

wrapper.__name__ = getattr(func, "__name__", "unknown")
Expand All @@ -231,6 +311,7 @@ async def wrapper(*args, **kwargs):
retry=tc.retry_if_exception_type(error_types),
stop=tc.stop_after_attempt(max_retries + 1),
wait=tc.wait_exponential_jitter(initial=initial, max=max_wait),
before=begin_retry_call,
before_sleep=log_retry,
retry_error_callback=return_last_result,
reraise=True,
Expand Down
5 changes: 5 additions & 0 deletions verifiers/utils/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ def state_to_output(
output.pop("answer")
if "info" in output and not output["info"]:
output.pop("info")
retry = state.get("retry")
if retry is not None:
if not is_json_serializable(retry):
raise ValueError("state.retry must be JSON-serializable.")
output["retry"] = retry
# flatten metrics to top-level keys (backwards compatibility)
state_metrics = state.get("metrics") or {}
for k, v in state_metrics.items():
Expand Down