diff --git a/lionagi/work/__init__.py b/lionagi/work/__init__.py new file mode 100644 index 000000000..b57b2e3e1 --- /dev/null +++ b/lionagi/work/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023-2025, HaiyangLi +# SPDX-License-Identifier: Apache-2.0 + +"""lionagi.work — WorkForm + Rule/RuleSet (R2-1 slim re-cut). + +Public surface:: + + from lionagi.work import ( + FieldSpec, + FieldType, + FormStatus, + VALID_TRANSITIONS, + WorkForm, + fill_form, + validate_form, + Rule, + RuleSet, + CheckKind, + REGEX_MAX_INPUT_LENGTH, + ) +""" + +from .form import ( + VALID_TRANSITIONS, + FieldSpec, + FieldType, + FormStatus, + WorkForm, + fill_form, + validate_form, +) +from .rules import ( + REGEX_MAX_INPUT_LENGTH, + CheckKind, + Rule, + RuleSet, +) + +__all__ = ( + # form + "FieldSpec", + "FieldType", + "FormStatus", + "VALID_TRANSITIONS", + "WorkForm", + "fill_form", + "validate_form", + # rules + "Rule", + "RuleSet", + "CheckKind", + "REGEX_MAX_INPUT_LENGTH", +) diff --git a/lionagi/work/form.py b/lionagi/work/form.py new file mode 100644 index 000000000..d29bfe5aa --- /dev/null +++ b/lionagi/work/form.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023-2025, HaiyangLi +# SPDX-License-Identifier: Apache-2.0 + +"""WorkForm: structured input/output container for worker tasks. + +A WorkForm captures a typed specification (FieldSpec) for every input +and output slot a worker needs, tracks live values, and records the +validation status of those values. The lifecycle is: + + draft → filled → validated (happy path) + draft → filled → error (validation failed) + validated → submitted (engine accepted it) + submitted → completed (worker finished) + error → draft (allow re-opening for correction) +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import ConfigDict, Field, model_validator + +from lionagi.protocols.generic.element import Element + +if TYPE_CHECKING: + from .rules import RuleSet + +__all__ = ( + "FieldSpec", + "FieldType", + "FormStatus", + "VALID_TRANSITIONS", + "WorkForm", + "fill_form", + "validate_form", +) + +# Allowed value-type labels. "list" and "dict" are JSON containers. +FieldType = Literal["str", "int", "float", "bool", "list", "dict"] + +_PYTHON_TYPE_MAP: dict[str, type] = { + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, +} + +FormStatus = Literal["draft", "filled", "validated", "error", "submitted", "completed"] + +# Allowed lifecycle transitions. Any move not listed here is invalid. +VALID_TRANSITIONS: dict[str, frozenset[str]] = { + "draft": frozenset({"filled"}), + "filled": frozenset({"validated", "error"}), + "validated": frozenset({"submitted", "error"}), + "error": frozenset({"draft"}), # allow re-opening for correction + "submitted": frozenset({"completed", "error"}), + "completed": frozenset(), # terminal — no outgoing transitions +} + + +class FieldSpec(Element): + """Declaration of a single field inside a WorkForm. + + FieldSpec is a plain value object (no lifecycle, no graph identity needed), + but inherits from Element for UUID tracking and created_at timestamps. + + Attributes: + name: Machine-readable field name (alphanumeric + underscores, + must start with a letter or underscore). + type: Expected Python type expressed as a string literal. + required: When True, the form cannot be validated with this + field absent or None. + default: Value used when the field is absent and not required. + Must be compatible with the declared ``type`` at construction + time (validated eagerly). + description: Human-readable explanation of this field's purpose. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + ) + + name: str = Field(..., description="Field identifier (alphanumeric + underscores).") + type: FieldType = Field("str", description="Expected value type.") + required: bool = Field(True, description="Whether this field must be supplied.") + default: Any = Field(None, description="Default value when field is absent.") + description: str = Field("", description="Human-readable description.") + + @model_validator(mode="after") + def _validate_name_and_default(self) -> FieldSpec: + # Name must be a valid Python identifier (letters/digits/underscores, + # starting with a letter or underscore). + if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", self.name): + raise ValueError( + f"Field name {self.name!r} must start with a letter or underscore " + "and contain only alphanumeric characters and underscores." + ) + + # Default value must be type-compatible when provided. + if self.default is not None: + target = _PYTHON_TYPE_MAP[self.type] + # Allow int default for float field (numeric widening). + if self.type == "float" and isinstance(self.default, int): + return self + if not isinstance(self.default, target): + raise ValueError( + f"FieldSpec {self.name!r}: default {self.default!r} is not " + f"compatible with declared type {self.type!r}." + ) + return self + + def coerce(self, value: Any) -> Any: + """Attempt to coerce *value* to this field's declared type. + + Returns the coerced value on success, raises ``TypeError`` on failure. + ``None`` is returned unchanged. + """ + if value is None: + return None + target = _PYTHON_TYPE_MAP[self.type] + if isinstance(value, target): + return value + # Numeric widening: int → float is allowed. + if self.type == "float" and isinstance(value, int): + return float(value) + # str → bool special case. + if self.type == "bool" and isinstance(value, str): + if value.lower() in {"true", "1", "yes"}: + return True + if value.lower() in {"false", "0", "no"}: + return False + # str → int / float. + if self.type in {"int", "float"} and isinstance(value, str): + try: + return target(value) + except ValueError: + pass + raise TypeError( + f"Field {self.name!r} expects type {self.type!r}, " + f"got {type(value).__name__!r} with value {value!r}." + ) + + +class WorkForm(Element): + """A structured data container for a single worker invocation. + + WorkForm inherits from :class:`~lionagi.protocols.generic.element.Element`, + gaining a UUID ``id``, ``created_at`` timestamp, and ``metadata`` dict + consistent with the rest of the lionagi ecosystem. + + The string ``form_id`` property is a convenience alias over ``str(self.id)`` + for human-readable references. + + WorkForm instances are *immutable by convention* — mutation helpers + (:func:`fill_form`, :func:`validate_form`, :meth:`transition_to`) + always return a *new* copy via ``model_copy``. + + Attributes: + title: Human-readable label shown in UI and logs. + fields: Ordered mapping from field name to its :class:`FieldSpec`. + values: Mutable mapping from field name to its current value. + status: Lifecycle status of this form instance. + validation_errors: List of human-readable error messages from the + last call to :func:`validate_form`. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + ) + + title: str = Field("", description="Human-readable form title.") + fields: dict[str, FieldSpec] = Field( + default_factory=dict, + description="Field name → FieldSpec mapping.", + ) + values: dict[str, Any] = Field( + default_factory=dict, + description="Current field values.", + ) + status: FormStatus = Field("draft", description="Form lifecycle status.") + validation_errors: list[str] = Field( + default_factory=list, + description="Errors from the most recent validation pass.", + ) + + @property + def form_id(self) -> str: + """Convenience alias: string representation of the Element UUID id.""" + return str(self.id) + + def get(self, name: str, default: Any = None) -> Any: + """Return the value for *name*, falling back to *default*.""" + return self.values.get(name, default) + + def field_names(self) -> list[str]: + """Return the list of declared field names.""" + return list(self.fields.keys()) + + def is_complete(self) -> bool: + """Return True when status is ``validated`` or ``completed``.""" + return self.status in {"validated", "completed"} + + def transition_to(self, new_status: FormStatus) -> WorkForm: + """Return a *new* WorkForm after validating the status transition. + + Args: + new_status: The desired next lifecycle status. + + Returns: + A new WorkForm with ``status`` set to *new_status*. + + Raises: + ValueError: If the transition from the current status to + *new_status* is not permitted by :data:`VALID_TRANSITIONS`. + """ + allowed = VALID_TRANSITIONS.get(self.status, frozenset()) + if new_status not in allowed: + raise ValueError( + f"Invalid transition {self.status!r} → {new_status!r}. " + f"Allowed from {self.status!r}: " + f"{sorted(allowed) or '(none — terminal state)'}." + ) + return self.model_copy(update={"status": new_status}) + + +# --------------------------------------------------------------------------- +# Functional API +# --------------------------------------------------------------------------- + + +def fill_form( + form: WorkForm, + values: dict[str, Any], + *, + ruleset: RuleSet | None = None, +) -> WorkForm: + """Return a *new* WorkForm with *values* merged into it. + + Missing fields whose FieldSpec declares a non-None ``default`` are + pre-filled with that default. After merging, :func:`validate_form` is + called automatically — the returned form will have status ``validated`` + or ``error``. + + Args: + form: Source form (not mutated). + values: Key/value pairs to set on the form. + ruleset: Optional :class:`~lionagi.work.rules.RuleSet` to apply as + part of validation. Forwarded to :func:`validate_form`. + + Returns: + A new WorkForm instance with merged values and updated status. + """ + merged: dict[str, Any] = {} + for name, spec in form.fields.items(): + if name in values: + merged[name] = values[name] + elif spec.default is not None: + merged[name] = spec.default + # Required with no value: leave absent so validate_form flags it. + + # Propagate extra keys that are not declared in spec (passed through as-is). + for k, v in values.items(): + if k not in merged: + merged[k] = v + + filled = form.model_copy(update={"values": merged, "status": "filled", "validation_errors": []}) + return validate_form(filled, ruleset=ruleset) + + +def validate_form( + form: WorkForm, + *, + ruleset: RuleSet | None = None, +) -> WorkForm: + """Validate *form* values against its FieldSpec declarations. + + Returns a *new* WorkForm with status ``validated`` when all checks pass, + or ``error`` with ``validation_errors`` populated when any check fails. + + Checks performed per declared field: + + 1. Required fields must be present (key exists) and not ``None``. + 2. Present values must be coercible to the declared type; coerced + values are stored in the returned form's ``values``. + + When *ruleset* is provided, its rules are evaluated **after** the + FieldSpec checks. Any rule failures prevent ``validated`` status — + the form will be ``error`` and rule error messages are appended to + ``validation_errors``. + + Args: + form: Form to validate (not mutated). + ruleset: Optional :class:`~lionagi.work.rules.RuleSet`. When + supplied, rules run as part of this validation pass and + failures are treated identically to spec failures. + + Returns: + New WorkForm with updated ``status`` and ``validation_errors``. + """ + errors: list[str] = [] + coerced_values: dict[str, Any] = dict(form.values) + + for name, spec in form.fields.items(): + value = form.values.get(name) + + # Required check. + if spec.required and value is None: + errors.append(f"Field {name!r} is required but missing or None.") + continue + + # Type check / coercion (only when a value is present). + if value is not None: + try: + coerced_values[name] = spec.coerce(value) + except TypeError as exc: + errors.append(str(exc)) + + # Run ruleset against a form that carries the coerced values, so rules + # see the post-coercion state (e.g., "7" already became 7). + if ruleset is not None: + coerced_form = form.model_copy(update={"values": coerced_values}) + rule_errors = ruleset.apply_all(coerced_form) + errors.extend(rule_errors) + + new_status: FormStatus = "error" if errors else "validated" + return form.model_copy( + update={ + "values": coerced_values, + "status": new_status, + "validation_errors": errors, + } + ) diff --git a/lionagi/work/rules.py b/lionagi/work/rules.py new file mode 100644 index 000000000..cac098841 --- /dev/null +++ b/lionagi/work/rules.py @@ -0,0 +1,300 @@ +# Copyright (c) 2023-2025, HaiyangLi +# SPDX-License-Identifier: Apache-2.0 + +r"""Declarative validation rules for WorkForm fields. + +Rules complement FieldSpec by expressing *value-level* or *cross-field* +constraints that cannot be expressed in a plain type declaration: + +- **required**: field must be present and not None. +- **type**: value must be an instance of the declared type. +- **range**: numeric value must fall within [min, max]. +- **pattern**: string value must match a regex pattern. +- **custom**: arbitrary Python callable returning bool. + +Usage:: + + from lionagi.work.rules import Rule, RuleSet + from lionagi.work.form import WorkForm, FieldSpec + + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="age", check="range", params={"min": 0, "max": 150})) + rs.add(Rule(rule_id="r2", field="email", check="pattern", + params={"pattern": r".+@.+\..+"})) + + errors = rs.apply_all(form) + +.. warning:: + + **Pattern rules are NOT safe for untrusted or adversarial input.** + + The stdlib ``re`` engine uses backtracking and can hold the GIL during + catastrophic matches, making any thread-based timeout ineffective. + Pattern rules are intended for **trusted patterns only** — for example, + validating application-controlled fields (phone formats, zip codes, etc.) + where the pattern is authored by the developer, not supplied by users. + + To mitigate worst-case performance: inputs exceeding + :data:`REGEX_MAX_INPUT_LENGTH` characters are rejected outright before + the regex engine is invoked. This bounds the *input* dimension; it does + not bound the *pattern* dimension. Nested-quantifier patterns such as + ``(a+)+`` remain pathological regardless of input length if that limit + is not tight enough. + + If you need safe matching against untrusted patterns or very long + inputs, use a non-backtracking engine (e.g., ``google-re2``) and + provide a ``custom`` rule backed by that engine instead. +""" + +from __future__ import annotations + +import re +from collections.abc import Callable +from typing import Any, Literal + +from pydantic import BaseModel, Field + +from .form import WorkForm + +__all__ = ( + "Rule", + "RuleSet", + "CheckKind", + "REGEX_MAX_INPUT_LENGTH", +) + +# Maximum input length for pattern checks. Inputs longer than this are +# rejected before regex evaluation. This limits the *input* dimension of +# worst-case backtracking but does NOT eliminate the risk for pathological +# patterns. See module docstring. +REGEX_MAX_INPUT_LENGTH: int = 4096 + +CheckKind = Literal["required", "type", "range", "pattern", "custom"] + + +class Rule(BaseModel): + """A single declarative validation rule. + + Attributes: + rule_id: Unique identifier within a RuleSet. + field: Name of the WorkForm field this rule targets. + check: Kind of check to perform. + params: Check-specific parameters: + + - ``range``: ``{"min": , "max": }`` — either + or both bounds are optional. + - ``pattern``: ``{"pattern": "", "flags": }`` — + ``flags`` defaults to 0. See module-level warning about + trusted-patterns-only. + - ``type``: ``{"type": ""}`` — one of + ``str|int|float|bool|list|dict``. + - ``custom``: ``{"callable": Callable[[Any], bool], + "error": ""}`` — ``error`` is the fallback message. + + message: Optional override for the generated error message. + enabled: When False, this rule is skipped silently. + """ + + rule_id: str = Field(..., description="Unique rule identifier.") + field: str = Field(..., description="WorkForm field name this rule applies to.") + check: CheckKind = Field(..., description="Kind of validation check.") + params: dict[str, Any] = Field( + default_factory=dict, + description="Check-specific parameters.", + ) + message: str | None = Field( + None, + description="Custom error message (overrides auto-generated text).", + ) + enabled: bool = Field(True, description="Skip rule when False.") + + model_config = {"arbitrary_types_allowed": True} + + def apply(self, form: WorkForm) -> str | None: + """Apply this rule to *form*. + + Returns an error string on failure, ``None`` on pass or when disabled. + """ + if not self.enabled: + return None + + value = form.values.get(self.field) + + if self.check == "required": + return self._check_required(value) + if self.check == "type": + return self._check_type(value) + if self.check == "range": + return self._check_range(value) + if self.check == "pattern": + return self._check_pattern(value) + if self.check == "custom": + return self._check_custom(value) + + return f"Rule {self.rule_id!r}: unknown check kind {self.check!r}." + + # ------------------------------------------------------------------ + # Internal checkers + # ------------------------------------------------------------------ + + def _check_required(self, value: Any) -> str | None: + if value is None: + return self.message or f"Field {self.field!r} is required but missing or None." + return None + + def _check_type(self, value: Any) -> str | None: + if value is None: + return None # type check does not apply to absent values + expected = self.params.get("type", "str") + type_map: dict[str, type] = { + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + } + target = type_map.get(expected) + if target is None: + return f"Rule {self.rule_id!r}: unknown type {expected!r}." + # Allow int where float is expected (numeric widening). + if expected == "float" and isinstance(value, int): + return None + if not isinstance(value, target): + return self.message or ( + f"Field {self.field!r} must be type {expected!r}, got {type(value).__name__!r}." + ) + return None + + def _check_range(self, value: Any) -> str | None: + if value is None: + return None + if not isinstance(value, int | float): + return self.message or ( + f"Field {self.field!r}: range check requires numeric type, " + f"got {type(value).__name__!r}." + ) + lo = self.params.get("min") + hi = self.params.get("max") + if lo is not None and value < lo: + return self.message or f"Field {self.field!r} = {value} is below minimum {lo}." + if hi is not None and value > hi: + return self.message or f"Field {self.field!r} = {value} exceeds maximum {hi}." + return None + + def _check_pattern(self, value: Any) -> str | None: + """Check that *value* matches the declared pattern. + + .. warning:: + Uses the stdlib ``re`` backtracking engine. Suitable for + **trusted patterns only**. See module docstring for details. + """ + if value is None: + return None + if not isinstance(value, str): + return self.message or ( + f"Field {self.field!r}: pattern check requires str, got {type(value).__name__!r}." + ) + + # Reject inputs that exceed the configurable length limit. + # This bounds the input dimension of worst-case backtracking; it + # does NOT make arbitrary patterns safe (see module docstring). + if len(value) > REGEX_MAX_INPUT_LENGTH: + return self.message or ( + f"Field {self.field!r}: input length {len(value)} exceeds " + f"the maximum allowed for pattern matching ({REGEX_MAX_INPUT_LENGTH})." + ) + + pattern = self.params.get("pattern", "") + flags = int(self.params.get("flags", 0)) + try: + re.compile(pattern, flags) + except re.error as exc: + return f"Rule {self.rule_id!r}: invalid regex pattern — {exc}." + + if not re.search(pattern, value, flags): + return self.message or ( + f"Field {self.field!r} value {value!r} does not match pattern {pattern!r}." + ) + return None + + def _check_custom(self, value: Any) -> str | None: + fn: Callable[[Any], bool] | None = self.params.get("callable") + if fn is None: + return f"Rule {self.rule_id!r}: 'custom' check requires params['callable']." + try: + passed = fn(value) + except Exception as exc: # noqa: BLE001 + return f"Rule {self.rule_id!r}: custom check raised {type(exc).__name__}: {exc}." + if not passed: + return ( + self.message + or self.params.get("error") + or f"Field {self.field!r} failed custom check {self.rule_id!r}." + ) + return None + + +class RuleSet: + """An ordered collection of :class:`Rule` objects. + + Rules are applied in insertion order. All rules are evaluated + (no short-circuit), so the caller receives a complete list of errors. + + Each rule must have a unique ``rule_id`` within this set — :meth:`add` + raises ``ValueError`` if a duplicate ``rule_id`` is supplied. + + Usage:: + + rs = RuleSet() + rs.add(Rule(...)) + errors = rs.apply_all(form) + """ + + def __init__(self) -> None: + self._rules: list[Rule] = [] + + def add(self, rule: Rule) -> RuleSet: + """Append *rule* and return ``self`` for chaining. + + Raises: + ValueError: If a rule with the same ``rule_id`` already exists + in this set. + """ + if any(r.rule_id == rule.rule_id for r in self._rules): + raise ValueError( + f"RuleSet already contains a rule with rule_id={rule.rule_id!r}. " + "Use a unique rule_id or remove the existing rule first." + ) + self._rules.append(rule) + return self + + def remove(self, rule_id: str) -> bool: + """Remove the rule with *rule_id*. Returns True if found and removed.""" + before = len(self._rules) + self._rules = [r for r in self._rules if r.rule_id != rule_id] + return len(self._rules) < before + + def get(self, rule_id: str) -> Rule | None: + """Return the rule with *rule_id*, or ``None`` if not found.""" + for r in self._rules: + if r.rule_id == rule_id: + return r + return None + + def rules(self) -> list[Rule]: + """Return a shallow copy of the rule list.""" + return list(self._rules) + + def apply_all(self, form: WorkForm) -> list[str]: + """Apply every enabled rule to *form*. + + Returns a list of error messages. An empty list means all rules + passed (or all were disabled). + """ + errors: list[str] = [] + for rule in self._rules: + err = rule.apply(form) + if err is not None: + errors.append(err) + return errors diff --git a/tests/work/__init__.py b/tests/work/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/work/test_work_forms_rules.py b/tests/work/test_work_forms_rules.py new file mode 100644 index 000000000..0a4fab2fd --- /dev/null +++ b/tests/work/test_work_forms_rules.py @@ -0,0 +1,1000 @@ +# Copyright (c) 2023-2025, HaiyangLi +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for lionagi.work: WorkForm, FieldSpec, Rule, RuleSet. + +Covers the public API exported by lionagi.work.__init__: + - FieldSpec declaration, type coercion, and default validation + - WorkForm Element inheritance (UUID id, created_at, form_id alias) + - WorkForm lifecycle and transition logic + - fill_form / validate_form functional helpers (including ruleset param) + - Rule (required, type, range, pattern, custom) apply() + - RuleSet composition (add, remove, get, apply_all, duplicate prevention) + - Bool-as-int subclass behaviour (pinned) + - ReDoS: length cap enforced; no phantom timeout claim +""" + +from __future__ import annotations + +from typing import Any +from uuid import UUID + +import pytest + +from lionagi.work import ( + REGEX_MAX_INPUT_LENGTH, + VALID_TRANSITIONS, + FieldSpec, + FormStatus, + Rule, + RuleSet, + WorkForm, + fill_form, + validate_form, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_form( + fields: dict[str, dict[str, Any]] | None = None, + values: dict[str, Any] | None = None, + status: FormStatus = "draft", +) -> WorkForm: + """Build a WorkForm for testing from compact field dicts.""" + specs = { + name: FieldSpec(name=name, **spec_kwargs) for name, spec_kwargs in (fields or {}).items() + } + return WorkForm( + title="Test Form", + fields=specs, + values=values or {}, + status=status, + ) + + +# --------------------------------------------------------------------------- +# WorkForm — Element inheritance +# --------------------------------------------------------------------------- + + +class TestWorkFormElement: + def test_has_uuid_id(self): + form = WorkForm() + assert isinstance(form.id, UUID) + + def test_form_id_is_str_alias_of_id(self): + form = WorkForm() + assert form.form_id == str(form.id) + + def test_has_created_at_timestamp(self): + form = WorkForm() + assert form.created_at > 0 + + def test_has_metadata(self): + form = WorkForm() + assert isinstance(form.metadata, dict) + + def test_two_forms_have_distinct_ids(self): + f1 = WorkForm() + f2 = WorkForm() + assert f1.id != f2.id + + def test_model_copy_preserves_id(self): + form = WorkForm(title="original") + copy = form.model_copy(update={"title": "copy"}) + assert copy.id == form.id + + def test_element_equality_by_id(self): + form = WorkForm() + # Two distinct instances with same id compare equal + copy = form.model_copy(update={}) + assert form == copy + + def test_element_hash_by_id(self): + form = WorkForm() + # Should be hashable + s = {form} + assert form in s + + +# --------------------------------------------------------------------------- +# FieldSpec +# --------------------------------------------------------------------------- + + +class TestFieldSpec: + def test_defaults(self): + spec = FieldSpec(name="x") + assert spec.type == "str" + assert spec.required is True + assert spec.default is None + assert spec.description == "" + + def test_fieldspec_also_has_element_id(self): + """FieldSpec inherits Element — each spec is a tracked object.""" + spec = FieldSpec(name="x") + assert isinstance(spec.id, UUID) + + def test_name_valid_letter_start(self): + FieldSpec(name="my_field_1") # should not raise + + def test_name_valid_underscore_start(self): + FieldSpec(name="_private") # should not raise + + def test_name_invalid_digit_start(self): + with pytest.raises(Exception): + FieldSpec(name="1bad") + + def test_name_invalid_spaces(self): + with pytest.raises(Exception): + FieldSpec(name="bad field") + + def test_name_invalid_hyphen(self): + with pytest.raises(Exception): + FieldSpec(name="bad-field") + + # Default value type validation (Fix 4) + def test_valid_int_default(self): + spec = FieldSpec(name="n", type="int", required=False, default=5) + assert spec.default == 5 + + def test_invalid_default_rejected_at_construction(self): + with pytest.raises(Exception, match="not compatible"): + FieldSpec(name="n", type="int", required=False, default="not_an_int") + + def test_invalid_list_default_for_str_rejected(self): + with pytest.raises(Exception): + FieldSpec(name="x", type="str", required=False, default=[1, 2, 3]) + + def test_float_field_accepts_int_default(self): + """int default for float field is valid (numeric widening).""" + spec = FieldSpec(name="f", type="float", required=False, default=3) + assert spec.default == 3 + + def test_none_default_always_valid(self): + """None means 'no default', not a value — skips type check.""" + spec = FieldSpec(name="x", type="int", required=False, default=None) + assert spec.default is None + + # Coerce: same type passthrough + def test_coerce_str_passthrough(self): + spec = FieldSpec(name="s", type="str") + assert spec.coerce("hello") == "hello" + + def test_coerce_int_passthrough(self): + spec = FieldSpec(name="n", type="int") + assert spec.coerce(42) == 42 + + # Coerce: numeric widening + def test_coerce_int_to_float(self): + spec = FieldSpec(name="f", type="float") + result = spec.coerce(3) + assert result == 3.0 + assert isinstance(result, float) + + def test_coerce_float_passthrough(self): + spec = FieldSpec(name="f", type="float") + assert spec.coerce(1.5) == 1.5 + + # Coerce: str → bool + def test_coerce_str_true_variants(self): + spec = FieldSpec(name="b", type="bool") + assert spec.coerce("true") is True + assert spec.coerce("TRUE") is True + assert spec.coerce("yes") is True + assert spec.coerce("1") is True + + def test_coerce_str_false_variants(self): + spec = FieldSpec(name="b", type="bool") + assert spec.coerce("false") is False + assert spec.coerce("FALSE") is False + assert spec.coerce("no") is False + assert spec.coerce("0") is False + + # Coerce: str → numeric + def test_coerce_str_to_int(self): + spec = FieldSpec(name="n", type="int") + assert spec.coerce("42") == 42 + + def test_coerce_str_to_float(self): + spec = FieldSpec(name="f", type="float") + assert spec.coerce("3.14") == pytest.approx(3.14) + + def test_coerce_type_mismatch_raises_type_error(self): + spec = FieldSpec(name="n", type="int") + with pytest.raises(TypeError): + spec.coerce([1, 2, 3]) + + def test_coerce_none_returns_none(self): + spec = FieldSpec(name="x", type="str") + assert spec.coerce(None) is None + + def test_coerce_invalid_str_to_int_raises(self): + spec = FieldSpec(name="n", type="int") + with pytest.raises(TypeError): + spec.coerce("not_a_number") + + # Bool-as-int subclass behaviour (pinned — see reviewer request) + def test_bool_is_accepted_by_int_type(self): + """Python bool is a subclass of int. Pinned: type='int' accepts True/False.""" + spec = FieldSpec(name="flag", type="int") + # isinstance(True, int) is True, so coerce returns True as-is + assert spec.coerce(True) is True + assert spec.coerce(False) is False + + def test_bool_coerce_to_bool_type_passthrough(self): + spec = FieldSpec(name="flag", type="bool") + assert spec.coerce(True) is True + assert spec.coerce(False) is False + + +# --------------------------------------------------------------------------- +# WorkForm +# --------------------------------------------------------------------------- + + +class TestWorkForm: + def test_creation_defaults(self): + form = WorkForm(title="My Form") + assert form.title == "My Form" + assert form.fields == {} + assert form.values == {} + assert form.status == "draft" + assert form.validation_errors == [] + + def test_no_args_creates_valid_form(self): + form = WorkForm() + assert form.status == "draft" + assert isinstance(form.id, UUID) + + def test_field_names(self): + form = _make_form({"a": {"type": "str"}, "b": {"type": "int"}}) + assert set(form.field_names()) == {"a", "b"} + + def test_field_names_empty(self): + form = WorkForm() + assert form.field_names() == [] + + def test_get_value_present(self): + form = _make_form(values={"x": "hello"}) + assert form.get("x") == "hello" + + def test_get_value_missing_returns_default(self): + form = _make_form() + assert form.get("missing", "fallback") == "fallback" + + def test_get_value_missing_returns_none_default(self): + form = _make_form() + assert form.get("missing") is None + + def test_is_complete_false_on_draft(self): + form = _make_form() + assert form.is_complete() is False + + def test_is_complete_false_on_filled(self): + form = _make_form(status="filled") + assert form.is_complete() is False + + def test_is_complete_false_on_error(self): + form = _make_form(status="error") + assert form.is_complete() is False + + def test_is_complete_true_on_validated(self): + form = _make_form(status="validated") + assert form.is_complete() is True + + def test_is_complete_true_on_completed(self): + form = _make_form(status="completed") + assert form.is_complete() is True + + +class TestWorkFormTransitions: + def test_draft_to_filled(self): + form = _make_form() + new_form = form.transition_to("filled") + assert new_form.status == "filled" + assert form.status == "draft" # original unchanged + + def test_filled_to_validated(self): + form = _make_form(status="filled") + new_form = form.transition_to("validated") + assert new_form.status == "validated" + + def test_filled_to_error(self): + form = _make_form(status="filled") + new_form = form.transition_to("error") + assert new_form.status == "error" + + def test_error_to_draft(self): + form = _make_form(status="error") + new_form = form.transition_to("draft") + assert new_form.status == "draft" + + def test_validated_to_submitted(self): + form = _make_form(status="validated") + new_form = form.transition_to("submitted") + assert new_form.status == "submitted" + + def test_submitted_to_completed(self): + form = _make_form(status="submitted") + new_form = form.transition_to("completed") + assert new_form.status == "completed" + + def test_invalid_transition_raises(self): + form = _make_form(status="completed") + with pytest.raises(ValueError, match="terminal"): + form.transition_to("draft") + + def test_draft_cannot_skip_to_validated(self): + form = _make_form() + with pytest.raises(ValueError): + form.transition_to("validated") + + def test_transition_returns_new_instance(self): + form = _make_form() + new_form = form.transition_to("filled") + assert form is not new_form + + def test_valid_transitions_table_coverage(self): + """All expected transitions from VALID_TRANSITIONS are accepted.""" + for from_status, allowed in VALID_TRANSITIONS.items(): + for to_status in allowed: + form = _make_form(status=from_status) # type: ignore[arg-type] + result = form.transition_to(to_status) # type: ignore[arg-type] + assert result.status == to_status + + +# --------------------------------------------------------------------------- +# validate_form — FieldSpec checks +# --------------------------------------------------------------------------- + + +class TestValidateForm: + def test_validates_required_field_present(self): + form = _make_form( + fields={"name": {"type": "str", "required": True}}, + values={"name": "Alice"}, + ) + result = validate_form(form) + assert result.status == "validated" + assert result.validation_errors == [] + + def test_error_on_missing_required(self): + form = _make_form(fields={"name": {"type": "str", "required": True}}) + result = validate_form(form) + assert result.status == "error" + assert any("name" in e for e in result.validation_errors) + + def test_optional_field_absent_is_ok(self): + form = _make_form(fields={"opt": {"type": "str", "required": False}}) + result = validate_form(form) + assert result.status == "validated" + assert result.validation_errors == [] + + def test_type_mismatch_yields_error(self): + form = _make_form( + fields={"count": {"type": "int"}}, + values={"count": [1, 2]}, + ) + result = validate_form(form) + assert result.status == "error" + assert any("count" in e for e in result.validation_errors) + + def test_coerces_string_int(self): + form = _make_form( + fields={"n": {"type": "int"}}, + values={"n": "7"}, + ) + result = validate_form(form) + assert result.status == "validated" + assert result.values["n"] == 7 + + def test_coerces_int_to_float(self): + form = _make_form( + fields={"f": {"type": "float"}}, + values={"f": 3}, + ) + result = validate_form(form) + assert result.status == "validated" + assert result.values["f"] == 3.0 + + def test_does_not_mutate_original(self): + form = _make_form( + fields={"x": {"type": "str"}}, + values={"x": "hello"}, + ) + result = validate_form(form) + assert form is not result + assert form.status == "draft" # original status preserved + + def test_multiple_errors_collected(self): + form = _make_form( + fields={ + "a": {"type": "str", "required": True}, + "b": {"type": "int", "required": True}, + } + ) + result = validate_form(form) + assert result.status == "error" + assert len(result.validation_errors) == 2 + + def test_no_fields_validates_ok(self): + form = WorkForm() + result = validate_form(form) + assert result.status == "validated" + + +# --------------------------------------------------------------------------- +# validate_form — ruleset integration (Fix 2) +# --------------------------------------------------------------------------- + + +class TestValidateFormWithRuleset: + def test_ruleset_failure_blocks_validated_status(self): + """Spec passes but rule fails → status must be 'error'.""" + form = _make_form( + fields={"age": {"type": "int"}}, + values={"age": 200}, + status="filled", + ) + rs = RuleSet() + rs.add(Rule(rule_id="age_range", field="age", check="range", params={"min": 0, "max": 120})) + result = validate_form(form, ruleset=rs) + assert result.status == "error" + assert any("maximum" in e for e in result.validation_errors) + + def test_ruleset_pass_and_spec_pass_yields_validated(self): + """Spec passes AND all rules pass → status is 'validated'.""" + form = _make_form( + fields={"age": {"type": "int"}}, + values={"age": 30}, + status="filled", + ) + rs = RuleSet() + rs.add(Rule(rule_id="age_range", field="age", check="range", params={"min": 0, "max": 120})) + result = validate_form(form, ruleset=rs) + assert result.status == "validated" + assert result.validation_errors == [] + + def test_spec_failure_plus_ruleset_errors_all_collected(self): + """Both spec errors and rule errors are reported together.""" + form = _make_form( + fields={ + "name": {"type": "str", "required": True}, + "age": {"type": "int"}, + }, + values={"age": 200}, + status="filled", + ) + rs = RuleSet() + rs.add(Rule(rule_id="age_range", field="age", check="range", params={"max": 120})) + result = validate_form(form, ruleset=rs) + assert result.status == "error" + # both 'name' missing and 'age' range errors should appear + texts = " ".join(result.validation_errors) + assert "name" in texts + assert "maximum" in texts + + def test_ruleset_sees_coerced_values(self): + """Rules run after coercion — '25' is coerced to 25 before range check.""" + form = _make_form( + fields={"age": {"type": "int"}}, + values={"age": "25"}, # string + status="filled", + ) + rs = RuleSet() + rs.add(Rule(rule_id="r", field="age", check="range", params={"min": 0, "max": 120})) + result = validate_form(form, ruleset=rs) + assert result.status == "validated" + + def test_no_ruleset_path_unchanged(self): + """validate_form without ruleset behaves exactly as before.""" + form = _make_form( + fields={"x": {"type": "str"}}, + values={"x": "hello"}, + ) + result = validate_form(form) + assert result.status == "validated" + + def test_fill_form_accepts_ruleset_kwarg(self): + """fill_form forwards ruleset to validate_form.""" + form = _make_form(fields={"score": {"type": "int"}}) + rs = RuleSet() + rs.add(Rule(rule_id="r", field="score", check="range", params={"max": 100})) + result = fill_form(form, {"score": 150}, ruleset=rs) + assert result.status == "error" + assert any("maximum" in e for e in result.validation_errors) + + def test_fill_form_ruleset_pass(self): + form = _make_form(fields={"score": {"type": "int"}}) + rs = RuleSet() + rs.add(Rule(rule_id="r", field="score", check="range", params={"max": 100})) + result = fill_form(form, {"score": 80}, ruleset=rs) + assert result.status == "validated" + + +# --------------------------------------------------------------------------- +# fill_form +# --------------------------------------------------------------------------- + + +class TestFillForm: + def test_fill_and_auto_validate(self): + form = _make_form(fields={"msg": {"type": "str"}}) + result = fill_form(form, {"msg": "hello"}) + assert result.status == "validated" + assert result.values["msg"] == "hello" + + def test_fill_uses_default_when_absent(self): + form = _make_form(fields={"level": {"type": "int", "required": False, "default": 1}}) + result = fill_form(form, {}) + assert result.values.get("level") == 1 + + def test_fill_required_missing_yields_error(self): + form = _make_form(fields={"required_field": {"type": "str", "required": True}}) + result = fill_form(form, {}) + assert result.status == "error" + assert any("required_field" in e for e in result.validation_errors) + + def test_fill_extra_keys_preserved(self): + form = _make_form(fields={"a": {"type": "str"}}) + result = fill_form(form, {"a": "x", "extra": 99}) + assert result.values.get("extra") == 99 + + def test_fill_does_not_mutate_original(self): + form = _make_form(fields={"x": {"type": "str"}}) + fill_form(form, {"x": "hello"}) + assert form.values == {} + assert form.status == "draft" + + def test_fill_overrides_previous_values(self): + form = _make_form(fields={"x": {"type": "str"}}, values={"x": "old"}) + result = fill_form(form, {"x": "new"}) + assert result.values["x"] == "new" + + def test_fill_coerces_during_validation(self): + form = _make_form(fields={"n": {"type": "int"}}) + result = fill_form(form, {"n": "42"}) + assert result.status == "validated" + assert result.values["n"] == 42 + + def test_fill_multiple_fields(self): + form = _make_form( + fields={ + "name": {"type": "str"}, + "age": {"type": "int"}, + "active": {"type": "bool"}, + } + ) + result = fill_form(form, {"name": "Bob", "age": 30, "active": True}) + assert result.status == "validated" + assert result.values == {"name": "Bob", "age": 30, "active": True} + + +# --------------------------------------------------------------------------- +# Rule +# --------------------------------------------------------------------------- + + +class TestRuleRequired: + def test_passes_when_value_present(self): + rule = Rule(rule_id="r1", field="name", check="required") + form = WorkForm(values={"name": "Alice"}) + assert rule.apply(form) is None + + def test_fails_when_value_absent(self): + rule = Rule(rule_id="r1", field="name", check="required") + form = WorkForm(values={}) + assert rule.apply(form) is not None + + def test_fails_when_value_is_none(self): + rule = Rule(rule_id="r1", field="name", check="required") + form = WorkForm(values={"name": None}) + assert rule.apply(form) is not None + + def test_custom_message_used(self): + rule = Rule(rule_id="r1", field="x", check="required", message="x is absolutely required") + form = WorkForm(values={}) + assert rule.apply(form) == "x is absolutely required" + + def test_disabled_rule_skipped(self): + rule = Rule(rule_id="r1", field="name", check="required", enabled=False) + form = WorkForm(values={}) + assert rule.apply(form) is None + + +class TestRuleType: + def test_passes_correct_type(self): + rule = Rule(rule_id="r", field="n", check="type", params={"type": "int"}) + form = WorkForm(values={"n": 5}) + assert rule.apply(form) is None + + def test_passes_int_for_float(self): + rule = Rule(rule_id="r", field="f", check="type", params={"type": "float"}) + form = WorkForm(values={"f": 3}) + assert rule.apply(form) is None # int widened to float + + def test_fails_wrong_type(self): + rule = Rule(rule_id="r", field="n", check="type", params={"type": "int"}) + form = WorkForm(values={"n": "not_int"}) + err = rule.apply(form) + assert err is not None + assert "int" in err + + def test_absent_value_passes(self): + rule = Rule(rule_id="r", field="n", check="type", params={"type": "int"}) + form = WorkForm(values={}) + assert rule.apply(form) is None + + def test_unknown_type_returns_error(self): + rule = Rule(rule_id="r", field="n", check="type", params={"type": "uuid"}) + form = WorkForm(values={"n": "some-uuid"}) + err = rule.apply(form) + assert err is not None + assert "unknown type" in err + + def test_bool_value_passes_int_type_check(self): + """Pinned: bool is a subclass of int; type='int' rule accepts True/False.""" + rule = Rule(rule_id="r", field="flag", check="type", params={"type": "int"}) + form = WorkForm(values={"flag": True}) + # isinstance(True, int) is True in Python — this should PASS + assert rule.apply(form) is None + + def test_bool_type_check_accepts_false(self): + rule = Rule(rule_id="r", field="flag", check="type", params={"type": "int"}) + form = WorkForm(values={"flag": False}) + assert rule.apply(form) is None + + +class TestRuleRange: + def test_passes_within_range(self): + rule = Rule(rule_id="r", field="age", check="range", params={"min": 0, "max": 120}) + form = WorkForm(values={"age": 30}) + assert rule.apply(form) is None + + def test_passes_at_min_boundary(self): + rule = Rule(rule_id="r", field="age", check="range", params={"min": 0}) + form = WorkForm(values={"age": 0}) + assert rule.apply(form) is None + + def test_passes_at_max_boundary(self): + rule = Rule(rule_id="r", field="score", check="range", params={"max": 100}) + form = WorkForm(values={"score": 100}) + assert rule.apply(form) is None + + def test_fails_below_min(self): + rule = Rule(rule_id="r", field="age", check="range", params={"min": 18}) + form = WorkForm(values={"age": 5}) + err = rule.apply(form) + assert err is not None + assert "minimum" in err + + def test_fails_above_max(self): + rule = Rule(rule_id="r", field="score", check="range", params={"max": 100}) + form = WorkForm(values={"score": 150}) + err = rule.apply(form) + assert err is not None + assert "maximum" in err + + def test_absent_value_skipped(self): + rule = Rule(rule_id="r", field="n", check="range", params={"min": 0}) + form = WorkForm(values={}) + assert rule.apply(form) is None + + def test_non_numeric_value_error(self): + rule = Rule(rule_id="r", field="n", check="range", params={"min": 0}) + form = WorkForm(values={"n": "five"}) + err = rule.apply(form) + assert err is not None + assert "numeric" in err + + def test_float_in_range(self): + rule = Rule(rule_id="r", field="x", check="range", params={"min": 0.0, "max": 1.0}) + form = WorkForm(values={"x": 0.5}) + assert rule.apply(form) is None + + def test_only_min_bound(self): + rule = Rule(rule_id="r", field="x", check="range", params={"min": 10}) + form = WorkForm(values={"x": 100}) + assert rule.apply(form) is None + + def test_only_max_bound(self): + rule = Rule(rule_id="r", field="x", check="range", params={"max": 10}) + form = WorkForm(values={"x": 0}) + assert rule.apply(form) is None + + +class TestRulePattern: + def test_passes_matching_pattern(self): + rule = Rule(rule_id="r", field="email", check="pattern", params={"pattern": r".+@.+"}) + form = WorkForm(values={"email": "a@b.com"}) + assert rule.apply(form) is None + + def test_fails_non_matching(self): + rule = Rule(rule_id="r", field="email", check="pattern", params={"pattern": r".+@.+"}) + form = WorkForm(values={"email": "notanemail"}) + assert rule.apply(form) is not None + + def test_absent_value_skipped(self): + rule = Rule(rule_id="r", field="email", check="pattern", params={"pattern": r".+@.+"}) + form = WorkForm(values={}) + assert rule.apply(form) is None + + def test_non_string_value_error(self): + rule = Rule(rule_id="r", field="x", check="pattern", params={"pattern": r"\d+"}) + form = WorkForm(values={"x": 42}) + err = rule.apply(form) + assert err is not None + assert "str" in err + + def test_invalid_regex_returns_error(self): + rule = Rule(rule_id="r", field="x", check="pattern", params={"pattern": r"[invalid"}) + form = WorkForm(values={"x": "hello"}) + err = rule.apply(form) + assert err is not None + assert "invalid regex" in err + + def test_input_too_long_rejected(self): + """Length cap (REGEX_MAX_INPUT_LENGTH) is enforced before re engine.""" + rule = Rule(rule_id="r", field="x", check="pattern", params={"pattern": r".*"}) + form = WorkForm(values={"x": "a" * (REGEX_MAX_INPUT_LENGTH + 1)}) + err = rule.apply(form) + assert err is not None + assert "length" in err + + def test_input_at_exact_limit_accepted(self): + """Input at exactly the limit is allowed (limit is exclusive).""" + rule = Rule(rule_id="r", field="x", check="pattern", params={"pattern": r"a+"}) + form = WorkForm(values={"x": "a" * REGEX_MAX_INPUT_LENGTH}) + err = rule.apply(form) + assert err is None # matches the pattern, within limit + + def test_no_timeout_mechanism_exists(self): + """Confirm the removed threat: REGEX_MATCH_TIMEOUT is not exported.""" + import lionagi.work as work_module + + assert not hasattr(work_module, "REGEX_MATCH_TIMEOUT"), ( + "REGEX_MATCH_TIMEOUT was removed because thread-join timeout cannot " + "interrupt the GIL-holding C re engine. Do not re-add it." + ) + + def test_custom_message_on_mismatch(self): + rule = Rule( + rule_id="r", + field="code", + check="pattern", + params={"pattern": r"^\d{4}$"}, + message="Must be 4 digits", + ) + form = WorkForm(values={"code": "abc"}) + assert rule.apply(form) == "Must be 4 digits" + + +class TestRuleCustom: + def test_passes_when_callable_returns_true(self): + rule = Rule( + rule_id="r", + field="val", + check="custom", + params={"callable": lambda v: v is not None and v > 0}, + ) + form = WorkForm(values={"val": 5}) + assert rule.apply(form) is None + + def test_fails_when_callable_returns_false(self): + rule = Rule( + rule_id="r", + field="val", + check="custom", + params={"callable": lambda v: v is not None and v > 0}, + ) + form = WorkForm(values={"val": -1}) + assert rule.apply(form) is not None + + def test_uses_params_error_message(self): + rule = Rule( + rule_id="r", + field="val", + check="custom", + params={ + "callable": lambda v: v > 0, + "error": "Must be positive", + }, + ) + form = WorkForm(values={"val": -1}) + err = rule.apply(form) + assert "Must be positive" in err + + def test_uses_rule_message_over_params_error(self): + rule = Rule( + rule_id="r", + field="val", + check="custom", + params={"callable": lambda v: False, "error": "params error"}, + message="rule message", + ) + form = WorkForm(values={"val": 1}) + assert rule.apply(form) == "rule message" + + def test_missing_callable_returns_error(self): + rule = Rule(rule_id="r", field="val", check="custom", params={}) + form = WorkForm(values={"val": 1}) + err = rule.apply(form) + assert err is not None + assert "callable" in err + + def test_callable_exception_returns_error(self): + def bad_fn(v: Any) -> bool: + raise RuntimeError("exploded") + + rule = Rule( + rule_id="r", + field="val", + check="custom", + params={"callable": bad_fn}, + ) + form = WorkForm(values={"val": 1}) + err = rule.apply(form) + assert err is not None + assert "RuntimeError" in err + + +# --------------------------------------------------------------------------- +# RuleSet +# --------------------------------------------------------------------------- + + +class TestRuleSet: + def test_add_and_apply_all_no_errors(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="name", check="required")) + form = WorkForm(values={"name": "Bob"}) + assert rs.apply_all(form) == [] + + def test_apply_all_collects_all_errors(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="a", check="required")) + rs.add(Rule(rule_id="r2", field="b", check="required")) + form = WorkForm(values={}) + errors = rs.apply_all(form) + assert len(errors) == 2 + + def test_apply_all_no_rules_returns_empty(self): + rs = RuleSet() + form = WorkForm(values={}) + assert rs.apply_all(form) == [] + + def test_remove_existing_rule(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="x", check="required")) + removed = rs.remove("r1") + assert removed is True + assert rs.get("r1") is None + + def test_remove_nonexistent_returns_false(self): + rs = RuleSet() + assert rs.remove("nope") is False + + def test_get_rule_found(self): + rs = RuleSet() + rule = Rule(rule_id="r1", field="x", check="required") + rs.add(rule) + assert rs.get("r1") is rule + + def test_get_rule_not_found(self): + rs = RuleSet() + assert rs.get("missing") is None + + def test_add_returns_self_for_chaining(self): + rs = RuleSet() + result = rs.add(Rule(rule_id="r1", field="x", check="required")) + assert result is rs + + def test_chained_add(self): + rs = ( + RuleSet() + .add(Rule(rule_id="r1", field="x", check="required")) + .add(Rule(rule_id="r2", field="y", check="required")) + ) + assert len(rs.rules()) == 2 + + def test_rules_returns_copy(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="x", check="required")) + copy = rs.rules() + copy.clear() + assert len(rs.rules()) == 1 # original unaffected + + def test_disabled_rule_not_counted_in_errors(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="a", check="required")) + rs.add(Rule(rule_id="r2", field="b", check="required", enabled=False)) + form = WorkForm(values={}) + errors = rs.apply_all(form) + assert len(errors) == 1 # only r1 fires + + def test_rules_applied_in_insertion_order(self): + rs = RuleSet() + rs.add(Rule(rule_id="first", field="a", check="required")) + rs.add(Rule(rule_id="second", field="b", check="required")) + form = WorkForm(values={}) + errors = rs.apply_all(form) + assert "a" in errors[0] + assert "b" in errors[1] + + def test_remove_and_reapply(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="x", check="required")) + rs.remove("r1") + form = WorkForm(values={}) + assert rs.apply_all(form) == [] + + def test_mixed_pass_and_fail(self): + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="age", check="range", params={"min": 0, "max": 120})) + rs.add(Rule(rule_id="r2", field="email", check="pattern", params={"pattern": r".+@.+"})) + form = WorkForm(values={"age": 30, "email": "notanemail"}) + errors = rs.apply_all(form) + assert len(errors) == 1 + assert "email" in errors[0] + + # Duplicate rule_id prevention (Fix 5) + def test_add_duplicate_rule_id_raises(self): + """add() must reject a rule_id already in the set.""" + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="x", check="required")) + with pytest.raises(ValueError, match="r1"): + rs.add(Rule(rule_id="r1", field="y", check="required")) + + def test_add_duplicate_after_remove_is_ok(self): + """After removing a rule, its id can be reused.""" + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="x", check="required")) + rs.remove("r1") + # Should not raise + rs.add(Rule(rule_id="r1", field="y", check="required")) + assert rs.get("r1") is not None + assert rs.get("r1").field == "y" + + +# --------------------------------------------------------------------------- +# Integration: fill_form + RuleSet (standalone apply) +# --------------------------------------------------------------------------- + + +class TestFillFormWithRuleSetStandalone: + def test_filled_validated_form_passes_ruleset(self): + form = _make_form(fields={"age": {"type": "int"}}) + filled = fill_form(form, {"age": 25}) + assert filled.status == "validated" + + rs = RuleSet() + rs.add(Rule(rule_id="age_range", field="age", check="range", params={"min": 0, "max": 120})) + errors = rs.apply_all(filled) + assert errors == [] + + def test_filled_form_fails_standalone_ruleset_range(self): + """apply_all is a diagnostic tool; it doesn't change form status.""" + form = _make_form(fields={"age": {"type": "int"}}) + filled = fill_form(form, {"age": 200}) + assert filled.status == "validated" # type check passes, no ruleset used + + rs = RuleSet() + rs.add(Rule(rule_id="age_range", field="age", check="range", params={"min": 0, "max": 120})) + errors = rs.apply_all(filled) + assert len(errors) == 1 + assert "maximum" in errors[0] + + def test_ruleset_on_error_form_for_diagnostics(self): + """A form in error state can still have rules applied independently.""" + form = _make_form(fields={"name": {"type": "str", "required": True}}) + error_form = fill_form(form, {}) + assert error_form.status == "error" + + rs = RuleSet() + rs.add(Rule(rule_id="r1", field="name", check="required")) + errors = rs.apply_all(error_form) + assert len(errors) == 1