Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,34 @@ _Important:_ For OpenAI Codex agents (most likely you!), your environment does n
- `weave/` - Core implementation
- `weave/` - Python package implementation
- `weave/trace_server` - Backend server implementation
- `weave/trace_server_bindings` - Client-side remote server binding (backed by `weave_server_sdk`)

### Client/Server Layering (lint-enforced)

Client code must not use the trace server's interface/model modules. This is
enforced incrementally by ruff `TID251` (`flake8-tidy-imports.banned-api` in
`pyproject.toml`); the server package itself, `tests/`, `scripts/`, and
`trace_server_mock/` are exempt.

- API request/response types come from `weave_server_sdk.models` — generated
from the trace server's OpenAPI spec (the source of truth; the generator
lives in wandb/core `services/weave-trace/tools/codegen`). The package is
TEMPORARILY resolved from test PyPI via `[tool.uv.sources]`.
- Surface the published SDK cannot express yet (endpoints excluded from the
OpenAPI spec, 0.0.1 codegen gaps) lives in
`weave/trace_server_bindings/models.py` as documented gap models — prefer
deleting these when a regenerated SDK covers them. Never fall back to the
server's tsi models in client code.
- Client code outside the bindings package must not import
`weave.trace_server_bindings.models` either (also TID251-banned): API types
come from `weave_server_sdk` directly. Where a gap model is genuinely
unavoidable today, the import carries a spot-level
`# noqa: TID251` with a reason, so the exemption list burns down as the
SDK regenerates.
- In tests, construct `weave_server_sdk.models` (or plain JSON dicts), not
tsi models. The in-process test servers parse incoming requests with their
own types via `tests/trace_server/conftest_lib/request_coercion.py` — the
in-process equivalent of the HTTP seam.

## Generated Files — Do Not Hand-Edit

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,20 @@ exclude = ["rules"]
# list grows as the migration proceeds. The server package itself, tests, and
# scripts are exempt via per-file-ignores below.
"weave.trace_server.http_service_interface".msg = "Client code must not use the server's HTTP body models. Use weave_server_sdk.models (generated from the OpenAPI spec) instead."
"weave.trace_server.trace_server_interface".msg = "Client code must not use the server's request/response models. Use weave_server_sdk.models (generated from the OpenAPI spec) or the gap models in weave.trace_server_bindings.models instead."
"weave.trace_server.interface.query".msg = "Client code must not use the server's query AST models. Use weave_server_sdk.models instead."
"weave.trace_server.common_interface".msg = "Client code must not use the server's common interface models. Use weave_server_sdk.models instead."
"weave.trace_server.service_interface".msg = "Client code must not use the server's service interface. Use weave_server_sdk.models and weave.trace_server_bindings instead."
"weave.trace_server_bindings.models".msg = "Client code must import API types from weave_server_sdk directly. The bindings' gap models exist only because weave-server-sdk 0.0.1 cannot express that surface; where one is unavoidable, use a spot-level noqa with a reason so the exemption burns down when the SDK regenerates."

