Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ By default, scoring runs interleaved with generation. Use `--no-interleave-scori

The `--max-retries` flag enables automatic retry with exponential backoff when rollouts fail due to transient infrastructure errors (e.g., sandbox timeouts, API failures).

When a rollout retries, its saved output includes a `retry` block. It records attempt count, exhaustion status, elapsed retry time, and the retryable errors that triggered each attempt. This keeps transient provider or sandbox instability visible in `results.jsonl`, not only in logs.
Comment thread
cursor[bot] marked this conversation as resolved.

The `--num-workers` flag controls how many worker processes the env server spawns. Each worker owns its own environment instance and runs rollouts independently. The default `auto` scales with concurrency.

### Display
Expand Down
3 changes: 3 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,15 @@ class RolloutOutput(dict):
error: str | None
stop_condition: str | None
token_usage: TokenUsage
retry: RetryData
trajectory: list[TrajectoryStep]
tool_defs: list[Tool] | None
```

Serialized output from a rollout. This is a `dict` subclass that provides typed access to known fields while supporting arbitrary additional fields from `state_columns`. All values must be JSON-serializable. Used in `GenerateOutputs` and for saving results to disk.

The `retry` field is present only when `max_retries` caused at least one retry. It records attempts, retry count, exhaustion status, elapsed retry time, and per-attempt error summaries.

### TrajectoryStep

```python
Expand Down
1 change: 1 addition & 0 deletions skills/evaluate-environments/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ prime eval list
prime eval get <eval-id>
prime eval samples <eval-id>
```
4. When debugging transient failures from runs that used `--max-retries`, inspect each sample's saved `retry` block. It shows attempt count, whether retries were exhausted, elapsed retry time, and the retryable errors that triggered each retry.

## Metrics Interpretation
1. Treat binary and continuous rewards differently.
Expand Down
63 changes: 63 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,12 @@ async def test_retry_after_retryable_error(self, mock_client, make_input):

assert states[0].get("error") is None
assert env.call_counts[0] == 3
assert states[0]["retry"]["attempts"] == 3
assert states[0]["retry"]["retry_count"] == 2
assert states[0]["retry"]["exhausted"] is False
assert [
event["error"] for event in states[0]["retry"]["events"]
] == ["InfraError", "InfraError"]

