From 6dfabb8895aa61f4f1e03d132688e75789fff9c1 Mon Sep 17 00:00:00 2001 From: AJ Hallameyer Date: Sat, 2 May 2026 13:01:32 -0400 Subject: [PATCH] feat: warn on dropped extra fields in public API request models Pydantic v2 BaseModel defaults to extra="ignore", silently dropping unknown fields on request bodies. For internal data structures this is fine, but for public API request bodies it makes mistyped or out-of-spec fields invisible: the request is accepted, the offending field is dropped, and the request runs with whatever defaults the matching declared fields had. We hit this in cluster validation: a mistyped instance_meta field on /place_instance was silently dropped, falling back to the default MlxRing instead of the requested MlxJaccl, with no log line. Symptom was lower decode throughput, no error. This change adds a WarnExtraModel base class that keeps the existing extra="ignore" semantics (so it's non-breaking) but logs a warning the first time a given (model_name, unknown_field) pair is seen, rate-limited per pair to avoid log-spam from repeat requests. Public API request models updated: ChatCompletionRequest, BenchChatCompletionRequest (via inheritance), PlaceInstanceParams, CreateInstanceParams, AddCustomModelParams, DeleteInstanceTaskParams, InstanceLinkBody, ImageGenerationTaskParams, ImageEditsTaskParams, ResponsesRequest, OllamaChatRequest, OllamaGenerateRequest, OllamaShowRequest, ClaudeMessagesRequest. A future change can flip extra="forbid" once the warning has had time to surface miswired clients in the wild. --- src/exo/api/types/api.py | 17 +- src/exo/api/types/claude_api.py | 3 +- src/exo/api/types/ollama_api.py | 7 +- src/exo/api/types/openai_responses.py | 3 +- src/exo/utils/extra_fields_warner.py | 136 +++++++++++++ .../utils/tests/test_extra_fields_warner.py | 179 ++++++++++++++++++ 6 files changed, 332 insertions(+), 13 deletions(-) create mode 100644 src/exo/utils/extra_fields_warner.py create mode 100644 src/exo/utils/tests/test_extra_fields_warner.py diff --git a/src/exo/api/types/api.py b/src/exo/api/types/api.py index 8cfa10dd1a..102c8af8a6 100644 --- a/src/exo/api/types/api.py +++ b/src/exo/api/types/api.py @@ -11,6 +11,7 @@ from exo.shared.types.text_generation import ReasoningDialect, ReasoningEffort from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.shards import Sharding, ShardMetadata +from exo.utils.extra_fields_warner import WarnExtraModel from exo.utils.pydantic_ext import FrozenModel FinishReason = Literal[ @@ -204,7 +205,7 @@ class StreamOptions(BaseModel): include_usage: bool = False -class ChatCompletionRequest(BaseModel): +class ChatCompletionRequest(WarnExtraModel): model: ModelId frequency_penalty: float | None = None messages: list[ChatCompletionMessage] @@ -237,7 +238,7 @@ class BenchChatCompletionRequest(ChatCompletionRequest): use_prefix_cache: bool = False -class AddCustomModelParams(BaseModel): +class AddCustomModelParams(WarnExtraModel): model_id: ModelId @@ -250,14 +251,14 @@ class HuggingFaceSearchResult(BaseModel): tags: list[str] = Field(default_factory=list) -class PlaceInstanceParams(BaseModel): +class PlaceInstanceParams(WarnExtraModel): model_id: ModelId sharding: Sharding = Sharding.Pipeline instance_meta: InstanceMeta = InstanceMeta.MlxRing min_nodes: int = 1 -class CreateInstanceParams(BaseModel): +class CreateInstanceParams(WarnExtraModel): instance: Instance @@ -275,7 +276,7 @@ class PlacementPreviewResponse(BaseModel): previews: list[PlacementPreview] -class DeleteInstanceTaskParams(BaseModel): +class DeleteInstanceTaskParams(WarnExtraModel): instance_id: str @@ -296,7 +297,7 @@ class CancelCommandResponse(BaseModel): command_id: CommandId -class InstanceLinkBody(BaseModel): +class InstanceLinkBody(WarnExtraModel): prefill_instances: list[InstanceId] decode_instances: list[InstanceId] @@ -335,7 +336,7 @@ class AdvancedImageParams(BaseModel): num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None -class ImageGenerationTaskParams(BaseModel): +class ImageGenerationTaskParams(WarnExtraModel): prompt: str background: str | None = None model: str @@ -364,7 +365,7 @@ class BenchImageGenerationTaskParams(ImageGenerationTaskParams): bench: bool = True -class ImageEditsTaskParams(BaseModel): +class ImageEditsTaskParams(WarnExtraModel): """Internal task params for image-editing requests.""" image_data: str = "" # Base64-encoded image (empty when using chunked transfer) diff --git a/src/exo/api/types/claude_api.py b/src/exo/api/types/claude_api.py index 645da46398..d7ac65b3a5 100644 --- a/src/exo/api/types/claude_api.py +++ b/src/exo/api/types/claude_api.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from exo.shared.types.common import ModelId +from exo.utils.extra_fields_warner import WarnExtraModel # Tool definition types ClaudeToolInputSchema = dict[str, Any] @@ -101,7 +102,7 @@ class ClaudeThinkingConfig(BaseModel, frozen=True): budget_tokens: int | None = None -class ClaudeMessagesRequest(BaseModel): +class ClaudeMessagesRequest(WarnExtraModel): """Request body for Claude Messages API.""" model: ModelId diff --git a/src/exo/api/types/ollama_api.py b/src/exo/api/types/ollama_api.py index 58a54ac02f..ef2bd67ef6 100644 --- a/src/exo/api/types/ollama_api.py +++ b/src/exo/api/types/ollama_api.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from exo.shared.models.model_cards import ModelId +from exo.utils.extra_fields_warner import WarnExtraModel # https://github.com/ollama/ollama/blob/main/docs/api.md @@ -44,7 +45,7 @@ class OllamaOptions(BaseModel, frozen=True): seed: int | None = None -class OllamaChatRequest(BaseModel, frozen=True): +class OllamaChatRequest(WarnExtraModel, frozen=True): model: ModelId messages: list[OllamaMessage] stream: bool = True @@ -55,7 +56,7 @@ class OllamaChatRequest(BaseModel, frozen=True): think: bool | None = None -class OllamaGenerateRequest(BaseModel, frozen=True): +class OllamaGenerateRequest(WarnExtraModel, frozen=True): model: ModelId prompt: str = "" system: str | None = None @@ -85,7 +86,7 @@ class OllamaGenerateResponse(BaseModel, frozen=True, strict=True): eval_duration: int | None = None -class OllamaShowRequest(BaseModel, frozen=True): +class OllamaShowRequest(WarnExtraModel, frozen=True): name: str | None = None model: str | None = None verbose: bool | None = None diff --git a/src/exo/api/types/openai_responses.py b/src/exo/api/types/openai_responses.py index 753b57d0b7..7cdbdfba46 100644 --- a/src/exo/api/types/openai_responses.py +++ b/src/exo/api/types/openai_responses.py @@ -13,6 +13,7 @@ from exo.shared.types.common import ModelId from exo.shared.types.text_generation import ReasoningEffort +from exo.utils.extra_fields_warner import WarnExtraModel # Type aliases ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"] @@ -307,7 +308,7 @@ class Reasoning(BaseModel, frozen=True): summary: Literal["auto", "concise", "detailed"] | None = None -class ResponsesRequest(BaseModel, frozen=True): +class ResponsesRequest(WarnExtraModel, frozen=True): """Request body for OpenAI Responses API. This is the API wire type for the Responses endpoint. The canonical diff --git a/src/exo/utils/extra_fields_warner.py b/src/exo/utils/extra_fields_warner.py new file mode 100644 index 0000000000..a5b9217b9c --- /dev/null +++ b/src/exo/utils/extra_fields_warner.py @@ -0,0 +1,136 @@ +"""Warn-on-extra-fields mixin for public API request models. + +Pydantic v2 defaults to ``extra="ignore"`` for ``BaseModel``, which silently +drops unknown fields. For internal data structures this is fine, but for +public API request bodies it makes mistyped or out-of-spec fields invisible: +the request is accepted, the offending field is dropped, and the request +runs with whatever defaults the matching declared fields had. + +In practice this has bitten us at the placement layer (``/place_instance``): +a mistyped ``instance_meta`` field was silently dropped, falling back to the +default ``MlxRing`` instead of the requested ``MlxJaccl``, and the only +symptom was lower decode throughput. No warning, no log line, no 4xx. + +This module exposes :class:`WarnExtraModel`, a drop-in base class that keeps +``extra="ignore"`` semantics (so the change is non-breaking for existing +clients) but logs a one-line warning the first time a given +``(model_name, unknown_field)`` pair is seen, rate-limited thereafter to +avoid log-spam from repeated requests. + +A future change can flip ``extra="forbid"`` once the warning has had time +to surface miswired clients in the wild. +""" + +from threading import Lock +from time import monotonic +from typing import Any + +from loguru import logger +from pydantic import AliasChoices, AliasPath, BaseModel, model_validator + +# Default rate-limit window for repeat warnings of the same +# (model_name, field_name) pair. Chosen to be long enough that a hot loop +# of mistyped requests does not flood the log, but short enough that the +# warning re-surfaces after operator attention has likely moved on. +_DEFAULT_RATE_LIMIT_SECONDS: float = 60.0 + +# (class_name, field_name) -> last-emitted monotonic timestamp. +_last_warned: dict[tuple[str, str], float] = {} +_lock = Lock() + + +def _should_emit(key: tuple[str, str], now: float, window: float) -> bool: + """Return True if a warning for ``key`` should be emitted now. + + Rate-limited per key with a fixed window. Threadsafe; intended to be + called from request validation paths that may run concurrently under + asyncio + thread-pool offloading. + """ + with _lock: + last = _last_warned.get(key) + if last is not None and (now - last) < window: + return False + _last_warned[key] = now + return True + + +def _reset_rate_limit_state() -> None: + """Test helper: clear the rate-limit cache. + + Not part of the public API. Tests use this to assert per-key behavior + without coupling to wall-clock timing. + """ + with _lock: + _last_warned.clear() + + +class WarnExtraModel(BaseModel): + """Base class for public API request models that warns on extra fields. + + Subclasses keep the default ``extra="ignore"`` behavior — unknown keys + are still dropped — but a warning is logged when this happens, rate + limited to once per ``(model_name, field_name)`` pair per + ``_DEFAULT_RATE_LIMIT_SECONDS``. + + Why not ``extra="forbid"``? + Existing clients may already be sending fields we silently drop + (whether a typo or a future-spec field they expect us to ignore). + Flipping to ``forbid`` is a breaking change. Warning gives us a + deprecation path: surface the problem in logs, then forbid in a + follow-up. + """ + + @model_validator(mode="before") + @classmethod + def _warn_unknown_fields(cls, data: Any) -> Any: # pyright: ignore[reportAny, reportExplicitAny] + # Validators see whatever the caller passed. For public API requests + # this is normally a dict from JSON, but pydantic also dispatches + # this hook for nested model instances and other types — those are + # not the case we care about, so bail cleanly. + if not isinstance(data, dict): + return data + + known: set[str] = set() + for name, field in cls.model_fields.items(): + known.add(name) + if field.alias is not None: + known.add(field.alias) + va = field.validation_alias + if isinstance(va, str): + known.add(va) + elif isinstance(va, AliasChoices): + for choice in va.choices: + if isinstance(choice, str): + known.add(choice) + elif isinstance(choice, AliasPath): + # AliasPath targets a nested location; the top-level + # key it reads from is the first path element. + first = choice.path[0] if choice.path else None + if isinstance(first, str): + known.add(first) + elif isinstance(va, AliasPath): + first = va.path[0] if va.path else None + if isinstance(first, str): + known.add(first) + + unknown = [k for k in data if k not in known] # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + if not unknown: + return data + + now = monotonic() + cls_name = cls.__name__ + for key in unknown: + rate_key = (cls_name, str(key)) + if _should_emit(rate_key, now, _DEFAULT_RATE_LIMIT_SECONDS): + logger.warning( + "Dropping unknown field {field!r} on {model} request — " + "this field is not declared on the model and will be " + "ignored. Check for a typo or an out-of-spec field. " + "(Subsequent occurrences of this same field on this " + "model are rate-limited to once per {window:.0f}s.)", + field=key, + model=cls_name, + window=_DEFAULT_RATE_LIMIT_SECONDS, + ) + + return data diff --git a/src/exo/utils/tests/test_extra_fields_warner.py b/src/exo/utils/tests/test_extra_fields_warner.py new file mode 100644 index 0000000000..b372793d00 --- /dev/null +++ b/src/exo/utils/tests/test_extra_fields_warner.py @@ -0,0 +1,179 @@ +"""Tests for the warn-on-extra-fields mixin.""" + +from collections.abc import Iterator + +import pytest +from loguru import logger +from pydantic import AliasChoices, AliasPath, BaseModel, Field + +from exo.utils.extra_fields_warner import ( + WarnExtraModel, + _reset_rate_limit_state, # pyright: ignore[reportPrivateUsage] +) + + +@pytest.fixture +def captured_warnings() -> Iterator[list[str]]: + """Capture loguru warnings for the duration of one test. + + Each test gets a fresh rate-limit cache so `(model, field)` pair state + does not leak across tests. + """ + _reset_rate_limit_state() + messages: list[str] = [] + sink_id = logger.add( + lambda msg: messages.append(str(msg)), + level="WARNING", + format="{message}", + ) + try: + yield messages + finally: + logger.remove(sink_id) + _reset_rate_limit_state() + + +class _ChatLikeRequest(WarnExtraModel): + """Stand-in for a real public API request model.""" + + model: str + temperature: float | None = None + max_tokens: int | None = None + + +def test_known_fields_parse_without_warning(captured_warnings: list[str]) -> None: + parsed = _ChatLikeRequest.model_validate( + {"model": "qwen", "temperature": 0.5, "max_tokens": 100} + ) + + assert parsed.model == "qwen" + assert parsed.temperature == 0.5 + assert parsed.max_tokens == 100 + assert captured_warnings == [] + + +def test_unknown_field_is_dropped_and_warned(captured_warnings: list[str]) -> None: + # Mistyped field name — the historical bug case (e.g. `instance_meta` + # arriving as `instanceMeta` on a snake_case model). + parsed = _ChatLikeRequest.model_validate( + {"model": "qwen", "tempreture": 0.5} # typo, codespell-ignore + ) + + # Behavior is unchanged: extra is dropped, request still parses. + assert parsed.model == "qwen" + assert parsed.temperature is None + assert not hasattr(parsed, "tempreture") # codespell-ignore + + # But a warning IS emitted, naming both the model and the field. + assert len(captured_warnings) == 1 + msg = captured_warnings[0] + assert "tempreture" in msg # codespell-ignore + assert "_ChatLikeRequest" in msg + + +def test_repeat_unknown_field_is_rate_limited( + captured_warnings: list[str], +) -> None: + for _ in range(50): + _ChatLikeRequest.model_validate({"model": "qwen", "bogus": True}) + + # All 50 requests parse, but only one warning is logged for this + # (model, field) pair within the rate-limit window. + assert len(captured_warnings) == 1 + assert "bogus" in captured_warnings[0] + + +def test_distinct_unknown_fields_each_warn_once( + captured_warnings: list[str], +) -> None: + _ChatLikeRequest.model_validate({"model": "qwen", "alpha": 1}) + _ChatLikeRequest.model_validate({"model": "qwen", "alpha": 2}) # repeat + _ChatLikeRequest.model_validate({"model": "qwen", "beta": 3}) + + # Two distinct fields → two warnings; the repeat of "alpha" is muted. + assert len(captured_warnings) == 2 + fields_warned_on = { + "alpha" if "alpha" in m else "beta" if "beta" in m else "?" + for m in captured_warnings + } + assert fields_warned_on == {"alpha", "beta"} + + +def test_aliased_field_does_not_warn(captured_warnings: list[str]) -> None: + class _Aliased(WarnExtraModel): + model_id: str = Field(alias="model") + + parsed = _Aliased.model_validate({"model": "qwen"}) + assert parsed.model_id == "qwen" + assert captured_warnings == [] + + +def test_alias_choices_does_not_warn(captured_warnings: list[str]) -> None: + class _MultiAlias(WarnExtraModel): + thinking: bool = Field( + default=False, + validation_alias=AliasChoices("thinking", "enable_thinking"), + ) + + # Either declared name should be accepted without warning. + parsed_a = _MultiAlias.model_validate({"thinking": True}) + parsed_b = _MultiAlias.model_validate({"enable_thinking": True}) + assert parsed_a.thinking is True + assert parsed_b.thinking is True + assert captured_warnings == [] + + +def test_alias_path_top_level_key_does_not_warn( + captured_warnings: list[str], +) -> None: + class _Pathed(WarnExtraModel): + nested_value: int = Field( + default=0, validation_alias=AliasPath("outer", "inner") + ) + + # The validator can only see the top-level keys of the input dict; + # "outer" is the top-level key for an AliasPath that drills into + # `outer.inner`, so it must not trigger the unknown-field warning. + parsed = _Pathed.model_validate({"outer": {"inner": 7}}) + assert parsed.nested_value == 7 + assert captured_warnings == [] + + +def test_subclass_inherits_validator(captured_warnings: list[str]) -> None: + # BenchChatCompletionRequest pattern: a subclass adds new fields and + # should still warn on unknowns inherited or added. + class _BenchChatLike(_ChatLikeRequest): + use_prefix_cache: bool = False + + parsed = _BenchChatLike.model_validate( + {"model": "qwen", "use_prefix_cache": True, "wrong_key": "x"} + ) + + assert parsed.use_prefix_cache is True + assert len(captured_warnings) == 1 + assert "wrong_key" in captured_warnings[0] + # Subclass name appears in the warning, not the parent's. + assert "_BenchChatLike" in captured_warnings[0] + + +def test_non_dict_input_is_passthrough(captured_warnings: list[str]) -> None: + # When pydantic dispatches the validator on a model instance (e.g., + # nested validation), the input is not a dict and should pass through + # without inspection. + inst = _ChatLikeRequest(model="qwen") + again = _ChatLikeRequest.model_validate(inst) + assert again.model == "qwen" + assert captured_warnings == [] + + +def test_does_not_affect_unrelated_basemodels( + captured_warnings: list[str], +) -> None: + # Sanity: a plain BaseModel still has the historical default (no + # warning, no error). The mixin is opt-in. + class _Plain(BaseModel): + x: int + + parsed = _Plain.model_validate({"x": 1, "extra": "ignored"}) + assert parsed.x == 1 + assert captured_warnings == []