From 17030c0b89d1f7d53b6732d681e42029a0c87a89 Mon Sep 17 00:00:00 2001 From: Ritwij Aryan Parmar Date: Thu, 11 Jun 2026 17:09:57 -0400 Subject: [PATCH 1/2] feat(eval): persist rollout retry diagnostics --- docs/evaluation.md | 2 + docs/reference.md | 3 + tests/test_environment.py | 33 +++++++++++ tests/test_save_utils.py | 23 ++++++++ verifiers/types.py | 21 ++++++- verifiers/utils/async_utils.py | 105 +++++++++++++++++++++++++++++---- verifiers/utils/save_utils.py | 5 ++ 7 files changed, 179 insertions(+), 13 deletions(-) diff --git a/docs/evaluation.md b/docs/evaluation.md index bb11ebbb72..da31726c36 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -255,6 +255,8 @@ By default, scoring runs interleaved with generation. Use `--no-interleave-scori The `--max-retries` flag enables automatic retry with exponential backoff when rollouts fail due to transient infrastructure errors (e.g., sandbox timeouts, API failures). +When a rollout retries, its saved output includes a `retry` block. It records attempt count, exhaustion status, elapsed retry time, and the retryable errors that triggered each attempt. This keeps transient provider or sandbox instability visible in `results.jsonl`, not only in logs. + The `--num-workers` flag controls how many worker processes the env server spawns. Each worker owns its own environment instance and runs rollouts independently. The default `auto` scales with concurrency. ### Display diff --git a/docs/reference.md b/docs/reference.md index a50811f4aa..90f6eafd0c 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -185,12 +185,15 @@ class RolloutOutput(dict): error: str | None stop_condition: str | None token_usage: TokenUsage + retry: RetryData trajectory: list[TrajectoryStep] tool_defs: list[Tool] | None ``` Serialized output from a rollout. This is a `dict` subclass that provides typed access to known fields while supporting arbitrary additional fields from `state_columns`. All values must be JSON-serializable. Used in `GenerateOutputs` and for saving results to disk. +The `retry` field is present only when `max_retries` caused at least one retry. It records attempts, retry count, exhaustion status, elapsed retry time, and per-attempt error summaries. + ### TrajectoryStep ```python diff --git a/tests/test_environment.py b/tests/test_environment.py index cc30b19bed..ac97108d0c 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -652,6 +652,12 @@ async def test_retry_after_retryable_error(self, mock_client, make_input): assert states[0].get("error") is None assert env.call_counts[0] == 3 + assert states[0]["retry"]["attempts"] == 3 + assert states[0]["retry"]["retry_count"] == 2 + assert states[0]["retry"]["exhausted"] is False + assert [ + event["error"] for event in states[0]["retry"]["events"] + ] == ["InfraError", "InfraError"] @pytest.mark.asyncio async def test_no_retry_after_non_retryable_error(self, mock_client, make_input): @@ -696,6 +702,33 @@ async def test_error_in_state_after_max_retries_exhausted( assert rollout_outputs[0].get("error") is not None error_data = rollout_outputs[0]["error"] assert "InfraError" == error_data["error"] + retry = rollout_outputs[0]["retry"] + assert retry["attempts"] == 3 + assert retry["retry_count"] == 2 + assert retry["exhausted"] is True + assert retry["events"][-1]["error"] == "InfraError" + + @pytest.mark.asyncio + async def test_group_retry_metadata_is_attached_to_each_rollout( + self, mock_client, make_input + ): + """Grouped scoring should expose retry history on every returned state.""" + dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]}) + env = RetryCounterEnv( + fail_count=1, dataset=dataset, parser=Parser(), rubric=Rubric() + ) + + inputs = [make_input(example_id=0), make_input(example_id=0)] + outputs = await env.generate( + inputs, client=mock_client, model="test-model", max_retries=2 + ) + + rollout_outputs = outputs["outputs"] + assert len(rollout_outputs) == 2 + assert env.call_counts[0] == 4 + assert all(output.get("error") is None for output in rollout_outputs) + assert all(output["retry"]["attempts"] == 2 for output in rollout_outputs) + assert all(output["retry"]["retry_count"] == 1 for output in rollout_outputs) @pytest.mark.asyncio async def test_retries_serialized_infra_error_subclass(self): diff --git a/tests/test_save_utils.py b/tests/test_save_utils.py index 4c4fdf3804..788897fd8a 100644 --- a/tests/test_save_utils.py +++ b/tests/test_save_utils.py @@ -289,6 +289,29 @@ def test_states_to_outputs(self, make_state): assert result[0].get("foo") == "bar" # custom field from make_state fixture assert result[0]["reward"] == 1.0 + def test_states_to_outputs_includes_retry_metadata(self, make_state): + state = make_state() + state["retry"] = { + "attempts": 3, + "max_retries": 2, + "retry_count": 2, + "exhausted": True, + "elapsed_seconds": 0.42, + "events": [ + { + "attempt": 1, + "error": "InfraError", + "message": "sandbox timed out", + "error_chain_str": "InfraError", + "next_sleep_seconds": 1.0, + } + ], + } + + outputs = states_to_outputs([state], state_columns=[]) + + assert outputs[0]["retry"] == state["retry"] + def test_states_to_outputs_requires_example_id(self, make_state): state = make_state() del state["example_id"] diff --git a/verifiers/types.py b/verifiers/types.py index 4242f8a86f..7e2ba13c95 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -399,6 +399,24 @@ class ErrorData(TypedDict): error_chain_str: str +class RetryEvent(TypedDict): + attempt: int + error: str + message: str + error_chain_str: str + next_sleep_seconds: NotRequired[float] + state_index: NotRequired[str] + + +class RetryData(TypedDict): + attempts: int + max_retries: int + retry_count: int + exhausted: bool + elapsed_seconds: float + events: list[RetryEvent] + + class RolloutOutput(dict): """Serialized output from a rollout (mirrors RolloutInput). @@ -409,7 +427,7 @@ class RolloutOutput(dict): Required fields: example_id, prompt, completion, reward, timing, is_completed, is_truncated, metrics Optional fields: answer, info, error, stop_condition, trajectory, tool_defs, - token_usage + token_usage, retry Additional fields: arbitrary serializable state_columns """ @@ -430,6 +448,7 @@ class RolloutOutput(dict): trajectory: list["TrajectoryStep"] tool_defs: list[Tool] token_usage: TokenUsage + retry: RetryData _MISSING = object() diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index 1195599dab..1a28b26f74 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -1,6 +1,7 @@ import asyncio import inspect import logging +import time from collections import deque from collections.abc import Mapping from collections.abc import Coroutine @@ -13,6 +14,7 @@ import verifiers as vf from verifiers.utils.error_utils import ErrorChain +from verifiers.utils.error_utils import error_data from verifiers.utils.error_utils import error_from_data, is_error_data from verifiers.utils.logging_utils import print_time @@ -161,6 +163,31 @@ def maybe_retry( if max_retries <= 0: return func + retry_started_at = 0.0 + retry_events: list[dict[str, Any]] = [] + last_result = None + + def summarize_error(error: BaseException | Mapping[str, Any]) -> dict[str, str]: + """Build the JSON-safe part of a retry event.""" + if isinstance(error, Mapping) and is_error_data(error): + return { + "error": error["error"], + "message": error["message"], + "error_chain_str": error["error_chain_str"], + } + if isinstance(error, BaseException): + data = error_data(error) + return { + "error": data["error"], + "message": data["message"], + "error_chain_str": data["error_chain_str"], + } + return { + "error": type(error).__name__, + "message": str(error), + "error_chain_str": type(error).__name__, + } + def reraise_one(err, error_types: tuple[type[Exception], ...]): if not err: return @@ -172,6 +199,19 @@ def reraise_one(err, error_types: tuple[type[Exception], ...]): if isinstance(rebuilt, error_types): raise rebuilt + def iter_state_errors(result) -> list[tuple[str, Mapping[str, Any] | BaseException]]: + if isinstance(result, dict): + err = result.get("error") + return [("", err)] if err else [] + if isinstance(result, list): + errors = [] + for index, state in enumerate(result): + err = state.get("error") + if err: + errors.append((str(index), err)) + return errors + return [] + def reraise_error_from_state(result, error_types: tuple[type[Exception], ...]): """Re-raise specified errors from state(s) to trigger tenacity retry.""" if isinstance(result, dict): @@ -180,25 +220,51 @@ def reraise_error_from_state(result, error_types: tuple[type[Exception], ...]): for state in result: reraise_one(state.get("error"), error_types) + def attach_retry_info(result, *, exhausted: bool = False) -> None: + if not retry_events: + return + attempts = len(retry_events) + (0 if exhausted else 1) + summary = { + "attempts": attempts, + "max_retries": max_retries, + "retry_count": min(max_retries, max(0, attempts - 1)), + "exhausted": exhausted, + "elapsed_seconds": time.time() - retry_started_at, + "events": retry_events, + } + if isinstance(result, dict): + result["retry"] = summary + elif isinstance(result, list): + for state in result: + state["retry"] = summary + + def begin_retry_call(retry_state: tc.RetryCallState) -> None: + nonlocal last_result, retry_events, retry_started_at + if retry_state.attempt_number == 1: + last_result = None + retry_events = [] + retry_started_at = time.time() + def log_retry(retry_state: tc.RetryCallState) -> None: """Log a warning with the exception and the number of attempts.""" caller = retry_state.fn.__name__ if retry_state.fn else "unknown function" - error_chain = ( - repr( - ErrorChain( - retry_state.outcome.exception() or Exception("Unknown exception") - ) - ) - if retry_state.outcome - else None - ) + error = retry_state.outcome.exception() if retry_state.outcome else None + error_chain = repr(ErrorChain(error or Exception("Unknown exception"))) next_action = retry_state.next_action.sleep if retry_state.next_action else 0 + if retry_events and retry_events[-1]["attempt"] == retry_state.attempt_number: + retry_events[-1]["next_sleep_seconds"] = next_action + else: + retry_events.append( + { + "attempt": retry_state.attempt_number, + "next_sleep_seconds": next_action, + **summarize_error(error or Exception("Unknown exception")), + } + ) logger.warning( f"Caught {error_chain} in {caller}. Retrying in {print_time(next_action)} (retry {retry_state.attempt_number}/{max_retries})" ) - last_result = None - def return_last_result(retry_state: tc.RetryCallState): """Return the last result when retries are exhausted (instead of raising).""" caller = retry_state.fn.__name__ if retry_state.fn else "unknown function" @@ -215,13 +281,27 @@ def return_last_result(retry_state: tc.RetryCallState): f"Retries exhausted for {caller} after {max_retries} attempts. " f"Last error: {error_chain}. Continuing with error in state." ) + if last_result is not None: + attach_retry_info(last_result, exhausted=True) return last_result async def wrapper(*args, **kwargs): nonlocal last_result result = await func(*args, **kwargs) last_result = result # store result - reraise_error_from_state(result, error_types) + state_errors = iter_state_errors(result) + try: + reraise_error_from_state(result, error_types) + except error_types as error: + retry_events.append( + { + "attempt": len(retry_events) + 1, + "state_index": state_errors[0][0] if state_errors else "", + **summarize_error(error), + } + ) + raise + attach_retry_info(result) return result wrapper.__name__ = getattr(func, "__name__", "unknown") @@ -231,6 +311,7 @@ async def wrapper(*args, **kwargs): retry=tc.retry_if_exception_type(error_types), stop=tc.stop_after_attempt(max_retries + 1), wait=tc.wait_exponential_jitter(initial=initial, max=max_wait), + before=begin_retry_call, before_sleep=log_retry, retry_error_callback=return_last_result, reraise=True, diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 8f59e8e01d..7aa71aaca4 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -294,6 +294,11 @@ def state_to_output( output.pop("answer") if "info" in output and not output["info"]: output.pop("info") + retry = state.get("retry") + if retry is not None: + if not is_json_serializable(retry): + raise ValueError("state.retry must be JSON-serializable.") + output["retry"] = retry # flatten metrics to top-level keys (backwards compatibility) state_metrics = state.get("metrics") or {} for k, v in state_metrics.items(): From 129847c19a3ad8624c7168e1e553c04b38118a4f Mon Sep 17 00:00:00 2001 From: Ritwij Aryan Parmar Date: Thu, 11 Jun 2026 17:27:43 -0400 Subject: [PATCH 2/2] fix(eval): align retry event state index --- skills/evaluate-environments/SKILL.md | 1 + tests/test_environment.py | 30 ++++++++++++++++++ verifiers/utils/async_utils.py | 44 +++++++++++++++------------ 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/skills/evaluate-environments/SKILL.md b/skills/evaluate-environments/SKILL.md index e0e7860213..6723b7ec7d 100644 --- a/skills/evaluate-environments/SKILL.md +++ b/skills/evaluate-environments/SKILL.md @@ -219,6 +219,7 @@ prime eval list prime eval get prime eval samples ``` +4. When debugging transient failures from runs that used `--max-retries`, inspect each sample's saved `retry` block. It shows attempt count, whether retries were exhausted, elapsed retry time, and the retryable errors that triggered each retry. ## Metrics Interpretation 1. Treat binary and continuous rewards differently. diff --git a/tests/test_environment.py b/tests/test_environment.py index ac97108d0c..c6b5659d75 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -754,6 +754,36 @@ async def attempt(): assert calls["n"] == 3 # 1 initial + 2 retries (InfraError is retryable) assert result["error"] == serialized # last result returned after exhaustion + @pytest.mark.asyncio + async def test_group_retry_metadata_uses_retryable_state_index(self): + """Retry metadata should point at the state that actually triggered retry.""" + from verifiers.utils.async_utils import maybe_retry + + calls = {"n": 0} + + async def attempt(): + calls["n"] += 1 + if calls["n"] == 1: + return [ + {"error": vf.ToolError("bad tool call")}, + {"error": vf.InfraError("worker timed out")}, + ] + return [ + {"error": None}, + {"error": None}, + ] + + result = await maybe_retry( + attempt, + max_retries=1, + initial=0.0, + max_wait=0.0, + )() + + assert calls["n"] == 2 + assert result[0]["retry"]["events"][0]["state_index"] == "1" + assert result[1]["retry"]["events"][0]["state_index"] == "1" + class TestEmptyModelResponseErrors: """Test cases for empty and invalid model response error handling.""" diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index 1a28b26f74..10ade05e23 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -188,37 +188,39 @@ def summarize_error(error: BaseException | Mapping[str, Any]) -> dict[str, str]: "error_chain_str": type(error).__name__, } - def reraise_one(err, error_types: tuple[type[Exception], ...]): + def retryable_error( + err, error_types: tuple[type[Exception], ...] + ) -> Exception | None: if not err: - return + return None if any(isinstance(err, err_type) for err_type in error_types): - raise err + return err # Rebuild serialized ErrorData so base types match (SandboxError -> InfraError). if isinstance(err, Mapping) and is_error_data(err): rebuilt = error_from_data(err) if isinstance(rebuilt, error_types): - raise rebuilt + return rebuilt + return None - def iter_state_errors(result) -> list[tuple[str, Mapping[str, Any] | BaseException]]: + def first_retryable_state_error( + result, error_types: tuple[type[Exception], ...] + ) -> tuple[str, Exception] | None: if isinstance(result, dict): err = result.get("error") - return [("", err)] if err else [] + retryable = retryable_error(err, error_types) + return ("", retryable) if retryable is not None else None if isinstance(result, list): - errors = [] for index, state in enumerate(result): - err = state.get("error") - if err: - errors.append((str(index), err)) - return errors - return [] + retryable = retryable_error(state.get("error"), error_types) + if retryable is not None: + return (str(index), retryable) + return None def reraise_error_from_state(result, error_types: tuple[type[Exception], ...]): """Re-raise specified errors from state(s) to trigger tenacity retry.""" - if isinstance(result, dict): - reraise_one(result.get("error"), error_types) - elif isinstance(result, list): - for state in result: - reraise_one(state.get("error"), error_types) + retryable = first_retryable_state_error(result, error_types) + if retryable is not None: + raise retryable[1] def attach_retry_info(result, *, exhausted: bool = False) -> None: if not retry_events: @@ -289,14 +291,18 @@ async def wrapper(*args, **kwargs): nonlocal last_result result = await func(*args, **kwargs) last_result = result # store result - state_errors = iter_state_errors(result) + retryable_state_error = first_retryable_state_error(result, error_types) try: reraise_error_from_state(result, error_types) except error_types as error: retry_events.append( { "attempt": len(retry_events) + 1, - "state_index": state_errors[0][0] if state_errors else "", + "state_index": ( + retryable_state_error[0] + if retryable_state_error is not None + else "" + ), **summarize_error(error), } )