@pytest.mark.asyncio
async def test_no_retry_after_non_retryable_error(self, mock_client, make_input):
Expand Down Expand Up @@ -696,6 +702,33 @@ async def test_error_in_state_after_max_retries_exhausted(
assert rollout_outputs[0].get("error") is not None
error_data = rollout_outputs[0]["error"]
assert "InfraError" == error_data["error"]
retry = rollout_outputs[0]["retry"]
assert retry["attempts"] == 3
assert retry["retry_count"] == 2
assert retry["exhausted"] is True
assert retry["events"][-1]["error"] == "InfraError"

@pytest.mark.asyncio
async def test_group_retry_metadata_is_attached_to_each_rollout(
self, mock_client, make_input
):
"""Grouped scoring should expose retry history on every returned state."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=1, dataset=dataset, parser=Parser(), rubric=Rubric()
)

inputs = [make_input(example_id=0), make_input(example_id=0)]
outputs = await env.generate(
inputs, client=mock_client, model="test-model", max_retries=2
)

rollout_outputs = outputs["outputs"]
assert len(rollout_outputs) == 2
assert env.call_counts[0] == 4
assert all(output.get("error") is None for output in rollout_outputs)
assert all(output["retry"]["attempts"] == 2 for output in rollout_outputs)
assert all(output["retry"]["retry_count"] == 1 for output in rollout_outputs)

@pytest.mark.asyncio
async def test_retries_serialized_infra_error_subclass(self):
Expand All @@ -721,6 +754,36 @@ async def attempt():
assert calls["n"] == 3 # 1 initial + 2 retries (InfraError is retryable)
assert result["error"] == serialized # last result returned after exhaustion

@pytest.mark.asyncio
async def test_group_retry_metadata_uses_retryable_state_index(self):
"""Retry metadata should point at the state that actually triggered retry."""
from verifiers.utils.async_utils import maybe_retry

calls = {"n": 0}

async def attempt():
calls["n"] += 1
if calls["n"] == 1:
return [
{"error": vf.ToolError("bad tool call")},
{"error": vf.InfraError("worker timed out")},
]
return [
{"error": None},
{"error": None},
]

result = await maybe_retry(
attempt,
max_retries=1,
initial=0.0,
max_wait=0.0,
)()

assert calls["n"] == 2
assert result[0]["retry"]["events"][0]["state_index"] == "1"
assert result[1]["retry"]["events"][0]["state_index"] == "1"


class TestEmptyModelResponseErrors:
"""Test cases for empty and invalid model response error handling."""
Expand Down
23 changes: 23 additions & 0 deletions tests/test_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,29 @@ def test_states_to_outputs(self, make_state):
assert result[0].get("foo") == "bar" # custom field from make_state fixture
assert result[0]["reward"] == 1.0

def test_states_to_outputs_includes_retry_metadata(self, make_state):
state = make_state()
state["retry"] = {
"attempts": 3,
"max_retries": 2,
"retry_count": 2,
"exhausted": True,
"elapsed_seconds": 0.42,
"events": [
{
"attempt": 1,
"error": "InfraError",
"message": "sandbox timed out",
"error_chain_str": "InfraError",
"next_sleep_seconds": 1.0,
}
],
}

outputs = states_to_outputs([state], state_columns=[])

assert outputs[0]["retry"] == state["retry"]

def test_states_to_outputs_requires_example_id(self, make_state):
state = make_state()
del state["example_id"]
Expand Down
21 changes: 20 additions & 1 deletion verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,24 @@ class ErrorData(TypedDict):
error_chain_str: str


class RetryEvent(TypedDict):
attempt: int
error: str
message: str
error_chain_str: str
next_sleep_seconds: NotRequired[float]
state_index: NotRequired[str]


class RetryData(TypedDict):
attempts: int
max_retries: int
retry_count: int
exhausted: bool
elapsed_seconds: float
events: list[RetryEvent]


class RolloutOutput(dict):
"""Serialized output from a rollout (mirrors RolloutInput).

Expand All @@ -409,7 +427,7 @@ class RolloutOutput(dict):
Required fields: example_id, prompt, completion, reward, timing,
is_completed, is_truncated, metrics
Optional fields: answer, info, error, stop_condition, trajectory, tool_defs,
token_usage
token_usage, retry
Additional fields: arbitrary serializable state_columns
"""

Expand All @@ -430,6 +448,7 @@ class RolloutOutput(dict):
trajectory: list["TrajectoryStep"]
tool_defs: list[Tool]
token_usage: TokenUsage
retry: RetryData


_MISSING = object()
Expand Down
123 changes: 105 additions & 18 deletions verifiers/utils/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import inspect
import logging
import time
from collections import deque
from collections.abc import Mapping
from collections.abc import Coroutine
Expand All @@ -13,6 +14,7 @@

import verifiers as vf
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.error_utils import error_data
from verifiers.utils.error_utils import error_from_data, is_error_data
from verifiers.utils.logging_utils import print_time

Expand Down Expand Up @@ -161,44 +163,110 @@ def maybe_retry(
if max_retries <= 0:
return func

def reraise_one(err, error_types: tuple[type[Exception], ...]):
retry_started_at = 0.0
retry_events: list[dict[str, Any]] = []
last_result = None

def summarize_error(error: BaseException | Mapping[str, Any]) -> dict[str, str]:
"""Build the JSON-safe part of a retry event."""
if isinstance(error, Mapping) and is_error_data(error):
return {
"error": error["error"],
"message": error["message"],
"error_chain_str": error["error_chain_str"],
}
if isinstance(error, BaseException):
data = error_data(error)
return {
"error": data["error"],
"message": data["message"],
"error_chain_str": data["error_chain_str"],
}
return {
"error": type(error).__name__,
"message": str(error),
"error_chain_str": type(error).__name__,
}

def retryable_error(
err, error_types: tuple[type[Exception], ...]
) -> Exception | None:
if not err:
return
return None
if any(isinstance(err, err_type) for err_type in error_types):
raise err
return err
# Rebuild serialized ErrorData so base types match (SandboxError -> InfraError).
if isinstance(err, Mapping) and is_error_data(err):
rebuilt = error_from_data(err)
if isinstance(rebuilt, error_types):
raise rebuilt
return rebuilt
return None

def first_retryable_state_error(
result, error_types: tuple[type[Exception], ...]
) -> tuple[str, Exception] | None:
if isinstance(result, dict):
err = result.get("error")
retryable = retryable_error(err, error_types)
return ("", retryable) if retryable is not None else None
if isinstance(result, list):
for index, state in enumerate(result):
retryable = retryable_error(state.get("error"), error_types)
if retryable is not None:
return (str(index), retryable)
return None

def reraise_error_from_state(result, error_types: tuple[type[Exception], ...]):
"""Re-raise specified errors from state(s) to trigger tenacity retry."""
retryable = first_retryable_state_error(result, error_types)
if retryable is not None:
raise retryable[1]

def attach_retry_info(result, *, exhausted: bool = False) -> None:
if not retry_events:
return
attempts = len(retry_events) + (0 if exhausted else 1)
summary = {
"attempts": attempts,
"max_retries": max_retries,
"retry_count": min(max_retries, max(0, attempts - 1)),
"exhausted": exhausted,
"elapsed_seconds": time.time() - retry_started_at,
"events": retry_events,
}
if isinstance(result, dict):
reraise_one(result.get("error"), error_types)
result["retry"] = summary
elif isinstance(result, list):
for state in result:
reraise_one(state.get("error"), error_types)
state["retry"] = summary

def begin_retry_call(retry_state: tc.RetryCallState) -> None:
nonlocal last_result, retry_events, retry_started_at
if retry_state.attempt_number == 1:
last_result = None
retry_events = []
retry_started_at = time.time()

def log_retry(retry_state: tc.RetryCallState) -> None:
"""Log a warning with the exception and the number of attempts."""
caller = retry_state.fn.__name__ if retry_state.fn else "unknown function"
error_chain = (
repr(
ErrorChain(
retry_state.outcome.exception() or Exception("Unknown exception")
)
)
if retry_state.outcome
else None
)
error = retry_state.outcome.exception() if retry_state.outcome else None
error_chain = repr(ErrorChain(error or Exception("Unknown exception")))
next_action = retry_state.next_action.sleep if retry_state.next_action else 0
if retry_events and retry_events[-1]["attempt"] == retry_state.attempt_number:
retry_events[-1]["next_sleep_seconds"] = next_action
else:
retry_events.append(
{
"attempt": retry_state.attempt_number,
"next_sleep_seconds": next_action,
**summarize_error(error or Exception("Unknown exception")),
}
)
logger.warning(
f"Caught {error_chain} in {caller}. Retrying in {print_time(next_action)} (retry {retry_state.attempt_number}/{max_retries})"
)

last_result = None

def return_last_result(retry_state: tc.RetryCallState):
"""Return the last result when retries are exhausted (instead of raising)."""
caller = retry_state.fn.__name__ if retry_state.fn else "unknown function"
Expand All @@ -215,13 +283,31 @@ def return_last_result(retry_state: tc.RetryCallState):
f"Retries exhausted for {caller} after {max_retries} attempts. "
f"Last error: {error_chain}. Continuing with error in state."
)
if last_result is not None:
attach_retry_info(last_result, exhausted=True)
return last_result

async def wrapper(*args, **kwargs):
nonlocal last_result
result = await func(*args, **kwargs)
last_result = result # store result
reraise_error_from_state(result, error_types)
retryable_state_error = first_retryable_state_error(result, error_types)
try:
reraise_error_from_state(result, error_types)
except error_types as error:
retry_events.append(
{
"attempt": len(retry_events) + 1,
"state_index": (
retryable_state_error[0]
if retryable_state_error is not None
else ""
),
**summarize_error(error),
}
Comment thread
cursor[bot] marked this conversation as resolved.
)
raise
attach_retry_info(result)
return result

wrapper.__name__ = getattr(func, "__name__", "unknown")
Expand All @@ -231,6 +317,7 @@ async def wrapper(*args, **kwargs):
retry=tc.retry_if_exception_type(error_types),
stop=tc.stop_after_attempt(max_retries + 1),
wait=tc.wait_exponential_jitter(initial=initial, max=max_wait),
before=begin_retry_call,
before_sleep=log_retry,
retry_error_callback=return_last_result,
reraise=True,
Expand Down
5 changes: 5 additions & 0 deletions verifiers/utils/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ def state_to_output(
output.pop("answer")
if "info" in output and not output["info"]:
output.pop("info")
retry = state.get("retry")
if retry is not None:
if not is_json_serializable(retry):
raise ValueError("state.retry must be JSON-serializable.")
output["retry"] = retry
# flatten metrics to top-level keys (backwards compatibility)
state_metrics = state.get("metrics") or {}
for k, v in state_metrics.items():
Expand Down