From 36e431c3adfe24f596cc4c860fa6dd89225776a0 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Tue, 5 May 2026 23:24:39 +0000 Subject: [PATCH 1/9] first implement of chaos module --- src/strands_evals/__init__.py | 3 +- src/strands_evals/chaos/__init__.py | 25 +++ src/strands_evals/chaos/effects.py | 117 ++++++++++++ src/strands_evals/chaos/experiment.py | 165 +++++++++++++++++ src/strands_evals/chaos/plugin.py | 248 ++++++++++++++++++++++++++ src/strands_evals/chaos/scenario.py | 65 +++++++ 6 files changed, 622 insertions(+), 1 deletion(-) create mode 100644 src/strands_evals/chaos/__init__.py create mode 100644 src/strands_evals/chaos/effects.py create mode 100644 src/strands_evals/chaos/experiment.py create mode 100644 src/strands_evals/chaos/plugin.py create mode 100644 src/strands_evals/chaos/scenario.py diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 229dba0a..35919db8 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,4 @@ -from . import detectors, evaluators, extractors, generators, providers, simulation, telemetry, types +from . import chaos, detectors, evaluators, extractors, generators, providers, simulation, telemetry, types from .case import Case from .eval_task_handler import EvalTaskHandler, TracedHandler, eval_task from .evaluation_data_store import EvaluationDataStore @@ -17,6 +17,7 @@ "EvalTaskHandler", "TracedHandler", "eval_task", + "chaos", "detectors", "evaluators", "extractors", diff --git a/src/strands_evals/chaos/__init__.py b/src/strands_evals/chaos/__init__.py new file mode 100644 index 00000000..fbfc129e --- /dev/null +++ b/src/strands_evals/chaos/__init__.py @@ -0,0 +1,25 @@ +"""Chaos testing module for Strands Evals. + +Provides deterministic fault injection for evaluating agent resilience +under tool failures and response corruption scenarios. +""" + +from .effects import ( + TOOL_CORRUPTION_EFFECTS, + TOOL_ERROR_EFFECTS, + ToolChaosEffect, + ChaosEffectConfig, +) +from .experiment import ChaosExperiment +from .plugin import ChaosPlugin +from .scenario import ChaosScenario + +__all__ = [ + "ToolChaosEffect", + "ChaosEffectConfig", + "ChaosExperiment", + "ChaosPlugin", + "ChaosScenario", + "TOOL_CORRUPTION_EFFECTS", + "TOOL_ERROR_EFFECTS", +] diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py new file mode 100644 index 00000000..5a29bd19 --- /dev/null +++ b/src/strands_evals/chaos/effects.py @@ -0,0 +1,117 @@ +"""Chaos effect definitions. + +Provides a flat enum of all chaos effects and an optional config class +for advanced parameterization. +""" + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + + +class ToolChaosEffect(str, Enum): + """All chaos effects a tool can experience. + + Error effects cause the tool call to be cancelled before execution. + Corruption effects mutate the tool response after execution. + """ + + # Error effects (tool call is cancelled with a simulated error) + TIMEOUT = "timeout" + NETWORK_ERROR = "network_error" + EXECUTION_ERROR = "execution_error" + VALIDATION_ERROR = "validation_error" + + # Response corruption effects (tool executes, response is mangled) + TRUNCATE_FIELDS = "truncate_fields" + REMOVE_FIELDS = "remove_fields" + CORRUPT_VALUES = "corrupt_values" + + + + + +# Sets for classification +TOOL_ERROR_EFFECTS = { + ToolChaosEffect.TIMEOUT, + ToolChaosEffect.NETWORK_ERROR, + ToolChaosEffect.EXECUTION_ERROR, + ToolChaosEffect.VALIDATION_ERROR, +} + +TOOL_CORRUPTION_EFFECTS = { + ToolChaosEffect.TRUNCATE_FIELDS, + ToolChaosEffect.REMOVE_FIELDS, + ToolChaosEffect.CORRUPT_VALUES, +} + + +class ChaosEffectConfig(BaseModel): + """Advanced chaos effect configuration. + + Use when the bare ToolChaosEffect enum needs tuning. Most users only need + the enum value directly in ChaosScenario.tool_effects. + + Example:: + + # Simple (90% of cases): just use the enum + tool_effects={"search_tool": ToolChaosEffect.TIMEOUT} + + # Advanced: custom error message + tool_effects={"search_tool": ChaosEffectConfig( + effect=ToolChaosEffect.TIMEOUT, + error_message="Request timed out after 30s", + )} + + # Advanced: tune corruption ratio + tool_effects={"db_tool": ChaosEffectConfig( + effect=ToolChaosEffect.REMOVE_FIELDS, + remove_ratio=0.5, + )} + """ + + effect: ToolChaosEffect = Field(..., description="The chaos effect to apply") + apply_rate: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Probability that this effect fires (1.0 = always, deterministic)", + ) + + # Error-specific parameters + error_message: Optional[str] = Field( + default=None, + description="Custom error message (for TIMEOUT, NETWORK_ERROR, EXECUTION_ERROR, VALIDATION_ERROR)", + ) + + # Corruption-specific parameters + remove_ratio: Optional[float] = Field( + default=None, + ge=0.0, + le=1.0, + description="Fraction of fields to remove (for REMOVE_FIELDS, default 0.33)", + ) + corrupt_rate: Optional[float] = Field( + default=None, + ge=0.0, + le=1.0, + description="Fraction of values to corrupt (for CORRUPT_VALUES, default 0.4)", + ) + + @model_validator(mode="after") + def validate_params_match_effect(self) -> "ChaosEffectConfig": + """Warn if irrelevant parameters are set for the given effect.""" + if self.effect in TOOL_ERROR_EFFECTS: + if self.remove_ratio is not None or self.corrupt_rate is not None: + raise ValueError( + f"remove_ratio and corrupt_rate are not applicable to error effect " + f"'{self.effect.value}'. These params only apply to corruption effects." + ) + if self.effect in TOOL_CORRUPTION_EFFECTS: + if self.error_message is not None: + raise ValueError( + f"error_message is not applicable to corruption effect " + f"'{self.effect.value}'. This param only applies to error effects." + ) + return self diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py new file mode 100644 index 00000000..907e7b49 --- /dev/null +++ b/src/strands_evals/chaos/experiment.py @@ -0,0 +1,165 @@ +"""Chaos Experiment. + +Extends the base Experiment to run test cases across multiple chaos scenarios, +providing deterministic evaluation of agent resilience under tool failures. +""" + +import logging +from collections.abc import Callable +from typing import Any, Optional + +from ..case import Case +from ..evaluators.evaluator import Evaluator +from ..experiment import Experiment +from ..types.evaluation_report import EvaluationReport +from .plugin import ChaosPlugin +from .scenario import ChaosScenario + +logger = logging.getLogger(__name__) + + +class ChaosExperiment(Experiment): + """Extends Experiment to run cases × chaos scenarios. + + For each scenario, sets it as active on the ChaosPlugin, runs all cases + through the evaluators, then clears. Reports are tagged by scenario name. + + Optionally includes a baseline run (no chaos) for comparison. + + Example:: + + from strands_evals.chaos import ( + ToolChaosEffect, + ChaosExperiment, + ChaosPlugin, + ChaosScenario, + ) + + chaos = ChaosPlugin() + agent = Agent(model=my_model, tools=[...], plugins=[chaos]) + + scenarios = [ + ChaosScenario(name="search_timeout", tool_effects={"search_tool": ToolChaosEffect.TIMEOUT}), + ChaosScenario(name="db_down", tool_effects={"database_tool": ToolChaosEffect.NETWORK_ERROR}), + ] + + experiment = ChaosExperiment( + chaos_plugin=chaos, + chaos_scenarios=scenarios, + cases=[Case(input="Find flights to Tokyo", name="flight_search")], + evaluators=[my_evaluator], + include_baseline=True, + ) + + reports = experiment.run_evaluations(task=lambda case: agent(case.input)) + """ + + def __init__( + self, + chaos_plugin: ChaosPlugin, + chaos_scenarios: list[ChaosScenario], + cases: Optional[list[Case]] = None, + evaluators: Optional[list[Evaluator]] = None, + include_baseline: bool = True, + baseline_assertion: Optional[str] = None, + ): + """Initialize a ChaosExperiment. + + Args: + chaos_plugin: The ChaosPlugin instance attached to the agent. + chaos_scenarios: List of scenarios to evaluate. Each scenario runs + all cases independently. + cases: Test cases to evaluate (same as base Experiment). + evaluators: Evaluators to assess results (same as base Experiment). + include_baseline: If True, runs all cases with no chaos first for comparison. + baseline_assertion: Optional assertion string for baseline evaluation + (e.g., "agent should respond correctly without any failures"). + """ + super().__init__(cases=cases, evaluators=evaluators) + self.chaos_plugin = chaos_plugin + self.chaos_scenarios = chaos_scenarios + self.include_baseline = include_baseline + self.baseline_assertion = baseline_assertion + + def run_evaluations( + self, + task: Callable[[Case], Any], + **kwargs, + ) -> list[EvaluationReport]: + """Run evaluations across all scenarios (and optionally baseline). + + Executes the task for each (scenario, case) pair: + 1. If include_baseline=True, runs all cases with no chaos active. + 2. For each scenario, activates it on the plugin, runs all cases, + then clears. + + Results from all runs are collected into a flat list of EvaluationReports, + with each case tagged with its scenario name in metadata. + + Args: + task: The task function to evaluate. Takes a Case and returns output. + **kwargs: Additional kwargs passed to the base run_evaluations. + + Returns: + List of EvaluationReport objects covering all scenarios. + """ + all_reports: list[EvaluationReport] = [] + + # Baseline run (no chaos) + if self.include_baseline: + logger.info("Running baseline evaluation (no chaos)") + self.chaos_plugin.set_active_scenario(None) + baseline_cases = self._tag_cases_with_scenario("baseline") + original_cases = self._cases + self._cases = baseline_cases + try: + reports = super().run_evaluations(task, **kwargs) + all_reports.extend(reports) + finally: + self._cases = original_cases + + # Chaos scenario runs + for scenario in self.chaos_scenarios: + logger.info(f"Running chaos scenario: {scenario.name}") + self.chaos_plugin.set_active_scenario(scenario) + scenario_cases = self._tag_cases_with_scenario(scenario.name) + original_cases = self._cases + self._cases = scenario_cases + try: + reports = super().run_evaluations(task, **kwargs) + all_reports.extend(reports) + finally: + self._cases = original_cases + + # Clear active scenario after all runs + self.chaos_plugin.set_active_scenario(None) + logger.info( + f"Chaos experiment complete: {len(all_reports)} reports " + f"({1 if self.include_baseline else 0} baseline + " + f"{len(self.chaos_scenarios)} scenarios)" + ) + + return all_reports + + def _tag_cases_with_scenario(self, scenario_name: str) -> list[Case]: + """Create copies of cases with scenario name injected into metadata. + + Args: + scenario_name: The scenario name to tag. + + Returns: + Deep copies of all cases with metadata["chaos_scenario"] set. + """ + tagged_cases = [] + for case in self._cases: + tagged = case.model_copy(deep=True) + if tagged.metadata is None: + tagged.metadata = {} + tagged.metadata["chaos_scenario"] = scenario_name + # Update case name to include scenario for report clarity + if tagged.name: + tagged.name = f"{tagged.name} [{scenario_name}]" + else: + tagged.name = f"[{scenario_name}]" + tagged_cases.append(tagged) + return tagged_cases diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py new file mode 100644 index 00000000..6e8439f3 --- /dev/null +++ b/src/strands_evals/chaos/plugin.py @@ -0,0 +1,248 @@ +"""Chaos Plugin for Strands Agents. + +Implements chaos injection as a standard Strands Plugin using the SDK's +native hook system (BeforeToolCallEvent / AfterToolCallEvent). +""" + +import logging +import math +import random +from typing import Any, Optional + +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.plugins import Plugin, hook + +from .effects import ( + TOOL_CORRUPTION_EFFECTS, + TOOL_ERROR_EFFECTS, + ToolChaosEffect, + ChaosEffectConfig, +) +from .scenario import ChaosScenario + +logger = logging.getLogger(__name__) + + +class ChaosPlugin(Plugin): + """Strands Plugin that injects deterministic chaos based on the active scenario. + + The plugin intercepts tool calls via Strands' native hook system: + - BeforeToolCallEvent: cancels tool calls for error effects (TIMEOUT, NETWORK_ERROR, etc.) + - AfterToolCallEvent: corrupts tool responses for corruption effects (TRUNCATE_FIELDS, etc.) + + The active scenario is set externally (typically by ChaosExperiment) before + each evaluation run. When no scenario is active, all tools behave normally. + + Example:: + + from strands import Agent + from strands_evals.chaos import ChaosPlugin, ChaosScenario, ToolChaosEffect + + chaos = ChaosPlugin() + agent = Agent( + model=my_model, + tools=[search_tool, database_tool], + plugins=[chaos], + ) + + # Activate a scenario + chaos.set_active_scenario(ChaosScenario( + name="search_timeout", + tool_effects={"search_tool": ToolChaosEffect.TIMEOUT}, + )) + + result = agent("Find flights to Tokyo") + # search_tool will be cancelled with a timeout error + """ + + name = "chaos-testing" + + def __init__(self) -> None: + super().__init__() + self._active_scenario: Optional[ChaosScenario] = None + + @property + def active_scenario(self) -> Optional[ChaosScenario]: + """The currently active chaos scenario, or None for baseline (no chaos).""" + return self._active_scenario + + def set_active_scenario(self, scenario: Optional[ChaosScenario]) -> None: + """Set the scenario that drives chaos injection for subsequent tool calls. + + Args: + scenario: The scenario to activate, or None to disable chaos (baseline). + """ + self._active_scenario = scenario + if scenario: + logger.info(f"Chaos scenario activated: {scenario.name}") + else: + logger.info("Chaos scenario cleared (baseline mode)") + + def _should_apply(self, config: ChaosEffectConfig) -> bool: + """Check if the effect should fire based on apply_rate.""" + if config.apply_rate >= 1.0: + return True + if config.apply_rate <= 0.0: + return False + return random.random() < config.apply_rate + + @hook + def before_tool_call(self, event: BeforeToolCallEvent) -> None: + """Intercept tool calls to inject error effects. + + For error effects (TIMEOUT, NETWORK_ERROR, etc.), cancels the tool call + with a simulated error message before the tool executes. + """ + if not self._active_scenario: + return + + tool_name = event.tool_use.get("name", "") + chaos_effect = self._active_scenario.tool_effects.get(tool_name) + if chaos_effect is None: + return + + if isinstance(chaos_effect, ToolChaosEffect): + chaos_config = ChaosEffectConfig(effect=chaos_effect) + elif isinstance(chaos_effect, ChaosEffectConfig): + chaos_config = chaos_effect + else: + raise TypeError( + f"Unexpected effect type for tool '{tool_name}': {type(chaos_effect).__name__}. " + f"Expected ToolChaosEffect or ChaosEffectConfig." + ) + + # Only handle error effects in the before hook + if chaos_config.effect not in TOOL_ERROR_EFFECTS: + return + + if not self._should_apply(chaos_config): + return + + # Cancel the tool call with a simulated error + error_message = chaos_config.error_message or f"Simulated {chaos_config.effect.value}" + event.cancel_tool = error_message + logger.info( + f"[Chaos] Injected {chaos_config.effect.value} on tool '{tool_name}': {error_message}" + ) + + @hook + def after_tool_call(self, event: AfterToolCallEvent) -> None: + """Intercept tool results to inject corruption effects. + + For corruption effects (TRUNCATE_FIELDS, REMOVE_FIELDS, CORRUPT_VALUES), + mutates the tool response after successful execution. + """ + if not self._active_scenario: + return + + tool_name = event.tool_use.get("name", "") + chaos_effect = self._active_scenario.tool_effects.get(tool_name) + if chaos_effect is None: + return + + if isinstance(chaos_effect, ToolChaosEffect): + chaos_config = ChaosEffectConfig(effect=chaos_effect) + elif isinstance(chaos_effect, ChaosEffectConfig): + chaos_config = chaos_effect + else: + raise TypeError( + f"Unexpected effect type for tool '{tool_name}': {type(chaos_effect).__name__}. " + f"Expected ToolChaosEffect or ChaosEffectConfig." + ) + + # Only handle corruption effects in the after hook + if chaos_config.effect not in TOOL_CORRUPTION_EFFECTS: + return + + if not self._should_apply(chaos_config): + return + + # Corrupt the tool result + if hasattr(event, "result") and event.result is not None: + result = event.result + # Handle ToolResult-like objects with content + if hasattr(result, "content") and isinstance(result.content, dict): + result.content = self._apply_corruption(chaos_config, result.content) + elif isinstance(result, dict): + event.result = self._apply_corruption(chaos_config, result) + + logger.info( + f"[Chaos] Applied {chaos_config.effect.value} corruption on tool '{tool_name}'" + ) + + # ------------------------------------------------------------------ + # Corruption helpers (private) + # ------------------------------------------------------------------ + + def _apply_corruption(self, effect_config: ChaosEffectConfig, response: Any) -> Any: + """Apply a corruption effect to a tool response. + + Args: + effect_config: The normalized effect configuration. + response: The tool result to corrupt. Expected to be a dict. + + Returns: + The corrupted response. + """ + if not isinstance(response, dict): + return response + + effect = effect_config.effect + + if effect == ToolChaosEffect.TRUNCATE_FIELDS: + return self._truncate_fields(response) + elif effect == ToolChaosEffect.REMOVE_FIELDS: + ratio = effect_config.remove_ratio if effect_config.remove_ratio is not None else 0.33 + return self._remove_fields(response, ratio) + elif effect == ToolChaosEffect.CORRUPT_VALUES: + rate = effect_config.corrupt_rate if effect_config.corrupt_rate is not None else 0.4 + return self._corrupt_values(response, rate) + + return response + + @staticmethod + def _truncate_fields(response: dict[str, Any]) -> dict[str, Any]: + """Truncate string values to partial content.""" + result: dict[str, Any] = {} + for key, value in response.items(): + if isinstance(value, str) and len(value) > 0: + result[key] = value[: random.randint(0, max(0, len(value) - 1))] + elif isinstance(value, dict): + result[key] = ChaosPlugin._truncate_fields(value) + else: + result[key] = value + return result + + @staticmethod + def _remove_fields(response: dict[str, Any], remove_ratio: float) -> dict[str, Any]: + """Remove a fraction of fields from the response.""" + keys = list(response.keys()) + if not keys: + return response + + num_to_remove = max(1, math.ceil(len(keys) * remove_ratio)) + keys_to_remove = set(random.sample(keys, min(num_to_remove, len(keys)))) + return {k: v for k, v in response.items() if k not in keys_to_remove} + + @staticmethod + def _corrupt_values(response: dict[str, Any], corrupt_rate: float) -> dict[str, Any]: + """Replace a fraction of values with wrong types or garbage data.""" + corruptions: list[Any] = [None, 99999, "", True, [], "CORRUPTED_DATA"] + + keys = list(response.keys()) + if not keys: + return response + + num_to_corrupt = max(1, math.ceil(len(keys) * corrupt_rate)) + keys_to_corrupt = set(random.sample(keys, min(num_to_corrupt, len(keys)))) + + result: dict[str, Any] = {} + for key, value in response.items(): + if key in keys_to_corrupt: + candidates = [c for c in corruptions if c != value] + result[key] = random.choice(candidates) if candidates else "CORRUPTED_DATA" + elif isinstance(value, dict): + result[key] = ChaosPlugin._corrupt_values(value, corrupt_rate) + else: + result[key] = value + return result diff --git a/src/strands_evals/chaos/scenario.py b/src/strands_evals/chaos/scenario.py new file mode 100644 index 00000000..d212d438 --- /dev/null +++ b/src/strands_evals/chaos/scenario.py @@ -0,0 +1,65 @@ +"""Chaos scenario definition. + +A ChaosScenario is a named, deterministic mapping of tool names to the +chaos effects that will fire when those tools are invoked. +""" + +from typing import Union + +from pydantic import BaseModel, Field + +from .effects import ToolChaosEffect, ChaosEffectConfig + + +# Type alias for what a tool_effects value can be +EffectSpec = Union[ToolChaosEffect, ChaosEffectConfig] + + +class ChaosScenario(BaseModel): + """A single, deterministic chaos injection scenario. + + Each scenario maps tool names to the exact effect that fires when that + tool is invoked. Tools not listed in tool_effects behave normally (no chaos). + + Example:: + + # Simple: one tool fails with timeout + ChaosScenario( + name="search_timeout", + tool_effects={"search_tool": ToolChaosEffect.TIMEOUT}, + ) + + # Multiple tools affected + ChaosScenario( + name="both_tools_down", + tool_effects={ + "search_tool": ToolChaosEffect.TIMEOUT, + "database_tool": ToolChaosEffect.NETWORK_ERROR, + }, + ) + + # Advanced: custom parameters + ChaosScenario( + name="partial_corruption", + tool_effects={ + "database_tool": ChaosEffectConfig( + effect=ToolChaosEffect.REMOVE_FIELDS, + remove_ratio=0.5, + ), + }, + ) + """ + + name: str = Field(..., description="Human-readable name for this scenario") + tool_effects: dict[str, EffectSpec] = Field( + default_factory=dict, + description="Mapping of tool_name -> effect to inject. " + "Tools not listed here behave normally.", + ) + + def __repr__(self) -> str: + effects_str = ", ".join( + f"{tool}: {eff.value if isinstance(eff, ToolChaosEffect) else eff.effect.value}" + for tool, eff in self.tool_effects.items() + ) + return f"ChaosScenario(name='{self.name}', tool_effects={{{effects_str}}})" From b6211f1c1c5dd92c30c405dd915afc15553b3235 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Wed, 6 May 2026 23:33:45 +0000 Subject: [PATCH 2/9] fix tool output corruption --- src/strands_evals/chaos/plugin.py | 57 +++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index 6e8439f3..5d57f272 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -161,8 +161,33 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: if hasattr(event, "result") and event.result is not None: result = event.result # Handle ToolResult-like objects with content - if hasattr(result, "content") and isinstance(result.content, dict): - result.content = self._apply_corruption(chaos_config, result.content) + if hasattr(result, "content"): + if isinstance(result.content, dict): + # Content is a data dict — corrupt it directly + result.content = self._apply_corruption(chaos_config, result.content) + elif isinstance(result.content, list): + # Content is a list of blocks (e.g., [{"text": "..."}]) + # Corrupt text content within each block + corrupted_blocks = [] + for block in result.content: + if isinstance(block, dict) and "text" in block: + text_data = block["text"] + if isinstance(text_data, str): + # Try to parse as JSON dict for field-level corruption + import json + try: + parsed = json.loads(text_data) + if isinstance(parsed, dict): + corrupted = self._apply_corruption(chaos_config, parsed) + block = {**block, "text": json.dumps(corrupted)} + else: + block = {**block, "text": text_data} + except (json.JSONDecodeError, ValueError): + # Plain text — truncate if that's the effect + if chaos_config.effect == ToolChaosEffect.TRUNCATE_FIELDS: + block = {**block, "text": text_data[: len(text_data) // 2]} + corrupted_blocks.append(block) + result.content = corrupted_blocks elif isinstance(result, dict): event.result = self._apply_corruption(chaos_config, result) @@ -175,11 +200,14 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: # ------------------------------------------------------------------ def _apply_corruption(self, effect_config: ChaosEffectConfig, response: Any) -> Any: - """Apply a corruption effect to a tool response. + """Apply a corruption effect to a tool response data dict. + + Only preserves the `status` field (Bedrock requires "success"/"error"). + All other fields including `content` are fair game for corruption. Args: effect_config: The normalized effect configuration. - response: The tool result to corrupt. Expected to be a dict. + response: The tool result data dict to corrupt. Returns: The corrupted response. @@ -189,14 +217,21 @@ def _apply_corruption(self, effect_config: ChaosEffectConfig, response: Any) -> effect = effect_config.effect + # Preserve status — Bedrock requires toolResult.status to be "success" or "error" + saved_status = response.get("status") + if effect == ToolChaosEffect.TRUNCATE_FIELDS: - return self._truncate_fields(response) + response = self._truncate_fields(response) elif effect == ToolChaosEffect.REMOVE_FIELDS: ratio = effect_config.remove_ratio if effect_config.remove_ratio is not None else 0.33 - return self._remove_fields(response, ratio) + response = self._remove_fields(response, ratio) elif effect == ToolChaosEffect.CORRUPT_VALUES: rate = effect_config.corrupt_rate if effect_config.corrupt_rate is not None else 0.4 - return self._corrupt_values(response, rate) + response = self._corrupt_values(response, rate) + + # Restore status + if saved_status is not None: + response["status"] = saved_status return response @@ -205,7 +240,9 @@ def _truncate_fields(response: dict[str, Any]) -> dict[str, Any]: """Truncate string values to partial content.""" result: dict[str, Any] = {} for key, value in response.items(): - if isinstance(value, str) and len(value) > 0: + if key == "status": + result[key] = value + elif isinstance(value, str) and len(value) > 0: result[key] = value[: random.randint(0, max(0, len(value) - 1))] elif isinstance(value, dict): result[key] = ChaosPlugin._truncate_fields(value) @@ -216,7 +253,7 @@ def _truncate_fields(response: dict[str, Any]) -> dict[str, Any]: @staticmethod def _remove_fields(response: dict[str, Any], remove_ratio: float) -> dict[str, Any]: """Remove a fraction of fields from the response.""" - keys = list(response.keys()) + keys = [k for k in response.keys() if k != "status"] if not keys: return response @@ -229,7 +266,7 @@ def _corrupt_values(response: dict[str, Any], corrupt_rate: float) -> dict[str, """Replace a fraction of values with wrong types or garbage data.""" corruptions: list[Any] = [None, 99999, "", True, [], "CORRUPTED_DATA"] - keys = list(response.keys()) + keys = [k for k in response.keys() if k != "status"] if not keys: return response From 1bc20f3e471a728c7844db6868e097e2f47c33f8 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Fri, 8 May 2026 22:56:15 +0000 Subject: [PATCH 3/9] refactor with contextvar --- src/strands_evals/chaos/__init__.py | 22 +- src/strands_evals/chaos/_context.py | 21 ++ src/strands_evals/chaos/effects.py | 314 +++++++++++++++++++------- src/strands_evals/chaos/experiment.py | 252 ++++++++++++++------- src/strands_evals/chaos/plugin.py | 314 ++++++++------------------ src/strands_evals/chaos/scenario.py | 70 +++--- 6 files changed, 569 insertions(+), 424 deletions(-) create mode 100644 src/strands_evals/chaos/_context.py diff --git a/src/strands_evals/chaos/__init__.py b/src/strands_evals/chaos/__init__.py index fbfc129e..1af6d808 100644 --- a/src/strands_evals/chaos/__init__.py +++ b/src/strands_evals/chaos/__init__.py @@ -7,19 +7,31 @@ from .effects import ( TOOL_CORRUPTION_EFFECTS, TOOL_ERROR_EFFECTS, - ToolChaosEffect, - ChaosEffectConfig, + ChaosEffect, + CorruptValues, + RemoveFields, + ToolCallFailure, + ToolEffect, + TruncateFields, ) from .experiment import ChaosExperiment from .plugin import ChaosPlugin from .scenario import ChaosScenario __all__ = [ - "ToolChaosEffect", - "ChaosEffectConfig", + # Core classes "ChaosExperiment", "ChaosPlugin", "ChaosScenario", - "TOOL_CORRUPTION_EFFECTS", + # Effect hierarchy + "ChaosEffect", + "ToolEffect", + # Concrete effects + "ToolCallFailure", + "TruncateFields", + "RemoveFields", + "CorruptValues", + # Classification sets "TOOL_ERROR_EFFECTS", + "TOOL_CORRUPTION_EFFECTS", ] diff --git a/src/strands_evals/chaos/_context.py b/src/strands_evals/chaos/_context.py new file mode 100644 index 00000000..8c8c0624 --- /dev/null +++ b/src/strands_evals/chaos/_context.py @@ -0,0 +1,21 @@ +"""Internal context variable for tracking the active chaos scenario. + +The ChaosPlugin reads from this ContextVar at hook time. +The ChaosExperiment sets and resets it around each case's task invocation. + +Using a ContextVar ensures correct behavior under: +- Sequential execution (trivially correct) +- Async execution (each asyncio.Task inherits the var from its parent) +- Threaded execution (each thread gets its own copy) +""" + +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .scenario import ChaosScenario + +_current_scenario: ContextVar["ChaosScenario | None"] = ContextVar( + "chaos_current_scenario", + default=None, +) diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py index 5a29bd19..3f51ea79 100644 --- a/src/strands_evals/chaos/effects.py +++ b/src/strands_evals/chaos/effects.py @@ -1,117 +1,273 @@ """Chaos effect definitions. -Provides a flat enum of all chaos effects and an optional config class -for advanced parameterization. +Effects are first-class parameterized classes organized in a hierarchy: + ChaosEffect → ToolEffect → concrete effects (Timeout, NetworkError, etc.) + → ModelEffect → (reserved for v2) + +Each concrete effect carries only the parameters meaningful to it. +The `hook` class variable indicates whether the effect fires pre-tool-call +(error effects) or post-tool-call (corruption effects). + +Pre-hook effects provide `error_message` (the plugin cancels the tool call). +Post-hook effects implement `apply(response)` (the plugin passes the response through). """ -from enum import Enum -from typing import Optional +import math +import random +from abc import abstractmethod +from typing import Any, ClassVar, Literal + +from pydantic import BaseModel, Field -from pydantic import BaseModel, Field, model_validator +# --------------------------------------------------------------------------- +# Base classes +# --------------------------------------------------------------------------- -class ToolChaosEffect(str, Enum): - """All chaos effects a tool can experience. - Error effects cause the tool call to be cancelled before execution. - Corruption effects mutate the tool response after execution. +class ChaosEffect(BaseModel): + """Base for all chaos effects. + + Attributes: + apply_rate: Probability that this effect fires. + In v1 this field is accepted but ignored (always fires). + hook: Whether this effect fires pre-call ("pre") or post-call ("post"). """ - # Error effects (tool call is cancelled with a simulated error) - TIMEOUT = "timeout" - NETWORK_ERROR = "network_error" - EXECUTION_ERROR = "execution_error" - VALIDATION_ERROR = "validation_error" + hook: ClassVar[Literal["pre", "post"]] - # Response corruption effects (tool executes, response is mangled) - TRUNCATE_FIELDS = "truncate_fields" - REMOVE_FIELDS = "remove_fields" - CORRUPT_VALUES = "corrupt_values" + apply_rate: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Probability that this effect fires (1.0 = always).", + ) + @abstractmethod + def apply(self, context: Any = None) -> Any: + """Apply the chaos effect. + Pre-hook effects return an error message string. + Post-hook effects accept a response dict and return the corrupted dict. + """ + ... +class ToolEffect(ChaosEffect): + """Effect valid at the tool invocation boundary. -# Sets for classification -TOOL_ERROR_EFFECTS = { - ToolChaosEffect.TIMEOUT, - ToolChaosEffect.NETWORK_ERROR, - ToolChaosEffect.EXECUTION_ERROR, - ToolChaosEffect.VALIDATION_ERROR, -} + - "pre": effect fires before tool execution (cancels the call with an error) + - "post": effect fires after tool execution (corrupts the response) + """ + + +# --------------------------------------------------------------------------- +# Pre-hook effect — cancels the tool call before execution +# --------------------------------------------------------------------------- -TOOL_CORRUPTION_EFFECTS = { - ToolChaosEffect.TRUNCATE_FIELDS, - ToolChaosEffect.REMOVE_FIELDS, - ToolChaosEffect.CORRUPT_VALUES, +# All supported failure types +ToolCallFailureType = Literal["timeout", "network_error", "execution_error", "validation_error"] + +# Default error messages per failure type +_DEFAULT_ERROR_MESSAGES: dict[str, str] = { + "timeout": "Tool call timed out", + "network_error": "Network unreachable", + "execution_error": "Tool execution failed", + "validation_error": "Tool input validation failed", } -class ChaosEffectConfig(BaseModel): - """Advanced chaos effect configuration. +class ToolCallFailure(ToolEffect): + """Simulates a tool call failure that prevents the tool from executing. - Use when the bare ToolChaosEffect enum needs tuning. Most users only need - the enum value directly in ChaosScenario.tool_effects. + The tool call is cancelled before execution with a simulated error message. Example:: - # Simple (90% of cases): just use the enum - tool_effects={"search_tool": ToolChaosEffect.TIMEOUT} - - # Advanced: custom error message - tool_effects={"search_tool": ChaosEffectConfig( - effect=ToolChaosEffect.TIMEOUT, - error_message="Request timed out after 30s", - )} + ChaosScenario( + name="search_timeout", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ) - # Advanced: tune corruption ratio - tool_effects={"db_tool": ChaosEffectConfig( - effect=ToolChaosEffect.REMOVE_FIELDS, - remove_ratio=0.5, - )} + ChaosScenario( + name="db_network_error", + effects={"database_tool": [ToolCallFailure( + error_type="network_error", + error_message="Connection refused on port 5432", + )]}, + ) """ - effect: ToolChaosEffect = Field(..., description="The chaos effect to apply") - apply_rate: float = Field( - default=1.0, - ge=0.0, - le=1.0, - description="Probability that this effect fires (1.0 = always, deterministic)", + hook: ClassVar[Literal["pre", "post"]] = "pre" + error_type: ToolCallFailureType = Field( + default="execution_error", + description="Type of failure to simulate.", ) - - # Error-specific parameters - error_message: Optional[str] = Field( + error_message: str | None = Field( default=None, - description="Custom error message (for TIMEOUT, NETWORK_ERROR, EXECUTION_ERROR, VALIDATION_ERROR)", + description="Custom error message. If None, uses a default for the error_type.", ) - # Corruption-specific parameters - remove_ratio: Optional[float] = Field( - default=None, + def apply(self, context: Any = None) -> str: + """Return the error message to cancel the tool call with.""" + if self.error_message is not None: + return self.error_message + return _DEFAULT_ERROR_MESSAGES[self.error_type] + + +# --------------------------------------------------------------------------- +# Concrete tool corruption effects (post-hook — mutate the response) +# +# Post-hook effects implement apply(response) -> response. +# The plugin calls effect.apply(response_dict) and uses the return value. +# --------------------------------------------------------------------------- + + +class TruncateFields(ToolEffect): + """Truncates string values in the tool response. + + The tool executes normally, but string fields in the response are + truncated to at most `max_length` characters. + + Example:: + + ChaosScenario( + name="search_truncated", + effects={ + "search_tool": [TruncateFields(max_length=5)], + }, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "post" + max_length: int = Field(default=10, ge=0, description="Maximum length to truncate string values to") + + def apply(self, response: dict[str, Any]) -> dict[str, Any]: + """Truncate string values to max_length. + + Args: + response: The tool response dict to corrupt. + + Returns: + Response with string values truncated. + """ + result: dict[str, Any] = {} + for key, value in response.items(): + if isinstance(value, str) and len(value) > self.max_length: + result[key] = value[: self.max_length] + elif isinstance(value, dict): + result[key] = self._truncate(value) + else: + result[key] = value + return result + + +class RemoveFields(ToolEffect): + """Removes a fraction of fields from the tool response. + + The tool executes normally, but a portion of the response fields + are deleted. + + Example:: + + ChaosScenario( + name="db_remove_fields", + effects={ + "database_tool": [RemoveFields(remove_ratio=0.5)], + }, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "post" + remove_ratio: float = Field( + default=0.5, ge=0.0, le=1.0, - description="Fraction of fields to remove (for REMOVE_FIELDS, default 0.33)", + description="Fraction of fields to remove from the response", ) - corrupt_rate: Optional[float] = Field( - default=None, + + def apply(self, response: dict[str, Any]) -> dict[str, Any]: + """Remove a fraction of fields from the response. + + Always removes at least 1 field when called. + + Args: + response: The tool response dict to corrupt. + + Returns: + Response with fields removed. + """ + keys = list(response.keys()) + if not keys: + return response + + num_to_remove = max(1, math.ceil(len(keys) * self.remove_ratio)) + keys_to_remove = set(random.sample(keys, min(num_to_remove, len(keys)))) + return {k: v for k, v in response.items() if k not in keys_to_remove} + + +class CorruptValues(ToolEffect): + """Replaces a fraction of values with garbage data. + + The tool executes normally, but a portion of the response values + are replaced with wrong types or nonsense data. + + Example:: + + ChaosScenario( + name="db_corrupt", + effects={ + "database_tool": [CorruptValues(corrupt_ratio=0.8)], + }, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "post" + corrupt_ratio: float = Field( + default=0.5, ge=0.0, le=1.0, - description="Fraction of values to corrupt (for CORRUPT_VALUES, default 0.4)", + description="Fraction of values to corrupt in the response", ) - @model_validator(mode="after") - def validate_params_match_effect(self) -> "ChaosEffectConfig": - """Warn if irrelevant parameters are set for the given effect.""" - if self.effect in TOOL_ERROR_EFFECTS: - if self.remove_ratio is not None or self.corrupt_rate is not None: - raise ValueError( - f"remove_ratio and corrupt_rate are not applicable to error effect " - f"'{self.effect.value}'. These params only apply to corruption effects." - ) - if self.effect in TOOL_CORRUPTION_EFFECTS: - if self.error_message is not None: - raise ValueError( - f"error_message is not applicable to corruption effect " - f"'{self.effect.value}'. This param only applies to error effects." - ) - return self + _CORRUPTIONS: ClassVar[list[Any]] = [None, 99999, "", True, [], "CORRUPTED_DATA"] + + def apply(self, response: dict[str, Any]) -> dict[str, Any]: + """Replace a fraction of values with wrong types or garbage data. + + Always corrupts at least 1 field when called. + + Args: + response: The tool response dict to corrupt. + + Returns: + Response with corrupted values. + """ + keys = list(response.keys()) + if not keys: + return response + + num_to_corrupt = max(1, math.ceil(len(keys) * self.corrupt_ratio)) + keys_to_corrupt = set(random.sample(keys, min(num_to_corrupt, len(keys)))) + + result: dict[str, Any] = {} + for key, value in response.items(): + if key in keys_to_corrupt: + candidates = [c for c in self._CORRUPTIONS if c != value] + result[key] = random.choice(candidates) if candidates else "CORRUPTED_DATA" + elif isinstance(value, dict): + result[key] = self.apply(value) + else: + result[key] = value + return result + + +# --------------------------------------------------------------------------- +# Convenience sets for classification (derived from hierarchy, not maintained manually) +# --------------------------------------------------------------------------- + +# All concrete pre-hook (error) effect classes +TOOL_ERROR_EFFECTS: set[type[ToolEffect]] = {ToolCallFailure} + +# All concrete post-hook (corruption) effect classes +TOOL_CORRUPTION_EFFECTS: set[type[ToolEffect]] = {TruncateFields, RemoveFields, CorruptValues} diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py index 907e7b49..9e30eee5 100644 --- a/src/strands_evals/chaos/experiment.py +++ b/src/strands_evals/chaos/experiment.py @@ -1,10 +1,11 @@ """Chaos Experiment. -Extends the base Experiment to run test cases across multiple chaos scenarios, +Composes the base Experiment to run test cases across multiple chaos scenarios, providing deterministic evaluation of agent resilience under tool failures. """ import logging +import uuid from collections.abc import Callable from typing import Any, Optional @@ -12,154 +13,233 @@ from ..evaluators.evaluator import Evaluator from ..experiment import Experiment from ..types.evaluation_report import EvaluationReport -from .plugin import ChaosPlugin +from ._context import _current_scenario from .scenario import ChaosScenario logger = logging.getLogger(__name__) +# The baseline scenario — no chaos effects +_BASELINE_SCENARIO = ChaosScenario(name="baseline") -class ChaosExperiment(Experiment): - """Extends Experiment to run cases × chaos scenarios. - For each scenario, sets it as active on the ChaosPlugin, runs all cases - through the evaluators, then clears. Reports are tagged by scenario name. +class ChaosExperiment: + """Runs cases × scenarios by composing the base Experiment. + + For each scenario, activates it via ContextVar, runs all cases through + the evaluators, then resets. The user's task body contains zero chaos + concepts — the plugin reads the active scenario from the ContextVar. Optionally includes a baseline run (no chaos) for comparison. Example:: from strands_evals.chaos import ( - ToolChaosEffect, ChaosExperiment, ChaosPlugin, ChaosScenario, ) + from strands_evals.chaos.effects import ToolCallFailure chaos = ChaosPlugin() - agent = Agent(model=my_model, tools=[...], plugins=[chaos]) scenarios = [ - ChaosScenario(name="search_timeout", tool_effects={"search_tool": ToolChaosEffect.TIMEOUT}), - ChaosScenario(name="db_down", tool_effects={"database_tool": ToolChaosEffect.NETWORK_ERROR}), + ChaosScenario( + name="search_timeout", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ), + ChaosScenario( + name="db_down", + effects={"database_tool": [ToolCallFailure(error_type="network_error")]}, + ), ] + def my_task(case): + agent = Agent(tools=[search_tool, database_tool], plugins=[chaos]) + return {"output": str(agent(case.input))} + experiment = ChaosExperiment( - chaos_plugin=chaos, - chaos_scenarios=scenarios, cases=[Case(input="Find flights to Tokyo", name="flight_search")], + scenarios=scenarios, evaluators=[my_evaluator], include_baseline=True, ) - reports = experiment.run_evaluations(task=lambda case: agent(case.input)) + reports = experiment.run_evaluations(task=my_task) """ def __init__( self, - chaos_plugin: ChaosPlugin, - chaos_scenarios: list[ChaosScenario], - cases: Optional[list[Case]] = None, + cases: list[Case], + scenarios: list[ChaosScenario], evaluators: Optional[list[Evaluator]] = None, include_baseline: bool = True, - baseline_assertion: Optional[str] = None, ): """Initialize a ChaosExperiment. Args: - chaos_plugin: The ChaosPlugin instance attached to the agent. - chaos_scenarios: List of scenarios to evaluate. Each scenario runs - all cases independently. - cases: Test cases to evaluate (same as base Experiment). - evaluators: Evaluators to assess results (same as base Experiment). + cases: Test cases to evaluate. + scenarios: List of chaos scenarios. Each scenario runs all cases. + All effects in a scenario fire simultaneously in a single run. + evaluators: Evaluators to assess results. include_baseline: If True, runs all cases with no chaos first for comparison. - baseline_assertion: Optional assertion string for baseline evaluation - (e.g., "agent should respond correctly without any failures"). """ - super().__init__(cases=cases, evaluators=evaluators) - self.chaos_plugin = chaos_plugin - self.chaos_scenarios = chaos_scenarios - self.include_baseline = include_baseline - self.baseline_assertion = baseline_assertion + self._original_cases = cases + self._scenarios = scenarios + self._evaluators = evaluators + self._include_baseline = include_baseline + + # Build the expanded case list and internal maps + self._expanded_cases: list[Case] = [] + self._scenario_by_session: dict[str, ChaosScenario] = {} + self._original_case_name_by_session: dict[str, Optional[str]] = {} + + all_scenarios = [] + if include_baseline: + all_scenarios.append(_BASELINE_SCENARIO) + all_scenarios.extend(scenarios) + + for case in cases: + for scenario in all_scenarios: + # Create expanded case with fresh session_id + session_id = str(uuid.uuid4()) + expanded_case = case.model_copy( + update={ + "name": f"{case.name}|{scenario.name}" if case.name else scenario.name, + "session_id": session_id, + } + ) + self._expanded_cases.append(expanded_case) + self._scenario_by_session[session_id] = scenario + self._original_case_name_by_session[session_id] = case.name + + # Internal Experiment with expanded cases + self._experiment = Experiment( + cases=self._expanded_cases, + evaluators=evaluators, + ) + + @property + def scenarios(self) -> list[ChaosScenario]: + """The chaos scenarios configured for this experiment.""" + return self._scenarios + + @property + def cases(self) -> list[Case]: + """The original (unexpanded) test cases.""" + return self._original_cases + + def get_scenario_for_session(self, session_id: str) -> Optional[ChaosScenario]: + """Look up the scenario assigned to a given session_id. + + Useful for downstream aggregation and reporting. + + Args: + session_id: The session_id of an expanded case. + + Returns: + The ChaosScenario for that session, or None if not found. + """ + return self._scenario_by_session.get(session_id) + + def get_original_case_name(self, session_id: str) -> Optional[str]: + """Look up the original case name for a given session_id. + + Args: + session_id: The session_id of an expanded case. + + Returns: + The original case name, or None if not found. + """ + return self._original_case_name_by_session.get(session_id) def run_evaluations( self, task: Callable[[Case], Any], **kwargs, ) -> list[EvaluationReport]: - """Run evaluations across all scenarios (and optionally baseline). + """Run evaluations across all (case × scenario) combinations. - Executes the task for each (scenario, case) pair: - 1. If include_baseline=True, runs all cases with no chaos active. - 2. For each scenario, activates it on the plugin, runs all cases, - then clears. - - Results from all runs are collected into a flat list of EvaluationReports, - with each case tagged with its scenario name in metadata. + Wraps the user's task function to set the ContextVar before each + case execution, so the ChaosPlugin sees the correct scenario. Args: task: The task function to evaluate. Takes a Case and returns output. - **kwargs: Additional kwargs passed to the base run_evaluations. + The task body should contain zero chaos concepts — just construct + the agent with plugins=[chaos] and call it. + **kwargs: Additional kwargs passed to the base Experiment.run_evaluations. Returns: List of EvaluationReport objects covering all scenarios. """ - all_reports: list[EvaluationReport] = [] - - # Baseline run (no chaos) - if self.include_baseline: - logger.info("Running baseline evaluation (no chaos)") - self.chaos_plugin.set_active_scenario(None) - baseline_cases = self._tag_cases_with_scenario("baseline") - original_cases = self._cases - self._cases = baseline_cases - try: - reports = super().run_evaluations(task, **kwargs) - all_reports.extend(reports) - finally: - self._cases = original_cases - - # Chaos scenario runs - for scenario in self.chaos_scenarios: - logger.info(f"Running chaos scenario: {scenario.name}") - self.chaos_plugin.set_active_scenario(scenario) - scenario_cases = self._tag_cases_with_scenario(scenario.name) - original_cases = self._cases - self._cases = scenario_cases + + def chaos_aware_task(case: Case) -> Any: + """Wrapper that activates the correct scenario via ContextVar.""" + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) try: - reports = super().run_evaluations(task, **kwargs) - all_reports.extend(reports) + return task(case) finally: - self._cases = original_cases + _current_scenario.reset(token) + + reports = self._experiment.run_evaluations(chaos_aware_task, **kwargs) - # Clear active scenario after all runs - self.chaos_plugin.set_active_scenario(None) + num_scenarios = len(self._scenarios) + (1 if self._include_baseline else 0) logger.info( - f"Chaos experiment complete: {len(all_reports)} reports " - f"({1 if self.include_baseline else 0} baseline + " - f"{len(self.chaos_scenarios)} scenarios)" + f"Chaos experiment complete: {len(reports)} reports " + f"({len(self._original_cases)} cases × {num_scenarios} scenarios)" ) - return all_reports + return reports - def _tag_cases_with_scenario(self, scenario_name: str) -> list[Case]: - """Create copies of cases with scenario name injected into metadata. + async def run_evaluations_async( + self, + task: Callable[[Case], Any], + max_workers: int = 10, + **kwargs, + ) -> list[EvaluationReport]: + """Run evaluations asynchronously across all (case × scenario) combinations. + + Same as run_evaluations but uses the async worker pool for parallelism. + ContextVar ensures each case sees its own scenario even under concurrency. Args: - scenario_name: The scenario name to tag. + task: The task function (sync or async). + max_workers: Maximum number of parallel workers. + **kwargs: Additional kwargs passed to the base Experiment.run_evaluations_async. Returns: - Deep copies of all cases with metadata["chaos_scenario"] set. + List of EvaluationReport objects covering all scenarios. """ - tagged_cases = [] - for case in self._cases: - tagged = case.model_copy(deep=True) - if tagged.metadata is None: - tagged.metadata = {} - tagged.metadata["chaos_scenario"] = scenario_name - # Update case name to include scenario for report clarity - if tagged.name: - tagged.name = f"{tagged.name} [{scenario_name}]" - else: - tagged.name = f"[{scenario_name}]" - tagged_cases.append(tagged) - return tagged_cases + import asyncio + + def chaos_aware_task(case: Case) -> Any: + """Wrapper that activates the correct scenario via ContextVar.""" + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + return task(case) + finally: + _current_scenario.reset(token) + + async def chaos_aware_task_async(case: Case) -> Any: + """Async wrapper that activates the correct scenario via ContextVar.""" + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + if asyncio.iscoroutinefunction(task): + return await task(case) + else: + return task(case) + finally: + _current_scenario.reset(token) + + if asyncio.iscoroutinefunction(task): + reports = await self._experiment.run_evaluations_async( + chaos_aware_task_async, max_workers=max_workers, **kwargs + ) + else: + reports = await self._experiment.run_evaluations_async( + chaos_aware_task, max_workers=max_workers, **kwargs + ) + + return reports diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index 5d57f272..e05c8f6f 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -2,23 +2,24 @@ Implements chaos injection as a standard Strands Plugin using the SDK's native hook system (BeforeToolCallEvent / AfterToolCallEvent). + +The plugin is stateless — it reads the active scenario from a module-level +ContextVar at hook time. The ChaosExperiment manages the ContextVar lifecycle. + +The plugin is a thin router: +- Pre-hook effects: reads effect.error_message, cancels the tool call. +- Post-hook effects: calls effect.apply(response), uses the return value. """ +import json import logging -import math -import random -from typing import Any, Optional +from typing import Any from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.plugins import Plugin, hook -from .effects import ( - TOOL_CORRUPTION_EFFECTS, - TOOL_ERROR_EFFECTS, - ToolChaosEffect, - ChaosEffectConfig, -) -from .scenario import ChaosScenario +from ._context import _current_scenario +from .effects import ChaosEffect logger = logging.getLogger(__name__) @@ -27,16 +28,19 @@ class ChaosPlugin(Plugin): """Strands Plugin that injects deterministic chaos based on the active scenario. The plugin intercepts tool calls via Strands' native hook system: - - BeforeToolCallEvent: cancels tool calls for error effects (TIMEOUT, NETWORK_ERROR, etc.) - - AfterToolCallEvent: corrupts tool responses for corruption effects (TRUNCATE_FIELDS, etc.) + - BeforeToolCallEvent: cancels tool calls for pre-hook effects (Timeout, NetworkError, etc.) + - AfterToolCallEvent: corrupts tool responses for post-hook effects (TruncateFields, etc.) - The active scenario is set externally (typically by ChaosExperiment) before - each evaluation run. When no scenario is active, all tools behave normally. + The active scenario is managed via a ContextVar (set by ChaosExperiment). + When no scenario is active, all tools behave normally. + + The plugin is stateless — no set_active_scenario method, no instance state + for the current scenario. This makes it safe under concurrent execution. Example:: from strands import Agent - from strands_evals.chaos import ChaosPlugin, ChaosScenario, ToolChaosEffect + from strands_evals.chaos import ChaosPlugin chaos = ChaosPlugin() agent = Agent( @@ -45,241 +49,111 @@ class ChaosPlugin(Plugin): plugins=[chaos], ) - # Activate a scenario - chaos.set_active_scenario(ChaosScenario( - name="search_timeout", - tool_effects={"search_tool": ToolChaosEffect.TIMEOUT}, - )) - - result = agent("Find flights to Tokyo") - # search_tool will be cancelled with a timeout error + # The ChaosExperiment handles scenario activation via ContextVar. + # The user's task body contains zero chaos concepts. """ name = "chaos-testing" def __init__(self) -> None: super().__init__() - self._active_scenario: Optional[ChaosScenario] = None - - @property - def active_scenario(self) -> Optional[ChaosScenario]: - """The currently active chaos scenario, or None for baseline (no chaos).""" - return self._active_scenario - - def set_active_scenario(self, scenario: Optional[ChaosScenario]) -> None: - """Set the scenario that drives chaos injection for subsequent tool calls. - - Args: - scenario: The scenario to activate, or None to disable chaos (baseline). - """ - self._active_scenario = scenario - if scenario: - logger.info(f"Chaos scenario activated: {scenario.name}") - else: - logger.info("Chaos scenario cleared (baseline mode)") - - def _should_apply(self, config: ChaosEffectConfig) -> bool: - """Check if the effect should fire based on apply_rate.""" - if config.apply_rate >= 1.0: - return True - if config.apply_rate <= 0.0: - return False - return random.random() < config.apply_rate @hook def before_tool_call(self, event: BeforeToolCallEvent) -> None: - """Intercept tool calls to inject error effects. + """Intercept tool calls to inject pre-hook (error) effects. - For error effects (TIMEOUT, NETWORK_ERROR, etc.), cancels the tool call - with a simulated error message before the tool executes. + For error effects (Timeout, NetworkError, etc.), cancels the tool call + with the effect's error_message before the tool executes. """ - if not self._active_scenario: + scenario = _current_scenario.get() + if scenario is None: return tool_name = event.tool_use.get("name", "") - chaos_effect = self._active_scenario.tool_effects.get(tool_name) - if chaos_effect is None: - return - - if isinstance(chaos_effect, ToolChaosEffect): - chaos_config = ChaosEffectConfig(effect=chaos_effect) - elif isinstance(chaos_effect, ChaosEffectConfig): - chaos_config = chaos_effect - else: - raise TypeError( - f"Unexpected effect type for tool '{tool_name}': {type(chaos_effect).__name__}. " - f"Expected ToolChaosEffect or ChaosEffectConfig." - ) - - # Only handle error effects in the before hook - if chaos_config.effect not in TOOL_ERROR_EFFECTS: + effects = scenario.effects.get(tool_name, []) + if not effects: return - if not self._should_apply(chaos_config): - return - - # Cancel the tool call with a simulated error - error_message = chaos_config.error_message or f"Simulated {chaos_config.effect.value}" - event.cancel_tool = error_message - logger.info( - f"[Chaos] Injected {chaos_config.effect.value} on tool '{tool_name}': {error_message}" - ) + # First pre-hook effect wins (tool is cancelled once) + for effect in effects: + if effect.hook == "pre": + event.cancel_tool = effect.apply() + logger.info( + f"[Chaos] Injected {type(effect).__name__} on tool '{tool_name}'" + ) + return @hook def after_tool_call(self, event: AfterToolCallEvent) -> None: - """Intercept tool results to inject corruption effects. + """Intercept tool results to inject post-hook (corruption) effects. - For corruption effects (TRUNCATE_FIELDS, REMOVE_FIELDS, CORRUPT_VALUES), - mutates the tool response after successful execution. + For corruption effects (TruncateFields, RemoveFields, CorruptValues), + calls effect.apply(response) to mutate the tool response. + + Handles Strands ToolResult content shapes: + - dict content: pass directly to effect.apply() + - list of blocks: extract text dicts, parse JSON, apply effect + - plain dict result: pass directly to effect.apply() + + Envelope fields (status, toolUseId) are preserved around the corruption. """ - if not self._active_scenario: + scenario = _current_scenario.get() + if scenario is None: return tool_name = event.tool_use.get("name", "") - chaos_effect = self._active_scenario.tool_effects.get(tool_name) - if chaos_effect is None: + effects = scenario.effects.get(tool_name, []) + if not effects: return - if isinstance(chaos_effect, ToolChaosEffect): - chaos_config = ChaosEffectConfig(effect=chaos_effect) - elif isinstance(chaos_effect, ChaosEffectConfig): - chaos_config = chaos_effect - else: - raise TypeError( - f"Unexpected effect type for tool '{tool_name}': {type(chaos_effect).__name__}. " - f"Expected ToolChaosEffect or ChaosEffectConfig." - ) - - # Only handle corruption effects in the after hook - if chaos_config.effect not in TOOL_CORRUPTION_EFFECTS: - return + # Apply all post-hook effects sequentially (they compose) + for effect in effects: + if effect.hook != "post": + continue - if not self._should_apply(chaos_config): - return + if not hasattr(event, "result") or event.result is None: + continue - # Corrupt the tool result - if hasattr(event, "result") and event.result is not None: result = event.result - # Handle ToolResult-like objects with content + if hasattr(result, "content"): if isinstance(result.content, dict): - # Content is a data dict — corrupt it directly - result.content = self._apply_corruption(chaos_config, result.content) + result.content = self._apply_with_envelope(effect, result.content) elif isinstance(result.content, list): - # Content is a list of blocks (e.g., [{"text": "..."}]) - # Corrupt text content within each block - corrupted_blocks = [] - for block in result.content: - if isinstance(block, dict) and "text" in block: - text_data = block["text"] - if isinstance(text_data, str): - # Try to parse as JSON dict for field-level corruption - import json - try: - parsed = json.loads(text_data) - if isinstance(parsed, dict): - corrupted = self._apply_corruption(chaos_config, parsed) - block = {**block, "text": json.dumps(corrupted)} - else: - block = {**block, "text": text_data} - except (json.JSONDecodeError, ValueError): - # Plain text — truncate if that's the effect - if chaos_config.effect == ToolChaosEffect.TRUNCATE_FIELDS: - block = {**block, "text": text_data[: len(text_data) // 2]} - corrupted_blocks.append(block) - result.content = corrupted_blocks + result.content = self._apply_to_blocks(effect, result.content) elif isinstance(result, dict): - event.result = self._apply_corruption(chaos_config, result) - - logger.info( - f"[Chaos] Applied {chaos_config.effect.value} corruption on tool '{tool_name}'" - ) - - # ------------------------------------------------------------------ - # Corruption helpers (private) - # ------------------------------------------------------------------ - - def _apply_corruption(self, effect_config: ChaosEffectConfig, response: Any) -> Any: - """Apply a corruption effect to a tool response data dict. - - Only preserves the `status` field (Bedrock requires "success"/"error"). - All other fields including `content` are fair game for corruption. - - Args: - effect_config: The normalized effect configuration. - response: The tool result data dict to corrupt. - - Returns: - The corrupted response. - """ - if not isinstance(response, dict): - return response - - effect = effect_config.effect - - # Preserve status — Bedrock requires toolResult.status to be "success" or "error" - saved_status = response.get("status") - - if effect == ToolChaosEffect.TRUNCATE_FIELDS: - response = self._truncate_fields(response) - elif effect == ToolChaosEffect.REMOVE_FIELDS: - ratio = effect_config.remove_ratio if effect_config.remove_ratio is not None else 0.33 - response = self._remove_fields(response, ratio) - elif effect == ToolChaosEffect.CORRUPT_VALUES: - rate = effect_config.corrupt_rate if effect_config.corrupt_rate is not None else 0.4 - response = self._corrupt_values(response, rate) - - # Restore status - if saved_status is not None: - response["status"] = saved_status - - return response - - @staticmethod - def _truncate_fields(response: dict[str, Any]) -> dict[str, Any]: - """Truncate string values to partial content.""" - result: dict[str, Any] = {} - for key, value in response.items(): - if key == "status": - result[key] = value - elif isinstance(value, str) and len(value) > 0: - result[key] = value[: random.randint(0, max(0, len(value) - 1))] - elif isinstance(value, dict): - result[key] = ChaosPlugin._truncate_fields(value) - else: - result[key] = value - return result - - @staticmethod - def _remove_fields(response: dict[str, Any], remove_ratio: float) -> dict[str, Any]: - """Remove a fraction of fields from the response.""" - keys = [k for k in response.keys() if k != "status"] - if not keys: - return response - - num_to_remove = max(1, math.ceil(len(keys) * remove_ratio)) - keys_to_remove = set(random.sample(keys, min(num_to_remove, len(keys)))) - return {k: v for k, v in response.items() if k not in keys_to_remove} - - @staticmethod - def _corrupt_values(response: dict[str, Any], corrupt_rate: float) -> dict[str, Any]: - """Replace a fraction of values with wrong types or garbage data.""" - corruptions: list[Any] = [None, 99999, "", True, [], "CORRUPTED_DATA"] - - keys = [k for k in response.keys() if k != "status"] - if not keys: - return response - - num_to_corrupt = max(1, math.ceil(len(keys) * corrupt_rate)) - keys_to_corrupt = set(random.sample(keys, min(num_to_corrupt, len(keys)))) - - result: dict[str, Any] = {} - for key, value in response.items(): - if key in keys_to_corrupt: - candidates = [c for c in corruptions if c != value] - result[key] = random.choice(candidates) if candidates else "CORRUPTED_DATA" - elif isinstance(value, dict): - result[key] = ChaosPlugin._corrupt_values(value, corrupt_rate) - else: - result[key] = value - return result + event.result = self._apply_with_envelope(effect, result) + + logger.info(f"[Chaos] Applied {type(effect).__name__} on tool '{tool_name}'") + + def _apply_with_envelope(self, effect: ChaosEffect, response: dict[str, Any]) -> dict[str, Any]: + """Apply effect while preserving envelope fields.""" + envelope_fields = {"status", "toolUseId"} + saved = {k: response[k] for k in envelope_fields if k in response} + + # Strip envelope before passing to effect + payload = {k: v for k, v in response.items() if k not in envelope_fields} + corrupted = effect.apply(payload) + + # Restore envelope + corrupted.update(saved) + return corrupted + + def _apply_to_blocks(self, effect: ChaosEffect, blocks: list) -> list: + """Apply effect to text blocks in a content list.""" + corrupted_blocks = [] + for block in blocks: + if isinstance(block, dict) and "text" in block: + text_data = block["text"] + if isinstance(text_data, str): + try: + parsed = json.loads(text_data) + if isinstance(parsed, dict): + corrupted = effect.apply(parsed) + block = {**block, "text": json.dumps(corrupted)} + except (json.JSONDecodeError, ValueError): + # Plain text — apply truncation via effect if applicable + if hasattr(effect, "max_length"): + block = {**block, "text": text_data[: effect.max_length]} + corrupted_blocks.append(block) + return corrupted_blocks diff --git a/src/strands_evals/chaos/scenario.py b/src/strands_evals/chaos/scenario.py index d212d438..14d07c23 100644 --- a/src/strands_evals/chaos/scenario.py +++ b/src/strands_evals/chaos/scenario.py @@ -1,65 +1,67 @@ """Chaos scenario definition. -A ChaosScenario is a named, deterministic mapping of tool names to the -chaos effects that will fire when those tools are invoked. +A ChaosScenario is a named, deterministic configuration of chaos effects +that will fire simultaneously when the scenario is active. """ -from typing import Union +from typing import Optional from pydantic import BaseModel, Field -from .effects import ToolChaosEffect, ChaosEffectConfig - - -# Type alias for what a tool_effects value can be -EffectSpec = Union[ToolChaosEffect, ChaosEffectConfig] +from .effects import ChaosEffect class ChaosScenario(BaseModel): """A single, deterministic chaos injection scenario. - Each scenario maps tool names to the exact effect that fires when that - tool is invoked. Tools not listed in tool_effects behave normally (no chaos). + Each scenario defines a set of tool effects that fire simultaneously when + the scenario is active. All listed effects are applied in the same + agent execution — this is NOT expanded into multiple runs. + + Tools not listed in tool_effects behave normally (no chaos). Example:: - # Simple: one tool fails with timeout - ChaosScenario( - name="search_timeout", - tool_effects={"search_tool": ToolChaosEffect.TIMEOUT}, - ) + from strands_evals.chaos import ChaosScenario + from strands_evals.chaos.effects import Timeout, NetworkError, CorruptValues + + # Baseline — no chaos + ChaosScenario(name="baseline") - # Multiple tools affected + # Single-fault: one tool fails ChaosScenario( - name="both_tools_down", - tool_effects={ - "search_tool": ToolChaosEffect.TIMEOUT, - "database_tool": ToolChaosEffect.NETWORK_ERROR, - }, + name="search_timeout", + effects={"search_tool": [Timeout()]}, ) - # Advanced: custom parameters + # Compound: multiple tools/models fail simultaneously ChaosScenario( - name="partial_corruption", - tool_effects={ - "database_tool": ChaosEffectConfig( - effect=ToolChaosEffect.REMOVE_FIELDS, - remove_ratio=0.5, - ), + name="search_times_out_while_book_corrupts", + description=( + "Worst-case compound: primary path fails hard while the " + "recovery path silently returns bad data." + ), + effects={ + "search_tool": [Timeout()], + "book_tool": [CorruptValues(corrupt_ratio=0.8)], }, ) """ name: str = Field(..., description="Human-readable name for this scenario") - tool_effects: dict[str, EffectSpec] = Field( + description: Optional[str] = Field( + default=None, + description="Optional description of what this scenario tests.", + ) + effects: dict[str, list[ChaosEffect]] = Field( default_factory=dict, - description="Mapping of tool_name -> effect to inject. " - "Tools not listed here behave normally.", + description="Mapping of target_name -> list of effects to inject simultaneously. " + "Targets not listed here behave normally.", ) def __repr__(self) -> str: effects_str = ", ".join( - f"{tool}: {eff.value if isinstance(eff, ToolChaosEffect) else eff.effect.value}" - for tool, eff in self.tool_effects.items() + f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" + for target, effs in self.effects.items() ) - return f"ChaosScenario(name='{self.name}', tool_effects={{{effects_str}}})" + return f"ChaosScenario(name='{self.name}', effects={{{effects_str}}})" From 68db60c859d847819002bbd4b37547d89a87bde6 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Mon, 11 May 2026 23:26:21 +0000 Subject: [PATCH 4/9] improve style --- src/strands_evals/chaos/__init__.py | 5 --- src/strands_evals/chaos/effects.py | 55 ++++++-------------------- src/strands_evals/chaos/experiment.py | 16 ++++---- src/strands_evals/chaos/plugin.py | 56 ++++++--------------------- src/strands_evals/chaos/scenario.py | 6 +-- 5 files changed, 32 insertions(+), 106 deletions(-) diff --git a/src/strands_evals/chaos/__init__.py b/src/strands_evals/chaos/__init__.py index 1af6d808..1d8a5a93 100644 --- a/src/strands_evals/chaos/__init__.py +++ b/src/strands_evals/chaos/__init__.py @@ -5,8 +5,6 @@ """ from .effects import ( - TOOL_CORRUPTION_EFFECTS, - TOOL_ERROR_EFFECTS, ChaosEffect, CorruptValues, RemoveFields, @@ -31,7 +29,4 @@ "TruncateFields", "RemoveFields", "CorruptValues", - # Classification sets - "TOOL_ERROR_EFFECTS", - "TOOL_CORRUPTION_EFFECTS", ] diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py index 3f51ea79..b93d1362 100644 --- a/src/strands_evals/chaos/effects.py +++ b/src/strands_evals/chaos/effects.py @@ -1,15 +1,11 @@ """Chaos effect definitions. Effects are first-class parameterized classes organized in a hierarchy: - ChaosEffect → ToolEffect → concrete effects (Timeout, NetworkError, etc.) - → ModelEffect → (reserved for v2) + ChaosEffect → ToolEffect → concrete effects (ToolCallFailure, TruncateFields, etc.) Each concrete effect carries only the parameters meaningful to it. The `hook` class variable indicates whether the effect fires pre-tool-call (error effects) or post-tool-call (corruption effects). - -Pre-hook effects provide `error_message` (the plugin cancels the tool call). -Post-hook effects implement `apply(response)` (the plugin passes the response through). """ import math @@ -20,17 +16,11 @@ from pydantic import BaseModel, Field -# --------------------------------------------------------------------------- -# Base classes -# --------------------------------------------------------------------------- - - class ChaosEffect(BaseModel): """Base for all chaos effects. Attributes: apply_rate: Probability that this effect fires. - In v1 this field is accepted but ignored (always fires). hook: Whether this effect fires pre-call ("pre") or post-call ("post"). """ @@ -45,11 +35,7 @@ class ChaosEffect(BaseModel): @abstractmethod def apply(self, context: Any = None) -> Any: - """Apply the chaos effect. - - Pre-hook effects return an error message string. - Post-hook effects accept a response dict and return the corrupted dict. - """ + """Apply the chaos effect to the given context and return the result.""" ... @@ -61,10 +47,6 @@ class ToolEffect(ChaosEffect): """ -# --------------------------------------------------------------------------- -# Pre-hook effect — cancels the tool call before execution -# --------------------------------------------------------------------------- - # All supported failure types ToolCallFailureType = Literal["timeout", "network_error", "execution_error", "validation_error"] @@ -115,14 +97,6 @@ def apply(self, context: Any = None) -> str: return _DEFAULT_ERROR_MESSAGES[self.error_type] -# --------------------------------------------------------------------------- -# Concrete tool corruption effects (post-hook — mutate the response) -# -# Post-hook effects implement apply(response) -> response. -# The plugin calls effect.apply(response_dict) and uses the return value. -# --------------------------------------------------------------------------- - - class TruncateFields(ToolEffect): """Truncates string values in the tool response. @@ -142,7 +116,7 @@ class TruncateFields(ToolEffect): hook: ClassVar[Literal["pre", "post"]] = "post" max_length: int = Field(default=10, ge=0, description="Maximum length to truncate string values to") - def apply(self, response: dict[str, Any]) -> dict[str, Any]: + def apply(self, response: Any = None) -> Any: """Truncate string values to max_length. Args: @@ -151,12 +125,14 @@ def apply(self, response: dict[str, Any]) -> dict[str, Any]: Returns: Response with string values truncated. """ + if not isinstance(response, dict): + return response result: dict[str, Any] = {} for key, value in response.items(): if isinstance(value, str) and len(value) > self.max_length: result[key] = value[: self.max_length] elif isinstance(value, dict): - result[key] = self._truncate(value) + result[key] = self.apply(value) else: result[key] = value return result @@ -186,7 +162,7 @@ class RemoveFields(ToolEffect): description="Fraction of fields to remove from the response", ) - def apply(self, response: dict[str, Any]) -> dict[str, Any]: + def apply(self, response: Any = None) -> Any: """Remove a fraction of fields from the response. Always removes at least 1 field when called. @@ -197,6 +173,8 @@ def apply(self, response: dict[str, Any]) -> dict[str, Any]: Returns: Response with fields removed. """ + if not isinstance(response, dict): + return response keys = list(response.keys()) if not keys: return response @@ -232,7 +210,7 @@ class CorruptValues(ToolEffect): _CORRUPTIONS: ClassVar[list[Any]] = [None, 99999, "", True, [], "CORRUPTED_DATA"] - def apply(self, response: dict[str, Any]) -> dict[str, Any]: + def apply(self, response: Any = None) -> Any: """Replace a fraction of values with wrong types or garbage data. Always corrupts at least 1 field when called. @@ -243,6 +221,8 @@ def apply(self, response: dict[str, Any]) -> dict[str, Any]: Returns: Response with corrupted values. """ + if not isinstance(response, dict): + return response keys = list(response.keys()) if not keys: return response @@ -260,14 +240,3 @@ def apply(self, response: dict[str, Any]) -> dict[str, Any]: else: result[key] = value return result - - -# --------------------------------------------------------------------------- -# Convenience sets for classification (derived from hierarchy, not maintained manually) -# --------------------------------------------------------------------------- - -# All concrete pre-hook (error) effect classes -TOOL_ERROR_EFFECTS: set[type[ToolEffect]] = {ToolCallFailure} - -# All concrete post-hook (corruption) effect classes -TOOL_CORRUPTION_EFFECTS: set[type[ToolEffect]] = {TruncateFields, RemoveFields, CorruptValues} diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py index 9e30eee5..58ba2c7f 100644 --- a/src/strands_evals/chaos/experiment.py +++ b/src/strands_evals/chaos/experiment.py @@ -18,9 +18,6 @@ logger = logging.getLogger(__name__) -# The baseline scenario — no chaos effects -_BASELINE_SCENARIO = ChaosScenario(name="baseline") - class ChaosExperiment: """Runs cases × scenarios by composing the base Experiment. @@ -31,6 +28,11 @@ class ChaosExperiment: Optionally includes a baseline run (no chaos) for comparison. + Note: The experiment runs ``len(cases) × (len(scenarios) + 1)`` evaluations + when ``include_baseline=True``, or ``len(cases) × len(scenarios)`` otherwise. + Plan scenario counts accordingly — each combination triggers a full agent + invocation. + Example:: from strands_evals.chaos import ( @@ -93,9 +95,7 @@ def __init__( self._scenario_by_session: dict[str, ChaosScenario] = {} self._original_case_name_by_session: dict[str, Optional[str]] = {} - all_scenarios = [] - if include_baseline: - all_scenarios.append(_BASELINE_SCENARIO) + all_scenarios = [ChaosScenario(name="baseline")] if include_baseline else [] all_scenarios.extend(scenarios) for case in cases: @@ -238,8 +238,6 @@ async def chaos_aware_task_async(case: Case) -> Any: chaos_aware_task_async, max_workers=max_workers, **kwargs ) else: - reports = await self._experiment.run_evaluations_async( - chaos_aware_task, max_workers=max_workers, **kwargs - ) + reports = await self._experiment.run_evaluations_async(chaos_aware_task, max_workers=max_workers, **kwargs) return reports diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index e05c8f6f..5a812145 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -3,17 +3,12 @@ Implements chaos injection as a standard Strands Plugin using the SDK's native hook system (BeforeToolCallEvent / AfterToolCallEvent). -The plugin is stateless — it reads the active scenario from a module-level -ContextVar at hook time. The ChaosExperiment manages the ContextVar lifecycle. - -The plugin is a thin router: -- Pre-hook effects: reads effect.error_message, cancels the tool call. -- Post-hook effects: calls effect.apply(response), uses the return value. +The plugin reads the active scenario from a module-level ContextVar at hook +time. The ChaosExperiment manages the ContextVar lifecycle. """ import json import logging -from typing import Any from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.plugins import Plugin, hook @@ -28,15 +23,12 @@ class ChaosPlugin(Plugin): """Strands Plugin that injects deterministic chaos based on the active scenario. The plugin intercepts tool calls via Strands' native hook system: - - BeforeToolCallEvent: cancels tool calls for pre-hook effects (Timeout, NetworkError, etc.) + - BeforeToolCallEvent: cancels tool calls for pre-hook effects (ToolCallFailure) - AfterToolCallEvent: corrupts tool responses for post-hook effects (TruncateFields, etc.) The active scenario is managed via a ContextVar (set by ChaosExperiment). When no scenario is active, all tools behave normally. - The plugin is stateless — no set_active_scenario method, no instance state - for the current scenario. This makes it safe under concurrent execution. - Example:: from strands import Agent @@ -58,7 +50,7 @@ class ChaosPlugin(Plugin): def __init__(self) -> None: super().__init__() - @hook + @hook # type: ignore[call-overload] def before_tool_call(self, event: BeforeToolCallEvent) -> None: """Intercept tool calls to inject pre-hook (error) effects. @@ -78,24 +70,15 @@ def before_tool_call(self, event: BeforeToolCallEvent) -> None: for effect in effects: if effect.hook == "pre": event.cancel_tool = effect.apply() - logger.info( - f"[Chaos] Injected {type(effect).__name__} on tool '{tool_name}'" - ) + logger.info(f"[Chaos] Injected {type(effect).__name__} on tool '{tool_name}'") return - @hook + @hook # type: ignore[call-overload] def after_tool_call(self, event: AfterToolCallEvent) -> None: """Intercept tool results to inject post-hook (corruption) effects. For corruption effects (TruncateFields, RemoveFields, CorruptValues), - calls effect.apply(response) to mutate the tool response. - - Handles Strands ToolResult content shapes: - - dict content: pass directly to effect.apply() - - list of blocks: extract text dicts, parse JSON, apply effect - - plain dict result: pass directly to effect.apply() - - Envelope fields (status, toolUseId) are preserved around the corruption. + applies effect.apply() to JSON content blocks in the tool response. """ scenario = _current_scenario.get() if scenario is None: @@ -111,34 +94,17 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: if effect.hook != "post": continue - if not hasattr(event, "result") or event.result is None: + if event.result is None: continue result = event.result + content = result.get("content") - if hasattr(result, "content"): - if isinstance(result.content, dict): - result.content = self._apply_with_envelope(effect, result.content) - elif isinstance(result.content, list): - result.content = self._apply_to_blocks(effect, result.content) - elif isinstance(result, dict): - event.result = self._apply_with_envelope(effect, result) + if isinstance(content, list): + result["content"] = self._apply_to_blocks(effect, content) # type: ignore[assignment] logger.info(f"[Chaos] Applied {type(effect).__name__} on tool '{tool_name}'") - def _apply_with_envelope(self, effect: ChaosEffect, response: dict[str, Any]) -> dict[str, Any]: - """Apply effect while preserving envelope fields.""" - envelope_fields = {"status", "toolUseId"} - saved = {k: response[k] for k in envelope_fields if k in response} - - # Strip envelope before passing to effect - payload = {k: v for k, v in response.items() if k not in envelope_fields} - corrupted = effect.apply(payload) - - # Restore envelope - corrupted.update(saved) - return corrupted - def _apply_to_blocks(self, effect: ChaosEffect, blocks: list) -> list: """Apply effect to text blocks in a content list.""" corrupted_blocks = [] diff --git a/src/strands_evals/chaos/scenario.py b/src/strands_evals/chaos/scenario.py index 14d07c23..ee127479 100644 --- a/src/strands_evals/chaos/scenario.py +++ b/src/strands_evals/chaos/scenario.py @@ -15,8 +15,7 @@ class ChaosScenario(BaseModel): """A single, deterministic chaos injection scenario. Each scenario defines a set of tool effects that fire simultaneously when - the scenario is active. All listed effects are applied in the same - agent execution — this is NOT expanded into multiple runs. + the scenario is active. Tools not listed in tool_effects behave normally (no chaos). @@ -61,7 +60,6 @@ class ChaosScenario(BaseModel): def __repr__(self) -> str: effects_str = ", ".join( - f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" - for target, effs in self.effects.items() + f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" for target, effs in self.effects.items() ) return f"ChaosScenario(name='{self.name}', effects={{{effects_str}}})" From 28a0679274b8c00355e99ded73adb6b5a88590bb Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Mon, 11 May 2026 23:32:38 +0000 Subject: [PATCH 5/9] add tests --- tests/strands_evals/chaos/__init__.py | 1 + tests/strands_evals/chaos/test_context.py | 33 ++++ tests/strands_evals/chaos/test_effects.py | 163 ++++++++++++++++ tests/strands_evals/chaos/test_experiment.py | 161 +++++++++++++++ tests/strands_evals/chaos/test_plugin.py | 194 +++++++++++++++++++ tests/strands_evals/chaos/test_scenario.py | 45 +++++ 6 files changed, 597 insertions(+) create mode 100644 tests/strands_evals/chaos/__init__.py create mode 100644 tests/strands_evals/chaos/test_context.py create mode 100644 tests/strands_evals/chaos/test_effects.py create mode 100644 tests/strands_evals/chaos/test_experiment.py create mode 100644 tests/strands_evals/chaos/test_plugin.py create mode 100644 tests/strands_evals/chaos/test_scenario.py diff --git a/tests/strands_evals/chaos/__init__.py b/tests/strands_evals/chaos/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/strands_evals/chaos/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/strands_evals/chaos/test_context.py b/tests/strands_evals/chaos/test_context.py new file mode 100644 index 00000000..5b93ced5 --- /dev/null +++ b/tests/strands_evals/chaos/test_context.py @@ -0,0 +1,33 @@ +"""Unit tests for the chaos _context module.""" + +from strands_evals.chaos import ChaosScenario +from strands_evals.chaos._context import _current_scenario + + +class TestContextVar: + """Tests for the _current_scenario ContextVar.""" + + def test_set_and_get(self): + scenario = ChaosScenario(name="test_scenario") + token = _current_scenario.set(scenario) + try: + assert _current_scenario.get() is scenario + assert _current_scenario.get().name == "test_scenario" + finally: + _current_scenario.reset(token) + + def test_nested_set_and_reset(self): + s1 = ChaosScenario(name="outer") + s2 = ChaosScenario(name="inner") + + token1 = _current_scenario.set(s1) + try: + assert _current_scenario.get().name == "outer" + token2 = _current_scenario.set(s2) + try: + assert _current_scenario.get().name == "inner" + finally: + _current_scenario.reset(token2) + assert _current_scenario.get().name == "outer" + finally: + _current_scenario.reset(token1) diff --git a/tests/strands_evals/chaos/test_effects.py b/tests/strands_evals/chaos/test_effects.py new file mode 100644 index 00000000..c55821d9 --- /dev/null +++ b/tests/strands_evals/chaos/test_effects.py @@ -0,0 +1,163 @@ +"""Unit tests for chaos effect classes.""" + +import random + +import pytest + +from strands_evals.chaos.effects import ( + CorruptValues, + RemoveFields, + ToolCallFailure, + TruncateFields, +) + + +class TestToolCallFailure: + """Tests for the ToolCallFailure pre-hook effect.""" + + @pytest.mark.parametrize( + "error_type,expected_message", + [ + ("timeout", "Tool call timed out"), + ("network_error", "Network unreachable"), + ("execution_error", "Tool execution failed"), + ("validation_error", "Tool input validation failed"), + ], + ) + def test_apply_returns_default_message(self, error_type, expected_message): + effect = ToolCallFailure(error_type=error_type) + assert effect.apply() == expected_message + + def test_apply_returns_custom_message_when_provided(self): + effect = ToolCallFailure(error_type="timeout", error_message="Custom timeout msg") + assert effect.apply() == "Custom timeout msg" + + def test_apply_rate_defaults_to_one(self): + effect = ToolCallFailure() + assert effect.apply_rate == 1.0 + + +class TestTruncateFields: + """Tests for the TruncateFields post-hook effect.""" + + def test_truncates_long_strings(self): + effect = TruncateFields(max_length=5) + response = {"name": "hello world", "short": "hi"} + result = effect.apply(response) + assert result["name"] == "hello" + assert result["short"] == "hi" + + def test_preserves_non_string_values(self): + effect = TruncateFields(max_length=3) + response = {"count": 42, "flag": True, "items": [1, 2, 3]} + result = effect.apply(response) + assert result["count"] == 42 + assert result["flag"] is True + assert result["items"] == [1, 2, 3] + + def test_truncates_nested_dicts(self): + effect = TruncateFields(max_length=3) + response = {"nested": {"deep_value": "abcdef"}} + result = effect.apply(response) + assert result["nested"]["deep_value"] == "abc" + + def test_empty_dict_returns_empty(self): + effect = TruncateFields(max_length=5) + assert effect.apply({}) == {} + + def test_non_dict_input_returned_as_is(self): + effect = TruncateFields(max_length=5) + assert effect.apply("not a dict") == "not a dict" + assert effect.apply(None) is None + + def test_zero_max_length_truncates_all_strings(self): + effect = TruncateFields(max_length=0) + response = {"text": "hello"} + result = effect.apply(response) + assert result["text"] == "" + + +class TestRemoveFields: + """Tests for the RemoveFields post-hook effect.""" + + def test_removes_at_least_one_field(self): + random.seed(42) + effect = RemoveFields(remove_ratio=0.1) + response = {"a": 1, "b": 2, "c": 3, "d": 4} + result = effect.apply(response) + assert len(result) < len(response) + + def test_removes_half_fields(self): + random.seed(42) + effect = RemoveFields(remove_ratio=0.5) + response = {"a": 1, "b": 2, "c": 3, "d": 4} + result = effect.apply(response) + assert len(result) == 2 + + def test_removes_all_fields_at_ratio_one(self): + random.seed(42) + effect = RemoveFields(remove_ratio=1.0) + response = {"a": 1, "b": 2, "c": 3} + result = effect.apply(response) + assert len(result) == 0 + + def test_empty_dict_returns_empty(self): + effect = RemoveFields(remove_ratio=0.5) + assert effect.apply({}) == {} + + def test_non_dict_input_returned_as_is(self): + effect = RemoveFields(remove_ratio=0.5) + assert effect.apply("not a dict") == "not a dict" + assert effect.apply(None) is None + + def test_single_field_always_removed(self): + random.seed(42) + effect = RemoveFields(remove_ratio=0.5) + response = {"only_key": "value"} + result = effect.apply(response) + assert len(result) == 0 + + +class TestCorruptValues: + """Tests for the CorruptValues post-hook effect.""" + + def test_corrupts_at_least_one_field(self): + random.seed(42) + effect = CorruptValues(corrupt_ratio=0.1) + response = {"a": "original_a", "b": "original_b", "c": "original_c", "d": "original_d"} + result = effect.apply(response) + corrupted_count = sum(1 for k in response if result[k] != response[k]) + assert corrupted_count >= 1 + + def test_corrupted_values_come_from_corruption_pool(self): + random.seed(42) + effect = CorruptValues(corrupt_ratio=1.0) + response = {"a": "original", "b": "data"} + result = effect.apply(response) + corruption_pool = [None, 99999, "", True, [], "CORRUPTED_DATA"] + for key in response: + assert result[key] in corruption_pool + + def test_corrupts_nested_dicts_recursively(self): + random.seed(42) + effect = CorruptValues(corrupt_ratio=1.0) + response = {"top": "value", "nested": {"inner": "deep_value"}} + result = effect.apply(response) + # The nested dict should also be processed + assert "nested" in result or "top" in result + + def test_empty_dict_returns_empty(self): + effect = CorruptValues(corrupt_ratio=0.5) + assert effect.apply({}) == {} + + def test_non_dict_input_returned_as_is(self): + effect = CorruptValues(corrupt_ratio=0.5) + assert effect.apply("not a dict") == "not a dict" + assert effect.apply(None) is None + + def test_corrupted_value_differs_from_original(self): + random.seed(42) + effect = CorruptValues(corrupt_ratio=1.0) + response = {"key": "unique_original_value"} + result = effect.apply(response) + assert result["key"] != "unique_original_value" diff --git a/tests/strands_evals/chaos/test_experiment.py b/tests/strands_evals/chaos/test_experiment.py new file mode 100644 index 00000000..9b43c4a2 --- /dev/null +++ b/tests/strands_evals/chaos/test_experiment.py @@ -0,0 +1,161 @@ +"""Unit tests for ChaosExperiment.""" + +import pytest + +from strands_evals import Case +from strands_evals.chaos import ChaosExperiment, ChaosScenario +from strands_evals.chaos._context import _current_scenario +from strands_evals.chaos.effects import CorruptValues, ToolCallFailure +from strands_evals.evaluators.evaluator import Evaluator +from strands_evals.types import EvaluationData, EvaluationOutput + + +class MockChaosEvaluator(Evaluator): + """Simple evaluator that always passes.""" + + def evaluate(self, evaluation_case: EvaluationData) -> list[EvaluationOutput]: + return [EvaluationOutput(score=1.0, test_pass=True, reason="Mock pass")] + + +@pytest.fixture +def cases(): + return [ + Case(name="case_a", input="hello"), + Case(name="case_b", input="world"), + ] + + +@pytest.fixture +def scenarios(): + return [ + ChaosScenario( + name="search_timeout", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ), + ChaosScenario( + name="db_corrupt", + effects={"db_tool": [CorruptValues(corrupt_ratio=0.8)]}, + ), + ] + + +@pytest.fixture +def evaluator(): + return MockChaosEvaluator() + + +class TestChaosExperiment: + """Tests for ChaosExperiment initialization and execution.""" + + def test_expanded_cases_count_with_baseline(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + # 2 cases × (2 scenarios + 1 baseline) = 6 + assert len(experiment._expanded_cases) == 6 + + def test_expanded_cases_count_without_baseline(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=False) + # 2 cases × 2 scenarios = 4 + assert len(experiment._expanded_cases) == 4 + + def test_expanded_case_names_include_scenario(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + names = [c.name for c in experiment._expanded_cases] + assert "case_a|baseline" in names + assert "case_a|search_timeout" in names + assert "case_a|db_corrupt" in names + assert "case_b|baseline" in names + assert "case_b|search_timeout" in names + assert "case_b|db_corrupt" in names + + def test_each_expanded_case_has_unique_session_id(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator]) + session_ids = [c.session_id for c in experiment._expanded_cases] + assert len(session_ids) == len(set(session_ids)) + + def test_get_scenario_for_session(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + # Pick an expanded case and verify its scenario maps correctly + for expanded_case in experiment._expanded_cases: + scenario = experiment.get_scenario_for_session(expanded_case.session_id) + assert scenario is not None + assert scenario.name in expanded_case.name + + def test_get_scenario_for_unknown_session(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator]) + assert experiment.get_scenario_for_session("nonexistent-id") is None + + def test_get_original_case_name(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + for expanded_case in experiment._expanded_cases: + original_name = experiment.get_original_case_name(expanded_case.session_id) + assert original_name in ("case_a", "case_b") + + def test_get_original_case_name_unknown_session(self, cases, scenarios, evaluator): + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator]) + assert experiment.get_original_case_name("nonexistent-id") is None + + def test_context_var_set_and_reset(self, cases, scenarios, evaluator): + """Verify the ContextVar is set to the correct scenario during task execution and reset after.""" + observed_scenarios = [] + + def capturing_task(case: Case): + scenario = _current_scenario.get() + observed_scenarios.append((case.name, scenario.name if scenario else None)) + return "output" + + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + experiment.run_evaluations(task=capturing_task) + + # Should have 6 observations (2 cases × 3 scenarios) + assert len(observed_scenarios) == 6 + + # Verify baseline scenarios observed + baseline_obs = [(name, sn) for name, sn in observed_scenarios if sn == "baseline"] + assert len(baseline_obs) == 2 + + # Verify chaos scenarios observed + timeout_obs = [(name, sn) for name, sn in observed_scenarios if sn == "search_timeout"] + assert len(timeout_obs) == 2 + + # After all runs, the ContextVar should be back to None + assert _current_scenario.get() is None + + def test_context_var_reset_on_task_exception(self, evaluator): + """Verify the ContextVar is reset even if the task raises.""" + cases = [Case(name="failing", input="x")] + scenarios_list = [ChaosScenario(name="chaos", effects={"t": [ToolCallFailure()]})] + + call_count = [0] + + def failing_task(case: Case): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("Task failed") + return "output" + + experiment = ChaosExperiment( + cases=cases, scenarios=scenarios_list, evaluators=[evaluator], include_baseline=True + ) + + # The base Experiment should handle the exception internally + # ContextVar should still be reset + try: + experiment.run_evaluations(task=failing_task) + except Exception: + pass + + assert _current_scenario.get() is None + + def test_returns_evaluation_reports(self, cases, scenarios, evaluator): + """Verify run_evaluations returns reports.""" + + def task(case: Case): + return "output" + + experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + reports = experiment.run_evaluations(task=task) + + assert len(reports) >= 1 + report = reports[0] + # 2 cases × 3 scenarios = 6 scores + assert len(report.scores) == 6 diff --git a/tests/strands_evals/chaos/test_plugin.py b/tests/strands_evals/chaos/test_plugin.py new file mode 100644 index 00000000..8bb1090d --- /dev/null +++ b/tests/strands_evals/chaos/test_plugin.py @@ -0,0 +1,194 @@ +"""Unit tests for ChaosPlugin.""" + +import json +from unittest.mock import MagicMock + +import pytest + +from strands_evals.chaos import ChaosPlugin, ChaosScenario +from strands_evals.chaos._context import _current_scenario +from strands_evals.chaos.effects import ( + ToolCallFailure, + TruncateFields, +) + + +@pytest.fixture +def chaos_plugin(): + return ChaosPlugin() + + +@pytest.fixture +def before_event(): + """Create a mock BeforeToolCallEvent.""" + event = MagicMock() + event.tool_use = {"name": "search_tool"} + event.cancel_tool = None + return event + + +@pytest.fixture +def after_event(): + """Create a mock AfterToolCallEvent with list content.""" + event = MagicMock() + event.tool_use = {"name": "search_tool"} + event.result = { + "content": [{"text": json.dumps({"title": "Long Title Here", "count": 42})}], + "status": "success", + "toolUseId": "tool-123", + } + return event + + +class TestChaosPluginBeforeToolCall: + """Tests for the before_tool_call hook.""" + + def test_no_scenario_active_does_nothing(self, chaos_plugin, before_event): + token = _current_scenario.set(None) + try: + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool is None + finally: + _current_scenario.reset(token) + + def test_scenario_without_matching_tool_does_nothing(self, chaos_plugin, before_event): + scenario = ChaosScenario( + name="other_tool_fails", + effects={"other_tool": [ToolCallFailure(error_type="timeout")]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool is None + finally: + _current_scenario.reset(token) + + def test_pre_hook_effect_cancels_tool(self, chaos_plugin, before_event): + scenario = ChaosScenario( + name="search_timeout", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool == "Tool call timed out" + finally: + _current_scenario.reset(token) + + def test_post_hook_effect_does_not_cancel_tool(self, chaos_plugin, before_event): + scenario = ChaosScenario( + name="search_truncated", + effects={"search_tool": [TruncateFields(max_length=5)]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool is None + finally: + _current_scenario.reset(token) + + def test_first_pre_hook_effect_wins(self, chaos_plugin, before_event): + scenario = ChaosScenario( + name="multi_pre", + effects={ + "search_tool": [ + ToolCallFailure(error_type="timeout"), + ToolCallFailure(error_type="network_error"), + ] + }, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool == "Tool call timed out" + finally: + _current_scenario.reset(token) + + +class TestChaosPluginAfterToolCall: + """Tests for the after_tool_call hook.""" + + def test_no_scenario_active_does_nothing(self, chaos_plugin, after_event): + token = _current_scenario.set(None) + try: + original_content = after_event.result["content"][0]["text"] + chaos_plugin.after_tool_call(after_event) + assert after_event.result["content"][0]["text"] == original_content + finally: + _current_scenario.reset(token) + + def test_scenario_without_matching_tool_does_nothing(self, chaos_plugin, after_event): + scenario = ChaosScenario( + name="other_tool", + effects={"other_tool": [TruncateFields(max_length=3)]}, + ) + token = _current_scenario.set(scenario) + try: + original_content = after_event.result["content"][0]["text"] + chaos_plugin.after_tool_call(after_event) + assert after_event.result["content"][0]["text"] == original_content + finally: + _current_scenario.reset(token) + + def test_post_hook_corrupts_json_text_blocks(self, chaos_plugin, after_event): + scenario = ChaosScenario( + name="truncate", + effects={"search_tool": [TruncateFields(max_length=3)]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.after_tool_call(after_event) + corrupted = json.loads(after_event.result["content"][0]["text"]) + assert corrupted["title"] == "Lon" + assert corrupted["count"] == 42 # non-string preserved + finally: + _current_scenario.reset(token) + + def test_pre_hook_effect_ignored_in_after_hook(self, chaos_plugin, after_event): + scenario = ChaosScenario( + name="pre_only", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ) + token = _current_scenario.set(scenario) + try: + original_content = after_event.result["content"][0]["text"] + chaos_plugin.after_tool_call(after_event) + assert after_event.result["content"][0]["text"] == original_content + finally: + _current_scenario.reset(token) + + def test_none_result_is_skipped(self, chaos_plugin): + event = MagicMock() + event.tool_use = {"name": "search_tool"} + event.result = None + + scenario = ChaosScenario( + name="truncate", + effects={"search_tool": [TruncateFields(max_length=3)]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.after_tool_call(event) # Should not raise + finally: + _current_scenario.reset(token) + + def test_plain_text_truncation(self, chaos_plugin): + """Test that plain (non-JSON) text blocks get truncated if effect has max_length.""" + event = MagicMock() + event.tool_use = {"name": "search_tool"} + event.result = { + "content": [{"text": "This is plain text, not JSON"}], + "status": "success", + "toolUseId": "tool-456", + } + + scenario = ChaosScenario( + name="truncate", + effects={"search_tool": [TruncateFields(max_length=4)]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.after_tool_call(event) + assert event.result["content"][0]["text"] == "This" + finally: + _current_scenario.reset(token) diff --git a/tests/strands_evals/chaos/test_scenario.py b/tests/strands_evals/chaos/test_scenario.py new file mode 100644 index 00000000..d9eceb07 --- /dev/null +++ b/tests/strands_evals/chaos/test_scenario.py @@ -0,0 +1,45 @@ +"""Unit tests for ChaosScenario.""" + +from strands_evals.chaos import ChaosScenario +from strands_evals.chaos.effects import CorruptValues, ToolCallFailure, TruncateFields + + +class TestChaosScenario: + """Tests for the ChaosScenario data model.""" + + def test_baseline_scenario_has_no_effects(self): + scenario = ChaosScenario(name="baseline") + assert scenario.effects == {} + + def test_scenario_with_multiple_tools(self): + scenario = ChaosScenario( + name="compound_failure", + effects={ + "search_tool": [ToolCallFailure(error_type="timeout")], + "db_tool": [CorruptValues(corrupt_ratio=0.8)], + }, + ) + assert len(scenario.effects) == 2 + assert isinstance(scenario.effects["search_tool"][0], ToolCallFailure) + assert isinstance(scenario.effects["db_tool"][0], CorruptValues) + + def test_scenario_with_multiple_effects_per_tool(self): + scenario = ChaosScenario( + name="multi_effect", + effects={ + "tool_a": [ + TruncateFields(max_length=5), + CorruptValues(corrupt_ratio=0.3), + ], + }, + ) + assert len(scenario.effects["tool_a"]) == 2 + + def test_repr_shows_effects(self): + scenario = ChaosScenario( + name="test", + effects={"tool": [ToolCallFailure()]}, + ) + repr_str = repr(scenario) + assert "test" in repr_str + assert "ToolCallFailure" in repr_str From bb023d64005e3604130ffb618db88be25fc7c33c Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Wed, 13 May 2026 17:40:43 +0000 Subject: [PATCH 6/9] address review bot's comments --- src/strands_evals/chaos/effects.py | 9 +- src/strands_evals/chaos/experiment.py | 111 +++++++++++++---------- src/strands_evals/chaos/plugin.py | 21 +++-- src/strands_evals/chaos/scenario.py | 6 +- tests/strands_evals/chaos/test_plugin.py | 61 +++++++++++++ 5 files changed, 144 insertions(+), 64 deletions(-) diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py index b93d1362..cf438bb3 100644 --- a/src/strands_evals/chaos/effects.py +++ b/src/strands_evals/chaos/effects.py @@ -20,7 +20,7 @@ class ChaosEffect(BaseModel): """Base for all chaos effects. Attributes: - apply_rate: Probability that this effect fires. + apply_rate: Probability that this effect fires, defaults to 1 (always fire). hook: Whether this effect fires pre-call ("pre") or post-call ("post"). """ @@ -40,10 +40,11 @@ def apply(self, context: Any = None) -> Any: class ToolEffect(ChaosEffect): - """Effect valid at the tool invocation boundary. + """Effect that operates at the tool invocation boundary. - - "pre": effect fires before tool execution (cancels the call with an error) - - "post": effect fires after tool execution (corrupts the response) + This intermediate class enables type-based dispatch so the plugin can + distinguish tool-level effects from other planned effect categories + (e.g., upcoming ``ModelEffect`` for LLM input and output chaos injection). """ diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py index 58ba2c7f..b3f2b12e 100644 --- a/src/strands_evals/chaos/experiment.py +++ b/src/strands_evals/chaos/experiment.py @@ -152,6 +152,44 @@ def get_original_case_name(self, session_id: str) -> Optional[str]: """ return self._original_case_name_by_session.get(session_id) + def _wrap_task(self, task: Callable[[Case], Any]) -> Callable[[Case], Any]: + """Wrap a task function to activate the correct scenario via ContextVar. + + Handles both sync and async tasks — returns a sync wrapper for sync tasks + and an async wrapper for async tasks, so the base Experiment dispatches + correctly. + + Args: + task: The original task function (sync or async). + + Returns: + A wrapped callable that sets/resets the ContextVar around each invocation. + """ + import asyncio + + if asyncio.iscoroutinefunction(task): + + async def chaos_aware_task_async(case: Case) -> Any: + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + return await task(case) + finally: + _current_scenario.reset(token) + + return chaos_aware_task_async + else: + + def chaos_aware_task(case: Case) -> Any: + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + return task(case) + finally: + _current_scenario.reset(token) + + return chaos_aware_task + def run_evaluations( self, task: Callable[[Case], Any], @@ -159,37 +197,29 @@ def run_evaluations( ) -> list[EvaluationReport]: """Run evaluations across all (case × scenario) combinations. - Wraps the user's task function to set the ContextVar before each - case execution, so the ChaosPlugin sees the correct scenario. + Delegates to run_evaluations_async with max_workers=1, mirroring the + base Experiment pattern. Args: task: The task function to evaluate. Takes a Case and returns output. The task body should contain zero chaos concepts — just construct the agent with plugins=[chaos] and call it. - **kwargs: Additional kwargs passed to the base Experiment.run_evaluations. + **kwargs: Additional kwargs passed to the base Experiment.run_evaluations_async. Returns: List of EvaluationReport objects covering all scenarios. - """ - def chaos_aware_task(case: Case) -> Any: - """Wrapper that activates the correct scenario via ContextVar.""" - scenario = self._scenario_by_session.get(case.session_id) - token = _current_scenario.set(scenario) - try: - return task(case) - finally: - _current_scenario.reset(token) - - reports = self._experiment.run_evaluations(chaos_aware_task, **kwargs) + Raises: + ValueError: If an async task is passed (use run_evaluations_async instead). + """ + import asyncio - num_scenarios = len(self._scenarios) + (1 if self._include_baseline else 0) - logger.info( - f"Chaos experiment complete: {len(reports)} reports " - f"({len(self._original_cases)} cases × {num_scenarios} scenarios)" - ) + if asyncio.iscoroutinefunction(task): + raise ValueError( + "Async task is not supported in run_evaluations. Please use run_evaluations_async instead." + ) - return reports + return asyncio.run(self.run_evaluations_async(task, max_workers=1, **kwargs)) async def run_evaluations_async( self, @@ -199,8 +229,8 @@ async def run_evaluations_async( ) -> list[EvaluationReport]: """Run evaluations asynchronously across all (case × scenario) combinations. - Same as run_evaluations but uses the async worker pool for parallelism. - ContextVar ensures each case sees its own scenario even under concurrency. + Wraps the user's task to set the ContextVar before each case execution. + The base Experiment handles sync-to-async dispatch internally. Args: task: The task function (sync or async). @@ -210,34 +240,15 @@ async def run_evaluations_async( Returns: List of EvaluationReport objects covering all scenarios. """ - import asyncio + wrapped = self._wrap_task(task) + reports = await self._experiment.run_evaluations_async(wrapped, max_workers=max_workers, **kwargs) - def chaos_aware_task(case: Case) -> Any: - """Wrapper that activates the correct scenario via ContextVar.""" - scenario = self._scenario_by_session.get(case.session_id) - token = _current_scenario.set(scenario) - try: - return task(case) - finally: - _current_scenario.reset(token) - - async def chaos_aware_task_async(case: Case) -> Any: - """Async wrapper that activates the correct scenario via ContextVar.""" - scenario = self._scenario_by_session.get(case.session_id) - token = _current_scenario.set(scenario) - try: - if asyncio.iscoroutinefunction(task): - return await task(case) - else: - return task(case) - finally: - _current_scenario.reset(token) - - if asyncio.iscoroutinefunction(task): - reports = await self._experiment.run_evaluations_async( - chaos_aware_task_async, max_workers=max_workers, **kwargs - ) - else: - reports = await self._experiment.run_evaluations_async(chaos_aware_task, max_workers=max_workers, **kwargs) + num_scenarios = len(self._scenarios) + (1 if self._include_baseline else 0) + logger.info( + "cases=<%d>, scenarios=<%d>, reports=<%d> | chaos experiment complete", + len(self._original_cases), + num_scenarios, + len(reports), + ) return reports diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index 5a812145..7f76f584 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -9,12 +9,13 @@ import json import logging +import random from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.plugins import Plugin, hook from ._context import _current_scenario -from .effects import ChaosEffect +from .effects import ChaosEffect, TruncateFields logger = logging.getLogger(__name__) @@ -54,8 +55,9 @@ def __init__(self) -> None: def before_tool_call(self, event: BeforeToolCallEvent) -> None: """Intercept tool calls to inject pre-hook (error) effects. - For error effects (Timeout, NetworkError, etc.), cancels the tool call - with the effect's error_message before the tool executes. + For ToolCallFailure effects (with error_type='timeout', 'network_error', + etc.), cancels the tool call with the effect's error_message before the + tool executes. """ scenario = _current_scenario.get() if scenario is None: @@ -69,8 +71,10 @@ def before_tool_call(self, event: BeforeToolCallEvent) -> None: # First pre-hook effect wins (tool is cancelled once) for effect in effects: if effect.hook == "pre": + if random.random() > effect.apply_rate: + continue event.cancel_tool = effect.apply() - logger.info(f"[Chaos] Injected {type(effect).__name__} on tool '{tool_name}'") + logger.info("effect=<%s>, tool=<%s> | injected chaos pre-hook", type(effect).__name__, tool_name) return @hook # type: ignore[call-overload] @@ -94,6 +98,9 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: if effect.hook != "post": continue + if random.random() > effect.apply_rate: + continue + if event.result is None: continue @@ -103,7 +110,7 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: if isinstance(content, list): result["content"] = self._apply_to_blocks(effect, content) # type: ignore[assignment] - logger.info(f"[Chaos] Applied {type(effect).__name__} on tool '{tool_name}'") + logger.info("effect=<%s>, tool=<%s> | applied chaos post-hook", type(effect).__name__, tool_name) def _apply_to_blocks(self, effect: ChaosEffect, blocks: list) -> list: """Apply effect to text blocks in a content list.""" @@ -118,8 +125,8 @@ def _apply_to_blocks(self, effect: ChaosEffect, blocks: list) -> list: corrupted = effect.apply(parsed) block = {**block, "text": json.dumps(corrupted)} except (json.JSONDecodeError, ValueError): - # Plain text — apply truncation via effect if applicable - if hasattr(effect, "max_length"): + # Plain text — apply truncation if effect is TruncateFields + if isinstance(effect, TruncateFields): block = {**block, "text": text_data[: effect.max_length]} corrupted_blocks.append(block) return corrupted_blocks diff --git a/src/strands_evals/chaos/scenario.py b/src/strands_evals/chaos/scenario.py index ee127479..e387d890 100644 --- a/src/strands_evals/chaos/scenario.py +++ b/src/strands_evals/chaos/scenario.py @@ -22,7 +22,7 @@ class ChaosScenario(BaseModel): Example:: from strands_evals.chaos import ChaosScenario - from strands_evals.chaos.effects import Timeout, NetworkError, CorruptValues + from strands_evals.chaos.effects import ToolCallFailure, CorruptValues # Baseline — no chaos ChaosScenario(name="baseline") @@ -30,7 +30,7 @@ class ChaosScenario(BaseModel): # Single-fault: one tool fails ChaosScenario( name="search_timeout", - effects={"search_tool": [Timeout()]}, + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, ) # Compound: multiple tools/models fail simultaneously @@ -41,7 +41,7 @@ class ChaosScenario(BaseModel): "recovery path silently returns bad data." ), effects={ - "search_tool": [Timeout()], + "search_tool": [ToolCallFailure(error_type="timeout")], "book_tool": [CorruptValues(corrupt_ratio=0.8)], }, ) diff --git a/tests/strands_evals/chaos/test_plugin.py b/tests/strands_evals/chaos/test_plugin.py index 8bb1090d..217253d8 100644 --- a/tests/strands_evals/chaos/test_plugin.py +++ b/tests/strands_evals/chaos/test_plugin.py @@ -192,3 +192,64 @@ def test_plain_text_truncation(self, chaos_plugin): assert event.result["content"][0]["text"] == "This" finally: _current_scenario.reset(token) + + +class TestApplyRate: + """Tests for the apply_rate probability check in ChaosPlugin.""" + + def test_apply_rate_zero_skips_pre_hook_effect(self, chaos_plugin, before_event): + """Effect with apply_rate=0.0 should never fire.""" + scenario = ChaosScenario( + name="never_fires", + effects={"search_tool": [ToolCallFailure(error_type="timeout", apply_rate=0.0)]}, + ) + token = _current_scenario.set(scenario) + try: + # Run multiple times to confirm it never fires + for _ in range(20): + before_event.cancel_tool = None + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool is None + finally: + _current_scenario.reset(token) + + def test_apply_rate_one_always_fires_pre_hook(self, chaos_plugin, before_event): + """Effect with apply_rate=1.0 should always fire.""" + scenario = ChaosScenario( + name="always_fires", + effects={"search_tool": [ToolCallFailure(error_type="timeout", apply_rate=1.0)]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.before_tool_call(before_event) + assert before_event.cancel_tool == "Tool call timed out" + finally: + _current_scenario.reset(token) + + def test_apply_rate_zero_skips_post_hook_effect(self, chaos_plugin, after_event): + """Post-hook effect with apply_rate=0.0 should never fire.""" + scenario = ChaosScenario( + name="never_truncates", + effects={"search_tool": [TruncateFields(max_length=3, apply_rate=0.0)]}, + ) + token = _current_scenario.set(scenario) + try: + original_content = after_event.result["content"][0]["text"] + chaos_plugin.after_tool_call(after_event) + assert after_event.result["content"][0]["text"] == original_content + finally: + _current_scenario.reset(token) + + def test_apply_rate_one_always_fires_post_hook(self, chaos_plugin, after_event): + """Post-hook effect with apply_rate=1.0 should always fire.""" + scenario = ChaosScenario( + name="always_truncates", + effects={"search_tool": [TruncateFields(max_length=3, apply_rate=1.0)]}, + ) + token = _current_scenario.set(scenario) + try: + chaos_plugin.after_tool_call(after_event) + corrupted = json.loads(after_event.result["content"][0]["text"]) + assert corrupted["title"] == "Lon" + finally: + _current_scenario.reset(token) From 61dd7d7445b5b295d19dfe5a99f3d427c29e9e2b Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Fri, 15 May 2026 22:01:03 +0000 Subject: [PATCH 7/9] replace chaos scenario with chaos case --- src/strands_evals/chaos/__init__.py | 4 +- src/strands_evals/chaos/_context.py | 8 +- src/strands_evals/chaos/case.py | 123 +++++++++++++ src/strands_evals/chaos/effects.py | 15 +- src/strands_evals/chaos/experiment.py | 161 +++++------------ src/strands_evals/chaos/plugin.py | 24 +-- src/strands_evals/chaos/scenario.py | 65 ------- tests/strands_evals/chaos/test_case.py | 176 +++++++++++++++++++ tests/strands_evals/chaos/test_context.py | 43 +++-- tests/strands_evals/chaos/test_experiment.py | 135 ++++++-------- tests/strands_evals/chaos/test_plugin.py | 111 ++++++------ tests/strands_evals/chaos/test_scenario.py | 45 ----- 12 files changed, 513 insertions(+), 397 deletions(-) create mode 100644 src/strands_evals/chaos/case.py delete mode 100644 src/strands_evals/chaos/scenario.py create mode 100644 tests/strands_evals/chaos/test_case.py delete mode 100644 tests/strands_evals/chaos/test_scenario.py diff --git a/src/strands_evals/chaos/__init__.py b/src/strands_evals/chaos/__init__.py index 1d8a5a93..e04544ac 100644 --- a/src/strands_evals/chaos/__init__.py +++ b/src/strands_evals/chaos/__init__.py @@ -4,6 +4,7 @@ under tool failures and response corruption scenarios. """ +from .case import ChaosCase from .effects import ( ChaosEffect, CorruptValues, @@ -14,13 +15,12 @@ ) from .experiment import ChaosExperiment from .plugin import ChaosPlugin -from .scenario import ChaosScenario __all__ = [ # Core classes + "ChaosCase", "ChaosExperiment", "ChaosPlugin", - "ChaosScenario", # Effect hierarchy "ChaosEffect", "ToolEffect", diff --git a/src/strands_evals/chaos/_context.py b/src/strands_evals/chaos/_context.py index 8c8c0624..74118e56 100644 --- a/src/strands_evals/chaos/_context.py +++ b/src/strands_evals/chaos/_context.py @@ -1,4 +1,4 @@ -"""Internal context variable for tracking the active chaos scenario. +"""Internal context variable for tracking the active chaos case. The ChaosPlugin reads from this ContextVar at hook time. The ChaosExperiment sets and resets it around each case's task invocation. @@ -13,9 +13,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .scenario import ChaosScenario + from .case import ChaosCase -_current_scenario: ContextVar["ChaosScenario | None"] = ContextVar( - "chaos_current_scenario", +_current_chaos_case: ContextVar["ChaosCase | None"] = ContextVar( + "chaos_current_case", default=None, ) diff --git a/src/strands_evals/chaos/case.py b/src/strands_evals/chaos/case.py new file mode 100644 index 00000000..4d66fe44 --- /dev/null +++ b/src/strands_evals/chaos/case.py @@ -0,0 +1,123 @@ +"""Chaos case definition. + +A ChaosCase extends Case with chaos-specific fields, providing a stable +extension point for failure injection configuration without modifying the +base Case class. +""" + +import uuid + +from pydantic import Field +from typing_extensions import Generic + +from ..case import Case +from ..types.evaluation import InputT, OutputT +from .effects import ChaosEffect + + +class ChaosCase(Case, Generic[InputT, OutputT]): + """A test case with associated chaos effects. + + Extends Case to carry the effects mapping that the ChaosPlugin reads + at hook time. A ChaosCase with empty effects is a baseline run. + + The ``expand`` class method provides the Cartesian product of cases × + effect maps, producing a flat list of ChaosCase objects ready for + ChaosExperiment. + + Attributes: + effects: Mapping of tool_name -> list of effects to inject for this case. + Tools not listed behave normally. Empty dict means baseline (no chaos). + + Example:: + + from strands_evals import Case + from strands_evals.chaos import ChaosCase + from strands_evals.chaos.effects import ToolCallFailure, TruncateFields + + # Direct construction + chaos_case = ChaosCase( + name="search_timeout", + input="Find flights to Tokyo", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ) + + # Expansion from base cases × named effect maps + cases = [ + Case(name="flight_search", input="Find flights to Tokyo"), + Case(name="hotel_search", input="Find hotels in Tokyo"), + ] + effect_maps = { + "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, + "search_truncated": {"search_tool": [TruncateFields(max_length=5)]}, + } + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + # Produces 6 ChaosCase objects: 2 cases × (2 effect maps + 1 baseline) + """ + + effects: dict[str, list[ChaosEffect]] = Field( + default_factory=dict, + description="Mapping of tool_name -> list of effects to inject for this case. " + "Empty dict means baseline (no chaos).", + ) + + @classmethod + def expand( + cls, + cases: list[Case], + effect_maps: dict[str, dict[str, list[ChaosEffect]]], + include_no_effect_baseline: bool = False, + ) -> list["ChaosCase"]: + """Generate the Cartesian product of cases × named effect maps. + + Produces a flat list of ChaosCase objects, one for each (case, effect_map) + combination. Each ChaosCase gets a fresh session_id and a composite name + built from the case name and the effect map key. + + Args: + cases: Base test cases to expand. + effect_maps: Named effect configurations. Keys are short human-readable + names (used in the composite case name); values are mappings of + tool_name -> list of ChaosEffect instances. + include_no_effect_baseline: If True, includes a baseline (no chaos) + variant for each case. Defaults to False. + + Returns: + Flat list of ChaosCase objects with composite names like + "flight_search|baseline" or "flight_search|search_timeout". + """ + all_entries: list[tuple[str, dict[str, list[ChaosEffect]]]] = [] + + if include_no_effect_baseline: + all_entries.append(("baseline", {})) + + for name, effects in effect_maps.items(): + all_entries.append((name, effects)) + + expanded: list[ChaosCase] = [] + for case in cases: + for condition_name, effects in all_entries: + session_id = str(uuid.uuid4()) + expanded_name = f"{case.name}|{condition_name}" if case.name else condition_name + expanded.append( + cls( + name=expanded_name, + session_id=session_id, + input=case.input, + expected_output=case.expected_output, + expected_assertion=case.expected_assertion, + expected_trajectory=case.expected_trajectory, + expected_interactions=case.expected_interactions, + expected_environment_state=case.expected_environment_state, + metadata=case.metadata, + effects=effects, + ) + ) + + return expanded + + def __repr__(self) -> str: + effects_str = ", ".join( + f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" for target, effs in self.effects.items() + ) + return f"ChaosCase(name='{self.name}', effects={{{effects_str}}})" diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py index cf438bb3..f65af487 100644 --- a/src/strands_evals/chaos/effects.py +++ b/src/strands_evals/chaos/effects.py @@ -67,13 +67,15 @@ class ToolCallFailure(ToolEffect): Example:: - ChaosScenario( + ChaosCase( name="search_timeout", + input="Find flights", effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, ) - ChaosScenario( + ChaosCase( name="db_network_error", + input="Query database", effects={"database_tool": [ToolCallFailure( error_type="network_error", error_message="Connection refused on port 5432", @@ -106,8 +108,9 @@ class TruncateFields(ToolEffect): Example:: - ChaosScenario( + ChaosCase( name="search_truncated", + input="Find flights", effects={ "search_tool": [TruncateFields(max_length=5)], }, @@ -147,8 +150,9 @@ class RemoveFields(ToolEffect): Example:: - ChaosScenario( + ChaosCase( name="db_remove_fields", + input="Query database", effects={ "database_tool": [RemoveFields(remove_ratio=0.5)], }, @@ -193,8 +197,9 @@ class CorruptValues(ToolEffect): Example:: - ChaosScenario( + ChaosCase( name="db_corrupt", + input="Query database", effects={ "database_tool": [CorruptValues(corrupt_ratio=0.8)], }, diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py index b3f2b12e..e841da33 100644 --- a/src/strands_evals/chaos/experiment.py +++ b/src/strands_evals/chaos/experiment.py @@ -1,69 +1,59 @@ """Chaos Experiment. -Composes the base Experiment to run test cases across multiple chaos scenarios, +Composes the base Experiment to run ChaosCase objects through evaluators, providing deterministic evaluation of agent resilience under tool failures. """ import logging -import uuid from collections.abc import Callable from typing import Any, Optional -from ..case import Case from ..evaluators.evaluator import Evaluator from ..experiment import Experiment from ..types.evaluation_report import EvaluationReport -from ._context import _current_scenario -from .scenario import ChaosScenario +from ._context import _current_chaos_case +from .case import ChaosCase logger = logging.getLogger(__name__) class ChaosExperiment: - """Runs cases × scenarios by composing the base Experiment. + """Runs ChaosCase objects through evaluators with chaos-aware dispatch. - For each scenario, activates it via ContextVar, runs all cases through - the evaluators, then resets. The user's task body contains zero chaos - concepts — the plugin reads the active scenario from the ContextVar. + Sets the active ChaosCase via ContextVar before each task invocation so + the ChaosPlugin can read the case's effects at hook time. The user's task + body contains zero chaos concepts — the plugin reads the active case from + the ContextVar. - Optionally includes a baseline run (no chaos) for comparison. - - Note: The experiment runs ``len(cases) × (len(scenarios) + 1)`` evaluations - when ``include_baseline=True``, or ``len(cases) × len(scenarios)`` otherwise. - Plan scenario counts accordingly — each combination triggers a full agent - invocation. + Use ``ChaosCase.expand()`` to generate the Cartesian product of base cases + × effect sets before passing them to this experiment. Example:: + from strands_evals import Case from strands_evals.chaos import ( + ChaosCase, ChaosExperiment, ChaosPlugin, - ChaosScenario, ) from strands_evals.chaos.effects import ToolCallFailure chaos = ChaosPlugin() - scenarios = [ - ChaosScenario( - name="search_timeout", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, - ), - ChaosScenario( - name="db_down", - effects={"database_tool": [ToolCallFailure(error_type="network_error")]}, - ), + cases = [Case(input="Find flights to Tokyo", name="flight_search")] + effect_sets = [ + {"search_tool": [ToolCallFailure(error_type="timeout")]}, + {"database_tool": [ToolCallFailure(error_type="network_error")]}, ] + chaos_cases = ChaosCase.expand(cases, effect_sets) def my_task(case): agent = Agent(tools=[search_tool, database_tool], plugins=[chaos]) return {"output": str(agent(case.input))} experiment = ChaosExperiment( - cases=[Case(input="Find flights to Tokyo", name="flight_search")], - scenarios=scenarios, + cases=chaos_cases, evaluators=[my_evaluator], - include_baseline=True, ) reports = experiment.run_evaluations(task=my_task) @@ -71,89 +61,32 @@ def my_task(case): def __init__( self, - cases: list[Case], - scenarios: list[ChaosScenario], + cases: list[ChaosCase], evaluators: Optional[list[Evaluator]] = None, - include_baseline: bool = True, ): """Initialize a ChaosExperiment. Args: - cases: Test cases to evaluate. - scenarios: List of chaos scenarios. Each scenario runs all cases. - All effects in a scenario fire simultaneously in a single run. + cases: ChaosCase objects to evaluate. Use ``ChaosCase.expand()`` + to generate these from base cases and effect sets. evaluators: Evaluators to assess results. - include_baseline: If True, runs all cases with no chaos first for comparison. """ - self._original_cases = cases - self._scenarios = scenarios + self._cases = cases self._evaluators = evaluators - self._include_baseline = include_baseline - - # Build the expanded case list and internal maps - self._expanded_cases: list[Case] = [] - self._scenario_by_session: dict[str, ChaosScenario] = {} - self._original_case_name_by_session: dict[str, Optional[str]] = {} - - all_scenarios = [ChaosScenario(name="baseline")] if include_baseline else [] - all_scenarios.extend(scenarios) - - for case in cases: - for scenario in all_scenarios: - # Create expanded case with fresh session_id - session_id = str(uuid.uuid4()) - expanded_case = case.model_copy( - update={ - "name": f"{case.name}|{scenario.name}" if case.name else scenario.name, - "session_id": session_id, - } - ) - self._expanded_cases.append(expanded_case) - self._scenario_by_session[session_id] = scenario - self._original_case_name_by_session[session_id] = case.name - - # Internal Experiment with expanded cases + + # Internal Experiment with the chaos cases self._experiment = Experiment( - cases=self._expanded_cases, + cases=list(cases), evaluators=evaluators, ) @property - def scenarios(self) -> list[ChaosScenario]: - """The chaos scenarios configured for this experiment.""" - return self._scenarios - - @property - def cases(self) -> list[Case]: - """The original (unexpanded) test cases.""" - return self._original_cases - - def get_scenario_for_session(self, session_id: str) -> Optional[ChaosScenario]: - """Look up the scenario assigned to a given session_id. - - Useful for downstream aggregation and reporting. - - Args: - session_id: The session_id of an expanded case. - - Returns: - The ChaosScenario for that session, or None if not found. - """ - return self._scenario_by_session.get(session_id) - - def get_original_case_name(self, session_id: str) -> Optional[str]: - """Look up the original case name for a given session_id. - - Args: - session_id: The session_id of an expanded case. - - Returns: - The original case name, or None if not found. - """ - return self._original_case_name_by_session.get(session_id) + def cases(self) -> list[ChaosCase]: + """The ChaosCase objects configured for this experiment.""" + return self._cases - def _wrap_task(self, task: Callable[[Case], Any]) -> Callable[[Case], Any]: - """Wrap a task function to activate the correct scenario via ContextVar. + def _wrap_task(self, task: Callable[[ChaosCase], Any]) -> Callable[[ChaosCase], Any]: + """Wrap a task function to activate the correct ChaosCase via ContextVar. Handles both sync and async tasks — returns a sync wrapper for sync tasks and an async wrapper for async tasks, so the base Experiment dispatches @@ -169,45 +102,43 @@ def _wrap_task(self, task: Callable[[Case], Any]) -> Callable[[Case], Any]: if asyncio.iscoroutinefunction(task): - async def chaos_aware_task_async(case: Case) -> Any: - scenario = self._scenario_by_session.get(case.session_id) - token = _current_scenario.set(scenario) + async def chaos_aware_task_async(case: ChaosCase) -> Any: + token = _current_chaos_case.set(case) try: return await task(case) finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) return chaos_aware_task_async else: - def chaos_aware_task(case: Case) -> Any: - scenario = self._scenario_by_session.get(case.session_id) - token = _current_scenario.set(scenario) + def chaos_aware_task(case: ChaosCase) -> Any: + token = _current_chaos_case.set(case) try: return task(case) finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) return chaos_aware_task def run_evaluations( self, - task: Callable[[Case], Any], + task: Callable[[ChaosCase], Any], **kwargs, ) -> list[EvaluationReport]: - """Run evaluations across all (case × scenario) combinations. + """Run evaluations across all ChaosCase objects. Delegates to run_evaluations_async with max_workers=1, mirroring the base Experiment pattern. Args: - task: The task function to evaluate. Takes a Case and returns output. + task: The task function to evaluate. Takes a ChaosCase and returns output. The task body should contain zero chaos concepts — just construct the agent with plugins=[chaos] and call it. **kwargs: Additional kwargs passed to the base Experiment.run_evaluations_async. Returns: - List of EvaluationReport objects covering all scenarios. + List of EvaluationReport objects. Raises: ValueError: If an async task is passed (use run_evaluations_async instead). @@ -223,11 +154,11 @@ def run_evaluations( async def run_evaluations_async( self, - task: Callable[[Case], Any], + task: Callable[[ChaosCase], Any], max_workers: int = 10, **kwargs, ) -> list[EvaluationReport]: - """Run evaluations asynchronously across all (case × scenario) combinations. + """Run evaluations asynchronously across all ChaosCase objects. Wraps the user's task to set the ContextVar before each case execution. The base Experiment handles sync-to-async dispatch internally. @@ -238,16 +169,14 @@ async def run_evaluations_async( **kwargs: Additional kwargs passed to the base Experiment.run_evaluations_async. Returns: - List of EvaluationReport objects covering all scenarios. + List of EvaluationReport objects. """ wrapped = self._wrap_task(task) reports = await self._experiment.run_evaluations_async(wrapped, max_workers=max_workers, **kwargs) - num_scenarios = len(self._scenarios) + (1 if self._include_baseline else 0) logger.info( - "cases=<%d>, scenarios=<%d>, reports=<%d> | chaos experiment complete", - len(self._original_cases), - num_scenarios, + "cases=<%d>, reports=<%d> | chaos experiment complete", + len(self._cases), len(reports), ) diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index 7f76f584..8d1dbed9 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -3,7 +3,7 @@ Implements chaos injection as a standard Strands Plugin using the SDK's native hook system (BeforeToolCallEvent / AfterToolCallEvent). -The plugin reads the active scenario from a module-level ContextVar at hook +The plugin reads the active ChaosCase from a module-level ContextVar at hook time. The ChaosExperiment manages the ContextVar lifecycle. """ @@ -14,21 +14,21 @@ from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.plugins import Plugin, hook -from ._context import _current_scenario +from ._context import _current_chaos_case from .effects import ChaosEffect, TruncateFields logger = logging.getLogger(__name__) class ChaosPlugin(Plugin): - """Strands Plugin that injects deterministic chaos based on the active scenario. + """Strands Plugin that injects deterministic chaos based on the active ChaosCase. The plugin intercepts tool calls via Strands' native hook system: - BeforeToolCallEvent: cancels tool calls for pre-hook effects (ToolCallFailure) - AfterToolCallEvent: corrupts tool responses for post-hook effects (TruncateFields, etc.) - The active scenario is managed via a ContextVar (set by ChaosExperiment). - When no scenario is active, all tools behave normally. + The active ChaosCase is managed via a ContextVar (set by ChaosExperiment). + When no ChaosCase is active or the case has no effects, all tools behave normally. Example:: @@ -42,7 +42,7 @@ class ChaosPlugin(Plugin): plugins=[chaos], ) - # The ChaosExperiment handles scenario activation via ContextVar. + # The ChaosExperiment handles ChaosCase activation via ContextVar. # The user's task body contains zero chaos concepts. """ @@ -59,12 +59,12 @@ def before_tool_call(self, event: BeforeToolCallEvent) -> None: etc.), cancels the tool call with the effect's error_message before the tool executes. """ - scenario = _current_scenario.get() - if scenario is None: + chaos_case = _current_chaos_case.get() + if chaos_case is None or not chaos_case.effects: return tool_name = event.tool_use.get("name", "") - effects = scenario.effects.get(tool_name, []) + effects = chaos_case.effects.get(tool_name, []) if not effects: return @@ -84,12 +84,12 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: For corruption effects (TruncateFields, RemoveFields, CorruptValues), applies effect.apply() to JSON content blocks in the tool response. """ - scenario = _current_scenario.get() - if scenario is None: + chaos_case = _current_chaos_case.get() + if chaos_case is None or not chaos_case.effects: return tool_name = event.tool_use.get("name", "") - effects = scenario.effects.get(tool_name, []) + effects = chaos_case.effects.get(tool_name, []) if not effects: return diff --git a/src/strands_evals/chaos/scenario.py b/src/strands_evals/chaos/scenario.py deleted file mode 100644 index e387d890..00000000 --- a/src/strands_evals/chaos/scenario.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Chaos scenario definition. - -A ChaosScenario is a named, deterministic configuration of chaos effects -that will fire simultaneously when the scenario is active. -""" - -from typing import Optional - -from pydantic import BaseModel, Field - -from .effects import ChaosEffect - - -class ChaosScenario(BaseModel): - """A single, deterministic chaos injection scenario. - - Each scenario defines a set of tool effects that fire simultaneously when - the scenario is active. - - Tools not listed in tool_effects behave normally (no chaos). - - Example:: - - from strands_evals.chaos import ChaosScenario - from strands_evals.chaos.effects import ToolCallFailure, CorruptValues - - # Baseline — no chaos - ChaosScenario(name="baseline") - - # Single-fault: one tool fails - ChaosScenario( - name="search_timeout", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, - ) - - # Compound: multiple tools/models fail simultaneously - ChaosScenario( - name="search_times_out_while_book_corrupts", - description=( - "Worst-case compound: primary path fails hard while the " - "recovery path silently returns bad data." - ), - effects={ - "search_tool": [ToolCallFailure(error_type="timeout")], - "book_tool": [CorruptValues(corrupt_ratio=0.8)], - }, - ) - """ - - name: str = Field(..., description="Human-readable name for this scenario") - description: Optional[str] = Field( - default=None, - description="Optional description of what this scenario tests.", - ) - effects: dict[str, list[ChaosEffect]] = Field( - default_factory=dict, - description="Mapping of target_name -> list of effects to inject simultaneously. " - "Targets not listed here behave normally.", - ) - - def __repr__(self) -> str: - effects_str = ", ".join( - f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" for target, effs in self.effects.items() - ) - return f"ChaosScenario(name='{self.name}', effects={{{effects_str}}})" diff --git a/tests/strands_evals/chaos/test_case.py b/tests/strands_evals/chaos/test_case.py new file mode 100644 index 00000000..7a57ea78 --- /dev/null +++ b/tests/strands_evals/chaos/test_case.py @@ -0,0 +1,176 @@ +"""Unit tests for ChaosCase.""" + +from strands_evals import Case +from strands_evals.chaos import ChaosCase +from strands_evals.chaos.effects import CorruptValues, ToolCallFailure, TruncateFields + + +class TestChaosCase: + """Tests for the ChaosCase data model.""" + + def test_baseline_case_has_no_effects(self): + case = ChaosCase(name="baseline", input="hello") + assert case.effects == {} + + def test_case_with_effects(self): + case = ChaosCase( + name="search_timeout", + input="hello", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ) + assert len(case.effects) == 1 + assert isinstance(case.effects["search_tool"][0], ToolCallFailure) + + def test_case_with_multiple_tools(self): + case = ChaosCase( + name="compound", + input="hello", + effects={ + "search_tool": [ToolCallFailure(error_type="timeout")], + "db_tool": [CorruptValues(corrupt_ratio=0.8)], + }, + ) + assert len(case.effects) == 2 + + def test_case_with_multiple_effects_per_tool(self): + case = ChaosCase( + name="multi_effect", + input="hello", + effects={ + "tool_a": [ + TruncateFields(max_length=5), + CorruptValues(corrupt_ratio=0.3), + ], + }, + ) + assert len(case.effects["tool_a"]) == 2 + + def test_inherits_case_fields(self): + case = ChaosCase( + name="with_expected", + input="hello", + expected_output="world", + expected_trajectory=["tool_a"], + metadata={"key": "value"}, + effects={"tool_a": [ToolCallFailure()]}, + ) + assert case.input == "hello" + assert case.expected_output == "world" + assert case.expected_trajectory == ["tool_a"] + assert case.metadata == {"key": "value"} + + def test_repr_shows_effects(self): + case = ChaosCase( + name="test", + input="hello", + effects={"tool": [ToolCallFailure()]}, + ) + repr_str = repr(case) + assert "test" in repr_str + assert "ToolCallFailure" in repr_str + + +class TestChaosCaseExpand: + """Tests for the ChaosCase.expand() class method.""" + + def test_expand_with_baseline(self): + cases = [ + Case(name="case_a", input="hello"), + Case(name="case_b", input="world"), + ] + effect_maps = { + "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, + "db_corrupt": {"db_tool": [CorruptValues(corrupt_ratio=0.8)]}, + } + result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + # 2 cases × (2 effect maps + 1 baseline) = 6 + assert len(result) == 6 + + def test_expand_without_baseline(self): + cases = [ + Case(name="case_a", input="hello"), + Case(name="case_b", input="world"), + ] + effect_maps = { + "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, + } + result = ChaosCase.expand(cases, effect_maps) + # 2 cases × 1 effect map = 2 (no baseline by default) + assert len(result) == 2 + + def test_expand_baseline_names(self): + cases = [Case(name="case_a", input="hello")] + effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + names = [c.name for c in result] + assert "case_a|baseline" in names + + def test_expand_uses_dict_keys_as_names(self): + cases = [Case(name="case_a", input="hello")] + effect_maps = {"search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}} + result = ChaosCase.expand(cases, effect_maps) + assert result[0].name == "case_a|search_timeout" + + def test_expand_compound_effect_name(self): + cases = [Case(name="case_a", input="hello")] + effect_maps = { + "multi_failure": { + "search_tool": [ToolCallFailure(error_type="timeout")], + "db_tool": [CorruptValues()], + } + } + result = ChaosCase.expand(cases, effect_maps) + assert result[0].name == "case_a|multi_failure" + + def test_expand_unique_session_ids(self): + cases = [Case(name="case_a", input="hello"), Case(name="case_b", input="world")] + effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + result = ChaosCase.expand(cases, effect_maps) + session_ids = [c.session_id for c in result] + assert len(session_ids) == len(set(session_ids)) + + def test_expand_preserves_case_fields(self): + cases = [ + Case( + name="case_a", + input="hello", + expected_output="world", + expected_trajectory=["tool_a"], + metadata={"key": "value"}, + ) + ] + effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + result = ChaosCase.expand(cases, effect_maps) + expanded = result[0] + assert expanded.input == "hello" + assert expanded.expected_output == "world" + assert expanded.expected_trajectory == ["tool_a"] + assert expanded.metadata == {"key": "value"} + + def test_expand_baseline_has_empty_effects(self): + cases = [Case(name="case_a", input="hello")] + effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + baseline = [c for c in result if "baseline" in c.name][0] + assert baseline.effects == {} + + def test_expand_empty_effect_maps_with_baseline(self): + cases = [Case(name="case_a", input="hello")] + result = ChaosCase.expand(cases, {}, include_no_effect_baseline=True) + # Only baseline + assert len(result) == 1 + assert "baseline" in result[0].name + + def test_expand_empty_effect_maps_without_baseline(self): + cases = [Case(name="case_a", input="hello")] + result = ChaosCase.expand(cases, {}) + # No baseline by default, no effect maps → empty + assert len(result) == 0 + + def test_expand_case_without_name(self): + cases = [Case(input="hello")] + effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + names = [c.name for c in result] + assert "baseline" in names + assert "timeout" in names diff --git a/tests/strands_evals/chaos/test_context.py b/tests/strands_evals/chaos/test_context.py index 5b93ced5..ebc2d030 100644 --- a/tests/strands_evals/chaos/test_context.py +++ b/tests/strands_evals/chaos/test_context.py @@ -1,33 +1,42 @@ """Unit tests for the chaos _context module.""" -from strands_evals.chaos import ChaosScenario -from strands_evals.chaos._context import _current_scenario +from strands_evals.chaos import ChaosCase +from strands_evals.chaos._context import _current_chaos_case class TestContextVar: - """Tests for the _current_scenario ContextVar.""" + """Tests for the _current_chaos_case ContextVar.""" + + def test_default_is_none(self): + assert _current_chaos_case.get() is None def test_set_and_get(self): - scenario = ChaosScenario(name="test_scenario") - token = _current_scenario.set(scenario) + case = ChaosCase(name="test_case", input="hello") + token = _current_chaos_case.set(case) try: - assert _current_scenario.get() is scenario - assert _current_scenario.get().name == "test_scenario" + assert _current_chaos_case.get() is case + assert _current_chaos_case.get().name == "test_case" finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_nested_set_and_reset(self): - s1 = ChaosScenario(name="outer") - s2 = ChaosScenario(name="inner") + c1 = ChaosCase(name="outer", input="hello") + c2 = ChaosCase(name="inner", input="world") - token1 = _current_scenario.set(s1) + token1 = _current_chaos_case.set(c1) try: - assert _current_scenario.get().name == "outer" - token2 = _current_scenario.set(s2) + assert _current_chaos_case.get().name == "outer" + token2 = _current_chaos_case.set(c2) try: - assert _current_scenario.get().name == "inner" + assert _current_chaos_case.get().name == "inner" finally: - _current_scenario.reset(token2) - assert _current_scenario.get().name == "outer" + _current_chaos_case.reset(token2) + assert _current_chaos_case.get().name == "outer" finally: - _current_scenario.reset(token1) + _current_chaos_case.reset(token1) + + def test_reset_restores_none(self): + case = ChaosCase(name="test", input="hello") + token = _current_chaos_case.set(case) + _current_chaos_case.reset(token) + assert _current_chaos_case.get() is None diff --git a/tests/strands_evals/chaos/test_experiment.py b/tests/strands_evals/chaos/test_experiment.py index 9b43c4a2..4fb0c574 100644 --- a/tests/strands_evals/chaos/test_experiment.py +++ b/tests/strands_evals/chaos/test_experiment.py @@ -3,8 +3,8 @@ import pytest from strands_evals import Case -from strands_evals.chaos import ChaosExperiment, ChaosScenario -from strands_evals.chaos._context import _current_scenario +from strands_evals.chaos import ChaosCase, ChaosExperiment +from strands_evals.chaos._context import _current_chaos_case from strands_evals.chaos.effects import CorruptValues, ToolCallFailure from strands_evals.evaluators.evaluator import Evaluator from strands_evals.types import EvaluationData, EvaluationOutput @@ -26,17 +26,11 @@ def cases(): @pytest.fixture -def scenarios(): - return [ - ChaosScenario( - name="search_timeout", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, - ), - ChaosScenario( - name="db_corrupt", - effects={"db_tool": [CorruptValues(corrupt_ratio=0.8)]}, - ), - ] +def effect_maps(): + return { + "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, + "db_corrupt": {"db_tool": [CorruptValues(corrupt_ratio=0.8)]}, + } @pytest.fixture @@ -47,95 +41,71 @@ def evaluator(): class TestChaosExperiment: """Tests for ChaosExperiment initialization and execution.""" - def test_expanded_cases_count_with_baseline(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) - # 2 cases × (2 scenarios + 1 baseline) = 6 - assert len(experiment._expanded_cases) == 6 - - def test_expanded_cases_count_without_baseline(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=False) - # 2 cases × 2 scenarios = 4 - assert len(experiment._expanded_cases) == 4 - - def test_expanded_case_names_include_scenario(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) - names = [c.name for c in experiment._expanded_cases] + def test_cases_count_with_baseline(self, cases, effect_maps, evaluator): + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + # 2 cases × (2 effect maps + 1 baseline) = 6 + assert len(experiment.cases) == 6 + + def test_cases_count_without_baseline(self, cases, effect_maps, evaluator): + chaos_cases = ChaosCase.expand(cases, effect_maps) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + # 2 cases × 2 effect maps = 4 + assert len(experiment.cases) == 4 + + def test_case_names_include_effect_map_key(self, cases, effect_maps, evaluator): + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + names = [c.name for c in experiment.cases] assert "case_a|baseline" in names - assert "case_a|search_timeout" in names - assert "case_a|db_corrupt" in names assert "case_b|baseline" in names - assert "case_b|search_timeout" in names + assert "case_a|search_timeout" in names assert "case_b|db_corrupt" in names - def test_each_expanded_case_has_unique_session_id(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator]) - session_ids = [c.session_id for c in experiment._expanded_cases] + def test_each_case_has_unique_session_id(self, cases, effect_maps, evaluator): + chaos_cases = ChaosCase.expand(cases, effect_maps) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + session_ids = [c.session_id for c in experiment.cases] assert len(session_ids) == len(set(session_ids)) - def test_get_scenario_for_session(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) - # Pick an expanded case and verify its scenario maps correctly - for expanded_case in experiment._expanded_cases: - scenario = experiment.get_scenario_for_session(expanded_case.session_id) - assert scenario is not None - assert scenario.name in expanded_case.name - - def test_get_scenario_for_unknown_session(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator]) - assert experiment.get_scenario_for_session("nonexistent-id") is None - - def test_get_original_case_name(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) - for expanded_case in experiment._expanded_cases: - original_name = experiment.get_original_case_name(expanded_case.session_id) - assert original_name in ("case_a", "case_b") - - def test_get_original_case_name_unknown_session(self, cases, scenarios, evaluator): - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator]) - assert experiment.get_original_case_name("nonexistent-id") is None - - def test_context_var_set_and_reset(self, cases, scenarios, evaluator): - """Verify the ContextVar is set to the correct scenario during task execution and reset after.""" - observed_scenarios = [] - - def capturing_task(case: Case): - scenario = _current_scenario.get() - observed_scenarios.append((case.name, scenario.name if scenario else None)) + def test_context_var_set_and_reset(self, cases, effect_maps, evaluator): + """Verify the ContextVar is set to the correct ChaosCase during task execution and reset after.""" + observed_cases = [] + + def capturing_task(case: ChaosCase): + active_case = _current_chaos_case.get() + observed_cases.append((case.name, active_case.name if active_case else None)) return "output" - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) experiment.run_evaluations(task=capturing_task) - # Should have 6 observations (2 cases × 3 scenarios) - assert len(observed_scenarios) == 6 - - # Verify baseline scenarios observed - baseline_obs = [(name, sn) for name, sn in observed_scenarios if sn == "baseline"] - assert len(baseline_obs) == 2 + # Should have 6 observations (2 cases × 3 conditions) + assert len(observed_cases) == 6 - # Verify chaos scenarios observed - timeout_obs = [(name, sn) for name, sn in observed_scenarios if sn == "search_timeout"] - assert len(timeout_obs) == 2 + # Verify the ContextVar matched the case being executed + for case_name, active_name in observed_cases: + assert case_name == active_name # After all runs, the ContextVar should be back to None - assert _current_scenario.get() is None + assert _current_chaos_case.get() is None def test_context_var_reset_on_task_exception(self, evaluator): """Verify the ContextVar is reset even if the task raises.""" cases = [Case(name="failing", input="x")] - scenarios_list = [ChaosScenario(name="chaos", effects={"t": [ToolCallFailure()]})] + effect_maps = {"chaos": {"t": [ToolCallFailure()]}} + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) call_count = [0] - def failing_task(case: Case): + def failing_task(case: ChaosCase): call_count[0] += 1 if call_count[0] == 1: raise RuntimeError("Task failed") return "output" - experiment = ChaosExperiment( - cases=cases, scenarios=scenarios_list, evaluators=[evaluator], include_baseline=True - ) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) # The base Experiment should handle the exception internally # ContextVar should still be reset @@ -144,18 +114,19 @@ def failing_task(case: Case): except Exception: pass - assert _current_scenario.get() is None + assert _current_chaos_case.get() is None - def test_returns_evaluation_reports(self, cases, scenarios, evaluator): + def test_returns_evaluation_reports(self, cases, effect_maps, evaluator): """Verify run_evaluations returns reports.""" - def task(case: Case): + def task(case: ChaosCase): return "output" - experiment = ChaosExperiment(cases=cases, scenarios=scenarios, evaluators=[evaluator], include_baseline=True) + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) reports = experiment.run_evaluations(task=task) assert len(reports) >= 1 report = reports[0] - # 2 cases × 3 scenarios = 6 scores + # 2 cases × 3 conditions = 6 scores assert len(report.scores) == 6 diff --git a/tests/strands_evals/chaos/test_plugin.py b/tests/strands_evals/chaos/test_plugin.py index 217253d8..45b1844f 100644 --- a/tests/strands_evals/chaos/test_plugin.py +++ b/tests/strands_evals/chaos/test_plugin.py @@ -5,8 +5,8 @@ import pytest -from strands_evals.chaos import ChaosPlugin, ChaosScenario -from strands_evals.chaos._context import _current_scenario +from strands_evals.chaos import ChaosCase, ChaosPlugin +from strands_evals.chaos._context import _current_chaos_case from strands_evals.chaos.effects import ( ToolCallFailure, TruncateFields, @@ -43,53 +43,57 @@ def after_event(): class TestChaosPluginBeforeToolCall: """Tests for the before_tool_call hook.""" - def test_no_scenario_active_does_nothing(self, chaos_plugin, before_event): - token = _current_scenario.set(None) + def test_no_case_active_does_nothing(self, chaos_plugin, before_event): + token = _current_chaos_case.set(None) try: chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool is None finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) - def test_scenario_without_matching_tool_does_nothing(self, chaos_plugin, before_event): - scenario = ChaosScenario( + def test_case_without_matching_tool_does_nothing(self, chaos_plugin, before_event): + case = ChaosCase( name="other_tool_fails", + input="test", effects={"other_tool": [ToolCallFailure(error_type="timeout")]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool is None finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_pre_hook_effect_cancels_tool(self, chaos_plugin, before_event): - scenario = ChaosScenario( + case = ChaosCase( name="search_timeout", + input="test", effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool == "Tool call timed out" finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_post_hook_effect_does_not_cancel_tool(self, chaos_plugin, before_event): - scenario = ChaosScenario( + case = ChaosCase( name="search_truncated", + input="test", effects={"search_tool": [TruncateFields(max_length=5)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool is None finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_first_pre_hook_effect_wins(self, chaos_plugin, before_event): - scenario = ChaosScenario( + case = ChaosCase( name="multi_pre", + input="test", effects={ "search_tool": [ ToolCallFailure(error_type="timeout"), @@ -97,80 +101,84 @@ def test_first_pre_hook_effect_wins(self, chaos_plugin, before_event): ] }, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool == "Tool call timed out" finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) class TestChaosPluginAfterToolCall: """Tests for the after_tool_call hook.""" - def test_no_scenario_active_does_nothing(self, chaos_plugin, after_event): - token = _current_scenario.set(None) + def test_no_case_active_does_nothing(self, chaos_plugin, after_event): + token = _current_chaos_case.set(None) try: original_content = after_event.result["content"][0]["text"] chaos_plugin.after_tool_call(after_event) assert after_event.result["content"][0]["text"] == original_content finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) - def test_scenario_without_matching_tool_does_nothing(self, chaos_plugin, after_event): - scenario = ChaosScenario( + def test_case_without_matching_tool_does_nothing(self, chaos_plugin, after_event): + case = ChaosCase( name="other_tool", + input="test", effects={"other_tool": [TruncateFields(max_length=3)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: original_content = after_event.result["content"][0]["text"] chaos_plugin.after_tool_call(after_event) assert after_event.result["content"][0]["text"] == original_content finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_post_hook_corrupts_json_text_blocks(self, chaos_plugin, after_event): - scenario = ChaosScenario( + case = ChaosCase( name="truncate", + input="test", effects={"search_tool": [TruncateFields(max_length=3)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.after_tool_call(after_event) corrupted = json.loads(after_event.result["content"][0]["text"]) assert corrupted["title"] == "Lon" assert corrupted["count"] == 42 # non-string preserved finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_pre_hook_effect_ignored_in_after_hook(self, chaos_plugin, after_event): - scenario = ChaosScenario( + case = ChaosCase( name="pre_only", + input="test", effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: original_content = after_event.result["content"][0]["text"] chaos_plugin.after_tool_call(after_event) assert after_event.result["content"][0]["text"] == original_content finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_none_result_is_skipped(self, chaos_plugin): event = MagicMock() event.tool_use = {"name": "search_tool"} event.result = None - scenario = ChaosScenario( + case = ChaosCase( name="truncate", + input="test", effects={"search_tool": [TruncateFields(max_length=3)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.after_tool_call(event) # Should not raise finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_plain_text_truncation(self, chaos_plugin): """Test that plain (non-JSON) text blocks get truncated if effect has max_length.""" @@ -182,16 +190,17 @@ def test_plain_text_truncation(self, chaos_plugin): "toolUseId": "tool-456", } - scenario = ChaosScenario( + case = ChaosCase( name="truncate", + input="test", effects={"search_tool": [TruncateFields(max_length=4)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.after_tool_call(event) assert event.result["content"][0]["text"] == "This" finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) class TestApplyRate: @@ -199,11 +208,12 @@ class TestApplyRate: def test_apply_rate_zero_skips_pre_hook_effect(self, chaos_plugin, before_event): """Effect with apply_rate=0.0 should never fire.""" - scenario = ChaosScenario( + case = ChaosCase( name="never_fires", + input="test", effects={"search_tool": [ToolCallFailure(error_type="timeout", apply_rate=0.0)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: # Run multiple times to confirm it never fires for _ in range(20): @@ -211,45 +221,48 @@ def test_apply_rate_zero_skips_pre_hook_effect(self, chaos_plugin, before_event) chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool is None finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_apply_rate_one_always_fires_pre_hook(self, chaos_plugin, before_event): """Effect with apply_rate=1.0 should always fire.""" - scenario = ChaosScenario( + case = ChaosCase( name="always_fires", + input="test", effects={"search_tool": [ToolCallFailure(error_type="timeout", apply_rate=1.0)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.before_tool_call(before_event) assert before_event.cancel_tool == "Tool call timed out" finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_apply_rate_zero_skips_post_hook_effect(self, chaos_plugin, after_event): """Post-hook effect with apply_rate=0.0 should never fire.""" - scenario = ChaosScenario( + case = ChaosCase( name="never_truncates", + input="test", effects={"search_tool": [TruncateFields(max_length=3, apply_rate=0.0)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: original_content = after_event.result["content"][0]["text"] chaos_plugin.after_tool_call(after_event) assert after_event.result["content"][0]["text"] == original_content finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) def test_apply_rate_one_always_fires_post_hook(self, chaos_plugin, after_event): """Post-hook effect with apply_rate=1.0 should always fire.""" - scenario = ChaosScenario( + case = ChaosCase( name="always_truncates", + input="test", effects={"search_tool": [TruncateFields(max_length=3, apply_rate=1.0)]}, ) - token = _current_scenario.set(scenario) + token = _current_chaos_case.set(case) try: chaos_plugin.after_tool_call(after_event) corrupted = json.loads(after_event.result["content"][0]["text"]) assert corrupted["title"] == "Lon" finally: - _current_scenario.reset(token) + _current_chaos_case.reset(token) diff --git a/tests/strands_evals/chaos/test_scenario.py b/tests/strands_evals/chaos/test_scenario.py deleted file mode 100644 index d9eceb07..00000000 --- a/tests/strands_evals/chaos/test_scenario.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Unit tests for ChaosScenario.""" - -from strands_evals.chaos import ChaosScenario -from strands_evals.chaos.effects import CorruptValues, ToolCallFailure, TruncateFields - - -class TestChaosScenario: - """Tests for the ChaosScenario data model.""" - - def test_baseline_scenario_has_no_effects(self): - scenario = ChaosScenario(name="baseline") - assert scenario.effects == {} - - def test_scenario_with_multiple_tools(self): - scenario = ChaosScenario( - name="compound_failure", - effects={ - "search_tool": [ToolCallFailure(error_type="timeout")], - "db_tool": [CorruptValues(corrupt_ratio=0.8)], - }, - ) - assert len(scenario.effects) == 2 - assert isinstance(scenario.effects["search_tool"][0], ToolCallFailure) - assert isinstance(scenario.effects["db_tool"][0], CorruptValues) - - def test_scenario_with_multiple_effects_per_tool(self): - scenario = ChaosScenario( - name="multi_effect", - effects={ - "tool_a": [ - TruncateFields(max_length=5), - CorruptValues(corrupt_ratio=0.3), - ], - }, - ) - assert len(scenario.effects["tool_a"]) == 2 - - def test_repr_shows_effects(self): - scenario = ChaosScenario( - name="test", - effects={"tool": [ToolCallFailure()]}, - ) - repr_str = repr(scenario) - assert "test" in repr_str - assert "ToolCallFailure" in repr_str From 46a49ed41b55bf6f8122b2005b099e6f964ef860 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Thu, 21 May 2026 23:52:35 +0000 Subject: [PATCH 8/9] update chaos effect type and map; fix pydantic serialization and async test --- src/strands_evals/chaos/__init__.py | 15 +- src/strands_evals/chaos/case.py | 52 +++--- src/strands_evals/chaos/effects.py | 157 +++++++++++++++---- src/strands_evals/chaos/experiment.py | 12 +- src/strands_evals/chaos/plugin.py | 14 +- tests/strands_evals/chaos/test_case.py | 95 +++++++---- tests/strands_evals/chaos/test_effects.py | 112 ++++++++++--- tests/strands_evals/chaos/test_experiment.py | 56 ++++++- tests/strands_evals/chaos/test_plugin.py | 37 +++-- 9 files changed, 411 insertions(+), 139 deletions(-) diff --git a/src/strands_evals/chaos/__init__.py b/src/strands_evals/chaos/__init__.py index e04544ac..8670012d 100644 --- a/src/strands_evals/chaos/__init__.py +++ b/src/strands_evals/chaos/__init__.py @@ -8,10 +8,14 @@ from .effects import ( ChaosEffect, CorruptValues, + ExecutionError, + NetworkError, RemoveFields, - ToolCallFailure, + Timeout, ToolEffect, + ToolEffectUnion, TruncateFields, + ValidationError, ) from .experiment import ChaosExperiment from .plugin import ChaosPlugin @@ -24,8 +28,13 @@ # Effect hierarchy "ChaosEffect", "ToolEffect", - # Concrete effects - "ToolCallFailure", + "ToolEffectUnion", + # Pre-hook effects (tool call failures) + "Timeout", + "NetworkError", + "ExecutionError", + "ValidationError", + # Post-hook effects (response corruption) "TruncateFields", "RemoveFields", "CorruptValues", diff --git a/src/strands_evals/chaos/case.py b/src/strands_evals/chaos/case.py index 4d66fe44..4419548f 100644 --- a/src/strands_evals/chaos/case.py +++ b/src/strands_evals/chaos/case.py @@ -12,7 +12,7 @@ from ..case import Case from ..types.evaluation import InputT, OutputT -from .effects import ChaosEffect +from .effects import ToolEffectUnion class ChaosCase(Case, Generic[InputT, OutputT]): @@ -26,20 +26,20 @@ class ChaosCase(Case, Generic[InputT, OutputT]): ChaosExperiment. Attributes: - effects: Mapping of tool_name -> list of effects to inject for this case. - Tools not listed behave normally. Empty dict means baseline (no chaos). + effects: A dict keyed by effect category. Currently supports + ``"tool_effects"`` mapping tool_name -> list of effects. Example:: from strands_evals import Case from strands_evals.chaos import ChaosCase - from strands_evals.chaos.effects import ToolCallFailure, TruncateFields + from strands_evals.chaos.effects import Timeout, TruncateFields # Direct construction chaos_case = ChaosCase( name="search_timeout", input="Find flights to Tokyo", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + effects={"tool_effects": {"search_tool": [Timeout()]}}, ) # Expansion from base cases × named effect maps @@ -48,24 +48,24 @@ class ChaosCase(Case, Generic[InputT, OutputT]): Case(name="hotel_search", input="Find hotels in Tokyo"), ] effect_maps = { - "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, - "search_truncated": {"search_tool": [TruncateFields(max_length=5)]}, + "search_timeout": {"tool_effects": {"search_tool": [Timeout()]}}, + "search_truncated": {"tool_effects": {"search_tool": [TruncateFields(max_length=5)]}}, } chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) # Produces 6 ChaosCase objects: 2 cases × (2 effect maps + 1 baseline) """ - effects: dict[str, list[ChaosEffect]] = Field( + effects: dict[str, dict[str, list[ToolEffectUnion]]] = Field( default_factory=dict, - description="Mapping of tool_name -> list of effects to inject for this case. " - "Empty dict means baseline (no chaos).", + description="Effect categories. Currently supports 'tool_effects' mapping " + "tool_name -> list of effects. Empty dict means baseline (no chaos).", ) @classmethod def expand( cls, cases: list[Case], - effect_maps: dict[str, dict[str, list[ChaosEffect]]], + effect_maps: dict[str, dict[str, dict[str, list[ToolEffectUnion]]]], include_no_effect_baseline: bool = False, ) -> list["ChaosCase"]: """Generate the Cartesian product of cases × named effect maps. @@ -77,8 +77,16 @@ def expand( Args: cases: Base test cases to expand. effect_maps: Named effect configurations. Keys are short human-readable - names (used in the composite case name); values are mappings of - tool_name -> list of ChaosEffect instances. + names (used in the composite case name); values are dicts keyed by + effect category (e.g. ``"tool_effects"``) mapping tool_name -> list + of effect instances. + Example:: + + { + "search_timeout": { + "tool_effects": {"search_tool": [Timeout()]} + }, + } include_no_effect_baseline: If True, includes a baseline (no chaos) variant for each case. Defaults to False. @@ -86,19 +94,20 @@ def expand( Flat list of ChaosCase objects with composite names like "flight_search|baseline" or "flight_search|search_timeout". """ - all_entries: list[tuple[str, dict[str, list[ChaosEffect]]]] = [] + all_entries: list[tuple[str, dict[str, dict[str, list[ToolEffectUnion]]]]] = [] if include_no_effect_baseline: all_entries.append(("baseline", {})) - for name, effects in effect_maps.items(): - all_entries.append((name, effects)) + for name, effects_config in effect_maps.items(): + all_entries.append((name, effects_config)) expanded: list[ChaosCase] = [] for case in cases: - for condition_name, effects in all_entries: + for condition_name, effects_config in all_entries: session_id = str(uuid.uuid4()) expanded_name = f"{case.name}|{condition_name}" if case.name else condition_name + expanded.append( cls( name=expanded_name, @@ -110,14 +119,19 @@ def expand( expected_interactions=case.expected_interactions, expected_environment_state=case.expected_environment_state, metadata=case.metadata, - effects=effects, + effects=effects_config, ) ) return expanded + @property + def tool_effects(self) -> dict[str, list[ToolEffectUnion]]: + """Convenience accessor for effects['tool_effects'].""" + return self.effects.get("tool_effects", {}) + def __repr__(self) -> str: effects_str = ", ".join( - f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" for target, effs in self.effects.items() + f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" for target, effs in self.tool_effects.items() ) return f"ChaosCase(name='{self.name}', effects={{{effects_str}}})" diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py index f65af487..e3944629 100644 --- a/src/strands_evals/chaos/effects.py +++ b/src/strands_evals/chaos/effects.py @@ -1,19 +1,22 @@ """Chaos effect definitions. Effects are first-class parameterized classes organized in a hierarchy: - ChaosEffect → ToolEffect → concrete effects (ToolCallFailure, TruncateFields, etc.) + ChaosEffect → ToolEffect → concrete effects (Timeout, NetworkError, etc.) Each concrete effect carries only the parameters meaningful to it. The `hook` class variable indicates whether the effect fires pre-tool-call (error effects) or post-tool-call (corruption effects). + +Each concrete effect has an `effect_type` discriminator field for Pydantic +discriminated-union serialization, ensuring full round-trip fidelity. """ import math import random from abc import abstractmethod -from typing import Any, ClassVar, Literal +from typing import Annotated, Any, ClassVar, Literal, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Discriminator, Field, Tag class ChaosEffect(BaseModel): @@ -48,56 +51,118 @@ class ToolEffect(ChaosEffect): """ -# All supported failure types -ToolCallFailureType = Literal["timeout", "network_error", "execution_error", "validation_error"] - -# Default error messages per failure type -_DEFAULT_ERROR_MESSAGES: dict[str, str] = { - "timeout": "Tool call timed out", - "network_error": "Network unreachable", - "execution_error": "Tool execution failed", - "validation_error": "Tool input validation failed", -} +# --------------------------------------------------------------------------- +# Pre-hook effects: cancel the tool call before execution +# --------------------------------------------------------------------------- -class ToolCallFailure(ToolEffect): - """Simulates a tool call failure that prevents the tool from executing. +class Timeout(ToolEffect): + """Simulates a tool call timeout. - The tool call is cancelled before execution with a simulated error message. + The tool call is cancelled before execution with a timeout error message. Example:: ChaosCase( name="search_timeout", input="Find flights", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + effects={"tool_effects": {"search_tool": [Timeout()]}}, ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "pre" + effect_type: Literal["timeout"] = "timeout" + error_message: str = Field( + default="Tool call timed out", + description="Error message returned to the agent.", + ) + + def apply(self, context: Any = None) -> str: + """Return the error message to cancel the tool call with.""" + return self.error_message + + +class NetworkError(ToolEffect): + """Simulates a network error during tool invocation. + + The tool call is cancelled before execution with a network error message. + + Example:: ChaosCase( name="db_network_error", input="Query database", - effects={"database_tool": [ToolCallFailure( - error_type="network_error", - error_message="Connection refused on port 5432", - )]}, + effects={"tool_effects": {"database_tool": [NetworkError()]}}, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "pre" + effect_type: Literal["network_error"] = "network_error" + error_message: str = Field( + default="Network unreachable", + description="Error message returned to the agent.", + ) + + def apply(self, context: Any = None) -> str: + """Return the error message to cancel the tool call with.""" + return self.error_message + + +class ExecutionError(ToolEffect): + """Simulates a tool execution failure. + + The tool call is cancelled before execution with an execution error message. + + Example:: + + ChaosCase( + name="tool_exec_error", + input="Run analysis", + effects={"tool_effects": {"analysis_tool": [ExecutionError()]}}, ) """ hook: ClassVar[Literal["pre", "post"]] = "pre" - error_type: ToolCallFailureType = Field( - default="execution_error", - description="Type of failure to simulate.", + effect_type: Literal["execution_error"] = "execution_error" + error_message: str = Field( + default="Tool execution failed", + description="Error message returned to the agent.", ) - error_message: str | None = Field( - default=None, - description="Custom error message. If None, uses a default for the error_type.", + + def apply(self, context: Any = None) -> str: + """Return the error message to cancel the tool call with.""" + return self.error_message + + +class ValidationError(ToolEffect): + """Simulates a tool input validation failure. + + The tool call is cancelled before execution with a validation error message. + + Example:: + + ChaosCase( + name="bad_input", + input="Search with invalid params", + effects={"tool_effects": {"search_tool": [ValidationError()]}}, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "pre" + effect_type: Literal["validation_error"] = "validation_error" + error_message: str = Field( + default="Tool input validation failed", + description="Error message returned to the agent.", ) def apply(self, context: Any = None) -> str: """Return the error message to cancel the tool call with.""" - if self.error_message is not None: - return self.error_message - return _DEFAULT_ERROR_MESSAGES[self.error_type] + return self.error_message + + +# --------------------------------------------------------------------------- +# Post-hook effects: corrupt the tool response after execution +# --------------------------------------------------------------------------- class TruncateFields(ToolEffect): @@ -112,12 +177,13 @@ class TruncateFields(ToolEffect): name="search_truncated", input="Find flights", effects={ - "search_tool": [TruncateFields(max_length=5)], + "tool_effects": {"search_tool": [TruncateFields(max_length=5)]}, }, ) """ hook: ClassVar[Literal["pre", "post"]] = "post" + effect_type: Literal["truncate_fields"] = "truncate_fields" max_length: int = Field(default=10, ge=0, description="Maximum length to truncate string values to") def apply(self, response: Any = None) -> Any: @@ -154,12 +220,13 @@ class RemoveFields(ToolEffect): name="db_remove_fields", input="Query database", effects={ - "database_tool": [RemoveFields(remove_ratio=0.5)], + "tool_effects": {"database_tool": [RemoveFields(remove_ratio=0.5)]}, }, ) """ hook: ClassVar[Literal["pre", "post"]] = "post" + effect_type: Literal["remove_fields"] = "remove_fields" remove_ratio: float = Field( default=0.5, ge=0.0, @@ -201,12 +268,13 @@ class CorruptValues(ToolEffect): name="db_corrupt", input="Query database", effects={ - "database_tool": [CorruptValues(corrupt_ratio=0.8)], + "tool_effects": {"database_tool": [CorruptValues(corrupt_ratio=0.8)]}, }, ) """ hook: ClassVar[Literal["pre", "post"]] = "post" + effect_type: Literal["corrupt_values"] = "corrupt_values" corrupt_ratio: float = Field( default=0.5, ge=0.0, @@ -246,3 +314,26 @@ def apply(self, response: Any = None) -> Any: else: result[key] = value return result + + +# --------------------------------------------------------------------------- +# Discriminated union type for Pydantic serialization +# --------------------------------------------------------------------------- + +ToolEffectUnion = Annotated[ + Union[ + Annotated[Timeout, Tag("timeout")], + Annotated[NetworkError, Tag("network_error")], + Annotated[ExecutionError, Tag("execution_error")], + Annotated[ValidationError, Tag("validation_error")], + Annotated[TruncateFields, Tag("truncate_fields")], + Annotated[RemoveFields, Tag("remove_fields")], + Annotated[CorruptValues, Tag("corrupt_values")], + ], + Discriminator("effect_type"), +] +"""Discriminated union of all concrete tool effects. + +Used in ChaosCase.effects to ensure full round-trip serialization fidelity +with Pydantic's model_dump() / model_validate(). +""" diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py index e841da33..e7bc977a 100644 --- a/src/strands_evals/chaos/experiment.py +++ b/src/strands_evals/chaos/experiment.py @@ -36,16 +36,16 @@ class ChaosExperiment: ChaosExperiment, ChaosPlugin, ) - from strands_evals.chaos.effects import ToolCallFailure + from strands_evals.chaos.effects import Timeout, NetworkError chaos = ChaosPlugin() cases = [Case(input="Find flights to Tokyo", name="flight_search")] - effect_sets = [ - {"search_tool": [ToolCallFailure(error_type="timeout")]}, - {"database_tool": [ToolCallFailure(error_type="network_error")]}, - ] - chaos_cases = ChaosCase.expand(cases, effect_sets) + effect_maps = { + "search_timeout": {"tool_effects": {"search_tool": [Timeout()]}}, + "db_network_error": {"tool_effects": {"database_tool": [NetworkError()]}}, + } + chaos_cases = ChaosCase.expand(cases, effect_maps) def my_task(case): agent = Agent(tools=[search_tool, database_tool], plugins=[chaos]) diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index 8d1dbed9..28c1ecce 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -55,16 +55,16 @@ def __init__(self) -> None: def before_tool_call(self, event: BeforeToolCallEvent) -> None: """Intercept tool calls to inject pre-hook (error) effects. - For ToolCallFailure effects (with error_type='timeout', 'network_error', - etc.), cancels the tool call with the effect's error_message before the - tool executes. + For pre-hook effects (Timeout, NetworkError, ExecutionError, + ValidationError), cancels the tool call with the effect's error_message + before the tool executes. """ chaos_case = _current_chaos_case.get() - if chaos_case is None or not chaos_case.effects: + if chaos_case is None or not chaos_case.tool_effects: return tool_name = event.tool_use.get("name", "") - effects = chaos_case.effects.get(tool_name, []) + effects = chaos_case.tool_effects.get(tool_name, []) if not effects: return @@ -85,11 +85,11 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: applies effect.apply() to JSON content blocks in the tool response. """ chaos_case = _current_chaos_case.get() - if chaos_case is None or not chaos_case.effects: + if chaos_case is None or not chaos_case.tool_effects: return tool_name = event.tool_use.get("name", "") - effects = chaos_case.effects.get(tool_name, []) + effects = chaos_case.tool_effects.get(tool_name, []) if not effects: return diff --git a/tests/strands_evals/chaos/test_case.py b/tests/strands_evals/chaos/test_case.py index 7a57ea78..04b3f676 100644 --- a/tests/strands_evals/chaos/test_case.py +++ b/tests/strands_evals/chaos/test_case.py @@ -2,7 +2,7 @@ from strands_evals import Case from strands_evals.chaos import ChaosCase -from strands_evals.chaos.effects import CorruptValues, ToolCallFailure, TruncateFields +from strands_evals.chaos.effects import CorruptValues, Timeout, TruncateFields class TestChaosCase: @@ -10,40 +10,44 @@ class TestChaosCase: def test_baseline_case_has_no_effects(self): case = ChaosCase(name="baseline", input="hello") - assert case.effects == {} + assert case.tool_effects == {} def test_case_with_effects(self): case = ChaosCase( name="search_timeout", input="hello", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + effects={"tool_effects": {"search_tool": [Timeout()]}}, ) - assert len(case.effects) == 1 - assert isinstance(case.effects["search_tool"][0], ToolCallFailure) + assert len(case.tool_effects) == 1 + assert isinstance(case.tool_effects["search_tool"][0], Timeout) def test_case_with_multiple_tools(self): case = ChaosCase( name="compound", input="hello", effects={ - "search_tool": [ToolCallFailure(error_type="timeout")], - "db_tool": [CorruptValues(corrupt_ratio=0.8)], + "tool_effects": { + "search_tool": [Timeout()], + "db_tool": [CorruptValues(corrupt_ratio=0.8)], + } }, ) - assert len(case.effects) == 2 + assert len(case.tool_effects) == 2 def test_case_with_multiple_effects_per_tool(self): case = ChaosCase( name="multi_effect", input="hello", effects={ - "tool_a": [ - TruncateFields(max_length=5), - CorruptValues(corrupt_ratio=0.3), - ], + "tool_effects": { + "tool_a": [ + TruncateFields(max_length=5), + CorruptValues(corrupt_ratio=0.3), + ], + } }, ) - assert len(case.effects["tool_a"]) == 2 + assert len(case.tool_effects["tool_a"]) == 2 def test_inherits_case_fields(self): case = ChaosCase( @@ -52,7 +56,7 @@ def test_inherits_case_fields(self): expected_output="world", expected_trajectory=["tool_a"], metadata={"key": "value"}, - effects={"tool_a": [ToolCallFailure()]}, + effects={"tool_effects": {"tool_a": [Timeout()]}}, ) assert case.input == "hello" assert case.expected_output == "world" @@ -63,11 +67,44 @@ def test_repr_shows_effects(self): case = ChaosCase( name="test", input="hello", - effects={"tool": [ToolCallFailure()]}, + effects={"tool_effects": {"tool": [Timeout()]}}, ) repr_str = repr(case) assert "test" in repr_str - assert "ToolCallFailure" in repr_str + assert "Timeout" in repr_str + + def test_model_dump_preserves_concrete_fields(self): + """Verify discriminated union serialization preserves all concrete fields.""" + case = ChaosCase( + name="serialization_test", + input="hello", + effects={"tool_effects": {"search_tool": [Timeout(error_message="custom timeout")]}}, + ) + dumped = case.model_dump() + effect_data = dumped["effects"]["tool_effects"]["search_tool"][0] + assert effect_data["effect_type"] == "timeout" + assert effect_data["error_message"] == "custom timeout" + assert effect_data["apply_rate"] == 1.0 + + def test_model_dump_roundtrip(self): + """Verify full round-trip serialization/deserialization.""" + case = ChaosCase( + name="roundtrip", + input="hello", + effects={ + "tool_effects": { + "tool_a": [Timeout()], + "tool_b": [TruncateFields(max_length=5), CorruptValues(corrupt_ratio=0.7)], + } + }, + ) + dumped = case.model_dump() + restored = ChaosCase.model_validate(dumped) + assert isinstance(restored.tool_effects["tool_a"][0], Timeout) + assert isinstance(restored.tool_effects["tool_b"][0], TruncateFields) + assert restored.tool_effects["tool_b"][0].max_length == 5 + assert isinstance(restored.tool_effects["tool_b"][1], CorruptValues) + assert restored.tool_effects["tool_b"][1].corrupt_ratio == 0.7 class TestChaosCaseExpand: @@ -79,8 +116,8 @@ def test_expand_with_baseline(self): Case(name="case_b", input="world"), ] effect_maps = { - "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, - "db_corrupt": {"db_tool": [CorruptValues(corrupt_ratio=0.8)]}, + "search_timeout": {"tool_effects": {"search_tool": [Timeout()]}}, + "db_corrupt": {"tool_effects": {"db_tool": [CorruptValues(corrupt_ratio=0.8)]}}, } result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) # 2 cases × (2 effect maps + 1 baseline) = 6 @@ -92,7 +129,7 @@ def test_expand_without_baseline(self): Case(name="case_b", input="world"), ] effect_maps = { - "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, + "search_timeout": {"tool_effects": {"search_tool": [Timeout()]}}, } result = ChaosCase.expand(cases, effect_maps) # 2 cases × 1 effect map = 2 (no baseline by default) @@ -100,14 +137,14 @@ def test_expand_without_baseline(self): def test_expand_baseline_names(self): cases = [Case(name="case_a", input="hello")] - effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + effect_maps = {"timeout": {"tool_effects": {"tool": [Timeout()]}}} result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) names = [c.name for c in result] assert "case_a|baseline" in names def test_expand_uses_dict_keys_as_names(self): cases = [Case(name="case_a", input="hello")] - effect_maps = {"search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}} + effect_maps = {"search_timeout": {"tool_effects": {"search_tool": [Timeout()]}}} result = ChaosCase.expand(cases, effect_maps) assert result[0].name == "case_a|search_timeout" @@ -115,8 +152,10 @@ def test_expand_compound_effect_name(self): cases = [Case(name="case_a", input="hello")] effect_maps = { "multi_failure": { - "search_tool": [ToolCallFailure(error_type="timeout")], - "db_tool": [CorruptValues()], + "tool_effects": { + "search_tool": [Timeout()], + "db_tool": [CorruptValues()], + } } } result = ChaosCase.expand(cases, effect_maps) @@ -124,7 +163,7 @@ def test_expand_compound_effect_name(self): def test_expand_unique_session_ids(self): cases = [Case(name="case_a", input="hello"), Case(name="case_b", input="world")] - effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + effect_maps = {"timeout": {"tool_effects": {"tool": [Timeout()]}}} result = ChaosCase.expand(cases, effect_maps) session_ids = [c.session_id for c in result] assert len(session_ids) == len(set(session_ids)) @@ -139,7 +178,7 @@ def test_expand_preserves_case_fields(self): metadata={"key": "value"}, ) ] - effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + effect_maps = {"timeout": {"tool_effects": {"tool": [Timeout()]}}} result = ChaosCase.expand(cases, effect_maps) expanded = result[0] assert expanded.input == "hello" @@ -149,10 +188,10 @@ def test_expand_preserves_case_fields(self): def test_expand_baseline_has_empty_effects(self): cases = [Case(name="case_a", input="hello")] - effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + effect_maps = {"timeout": {"tool_effects": {"tool": [Timeout()]}}} result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) baseline = [c for c in result if "baseline" in c.name][0] - assert baseline.effects == {} + assert baseline.tool_effects == {} def test_expand_empty_effect_maps_with_baseline(self): cases = [Case(name="case_a", input="hello")] @@ -169,7 +208,7 @@ def test_expand_empty_effect_maps_without_baseline(self): def test_expand_case_without_name(self): cases = [Case(input="hello")] - effect_maps = {"timeout": {"tool": [ToolCallFailure()]}} + effect_maps = {"timeout": {"tool_effects": {"tool": [Timeout()]}}} result = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) names = [c.name for c in result] assert "baseline" in names diff --git a/tests/strands_evals/chaos/test_effects.py b/tests/strands_evals/chaos/test_effects.py index c55821d9..8ddde000 100644 --- a/tests/strands_evals/chaos/test_effects.py +++ b/tests/strands_evals/chaos/test_effects.py @@ -2,41 +2,97 @@ import random -import pytest - from strands_evals.chaos.effects import ( CorruptValues, + ExecutionError, + NetworkError, RemoveFields, - ToolCallFailure, + Timeout, TruncateFields, + ValidationError, ) -class TestToolCallFailure: - """Tests for the ToolCallFailure pre-hook effect.""" - - @pytest.mark.parametrize( - "error_type,expected_message", - [ - ("timeout", "Tool call timed out"), - ("network_error", "Network unreachable"), - ("execution_error", "Tool execution failed"), - ("validation_error", "Tool input validation failed"), - ], - ) - def test_apply_returns_default_message(self, error_type, expected_message): - effect = ToolCallFailure(error_type=error_type) - assert effect.apply() == expected_message - - def test_apply_returns_custom_message_when_provided(self): - effect = ToolCallFailure(error_type="timeout", error_message="Custom timeout msg") +class TestTimeout: + """Tests for the Timeout pre-hook effect.""" + + def test_apply_returns_default_message(self): + effect = Timeout() + assert effect.apply() == "Tool call timed out" + + def test_apply_returns_custom_message(self): + effect = Timeout(error_message="Custom timeout msg") assert effect.apply() == "Custom timeout msg" + def test_hook_is_pre(self): + assert Timeout.hook == "pre" + + def test_effect_type(self): + effect = Timeout() + assert effect.effect_type == "timeout" + def test_apply_rate_defaults_to_one(self): - effect = ToolCallFailure() + effect = Timeout() assert effect.apply_rate == 1.0 +class TestNetworkError: + """Tests for the NetworkError pre-hook effect.""" + + def test_apply_returns_default_message(self): + effect = NetworkError() + assert effect.apply() == "Network unreachable" + + def test_apply_returns_custom_message(self): + effect = NetworkError(error_message="Connection refused on port 5432") + assert effect.apply() == "Connection refused on port 5432" + + def test_hook_is_pre(self): + assert NetworkError.hook == "pre" + + def test_effect_type(self): + effect = NetworkError() + assert effect.effect_type == "network_error" + + +class TestExecutionError: + """Tests for the ExecutionError pre-hook effect.""" + + def test_apply_returns_default_message(self): + effect = ExecutionError() + assert effect.apply() == "Tool execution failed" + + def test_apply_returns_custom_message(self): + effect = ExecutionError(error_message="Segfault in native code") + assert effect.apply() == "Segfault in native code" + + def test_hook_is_pre(self): + assert ExecutionError.hook == "pre" + + def test_effect_type(self): + effect = ExecutionError() + assert effect.effect_type == "execution_error" + + +class TestValidationError: + """Tests for the ValidationError pre-hook effect.""" + + def test_apply_returns_default_message(self): + effect = ValidationError() + assert effect.apply() == "Tool input validation failed" + + def test_apply_returns_custom_message(self): + effect = ValidationError(error_message="Missing required field: origin") + assert effect.apply() == "Missing required field: origin" + + def test_hook_is_pre(self): + assert ValidationError.hook == "pre" + + def test_effect_type(self): + effect = ValidationError() + assert effect.effect_type == "validation_error" + + class TestTruncateFields: """Tests for the TruncateFields post-hook effect.""" @@ -76,6 +132,10 @@ def test_zero_max_length_truncates_all_strings(self): result = effect.apply(response) assert result["text"] == "" + def test_effect_type(self): + effect = TruncateFields() + assert effect.effect_type == "truncate_fields" + class TestRemoveFields: """Tests for the RemoveFields post-hook effect.""" @@ -117,6 +177,10 @@ def test_single_field_always_removed(self): result = effect.apply(response) assert len(result) == 0 + def test_effect_type(self): + effect = RemoveFields() + assert effect.effect_type == "remove_fields" + class TestCorruptValues: """Tests for the CorruptValues post-hook effect.""" @@ -161,3 +225,7 @@ def test_corrupted_value_differs_from_original(self): response = {"key": "unique_original_value"} result = effect.apply(response) assert result["key"] != "unique_original_value" + + def test_effect_type(self): + effect = CorruptValues() + assert effect.effect_type == "corrupt_values" diff --git a/tests/strands_evals/chaos/test_experiment.py b/tests/strands_evals/chaos/test_experiment.py index 4fb0c574..47e5a1e2 100644 --- a/tests/strands_evals/chaos/test_experiment.py +++ b/tests/strands_evals/chaos/test_experiment.py @@ -5,7 +5,7 @@ from strands_evals import Case from strands_evals.chaos import ChaosCase, ChaosExperiment from strands_evals.chaos._context import _current_chaos_case -from strands_evals.chaos.effects import CorruptValues, ToolCallFailure +from strands_evals.chaos.effects import CorruptValues, Timeout from strands_evals.evaluators.evaluator import Evaluator from strands_evals.types import EvaluationData, EvaluationOutput @@ -28,8 +28,8 @@ def cases(): @pytest.fixture def effect_maps(): return { - "search_timeout": {"search_tool": [ToolCallFailure(error_type="timeout")]}, - "db_corrupt": {"db_tool": [CorruptValues(corrupt_ratio=0.8)]}, + "search_timeout": {"tool_effects": {"search_tool": [Timeout()]}}, + "db_corrupt": {"tool_effects": {"db_tool": [CorruptValues(corrupt_ratio=0.8)]}}, } @@ -94,7 +94,7 @@ def capturing_task(case: ChaosCase): def test_context_var_reset_on_task_exception(self, evaluator): """Verify the ContextVar is reset even if the task raises.""" cases = [Case(name="failing", input="x")] - effect_maps = {"chaos": {"t": [ToolCallFailure()]}} + effect_maps = {"chaos": {"tool_effects": {"t": [Timeout()]}}} chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) call_count = [0] @@ -130,3 +130,51 @@ def task(case: ChaosCase): report = reports[0] # 2 cases × 3 conditions = 6 scores assert len(report.scores) == 6 + + def test_run_evaluations_rejects_async_task(self, cases, effect_maps, evaluator): + """Verify run_evaluations raises ValueError for async tasks.""" + + async def async_task(case: ChaosCase): + return "output" + + chaos_cases = ChaosCase.expand(cases, effect_maps) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + + with pytest.raises(ValueError, match="Async task is not supported"): + experiment.run_evaluations(task=async_task) + + @pytest.mark.asyncio + async def test_run_evaluations_async_with_async_task(self, cases, effect_maps, evaluator): + """Verify run_evaluations_async works with an actual async task.""" + chaos_cases = ChaosCase.expand(cases, effect_maps) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + + async def async_task(case: ChaosCase): + active = _current_chaos_case.get() + assert active is case + return "async_output" + + reports = await experiment.run_evaluations_async(task=async_task, max_workers=2) + assert len(reports) >= 1 + + @pytest.mark.asyncio + async def test_run_evaluations_async_context_var_reset(self, cases, effect_maps, evaluator): + """Verify the ContextVar is properly reset after async execution.""" + chaos_cases = ChaosCase.expand(cases, effect_maps, include_no_effect_baseline=True) + experiment = ChaosExperiment(cases=chaos_cases, evaluators=[evaluator]) + + observed_cases = [] + + async def async_capturing_task(case: ChaosCase): + active_case = _current_chaos_case.get() + observed_cases.append((case.name, active_case.name if active_case else None)) + return "output" + + await experiment.run_evaluations_async(task=async_capturing_task, max_workers=1) + + # All observations should have matching case names + for case_name, active_name in observed_cases: + assert case_name == active_name + + # ContextVar should be reset + assert _current_chaos_case.get() is None diff --git a/tests/strands_evals/chaos/test_plugin.py b/tests/strands_evals/chaos/test_plugin.py index 45b1844f..6bf2a94a 100644 --- a/tests/strands_evals/chaos/test_plugin.py +++ b/tests/strands_evals/chaos/test_plugin.py @@ -8,7 +8,8 @@ from strands_evals.chaos import ChaosCase, ChaosPlugin from strands_evals.chaos._context import _current_chaos_case from strands_evals.chaos.effects import ( - ToolCallFailure, + NetworkError, + Timeout, TruncateFields, ) @@ -55,7 +56,7 @@ def test_case_without_matching_tool_does_nothing(self, chaos_plugin, before_even case = ChaosCase( name="other_tool_fails", input="test", - effects={"other_tool": [ToolCallFailure(error_type="timeout")]}, + effects={"tool_effects": {"other_tool": [Timeout()]}}, ) token = _current_chaos_case.set(case) try: @@ -68,7 +69,7 @@ def test_pre_hook_effect_cancels_tool(self, chaos_plugin, before_event): case = ChaosCase( name="search_timeout", input="test", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + effects={"tool_effects": {"search_tool": [Timeout()]}}, ) token = _current_chaos_case.set(case) try: @@ -81,7 +82,7 @@ def test_post_hook_effect_does_not_cancel_tool(self, chaos_plugin, before_event) case = ChaosCase( name="search_truncated", input="test", - effects={"search_tool": [TruncateFields(max_length=5)]}, + effects={"tool_effects": {"search_tool": [TruncateFields(max_length=5)]}}, ) token = _current_chaos_case.set(case) try: @@ -95,10 +96,12 @@ def test_first_pre_hook_effect_wins(self, chaos_plugin, before_event): name="multi_pre", input="test", effects={ - "search_tool": [ - ToolCallFailure(error_type="timeout"), - ToolCallFailure(error_type="network_error"), - ] + "tool_effects": { + "search_tool": [ + Timeout(), + NetworkError(), + ] + } }, ) token = _current_chaos_case.set(case) @@ -125,7 +128,7 @@ def test_case_without_matching_tool_does_nothing(self, chaos_plugin, after_event case = ChaosCase( name="other_tool", input="test", - effects={"other_tool": [TruncateFields(max_length=3)]}, + effects={"tool_effects": {"other_tool": [TruncateFields(max_length=3)]}}, ) token = _current_chaos_case.set(case) try: @@ -139,7 +142,7 @@ def test_post_hook_corrupts_json_text_blocks(self, chaos_plugin, after_event): case = ChaosCase( name="truncate", input="test", - effects={"search_tool": [TruncateFields(max_length=3)]}, + effects={"tool_effects": {"search_tool": [TruncateFields(max_length=3)]}}, ) token = _current_chaos_case.set(case) try: @@ -154,7 +157,7 @@ def test_pre_hook_effect_ignored_in_after_hook(self, chaos_plugin, after_event): case = ChaosCase( name="pre_only", input="test", - effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + effects={"tool_effects": {"search_tool": [Timeout()]}}, ) token = _current_chaos_case.set(case) try: @@ -172,7 +175,7 @@ def test_none_result_is_skipped(self, chaos_plugin): case = ChaosCase( name="truncate", input="test", - effects={"search_tool": [TruncateFields(max_length=3)]}, + effects={"tool_effects": {"search_tool": [TruncateFields(max_length=3)]}}, ) token = _current_chaos_case.set(case) try: @@ -193,7 +196,7 @@ def test_plain_text_truncation(self, chaos_plugin): case = ChaosCase( name="truncate", input="test", - effects={"search_tool": [TruncateFields(max_length=4)]}, + effects={"tool_effects": {"search_tool": [TruncateFields(max_length=4)]}}, ) token = _current_chaos_case.set(case) try: @@ -211,7 +214,7 @@ def test_apply_rate_zero_skips_pre_hook_effect(self, chaos_plugin, before_event) case = ChaosCase( name="never_fires", input="test", - effects={"search_tool": [ToolCallFailure(error_type="timeout", apply_rate=0.0)]}, + effects={"tool_effects": {"search_tool": [Timeout(apply_rate=0.0)]}}, ) token = _current_chaos_case.set(case) try: @@ -228,7 +231,7 @@ def test_apply_rate_one_always_fires_pre_hook(self, chaos_plugin, before_event): case = ChaosCase( name="always_fires", input="test", - effects={"search_tool": [ToolCallFailure(error_type="timeout", apply_rate=1.0)]}, + effects={"tool_effects": {"search_tool": [Timeout(apply_rate=1.0)]}}, ) token = _current_chaos_case.set(case) try: @@ -242,7 +245,7 @@ def test_apply_rate_zero_skips_post_hook_effect(self, chaos_plugin, after_event) case = ChaosCase( name="never_truncates", input="test", - effects={"search_tool": [TruncateFields(max_length=3, apply_rate=0.0)]}, + effects={"tool_effects": {"search_tool": [TruncateFields(max_length=3, apply_rate=0.0)]}}, ) token = _current_chaos_case.set(case) try: @@ -257,7 +260,7 @@ def test_apply_rate_one_always_fires_post_hook(self, chaos_plugin, after_event): case = ChaosCase( name="always_truncates", input="test", - effects={"search_tool": [TruncateFields(max_length=3, apply_rate=1.0)]}, + effects={"tool_effects": {"search_tool": [TruncateFields(max_length=3, apply_rate=1.0)]}}, ) token = _current_chaos_case.set(case) try: From d5bcd6e24877bb1f061f49b62e88281cfdfd3964 Mon Sep 17 00:00:00 2001 From: Darren Wang Date: Fri, 22 May 2026 16:43:33 +0000 Subject: [PATCH 9/9] remove apply rate; limit 1 effect per tool --- src/strands_evals/chaos/case.py | 13 ++- src/strands_evals/chaos/effects.py | 8 -- src/strands_evals/chaos/plugin.py | 6 -- tests/strands_evals/chaos/test_case.py | 34 ++++---- tests/strands_evals/chaos/test_effects.py | 4 - tests/strands_evals/chaos/test_plugin.py | 99 ++++------------------- 6 files changed, 44 insertions(+), 120 deletions(-) diff --git a/src/strands_evals/chaos/case.py b/src/strands_evals/chaos/case.py index 4419548f..52e68d21 100644 --- a/src/strands_evals/chaos/case.py +++ b/src/strands_evals/chaos/case.py @@ -7,7 +7,7 @@ import uuid -from pydantic import Field +from pydantic import Field, model_validator from typing_extensions import Generic from ..case import Case @@ -61,6 +61,17 @@ class ChaosCase(Case, Generic[InputT, OutputT]): "tool_name -> list of effects. Empty dict means baseline (no chaos).", ) + @model_validator(mode="after") + def _validate_tool_effects(self) -> "ChaosCase": + """Validate tool effects configuration.""" + for tool_name, effects_list in self.tool_effects.items(): + if len(effects_list) > 1: + raise ValueError( + f"Tool '{tool_name}' has {len(effects_list)} effects — only 1 is allowed per " + f"ChaosCase. Use separate ChaosCase instances to test effects independently." + ) + return self + @classmethod def expand( cls, diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py index e3944629..7a22dd1b 100644 --- a/src/strands_evals/chaos/effects.py +++ b/src/strands_evals/chaos/effects.py @@ -23,19 +23,11 @@ class ChaosEffect(BaseModel): """Base for all chaos effects. Attributes: - apply_rate: Probability that this effect fires, defaults to 1 (always fire). hook: Whether this effect fires pre-call ("pre") or post-call ("post"). """ hook: ClassVar[Literal["pre", "post"]] - apply_rate: float = Field( - default=1.0, - ge=0.0, - le=1.0, - description="Probability that this effect fires (1.0 = always).", - ) - @abstractmethod def apply(self, context: Any = None) -> Any: """Apply the chaos effect to the given context and return the result.""" diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py index 28c1ecce..9d3aa7c8 100644 --- a/src/strands_evals/chaos/plugin.py +++ b/src/strands_evals/chaos/plugin.py @@ -9,7 +9,6 @@ import json import logging -import random from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.plugins import Plugin, hook @@ -71,8 +70,6 @@ def before_tool_call(self, event: BeforeToolCallEvent) -> None: # First pre-hook effect wins (tool is cancelled once) for effect in effects: if effect.hook == "pre": - if random.random() > effect.apply_rate: - continue event.cancel_tool = effect.apply() logger.info("effect=<%s>, tool=<%s> | injected chaos pre-hook", type(effect).__name__, tool_name) return @@ -98,9 +95,6 @@ def after_tool_call(self, event: AfterToolCallEvent) -> None: if effect.hook != "post": continue - if random.random() > effect.apply_rate: - continue - if event.result is None: continue diff --git a/tests/strands_evals/chaos/test_case.py b/tests/strands_evals/chaos/test_case.py index 04b3f676..549e95a5 100644 --- a/tests/strands_evals/chaos/test_case.py +++ b/tests/strands_evals/chaos/test_case.py @@ -1,5 +1,7 @@ """Unit tests for ChaosCase.""" +import pytest + from strands_evals import Case from strands_evals.chaos import ChaosCase from strands_evals.chaos.effects import CorruptValues, Timeout, TruncateFields @@ -35,19 +37,20 @@ def test_case_with_multiple_tools(self): assert len(case.tool_effects) == 2 def test_case_with_multiple_effects_per_tool(self): - case = ChaosCase( - name="multi_effect", - input="hello", - effects={ - "tool_effects": { - "tool_a": [ - TruncateFields(max_length=5), - CorruptValues(corrupt_ratio=0.3), - ], - } - }, - ) - assert len(case.tool_effects["tool_a"]) == 2 + """Multiple effects for one tool should be rejected.""" + with pytest.raises(ValueError, match="only 1 is allowed"): + ChaosCase( + name="multi_effect", + input="hello", + effects={ + "tool_effects": { + "tool_a": [ + TruncateFields(max_length=5), + CorruptValues(corrupt_ratio=0.3), + ], + } + }, + ) def test_inherits_case_fields(self): case = ChaosCase( @@ -84,7 +87,6 @@ def test_model_dump_preserves_concrete_fields(self): effect_data = dumped["effects"]["tool_effects"]["search_tool"][0] assert effect_data["effect_type"] == "timeout" assert effect_data["error_message"] == "custom timeout" - assert effect_data["apply_rate"] == 1.0 def test_model_dump_roundtrip(self): """Verify full round-trip serialization/deserialization.""" @@ -94,7 +96,7 @@ def test_model_dump_roundtrip(self): effects={ "tool_effects": { "tool_a": [Timeout()], - "tool_b": [TruncateFields(max_length=5), CorruptValues(corrupt_ratio=0.7)], + "tool_b": [TruncateFields(max_length=5)], } }, ) @@ -103,8 +105,6 @@ def test_model_dump_roundtrip(self): assert isinstance(restored.tool_effects["tool_a"][0], Timeout) assert isinstance(restored.tool_effects["tool_b"][0], TruncateFields) assert restored.tool_effects["tool_b"][0].max_length == 5 - assert isinstance(restored.tool_effects["tool_b"][1], CorruptValues) - assert restored.tool_effects["tool_b"][1].corrupt_ratio == 0.7 class TestChaosCaseExpand: diff --git a/tests/strands_evals/chaos/test_effects.py b/tests/strands_evals/chaos/test_effects.py index 8ddde000..fdfd5890 100644 --- a/tests/strands_evals/chaos/test_effects.py +++ b/tests/strands_evals/chaos/test_effects.py @@ -31,10 +31,6 @@ def test_effect_type(self): effect = Timeout() assert effect.effect_type == "timeout" - def test_apply_rate_defaults_to_one(self): - effect = Timeout() - assert effect.apply_rate == 1.0 - class TestNetworkError: """Tests for the NetworkError pre-hook effect.""" diff --git a/tests/strands_evals/chaos/test_plugin.py b/tests/strands_evals/chaos/test_plugin.py index 6bf2a94a..6ff99e4a 100644 --- a/tests/strands_evals/chaos/test_plugin.py +++ b/tests/strands_evals/chaos/test_plugin.py @@ -91,25 +91,21 @@ def test_post_hook_effect_does_not_cancel_tool(self, chaos_plugin, before_event) finally: _current_chaos_case.reset(token) - def test_first_pre_hook_effect_wins(self, chaos_plugin, before_event): - case = ChaosCase( - name="multi_pre", - input="test", - effects={ - "tool_effects": { - "search_tool": [ - Timeout(), - NetworkError(), - ] - } - }, - ) - token = _current_chaos_case.set(case) - try: - chaos_plugin.before_tool_call(before_event) - assert before_event.cancel_tool == "Tool call timed out" - finally: - _current_chaos_case.reset(token) + def test_multiple_pre_hook_effects(self, chaos_plugin, before_event): + """Multiple effects per tool should be rejected.""" + with pytest.raises(ValueError, match="only 1 is allowed"): + ChaosCase( + name="multi_pre", + input="test", + effects={ + "tool_effects": { + "search_tool": [ + Timeout(), + NetworkError(), + ] + } + }, + ) class TestChaosPluginAfterToolCall: @@ -204,68 +200,3 @@ def test_plain_text_truncation(self, chaos_plugin): assert event.result["content"][0]["text"] == "This" finally: _current_chaos_case.reset(token) - - -class TestApplyRate: - """Tests for the apply_rate probability check in ChaosPlugin.""" - - def test_apply_rate_zero_skips_pre_hook_effect(self, chaos_plugin, before_event): - """Effect with apply_rate=0.0 should never fire.""" - case = ChaosCase( - name="never_fires", - input="test", - effects={"tool_effects": {"search_tool": [Timeout(apply_rate=0.0)]}}, - ) - token = _current_chaos_case.set(case) - try: - # Run multiple times to confirm it never fires - for _ in range(20): - before_event.cancel_tool = None - chaos_plugin.before_tool_call(before_event) - assert before_event.cancel_tool is None - finally: - _current_chaos_case.reset(token) - - def test_apply_rate_one_always_fires_pre_hook(self, chaos_plugin, before_event): - """Effect with apply_rate=1.0 should always fire.""" - case = ChaosCase( - name="always_fires", - input="test", - effects={"tool_effects": {"search_tool": [Timeout(apply_rate=1.0)]}}, - ) - token = _current_chaos_case.set(case) - try: - chaos_plugin.before_tool_call(before_event) - assert before_event.cancel_tool == "Tool call timed out" - finally: - _current_chaos_case.reset(token) - - def test_apply_rate_zero_skips_post_hook_effect(self, chaos_plugin, after_event): - """Post-hook effect with apply_rate=0.0 should never fire.""" - case = ChaosCase( - name="never_truncates", - input="test", - effects={"tool_effects": {"search_tool": [TruncateFields(max_length=3, apply_rate=0.0)]}}, - ) - token = _current_chaos_case.set(case) - try: - original_content = after_event.result["content"][0]["text"] - chaos_plugin.after_tool_call(after_event) - assert after_event.result["content"][0]["text"] == original_content - finally: - _current_chaos_case.reset(token) - - def test_apply_rate_one_always_fires_post_hook(self, chaos_plugin, after_event): - """Post-hook effect with apply_rate=1.0 should always fire.""" - case = ChaosCase( - name="always_truncates", - input="test", - effects={"tool_effects": {"search_tool": [TruncateFields(max_length=3, apply_rate=1.0)]}}, - ) - token = _current_chaos_case.set(case) - try: - chaos_plugin.after_tool_call(after_event) - corrupted = json.loads(after_event.result["content"][0]["text"]) - assert corrupted["title"] == "Lon" - finally: - _current_chaos_case.reset(token)