[tool.ruff.lint.per-file-ignores]
"!/weave/trace/**/*.py" = ["T201"]
"!/tests/**/*.py" = ["RUF059"]
# The trace server may import its own modules; tests exercise the server
# directly; scripts and the mock server are operational server tooling.
# The trace server may import its own modules (and the bindings package its
# own models); tests exercise the server directly; scripts and the mock
# server are operational server tooling.
"weave/trace_server/**/*.py" = ["TID251"]
"weave/trace_server_bindings/**/*.py" = ["TID251"]
"scripts/**/*.py" = ["TID251"]
"trace_server_mock/**/*.py" = ["TID251"]
# Tests intentionally use async functions without await, compare known float
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import weave
from tests.trace.util import DummyTestException
from tests.trace_server.conftest import TEST_ENTITY, get_trace_server_flag
from tests.trace_server.conftest_lib.request_coercion import RequestCoercingTraceServer
from weave.trace import weave_client, weave_init
from weave.trace.context import weave_client_context
from weave.trace.context.call_context import set_call_stack
Expand Down Expand Up @@ -460,7 +461,9 @@ def create_client(
elif trace_server_flag == "http":
server = RemoteHTTPTraceServer(trace_server_flag)
else:
server = trace_server
# The client sends weave_server_sdk models; the in-process servers
# parse them with their own request types, like the HTTP seam would.
server = RequestCoercingTraceServer(trace_server)

# Removing this as it lead to passing tests that were not passing in prod!
# Keeping off for now until it is the default behavior.
Expand Down
2 changes: 1 addition & 1 deletion tests/flow/test_evaluation_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing import TypedDict

import pytest
from weave_server_sdk.models import ObjectVersionFilter

import weave
from weave.evaluation.eval_imperative import EvaluationLogger, Model, Scorer
from weave.integrations.integration_utilities import op_name_from_call
from weave.trace.context import call_context
from weave.trace.serialization.serialize import to_json
from weave.trace_server.trace_server_interface import ObjectVersionFilter


class ExampleRow(TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion tests/flow/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from pydantic import ValidationError
from weave_server_sdk.models import Query

from weave.flow.monitor import Monitor
from weave.scorers import ValidJSONScorer
from weave.trace.api import publish
from weave.trace_server.interface.query import Query


def test_init_pass():
Expand Down
16 changes: 4 additions & 12 deletions tests/trace/data_serialization/test_cases/config_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,15 @@
"sampling_rate": 0.5,
"scorers": [],
"op_names": ["weave.genai.turn_ended"],
# The generated query models type operand lists as list[Any], so
# leaf operands serialize as plain dicts (no _type metadata).
"query": {
"_type": "Query",
"$expr": {
"_type": "GtOperation",
"$gt": [
{
"_type": "GetFieldOperator",
"$getField": "started_at",
"_class_name": "GetFieldOperator",
"_bases": ["BaseModel"],
},
{
"_type": "LiteralOperation",
"$literal": 1742540400,
"_class_name": "LiteralOperation",
"_bases": ["BaseModel"],
},
{"$getField": "started_at"},
{"$literal": 1742540400},
],
"_class_name": "GtOperation",
"_bases": ["BaseModel"],
Expand Down
15 changes: 13 additions & 2 deletions tests/trace/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
TEST_ENTITY = "shawn"

# Attribute names used by each middleware layer to reference the next server.
_NEXT_SERVER_ATTRS = ("server", "_next_trace_server", "_internal_trace_server")
# _server must come first: the flushing proxy and the request-coercion
# wrapper forward unknown attributes to their inner server, so probing the
# other names on them would skip a layer.
_NEXT_SERVER_ATTRS = (
"_server",
"server",
"_next_trace_server",
"_internal_trace_server",
)


def find_server_layer(server: tsi.TraceServerInterface, layer_type: type[T]) -> T:
Expand All @@ -32,7 +40,10 @@ def find_server_layer(server: tsi.TraceServerInterface, layer_type: type[T]) ->
visited.add(obj_id)
next_layer = None
for attr in _NEXT_SERVER_ATTRS:
next_layer = getattr(current, attr, None)
# Probe the instance dict directly: several wrappers forward
# unknown attributes to their inner server, so getattr would
# tunnel through and skip layers.
next_layer = vars(current).get(attr)
if next_layer is not None:
break
current = next_layer
Expand Down
2 changes: 1 addition & 1 deletion tests/trace/test_client_annotation_queue_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import datetime

import pytest
from weave_server_sdk.models import AnnotationQueueItemsFilter, SortBy

from tests.trace.util import client_is_sqlite
from weave.trace.call import Call
from weave.trace.weave_client import WeaveClient
from weave.trace_server.common_interface import AnnotationQueueItemsFilter, SortBy
from weave.trace_server.errors import NotFoundError
from weave.trace_server.trace_server_interface import (
AnnotationQueueAddCallsRes,
Expand Down
2 changes: 1 addition & 1 deletion tests/trace/test_saved_view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime

import pytest
from weave_server_sdk import models as tsi

import weave
from weave.flow.saved_view import (
Expand All @@ -11,7 +12,6 @@
to_seconds,
)
from weave.trace.api import ObjectRef
from weave.trace_server import trace_server_interface as tsi


def test_to_seconds():
Expand Down
4 changes: 3 additions & 1 deletion tests/trace/test_wal_client_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from weave.durability.wal_writer import JSONLWALWriter
from weave.trace import weave_client
from weave.trace.settings import UserSettings, override_settings
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server_bindings import models as tsi


def _read_all_wal_records(client: weave.WeaveClient) -> list[dict]:
Expand Down Expand Up @@ -119,6 +119,7 @@ def test_obj_create(self, wal_client):
"val": {"model": "gpt-4", "temp": 0.7},
"builtin_object_class": None,
"expected_digest": None,
"set_base_object_class": None,
"wb_user_id": None,
}
},
Expand Down Expand Up @@ -398,6 +399,7 @@ def test_write_only_manager(self, tmp_path):
"val": {"hello": "world"},
"builtin_object_class": None,
"expected_digest": None,
"set_base_object_class": None,
"wb_user_id": None,
}
},
Expand Down
24 changes: 14 additions & 10 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import pytest
from pydantic import ValidationError

# `tsi` aliases the server interface in this file (coercion-seam tests);
# the binding-facing tests use the SDK models module unaliased.
from weave_server_sdk import models

import weave
import weave.trace.call
import weave.trace_server.trace_server_interface as tsi
Expand Down Expand Up @@ -1601,8 +1605,8 @@ def test_table_partitioning(network_proxy_client, use_parallel_table_upload):
100 * 1024
) # very large buffer to ensure a single request
res = remote_client.table_create(
tsi.TableCreateReq(
table=tsi.TableSchemaForInsert(
models.TableCreateReq(
table=models.TableSchemaForInsert(
project_id=client.project_id,
rows=rows,
)
Expand Down Expand Up @@ -4380,8 +4384,8 @@ def test_table_create_from_digests(network_proxy_client):

# Create a table with these rows to get the row digests
table_res = client.server.table_create(
tsi.TableCreateReq(
table=tsi.TableSchemaForInsert(
models.TableCreateReq(
table=models.TableSchemaForInsert(
project_id=client.project_id,
rows=rows,
)
Expand All @@ -4393,7 +4397,7 @@ def test_table_create_from_digests(network_proxy_client):

# Now create a new table using the same row digests
from_digests_res = client.server.table_create_from_digests(
tsi.TableCreateFromDigestsReq(
models.TableCreateFromDigestsReq(
project_id=client.project_id,
row_digests=row_digests,
)
Expand All @@ -4411,8 +4415,8 @@ def test_table_create_from_digests(network_proxy_client):
]

more_table_res = client.server.table_create(
tsi.TableCreateReq(
table=tsi.TableSchemaForInsert(
models.TableCreateReq(
table=models.TableSchemaForInsert(
project_id=client.project_id,
rows=more_rows,
)
Expand All @@ -4423,15 +4427,15 @@ def test_table_create_from_digests(network_proxy_client):

# Test with a different order of row digests - should produce different digest
combined_res = client.server.table_create_from_digests(
tsi.TableCreateFromDigestsReq(
models.TableCreateFromDigestsReq(
project_id=client.project_id,
row_digests=combined_digests,
)
)

# now get the new table
new_table_res = basic_client.server.table_query(
tsi.TableQueryReq(
models.TableQueryReq(
project_id=client.project_id,
digest=combined_res.digest,
)
Expand All @@ -4450,7 +4454,7 @@ def test_table_create_from_digests(network_proxy_client):
# Test with a different order of row digests - should produce different digest
shuffled_digests = [row_digests[2], row_digests[0], row_digests[1]] # [3, 1, 2]
shuffled_res = client.server.table_create_from_digests(
tsi.TableCreateFromDigestsReq(
models.TableCreateFromDigestsReq(
project_id=client.project_id,
row_digests=shuffled_digests,
)
Expand Down
87 changes: 87 additions & 0 deletions tests/trace_server/conftest_lib/request_coercion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Generic request-family coercion for the in-process test servers.

The Weave client speaks ``weave_server_sdk`` models (generated from the
trace server's OpenAPI spec). The in-process test servers declare the
server's own request types and rely on ``isinstance`` checks internally
(e.g. the query AST in the ClickHouse query builder), so a foreign-but-
field-compatible pydantic model cannot be passed through structurally.

In production the seam is HTTP: the client serializes to JSON and the server
parses JSON into its own types. This wrapper is the in-process equivalent of
that seam — it serializes whatever request arrives (client SDK model, legacy
model, or plain dict) to JSON-shaped data and lets the wrapped server's own
parameter annotation parse it. ``exclude_none`` matches the SDK's wire
encoding (None means unset).

This is server-side test infrastructure: it introspects the wrapped server's
own signatures and never imports client or SDK types. Responses are returned
unchanged — the server's response models are JSON-shape-identical to the
SDK's, and callers access fields, not types.
"""

from __future__ import annotations

import functools
import inspect
from collections.abc import Callable
from typing import Any, get_type_hints

from pydantic import BaseModel


def _first_param_model(method: Callable[..., Any]) -> type[BaseModel] | None:
"""Return the method's first parameter type if it is a pydantic model."""
try:
sig = inspect.signature(method)
hints = get_type_hints(method)
except (TypeError, ValueError, NameError):
return None
for name in sig.parameters:
if name in {"self", "cls"}:
continue
annotation = hints.get(name)
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
return annotation
return None
return None


class RequestCoercingTraceServer:
"""Parses incoming requests with the wrapped server's own request types."""

def __init__(self, server: Any) -> None:
self._server = server
self._wrapped_methods: dict[str, Callable[..., Any]] = {}

def __getattr__(self, name: str) -> Any:
attr = getattr(self._server, name)
if name.startswith("_") or not callable(attr):
return attr
if name in self._wrapped_methods:
return self._wrapped_methods[name]

expected = _first_param_model(attr)
if expected is None:
return attr

# functools.wraps matters beyond hygiene: the caching middleware
# namespaces cache keys by func.__name__, and invalidation prefixes
# are built from the real method names.
@functools.wraps(attr)
def wrapper(
req: Any = None,
*args: Any,
_m: Any = attr,
_t: Any = expected,
**kwargs: Any,
) -> Any:
if isinstance(req, BaseModel) and not isinstance(req, _t):
req = _t.model_validate(
req.model_dump(by_alias=True, exclude_none=True)
)
elif isinstance(req, dict):
req = _t.model_validate(req)
return _m(req, *args, **kwargs)

self._wrapped_methods[name] = wrapper
return wrapper
2 changes: 1 addition & 1 deletion tests/trace_server_bindings/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import httpx
import pytest
import tenacity
from weave_server_sdk import models as tsi

from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.ids import generate_id
from weave.trace_server_bindings.remote_http_trace_server import (
RemoteHTTPTraceServer,
Expand Down
Loading
Loading