Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/exo/api/types/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from exo.shared.types.text_generation import ReasoningDialect, ReasoningEffort
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.extra_fields_warner import WarnExtraModel
from exo.utils.pydantic_ext import FrozenModel

FinishReason = Literal[
Expand Down Expand Up @@ -204,7 +205,7 @@ class StreamOptions(BaseModel):
include_usage: bool = False


class ChatCompletionRequest(BaseModel):
class ChatCompletionRequest(WarnExtraModel):
model: ModelId
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
Expand Down Expand Up @@ -237,7 +238,7 @@ class BenchChatCompletionRequest(ChatCompletionRequest):
use_prefix_cache: bool = False


class AddCustomModelParams(BaseModel):
class AddCustomModelParams(WarnExtraModel):
model_id: ModelId


Expand All @@ -250,14 +251,14 @@ class HuggingFaceSearchResult(BaseModel):
tags: list[str] = Field(default_factory=list)


class PlaceInstanceParams(BaseModel):
class PlaceInstanceParams(WarnExtraModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1


class CreateInstanceParams(BaseModel):
class CreateInstanceParams(WarnExtraModel):
instance: Instance


Expand All @@ -275,7 +276,7 @@ class PlacementPreviewResponse(BaseModel):
previews: list[PlacementPreview]


class DeleteInstanceTaskParams(BaseModel):
class DeleteInstanceTaskParams(WarnExtraModel):
instance_id: str


Expand All @@ -296,7 +297,7 @@ class CancelCommandResponse(BaseModel):
command_id: CommandId


class InstanceLinkBody(BaseModel):
class InstanceLinkBody(WarnExtraModel):
prefill_instances: list[InstanceId]
decode_instances: list[InstanceId]

Expand Down Expand Up @@ -335,7 +336,7 @@ class AdvancedImageParams(BaseModel):
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None


class ImageGenerationTaskParams(BaseModel):
class ImageGenerationTaskParams(WarnExtraModel):
prompt: str
background: str | None = None
model: str
Expand Down Expand Up @@ -364,7 +365,7 @@ class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True


class ImageEditsTaskParams(BaseModel):
class ImageEditsTaskParams(WarnExtraModel):
"""Internal task params for image-editing requests."""

image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
Expand Down
3 changes: 2 additions & 1 deletion src/exo/api/types/claude_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, Field

from exo.shared.types.common import ModelId
from exo.utils.extra_fields_warner import WarnExtraModel

# Tool definition types
ClaudeToolInputSchema = dict[str, Any]
Expand Down Expand Up @@ -101,7 +102,7 @@ class ClaudeThinkingConfig(BaseModel, frozen=True):
budget_tokens: int | None = None


class ClaudeMessagesRequest(BaseModel):
class ClaudeMessagesRequest(WarnExtraModel):
"""Request body for Claude Messages API."""

model: ModelId
Expand Down
7 changes: 4 additions & 3 deletions src/exo/api/types/ollama_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, Field

from exo.shared.models.model_cards import ModelId
from exo.utils.extra_fields_warner import WarnExtraModel

# https://github.com/ollama/ollama/blob/main/docs/api.md

Expand Down Expand Up @@ -44,7 +45,7 @@ class OllamaOptions(BaseModel, frozen=True):
seed: int | None = None


class OllamaChatRequest(BaseModel, frozen=True):
class OllamaChatRequest(WarnExtraModel, frozen=True):
model: ModelId
messages: list[OllamaMessage]
stream: bool = True
Expand All @@ -55,7 +56,7 @@ class OllamaChatRequest(BaseModel, frozen=True):
think: bool | None = None


class OllamaGenerateRequest(BaseModel, frozen=True):
class OllamaGenerateRequest(WarnExtraModel, frozen=True):
model: ModelId
prompt: str = ""
system: str | None = None
Expand Down Expand Up @@ -85,7 +86,7 @@ class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
eval_duration: int | None = None


class OllamaShowRequest(BaseModel, frozen=True):
class OllamaShowRequest(WarnExtraModel, frozen=True):
name: str | None = None
model: str | None = None
verbose: bool | None = None
Expand Down
3 changes: 2 additions & 1 deletion src/exo/api/types/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from exo.shared.types.common import ModelId
from exo.shared.types.text_generation import ReasoningEffort
from exo.utils.extra_fields_warner import WarnExtraModel

# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
Expand Down Expand Up @@ -307,7 +308,7 @@ class Reasoning(BaseModel, frozen=True):
summary: Literal["auto", "concise", "detailed"] | None = None


class ResponsesRequest(BaseModel, frozen=True):
class ResponsesRequest(WarnExtraModel, frozen=True):
"""Request body for OpenAI Responses API.

This is the API wire type for the Responses endpoint. The canonical
Expand Down
136 changes: 136 additions & 0 deletions src/exo/utils/extra_fields_warner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Warn-on-extra-fields mixin for public API request models.

Pydantic v2 defaults to ``extra="ignore"`` for ``BaseModel``, which silently
drops unknown fields. For internal data structures this is fine, but for
public API request bodies it makes mistyped or out-of-spec fields invisible:
the request is accepted, the offending field is dropped, and the request
runs with whatever defaults the matching declared fields had.

In practice this has bitten us at the placement layer (``/place_instance``):
a mistyped ``instance_meta`` field was silently dropped, falling back to the
default ``MlxRing`` instead of the requested ``MlxJaccl``, and the only
symptom was lower decode throughput. No warning, no log line, no 4xx.

This module exposes :class:`WarnExtraModel`, a drop-in base class that keeps
``extra="ignore"`` semantics (so the change is non-breaking for existing
clients) but logs a one-line warning the first time a given
``(model_name, unknown_field)`` pair is seen, rate-limited thereafter to
avoid log-spam from repeated requests.

A future change can flip ``extra="forbid"`` once the warning has had time
to surface miswired clients in the wild.
"""

from threading import Lock
from time import monotonic
from typing import Any

from loguru import logger
from pydantic import AliasChoices, AliasPath, BaseModel, model_validator

# Default rate-limit window for repeat warnings of the same
# (model_name, field_name) pair. Chosen to be long enough that a hot loop
# of mistyped requests does not flood the log, but short enough that the
# warning re-surfaces after operator attention has likely moved on.
_DEFAULT_RATE_LIMIT_SECONDS: float = 60.0

# (class_name, field_name) -> last-emitted monotonic timestamp.
_last_warned: dict[tuple[str, str], float] = {}
_lock = Lock()


def _should_emit(key: tuple[str, str], now: float, window: float) -> bool:
"""Return True if a warning for ``key`` should be emitted now.

Rate-limited per key with a fixed window. Threadsafe; intended to be
called from request validation paths that may run concurrently under
asyncio + thread-pool offloading.
"""
with _lock:
last = _last_warned.get(key)
if last is not None and (now - last) < window:
return False
_last_warned[key] = now
return True


def _reset_rate_limit_state() -> None:
"""Test helper: clear the rate-limit cache.

Not part of the public API. Tests use this to assert per-key behavior
without coupling to wall-clock timing.
"""
with _lock:
_last_warned.clear()


class WarnExtraModel(BaseModel):
"""Base class for public API request models that warns on extra fields.

Subclasses keep the default ``extra="ignore"`` behavior — unknown keys
are still dropped — but a warning is logged when this happens, rate
limited to once per ``(model_name, field_name)`` pair per
``_DEFAULT_RATE_LIMIT_SECONDS``.

Why not ``extra="forbid"``?
Existing clients may already be sending fields we silently drop
(whether a typo or a future-spec field they expect us to ignore).
Flipping to ``forbid`` is a breaking change. Warning gives us a
deprecation path: surface the problem in logs, then forbid in a
follow-up.
"""

@model_validator(mode="before")
@classmethod
def _warn_unknown_fields(cls, data: Any) -> Any: # pyright: ignore[reportAny, reportExplicitAny]
# Validators see whatever the caller passed. For public API requests
# this is normally a dict from JSON, but pydantic also dispatches
# this hook for nested model instances and other types — those are
# not the case we care about, so bail cleanly.
if not isinstance(data, dict):
return data

known: set[str] = set()
for name, field in cls.model_fields.items():
known.add(name)
if field.alias is not None:
known.add(field.alias)
va = field.validation_alias
if isinstance(va, str):
known.add(va)
elif isinstance(va, AliasChoices):
for choice in va.choices:
if isinstance(choice, str):
known.add(choice)
elif isinstance(choice, AliasPath):
# AliasPath targets a nested location; the top-level
# key it reads from is the first path element.
first = choice.path[0] if choice.path else None
if isinstance(first, str):
known.add(first)
elif isinstance(va, AliasPath):
first = va.path[0] if va.path else None
if isinstance(first, str):
known.add(first)

unknown = [k for k in data if k not in known] # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if not unknown:
return data

now = monotonic()
cls_name = cls.__name__
for key in unknown:
rate_key = (cls_name, str(key))
if _should_emit(rate_key, now, _DEFAULT_RATE_LIMIT_SECONDS):
logger.warning(
"Dropping unknown field {field!r} on {model} request — "
"this field is not declared on the model and will be "
"ignored. Check for a typo or an out-of-spec field. "
"(Subsequent occurrences of this same field on this "
"model are rate-limited to once per {window:.0f}s.)",
field=key,
model=cls_name,
window=_DEFAULT_RATE_LIMIT_SECONDS,
)

return data
Loading